├── .github ├── FUNDING.yml └── workflows │ ├── codeql.yml │ ├── publish.yml │ └── test_python.yml ├── .gitignore ├── .readthedocs.yml ├── CITATION.cff ├── LICENSE.md ├── README.md ├── complexnn ├── __init__.py ├── bn.py ├── conv.py ├── dense.py ├── fft.py ├── init.py ├── norm.py ├── pool.py └── utils.py ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── bib.bib │ ├── cite.md │ ├── conf.py │ ├── contrib.md │ ├── figures │ └── complex_nn.png │ ├── index.rst │ ├── install.md │ ├── intro.rst │ ├── license.rst │ └── math.rst ├── pyproject.toml ├── requirements.txt └── tests ├── __init__.py ├── test_conv.py ├── test_dense.py ├── test_readme.py └── test_train.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | custom: https://dramsch.net -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '27 20 * * 0' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI.org 2 | on: 3 | workflow_dispatch: 4 | release: 5 | types: [published] 6 | jobs: 7 | pypi: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout 11 | uses: actions/checkout@v3 12 | with: 13 | fetch-depth: 0 14 | - run: python3 -m pip install --upgrade build && python3 -m build 15 | - name: Publish package 16 | uses: pypa/gh-action-pypi-publish@release/v1 17 | with: 18 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/test_python.yml: -------------------------------------------------------------------------------- 1 | name: Test Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.8", "3.9", "3.10"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install flake8 pytest 23 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 24 | - name: Lint with flake8 25 | run: | 26 | # stop the build if there are Python syntax errors or undefined names 27 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 28 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 29 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 30 | - name: Test with pytest 31 | run: | 32 | pytest -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Global ignore patters. They cause the files with the following names and 2 | # extensions to be ignored in general by Git. 3 | # 4 | # We ignore these files because they're usually binary and/or very large, which 5 | # greatly bloats the repository and its clone time. In this category are the 6 | # files .h5/.hdf5/.txt/.log and files named *log_file*. Some of these logs are 7 | # enormous. 8 | # 9 | # We also ignore files that can be easily reproduced one way or another by 10 | # compilers, like .pyc/.o/.so/.a/.lib. 11 | # 12 | # Lastly we ignore files and directories that signal their temporary character 13 | # with a name that contains "tmp". 14 | # 15 | # **************************************************************************** 16 | # ** DO NOT put ignore patterns here that are private to your clone of this ** 17 | # ** repository. Those belong in your clone's .git/info/exclude file. ** 18 | # **************************************************************************** 19 | # 20 | **/*.h5 21 | **/*.hdf5 22 | **/*.txt 23 | **/*log_file* 24 | **/*.log 25 | **/*.out 26 | **/*.pyc 27 | **/*.o 28 | **/*.so 29 | **/*.a 30 | **/*.lib 31 | **/*.dll 32 | **/*.exe 33 | **/*tmp* 34 | **/*.yaml 35 | **/*.npy 36 | **/slurm-*.out 37 | keras_complex.egg-info/ 38 | build/ 39 | 40 | # Ignore virtual environment 41 | venv 42 | .venv 43 | 44 | # IDE 45 | .vscode 46 | .idea 47 | 48 | dist/* 49 | 50 | __pycache__/* -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.10" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | formats: 24 | - pdf 25 | 26 | python: 27 | install: 28 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: '1.2.0' 2 | message: 'Please cite the following works when using this software.' 3 | authors: 4 | - family-names: 'Dramsch' 5 | given-names: 'Jesper Sören' 6 | orcid: '0000-0001-8273-905X' 7 | doi: '10.6084/m9.figshare.9783773' 8 | identifiers: 9 | - type: 'doi' 10 | value: '10.6084/m9.figshare.9783773' 11 | - type: 'url' 12 | value: 'https://figshare.com/articles/Complex-Valued_Neural_Networks_in_Keras_with_Tensorflow/9783773/1' 13 | title: 'Complex-Valued Neural Networks in Keras with Tensorflow' 14 | url: 'https://figshare.com/articles/Complex-Valued_Neural_Networks_in_Keras_with_Tensorflow/9783773/1' 15 | date-published: 2019-01-01 16 | year: 2019 17 | publisher: 18 | name: 'figshare' 19 | type: 'software' 20 | preferred-citation: 21 | abstract: 'Deep learning has become an area of interest in most scientific areas, including physical sciences. Modern networks apply real-valued transformations on the data. Particularly, convolutions in convolutional neural networks discard phase information entirely. Many deterministic signals, such as seismic data or electrical signals, contain significant information in the phase of the signal. We explore complex-valued deep convolutional networks to leverage non-linear feature maps. Seismic data commonly has a lowcut filter applied, to attenuate noise from ocean waves and similar long wavelength contributions. In non-stationary data, the phase content can stabilize training and improve the generalizability of neural networks. While it has been shown that phase content can be restored in deep neural networks, we show how including phase information in feature maps improves both training and inference from deterministic physical data. Furthermore, we show that smaller complex networks outperform larger real-valued networks.' 22 | authors: 23 | - family-names: 'Dramsch' 24 | given-names: 'Jesper Sören' 25 | orcid: '0000-0001-8273-905X' 26 | - family-names: 'Lüthje' 27 | given-names: 'Mikael' 28 | orcid: "0000-0003-2715-1653" 29 | - family-names: 'Christensen' 30 | given-names: 'Anders Nymark' 31 | orcid: "0000-0002-3668-3128" 32 | doi: 'https://doi.org/10.1016/j.cageo.2020.104643' 33 | identifiers: 34 | - type: 'doi' 35 | value: 'https://doi.org/10.1016/j.cageo.2020.104643' 36 | - type: 'url' 37 | value: 'https://www.sciencedirect.com/science/article/pii/S0098300420306208' 38 | - type: 'other' 39 | value: 'urn:issn:0098-3004' 40 | title: 'Complex-valued neural networks for machine learning on non-stationary physical data' 41 | url: 'https://www.sciencedirect.com/science/article/pii/S0098300420306208' 42 | references: 43 | - authors: 44 | - family-names: 'Trabelsi' 45 | given-names: 'Chiheb' 46 | title: 'Deep Complex Networks' 47 | date-published: 2017-01-01 48 | year: 2017 49 | journal: 'arXiv preprint arXiv:1705.09792' 50 | type: 'article' 51 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Tensorflow Port: MIT License 2 | 3 | Copyright (c) 2019 Jesper Sören Dramsch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | Deep Complex Networks: MIT License 24 | 25 | Copyright (c) 2017 Chiheb Trabelsi 26 | 27 | Copyright (c) 2017 Olexa Bilaniuk 28 | 29 | Copyright (c) 2017 Ying Zhang 30 | 31 | Copyright (c) 2017 Dmitriy Serdyuk 32 | 33 | Copyright (c) 2017 Sandeep Subramanian 34 | 35 | Copyright (c) 2017 João Felipe Santos 36 | 37 | Copyright (c) 2017 Soroush Mehri 38 | 39 | Copyright (c) 2017 Negar Rostamzadeh 40 | 41 | Copyright (c) 2017 Yoshua Bengio 42 | 43 | Copyright (c) 2017 Christopher J Pal 44 | 45 | Permission is hereby granted, free of charge, to any person obtaining a copy 46 | of this software and associated documentation files (the "Software"), to deal 47 | in the Software without restriction, including without limitation the rights 48 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 49 | copies of the Software, and to permit persons to whom the Software is 50 | furnished to do so, subject to the following conditions: 51 | 52 | The above copyright notice and this permission notice shall be included in all 53 | copies or substantial portions of the Software. 54 | 55 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 56 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 57 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 58 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 59 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 60 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 61 | SOFTWARE. 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Complex-Valued Neural Networks in Keras with Tensorflow 2 | [![Documentation](https://readthedocs.org/projects/keras-complex/badge/?version=latest)](https://keras-complex.readthedocs.io/) [![Build Status](https://github.com/JesperDramsch/keras-complex/actions/workflows/test_python.yml/badge.svg)](https://github.com/JesperDramsch/keras-complex/actions/) [![PyPI Versions](https://img.shields.io/pypi/pyversions/keras-complex.svg)](https://pypi.python.org/pypi/keras-complex) ![Tensorflow 2+](https://img.shields.io/badge/tensorflow-%3E2.0-orange) ![PyPI - Downloads](https://img.shields.io/pypi/dm/keras-complex) [![PyPI Status](https://img.shields.io/pypi/status/keras-complex.svg)](https://pypi.python.org/pypi/keras-complex) [![PyPI License](https://img.shields.io/badge/License-MIT-green)](LICENSCE.md) 3 | 4 | [Complex-valued convolutions](https://en.wikipedia.org/wiki/Convolution#Domain_of_definition) could provide some interesting results in signal processing-based deep learning. A simple(-ish) idea is including explicit phase information of time series in neural networks. This code enables complex-valued convolution in convolutional neural networks in [keras](https://keras.io) with the [TensorFlow](https://tensorflow.org/) backend. This makes the network modular and interoperable with standard keras layers and operations. 5 | 6 | This code is very much in **Alpha**. Please consider helping out improving the code to advance together. This repository is based on the code which reproduces experiments presented in the paper [Deep Complex Networks](https://arxiv.org/abs/1705.09792). It is a port to Keras with Tensorflow-backend. 7 | 8 | Requirements 9 | ------------ 10 | 11 | - numpy 12 | - scipy 13 | - scikit-learn 14 | - tensorflow 2.X 15 | 16 | Install requirements for computer vision experiments with pip: 17 | 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | Depending on your Python installation you might want to use anaconda or venv or other tools. 23 | 24 | 25 | Installation 26 | ------------ 27 | 28 | ``` 29 | pip install keras-complex 30 | ``` 31 | 32 | Usage 33 | ----- 34 | Build your neural networks with the help of keras. 35 | 36 | ``` 37 | import complexnn 38 | 39 | import keras 40 | from keras import models 41 | from keras import layers 42 | from keras import optimizers 43 | 44 | model = models.Sequential() 45 | 46 | model.add(complexnn.conv.ComplexConv2D(32, (3, 3), activation='relu', padding='same', input_shape=(28, 28, 2))) 47 | model.add(complexnn.bn.ComplexBatchNormalization()) 48 | model.add(layers.MaxPooling2D((2, 2), padding='same')) 49 | 50 | model.compile(optimizer=optimizers.Adam(), loss='mse') 51 | 52 | ``` 53 | 54 | An example working implementation of an autoencoder can be found [here](https://github.com/JesperDramsch/Complex-CNN-Seismic/). 55 | 56 | Complex Format of Tensors 57 | ------------------------- 58 | 59 | This library assumes that complex values are split into two real-valued parts. The real-valued and complex-valued complement, also seen [in the Docs](https://keras-complex.readthedocs.io/math.html). 60 | 61 | The tensors for a 2D complex tensor of 3x3, the look like: 62 | 63 | ``` 64 | [[[r r r], 65 | [r r r], 66 | [r r r]], 67 | [i,i,i], 68 | [i,i,i], 69 | [i,i,i]]] 70 | ``` 71 | 72 | So multiple samples should then be arranged into `[r,r,r,i,i,i]`, which is also documented [in the Docs](https://keras-complex.readthedocs.io/math.html#implementation). 73 | 74 | Citation 75 | -------- 76 | 77 | Find the [CITATION file](/CITATION.cff) or cite this software version as: 78 | ``` 79 | @misc{dramsch2019complex, 80 | title = {Complex-Valued Neural Networks in Keras with Tensorflow}, 81 | url = {https://figshare.com/articles/Complex-Valued_Neural_Networks_in_Keras_with_Tensorflow/9783773/1}, 82 | DOI = {10.6084/m9.figshare.9783773}, 83 | publisher = {figshare}, 84 | author = {Dramsch, Jesper S{\"o}ren and Contributors}, 85 | year = {2019} 86 | } 87 | ``` 88 | 89 | Please cite the original work as: 90 | 91 | ``` 92 | @ARTICLE {Trabelsi2017, 93 | author = "Chiheb Trabelsi, Olexa Bilaniuk, Ying Zhang, Dmitriy Serdyuk, Sandeep Subramanian, João Felipe Santos, Soroush Mehri, Negar Rostamzadeh, Yoshua Bengio, Christopher J Pal", 94 | title = "Deep Complex Networks", 95 | journal = "arXiv preprint arXiv:1705.09792", 96 | year = "2017" 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /complexnn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from . import bn, conv, dense, init, norm, pool 5 | 6 | # from . import fft 7 | 8 | from .bn import ComplexBatchNormalization as ComplexBN 9 | from .conv import ( 10 | ComplexConv, 11 | ComplexConv1D, 12 | ComplexConv2D, 13 | ComplexConv3D, 14 | WeightNorm_Conv, 15 | ) 16 | from .dense import ComplexDense 17 | 18 | # from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2) 19 | from .init import ( 20 | ComplexIndependentFilters, 21 | IndependentFilters, 22 | ComplexInit, 23 | SqrtInit, 24 | ) 25 | from .norm import LayerNormalization, ComplexLayerNorm 26 | from .pool import SpectralPooling1D, SpectralPooling2D 27 | from .utils import ( 28 | get_realpart, 29 | get_imagpart, 30 | getpart_output_shape, 31 | GetImag, 32 | GetReal, 33 | GetAbs, 34 | ) 35 | -------------------------------------------------------------------------------- /complexnn/bn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: The implementation of complex Batchnorm is based on 5 | # the Keras implementation of batch Normalization 6 | # available here: 7 | # https://github.com/fchollet/keras/blob/master/keras/layers/normalization.py 8 | 9 | import numpy as np 10 | from tensorflow.keras.layers import Layer, InputSpec 11 | from tensorflow.keras import initializers, regularizers, constraints 12 | import tensorflow.keras.backend as K 13 | 14 | 15 | def sqrt_init(shape, dtype=None): 16 | value = (1 / np.sqrt(2)) * K.ones(shape) 17 | return value 18 | 19 | 20 | def sanitizedInitGet(init): 21 | if init in ["sqrt_init"]: 22 | return sqrt_init 23 | else: 24 | return initializers.get(init) 25 | 26 | 27 | def sanitizedInitSer(init): 28 | if init in [sqrt_init]: 29 | return "sqrt_init" 30 | else: 31 | return initializers.serialize(init) 32 | 33 | 34 | def complex_standardization(input_centred, Vrr, Vii, Vri, layernorm=False, axis=-1): 35 | """Complex Standardization of input 36 | 37 | Arguments: 38 | input_centred -- Input Tensor 39 | Vrr -- Real component of covariance matrix V 40 | Vii -- Imaginary component of covariance matrix V 41 | Vri -- Non-diagonal component of covariance matrix V 42 | 43 | Keyword Arguments: 44 | layernorm {bool} -- Normalization (default: {False}) 45 | axis {int} -- Axis for Standardization (default: {-1}) 46 | 47 | Raises: 48 | ValueError: Mismatched dimensoins 49 | 50 | Returns: 51 | Complex standardized input 52 | """ 53 | 54 | ndim = K.ndim(input_centred) 55 | input_dim = K.shape(input_centred)[axis] // 2 56 | variances_broadcast = [1] * ndim 57 | variances_broadcast[axis] = input_dim 58 | if layernorm: 59 | variances_broadcast[0] = K.shape(input_centred)[0] 60 | 61 | # We require the covariance matrix's inverse square root. That first 62 | # requires square rooting, followed by inversion (I do this in that order 63 | # because during the computation of square root we compute the determinant 64 | # we'll need for inversion as well). 65 | 66 | # tau = Vrr + Vii = Trace. Guaranteed >= 0 because SPD 67 | tau = Vrr + Vii 68 | # delta = (Vrr * Vii) - (Vri ** 2) = Determinant. Guaranteed >= 0 because 69 | # SPD 70 | delta = (Vrr * Vii) - (Vri**2) 71 | 72 | s = K.sqrt(delta) # Determinant of square root matrix 73 | t = K.sqrt(tau + 2 * s) 74 | 75 | # The square root matrix could now be explicitly formed as 76 | # [ Vrr+s Vri ] 77 | # (1/t) [ Vir Vii+s ] 78 | # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix 79 | # but we don't need to do this immediately since we can also simultaneously 80 | # invert. We can do this because we've already computed the determinant of 81 | # the square root matrix, and can thus invert it using the analytical 82 | # solution for 2x2 matrices 83 | # [ A B ] [ D -B ] 84 | # inv( [ C D ] ) = (1/det) [ -C A ] 85 | # http://mathworld.wolfram.com/MatrixInverse.html 86 | # Thus giving us 87 | # [ Vii+s -Vri ] 88 | # (1/s)(1/t)[ -Vir Vrr+s ] 89 | # So we proceed as follows: 90 | 91 | inverse_st = 1.0 / (s * t) 92 | Wrr = (Vii + s) * inverse_st 93 | Wii = (Vrr + s) * inverse_st 94 | Wri = -Vri * inverse_st 95 | 96 | # And we have computed the inverse square root matrix W = sqrt(V)! 97 | # Normalization. We multiply, x_normalized = W.x. 98 | 99 | # The returned result will be a complex standardized input 100 | # where the real and imaginary parts are obtained as follows: 101 | # x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred 102 | # x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred 103 | 104 | broadcast_Wrr = K.reshape(Wrr, variances_broadcast) 105 | broadcast_Wri = K.reshape(Wri, variances_broadcast) 106 | broadcast_Wii = K.reshape(Wii, variances_broadcast) 107 | 108 | cat_W_4_real = K.concatenate([broadcast_Wrr, broadcast_Wii], axis=axis) 109 | cat_W_4_imag = K.concatenate([broadcast_Wri, broadcast_Wri], axis=axis) 110 | 111 | if (axis == 1 and ndim != 3) or ndim == 2: 112 | centred_real = input_centred[:, :input_dim] 113 | centred_imag = input_centred[:, input_dim:] 114 | elif ndim == 3: 115 | centred_real = input_centred[:, :, :input_dim] 116 | centred_imag = input_centred[:, :, input_dim:] 117 | elif axis == -1 and ndim == 4: 118 | centred_real = input_centred[:, :, :, :input_dim] 119 | centred_imag = input_centred[:, :, :, input_dim:] 120 | elif axis == -1 and ndim == 5: 121 | centred_real = input_centred[:, :, :, :, :input_dim] 122 | centred_imag = input_centred[:, :, :, :, input_dim:] 123 | else: 124 | raise ValueError( 125 | "Incorrect Batchnorm combination of axis and dimensions. axis " 126 | "should be either 1 or -1. " 127 | "axis: " + str(axis) + "; ndim: " + str(ndim) + "." 128 | ) 129 | rolled_input = K.concatenate([centred_imag, centred_real], axis=axis) 130 | 131 | output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input 132 | 133 | # Wrr * x_real_centered | Wii * x_imag_centered 134 | # + Wri * x_imag_centered | Wri * x_real_centered 135 | # ----------------------------------------------- 136 | # = output 137 | 138 | return output 139 | 140 | 141 | def ComplexBN( 142 | input_centred, Vrr, Vii, Vri, beta, gamma_rr, gamma_ri, gamma_ii, scale=True, center=True, layernorm=False, axis=-1 143 | ): 144 | """Complex Batch Normalization 145 | 146 | Arguments: 147 | input_centred -- input data 148 | Vrr -- Real component of covariance matrix V 149 | Vii -- Imaginary component of covariance matrix V 150 | Vri -- Non-diagonal component of covariance matrix V 151 | beta -- Lernable shift parameter beta 152 | gamma_rr -- Scaling parameter gamma - rr component of 2x2 matrix 153 | gamma_ri -- Scaling parameter gamma - ri component of 2x2 matrix 154 | gamma_ii -- Scaling parameter gamma - ii component of 2x2 matrix 155 | 156 | Keyword Arguments: 157 | scale {bool} {bool} -- Standardization of input (default: {True}) 158 | center {bool} -- Mean-shift correction (default: {True}) 159 | layernorm {bool} -- Normalization (default: {False}) 160 | axis {int} -- Axis for Standardization (default: {-1}) 161 | 162 | Raises: 163 | ValueError: Dimonsional mismatch 164 | 165 | Returns: 166 | Batch-Normalized Input 167 | """ 168 | 169 | ndim = K.ndim(input_centred) 170 | input_dim = K.shape(input_centred)[axis] // 2 171 | if scale: 172 | gamma_broadcast_shape = [1] * ndim 173 | gamma_broadcast_shape[axis] = input_dim 174 | if center: 175 | broadcast_beta_shape = [1] * ndim 176 | broadcast_beta_shape[axis] = input_dim * 2 177 | 178 | if scale: 179 | standardized_output = complex_standardization(input_centred, Vrr, Vii, Vri, layernorm, axis=axis) 180 | 181 | # Now we perform th scaling and Shifting of the normalized x using 182 | # the scaling parameter 183 | # [ gamma_rr gamma_ri ] 184 | # Gamma = [ gamma_ri gamma_ii ] 185 | # and the shifting parameter 186 | # Beta = [beta_real beta_imag].T 187 | # where: 188 | # x_real_BN = gamma_rr * x_real_normed + 189 | # gamma_ri * x_imag_normed + beta_real 190 | # x_imag_BN = gamma_ri * x_real_normed + 191 | # gamma_ii * x_imag_normed + beta_imag 192 | 193 | broadcast_gamma_rr = K.reshape(gamma_rr, gamma_broadcast_shape) 194 | broadcast_gamma_ri = K.reshape(gamma_ri, gamma_broadcast_shape) 195 | broadcast_gamma_ii = K.reshape(gamma_ii, gamma_broadcast_shape) 196 | 197 | cat_gamma_4_real = K.concatenate([broadcast_gamma_rr, broadcast_gamma_ii], axis=axis) 198 | cat_gamma_4_imag = K.concatenate([broadcast_gamma_ri, broadcast_gamma_ri], axis=axis) 199 | if (axis == 1 and ndim != 3) or ndim == 2: 200 | centred_real = standardized_output[:, :input_dim] 201 | centred_imag = standardized_output[:, input_dim:] 202 | elif ndim == 3: 203 | centred_real = standardized_output[:, :, :input_dim] 204 | centred_imag = standardized_output[:, :, input_dim:] 205 | elif axis == -1 and ndim == 4: 206 | centred_real = standardized_output[:, :, :, :input_dim] 207 | centred_imag = standardized_output[:, :, :, input_dim:] 208 | elif axis == -1 and ndim == 5: 209 | centred_real = standardized_output[:, :, :, :, :input_dim] 210 | centred_imag = standardized_output[:, :, :, :, input_dim:] 211 | else: 212 | raise ValueError( 213 | "Incorrect Batchnorm combination of axis and dimensions. axis" 214 | " should be either 1 or -1. " 215 | "axis: " + str(axis) + "; ndim: " + str(ndim) + "." 216 | ) 217 | rolled_standardized_output = K.concatenate([centred_imag, centred_real], axis=axis) 218 | if center: 219 | broadcast_beta = K.reshape(beta, broadcast_beta_shape) 220 | return ( 221 | cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output + broadcast_beta 222 | ) 223 | else: 224 | return cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output 225 | else: 226 | if center: 227 | broadcast_beta = K.reshape(beta, broadcast_beta_shape) 228 | return input_centred + broadcast_beta 229 | else: 230 | return input_centred 231 | 232 | 233 | class ComplexBatchNormalization(Layer): 234 | """Complex version of the real domain 235 | Batch normalization layer (Ioffe and Szegedy, 2014). 236 | Normalize the activations of the previous complex layer at each batch, 237 | i.e. applies a transformation that maintains the mean of a complex unit 238 | close to the null vector, the 2 by 2 covariance matrix of a complex unit close to identity 239 | and the 2 by 2 relation matrix, also called pseudo-covariance, close to the 240 | null matrix. 241 | # Arguments 242 | axis: Integer, the axis that should be normalized 243 | (typically the features axis). 244 | For instance, after a `Conv2D` layer with 245 | `data_format="channels_first"`, 246 | set `axis=2` in `ComplexBatchNormalization`. 247 | momentum: Momentum for the moving statistics related to the real and 248 | imaginary parts. 249 | epsilon: Small float added to each of the variances related to the 250 | real and imaginary parts in order to avoid dividing by zero. 251 | center: If True, add offset of `beta` to complex normalized tensor. 252 | If False, `beta` is ignored. 253 | (beta is formed by real_beta and imag_beta) 254 | scale: If True, multiply by the `gamma` matrix. 255 | If False, `gamma` is not used. 256 | beta_initializer: Initializer for the real_beta and the imag_beta weight. 257 | gamma_diag_initializer: Initializer for the diagonal elements of the gamma matrix. 258 | which are the variances of the real part and the imaginary part. 259 | gamma_off_initializer: Initializer for the off-diagonal elements of the gamma matrix. 260 | moving_mean_initializer: Initializer for the moving means. 261 | moving_variance_initializer: Initializer for the moving variances. 262 | moving_covariance_initializer: Initializer for the moving covariance of 263 | the real and imaginary parts. 264 | beta_regularizer: Optional regularizer for the beta weights. 265 | gamma_regularizer: Optional regularizer for the gamma weights. 266 | beta_constraint: Optional constraint for the beta weights. 267 | gamma_constraint: Optional constraint for the gamma weights. 268 | # Input shape 269 | Arbitrary. Use the keyword argument `input_shape` 270 | (tuple of integers, does not include the samples axis) 271 | when using this layer as the first layer in a model. 272 | # Output shape 273 | Same shape as input. 274 | # References 275 | - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167) 276 | """ 277 | 278 | def __init__( 279 | self, 280 | axis=-1, 281 | momentum=0.9, 282 | epsilon=1e-4, 283 | center=True, 284 | scale=True, 285 | beta_initializer="zeros", 286 | gamma_diag_initializer="sqrt_init", 287 | gamma_off_initializer="zeros", 288 | moving_mean_initializer="zeros", 289 | moving_variance_initializer="sqrt_init", 290 | moving_covariance_initializer="zeros", 291 | beta_regularizer=None, 292 | gamma_diag_regularizer=None, 293 | gamma_off_regularizer=None, 294 | beta_constraint=None, 295 | gamma_diag_constraint=None, 296 | gamma_off_constraint=None, 297 | **kwargs 298 | ): 299 | super(ComplexBatchNormalization, self).__init__(**kwargs) 300 | self.supports_masking = True 301 | self.axis = axis 302 | self.momentum = momentum 303 | self.epsilon = epsilon 304 | self.center = center 305 | self.scale = scale 306 | self.beta_initializer = sanitizedInitGet(beta_initializer) 307 | self.gamma_diag_initializer = sanitizedInitGet(gamma_diag_initializer) 308 | self.gamma_off_initializer = sanitizedInitGet(gamma_off_initializer) 309 | self.moving_mean_initializer = sanitizedInitGet(moving_mean_initializer) 310 | self.moving_variance_initializer = sanitizedInitGet(moving_variance_initializer) 311 | self.moving_covariance_initializer = sanitizedInitGet(moving_covariance_initializer) 312 | self.beta_regularizer = regularizers.get(beta_regularizer) 313 | self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) 314 | self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) 315 | self.beta_constraint = constraints.get(beta_constraint) 316 | self.gamma_diag_constraint = constraints.get(gamma_diag_constraint) 317 | self.gamma_off_constraint = constraints.get(gamma_off_constraint) 318 | 319 | def build(self, input_shape): 320 | 321 | ndim = len(input_shape) 322 | 323 | dim = input_shape[self.axis] 324 | if dim is None: 325 | raise ValueError( 326 | "Axis " + str(self.axis) + " of " 327 | "input tensor should have a defined dimension " 328 | "but the layer received an input with shape " + str(input_shape) + "." 329 | ) 330 | self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: dim}) 331 | 332 | param_shape = (input_shape[self.axis] // 2,) 333 | 334 | if self.scale: 335 | self.gamma_rr = self.add_weight( 336 | shape=param_shape, 337 | name="gamma_rr", 338 | initializer=self.gamma_diag_initializer, 339 | regularizer=self.gamma_diag_regularizer, 340 | constraint=self.gamma_diag_constraint, 341 | ) 342 | self.gamma_ii = self.add_weight( 343 | shape=param_shape, 344 | name="gamma_ii", 345 | initializer=self.gamma_diag_initializer, 346 | regularizer=self.gamma_diag_regularizer, 347 | constraint=self.gamma_diag_constraint, 348 | ) 349 | self.gamma_ri = self.add_weight( 350 | shape=param_shape, 351 | name="gamma_ri", 352 | initializer=self.gamma_off_initializer, 353 | regularizer=self.gamma_off_regularizer, 354 | constraint=self.gamma_off_constraint, 355 | ) 356 | self.moving_Vrr = self.add_weight( 357 | shape=param_shape, initializer=self.moving_variance_initializer, name="moving_Vrr", trainable=False 358 | ) 359 | self.moving_Vii = self.add_weight( 360 | shape=param_shape, initializer=self.moving_variance_initializer, name="moving_Vii", trainable=False 361 | ) 362 | self.moving_Vri = self.add_weight( 363 | shape=param_shape, initializer=self.moving_covariance_initializer, name="moving_Vri", trainable=False 364 | ) 365 | else: 366 | self.gamma_rr = None 367 | self.gamma_ii = None 368 | self.gamma_ri = None 369 | self.moving_Vrr = None 370 | self.moving_Vii = None 371 | self.moving_Vri = None 372 | 373 | if self.center: 374 | self.beta = self.add_weight( 375 | shape=(input_shape[self.axis],), 376 | name="beta", 377 | initializer=self.beta_initializer, 378 | regularizer=self.beta_regularizer, 379 | constraint=self.beta_constraint, 380 | ) 381 | self.moving_mean = self.add_weight( 382 | shape=(input_shape[self.axis],), 383 | initializer=self.moving_mean_initializer, 384 | name="moving_mean", 385 | trainable=False, 386 | ) 387 | else: 388 | self.beta = None 389 | self.moving_mean = None 390 | 391 | self.built = True 392 | 393 | def call(self, inputs, training=None): 394 | input_shape = K.int_shape(inputs) 395 | ndim = len(input_shape) 396 | reduction_axes = list(range(ndim)) 397 | del reduction_axes[self.axis] 398 | input_dim = input_shape[self.axis] // 2 399 | mu = K.mean(inputs, axis=reduction_axes) 400 | broadcast_mu_shape = [1] * len(input_shape) 401 | broadcast_mu_shape[self.axis] = input_shape[self.axis] 402 | broadcast_mu = K.reshape(mu, broadcast_mu_shape) 403 | if self.center: 404 | input_centred = inputs - broadcast_mu 405 | else: 406 | input_centred = inputs 407 | centred_squared = input_centred**2 408 | if (self.axis == 1 and ndim != 3) or ndim == 2: 409 | centred_squared_real = centred_squared[:, :input_dim] 410 | centred_squared_imag = centred_squared[:, input_dim:] 411 | centred_real = input_centred[:, :input_dim] 412 | centred_imag = input_centred[:, input_dim:] 413 | elif ndim == 3: 414 | centred_squared_real = centred_squared[:, :, :input_dim] 415 | centred_squared_imag = centred_squared[:, :, input_dim:] 416 | centred_real = input_centred[:, :, :input_dim] 417 | centred_imag = input_centred[:, :, input_dim:] 418 | elif self.axis == -1 and ndim == 4: 419 | centred_squared_real = centred_squared[:, :, :, :input_dim] 420 | centred_squared_imag = centred_squared[:, :, :, input_dim:] 421 | centred_real = input_centred[:, :, :, :input_dim] 422 | centred_imag = input_centred[:, :, :, input_dim:] 423 | elif self.axis == -1 and ndim == 5: 424 | centred_squared_real = centred_squared[:, :, :, :, :input_dim] 425 | centred_squared_imag = centred_squared[:, :, :, :, input_dim:] 426 | centred_real = input_centred[:, :, :, :, :input_dim] 427 | centred_imag = input_centred[:, :, :, :, input_dim:] 428 | else: 429 | raise ValueError( 430 | "Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. " 431 | "axis: " + str(self.axis) + "; ndim: " + str(ndim) + "." 432 | ) 433 | if self.scale: 434 | Vrr = K.mean(centred_squared_real, axis=reduction_axes) + self.epsilon 435 | Vii = K.mean(centred_squared_imag, axis=reduction_axes) + self.epsilon 436 | # Vri contains the real and imaginary covariance for each feature map. 437 | Vri = K.mean( 438 | centred_real * centred_imag, 439 | axis=reduction_axes, 440 | ) 441 | elif self.center: 442 | Vrr = None 443 | Vii = None 444 | Vri = None 445 | else: 446 | raise ValueError("Error. Both scale and center in batchnorm are set to False.") 447 | 448 | input_bn = ComplexBN( 449 | input_centred, 450 | Vrr, 451 | Vii, 452 | Vri, 453 | self.beta, 454 | self.gamma_rr, 455 | self.gamma_ri, 456 | self.gamma_ii, 457 | self.scale, 458 | self.center, 459 | axis=self.axis, 460 | ) 461 | if training in {0, False}: 462 | return input_bn 463 | else: 464 | update_list = [] 465 | if self.center: 466 | update_list.append(K.moving_average_update(self.moving_mean, mu, self.momentum)) 467 | if self.scale: 468 | update_list.append(K.moving_average_update(self.moving_Vrr, Vrr, self.momentum)) 469 | update_list.append(K.moving_average_update(self.moving_Vii, Vii, self.momentum)) 470 | update_list.append(K.moving_average_update(self.moving_Vri, Vri, self.momentum)) 471 | self.add_update(update_list) 472 | 473 | def normalize_inference(): 474 | if self.center: 475 | inference_centred = inputs - K.reshape(self.moving_mean, broadcast_mu_shape) 476 | else: 477 | inference_centred = inputs 478 | return ComplexBN( 479 | inference_centred, 480 | self.moving_Vrr, 481 | self.moving_Vii, 482 | self.moving_Vri, 483 | self.beta, 484 | self.gamma_rr, 485 | self.gamma_ri, 486 | self.gamma_ii, 487 | self.scale, 488 | self.center, 489 | axis=self.axis, 490 | ) 491 | 492 | # Pick the normalized form corresponding to the training phase. 493 | return K.in_train_phase(input_bn, normalize_inference, training=training) 494 | 495 | def get_config(self): 496 | config = { 497 | "axis": self.axis, 498 | "momentum": self.momentum, 499 | "epsilon": self.epsilon, 500 | "center": self.center, 501 | "scale": self.scale, 502 | "beta_initializer": sanitizedInitSer(self.beta_initializer), 503 | "gamma_diag_initializer": sanitizedInitSer(self.gamma_diag_initializer), 504 | "gamma_off_initializer": sanitizedInitSer(self.gamma_off_initializer), 505 | "moving_mean_initializer": sanitizedInitSer(self.moving_mean_initializer), 506 | "moving_variance_initializer": sanitizedInitSer(self.moving_variance_initializer), 507 | "moving_covariance_initializer": sanitizedInitSer(self.moving_covariance_initializer), 508 | "beta_regularizer": regularizers.serialize(self.beta_regularizer), 509 | "gamma_diag_regularizer": regularizers.serialize(self.gamma_diag_regularizer), 510 | "gamma_off_regularizer": regularizers.serialize(self.gamma_off_regularizer), 511 | "beta_constraint": constraints.serialize(self.beta_constraint), 512 | "gamma_diag_constraint": constraints.serialize(self.gamma_diag_constraint), 513 | "gamma_off_constraint": constraints.serialize(self.gamma_off_constraint), 514 | } 515 | base_config = super(ComplexBatchNormalization, self).get_config() 516 | return dict(list(base_config.items()) + list(config.items())) 517 | -------------------------------------------------------------------------------- /complexnn/conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras import activations, initializers, regularizers, constraints 7 | from tensorflow.keras.layers import ( 8 | Layer, 9 | InputSpec, 10 | ) 11 | from tensorflow.python.keras.layers.convolutional import Conv 12 | from tensorflow.python.keras.utils import conv_utils 13 | import numpy as np 14 | from .fft import fft, ifft, fft2, ifft2 15 | from .bn import ComplexBN as complex_normalization 16 | from .bn import sqrt_init 17 | from .init import ComplexInit, ComplexIndependentFilters 18 | 19 | 20 | def conv1d_transpose( 21 | inputs, 22 | filter, 23 | kernel_size=None, 24 | filters=None, 25 | strides=(1,), 26 | padding="SAME", 27 | output_padding=None, 28 | data_format="channels_last", 29 | ): 30 | """Compatibility layer for nn.conv1d_transpose 31 | 32 | Take a filter defined for forward convolution and adjusts it for a 33 | transposed convolution.""" 34 | if isinstance(kernel_size, tuple): 35 | kernel_size = kernel_size[0] 36 | input_shape = inputs.shape 37 | batch_size = input_shape[0] 38 | if data_format == "channels_first": 39 | w_axis = 2 40 | d_format = "NCW" 41 | else: 42 | w_axis = 1 43 | d_format = "NWC" 44 | 45 | width = input_shape[w_axis] 46 | 47 | # Infer the dynamic output shape: 48 | out_width = conv_utils.deconv_output_length( 49 | input_length=width, 50 | filter_size=kernel_size, 51 | padding=padding, 52 | output_padding=output_padding, 53 | stride=strides, 54 | ) 55 | 56 | if data_format == "channels_first": 57 | output_shape = (batch_size, filters, out_width) 58 | else: 59 | output_shape = (batch_size, out_width, filters) 60 | 61 | filter = K.permute_dimensions(filter, (0, 2, 1)) 62 | return tf.nn.conv1d_transpose( 63 | inputs, 64 | filter, 65 | output_shape, 66 | strides, 67 | padding=padding.upper(), 68 | data_format=d_format, 69 | ) 70 | 71 | 72 | def conv2d_transpose( 73 | inputs, 74 | filter, 75 | kernel_size=None, 76 | filters=None, 77 | strides=(1, 1), 78 | padding="SAME", 79 | output_padding=None, 80 | data_format="channels_last", 81 | ): 82 | """Compatibility layer for K.conv2d_transpose 83 | 84 | Take a filter defined for forward convolution and adjusts it for a 85 | transposed convolution.""" 86 | input_shape = inputs.shape 87 | batch_size = input_shape[0] 88 | if data_format == "channels_first": 89 | h_axis, w_axis = 2, 3 90 | else: 91 | h_axis, w_axis = 1, 2 92 | 93 | height, width = input_shape[h_axis], input_shape[w_axis] 94 | kernel_h, kernel_w = kernel_size 95 | stride_h, stride_w = strides 96 | 97 | # Infer the dynamic output shape: 98 | out_height = conv_utils.deconv_output_length( 99 | input_length=height, 100 | filter_size=kernel_h, 101 | padding=padding, 102 | output_padding=output_padding, 103 | stride=stride_h, 104 | ) 105 | out_width = conv_utils.deconv_output_length( 106 | input_length=width, 107 | filter_size=kernel_w, 108 | padding=padding, 109 | output_padding=output_padding, 110 | stride=stride_w, 111 | ) 112 | 113 | if data_format == "channels_first": 114 | output_shape = (batch_size, filters, out_height, out_width) 115 | else: 116 | output_shape = (batch_size, out_height, out_width, filters) 117 | 118 | filter = K.permute_dimensions(filter, (0, 1, 3, 2)) 119 | return K.conv2d_transpose(inputs, filter, output_shape, strides, padding=padding, data_format=data_format) 120 | 121 | 122 | def ifft(f): 123 | """Stub""" 124 | raise NotImplementedError(str(f)) 125 | 126 | 127 | def ifft2(f): 128 | """Stub""" 129 | raise NotImplementedError(str(f)) 130 | 131 | 132 | def conv_transpose_output_length(input_length, filter_size, padding, stride, dilation=1, output_padding=None): 133 | """Rearrange arguments for compatibility with conv_output_length.""" 134 | if dilation != 1: 135 | msg = f"Dilation must be 1 for transposed convolution. " 136 | msg += f"Got dilation = {dilation}" 137 | raise ValueError(msg) 138 | # return conv_utils.deconv_length( 139 | # input_length, # dim_size 140 | # stride, # stride_size 141 | # filter_size, # kernel_size 142 | # padding, # padding 143 | # output_padding, # output_padding 144 | # ) 145 | return conv_utils.deconv_output_length( 146 | input_length, 147 | filter_size, 148 | padding, 149 | output_padding=output_padding, 150 | stride=stride, 151 | dilation=dilation, 152 | ) 153 | 154 | 155 | def sanitizedInitGet(init): 156 | """sanitizedInitGet""" 157 | if init in ["sqrt_init"]: 158 | return sqrt_init 159 | elif init in ["complex", "complex_independent", "glorot_complex", "he_complex"]: 160 | return init 161 | else: 162 | return initializers.get(init) 163 | 164 | 165 | def sanitizedInitSer(init): 166 | """sanitizedInitSer""" 167 | if init in [sqrt_init]: 168 | return "sqrt_init" 169 | elif init == "complex" or isinstance(init, ComplexInit): 170 | return "complex" 171 | elif init == "complex_independent" or isinstance(init, ComplexIndependentFilters): 172 | return "complex_independent" 173 | else: 174 | return initializers.serialize(init) 175 | 176 | 177 | class ComplexConv(Layer): 178 | """Abstract nD complex convolution layer. 179 | 180 | This layer creates a complex convolution kernel that is convolved with the 181 | layer input to produce a tensor of outputs. If `use_bias` is True, a bias 182 | vector is created and added to the outputs. Finally, if `activation` is not 183 | `None`, it is applied to the outputs as well. 184 | 185 | Arguments: 186 | rank: Integer, the rank of the convolution, e.g., "2" for 2D 187 | convolution. 188 | filters: Integer, the dimensionality of the output space, i.e., the 189 | number of complex feature maps. It is also the effective number of 190 | feature maps for each of the real and imaginary parts. (I.e., the 191 | number of complex filters in the convolution) The total effective 192 | number of filters is 2 x filters. 193 | kernel_size: An integer or tuple/list of n integers, specifying the 194 | dimensions of the convolution window. 195 | strides: An integer or tuple/list of n integers, specifying the strides 196 | of the convolution. Specifying any stride value != 1 is 197 | incompatible with specifying any `dilation_rate` value != 1. 198 | padding: One of `"valid"` or `"same"` (case-insensitive). 199 | data_format: A string, one of `channels_last` (default) or 200 | `channels_first`. The ordering of the dimensions in the inputs. 201 | `channels_last` corresponds to inputs with shape 202 | `(batch, ..., channels)` while `channels_first` corresponds to 203 | inputs with shape `(batch, channels, ...)`. It defaults to the 204 | `image_data_format` value found in your Keras config file at 205 | `~/.keras/keras.json`. If you never set it, then it will be 206 | "channels_last". 207 | dilation_rate: An integer or tuple/list of n integers, specifying 208 | the dilation rate to use for dilated convolution. Currently, 209 | specifying any `dilation_rate` value != 1 is incompatible with 210 | specifying any `strides` value != 1. 211 | activation: Activation function to use (see keras.activations). If you 212 | don't specify anything, no activation is applied (i.e., "linear" 213 | activation: `a(x) = x`). 214 | use_bias: Boolean, whether the layer uses a bias vector. 215 | normalize_weight: Boolean, whether the layer normalizes its complex 216 | weights before convolving the complex input. The complex 217 | normalization performed is similar to the one for the batchnorm. 218 | Each of the complex kernels is centred and multiplied by the 219 | inverse square root of the covariance matrix. Then a complex 220 | multiplication is performed as the normalized weights are 221 | multiplied by the complex scaling factor gamma. 222 | kernel_initializer: Initializer for the complex `kernel` weights 223 | matrix. By default it is 'complex'. The 'complex_independent' 224 | and the usual initializers could also be used. (See 225 | keras.initializers and init.py). 226 | bias_initializer: Initializer for the bias vector 227 | (see keras.initializers). 228 | kernel_regularizer: Regularizer function applied to the `kernel` 229 | weights matrix (see keras.regularizers). 230 | bias_regularizer: Regularizer function applied to the bias vector 231 | (see keras.regularizers). 232 | activity_regularizer: Regularizer function applied to the output of the 233 | layer (its "activation"). (See keras.regularizers). 234 | kernel_constraint: Constraint function applied to the kernel matrix 235 | (see keras.constraints). 236 | bias_constraint: Constraint function applied to the bias vector 237 | (see keras.constraints). 238 | spectral_parametrization: Boolean, whether or not to use a spectral 239 | parametrization of the parameters. 240 | transposed: Boolean, whether or not to use transposed convolution 241 | """ 242 | 243 | def __init__( 244 | self, 245 | rank, 246 | filters, 247 | kernel_size, 248 | strides=1, 249 | padding="valid", 250 | data_format=None, 251 | dilation_rate=1, 252 | activation=None, 253 | use_bias=True, 254 | normalize_weight=False, 255 | kernel_initializer="complex", 256 | bias_initializer="zeros", 257 | gamma_diag_initializer=sqrt_init, 258 | gamma_off_initializer="zeros", 259 | kernel_regularizer=None, 260 | bias_regularizer=None, 261 | gamma_diag_regularizer=None, 262 | gamma_off_regularizer=None, 263 | activity_regularizer=None, 264 | kernel_constraint=None, 265 | bias_constraint=None, 266 | gamma_diag_constraint=None, 267 | gamma_off_constraint=None, 268 | init_criterion="he", 269 | seed=None, 270 | spectral_parametrization=False, 271 | transposed=False, 272 | epsilon=1e-7, 273 | **kwargs, 274 | ): 275 | super(ComplexConv, self).__init__(**kwargs) 276 | self.rank = rank 277 | self.filters = filters 278 | self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, "kernel_size") 279 | self.strides = conv_utils.normalize_tuple(strides, rank, "strides") 280 | self.padding = conv_utils.normalize_padding(padding) 281 | self.data_format = "channels_last" if rank == 1 else conv_utils.normalize_data_format(data_format) 282 | self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, "dilation_rate") 283 | self.activation = activations.get(activation) 284 | self.use_bias = use_bias 285 | self.normalize_weight = normalize_weight 286 | self.init_criterion = init_criterion 287 | self.spectral_parametrization = spectral_parametrization 288 | self.transposed = transposed 289 | self.epsilon = epsilon 290 | self.kernel_initializer = sanitizedInitGet(kernel_initializer) 291 | self.bias_initializer = sanitizedInitGet(bias_initializer) 292 | self.gamma_diag_initializer = sanitizedInitGet(gamma_diag_initializer) 293 | self.gamma_off_initializer = sanitizedInitGet(gamma_off_initializer) 294 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 295 | self.bias_regularizer = regularizers.get(bias_regularizer) 296 | self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) 297 | self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) 298 | self.activity_regularizer = regularizers.get(activity_regularizer) 299 | self.kernel_constraint = constraints.get(kernel_constraint) 300 | self.bias_constraint = constraints.get(bias_constraint) 301 | self.gamma_diag_constraint = constraints.get(gamma_diag_constraint) 302 | self.gamma_off_constraint = constraints.get(gamma_off_constraint) 303 | if seed is None: 304 | self.seed = np.random.randint(1, 10e6) 305 | else: 306 | self.seed = seed 307 | self.input_spec = InputSpec(ndim=self.rank + 2) 308 | 309 | # The following are initialized later 310 | self.kernel_shape = None 311 | self.kernel = None 312 | self.gamma_rr = None 313 | self.gamma_ii = None 314 | self.gamma_ri = None 315 | self.bias = None 316 | 317 | def build(self, input_shape): 318 | """build""" 319 | if self.data_format == "channels_first": 320 | channel_axis = 1 321 | else: 322 | channel_axis = -1 323 | if input_shape[channel_axis] is None: 324 | raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.") 325 | # Divide by 2 for real and complex input. 326 | input_dim = input_shape[channel_axis] // 2 327 | if False and self.transposed: 328 | self.kernel_shape = self.kernel_size + (self.filters, input_dim) 329 | else: 330 | self.kernel_shape = self.kernel_size + (input_dim, self.filters) 331 | # The kernel shape here is a complex kernel shape: 332 | # nb of complex feature maps = input_dim; 333 | # nb of output complex feature maps = self.filters; 334 | # imaginary kernel size = real kernel size 335 | # = self.kernel_size 336 | # = complex kernel size 337 | if self.kernel_initializer in {"complex", "complex_independent"}: 338 | kls = { 339 | "complex": ComplexInit, 340 | "complex_independent": ComplexIndependentFilters, 341 | }[self.kernel_initializer] 342 | kern_init = kls( 343 | kernel_size=self.kernel_size, 344 | input_dim=input_dim, 345 | weight_dim=self.rank, 346 | nb_filters=self.filters, 347 | criterion=self.init_criterion, 348 | ) 349 | else: 350 | kern_init = self.kernel_initializer 351 | 352 | # Fix for 'ValueError: The initial value's shape ((3, 3, 1, 8)) is not compatible with the explicitly supplied `shape` argument' 353 | actual_kernel_shape = list(self.kernel_shape) 354 | actual_kernel_shape[-1] *= 2 355 | 356 | self.kernel = self.add_weight( 357 | "kernel", 358 | shape=tuple(actual_kernel_shape), 359 | initializer=kern_init, 360 | regularizer=self.kernel_regularizer, 361 | constraint=self.kernel_constraint, 362 | ) 363 | 364 | if self.normalize_weight: 365 | gamma_shape = (input_dim * self.filters,) 366 | self.gamma_rr = self.add_weight( 367 | shape=gamma_shape, 368 | name="gamma_rr", 369 | initializer=self.gamma_diag_initializer, 370 | regularizer=self.gamma_diag_regularizer, 371 | constraint=self.gamma_diag_constraint, 372 | ) 373 | self.gamma_ii = self.add_weight( 374 | shape=gamma_shape, 375 | name="gamma_ii", 376 | initializer=self.gamma_diag_initializer, 377 | regularizer=self.gamma_diag_regularizer, 378 | constraint=self.gamma_diag_constraint, 379 | ) 380 | self.gamma_ri = self.add_weight( 381 | shape=gamma_shape, 382 | name="gamma_ri", 383 | initializer=self.gamma_off_initializer, 384 | regularizer=self.gamma_off_regularizer, 385 | constraint=self.gamma_off_constraint, 386 | ) 387 | else: 388 | self.gamma_rr = None 389 | self.gamma_ii = None 390 | self.gamma_ri = None 391 | 392 | if self.use_bias: 393 | bias_shape = (2 * self.filters,) 394 | self.bias = self.add_weight( 395 | "bias", 396 | bias_shape, 397 | initializer=self.bias_initializer, 398 | regularizer=self.bias_regularizer, 399 | constraint=self.bias_constraint, 400 | ) 401 | 402 | else: 403 | self.bias = None 404 | 405 | # Set input spec. 406 | self.input_spec = InputSpec(ndim=self.rank + 2, axes={channel_axis: input_dim * 2}) 407 | self.built = True 408 | 409 | def call(self, inputs, **kwargs): 410 | if self.data_format == "channels_first": 411 | channel_axis = 1 412 | else: 413 | channel_axis = -1 414 | input_dim = K.shape(inputs)[channel_axis] // 2 415 | if False and self.transposed: 416 | if self.rank == 1: 417 | f_real = self.kernel[:, : self.filters, :] 418 | f_imag = self.kernel[:, self.filters :, :] 419 | elif self.rank == 2: 420 | f_real = self.kernel[:, :, : self.filters, :] 421 | f_imag = self.kernel[:, :, self.filters :, :] 422 | elif self.rank == 3: 423 | f_real = self.kernel[:, :, :, : self.filters, :] 424 | f_imag = self.kernel[:, :, :, self.filters :, :] 425 | else: 426 | if self.rank == 1: 427 | f_real = self.kernel[:, :, : self.filters] 428 | f_imag = self.kernel[:, :, self.filters :] 429 | elif self.rank == 2: 430 | f_real = self.kernel[:, :, :, : self.filters] 431 | f_imag = self.kernel[:, :, :, self.filters :] 432 | elif self.rank == 3: 433 | f_real = self.kernel[:, :, :, :, : self.filters] 434 | f_imag = self.kernel[:, :, :, :, self.filters :] 435 | 436 | convArgs = { 437 | "strides": self.strides[0] if self.rank == 1 else self.strides, 438 | "padding": self.padding, 439 | "data_format": self.data_format, 440 | "dilation_rate": self.dilation_rate[0] if self.rank == 1 else self.dilation_rate, 441 | } 442 | if self.transposed: 443 | convArgs.pop("dilation_rate", None) 444 | convArgs["kernel_size"] = self.kernel_size 445 | convArgs["filters"] = 2 * self.filters 446 | convFunc = {1: conv1d_transpose, 2: conv2d_transpose}[self.rank] 447 | else: 448 | convFunc = {1: K.conv1d, 2: K.conv2d, 3: K.conv3d}[self.rank] 449 | 450 | # processing if the weights are assumed to be represented in the 451 | # spectral domain 452 | 453 | if self.spectral_parametrization: 454 | if self.rank == 1: 455 | f_real = K.permute_dimensions(f_real, (2, 1, 0)) 456 | f_imag = K.permute_dimensions(f_imag, (2, 1, 0)) 457 | f = K.concatenate([f_real, f_imag], axis=0) 458 | fshape = K.shape(f) 459 | f = K.reshape(f, (fshape[0] * fshape[1], fshape[2])) 460 | f = ifft(f) 461 | f = K.reshape(f, fshape) 462 | f_real = f[: fshape[0] // 2] 463 | f_imag = f[fshape[0] // 2 :] 464 | f_real = K.permute_dimensions(f_real, (2, 1, 0)) 465 | f_imag = K.permute_dimensions(f_imag, (2, 1, 0)) 466 | elif self.rank == 2: 467 | f_real = K.permute_dimensions(f_real, (3, 2, 0, 1)) 468 | f_imag = K.permute_dimensions(f_imag, (3, 2, 0, 1)) 469 | f = K.concatenate([f_real, f_imag], axis=0) 470 | fshape = K.shape(f) 471 | f = K.reshape(f, (fshape[0] * fshape[1], fshape[2], fshape[3])) 472 | f = ifft2(f) 473 | f = K.reshape(f, fshape) 474 | f_real = f[: fshape[0] // 2] 475 | f_imag = f[fshape[0] // 2 :] 476 | f_real = K.permute_dimensions(f_real, (2, 3, 1, 0)) 477 | f_imag = K.permute_dimensions(f_imag, (2, 3, 1, 0)) 478 | 479 | # In case of weight normalization, real and imaginary weights are 480 | # normalized 481 | 482 | if self.normalize_weight: 483 | ker_shape = self.kernel_shape 484 | nb_kernels = ker_shape[-2] * ker_shape[-1] 485 | kernel_shape_4_norm = (np.prod(self.kernel_size), nb_kernels) 486 | reshaped_f_real = K.reshape(f_real, kernel_shape_4_norm) 487 | reshaped_f_imag = K.reshape(f_imag, kernel_shape_4_norm) 488 | reduction_axes = list(range(2)) 489 | del reduction_axes[-1] 490 | mu_real = K.mean(reshaped_f_real, axis=reduction_axes) 491 | mu_imag = K.mean(reshaped_f_imag, axis=reduction_axes) 492 | 493 | broadcast_mu_shape = [1] * 2 494 | broadcast_mu_shape[-1] = nb_kernels 495 | broadcast_mu_real = K.reshape(mu_real, broadcast_mu_shape) 496 | broadcast_mu_imag = K.reshape(mu_imag, broadcast_mu_shape) 497 | reshaped_f_real_centred = reshaped_f_real - broadcast_mu_real 498 | reshaped_f_imag_centred = reshaped_f_imag - broadcast_mu_imag 499 | Vrr = K.mean(reshaped_f_real_centred**2, axis=reduction_axes) + self.epsilon 500 | Vii = K.mean(reshaped_f_imag_centred**2, axis=reduction_axes) + self.epsilon 501 | Vri = ( 502 | K.mean( 503 | reshaped_f_real_centred * reshaped_f_imag_centred, 504 | axis=reduction_axes, 505 | ) 506 | + self.epsilon 507 | ) 508 | 509 | normalized_weight = complex_normalization( 510 | K.concatenate([reshaped_f_real, reshaped_f_imag], axis=-1), 511 | Vrr, 512 | Vii, 513 | Vri, 514 | beta=None, 515 | gamma_rr=self.gamma_rr, 516 | gamma_ri=self.gamma_ri, 517 | gamma_ii=self.gamma_ii, 518 | scale=True, 519 | center=False, 520 | axis=-1, 521 | ) 522 | 523 | normalized_real = normalized_weight[:, :nb_kernels] 524 | normalized_imag = normalized_weight[:, nb_kernels:] 525 | f_real = K.reshape(normalized_real, self.kernel_shape) 526 | f_imag = K.reshape(normalized_imag, self.kernel_shape) 527 | 528 | # Performing complex convolution 529 | 530 | f_real._keras_shape = self.kernel_shape 531 | f_imag._keras_shape = self.kernel_shape 532 | 533 | cat_kernels_4_real = K.concatenate([f_real, -f_imag], axis=-2) 534 | cat_kernels_4_imag = K.concatenate([f_imag, f_real], axis=-2) 535 | cat_kernels_4_complex = K.concatenate([cat_kernels_4_real, cat_kernels_4_imag], axis=-1) 536 | if False and self.transposed: 537 | cat_kernels_4_complex._keras_shape = self.kernel_size + ( 538 | 2 * self.filters, 539 | 2 * input_dim, 540 | ) 541 | else: 542 | cat_kernels_4_complex._keras_shape = ( 543 | self.kernel_size + 2 * input_dim, 544 | 2 * self.filters, 545 | ) 546 | 547 | output = convFunc(inputs, cat_kernels_4_complex, **convArgs) 548 | 549 | if self.use_bias: 550 | output = K.bias_add(output, self.bias, data_format=self.data_format) 551 | 552 | if self.activation is not None: 553 | output = self.activation(output) 554 | 555 | return output 556 | 557 | def compute_output_shape(self, input_shape): 558 | if self.transposed: 559 | outputLengthFunc = conv_transpose_output_length 560 | else: 561 | outputLengthFunc = conv_utils.conv_output_length 562 | if self.data_format == "channels_last": 563 | space = input_shape[1:-1] 564 | new_space = [] 565 | for i in range(len(space)): 566 | new_dim = outputLengthFunc( 567 | space[i], 568 | self.kernel_size[i], 569 | padding=self.padding, 570 | stride=self.strides[i], 571 | dilation=self.dilation_rate[i], 572 | ) 573 | new_space.append(new_dim) 574 | return (input_shape[0],) + tuple(new_space) + (2 * self.filters,) 575 | if self.data_format == "channels_first": 576 | space = input_shape[2:] 577 | new_space = [] 578 | for i in range(len(space)): 579 | new_dim = outputLengthFunc( 580 | space[i], 581 | self.kernel_size[i], 582 | padding=self.padding, 583 | stride=self.strides[i], 584 | dilation=self.dilation_rate[i], 585 | ) 586 | new_space.append(new_dim) 587 | return (input_shape[0],) + (2 * self.filters,) + tuple(new_space) 588 | 589 | def get_config(self): 590 | config = { 591 | "rank": self.rank, 592 | "filters": self.filters, 593 | "kernel_size": self.kernel_size, 594 | "strides": self.strides, 595 | "padding": self.padding, 596 | "data_format": self.data_format, 597 | "dilation_rate": self.dilation_rate, 598 | "activation": activations.serialize(self.activation), 599 | "use_bias": self.use_bias, 600 | "normalize_weight": self.normalize_weight, 601 | "kernel_initializer": sanitizedInitSer(self.kernel_initializer), 602 | "bias_initializer": sanitizedInitSer(self.bias_initializer), 603 | "gamma_diag_initializer": sanitizedInitSer(self.gamma_diag_initializer), 604 | "gamma_off_initializer": sanitizedInitSer(self.gamma_off_initializer), 605 | "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), 606 | "bias_regularizer": regularizers.serialize(self.bias_regularizer), 607 | "gamma_diag_regularizer": regularizers.serialize(self.gamma_diag_regularizer), 608 | "gamma_off_regularizer": regularizers.serialize(self.gamma_off_regularizer), 609 | "activity_regularizer": regularizers.serialize(self.activity_regularizer), 610 | "kernel_constraint": constraints.serialize(self.kernel_constraint), 611 | "bias_constraint": constraints.serialize(self.bias_constraint), 612 | "gamma_diag_constraint": constraints.serialize(self.gamma_diag_constraint), 613 | "gamma_off_constraint": constraints.serialize(self.gamma_off_constraint), 614 | "init_criterion": self.init_criterion, 615 | "spectral_parametrization": self.spectral_parametrization, 616 | "transposed": self.transposed, 617 | } 618 | base_config = super(ComplexConv, self).get_config() 619 | return dict(list(base_config.items()) + list(config.items())) 620 | 621 | 622 | class ComplexConv1D(ComplexConv): 623 | """1D complex convolution layer. 624 | This layer creates a complex convolution kernel that is convolved 625 | with a complex input layer over a single complex spatial (or temporal) 626 | dimension 627 | to produce a complex output tensor. 628 | If `use_bias` is True, a bias vector is created and added to the complex 629 | output. 630 | Finally, if `activation` is not `None`, 631 | it is applied each of the real and imaginary parts of the output. 632 | When using this layer as the first layer in a model, 633 | provide an `input_shape` argument 634 | (tuple of integers or `None`, e.g. 635 | `(10, 128)` for sequences of 10 vectors of 128-dimensional vectors, 636 | or `(None, 128)` for variable-length sequences of 128-dimensional vectors. 637 | # Arguments 638 | filters: Integer, the dimensionality of the output space, i.e, 639 | the number of complex feature maps. It is also the effective number 640 | of feature maps for each of the real and imaginary parts. 641 | (i.e. the number of complex filters in the convolution) 642 | The total effective number of filters is 2 x filters. 643 | kernel_size: An integer or tuple/list of n integers, specifying the 644 | dimensions of the convolution window. 645 | strides: An integer or tuple/list of a single integer, 646 | specifying the stride length of the convolution. 647 | Specifying any stride value != 1 is incompatible with specifying 648 | any `dilation_rate` value != 1. 649 | padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive). 650 | `"causal"` results in causal (dilated) convolutions, e.g. output[t] 651 | does not depend on input[t+1:]. Useful when modeling temporal data 652 | where the model should not violate the temporal order. 653 | See [WaveNet: A Generative Model for Raw Audio, section 2.1] 654 | (https://arxiv.org/abs/1609.03499). 655 | dilation_rate: an integer or tuple/list of a single integer, specifying 656 | the dilation rate to use for dilated convolution. 657 | Currently, specifying any `dilation_rate` value != 1 is 658 | incompatible with specifying any `strides` value != 1. 659 | activation: Activation function to use 660 | (see keras.activations). 661 | If you don't specify anything, no activation is applied 662 | (ie. "linear" activation: `a(x) = x`). 663 | use_bias: Boolean, whether the layer uses a bias vector. 664 | normalize_weight: Boolean, whether the layer normalizes its complex 665 | weights before convolving the complex input. 666 | The complex normalization performed is similar to the one 667 | for the batchnorm. Each of the complex kernels are centred and 668 | multiplied by 669 | the inverse square root of covariance matrix. 670 | Then, a complex multiplication is perfromed as the normalized 671 | weights are 672 | multiplied by the complex scaling factor gamma. 673 | kernel_initializer: Initializer for the complex `kernel` weights 674 | matrix. 675 | By default it is 'complex'. The 'complex_independent' 676 | and the usual initializers could also be used. 677 | (see keras.initializers and init.py). 678 | bias_initializer: Initializer for the bias vector 679 | (see keras.initializers). 680 | kernel_regularizer: Regularizer function applied to 681 | the `kernel` weights matrix 682 | (see keras.regularizers). 683 | bias_regularizer: Regularizer function applied to the bias vector 684 | (see keras.regularizers). 685 | activity_regularizer: Regularizer function applied to 686 | the output of the layer (its "activation"). 687 | (see keras.regularizers). 688 | kernel_constraint: Constraint function applied to the kernel matrix 689 | (see keras.constraints). 690 | bias_constraint: Constraint function applied to the bias vector 691 | (see keras.constraints). 692 | spectral_parametrization: Whether or not to use a spectral 693 | parametrization of the parameters. 694 | transposed: Boolean, whether or not to use transposed convolution 695 | # Input shape 696 | 3D tensor with shape: `(batch_size, steps, input_dim)` 697 | # Output shape 698 | 3D tensor with shape: `(batch_size, new_steps, 2 x filters)` 699 | `steps` value might have changed due to padding or strides. 700 | """ 701 | 702 | def __init__( 703 | self, 704 | filters, 705 | kernel_size, 706 | strides=1, 707 | padding="valid", 708 | dilation_rate=1, 709 | activation=None, 710 | use_bias=True, 711 | kernel_initializer="complex", 712 | bias_initializer="zeros", 713 | kernel_regularizer=None, 714 | bias_regularizer=None, 715 | activity_regularizer=None, 716 | kernel_constraint=None, 717 | bias_constraint=None, 718 | seed=None, 719 | init_criterion="he", 720 | spectral_parametrization=False, 721 | transposed=False, 722 | **kwargs, 723 | ): 724 | super(ComplexConv1D, self).__init__( 725 | rank=1, 726 | filters=filters, 727 | kernel_size=kernel_size, 728 | strides=strides, 729 | padding=padding, 730 | data_format="channels_last", 731 | dilation_rate=dilation_rate, 732 | activation=activation, 733 | use_bias=use_bias, 734 | kernel_initializer=kernel_initializer, 735 | bias_initializer=bias_initializer, 736 | kernel_regularizer=kernel_regularizer, 737 | bias_regularizer=bias_regularizer, 738 | activity_regularizer=activity_regularizer, 739 | kernel_constraint=kernel_constraint, 740 | bias_constraint=bias_constraint, 741 | init_criterion=init_criterion, 742 | spectral_parametrization=spectral_parametrization, 743 | transposed=transposed, 744 | **kwargs, 745 | ) 746 | 747 | def get_config(self): 748 | config = super(ComplexConv1D, self).get_config() 749 | config.pop("rank") 750 | return config 751 | 752 | 753 | class ComplexConv2D(ComplexConv): 754 | """2D Complex convolution layer (e.g. spatial convolution over images). 755 | This layer creates a complex convolution kernel that is convolved 756 | with a complex input layer to produce a complex output tensor. If 757 | `use_bias` 758 | is True, a complex bias vector is created and added to the outputs. 759 | Finally, if `activation` is not `None`, it is applied to both the 760 | real and imaginary parts of the output. 761 | When using this layer as the first layer in a model, 762 | provide the keyword argument `input_shape` 763 | (tuple of integers, does not include the sample axis), 764 | e.g. `input_shape=(128, 128, 3)` for 128x128 RGB pictures 765 | in `data_format="channels_last"`. 766 | # Arguments 767 | filters: Integer, the dimensionality of the complex output space 768 | (i.e, the number complex feature maps in the convolution). The 769 | total effective number of filters or feature maps is 2 x filters. 770 | kernel_size: An integer or tuple/list of 2 integers, specifying the 771 | width and height of the 2D convolution window. 772 | Can be a single integer to specify the same value for 773 | all spatial dimensions. 774 | strides: An integer or tuple/list of 2 integers, 775 | specifying the strides of the convolution along the width and 776 | height. 777 | Can be a single integer to specify the same value for 778 | all spatial dimensions. 779 | Specifying any stride value != 1 is incompatible with specifying 780 | any `dilation_rate` value != 1. 781 | padding: one of `"valid"` or `"same"` (case-insensitive). 782 | data_format: A string, 783 | one of `channels_last` (default) or `channels_first`. 784 | The ordering of the dimensions in the inputs. 785 | `channels_last` corresponds to inputs with shape 786 | `(batch, height, width, channels)` while `channels_first` 787 | corresponds to inputs with shape 788 | `(batch, channels, height, width)`. 789 | It defaults to the `image_data_format` value found in your 790 | Keras config file at `~/.keras/keras.json`. 791 | If you never set it, then it will be "channels_last". 792 | dilation_rate: an integer or tuple/list of 2 integers, specifying 793 | the dilation rate to use for dilated convolution. 794 | Can be a single integer to specify the same value for 795 | all spatial dimensions. 796 | Currently, specifying any `dilation_rate` value != 1 is 797 | incompatible with specifying any stride value != 1. 798 | activation: Activation function to use 799 | (see keras.activations). 800 | If you don't specify anything, no activation is applied 801 | (ie. "linear" activation: `a(x) = x`). 802 | use_bias: Boolean, whether the layer uses a bias vector. 803 | normalize_weight: Boolean, whether the layer normalizes its complex 804 | weights before convolving the complex input. 805 | The complex normalization performed is similar to the one 806 | for the batchnorm. Each of the complex kernels are centred and 807 | multiplied by 808 | the inverse square root of covariance matrix. 809 | Then, a complex multiplication is perfromed as the normalized 810 | weights are 811 | multiplied by the complex scaling factor gamma. 812 | kernel_initializer: Initializer for the complex `kernel` weights 813 | matrix. 814 | By default it is 'complex'. The 'complex_independent' 815 | and the usual initializers could also be used. 816 | (see keras.initializers and init.py). 817 | bias_initializer: Initializer for the bias vector 818 | (see keras.initializers). 819 | kernel_regularizer: Regularizer function applied to 820 | the `kernel` weights matrix 821 | (see keras.regularizers). 822 | bias_regularizer: Regularizer function applied to the bias vector 823 | (see keras.regularizers). 824 | activity_regularizer: Regularizer function applied to 825 | the output of the layer (its "activation"). 826 | (see keras.regularizers). 827 | kernel_constraint: Constraint function applied to the kernel matrix 828 | (see keras.constraints). 829 | bias_constraint: Constraint function applied to the bias vector 830 | (see keras.constraints). 831 | spectral_parametrization: Whether or not to use a spectral 832 | parametrization of the parameters. 833 | transposed: Boolean, whether or not to use transposed convolution 834 | # Input shape 835 | 4D tensor with shape: 836 | `(samples, channels, rows, cols)` if data_format='channels_first' 837 | or 4D tensor with shape: 838 | `(samples, rows, cols, channels)` if data_format='channels_last'. 839 | # Output shape 840 | 4D tensor with shape: 841 | `(samples, 2 x filters, new_rows, new_cols)` if 842 | data_format='channels_first' or 4D tensor with shape: 843 | `(samples, new_rows, new_cols, 2 x filters)` if 844 | data_format='channels_last'. `rows` and `cols` values might have 845 | changed due to padding. 846 | """ 847 | 848 | def __init__( 849 | self, 850 | filters, 851 | kernel_size, 852 | strides=(1, 1), 853 | padding="valid", 854 | data_format=None, 855 | dilation_rate=(1, 1), 856 | activation=None, 857 | use_bias=True, 858 | kernel_initializer="complex", 859 | bias_initializer="zeros", 860 | kernel_regularizer=None, 861 | bias_regularizer=None, 862 | activity_regularizer=None, 863 | kernel_constraint=None, 864 | bias_constraint=None, 865 | seed=None, 866 | init_criterion="he", 867 | spectral_parametrization=False, 868 | transposed=False, 869 | **kwargs, 870 | ): 871 | super(ComplexConv2D, self).__init__( 872 | rank=2, 873 | filters=filters, 874 | kernel_size=kernel_size, 875 | strides=strides, 876 | padding=padding, 877 | data_format=data_format, 878 | dilation_rate=dilation_rate, 879 | activation=activation, 880 | use_bias=use_bias, 881 | kernel_initializer=kernel_initializer, 882 | bias_initializer=bias_initializer, 883 | kernel_regularizer=kernel_regularizer, 884 | bias_regularizer=bias_regularizer, 885 | activity_regularizer=activity_regularizer, 886 | kernel_constraint=kernel_constraint, 887 | bias_constraint=bias_constraint, 888 | init_criterion=init_criterion, 889 | spectral_parametrization=spectral_parametrization, 890 | transposed=transposed, 891 | **kwargs, 892 | ) 893 | 894 | def get_config(self): 895 | config = super(ComplexConv2D, self).get_config() 896 | config.pop("rank") 897 | return config 898 | 899 | 900 | class ComplexConv3D(ComplexConv): 901 | """3D convolution layer (e.g. spatial convolution over volumes). 902 | This layer creates a complex convolution kernel that is convolved 903 | with a complex layer input to produce a complex output tensor. 904 | If `use_bias` is True, 905 | a complex bias vector is created and added to the outputs. Finally, if 906 | `activation` is not `None`, it is applied to each of the real and imaginary 907 | parts of the output. 908 | When using this layer as the first layer in a model, 909 | provide the keyword argument `input_shape` 910 | (tuple of integers, does not include the sample axis), 911 | e.g. `input_shape=(2, 128, 128, 128, 3)` for 128x128x128 volumes 912 | with 3 channels, 913 | in `data_format="channels_last"`. 914 | # Arguments 915 | filters: Integer, the dimensionality of the complex output space 916 | (i.e, the number complex feature maps in the convolution). The 917 | total effective number of filters or feature maps is 2 x filters. 918 | kernel_size: An integer or tuple/list of 3 integers, specifying the 919 | width and height of the 3D convolution window. 920 | Can be a single integer to specify the same value for 921 | all spatial dimensions. 922 | strides: An integer or tuple/list of 3 integers, specifying 923 | the strides of the convolution along each spatial dimension. 924 | Can be a single integer to specify the same value for 925 | all spatial dimensions. 926 | Specifying any stride value != 1 is incompatible with specifying 927 | any `dilation_rate` value != 1. 928 | padding: one of `"valid"` or `"same"` (case-insensitive). 929 | data_format: A string, 930 | one of `channels_last` (default) or `channels_first`. 931 | The ordering of the dimensions in the inputs. 932 | `channels_last` corresponds to inputs with shape 933 | `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 934 | while `channels_first` corresponds to inputs with shape 935 | `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 936 | It defaults to the `image_data_format` value found in your 937 | Keras config file at `~/.keras/keras.json`. 938 | If you never set it, then it will be "channels_last". 939 | dilation_rate: an integer or tuple/list of 3 integers, specifying 940 | the dilation rate to use for dilated convolution. 941 | Can be a single integer to specify the same value for 942 | all spatial dimensions. 943 | Currently, specifying any `dilation_rate` value != 1 is 944 | incompatible with specifying any stride value != 1. 945 | activation: Activation function to use 946 | (see keras.activations). 947 | If you don't specify anything, no activation is applied 948 | (ie. "linear" activation: `a(x) = x`). 949 | use_bias: Boolean, whether the layer uses a bias vector. 950 | normalize_weight: Boolean, whether the layer normalizes its complex 951 | weights before convolving the complex input. 952 | The complex normalization performed is similar to the one 953 | for the batchnorm. Each of the complex kernels are centred and 954 | multiplied by 955 | the inverse square root of covariance matrix. 956 | Then, a complex multiplication is perfromed as the normalized 957 | weights are 958 | multiplied by the complex scaling factor gamma. 959 | kernel_initializer: Initializer for the complex `kernel` weights 960 | matrix. 961 | By default it is 'complex'. The 'complex_independent' 962 | and the usual initializers could also be used. 963 | (see keras.initializers and init.py). 964 | bias_initializer: Initializer for the bias vector 965 | (see keras.initializers). 966 | kernel_regularizer: Regularizer function applied to 967 | the `kernel` weights matrix 968 | (see keras.regularizers). 969 | bias_regularizer: Regularizer function applied to the bias vector 970 | (see keras.regularizers). 971 | activity_regularizer: Regularizer function applied to 972 | the output of the layer (its "activation"). 973 | (see keras.regularizers). 974 | kernel_constraint: Constraint function applied to the kernel matrix 975 | (see keras.constraints). 976 | bias_constraint: Constraint function applied to the bias vector 977 | (see keras.constraints). 978 | spectral_parametrization: Whether or not to use a spectral 979 | parametrization of the parameters. 980 | transposed: Boolean, whether or not to use transposed convolution 981 | # Input shape 982 | 5D tensor with shape: 983 | `(samples, channels, conv_dim1, conv_dim2, conv_dim3)` if 984 | data_format='channels_first' 985 | or 5D tensor with shape: 986 | `(samples, conv_dim1, conv_dim2, conv_dim3, channels)` if 987 | data_format='channels_last'. 988 | # Output shape 989 | 5D tensor with shape: 990 | `(samples, 2 x filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` 991 | if data_format='channels_first' 992 | or 5D tensor with shape: 993 | `(samples, new_conv_dim1, new_conv_dim2, new_conv_dim3, 2 x filters)` 994 | if data_format='channels_last'. 995 | `new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have 996 | changed due to padding. 997 | """ 998 | 999 | def __init__( 1000 | self, 1001 | filters, 1002 | kernel_size, 1003 | strides=(1, 1, 1), 1004 | padding="valid", 1005 | data_format=None, 1006 | dilation_rate=(1, 1, 1), 1007 | activation=None, 1008 | use_bias=True, 1009 | kernel_initializer="complex", 1010 | bias_initializer="zeros", 1011 | kernel_regularizer=None, 1012 | bias_regularizer=None, 1013 | activity_regularizer=None, 1014 | kernel_constraint=None, 1015 | bias_constraint=None, 1016 | seed=None, 1017 | init_criterion="he", 1018 | spectral_parametrization=False, 1019 | transposed=False, 1020 | **kwargs, 1021 | ): 1022 | super(ComplexConv3D, self).__init__( 1023 | rank=3, 1024 | filters=filters, 1025 | kernel_size=kernel_size, 1026 | strides=strides, 1027 | padding=padding, 1028 | data_format=data_format, 1029 | dilation_rate=dilation_rate, 1030 | activation=activation, 1031 | use_bias=use_bias, 1032 | kernel_initializer=kernel_initializer, 1033 | bias_initializer=bias_initializer, 1034 | kernel_regularizer=kernel_regularizer, 1035 | bias_regularizer=bias_regularizer, 1036 | activity_regularizer=activity_regularizer, 1037 | kernel_constraint=kernel_constraint, 1038 | bias_constraint=bias_constraint, 1039 | init_criterion=init_criterion, 1040 | spectral_parametrization=spectral_parametrization, 1041 | transposed=transposed, 1042 | **kwargs, 1043 | ) 1044 | 1045 | def get_config(self): 1046 | config = super(ComplexConv3D, self).get_config() 1047 | config.pop("rank") 1048 | return config 1049 | 1050 | 1051 | class WeightNorm_Conv(Conv): 1052 | """WeightNorm_Conv""" 1053 | 1054 | # Real-valued Convolutional Layer that normalizes its weights 1055 | # before convolving the input. 1056 | # The weight Normalization performed the one 1057 | # described in the following paper: 1058 | # Weight Normalization: A Simple Reparameterization to Accelerate Training 1059 | # of Deep Neural Networks 1060 | # (see https://arxiv.org/abs/1602.07868) 1061 | 1062 | def __init__( 1063 | self, 1064 | gamma_initializer="ones", 1065 | gamma_regularizer=None, 1066 | gamma_constraint=None, 1067 | epsilon=1e-07, 1068 | **kwargs, 1069 | ): 1070 | super(WeightNorm_Conv, self).__init__(**kwargs) 1071 | if self.rank == 1: 1072 | self.data_format = "channels_last" 1073 | self.gamma_initializer = sanitizedInitGet(gamma_initializer) 1074 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 1075 | self.gamma_constraint = constraints.get(gamma_constraint) 1076 | self.epsilon = epsilon 1077 | self.gamma = None 1078 | 1079 | def build(self, input_shape): 1080 | super(WeightNorm_Conv, self).build(input_shape) 1081 | if self.data_format == "channels_first": 1082 | channel_axis = 1 1083 | else: 1084 | channel_axis = -1 1085 | if input_shape[channel_axis] is None: 1086 | raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.") 1087 | input_dim = input_shape[channel_axis] 1088 | gamma_shape = (input_dim * self.filters,) 1089 | self.gamma = self.add_weight( 1090 | shape=gamma_shape, 1091 | name="gamma", 1092 | initializer=self.gamma_initializer, 1093 | regularizer=self.gamma_regularizer, 1094 | constraint=self.gamma_constraint, 1095 | ) 1096 | 1097 | def call(self, inputs): 1098 | input_shape = K.shape(inputs) 1099 | if self.data_format == "channels_first": 1100 | channel_axis = 1 1101 | else: 1102 | channel_axis = -1 1103 | if input_shape[channel_axis] is None: 1104 | raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.") 1105 | input_dim = input_shape[channel_axis] 1106 | ker_shape = self.kernel_size + (input_dim, self.filters) 1107 | nb_kernels = ker_shape[-2] * ker_shape[-1] 1108 | kernel_shape_4_norm = (np.prod(self.kernel_size), nb_kernels) 1109 | reshaped_kernel = K.reshape(self.kernel, kernel_shape_4_norm) 1110 | normalized_weight = K.l2_normalize(reshaped_kernel, axis=0, epsilon=self.epsilon) 1111 | normalized_weight = K.reshape(self.gamma, (1, ker_shape[-2] * ker_shape[-1])) * normalized_weight 1112 | shaped_kernel = K.reshape(normalized_weight, ker_shape) 1113 | shaped_kernel._keras_shape = ker_shape 1114 | 1115 | convArgs = { 1116 | "strides": self.strides[0] if self.rank == 1 else self.strides, 1117 | "padding": self.padding, 1118 | "data_format": self.data_format, 1119 | "dilation_rate": self.dilation_rate[0] if self.rank == 1 else self.dilation_rate, 1120 | } 1121 | convFunc = {1: K.conv1d, 2: K.conv2d, 3: K.conv3d}[self.rank] 1122 | output = convFunc(inputs, shaped_kernel, **convArgs) 1123 | 1124 | if self.use_bias: 1125 | output = K.bias_add(output, self.bias, data_format=self.data_format) 1126 | 1127 | if self.activation is not None: 1128 | output = self.activation(output) 1129 | 1130 | return output 1131 | 1132 | def get_config(self): 1133 | config = { 1134 | "gamma_initializer": sanitizedInitSer(self.gamma_initializer), 1135 | "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), 1136 | "gamma_constraint": constraints.serialize(self.gamma_constraint), 1137 | "epsilon": self.epsilon, 1138 | } 1139 | base_config = super(WeightNorm_Conv, self).get_config() 1140 | return dict(list(base_config.items()) + list(config.items())) 1141 | 1142 | 1143 | # Aliases 1144 | ComplexConvolution1D = ComplexConv1D 1145 | ComplexConvolution2D = ComplexConv2D 1146 | ComplexConvolution3D = ComplexConv3D 1147 | -------------------------------------------------------------------------------- /complexnn/dense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from tensorflow.keras import backend as K 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras import activations, initializers, regularizers, constraints 7 | from tensorflow.keras.layers import Layer, InputSpec 8 | import numpy as np 9 | from numpy.random import RandomState 10 | from .utils import _compute_fans 11 | 12 | 13 | class ComplexDense(Layer): 14 | """Regular complex densely-connected NN layer. 15 | `Dense` implements the operation: 16 | `real_preact = dot(real_input, real_kernel) - dot(imag_input, imag_kernel)` 17 | `imag_preact = dot(real_input, imag_kernel) + dot(imag_input, real_kernel)` 18 | `output = activation(K.concatenate([real_preact, imag_preact]) + bias)` 19 | where `activation` is the element-wise activation function 20 | passed as the `activation` argument, `kernel` is a weights matrix 21 | created by the layer, and `bias` is a bias vector created by the layer 22 | (only applicable if `use_bias` is `True`). 23 | Note: if the input to the layer has a rank greater than 2, then 24 | AN ERROR MESSAGE IS PRINTED. 25 | # Arguments 26 | units: Positive integer, dimensionality of each of the real part 27 | and the imaginary part. It is actualy the number of complex units. 28 | activation: Activation function to use 29 | (see keras.activations). 30 | If you don't specify anything, no activation is applied 31 | (ie. "linear" activation: `a(x) = x`). 32 | use_bias: Boolean, whether the layer uses a bias vector. 33 | kernel_initializer: Initializer for the complex `kernel` weights matrix. 34 | By default it is 'complex'. 35 | and the usual initializers could also be used. 36 | (see keras.initializers and init.py). 37 | bias_initializer: Initializer for the bias vector 38 | (see keras.initializers). 39 | kernel_regularizer: Regularizer function applied to 40 | the `kernel` weights matrix 41 | (see keras.regularizers). 42 | bias_regularizer: Regularizer function applied to the bias vector 43 | (see keras.regularizers). 44 | activity_regularizer: Regularizer function applied to 45 | the output of the layer (its "activation"). 46 | (see keras.regularizers). 47 | kernel_constraint: Constraint function applied to the kernel matrix 48 | (see keras.constraints). 49 | bias_constraint: Constraint function applied to the bias vector 50 | (see keras.constraints). 51 | # Input shape 52 | a 2D input with shape `(batch_size, input_dim)`. 53 | # Output shape 54 | For a 2D input with shape `(batch_size, input_dim)`, 55 | the output would have shape `(batch_size, units)`. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | units, 61 | activation=None, 62 | use_bias=True, 63 | init_criterion="he", 64 | kernel_initializer="complex", 65 | bias_initializer="zeros", 66 | kernel_regularizer=None, 67 | bias_regularizer=None, 68 | activity_regularizer=None, 69 | kernel_constraint=None, 70 | bias_constraint=None, 71 | seed=None, 72 | **kwargs 73 | ): 74 | if "input_shape" not in kwargs and "input_dim" in kwargs: 75 | kwargs["input_shape"] = (kwargs.pop("input_dim"),) 76 | super(ComplexDense, self).__init__(**kwargs) 77 | self.units = units 78 | self.activation = activations.get(activation) 79 | self.use_bias = use_bias 80 | self.init_criterion = init_criterion 81 | if kernel_initializer in {"complex"}: 82 | self.kernel_initializer = kernel_initializer 83 | else: 84 | self.kernel_initializer = initializers.get(kernel_initializer) 85 | self.bias_initializer = initializers.get(bias_initializer) 86 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 87 | self.bias_regularizer = regularizers.get(bias_regularizer) 88 | self.activity_regularizer = regularizers.get(activity_regularizer) 89 | self.kernel_constraint = constraints.get(kernel_constraint) 90 | self.bias_constraint = constraints.get(bias_constraint) 91 | if seed is None: 92 | self.seed = np.random.randint(1, 10e6) 93 | else: 94 | self.seed = seed 95 | self.input_spec = InputSpec(ndim=2) 96 | self.supports_masking = True 97 | 98 | def build(self, input_shape): 99 | assert len(input_shape) == 2 100 | assert input_shape[-1] % 2 == 0 101 | input_dim = input_shape[-1] // 2 102 | data_format = K.image_data_format() 103 | kernel_shape = (input_dim, self.units) 104 | fan_in, fan_out = _compute_fans(kernel_shape, data_format=data_format) 105 | if self.init_criterion == "he": 106 | s = np.sqrt(1.0 / fan_in) 107 | elif self.init_criterion == "glorot": 108 | s = np.sqrt(1.0 / (fan_in + fan_out)) 109 | rng = RandomState(seed=self.seed) 110 | 111 | # Equivalent initialization using amplitude phase representation: 112 | """modulus = rng.rayleigh(scale=s, size=kernel_shape) 113 | phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) 114 | def init_w_real(shape, dtype=None): 115 | return modulus * K.cos(phase) 116 | def init_w_imag(shape, dtype=None): 117 | return modulus * K.sin(phase)""" 118 | 119 | # Initialization using euclidean representation: 120 | def init_w_real(shape, dtype=None): 121 | return rng.normal( 122 | size=kernel_shape, 123 | loc=0, 124 | scale=s, 125 | ) # .astype(dtype) 126 | 127 | def init_w_imag(shape, dtype=None): 128 | return rng.normal(size=kernel_shape, loc=0, scale=s) # .astype(dtype) 129 | 130 | if self.kernel_initializer in {"complex"}: 131 | real_init = init_w_real 132 | imag_init = init_w_imag 133 | else: 134 | real_init = self.kernel_initializer 135 | imag_init = self.kernel_initializer 136 | 137 | self.real_kernel = self.add_weight( 138 | shape=kernel_shape, 139 | initializer=real_init, 140 | name="real_kernel", 141 | regularizer=self.kernel_regularizer, 142 | constraint=self.kernel_constraint, 143 | ) 144 | self.imag_kernel = self.add_weight( 145 | shape=kernel_shape, 146 | initializer=imag_init, 147 | name="imag_kernel", 148 | regularizer=self.kernel_regularizer, 149 | constraint=self.kernel_constraint, 150 | ) 151 | 152 | if self.use_bias: 153 | self.bias = self.add_weight( 154 | shape=(2 * self.units,), 155 | initializer=self.bias_initializer, 156 | name="bias", 157 | regularizer=self.bias_regularizer, 158 | constraint=self.bias_constraint, 159 | ) 160 | else: 161 | self.bias = None 162 | 163 | self.input_spec = InputSpec(ndim=2, axes={-1: 2 * input_dim}) 164 | self.built = True 165 | 166 | def call(self, inputs, **kwargs): 167 | input_shape = K.shape(inputs) 168 | input_dim = input_shape[-1] // 2 169 | real_input = inputs[:, :input_dim] 170 | imag_input = inputs[:, input_dim:] 171 | 172 | cat_kernels_4_real = K.concatenate([self.real_kernel, -self.imag_kernel], axis=-1) 173 | cat_kernels_4_imag = K.concatenate([self.imag_kernel, self.real_kernel], axis=-1) 174 | cat_kernels_4_complex = K.concatenate([cat_kernels_4_real, cat_kernels_4_imag], axis=0) 175 | 176 | output = K.dot(inputs, cat_kernels_4_complex) 177 | 178 | if self.use_bias: 179 | output = K.bias_add(output, self.bias) 180 | if self.activation is not None: 181 | output = self.activation(output) 182 | 183 | return output 184 | 185 | def compute_output_shape(self, input_shape): 186 | assert input_shape and len(input_shape) == 2 187 | assert input_shape[-1] 188 | output_shape = list(input_shape) 189 | output_shape[-1] = 2 * self.units 190 | return tuple(output_shape) 191 | 192 | def get_config(self): 193 | if self.kernel_initializer in {"complex"}: 194 | ki = self.kernel_initializer 195 | else: 196 | ki = initializers.serialize(self.kernel_initializer) 197 | config = { 198 | "units": self.units, 199 | "activation": activations.serialize(self.activation), 200 | "use_bias": self.use_bias, 201 | "init_criterion": self.init_criterion, 202 | "kernel_initializer": ki, 203 | "bias_initializer": initializers.serialize(self.bias_initializer), 204 | "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), 205 | "bias_regularizer": regularizers.serialize(self.bias_regularizer), 206 | "activity_regularizer": regularizers.serialize(self.activity_regularizer), 207 | "kernel_constraint": constraints.serialize(self.kernel_constraint), 208 | "bias_constraint": constraints.serialize(self.bias_constraint), 209 | "seed": self.seed, 210 | } 211 | base_config = super(ComplexDense, self).get_config() 212 | return dict(list(base_config.items()) + list(config.items())) 213 | -------------------------------------------------------------------------------- /complexnn/fft.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # import tensorflow.keras.engine as KE 5 | import tensorflow.keras.backend as KB 6 | import tensorflow.keras.layers as KL 7 | import tensorflow.keras.optimizers as KO 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | 12 | # 13 | # FFT functions: 14 | # 15 | # fft(): Batched 1-D FFT (Input: (Batch, TimeSamples)) 16 | # ifft(): Batched 1-D IFFT (Input: (Batch, FreqSamples)) 17 | # fft2(): Batched 2-D FFT (Input: (Batch, TimeSamplesH, TimeSamplesW)) 18 | # ifft2(): Batched 2-D IFFT (Input: (Batch, FreqSamplesH, FreqSamplesW)) 19 | # 20 | 21 | 22 | def fft(z): 23 | B = z.shape[0] // 2 24 | L = z.shape[1] 25 | C = tf.Variable(np.asarray([[[1, -1]]], dtype=tf.float32)) 26 | Zr, Zi = tf.signal.rfft(z[:B]), tf.signal.rfft(z[B:]) 27 | isOdd = tf.equal(L % 2, 1) 28 | Zr = tf.cond( 29 | isOdd, tf.concat([Zr, C * Zr[:, 1:][:, ::-1]], axis=1), tf.concat([Zr, C * Zr[:, 1:-1][:, ::-1]], axis=1) 30 | ) 31 | Zi = tf.cond( 32 | isOdd, tf.concat([Zi, C * Zi[:, 1:][:, ::-1]], axis=1), tf.concat([Zi, C * Zi[:, 1:-1][:, ::-1]], axis=1) 33 | ) 34 | Zi = (C * Zi)[:, :, ::-1] # Zi * i 35 | Z = Zr + Zi 36 | return tf.concat([Z[:, :, 0], Z[:, :, 1]], axis=0) 37 | 38 | 39 | def ifft(z): 40 | B = z.shape[0] // 2 41 | L = z.shape[1] 42 | C = tf.Variable(np.asarray([[[1, -1]]], dtype=tf.float32)) 43 | Zr, Zi = tf.signal.rfft(z[:B]), tf.signal.rfft(z[B:] * -1) 44 | isOdd = tf.equal(L % 2, 1) 45 | Zr = tf.cond( 46 | isOdd, tf.concat([Zr, C * Zr[:, 1:][:, ::-1]], axis=1), tf.concat([Zr, C * Zr[:, 1:-1][:, ::-1]], axis=1) 47 | ) 48 | Zi = tf.cond( 49 | isOdd, tf.concat([Zi, C * Zi[:, 1:][:, ::-1]], axis=1), tf.concat([Zi, C * Zi[:, 1:-1][:, ::-1]], axis=1) 50 | ) 51 | Zi = (C * Zi)[:, :, ::-1] # Zi * i 52 | Z = Zr + Zi 53 | return tf.concat([Z[:, :, 0], Z[:, :, 1] * -1], axis=0) 54 | 55 | 56 | def fft2(x): 57 | tt = x 58 | tt = KB.reshape(tt, (x.shape[0] * x.shape[1], x.shape[2])) 59 | tf = fft(tt) 60 | tf = KB.reshape(tf, (x.shape[0], x.shape[1], x.shape[2])) 61 | tf = KB.permute_dimensions(tf, (0, 2, 1)) 62 | tf = KB.reshape(tf, (x.shape[0] * x.shape[2], x.shape[1])) 63 | ff = fft(tf) 64 | ff = KB.reshape(ff, (x.shape[0], x.shape[2], x.shape[1])) 65 | ff = KB.permute_dimensions(ff, (0, 2, 1)) 66 | return ff 67 | 68 | 69 | def ifft2(x): 70 | ff = x 71 | ff = KB.permute_dimensions(ff, (0, 2, 1)) 72 | ff = KB.reshape(ff, (x.shape[0] * x.shape[2], x.shape[1])) 73 | tf = ifft(ff) 74 | tf = KB.reshape(tf, (x.shape[0], x.shape[2], x.shape[1])) 75 | tf = KB.permute_dimensions(tf, (0, 2, 1)) 76 | tf = KB.reshape(tf, (x.shape[0] * x.shape[1], x.shape[2])) 77 | tt = ifft(tf) 78 | tt = KB.reshape(tt, (x.shape[0], x.shape[1], x.shape[2])) 79 | return tt 80 | 81 | 82 | # 83 | # FFT Layers: 84 | # 85 | # FFT: Batched 1-D FFT (Input: (Batch, FeatureMaps, TimeSamples)) 86 | # IFFT: Batched 1-D IFFT (Input: (Batch, FeatureMaps, FreqSamples)) 87 | # FFT2: Batched 2-D FFT (Input: (Batch, FeatureMaps, TimeSamplesH, TimeSamplesW)) 88 | # IFFT2: Batched 2-D IFFT (Input: (Batch, FeatureMaps, FreqSamplesH, FreqSamplesW)) 89 | # 90 | 91 | 92 | class FFT(KL.Layer): 93 | def call(self, x, mask=None): 94 | a = KB.permute_dimensions(x, (1, 0, 2)) 95 | a = KB.reshape(a, (x.shape[1] * x.shape[0], x.shape[2])) 96 | a = fft(a) 97 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2])) 98 | return KB.permute_dimensions(a, (1, 0, 2)) 99 | 100 | 101 | class IFFT(KL.Layer): 102 | def call(self, x, mask=None): 103 | a = KB.permute_dimensions(x, (1, 0, 2)) 104 | a = KB.reshape(a, (x.shape[1] * x.shape[0], x.shape[2])) 105 | a = ifft(a) 106 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2])) 107 | return KB.permute_dimensions(a, (1, 0, 2)) 108 | 109 | 110 | class FFT2(KL.Layer): 111 | def call(self, x, mask=None): 112 | a = KB.permute_dimensions(x, (1, 0, 2, 3)) 113 | a = KB.reshape(a, (x.shape[1] * x.shape[0], x.shape[2], x.shape[3])) 114 | a = fft2(a) 115 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2], x.shape[3])) 116 | return KB.permute_dimensions(a, (1, 0, 2, 3)) 117 | 118 | 119 | class IFFT2(KL.Layer): 120 | def call(self, x, mask=None): 121 | a = KB.permute_dimensions(x, (1, 0, 2, 3)) 122 | a = KB.reshape(a, (x.shape[1] * x.shape[0], x.shape[2], x.shape[3])) 123 | a = ifft2(a) 124 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2], x.shape[3])) 125 | return KB.permute_dimensions(a, (1, 0, 2, 3)) 126 | -------------------------------------------------------------------------------- /complexnn/init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | from numpy.random import RandomState 6 | import tensorflow.keras.backend as K 7 | from tensorflow.keras.initializers import Initializer 8 | from tensorflow.python.keras.utils.generic_utils import serialize_keras_object, deserialize_keras_object 9 | from .utils import _compute_fans 10 | 11 | 12 | class IndependentFilters(Initializer): 13 | # This initialization constructs real-valued kernels 14 | # that are independent as much as possible from each other 15 | # while respecting either the He or the Glorot criterion. 16 | def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterion="glorot", seed=None): 17 | 18 | # `weight_dim` is used as a parameter for sanity check 19 | # as we should not pass an integer as kernel_size when 20 | # the weight dimension is >= 2. 21 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 22 | # then in such a case, weight_dim = 2. 23 | # (in case of 2D input): 24 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 25 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 26 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 27 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 28 | 29 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 30 | self.nb_filters = nb_filters 31 | self.kernel_size = kernel_size 32 | self.input_dim = input_dim 33 | self.weight_dim = weight_dim 34 | self.criterion = criterion 35 | self.seed = 1337 if seed is None else seed 36 | 37 | def __call__(self, shape, dtype=None): 38 | 39 | if self.nb_filters is not None: 40 | num_rows = self.nb_filters * self.input_dim 41 | num_cols = np.prod(self.kernel_size) 42 | else: 43 | # in case it is the kernel is a matrix and not a filter 44 | # which is the case of 2D input (No feature maps). 45 | num_rows = self.input_dim 46 | num_cols = self.kernel_size[-1] 47 | 48 | flat_shape = (num_rows, num_cols) 49 | rng = RandomState(self.seed) 50 | x = rng.uniform(size=flat_shape) 51 | u, _, v = np.linalg.svd(x) 52 | orthogonal_x = np.dot(u, np.dot(np.eye(num_rows, num_cols), v.T)) 53 | if self.nb_filters is not None: 54 | independent_filters = np.reshape(orthogonal_x, (num_rows,) + tuple(self.kernel_size)) 55 | fan_in, fan_out = _compute_fans(tuple(self.kernel_size) + (self.input_dim, self.nb_filters)) 56 | else: 57 | independent_filters = orthogonal_x 58 | fan_in, fan_out = (self.input_dim, self.kernel_size[-1]) 59 | 60 | if self.criterion == "glorot": 61 | desired_var = 2.0 / (fan_in + fan_out) 62 | elif self.criterion == "he": 63 | desired_var = 2.0 / fan_in 64 | else: 65 | raise ValueError("Invalid criterion: " + self.criterion) 66 | 67 | multip_constant = np.sqrt(desired_var / np.var(independent_filters)) 68 | scaled_indep = multip_constant * independent_filters 69 | 70 | 71 | if self.weight_dim == 2 and self.nb_filters is None: 72 | weight = scaled_indep 73 | else: 74 | kernel_shape = tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 75 | if self.weight_dim == 1: 76 | transpose_shape = (1, 0) 77 | elif self.weight_dim == 2 and self.nb_filters is not None: 78 | transpose_shape = (1, 2, 0) 79 | elif self.weight_dim == 3 and self.nb_filters is not None: 80 | transpose_shape = (1, 2, 3, 0) 81 | weight = np.transpose(scaled_indep, transpose_shape) 82 | weight = np.reshape(weight, kernel_shape) 83 | 84 | return weight 85 | 86 | def get_config(self): 87 | return { 88 | "nb_filters": self.nb_filters, 89 | "kernel_size": self.kernel_size, 90 | "input_dim": self.input_dim, 91 | "weight_dim": self.weight_dim, 92 | "criterion": self.criterion, 93 | "seed": self.seed, 94 | } 95 | 96 | 97 | class ComplexIndependentFilters(Initializer): 98 | # This initialization constructs complex-valued kernels 99 | # that are independent as much as possible from each other 100 | # while respecting either the He or the Glorot criterion. 101 | def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterion="glorot", seed=None): 102 | 103 | # `weight_dim` is used as a parameter for sanity check 104 | # as we should not pass an integer as kernel_size when 105 | # the weight dimension is >= 2. 106 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 107 | # then in such a case, weight_dim = 2. 108 | # (in case of 2D input): 109 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 110 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 111 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 112 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 113 | 114 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 115 | self.nb_filters = nb_filters 116 | self.kernel_size = kernel_size 117 | self.input_dim = input_dim 118 | self.weight_dim = weight_dim 119 | self.criterion = criterion 120 | self.seed = 1337 if seed is None else seed 121 | 122 | def __call__(self, shape, dtype=None): 123 | 124 | if self.nb_filters is not None: 125 | num_rows = self.nb_filters * self.input_dim 126 | num_cols = np.prod(self.kernel_size) 127 | else: 128 | # in case it is the kernel is a matrix and not a filter 129 | # which is the case of 2D input (No feature maps). 130 | num_rows = self.input_dim 131 | num_cols = self.kernel_size[-1] 132 | 133 | flat_shape = (int(num_rows), int(num_cols)) 134 | rng = RandomState(self.seed) 135 | r = rng.uniform(size=flat_shape) 136 | i = rng.uniform(size=flat_shape) 137 | z = r + 1j * i 138 | u, _, v = np.linalg.svd(z) 139 | unitary_z = np.dot(u, np.dot(np.eye(int(num_rows), int(num_cols)), np.conjugate(v).T)) 140 | real_unitary = unitary_z.real 141 | imag_unitary = unitary_z.imag 142 | if self.nb_filters is not None: 143 | indep_real = np.reshape(real_unitary, (num_rows,) + tuple(self.kernel_size)) 144 | indep_imag = np.reshape(imag_unitary, (num_rows,) + tuple(self.kernel_size)) 145 | fan_in, fan_out = _compute_fans(tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters)) 146 | else: 147 | indep_real = real_unitary 148 | indep_imag = imag_unitary 149 | fan_in, fan_out = (int(self.input_dim), self.kernel_size[-1]) 150 | 151 | if self.criterion == "glorot": 152 | desired_var = 1.0 / (fan_in + fan_out) 153 | elif self.criterion == "he": 154 | desired_var = 1.0 / (fan_in) 155 | else: 156 | raise ValueError("Invalid criterion: " + self.criterion) 157 | 158 | multip_real = np.sqrt(desired_var / np.var(indep_real)) 159 | multip_imag = np.sqrt(desired_var / np.var(indep_imag)) 160 | scaled_real = multip_real * indep_real 161 | scaled_imag = multip_imag * indep_imag 162 | 163 | if self.weight_dim == 2 and self.nb_filters is None: 164 | weight_real = scaled_real 165 | weight_imag = scaled_imag 166 | else: 167 | kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) 168 | if self.weight_dim == 1: 169 | transpose_shape = (1, 0) 170 | elif self.weight_dim == 2 and self.nb_filters is not None: 171 | transpose_shape = (1, 2, 0) 172 | elif self.weight_dim == 3 and self.nb_filters is not None: 173 | transpose_shape = (1, 2, 3, 0) 174 | 175 | weight_real = np.transpose(scaled_real, transpose_shape) 176 | weight_imag = np.transpose(scaled_imag, transpose_shape) 177 | weight_real = np.reshape(weight_real, kernel_shape) 178 | weight_imag = np.reshape(weight_imag, kernel_shape) 179 | weight = np.concatenate([weight_real, weight_imag], axis=-1) 180 | 181 | return weight 182 | 183 | def get_config(self): 184 | return { 185 | "nb_filters": self.nb_filters, 186 | "kernel_size": self.kernel_size, 187 | "input_dim": self.input_dim, 188 | "weight_dim": self.weight_dim, 189 | "criterion": self.criterion, 190 | "seed": self.seed, 191 | } 192 | 193 | 194 | class ComplexInit(Initializer): 195 | # The standard complex initialization using 196 | # either the He or the Glorot criterion. 197 | def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterion="glorot", seed=None): 198 | 199 | # `weight_dim` is used as a parameter for sanity check 200 | # as we should not pass an integer as kernel_size when 201 | # the weight dimension is >= 2. 202 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 203 | # then in such a case, weight_dim = 2. 204 | # (in case of 2D input): 205 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 206 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 207 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 208 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 209 | 210 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 211 | self.nb_filters = nb_filters 212 | self.kernel_size = kernel_size 213 | self.input_dim = input_dim 214 | self.weight_dim = weight_dim 215 | self.criterion = criterion 216 | self.seed = 1337 if seed is None else seed 217 | 218 | def __call__(self, shape, dtype=None): 219 | 220 | if self.nb_filters is not None: 221 | kernel_shape = shape 222 | # kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), 223 | # self.nb_filters) 224 | else: 225 | kernel_shape = (int(self.input_dim), self.kernel_size[-1]) 226 | 227 | fan_in, fan_out = _compute_fans( 228 | # tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 229 | kernel_shape 230 | ) 231 | 232 | # fix for ValueError: The initial value's shape (...) is not compatible with the explicitly supplied `shape` argument 233 | reim_shape = list(kernel_shape) 234 | reim_shape[-1] //= 2 235 | reim_shape = tuple(reim_shape) 236 | 237 | if self.criterion == "glorot": 238 | s = 1.0 / (fan_in + fan_out) 239 | elif self.criterion == "he": 240 | s = 1.0 / fan_in 241 | else: 242 | raise ValueError("Invalid criterion: " + self.criterion) 243 | rng = RandomState(self.seed) 244 | modulus = rng.rayleigh(scale=s, size=reim_shape) 245 | phase = rng.uniform(low=-np.pi, high=np.pi, size=reim_shape) 246 | weight_real = modulus * np.cos(phase) 247 | weight_imag = modulus * np.sin(phase) 248 | weight = np.concatenate([weight_real, weight_imag], axis=-1) 249 | 250 | return weight 251 | 252 | 253 | class SqrtInit(Initializer): 254 | def __call__(self, shape, dtype=None): 255 | return K.constant(1 / K.sqrt(2), shape=shape, dtype=dtype) 256 | 257 | 258 | # Aliases: 259 | sqrt_init = SqrtInit 260 | independent_filters = IndependentFilters 261 | complex_init = ComplexInit 262 | -------------------------------------------------------------------------------- /complexnn/norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Implementation of Layer Normalization and Complex Layer Normalization 5 | 6 | 7 | import numpy as np 8 | from tensorflow.keras.layers import Layer, InputSpec 9 | from tensorflow.keras import initializers, regularizers, constraints 10 | import tensorflow.keras.backend as K 11 | from .bn import ComplexBN as complex_normalization 12 | from .bn import sqrt_init 13 | 14 | 15 | def layernorm(x, axis, epsilon, gamma, beta): 16 | # assert self.built, 'Layer must be built before being called' 17 | input_shape = K.shape(x) 18 | reduction_axes = list(range(K.ndim(x))) 19 | del reduction_axes[axis] 20 | del reduction_axes[0] 21 | broadcast_shape = [1] * K.ndim(x) 22 | broadcast_shape[axis] = input_shape[axis] 23 | broadcast_shape[0] = K.shape(x)[0] 24 | 25 | # Perform normalization: centering and reduction 26 | 27 | mean = K.mean(x, axis=reduction_axes) 28 | broadcast_mean = K.reshape(mean, broadcast_shape) 29 | x_centred = x - broadcast_mean 30 | variance = K.mean(x_centred**2, axis=reduction_axes) + epsilon 31 | broadcast_variance = K.reshape(variance, broadcast_shape) 32 | 33 | x_normed = x_centred / K.sqrt(broadcast_variance) 34 | 35 | # Perform scaling and shifting 36 | 37 | broadcast_shape_params = [1] * K.ndim(x) 38 | broadcast_shape_params[axis] = K.shape(x)[axis] 39 | broadcast_gamma = K.reshape(gamma, broadcast_shape_params) 40 | broadcast_beta = K.reshape(beta, broadcast_shape_params) 41 | 42 | x_LN = broadcast_gamma * x_normed + broadcast_beta 43 | 44 | return x_LN 45 | 46 | 47 | class LayerNormalization(Layer): 48 | def __init__( 49 | self, 50 | epsilon=1e-4, 51 | axis=-1, 52 | beta_init="zeros", 53 | gamma_init="ones", 54 | gamma_regularizer=None, 55 | beta_regularizer=None, 56 | **kwargs 57 | ): 58 | 59 | self.supports_masking = True 60 | self.beta_init = initializers.get(beta_init) 61 | self.gamma_init = initializers.get(gamma_init) 62 | self.epsilon = epsilon 63 | self.axis = axis 64 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 65 | self.beta_regularizer = regularizers.get(beta_regularizer) 66 | 67 | super(LayerNormalization, self).__init__(**kwargs) 68 | 69 | def build(self, input_shape): 70 | self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: input_shape[self.axis]}) 71 | shape = (input_shape[self.axis],) 72 | 73 | self.gamma = self.add_weight( 74 | shape, initializer=self.gamma_init, regularizer=self.gamma_regularizer, name="{}_gamma".format(self.name) 75 | ) 76 | self.beta = self.add_weight( 77 | shape, initializer=self.beta_init, regularizer=self.beta_regularizer, name="{}_beta".format(self.name) 78 | ) 79 | 80 | self.built = True 81 | 82 | def call(self, x, mask=None): 83 | assert self.built, "Layer must be built before being called" 84 | return layernorm(x, self.axis, self.epsilon, self.gamma, self.beta) 85 | 86 | def get_config(self): 87 | config = { 88 | "epsilon": self.epsilon, 89 | "axis": self.axis, 90 | "gamma_regularizer": self.gamma_regularizer.get_config() if self.gamma_regularizer else None, 91 | "beta_regularizer": self.beta_regularizer.get_config() if self.beta_regularizer else None, 92 | } 93 | base_config = super(LayerNormalization, self).get_config() 94 | return dict(list(base_config.items()) + list(config.items())) 95 | 96 | 97 | class ComplexLayerNorm(Layer): 98 | def __init__( 99 | self, 100 | epsilon=1e-4, 101 | axis=-1, 102 | center=True, 103 | scale=True, 104 | beta_initializer="zeros", 105 | gamma_diag_initializer=sqrt_init, 106 | gamma_off_initializer="zeros", 107 | beta_regularizer=None, 108 | gamma_diag_regularizer=None, 109 | gamma_off_regularizer=None, 110 | beta_constraint=None, 111 | gamma_diag_constraint=None, 112 | gamma_off_constraint=None, 113 | **kwargs 114 | ): 115 | 116 | self.supports_masking = True 117 | self.epsilon = epsilon 118 | self.axis = axis 119 | self.center = center 120 | self.scale = scale 121 | self.beta_initializer = initializers.get(beta_initializer) 122 | self.gamma_diag_initializer = initializers.get(gamma_diag_initializer) 123 | self.gamma_off_initializer = initializers.get(gamma_off_initializer) 124 | self.beta_regularizer = regularizers.get(beta_regularizer) 125 | self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) 126 | self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) 127 | self.beta_constraint = constraints.get(beta_constraint) 128 | self.gamma_diag_constraint = constraints.get(gamma_diag_constraint) 129 | self.gamma_off_constraint = constraints.get(gamma_off_constraint) 130 | super(ComplexLayerNorm, self).__init__(**kwargs) 131 | 132 | def build(self, input_shape): 133 | 134 | ndim = len(input_shape) 135 | dim = input_shape[self.axis] 136 | if dim is None: 137 | raise ValueError( 138 | "Axis " + str(self.axis) + " of " 139 | "input tensor should have a defined dimension " 140 | "but the layer received an input with shape " + str(input_shape) + "." 141 | ) 142 | self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: dim}) 143 | 144 | gamma_shape = (input_shape[self.axis] // 2,) 145 | if self.scale: 146 | self.gamma_rr = self.add_weight( 147 | shape=gamma_shape, 148 | name="gamma_rr", 149 | initializer=self.gamma_diag_initializer, 150 | regularizer=self.gamma_diag_regularizer, 151 | constraint=self.gamma_diag_constraint, 152 | ) 153 | self.gamma_ii = self.add_weight( 154 | shape=gamma_shape, 155 | name="gamma_ii", 156 | initializer=self.gamma_diag_initializer, 157 | regularizer=self.gamma_diag_regularizer, 158 | constraint=self.gamma_diag_constraint, 159 | ) 160 | self.gamma_ri = self.add_weight( 161 | shape=gamma_shape, 162 | name="gamma_ri", 163 | initializer=self.gamma_off_initializer, 164 | regularizer=self.gamma_off_regularizer, 165 | constraint=self.gamma_off_constraint, 166 | ) 167 | else: 168 | self.gamma_rr = None 169 | self.gamma_ii = None 170 | self.gamma_ri = None 171 | 172 | if self.center: 173 | self.beta = self.add_weight( 174 | shape=(input_shape[self.axis],), 175 | name="beta", 176 | initializer=self.beta_initializer, 177 | regularizer=self.beta_regularizer, 178 | constraint=self.beta_constraint, 179 | ) 180 | else: 181 | self.beta = None 182 | 183 | self.built = True 184 | 185 | def call(self, inputs, **kwargs): 186 | input_shape = K.shape(inputs) 187 | ndim = K.ndim(inputs) 188 | reduction_axes = list(range(ndim)) 189 | del reduction_axes[self.axis] 190 | del reduction_axes[0] 191 | input_dim = input_shape[self.axis] // 2 192 | mu = K.mean(inputs, axis=reduction_axes) 193 | broadcast_mu_shape = [1] * ndim 194 | broadcast_mu_shape[self.axis] = input_shape[self.axis] 195 | broadcast_mu_shape[0] = K.shape(inputs)[0] 196 | broadcast_mu = K.reshape(mu, broadcast_mu_shape) 197 | if self.center: 198 | input_centred = inputs - broadcast_mu 199 | else: 200 | input_centred = inputs 201 | centred_squared = input_centred**2 202 | if (self.axis == 1 and ndim != 3) or ndim == 2: 203 | centred_squared_real = centred_squared[:, :input_dim] 204 | centred_squared_imag = centred_squared[:, input_dim:] 205 | centred_real = input_centred[:, :input_dim] 206 | centred_imag = input_centred[:, input_dim:] 207 | elif ndim == 3: 208 | centred_squared_real = centred_squared[:, :, :input_dim] 209 | centred_squared_imag = centred_squared[:, :, input_dim:] 210 | centred_real = input_centred[:, :, :input_dim] 211 | centred_imag = input_centred[:, :, input_dim:] 212 | elif self.axis == -1 and ndim == 4: 213 | centred_squared_real = centred_squared[:, :, :, :input_dim] 214 | centred_squared_imag = centred_squared[:, :, :, input_dim:] 215 | centred_real = input_centred[:, :, :, :input_dim] 216 | centred_imag = input_centred[:, :, :, input_dim:] 217 | elif self.axis == -1 and ndim == 5: 218 | centred_squared_real = centred_squared[:, :, :, :, :input_dim] 219 | centred_squared_imag = centred_squared[:, :, :, :, input_dim:] 220 | centred_real = input_centred[:, :, :, :, :input_dim] 221 | centred_imag = input_centred[:, :, :, :, input_dim:] 222 | else: 223 | raise ValueError( 224 | "Incorrect Layernorm combination of axis and dimensions. axis should be either 1 or -1. " 225 | "axis: " + str(self.axis) + "; ndim: " + str(ndim) + "." 226 | ) 227 | if self.scale: 228 | Vrr = K.mean(centred_squared_real, axis=reduction_axes) + self.epsilon 229 | Vii = K.mean(centred_squared_imag, axis=reduction_axes) + self.epsilon 230 | # Vri contains the real and imaginary covariance for each feature map. 231 | Vri = K.mean( 232 | centred_real * centred_imag, 233 | axis=reduction_axes, 234 | ) 235 | elif self.center: 236 | Vrr = None 237 | Vii = None 238 | Vri = None 239 | else: 240 | raise ValueError("Error. Both scale and center in batchnorm are set to False.") 241 | 242 | return complex_normalization( 243 | input_centred, 244 | Vrr, 245 | Vii, 246 | Vri, 247 | self.beta, 248 | self.gamma_rr, 249 | self.gamma_ri, 250 | self.gamma_ii, 251 | self.scale, 252 | self.center, 253 | layernorm=True, 254 | axis=self.axis, 255 | ) 256 | 257 | def get_config(self): 258 | config = { 259 | "axis": self.axis, 260 | "epsilon": self.epsilon, 261 | "center": self.center, 262 | "scale": self.scale, 263 | "beta_initializer": initializers.serialize(self.beta_initializer), 264 | "gamma_diag_initializer": initializers.serialize(self.gamma_diag_initializer), 265 | "gamma_off_initializer": initializers.serialize(self.gamma_off_initializer), 266 | "beta_regularizer": regularizers.serialize(self.beta_regularizer), 267 | "gamma_diag_regularizer": regularizers.serialize(self.gamma_diag_regularizer), 268 | "gamma_off_regularizer": regularizers.serialize(self.gamma_off_regularizer), 269 | "beta_constraint": constraints.serialize(self.beta_constraint), 270 | "gamma_diag_constraint": constraints.serialize(self.gamma_diag_constraint), 271 | "gamma_off_constraint": constraints.serialize(self.gamma_off_constraint), 272 | } 273 | base_config = super(ComplexLayerNorm, self).get_config() 274 | return dict(list(base_config.items()) + list(config.items())) 275 | -------------------------------------------------------------------------------- /complexnn/pool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow.keras.backend as KB 5 | import tensorflow.keras as KE 6 | import tensorflow.keras.layers as KL 7 | import tensorflow.keras.optimizers as KO 8 | import numpy as np 9 | 10 | 11 | # 12 | # Spectral Pooling Layer 13 | # 14 | 15 | 16 | class SpectralPooling1D(KL.Layer): 17 | def __init__(self, topf=(0,), **kwargs): 18 | super(SpectralPooling1D, self).__init__() 19 | if "topf" in kwargs: 20 | self.topf = (int(kwargs["topf"][0]),) 21 | self.topf = (self.topf[0] // 2,) 22 | elif "gamma" in kwargs: 23 | self.gamma = (float(kwargs["gamma"][0]),) 24 | self.gamma = (self.gamma[0] / 2,) 25 | else: 26 | raise RuntimeError("Must provide either topf= or gamma= !") 27 | 28 | def call(self, x, mask=None): 29 | xshape = x._keras_shape 30 | if hasattr(self, "topf"): 31 | topf = self.topf 32 | else: 33 | if KB.image_data_format() == "channels_first": 34 | topf = (int(self.gamma[0] * xshape[2]),) 35 | else: 36 | topf = (int(self.gamma[0] * xshape[1]),) 37 | 38 | if KB.image_data_format() == "channels_first": 39 | if topf[0] > 0 and xshape[2] >= 2 * topf[0]: 40 | mask = [1] * (topf[0]) + [0] * (xshape[2] - 2 * topf[0]) + [1] * (topf[0]) 41 | mask = [[mask]] 42 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 1, 2)) 43 | mask = KB.constant(mask) 44 | x *= mask 45 | else: 46 | if topf[0] > 0 and xshape[1] >= 2 * topf[0]: 47 | mask = [1] * (topf[0]) + [0] * (xshape[1] - 2 * topf[0]) + [1] * (topf[0]) 48 | mask = [[mask]] 49 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 2, 1)) 50 | mask = KB.constant(mask) 51 | x *= mask 52 | 53 | return x 54 | 55 | 56 | class SpectralPooling2D(KL.Layer): 57 | def __init__(self, **kwargs): 58 | super(SpectralPooling2D, self).__init__() 59 | if "topf" in kwargs: 60 | self.topf = (int(kwargs["topf"][0]), int(kwargs["topf"][1])) 61 | self.topf = (self.topf[0] // 2, self.topf[1] // 2) 62 | elif "gamma" in kwargs: 63 | self.gamma = (float(kwargs["gamma"][0]), float(kwargs["gamma"][1])) 64 | self.gamma = (self.gamma[0] / 2, self.gamma[1] / 2) 65 | else: 66 | raise RuntimeError("Must provide either topf= or gamma= !") 67 | 68 | def call(self, x, mask=None): 69 | xshape = x._keras_shape 70 | if hasattr(self, "topf"): 71 | topf = self.topf 72 | else: 73 | if KB.image_data_format() == "channels_first": 74 | topf = (int(self.gamma[0] * xshape[2]), int(self.gamma[1] * xshape[3])) 75 | else: 76 | topf = (int(self.gamma[0] * xshape[1]), int(self.gamma[1] * xshape[2])) 77 | 78 | if KB.image_data_format() == "channels_first": 79 | if topf[0] > 0 and xshape[2] >= 2 * topf[0]: 80 | mask = [1] * (topf[0]) + [0] * (xshape[2] - 2 * topf[0]) + [1] * (topf[0]) 81 | mask = [[[mask]]] 82 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 1, 3, 2)) 83 | mask = KB.constant(mask) 84 | x *= mask 85 | if topf[1] > 0 and xshape[3] >= 2 * topf[1]: 86 | mask = [1] * (topf[1]) + [0] * (xshape[3] - 2 * topf[1]) + [1] * (topf[1]) 87 | mask = [[[mask]]] 88 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 1, 2, 3)) 89 | mask = KB.constant(mask) 90 | x *= mask 91 | else: 92 | if topf[0] > 0 and xshape[1] >= 2 * topf[0]: 93 | mask = [1] * (topf[0]) + [0] * (xshape[1] - 2 * topf[0]) + [1] * (topf[0]) 94 | mask = [[[mask]]] 95 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 3, 1, 2)) 96 | mask = KB.constant(mask) 97 | x *= mask 98 | if topf[1] > 0 and xshape[2] >= 2 * topf[1]: 99 | mask = [1] * (topf[1]) + [0] * (xshape[2] - 2 * topf[1]) + [1] * (topf[1]) 100 | mask = [[[mask]]] 101 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 1, 3, 2)) 102 | mask = KB.constant(mask) 103 | x *= mask 104 | 105 | return x 106 | 107 | 108 | if __name__ == "__main__": 109 | import cv2, sys 110 | import __main__ as SP 111 | import fft as CF 112 | 113 | # Build Model 114 | x = i = KL.Input(shape=(6, 512, 512)) 115 | f = CF.FFT2()(x) 116 | p = SP.SpectralPooling2D(gamma=[0.15, 0.15])(f) 117 | o = CF.IFFT2()(p) 118 | 119 | model = KE.Model([i], [f, p, o]) 120 | model.compile("sgd", "mse") 121 | 122 | # Use it 123 | img = cv2.imread(sys.argv[1]) 124 | imgBatch = img[np.newaxis, ...].transpose((0, 3, 1, 2)) 125 | imgBatch = np.concatenate([imgBatch, np.zeros_like(imgBatch)], axis=1) 126 | f, p, o = model.predict(imgBatch) 127 | ffted = np.sqrt(np.sum(f[:, :3] ** 2 + f[:, 3:] ** 2, axis=1)) 128 | ffted = ffted.transpose((1, 2, 0)) / 255 129 | pooled = np.sqrt(np.sum(p[:, :3] ** 2 + p[:, 3:] ** 2, axis=1)) 130 | pooled = pooled.transpose((1, 2, 0)) / 255 131 | filtered = np.clip(o, 0, 255).transpose((0, 2, 3, 1))[0, :, :, :3].astype("uint8") 132 | 133 | # Display it 134 | cv2.imshow("Original", img) 135 | cv2.imshow("FFT", ffted) 136 | cv2.imshow("Pooled", pooled) 137 | cv2.imshow("Filtered", filtered) 138 | cv2.waitKey(0) 139 | -------------------------------------------------------------------------------- /complexnn/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow.keras.backend as K 5 | from tensorflow.keras.layers import Layer, Lambda 6 | import numpy as np 7 | 8 | 9 | # 10 | # GetReal/GetImag Lambda layer Implementation 11 | # 12 | 13 | 14 | def get_realpart(x): 15 | image_format = K.image_data_format() 16 | ndim = K.ndim(x) 17 | input_shape = K.shape(x) 18 | 19 | if (image_format == "channels_first" and ndim != 3) or ndim == 2: 20 | input_dim = input_shape[1] // 2 21 | return x[:, :input_dim] 22 | 23 | input_dim = input_shape[-1] // 2 24 | if ndim == 3: 25 | return x[:, :, :input_dim] 26 | elif ndim == 4: 27 | return x[:, :, :, :input_dim] 28 | elif ndim == 5: 29 | return x[:, :, :, :, :input_dim] 30 | 31 | 32 | def get_imagpart(x): 33 | image_format = K.image_data_format() 34 | ndim = K.ndim(x) 35 | input_shape = K.shape(x) 36 | 37 | if (image_format == "channels_first" and ndim != 3) or ndim == 2: 38 | input_dim = input_shape[1] // 2 39 | return x[:, input_dim:] 40 | 41 | input_dim = input_shape[-1] // 2 42 | if ndim == 3: 43 | return x[:, :, input_dim:] 44 | elif ndim == 4: 45 | return x[:, :, :, input_dim:] 46 | elif ndim == 5: 47 | return x[:, :, :, :, input_dim:] 48 | 49 | 50 | def get_abs(x): 51 | real = get_realpart(x) 52 | imag = get_imagpart(x) 53 | 54 | return K.sqrt(real * real + imag * imag) 55 | 56 | 57 | def getpart_output_shape(input_shape): 58 | returned_shape = list(input_shape[:]) 59 | image_format = K.image_data_format() 60 | ndim = len(returned_shape) 61 | 62 | if (image_format == "channels_first" and ndim != 3) or ndim == 2: 63 | axis = 1 64 | else: 65 | axis = -1 66 | 67 | returned_shape[axis] = returned_shape[axis] // 2 68 | 69 | return tuple(returned_shape) 70 | 71 | 72 | # _compute_fans is different in keras-2 keras.initializers and tensorflow.python.ops.init_ops 73 | # this is the implementation copied from keras-2: 74 | def _compute_fans(shape, data_format="channels_last"): 75 | """Computes the number of input and output units for a weight shape. 76 | # Arguments 77 | 78 | shape: Integer shape tuple. 79 | data_format: Image data format to use for convolution kernels. 80 | Note that all kernels in Keras are standardized on the 81 | `channels_last` ordering (even when inputs are set 82 | to `channels_first`). 83 | # Returns 84 | A tuple of scalars, `(fan_in, fan_out)`. 85 | # Raises 86 | ValueError: in case of invalid `data_format` argument. 87 | """ 88 | if len(shape) == 2: 89 | fan_in = shape[0] 90 | fan_out = shape[1] 91 | elif len(shape) in {3, 4, 5}: 92 | # Assuming convolution kernels (1D, 2D or 3D). 93 | # TH kernel shape: (depth, input_depth, ...) 94 | # TF kernel shape: (..., input_depth, depth) 95 | if data_format == "channels_first": 96 | receptive_field_size = np.prod(shape[2:]) 97 | fan_in = shape[1] * receptive_field_size 98 | fan_out = shape[0] * receptive_field_size 99 | elif data_format == "channels_last": 100 | receptive_field_size = np.prod(shape[:2]) 101 | fan_in = shape[-2] * receptive_field_size 102 | fan_out = shape[-1] * receptive_field_size 103 | else: 104 | raise ValueError("Invalid data_format: " + data_format) 105 | else: 106 | # No specific assumptions. 107 | fan_in = np.sqrt(np.prod(shape)) 108 | fan_out = np.sqrt(np.prod(shape)) 109 | return fan_in, fan_out 110 | 111 | 112 | class GetReal(Layer): 113 | def call(self, inputs): 114 | return get_realpart(inputs) 115 | 116 | def compute_output_shape(self, input_shape): 117 | return getpart_output_shape(input_shape) 118 | 119 | 120 | class GetImag(Layer): 121 | def call(self, inputs): 122 | return get_imagpart(inputs) 123 | 124 | def compute_output_shape(self, input_shape): 125 | return getpart_output_shape(input_shape) 126 | 127 | 128 | class GetAbs(Layer): 129 | def call(self, inputs): 130 | return get_abs(inputs) 131 | 132 | def compute_output_shape(self, input_shape): 133 | return getpart_output_shape(input_shape) 134 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | sphinx>=1.8 3 | sphinxcontrib-apidoc 4 | sphinxcontrib-bibtex -------------------------------------------------------------------------------- /docs/source/bib.bib: -------------------------------------------------------------------------------- 1 | @article{trabelsi2017deep, 2 | title={Deep complex networks}, 3 | author={Trabelsi, Chiheb and Bilaniuk, Olexa and Zhang, Ying and Serdyuk, Dmitriy and Subramanian, Sandeep and Santos, Jo{\~a}o Felipe and Mehri, Soroush and Rostamzadeh, Negar and Bengio, Yoshua and Pal, Christopher J}, 4 | journal={arXiv preprint arXiv:1705.09792}, 5 | year={2017} 6 | } 7 | 8 | @misc{dramsch2019complexsoftware, 9 | title = {Complex-Valued Neural Networks in Keras with Tensorflow}, 10 | url = {https://figshare.com/articles/Complex-Valued_Neural_Networks_in_Keras_with_Tensorflow/9783773/1}, 11 | DOI = {10.6084/m9.figshare.9783773}, 12 | publisher = {figshare}, 13 | author = {Dramsch, Jesper Soeren and Contributors}, 14 | year = {2019} 15 | } 16 | 17 | 18 | @article{dramsch2019complex, 19 | title={Complex-valued neural networks for machine learning on non-stationary physical data}, 20 | author={Dramsch, Jesper S{\"o}ren and L{\"u}thje, Mikael and Christensen, Anders Nymark}, 21 | journal={arXiv preprint arXiv:1905.12321}, 22 | year={2019} 23 | } 24 | 25 | @article{Sarroff2015, 26 | author = {Andy M. Sarroff and 27 | Victor Shepardson and 28 | Michael A. Casey}, 29 | title = {Learning Representations Using Complex-Valued Nets}, 30 | journal = {CoRR}, 31 | volume = {abs/1511.06351}, 32 | year = {2015}, 33 | url = {http://arxiv.org/abs/1511.06351}, 34 | archivePrefix = {arXiv}, 35 | eprint = {1511.06351}, 36 | timestamp = {Mon, 13 Aug 2018 16:46:29 +0200}, 37 | biburl = {https://dblp.org/rec/bib/journals/corr/SarroffSC15}, 38 | bibsource = {dblp computer science bibliography, https://dblp.org} 39 | } 40 | 41 | @article{Hirose2012, 42 | year = {2012}, 43 | author={Hirose, Akira and Yoshida, Shotaro}, 44 | title = {Generalization Characteristics of Complex-Valued Feedforward Neural Networks in Relation to Signal Coherence}, 45 | journal = {IEEE Transactions on Neural Networks and Learning Systems} 46 | } -------------------------------------------------------------------------------- /docs/source/cite.md: -------------------------------------------------------------------------------- 1 | Citation 2 | -------- 3 | 4 | Find the CITATION file called CITATION.cff on Github or cite this software version as: 5 | 6 | ``` 7 | @misc{dramsch2019complex, 8 | title = {Complex-Valued Neural Networks in Keras with Tensorflow}, 9 | url = {https://figshare.com/articles/Complex-Valued_Neural_Networks_in_Keras_with_Tensorflow/9783773/1}, 10 | DOI = {10.6084/m9.figshare.9783773}, 11 | publisher = {figshare}, 12 | author = {Dramsch, Jesper S{\"o}ren and Contributors}, 13 | year = {2019} 14 | } 15 | ``` 16 | 17 | Please cite the original work as: 18 | 19 | ``` 20 | @ARTICLE {Trabelsi2017, 21 | author = "Chiheb Trabelsi, Olexa Bilaniuk, Ying Zhang, Dmitriy Serdyuk, Sandeep Subramanian, João Felipe Santos, Soroush Mehri, Negar Rostamzadeh, Yoshua Bengio, Christopher J Pal", 22 | title = "Deep Complex Networks", 23 | journal = "arXiv preprint arXiv:1705.09792", 24 | year = "2017" 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "Keras Complex" 21 | copyright = "2019, Keras-Complex Contributors" 22 | author = "Jesper Dramsch, Chiheb Trabelsi, Olexa Bilaniuk, Bruce Sharpe, Ying Zhang, Dmitriy Serdyuk, Sandeep Subramanian, João Felipe Santos, Soroush Mehri, Negar Rostamzadeh, Yoshua Bengio, Christopher J Pal" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = ["sphinxcontrib.apidoc", "sphinx.ext.autodoc", "recommonmark", "sphinxcontrib.bibtex"] 31 | 32 | master_doc = "index" # Needed by RTD 33 | 34 | apidoc_module_dir = "../../complexnn" 35 | apidoc_excluded_paths = ["tests"] 36 | apidoc_toc_file = "api_toc" 37 | apidoc_separate_modules = True 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ["_templates"] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = [] 46 | 47 | 48 | # -- Options for HTML output ------------------------------------------------- 49 | 50 | # The theme to use for HTML and HTML Help pages. See the documentation for 51 | # a list of builtin themes. 52 | # 53 | html_theme = "sphinx_rtd_theme" 54 | 55 | # Add any paths that contain custom static files (such as style sheets) here, 56 | # relative to this directory. They are copied after the builtin static files, 57 | # so a file named "default.css" will overwrite the builtin "default.css". 58 | html_static_path = ["_static"] 59 | 60 | # Bibtex configuration 61 | bibtex_bibfiles = ["bib.bib"] 62 | -------------------------------------------------------------------------------- /docs/source/contrib.md: -------------------------------------------------------------------------------- 1 | How to Contribute 2 | ================= 3 | 4 | You can add a [Pull Request](https://github.com/JesperDramsch/keras-complex/pulls/) on Github. 5 | 6 | Test 7 | ---- 8 | 9 | Make sure the tests pass and new features have at least unittests to cover the new functions. 10 | 11 | These tests should run with `pytest`. 12 | 13 | Documentation 14 | ------------- 15 | 16 | New features should be documented in the docs/ folder, which will be automatically generated on readthedocs.org. -------------------------------------------------------------------------------- /docs/source/figures/complex_nn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JesperDramsch/keras-complex/e723e96e6b309fd22545f542e535d2caa16ffd7b/docs/source/figures/complex_nn.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Keras Complex documentation master file, created by 2 | sphinx-quickstart on Mon Oct 14 23:45:44 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Documentation for Complex-valued Keras 7 | ========================================= 8 | 9 | Complex-valued convolutions could provide some interesting results in signal processing-based deep learning. A simple(-ish) idea is including explicit phase information of time series in neural networks. This code enables complex-valued convolution in convolutional neural networks in keras with the TensorFlow backend. This makes the network modular and interoperable with standard keras layers and operations. 10 | 11 | Contents 12 | ======== 13 | .. toctree:: 14 | :maxdepth: 2 15 | :caption: Table of Contents 16 | 17 | Introduction 18 | Installation 19 | API 20 | Contributing 21 | math 22 | cite 23 | 24 | Indices and tables 25 | ================== 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | -------------------------------------------------------------------------------- /docs/source/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Installation is as easy as 4 | 5 | ``` 6 | pip install keras-complex 7 | ``` 8 | 9 | The requirements are: 10 | 11 | ``` 12 | tensorflow >= 2 13 | numpy 14 | scipy 15 | scikit-learn 16 | ``` 17 | -------------------------------------------------------------------------------- /docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | Complex-valued convolutions could provide some interesting results in signal processing-based deep learning. A simple(-ish) idea is including explicit phase information of time series in neural networks. This code enables complex-valued convolution in convolutional neural networks in keras with the TensorFlow backend. This makes the network modular and interoperable with standard keras layers and operations. -------------------------------------------------------------------------------- /docs/source/license.rst: -------------------------------------------------------------------------------- 1 | License 2 | ======= 3 | 4 | .. include:: ../../LICENSE.md -------------------------------------------------------------------------------- /docs/source/math.rst: -------------------------------------------------------------------------------- 1 | Implementation and Math 2 | ========================= 3 | Complex convolutional networks provide the benefit of explicitly modelling the phase space of physical systems :cite:`trabelsi2017deep`. 4 | The complex convolution introduced can be explicitly implemented as convolutions of the real and complex components of both kernels and the data. 5 | A complex-valued data matrix in cartesian notation is defined as :math:`\textbf{M} = M_\Re + i M_\Im` and equally, the complex-valued convolutional kernel is defined as :math:`\textbf{K} = K_\Re + i K_\Im`. 6 | The individual coefficients :math:`(M_\Re, M_\Im, K_\Re, K_\Im)` are real-valued matrices, considering vectors are special cases of matrices with one of two dimensions being one. 7 | 8 | 9 | Complex Convolution Math 10 | --------------------------- 11 | The math for complex convolutional networks is similar to real-valued convolutions, with real-valued convolutions being: 12 | 13 | .. math:: 14 | \int f(y)\cdot g(x-y) \, dy 15 | 16 | which generalizes to complex-valued function on :math:`\mathbf{R}^d`: 17 | 18 | .. math:: 19 | (f * g )(x) = \int_{\mathbf{R}^d} f(y)g(x-y)\,dy = \int_{\mathbf{R}^d} f(x-y)g(y)\,dy, 20 | 21 | in order for the integral to exist, f and g need to decay sufficiently rapidly at infinity [`CC-BY-SA Wiki `_]. 22 | 23 | 24 | Implementation 25 | ----------------- 26 | Solving the convolution of, implemented by :cite:`trabelsi2017deep`, translated to keras in :cite:`dramsch2019complexsoftware` 27 | 28 | .. figure:: figures/complex_nn.png 29 | :width: 50% 30 | :align: center 31 | :alt: Complex Convolutions (Trabelsi 2017) 32 | :figclass: align-center 33 | 34 | Complex Convolution implementation (CC-BY :cite:`trabelsi2017deep`) 35 | 36 | .. math:: 37 | M' = K * M = (M_\Re + i M_\Im) * (K_\Re + i K_\Im), 38 | 39 | we can apply the distributivity of convolutions to obtain 40 | 41 | .. math:: 42 | M' = \{M_\Re * K_\Re - M_\Im * K_\Im\} + i \{ M_\Re * K_\Im + M_\Im * K_\Re\}, 43 | 44 | where K is the Kernel and M is a data vector. 45 | 46 | Considerations 47 | ----------------- 48 | Complex convolutional neural networks learn by back-propagation. 49 | :cite:`Sarroff2015` state that the activation functions, as well as the loss function must be complex differentiable (holomorphic). 50 | :cite:`trabelsi2017deep` suggest that employing complex losses and activation functions is valid for speed, however, refers that :cite:`Hirose2012` show that complex-valued networks can be optimized individually with real-valued loss functions and contain piecewise real-valued activations. 51 | We reimplement the code :cite:`trabelsi2017deep` provides in keras with tensorflow , which provides convenience functions implementing a multitude of real-valued loss functions and activations. 52 | 53 | [CC-BY :cite:`dramsch2019complex`] 54 | 55 | .. bibliography:: bib.bib 56 | :cited: -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools_scm] 6 | 7 | [project] 8 | name = "keras-complex" 9 | maintainers = [ 10 | {name = "Jesper Dramsch", email = "jesper@dramsch.net"} 11 | ] 12 | description = "Complex values in Keras" 13 | requires-python = ">=3.8" 14 | keywords = ["Machine Learning", "Deep Learning", "Complex Numbers", "Keras", "TensorFlow"] 15 | license = {file = "LICENSE.md"} 16 | classifiers = [ 17 | "Development Status :: 3 - Alpha", 18 | "Intended Audience :: Education", 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: MIT License", 21 | "Programming Language :: Python :: 3 :: Only", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.8", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Topic :: Software Development :: Libraries", 27 | "Topic :: Software Development :: Libraries :: Python Modules", 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | "Topic :: Scientific/Engineering :: Mathematics", 30 | "Topic :: Documentation :: Sphinx", 31 | "Natural Language :: English", 32 | ] 33 | dependencies = [ 34 | 'tensorflow>=2.0.0', 35 | 'numpy', 36 | 'scipy', 37 | 'scikit-learn', 38 | ] 39 | dynamic = ["version"] 40 | 41 | [project.urls] 42 | homepage = "https://github.com/JesperDramsch/keras-complex/" 43 | documentation = "https://keras-complex.readthedocs.org" 44 | repository = "https://github.com/JesperDramsch/keras-complex" 45 | changelog = "https://github.com/JesperDramsch/keras-complex/releases" 46 | bugtracker = "https://github.com/JesperDramsch/keras-complex/issues" 47 | 48 | [tool.setuptools.packages.find] 49 | exclude = ["tests*", "docs*"] 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | scipy 4 | tensorflow -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JesperDramsch/keras-complex/e723e96e6b309fd22545f542e535d2caa16ffd7b/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_conv.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tensorflow.keras.layers import Input, MaxPooling2D, Dense 4 | from tensorflow.keras.models import Model, Sequential 5 | import tensorflow as tf 6 | 7 | import numpy as np 8 | import complexnn as conn 9 | 10 | 11 | class TestConvMethods(unittest.TestCase): 12 | """Test Conv methods class""" 13 | 14 | def test_conv_outputs_forward(self): 15 | """Test computed shape of forward convolution output""" 16 | layer = conn.ComplexConv2D(filters=4, kernel_size=3, strides=2, padding="same", transposed=False) 17 | input_shape = (None, 128, 128, 2) 18 | true = (None, 64, 64, 8) 19 | calc = layer.compute_output_shape(input_shape) 20 | self.assertEqual(true, calc) 21 | 22 | def test_outputs_transpose(self): 23 | """Test computed shape of transposed convolution output""" 24 | layer = conn.ComplexConv2D(filters=2, kernel_size=3, strides=2, padding="same", transposed=True) 25 | input_shape = (None, 64, 64, 4) 26 | true = (None, 128, 128, 4) 27 | calc = layer.compute_output_shape(input_shape) 28 | self.assertEqual(true, calc) 29 | 30 | def test_conv2D_forward(self): 31 | """Test shape of model output, forward""" 32 | inputs = Input(shape=(128, 128, 2)) 33 | outputs = conn.ComplexConv2D(filters=4, kernel_size=3, strides=2, padding="same", transposed=False)(inputs) 34 | model = Model(inputs=inputs, outputs=outputs) 35 | true = (None, 64, 64, 8) 36 | calc = model.output_shape 37 | self.assertEqual(true, calc) 38 | 39 | def test_conv2Dtranspose(self): 40 | """Test shape of model output, transposed""" 41 | inputs = Input(shape=(64, 64, 20)) # = 10 CDN filters 42 | outputs = conn.ComplexConv2D( 43 | filters=2, kernel_size=3, strides=2, padding="same", transposed=True # = 4 Keras filters 44 | )(inputs) 45 | model = Model(inputs=inputs, outputs=outputs) 46 | true = (None, 128, 128, 4) 47 | calc = model.output_shape 48 | self.assertEqual(true, calc) 49 | -------------------------------------------------------------------------------- /tests/test_dense.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tensorflow.keras.layers import Input, MaxPooling2D, Dense 4 | from tensorflow.keras.models import Model, Sequential 5 | import tensorflow as tf 6 | 7 | import numpy as np 8 | import complexnn as conn 9 | 10 | 11 | class TestDenseMethods(unittest.TestCase): 12 | """Test Dense layer""" 13 | 14 | def test_outputs_dense(self): 15 | """Test computed shape of dense layer output""" 16 | layer = conn.ComplexDense(units=16, activation="relu") 17 | input_shape = (None, 8) 18 | true = (None, 16 * 2) 19 | calc = layer.compute_output_shape(input_shape) 20 | self.assertEqual(true, calc) 21 | 22 | def test_outputs_dense(self): 23 | """Test computed shape of dense layer output""" 24 | layer = conn.ComplexDense(units=16, activation="relu") 25 | input_shape = (None, 8) 26 | true = (None, 16 * 2) 27 | calc = layer.compute_output_shape(input_shape) 28 | self.assertEqual(true, calc) 29 | -------------------------------------------------------------------------------- /tests/test_readme.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tensorflow.keras.layers import Input, MaxPooling2D, Dense 4 | from tensorflow.keras.models import Model, Sequential 5 | import tensorflow as tf 6 | 7 | import numpy as np 8 | import complexnn as conn 9 | 10 | 11 | class TestDNCMethods(unittest.TestCase): 12 | """Unit test class""" 13 | 14 | def test_github_example(self): 15 | # example from repository https://github.com/JesperDramsch/keras-complex/blob/master/README.md page 16 | model = tf.keras.models.Sequential() 17 | model.add(conn.conv.ComplexConv2D(32, (3, 3), activation="relu", padding="same", input_shape=(28, 28, 2))) 18 | model.add(conn.bn.ComplexBatchNormalization()) 19 | model.add(MaxPooling2D((2, 2), padding="same")) 20 | model.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse") 21 | model.summary() -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tensorflow.keras.layers import Input, MaxPooling2D, Dense 4 | from tensorflow.keras.models import Model, Sequential 5 | import tensorflow as tf 6 | 7 | import numpy as np 8 | import complexnn as conn 9 | 10 | 11 | class TestTrainingRuns(unittest.TestCase): 12 | """Unit test class""" 13 | 14 | def test_train_transpose(self): 15 | """Train using Conv2DTranspose""" 16 | x = np.random.randn(64 * 64).reshape((64, 64)) 17 | y = np.random.randn(64 * 64).reshape((64, 64)) 18 | X = np.stack((x, y), -1) 19 | X = np.expand_dims(X, 0) 20 | Y = X 21 | inputs = Input(shape=(64, 64, 2)) 22 | conv1 = conn.ComplexConv2D( 23 | filters=2, kernel_size=3, strides=2, padding="same", transposed=False # = 4 Keras filters 24 | )(inputs) 25 | outputs = conn.ComplexConv2D( 26 | filters=1, kernel_size=3, strides=2, padding="same", transposed=True # = 2 Keras filters => 1 complex layer 27 | )(conv1) 28 | model = Model(inputs=inputs, outputs=outputs) 29 | model.compile(optimizer="adam", loss="mean_squared_error", metrics=["accuracy"]) 30 | model.fit(X, Y, batch_size=1, epochs=10) 31 | 32 | def test_train_dense(self): 33 | inputs = 28 34 | outputs = 128 35 | # build a sequential complex dense model 36 | model = Sequential(name="complex") 37 | model.add(conn.ComplexDense(32, activation="relu", input_shape=(inputs * 2,))) 38 | model.add(conn.ComplexBN()) 39 | model.add(conn.ComplexDense(64, activation="relu")) 40 | model.add(conn.ComplexBN()) 41 | model.add(Dense(128, activation="sigmoid")) 42 | model.compile(optimizer="adam", loss="mse") 43 | model.summary() 44 | # create some random data 45 | re = np.random.randn(inputs) 46 | im = np.random.randn(inputs) 47 | X = np.expand_dims(np.concatenate((re, im), -1), 0) 48 | Y = np.expand_dims(np.random.randn(outputs), 0) 49 | model.fit(X, Y, batch_size=1, epochs=10) 50 | --------------------------------------------------------------------------------