├── CHANGELOG.md ├── Example.ipynb ├── LICENSE ├── README.md ├── complexPyTorch ├── __init__.py ├── complexFunctions.py └── complexLayers.py └── setup.py /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.4.1 2 | 3 | ### Added 4 | 5 | * GRU Cell and BN-GRU Cell layers (#11) 6 | * Sigmoid functions/layers (#11) 7 | * Tanh functions/layers (#11) 8 | 9 | ## 0.4 10 | 11 | ### Fixed 12 | 13 | * Corrected BatchNorm1d tensor size issue 14 | 15 | ## 0.3 16 | 17 | ## 0.2.1 18 | 19 | ### Fixed 20 | * Correct bug causing ComplexBatchNorm to fail in eval mode 21 | * Correct behaviour of ComplexBatchNorm for track_running_stats=False 22 | 23 | ### Added 24 | * ComplexAvgPool2d 25 | 26 | ## 0.2 27 | 28 | Requires Pytorch version >= 1.7 29 | 30 | ### Changed 31 | * Use complex64 tensors now supported by Pytorch version >= 1.7 32 | 33 | 34 | ## 0.1 35 | 36 | Initial release 37 | 38 | ### Fixed 39 | * Correct memory leak with torch.nograd() 40 | -------------------------------------------------------------------------------- /Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "from torch.utils.data import Subset\n", 13 | "from torchvision import datasets, transforms\n", 14 | "from complexPyTorch.complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear\n", 15 | "from complexPyTorch.complexLayers import ComplexDropout2d, NaiveComplexBatchNorm2d\n", 16 | "from complexPyTorch.complexLayers import ComplexBatchNorm1d\n", 17 | "from complexPyTorch.complexFunctions import complex_relu, complex_max_pool2d" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 3, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "batch_size = 64\n", 27 | "n_train = 1000\n", 28 | "n_test = 100\n", 29 | "trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])\n", 30 | "train_set = datasets.MNIST('../data', train=True, transform=trans, download=True)\n", 31 | "train_set = Subset(train_set, torch.arange(n_train))\n", 32 | "test_set = datasets.MNIST('../data', train=False, transform=trans, download=True)\n", 33 | "test_set = Subset(test_set, torch.arange(n_test))\n", 34 | "\n", 35 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size= batch_size, shuffle=True)\n", 36 | "test_loader = torch.utils.data.DataLoader(test_set, batch_size= batch_size, shuffle=True)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 6, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "class ComplexNet(nn.Module):\n", 46 | " \n", 47 | " def __init__(self):\n", 48 | " super(ComplexNet, self).__init__()\n", 49 | " self.conv1 = ComplexConv2d(1, 10, 5, 1)\n", 50 | " self.bn2d = ComplexBatchNorm2d(10, track_running_stats = False)\n", 51 | " self.conv2 = ComplexConv2d(10, 20, 5, 1)\n", 52 | " self.fc1 = ComplexLinear(4*4*20, 500)\n", 53 | " self.dropout = ComplexDropout2d(p = 0.3)\n", 54 | " self.bn1d = ComplexBatchNorm1d(500, track_running_stats = False)\n", 55 | " self.fc2 = ComplexLinear(500, 10)\n", 56 | " \n", 57 | " def forward(self,x):\n", 58 | " x = self.conv1(x)\n", 59 | " x = complex_relu(x)\n", 60 | " x = complex_max_pool2d(x, 2, 2)\n", 61 | " x = self.bn2d(x)\n", 62 | " x = self.conv2(x)\n", 63 | " x = complex_relu(x)\n", 64 | " x = complex_max_pool2d(x, 2, 2)\n", 65 | " x = x.view(-1,4*4*20)\n", 66 | " x = self.fc1(x)\n", 67 | " x = self.dropout(x)\n", 68 | " x = complex_relu(x)\n", 69 | " x = self.bn1d(x)\n", 70 | " x = self.fc2(x)\n", 71 | " x = x.abs()\n", 72 | " x = F.log_softmax(x, dim=1)\n", 73 | " return x\n", 74 | " \n", 75 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 76 | "model = ComplexNet().to(device)\n", 77 | "optimizer = torch.optim.SGD(model.parameters(), lr=5e-3, momentum=0.9)\n", 78 | "\n", 79 | "def train(model, device, train_loader, optimizer, epoch):\n", 80 | " model.train()\n", 81 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 82 | " data, target =data.to(device).type(torch.complex64), target.to(device)\n", 83 | " optimizer.zero_grad()\n", 84 | " output = model(data)\n", 85 | " loss = F.nll_loss(output, target)\n", 86 | " loss.backward()\n", 87 | " optimizer.step()\n", 88 | " if batch_idx % 100 == 0:\n", 89 | " print('Train\\t Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\\tLoss: {:.6f}'.format(\n", 90 | " epoch,\n", 91 | " batch_idx * len(data), \n", 92 | " len(train_loader.dataset),\n", 93 | " 100. * batch_idx / len(train_loader), \n", 94 | " loss.item())\n", 95 | " )\n", 96 | " \n", 97 | "def test(model, device, test_loader, optimizer, epoch):\n", 98 | " model.eval()\n", 99 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 100 | " data, target = data.to(device).type(torch.complex64), target.to(device)\n", 101 | " output = model(data)\n", 102 | " loss = F.nll_loss(output, target)\n", 103 | " if batch_idx % 100 == 0:\n", 104 | " print('Test\\t Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\\tLoss: {:.6f}'.format(\n", 105 | " epoch,\n", 106 | " batch_idx * len(data), \n", 107 | " len(test_loader.dataset),\n", 108 | " 100. * batch_idx / len(test_loader), \n", 109 | " loss.item())\n", 110 | " )" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 7, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stderr", 120 | "output_type": "stream", 121 | "text": [ 122 | "/opt/miniconda/envs/py38/lib/python3.8/site-packages/torch/autograd/__init__.py:130: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at /opt/conda/conda-bld/pytorch_1607370124688/work/aten/src/ATen/native/Copy.cpp:162.)\n", 123 | " Variable._execution_engine.run_backward(\n" 124 | ] 125 | }, 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "Train\t Epoch: 0 [ 0/ 1000 ( 0%)]\tLoss: 2.575082\n", 131 | "Test\t Epoch: 0 [ 0/ 1000 ( 0%)]\tLoss: 0.868310\n", 132 | "Train\t Epoch: 1 [ 0/ 1000 ( 0%)]\tLoss: 0.803982\n", 133 | "Test\t Epoch: 1 [ 0/ 1000 ( 0%)]\tLoss: 0.194764\n", 134 | "Train\t Epoch: 2 [ 0/ 1000 ( 0%)]\tLoss: 0.366340\n", 135 | "Test\t Epoch: 2 [ 0/ 1000 ( 0%)]\tLoss: 0.160019\n", 136 | "Train\t Epoch: 3 [ 0/ 1000 ( 0%)]\tLoss: 0.411020\n", 137 | "Test\t Epoch: 3 [ 0/ 1000 ( 0%)]\tLoss: 0.088335\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "# Run training on 4 epochs\n", 143 | "for epoch in range(4):\n", 144 | " train(model, device, train_loader, optimizer, epoch)\n", 145 | " test(model, device, test_loader, optimizer, epoch)" 146 | ] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "Python 3", 152 | "language": "python", 153 | "name": "python3" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.8.5" 166 | }, 167 | "toc": { 168 | "base_numbering": 1, 169 | "nav_menu": {}, 170 | "number_sections": true, 171 | "sideBar": true, 172 | "skip_h1_title": true, 173 | "title_cell": "Table of Contents", 174 | "title_sidebar": "Contents", 175 | "toc_cell": false, 176 | "toc_position": {}, 177 | "toc_section_display": true, 178 | "toc_window_display": false 179 | } 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 4 183 | } 184 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sébastien M. P. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # complexPyTorch 2 | 3 | A high-level toolbox for using complex valued neural networks in PyTorch. 4 | 5 | Before version 1.7 of PyTroch, complex tensor were not supported. 6 | The initial version of **complexPyTorch** represented complex tensor using two tensors, one for the real and one for the imaginary part. 7 | Since version 1.7, compex tensors of type `torch.complex64` are allowed, but only a limited number of operation are supported. 8 | The current version **complexPyTorch** use complex tensors (hence requires PyTorch version >= 1.7) and add support for various operations and layers. 9 | 10 | ## Installation 11 | ```bash 12 | pip install complexPyTorch 13 | ``` 14 | 15 | ## Complex Valued Networks with PyTorch 16 | 17 | Artificial neural networks are mainly used for treating data encoded in real values, such as digitized images or sounds. 18 | In such systems, using complex-valued tensors would be quite useless. 19 | However, for physic related topics, in particular when dealing with wave propagation, using complex values is interesting as the physics typically has linear, hence more simple, behavior when considering complex fields. 20 | complexPyTorch is a simple implementation of complex-valued functions and modules using the high-level API of PyTorch. 21 | Following [[C. Trabelsi et al., International Conference on Learning Representations, (2018)](https://openreview.net/forum?id=H1T2hmZAb)], it allows the following layers and functions to be used with complex values: 22 | * Linear 23 | * Conv2d 24 | * ConvTranspose2d 25 | * MaxPool2d 26 | * AvgPool2d 27 | * Relu (ℂRelu) 28 | * Sigmoid 29 | * Tanh 30 | * Dropout2d 31 | * BatchNorm1d (Naive and Covariance approach) 32 | * BatchNorm2d (Naive and Covariance approach) 33 | * GRU/BN-GRU Cell 34 | 35 | ## Citating the code 36 | 37 | If the code was helpful to your work, please consider citing it: 38 | 39 | [![DOI](https://img.shields.io/badge/DOI-10.1103%2FPhysRevX.11.021060-blue)](https://doi.org/10.1103/PhysRevX.11.021060) 40 | 41 | 42 | ## Syntax and usage 43 | 44 | The syntax is supposed to copy the one of the standard real functions and modules from PyTorch. 45 | The names are the same as in `nn.modules` and `nn.functional` except that they start with `Complex` for Modules, e.g. `ComplexRelu`, `ComplexMaxPool2d` or `complex_` for functions, e.g. `complex_relu`, `complex_max_pool2d`. 46 | The only usage difference is that the forward function takes two tensors, corresponding to real and imaginary parts, and returns two ones too. 47 | 48 | ## BatchNorm 49 | 50 | For all other layers, using the recommendation of [[C. Trabelsi et al., International Conference on Learning Representations, (2018)](https://openreview.net/forum?id=H1T2hmZAb)], the calculation can be done in a straightforward manner using functions and modules form `nn.modules` and `nn.functional`. 51 | For instance, the function `complex_relu` in `complexFunctions`, or its associated module `ComplexRelu` in `complexLayers`, simply performs `relu` both on the real and imaginary part and returns the two tensors. 52 | The complex BatchNorm proposed in [[C. Trabelsi et al., International Conference on Learning Representations, (2018)](https://openreview.net/forum?id=H1T2hmZAb)] requires the calculation of the inverse square root of the covariance matrix. 53 | This is implemented in `ComplexbatchNorm1D` and `ComplexbatchNorm2D` but using the high-level PyTorch API, which is quite slow. 54 | The gain of using this approach, however, can be experimentally marginal compared to the naive approach which consists in simply performing the BatchNorm on both the real and imaginary part, which is available using `NaiveComplexbatchNorm1D` or `NaiveComplexbatchNorm2D`. 55 | 56 | 57 | ## Example 58 | 59 | For illustration, here is a small example of a complex model. 60 | Note that in that example, complex values are not particularly useful, it just shows how one can handle complex ANNs. 61 | 62 | ```python 63 | import torch 64 | import torch.nn as nn 65 | import torch.nn.functional as F 66 | from torchvision import datasets, transforms 67 | from complexPyTorch.complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear 68 | from complexPyTorch.complexFunctions import complex_relu, complex_max_pool2d 69 | 70 | batch_size = 64 71 | trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) 72 | train_set = datasets.MNIST('../data', train=True, transform=trans, download=True) 73 | test_set = datasets.MNIST('../data', train=False, transform=trans, download=True) 74 | 75 | train_loader = torch.utils.data.DataLoader(train_set, batch_size= batch_size, shuffle=True) 76 | test_loader = torch.utils.data.DataLoader(test_set, batch_size= batch_size, shuffle=True) 77 | 78 | class ComplexNet(nn.Module): 79 | 80 | def __init__(self): 81 | super(ComplexNet, self).__init__() 82 | self.conv1 = ComplexConv2d(1, 10, 5, 1) 83 | self.bn = ComplexBatchNorm2d(10) 84 | self.conv2 = ComplexConv2d(10, 20, 5, 1) 85 | self.fc1 = ComplexLinear(4*4*20, 500) 86 | self.fc2 = ComplexLinear(500, 10) 87 | 88 | def forward(self,x): 89 | x = self.conv1(x) 90 | x = complex_relu(x) 91 | x = complex_max_pool2d(x, 2, 2) 92 | x = self.bn(x) 93 | x = self.conv2(x) 94 | x = complex_relu(x) 95 | x = complex_max_pool2d(x, 2, 2) 96 | x = x.view(-1,4*4*20) 97 | x = self.fc1(x) 98 | x = complex_relu(x) 99 | x = self.fc2(x) 100 | x = x.abs() 101 | x = F.log_softmax(x, dim=1) 102 | return x 103 | 104 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 105 | model = ComplexNet().to(device) 106 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) 107 | 108 | def train(model, device, train_loader, optimizer, epoch): 109 | model.train() 110 | for batch_idx, (data, target) in enumerate(train_loader): 111 | data, target = data.to(device).type(torch.complex64), target.to(device) 112 | optimizer.zero_grad() 113 | output = model(data) 114 | loss = F.nll_loss(output, target) 115 | loss.backward() 116 | optimizer.step() 117 | if batch_idx % 100 == 0: 118 | print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format( 119 | epoch, 120 | batch_idx * len(data), 121 | len(train_loader.dataset), 122 | 100. * batch_idx / len(train_loader), 123 | loss.item()) 124 | ) 125 | 126 | # Run training on 50 epochs 127 | for epoch in range(50): 128 | train(model, device, train_loader, optimizer, epoch) 129 | ``` 130 | 131 | 132 | ## Acknowledgments 133 | 134 | I want to thank Piotr Bialecki for his invaluable help on the PyTorch forum. 135 | -------------------------------------------------------------------------------- /complexPyTorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/complexPyTorch/636e1cc545be662f41e10808ff86f6daa2d18ce6/complexPyTorch/__init__.py -------------------------------------------------------------------------------- /complexPyTorch/complexFunctions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | @author: spopoff 6 | """ 7 | 8 | import torch 9 | from torch.nn.functional import ( 10 | avg_pool2d, 11 | dropout, 12 | dropout2d, 13 | interpolate, 14 | max_pool2d, 15 | relu, 16 | sigmoid, 17 | tanh, 18 | ) 19 | 20 | 21 | from torch.nn.functional import max_pool2d, avg_pool2d, dropout, dropout2d, interpolate 22 | from torch import tanh, relu, sigmoid 23 | 24 | 25 | def complex_matmul(A, B): 26 | """ 27 | Performs the matrix product between two complex matrices 28 | """ 29 | 30 | outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag) 31 | outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real) 32 | 33 | return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) 34 | 35 | 36 | def complex_avg_pool2d(inp, *args, **kwargs): 37 | """ 38 | Perform complex average pooling. 39 | """ 40 | absolute_value_real = avg_pool2d(inp.real, *args, **kwargs) 41 | absolute_value_imag = avg_pool2d(inp.imag, *args, **kwargs) 42 | 43 | return absolute_value_real.type(torch.complex64) + 1j * absolute_value_imag.type( 44 | torch.complex64 45 | ) 46 | 47 | 48 | def complex_normalize(inp): 49 | """ 50 | Perform complex normalization 51 | """ 52 | real_value, imag_value = inp.real, inp.imag 53 | real_norm = (real_value - real_value.mean()) / real_value.std() 54 | imag_norm = (imag_value - imag_value.mean()) / imag_value.std() 55 | return real_norm.type(torch.complex64) + 1j * imag_norm.type(torch.complex64) 56 | 57 | 58 | def complex_relu(inp): 59 | return relu(inp.real).type(torch.complex64) + 1j * relu(inp.imag).type( 60 | torch.complex64 61 | ) 62 | 63 | 64 | def complex_sigmoid(inp): 65 | return sigmoid(inp.real).type(torch.complex64) + 1j * sigmoid(inp.imag).type( 66 | torch.complex64 67 | ) 68 | 69 | 70 | def complex_tanh(inp): 71 | return tanh(inp.real).type(torch.complex64) + 1j * tanh(inp.imag).type( 72 | torch.complex64 73 | ) 74 | 75 | 76 | def complex_opposite(inp): 77 | return -inp.real.type(torch.complex64) + 1j * (-inp.imag.type(torch.complex64)) 78 | 79 | 80 | def complex_stack(inp, dim): 81 | inp_real = [x.real for x in inp] 82 | inp_imag = [x.imag for x in inp] 83 | return torch.stack(inp_real, dim).type(torch.complex64) + 1j * torch.stack( 84 | inp_imag, dim 85 | ).type(torch.complex64) 86 | 87 | 88 | def _retrieve_elements_from_indices(tensor, indices): 89 | flattened_tensor = tensor.flatten(start_dim=-2) 90 | output = flattened_tensor.gather( 91 | dim=-1, index=indices.flatten(start_dim=-2) 92 | ).view_as(indices) 93 | return output 94 | 95 | 96 | def complex_upsample( 97 | inp, 98 | size=None, 99 | scale_factor=None, 100 | mode="nearest", 101 | align_corners=None, 102 | recompute_scale_factor=None, 103 | ): 104 | """ 105 | Performs upsampling by separately interpolating the real and imaginary part and recombining 106 | """ 107 | outp_real = interpolate( 108 | inp.real, 109 | size=size, 110 | scale_factor=scale_factor, 111 | mode=mode, 112 | align_corners=align_corners, 113 | recompute_scale_factor=recompute_scale_factor, 114 | ) 115 | outp_imag = interpolate( 116 | inp.imag, 117 | size=size, 118 | scale_factor=scale_factor, 119 | mode=mode, 120 | align_corners=align_corners, 121 | recompute_scale_factor=recompute_scale_factor, 122 | ) 123 | 124 | return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) 125 | 126 | 127 | def complex_upsample2( 128 | inp, 129 | size=None, 130 | scale_factor=None, 131 | mode="nearest", 132 | align_corners=None, 133 | recompute_scale_factor=None, 134 | ): 135 | """ 136 | Performs upsampling by separately interpolating the amplitude and phase part and recombining 137 | """ 138 | outp_abs = interpolate( 139 | inp.abs(), 140 | size=size, 141 | scale_factor=scale_factor, 142 | mode=mode, 143 | align_corners=align_corners, 144 | recompute_scale_factor=recompute_scale_factor, 145 | ) 146 | angle = torch.atan2(inp.imag, inp.real) 147 | outp_angle = interpolate( 148 | angle, 149 | size=size, 150 | scale_factor=scale_factor, 151 | mode=mode, 152 | align_corners=align_corners, 153 | recompute_scale_factor=recompute_scale_factor, 154 | ) 155 | 156 | return outp_abs * ( 157 | torch.cos(outp_angle).type(torch.complex64) 158 | + 1j * torch.sin(outp_angle).type(torch.complex64) 159 | ) 160 | 161 | 162 | def complex_max_pool2d( 163 | inp, 164 | kernel_size, 165 | stride=None, 166 | padding=0, 167 | dilation=1, 168 | ceil_mode=False, 169 | return_indices=False, 170 | ): 171 | """ 172 | Perform complex max pooling by selecting on the absolute value on the complex values. 173 | """ 174 | absolute_value, indices = max_pool2d( 175 | inp.abs(), 176 | kernel_size=kernel_size, 177 | stride=stride, 178 | padding=padding, 179 | dilation=dilation, 180 | ceil_mode=ceil_mode, 181 | return_indices=True, 182 | ) 183 | # performs the selection on the absolute values 184 | absolute_value = absolute_value.type(torch.complex64) 185 | # retrieve the corresponding phase value using the indices 186 | # unfortunately, the derivative for 'angle' is not implemented 187 | angle = torch.atan2(inp.imag, inp.real) 188 | # get only the phase values selected by max pool 189 | angle = _retrieve_elements_from_indices(angle, indices) 190 | return absolute_value * ( 191 | torch.cos(angle).type(torch.complex64) 192 | + 1j * torch.sin(angle).type(torch.complex64) 193 | ) 194 | 195 | 196 | def complex_dropout(inp, p=0.5, training=True): 197 | # need to have the same dropout mask for real and imaginary part, 198 | # this not a clean solution! 199 | mask = torch.ones(*inp.shape, dtype=torch.float32, device=inp.device) 200 | mask = dropout(mask, p, training) * 1 / (1 - p) 201 | mask.type(inp.dtype) 202 | return mask * inp 203 | 204 | 205 | def complex_dropout2d(inp, p=0.5, training=True): 206 | # need to have the same dropout mask for real and imaginary part, 207 | # this not a clean solution! 208 | mask = torch.ones(*inp.shape, dtype=torch.float32, device=inp.device) 209 | mask = dropout2d(mask, p, training) * 1 / (1 - p) 210 | mask.type(inp.dtype) 211 | return mask * inp 212 | -------------------------------------------------------------------------------- /complexPyTorch/complexLayers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Mar 19 10:30:02 2019 5 | 6 | @author: Sebastien M. Popoff 7 | 8 | 9 | Based on https://openreview.net/forum?id=H1T2hmZAb 10 | """ 11 | from typing import Optional 12 | 13 | import torch 14 | from torch.nn import ( 15 | Module, Parameter, init, 16 | Conv2d, ConvTranspose2d, Linear, LSTM, GRU, 17 | BatchNorm1d, BatchNorm2d, 18 | PReLU 19 | ) 20 | 21 | from .complexFunctions import ( 22 | complex_relu, 23 | complex_tanh, 24 | complex_sigmoid, 25 | complex_max_pool2d, 26 | complex_avg_pool2d, 27 | complex_dropout, 28 | complex_dropout2d, 29 | complex_opposite, 30 | ) 31 | 32 | 33 | def apply_complex(fr, fi, input, dtype=torch.complex64): 34 | return (fr(input.real)-fi(input.imag)).type(dtype) \ 35 | + 1j*(fr(input.imag)+fi(input.real)).type(dtype) 36 | 37 | 38 | class ComplexDropout(Module): 39 | def __init__(self, p=0.5): 40 | super().__init__() 41 | self.p = p 42 | 43 | def forward(self, input): 44 | if self.training: 45 | return complex_dropout(input, self.p) 46 | else: 47 | return input 48 | 49 | 50 | class ComplexDropout2d(Module): 51 | def __init__(self, p=0.5): 52 | super(ComplexDropout2d, self).__init__() 53 | self.p = p 54 | 55 | def forward(self, inp): 56 | if self.training: 57 | return complex_dropout2d(inp, self.p) 58 | else: 59 | return inp 60 | 61 | 62 | class ComplexMaxPool2d(Module): 63 | def __init__( 64 | self, 65 | kernel_size, 66 | stride=None, 67 | padding=0, 68 | dilation=1, 69 | return_indices=False, 70 | ceil_mode=False, 71 | ): 72 | super(ComplexMaxPool2d, self).__init__() 73 | self.kernel_size = kernel_size 74 | self.stride = stride 75 | self.padding = padding 76 | self.dilation = dilation 77 | self.ceil_mode = ceil_mode 78 | self.return_indices = return_indices 79 | 80 | def forward(self, inp): 81 | return complex_max_pool2d( 82 | inp, 83 | kernel_size=self.kernel_size, 84 | stride=self.stride, 85 | padding=self.padding, 86 | dilation=self.dilation, 87 | ceil_mode=self.ceil_mode, 88 | return_indices=self.return_indices, 89 | ) 90 | 91 | 92 | class ComplexAvgPool2d(torch.nn.Module): 93 | 94 | def __init__(self, kernel_size, stride=None, padding=0, 95 | ceil_mode=False, count_include_pad=True, divisor_override=None): 96 | super(ComplexAvgPool2d, self).__init__() 97 | self.kernel_size = kernel_size 98 | self.stride = stride 99 | self.padding = padding 100 | self.ceil_mode = ceil_mode 101 | self.count_include_pad = count_include_pad 102 | self.divisor_override = divisor_override 103 | 104 | def forward(self, inp): 105 | return complex_avg_pool2d(inp, kernel_size=self.kernel_size, 106 | stride=self.stride, padding=self.padding, 107 | ceil_mode=self.ceil_mode, count_include_pad=self.count_include_pad, 108 | divisor_override=self.divisor_override) 109 | 110 | 111 | class ComplexReLU(Module): 112 | @staticmethod 113 | def forward(inp): 114 | return complex_relu(inp) 115 | 116 | 117 | class ComplexSigmoid(Module): 118 | @staticmethod 119 | def forward(inp): 120 | return complex_sigmoid(inp) 121 | 122 | 123 | class ComplexPReLU(Module): 124 | def __init__(self): 125 | super().__init__() 126 | self.r_prelu = PReLU() 127 | self.i_prelu = PReLU() 128 | 129 | @staticmethod 130 | def forward(self, inp): 131 | return self.r_prelu(inp.real) + 1j*self.i_prelu(inp.imag) 132 | 133 | 134 | class ComplexTanh(Module): 135 | @staticmethod 136 | def forward(inp): 137 | return complex_tanh(inp) 138 | 139 | 140 | class ComplexConvTranspose2d(Module): 141 | def __init__( 142 | self, 143 | in_channels, 144 | out_channels, 145 | kernel_size, 146 | stride=1, 147 | padding=0, 148 | output_padding=0, 149 | groups=1, 150 | bias=True, 151 | dilation=1, 152 | padding_mode="zeros", 153 | ): 154 | 155 | super().__init__() 156 | 157 | self.conv_tran_r = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, 158 | output_padding, groups, bias, dilation, padding_mode) 159 | self.conv_tran_i = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, 160 | output_padding, groups, bias, dilation, padding_mode) 161 | 162 | def forward(self, inp): 163 | return apply_complex(self.conv_tran_r, self.conv_tran_i, inp) 164 | 165 | 166 | class ComplexConv2d(Module): 167 | def __init__( 168 | self, 169 | in_channels, 170 | out_channels, 171 | kernel_size=3, 172 | stride=1, 173 | padding=0, 174 | dilation=1, 175 | groups=1, 176 | bias=True, 177 | ): 178 | super(ComplexConv2d, self).__init__() 179 | self.conv_r = Conv2d( 180 | in_channels, 181 | out_channels, 182 | kernel_size, 183 | stride, 184 | padding, 185 | dilation, 186 | groups, 187 | bias, 188 | ) 189 | self.conv_i = Conv2d( 190 | in_channels, 191 | out_channels, 192 | kernel_size, 193 | stride, 194 | padding, 195 | dilation, 196 | groups, 197 | bias, 198 | ) 199 | 200 | def forward(self, inp): 201 | return apply_complex(self.conv_r, self.conv_i, inp) 202 | 203 | 204 | class ComplexLinear(Module): 205 | def __init__(self, in_features, out_features): 206 | super().__init__() 207 | self.fc_r = Linear(in_features, out_features) 208 | self.fc_i = Linear(in_features, out_features) 209 | 210 | def forward(self, inp): 211 | return apply_complex(self.fc_r, self.fc_i, inp) 212 | 213 | 214 | class NaiveComplexBatchNorm1d(Module): 215 | """ 216 | Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. 217 | """ 218 | 219 | def __init__( 220 | self, 221 | num_features, 222 | eps=1e-5, 223 | momentum=0.1, 224 | affine=True, 225 | track_running_stats=True, 226 | ): 227 | super(NaiveComplexBatchNorm1d, self).__init__() 228 | self.bn_r = BatchNorm1d( 229 | num_features, eps, momentum, affine, track_running_stats 230 | ) 231 | self.bn_i = BatchNorm1d( 232 | num_features, eps, momentum, affine, track_running_stats 233 | ) 234 | 235 | def forward(self, inp): 236 | return self.bn_r(inp.real).type(torch.complex64) + 1j * self.bn_i( 237 | inp.imag 238 | ).type(torch.complex64) 239 | 240 | 241 | class NaiveComplexBatchNorm2d(Module): 242 | """ 243 | Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. 244 | """ 245 | 246 | def __init__( 247 | self, 248 | num_features, 249 | eps=1e-5, 250 | momentum=0.1, 251 | affine=True, 252 | track_running_stats=True, 253 | ): 254 | super(NaiveComplexBatchNorm2d, self).__init__() 255 | self.bn_r = BatchNorm2d( 256 | num_features, eps, momentum, affine, track_running_stats 257 | ) 258 | self.bn_i = BatchNorm2d( 259 | num_features, eps, momentum, affine, track_running_stats 260 | ) 261 | 262 | def forward(self, inp): 263 | return self.bn_r(inp.real).type(torch.complex64) + 1j * self.bn_i( 264 | inp.imag 265 | ).type(torch.complex64) 266 | 267 | 268 | class _ComplexBatchNorm(Module): 269 | running_mean: Optional[torch.Tensor] 270 | 271 | def __init__( 272 | self, 273 | num_features, 274 | eps=1e-5, 275 | momentum=0.1, 276 | affine=True, 277 | track_running_stats=True, 278 | ): 279 | super(_ComplexBatchNorm, self).__init__() 280 | self.num_features = num_features 281 | self.eps = eps 282 | self.momentum = momentum 283 | self.affine = affine 284 | self.track_running_stats = track_running_stats 285 | if self.affine: 286 | self.weight = Parameter(torch.Tensor(num_features, 3)) 287 | self.bias = Parameter(torch.Tensor(num_features, 2)) 288 | else: 289 | self.register_parameter("weight", None) 290 | self.register_parameter("bias", None) 291 | if self.track_running_stats: 292 | self.register_buffer( 293 | "running_mean", torch.zeros( 294 | num_features, dtype=torch.complex64) 295 | ) 296 | self.register_buffer("running_covar", torch.zeros(num_features, 3)) 297 | self.running_covar[:, 0] = 1.4142135623730951 298 | self.running_covar[:, 1] = 1.4142135623730951 299 | self.register_buffer( 300 | "num_batches_tracked", torch.tensor(0, dtype=torch.long) 301 | ) 302 | else: 303 | self.register_parameter("running_mean", None) 304 | self.register_parameter("running_covar", None) 305 | self.register_parameter("num_batches_tracked", None) 306 | self.reset_parameters() 307 | 308 | def reset_running_stats(self): 309 | if self.track_running_stats: 310 | self.running_mean.zero_() 311 | self.running_covar.zero_() 312 | self.running_covar[:, 0] = 1.4142135623730951 313 | self.running_covar[:, 1] = 1.4142135623730951 314 | self.num_batches_tracked.zero_() 315 | 316 | def reset_parameters(self): 317 | self.reset_running_stats() 318 | if self.affine: 319 | init.constant_(self.weight[:, :2], 1.4142135623730951) 320 | init.zeros_(self.weight[:, 2]) 321 | init.zeros_(self.bias) 322 | 323 | 324 | class ComplexBatchNorm2d(_ComplexBatchNorm): 325 | def forward(self, inp): 326 | exponential_average_factor = 0.0 327 | 328 | if self.training and self.track_running_stats: 329 | if self.num_batches_tracked is not None: 330 | self.num_batches_tracked += 1 331 | if self.momentum is None: # use cumulative moving average 332 | exponential_average_factor = 1.0 / \ 333 | float(self.num_batches_tracked) 334 | else: # use exponential moving average 335 | exponential_average_factor = self.momentum 336 | 337 | if self.training or (not self.track_running_stats): 338 | # calculate mean of real and imaginary part 339 | # mean does not support automatic differentiation for outputs with complex dtype. 340 | mean_r = inp.real.mean([0, 2, 3]).type(torch.complex64) 341 | mean_i = inp.imag.mean([0, 2, 3]).type(torch.complex64) 342 | mean = mean_r + 1j * mean_i 343 | else: 344 | mean = self.running_mean 345 | 346 | if self.training and self.track_running_stats: 347 | # update running mean 348 | with torch.no_grad(): 349 | self.running_mean = ( 350 | exponential_average_factor * mean 351 | + (1 - exponential_average_factor) * self.running_mean 352 | ) 353 | 354 | inp = inp - mean[None, :, None, None] 355 | 356 | if self.training or (not self.track_running_stats): 357 | # Elements of the covariance matrix (biased for train) 358 | n = inp.numel() / inp.size(1) 359 | Crr = 1.0 / n * inp.real.pow(2).sum(dim=[0, 2, 3]) + self.eps 360 | Cii = 1.0 / n * inp.imag.pow(2).sum(dim=[0, 2, 3]) + self.eps 361 | Cri = (inp.real.mul(inp.imag)).mean(dim=[0, 2, 3]) 362 | else: 363 | Crr = self.running_covar[:, 0] + self.eps 364 | Cii = self.running_covar[:, 1] + self.eps 365 | Cri = self.running_covar[:, 2] # +self.eps 366 | 367 | if self.training and self.track_running_stats: 368 | with torch.no_grad(): 369 | self.running_covar[:, 0] = ( 370 | exponential_average_factor * Crr * n / (n - 1) # 371 | + (1 - exponential_average_factor) * \ 372 | self.running_covar[:, 0] 373 | ) 374 | 375 | self.running_covar[:, 1] = ( 376 | exponential_average_factor * Cii * n / (n - 1) 377 | + (1 - exponential_average_factor) * 378 | self.running_covar[:, 1] 379 | ) 380 | 381 | self.running_covar[:, 2] = ( 382 | exponential_average_factor * Cri * n / (n - 1) 383 | + (1 - exponential_average_factor) * 384 | self.running_covar[:, 2] 385 | ) 386 | 387 | # calculate the inverse square root the covariance matrix 388 | det = Crr * Cii - Cri.pow(2) 389 | s = torch.sqrt(det) 390 | t = torch.sqrt(Cii + Crr + 2 * s) 391 | inverse_st = 1.0 / (s * t) 392 | Rrr = (Cii + s) * inverse_st 393 | Rii = (Crr + s) * inverse_st 394 | Rri = -Cri * inverse_st 395 | 396 | inp = ( 397 | Rrr[None, :, None, None] * inp.real + 398 | Rri[None, :, None, None] * inp.imag 399 | ).type(torch.complex64) + 1j * ( 400 | Rii[None, :, None, None] * inp.imag + 401 | Rri[None, :, None, None] * inp.real 402 | ).type( 403 | torch.complex64 404 | ) 405 | 406 | if self.affine: 407 | inp = ( 408 | self.weight[None, :, 0, None, None] * inp.real 409 | + self.weight[None, :, 2, None, None] * inp.imag 410 | + self.bias[None, :, 0, None, None] 411 | ).type(torch.complex64) + 1j * ( 412 | self.weight[None, :, 2, None, None] * inp.real 413 | + self.weight[None, :, 1, None, None] * inp.imag 414 | + self.bias[None, :, 1, None, None] 415 | ).type( 416 | torch.complex64 417 | ) 418 | return inp 419 | 420 | 421 | class ComplexBatchNorm1d(_ComplexBatchNorm): 422 | def forward(self, inp): 423 | 424 | exponential_average_factor = 0.0 425 | 426 | if self.training and self.track_running_stats: 427 | if self.num_batches_tracked is not None: 428 | self.num_batches_tracked += 1 429 | if self.momentum is None: # use cumulative moving average 430 | exponential_average_factor = 1.0 / \ 431 | float(self.num_batches_tracked) 432 | else: # use exponential moving average 433 | exponential_average_factor = self.momentum 434 | 435 | if self.training or (not self.track_running_stats): 436 | # calculate mean of real and imaginary part 437 | mean_r = inp.real.mean(dim=0).type(torch.complex64) 438 | mean_i = inp.imag.mean(dim=0).type(torch.complex64) 439 | mean = mean_r + 1j * mean_i 440 | else: 441 | mean = self.running_mean 442 | 443 | if self.training and self.track_running_stats: 444 | # update running mean 445 | with torch.no_grad(): 446 | self.running_mean = ( 447 | exponential_average_factor * mean 448 | + (1 - exponential_average_factor) * self.running_mean 449 | ) 450 | 451 | inp = inp - mean[None, ...] 452 | 453 | if self.training or (not self.track_running_stats): 454 | # Elements of the covariance matrix (biased for train) 455 | n = inp.numel() / inp.size(1) 456 | Crr = inp.real.var(dim=0, unbiased=False) + self.eps 457 | Cii = inp.imag.var(dim=0, unbiased=False) + self.eps 458 | Cri = (inp.real.mul(inp.imag)).mean(dim=0) 459 | else: 460 | Crr = self.running_covar[:, 0] + self.eps 461 | Cii = self.running_covar[:, 1] + self.eps 462 | Cri = self.running_covar[:, 2] 463 | 464 | if self.training and self.track_running_stats: 465 | with torch.no_grad(): 466 | self.running_covar[:, 0] = ( 467 | exponential_average_factor * Crr * n / (n - 1) 468 | + (1 - exponential_average_factor) * 469 | self.running_covar[:, 0] 470 | ) 471 | 472 | self.running_covar[:, 1] = ( 473 | exponential_average_factor * Cii * n / (n - 1) 474 | + (1 - exponential_average_factor) * 475 | self.running_covar[:, 1] 476 | ) 477 | 478 | self.running_covar[:, 2] = ( 479 | exponential_average_factor * Cri * n / (n - 1) 480 | + (1 - exponential_average_factor) * 481 | self.running_covar[:, 2] 482 | ) 483 | 484 | # calculate the inverse square root the covariance matrix 485 | det = Crr * Cii - Cri.pow(2) 486 | s = torch.sqrt(det) 487 | t = torch.sqrt(Cii + Crr + 2 * s) 488 | inverse_st = 1.0 / (s * t) 489 | Rrr = (Cii + s) * inverse_st 490 | Rii = (Crr + s) * inverse_st 491 | Rri = -Cri * inverse_st 492 | 493 | inp = (Rrr[None, :] * inp.real + Rri[None, :] * inp.imag).type( 494 | torch.complex64 495 | ) + 1j * (Rii[None, :] * inp.imag + Rri[None, :] * inp.real).type( 496 | torch.complex64 497 | ) 498 | 499 | if self.affine: 500 | inp = ( 501 | self.weight[None, :, 0] * inp.real 502 | + self.weight[None, :, 2] * inp.imag 503 | + self.bias[None, :, 0] 504 | ).type(torch.complex64) + 1j * ( 505 | self.weight[None, :, 2] * inp.real 506 | + self.weight[None, :, 1] * inp.imag 507 | + self.bias[None, :, 1] 508 | ).type( 509 | torch.complex64 510 | ) 511 | 512 | del Crr, Cri, Cii, Rrr, Rii, Rri, det, s, t 513 | return inp 514 | 515 | 516 | class ComplexGRUCell(Module): 517 | """ 518 | A GRU cell for complex-valued inputs 519 | """ 520 | 521 | def __init__(self, input_length, hidden_length): 522 | super().__init__() 523 | self.input_length = input_length 524 | self.hidden_length = hidden_length 525 | 526 | # reset gate components 527 | self.linear_reset_w1 = ComplexLinear( 528 | self.input_length, self.hidden_length) 529 | self.linear_reset_r1 = ComplexLinear( 530 | self.hidden_length, self.hidden_length) 531 | 532 | self.linear_reset_w2 = ComplexLinear( 533 | self.input_length, self.hidden_length) 534 | self.linear_reset_r2 = ComplexLinear( 535 | self.hidden_length, self.hidden_length) 536 | 537 | # update gate components 538 | self.linear_gate_w3 = ComplexLinear( 539 | self.input_length, self.hidden_length) 540 | self.linear_gate_r3 = ComplexLinear( 541 | self.hidden_length, self.hidden_length) 542 | 543 | self.activation_gate = ComplexSigmoid() 544 | self.activation_candidate = ComplexTanh() 545 | 546 | def reset_gate(self, x, h): 547 | x_1 = self.linear_reset_w1(x) 548 | h_1 = self.linear_reset_r1(h) 549 | # gate update 550 | reset = self.activation_gate(x_1 + h_1) 551 | return reset 552 | 553 | def update_gate(self, x, h): 554 | x_2 = self.linear_reset_w2(x) 555 | h_2 = self.linear_reset_r2(h) 556 | z = self.activation_gate(h_2 + x_2) 557 | return z 558 | 559 | def update_component(self, x, h, r): 560 | x_3 = self.linear_gate_w3(x) 561 | h_3 = r * self.linear_gate_r3(h) # element-wise multiplication 562 | gate_update = self.activation_candidate(x_3 + h_3) 563 | return gate_update 564 | 565 | def forward(self, x, h): 566 | # Equation 1. reset gate vector 567 | r = self.reset_gate(x, h) 568 | 569 | # Equation 2: the update gate - the shared update gate vector z 570 | z = self.update_gate(x, h) 571 | 572 | # Equation 3: The almost output component 573 | n = self.update_component(x, h, r) 574 | 575 | # Equation 4: the new hidden state 576 | h_new = (1 + complex_opposite(z)) * n + \ 577 | z * h # element-wise multiplication 578 | return h_new 579 | 580 | 581 | class ComplexBNGRUCell(Module): 582 | """ 583 | A BN-GRU cell for complex-valued inputs 584 | """ 585 | 586 | def __init__(self, input_length=10, hidden_length=20): 587 | super().__init__() 588 | self.input_length = input_length 589 | self.hidden_length = hidden_length 590 | 591 | # reset gate components 592 | self.linear_reset_w1 = ComplexLinear( 593 | self.input_length, self.hidden_length) 594 | self.linear_reset_r1 = ComplexLinear( 595 | self.hidden_length, self.hidden_length) 596 | 597 | self.linear_reset_w2 = ComplexLinear( 598 | self.input_length, self.hidden_length) 599 | self.linear_reset_r2 = ComplexLinear( 600 | self.hidden_length, self.hidden_length) 601 | 602 | # update gate components 603 | self.linear_gate_w3 = ComplexLinear( 604 | self.input_length, self.hidden_length) 605 | self.linear_gate_r3 = ComplexLinear( 606 | self.hidden_length, self.hidden_length) 607 | 608 | self.activation_gate = ComplexSigmoid() 609 | self.activation_candidate = ComplexTanh() 610 | 611 | self.bn = ComplexBatchNorm2d(1) 612 | 613 | def reset_gate(self, x, h): 614 | x_1 = self.linear_reset_w1(x) 615 | h_1 = self.linear_reset_r1(h) 616 | # gate update 617 | reset = self.activation_gate(self.bn(x_1) + self.bn(h_1)) 618 | return reset 619 | 620 | def update_gate(self, x, h): 621 | x_2 = self.linear_reset_w2(x) 622 | h_2 = self.linear_reset_r2(h) 623 | z = self.activation_gate(self.bn(h_2) + self.bn(x_2)) 624 | return z 625 | 626 | def update_component(self, x, h, r): 627 | x_3 = self.linear_gate_w3(x) 628 | # element-wise multiplication 629 | h_3 = r * self.bn(self.linear_gate_r3(h)) 630 | gate_update = self.activation_candidate(self.bn(self.bn(x_3) + h_3)) 631 | return gate_update 632 | 633 | def forward(self, x, h): 634 | # Equation 1. reset gate vector 635 | r = self.reset_gate(x, h) 636 | 637 | # Equation 2: the update gate - the shared update gate vector z 638 | z = self.update_gate(x, h) 639 | 640 | # Equation 3: The almost output component 641 | n = self.update_component(x, h, r) 642 | 643 | # Equation 4: the new hidden state 644 | 645 | 646 | class ComplexGRU(Module): 647 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, 648 | batch_first=False, dropout=0, bidirectional=False): 649 | super().__init__() 650 | 651 | self.gru_re = GRU(input_size=input_size, hidden_size=hidden_size, 652 | num_layers=num_layers, bias=bias, 653 | batch_first=batch_first, dropout=dropout, 654 | bidirectional=bidirectional) 655 | self.gru_im = GRU(input_size=input_size, hidden_size=hidden_size, 656 | num_layers=num_layers, bias=bias, 657 | batch_first=batch_first, dropout=dropout, 658 | bidirectional=bidirectional) 659 | 660 | def forward(self, x): 661 | real, state_real = self._forward_real(x) 662 | imaginary, state_imag = self._forward_imaginary(x) 663 | 664 | output = torch.complex(real, imaginary) 665 | state = torch.complex(state_real, state_imag) 666 | 667 | return output, state 668 | 669 | def forward(self, x): 670 | r2r_out = self.gru_re(x.real)[0] 671 | r2i_out = self.gru_im(x.real)[0] 672 | i2r_out = self.gru_re(x.imag)[0] 673 | i2i_out = self.gru_im(x.imag)[0] 674 | real_out = r2r_out - i2i_out 675 | imag_out = i2r_out + r2i_out 676 | 677 | return torch.complex(real_out, imag_out), None 678 | 679 | def _forward_real(self, x): 680 | real_real, h_real = self.gru_re(x.real) 681 | imag_imag, h_imag = self.gru_im(x.imag) 682 | real = real_real - imag_imag 683 | 684 | return real, torch.complex(h_real, h_imag) 685 | 686 | def _forward_imaginary(self, x): 687 | imag_real, h_real = self.gru_re(x.imag) 688 | real_imag, h_imag = self.gru_im(x.real) 689 | imaginary = imag_real + real_imag 690 | 691 | return imaginary, torch.complex(h_real, h_imag) 692 | 693 | 694 | class ComplexLSTM(Module): 695 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, 696 | batch_first=False, dropout=0, bidirectional=False): 697 | super().__init__() 698 | self.num_layer = num_layers 699 | self.hidden_size = hidden_size 700 | self.batch_dim = 0 if batch_first else 1 701 | self.bidirectional = bidirectional 702 | 703 | self.lstm_re = LSTM(input_size=input_size, hidden_size=hidden_size, 704 | num_layers=num_layers, bias=bias, 705 | batch_first=batch_first, dropout=dropout, 706 | bidirectional=bidirectional) 707 | self.lstm_im = LSTM(input_size=input_size, hidden_size=hidden_size, 708 | num_layers=num_layers, bias=bias, 709 | batch_first=batch_first, dropout=dropout, 710 | bidirectional=bidirectional) 711 | 712 | def forward(self, x): 713 | real, state_real = self._forward_real(x) 714 | imaginary, state_imag = self._forward_imaginary(x) 715 | 716 | output = torch.complex(real, imaginary) 717 | 718 | return output, (state_real, state_imag) 719 | 720 | def _forward_real(self, x): 721 | h_real, h_imag, c_real, c_imag = self._init_state( 722 | self._get_batch_size(x), x.is_cuda) 723 | real_real, (h_real, c_real) = self.lstm_re(x.real, (h_real, c_real)) 724 | imag_imag, (h_imag, c_imag) = self.lstm_im(x.imag, (h_imag, c_imag)) 725 | real = real_real - imag_imag 726 | return real, ((h_real, c_real), (h_imag, c_imag)) 727 | 728 | def _forward_imaginary(self, x): 729 | h_real, h_imag, c_real, c_imag = self._init_state( 730 | self._get_batch_size(x), x.is_cuda) 731 | imag_real, (h_real, c_real) = self.lstm_re(x.imag, (h_real, c_real)) 732 | real_imag, (h_imag, c_imag) = self.lstm_im(x.real, (h_imag, c_imag)) 733 | imaginary = imag_real + real_imag 734 | 735 | return imaginary, ((h_real, c_real), (h_imag, c_imag)) 736 | 737 | def _init_state(self, batch_size, to_gpu=False): 738 | dim_0 = 2 if self.bidirectional else 1 739 | dims = (dim_0, batch_size, self.hidden_size) 740 | 741 | h_real, h_imag, c_real, c_imag = [ 742 | torch.zeros(dims) for i in range(4)] 743 | 744 | if to_gpu: 745 | h_real, h_imag, c_real, c_imag = [ 746 | t.cuda() for t in [h_real, h_imag, c_real, c_imag]] 747 | 748 | return h_real, h_imag, c_real, c_imag 749 | 750 | def _get_batch_size(self, x): 751 | return x.size(self.batch_dim) 752 | h_new = (1 + complex_opposite(z)) * n + \ 753 | z * h # element-wise multiplication 754 | 755 | return h_new 756 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="complexPyTorch", 5 | version="0.4.1", 6 | description="A high-level toolbox for using complex valued neural networks in PyTorch.", 7 | long_description=open("README.md").read().strip(), 8 | long_description_content_type="text/markdown", 9 | author="Sebastien M. Popoff", 10 | author_email="sebastien.popoff@espci.psl.eu", 11 | url="https://gitlab.institut-langevin.espci.fr/spopoff/complexPyTorch", 12 | packages=find_packages(), 13 | install_requires=["torch"], 14 | python_requires=">=3.6", 15 | license="MIT License", 16 | zip_safe=False, 17 | keywords="pytorch, deep learning, complex values", 18 | classifiers=[""], 19 | ) 20 | --------------------------------------------------------------------------------