├── tests ├── __init__.py ├── torch_complex_test.py └── pytorch_complex_tensor_test.py ├── pytorch_complex_tensor ├── __init__.py ├── complex_scalar.py ├── torch_complex.py ├── complex_grad.py └── complex_tensor.py ├── update.sh ├── requirements.txt ├── setup.py ├── LICENSE ├── .circleci └── config.yml ├── .gitignore ├── README.md └── Complex demo example.ipynb /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_complex_tensor/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_complex_tensor.complex_scalar import ComplexScalar 2 | from pytorch_complex_tensor.complex_grad import ComplexGrad 3 | from pytorch_complex_tensor.complex_tensor import ComplexTensor 4 | -------------------------------------------------------------------------------- /update.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | version=$1 4 | 5 | git commit -am "release v$version" 6 | git tag $version -m "pytorch_complex_tensor v$version" 7 | git push --tags origin master 8 | 9 | # push to pypi 10 | rm -rf ./dist/* 11 | python3 setup.py sdist 12 | twine upload dist/* 13 | 14 | 15 | 16 | # to update docs 17 | # cd to root dir 18 | # mkdocs gh-deploy 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | atomicwrites==1.3.0 2 | attrs==18.2.0 3 | certifi==2018.11.29 4 | cffi==1.11.5 5 | chardet==3.0.4 6 | codecov==2.0.15 7 | coverage==4.5.2 8 | idna==2.8 9 | more-itertools==5.0.0 10 | numpy==1.15.4 11 | olefile==0.46 12 | Pillow==6.2.0 13 | pluggy==0.8.1 14 | py==1.7.0 15 | pycparser==2.19 16 | pytest==4.2.0 17 | requests==2.21.0 18 | six==1.12.0 19 | torch==1.0.1 20 | torchvision==0.2.1 21 | urllib3==1.24.2 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='pytorch-complex-tensor', 7 | version='0.0.134', 8 | description='Pytorch complex tensor', 9 | author='William Falcon', 10 | author_email='waf2107@columbia.edu', 11 | url='https://github.com/williamFalcon/pytorch-complex-tensor', 12 | install_requires=[ 13 | 'numpy>=1.15.4', 14 | 'torch>=1.0', 15 | 'torchvision>=0.2.1' 16 | ], 17 | packages=find_packages() 18 | ) 19 | -------------------------------------------------------------------------------- /pytorch_complex_tensor/complex_scalar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Thin wrapper for complex scalar. 3 | Main contribution is to use only real part for backward 4 | """ 5 | 6 | 7 | class ComplexScalar(object): 8 | 9 | def __init__(self, real, imag): 10 | self._real = real 11 | self._imag = imag 12 | 13 | @property 14 | def real(self): 15 | return self._real 16 | 17 | @property 18 | def imag(self): 19 | return self._imag 20 | 21 | def backward(self): 22 | self._real.backward() 23 | 24 | def __repr__(self): 25 | return str(complex(self.real.item(), self.imag.item())) 26 | 27 | def __str__(self): 28 | return self.__repr__() 29 | -------------------------------------------------------------------------------- /pytorch_complex_tensor/torch_complex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_complex_tensor import ComplexTensor 3 | 4 | 5 | def __graph_copy__(real, imag): 6 | # return tensor copy but maintain graph connections 7 | # force the result to be a ComplexTensor 8 | result = torch.cat([real, imag], dim=-2) 9 | result.__class__ = ComplexTensor 10 | return result 11 | 12 | 13 | def __apply_fx_to_parts(items, fx, *args, **kwargs): 14 | r = [x.real for x in items] 15 | r = fx(r, *args, **kwargs) 16 | 17 | i = [x.imag for x in items] 18 | i = fx(i, *args, **kwargs) 19 | 20 | return __graph_copy__(r, i) 21 | 22 | 23 | def stack(items, *args, **kwargs): 24 | return __apply_fx_to_parts(items, torch.stack, *args, **kwargs) 25 | 26 | 27 | def cat(items, *args, **kwargs): 28 | return __apply_fx_to_parts(items, torch.cat, *args, **kwargs) 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 William Falcon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/torch_complex_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from pytorch_complex_tensor import ComplexTensor 5 | from pytorch_complex_tensor import torch_complex 6 | 7 | 8 | def __test_torch_op(complex_op, torch_op): 9 | a = ComplexTensor(torch.zeros(4, 3)) + 1 10 | b = ComplexTensor(torch.zeros(4, 3)) + 2 11 | c = ComplexTensor(torch.zeros(4, 3)) + 3 12 | 13 | d = complex_op([a, b, c], dim=0) 14 | size = list(d.size()) 15 | 16 | # double second to last axis bc we always half it when generating tensors 17 | size[-2] *= 2 18 | 19 | # compare against regular torch implementation 20 | r_a = torch.zeros(4, 3) 21 | r_b = torch.zeros(4, 3) 22 | r_c = torch.zeros(4, 3) 23 | r_d = torch_op([r_a, r_b, r_c], dim=0) 24 | t_size = r_d.size() 25 | 26 | for i in range(len(size)): 27 | assert size[i] == t_size[i] 28 | 29 | 30 | def test_stack(): 31 | __test_torch_op(torch_complex.stack, torch.stack) 32 | 33 | 34 | def test_cat(): 35 | __test_torch_op(torch_complex.cat, torch.cat) 36 | 37 | 38 | if __name__ == '__main__': 39 | pytest.main([__file__]) 40 | 41 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # 5 | version: 2 6 | jobs: 7 | build: 8 | docker: 9 | # specify the version you desire here 10 | # use `-browsers` prefix for selenium tests, e.g. `3.6.1-browsers` 11 | - image: circleci/python:3.6.1 12 | 13 | # Specify service dependencies here if necessary 14 | # CircleCI maintains a library of pre-built images 15 | # documented at https://circleci.com/docs/2.0/circleci-images/ 16 | # - image: circleci/postgres:9.4 17 | 18 | working_directory: ~/repo 19 | 20 | steps: 21 | - checkout 22 | 23 | # Download and cache dependencies 24 | - restore_cache: 25 | keys: 26 | - v1-dependencies-{{ checksum "requirements.txt" }} 27 | # fallback to using the latest cache if no exact match is found 28 | - v1-dependencies- 29 | 30 | - run: 31 | name: install dependencies 32 | command: | 33 | python3 -m venv venv 34 | . venv/bin/activate 35 | pip install -r requirements.txt 36 | pip install -e . 37 | 38 | - save_cache: 39 | paths: 40 | - ./venv 41 | key: v1-dependencies-{{ checksum "requirements.txt" }} 42 | 43 | # run tests! 44 | # this example uses Django's built-in test-runner 45 | # other common Python testing frameworks include pytest and nose 46 | # https://pytest.org 47 | # https://nose.readthedocs.io 48 | - run: 49 | name: run tests 50 | command: | 51 | . venv/bin/activate 52 | pytest 53 | codecov 54 | 55 | - store_artifacts: 56 | path: test-reports 57 | destination: test-reports 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # project 2 | .DS_Store 3 | dev.py 4 | app/models/ 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | example.py 10 | timit_data/ 11 | LJSpeech-1.1/ 12 | 13 | # C extensions 14 | *.so 15 | 16 | .idea/ 17 | 18 | # Distribution / packaging 19 | .Python 20 | env/ 21 | ide_layouts/ 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # dotenv 93 | .env 94 | 95 | # virtualenv 96 | .venv 97 | venv/ 98 | ENV/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | -------------------------------------------------------------------------------- /pytorch_complex_tensor/complex_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import re 4 | 5 | """ 6 | Does nothing except pretty print complex grad info 7 | """ 8 | 9 | class ComplexGrad(torch.Tensor): 10 | 11 | def __deepcopy__(self, memo): 12 | if not self.is_leaf: 13 | raise RuntimeError("Only Tensors created explicitly by the user " 14 | "(graph leaves) support the deepcopy protocol at the moment") 15 | if id(self) in memo: 16 | return memo[id(self)] 17 | with torch.no_grad(): 18 | if self.is_sparse: 19 | new_tensor = self.clone() 20 | 21 | # hack tensor to cast as complex 22 | new_tensor.__class__ = ComplexGrad 23 | else: 24 | new_storage = self.storage().__deepcopy__(memo) 25 | new_tensor = self.new() 26 | 27 | # hack tensor to cast as complex 28 | new_tensor.__class__ = ComplexGrad 29 | new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride()) 30 | memo[id(self)] = new_tensor 31 | new_tensor.requires_grad = self.requires_grad 32 | return new_tensor 33 | 34 | def __repr__(self): 35 | size = self.size() 36 | split_i = size[0] // 2 37 | real = self[:split_i] 38 | imag = self[split_i:] 39 | size_r = real.size() 40 | 41 | real = real.view(-1) 42 | imag = imag.view(-1) 43 | 44 | strings = np.asarray([f'({a}{"+" if b > 0 else "-"}{abs(b)}j)' for a, b in zip(real, imag)]) 45 | strings = strings.reshape(*size_r) 46 | strings = f'tensor({strings.__str__()})' 47 | strings = re.sub('\n', ',\n ', strings) 48 | return strings 49 | 50 | def __str__(self): 51 | return self.__repr__() 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 5 |

