├── tests
├── __init__.py
├── torch_complex_test.py
└── pytorch_complex_tensor_test.py
├── pytorch_complex_tensor
├── __init__.py
├── complex_scalar.py
├── torch_complex.py
├── complex_grad.py
└── complex_tensor.py
├── update.sh
├── requirements.txt
├── setup.py
├── LICENSE
├── .circleci
└── config.yml
├── .gitignore
├── README.md
└── Complex demo example.ipynb
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pytorch_complex_tensor/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_complex_tensor.complex_scalar import ComplexScalar
2 | from pytorch_complex_tensor.complex_grad import ComplexGrad
3 | from pytorch_complex_tensor.complex_tensor import ComplexTensor
4 |
--------------------------------------------------------------------------------
/update.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | version=$1
4 |
5 | git commit -am "release v$version"
6 | git tag $version -m "pytorch_complex_tensor v$version"
7 | git push --tags origin master
8 |
9 | # push to pypi
10 | rm -rf ./dist/*
11 | python3 setup.py sdist
12 | twine upload dist/*
13 |
14 |
15 |
16 | # to update docs
17 | # cd to root dir
18 | # mkdocs gh-deploy
19 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | atomicwrites==1.3.0
2 | attrs==18.2.0
3 | certifi==2018.11.29
4 | cffi==1.11.5
5 | chardet==3.0.4
6 | codecov==2.0.15
7 | coverage==4.5.2
8 | idna==2.8
9 | more-itertools==5.0.0
10 | numpy==1.15.4
11 | olefile==0.46
12 | Pillow==6.2.0
13 | pluggy==0.8.1
14 | py==1.7.0
15 | pycparser==2.19
16 | pytest==4.2.0
17 | requests==2.21.0
18 | six==1.12.0
19 | torch==1.0.1
20 | torchvision==0.2.1
21 | urllib3==1.24.2
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, find_packages
4 |
5 | setup(
6 | name='pytorch-complex-tensor',
7 | version='0.0.134',
8 | description='Pytorch complex tensor',
9 | author='William Falcon',
10 | author_email='waf2107@columbia.edu',
11 | url='https://github.com/williamFalcon/pytorch-complex-tensor',
12 | install_requires=[
13 | 'numpy>=1.15.4',
14 | 'torch>=1.0',
15 | 'torchvision>=0.2.1'
16 | ],
17 | packages=find_packages()
18 | )
19 |
--------------------------------------------------------------------------------
/pytorch_complex_tensor/complex_scalar.py:
--------------------------------------------------------------------------------
1 | """
2 | Thin wrapper for complex scalar.
3 | Main contribution is to use only real part for backward
4 | """
5 |
6 |
7 | class ComplexScalar(object):
8 |
9 | def __init__(self, real, imag):
10 | self._real = real
11 | self._imag = imag
12 |
13 | @property
14 | def real(self):
15 | return self._real
16 |
17 | @property
18 | def imag(self):
19 | return self._imag
20 |
21 | def backward(self):
22 | self._real.backward()
23 |
24 | def __repr__(self):
25 | return str(complex(self.real.item(), self.imag.item()))
26 |
27 | def __str__(self):
28 | return self.__repr__()
29 |
--------------------------------------------------------------------------------
/pytorch_complex_tensor/torch_complex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_complex_tensor import ComplexTensor
3 |
4 |
5 | def __graph_copy__(real, imag):
6 | # return tensor copy but maintain graph connections
7 | # force the result to be a ComplexTensor
8 | result = torch.cat([real, imag], dim=-2)
9 | result.__class__ = ComplexTensor
10 | return result
11 |
12 |
13 | def __apply_fx_to_parts(items, fx, *args, **kwargs):
14 | r = [x.real for x in items]
15 | r = fx(r, *args, **kwargs)
16 |
17 | i = [x.imag for x in items]
18 | i = fx(i, *args, **kwargs)
19 |
20 | return __graph_copy__(r, i)
21 |
22 |
23 | def stack(items, *args, **kwargs):
24 | return __apply_fx_to_parts(items, torch.stack, *args, **kwargs)
25 |
26 |
27 | def cat(items, *args, **kwargs):
28 | return __apply_fx_to_parts(items, torch.cat, *args, **kwargs)
29 |
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 William Falcon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/torch_complex_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | from pytorch_complex_tensor import ComplexTensor
5 | from pytorch_complex_tensor import torch_complex
6 |
7 |
8 | def __test_torch_op(complex_op, torch_op):
9 | a = ComplexTensor(torch.zeros(4, 3)) + 1
10 | b = ComplexTensor(torch.zeros(4, 3)) + 2
11 | c = ComplexTensor(torch.zeros(4, 3)) + 3
12 |
13 | d = complex_op([a, b, c], dim=0)
14 | size = list(d.size())
15 |
16 | # double second to last axis bc we always half it when generating tensors
17 | size[-2] *= 2
18 |
19 | # compare against regular torch implementation
20 | r_a = torch.zeros(4, 3)
21 | r_b = torch.zeros(4, 3)
22 | r_c = torch.zeros(4, 3)
23 | r_d = torch_op([r_a, r_b, r_c], dim=0)
24 | t_size = r_d.size()
25 |
26 | for i in range(len(size)):
27 | assert size[i] == t_size[i]
28 |
29 |
30 | def test_stack():
31 | __test_torch_op(torch_complex.stack, torch.stack)
32 |
33 |
34 | def test_cat():
35 | __test_torch_op(torch_complex.cat, torch.cat)
36 |
37 |
38 | if __name__ == '__main__':
39 | pytest.main([__file__])
40 |
41 |
--------------------------------------------------------------------------------
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | # Python CircleCI 2.0 configuration file
2 | #
3 | # Check https://circleci.com/docs/2.0/language-python/ for more details
4 | #
5 | version: 2
6 | jobs:
7 | build:
8 | docker:
9 | # specify the version you desire here
10 | # use `-browsers` prefix for selenium tests, e.g. `3.6.1-browsers`
11 | - image: circleci/python:3.6.1
12 |
13 | # Specify service dependencies here if necessary
14 | # CircleCI maintains a library of pre-built images
15 | # documented at https://circleci.com/docs/2.0/circleci-images/
16 | # - image: circleci/postgres:9.4
17 |
18 | working_directory: ~/repo
19 |
20 | steps:
21 | - checkout
22 |
23 | # Download and cache dependencies
24 | - restore_cache:
25 | keys:
26 | - v1-dependencies-{{ checksum "requirements.txt" }}
27 | # fallback to using the latest cache if no exact match is found
28 | - v1-dependencies-
29 |
30 | - run:
31 | name: install dependencies
32 | command: |
33 | python3 -m venv venv
34 | . venv/bin/activate
35 | pip install -r requirements.txt
36 | pip install -e .
37 |
38 | - save_cache:
39 | paths:
40 | - ./venv
41 | key: v1-dependencies-{{ checksum "requirements.txt" }}
42 |
43 | # run tests!
44 | # this example uses Django's built-in test-runner
45 | # other common Python testing frameworks include pytest and nose
46 | # https://pytest.org
47 | # https://nose.readthedocs.io
48 | - run:
49 | name: run tests
50 | command: |
51 | . venv/bin/activate
52 | pytest
53 | codecov
54 |
55 | - store_artifacts:
56 | path: test-reports
57 | destination: test-reports
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # project
2 | .DS_Store
3 | dev.py
4 | app/models/
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 | example.py
10 | timit_data/
11 | LJSpeech-1.1/
12 |
13 | # C extensions
14 | *.so
15 |
16 | .idea/
17 |
18 | # Distribution / packaging
19 | .Python
20 | env/
21 | ide_layouts/
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | .hypothesis/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # celery beat schedule file
87 | celerybeat-schedule
88 |
89 | # SageMath parsed files
90 | *.sage.py
91 |
92 | # dotenv
93 | .env
94 |
95 | # virtualenv
96 | .venv
97 | venv/
98 | ENV/
99 |
100 | # Spyder project settings
101 | .spyderproject
102 | .spyproject
103 |
104 | # Rope project settings
105 | .ropeproject
106 |
107 | # mkdocs documentation
108 | /site
109 |
110 | # mypy
111 | .mypy_cache/
112 |
--------------------------------------------------------------------------------
/pytorch_complex_tensor/complex_grad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import re
4 |
5 | """
6 | Does nothing except pretty print complex grad info
7 | """
8 |
9 | class ComplexGrad(torch.Tensor):
10 |
11 | def __deepcopy__(self, memo):
12 | if not self.is_leaf:
13 | raise RuntimeError("Only Tensors created explicitly by the user "
14 | "(graph leaves) support the deepcopy protocol at the moment")
15 | if id(self) in memo:
16 | return memo[id(self)]
17 | with torch.no_grad():
18 | if self.is_sparse:
19 | new_tensor = self.clone()
20 |
21 | # hack tensor to cast as complex
22 | new_tensor.__class__ = ComplexGrad
23 | else:
24 | new_storage = self.storage().__deepcopy__(memo)
25 | new_tensor = self.new()
26 |
27 | # hack tensor to cast as complex
28 | new_tensor.__class__ = ComplexGrad
29 | new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
30 | memo[id(self)] = new_tensor
31 | new_tensor.requires_grad = self.requires_grad
32 | return new_tensor
33 |
34 | def __repr__(self):
35 | size = self.size()
36 | split_i = size[0] // 2
37 | real = self[:split_i]
38 | imag = self[split_i:]
39 | size_r = real.size()
40 |
41 | real = real.view(-1)
42 | imag = imag.view(-1)
43 |
44 | strings = np.asarray([f'({a}{"+" if b > 0 else "-"}{abs(b)}j)' for a, b in zip(real, imag)])
45 | strings = strings.reshape(*size_r)
46 | strings = f'tensor({strings.__str__()})'
47 | strings = re.sub('\n', ',\n ', strings)
48 | return strings
49 |
50 | def __str__(self):
51 | return self.__repr__()
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 | Pytorch Complex Tensor
8 |
9 |
10 | Unofficial complex Tensor support for Pytorch
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | ### How it works
20 |
21 | Treats first half of tensor as real, second as imaginary. A few arithmetic operations are implemented to emulate complex arithmetic. Supports gradients.
22 |
23 | ### Installation
24 | ```bash
25 | pip install pytorch-complex-tensor
26 | ```
27 |
28 | ### Example:
29 | Easy import
30 |
31 | ```python
32 | from pytorch_complex_tensor import ComplexTensor
33 | ```
34 |
35 | Init tensor
36 |
37 | ```python
38 | # equivalent to:
39 | # np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64)
40 | C = ComplexTensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
41 | C.requires_grad = True
42 | ```
43 |
44 | Pretty printing
45 |
46 | ```python
47 | print(C)
48 | # tensor([['(1.0+3.0j)' '(1.0+3.0j)' '(1.0+3.0j)'],
49 | # ['(2.0+4.0j)' '(2.0+4.0j)' '(2.0+4.0j)']])
50 | ```
51 |
52 | handles absolute value properly for complex tensors
53 |
54 | ```python
55 | # complex absolute value implementation
56 | print(C.abs())
57 | # tensor([[3.1623, 3.1623, 3.1623],
58 | # [4.4721, 4.4721, 4.4721]], grad_fn=)
59 | ```
60 |
61 |
62 | prints correct sizing treating first half of matrix as real, second as imag
63 | ```python
64 | print(C.size())
65 | # torch.Size([2, 3])
66 | ```
67 |
68 | multiplies both complex and real tensors
69 | ```python
70 | # show matrix multiply with real tensor
71 | # also works with complex tensor
72 | x = torch.Tensor([[3, 3], [4, 4], [2, 2]])
73 | xy = C.mm(x)
74 | print(xy)
75 | # tensor([['(9.0+27.0j)' '(9.0+27.0j)'],
76 | # ['(18.0+36.0j)' '(18.0+36.0j)']])
77 | ```
78 |
79 | reduce ops return ComplexScalar
80 | ```python
81 | xy = xy.sum()
82 |
83 | # this is now a complex scalar (thin wrapper with .real, .imag)
84 | print(type(xy))
85 | # pytorch_complex_tensor.complex_scalar.ComplexScalar
86 |
87 | print(xy)
88 | # (54+126j)
89 | ```
90 |
91 | which can be used for gradients without breaking anything... (differentiates wrt the real part)
92 | ```python
93 | # calculate dxy / dC
94 | # for complex scalars, grad is wrt the real part
95 | xy.backward()
96 | print(C.grad)
97 | # tensor([['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)'],
98 | # ['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)']])
99 | ```
100 |
101 | supports all section ops...
102 | ```python
103 | print(C[-1])
104 | print(C[0, 0:-2, ...])
105 | print(C[0, ..., 0])
106 | ```
107 |
108 |
109 | ### Supported ops:
110 | | Operation | complex tensor | real tensor | complex scalar | real scalar |
111 | | ----------| :-------------:|:-----------:|:--------------:|:-----------:|
112 | | addition | Y | Y | Y | Y |
113 | | subtraction | Y | Y | Y | Y |
114 | | multiply | Y | Y | Y | Y |
115 | | mm | Y | Y | Y | Y |
116 | | abs | Y | - | - | - |
117 | | t | Y | - | - | - |
118 | | grads | Y | Y | Y | Y |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/tests/pytorch_complex_tensor_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | from pytorch_complex_tensor import ComplexTensor
5 |
6 |
7 | # ------------------
8 | # GRAD TESTS
9 | # -------------------
10 | def test_grad():
11 | """
12 | Grad calculated first with tensorflow
13 |
14 | :return:
15 | """
16 |
17 | c = ComplexTensor([[1, 3, 5], [7, 9, 11], [2, 4, 6], [8, 10, 12]])
18 | c.requires_grad = True
19 |
20 | # simulate some ops
21 | out = c + 4
22 | out = out.mm(c.t())
23 |
24 | # calc grad
25 | out = out.sum()
26 | out.backward()
27 |
28 | # d_out/dc
29 | g = c.grad.view(-1).data.numpy()
30 |
31 | # solution (as provided by running same ops in tensorflow)
32 | """
33 | tf_c2 = tf.constant([[1+2j, 3+4j, 5+6j], [7+8j,9+10j,11+12j]], dtype=tf.complex64)
34 |
35 | with tf.GradientTape() as t:
36 | t.watch(tf_c2)
37 | tf_out = tf_c2 + 4
38 | tf_out = tf.matmul(tf_out, tf.transpose(tf_c2, perm=[1,0]))
39 |
40 | tf_y = tf.reduce_sum(tf_out)
41 | dy_dc2 = t.gradient(tf_y, tf_c2)
42 |
43 | # solution
44 | print(dy_dc2)
45 | """
46 | #
47 | sol = np.asarray([24, 32, 40, 24, 32, 40, -20, -28, -36, -20, -28, -36])
48 | assert np.array_equal(g, sol)
49 |
50 |
51 | def test_size():
52 | # test sizing when init with tensor
53 | c = ComplexTensor(torch.zeros(4, 3))
54 | size = c.size()
55 | n, m = size[-2:]
56 | assert n == 2
57 | assert m == 3
58 |
59 | # test sizing when init with dim spec
60 | c = ComplexTensor(12, 8)
61 | size = c.size()
62 | n, m = size[-2:]
63 | assert n == 12
64 | assert m == 8
65 |
66 |
67 | def test_shape():
68 | # test sizing when init with tensor
69 | c = ComplexTensor(torch.zeros(4, 3))
70 | size = c.shape
71 | n, m = size[-2:]
72 | assert n == 2
73 | assert m == 3
74 |
75 | # test sizing when init with dim spec
76 | c = ComplexTensor(12, 8)
77 | size = c.shape
78 | n, m = size[-2:]
79 | assert n == 12
80 | assert m == 8
81 |
82 |
83 | # ------------------
84 | # REDUCE FX TESTS
85 | # ------------------
86 | def test_abs():
87 | c = ComplexTensor(torch.zeros(4, 3)) + 2
88 | c = (4+3j) * c
89 | c = c.abs()
90 | c = c.view(-1).data.numpy()
91 |
92 | # do the same in numpy
93 | sol = np.zeros((2, 3)).astype(np.complex64) + 2
94 | sol = (4+3j) * sol
95 | sol = np.abs(sol)
96 |
97 | sol = sol.flatten()
98 | sol = list(sol.real)
99 |
100 | assert np.array_equal(c, sol)
101 |
102 |
103 | # ------------------
104 | # SUM TESTS
105 | # ------------------
106 | def test_real_scalar_sum():
107 |
108 | c = ComplexTensor(torch.zeros(4, 3))
109 | c = c + 4
110 | c = c.view(-1).data.numpy()
111 |
112 | # do the same in numpy
113 | sol = np.zeros((2, 3)).astype(np.complex64)
114 | sol = sol + 4
115 | sol = sol.flatten()
116 | sol = list(sol.real) + list(sol.imag)
117 |
118 | assert np.array_equal(c, sol)
119 |
120 |
121 | def test_complex_scalar_sum():
122 | c = ComplexTensor(torch.zeros(4, 3))
123 | c = c + (4+3j)
124 | c = c.view(-1).data.numpy()
125 |
126 | # do the same in numpy
127 | sol = np.zeros((2, 3)).astype(np.complex64)
128 | sol = sol + (4+3j)
129 | sol = sol.flatten()
130 | sol = list(sol.real) + list(sol.imag)
131 |
132 | assert np.array_equal(c, sol)
133 |
134 |
135 | def test_real_matrix_sum():
136 |
137 | c = ComplexTensor(torch.zeros(4, 3))
138 | r = torch.ones(2, 3)
139 | c = c + r
140 | c = c.view(-1).data.numpy()
141 |
142 | # do the same in numpy
143 | sol = np.zeros((2, 3)).astype(np.complex64)
144 | sol_r = np.ones((2, 3))
145 | sol = sol + sol_r
146 | sol = sol.flatten()
147 | sol = list(sol.real) + list(sol.imag)
148 |
149 | assert np.array_equal(c, sol)
150 |
151 |
152 | def test_complex_matrix_sum():
153 |
154 | c = ComplexTensor(torch.zeros(4, 3))
155 | cc = c + c
156 | cc = cc.view(-1).data.numpy()
157 |
158 | # do the same in numpy
159 | sol = np.zeros((2, 3)).astype(np.complex64)
160 | sol = sol + sol
161 | sol = sol.flatten()
162 | sol = list(sol.real) + list(sol.imag)
163 |
164 | assert np.array_equal(cc, sol)
165 |
166 |
167 | # ------------------
168 | # MULT TESTS
169 | # ------------------
170 | def test_scalar_mult():
171 | c = ComplexTensor(torch.zeros(4, 3)) + 1
172 | c = c * 4
173 | c = c.view(-1).data.numpy()
174 |
175 | # do the same in numpy
176 | sol = np.zeros((2, 3)).astype(np.complex64) + 1
177 | sol = sol * 4
178 | sol = sol.flatten()
179 | sol = list(sol.real) + list(sol.imag)
180 |
181 | assert np.array_equal(c, sol)
182 |
183 |
184 | def test_scalar_rmult():
185 | c = ComplexTensor(torch.zeros(4, 3)) + 1
186 | c = 4 * c
187 | c = c.view(-1).data.numpy()
188 |
189 | # do the same in numpy
190 | sol = np.zeros((2, 3)).astype(np.complex64) + 1
191 | sol = 4 * sol
192 | sol = sol.flatten()
193 | sol = list(sol.real) + list(sol.imag)
194 |
195 | assert np.array_equal(c, sol)
196 |
197 |
198 | def test_complex_mult():
199 | c = ComplexTensor(torch.zeros(4, 3)) + 1
200 | c = c * (4+3j)
201 | c = c.view(-1).data.numpy()
202 |
203 | # do the same in numpy
204 | sol = np.zeros((2, 3)).astype(np.complex64) + 1
205 | sol = sol * (4+3j)
206 | sol = sol.flatten()
207 | sol = list(sol.real) + list(sol.imag)
208 |
209 | assert np.array_equal(c, sol)
210 |
211 |
212 | def test_complex_rmult():
213 | c = ComplexTensor(torch.zeros(4, 3)) + 1
214 | c = (4+3j) * c
215 | c = c.view(-1).data.numpy()
216 |
217 | # do the same in numpy
218 | sol = np.zeros((2, 3)).astype(np.complex64) + 1
219 | sol = (4+3j) * sol
220 | sol = sol.flatten()
221 | sol = list(sol.real) + list(sol.imag)
222 |
223 | assert np.array_equal(c, sol)
224 |
225 |
226 | def test_complex_complex_ele_mult():
227 | """
228 | Complex mtx x complex mtx elementwise multiply
229 | :return:
230 | """
231 | c = ComplexTensor(torch.zeros(4, 3)) + 1
232 | c = c * c
233 | c = c.view(-1).data.numpy()
234 |
235 | # do the same in numpy
236 | sol = np.zeros((2, 3)).astype(np.complex64) + 1
237 | sol = sol * sol
238 | sol = sol.flatten()
239 | sol = list(sol.real) + list(sol.imag)
240 |
241 | assert np.array_equal(c, sol)
242 |
243 |
244 | def test_complex_real_ele_mult():
245 | """
246 | Complex mtx x real mtx elementwise multiply
247 | :return:
248 | """
249 | c = ComplexTensor(torch.zeros(4, 3)) + 1
250 | r = torch.ones(2, 3) * 2 + 3
251 | cr = c * r
252 | cr = cr.view(-1).data.numpy()
253 |
254 | # do the same in numpy
255 | np_c = np.ones((2, 3)).astype(np.complex64)
256 | np_r = np.ones((2, 3)) * 2 + 3
257 | np_cr = np_c * np_r
258 |
259 | # compare
260 | np_cr = np_cr.flatten()
261 | np_cr = list(np_cr.real) + list(np_cr.imag)
262 |
263 | assert np.array_equal(np_cr, cr)
264 |
265 |
266 | # ------------------
267 | # MM TESTS
268 | # ------------------
269 | def test_complex_real_mm():
270 | """
271 | Complex mtx x real mtx matrix multiply
272 | :return:
273 | """
274 | c = ComplexTensor(torch.zeros(4, 3)) + 1
275 | r = torch.ones(2, 3) * 2 + 3
276 | cr = c.mm(r.t())
277 | cr = cr.view(-1).data.numpy()
278 |
279 | # do the same in numpy
280 | np_c = np.ones((2, 3)).astype(np.complex64)
281 | np_r = np.ones((2, 3)) * 2 + 3
282 | np_cr = np.matmul(np_c, np_r.T)
283 |
284 | # compare
285 | np_cr = np_cr.flatten()
286 | np_cr = list(np_cr.real) + list(np_cr.imag)
287 |
288 | assert np.array_equal(np_cr, cr)
289 |
290 |
291 | def test_complex_complex_mm():
292 | """
293 | Complex mtx x complex mtx matrix multiply
294 | :return:
295 | """
296 | c = ComplexTensor(torch.zeros(4, 3)) + 1
297 | cc = c.mm(c.t())
298 | cc = cc.view(-1).data.numpy()
299 |
300 | # do the same in numpy
301 | np_c = np.ones((2, 3)).astype(np.complex64)
302 | np_cc = np.matmul(np_c, np_c.T)
303 |
304 | # compare
305 | np_cc = np_cc.flatten()
306 | np_cc = list(np_cc.real) + list(np_cc.imag)
307 |
308 | assert np.array_equal(np_cc, cc)
309 |
310 |
311 | def test_get_item():
312 | # init random complex numpy and ct tensors
313 | a = np.random.randint(0, 10, (3, 2, 3))
314 | a = a * (1+5j)
315 | ct = ComplexTensor(a)
316 |
317 | # match dim 0
318 | __assert_tensors_equal(ct[0], a[0])
319 | __assert_tensors_equal(ct[-1], a[-1])
320 |
321 | # match dim 1
322 | __assert_tensors_equal(ct[:, 0], a[:, 0])
323 | __assert_tensors_equal(ct[:, -1], a[:, -1])
324 |
325 | # match dim 2
326 | __assert_tensors_equal(ct[:, :, 0], a[:, :, 0])
327 | __assert_tensors_equal(ct[:, :, -1], a[:, :, -1])
328 |
329 | # match ranges
330 | __assert_tensors_equal(ct[0:1, 0, -2:], a[0:1, 0, -2:])
331 | __assert_tensors_equal(ct[-1:, -1:, -2:], a[-1:, -1:, -2:])
332 | __assert_tensors_equal(ct[:-1, :-1, :-2], a[:-1, :-1, :-2])
333 |
334 |
335 | def __assert_tensors_equal(ct_tensor, np_tensor):
336 | # assert we have complexTensor
337 | assert type(ct_tensor) is ComplexTensor
338 |
339 | # assert values are same
340 | np_tensor = np_tensor.flatten()
341 | np_tensor = list(np_tensor.real) + list(np_tensor.imag)
342 | ct_tensor = ct_tensor.flatten().data.numpy()
343 | assert np.array_equal(np_tensor, ct_tensor)
344 |
345 |
346 | if __name__ == '__main__':
347 | pytest.main([__file__])
348 |
--------------------------------------------------------------------------------
/pytorch_complex_tensor/complex_tensor.py:
--------------------------------------------------------------------------------
1 | from pytorch_complex_tensor import ComplexScalar, ComplexGrad
2 |
3 | import inspect
4 | import numpy as np
5 | import torch
6 | import re
7 |
8 |
9 | """
10 | Complex tensor support for PyTorch.
11 |
12 | Uses a regular tensor where the first half are the real numbers and second are the imaginary.
13 |
14 | Supports only some basic operations without breaking the gradients for complex math.
15 |
16 | Supported ops:
17 | 1. addition
18 | - (tensor, scalar). Both complex and real.
19 | 2. subtraction
20 | - (tensor, scalar). Both complex and real.
21 | 3. multiply
22 | - (tensor, scalar). Both complex and real.
23 | 4. mm (matrix multiply)
24 | - (tensor). Both complex and real.
25 | 5. abs (absolute value)
26 | 6. all indexing ops.
27 | 7. t (transpose)
28 |
29 | >> c = ComplexTensor(10, 20)
30 |
31 | >> # do regular tensor ops now
32 | >> c = c * 4
33 | >> c = c.mm(c.t())
34 | >> print(c.shape, c.size(0))
35 | >> print(c)
36 | >> print(c[0:1, 1:-1])
37 | """
38 |
39 |
40 | class ComplexTensor(torch.Tensor):
41 |
42 | @staticmethod
43 | def __new__(cls, x, *args, **kwargs):
44 | # requested to init with dim list
45 | # double the second to last dim (..., 1, 3, 2) -> (..., 1, 6, 2)
46 |
47 | # reformat complex numpy arrays so we can init with them
48 | if isinstance(x, np.ndarray) and 'complex' in str(x.dtype):
49 | # collapse second to last dim
50 | r = x.real
51 | i = x.imag
52 | x = np.concatenate([r, i], axis=-2)
53 |
54 | # x is the second to last dim in this case
55 | if type(x) is int and len(args) == 1:
56 | x = x * 2
57 |
58 | elif len(args) >= 2:
59 | size_args = list(args)
60 | size_args[-2] *= 2
61 | args = tuple(size_args)
62 |
63 | else:
64 | if isinstance(x, torch.Tensor):
65 | s = x.size()[-2]
66 | elif isinstance(x, list):
67 | s = len(x)
68 | elif isinstance(x, np.ndarray):
69 | s = x.shape[-2]
70 | if not (s % 2 == 0): raise Exception('second to last dim must be even. ComplexTensor is 2 real matrices under the hood')
71 |
72 | # init new t
73 | new_t = super().__new__(cls, x, *args, **kwargs)
74 | return new_t
75 |
76 | def __deepcopy__(self, memo):
77 | if not self.is_leaf:
78 | raise RuntimeError("Only Tensors created explicitly by the user "
79 | "(graph leaves) support the deepcopy protocol at the moment")
80 | if id(self) in memo:
81 | return memo[id(self)]
82 | with torch.no_grad():
83 | if self.is_sparse:
84 | new_tensor = self.clone()
85 |
86 | # hack tensor to cast as complex
87 | new_tensor.__class__ = ComplexTensor
88 | else:
89 | new_storage = self.storage().__deepcopy__(memo)
90 | new_tensor = self.new()
91 |
92 | # hack tensor to cast as complex
93 | new_tensor.__class__ = ComplexTensor
94 | new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
95 | memo[id(self)] = new_tensor
96 | new_tensor.requires_grad = self.requires_grad
97 | return new_tensor
98 |
99 | @property
100 | def real(self):
101 | size = self.size()
102 | n = size[-2]
103 | n = n * 2
104 | result = self[..., :n//2, :]
105 | return result
106 |
107 | @property
108 | def imag(self):
109 | size = self.size()
110 | n = size[-2]
111 | n = n * 2
112 | result = self[..., n//2:, :]
113 | return result
114 |
115 | def __graph_copy__(self, real, imag):
116 | # return tensor copy but maintain graph connections
117 | # force the result to be a ComplexTensor
118 | result = torch.cat([real, imag], dim=0)
119 | result.__class__ = ComplexTensor
120 | return result
121 |
122 | def __graph_copy_scalar__(self, real, imag):
123 | # return tensor copy but maintain graph connections
124 | # force the result to be a ComplexTensor
125 | result = torch.stack([real, imag], dim=-2)
126 | result.__class__ = ComplexScalar
127 | return result
128 |
129 | def __add__(self, other):
130 | """
131 | Handles scalar (real, complex) and tensor (real, complex) addition
132 | :param other:
133 | :return:
134 | """
135 | real = self.real
136 | imag = self.imag
137 |
138 | # given a real tensor
139 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
140 | real = real + other
141 |
142 | # given a complex tensor
143 | elif type(other) is ComplexTensor:
144 | real = real + other.real
145 | imag = imag + other.imag
146 |
147 | # given a real scalar
148 | elif np.isreal(other):
149 | real = real + other
150 |
151 | # given a complex scalar
152 | else:
153 | real = real + other.real
154 | imag = imag + other.imag
155 |
156 | return self.__graph_copy__(real, imag)
157 |
158 | def __radd__(self, other):
159 | return self.__add__(other)
160 |
161 | def __sub__(self, other):
162 | """
163 | Handles scalar (real, complex) and tensor (real, complex) addition
164 | :param other:
165 | :return:
166 | """
167 | real = self.real
168 | imag = self.imag
169 |
170 | # given a real tensor
171 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
172 | real = real - other
173 |
174 | # given a complex tensor
175 | elif type(other) is ComplexTensor:
176 | real = real - other.real
177 | imag = imag - other.imag
178 |
179 | # given a real scalar
180 | elif np.isreal(other):
181 | real = real - other
182 |
183 | # given a complex scalar
184 | else:
185 | real = real - other.real
186 | imag = imag - other.imag
187 |
188 | return self.__graph_copy__(real, imag)
189 |
190 | def __rsub__(self, other):
191 | return self.__sub__(other)
192 |
193 | def __mul__(self, other):
194 | """
195 | Handles scalar (real, complex) and tensor (real, complex) multiplication
196 | :param other:
197 | :return:
198 | """
199 | real = self.real.clone()
200 | imag = self.imag.clone()
201 |
202 | # given a real tensor
203 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
204 | real = real * other
205 | imag = imag * other
206 |
207 | # given a complex tensor
208 | elif type(other) is ComplexTensor:
209 | ac = real * other.real
210 | bd = imag * other.imag
211 | ad = real * other.imag
212 | bc = imag * other.real
213 | real = ac - bd
214 | imag = ad + bc
215 |
216 | # given a real scalar
217 | elif np.isreal(other):
218 | real = real * other
219 | imag = imag * other
220 |
221 | # given a complex scalar
222 | else:
223 | ac = real * other.real
224 | bd = imag * other.imag
225 | ad = real * other.imag
226 | bc = imag * other.real
227 | real = ac - bd
228 | imag = ad + bc
229 |
230 | return self.__graph_copy__(real, imag)
231 |
232 | def __truediv__(self, other):
233 | real = self.real.clone()
234 | imag = self.imag.clone()
235 |
236 | # given a real tensor
237 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
238 | raise NotImplementedError
239 |
240 | # given a complex tensor
241 | elif type(other) is ComplexTensor:
242 | raise NotImplementedError
243 |
244 | # given a real scalar
245 | elif np.isreal(other):
246 | real = real / other
247 | imag = imag / other
248 |
249 | # given a complex scalar
250 | else:
251 | raise NotImplementedError
252 |
253 | return self.__graph_copy__(real, imag)
254 |
255 | def __rmul__(self, other):
256 | return self.__mul__(other)
257 |
258 | def __neg__(self):
259 | return self.__mul__(-1)
260 |
261 | def mm(self, other):
262 | """
263 | Handles tensor (real, complex) matrix multiply
264 | :param other:
265 | :return:
266 | """
267 | real = self.real.clone()
268 | imag = self.imag.clone()
269 |
270 | # given a real tensor
271 | if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
272 | real = real.mm(other)
273 | imag = imag.mm(other)
274 |
275 | # given a complex tensor
276 | elif type(other) is ComplexTensor:
277 | ac = real.mm(other.real)
278 | bd = imag.mm(other.imag)
279 | ad = real.mm(other.imag)
280 | bc = imag.mm(other.real)
281 | real = ac - bd
282 | imag = ad + bc
283 |
284 | return self.__graph_copy__(real, imag)
285 |
286 | def t(self):
287 | real = self.real.t()
288 | imag = self.imag.t()
289 |
290 | return self.__graph_copy__(real, imag)
291 |
292 | def abs(self):
293 | result = torch.sqrt(self.real**2 + self.imag**2)
294 | return result
295 |
296 | def sum(self, *args):
297 | real_sum = self.real.sum(*args)
298 | imag_sum = self.imag.sum(*args)
299 | return ComplexScalar(real_sum, imag_sum)
300 |
301 | def mean(self, *args):
302 | real_mean = self.real.mean(*args)
303 | imag_mean = self.imag.mean(*args)
304 | return ComplexScalar(real_mean, imag_mean)
305 |
306 | @property
307 | def grad(self):
308 | g = self._grad
309 | g.__class__ = ComplexGrad
310 |
311 | return g
312 |
313 | def cuda(self):
314 | real = self.real.cuda()
315 | imag = self.imag.cuda()
316 |
317 | return self.__graph_copy__(real, imag)
318 |
319 | def __repr__(self):
320 | real = self.real.flatten()
321 | imag = self.imag.flatten()
322 |
323 | # use numpy to print for us
324 | # strings = np.asarray([f'({a}{"+" if b > 0 else "-"}{abs(b)}j)' for a, b in zip(real, imag)])
325 | strings = np.asarray([complex(a,b) for a, b in zip(real, imag)]).astype(np.complex64)
326 | strings = strings.__repr__()
327 | strings = re.sub('array', 'tensor', strings)
328 | return strings
329 |
330 | def __str__(self):
331 | return self.__repr__()
332 |
333 | def is_complex(self):
334 | return True
335 |
336 | def size(self, *args):
337 | size = self.data.size(*args)
338 | size = list(size)
339 | size[-2] = size[-2] // 2
340 | size = torch.Size(size)
341 | return size
342 |
343 | @property
344 | def shape(self):
345 | size = self.data.shape
346 | size = list(size)
347 | size[-2] = size[-2] // 2
348 | size = torch.Size(size)
349 | return size
350 |
351 | def __getitem__(self, item):
352 |
353 | # when real or imag is the caller return regular tensor
354 | curframe = inspect.currentframe()
355 | calframe = inspect.getouterframes(curframe, 2)
356 | caller = calframe[1][3]
357 |
358 | if caller == 'real' or caller == 'imag':
359 | return super(ComplexTensor, self).__getitem__(item)
360 |
361 | # this is a regular index op, select the requested pairs then form a new ComplexTensor
362 | r = self.real[item]
363 | c = self.imag[item]
364 |
365 | return self.__graph_copy__(r, c)
366 |
367 |
--------------------------------------------------------------------------------
/Complex demo example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## PT complex tensor examples"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "import torch\n",
18 | "from pytorch_complex_tensor import ComplexTensor\n",
19 | "import tensorflow as tf\n",
20 | "tf.enable_eager_execution()"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | "--- \n",
28 | "Creation"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "data": {
38 | "text/plain": [
39 | "array([[1.+3.j, 1.+3.j, 1.+3.j],\n",
40 | " [2.+4.j, 2.+4.j, 2.+4.j]], dtype=complex64)"
41 | ]
42 | },
43 | "execution_count": 2,
44 | "metadata": {},
45 | "output_type": "execute_result"
46 | }
47 | ],
48 | "source": [
49 | "# numpy complex tensor\n",
50 | "np_c = np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64)\n",
51 | "np_c"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 3,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "tensor([['(1.0+3.0j)' '(1.0+3.0j)' '(1.0+3.0j)'],\n",
64 | " ['(2.0+4.0j)' '(2.0+4.0j)' '(2.0+4.0j)']])\n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "# torch equivalent\n",
70 | "pt_c = ComplexTensor([[1, 1, 1], [2,2,2], [3,3,3], [4,4,4]])\n",
71 | "print(pt_c)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 4,
77 | "metadata": {},
78 | "outputs": [
79 | {
80 | "name": "stdout",
81 | "output_type": "stream",
82 | "text": [
83 | "[[1. 1. 1.]\n",
84 | " [2. 2. 2.]]\n",
85 | "tensor([[1., 1., 1.],\n",
86 | " [2., 2., 2.]])\n"
87 | ]
88 | }
89 | ],
90 | "source": [
91 | "# verify reals match\n",
92 | "print(np_c.real)\n",
93 | "print(pt_c.real)"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 5,
99 | "metadata": {},
100 | "outputs": [
101 | {
102 | "name": "stdout",
103 | "output_type": "stream",
104 | "text": [
105 | "[[3. 3. 3.]\n",
106 | " [4. 4. 4.]]\n",
107 | "tensor([[3., 3., 3.],\n",
108 | " [4., 4., 4.]])\n"
109 | ]
110 | }
111 | ],
112 | "source": [
113 | "# verify imag match\n",
114 | "print(np_c.imag)\n",
115 | "print(pt_c.imag)"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {},
121 | "source": [
122 | "--- \n",
123 | "Verify complex addition"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 6,
129 | "metadata": {},
130 | "outputs": [
131 | {
132 | "data": {
133 | "text/plain": [
134 | "array([[4.+5.j, 4.+5.j, 4.+5.j],\n",
135 | " [5.+6.j, 5.+6.j, 5.+6.j]], dtype=complex64)"
136 | ]
137 | },
138 | "execution_count": 6,
139 | "metadata": {},
140 | "output_type": "execute_result"
141 | }
142 | ],
143 | "source": [
144 | "np_c + (3+2j)"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 7,
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "data": {
154 | "text/plain": [
155 | "tensor([['(4.0+5.0j)' '(4.0+5.0j)' '(4.0+5.0j)'],\n",
156 | " ['(5.0+6.0j)' '(5.0+6.0j)' '(5.0+6.0j)']])"
157 | ]
158 | },
159 | "execution_count": 7,
160 | "metadata": {},
161 | "output_type": "execute_result"
162 | }
163 | ],
164 | "source": [
165 | "pt_c + (3 + 2j)"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "metadata": {},
171 | "source": [
172 | "--- \n",
173 | "verify abs"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 8,
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "data": {
183 | "text/plain": [
184 | "array([[3.1622777, 3.1622777, 3.1622777],\n",
185 | " [4.472136 , 4.472136 , 4.472136 ]], dtype=float32)"
186 | ]
187 | },
188 | "execution_count": 8,
189 | "metadata": {},
190 | "output_type": "execute_result"
191 | }
192 | ],
193 | "source": [
194 | "np.abs(np_c)"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": 9,
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "data": {
204 | "text/plain": [
205 | "tensor([[3.1623, 3.1623, 3.1623],\n",
206 | " [4.4721, 4.4721, 4.4721]])"
207 | ]
208 | },
209 | "execution_count": 9,
210 | "metadata": {},
211 | "output_type": "execute_result"
212 | }
213 | ],
214 | "source": [
215 | "pt_c.abs()"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {},
221 | "source": [
222 | "--- \n",
223 | "verify complex vs real matrix multiply"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 10,
229 | "metadata": {},
230 | "outputs": [
231 | {
232 | "name": "stdout",
233 | "output_type": "stream",
234 | "text": [
235 | "[[3 3]\n",
236 | " [4 4]\n",
237 | " [2 2]]\n",
238 | "tensor([[3., 3.],\n",
239 | " [4., 4.],\n",
240 | " [2., 2.]])\n"
241 | ]
242 | }
243 | ],
244 | "source": [
245 | "np_x = np.asarray([[3, 3], [4, 4], [2, 2]])\n",
246 | "pt_x = torch.Tensor([[3, 3], [4, 4], [2, 2]])\n",
247 | "\n",
248 | "print(np_x)\n",
249 | "print(pt_x)"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": 11,
255 | "metadata": {},
256 | "outputs": [
257 | {
258 | "data": {
259 | "text/plain": [
260 | "array([[ 9.+27.j, 9.+27.j],\n",
261 | " [18.+36.j, 18.+36.j]])"
262 | ]
263 | },
264 | "execution_count": 11,
265 | "metadata": {},
266 | "output_type": "execute_result"
267 | }
268 | ],
269 | "source": [
270 | "np_mm_out = np.matmul(np_c, np_x)\n",
271 | "np_mm_out"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 12,
277 | "metadata": {},
278 | "outputs": [
279 | {
280 | "data": {
281 | "text/plain": [
282 | "tensor([['(9.0+27.0j)' '(9.0+27.0j)'],\n",
283 | " ['(18.0+36.0j)' '(18.0+36.0j)']])"
284 | ]
285 | },
286 | "execution_count": 12,
287 | "metadata": {},
288 | "output_type": "execute_result"
289 | }
290 | ],
291 | "source": [
292 | "pt_mm_out = pt_c.mm(pt_x)\n",
293 | "pt_mm_out"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 13,
299 | "metadata": {},
300 | "outputs": [
301 | {
302 | "name": "stdout",
303 | "output_type": "stream",
304 | "text": [
305 | "[[ 9. 9.]\n",
306 | " [18. 18.]]\n",
307 | "tensor([[ 9., 9.],\n",
308 | " [18., 18.]])\n"
309 | ]
310 | }
311 | ],
312 | "source": [
313 | "# verify reals\n",
314 | "print(np_mm_out.real)\n",
315 | "print(pt_mm_out.real)"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 14,
321 | "metadata": {},
322 | "outputs": [
323 | {
324 | "name": "stdout",
325 | "output_type": "stream",
326 | "text": [
327 | "[[27. 27.]\n",
328 | " [36. 36.]]\n",
329 | "tensor([[27., 27.],\n",
330 | " [36., 36.]])\n"
331 | ]
332 | }
333 | ],
334 | "source": [
335 | "# verify imags\n",
336 | "print(np_mm_out.imag)\n",
337 | "print(pt_mm_out.imag)"
338 | ]
339 | },
340 | {
341 | "cell_type": "markdown",
342 | "metadata": {},
343 | "source": [
344 | "--- \n",
345 | "verify transpose"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": 15,
351 | "metadata": {},
352 | "outputs": [
353 | {
354 | "data": {
355 | "text/plain": [
356 | "array([[1.+3.j, 2.+4.j],\n",
357 | " [1.+3.j, 2.+4.j],\n",
358 | " [1.+3.j, 2.+4.j]], dtype=complex64)"
359 | ]
360 | },
361 | "execution_count": 15,
362 | "metadata": {},
363 | "output_type": "execute_result"
364 | }
365 | ],
366 | "source": [
367 | "np_c.T"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": 16,
373 | "metadata": {},
374 | "outputs": [
375 | {
376 | "data": {
377 | "text/plain": [
378 | "tensor([['(1.0+3.0j)' '(2.0+4.0j)'],\n",
379 | " ['(1.0+3.0j)' '(2.0+4.0j)'],\n",
380 | " ['(1.0+3.0j)' '(2.0+4.0j)']])"
381 | ]
382 | },
383 | "execution_count": 16,
384 | "metadata": {},
385 | "output_type": "execute_result"
386 | }
387 | ],
388 | "source": [
389 | "pt_c.t()"
390 | ]
391 | },
392 | {
393 | "cell_type": "markdown",
394 | "metadata": {},
395 | "source": [
396 | "--- \n",
397 | "## wfalcon/pytorch-complex-tensor\n"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": 17,
403 | "metadata": {},
404 | "outputs": [
405 | {
406 | "name": "stdout",
407 | "output_type": "stream",
408 | "text": [
409 | "tensor([['(1.0+2.0j)' '(3.0+4.0j)' '(5.0+6.0j)'],\n",
410 | " ['(7.0+8.0j)' '(9.0+10.0j)' '(11.0+12.0j)']])\n"
411 | ]
412 | }
413 | ],
414 | "source": [
415 | "pt_c2 = ComplexTensor([[1, 3, 5], [7,9,11], [2,4,6], [8,10,12]])\n",
416 | "print(pt_c2)\n",
417 | "pt_c2.requires_grad = True"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 18,
423 | "metadata": {},
424 | "outputs": [
425 | {
426 | "name": "stdout",
427 | "output_type": "stream",
428 | "text": [
429 | "tensor([['(15.0+136.0j)' '(69.0+334.0j)'],\n",
430 | " ['(-3.0+262.0j)' '(51.0+676.0j)']])\n"
431 | ]
432 | }
433 | ],
434 | "source": [
435 | "out = pt_c2 + 4\n",
436 | "out = out.mm(pt_c2.t())\n",
437 | "print(out)"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": 19,
443 | "metadata": {},
444 | "outputs": [
445 | {
446 | "name": "stdout",
447 | "output_type": "stream",
448 | "text": [
449 | "tensor(132., grad_fn=)\n",
450 | "tensor(1408., grad_fn=)\n"
451 | ]
452 | }
453 | ],
454 | "source": [
455 | "real_sum = out.real.sum()\n",
456 | "print(real_sum)\n",
457 | "out_imag = out.imag.sum()\n",
458 | "print(out_imag)"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": 20,
464 | "metadata": {},
465 | "outputs": [],
466 | "source": [
467 | "real_sum.backward()"
468 | ]
469 | },
470 | {
471 | "cell_type": "code",
472 | "execution_count": 21,
473 | "metadata": {},
474 | "outputs": [
475 | {
476 | "data": {
477 | "text/plain": [
478 | "tensor([['(24.0-20.0j)' '(32.0-28.0j)' '(40.0-36.0j)'],\n",
479 | " ['(24.0-20.0j)' '(32.0-28.0j)' '(40.0-36.0j)']])"
480 | ]
481 | },
482 | "execution_count": 21,
483 | "metadata": {},
484 | "output_type": "execute_result"
485 | }
486 | ],
487 | "source": [
488 | "pt_c2.grad"
489 | ]
490 | },
491 | {
492 | "cell_type": "markdown",
493 | "metadata": {},
494 | "source": [
495 | "--- \n",
496 | "### Use Tensorflow to compute grads\n"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": 6,
502 | "metadata": {},
503 | "outputs": [],
504 | "source": [
505 | "tf_c2 = tf.constant([[1+2j, 3+4j, 5+6j], [7+8j,9+10j,11+12j]], dtype=tf.complex64)"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": 7,
511 | "metadata": {},
512 | "outputs": [
513 | {
514 | "name": "stdout",
515 | "output_type": "stream",
516 | "text": [
517 | "tf.Tensor(\n",
518 | "[[15.+136.j 69.+334.j]\n",
519 | " [-3.+262.j 51.+676.j]], shape=(2, 2), dtype=complex64)\n",
520 | "tf.Tensor((132+1408j), shape=(), dtype=complex64)\n"
521 | ]
522 | }
523 | ],
524 | "source": [
525 | "with tf.GradientTape() as t:\n",
526 | " t.watch(tf_c2)\n",
527 | " tf_out = tf_c2 + 4\n",
528 | " tf_out = tf.matmul(tf_out, tf.transpose(tf_c2, perm=[1,0]))\n",
529 | " print(tf_out)\n",
530 | " \n",
531 | " tf_y = tf.reduce_sum(tf_out)\n",
532 | " print(tf_y)"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": 8,
538 | "metadata": {},
539 | "outputs": [],
540 | "source": [
541 | "dy_dc2 = t.gradient(tf_y, tf_c2)"
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "execution_count": 9,
547 | "metadata": {},
548 | "outputs": [
549 | {
550 | "data": {
551 | "text/plain": [
552 | ""
555 | ]
556 | },
557 | "execution_count": 9,
558 | "metadata": {},
559 | "output_type": "execute_result"
560 | }
561 | ],
562 | "source": [
563 | "dy_dc2"
564 | ]
565 | }
566 | ],
567 | "metadata": {
568 | "kernelspec": {
569 | "display_name": "Python [conda env:organics]",
570 | "language": "python",
571 | "name": "conda-env-organics-py"
572 | },
573 | "language_info": {
574 | "codemirror_mode": {
575 | "name": "ipython",
576 | "version": 3
577 | },
578 | "file_extension": ".py",
579 | "mimetype": "text/x-python",
580 | "name": "python",
581 | "nbconvert_exporter": "python",
582 | "pygments_lexer": "ipython3",
583 | "version": "3.6.6"
584 | }
585 | },
586 | "nbformat": 4,
587 | "nbformat_minor": 2
588 | }
589 |
--------------------------------------------------------------------------------