├── .gitignore ├── HISTORY.rst ├── LICENSE ├── README.md ├── setup.cfg ├── setup.py └── torchcomplex ├── __init__.py ├── nn ├── __init__.py ├── functional.py ├── init.py └── modules │ ├── __init__.py │ ├── activation.py │ ├── batchnorm.py │ ├── conv.py │ ├── dropout.py │ ├── linear.py │ ├── pooling.py │ └── upsampling.py └── utils ├── __init__.py ├── signaltools.py └── support_funcs.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .vscode/launch.json 131 | .vscode/settings.json 132 | tester.py 133 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | History 2 | ======= 3 | 4 | 0.1.2 (2023-06-19) 5 | ------------------ 6 | 7 | * Support Function "clamp" added 8 | 9 | 0.1.1 (2023-06-16) 10 | ------------------ 11 | 12 | * modReLU implementation changed 13 | * New activation functions added: Hirose and modSigmoid 14 | 15 | 0.1.0 (2022-11-06) 16 | ------------------ 17 | 18 | * Default behaviour of complex_weights has been changed to True 19 | * Basic max and avg functional callers added 20 | * Bug fix with the new pooling functionals 21 | 22 | 0.0.8 (2020-12-08) 23 | ------------------ 24 | 25 | * Complex Weight Initializations added (Few basics and trabelsi_standard, trabelsi_independent) 26 | 27 | 0.0.1 (2020-11-06) 28 | ------------------ 29 | 30 | * First release (of the package) on PyPI. 31 | * torchcomplex.nn -> Convolutions, Linears, Dropout, Pooling, BatchNorm, Few Actications 32 | * Untested first version, could be buggy 33 | 34 | 0.0.0 (2020-11-06) 35 | ------------------ 36 | 37 | * First release (basic structure, not the actual package code) on PyPI. 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020, Soumick Chatterjee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-complex 2 | 3 | Install it using pip: 4 | 5 | pip install pytorch-complex 6 | 7 | Usage: 8 | Similar to PyTorch. 9 | For using the Complex features of this library, just change the regular torch imports with torchcomplex imports. 10 | For example: 11 | import torchcomplex.nn as nn instead of import torch.nn as nn 12 | Then, simply nn.Conv2d for both torch and torchcomplex, for 2D Convolution 13 | 14 | ## Credits 15 | 16 | If you like this repository, please click on Star! 17 | 18 | If you use this package or benift from the codes of this repo, please cite the following in your publications: 19 | 20 | > [Soumick Chatterjee, Pavan Tummala, Oliver Speck, Andreas Nürnberger: Complex Network for Complex Problems: A comparative study of CNN and Complex-valued CNN (IEEE IPAS, Dec 2023)](https://arxiv.org/abs/2302.04584) 21 | 22 | BibTeX entry: 23 | 24 | ```bibtex 25 | @inproceedings{chatterjee2022complex, 26 | title={Complex Network for Complex Problems: A comparative study of CNN and Complex-valued CNN}, 27 | author={Chatterjee, Soumick and Tummala, Pavan and Speck, Oliver and N{\"u}rnberger, Andreas}, 28 | booktitle={2022 IEEE 5th International Conference on Image Processing Applications and Systems (IPAS)}, 29 | pages={1--5}, 30 | year={2022}, 31 | organization={IEEE} 32 | } 33 | ``` 34 | Thank you so much for your support. 35 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [egg_info] 2 | tag_build = 3 | tag_date = 0 4 | 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """The setup script.""" 4 | 5 | import setuptools 6 | 7 | with open("README.md", "r", encoding='utf8') as fh: 8 | readme = fh.read() 9 | 10 | with open('HISTORY.rst', "r", encoding='utf8') as history_file: 11 | history = history_file.read() 12 | 13 | requirements = [ 14 | 15 | ] 16 | 17 | setup_requirements = [ 18 | ] 19 | 20 | tests_requirements = [ 21 | ] 22 | 23 | setuptools.setup( 24 | name="pytorch-complex", 25 | version="0.1.2", 26 | author="Soumick Chatterjee", 27 | author_email="soumick.chatterjee@ovgu.de", 28 | description="Complex Modules for PyTorch", 29 | long_description=readme + '\n\n' + history, 30 | long_description_content_type="text/markdown", 31 | url="https://github.com/soumickmj/pytorch-complex", 32 | packages=setuptools.find_packages(include=['torchcomplex', 'torchcomplex.*']), 33 | classifiers=[ 34 | "Programming Language :: Python :: 3.6", 35 | "Programming Language :: Python :: 3.7", 36 | "Programming Language :: Python :: 3.8", 37 | "Programming Language :: Python :: 3.9", 38 | "Programming Language :: Python :: 3.10", 39 | "Programming Language :: Python :: 3.11", 40 | "License :: OSI Approved :: MIT License", 41 | "Operating System :: OS Independent", 42 | ], 43 | python_requires='>=3.6', 44 | install_requires=requirements, 45 | setup_requires=setup_requirements, 46 | tests_require=tests_requirements, 47 | license='MIT license', 48 | include_package_data=True, 49 | ) 50 | -------------------------------------------------------------------------------- /torchcomplex/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /torchcomplex/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from . import init -------------------------------------------------------------------------------- /torchcomplex/nn/functional.py: -------------------------------------------------------------------------------- 1 | r"""Functional interface""" 2 | import warnings 3 | import math 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.nn import ParameterList 9 | from ..utils.signaltools import resample 10 | # from torch._C import _infer_size, _add_docstr 11 | # from torch.nn import _reduction as _Reduction 12 | # from torch.nn.modules import utils 13 | # from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default 14 | # from torch.nn import grad # noqa: F401 15 | # from torch import _VF 16 | # from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple 17 | # from torch.overrides import has_torch_function, handle_torch_function 18 | # from torch._torch_docs import reproducibility_notes, tf32_notes 19 | 20 | Tensor = torch.Tensor 21 | 22 | def complex_fcaller(funtinal_handle, *args): 23 | return torch.view_as_complex(torch.stack((funtinal_handle(args[0].real, *args[1:]), funtinal_handle(args[0].imag, *args[1:])),dim=-1)) 24 | 25 | # Convolutions 26 | '''Following: https://arxiv.org/pdf/1705.09792.pdf 27 | 28 | ''' 29 | def _fcaller(funtinal_handle, *args): 30 | # For Convs: 0 input, 1 weight, 2 bias, 3 stride, 4 padding, 5 dilation, 6 groups 31 | # For ConvTrans: 0 input, 1 weight, 2 bias, 3 stride, 4 padding, 5 output_padding, 6 groups, 7 dilation 32 | 33 | # As PyTorch functional API only suports computations as Real-valued data, everything is converetd as Real representation of complex 34 | if type(args[0]) is tuple: #only incase of bilinear 35 | inp1 = torch.view_as_real(args[0][0]) 36 | inp1_r = inp1[...,0] 37 | inp1_i = inp1[...,1] 38 | inp2 = torch.view_as_real(args[0][0]) 39 | inp2_r = inp2[...,0] 40 | inp2_i = inp2[...,1] 41 | else: 42 | inp = torch.view_as_real(args[0]) 43 | inp_r = inp[...,0] 44 | inp_i = inp[...,1] 45 | if type(args[1]) is ParameterList: 46 | w_r = args[1][0] 47 | w_i = args[1][1] 48 | if args[2] is not None: 49 | b_r = args[2][0] 50 | b_i = args[2][1] 51 | else: 52 | b_r = None 53 | b_i = None 54 | else: 55 | w = torch.view_as_real(args[1]) 56 | w_r = w[...,0] 57 | w_i = w[...,1] 58 | if args[2] is not None: 59 | b = torch.view_as_real(args[2]) 60 | b_r = b[...,0] 61 | b_i = b[...,1] 62 | else: 63 | b_r = None 64 | b_i = None 65 | 66 | # Perform complex valued convolution 67 | if type(args[0]) is tuple: #only incase of bilinear 68 | MrKr = funtinal_handle(inp1_r, inp2_r, w_r, b_r, *args[3:]) #Real Feature Maps *(conv) Real Kernels 69 | MiKi = funtinal_handle(inp1_i, inp2_i, w_i, b_i, *args[3:]) #Imaginary Feature Maps * Imaginary Kernels 70 | MrKi = funtinal_handle(inp1_r, inp2_r, w_i, b_i, *args[3:]) #Real Feature Maps * Imaginary Kernels 71 | MiKr = funtinal_handle(inp1_i, inp2_i, w_r, b_r, *args[3:]) #Imaginary Feature Maps * Real Kernels 72 | else: 73 | MrKr = funtinal_handle(inp_r, w_r, b_r, *args[3:]) #Real Feature Maps *(conv) Real Kernels 74 | MiKi = funtinal_handle(inp_i, w_i, b_i, *args[3:]) #Imaginary Feature Maps * Imaginary Kernels 75 | MrKi = funtinal_handle(inp_r, w_i, b_i, *args[3:]) #Real Feature Maps * Imaginary Kernels 76 | MiKr = funtinal_handle(inp_i, w_r, b_r, *args[3:]) #Imaginary Feature Maps * Real Kernels 77 | real = MrKr - MiKi 78 | imag = MrKi + MiKr 79 | out = torch.view_as_complex(torch.stack((real,imag),dim=-1)) 80 | 81 | return out 82 | 83 | #Convolutions 84 | 85 | def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor: 86 | return _fcaller(F.conv1d, input, weight, bias, stride, padding, dilation, groups) 87 | 88 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor: 89 | return _fcaller(F.conv2d, input, weight, bias, stride, padding, dilation, groups) 90 | 91 | def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor: 92 | return _fcaller(F.conv3d, input, weight, bias, stride, padding, dilation, groups) 93 | 94 | def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor: 95 | return _fcaller(F.conv_transpose1d, input, weight, bias, stride, padding, output_padding, groups, dilation) 96 | 97 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor: 98 | return _fcaller(F.conv_transpose2d, input, weight, bias, stride, padding, output_padding, groups, dilation) 99 | 100 | def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor: 101 | return _fcaller(F.conv_transpose3d, input, weight, bias, stride, padding, output_padding, groups, dilation) 102 | 103 | #Poolings 104 | def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) -> Tensor: 105 | return complex_fcaller(F.max_pool1d, input, kernel_size, stride, padding, dilation, ceil_mode, return_indices) 106 | 107 | def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) -> Tensor: 108 | return complex_fcaller(F.max_pool2d, input, kernel_size, stride, padding, dilation, ceil_mode, return_indices) 109 | 110 | def max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) -> Tensor: 111 | return complex_fcaller(F.max_pool3d, input, kernel_size, stride, padding, dilation, ceil_mode, return_indices) 112 | 113 | def avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor: 114 | return complex_fcaller(F.avg_pool1d, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) 115 | 116 | def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor: 117 | return complex_fcaller(F.avg_pool2d, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) 118 | 119 | def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor: 120 | return complex_fcaller(F.avg_pool3d, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) 121 | 122 | # Linear 123 | def linear(input, weight, bias=None): 124 | return _fcaller(F.linear, input, weight, bias) 125 | 126 | def bilinear(input1, input2, weight, bias=None): 127 | return _fcaller(F.bilinear, (input1, input2), weight, bias) 128 | 129 | 130 | # Batch Normalizatin 131 | def _whiten2x2(tensor, training=True, running_mean=None, running_cov=None, 132 | momentum=0.1, nugget=1e-5): 133 | r"""Solve R M R = I for R and a given 2x2 matrix M = [[a, b], [c, d]]. 134 | 135 | Source: https://github.com/ivannz/cplxmodule/blob/master/cplxmodule/nn/modules/batchnorm.py 136 | 137 | Arguments 138 | --------- 139 | tensor : torch.tensor 140 | The input data expected to be at least 3d, with shape [2, B, F, ...], 141 | where `B` is the batch dimension, `F` -- the channels/features, 142 | `...` -- the spatial dimensions (if present). The leading dimension 143 | `2` represents real and imaginary components (stacked). 144 | 145 | training : bool, default=True 146 | Determines whether to update running feature statistics, if they are 147 | provided, or use them instead of batch computed statistics. If `False` 148 | then `running_mean` and `running_cov` MUST be provided. 149 | 150 | running_mean : torch.tensor, or None 151 | The tensor with running mean statistics having shape [2, F]. Ignored 152 | if explicitly `None`. 153 | 154 | running_cov : torch.tensor, or None 155 | The tensor with running real-imaginary covariance statistics having 156 | shape [2, 2, F]. Ignored if explicitly `None`. 157 | 158 | momentum : float, default=0.1 159 | The weight in the exponential moving average used to keep track of the 160 | running feature statistics. 161 | 162 | nugget : float, default=1e-05 163 | The ridge coefficient to stabilise the estimate of the real-imaginary 164 | covariance. 165 | 166 | Details 167 | ------- 168 | Using (tril) L L^T = V seems to 'favour' the first dimension (re), so 169 | Trabelsi et al. (2018) used explicit 2x2 root of M: such R that M = RR. 170 | 171 | For M = [[a, b], [c, d]] we have the following facts: 172 | (1) inv M = \frac1{ad - bc} [[d, -b], [-c, a]] 173 | (2) \sqrt{M} = \frac1{t} [[a + s, b], [c, d + s]] 174 | for s = \sqrt{ad - bc}, t = \sqrt{a + d + 2 s} 175 | det \sqrt{M} = t^{-2} (ad + s(d + a) + s^2 - bc) = s 176 | 177 | Therefore `inv \sqrt{M} = [[p, q], [r, s]]`, where 178 | [[p, q], [r, s]] = \frac1{t s} [[d + s, -b], [-c, a + s]] 179 | """ 180 | # assume tensor is 2 x B x F x ... 181 | 182 | # tail shape for broadcasting ? x 1 x F x [*1] 183 | tail = 1, tensor.shape[2], *([1] * (tensor.dim() - 3)) 184 | axes = 1, *range(3, tensor.dim()) 185 | 186 | # 1. compute batch mean [2 x F] and center the batch 187 | if training: 188 | mean = tensor.mean(dim=axes) 189 | if running_mean is not None: 190 | running_mean += momentum * (mean.data - running_mean) 191 | 192 | else: 193 | mean = running_mean 194 | 195 | tensor = tensor - mean.reshape(2, *tail) 196 | 197 | # 2. per feature real-imaginary 2x2 covariance matrix 198 | if training: 199 | # faster than doing mul and then mean. Stabilize by a small ridge. 200 | var = tensor.var(dim=axes, unbiased=False) + nugget 201 | cov_uu, cov_vv = var[0], var[1] 202 | 203 | # has to mul-mean here anyway (naïve) : reduction axes shifted left. 204 | cov_vu = cov_uv = (tensor[0] * tensor[1]).mean([a - 1 for a in axes]) 205 | if running_cov is not None: 206 | cov = torch.stack([ 207 | cov_uu.data, cov_uv.data, 208 | cov_vu.data, cov_vv.data, 209 | ], dim=0).reshape(2, 2, -1) 210 | running_cov += momentum * (cov - running_cov) 211 | 212 | else: 213 | cov_uu, cov_uv = running_cov[0, 0], running_cov[0, 1] 214 | cov_vu, cov_vv = running_cov[1, 0], running_cov[1, 1] 215 | 216 | # 3. get R = [[p, q], [r, s]], with E R c c^T R^T = R M R = I 217 | # (unsure if intentional, but the inv-root in Trabelsi et al. (2018) uses 218 | # numpy `np.sqrt` instead of `K.sqrt` so grads are not passed through 219 | # properly, i.e. constants, [complex_standardization](bn.py#L56-57). 220 | sqrdet = torch.sqrt(cov_uu * cov_vv - cov_uv * cov_vu) 221 | # torch.det uses svd, so may yield -ve machine zero 222 | 223 | denom = sqrdet * torch.sqrt(cov_uu + 2 * sqrdet + cov_vv) 224 | p, q = (cov_vv + sqrdet) / denom, -cov_uv / denom 225 | r, s = -cov_vu / denom, (cov_uu + sqrdet) / denom 226 | 227 | # 4. apply Q to x (manually) 228 | out = torch.stack([ 229 | tensor[0] * p.reshape(tail) + tensor[1] * r.reshape(tail), 230 | tensor[0] * q.reshape(tail) + tensor[1] * s.reshape(tail), 231 | ], dim=0) 232 | return out # , torch.cat([p, q, r, s], dim=0).reshape(2, 2, -1) 233 | 234 | def batch_norm(input, running_mean, running_var, weight=None, bias=None, 235 | training=False, momentum=0.1, eps=1e-5, naive=False): 236 | 237 | """ 238 | Source: Source: https://github.com/ivannz/cplxmodule/blob/master/cplxmodule/nn/modules/batchnorm.py 239 | """ 240 | complex_weight = not(type(weight) == torch.nn.ParameterList) 241 | if naive: 242 | real = F.batch_norm(input.real, 243 | running_mean[0] if running_mean is not None else None, 244 | running_var[0] if running_var is not None else None, 245 | weight.real if complex_weight else weight[0], bias.real if complex_weight else bias[0], training, momentum, eps) 246 | imag = F.batch_norm(input.imag, 247 | running_mean[1] if running_mean is not None else None, 248 | running_var[1] if running_var is not None else None, 249 | weight.imag if complex_weight else weight[1], bias.imag if complex_weight else bias[1], training, momentum, eps) 250 | return torch.view_as_complex(torch.stack((real, imag),dim=-1)) 251 | else: 252 | # stack along the first axis 253 | x = torch.stack([input.real, input.imag], dim=0) 254 | 255 | # whiten and apply affine transformation 256 | z = _whiten2x2(x, training=training, running_mean=running_mean, 257 | running_cov=running_var, momentum=momentum, nugget=eps) 258 | 259 | if weight is not None and bias is not None: 260 | shape = 1, x.shape[2], *([1] * (x.dim() - 3)) 261 | weight = weight.reshape(2, 2, *shape) 262 | z = torch.stack([ 263 | z[0] * weight[0, 0] + z[1] * weight[0, 1], 264 | z[0] * weight[1, 0] + z[1] * weight[1, 1], 265 | ], dim=0) + bias.reshape(2, *shape) 266 | 267 | return torch.view_as_complex(torch.stack((z[0], z[1]),dim=-1)) 268 | 269 | 270 | # Activations 271 | 272 | def crelu(input: Tensor, inplace: bool = False) -> Tensor: 273 | ''' 274 | Eq.(4) 275 | https://arxiv.org/pdf/1705.09792.pdf 276 | ''' 277 | if input.is_complex(): 278 | return torch.view_as_complex(torch.stack((F.relu(input.real), F.relu(input.imag)),dim=-1)) 279 | else: 280 | return F.relu(input, inplace=inplace) 281 | 282 | def zrelu(input: Tensor, inplace: bool = False) -> Tensor: 283 | ''' 284 | Guberman ReLU: 285 | Nitzan Guberman. On complex valued convolutional neural networks. arXiv preprint arXiv:1602.09046, 2016 286 | Eq.(5) 287 | https://arxiv.org/pdf/1705.09792.pdf 288 | ''' 289 | if input.is_complex(): 290 | return input * ((0 < input.angle()) * (input.angle() < math.pi/2)).float() 291 | else: 292 | return F.relu(input, inplace=inplace) 293 | 294 | def modrelu(input: Tensor, bias: Tensor, inplace: bool = False) -> Tensor: 295 | ''' 296 | Martin Arjovsky, Amar Shah, and Yoshua Bengio. Unitary evolution recurrent neural networks. arXiv preprint arXiv:1511.06464, 2015. 297 | Notice that |z| (z.magnitude) is always positive, so if b > 0 then |z| + b > = 0 always. 298 | In order to have any non-linearity effect, b must be smaller than 0 (b<0). 299 | Update: The implementation has been updated following: \\operatorname{ReLU}(|z|+b) \\frac{z}{|z|} 300 | ''' 301 | if input.is_complex(): 302 | z_mag = torch.abs(input) 303 | return F.relu(z_mag + bias) * (input / z_mag) 304 | else: 305 | return F.relu(input, inplace=inplace) 306 | 307 | def cmodrelu(input: Tensor, threshold: int, inplace: bool = False): 308 | r"""Compute the Complex modulus relu of the complex tensor in re-im pair. 309 | As proposed in : https://arxiv.org/pdf/1802.08026.pdf 310 | Source: https://github.com/ivannz/cplxmodule""" 311 | if input.is_complex(): 312 | modulus = torch.clamp(torch.abs(input), min=1e-5) 313 | _tmp_newshape = (1,len(threshold)) + (1,)*len(input.shape[2:]) 314 | return input * F.relu(1. - threshold.view(_tmp_newshape) / modulus) 315 | else: 316 | return F.relu(input, inplace=inplace) 317 | 318 | def softmax(input, dim=None, _stacklevel=3, dtype=None): 319 | ''' 320 | Complex-valued Neural Networks with Non-parametric Activation Functions 321 | (Eq. 36) 322 | https://arxiv.org/pdf/1802.08026.pdf 323 | ''' 324 | if input.is_complex(): 325 | return F.softmax(torch.abs(input), dim=dim, _stacklevel=_stacklevel, dtype=dtype) 326 | else: 327 | return F.softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) 328 | 329 | def tanh(input: Tensor): 330 | if input.is_complex(): 331 | a, b = input.real, input.imag 332 | denominator = torch.cosh(2*a) + torch.cos(2*b) 333 | real = torch.sinh(2 * a) / denominator 334 | imag = torch.sin(2 * a) / denominator 335 | return torch.view_as_complex(torch.stack((real, imag),dim=-1)) 336 | else: 337 | return F.tanh(input) 338 | 339 | def hirose(input: Tensor, m_sqaure: float = 1): 340 | ''' 341 | A. Hirose. Complex-valued neural networks: Advances and applications. John Wiley & Sons, 2013. 342 | and 343 | Wolter and Yao. Complex Gated Recurrent Neural Networks. NeurIPS 2018. (Eq. 5) https://papers.nips.cc/paper_files/paper/2018/file/652cf38361a209088302ba2b8b7f51e0-Paper.pdf 344 | ''' 345 | mag_input = torch.abs(input) 346 | return F.tanh(mag_input/m_sqaure) * (input / mag_input) 347 | 348 | def modsigmoid(input: Tensor, alpha: float = 0.5): 349 | ''' 350 | Wolter and Yao. Complex Gated Recurrent Neural Networks. NeurIPS 2018. (Eq. 13) https://papers.nips.cc/paper_files/paper/2018/file/652cf38361a209088302ba2b8b7f51e0-Paper.pdf 351 | and 352 | Xie et al. Complex Recurrent Variational Autoencoder with Application to Speech Enhancement. 2023. arXiv:2204.02195v2 353 | ''' 354 | return torch.sigmoid(alpha * input.real + (1 - alpha) * input.imag) 355 | 356 | def sigmoid(input: Tensor): 357 | if input.is_complex(): 358 | a, b = input.real, input.imag 359 | denominator = 1 + 2 * torch.exp(-a) * torch.cos(b) + torch.exp(-2 * a) 360 | real = 1 + torch.exp(-a) * torch.cos(b) / denominator 361 | imag = torch.exp(-a) * torch.sin(b) / denominator 362 | return torch.view_as_complex(torch.stack((real, imag),dim=-1)) 363 | else: 364 | return F.sigmoid(input) 365 | 366 | def _sinc_interpolate(input, size): 367 | axes = np.argwhere(np.equal(input.shape[2:], size) == False).squeeze(1) #2 dims for batch and channel 368 | out_shape = [size[i] for i in axes] 369 | return resample(input, out_shape, axis=axes+2) #2 dims for batch and channel 370 | 371 | def interpolate(input, size=None, scale_factor=None, mode='sinc', align_corners=None, recompute_scale_factor=None): 372 | if mode in ('nearest', 'area', 'sinc'): 373 | if align_corners is not None: 374 | raise ValueError("align_corners option can only be set with the " 375 | "interpolating modes: linear | bilinear | bicubic | trilinear") 376 | 377 | dim = input.dim() - 2 # Number of spatial dimensions. 378 | 379 | # Process size and scale_factor. Validate that exactly one is set. 380 | # Validate its length if it is a list, or expand it if it is a scalar. 381 | # After this block, exactly one of output_size and scale_factors will 382 | # be non-None, and it will be a list (or tuple). 383 | if size is not None and scale_factor is not None: 384 | raise ValueError('only one of size or scale_factor should be defined') 385 | elif size is not None: 386 | assert scale_factor is None 387 | scale_factors = None 388 | if isinstance(size, (list, tuple)): 389 | if len(size) != dim: 390 | raise ValueError('size shape must match input shape. ' 391 | 'Input is {}D, size is {}'.format(dim, len(size))) 392 | output_size = size 393 | else: 394 | output_size = [size for _ in range(dim)] 395 | elif scale_factor is not None: 396 | assert size is None 397 | output_size = None 398 | if isinstance(scale_factor, (list, tuple)): 399 | if len(scale_factor) != dim: 400 | raise ValueError('scale_factor shape must match input shape. ' 401 | 'Input is {}D, scale_factor is {}'.format(dim, len(scale_factor))) 402 | scale_factors = scale_factor 403 | else: 404 | scale_factors = [scale_factor for _ in range(dim)] 405 | else: 406 | raise ValueError('either size or scale_factor should be defined') 407 | 408 | if recompute_scale_factor is None: 409 | # only warn when the scales have floating values since 410 | # the result for ints is the same with/without recompute_scale_factor 411 | if scale_factors is not None: 412 | for scale in scale_factors: 413 | if math.floor(scale) != scale: 414 | warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed " 415 | "in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, " 416 | "instead of relying on the computed output size. " 417 | "If you wish to restore the old behavior, please set recompute_scale_factor=True. " 418 | "See the documentation of nn.Upsample for details. ") 419 | break 420 | elif recompute_scale_factor and size is not None: 421 | raise ValueError("recompute_scale_factor is not meaningful with an explicit size.") 422 | 423 | # "area" and "sinc" modes always require an explicit size rather than scale factor. 424 | # Re-use the recompute_scale_factor code path. 425 | if (mode == "area" or mode == "sinc") and output_size is None: 426 | recompute_scale_factor = True 427 | 428 | if recompute_scale_factor is not None and recompute_scale_factor: 429 | # We compute output_size here, then un-set scale_factors. 430 | # The C++ code will recompute it based on the (integer) output size. 431 | if not torch.jit.is_scripting() and torch._C._get_tracing_state(): 432 | # make scale_factor a tensor in tracing so constant doesn't get baked in 433 | output_size = [(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], 434 | dtype=torch.float32)).float())) for i in range(dim)] 435 | else: 436 | assert scale_factors is not None 437 | output_size = [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] 438 | scale_factors = None 439 | 440 | if mode == "sinc": 441 | return _sinc_interpolate(input, output_size) 442 | else: 443 | return complex_fcaller(F.interpolate, input, output_size, scale_factors, mode, align_corners) -------------------------------------------------------------------------------- /torchcomplex/nn/init.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from torch.nn.parameter import Parameter 8 | from torch.nn import ParameterList 9 | 10 | class _tensorprocessor(): 11 | @classmethod 12 | def _preprocess(cls, tensor): 13 | if type(tensor) is ParameterList: 14 | cls.complex_weight=False 15 | return tensor 16 | else: 17 | cls.complex_weight=True 18 | return ParameterList([Parameter(tensor.real), Parameter(tensor.imag)]) 19 | @classmethod 20 | def _postprocess(cls, tensor): 21 | if cls.complex_weight: 22 | return Parameter(tensor[0] + 1j*tensor[1]) 23 | else: 24 | if type(tensor) is ParameterList: 25 | return tensor 26 | else: 27 | return ParameterList(tensor) 28 | 29 | # These no_grad_* functions are necessary as wrappers around the parts of these 30 | # functions that use `with torch.no_grad()`. The JIT doesn't support context 31 | # managers, so these need to be implemented as builtins. Using these wrappers 32 | # lets us keep those builtins small and re-usable. 33 | def _no_grad_uniform_(tensor, a, b): 34 | with torch.no_grad(): 35 | return (tensor[0].uniform_(a, b), tensor[1].uniform_(a, b)) 36 | 37 | 38 | def _no_grad_normal_(tensor, mean, std): 39 | with torch.no_grad(): 40 | return (tensor[0].normal_(mean, std), tensor[1].normal_(mean, std)) 41 | 42 | 43 | # def _no_grad_trunc_normal_(tensor, mean, std, a, b): 44 | # # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 45 | # def norm_cdf(x): 46 | # # Computes standard normal cumulative distribution function 47 | # return (1. + math.erf(x / math.sqrt(2.))) / 2. 48 | 49 | # if (mean < a - 2 * std) or (mean > b + 2 * std): 50 | # warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 51 | # "The distribution of values may be incorrect.", 52 | # stacklevel=2) 53 | 54 | # with torch.no_grad(): 55 | # # Values are generated by using a truncated uniform distribution and 56 | # # then using the inverse CDF for the normal distribution. 57 | # # Get upper and lower cdf values 58 | # l = norm_cdf((a - mean) / std) 59 | # u = norm_cdf((b - mean) / std) 60 | 61 | # # Uniformly fill tensor with values from [l, u], then translate to 62 | # # [2l-1, 2u-1]. 63 | # tensor.uniform_(2 * l - 1, 2 * u - 1) 64 | 65 | # # Use inverse cdf transform for normal distribution to get truncated 66 | # # standard normal 67 | # tensor.erfinv_() 68 | 69 | # # Transform to proper mean, std 70 | # tensor.mul_(std * math.sqrt(2.)) 71 | # tensor.add_(mean) 72 | 73 | # # Clamp to ensure it's in the proper range 74 | # tensor.clamp_(min=a, max=b) 75 | # return tensor 76 | 77 | 78 | def _no_grad_fill_(tensor, val): 79 | with torch.no_grad(): 80 | return (tensor[0].fill_(val), tensor[1].fill_(val)) 81 | 82 | 83 | def _no_grad_zero_(tensor): 84 | with torch.no_grad(): 85 | return (tensor[0].zero_(), tensor[1].zero_()) 86 | 87 | ##TODO: implement 88 | # def calculate_gain(nonlinearity, param=None): 89 | # r"""Return the recommended gain value for the given nonlinearity function. 90 | # The values are as follows: 91 | 92 | # ================= ==================================================== 93 | # nonlinearity gain 94 | # ================= ==================================================== 95 | # Linear / Identity :math:`1` 96 | # Conv{1,2,3}D :math:`1` 97 | # Sigmoid :math:`1` 98 | # Tanh :math:`\frac{5}{3}` 99 | # ReLU :math:`\sqrt{2}` 100 | # Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` 101 | # ================= ==================================================== 102 | 103 | # Args: 104 | # nonlinearity: the non-linear function (`nn.functional` name) 105 | # param: optional parameter for the non-linear function 106 | 107 | # Examples: 108 | # >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 109 | # """ 110 | # linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] 111 | # if nonlinearity in linear_fns or nonlinearity == 'sigmoid': 112 | # return 1 113 | # elif nonlinearity == 'tanh': 114 | # return 5.0 / 3 115 | # elif nonlinearity == 'relu': 116 | # return math.sqrt(2.0) 117 | # elif nonlinearity == 'leaky_relu': 118 | # if param is None: 119 | # negative_slope = 0.01 120 | # elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): 121 | # # True/False are instances of int, hence check above 122 | # negative_slope = param 123 | # else: 124 | # raise ValueError("negative_slope {} not a valid number".format(param)) 125 | # return math.sqrt(2.0 / (1 + negative_slope ** 2)) 126 | # else: 127 | # raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) 128 | 129 | 130 | def uniform_(tensor, a=0., b=1.): 131 | # type: (Tensor, float, float) -> Tensor 132 | r"""Fills the input Tensor with values drawn from the uniform 133 | distribution :math:`\mathcal{U}(a, b)`. 134 | 135 | Args: 136 | tensor: an n-dimensional `torch.Tensor` 137 | a: the lower bound of the uniform distribution 138 | b: the upper bound of the uniform distribution 139 | 140 | Examples: 141 | >>> w = torch.empty(3, 5) 142 | >>> nn.init.uniform_(w) 143 | """ 144 | tensor = _tensorprocessor._preprocess(tensor) 145 | return _tensorprocessor._postprocess(_no_grad_uniform_(tensor, a, b)) 146 | 147 | 148 | def normal_(tensor, mean=0., std=1.): 149 | # type: (Tensor, float, float) -> Tensor 150 | r"""Fills the input Tensor with values drawn from the normal 151 | distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. 152 | 153 | Args: 154 | tensor: an n-dimensional `torch.Tensor` 155 | mean: the mean of the normal distribution 156 | std: the standard deviation of the normal distribution 157 | 158 | Examples: 159 | >>> w = torch.empty(3, 5) 160 | >>> nn.init.normal_(w) 161 | """ 162 | tensor = _tensorprocessor._preprocess(tensor) 163 | return _tensorprocessor._postprocess(_no_grad_normal_(tensor, mean, std)) 164 | 165 | # def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 166 | # # type: (Tensor, float, float, float, float) -> Tensor 167 | # r"""Fills the input Tensor with values drawn from a truncated 168 | # normal distribution. The values are effectively drawn from the 169 | # normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 170 | # with values outside :math:`[a, b]` redrawn until they are within 171 | # the bounds. The method used for generating the random values works 172 | # best when :math:`a \leq \text{mean} \leq b`. 173 | 174 | # Args: 175 | # tensor: an n-dimensional `torch.Tensor` 176 | # mean: the mean of the normal distribution 177 | # std: the standard deviation of the normal distribution 178 | # a: the minimum cutoff value 179 | # b: the maximum cutoff value 180 | 181 | # Examples: 182 | # >>> w = torch.empty(3, 5) 183 | # >>> nn.init.trunc_normal_(w) 184 | # """ 185 | # return _no_grad_trunc_normal_(tensor, mean, std, a, b) 186 | 187 | 188 | def constant_(tensor, val): 189 | # type: (Tensor, float) -> Tensor 190 | r"""Fills the input Tensor with the value :math:`\text{val}`. 191 | 192 | Args: 193 | tensor: an n-dimensional `torch.Tensor` 194 | val: the value to fill the tensor with 195 | 196 | Examples: 197 | >>> w = torch.empty(3, 5) 198 | >>> nn.init.constant_(w, 0.3) 199 | """ 200 | tensor = _tensorprocessor._preprocess(tensor) 201 | return _tensorprocessor._postprocess(_no_grad_fill_(tensor, val)) 202 | 203 | def ones_(tensor): 204 | # type: (Tensor) -> Tensor 205 | r"""Fills the input Tensor with the scalar value `1`. 206 | 207 | Args: 208 | tensor: an n-dimensional `torch.Tensor` 209 | 210 | Examples: 211 | >>> w = torch.empty(3, 5) 212 | >>> nn.init.ones_(w) 213 | """ 214 | tensor = _tensorprocessor._preprocess(tensor) 215 | return _tensorprocessor._postprocess(_no_grad_fill_(tensor, 1.)) 216 | 217 | def zeros_(tensor): 218 | # type: (Tensor) -> Tensor 219 | r"""Fills the input Tensor with the scalar value `0`. 220 | 221 | Args: 222 | tensor: an n-dimensional `torch.Tensor` 223 | 224 | Examples: 225 | >>> w = torch.empty(3, 5) 226 | >>> nn.init.zeros_(w) 227 | """ 228 | tensor = _tensorprocessor._preprocess(tensor) 229 | return _tensorprocessor._postprocess(_no_grad_zero_(tensor)) 230 | 231 | def eye_(tensor): 232 | r"""Fills the 2-dimensional input `Tensor` with the identity 233 | matrix. Preserves the identity of the inputs in `Linear` layers, where as 234 | many inputs are preserved as possible. 235 | 236 | Args: 237 | tensor: a 2-dimensional `torch.Tensor` 238 | 239 | Examples: 240 | >>> w = torch.empty(3, 5) 241 | >>> nn.init.eye_(w) 242 | """ 243 | tensor = _tensorprocessor._preprocess(tensor) 244 | torch.nn.init.eye_(tensor[0]) 245 | torch.nn.init.eye_(tensor[1]) 246 | return _tensorprocessor._postprocess(tensor) 247 | 248 | def dirac_(tensor, groups=1): 249 | r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac 250 | delta function. Preserves the identity of the inputs in `Convolutional` 251 | layers, where as many input channels are preserved as possible. In case 252 | of groups>1, each group of channels preserves identity 253 | 254 | Args: 255 | tensor: a {3, 4, 5}-dimensional `torch.Tensor` 256 | groups (optional): number of groups in the conv layer (default: 1) 257 | Examples: 258 | >>> w = torch.empty(3, 16, 5, 5) 259 | >>> nn.init.dirac_(w) 260 | >>> w = torch.empty(3, 24, 5, 5) 261 | >>> nn.init.dirac_(w, 3) 262 | """ 263 | tensor = _tensorprocessor._preprocess(tensor) 264 | torch.nn.init.dirac_(tensor[0], groups=groups) 265 | torch.nn.init.dirac_(tensor[1], groups=groups) 266 | return _tensorprocessor._postprocess(tensor) 267 | 268 | def _calculate_fan_in_and_fan_out(tensor): 269 | dimensions = tensor.dim() 270 | if dimensions < 2: 271 | raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") 272 | 273 | num_input_fmaps = tensor.size(1) 274 | num_output_fmaps = tensor.size(0) 275 | receptive_field_size = 1 276 | if tensor.dim() > 2: 277 | receptive_field_size = tensor[0][0].numel() 278 | fan_in = num_input_fmaps * receptive_field_size 279 | fan_out = num_output_fmaps * receptive_field_size 280 | 281 | return fan_in, fan_out 282 | 283 | def xavier_uniform_(tensor, gain=1.): 284 | # type: (Tensor, float) -> Tensor 285 | r"""Fills the input `Tensor` with values according to the method 286 | described in `Understanding the difficulty of training deep feedforward 287 | neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform 288 | distribution. The resulting tensor will have values sampled from 289 | :math:`\mathcal{U}(-a, a)` where 290 | 291 | .. math:: 292 | a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} 293 | 294 | Also known as Glorot initialization. 295 | 296 | Args: 297 | tensor: an n-dimensional `torch.Tensor` 298 | gain: an optional scaling factor 299 | 300 | Examples: 301 | >>> w = torch.empty(3, 5) 302 | >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) 303 | """ 304 | tensor = _tensorprocessor._preprocess(tensor) 305 | torch.nn.init.xavier_uniform_(tensor[0], gain=gain/math.sqrt(2)) 306 | torch.nn.init.xavier_uniform_(tensor[1], gain=gain/math.sqrt(2)) 307 | return _tensorprocessor._postprocess(tensor) 308 | 309 | def xavier_normal_(tensor, gain=1.): 310 | # type: (Tensor, float) -> Tensor 311 | r"""Fills the input `Tensor` with values according to the method 312 | described in `Understanding the difficulty of training deep feedforward 313 | neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal 314 | distribution. The resulting tensor will have values sampled from 315 | :math:`\mathcal{N}(0, \text{std}^2)` where 316 | 317 | .. math:: 318 | \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} 319 | 320 | Also known as Glorot initialization. 321 | 322 | Args: 323 | tensor: an n-dimensional `torch.Tensor` 324 | gain: an optional scaling factor 325 | 326 | Examples: 327 | >>> w = torch.empty(3, 5) 328 | >>> nn.init.xavier_normal_(w) 329 | """ 330 | tensor = _tensorprocessor._preprocess(tensor) 331 | torch.nn.init.xavier_normal_(tensor[0], gain=gain/math.sqrt(2)) 332 | torch.nn.init.xavier_normal_(tensor[1], gain=gain/math.sqrt(2)) 333 | return _tensorprocessor._postprocess(tensor) 334 | 335 | def _calculate_correct_fan(tensor, mode): 336 | mode = mode.lower() 337 | valid_modes = ['fan_in', 'fan_out'] 338 | if mode not in valid_modes: 339 | raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) 340 | 341 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 342 | return fan_in if mode == 'fan_in' else fan_out 343 | 344 | def kaiming_uniform_(tensor, a=0.0, mode='fan_in', nonlinearity='leaky_relu'): 345 | r"""Fills the input `Tensor` with values according to the method 346 | described in `Delving deep into rectifiers: Surpassing human-level 347 | performance on ImageNet classification` - He, K. et al. (2015), using a 348 | uniform distribution. The resulting tensor will have values sampled from 349 | :math:`\mathcal{U}(-\text{bound}, \text{bound})` where 350 | 351 | .. math:: 352 | \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} 353 | 354 | Also known as He initialization. 355 | 356 | Args: 357 | tensor: an n-dimensional `torch.Tensor` 358 | a: the negative slope of the rectifier used after this layer (only 359 | used with ``'leaky_relu'``) 360 | mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 361 | preserves the magnitude of the variance of the weights in the 362 | forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 363 | backwards pass. 364 | nonlinearity: the non-linear function (`nn.functional` name), 365 | recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 366 | 367 | Examples: 368 | >>> w = torch.empty(3, 5) 369 | >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') 370 | """ 371 | a = math.sqrt(1 + 2 * a * a) 372 | tensor = _tensorprocessor._preprocess(tensor) 373 | torch.nn.init.kaiming_uniform_(tensor[0], a=a, mode=mode, nonlinearity=nonlinearity) 374 | torch.nn.init.kaiming_uniform_(tensor[1], a=a, mode=mode, nonlinearity=nonlinearity) 375 | return _tensorprocessor._postprocess(tensor) 376 | 377 | def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): 378 | r"""Fills the input `Tensor` with values according to the method 379 | described in `Delving deep into rectifiers: Surpassing human-level 380 | performance on ImageNet classification` - He, K. et al. (2015), using a 381 | normal distribution. The resulting tensor will have values sampled from 382 | :math:`\mathcal{N}(0, \text{std}^2)` where 383 | 384 | .. math:: 385 | \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} 386 | 387 | Also known as He initialization. 388 | 389 | Args: 390 | tensor: an n-dimensional `torch.Tensor` 391 | a: the negative slope of the rectifier used after this layer (only 392 | used with ``'leaky_relu'``) 393 | mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 394 | preserves the magnitude of the variance of the weights in the 395 | forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 396 | backwards pass. 397 | nonlinearity: the non-linear function (`nn.functional` name), 398 | recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 399 | 400 | Examples: 401 | >>> w = torch.empty(3, 5) 402 | >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') 403 | """ 404 | a = math.sqrt(1 + 2 * a * a) 405 | tensor = _tensorprocessor._preprocess(tensor) 406 | torch.nn.init.kaiming_normal_(tensor[0], a=a, mode=mode, nonlinearity=nonlinearity) 407 | torch.nn.init.kaiming_normal_(tensor[1], a=a, mode=mode, nonlinearity=nonlinearity) 408 | return _tensorprocessor._postprocess(tensor) 409 | 410 | # def orthogonal_(tensor, gain=1): 411 | # r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as 412 | # described in `Exact solutions to the nonlinear dynamics of learning in deep 413 | # linear neural networks` - Saxe, A. et al. (2013). The input tensor must have 414 | # at least 2 dimensions, and for tensors with more than 2 dimensions the 415 | # trailing dimensions are flattened. 416 | 417 | # Args: 418 | # tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` 419 | # gain: optional scaling factor 420 | 421 | # Examples: 422 | # >>> w = torch.empty(3, 5) 423 | # >>> nn.init.orthogonal_(w) 424 | # """ 425 | # if tensor.ndimension() < 2: 426 | # raise ValueError("Only tensors with 2 or more dimensions are supported") 427 | 428 | # rows = tensor.size(0) 429 | # cols = tensor.numel() // rows 430 | # flattened = tensor.new(rows, cols).normal_(0, 1) 431 | 432 | # if rows < cols: 433 | # flattened.t_() 434 | 435 | # # Compute the qr factorization 436 | # q, r = torch.qr(flattened) 437 | # # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf 438 | # d = torch.diag(r, 0) 439 | # ph = d.sign() 440 | # q *= ph 441 | 442 | # if rows < cols: 443 | # q.t_() 444 | 445 | # with torch.no_grad(): 446 | # tensor.view_as(q).copy_(q) 447 | # tensor.mul_(gain) 448 | # return tensor 449 | 450 | # def sparse_(tensor, sparsity, std=0.01): 451 | # r"""Fills the 2D input `Tensor` as a sparse matrix, where the 452 | # non-zero elements will be drawn from the normal distribution 453 | # :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via 454 | # Hessian-free optimization` - Martens, J. (2010). 455 | 456 | # Args: 457 | # tensor: an n-dimensional `torch.Tensor` 458 | # sparsity: The fraction of elements in each column to be set to zero 459 | # std: the standard deviation of the normal distribution used to generate 460 | # the non-zero values 461 | 462 | # Examples: 463 | # >>> w = torch.empty(3, 5) 464 | # >>> nn.init.sparse_(w, sparsity=0.1) 465 | # """ 466 | # if tensor.ndimension() != 2: 467 | # raise ValueError("Only tensors with 2 dimensions are supported") 468 | 469 | # rows, cols = tensor.shape 470 | # num_zeros = int(math.ceil(sparsity * rows)) 471 | 472 | # with torch.no_grad(): 473 | # tensor.normal_(0, std) 474 | # for col_idx in range(cols): 475 | # row_indices = torch.randperm(rows) 476 | # zero_indices = row_indices[:num_zeros] 477 | # tensor[zero_indices, col_idx] = 0 478 | # return tensor 479 | 480 | def trabelsi_standard_(tensor, kind="glorot"): 481 | """Standard complex initialization proposed in Trabelsi et al. (2018).""" 482 | kind = kind.lower() 483 | assert kind in ("glorot", "xavier", "kaiming", "he") 484 | 485 | tensor = _tensorprocessor._preprocess(tensor) 486 | 487 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor[0]) 488 | if kind == "glorot" or kind == "xavier": 489 | scale = 1 / math.sqrt(fan_in + fan_out) 490 | else: 491 | scale = 1 / math.sqrt(fan_in) 492 | 493 | # Rayleigh(\sigma / \sqrt2) x uniform[-\pi, +\pi] on p. 7 494 | rho = np.random.rayleigh(scale, size=tensor[0].shape) 495 | theta = np.random.uniform(-np.pi, +np.pi, size=tensor[0].shape) 496 | 497 | # eq. (8) on p. 6 498 | with torch.no_grad(): 499 | tensor[0].copy_(torch.from_numpy(np.cos(theta) * rho)) 500 | tensor[1].copy_(torch.from_numpy(np.sin(theta) * rho)) 501 | 502 | return _tensorprocessor._postprocess(tensor) 503 | 504 | def trabelsi_independent_(tensor, kind="glorot"): 505 | """Orthogonal complex initialization proposed in Trabelsi et al. (2018).""" 506 | kind = kind.lower() 507 | assert kind in ("glorot", "xavier", "kaiming", "he") 508 | 509 | tensor = _tensorprocessor._preprocess(tensor) 510 | 511 | ndim = tensor[0].dim() 512 | if ndim == 2: 513 | shape = tensor[0].shape 514 | else: 515 | shape = int(np.prod(tensor[0].shape[:2])), int(np.prod(tensor[0].shape[2:])) 516 | 517 | # generate a semi-unitary (orthogonal) matrix from a random matrix 518 | # M = U V is semi-unitary: V^H U^H U V = I_k 519 | Z = np.random.rand(*shape) + 1j * np.random.rand(*shape) 520 | 521 | # Z is n x m, so u is n x n and vh is m x m 522 | u, _, vh = np.linalg.svd(Z, compute_uv=True, full_matrices=True, hermitian=False) 523 | k = min(*shape) 524 | M = np.dot(u[:, :k], vh[:, :k].conjugate().T) 525 | 526 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor[0]) 527 | if kind == "glorot" or kind == "xavier": 528 | scale = 1 / math.sqrt(fan_in + fan_out) 529 | else: 530 | scale = 1 / math.sqrt(fan_in) 531 | 532 | M /= M.std() / scale 533 | M = M.reshape(tensor[0].shape) 534 | 535 | with torch.no_grad(): 536 | tensor[0].copy_(torch.from_numpy(M.real)) 537 | tensor[1].copy_(torch.from_numpy(M.imag)) 538 | 539 | return _tensorprocessor._postprocess(tensor) -------------------------------------------------------------------------------- /torchcomplex/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | ############### 2 | # just to use the torch's version of these classes, to help with imports are imported here 3 | #module 4 | from torch.nn import Module 5 | 6 | #containers 7 | from torch.nn import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict 8 | 9 | #padding 10 | from torch.nn import ReflectionPad1d, ReflectionPad2d, ReplicationPad1d, ReplicationPad2d, \ 11 | ReplicationPad3d, ZeroPad2d, ConstantPad1d, ConstantPad2d, ConstantPad3d 12 | ############### 13 | 14 | from .linear import Identity, Linear, Bilinear 15 | from .conv import Conv1d, Conv2d, Conv3d, \ 16 | ConvTranspose1d, ConvTranspose2d, ConvTranspose3d 17 | from .activation import GenericComplexActivation, Sigmoid, Tanh, \ 18 | Softmax, Softmax2d, CReLU, zReLU, modReLU, CmodReLU, AdaptiveCmodReLU, Hirose, modSigmoid 19 | from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \ 20 | MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \ 21 | AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d 22 | from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d 23 | from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout 24 | from .upsampling import Upsample -------------------------------------------------------------------------------- /torchcomplex/nn/modules/activation.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.parameter import Parameter 5 | from .. import functional as cF 6 | from torch.nn.modules import Module 7 | from typing import Optional, List, Tuple, Union 8 | 9 | class GenericComplexActivation(Module): 10 | def __init__(self, activation, use_phase: bool = False): 11 | ''' 12 | activation can be either a function from nn.functional or an object of nn.Module if the ativation has learnable parameters 13 | Original idea from: https://github.com/albanD 14 | ''' 15 | self.activation = activation 16 | self.use_phase = use_phase 17 | 18 | def forward(self, input: Tensor): 19 | if self.use_phase: 20 | return self.activation(torch.abs(input)) * torch.exp(1.j * torch.angle(input)) 21 | else: 22 | return self.activation(input.real) + 1.j * self.activation(input.imag) 23 | 24 | class CReLU(Module): 25 | ''' 26 | Eq.(4) 27 | https://arxiv.org/pdf/1705.09792.pdf 28 | ''' 29 | __constants__ = ['inplace'] 30 | inplace: bool 31 | 32 | def __init__(self, inplace: bool = False): 33 | super(CReLU, self).__init__() 34 | self.inplace = inplace 35 | 36 | def forward(self, input: Tensor) -> Tensor: 37 | return cF.crelu(input, inplace=self.inplace) 38 | 39 | def extra_repr(self) -> str: 40 | inplace_str = 'inplace=True' if self.inplace else '' 41 | return inplace_str 42 | 43 | class zReLU(Module): 44 | ''' 45 | Guberman ReLU: 46 | Nitzan Guberman. On complex valued convolutional neural networks. arXiv preprint arXiv:1602.09046, 2016 47 | Eq.(5) 48 | https://arxiv.org/pdf/1705.09792.pdf 49 | 50 | Warning: 51 | Inplace will only be used if the input is real (i.e. while using the default relu of PyTorch) 52 | ''' 53 | __constants__ = ['inplace'] 54 | inplace: bool 55 | 56 | def __init__(self, inplace: bool = False): 57 | super(zReLU, self).__init__() 58 | self.inplace = inplace 59 | 60 | def forward(self, input: Tensor) -> Tensor: 61 | return cF.zrelu(input, inplace=self.inplace) 62 | 63 | def extra_repr(self) -> str: 64 | inplace_str = 'inplace=True' if self.inplace else '' 65 | return inplace_str 66 | 67 | class modReLU(Module): 68 | ''' 69 | Martin Arjovsky, Amar Shah, and Yoshua Bengio. Unitary evolution recurrent neural networks. arXiv preprint arXiv:1511.06464, 2015. 70 | Notice that |z| (z.magnitude) is always positive, so if b > 0 then |z| + b > = 0 always. 71 | In order to have any non-linearity effect, b must be smaller than 0 (b<0). 72 | Update: The implementation has been updated following: \\operatorname{ReLU}(|z|+b) \\frac{z}{|z|} 73 | 74 | Warning: 75 | Inplace will only be used if the input is real (i.e. while using the default relu of PyTorch) 76 | ''' 77 | __constants__ = ['inplace'] 78 | inplace: bool 79 | 80 | def __init__(self, inplace: bool = False): 81 | super(modReLU, self).__init__() 82 | self.inplace = inplace 83 | self.bias = Parameter(torch.rand(1) * 0.25) 84 | 85 | def forward(self, input: Tensor) -> Tensor: 86 | return cF.modrelu(input, bias=self.bias, inplace=self.inplace) 87 | 88 | def extra_repr(self) -> str: 89 | inplace_str = 'inplace=True' if self.inplace else '' 90 | return inplace_str 91 | 92 | class CmodReLU(Module): 93 | '''Compute the Complex modulus relu of the complex tensor in re-im pair. 94 | As proposed in : https://arxiv.org/pdf/1802.08026.pdf 95 | 96 | If threshold=None then it becomes a learnable parameter. 97 | ''' 98 | __constants__ = ['inplace'] 99 | inplace: bool 100 | 101 | def __init__(self, threshold: int = None, inplace: bool = False): 102 | super(CmodReLU, self).__init__() 103 | self.inplace = inplace 104 | if not isinstance(threshold, float): 105 | threshold = Parameter(torch.rand(1) * 0.25) 106 | self.threshold = threshold 107 | 108 | def forward(self, input: Tensor) -> Tensor: 109 | return cF.cmodrelu(input, threshold=self.threshold, inplace=self.inplace) 110 | 111 | def extra_repr(self) -> str: 112 | inplace_str = 'inplace=True' if self.inplace else '' 113 | return inplace_str 114 | 115 | class AdaptiveCmodReLU(Module): 116 | '''Compute the Complex modulus relu of the complex tensor in re-im pair. 117 | As proposed in : https://arxiv.org/pdf/1802.08026.pdf 118 | 119 | AdaptiveCmodReLU(1) learns one common threshold for all features, AdaptiveCmodReLU(d) learns seperate ones for each dimension 120 | ''' 121 | __constants__ = ['inplace'] 122 | inplace: bool 123 | 124 | def __init__(self, *dim, inplace: bool = False): 125 | super(AdaptiveCmodReLU, self).__init__() 126 | self.inplace = inplace 127 | self.dim = dim if dim else (1,) 128 | self.threshold = Parameter(torch.randn(*self.dim) * 0.02) 129 | 130 | def forward(self, input: Tensor) -> Tensor: 131 | return cF.cmodrelu(input, threshold=self.threshold, inplace=self.inplace) 132 | 133 | def extra_repr(self) -> str: 134 | inplace_str = 'inplace=True' if self.inplace else '' 135 | return inplace_str 136 | 137 | class Softmax(Module): 138 | __constants__ = ['dim'] 139 | dim: Optional[int] 140 | 141 | def __init__(self, dim: Optional[int] = None) -> None: 142 | super(Softmax, self).__init__() 143 | self.dim = dim 144 | 145 | def __setstate__(self, state): 146 | self.__dict__.update(state) 147 | if not hasattr(self, 'dim'): 148 | self.dim = None 149 | 150 | def forward(self, input: Tensor) -> Tensor: 151 | return cF.softmax(input, self.dim, _stacklevel=5) 152 | 153 | def extra_repr(self) -> str: 154 | return 'dim={dim}'.format(dim=self.dim) 155 | 156 | class Softmax2d(Module): 157 | def forward(self, input: Tensor) -> Tensor: 158 | assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' 159 | return cF.softmax(input, 1, _stacklevel=5) 160 | 161 | class Tanh(Module): 162 | def forward(self, input: Tensor) -> Tensor: 163 | return cF.tanh(input) 164 | 165 | class Hirose(Module): 166 | ''' 167 | A. Hirose. Complex-valued neural networks: Advances and applications. John Wiley & Sons, 2013. 168 | and 169 | Wolter and Yao. Complex Gated Recurrent Neural Networks. NeurIPS 2018. (Eq. 5) https://papers.nips.cc/paper_files/paper/2018/file/652cf38361a209088302ba2b8b7f51e0-Paper.pdf 170 | ''' 171 | __constants__ = ['m_sqaure'] 172 | m_sqaure: float 173 | 174 | def __init__(self, m_sqaure: float = 1.0): 175 | super(Hirose, self).__init__() 176 | self.m_sqaure = m_sqaure 177 | 178 | def forward(self, input: Tensor) -> Tensor: 179 | return cF.hirose(input, m_sqaure=self.m_sqaure) 180 | 181 | class Sigmoid(Module): 182 | def forward(self, input: Tensor) -> Tensor: 183 | return cF.sigmoid(input) 184 | 185 | class modSigmoid(Module): 186 | ''' 187 | Wolter and Yao. Complex Gated Recurrent Neural Networks. NeurIPS 2018. (Eq. 13) https://papers.nips.cc/paper_files/paper/2018/file/652cf38361a209088302ba2b8b7f51e0-Paper.pdf 188 | and 189 | Xie et al. Complex Recurrent Variational Autoencoder with Application to Speech Enhancement. 2023. arXiv:2204.02195v2 190 | ''' 191 | __constants__ = ['alpha'] 192 | alpha: float 193 | 194 | def __init__(self, alpha: float = 0.5): 195 | super(modSigmoid, self).__init__() 196 | assert alpha >= 0.0 and alpha <= 1.0, "alpha must be between 0 and 1" 197 | self.alpha = alpha 198 | 199 | def forward(self, input: Tensor) -> Tensor: 200 | return cF.modsigmoid(input, alpha=self.alpha) -------------------------------------------------------------------------------- /torchcomplex/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.modules import Module 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import ParameterList 7 | from .. import functional as cF 8 | from torch.nn import init 9 | 10 | class _NormBase(Module): 11 | """Common base of _InstanceNorm and _BatchNorm 12 | Few of the paramters are from : https://github.com/ivannz/cplxmodule/blob/master/cplxmodule/nn/modules/batchnorm.py 13 | """ 14 | _version = 2 15 | __constants__ = ['track_running_stats', 'momentum', 'eps', 16 | 'num_features', 'affine'] 17 | num_features: int 18 | eps: float 19 | momentum: float 20 | affine: bool 21 | track_running_stats: bool 22 | 23 | def __init__( 24 | self, 25 | num_features: int, 26 | eps: float = 1e-5, 27 | momentum: float = 0.1, 28 | affine: bool = True, 29 | track_running_stats: bool = True, 30 | naive = False, 31 | complex_weights = True 32 | ) -> None: 33 | super(_NormBase, self).__init__() 34 | self.num_features = num_features 35 | self.eps = eps 36 | self.momentum = momentum 37 | self.affine = affine 38 | self.track_running_stats = track_running_stats 39 | self.naive = naive 40 | self.complex_weights = complex_weights 41 | if naive: 42 | if self.affine: 43 | if complex_weights: 44 | self.weight = Parameter(torch.Tensor(num_features).to(torch.cfloat)) 45 | self.bias = Parameter(torch.Tensor(num_features).to(torch.cfloat)) 46 | else: 47 | self.weight = ParameterList([Parameter(torch.Tensor(num_features)), Parameter(torch.Tensor(num_features))]) 48 | self.bias = ParameterList([Parameter(torch.Tensor(num_features)), Parameter(torch.Tensor(num_features))]) 49 | else: 50 | self.register_parameter('weight', None) 51 | self.register_parameter('bias', None) 52 | if self.track_running_stats: 53 | self.register_buffer('running_mean', torch.zeros(2, num_features)) 54 | self.register_buffer('running_var', torch.ones(2, num_features)) 55 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 56 | else: 57 | self.register_parameter('running_mean', None) 58 | self.register_parameter('running_var', None) 59 | self.register_parameter('num_batches_tracked', None) 60 | else: 61 | if self.affine: 62 | self.weight = Parameter(torch.empty(2, 2, num_features)) 63 | self.bias = Parameter(torch.empty(2, num_features)) 64 | else: 65 | self.register_parameter('weight', None) 66 | self.register_parameter('bias', None) 67 | if self.track_running_stats: 68 | self.register_buffer('running_mean', torch.empty(2, num_features)) 69 | self.register_buffer('running_var', torch.empty(2, 2, num_features)) 70 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 71 | else: 72 | self.register_parameter('running_mean', None) 73 | self.register_parameter('running_var', None) 74 | self.register_parameter('num_batches_tracked', None) 75 | self.reset_parameters() 76 | 77 | def reset_running_stats(self) -> None: 78 | if self.naive: 79 | if self.track_running_stats: 80 | # running_mean/running_var/num_batches... are registerd at runtime depending 81 | # if self.track_running_stats is on 82 | self.running_mean.zero_() # type: ignore[operator] 83 | self.running_var.fill_(1) # type: ignore[operator] 84 | self.num_batches_tracked.zero_() # type: ignore[operator] 85 | else: 86 | if self.track_running_stats: 87 | # running_mean/running_var/num_batches... are registerd at runtime depending 88 | # if self.track_running_stats is on 89 | self.running_mean.zero_() # type: ignore[operator] 90 | self.running_var[0, 0].fill_(1) 91 | self.running_var[1, 0].zero_() 92 | self.running_var[0, 1].zero_() 93 | self.running_var[1, 1].fill_(1) 94 | self.num_batches_tracked.zero_() # type: ignore[operator] 95 | 96 | def reset_parameters(self) -> None: 97 | self.reset_running_stats() 98 | if self.naive: 99 | if self.affine: 100 | if self.complex_weights: 101 | init.ones_(self.weight) 102 | init.zeros_(self.bias) 103 | else: 104 | init.ones_(self.weight[0]) 105 | init.zeros_(self.bias[0]) 106 | init.ones_(self.weight[1]) 107 | init.zeros_(self.bias[1]) 108 | else: 109 | if self.affine: 110 | init.ones_(self.weight[0, 0]) 111 | init.zeros_(self.weight[1, 0]) 112 | init.zeros_(self.weight[0, 1]) 113 | init.ones_(self.weight[1, 1]) 114 | init.zeros_(self.bias) 115 | 116 | def _check_input_dim(self, input): 117 | raise NotImplementedError 118 | 119 | def extra_repr(self): 120 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 121 | 'track_running_stats={track_running_stats}'.format(**vars(self)) 122 | 123 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 124 | missing_keys, unexpected_keys, error_msgs): 125 | version = local_metadata.get('version', None) 126 | 127 | if (version is None or version < 2) and self.track_running_stats: 128 | # at version 2: added num_batches_tracked buffer 129 | # this should have a default value of 0 130 | num_batches_tracked_key = prefix + 'num_batches_tracked' 131 | if num_batches_tracked_key not in state_dict: 132 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 133 | 134 | super(_NormBase, self)._load_from_state_dict( 135 | state_dict, prefix, local_metadata, strict, 136 | missing_keys, unexpected_keys, error_msgs) 137 | 138 | 139 | class _BatchNorm(_NormBase): 140 | 141 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 142 | track_running_stats=True, naive=False, complex_weights=True): 143 | super(_BatchNorm, self).__init__( 144 | num_features, eps, momentum, affine, track_running_stats, naive, complex_weights) 145 | 146 | def forward(self, input: Tensor) -> Tensor: 147 | self._check_input_dim(input) 148 | 149 | # exponential_average_factor is set to self.momentum 150 | # (when it is available) only so that it gets updated 151 | # in ONNX graph when this node is exported to ONNX. 152 | if self.momentum is None: 153 | exponential_average_factor = 0.0 154 | else: 155 | exponential_average_factor = self.momentum 156 | 157 | if self.training and self.track_running_stats: 158 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 159 | if self.num_batches_tracked is not None: # type: ignore 160 | self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore 161 | if self.momentum is None: # use cumulative moving average 162 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 163 | else: # use exponential moving average 164 | exponential_average_factor = self.momentum 165 | 166 | r""" 167 | Decide whether the mini-batch stats should be used for normalization rather than the buffers. 168 | Mini-batch stats are used in training mode, and in eval mode when buffers are None. 169 | """ 170 | if self.training: 171 | bn_training = True 172 | else: 173 | bn_training = (self.running_mean is None) and (self.running_var is None) 174 | 175 | r""" 176 | Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be 177 | passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are 178 | used for normalization (i.e. in eval mode when buffers are not None). 179 | """ 180 | assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) 181 | assert self.running_var is None or isinstance(self.running_var, torch.Tensor) 182 | return cF.batch_norm(input, 183 | # If buffers are not to be tracked, ensure that they won't be updated 184 | self.running_mean if not self.training or self.track_running_stats else None, 185 | self.running_var if not self.training or self.track_running_stats else None, 186 | self.weight, self.bias, bn_training, exponential_average_factor, self.eps, self.naive) 187 | 188 | 189 | class BatchNorm1d(_BatchNorm): 190 | r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D 191 | inputs with optional additional channel dimension) as described in the paper 192 | `Batch Normalization: Accelerating Deep Network Training by Reducing 193 | Internal Covariate Shift `__ . 194 | 195 | .. math:: 196 | 197 | y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 198 | 199 | The mean and standard-deviation are calculated per-dimension over 200 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 201 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 202 | to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated 203 | via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. 204 | 205 | Also by default, during training this layer keeps running estimates of its 206 | computed mean and variance, which are then used for normalization during 207 | evaluation. The running estimates are kept with a default :attr:`momentum` 208 | of 0.1. 209 | 210 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 211 | keep running estimates, and batch statistics are instead used during 212 | evaluation time as well. 213 | 214 | .. note:: 215 | This :attr:`momentum` argument is different from one used in optimizer 216 | classes and the conventional notion of momentum. Mathematically, the 217 | update rule for running statistics here is 218 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 219 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 220 | new observed value. 221 | 222 | Because the Batch Normalization is done over the `C` dimension, computing statistics 223 | on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. 224 | 225 | Args: 226 | num_features: :math:`C` from an expected input of size 227 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 228 | eps: a value added to the denominator for numerical stability. 229 | Default: 1e-5 230 | momentum: the value used for the running_mean and running_var 231 | computation. Can be set to ``None`` for cumulative moving average 232 | (i.e. simple average). Default: 0.1 233 | affine: a boolean value that when set to ``True``, this module has 234 | learnable affine parameters. Default: ``True`` 235 | track_running_stats: a boolean value that when set to ``True``, this 236 | module tracks the running mean and variance, and when set to ``False``, 237 | this module does not track such statistics, and initializes statistics 238 | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 239 | When these buffers are ``None``, this module always uses batch statistics. 240 | in both training and eval modes. Default: ``True`` 241 | 242 | Shape: 243 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 244 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 245 | 246 | Examples:: 247 | 248 | >>> # With Learnable Parameters 249 | >>> m = nn.BatchNorm1d(100) 250 | >>> # Without Learnable Parameters 251 | >>> m = nn.BatchNorm1d(100, affine=False) 252 | >>> input = torch.randn(20, 100) 253 | >>> output = m(input) 254 | """ 255 | 256 | def _check_input_dim(self, input): 257 | if input.dim() != 2 and input.dim() != 3: 258 | raise ValueError('expected 2D or 3D input (got {}D input)' 259 | .format(input.dim())) 260 | 261 | 262 | class BatchNorm2d(_BatchNorm): 263 | r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs 264 | with additional channel dimension) as described in the paper 265 | `Batch Normalization: Accelerating Deep Network Training by Reducing 266 | Internal Covariate Shift `__ . 267 | 268 | .. math:: 269 | 270 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 271 | 272 | The mean and standard-deviation are calculated per-dimension over 273 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 274 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 275 | to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated 276 | via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. 277 | 278 | Also by default, during training this layer keeps running estimates of its 279 | computed mean and variance, which are then used for normalization during 280 | evaluation. The running estimates are kept with a default :attr:`momentum` 281 | of 0.1. 282 | 283 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 284 | keep running estimates, and batch statistics are instead used during 285 | evaluation time as well. 286 | 287 | .. note:: 288 | This :attr:`momentum` argument is different from one used in optimizer 289 | classes and the conventional notion of momentum. Mathematically, the 290 | update rule for running statistics here is 291 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 292 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 293 | new observed value. 294 | 295 | Because the Batch Normalization is done over the `C` dimension, computing statistics 296 | on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. 297 | 298 | Args: 299 | num_features: :math:`C` from an expected input of size 300 | :math:`(N, C, H, W)` 301 | eps: a value added to the denominator for numerical stability. 302 | Default: 1e-5 303 | momentum: the value used for the running_mean and running_var 304 | computation. Can be set to ``None`` for cumulative moving average 305 | (i.e. simple average). Default: 0.1 306 | affine: a boolean value that when set to ``True``, this module has 307 | learnable affine parameters. Default: ``True`` 308 | track_running_stats: a boolean value that when set to ``True``, this 309 | module tracks the running mean and variance, and when set to ``False``, 310 | this module does not track such statistics, and initializes statistics 311 | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 312 | When these buffers are ``None``, this module always uses batch statistics. 313 | in both training and eval modes. Default: ``True`` 314 | 315 | Shape: 316 | - Input: :math:`(N, C, H, W)` 317 | - Output: :math:`(N, C, H, W)` (same shape as input) 318 | 319 | Examples:: 320 | 321 | >>> # With Learnable Parameters 322 | >>> m = nn.BatchNorm2d(100) 323 | >>> # Without Learnable Parameters 324 | >>> m = nn.BatchNorm2d(100, affine=False) 325 | >>> input = torch.randn(20, 100, 35, 45) 326 | >>> output = m(input) 327 | """ 328 | 329 | def _check_input_dim(self, input): 330 | if input.dim() != 4: 331 | raise ValueError('expected 4D input (got {}D input)' 332 | .format(input.dim())) 333 | 334 | 335 | class BatchNorm3d(_BatchNorm): 336 | r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs 337 | with additional channel dimension) as described in the paper 338 | `Batch Normalization: Accelerating Deep Network Training by Reducing 339 | Internal Covariate Shift `__ . 340 | 341 | .. math:: 342 | 343 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 344 | 345 | The mean and standard-deviation are calculated per-dimension over 346 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 347 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 348 | to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated 349 | via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. 350 | 351 | Also by default, during training this layer keeps running estimates of its 352 | computed mean and variance, which are then used for normalization during 353 | evaluation. The running estimates are kept with a default :attr:`momentum` 354 | of 0.1. 355 | 356 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 357 | keep running estimates, and batch statistics are instead used during 358 | evaluation time as well. 359 | 360 | .. note:: 361 | This :attr:`momentum` argument is different from one used in optimizer 362 | classes and the conventional notion of momentum. Mathematically, the 363 | update rule for running statistics here is 364 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 365 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 366 | new observed value. 367 | 368 | Because the Batch Normalization is done over the `C` dimension, computing statistics 369 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization 370 | or Spatio-temporal Batch Normalization. 371 | 372 | Args: 373 | num_features: :math:`C` from an expected input of size 374 | :math:`(N, C, D, H, W)` 375 | eps: a value added to the denominator for numerical stability. 376 | Default: 1e-5 377 | momentum: the value used for the running_mean and running_var 378 | computation. Can be set to ``None`` for cumulative moving average 379 | (i.e. simple average). Default: 0.1 380 | affine: a boolean value that when set to ``True``, this module has 381 | learnable affine parameters. Default: ``True`` 382 | track_running_stats: a boolean value that when set to ``True``, this 383 | module tracks the running mean and variance, and when set to ``False``, 384 | this module does not track such statistics, and initializes statistics 385 | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 386 | When these buffers are ``None``, this module always uses batch statistics. 387 | in both training and eval modes. Default: ``True`` 388 | 389 | Shape: 390 | - Input: :math:`(N, C, D, H, W)` 391 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 392 | 393 | Examples:: 394 | 395 | >>> # With Learnable Parameters 396 | >>> m = nn.BatchNorm3d(100) 397 | >>> # Without Learnable Parameters 398 | >>> m = nn.BatchNorm3d(100, affine=False) 399 | >>> input = torch.randn(20, 100, 35, 45, 10) 400 | >>> output = m(input) 401 | """ 402 | 403 | def _check_input_dim(self, input): 404 | if input.dim() != 5: 405 | raise ValueError('expected 5D input (got {}D input)' 406 | .format(input.dim())) -------------------------------------------------------------------------------- /torchcomplex/nn/modules/conv.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import math 3 | import warnings 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn.parameter import Parameter 8 | from torch.nn import ParameterList 9 | from torch.nn import functional as F 10 | from .. import functional as cF 11 | from torch.nn import init 12 | from torch.nn.modules import Module 13 | from torch.nn.modules.utils import _single, _pair, _triple, _reverse_repeat_tuple 14 | # from ..._torch_docs import reproducibility_notes 15 | 16 | from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t 17 | from typing import Optional, List, Tuple, Union 18 | 19 | convolution_notes = \ 20 | {"groups_note": """* :attr:`groups` controls the connections between inputs and outputs. 21 | :attr:`in_channels` and :attr:`out_channels` must both be divisible by 22 | :attr:`groups`. For example, 23 | 24 | * At groups=1, all inputs are convolved to all outputs. 25 | * At groups=2, the operation becomes equivalent to having two conv 26 | layers side by side, each seeing half the input channels 27 | and producing half the output channels, and both subsequently 28 | concatenated. 29 | * At groups= :attr:`in_channels`, each input channel is convolved with 30 | its own set of filters (of size 31 | :math:`\\frac{\\text{out\_channels}}{\\text{in\_channels}}`).""", # noqa: W605 32 | 33 | "depthwise_separable_note": """When `groups == in_channels` and `out_channels == K * in_channels`, 34 | where `K` is a positive integer, this operation is also known as a "depthwise convolution". 35 | 36 | In other words, for an input of size :math:`(N, C_{in}, L_{in})`, 37 | a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments 38 | :math:`(C_\\text{in}=C_\\text{in}, C_\\text{out}=C_\\text{in} \\times \\text{K}, ..., \\text{groups}=C_\\text{in})`."""} # noqa: W605 39 | 40 | 41 | 42 | class _ConvNd(Module): 43 | 44 | __constants__ = ['stride', 'padding', 'dilation', 'groups', 45 | 'padding_mode', 'output_padding', 'in_channels', 46 | 'out_channels', 'kernel_size'] 47 | __annotations__ = {'bias': Optional[torch.Tensor]} 48 | 49 | _in_channels: int 50 | out_channels: int 51 | kernel_size: Tuple[int, ...] 52 | stride: Tuple[int, ...] 53 | padding: Tuple[int, ...] 54 | dilation: Tuple[int, ...] 55 | transposed: bool 56 | output_padding: Tuple[int, ...] 57 | groups: int 58 | padding_mode: str 59 | complex_weights: bool 60 | weight: Union[Tensor, Tuple[Tensor, Tensor]] 61 | bias: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] 62 | 63 | def __init__(self, 64 | in_channels: int, 65 | out_channels: int, 66 | kernel_size: _size_1_t, 67 | stride: _size_1_t, 68 | padding: _size_1_t, 69 | dilation: _size_1_t, 70 | transposed: bool, 71 | output_padding: _size_1_t, 72 | groups: int, 73 | bias: bool, 74 | padding_mode: str, 75 | complex_weights=True) -> None: 76 | super(_ConvNd, self).__init__() 77 | if in_channels % groups != 0: 78 | raise ValueError('in_channels must be divisible by groups') 79 | if out_channels % groups != 0: 80 | raise ValueError('out_channels must be divisible by groups') 81 | valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} 82 | if padding_mode not in valid_padding_modes: 83 | raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format( 84 | valid_padding_modes, padding_mode)) 85 | self.in_channels = in_channels 86 | self.out_channels = out_channels 87 | self.kernel_size = kernel_size 88 | self.stride = stride 89 | self.padding = padding 90 | self.dilation = dilation 91 | self.transposed = transposed 92 | self.output_padding = output_padding 93 | self.groups = groups 94 | self.padding_mode = padding_mode 95 | self.complex_weights = complex_weights 96 | # `_reversed_padding_repeated_twice` is the padding to be passed to 97 | # `F.pad` if needed (e.g., for non-zero padding types that are 98 | # implemented as two ops: padding + conv). `F.pad` accepts paddings in 99 | # reverse order than the dimension. 100 | self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) 101 | 102 | 103 | if complex_weights: 104 | if transposed: 105 | self.weight = Parameter(torch.Tensor( 106 | in_channels, out_channels // groups, *kernel_size).to(torch.cfloat)) 107 | else: 108 | self.weight = Parameter(torch.Tensor( 109 | out_channels, in_channels // groups, *kernel_size).to(torch.cfloat)) 110 | else: 111 | if transposed: 112 | weight_real = Parameter(torch.Tensor( 113 | in_channels, out_channels // groups, *kernel_size)) 114 | weight_imag = Parameter(torch.Tensor( 115 | in_channels, out_channels // groups, *kernel_size)) 116 | else: 117 | weight_real = Parameter(torch.Tensor( 118 | out_channels, in_channels // groups, *kernel_size)) 119 | weight_imag = Parameter(torch.Tensor( 120 | out_channels, in_channels // groups, *kernel_size)) 121 | self.weight = ParameterList([weight_real, weight_imag]) 122 | 123 | 124 | if bias: 125 | if complex_weights: 126 | self.bias = Parameter(torch.Tensor(out_channels).to(torch.cfloat)) 127 | else: 128 | bias_real = Parameter(torch.Tensor(out_channels)) 129 | bias_imag = Parameter(torch.Tensor(out_channels)) 130 | self.bias = ParameterList([bias_real, bias_imag]) 131 | else: 132 | self.register_parameter('bias', None) 133 | self.reset_parameters() 134 | 135 | def _reset_parameters(self, weight, bias) -> None: 136 | init.kaiming_uniform_(weight, a=math.sqrt(5)) 137 | if bias is not None: 138 | fan_in, _ = init._calculate_fan_in_and_fan_out(weight) 139 | bound = 1 / math.sqrt(fan_in) 140 | init.uniform_(bias, -bound, bound) 141 | 142 | def reset_parameters(self) -> None: 143 | if type(self.weight) is ParameterList: 144 | self._reset_parameters(self.weight[0], None if self.bias is None else self.bias[0]) 145 | self._reset_parameters(self.weight[1], None if self.bias is None else self.bias[1]) 146 | else: 147 | self._reset_parameters(self.weight, self.bias) 148 | 149 | def extra_repr(self): 150 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 151 | ', stride={stride}') 152 | if self.padding != (0,) * len(self.padding): 153 | s += ', padding={padding}' 154 | if self.dilation != (1,) * len(self.dilation): 155 | s += ', dilation={dilation}' 156 | if self.output_padding != (0,) * len(self.output_padding): 157 | s += ', output_padding={output_padding}' 158 | if self.groups != 1: 159 | s += ', groups={groups}' 160 | if self.bias is None: 161 | s += ', bias=False' 162 | if self.padding_mode != 'zeros': 163 | s += ', padding_mode={padding_mode}' 164 | return s.format(**self.__dict__) 165 | 166 | def __setstate__(self, state): 167 | super(_ConvNd, self).__setstate__(state) 168 | if not hasattr(self, 'padding_mode'): 169 | self.padding_mode = 'zeros' 170 | 171 | 172 | class Conv1d(_ConvNd): 173 | __doc__ = r"""Applies a 1D convolution over an input signal composed of several input 174 | planes. 175 | 176 | In the simplest case, the output value of the layer with input size 177 | :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be 178 | precisely described as: 179 | 180 | .. math:: 181 | \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + 182 | \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) 183 | \star \text{input}(N_i, k) 184 | 185 | where :math:`\star` is the valid `cross-correlation`_ operator, 186 | :math:`N` is a batch size, :math:`C` denotes a number of channels, 187 | :math:`L` is a length of signal sequence. 188 | """ + r""" 189 | 190 | This module supports :ref:`TensorFloat32`. 191 | 192 | * :attr:`stride` controls the stride for the cross-correlation, a single 193 | number or a one-element tuple. 194 | 195 | * :attr:`padding` controls the amount of implicit padding on both sides 196 | for :attr:`padding` number of points. 197 | 198 | * :attr:`dilation` controls the spacing between the kernel points; also 199 | known as the à trous algorithm. It is harder to describe, but this `link`_ 200 | has a nice visualization of what :attr:`dilation` does. 201 | 202 | {groups_note} 203 | 204 | Note: 205 | {depthwise_separable_note} 206 | Note: 207 | {cudnn_reproducibility_note} 208 | 209 | Args: 210 | in_channels (int): Number of channels in the input image 211 | out_channels (int): Number of channels produced by the convolution 212 | kernel_size (int or tuple): Size of the convolving kernel 213 | stride (int or tuple, optional): Stride of the convolution. Default: 1 214 | padding (int or tuple, optional): Zero-padding added to both sides of 215 | the input. Default: 0 216 | padding_mode (string, optional): ``'zeros'``, ``'reflect'``, 217 | ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` 218 | dilation (int or tuple, optional): Spacing between kernel 219 | elements. Default: 1 220 | groups (int, optional): Number of blocked connections from input 221 | channels to output channels. Default: 1 222 | bias (bool, optional): If ``True``, adds a learnable bias to the 223 | output. Default: ``True`` 224 | 225 | 226 | 227 | Shape: 228 | - Input: :math:`(N, C_{in}, L_{in})` 229 | - Output: :math:`(N, C_{out}, L_{out})` where 230 | 231 | .. math:: 232 | L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} 233 | \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor 234 | 235 | Attributes: 236 | weight (Tensor): the learnable weights of the module of shape 237 | :math:`(\text{out\_channels}, 238 | \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. 239 | The values of these weights are sampled from 240 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 241 | :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` 242 | bias (Tensor): the learnable bias of the module of shape 243 | (out_channels). If :attr:`bias` is ``True``, then the values of these weights are 244 | sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 245 | :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` 246 | 247 | Examples:: 248 | 249 | >>> m = nn.Conv1d(16, 33, 3, stride=2) 250 | >>> input = torch.randn(20, 16, 50) 251 | >>> output = m(input) 252 | 253 | .. _cross-correlation: 254 | https://en.wikipedia.org/wiki/Cross-correlation 255 | 256 | .. _link: 257 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 258 | """ 259 | 260 | def __init__( 261 | self, 262 | in_channels: int, 263 | out_channels: int, 264 | kernel_size: _size_1_t, 265 | stride: _size_1_t = 1, 266 | padding: _size_1_t = 0, 267 | dilation: _size_1_t = 1, 268 | groups: int = 1, 269 | bias: bool = True, 270 | padding_mode: str = 'zeros', 271 | complex_weights = True 272 | ): 273 | kernel_size = _single(kernel_size) 274 | stride = _single(stride) 275 | padding = _single(padding) 276 | dilation = _single(dilation) 277 | super(Conv1d, self).__init__( 278 | in_channels, out_channels, kernel_size, stride, padding, dilation, 279 | False, _single(0), groups, bias, padding_mode, complex_weights) 280 | 281 | def _conv_forward(self, input, weight): 282 | if self.padding_mode != 'zeros': 283 | return cF.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 284 | weight, self.bias, self.stride, 285 | _single(0), self.dilation, self.groups) 286 | return cF.conv1d(input, weight, self.bias, self.stride, 287 | self.padding, self.dilation, self.groups) 288 | 289 | def forward(self, input: Tensor) -> Tensor: 290 | return self._conv_forward(input, self.weight) 291 | 292 | 293 | class Conv2d(_ConvNd): 294 | __doc__ = r"""Applies a 2D convolution over an input signal composed of several input 295 | planes. 296 | 297 | In the simplest case, the output value of the layer with input size 298 | :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` 299 | can be precisely described as: 300 | 301 | .. math:: 302 | \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + 303 | \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) 304 | 305 | 306 | where :math:`\star` is the valid 2D `cross-correlation`_ operator, 307 | :math:`N` is a batch size, :math:`C` denotes a number of channels, 308 | :math:`H` is a height of input planes in pixels, and :math:`W` is 309 | width in pixels. 310 | """ + r""" 311 | 312 | This module supports :ref:`TensorFloat32`. 313 | 314 | * :attr:`stride` controls the stride for the cross-correlation, a single 315 | number or a tuple. 316 | 317 | * :attr:`padding` controls the amount of implicit padding on both 318 | sides for :attr:`padding` number of points for each dimension. 319 | 320 | * :attr:`dilation` controls the spacing between the kernel points; also 321 | known as the à trous algorithm. It is harder to describe, but this `link`_ 322 | has a nice visualization of what :attr:`dilation` does. 323 | 324 | {groups_note} 325 | 326 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: 327 | 328 | - a single ``int`` -- in which case the same value is used for the height and width dimension 329 | - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, 330 | and the second `int` for the width dimension 331 | 332 | Note: 333 | {depthwise_separable_note} 334 | 335 | Note: 336 | {cudnn_reproducibility_note} 337 | 338 | Args: 339 | in_channels (int): Number of channels in the input image 340 | out_channels (int): Number of channels produced by the convolution 341 | kernel_size (int or tuple): Size of the convolving kernel 342 | stride (int or tuple, optional): Stride of the convolution. Default: 1 343 | padding (int or tuple, optional): Zero-padding added to both sides of 344 | the input. Default: 0 345 | padding_mode (string, optional): ``'zeros'``, ``'reflect'``, 346 | ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` 347 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 348 | groups (int, optional): Number of blocked connections from input 349 | channels to output channels. Default: 1 350 | bias (bool, optional): If ``True``, adds a learnable bias to the 351 | output. Default: ``True`` 352 | 353 | 354 | Shape: 355 | - Input: :math:`(N, C_{in}, H_{in}, W_{in})` 356 | - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where 357 | 358 | .. math:: 359 | H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] 360 | \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor 361 | 362 | .. math:: 363 | W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] 364 | \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor 365 | 366 | Attributes: 367 | weight (Tensor): the learnable weights of the module of shape 368 | :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` 369 | :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. 370 | The values of these weights are sampled from 371 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 372 | :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` 373 | bias (Tensor): the learnable bias of the module of shape 374 | (out_channels). If :attr:`bias` is ``True``, 375 | then the values of these weights are 376 | sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 377 | :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` 378 | 379 | Examples: 380 | 381 | >>> # With square kernels and equal stride 382 | >>> m = nn.Conv2d(16, 33, 3, stride=2) 383 | >>> # non-square kernels and unequal stride and with padding 384 | >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) 385 | >>> # non-square kernels and unequal stride and with padding and dilation 386 | >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) 387 | >>> input = torch.randn(20, 16, 50, 100) 388 | >>> output = m(input) 389 | 390 | .. _cross-correlation: 391 | https://en.wikipedia.org/wiki/Cross-correlation 392 | 393 | .. _link: 394 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 395 | """ 396 | 397 | def __init__( 398 | self, 399 | in_channels: int, 400 | out_channels: int, 401 | kernel_size: _size_2_t, 402 | stride: _size_2_t = 1, 403 | padding: _size_2_t = 0, 404 | dilation: _size_2_t = 1, 405 | groups: int = 1, 406 | bias: bool = True, 407 | padding_mode: str = 'zeros', 408 | complex_weights = True 409 | ): 410 | kernel_size = _pair(kernel_size) 411 | stride = _pair(stride) 412 | padding = _pair(padding) 413 | dilation = _pair(dilation) 414 | super(Conv2d, self).__init__( 415 | in_channels, out_channels, kernel_size, stride, padding, dilation, 416 | False, _pair(0), groups, bias, padding_mode, complex_weights) 417 | 418 | def _conv_forward(self, input, weight): 419 | if self.padding_mode != 'zeros': 420 | return cF.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 421 | weight, self.bias, self.stride, 422 | _pair(0), self.dilation, self.groups) 423 | return cF.conv2d(input, weight, self.bias, self.stride, 424 | self.padding, self.dilation, self.groups) 425 | 426 | def forward(self, input: Tensor) -> Tensor: 427 | return self._conv_forward(input, self.weight) 428 | 429 | class Conv3d(_ConvNd): 430 | __doc__ = r"""Applies a 3D convolution over an input signal composed of several input 431 | planes. 432 | 433 | In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` 434 | and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as: 435 | 436 | .. math:: 437 | out(N_i, C_{out_j}) = bias(C_{out_j}) + 438 | \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) 439 | 440 | where :math:`\star` is the valid 3D `cross-correlation`_ operator 441 | """ + r""" 442 | 443 | This module supports :ref:`TensorFloat32`. 444 | 445 | * :attr:`stride` controls the stride for the cross-correlation. 446 | 447 | * :attr:`padding` controls the amount of implicit padding on both 448 | sides for :attr:`padding` number of points for each dimension. 449 | 450 | * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. 451 | It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. 452 | 453 | {groups_note} 454 | 455 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: 456 | 457 | - a single ``int`` -- in which case the same value is used for the depth, height and width dimension 458 | - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, 459 | the second `int` for the height dimension and the third `int` for the width dimension 460 | 461 | Note: 462 | {depthwise_separable_note} 463 | 464 | Note: 465 | {cudnn_reproducibility_note} 466 | 467 | Args: 468 | in_channels (int): Number of channels in the input image 469 | out_channels (int): Number of channels produced by the convolution 470 | kernel_size (int or tuple): Size of the convolving kernel 471 | stride (int or tuple, optional): Stride of the convolution. Default: 1 472 | padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 473 | padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` 474 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 475 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 476 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 477 | 478 | 479 | Shape: 480 | - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` 481 | - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where 482 | 483 | .. math:: 484 | D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] 485 | \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor 486 | 487 | .. math:: 488 | H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] 489 | \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor 490 | 491 | .. math:: 492 | W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] 493 | \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor 494 | 495 | Attributes: 496 | weight (Tensor): the learnable weights of the module of shape 497 | :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` 498 | :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. 499 | The values of these weights are sampled from 500 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 501 | :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` 502 | bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, 503 | then the values of these weights are 504 | sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 505 | :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` 506 | 507 | Examples:: 508 | 509 | >>> # With square kernels and equal stride 510 | >>> m = nn.Conv3d(16, 33, 3, stride=2) 511 | >>> # non-square kernels and unequal stride and with padding 512 | >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) 513 | >>> input = torch.randn(20, 16, 10, 50, 100) 514 | >>> output = m(input) 515 | 516 | .. _cross-correlation: 517 | https://en.wikipedia.org/wiki/Cross-correlation 518 | 519 | .. _link: 520 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 521 | """ 522 | 523 | def __init__( 524 | self, 525 | in_channels: int, 526 | out_channels: int, 527 | kernel_size: _size_3_t, 528 | stride: _size_3_t = 1, 529 | padding: _size_3_t = 0, 530 | dilation: _size_3_t = 1, 531 | groups: int = 1, 532 | bias: bool = True, 533 | padding_mode: str = 'zeros', 534 | complex_weights = True 535 | ): 536 | kernel_size = _triple(kernel_size) 537 | stride = _triple(stride) 538 | padding = _triple(padding) 539 | dilation = _triple(dilation) 540 | super(Conv3d, self).__init__( 541 | in_channels, out_channels, kernel_size, stride, padding, dilation, 542 | False, _triple(0), groups, bias, padding_mode, complex_weights) 543 | 544 | def forward(self, input: Tensor) -> Tensor: 545 | if self.padding_mode != 'zeros': 546 | return cF.conv3d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 547 | self.weight, self.bias, self.stride, _triple(0), 548 | self.dilation, self.groups) 549 | return cF.conv3d(input, self.weight, self.bias, self.stride, 550 | self.padding, self.dilation, self.groups) 551 | 552 | 553 | class _ConvTransposeNd(_ConvNd): 554 | def __init__(self, in_channels, out_channels, kernel_size, stride, 555 | padding, dilation, transposed, output_padding, 556 | groups, bias, padding_mode, complex_weights): 557 | if padding_mode != 'zeros': 558 | raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__)) 559 | 560 | super(_ConvTransposeNd, self).__init__( 561 | in_channels, out_channels, kernel_size, stride, 562 | padding, dilation, transposed, output_padding, 563 | groups, bias, padding_mode, complex_weights) 564 | 565 | # dilation being an optional parameter is for backwards 566 | # compatibility 567 | def _output_padding(self, input, output_size, stride, padding, kernel_size, dilation=None): 568 | # type: (Tensor, Optional[List[int]], List[int], List[int], List[int], Optional[List[int]]) -> List[int] 569 | if output_size is None: 570 | ret = _single(self.output_padding) # converting to list if was not already 571 | else: 572 | k = input.dim() - 2 573 | if len(output_size) == k + 2: 574 | output_size = output_size[2:] 575 | if len(output_size) != k: 576 | raise ValueError( 577 | "output_size must have {} or {} elements (got {})" 578 | .format(k, k + 2, len(output_size))) 579 | 580 | min_sizes = torch.jit.annotate(List[int], []) 581 | max_sizes = torch.jit.annotate(List[int], []) 582 | for d in range(k): 583 | dim_size = ((input.size(d + 2) - 1) * stride[d] - 584 | 2 * padding[d] + 585 | (dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + 1) 586 | min_sizes.append(dim_size) 587 | max_sizes.append(min_sizes[d] + stride[d] - 1) 588 | 589 | for i in range(len(output_size)): 590 | size = output_size[i] 591 | min_size = min_sizes[i] 592 | max_size = max_sizes[i] 593 | if size < min_size or size > max_size: 594 | raise ValueError(( 595 | "requested an output size of {}, but valid sizes range " 596 | "from {} to {} (for an input of {})").format( 597 | output_size, min_sizes, max_sizes, input.size()[2:])) 598 | 599 | res = torch.jit.annotate(List[int], []) 600 | for d in range(k): 601 | res.append(output_size[d] - min_sizes[d]) 602 | 603 | ret = res 604 | return ret 605 | 606 | 607 | class ConvTranspose1d(_ConvTransposeNd): 608 | __doc__ = r"""Applies a 1D transposed convolution operator over an input image 609 | composed of several input planes. 610 | 611 | This module can be seen as the gradient of Conv1d with respect to its input. 612 | It is also known as a fractionally-strided convolution or 613 | a deconvolution (although it is not an actual deconvolution operation). 614 | 615 | This module supports :ref:`TensorFloat32`. 616 | 617 | * :attr:`stride` controls the stride for the cross-correlation. 618 | 619 | * :attr:`padding` controls the amount of implicit zero padding on both 620 | sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note 621 | below for details. 622 | 623 | * :attr:`output_padding` controls the additional size added to one side 624 | of the output shape. See note below for details. 625 | 626 | * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. 627 | It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. 628 | 629 | {groups_note} 630 | 631 | Note: 632 | The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` 633 | amount of zero padding to both sizes of the input. This is set so that 634 | when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d` 635 | are initialized with same parameters, they are inverses of each other in 636 | regard to the input and output shapes. However, when ``stride > 1``, 637 | :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output 638 | shape. :attr:`output_padding` is provided to resolve this ambiguity by 639 | effectively increasing the calculated output shape on one side. Note 640 | that :attr:`output_padding` is only used to find output shape, but does 641 | not actually add zero-padding to output. 642 | 643 | Note: 644 | In some circumstances when using the CUDA backend with CuDNN, this operator 645 | may select a nondeterministic algorithm to increase performance. If this is 646 | undesirable, you can try to make the operation deterministic (potentially at 647 | a performance cost) by setting ``torch.backends.cudnn.deterministic = 648 | True``. 649 | Please see the notes on :doc:`/notes/randomness` for background. 650 | 651 | 652 | Args: 653 | in_channels (int): Number of channels in the input image 654 | out_channels (int): Number of channels produced by the convolution 655 | kernel_size (int or tuple): Size of the convolving kernel 656 | stride (int or tuple, optional): Stride of the convolution. Default: 1 657 | padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding 658 | will be added to both sides of the input. Default: 0 659 | output_padding (int or tuple, optional): Additional size added to one side 660 | of the output shape. Default: 0 661 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 662 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 663 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 664 | 665 | 666 | Shape: 667 | - Input: :math:`(N, C_{in}, L_{in})` 668 | - Output: :math:`(N, C_{out}, L_{out})` where 669 | 670 | .. math:: 671 | L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation} 672 | \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1 673 | 674 | Attributes: 675 | weight (Tensor): the learnable weights of the module of shape 676 | :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` 677 | :math:`\text{kernel\_size})`. 678 | The values of these weights are sampled from 679 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 680 | :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}` 681 | bias (Tensor): the learnable bias of the module of shape (out_channels). 682 | If :attr:`bias` is ``True``, then the values of these weights are 683 | sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 684 | :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}` 685 | 686 | .. _cross-correlation: 687 | https://en.wikipedia.org/wiki/Cross-correlation 688 | 689 | .. _link: 690 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 691 | """ 692 | 693 | def __init__( 694 | self, 695 | in_channels: int, 696 | out_channels: int, 697 | kernel_size: _size_1_t, 698 | stride: _size_1_t = 1, 699 | padding: _size_1_t = 0, 700 | output_padding: _size_1_t = 0, 701 | groups: int = 1, 702 | bias: bool = True, 703 | dilation: _size_1_t = 1, 704 | padding_mode: str = 'zeros', 705 | complex_weights = True 706 | ): 707 | kernel_size = _single(kernel_size) 708 | stride = _single(stride) 709 | padding = _single(padding) 710 | dilation = _single(dilation) 711 | output_padding = _single(output_padding) 712 | super(ConvTranspose1d, self).__init__( 713 | in_channels, out_channels, kernel_size, stride, padding, dilation, 714 | True, output_padding, groups, bias, padding_mode, complex_weights) 715 | 716 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 717 | if self.padding_mode != 'zeros': 718 | raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') 719 | 720 | output_padding = self._output_padding( 721 | input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) 722 | return cF.conv_transpose1d( 723 | input, self.weight, self.bias, self.stride, self.padding, 724 | output_padding, self.groups, self.dilation) 725 | 726 | 727 | class ConvTranspose2d(_ConvTransposeNd): 728 | __doc__ = r"""Applies a 2D transposed convolution operator over an input image 729 | composed of several input planes. 730 | 731 | This module can be seen as the gradient of Conv2d with respect to its input. 732 | It is also known as a fractionally-strided convolution or 733 | a deconvolution (although it is not an actual deconvolution operation). 734 | 735 | This module supports :ref:`TensorFloat32`. 736 | 737 | * :attr:`stride` controls the stride for the cross-correlation. 738 | 739 | * :attr:`padding` controls the amount of implicit zero padding on both 740 | sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note 741 | below for details. 742 | 743 | * :attr:`output_padding` controls the additional size added to one side 744 | of the output shape. See note below for details. 745 | 746 | * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. 747 | It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. 748 | 749 | {groups_note} 750 | 751 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` 752 | can either be: 753 | 754 | - a single ``int`` -- in which case the same value is used for the height and width dimensions 755 | - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, 756 | and the second `int` for the width dimension 757 | 758 | Note: 759 | The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` 760 | amount of zero padding to both sizes of the input. This is set so that 761 | when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d` 762 | are initialized with same parameters, they are inverses of each other in 763 | regard to the input and output shapes. However, when ``stride > 1``, 764 | :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output 765 | shape. :attr:`output_padding` is provided to resolve this ambiguity by 766 | effectively increasing the calculated output shape on one side. Note 767 | that :attr:`output_padding` is only used to find output shape, but does 768 | not actually add zero-padding to output. 769 | 770 | Note: 771 | {cudnn_reproducibility_note} 772 | 773 | Args: 774 | in_channels (int): Number of channels in the input image 775 | out_channels (int): Number of channels produced by the convolution 776 | kernel_size (int or tuple): Size of the convolving kernel 777 | stride (int or tuple, optional): Stride of the convolution. Default: 1 778 | padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding 779 | will be added to both sides of each dimension in the input. Default: 0 780 | output_padding (int or tuple, optional): Additional size added to one side 781 | of each dimension in the output shape. Default: 0 782 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 783 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 784 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 785 | 786 | 787 | Shape: 788 | - Input: :math:`(N, C_{in}, H_{in}, W_{in})` 789 | - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where 790 | 791 | .. math:: 792 | H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] 793 | \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 794 | .. math:: 795 | W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] 796 | \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 797 | 798 | Attributes: 799 | weight (Tensor): the learnable weights of the module of shape 800 | :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` 801 | :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. 802 | The values of these weights are sampled from 803 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 804 | :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` 805 | bias (Tensor): the learnable bias of the module of shape (out_channels) 806 | If :attr:`bias` is ``True``, then the values of these weights are 807 | sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 808 | :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` 809 | 810 | Examples:: 811 | 812 | >>> # With square kernels and equal stride 813 | >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2) 814 | >>> # non-square kernels and unequal stride and with padding 815 | >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) 816 | >>> input = torch.randn(20, 16, 50, 100) 817 | >>> output = m(input) 818 | >>> # exact output size can be also specified as an argument 819 | >>> input = torch.randn(1, 16, 12, 12) 820 | >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1) 821 | >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) 822 | >>> h = downsample(input) 823 | >>> h.size() 824 | torch.Size([1, 16, 6, 6]) 825 | >>> output = upsample(h, output_size=input.size()) 826 | >>> output.size() 827 | torch.Size([1, 16, 12, 12]) 828 | 829 | .. _cross-correlation: 830 | https://en.wikipedia.org/wiki/Cross-correlation 831 | 832 | .. _link: 833 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 834 | """ 835 | 836 | def __init__( 837 | self, 838 | in_channels: int, 839 | out_channels: int, 840 | kernel_size: _size_2_t, 841 | stride: _size_2_t = 1, 842 | padding: _size_2_t = 0, 843 | output_padding: _size_2_t = 0, 844 | groups: int = 1, 845 | bias: bool = True, 846 | dilation: int = 1, 847 | padding_mode: str = 'zeros', 848 | complex_weights = True 849 | ): 850 | kernel_size = _pair(kernel_size) 851 | stride = _pair(stride) 852 | padding = _pair(padding) 853 | dilation = _pair(dilation) 854 | output_padding = _pair(output_padding) 855 | super(ConvTranspose2d, self).__init__( 856 | in_channels, out_channels, kernel_size, stride, padding, dilation, 857 | True, output_padding, groups, bias, padding_mode, complex_weights) 858 | 859 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 860 | if self.padding_mode != 'zeros': 861 | raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') 862 | 863 | output_padding = self._output_padding( 864 | input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) 865 | 866 | return cF.conv_transpose2d( 867 | input, self.weight, self.bias, self.stride, self.padding, 868 | output_padding, self.groups, self.dilation) 869 | 870 | 871 | class ConvTranspose3d(_ConvTransposeNd): 872 | __doc__ = r"""Applies a 3D transposed convolution operator over an input image composed of several input 873 | planes. 874 | The transposed convolution operator multiplies each input value element-wise by a learnable kernel, 875 | and sums over the outputs from all input feature planes. 876 | 877 | This module can be seen as the gradient of Conv3d with respect to its input. 878 | It is also known as a fractionally-strided convolution or 879 | a deconvolution (although it is not an actual deconvolution operation). 880 | 881 | This module supports :ref:`TensorFloat32`. 882 | 883 | * :attr:`stride` controls the stride for the cross-correlation. 884 | 885 | * :attr:`padding` controls the amount of implicit zero padding on both 886 | sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note 887 | below for details. 888 | 889 | * :attr:`output_padding` controls the additional size added to one side 890 | of the output shape. See note below for details. 891 | 892 | * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. 893 | It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. 894 | 895 | {groups_note} 896 | 897 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` 898 | can either be: 899 | 900 | - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions 901 | - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, 902 | the second `int` for the height dimension and the third `int` for the width dimension 903 | 904 | Note: 905 | The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` 906 | amount of zero padding to both sizes of the input. This is set so that 907 | when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d` 908 | are initialized with same parameters, they are inverses of each other in 909 | regard to the input and output shapes. However, when ``stride > 1``, 910 | :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output 911 | shape. :attr:`output_padding` is provided to resolve this ambiguity by 912 | effectively increasing the calculated output shape on one side. Note 913 | that :attr:`output_padding` is only used to find output shape, but does 914 | not actually add zero-padding to output. 915 | 916 | Note: 917 | {cudnn_reproducibility_note} 918 | 919 | Args: 920 | in_channels (int): Number of channels in the input image 921 | out_channels (int): Number of channels produced by the convolution 922 | kernel_size (int or tuple): Size of the convolving kernel 923 | stride (int or tuple, optional): Stride of the convolution. Default: 1 924 | padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding 925 | will be added to both sides of each dimension in the input. Default: 0 926 | output_padding (int or tuple, optional): Additional size added to one side 927 | of each dimension in the output shape. Default: 0 928 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 929 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 930 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 931 | 932 | 933 | Shape: 934 | - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` 935 | - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where 936 | 937 | .. math:: 938 | D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] 939 | \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 940 | .. math:: 941 | H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] 942 | \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 943 | .. math:: 944 | W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2] 945 | \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1 946 | 947 | 948 | Attributes: 949 | weight (Tensor): the learnable weights of the module of shape 950 | :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` 951 | :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. 952 | The values of these weights are sampled from 953 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 954 | :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` 955 | bias (Tensor): the learnable bias of the module of shape (out_channels) 956 | If :attr:`bias` is ``True``, then the values of these weights are 957 | sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 958 | :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` 959 | 960 | Examples:: 961 | 962 | >>> # With square kernels and equal stride 963 | >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2) 964 | >>> # non-square kernels and unequal stride and with padding 965 | >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2)) 966 | >>> input = torch.randn(20, 16, 10, 50, 100) 967 | >>> output = m(input) 968 | 969 | .. _cross-correlation: 970 | https://en.wikipedia.org/wiki/Cross-correlation 971 | 972 | .. _link: 973 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 974 | """ 975 | 976 | def __init__( 977 | self, 978 | in_channels: int, 979 | out_channels: int, 980 | kernel_size: _size_3_t, 981 | stride: _size_3_t = 1, 982 | padding: _size_3_t = 0, 983 | output_padding: _size_3_t = 0, 984 | groups: int = 1, 985 | bias: bool = True, 986 | dilation: _size_3_t = 1, 987 | padding_mode: str = 'zeros', 988 | complex_weights = True 989 | ): 990 | kernel_size = _triple(kernel_size) 991 | stride = _triple(stride) 992 | padding = _triple(padding) 993 | dilation = _triple(dilation) 994 | output_padding = _triple(output_padding) 995 | super(ConvTranspose3d, self).__init__( 996 | in_channels, out_channels, kernel_size, stride, padding, dilation, 997 | True, output_padding, groups, bias, padding_mode, complex_weights) 998 | 999 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 1000 | if self.padding_mode != 'zeros': 1001 | raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d') 1002 | 1003 | output_padding = self._output_padding( 1004 | input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) 1005 | 1006 | return cF.conv_transpose3d( 1007 | input, self.weight, self.bias, self.stride, self.padding, 1008 | output_padding, self.groups, self.dilation) -------------------------------------------------------------------------------- /torchcomplex/nn/modules/dropout.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import Module 2 | import torch.nn.functional as F 3 | from .. import functional as cF 4 | 5 | from torch import Tensor 6 | 7 | 8 | class _DropoutNd(Module): 9 | __constants__ = ['p', 'inplace'] 10 | p: float 11 | inplace: bool 12 | 13 | def __init__(self, p: float = 0.5, inplace: bool = False) -> None: 14 | super(_DropoutNd, self).__init__() 15 | if p < 0 or p > 1: 16 | raise ValueError("dropout probability has to be between 0 and 1, " 17 | "but got {}".format(p)) 18 | self.p = p 19 | self.inplace = inplace 20 | 21 | def extra_repr(self) -> str: 22 | return 'p={}, inplace={}'.format(self.p, self.inplace) 23 | 24 | 25 | class Dropout(_DropoutNd): 26 | r"""During training, randomly zeroes some of the elements of the input 27 | tensor with probability :attr:`p` using samples from a Bernoulli 28 | distribution. Each channel will be zeroed out independently on every forward 29 | call. 30 | 31 | This has proven to be an effective technique for regularization and 32 | preventing the co-adaptation of neurons as described in the paper 33 | `Improving neural networks by preventing co-adaptation of feature 34 | detectors`_ . 35 | 36 | Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during 37 | training. This means that during evaluation the module simply computes an 38 | identity function. 39 | 40 | Args: 41 | p: probability of an element to be zeroed. Default: 0.5 42 | inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 43 | 44 | Shape: 45 | - Input: :math:`(*)`. Input can be of any shape 46 | - Output: :math:`(*)`. Output is of the same shape as input 47 | 48 | Examples:: 49 | 50 | >>> m = nn.Dropout(p=0.2) 51 | >>> input = torch.randn(20, 16) 52 | >>> output = m(input) 53 | 54 | .. _Improving neural networks by preventing co-adaptation of feature 55 | detectors: https://arxiv.org/abs/1207.0580 56 | """ 57 | 58 | def forward(self, input: Tensor) -> Tensor: 59 | return cF.complex_fcaller(F.dropout, input, self.p, self.training, self.inplace) 60 | 61 | 62 | class Dropout2d(_DropoutNd): 63 | r"""Randomly zero out entire channels (a channel is a 2D feature map, 64 | e.g., the :math:`j`-th channel of the :math:`i`-th sample in the 65 | batched input is a 2D tensor :math:`\text{input}[i, j]`). 66 | Each channel will be zeroed out independently on every forward call with 67 | probability :attr:`p` using samples from a Bernoulli distribution. 68 | 69 | Usually the input comes from :class:`nn.Conv2d` modules. 70 | 71 | As described in the paper 72 | `Efficient Object Localization Using Convolutional Networks`_ , 73 | if adjacent pixels within feature maps are strongly correlated 74 | (as is normally the case in early convolution layers) then i.i.d. dropout 75 | will not regularize the activations and will otherwise just result 76 | in an effective learning rate decrease. 77 | 78 | In this case, :func:`nn.Dropout2d` will help promote independence between 79 | feature maps and should be used instead. 80 | 81 | Args: 82 | p (float, optional): probability of an element to be zero-ed. 83 | inplace (bool, optional): If set to ``True``, will do this operation 84 | in-place 85 | 86 | Shape: 87 | - Input: :math:`(N, C, H, W)` 88 | - Output: :math:`(N, C, H, W)` (same shape as input) 89 | 90 | Examples:: 91 | 92 | >>> m = nn.Dropout2d(p=0.2) 93 | >>> input = torch.randn(20, 16, 32, 32) 94 | >>> output = m(input) 95 | 96 | .. _Efficient Object Localization Using Convolutional Networks: 97 | https://arxiv.org/abs/1411.4280 98 | """ 99 | 100 | def forward(self, input: Tensor) -> Tensor: 101 | return cF.complex_fcaller(F.dropout2d, input, self.p, self.training, self.inplace) 102 | 103 | 104 | class Dropout3d(_DropoutNd): 105 | r"""Randomly zero out entire channels (a channel is a 3D feature map, 106 | e.g., the :math:`j`-th channel of the :math:`i`-th sample in the 107 | batched input is a 3D tensor :math:`\text{input}[i, j]`). 108 | Each channel will be zeroed out independently on every forward call with 109 | probability :attr:`p` using samples from a Bernoulli distribution. 110 | 111 | Usually the input comes from :class:`nn.Conv3d` modules. 112 | 113 | As described in the paper 114 | `Efficient Object Localization Using Convolutional Networks`_ , 115 | if adjacent pixels within feature maps are strongly correlated 116 | (as is normally the case in early convolution layers) then i.i.d. dropout 117 | will not regularize the activations and will otherwise just result 118 | in an effective learning rate decrease. 119 | 120 | In this case, :func:`nn.Dropout3d` will help promote independence between 121 | feature maps and should be used instead. 122 | 123 | Args: 124 | p (float, optional): probability of an element to be zeroed. 125 | inplace (bool, optional): If set to ``True``, will do this operation 126 | in-place 127 | 128 | Shape: 129 | - Input: :math:`(N, C, D, H, W)` 130 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 131 | 132 | Examples:: 133 | 134 | >>> m = nn.Dropout3d(p=0.2) 135 | >>> input = torch.randn(20, 16, 4, 32, 32) 136 | >>> output = m(input) 137 | 138 | .. _Efficient Object Localization Using Convolutional Networks: 139 | https://arxiv.org/abs/1411.4280 140 | """ 141 | 142 | def forward(self, input: Tensor) -> Tensor: 143 | return cF.complex_fcaller(F.dropout3d, input, self.p, self.training, self.inplace) 144 | 145 | 146 | class AlphaDropout(_DropoutNd): 147 | r"""Applies Alpha Dropout over the input. 148 | 149 | Alpha Dropout is a type of Dropout that maintains the self-normalizing 150 | property. 151 | For an input with zero mean and unit standard deviation, the output of 152 | Alpha Dropout maintains the original mean and standard deviation of the 153 | input. 154 | Alpha Dropout goes hand-in-hand with SELU activation function, which ensures 155 | that the outputs have zero mean and unit standard deviation. 156 | 157 | During training, it randomly masks some of the elements of the input 158 | tensor with probability *p* using samples from a bernoulli distribution. 159 | The elements to masked are randomized on every forward call, and scaled 160 | and shifted to maintain zero mean and unit standard deviation. 161 | 162 | During evaluation the module simply computes an identity function. 163 | 164 | More details can be found in the paper `Self-Normalizing Neural Networks`_ . 165 | 166 | Args: 167 | p (float): probability of an element to be dropped. Default: 0.5 168 | inplace (bool, optional): If set to ``True``, will do this operation 169 | in-place 170 | 171 | Shape: 172 | - Input: :math:`(*)`. Input can be of any shape 173 | - Output: :math:`(*)`. Output is of the same shape as input 174 | 175 | Examples:: 176 | 177 | >>> m = nn.AlphaDropout(p=0.2) 178 | >>> input = torch.randn(20, 16) 179 | >>> output = m(input) 180 | 181 | .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 182 | """ 183 | 184 | def forward(self, input: Tensor) -> Tensor: 185 | return cF.complex_fcaller(F.alpha_dropout, input, self.p, self.training) 186 | 187 | 188 | class FeatureAlphaDropout(_DropoutNd): 189 | r"""Randomly masks out entire channels (a channel is a feature map, 190 | e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input 191 | is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of 192 | setting activations to zero, as in regular Dropout, the activations are set 193 | to the negative saturation value of the SELU activation function. More details 194 | can be found in the paper `Self-Normalizing Neural Networks`_ . 195 | 196 | Each element will be masked independently for each sample on every forward 197 | call with probability :attr:`p` using samples from a Bernoulli distribution. 198 | The elements to be masked are randomized on every forward call, and scaled 199 | and shifted to maintain zero mean and unit variance. 200 | 201 | Usually the input comes from :class:`nn.AlphaDropout` modules. 202 | 203 | As described in the paper 204 | `Efficient Object Localization Using Convolutional Networks`_ , 205 | if adjacent pixels within feature maps are strongly correlated 206 | (as is normally the case in early convolution layers) then i.i.d. dropout 207 | will not regularize the activations and will otherwise just result 208 | in an effective learning rate decrease. 209 | 210 | In this case, :func:`nn.AlphaDropout` will help promote independence between 211 | feature maps and should be used instead. 212 | 213 | Args: 214 | p (float, optional): probability of an element to be zeroed. Default: 0.5 215 | inplace (bool, optional): If set to ``True``, will do this operation 216 | in-place 217 | 218 | Shape: 219 | - Input: :math:`(N, C, D, H, W)` 220 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 221 | 222 | Examples:: 223 | 224 | >>> m = nn.FeatureAlphaDropout(p=0.2) 225 | >>> input = torch.randn(20, 16, 4, 32, 32) 226 | >>> output = m(input) 227 | 228 | .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 229 | .. _Efficient Object Localization Using Convolutional Networks: 230 | https://arxiv.org/abs/1411.4280 231 | """ 232 | 233 | def forward(self, input: Tensor) -> Tensor: 234 | return cF.complex_fcaller(F.feature_alpha_dropoutinput, input, self.p, self.training) 235 | -------------------------------------------------------------------------------- /torchcomplex/nn/modules/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.parameter import Parameter#, UninitializedParameter 6 | from torch.nn import ParameterList 7 | from torch.nn import functional as F 8 | from .. import functional as cF 9 | from torch.nn import init 10 | from torch.nn.modules import Module 11 | # from torch.nn.modules.lazy import LazyModuleMixin 12 | from typing import Optional, List, Tuple, Union 13 | 14 | from torch.nn import Identity #just to use the torch's version of identity, to help with imports 15 | 16 | class Linear(Module): 17 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 18 | 19 | This module supports :ref:`TensorFloat32`. 20 | 21 | Args: 22 | in_features: size of each input sample 23 | out_features: size of each output sample 24 | bias: If set to ``False``, the layer will not learn an additive bias. 25 | Default: ``True`` 26 | 27 | Shape: 28 | - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of 29 | additional dimensions and :math:`H_{in} = \text{in\_features}` 30 | - Output: :math:`(N, *, H_{out})` where all but the last dimension 31 | are the same shape as the input and :math:`H_{out} = \text{out\_features}`. 32 | 33 | Attributes: 34 | weight: the learnable weights of the module of shape 35 | :math:`(\text{out\_features}, \text{in\_features})`. The values are 36 | initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where 37 | :math:`k = \frac{1}{\text{in\_features}}` 38 | bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. 39 | If :attr:`bias` is ``True``, the values are initialized from 40 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 41 | :math:`k = \frac{1}{\text{in\_features}}` 42 | 43 | Examples:: 44 | 45 | >>> m = nn.Linear(20, 30) 46 | >>> input = torch.randn(128, 20) 47 | >>> output = m(input) 48 | >>> print(output.size()) 49 | torch.Size([128, 30]) 50 | """ 51 | __constants__ = ['in_features', 'out_features'] 52 | in_features: int 53 | out_features: int 54 | complex_weights: bool 55 | weight: Union[Tensor, Tuple[Tensor, Tensor]] 56 | bias: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] 57 | 58 | def __init__(self, in_features: int, out_features: int, bias: bool = True, complex_weights: bool = True) -> None: 59 | super(Linear, self).__init__() 60 | self.in_features = in_features 61 | self.out_features = out_features 62 | self.complex_weights = complex_weights 63 | 64 | if complex_weights: 65 | self.weight = Parameter(torch.Tensor(out_features, in_features).to(torch.cfloat)) 66 | else: 67 | weight_real = Parameter(torch.Tensor(out_features, in_features)) 68 | weight_imag = Parameter(torch.Tensor(out_features, in_features)) 69 | self.weight = ParameterList([weight_real, weight_imag]) 70 | 71 | if bias: 72 | if complex_weights: 73 | self.bias = Parameter(torch.Tensor(out_features).to(torch.cfloat)) 74 | else: 75 | bias_real = Parameter(torch.Tensor(out_features)) 76 | bias_imag = Parameter(torch.Tensor(out_features)) 77 | self.bias = ParameterList([bias_real, bias_imag]) 78 | else: 79 | self.register_parameter('bias', None) 80 | self.reset_parameters() 81 | 82 | def _reset_parameters(self, weight, bias) -> None: 83 | init.kaiming_uniform_(weight, a=math.sqrt(5)) 84 | if bias is not None: 85 | fan_in, _ = init._calculate_fan_in_and_fan_out(weight) 86 | bound = 1 / math.sqrt(fan_in) 87 | init.uniform_(bias, -bound, bound) 88 | 89 | def reset_parameters(self) -> None: 90 | if type(self.weight) is ParameterList: 91 | self._reset_parameters(self.weight[0], None if self.bias is None else self.bias[0]) 92 | self._reset_parameters(self.weight[1], None if self.bias is None else self.bias[1]) 93 | else: 94 | self._reset_parameters(self.weight, self.bias) 95 | 96 | def forward(self, input: Tensor) -> Tensor: 97 | return cF.linear(input, self.weight, self.bias) 98 | 99 | def extra_repr(self) -> str: 100 | return 'in_features={}, out_features={}, bias={}'.format( 101 | self.in_features, self.out_features, self.bias is not None 102 | ) 103 | 104 | 105 | # This class exists solely for Transformer; it has an annotation stating 106 | # that bias is never None, which appeases TorchScript 107 | class _LinearWithBias(Linear): 108 | bias: Tensor # type: ignore 109 | 110 | def __init__(self, in_features: int, out_features: int) -> None: 111 | super().__init__(in_features, out_features, bias=True) # type: ignore 112 | 113 | 114 | class Bilinear(Module): 115 | r"""Applies a bilinear transformation to the incoming data: 116 | :math:`y = x_1^T A x_2 + b` 117 | 118 | Args: 119 | in1_features: size of each first input sample 120 | in2_features: size of each second input sample 121 | out_features: size of each output sample 122 | bias: If set to False, the layer will not learn an additive bias. 123 | Default: ``True`` 124 | 125 | Shape: 126 | - Input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` and 127 | :math:`*` means any number of additional dimensions. All but the last dimension 128 | of the inputs should be the same. 129 | - Input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`. 130 | - Output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` 131 | and all but the last dimension are the same shape as the input. 132 | 133 | Attributes: 134 | weight: the learnable weights of the module of shape 135 | :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`. 136 | The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where 137 | :math:`k = \frac{1}{\text{in1\_features}}` 138 | bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. 139 | If :attr:`bias` is ``True``, the values are initialized from 140 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where 141 | :math:`k = \frac{1}{\text{in1\_features}}` 142 | 143 | Examples:: 144 | 145 | >>> m = nn.Bilinear(20, 30, 40) 146 | >>> input1 = torch.randn(128, 20) 147 | >>> input2 = torch.randn(128, 30) 148 | >>> output = m(input1, input2) 149 | >>> print(output.size()) 150 | torch.Size([128, 40]) 151 | """ 152 | __constants__ = ['in1_features', 'in2_features', 'out_features'] 153 | in1_features: int 154 | in2_features: int 155 | out_features: int 156 | complex_weights: bool 157 | weight: Union[Tensor, Tuple[Tensor, Tensor]] 158 | bias: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] 159 | 160 | def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True, complex_weights: bool = True) -> None: 161 | super(Bilinear, self).__init__() 162 | self.in1_features = in1_features 163 | self.in2_features = in2_features 164 | self.out_features = out_features 165 | self.complex_weights = complex_weights 166 | 167 | if complex_weights: 168 | self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features).to(torch.cfloat)) 169 | else: 170 | weight_real = Parameter(torch.Tensor(out_features, in1_features, in2_features)) 171 | weight_imag = Parameter(torch.Tensor(out_features, in1_features, in2_features)) 172 | self.weight = ParameterList([weight_real, weight_imag]) 173 | 174 | if bias: 175 | if complex_weights: 176 | self.bias = Parameter(torch.Tensor(out_features).to(torch.cfloat)) 177 | else: 178 | bias_real = Parameter(torch.Tensor(out_features)) 179 | bias_imag = Parameter(torch.Tensor(out_features)) 180 | self.bias = ParameterList([bias_real, bias_imag]) 181 | else: 182 | self.register_parameter('bias', None) 183 | self.reset_parameters() 184 | 185 | def _reset_parameters(self, weight, bias) -> None: 186 | bound = 1 / math.sqrt(weight.size(1)) 187 | init.uniform_(weight, -bound, bound) 188 | if bias is not None: 189 | init.uniform_(bias, -bound, bound) 190 | 191 | def reset_parameters(self) -> None: 192 | if type(self.weight) is ParameterList: 193 | self._reset_parameters(self.weight[0], None if self.bias is None else self.bias[0]) 194 | self._reset_parameters(self.weight[1], None if self.bias is None else self.bias[1]) 195 | else: 196 | self._reset_parameters(self.weight, self.bias) 197 | 198 | def forward(self, input1: Tensor, input2: Tensor) -> Tensor: 199 | return cF.bilinear(input1, input2, self.weight, self.bias) 200 | 201 | def extra_repr(self) -> str: 202 | return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format( 203 | self.in1_features, self.in2_features, self.out_features, self.bias is not None 204 | ) -------------------------------------------------------------------------------- /torchcomplex/nn/modules/pooling.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from torch import Tensor 4 | from torch.nn.modules import Module 5 | from torch.nn.modules.utils import _single, _pair, _triple 6 | import torch.nn.functional as F 7 | from .. import functional as cF 8 | 9 | from torch.nn.common_types import _size_any_t, _size_1_t, _size_2_t, _size_3_t, _ratio_3_t, _ratio_2_t 10 | 11 | 12 | class _MaxPoolNd(Module): 13 | __constants__ = ['kernel_size', 'stride', 'padding', 'dilation', 14 | 'return_indices', 'ceil_mode'] 15 | return_indices: bool 16 | ceil_mode: bool 17 | 18 | def __init__(self, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None, 19 | padding: _size_any_t = 0, dilation: _size_any_t = 1, 20 | return_indices: bool = False, ceil_mode: bool = False) -> None: 21 | super(_MaxPoolNd, self).__init__() 22 | self.kernel_size = kernel_size 23 | self.stride = stride if (stride is not None) else kernel_size 24 | self.padding = padding 25 | self.dilation = dilation 26 | self.return_indices = return_indices 27 | self.ceil_mode = ceil_mode 28 | 29 | def extra_repr(self) -> str: 30 | return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \ 31 | ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) 32 | 33 | 34 | class MaxPool1d(_MaxPoolNd): 35 | r"""Applies a 1D max pooling over an input signal composed of several input 36 | planes. 37 | 38 | In the simplest case, the output value of the layer with input size :math:`(N, C, L)` 39 | and output :math:`(N, C, L_{out})` can be precisely described as: 40 | 41 | .. math:: 42 | out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1} 43 | input(N_i, C_j, stride \times k + m) 44 | 45 | If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides 46 | for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the 47 | sliding window. This `link`_ has a nice visualization of the pooling parameters. 48 | 49 | Note: 50 | When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding 51 | or the input. Sliding windows that would start in the right padded region are ignored. 52 | 53 | Args: 54 | kernel_size: The size of the sliding window, must be > 0. 55 | stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. 56 | padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. 57 | dilation: The stride between elements within a sliding window, must be > 0. 58 | return_indices: If ``True``, will return the argmax along with the max values. 59 | Useful for :class:`torch.nn.MaxUnpool1d` later 60 | ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This 61 | ensures that every element in the input tensor is covered by a sliding window. 62 | 63 | Shape: 64 | - Input: :math:`(N, C, L_{in})` 65 | - Output: :math:`(N, C, L_{out})`, where 66 | 67 | .. math:: 68 | L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation} 69 | \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor 70 | 71 | Examples:: 72 | 73 | >>> # pool of size=3, stride=2 74 | >>> m = nn.MaxPool1d(3, stride=2) 75 | >>> input = torch.randn(20, 16, 50) 76 | >>> output = m(input) 77 | 78 | .. _link: 79 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 80 | """ 81 | 82 | kernel_size: _size_1_t 83 | stride: _size_1_t 84 | padding: _size_1_t 85 | dilation: _size_1_t 86 | 87 | def forward(self, input: Tensor) -> Tensor: 88 | return cF.complex_fcaller(F.max_pool1d, input, self.kernel_size, self.stride, 89 | self.padding, self.dilation, self.ceil_mode, 90 | self.return_indices) 91 | 92 | 93 | class MaxPool2d(_MaxPoolNd): 94 | r"""Applies a 2D max pooling over an input signal composed of several input 95 | planes. 96 | 97 | In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, 98 | output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` 99 | can be precisely described as: 100 | 101 | .. math:: 102 | \begin{aligned} 103 | out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ 104 | & \text{input}(N_i, C_j, \text{stride[0]} \times h + m, 105 | \text{stride[1]} \times w + n) 106 | \end{aligned} 107 | 108 | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides 109 | for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. 110 | It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. 111 | 112 | Note: 113 | When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding 114 | or the input. Sliding windows that would start in the right padded region are ignored. 115 | 116 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: 117 | 118 | - a single ``int`` -- in which case the same value is used for the height and width dimension 119 | - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, 120 | and the second `int` for the width dimension 121 | 122 | Args: 123 | kernel_size: the size of the window to take a max over 124 | stride: the stride of the window. Default value is :attr:`kernel_size` 125 | padding: implicit zero padding to be added on both sides 126 | dilation: a parameter that controls the stride of elements in the window 127 | return_indices: if ``True``, will return the max indices along with the outputs. 128 | Useful for :class:`torch.nn.MaxUnpool2d` later 129 | ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape 130 | 131 | Shape: 132 | - Input: :math:`(N, C, H_{in}, W_{in})` 133 | - Output: :math:`(N, C, H_{out}, W_{out})`, where 134 | 135 | .. math:: 136 | H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]} 137 | \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor 138 | 139 | .. math:: 140 | W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]} 141 | \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor 142 | 143 | Examples:: 144 | 145 | >>> # pool of square window of size=3, stride=2 146 | >>> m = nn.MaxPool2d(3, stride=2) 147 | >>> # pool of non-square window 148 | >>> m = nn.MaxPool2d((3, 2), stride=(2, 1)) 149 | >>> input = torch.randn(20, 16, 50, 32) 150 | >>> output = m(input) 151 | 152 | .. _link: 153 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 154 | """ 155 | 156 | kernel_size: _size_2_t 157 | stride: _size_2_t 158 | padding: _size_2_t 159 | dilation: _size_2_t 160 | 161 | def forward(self, input: Tensor) -> Tensor: 162 | return cF.complex_fcaller(F.max_pool2d, input, self.kernel_size, self.stride, 163 | self.padding, self.dilation, self.ceil_mode, 164 | self.return_indices) 165 | 166 | 167 | class MaxPool3d(_MaxPoolNd): 168 | r"""Applies a 3D max pooling over an input signal composed of several input 169 | planes. 170 | 171 | In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, 172 | output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` 173 | can be precisely described as: 174 | 175 | .. math:: 176 | \begin{aligned} 177 | \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ 178 | & \text{input}(N_i, C_j, \text{stride[0]} \times d + k, 179 | \text{stride[1]} \times h + m, \text{stride[2]} \times w + n) 180 | \end{aligned} 181 | 182 | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides 183 | for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. 184 | It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. 185 | 186 | Note: 187 | When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding 188 | or the input. Sliding windows that would start in the right padded region are ignored. 189 | 190 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: 191 | 192 | - a single ``int`` -- in which case the same value is used for the depth, height and width dimension 193 | - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, 194 | the second `int` for the height dimension and the third `int` for the width dimension 195 | 196 | Args: 197 | kernel_size: the size of the window to take a max over 198 | stride: the stride of the window. Default value is :attr:`kernel_size` 199 | padding: implicit zero padding to be added on all three sides 200 | dilation: a parameter that controls the stride of elements in the window 201 | return_indices: if ``True``, will return the max indices along with the outputs. 202 | Useful for :class:`torch.nn.MaxUnpool3d` later 203 | ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape 204 | 205 | Shape: 206 | - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` 207 | - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where 208 | 209 | .. math:: 210 | D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times 211 | (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor 212 | 213 | .. math:: 214 | H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times 215 | (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor 216 | 217 | .. math:: 218 | W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times 219 | (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor 220 | 221 | Examples:: 222 | 223 | >>> # pool of square window of size=3, stride=2 224 | >>> m = nn.MaxPool3d(3, stride=2) 225 | >>> # pool of non-square window 226 | >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2)) 227 | >>> input = torch.randn(20, 16, 50,44, 31) 228 | >>> output = m(input) 229 | 230 | .. _link: 231 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 232 | """ # noqa: E501 233 | 234 | kernel_size: _size_3_t 235 | stride: _size_3_t 236 | padding: _size_3_t 237 | dilation: _size_3_t 238 | 239 | def forward(self, input: Tensor) -> Tensor: 240 | return cF.complex_fcaller(F.max_pool3d, input, self.kernel_size, self.stride, 241 | self.padding, self.dilation, self.ceil_mode, 242 | self.return_indices) 243 | 244 | 245 | class _MaxUnpoolNd(Module): 246 | 247 | def extra_repr(self) -> str: 248 | return 'kernel_size={}, stride={}, padding={}'.format( 249 | self.kernel_size, self.stride, self.padding 250 | ) 251 | 252 | 253 | class MaxUnpool1d(_MaxUnpoolNd): 254 | r"""Computes a partial inverse of :class:`MaxPool1d`. 255 | 256 | :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost. 257 | 258 | :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d` 259 | including the indices of the maximal values and computes a partial inverse 260 | in which all non-maximal values are set to zero. 261 | 262 | .. note:: :class:`MaxPool1d` can map several input sizes to the same output 263 | sizes. Hence, the inversion process can get ambiguous. 264 | To accommodate this, you can provide the needed output size 265 | as an additional argument :attr:`output_size` in the forward call. 266 | See the Inputs and Example below. 267 | 268 | Args: 269 | kernel_size (int or tuple): Size of the max pooling window. 270 | stride (int or tuple): Stride of the max pooling window. 271 | It is set to :attr:`kernel_size` by default. 272 | padding (int or tuple): Padding that was added to the input 273 | 274 | Inputs: 275 | - `input`: the input Tensor to invert 276 | - `indices`: the indices given out by :class:`~torch.nn.MaxPool1d` 277 | - `output_size` (optional): the targeted output size 278 | 279 | Shape: 280 | - Input: :math:`(N, C, H_{in})` 281 | - Output: :math:`(N, C, H_{out})`, where 282 | 283 | .. math:: 284 | H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0] 285 | 286 | or as given by :attr:`output_size` in the call operator 287 | 288 | Example:: 289 | 290 | >>> pool = nn.MaxPool1d(2, stride=2, return_indices=True) 291 | >>> unpool = nn.MaxUnpool1d(2, stride=2) 292 | >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]]) 293 | >>> output, indices = pool(input) 294 | >>> unpool(output, indices) 295 | tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]]) 296 | 297 | >>> # Example showcasing the use of output_size 298 | >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]]) 299 | >>> output, indices = pool(input) 300 | >>> unpool(output, indices, output_size=input.size()) 301 | tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]]) 302 | 303 | >>> unpool(output, indices) 304 | tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]]) 305 | """ 306 | 307 | kernel_size: _size_1_t 308 | stride: _size_1_t 309 | padding: _size_1_t 310 | 311 | def __init__(self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, padding: _size_1_t = 0) -> None: 312 | super(MaxUnpool1d, self).__init__() 313 | self.kernel_size = _single(kernel_size) 314 | self.stride = _single(stride if (stride is not None) else kernel_size) 315 | self.padding = _single(padding) 316 | 317 | def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 318 | return cF.complex_fcaller(F.max_unpool1d, input, indices, self.kernel_size, self.stride, 319 | self.padding, output_size) 320 | 321 | 322 | class MaxUnpool2d(_MaxUnpoolNd): 323 | r"""Computes a partial inverse of :class:`MaxPool2d`. 324 | 325 | :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost. 326 | 327 | :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d` 328 | including the indices of the maximal values and computes a partial inverse 329 | in which all non-maximal values are set to zero. 330 | 331 | .. note:: :class:`MaxPool2d` can map several input sizes to the same output 332 | sizes. Hence, the inversion process can get ambiguous. 333 | To accommodate this, you can provide the needed output size 334 | as an additional argument :attr:`output_size` in the forward call. 335 | See the Inputs and Example below. 336 | 337 | Args: 338 | kernel_size (int or tuple): Size of the max pooling window. 339 | stride (int or tuple): Stride of the max pooling window. 340 | It is set to :attr:`kernel_size` by default. 341 | padding (int or tuple): Padding that was added to the input 342 | 343 | Inputs: 344 | - `input`: the input Tensor to invert 345 | - `indices`: the indices given out by :class:`~torch.nn.MaxPool2d` 346 | - `output_size` (optional): the targeted output size 347 | 348 | Shape: 349 | - Input: :math:`(N, C, H_{in}, W_{in})` 350 | - Output: :math:`(N, C, H_{out}, W_{out})`, where 351 | 352 | .. math:: 353 | H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} 354 | 355 | .. math:: 356 | W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} 357 | 358 | or as given by :attr:`output_size` in the call operator 359 | 360 | Example:: 361 | 362 | >>> pool = nn.MaxPool2d(2, stride=2, return_indices=True) 363 | >>> unpool = nn.MaxUnpool2d(2, stride=2) 364 | >>> input = torch.tensor([[[[ 1., 2, 3, 4], 365 | [ 5, 6, 7, 8], 366 | [ 9, 10, 11, 12], 367 | [13, 14, 15, 16]]]]) 368 | >>> output, indices = pool(input) 369 | >>> unpool(output, indices) 370 | tensor([[[[ 0., 0., 0., 0.], 371 | [ 0., 6., 0., 8.], 372 | [ 0., 0., 0., 0.], 373 | [ 0., 14., 0., 16.]]]]) 374 | 375 | >>> # specify a different output size than input size 376 | >>> unpool(output, indices, output_size=torch.Size([1, 1, 5, 5])) 377 | tensor([[[[ 0., 0., 0., 0., 0.], 378 | [ 6., 0., 8., 0., 0.], 379 | [ 0., 0., 0., 14., 0.], 380 | [ 16., 0., 0., 0., 0.], 381 | [ 0., 0., 0., 0., 0.]]]]) 382 | """ 383 | 384 | kernel_size: _size_2_t 385 | stride: _size_2_t 386 | padding: _size_2_t 387 | 388 | def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0) -> None: 389 | super(MaxUnpool2d, self).__init__() 390 | self.kernel_size = _pair(kernel_size) 391 | self.stride = _pair(stride if (stride is not None) else kernel_size) 392 | self.padding = _pair(padding) 393 | 394 | def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 395 | return cF.complex_fcaller(F.max_unpool2d, input, indices, self.kernel_size, self.stride, 396 | self.padding, output_size) 397 | 398 | 399 | class MaxUnpool3d(_MaxUnpoolNd): 400 | r"""Computes a partial inverse of :class:`MaxPool3d`. 401 | 402 | :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost. 403 | :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d` 404 | including the indices of the maximal values and computes a partial inverse 405 | in which all non-maximal values are set to zero. 406 | 407 | .. note:: :class:`MaxPool3d` can map several input sizes to the same output 408 | sizes. Hence, the inversion process can get ambiguous. 409 | To accommodate this, you can provide the needed output size 410 | as an additional argument :attr:`output_size` in the forward call. 411 | See the Inputs section below. 412 | 413 | Args: 414 | kernel_size (int or tuple): Size of the max pooling window. 415 | stride (int or tuple): Stride of the max pooling window. 416 | It is set to :attr:`kernel_size` by default. 417 | padding (int or tuple): Padding that was added to the input 418 | 419 | Inputs: 420 | - `input`: the input Tensor to invert 421 | - `indices`: the indices given out by :class:`~torch.nn.MaxPool3d` 422 | - `output_size` (optional): the targeted output size 423 | 424 | Shape: 425 | - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` 426 | - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where 427 | 428 | .. math:: 429 | D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} 430 | 431 | .. math:: 432 | H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} 433 | 434 | .. math:: 435 | W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} 436 | 437 | or as given by :attr:`output_size` in the call operator 438 | 439 | Example:: 440 | 441 | >>> # pool of square window of size=3, stride=2 442 | >>> pool = nn.MaxPool3d(3, stride=2, return_indices=True) 443 | >>> unpool = nn.MaxUnpool3d(3, stride=2) 444 | >>> output, indices = pool(torch.randn(20, 16, 51, 33, 15)) 445 | >>> unpooled_output = unpool(output, indices) 446 | >>> unpooled_output.size() 447 | torch.Size([20, 16, 51, 33, 15]) 448 | """ 449 | 450 | kernel_size: _size_3_t 451 | stride: _size_3_t 452 | padding: _size_3_t 453 | 454 | def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0) -> None: 455 | super(MaxUnpool3d, self).__init__() 456 | self.kernel_size = _triple(kernel_size) 457 | self.stride = _triple(stride if (stride is not None) else kernel_size) 458 | self.padding = _triple(padding) 459 | 460 | def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 461 | return cF.complex_fcaller(F.max_unpool3d, input, indices, self.kernel_size, self.stride, 462 | self.padding, output_size) 463 | 464 | 465 | class _AvgPoolNd(Module): 466 | __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad'] 467 | 468 | def extra_repr(self) -> str: 469 | return 'kernel_size={}, stride={}, padding={}'.format( 470 | self.kernel_size, self.stride, self.padding 471 | ) 472 | 473 | 474 | class AvgPool1d(_AvgPoolNd): 475 | r"""Applies a 1D average pooling over an input signal composed of several 476 | input planes. 477 | 478 | In the simplest case, the output value of the layer with input size :math:`(N, C, L)`, 479 | output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k` 480 | can be precisely described as: 481 | 482 | .. math:: 483 | 484 | \text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1} 485 | \text{input}(N_i, C_j, \text{stride} \times l + m) 486 | 487 | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides 488 | for :attr:`padding` number of points. 489 | 490 | Note: 491 | When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding 492 | or the input. Sliding windows that would start in the right padded region are ignored. 493 | 494 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be 495 | an ``int`` or a one-element tuple. 496 | 497 | Args: 498 | kernel_size: the size of the window 499 | stride: the stride of the window. Default value is :attr:`kernel_size` 500 | padding: implicit zero padding to be added on both sides 501 | ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape 502 | count_include_pad: when True, will include the zero-padding in the averaging calculation 503 | 504 | Shape: 505 | - Input: :math:`(N, C, L_{in})` 506 | - Output: :math:`(N, C, L_{out})`, where 507 | 508 | .. math:: 509 | L_{out} = \left\lfloor \frac{L_{in} + 510 | 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor 511 | 512 | Examples:: 513 | 514 | >>> # pool with window of size=3, stride=2 515 | >>> m = nn.AvgPool1d(3, stride=2) 516 | >>> m(torch.tensor([[[1.,2,3,4,5,6,7]]])) 517 | tensor([[[ 2., 4., 6.]]]) 518 | """ 519 | 520 | kernel_size: _size_1_t 521 | stride: _size_1_t 522 | padding: _size_1_t 523 | ceil_mode: bool 524 | count_include_pad: bool 525 | 526 | def __init__(self, kernel_size: _size_1_t, stride: _size_1_t = None, padding: _size_1_t = 0, ceil_mode: bool = False, 527 | count_include_pad: bool = True) -> None: 528 | super(AvgPool1d, self).__init__() 529 | self.kernel_size = _single(kernel_size) 530 | self.stride = _single(stride if stride is not None else kernel_size) 531 | self.padding = _single(padding) 532 | self.ceil_mode = ceil_mode 533 | self.count_include_pad = count_include_pad 534 | 535 | def forward(self, input: Tensor) -> Tensor: 536 | return cF.complex_fcaller(F.avg_pool1d, input, self.kernel_size, self.stride, self.padding, self.ceil_mode, 537 | self.count_include_pad) 538 | 539 | 540 | class AvgPool2d(_AvgPoolNd): 541 | r"""Applies a 2D average pooling over an input signal composed of several input 542 | planes. 543 | 544 | In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, 545 | output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` 546 | can be precisely described as: 547 | 548 | .. math:: 549 | 550 | out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} 551 | input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) 552 | 553 | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides 554 | for :attr:`padding` number of points. 555 | 556 | Note: 557 | When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding 558 | or the input. Sliding windows that would start in the right padded region are ignored. 559 | 560 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be: 561 | 562 | - a single ``int`` -- in which case the same value is used for the height and width dimension 563 | - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, 564 | and the second `int` for the width dimension 565 | 566 | Args: 567 | kernel_size: the size of the window 568 | stride: the stride of the window. Default value is :attr:`kernel_size` 569 | padding: implicit zero padding to be added on both sides 570 | ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape 571 | count_include_pad: when True, will include the zero-padding in the averaging calculation 572 | divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used 573 | 574 | Shape: 575 | - Input: :math:`(N, C, H_{in}, W_{in})` 576 | - Output: :math:`(N, C, H_{out}, W_{out})`, where 577 | 578 | .. math:: 579 | H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - 580 | \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor 581 | 582 | .. math:: 583 | W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - 584 | \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor 585 | 586 | Examples:: 587 | 588 | >>> # pool of square window of size=3, stride=2 589 | >>> m = nn.AvgPool2d(3, stride=2) 590 | >>> # pool of non-square window 591 | >>> m = nn.AvgPool2d((3, 2), stride=(2, 1)) 592 | >>> input = torch.randn(20, 16, 50, 32) 593 | >>> output = m(input) 594 | """ 595 | __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override'] 596 | 597 | kernel_size: _size_2_t 598 | stride: _size_2_t 599 | padding: _size_2_t 600 | ceil_mode: bool 601 | count_include_pad: bool 602 | 603 | def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, 604 | ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: bool = None) -> None: 605 | super(AvgPool2d, self).__init__() 606 | self.kernel_size = kernel_size 607 | self.stride = stride if (stride is not None) else kernel_size 608 | self.padding = padding 609 | self.ceil_mode = ceil_mode 610 | self.count_include_pad = count_include_pad 611 | self.divisor_override = divisor_override 612 | 613 | def forward(self, input: Tensor) -> Tensor: 614 | return cF.complex_fcaller(F.avg_pool2d, input, self.kernel_size, self.stride, 615 | self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override) 616 | 617 | 618 | class AvgPool3d(_AvgPoolNd): 619 | r"""Applies a 3D average pooling over an input signal composed of several input 620 | planes. 621 | 622 | In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, 623 | output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` 624 | can be precisely described as: 625 | 626 | .. math:: 627 | \begin{aligned} 628 | \text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\ 629 | & \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k, 630 | \text{stride}[1] \times h + m, \text{stride}[2] \times w + n)} 631 | {kD \times kH \times kW} 632 | \end{aligned} 633 | 634 | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides 635 | for :attr:`padding` number of points. 636 | 637 | Note: 638 | When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding 639 | or the input. Sliding windows that would start in the right padded region are ignored. 640 | 641 | The parameters :attr:`kernel_size`, :attr:`stride` can either be: 642 | 643 | - a single ``int`` -- in which case the same value is used for the depth, height and width dimension 644 | - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, 645 | the second `int` for the height dimension and the third `int` for the width dimension 646 | 647 | Args: 648 | kernel_size: the size of the window 649 | stride: the stride of the window. Default value is :attr:`kernel_size` 650 | padding: implicit zero padding to be added on all three sides 651 | ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape 652 | count_include_pad: when True, will include the zero-padding in the averaging calculation 653 | divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used 654 | 655 | Shape: 656 | - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` 657 | - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where 658 | 659 | .. math:: 660 | D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - 661 | \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor 662 | 663 | .. math:: 664 | H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - 665 | \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor 666 | 667 | .. math:: 668 | W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - 669 | \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor 670 | 671 | Examples:: 672 | 673 | >>> # pool of square window of size=3, stride=2 674 | >>> m = nn.AvgPool3d(3, stride=2) 675 | >>> # pool of non-square window 676 | >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2)) 677 | >>> input = torch.randn(20, 16, 50,44, 31) 678 | >>> output = m(input) 679 | """ 680 | __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override'] 681 | 682 | kernel_size: _size_3_t 683 | stride: _size_3_t 684 | padding: _size_3_t 685 | ceil_mode: bool 686 | count_include_pad: bool 687 | 688 | def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, 689 | ceil_mode: bool = False, count_include_pad: bool = True, divisor_override=None) -> None: 690 | super(AvgPool3d, self).__init__() 691 | self.kernel_size = kernel_size 692 | self.stride = stride if (stride is not None) else kernel_size 693 | self.padding = padding 694 | self.ceil_mode = ceil_mode 695 | self.count_include_pad = count_include_pad 696 | self.divisor_override = divisor_override 697 | 698 | def forward(self, input: Tensor) -> Tensor: 699 | return cF.complex_fcaller(F.avg_pool3d, input, self.kernel_size, self.stride, 700 | self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override) 701 | 702 | def __setstate__(self, d): 703 | super(AvgPool3d, self).__setstate__(d) 704 | self.__dict__.setdefault('padding', 0) 705 | self.__dict__.setdefault('ceil_mode', False) 706 | self.__dict__.setdefault('count_include_pad', True) 707 | 708 | 709 | class FractionalMaxPool2d(Module): 710 | r"""Applies a 2D fractional max pooling over an input signal composed of several input planes. 711 | 712 | Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham 713 | 714 | The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic 715 | step size determined by the target output size. 716 | The number of output features is equal to the number of input planes. 717 | 718 | Args: 719 | kernel_size: the size of the window to take a max over. 720 | Can be a single number k (for a square kernel of k x k) or a tuple `(kh, kw)` 721 | output_size: the target output size of the image of the form `oH x oW`. 722 | Can be a tuple `(oH, oW)` or a single number oH for a square image `oH x oH` 723 | output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. 724 | This has to be a number or tuple in the range (0, 1) 725 | return_indices: if ``True``, will return the indices along with the outputs. 726 | Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False`` 727 | 728 | Examples: 729 | >>> # pool of square window of size=3, and target output size 13x12 730 | >>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12)) 731 | >>> # pool of square window and target output size being half of input image size 732 | >>> m = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5)) 733 | >>> input = torch.randn(20, 16, 50, 32) 734 | >>> output = m(input) 735 | 736 | .. _Fractional MaxPooling: 737 | https://arxiv.org/abs/1412.6071 738 | """ 739 | __constants__ = ['kernel_size', 'return_indices', 'output_size', 740 | 'output_ratio'] 741 | 742 | kernel_size: _size_2_t 743 | return_indices: bool 744 | output_size: _size_2_t 745 | output_ratio: _ratio_2_t 746 | 747 | def __init__(self, kernel_size: _size_2_t, output_size: Optional[_size_2_t] = None, 748 | output_ratio: Optional[_ratio_2_t] = None, 749 | return_indices: bool = False, _random_samples=None) -> None: 750 | super(FractionalMaxPool2d, self).__init__() 751 | self.kernel_size = _pair(kernel_size) 752 | self.return_indices = return_indices 753 | self.register_buffer('_random_samples', _random_samples) 754 | self.output_size = _pair(output_size) if output_size is not None else None 755 | self.output_ratio = _pair(output_ratio) if output_ratio is not None else None 756 | if output_size is None and output_ratio is None: 757 | raise ValueError("FractionalMaxPool2d requires specifying either " 758 | "an output size, or a pooling ratio") 759 | if output_size is not None and output_ratio is not None: 760 | raise ValueError("only one of output_size and output_ratio may be specified") 761 | if self.output_ratio is not None: 762 | if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1): 763 | raise ValueError("output_ratio must be between 0 and 1 (got {})" 764 | .format(output_ratio)) 765 | 766 | def forward(self, input: Tensor) -> Tensor: 767 | return cF.complex_fcaller(F.fractional_max_pool2d, input, self.kernel_size, self.output_size, self.output_ratio, 768 | self.return_indices, _random_samples=self._random_samples) 769 | 770 | 771 | class FractionalMaxPool3d(Module): 772 | r"""Applies a 3D fractional max pooling over an input signal composed of several input planes. 773 | 774 | Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham 775 | 776 | The max-pooling operation is applied in :math:`kTxkHxkW` regions by a stochastic 777 | step size determined by the target output size. 778 | The number of output features is equal to the number of input planes. 779 | 780 | Args: 781 | kernel_size: the size of the window to take a max over. 782 | Can be a single number k (for a square kernel of k x k x k) or a tuple `(kt x kh x kw)` 783 | output_size: the target output size of the image of the form `oT x oH x oW`. 784 | Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH` 785 | output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. 786 | This has to be a number or tuple in the range (0, 1) 787 | return_indices: if ``True``, will return the indices along with the outputs. 788 | Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False`` 789 | 790 | Examples: 791 | >>> # pool of cubic window of size=3, and target output size 13x12x11 792 | >>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11)) 793 | >>> # pool of cubic window and target output size being half of input size 794 | >>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)) 795 | >>> input = torch.randn(20, 16, 50, 32, 16) 796 | >>> output = m(input) 797 | 798 | .. _Fractional MaxPooling: 799 | https://arxiv.org/abs/1412.6071 800 | """ 801 | __constants__ = ['kernel_size', 'return_indices', 'output_size', 802 | 'output_ratio'] 803 | kernel_size: _size_3_t 804 | return_indices: bool 805 | output_size: _size_3_t 806 | output_ratio: _ratio_3_t 807 | 808 | def __init__(self, kernel_size: _size_3_t, output_size: Optional[_size_3_t] = None, 809 | output_ratio: Optional[_ratio_3_t] = None, 810 | return_indices: bool = False, _random_samples=None) -> None: 811 | super(FractionalMaxPool3d, self).__init__() 812 | self.kernel_size = _triple(kernel_size) 813 | self.return_indices = return_indices 814 | self.register_buffer('_random_samples', _random_samples) 815 | self.output_size = _triple(output_size) if output_size is not None else None 816 | self.output_ratio = _triple(output_ratio) if output_ratio is not None else None 817 | if output_size is None and output_ratio is None: 818 | raise ValueError("FractionalMaxPool3d requires specifying either " 819 | "an output size, or a pooling ratio") 820 | if output_size is not None and output_ratio is not None: 821 | raise ValueError("only one of output_size and output_ratio may be specified") 822 | if self.output_ratio is not None: 823 | if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1 and 0 < self.output_ratio[2] < 1): 824 | raise ValueError("output_ratio must be between 0 and 1 (got {})" 825 | .format(output_ratio)) 826 | 827 | def forward(self, input: Tensor) -> Tensor: 828 | return cF.complex_fcaller(F.fractional_max_pool3d, 829 | input, self.kernel_size, self.output_size, self.output_ratio, 830 | self.return_indices, 831 | _random_samples=self._random_samples) 832 | 833 | 834 | class _LPPoolNd(Module): 835 | __constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode'] 836 | 837 | norm_type: float 838 | ceil_mode: bool 839 | 840 | def __init__(self, norm_type: float, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None, 841 | ceil_mode: bool = False) -> None: 842 | super(_LPPoolNd, self).__init__() 843 | self.norm_type = norm_type 844 | self.kernel_size = kernel_size 845 | self.stride = stride 846 | self.ceil_mode = ceil_mode 847 | 848 | def extra_repr(self) -> str: 849 | return 'norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, ' \ 850 | 'ceil_mode={ceil_mode}'.format(**self.__dict__) 851 | 852 | 853 | class LPPool1d(_LPPoolNd): 854 | r"""Applies a 1D power-average pooling over an input signal composed of several input 855 | planes. 856 | 857 | On each window, the function computed is: 858 | 859 | .. math:: 860 | f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} 861 | 862 | - At p = :math:`\infty`, one gets Max Pooling 863 | - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling) 864 | 865 | .. note:: If the sum to the power of `p` is zero, the gradient of this function is 866 | not defined. This implementation will set the gradient to zero in this case. 867 | 868 | Args: 869 | kernel_size: a single int, the size of the window 870 | stride: a single int, the stride of the window. Default value is :attr:`kernel_size` 871 | ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape 872 | 873 | Shape: 874 | - Input: :math:`(N, C, L_{in})` 875 | - Output: :math:`(N, C, L_{out})`, where 876 | 877 | .. math:: 878 | L_{out} = \left\lfloor\frac{L_{in} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor 879 | 880 | Examples:: 881 | >>> # power-2 pool of window of length 3, with stride 2. 882 | >>> m = nn.LPPool1d(2, 3, stride=2) 883 | >>> input = torch.randn(20, 16, 50) 884 | >>> output = m(input) 885 | """ 886 | 887 | kernel_size: _size_1_t 888 | stride: _size_1_t 889 | 890 | def forward(self, input: Tensor) -> Tensor: 891 | return cF.complex_fcaller(F.lp_pool1d, input, float(self.norm_type), self.kernel_size, 892 | self.stride, self.ceil_mode) 893 | 894 | 895 | class LPPool2d(_LPPoolNd): 896 | r"""Applies a 2D power-average pooling over an input signal composed of several input 897 | planes. 898 | 899 | On each window, the function computed is: 900 | 901 | .. math:: 902 | f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} 903 | 904 | - At p = :math:`\infty`, one gets Max Pooling 905 | - At p = 1, one gets Sum Pooling (which is proportional to average pooling) 906 | 907 | The parameters :attr:`kernel_size`, :attr:`stride` can either be: 908 | 909 | - a single ``int`` -- in which case the same value is used for the height and width dimension 910 | - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, 911 | and the second `int` for the width dimension 912 | 913 | .. note:: If the sum to the power of `p` is zero, the gradient of this function is 914 | not defined. This implementation will set the gradient to zero in this case. 915 | 916 | Args: 917 | kernel_size: the size of the window 918 | stride: the stride of the window. Default value is :attr:`kernel_size` 919 | ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape 920 | 921 | Shape: 922 | - Input: :math:`(N, C, H_{in}, W_{in})` 923 | - Output: :math:`(N, C, H_{out}, W_{out})`, where 924 | 925 | .. math:: 926 | H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor 927 | 928 | .. math:: 929 | W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor 930 | 931 | Examples:: 932 | 933 | >>> # power-2 pool of square window of size=3, stride=2 934 | >>> m = nn.LPPool2d(2, 3, stride=2) 935 | >>> # pool of non-square window of power 1.2 936 | >>> m = nn.LPPool2d(1.2, (3, 2), stride=(2, 1)) 937 | >>> input = torch.randn(20, 16, 50, 32) 938 | >>> output = m(input) 939 | 940 | """ 941 | 942 | kernel_size: _size_2_t 943 | stride: _size_2_t 944 | 945 | def forward(self, input: Tensor) -> Tensor: 946 | return cF.complex_fcaller(F.lp_pool2d, input, float(self.norm_type), self.kernel_size, 947 | self.stride, self.ceil_mode) 948 | 949 | 950 | class _AdaptiveMaxPoolNd(Module): 951 | __constants__ = ['output_size', 'return_indices'] 952 | return_indices: bool 953 | 954 | def __init__(self, output_size: _size_any_t, return_indices: bool = False) -> None: 955 | super(_AdaptiveMaxPoolNd, self).__init__() 956 | self.output_size = output_size 957 | self.return_indices = return_indices 958 | 959 | def extra_repr(self) -> str: 960 | return 'output_size={}'.format(self.output_size) 961 | 962 | # FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and 963 | # output shapes are, and how the operation computes output. 964 | 965 | 966 | class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): 967 | r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. 968 | 969 | The output size is H, for any input size. 970 | The number of output features is equal to the number of input planes. 971 | 972 | Args: 973 | output_size: the target output size H 974 | return_indices: if ``True``, will return the indices along with the outputs. 975 | Useful to pass to nn.MaxUnpool1d. Default: ``False`` 976 | 977 | Examples: 978 | >>> # target output size of 5 979 | >>> m = nn.AdaptiveMaxPool1d(5) 980 | >>> input = torch.randn(1, 64, 8) 981 | >>> output = m(input) 982 | 983 | """ 984 | 985 | output_size: _size_1_t 986 | 987 | def forward(self, input: Tensor) -> Tensor: 988 | return cF.complex_fcaller(F.adaptive_max_pool1d, input, self.output_size, self.return_indices) 989 | 990 | 991 | class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): 992 | r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. 993 | 994 | The output is of size H x W, for any input size. 995 | The number of output features is equal to the number of input planes. 996 | 997 | Args: 998 | output_size: the target output size of the image of the form H x W. 999 | Can be a tuple (H, W) or a single H for a square image H x H. 1000 | H and W can be either a ``int``, or ``None`` which means the size will 1001 | be the same as that of the input. 1002 | return_indices: if ``True``, will return the indices along with the outputs. 1003 | Useful to pass to nn.MaxUnpool2d. Default: ``False`` 1004 | 1005 | Examples: 1006 | >>> # target output size of 5x7 1007 | >>> m = nn.AdaptiveMaxPool2d((5,7)) 1008 | >>> input = torch.randn(1, 64, 8, 9) 1009 | >>> output = m(input) 1010 | >>> # target output size of 7x7 (square) 1011 | >>> m = nn.AdaptiveMaxPool2d(7) 1012 | >>> input = torch.randn(1, 64, 10, 9) 1013 | >>> output = m(input) 1014 | >>> # target output size of 10x7 1015 | >>> m = nn.AdaptiveMaxPool2d((None, 7)) 1016 | >>> input = torch.randn(1, 64, 10, 9) 1017 | >>> output = m(input) 1018 | 1019 | """ 1020 | 1021 | output_size: _size_2_t 1022 | 1023 | def forward(self, input: Tensor) -> Tensor: 1024 | return cF.complex_fcaller(F.adaptive_max_pool2d, input, self.output_size, self.return_indices) 1025 | 1026 | 1027 | class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): 1028 | r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. 1029 | 1030 | The output is of size D x H x W, for any input size. 1031 | The number of output features is equal to the number of input planes. 1032 | 1033 | Args: 1034 | output_size: the target output size of the image of the form D x H x W. 1035 | Can be a tuple (D, H, W) or a single D for a cube D x D x D. 1036 | D, H and W can be either a ``int``, or ``None`` which means the size will 1037 | be the same as that of the input. 1038 | 1039 | return_indices: if ``True``, will return the indices along with the outputs. 1040 | Useful to pass to nn.MaxUnpool3d. Default: ``False`` 1041 | 1042 | Examples: 1043 | >>> # target output size of 5x7x9 1044 | >>> m = nn.AdaptiveMaxPool3d((5,7,9)) 1045 | >>> input = torch.randn(1, 64, 8, 9, 10) 1046 | >>> output = m(input) 1047 | >>> # target output size of 7x7x7 (cube) 1048 | >>> m = nn.AdaptiveMaxPool3d(7) 1049 | >>> input = torch.randn(1, 64, 10, 9, 8) 1050 | >>> output = m(input) 1051 | >>> # target output size of 7x9x8 1052 | >>> m = nn.AdaptiveMaxPool3d((7, None, None)) 1053 | >>> input = torch.randn(1, 64, 10, 9, 8) 1054 | >>> output = m(input) 1055 | 1056 | """ 1057 | 1058 | output_size: _size_3_t 1059 | 1060 | def forward(self, input: Tensor) -> Tensor: 1061 | return cF.complex_fcaller(F.adaptive_max_pool3d, input, self.output_size, self.return_indices) 1062 | 1063 | 1064 | class _AdaptiveAvgPoolNd(Module): 1065 | __constants__ = ['output_size'] 1066 | 1067 | def __init__(self, output_size: _size_any_t) -> None: 1068 | super(_AdaptiveAvgPoolNd, self).__init__() 1069 | self.output_size = output_size 1070 | 1071 | def extra_repr(self) -> str: 1072 | return 'output_size={}'.format(self.output_size) 1073 | 1074 | 1075 | class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): 1076 | r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes. 1077 | 1078 | The output size is H, for any input size. 1079 | The number of output features is equal to the number of input planes. 1080 | 1081 | Args: 1082 | output_size: the target output size H 1083 | 1084 | Examples: 1085 | >>> # target output size of 5 1086 | >>> m = nn.AdaptiveAvgPool1d(5) 1087 | >>> input = torch.randn(1, 64, 8) 1088 | >>> output = m(input) 1089 | 1090 | """ 1091 | 1092 | output_size: _size_1_t 1093 | 1094 | def forward(self, input: Tensor) -> Tensor: 1095 | return cF.complex_fcaller(F.adaptive_avg_pool1d, input, self.output_size) 1096 | 1097 | 1098 | class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): 1099 | r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes. 1100 | 1101 | The output is of size H x W, for any input size. 1102 | The number of output features is equal to the number of input planes. 1103 | 1104 | Args: 1105 | output_size: the target output size of the image of the form H x W. 1106 | Can be a tuple (H, W) or a single H for a square image H x H. 1107 | H and W can be either a ``int``, or ``None`` which means the size will 1108 | be the same as that of the input. 1109 | 1110 | Examples: 1111 | >>> # target output size of 5x7 1112 | >>> m = nn.AdaptiveAvgPool2d((5,7)) 1113 | >>> input = torch.randn(1, 64, 8, 9) 1114 | >>> output = m(input) 1115 | >>> # target output size of 7x7 (square) 1116 | >>> m = nn.AdaptiveAvgPool2d(7) 1117 | >>> input = torch.randn(1, 64, 10, 9) 1118 | >>> output = m(input) 1119 | >>> # target output size of 10x7 1120 | >>> m = nn.AdaptiveAvgPool2d((None, 7)) 1121 | >>> input = torch.randn(1, 64, 10, 9) 1122 | >>> output = m(input) 1123 | 1124 | """ 1125 | 1126 | output_size: _size_2_t 1127 | 1128 | def forward(self, input: Tensor) -> Tensor: 1129 | return cF.complex_fcaller(F.adaptive_avg_pool2d, input, self.output_size) 1130 | 1131 | 1132 | class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): 1133 | r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes. 1134 | 1135 | The output is of size D x H x W, for any input size. 1136 | The number of output features is equal to the number of input planes. 1137 | 1138 | Args: 1139 | output_size: the target output size of the form D x H x W. 1140 | Can be a tuple (D, H, W) or a single number D for a cube D x D x D. 1141 | D, H and W can be either a ``int``, or ``None`` which means the size will 1142 | be the same as that of the input. 1143 | 1144 | Examples: 1145 | >>> # target output size of 5x7x9 1146 | >>> m = nn.AdaptiveAvgPool3d((5,7,9)) 1147 | >>> input = torch.randn(1, 64, 8, 9, 10) 1148 | >>> output = m(input) 1149 | >>> # target output size of 7x7x7 (cube) 1150 | >>> m = nn.AdaptiveAvgPool3d(7) 1151 | >>> input = torch.randn(1, 64, 10, 9, 8) 1152 | >>> output = m(input) 1153 | >>> # target output size of 7x9x8 1154 | >>> m = nn.AdaptiveAvgPool3d((7, None, None)) 1155 | >>> input = torch.randn(1, 64, 10, 9, 8) 1156 | >>> output = m(input) 1157 | 1158 | """ 1159 | 1160 | output_size: _size_3_t 1161 | 1162 | def forward(self, input: Tensor) -> Tensor: 1163 | return cF.complex_fcaller(F.adaptive_avg_pool3d, input, self.output_size) -------------------------------------------------------------------------------- /torchcomplex/nn/modules/upsampling.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import Module 2 | from torch.nn import functional as F 3 | from .. import functional as cF 4 | 5 | from torch import Tensor 6 | from typing import Optional 7 | from torch.nn.common_types import _size_any_t, _ratio_any_t 8 | 9 | 10 | class Upsample(Module): 11 | r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. 12 | 13 | The input data is assumed to be of the form 14 | `minibatch x channels x [optional depth] x [optional height] x width`. 15 | Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. 16 | 17 | The algorithms available for upsampling are nearest neighbor and linear, 18 | bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, 19 | respectively. 20 | 21 | One can either give a :attr:`scale_factor` or the target output :attr:`size` to 22 | calculate the output size. (You cannot give both, as it is ambiguous) 23 | 24 | Args: 25 | size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): 26 | output spatial sizes 27 | scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): 28 | multiplier for spatial size. Has to match input size if it is a tuple. 29 | mode (str, optional): the upsampling algorithm: one of ``'sinc'``, ``'nearest'` 30 | ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. 31 | Default: ``'nearest'`` 32 | align_corners (bool, optional): if ``True``, the corner pixels of the input 33 | and output tensors are aligned, and thus preserving the values at 34 | those pixels. This only has effect when :attr:`mode` is 35 | ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False`` 36 | 37 | Shape: 38 | - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` 39 | - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})` 40 | or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where 41 | 42 | .. math:: 43 | D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor 44 | 45 | .. math:: 46 | H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor 47 | 48 | .. math:: 49 | W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor 50 | 51 | .. warning:: 52 | With ``align_corners = True``, the linearly interpolating modes 53 | (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally 54 | align the output and input pixels, and thus the output values can depend 55 | on the input size. This was the default behavior for these modes up to 56 | version 0.3.1. Since then, the default behavior is 57 | ``align_corners = False``. See below for concrete examples on how this 58 | affects the outputs. 59 | 60 | .. note:: 61 | If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`. 62 | 63 | Examples:: 64 | 65 | >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) 66 | >>> input 67 | tensor([[[[ 1., 2.], 68 | [ 3., 4.]]]]) 69 | 70 | >>> m = nn.Upsample(scale_factor=2, mode='nearest') 71 | >>> m(input) 72 | tensor([[[[ 1., 1., 2., 2.], 73 | [ 1., 1., 2., 2.], 74 | [ 3., 3., 4., 4.], 75 | [ 3., 3., 4., 4.]]]]) 76 | 77 | >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False 78 | >>> m(input) 79 | tensor([[[[ 1.0000, 1.2500, 1.7500, 2.0000], 80 | [ 1.5000, 1.7500, 2.2500, 2.5000], 81 | [ 2.5000, 2.7500, 3.2500, 3.5000], 82 | [ 3.0000, 3.2500, 3.7500, 4.0000]]]]) 83 | 84 | >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 85 | >>> m(input) 86 | tensor([[[[ 1.0000, 1.3333, 1.6667, 2.0000], 87 | [ 1.6667, 2.0000, 2.3333, 2.6667], 88 | [ 2.3333, 2.6667, 3.0000, 3.3333], 89 | [ 3.0000, 3.3333, 3.6667, 4.0000]]]]) 90 | 91 | >>> # Try scaling the same data in a larger tensor 92 | >>> 93 | >>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3) 94 | >>> input_3x3[:, :, :2, :2].copy_(input) 95 | tensor([[[[ 1., 2.], 96 | [ 3., 4.]]]]) 97 | >>> input_3x3 98 | tensor([[[[ 1., 2., 0.], 99 | [ 3., 4., 0.], 100 | [ 0., 0., 0.]]]]) 101 | 102 | >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False 103 | >>> # Notice that values in top left corner are the same with the small input (except at boundary) 104 | >>> m(input_3x3) 105 | tensor([[[[ 1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000], 106 | [ 1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000], 107 | [ 2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000], 108 | [ 2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000], 109 | [ 0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000], 110 | [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) 111 | 112 | >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 113 | >>> # Notice that values in top left corner are now changed 114 | >>> m(input_3x3) 115 | tensor([[[[ 1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000], 116 | [ 1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000], 117 | [ 2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000], 118 | [ 2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000], 119 | [ 1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000], 120 | [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) 121 | """ 122 | __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name'] 123 | name: str 124 | size: Optional[_size_any_t] 125 | scale_factor: Optional[_ratio_any_t] 126 | mode: str 127 | align_corners: Optional[bool] 128 | 129 | def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None, 130 | mode: str = 'sinc', align_corners: Optional[bool] = None) -> None: 131 | super(Upsample, self).__init__() 132 | self.name = type(self).__name__ 133 | self.size = size 134 | if isinstance(scale_factor, tuple): 135 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 136 | else: 137 | self.scale_factor = float(scale_factor) if scale_factor else None 138 | self.mode = mode 139 | self.align_corners = align_corners 140 | 141 | def forward(self, input: Tensor) -> Tensor: 142 | return cF.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners) 143 | 144 | def extra_repr(self) -> str: 145 | if self.scale_factor is not None: 146 | info = 'scale_factor=' + str(self.scale_factor) 147 | else: 148 | info = 'size=' + str(self.size) 149 | info += ', mode=' + self.mode 150 | return info -------------------------------------------------------------------------------- /torchcomplex/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .support_funcs import * -------------------------------------------------------------------------------- /torchcomplex/utils/signaltools.py: -------------------------------------------------------------------------------- 1 | """ 2 | signaltools.py (Only a few functions) of Scipy's Signal processing package, implimented for PyTorch 3 | Currently implimeted: resample 4 | 5 | """ 6 | 7 | import sys 8 | import torch 9 | import torch.fft 10 | 11 | __author__ = "Soumick Chatterjee" 12 | __copyright__ = "Copyright 2020, Soumick Chatterjee & OvGU:ESF:MEMoRIAL" 13 | __credits__ = ["Soumick Chatterjee"] 14 | 15 | __license__ = "GPL" 16 | __version__ = "0.0.1" 17 | __email__ = "soumick.chatterjee@ovgu.de" 18 | __status__ = "Only x, num and axis of the resample function have been tested" 19 | 20 | def resample(x, num, t=None, axis=0, window=None, domain='time'): 21 | """ 22 | Resample `x` to `num` samples using Fourier method along the given axis. 23 | 24 | The resampled signal starts at the same value as `x` but is sampled 25 | with a spacing of ``len(x) / num * (spacing of x)``. Because a 26 | Fourier method is used, the signal is assumed to be periodic. 27 | 28 | Parameters 29 | ---------- 30 | x : array_like 31 | The data to be resampled. 32 | num : int or array_like 33 | The number of samples in the resampled signal. 34 | If array_like is supplied, then the resample function will be 35 | called recursively for each element of num. 36 | t : array_like, optional 37 | If `t` is given, it is assumed to be the equally spaced sample 38 | positions associated with the signal data in `x`. 39 | axis : (int, optional) or (array_like) 40 | The axis of `x` that is resampled. Default is 0. 41 | If num is array_like, then axis has to be supplied and has to be array_like. 42 | Each element of axis should have one-on-on mapping wtih num. 43 | If num is int but axis is array_like, then num will be repeated and will be 44 | made a list with same number of elements as axis. Then will proceed both as array_like. 45 | window : array_like, callable, string, float, or tuple, optional 46 | Specifies the window applied to the signal in the Fourier 47 | domain. See below for details. 48 | domain : string, optional 49 | A string indicating the domain of the input `x`: 50 | ``time`` Consider the input `x` as time-domain (Default), 51 | ``freq`` Consider the input `x` as frequency-domain. 52 | 53 | Returns 54 | ------- 55 | resampled_x or (resampled_x, resampled_t) 56 | Either the resampled array, or, if `t` was given, a tuple 57 | containing the resampled array and the corresponding resampled 58 | positions. 59 | 60 | See Also 61 | -------- 62 | decimate : Downsample the signal after applying an FIR or IIR filter. 63 | resample_poly : Resample using polyphase filtering and an FIR filter. 64 | 65 | Notes 66 | ----- 67 | The argument `window` controls a Fourier-domain window that tapers 68 | the Fourier spectrum before zero-padding to alleviate ringing in 69 | the resampled values for sampled signals you didn't intend to be 70 | interpreted as band-limited. 71 | 72 | If `window` is a function, then it is called with a vector of inputs 73 | indicating the frequency bins (i.e. fftfreq(x.shape[axis]) ). 74 | 75 | If `window` is an array of the same length as `x.shape[axis]` it is 76 | assumed to be the window to be applied directly in the Fourier 77 | domain (with dc and low-frequency first). 78 | 79 | For any other type of `window`, the function `scipy.signal.get_window` 80 | is called to generate the window. 81 | 82 | The first sample of the returned vector is the same as the first 83 | sample of the input vector. The spacing between samples is changed 84 | from ``dx`` to ``dx * len(x) / num``. 85 | 86 | If `t` is not None, then it is used solely to calculate the resampled 87 | positions `resampled_t` 88 | 89 | As noted, `resample` uses FFT transformations, which can be very 90 | slow if the number of input or output samples is large and prime; 91 | see `scipy.fft.fft`. 92 | 93 | Examples 94 | -------- 95 | Note that the end of the resampled data rises to meet the first 96 | sample of the next cycle: 97 | 98 | >>> from scipy import signal 99 | 100 | >>> x = np.linspace(0, 10, 20, endpoint=False) 101 | >>> y = np.cos(-x**2/6.0) 102 | >>> f = signal.resample(y, 100) 103 | >>> xnew = np.linspace(0, 10, 100, endpoint=False) 104 | 105 | >>> import matplotlib.pyplot as plt 106 | >>> plt.plot(x, y, 'go-', xnew, f, '.-', 10, y[0], 'ro') 107 | >>> plt.legend(['data', 'resampled'], loc='best') 108 | >>> plt.show() 109 | """ 110 | 111 | if domain not in ('time', 'freq'): 112 | raise ValueError("Acceptable domain flags are 'time' or" 113 | " 'freq', not domain={}".format(domain)) 114 | 115 | if hasattr(axis, "__len__") and not hasattr(num, "__len__"): 116 | num = [num]*len(axis) 117 | 118 | if hasattr(num, "__len__"): 119 | if hasattr(axis, "__len__") and len(num)==len(axis): 120 | _temp = x 121 | _t_list = [] 122 | for i in range(len(num)): 123 | _num = num[i] 124 | _axis = axis[i] 125 | if t is None: 126 | _temp = resample(_temp, _num, t, _axis, window, domain) 127 | else: 128 | _temp, _t = resample(_temp, _num, t, _axis, window, domain) 129 | _t_list.append(_t) 130 | if t is None: 131 | return _temp 132 | else: 133 | return _temp, torch.stack(_t_list) 134 | else: 135 | raise ValueError("if num is array like, then axis also has to be array like and of the same length") 136 | 137 | Nx = x.shape[axis] 138 | 139 | # Check if we can use faster real FFT 140 | real_input = not x.is_complex() 141 | 142 | if domain == 'time': 143 | # Forward transform 144 | if real_input: 145 | X = torch.fft.rfft(x, dim=axis) 146 | else: # Full complex FFT 147 | X = torch.fft.fft(x, dim=axis) 148 | else: # domain == 'freq' 149 | X = x 150 | 151 | # Apply window to spectrum 152 | if window is not None: 153 | if callable(window): 154 | W = window(torch.fft.fftfreq(Nx)) 155 | elif isinstance(window, torch.Tensor): 156 | if window.shape != (Nx,): 157 | raise ValueError('window must have the same length as data') 158 | W = window 159 | else: 160 | sys.exit("Window can only be either a function or Tensor. Window generation with get_window function of scipy.signal hasn't been implimented yet.") 161 | W = torch.fft.ifftshift(get_window(window, Nx)) 162 | 163 | newshape_W = [1] * x.ndim 164 | newshape_W[axis] = X.shape[axis] 165 | if real_input: 166 | # Fold the window back on itself to mimic complex behavior 167 | W_real = W.clone() 168 | W_real[1:] += W_real[-1:0:-1] 169 | W_real[1:] *= 0.5 170 | X *= W_real[:newshape_W[axis]].reshape(newshape_W) 171 | else: 172 | X *= W.reshape(newshape_W) 173 | 174 | # Copy each half of the original spectrum to the output spectrum, either 175 | # truncating high frequences (downsampling) or zero-padding them 176 | # (upsampling) 177 | 178 | # Placeholder array for output spectrum 179 | newshape = list(x.shape) 180 | if real_input: 181 | newshape[axis] = num // 2 + 1 182 | else: 183 | newshape[axis] = num 184 | Y = torch.zeros(newshape, dtype=X.dtype, device=x.device) 185 | 186 | # Copy positive frequency components (and Nyquist, if present) 187 | N = min(num, Nx) 188 | nyq = N // 2 + 1 # Slice index that includes Nyquist if present 189 | sl = [slice(None)] * x.ndim 190 | sl[axis] = slice(0, nyq) 191 | Y[tuple(sl)] = X[tuple(sl)] 192 | if not real_input: 193 | # Copy negative frequency components 194 | if N > 2: # (slice expression doesn't collapse to empty array) 195 | sl[axis] = slice(nyq - N, None) 196 | Y[tuple(sl)] = X[tuple(sl)] 197 | 198 | # Split/join Nyquist component(s) if present 199 | # So far we have set Y[+N/2]=X[+N/2] 200 | if N % 2 == 0: 201 | if num < Nx: # downsampling 202 | if real_input: 203 | sl[axis] = slice(N//2, N//2 + 1) 204 | Y[tuple(sl)] *= 2. 205 | else: 206 | # select the component of Y at frequency +N/2, 207 | # add the component of X at -N/2 208 | sl[axis] = slice(-N//2, -N//2 + 1) 209 | Y[tuple(sl)] += X[tuple(sl)] 210 | elif Nx < num: # upsampling 211 | # select the component at frequency +N/2 and halve it 212 | sl[axis] = slice(N//2, N//2 + 1) 213 | Y[tuple(sl)] *= 0.5 214 | if not real_input: 215 | temp = Y[tuple(sl)] 216 | # set the component at -N/2 equal to the component at +N/2 217 | sl[axis] = slice(num-N//2, num-N//2 + 1) 218 | Y[tuple(sl)] = temp 219 | 220 | # Inverse transform 221 | if real_input: 222 | y = torch.fft.irfft(Y, num, dim=axis) 223 | else: 224 | y = torch.fft.ifft(Y, dim=axis)#, overwrite_x=True) #PyTorch ifft doesn't have overwrite_x param 225 | 226 | y *= (float(num) / float(Nx)) 227 | 228 | if t is None: 229 | return y 230 | else: 231 | new_t = torch.arange(0, num) * (t[1] - t[0]) * Nx / float(num) + t[0] 232 | return y, new_t -------------------------------------------------------------------------------- /torchcomplex/utils/support_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def clamp(input, min=None, max=None, *, out=None): 4 | real = torch.clamp(input.real, min=min if type(min) is not complex else min.real, max=max if type(max) is not complex else max.real, out=out) 5 | imag = torch.clamp(input.imag, min=min if type(min) is not complex else min.imag, max=max if type(max) is not complex else max.imag, out=out) 6 | return torch.complex(real, imag) 7 | 8 | def complex_clamp(input, min=None, max=None): 9 | # convert to polar coordinates 10 | magnitude = torch.abs(input) 11 | angle = torch.angle(input) 12 | 13 | # clamp the magnitude 14 | magnitude = torch.clamp(magnitude, min=min, max=max) 15 | 16 | # convert back to Cartesian coordinates 17 | clamped_real = magnitude * torch.cos(angle) 18 | clamped_imag = magnitude * torch.sin(angle) 19 | 20 | return torch.complex(clamped_real, clamped_imag) 21 | --------------------------------------------------------------------------------