6 |

7 | Pytorch Complex Tensor 8 |

9 |

10 | Unofficial complex Tensor support for Pytorch 11 |

12 |

13 | PyPI version 14 | 15 | 16 | 17 |

18 | 19 | ### How it works 20 | 21 | Treats first half of tensor as real, second as imaginary. A few arithmetic operations are implemented to emulate complex arithmetic. Supports gradients. 22 | 23 | ### Installation 24 | ```bash 25 | pip install pytorch-complex-tensor 26 | ``` 27 | 28 | ### Example: 29 | Easy import 30 | 31 | ```python 32 | from pytorch_complex_tensor import ComplexTensor 33 | ``` 34 | 35 | Init tensor 36 | 37 | ```python 38 | # equivalent to: 39 | # np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64) 40 | C = ComplexTensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) 41 | C.requires_grad = True 42 | ``` 43 | 44 | Pretty printing 45 | 46 | ```python 47 | print(C) 48 | # tensor([['(1.0+3.0j)' '(1.0+3.0j)' '(1.0+3.0j)'], 49 | # ['(2.0+4.0j)' '(2.0+4.0j)' '(2.0+4.0j)']]) 50 | ``` 51 | 52 | handles absolute value properly for complex tensors 53 | 54 | ```python 55 | # complex absolute value implementation 56 | print(C.abs()) 57 | # tensor([[3.1623, 3.1623, 3.1623], 58 | # [4.4721, 4.4721, 4.4721]], grad_fn=) 59 | ``` 60 | 61 | 62 | prints correct sizing treating first half of matrix as real, second as imag 63 | ```python 64 | print(C.size()) 65 | # torch.Size([2, 3]) 66 | ``` 67 | 68 | multiplies both complex and real tensors 69 | ```python 70 | # show matrix multiply with real tensor 71 | # also works with complex tensor 72 | x = torch.Tensor([[3, 3], [4, 4], [2, 2]]) 73 | xy = C.mm(x) 74 | print(xy) 75 | # tensor([['(9.0+27.0j)' '(9.0+27.0j)'], 76 | # ['(18.0+36.0j)' '(18.0+36.0j)']]) 77 | ``` 78 | 79 | reduce ops return ComplexScalar 80 | ```python 81 | xy = xy.sum() 82 | 83 | # this is now a complex scalar (thin wrapper with .real, .imag) 84 | print(type(xy)) 85 | # pytorch_complex_tensor.complex_scalar.ComplexScalar 86 | 87 | print(xy) 88 | # (54+126j) 89 | ``` 90 | 91 | which can be used for gradients without breaking anything... (differentiates wrt the real part) 92 | ```python 93 | # calculate dxy / dC 94 | # for complex scalars, grad is wrt the real part 95 | xy.backward() 96 | print(C.grad) 97 | # tensor([['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)'], 98 | # ['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)']]) 99 | ``` 100 | 101 | supports all section ops... 102 | ```python 103 | print(C[-1]) 104 | print(C[0, 0:-2, ...]) 105 | print(C[0, ..., 0]) 106 | ``` 107 | 108 | 109 | ### Supported ops: 110 | | Operation | complex tensor | real tensor | complex scalar | real scalar | 111 | | ----------| :-------------:|:-----------:|:--------------:|:-----------:| 112 | | addition | Y | Y | Y | Y | 113 | | subtraction | Y | Y | Y | Y | 114 | | multiply | Y | Y | Y | Y | 115 | | mm | Y | Y | Y | Y | 116 | | abs | Y | - | - | - | 117 | | t | Y | - | - | - | 118 | | grads | Y | Y | Y | Y | 119 | 120 | 121 | -------------------------------------------------------------------------------- /tests/pytorch_complex_tensor_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from pytorch_complex_tensor import ComplexTensor 5 | 6 | 7 | # ------------------ 8 | # GRAD TESTS 9 | # ------------------- 10 | def test_grad(): 11 | """ 12 | Grad calculated first with tensorflow 13 | 14 | :return: 15 | """ 16 | 17 | c = ComplexTensor([[1, 3, 5], [7, 9, 11], [2, 4, 6], [8, 10, 12]]) 18 | c.requires_grad = True 19 | 20 | # simulate some ops 21 | out = c + 4 22 | out = out.mm(c.t()) 23 | 24 | # calc grad 25 | out = out.sum() 26 | out.backward() 27 | 28 | # d_out/dc 29 | g = c.grad.view(-1).data.numpy() 30 | 31 | # solution (as provided by running same ops in tensorflow) 32 | """ 33 | tf_c2 = tf.constant([[1+2j, 3+4j, 5+6j], [7+8j,9+10j,11+12j]], dtype=tf.complex64) 34 | 35 | with tf.GradientTape() as t: 36 | t.watch(tf_c2) 37 | tf_out = tf_c2 + 4 38 | tf_out = tf.matmul(tf_out, tf.transpose(tf_c2, perm=[1,0])) 39 | 40 | tf_y = tf.reduce_sum(tf_out) 41 | dy_dc2 = t.gradient(tf_y, tf_c2) 42 | 43 | # solution 44 | print(dy_dc2) 45 | """ 46 | # 47 | sol = np.asarray([24, 32, 40, 24, 32, 40, -20, -28, -36, -20, -28, -36]) 48 | assert np.array_equal(g, sol) 49 | 50 | 51 | def test_size(): 52 | # test sizing when init with tensor 53 | c = ComplexTensor(torch.zeros(4, 3)) 54 | size = c.size() 55 | n, m = size[-2:] 56 | assert n == 2 57 | assert m == 3 58 | 59 | # test sizing when init with dim spec 60 | c = ComplexTensor(12, 8) 61 | size = c.size() 62 | n, m = size[-2:] 63 | assert n == 12 64 | assert m == 8 65 | 66 | 67 | def test_shape(): 68 | # test sizing when init with tensor 69 | c = ComplexTensor(torch.zeros(4, 3)) 70 | size = c.shape 71 | n, m = size[-2:] 72 | assert n == 2 73 | assert m == 3 74 | 75 | # test sizing when init with dim spec 76 | c = ComplexTensor(12, 8) 77 | size = c.shape 78 | n, m = size[-2:] 79 | assert n == 12 80 | assert m == 8 81 | 82 | 83 | # ------------------ 84 | # REDUCE FX TESTS 85 | # ------------------ 86 | def test_abs(): 87 | c = ComplexTensor(torch.zeros(4, 3)) + 2 88 | c = (4+3j) * c 89 | c = c.abs() 90 | c = c.view(-1).data.numpy() 91 | 92 | # do the same in numpy 93 | sol = np.zeros((2, 3)).astype(np.complex64) + 2 94 | sol = (4+3j) * sol 95 | sol = np.abs(sol) 96 | 97 | sol = sol.flatten() 98 | sol = list(sol.real) 99 | 100 | assert np.array_equal(c, sol) 101 | 102 | 103 | # ------------------ 104 | # SUM TESTS 105 | # ------------------ 106 | def test_real_scalar_sum(): 107 | 108 | c = ComplexTensor(torch.zeros(4, 3)) 109 | c = c + 4 110 | c = c.view(-1).data.numpy() 111 | 112 | # do the same in numpy 113 | sol = np.zeros((2, 3)).astype(np.complex64) 114 | sol = sol + 4 115 | sol = sol.flatten() 116 | sol = list(sol.real) + list(sol.imag) 117 | 118 | assert np.array_equal(c, sol) 119 | 120 | 121 | def test_complex_scalar_sum(): 122 | c = ComplexTensor(torch.zeros(4, 3)) 123 | c = c + (4+3j) 124 | c = c.view(-1).data.numpy() 125 | 126 | # do the same in numpy 127 | sol = np.zeros((2, 3)).astype(np.complex64) 128 | sol = sol + (4+3j) 129 | sol = sol.flatten() 130 | sol = list(sol.real) + list(sol.imag) 131 | 132 | assert np.array_equal(c, sol) 133 | 134 | 135 | def test_real_matrix_sum(): 136 | 137 | c = ComplexTensor(torch.zeros(4, 3)) 138 | r = torch.ones(2, 3) 139 | c = c + r 140 | c = c.view(-1).data.numpy() 141 | 142 | # do the same in numpy 143 | sol = np.zeros((2, 3)).astype(np.complex64) 144 | sol_r = np.ones((2, 3)) 145 | sol = sol + sol_r 146 | sol = sol.flatten() 147 | sol = list(sol.real) + list(sol.imag) 148 | 149 | assert np.array_equal(c, sol) 150 | 151 | 152 | def test_complex_matrix_sum(): 153 | 154 | c = ComplexTensor(torch.zeros(4, 3)) 155 | cc = c + c 156 | cc = cc.view(-1).data.numpy() 157 | 158 | # do the same in numpy 159 | sol = np.zeros((2, 3)).astype(np.complex64) 160 | sol = sol + sol 161 | sol = sol.flatten() 162 | sol = list(sol.real) + list(sol.imag) 163 | 164 | assert np.array_equal(cc, sol) 165 | 166 | 167 | # ------------------ 168 | # MULT TESTS 169 | # ------------------ 170 | def test_scalar_mult(): 171 | c = ComplexTensor(torch.zeros(4, 3)) + 1 172 | c = c * 4 173 | c = c.view(-1).data.numpy() 174 | 175 | # do the same in numpy 176 | sol = np.zeros((2, 3)).astype(np.complex64) + 1 177 | sol = sol * 4 178 | sol = sol.flatten() 179 | sol = list(sol.real) + list(sol.imag) 180 | 181 | assert np.array_equal(c, sol) 182 | 183 | 184 | def test_scalar_rmult(): 185 | c = ComplexTensor(torch.zeros(4, 3)) + 1 186 | c = 4 * c 187 | c = c.view(-1).data.numpy() 188 | 189 | # do the same in numpy 190 | sol = np.zeros((2, 3)).astype(np.complex64) + 1 191 | sol = 4 * sol 192 | sol = sol.flatten() 193 | sol = list(sol.real) + list(sol.imag) 194 | 195 | assert np.array_equal(c, sol) 196 | 197 | 198 | def test_complex_mult(): 199 | c = ComplexTensor(torch.zeros(4, 3)) + 1 200 | c = c * (4+3j) 201 | c = c.view(-1).data.numpy() 202 | 203 | # do the same in numpy 204 | sol = np.zeros((2, 3)).astype(np.complex64) + 1 205 | sol = sol * (4+3j) 206 | sol = sol.flatten() 207 | sol = list(sol.real) + list(sol.imag) 208 | 209 | assert np.array_equal(c, sol) 210 | 211 | 212 | def test_complex_rmult(): 213 | c = ComplexTensor(torch.zeros(4, 3)) + 1 214 | c = (4+3j) * c 215 | c = c.view(-1).data.numpy() 216 | 217 | # do the same in numpy 218 | sol = np.zeros((2, 3)).astype(np.complex64) + 1 219 | sol = (4+3j) * sol 220 | sol = sol.flatten() 221 | sol = list(sol.real) + list(sol.imag) 222 | 223 | assert np.array_equal(c, sol) 224 | 225 | 226 | def test_complex_complex_ele_mult(): 227 | """ 228 | Complex mtx x complex mtx elementwise multiply 229 | :return: 230 | """ 231 | c = ComplexTensor(torch.zeros(4, 3)) + 1 232 | c = c * c 233 | c = c.view(-1).data.numpy() 234 | 235 | # do the same in numpy 236 | sol = np.zeros((2, 3)).astype(np.complex64) + 1 237 | sol = sol * sol 238 | sol = sol.flatten() 239 | sol = list(sol.real) + list(sol.imag) 240 | 241 | assert np.array_equal(c, sol) 242 | 243 | 244 | def test_complex_real_ele_mult(): 245 | """ 246 | Complex mtx x real mtx elementwise multiply 247 | :return: 248 | """ 249 | c = ComplexTensor(torch.zeros(4, 3)) + 1 250 | r = torch.ones(2, 3) * 2 + 3 251 | cr = c * r 252 | cr = cr.view(-1).data.numpy() 253 | 254 | # do the same in numpy 255 | np_c = np.ones((2, 3)).astype(np.complex64) 256 | np_r = np.ones((2, 3)) * 2 + 3 257 | np_cr = np_c * np_r 258 | 259 | # compare 260 | np_cr = np_cr.flatten() 261 | np_cr = list(np_cr.real) + list(np_cr.imag) 262 | 263 | assert np.array_equal(np_cr, cr) 264 | 265 | 266 | # ------------------ 267 | # MM TESTS 268 | # ------------------ 269 | def test_complex_real_mm(): 270 | """ 271 | Complex mtx x real mtx matrix multiply 272 | :return: 273 | """ 274 | c = ComplexTensor(torch.zeros(4, 3)) + 1 275 | r = torch.ones(2, 3) * 2 + 3 276 | cr = c.mm(r.t()) 277 | cr = cr.view(-1).data.numpy() 278 | 279 | # do the same in numpy 280 | np_c = np.ones((2, 3)).astype(np.complex64) 281 | np_r = np.ones((2, 3)) * 2 + 3 282 | np_cr = np.matmul(np_c, np_r.T) 283 | 284 | # compare 285 | np_cr = np_cr.flatten() 286 | np_cr = list(np_cr.real) + list(np_cr.imag) 287 | 288 | assert np.array_equal(np_cr, cr) 289 | 290 | 291 | def test_complex_complex_mm(): 292 | """ 293 | Complex mtx x complex mtx matrix multiply 294 | :return: 295 | """ 296 | c = ComplexTensor(torch.zeros(4, 3)) + 1 297 | cc = c.mm(c.t()) 298 | cc = cc.view(-1).data.numpy() 299 | 300 | # do the same in numpy 301 | np_c = np.ones((2, 3)).astype(np.complex64) 302 | np_cc = np.matmul(np_c, np_c.T) 303 | 304 | # compare 305 | np_cc = np_cc.flatten() 306 | np_cc = list(np_cc.real) + list(np_cc.imag) 307 | 308 | assert np.array_equal(np_cc, cc) 309 | 310 | 311 | def test_get_item(): 312 | # init random complex numpy and ct tensors 313 | a = np.random.randint(0, 10, (3, 2, 3)) 314 | a = a * (1+5j) 315 | ct = ComplexTensor(a) 316 | 317 | # match dim 0 318 | __assert_tensors_equal(ct[0], a[0]) 319 | __assert_tensors_equal(ct[-1], a[-1]) 320 | 321 | # match dim 1 322 | __assert_tensors_equal(ct[:, 0], a[:, 0]) 323 | __assert_tensors_equal(ct[:, -1], a[:, -1]) 324 | 325 | # match dim 2 326 | __assert_tensors_equal(ct[:, :, 0], a[:, :, 0]) 327 | __assert_tensors_equal(ct[:, :, -1], a[:, :, -1]) 328 | 329 | # match ranges 330 | __assert_tensors_equal(ct[0:1, 0, -2:], a[0:1, 0, -2:]) 331 | __assert_tensors_equal(ct[-1:, -1:, -2:], a[-1:, -1:, -2:]) 332 | __assert_tensors_equal(ct[:-1, :-1, :-2], a[:-1, :-1, :-2]) 333 | 334 | 335 | def __assert_tensors_equal(ct_tensor, np_tensor): 336 | # assert we have complexTensor 337 | assert type(ct_tensor) is ComplexTensor 338 | 339 | # assert values are same 340 | np_tensor = np_tensor.flatten() 341 | np_tensor = list(np_tensor.real) + list(np_tensor.imag) 342 | ct_tensor = ct_tensor.flatten().data.numpy() 343 | assert np.array_equal(np_tensor, ct_tensor) 344 | 345 | 346 | if __name__ == '__main__': 347 | pytest.main([__file__]) 348 | -------------------------------------------------------------------------------- /pytorch_complex_tensor/complex_tensor.py: -------------------------------------------------------------------------------- 1 | from pytorch_complex_tensor import ComplexScalar, ComplexGrad 2 | 3 | import inspect 4 | import numpy as np 5 | import torch 6 | import re 7 | 8 | 9 | """ 10 | Complex tensor support for PyTorch. 11 | 12 | Uses a regular tensor where the first half are the real numbers and second are the imaginary. 13 | 14 | Supports only some basic operations without breaking the gradients for complex math. 15 | 16 | Supported ops: 17 | 1. addition 18 | - (tensor, scalar). Both complex and real. 19 | 2. subtraction 20 | - (tensor, scalar). Both complex and real. 21 | 3. multiply 22 | - (tensor, scalar). Both complex and real. 23 | 4. mm (matrix multiply) 24 | - (tensor). Both complex and real. 25 | 5. abs (absolute value) 26 | 6. all indexing ops. 27 | 7. t (transpose) 28 | 29 | >> c = ComplexTensor(10, 20) 30 | 31 | >> # do regular tensor ops now 32 | >> c = c * 4 33 | >> c = c.mm(c.t()) 34 | >> print(c.shape, c.size(0)) 35 | >> print(c) 36 | >> print(c[0:1, 1:-1]) 37 | """ 38 | 39 | 40 | class ComplexTensor(torch.Tensor): 41 | 42 | @staticmethod 43 | def __new__(cls, x, *args, **kwargs): 44 | # requested to init with dim list 45 | # double the second to last dim (..., 1, 3, 2) -> (..., 1, 6, 2) 46 | 47 | # reformat complex numpy arrays so we can init with them 48 | if isinstance(x, np.ndarray) and 'complex' in str(x.dtype): 49 | # collapse second to last dim 50 | r = x.real 51 | i = x.imag 52 | x = np.concatenate([r, i], axis=-2) 53 | 54 | # x is the second to last dim in this case 55 | if type(x) is int and len(args) == 1: 56 | x = x * 2 57 | 58 | elif len(args) >= 2: 59 | size_args = list(args) 60 | size_args[-2] *= 2 61 | args = tuple(size_args) 62 | 63 | else: 64 | if isinstance(x, torch.Tensor): 65 | s = x.size()[-2] 66 | elif isinstance(x, list): 67 | s = len(x) 68 | elif isinstance(x, np.ndarray): 69 | s = x.shape[-2] 70 | if not (s % 2 == 0): raise Exception('second to last dim must be even. ComplexTensor is 2 real matrices under the hood') 71 | 72 | # init new t 73 | new_t = super().__new__(cls, x, *args, **kwargs) 74 | return new_t 75 | 76 | def __deepcopy__(self, memo): 77 | if not self.is_leaf: 78 | raise RuntimeError("Only Tensors created explicitly by the user " 79 | "(graph leaves) support the deepcopy protocol at the moment") 80 | if id(self) in memo: 81 | return memo[id(self)] 82 | with torch.no_grad(): 83 | if self.is_sparse: 84 | new_tensor = self.clone() 85 | 86 | # hack tensor to cast as complex 87 | new_tensor.__class__ = ComplexTensor 88 | else: 89 | new_storage = self.storage().__deepcopy__(memo) 90 | new_tensor = self.new() 91 | 92 | # hack tensor to cast as complex 93 | new_tensor.__class__ = ComplexTensor 94 | new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride()) 95 | memo[id(self)] = new_tensor 96 | new_tensor.requires_grad = self.requires_grad 97 | return new_tensor 98 | 99 | @property 100 | def real(self): 101 | size = self.size() 102 | n = size[-2] 103 | n = n * 2 104 | result = self[..., :n//2, :] 105 | return result 106 | 107 | @property 108 | def imag(self): 109 | size = self.size() 110 | n = size[-2] 111 | n = n * 2 112 | result = self[..., n//2:, :] 113 | return result 114 | 115 | def __graph_copy__(self, real, imag): 116 | # return tensor copy but maintain graph connections 117 | # force the result to be a ComplexTensor 118 | result = torch.cat([real, imag], dim=0) 119 | result.__class__ = ComplexTensor 120 | return result 121 | 122 | def __graph_copy_scalar__(self, real, imag): 123 | # return tensor copy but maintain graph connections 124 | # force the result to be a ComplexTensor 125 | result = torch.stack([real, imag], dim=-2) 126 | result.__class__ = ComplexScalar 127 | return result 128 | 129 | def __add__(self, other): 130 | """ 131 | Handles scalar (real, complex) and tensor (real, complex) addition 132 | :param other: 133 | :return: 134 | """ 135 | real = self.real 136 | imag = self.imag 137 | 138 | # given a real tensor 139 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor: 140 | real = real + other 141 | 142 | # given a complex tensor 143 | elif type(other) is ComplexTensor: 144 | real = real + other.real 145 | imag = imag + other.imag 146 | 147 | # given a real scalar 148 | elif np.isreal(other): 149 | real = real + other 150 | 151 | # given a complex scalar 152 | else: 153 | real = real + other.real 154 | imag = imag + other.imag 155 | 156 | return self.__graph_copy__(real, imag) 157 | 158 | def __radd__(self, other): 159 | return self.__add__(other) 160 | 161 | def __sub__(self, other): 162 | """ 163 | Handles scalar (real, complex) and tensor (real, complex) addition 164 | :param other: 165 | :return: 166 | """ 167 | real = self.real 168 | imag = self.imag 169 | 170 | # given a real tensor 171 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor: 172 | real = real - other 173 | 174 | # given a complex tensor 175 | elif type(other) is ComplexTensor: 176 | real = real - other.real 177 | imag = imag - other.imag 178 | 179 | # given a real scalar 180 | elif np.isreal(other): 181 | real = real - other 182 | 183 | # given a complex scalar 184 | else: 185 | real = real - other.real 186 | imag = imag - other.imag 187 | 188 | return self.__graph_copy__(real, imag) 189 | 190 | def __rsub__(self, other): 191 | return self.__sub__(other) 192 | 193 | def __mul__(self, other): 194 | """ 195 | Handles scalar (real, complex) and tensor (real, complex) multiplication 196 | :param other: 197 | :return: 198 | """ 199 | real = self.real.clone() 200 | imag = self.imag.clone() 201 | 202 | # given a real tensor 203 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor: 204 | real = real * other 205 | imag = imag * other 206 | 207 | # given a complex tensor 208 | elif type(other) is ComplexTensor: 209 | ac = real * other.real 210 | bd = imag * other.imag 211 | ad = real * other.imag 212 | bc = imag * other.real 213 | real = ac - bd 214 | imag = ad + bc 215 | 216 | # given a real scalar 217 | elif np.isreal(other): 218 | real = real * other 219 | imag = imag * other 220 | 221 | # given a complex scalar 222 | else: 223 | ac = real * other.real 224 | bd = imag * other.imag 225 | ad = real * other.imag 226 | bc = imag * other.real 227 | real = ac - bd 228 | imag = ad + bc 229 | 230 | return self.__graph_copy__(real, imag) 231 | 232 | def __truediv__(self, other): 233 | real = self.real.clone() 234 | imag = self.imag.clone() 235 | 236 | # given a real tensor 237 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor: 238 | raise NotImplementedError 239 | 240 | # given a complex tensor 241 | elif type(other) is ComplexTensor: 242 | raise NotImplementedError 243 | 244 | # given a real scalar 245 | elif np.isreal(other): 246 | real = real / other 247 | imag = imag / other 248 | 249 | # given a complex scalar 250 | else: 251 | raise NotImplementedError 252 | 253 | return self.__graph_copy__(real, imag) 254 | 255 | def __rmul__(self, other): 256 | return self.__mul__(other) 257 | 258 | def __neg__(self): 259 | return self.__mul__(-1) 260 | 261 | def mm(self, other): 262 | """ 263 | Handles tensor (real, complex) matrix multiply 264 | :param other: 265 | :return: 266 | """ 267 | real = self.real.clone() 268 | imag = self.imag.clone() 269 | 270 | # given a real tensor 271 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor: 272 | real = real.mm(other) 273 | imag = imag.mm(other) 274 | 275 | # given a complex tensor 276 | elif type(other) is ComplexTensor: 277 | ac = real.mm(other.real) 278 | bd = imag.mm(other.imag) 279 | ad = real.mm(other.imag) 280 | bc = imag.mm(other.real) 281 | real = ac - bd 282 | imag = ad + bc 283 | 284 | return self.__graph_copy__(real, imag) 285 | 286 | def t(self): 287 | real = self.real.t() 288 | imag = self.imag.t() 289 | 290 | return self.__graph_copy__(real, imag) 291 | 292 | def abs(self): 293 | result = torch.sqrt(self.real**2 + self.imag**2) 294 | return result 295 | 296 | def sum(self, *args): 297 | real_sum = self.real.sum(*args) 298 | imag_sum = self.imag.sum(*args) 299 | return ComplexScalar(real_sum, imag_sum) 300 | 301 | def mean(self, *args): 302 | real_mean = self.real.mean(*args) 303 | imag_mean = self.imag.mean(*args) 304 | return ComplexScalar(real_mean, imag_mean) 305 | 306 | @property 307 | def grad(self): 308 | g = self._grad 309 | g.__class__ = ComplexGrad 310 | 311 | return g 312 | 313 | def cuda(self): 314 | real = self.real.cuda() 315 | imag = self.imag.cuda() 316 | 317 | return self.__graph_copy__(real, imag) 318 | 319 | def __repr__(self): 320 | real = self.real.flatten() 321 | imag = self.imag.flatten() 322 | 323 | # use numpy to print for us 324 | # strings = np.asarray([f'({a}{"+" if b > 0 else "-"}{abs(b)}j)' for a, b in zip(real, imag)]) 325 | strings = np.asarray([complex(a,b) for a, b in zip(real, imag)]).astype(np.complex64) 326 | strings = strings.__repr__() 327 | strings = re.sub('array', 'tensor', strings) 328 | return strings 329 | 330 | def __str__(self): 331 | return self.__repr__() 332 | 333 | def is_complex(self): 334 | return True 335 | 336 | def size(self, *args): 337 | size = self.data.size(*args) 338 | size = list(size) 339 | size[-2] = size[-2] // 2 340 | size = torch.Size(size) 341 | return size 342 | 343 | @property 344 | def shape(self): 345 | size = self.data.shape 346 | size = list(size) 347 | size[-2] = size[-2] // 2 348 | size = torch.Size(size) 349 | return size 350 | 351 | def __getitem__(self, item): 352 | 353 | # when real or imag is the caller return regular tensor 354 | curframe = inspect.currentframe() 355 | calframe = inspect.getouterframes(curframe, 2) 356 | caller = calframe[1][3] 357 | 358 | if caller == 'real' or caller == 'imag': 359 | return super(ComplexTensor, self).__getitem__(item) 360 | 361 | # this is a regular index op, select the requested pairs then form a new ComplexTensor 362 | r = self.real[item] 363 | c = self.imag[item] 364 | 365 | return self.__graph_copy__(r, c) 366 | 367 | -------------------------------------------------------------------------------- /Complex demo example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## PT complex tensor examples" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import torch\n", 18 | "from pytorch_complex_tensor import ComplexTensor\n", 19 | "import tensorflow as tf\n", 20 | "tf.enable_eager_execution()" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "--- \n", 28 | "Creation" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "array([[1.+3.j, 1.+3.j, 1.+3.j],\n", 40 | " [2.+4.j, 2.+4.j, 2.+4.j]], dtype=complex64)" 41 | ] 42 | }, 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "# numpy complex tensor\n", 50 | "np_c = np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64)\n", 51 | "np_c" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "tensor([['(1.0+3.0j)' '(1.0+3.0j)' '(1.0+3.0j)'],\n", 64 | " ['(2.0+4.0j)' '(2.0+4.0j)' '(2.0+4.0j)']])\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "# torch equivalent\n", 70 | "pt_c = ComplexTensor([[1, 1, 1], [2,2,2], [3,3,3], [4,4,4]])\n", 71 | "print(pt_c)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "[[1. 1. 1.]\n", 84 | " [2. 2. 2.]]\n", 85 | "tensor([[1., 1., 1.],\n", 86 | " [2., 2., 2.]])\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "# verify reals match\n", 92 | "print(np_c.real)\n", 93 | "print(pt_c.real)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "[[3. 3. 3.]\n", 106 | " [4. 4. 4.]]\n", 107 | "tensor([[3., 3., 3.],\n", 108 | " [4., 4., 4.]])\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "# verify imag match\n", 114 | "print(np_c.imag)\n", 115 | "print(pt_c.imag)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "--- \n", 123 | "Verify complex addition" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "text/plain": [ 134 | "array([[4.+5.j, 4.+5.j, 4.+5.j],\n", 135 | " [5.+6.j, 5.+6.j, 5.+6.j]], dtype=complex64)" 136 | ] 137 | }, 138 | "execution_count": 6, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "np_c + (3+2j)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 7, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "text/plain": [ 155 | "tensor([['(4.0+5.0j)' '(4.0+5.0j)' '(4.0+5.0j)'],\n", 156 | " ['(5.0+6.0j)' '(5.0+6.0j)' '(5.0+6.0j)']])" 157 | ] 158 | }, 159 | "execution_count": 7, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "pt_c + (3 + 2j)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "--- \n", 173 | "verify abs" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 8, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "text/plain": [ 184 | "array([[3.1622777, 3.1622777, 3.1622777],\n", 185 | " [4.472136 , 4.472136 , 4.472136 ]], dtype=float32)" 186 | ] 187 | }, 188 | "execution_count": 8, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "np.abs(np_c)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 9, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "tensor([[3.1623, 3.1623, 3.1623],\n", 206 | " [4.4721, 4.4721, 4.4721]])" 207 | ] 208 | }, 209 | "execution_count": 9, 210 | "metadata": {}, 211 | "output_type": "execute_result" 212 | } 213 | ], 214 | "source": [ 215 | "pt_c.abs()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "--- \n", 223 | "verify complex vs real matrix multiply" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 10, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "[[3 3]\n", 236 | " [4 4]\n", 237 | " [2 2]]\n", 238 | "tensor([[3., 3.],\n", 239 | " [4., 4.],\n", 240 | " [2., 2.]])\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "np_x = np.asarray([[3, 3], [4, 4], [2, 2]])\n", 246 | "pt_x = torch.Tensor([[3, 3], [4, 4], [2, 2]])\n", 247 | "\n", 248 | "print(np_x)\n", 249 | "print(pt_x)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 11, 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "text/plain": [ 260 | "array([[ 9.+27.j, 9.+27.j],\n", 261 | " [18.+36.j, 18.+36.j]])" 262 | ] 263 | }, 264 | "execution_count": 11, 265 | "metadata": {}, 266 | "output_type": "execute_result" 267 | } 268 | ], 269 | "source": [ 270 | "np_mm_out = np.matmul(np_c, np_x)\n", 271 | "np_mm_out" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 12, 277 | "metadata": {}, 278 | "outputs": [ 279 | { 280 | "data": { 281 | "text/plain": [ 282 | "tensor([['(9.0+27.0j)' '(9.0+27.0j)'],\n", 283 | " ['(18.0+36.0j)' '(18.0+36.0j)']])" 284 | ] 285 | }, 286 | "execution_count": 12, 287 | "metadata": {}, 288 | "output_type": "execute_result" 289 | } 290 | ], 291 | "source": [ 292 | "pt_mm_out = pt_c.mm(pt_x)\n", 293 | "pt_mm_out" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 13, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "[[ 9. 9.]\n", 306 | " [18. 18.]]\n", 307 | "tensor([[ 9., 9.],\n", 308 | " [18., 18.]])\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "# verify reals\n", 314 | "print(np_mm_out.real)\n", 315 | "print(pt_mm_out.real)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 14, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "[[27. 27.]\n", 328 | " [36. 36.]]\n", 329 | "tensor([[27., 27.],\n", 330 | " [36., 36.]])\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "# verify imags\n", 336 | "print(np_mm_out.imag)\n", 337 | "print(pt_mm_out.imag)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "metadata": {}, 343 | "source": [ 344 | "--- \n", 345 | "verify transpose" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 15, 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "data": { 355 | "text/plain": [ 356 | "array([[1.+3.j, 2.+4.j],\n", 357 | " [1.+3.j, 2.+4.j],\n", 358 | " [1.+3.j, 2.+4.j]], dtype=complex64)" 359 | ] 360 | }, 361 | "execution_count": 15, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "np_c.T" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 16, 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "data": { 377 | "text/plain": [ 378 | "tensor([['(1.0+3.0j)' '(2.0+4.0j)'],\n", 379 | " ['(1.0+3.0j)' '(2.0+4.0j)'],\n", 380 | " ['(1.0+3.0j)' '(2.0+4.0j)']])" 381 | ] 382 | }, 383 | "execution_count": 16, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "pt_c.t()" 390 | ] 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "metadata": {}, 395 | "source": [ 396 | "--- \n", 397 | "## wfalcon/pytorch-complex-tensor\n" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 17, 403 | "metadata": {}, 404 | "outputs": [ 405 | { 406 | "name": "stdout", 407 | "output_type": "stream", 408 | "text": [ 409 | "tensor([['(1.0+2.0j)' '(3.0+4.0j)' '(5.0+6.0j)'],\n", 410 | " ['(7.0+8.0j)' '(9.0+10.0j)' '(11.0+12.0j)']])\n" 411 | ] 412 | } 413 | ], 414 | "source": [ 415 | "pt_c2 = ComplexTensor([[1, 3, 5], [7,9,11], [2,4,6], [8,10,12]])\n", 416 | "print(pt_c2)\n", 417 | "pt_c2.requires_grad = True" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 18, 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "name": "stdout", 427 | "output_type": "stream", 428 | "text": [ 429 | "tensor([['(15.0+136.0j)' '(69.0+334.0j)'],\n", 430 | " ['(-3.0+262.0j)' '(51.0+676.0j)']])\n" 431 | ] 432 | } 433 | ], 434 | "source": [ 435 | "out = pt_c2 + 4\n", 436 | "out = out.mm(pt_c2.t())\n", 437 | "print(out)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 19, 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "name": "stdout", 447 | "output_type": "stream", 448 | "text": [ 449 | "tensor(132., grad_fn=)\n", 450 | "tensor(1408., grad_fn=)\n" 451 | ] 452 | } 453 | ], 454 | "source": [ 455 | "real_sum = out.real.sum()\n", 456 | "print(real_sum)\n", 457 | "out_imag = out.imag.sum()\n", 458 | "print(out_imag)" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 20, 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "real_sum.backward()" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 21, 473 | "metadata": {}, 474 | "outputs": [ 475 | { 476 | "data": { 477 | "text/plain": [ 478 | "tensor([['(24.0-20.0j)' '(32.0-28.0j)' '(40.0-36.0j)'],\n", 479 | " ['(24.0-20.0j)' '(32.0-28.0j)' '(40.0-36.0j)']])" 480 | ] 481 | }, 482 | "execution_count": 21, 483 | "metadata": {}, 484 | "output_type": "execute_result" 485 | } 486 | ], 487 | "source": [ 488 | "pt_c2.grad" 489 | ] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": {}, 494 | "source": [ 495 | "--- \n", 496 | "### Use Tensorflow to compute grads\n" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 6, 502 | "metadata": {}, 503 | "outputs": [], 504 | "source": [ 505 | "tf_c2 = tf.constant([[1+2j, 3+4j, 5+6j], [7+8j,9+10j,11+12j]], dtype=tf.complex64)" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 7, 511 | "metadata": {}, 512 | "outputs": [ 513 | { 514 | "name": "stdout", 515 | "output_type": "stream", 516 | "text": [ 517 | "tf.Tensor(\n", 518 | "[[15.+136.j 69.+334.j]\n", 519 | " [-3.+262.j 51.+676.j]], shape=(2, 2), dtype=complex64)\n", 520 | "tf.Tensor((132+1408j), shape=(), dtype=complex64)\n" 521 | ] 522 | } 523 | ], 524 | "source": [ 525 | "with tf.GradientTape() as t:\n", 526 | " t.watch(tf_c2)\n", 527 | " tf_out = tf_c2 + 4\n", 528 | " tf_out = tf.matmul(tf_out, tf.transpose(tf_c2, perm=[1,0]))\n", 529 | " print(tf_out)\n", 530 | " \n", 531 | " tf_y = tf.reduce_sum(tf_out)\n", 532 | " print(tf_y)" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 8, 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "dy_dc2 = t.gradient(tf_y, tf_c2)" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 9, 547 | "metadata": {}, 548 | "outputs": [ 549 | { 550 | "data": { 551 | "text/plain": [ 552 | "" 555 | ] 556 | }, 557 | "execution_count": 9, 558 | "metadata": {}, 559 | "output_type": "execute_result" 560 | } 561 | ], 562 | "source": [ 563 | "dy_dc2" 564 | ] 565 | } 566 | ], 567 | "metadata": { 568 | "kernelspec": { 569 | "display_name": "Python [conda env:organics]", 570 | "language": "python", 571 | "name": "conda-env-organics-py" 572 | }, 573 | "language_info": { 574 | "codemirror_mode": { 575 | "name": "ipython", 576 | "version": 3 577 | }, 578 | "file_extension": ".py", 579 | "mimetype": "text/x-python", 580 | "name": "python", 581 | "nbconvert_exporter": "python", 582 | "pygments_lexer": "ipython3", 583 | "version": "3.6.6" 584 | } 585 | }, 586 | "nbformat": 4, 587 | "nbformat_minor": 2 588 | } 589 | --------------------------------------------------------------------------------