├── src
├── __init__.py
├── utils.py
├── Conv2d.py
└── ConvTranspose2d.py
├── tests
├── __init__.py
├── test_Conv2d.py
└── test_ConvTranspose2d.py
├── pics
├── convolution.jpeg
└── transposed_convolution.jpeg
├── .gitignore
├── LICENSE
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pics/convolution.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/loeweX/Custom-ConvLayers-Pytorch/HEAD/pics/convolution.jpeg
--------------------------------------------------------------------------------
/pics/transposed_convolution.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/loeweX/Custom-ConvLayers-Pytorch/HEAD/pics/transposed_convolution.jpeg
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .idea/
25 |
26 | .DS_Store
27 | */.DS_Store
28 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | from typing import Union, Tuple, Iterable
3 |
4 |
5 | def pair(x: Union[int, Iterable[int]]) -> Tuple[int, int]:
6 | """
7 | If input is iterable (e.g., list or tuple) of length 2, return it as tuple. If input is a single integer, duplicate
8 | it and return as a tuple.
9 |
10 | Arguments:
11 | x: Either an iterable of length 2 or a single integer.
12 |
13 | Returns:
14 | A tuple of length 2.
15 | """
16 | if isinstance(x, collections.abc.Iterable):
17 | return tuple(x)
18 | return tuple(repeat(x, 2))
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Sindy Löwe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/Conv2d.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Union, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | from einops import rearrange
7 |
8 | from src import utils
9 |
10 |
11 | class Conv2d(nn.Module):
12 | """
13 | Step-by-step implementation of a 2D convolutional layer.
14 |
15 | Arguments:
16 | in_channels (int): Number of input channels.
17 | out_channels (int): Number of output channels produced by the convolution.
18 | kernel_size (int or tuple): Size of the convolutional kernel.
19 | stride (int or tuple, optional): Stride of the convolution. Default: 1
20 | padding (int or tuple, optional): Zero-padding added to the input. Default: 0
21 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
22 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True
23 | """
24 |
25 | def __init__(
26 | self,
27 | in_channels: int,
28 | out_channels: int,
29 | kernel_size: Union[int, Tuple[int, int]],
30 | stride: Union[int, Tuple[int, int]] = 1,
31 | padding: Union[int, Tuple[int, int]] = 0,
32 | dilation: Union[int, Tuple[int, int]] = 1,
33 | bias: bool = True,
34 | ):
35 | super(Conv2d, self).__init__()
36 |
37 | self.in_channels = in_channels
38 | self.out_channels = out_channels
39 | self.kernel_size = utils.pair(kernel_size)
40 | self.stride = utils.pair(stride)
41 | self.padding = utils.pair(padding)
42 | self.dilation = utils.pair(dilation)
43 |
44 | self.weight = nn.Parameter(
45 | torch.empty(
46 | (
47 | self.out_channels,
48 | self.in_channels * self.kernel_size[0] * self.kernel_size[1],
49 | )
50 | )
51 | )
52 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
53 |
54 | if bias:
55 | self.bias = nn.Parameter(torch.zeros((self.out_channels)))
56 |
57 | def get_output_size(self, input_size: int, idx: int) -> int:
58 | """
59 | Calculates the output size (i.e. feature map size) of the convolutional layer.
60 |
61 | Arguments:
62 | input_size: Height or width of input tensor.
63 | idx: Index to choose between height and width values (0 = height, 1 = width).
64 |
65 | Returns:
66 | Output size of the tensor after performing convolution given the input tensor.
67 | """
68 | return (
69 | input_size
70 | + 2 * self.padding[idx]
71 | - self.dilation[idx] * (self.kernel_size[idx] - 1)
72 | - 1
73 | ) // self.stride[idx] + 1
74 |
75 | def forward(self, x: torch.Tensor) -> torch.Tensor:
76 | """
77 | Defines forward pass of the convolutional layer.
78 |
79 | Arguments:
80 | x: Input tensor.
81 |
82 | Returns:
83 | The tensor obtained after applying the convolutional layer.
84 | """
85 | height, width = x.shape[-2:]
86 |
87 | # Patchify input according to hyperparameters of convolutional layer.
88 | x = torch.nn.functional.unfold(
89 | x,
90 | self.kernel_size,
91 | padding=self.padding,
92 | stride=self.stride,
93 | dilation=self.dilation,
94 | )
95 |
96 | # Apply weight matrix.
97 | output = torch.einsum("b i p, o i -> b o p", x, self.weight)
98 |
99 | # Rearrange output to (b, c, h, w).
100 | output_height = self.get_output_size(height, 0)
101 | output_width = self.get_output_size(width, 1)
102 | output = rearrange(
103 | output, "b o (h w) -> b o h w", h=output_height, w=output_width
104 | )
105 |
106 | if hasattr(self, "bias"):
107 | return output + self.bias[None, :, None, None]
108 |
109 | return output
110 |
--------------------------------------------------------------------------------
/tests/test_Conv2d.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn as nn
4 | from einops import rearrange
5 |
6 | from src import Conv2d
7 |
8 |
9 | @pytest.fixture
10 | def conv_params():
11 | """
12 | Fixture for creating predefined set of parameters for a convolutional layer.
13 |
14 | Returns:
15 | A dictionary of parameters for a convolutional layer.
16 | """
17 | return {
18 | "batch_size": 8,
19 | "in_channels": 32,
20 | "height": 16,
21 | "width": 16,
22 | "out_channels": 64,
23 | "kernel_size": (3, 2),
24 | "stride": (2, 3),
25 | "padding": (1, 0),
26 | "dilation": (3, 1),
27 | }
28 |
29 |
30 | def test_Conv2d(conv_params):
31 | """
32 | Test to ensure that the output of the custom Conv2d class is equal to the
33 | output of the PyTorch Conv2d implementation.
34 |
35 | Arguments:
36 | conv_params: Parameters for the convolutional layers.
37 | """
38 | x = torch.rand(
39 | conv_params["batch_size"],
40 | conv_params["in_channels"],
41 | conv_params["height"],
42 | conv_params["width"],
43 | )
44 |
45 | pytorch_conv = torch.nn.Conv2d(
46 | conv_params["in_channels"],
47 | conv_params["out_channels"],
48 | conv_params["kernel_size"],
49 | conv_params["stride"],
50 | conv_params["padding"],
51 | conv_params["dilation"],
52 | )
53 | output_pytorch = pytorch_conv(x)
54 |
55 | custom_conv = Conv2d.Conv2d(
56 | conv_params["in_channels"],
57 | conv_params["out_channels"],
58 | conv_params["kernel_size"],
59 | conv_params["stride"],
60 | conv_params["padding"],
61 | conv_params["dilation"],
62 | )
63 |
64 | # Set the weight parameter of custom Conv2D to match with PyTorch Conv2d.
65 | custom_conv.weight = nn.Parameter(
66 | rearrange(pytorch_conv.weight, "o i h w -> o (i h w)")
67 | )
68 | custom_conv.bias = nn.Parameter(pytorch_conv.bias)
69 |
70 | output_custom = custom_conv(x)
71 |
72 | # Assert both the outputs are similar (allowing some tolerance).
73 | torch.testing.assert_close(output_custom, output_pytorch, rtol=1e-4, atol=1e-4)
74 |
75 |
76 | def test_Conv2d_gradients(conv_params):
77 | """
78 | Test the gradients calculated by the custom Conv2d layer against the native
79 | PyTorch implementation.
80 |
81 | Arguments:
82 | conv_params: Parameters for the convolutional layers.
83 | """
84 | x = torch.rand(
85 | conv_params["batch_size"],
86 | conv_params["in_channels"],
87 | conv_params["height"],
88 | conv_params["width"],
89 | dtype=torch.float,
90 | requires_grad=True,
91 | )
92 |
93 | pytorch_conv = torch.nn.Conv2d(
94 | conv_params["in_channels"],
95 | conv_params["out_channels"],
96 | conv_params["kernel_size"],
97 | conv_params["stride"],
98 | conv_params["padding"],
99 | conv_params["dilation"],
100 | )
101 | output_pytorch = pytorch_conv(x)
102 |
103 | target = torch.rand_like(output_pytorch)
104 | loss_pytorch = (output_pytorch - target).sum()
105 | loss_pytorch.backward()
106 |
107 | custom_conv = Conv2d.Conv2d(
108 | conv_params["in_channels"],
109 | conv_params["out_channels"],
110 | conv_params["kernel_size"],
111 | conv_params["stride"],
112 | conv_params["padding"],
113 | conv_params["dilation"],
114 | )
115 |
116 | # Set the weight parameter of custom Conv2D to match with PyTorch Conv2d.
117 | custom_conv.weight = nn.Parameter(
118 | rearrange(pytorch_conv.weight, "o i h w -> o (i h w)")
119 | )
120 | custom_conv.bias = nn.Parameter(pytorch_conv.bias)
121 |
122 | output_custom = custom_conv(x)
123 | loss_custom = (output_custom - target).sum()
124 | loss_custom.backward()
125 |
126 | # Compare gradients.
127 | torch.testing.assert_close(
128 | pytorch_conv.weight.grad.reshape(conv_params["out_channels"], -1),
129 | custom_conv.weight.grad,
130 | )
131 | torch.testing.assert_close(pytorch_conv.bias.grad, custom_conv.bias.grad)
132 |
--------------------------------------------------------------------------------
/src/ConvTranspose2d.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Union, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | from einops import rearrange
7 |
8 | from src import utils
9 |
10 |
11 | class ConvTranspose2d(nn.Module):
12 | """
13 | Step-by-step implementation of a 2D transposed convolutional layer.
14 |
15 | Arguments:
16 | in_channels (int): Number of input channels.
17 | out_channels (int): Number of output channels produced by the convolution.
18 | kernel_size (int or tuple): Size of the convolutional kernel.
19 | stride (int or tuple, optional): Stride of the convolution. Default: 1
20 | padding (int or tuple, optional): Zero-padding added to the input. Default: 0
21 | output_padding (int or tuple, optional): Additional padding added to output. Default: 0
22 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
23 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True
24 | """
25 |
26 | def __init__(
27 | self,
28 | in_channels: int,
29 | out_channels: int,
30 | kernel_size: Union[int, Tuple[int, int]],
31 | stride: Union[int, Tuple[int, int]] = 1,
32 | padding: Union[int, Tuple[int, int]] = 0,
33 | output_padding: Union[int, Tuple[int, int]] = 0,
34 | dilation: Union[int, Tuple[int, int]] = 1,
35 | bias: bool = True,
36 | ):
37 | super(ConvTranspose2d, self).__init__()
38 |
39 | self.in_channels = in_channels
40 | self.out_channels = out_channels
41 | self.kernel_size = utils.pair(kernel_size)
42 | self.stride = utils.pair(stride)
43 | self.padding = utils.pair(padding)
44 | self.output_padding = utils.pair(output_padding)
45 | self.dilation = utils.pair(dilation)
46 |
47 | self.weight = nn.Parameter(
48 | torch.empty(
49 | (
50 | self.in_channels,
51 | self.out_channels * self.kernel_size[0] * self.kernel_size[1],
52 | )
53 | )
54 | )
55 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
56 |
57 | if bias:
58 | self.bias = nn.Parameter(torch.zeros((self.out_channels)))
59 |
60 | def get_output_size(self, input_size: int, idx: int) -> int:
61 | """
62 | Calculates the output size (i.e. feature map size) of the transposed convolutional layer.
63 |
64 | Arguments:
65 | input_size: Height or width of input tensor.
66 | idx: Index to choose between height and width values (0 = height, 1 = width).
67 |
68 | Returns:
69 | Output size of the tensor after performing transposed convolution given the input tensor.
70 | """
71 | return (
72 | (input_size - 1) * self.stride[idx]
73 | - 2 * self.padding[idx]
74 | + self.dilation[idx] * (self.kernel_size[idx] - 1)
75 | + self.output_padding[idx]
76 | + 1
77 | )
78 |
79 | def forward(self, x: torch.Tensor) -> torch.Tensor:
80 | """
81 | Defines forward pass of the transposed convolutional layer.
82 |
83 | Arguments:
84 | x: Input tensor.
85 |
86 | Returns:
87 | The tensor obtained after applying the convolutional layer.
88 | """
89 | height, width = x.shape[-2:]
90 |
91 | # Rearrange x to (b, c, h*w).
92 | x = rearrange(x, "b c h w -> b c (h w)")
93 |
94 | # Apply weight matrix.
95 | output = torch.einsum("b i p, i o -> b o p", x, self.weight)
96 |
97 | # "Depatchify": Combine patches by summing overlapping values.
98 | output_size = (
99 | self.get_output_size(height, 0),
100 | self.get_output_size(width, 1),
101 | )
102 | output = torch.nn.functional.fold(
103 | output,
104 | output_size,
105 | stride=self.stride,
106 | kernel_size=self.kernel_size,
107 | padding=self.padding,
108 | dilation=self.dilation,
109 | )
110 |
111 | if hasattr(self, "bias"):
112 | return output + self.bias[None, :, None, None]
113 |
114 | return output
115 |
--------------------------------------------------------------------------------
/tests/test_ConvTranspose2d.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn as nn
4 | from einops import rearrange
5 |
6 | from src import ConvTranspose2d
7 |
8 |
9 | @pytest.fixture
10 | def conv_params():
11 | """
12 | Fixture for creating predefined set of parameters for transposed convolutional layers.
13 |
14 | Returns:
15 | A dictionary of parameters for transposed convolutional layers.
16 | """
17 | return {
18 | "batch_size": 8,
19 | "in_channels": 32,
20 | "height": 16,
21 | "width": 16,
22 | "out_channels": 64,
23 | "kernel_size": (3, 2),
24 | "stride": (2, 3),
25 | "padding": (1, 0),
26 | "dilation": (3, 1),
27 | "output_padding": (1, 2),
28 | }
29 |
30 |
31 | def test_ConvTranspose2d(conv_params):
32 | """
33 | Test to ensure that the output of the custom ConvTranspose2d class is equal to the
34 | output of the PyTorch ConvTranspose2d implementation.
35 |
36 | Arguments:
37 | conv_params: Parameters for transposed convolutional layers.
38 | """
39 | x = torch.rand(
40 | conv_params["batch_size"],
41 | conv_params["in_channels"],
42 | conv_params["height"],
43 | conv_params["width"],
44 | )
45 |
46 | pytorch_conv_tran = torch.nn.ConvTranspose2d(
47 | conv_params["in_channels"],
48 | conv_params["out_channels"],
49 | conv_params["kernel_size"],
50 | conv_params["stride"],
51 | conv_params["padding"],
52 | conv_params["output_padding"],
53 | dilation=conv_params["dilation"],
54 | )
55 | output_pytorch = pytorch_conv_tran(x)
56 |
57 | custom_conv_tran = ConvTranspose2d.ConvTranspose2d(
58 | conv_params["in_channels"],
59 | conv_params["out_channels"],
60 | conv_params["kernel_size"],
61 | conv_params["stride"],
62 | conv_params["padding"],
63 | conv_params["output_padding"],
64 | dilation=conv_params["dilation"],
65 | )
66 |
67 | # Set the weight parameter of custom ConvTranspose2D to match with PyTorch ConvTranspose2d.
68 | custom_conv_tran.weight = nn.Parameter(
69 | rearrange(pytorch_conv_tran.weight, "i o h w -> i (o h w)")
70 | )
71 | custom_conv_tran.bias = nn.Parameter(pytorch_conv_tran.bias)
72 |
73 | output_custom = custom_conv_tran(x)
74 |
75 | torch.testing.assert_close(output_custom, output_pytorch, rtol=1e-4, atol=1e-4)
76 |
77 |
78 | def test_ConvTranspose2d_gradoemts(conv_params):
79 | """
80 | Test the gradients calculated by the custom ConvTranspose2d layer against the native
81 | PyTorch implementation.
82 |
83 | Arguments:
84 | conv_params: Parameters for transposed convolutional layers.
85 | """
86 | x = torch.rand(
87 | conv_params["batch_size"],
88 | conv_params["in_channels"],
89 | conv_params["height"],
90 | conv_params["width"],
91 | dtype=torch.float,
92 | requires_grad=True,
93 | )
94 |
95 | pytorch_conv_tran = torch.nn.ConvTranspose2d(
96 | conv_params["in_channels"],
97 | conv_params["out_channels"],
98 | conv_params["kernel_size"],
99 | conv_params["stride"],
100 | conv_params["padding"],
101 | conv_params["output_padding"],
102 | dilation=conv_params["dilation"],
103 | )
104 | output_pytorch = pytorch_conv_tran(x)
105 |
106 | target = torch.rand_like(output_pytorch)
107 | loss_pytorch = (output_pytorch - target).sum()
108 | loss_pytorch.backward()
109 |
110 | custom_conv_tran = ConvTranspose2d.ConvTranspose2d(
111 | conv_params["in_channels"],
112 | conv_params["out_channels"],
113 | conv_params["kernel_size"],
114 | conv_params["stride"],
115 | conv_params["padding"],
116 | conv_params["output_padding"],
117 | dilation=conv_params["dilation"],
118 | )
119 |
120 | # Set the weight parameter of custom ConvTranspose2D to match with PyTorch ConvTranspose2d.
121 | custom_conv_tran.weight = nn.Parameter(
122 | rearrange(pytorch_conv_tran.weight, "i o h w -> i (o h w)")
123 | )
124 | custom_conv_tran.bias = nn.Parameter(pytorch_conv_tran.bias)
125 |
126 | output_custom = custom_conv_tran(x)
127 | loss_custom = (output_custom - target).sum()
128 | loss_custom.backward()
129 |
130 | torch.testing.assert_close(
131 | pytorch_conv_tran.weight.grad.reshape(conv_params["in_channels"], -1),
132 | custom_conv_tran.weight.grad,
133 | )
134 | torch.testing.assert_close(pytorch_conv_tran.bias.grad, custom_conv_tran.bias.grad)
135 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch 2D Convolutional and Transposed Convolutional Layers Reimplementation
2 |
3 | In this repository, you'll find a custom-built reimplementation of the 2D convolutional and transposed convolutional
4 | layers in PyTorch using the `torch.nn.functional.fold` and `torch.nn.functional.unfold` functions.
5 |
6 | By implementing these layers step-by-step, we can better understand their inner workings and modify them
7 | more easily.
8 | For instance, with this reimplementation, it is possible to implement variations where kernels with different weights
9 | are applied to each spatial location.
10 |
11 | For ease of use, the custom layers have the same interface as the PyTorch layers.
12 |
13 |
14 | ## Custom Conv2d Layer
15 |
16 | The custom `Conv2d` class in this repository is a reimplementation of the PyTorch's built-in
17 | 2D Convolutional layer (`torch.nn.Conv2d`).
18 |
19 | ### Usage
20 |
21 | ```python
22 | import torch
23 | from src import Conv2d
24 |
25 | # Initialize convolutional layer using custom implementation.
26 | reimplemented_conv = Conv2d.Conv2d(
27 | in_channels=in_channels,
28 | out_channels=out_channels,
29 | kernel_size=kernel_size,
30 | stride=stride,
31 | padding=padding,
32 | dilation=dilation,
33 | )
34 |
35 | # Apply layer to input tensor.
36 | input_tensor = torch.rand(batch_size, in_channels, height, width)
37 | output_tensor = reimplemented_conv(input_tensor)
38 | ```
39 |
40 |
41 | ### Custom Conv2d Layer Behavior
42 |
43 | The custom 2D convolution layer implements the following steps:
44 |
45 | 1. Apply the PyTorch function `torch.nn.functional.unfold` to generate local blocks by patchifying the input tensor based on the hyperparameters.
46 |
47 | 2. Multiply each element of the kernel with its corresponding element in each patch and sum up the results. The summed result for each spatial location constitutes the output feature map.
48 |
49 | 3. Add a bias, if provided, to the obtained result to produce the final output.
50 |
51 |
52 | For example, in the illustration below, the unfold operation extracts local blocks of size 2x2 (kernel_size=2)
53 | shifted by one (stride=1), without padding (padding=0) and without spacing between elements (dilation=1).
54 | Then, each element of the kernel (0,1,2,3) is applied to each element of the current spatial block
55 | (highlighted in blue; 3,4,6,7) and the products are summed to calculate the output value at the
56 | current location: 0x3 + 1x4 + 2x6 + 3x7 = 37.
57 |
58 |
59 |
60 |
61 |
62 | ## Custom ConvTranspose2d Layer
63 |
64 | The `ConvTranspose2d` class is a reimplemented version of the PyTorch's built-in 2D transposed convolutional layer (`torch.nn.ConvTranspose2d`).
65 |
66 |
67 | ### Usage
68 |
69 | ```python
70 | import torch
71 | from src import ConvTranspose2d
72 |
73 | # Initialize transposed convolutional layer using custom implementation.
74 | reimplemented_conv_tran = ConvTranspose2d.ConvTranspose2d(
75 | in_channels=in_channels,
76 | out_channels=out_channels,
77 | kernel_size=kernel_size,
78 | stride=stride,
79 | padding=padding,
80 | output_padding=output_padding,
81 | dilation=dilation,
82 | )
83 |
84 | # Apply layer to input tensor.
85 | input_tensor = torch.rand(batch_size, in_channels, height, width)
86 | output_tensor = reimplemented_conv_tran(input_tensor)
87 | ```
88 |
89 |
90 |
91 | ### Custom ConvTranspose2d Layer Behavior
92 |
93 | The custom 2D transposed convolution layer operates in the reverse order of the Conv2d layer:
94 |
95 | 1. Each element of the input tensor is multiplied by all elements of the kernel, creating a spatial output per input value.
96 |
97 | 2. The `torch.nn.functional.fold` function is applied to the results of the previous operation, summing up overlapping values to create the output tensor while considering the layer's hyperparameters.
98 |
99 | 3. If provided, a bias is added to this result to produce the final output.
100 |
101 | For example, highlighted in blue in the illustration below, we first apply each value of the kernel (0,1,2,3)
102 | to the current input value (2), to get four output values (0x2=0, 1x2=2 2x2=4, 3x2=6). Thus, the outputs of this
103 | operation are of size 2x2 (kernel_size=2) for each input value. Then, these results are rearranged and recombined by
104 | spatially overlapping them by one row/column (stride=1), applying no padding or output_padding (padding=0,
105 | output_padding=0), and introducing no spacing between elements (dilation=1). The center value in the output, for example,
106 | is the result of summing the values of the four overlapping outputs of the previous operation: 0 + 2 + 2 + 0 = 4.
107 |
108 |
109 |
110 |
111 |
112 | ## Dependencies
113 |
114 | The project requires Python along with the following libraries: PyTorch, einops, and pytest.
115 |
116 |
117 |
118 | ## Testing
119 |
120 | This repository includes tests which ensure that the reimplemented layers achieve the same results
121 | as the original PyTorch layers during the forward and backward passes.
122 |
123 | To run these tests, use the following command:
124 |
125 | ```bash
126 | pytest tests/
127 | ```
128 |
129 |
130 |
--------------------------------------------------------------------------------