├── demo ├── __init__.py └── mnist.py ├── requirements.txt ├── tests ├── __init__.py ├── test_conv_1d.py └── test_conv_2d.py ├── .github ├── stale.yml ├── main.workflow └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── MANIFEST.in ├── octave_model.png ├── publish.sh ├── requirements-dev.txt ├── keras_octave_conv ├── __init__.py ├── conv_util.py ├── conv_1d.py └── conv_2d.py ├── test.sh ├── LICENSE ├── setup.py ├── .gitignore └── README.md /demo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | daysUntilStale: 5 2 | daysUntilClose: 2 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt 3 | include octave_model.png 4 | -------------------------------------------------------------------------------- /octave_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-octave-conv/HEAD/octave_model.png -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf dist/* && python3 setup.py sdist && twine upload dist/* 3 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | setuptools>=38.6.0 2 | twine>=1.11.0 3 | wheel>=0.31.0 4 | tensorflow 5 | nose 6 | pycodestyle 7 | coverage 8 | -------------------------------------------------------------------------------- /keras_octave_conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_1d import * 2 | from .conv_2d import * 3 | from .conv_util import * 4 | 5 | __version__ = '0.11.0' 6 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | pycodestyle --max-line-length=120 keras_octave_conv tests && \ 3 | nosetests --with-coverage --cover-erase --cover-html --cover-html-dir=htmlcov --cover-package=keras_octave_conv tests 4 | -------------------------------------------------------------------------------- /.github/main.workflow: -------------------------------------------------------------------------------- 1 | workflow "Code Style" { 2 | on = "push" 3 | resolves = ["lint-action"] 4 | } 5 | 6 | action "lint-action" { 7 | uses = "CyberZHG/github-action-python-lint@master" 8 | args = "--max-line-length=120 keras_octave_conv tests" 9 | } 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: CyberZHG 7 | 8 | --- 9 | 10 | **Describe the Bug** 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | **Version Info** 15 | 16 | * [ ] I'm using the latest version 17 | 18 | **Minimal Codes To Reproduce** 19 | 20 | ```python 21 | import keras_octave_conv 22 | 23 | pass 24 | ``` 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: CyberZHG 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Zhao HG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /keras_octave_conv/conv_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tensorflow import keras 3 | 4 | __all__ = ['octave_dual'] 5 | 6 | 7 | def octave_dual(layers, builder): 8 | """Apply layers for outputs of octave convolution. 9 | 10 | :param layers: The outputs of octave convolution. 11 | :param builder: A function that builds the layer or just a layer. 12 | :return: The output tensors. 13 | """ 14 | if not isinstance(layers, (list, tuple)): 15 | layers = [layers] 16 | if isinstance(builder, keras.layers.Layer): 17 | intermediates = [builder] + [copy.copy(builder) for _ in range(len(layers) - 1)] 18 | else: 19 | intermediates = [builder() for _ in range(len(layers))] 20 | for i, name in enumerate(['H', 'L']): 21 | if i < len(intermediates): 22 | try: 23 | intermediates[i].name += '-' + name 24 | except AttributeError as e: 25 | config = intermediates[i].get_config() 26 | config['name'] += '-' + name 27 | re_spawn_layer = intermediates[i].__class__.from_config(config) 28 | re_spawn_layer.set_weights(intermediates[i].get_weights()) 29 | intermediates[i] = re_spawn_layer 30 | outputs = [intermediate(layers[i]) for i, intermediate in enumerate(intermediates)] 31 | if len(outputs) == 1: 32 | outputs = outputs[0] 33 | return outputs 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | from setuptools import setup, find_packages 5 | 6 | current_path = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | 9 | def read_file(*parts): 10 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 11 | return reader.read() 12 | 13 | 14 | def get_requirements(*parts): 15 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 16 | return list(map(lambda x: x.strip(), reader.readlines())) 17 | 18 | 19 | def find_version(*file_paths): 20 | version_file = read_file(*file_paths) 21 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 22 | if version_match: 23 | return version_match.group(1) 24 | raise RuntimeError('Unable to find version string.') 25 | 26 | 27 | setup( 28 | name='keras-octave-conv', 29 | version=find_version('keras_octave_conv', '__init__.py'), 30 | packages=find_packages(), 31 | url='https://github.com/CyberZHG/keras-octave-conv', 32 | license='MIT AND "Anti 996"', 33 | author='CyberZHG', 34 | author_email='CyberZHG@users.noreply.github.com', 35 | description='Octave convolution', 36 | long_description=read_file('README.md'), 37 | long_description_content_type='text/markdown', 38 | install_requires=get_requirements('requirements.txt'), 39 | classifiers=( 40 | "Programming Language :: Python :: 3", 41 | "License :: OSI Approved :: MIT License", 42 | "Operating System :: OS Independent", 43 | ), 44 | ) 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # System thumbnail 107 | .DS_Store 108 | 109 | # IDE 110 | .idea 111 | 112 | # Images 113 | *.png 114 | 115 | # Models 116 | *.h5 117 | -------------------------------------------------------------------------------- /demo/mnist.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import keras.backend as K 3 | import numpy as np 4 | from keras.layers import Input, BatchNormalization, MaxPool2D, Conv2D, Dropout, Flatten, Dense 5 | from keras.models import Model 6 | from keras.datasets import fashion_mnist 7 | from keras_octave_conv import OctaveConv2D 8 | 9 | 10 | (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() 11 | 12 | x_train = np.expand_dims(x_train.astype(K.floatx()) / 255, axis=-1) 13 | x_test = np.expand_dims(x_test.astype(K.floatx()) / 255, axis=-1) 14 | 15 | y_train, y_test = np.expand_dims(y_train, axis=-1), np.expand_dims(y_test, axis=-1) 16 | 17 | train_num = round(x_train.shape[0] * 0.9) 18 | x_train, x_valid = x_train[:train_num, ...], x_train[train_num:, ...] 19 | y_train, y_valid = y_train[:train_num, ...], y_train[train_num:, ...] 20 | 21 | 22 | # Octave Conv 23 | inputs = Input(shape=(28, 28, 1)) 24 | normal = BatchNormalization()(inputs) 25 | high, low = OctaveConv2D(64, kernel_size=3)(normal) 26 | high, low = MaxPool2D()(high), MaxPool2D()(low) 27 | high, low = OctaveConv2D(32, kernel_size=3)([high, low]) 28 | conv = OctaveConv2D(16, kernel_size=3, ratio_out=0.0)([high, low]) 29 | pool = MaxPool2D()(conv) 30 | flatten = Flatten()(pool) 31 | normal = BatchNormalization()(flatten) 32 | dropout = Dropout(rate=0.4)(normal) 33 | outputs = Dense(units=10, activation='softmax')(dropout) 34 | model = Model(inputs=inputs, outputs=outputs) 35 | model.compile( 36 | optimizer='adam', 37 | loss='sparse_categorical_crossentropy', 38 | metrics=['accuracy'], 39 | ) 40 | 41 | model.summary() 42 | model.fit( 43 | x=x_train, 44 | y=y_train, 45 | epochs=10, 46 | validation_data=(x_valid, y_valid), 47 | callbacks=[keras.callbacks.EarlyStopping(monitor='val_acc', patience=2)] 48 | ) 49 | octave_score = model.evaluate(x_test, y_test) 50 | print('Accuracy of Octave: %.4f' % octave_score[1]) 51 | 52 | 53 | # Normal Conv 54 | inputs = Input(shape=(28, 28, 1)) 55 | normal = BatchNormalization()(inputs) 56 | conv = Conv2D(64, kernel_size=3, padding='same')(normal) 57 | pool = MaxPool2D()(conv) 58 | conv = Conv2D(32, kernel_size=3, padding='same')(pool) 59 | conv = Conv2D(16, kernel_size=3, padding='same')(conv) 60 | pool = MaxPool2D()(conv) 61 | flatten = Flatten()(pool) 62 | normal = BatchNormalization()(flatten) 63 | dropout = Dropout(rate=0.4)(normal) 64 | outputs = Dense(units=10, activation='softmax')(dropout) 65 | model = Model(inputs=inputs, outputs=outputs) 66 | model.compile( 67 | optimizer='adam', 68 | loss='sparse_categorical_crossentropy', 69 | metrics=['accuracy'], 70 | ) 71 | 72 | model.summary() 73 | model.fit( 74 | x=x_train, 75 | y=y_train, 76 | epochs=10, 77 | validation_data=(x_valid, y_valid), 78 | callbacks=[keras.callbacks.EarlyStopping(monitor='val_acc', patience=2)] 79 | ) 80 | normal_score = model.evaluate(x_test, y_test) 81 | print('Accuracy of Octave: %.4f' % octave_score[1]) 82 | print('Accuracy of normal: %.4f' % normal_score[1]) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras Octave Conv 2 | 3 | ![](https://img.shields.io/badge/license-MIT-blue.svg) 4 | 5 | Unofficial implementation of [Drop an Octave: Reducing Spatial Redundancy in 6 | Convolutional Neural Networks with Octave Convolution](https://arxiv.org/pdf/1904.05049.pdf). 7 | 8 | ## Install 9 | 10 | ```bash 11 | pip install keras-octave-conv 12 | ``` 13 | 14 | ## Usage 15 | 16 | The `OctaveConv2D` layer could be used just like the `Conv2D` layer, except the `padding` argument is forced to be `'same'`. 17 | 18 | ### First Octave 19 | 20 | Use a single input for the first octave layer: 21 | 22 | ```python 23 | from tensorflow.keras.layers import Input 24 | from keras_octave_conv import OctaveConv2D 25 | 26 | inputs = Input(shape=(32, 32, 3)) 27 | high, low = OctaveConv2D(filters=16, kernel_size=3, octave=2, ratio_out=0.125)(inputs) 28 | ``` 29 | 30 | The two outputs represent the results in higher and lower spatial resolutions. 31 | 32 | Special arguments: 33 | * `octave`: default is `2`. The division of the spatial dimensions. 34 | * `ratio_out`: default is `0.5`. The ratio of filters for lower spatial resolution. 35 | 36 | ### Intermediate Octave 37 | 38 | The intermediate octave layers takes two inputs and produce two outputs: 39 | 40 | ```python 41 | from tensorflow.keras.layers import Input, MaxPool2D 42 | from keras_octave_conv import OctaveConv2D 43 | 44 | inputs = Input(shape=(32, 32, 3)) 45 | high, low = OctaveConv2D(filters=16, kernel_size=3)(inputs) 46 | 47 | high, low = MaxPool2D()(high), MaxPool2D()(low) 48 | high, low = OctaveConv2D(filters=8, kernel_size=3)([high, low]) 49 | ``` 50 | 51 | Note that the same `octave` value should be used throughout the whole model. 52 | 53 | ### Last Octave 54 | 55 | Set `ratio_out` to `0.0` to get a single output for further processing: 56 | 57 | ```python 58 | from tensorflow.keras.layers import Input, MaxPool2D, Flatten, Dense 59 | from tensorflow.keras.models import Model 60 | from keras_octave_conv import OctaveConv2D 61 | 62 | inputs = Input(shape=(32, 32, 3)) 63 | high, low = OctaveConv2D(filters=16, kernel_size=3)(inputs) 64 | 65 | high, low = MaxPool2D()(high), MaxPool2D()(low) 66 | high, low = OctaveConv2D(filters=8, kernel_size=3)([high, low]) 67 | 68 | high, low = MaxPool2D()(high), MaxPool2D()(low) 69 | conv = OctaveConv2D(filters=4, kernel_size=3, ratio_out=0.0)([high, low]) 70 | 71 | flatten = Flatten()(conv) 72 | outputs = Dense(units=10, activation='softmax')(flatten) 73 | 74 | model = Model(inputs=inputs, outputs=outputs) 75 | model.summary() 76 | ``` 77 | 78 | ### Utility 79 | 80 | `octave_dual` helps to create dual layers for processing the outputs of octave convolutions: 81 | 82 | ```python 83 | from tensorflow.keras.layers import Input, MaxPool2D, Flatten, Dense 84 | from tensorflow.keras.models import Model 85 | from keras_octave_conv import OctaveConv2D, octave_dual 86 | 87 | inputs = Input(shape=(32, 32, 3)) 88 | conv = OctaveConv2D(filters=16, kernel_size=3)(inputs) 89 | 90 | pool = octave_dual(conv, MaxPool2D()) 91 | conv = OctaveConv2D(filters=8, kernel_size=3)(pool) 92 | 93 | pool = octave_dual(conv, MaxPool2D()) 94 | conv = OctaveConv2D(filters=4, kernel_size=3, ratio_out=0.0)(pool) 95 | 96 | flatten = Flatten()(conv) 97 | outputs = Dense(units=10, activation='softmax')(flatten) 98 | 99 | model = Model(inputs=inputs, outputs=outputs) 100 | model.summary() 101 | ``` 102 | 103 | `octave_conv_2d` creates the octave structure with built-in Keras layers: 104 | 105 | ```python 106 | from tensorflow.keras.layers import Input, MaxPool2D, Flatten, Dense 107 | from tensorflow.keras.models import Model 108 | from tensorflow.keras.utils import plot_model 109 | from keras_octave_conv import octave_conv_2d, octave_dual 110 | 111 | inputs = Input(shape=(32, 32, 3), name='Input') 112 | conv = octave_conv_2d(inputs, filters=16, kernel_size=3, name='Octave-First') 113 | 114 | pool = octave_dual(conv, MaxPool2D(name='Pool-1')) 115 | conv = octave_conv_2d(pool, filters=8, kernel_size=3, name='Octave-Mid') 116 | 117 | pool = octave_dual(conv, MaxPool2D(name='Pool-2')) 118 | conv = octave_conv_2d(pool, filters=4, kernel_size=3, ratio_out=0.0, name='Octave-Last') 119 | 120 | flatten = Flatten(name='Flatten')(conv) 121 | outputs = Dense(units=10, activation='softmax', name='Output')(flatten) 122 | 123 | model = Model(inputs=inputs, outputs=outputs) 124 | model.summary() 125 | plot_model(model, to_file='octave_model.png') 126 | ``` 127 | 128 | ![](./octave_model.png) 129 | -------------------------------------------------------------------------------- /tests/test_conv_1d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from unittest import TestCase 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from keras_octave_conv import OctaveConv1D, octave_dual, octave_conv_1d 9 | 10 | 11 | class TestConv1D(TestCase): 12 | 13 | def _test_fit(self, model): 14 | data_size = 4096 15 | x = np.random.standard_normal((data_size, 32, 3)) 16 | y = np.random.randint(0, 1, data_size) 17 | model.fit(x, y, epochs=3) 18 | model_path = os.path.join(tempfile.gettempdir(), 'test_octave_conv_%f.h5' % np.random.random()) 19 | model.save(model_path) 20 | model = keras.models.load_model(model_path, custom_objects={'OctaveConv1D': OctaveConv1D}) 21 | predicted = model.predict(x).argmax(axis=-1) 22 | diff = np.sum(np.abs(y - predicted)) 23 | self.assertLess(diff, 100) 24 | 25 | def test_fit_default(self): 26 | inputs = keras.layers.Input(shape=(32, 3)) 27 | high, low = OctaveConv1D(13, kernel_size=3)(inputs) 28 | high, low = keras.layers.MaxPool1D()(high), keras.layers.MaxPool1D()(low) 29 | high, low = OctaveConv1D(7, kernel_size=3)([high, low]) 30 | high, low = keras.layers.MaxPool1D()(high), keras.layers.MaxPool1D()(low) 31 | conv = OctaveConv1D(5, kernel_size=3, ratio_out=0.0)([high, low]) 32 | flatten = keras.layers.Flatten()(conv) 33 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 34 | model = keras.models.Model(inputs=inputs, outputs=outputs) 35 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 36 | model.summary(line_length=200) 37 | self._test_fit(model) 38 | 39 | def test_fit_octave(self): 40 | inputs = keras.layers.Input(shape=(32, 3)) 41 | high, low = OctaveConv1D(13, kernel_size=3, octave=4)(inputs) 42 | high, low = keras.layers.MaxPool1D()(high), keras.layers.MaxPool1D()(low) 43 | conv = OctaveConv1D(5, kernel_size=3, octave=4, ratio_out=0.0)([high, low]) 44 | flatten = keras.layers.Flatten()(conv) 45 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 46 | model = keras.models.Model(inputs=inputs, outputs=outputs) 47 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 48 | model.summary(line_length=200) 49 | self._test_fit(model) 50 | 51 | def test_fit_lower_output(self): 52 | inputs = keras.layers.Input(shape=(32, 3)) 53 | high, low = OctaveConv1D(13, kernel_size=3)(inputs) 54 | high, low = keras.layers.MaxPool1D()(high), keras.layers.MaxPool1D()(low) 55 | high, low = OctaveConv1D(7, kernel_size=3)([high, low]) 56 | high, low = keras.layers.MaxPool1D()(high), keras.layers.MaxPool1D()(low) 57 | conv = OctaveConv1D(5, kernel_size=3, ratio_out=1.0)([high, low]) 58 | flatten = keras.layers.Flatten()(conv) 59 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 60 | model = keras.models.Model(inputs=inputs, outputs=outputs) 61 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 62 | model.summary(line_length=200) 63 | self._test_fit(model) 64 | 65 | def test_raise_dimension_specified(self): 66 | with self.assertRaises(ValueError): 67 | inputs = keras.layers.Input(shape=(32, None)) 68 | outputs = OctaveConv1D(13, kernel_size=3, ratio_out=0.0)(inputs) 69 | model = keras.models.Model(inputs=inputs, outputs=outputs) 70 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 71 | with self.assertRaises(ValueError): 72 | inputs_high = keras.layers.Input(shape=(32, 3)) 73 | inputs_low = keras.layers.Input(shape=(32, None)) 74 | outputs = OctaveConv1D(13, kernel_size=3, ratio_out=0.0)([inputs_high, inputs_low]) 75 | model = keras.models.Model(inputs=[inputs_high, inputs_low], outputs=outputs) 76 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 77 | 78 | def test_raise_octave_divisible(self): 79 | with self.assertRaises(ValueError): 80 | inputs = keras.layers.Input(shape=(32, 3)) 81 | outputs = OctaveConv1D(13, kernel_size=3, octave=5, ratio_out=0.0)(inputs) 82 | model = keras.models.Model(inputs=inputs, outputs=outputs) 83 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 84 | 85 | def test_make_dual_layer(self): 86 | inputs = keras.layers.Input(shape=(32, 3)) 87 | conv = OctaveConv1D(13, kernel_size=3)(inputs) 88 | pool = octave_dual(conv, keras.layers.MaxPool1D()) 89 | conv = OctaveConv1D(7, kernel_size=3)(pool) 90 | pool = octave_dual(conv, keras.layers.MaxPool1D()) 91 | conv = OctaveConv1D(5, kernel_size=3, ratio_out=0.0)(pool) 92 | flatten = octave_dual(conv, keras.layers.Flatten()) 93 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 94 | model = keras.models.Model(inputs=inputs, outputs=outputs) 95 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 96 | model.summary(line_length=200) 97 | self._test_fit(model) 98 | 99 | def test_fit_octave_conv_high(self): 100 | inputs = keras.layers.Input(shape=(32, 3)) 101 | conv = octave_conv_1d(inputs, filters=13, kernel_size=3) 102 | pool = octave_dual(conv, keras.layers.MaxPool1D()) 103 | conv = octave_conv_1d(pool, filters=7, kernel_size=3, name='Mid') 104 | pool = octave_dual(conv, keras.layers.MaxPool1D()) 105 | conv = octave_conv_1d(pool, filters=5, kernel_size=3, ratio_out=0.0) 106 | flatten = octave_dual(conv, keras.layers.Flatten()) 107 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 108 | model = keras.models.Model(inputs=inputs, outputs=outputs) 109 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 110 | model.summary(line_length=200) 111 | self._test_fit(model) 112 | 113 | def test_fit_octave_conv_low(self): 114 | inputs = keras.layers.Input(shape=(32, 3)) 115 | conv = octave_conv_1d(inputs, filters=13, kernel_size=3) 116 | pool = octave_dual(conv, keras.layers.MaxPool1D()) 117 | conv = octave_conv_1d(pool, filters=7, kernel_size=3, name='Mid') 118 | pool = octave_dual(conv, keras.layers.MaxPool1D()) 119 | conv = octave_conv_1d(pool, filters=5, kernel_size=3, ratio_out=1.0) 120 | flatten = octave_dual(conv, keras.layers.Flatten()) 121 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 122 | model = keras.models.Model(inputs=inputs, outputs=outputs) 123 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 124 | model.summary(line_length=200) 125 | self._test_fit(model) 126 | -------------------------------------------------------------------------------- /tests/test_conv_2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from unittest import TestCase 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from keras_octave_conv import OctaveConv2D, octave_conv_2d, octave_dual 9 | 10 | 11 | class TestConv2D(TestCase): 12 | 13 | def _test_fit(self, model, data_format='channels_last'): 14 | data_size = 4096 15 | if data_format == 'channels_last': 16 | x = np.random.standard_normal((data_size, 32, 32, 3)) 17 | else: 18 | x = np.random.standard_normal((data_size, 3, 32, 32)) 19 | y = np.random.randint(0, 1, data_size) 20 | model.fit(x, y, epochs=3) 21 | model_path = os.path.join(tempfile.gettempdir(), 'test_octave_conv_%f.h5' % np.random.random()) 22 | model.save(model_path) 23 | model = keras.models.load_model(model_path, custom_objects={'OctaveConv2D': OctaveConv2D}) 24 | predicted = model.predict(x).argmax(axis=-1) 25 | diff = np.sum(np.abs(y - predicted)) 26 | self.assertLess(diff, 100) 27 | 28 | def test_fit_default(self): 29 | inputs = keras.layers.Input(shape=(32, 32, 3)) 30 | high, low = OctaveConv2D(13, kernel_size=3)(inputs) 31 | high, low = keras.layers.MaxPool2D()(high), keras.layers.MaxPool2D()(low) 32 | high, low = OctaveConv2D(7, kernel_size=3)([high, low]) 33 | high, low = keras.layers.MaxPool2D()(high), keras.layers.MaxPool2D()(low) 34 | conv = OctaveConv2D(5, kernel_size=3, ratio_out=0.0)([high, low]) 35 | flatten = keras.layers.Flatten()(conv) 36 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 37 | model = keras.models.Model(inputs=inputs, outputs=outputs) 38 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 39 | model.summary(line_length=200) 40 | self._test_fit(model) 41 | 42 | def test_fit_channels_first(self): 43 | return 'The test needs GPU support' 44 | inputs = keras.layers.Input(shape=(3, 32, 32)) 45 | high, low = OctaveConv2D(13, kernel_size=3, data_format='channels_first')(inputs) 46 | high = keras.layers.MaxPool2D(data_format='channels_first')(high) 47 | low = keras.layers.MaxPool2D(data_format='channels_first')(low) 48 | high, low = OctaveConv2D(7, kernel_size=3, data_format='channels_first')([high, low]) 49 | high = keras.layers.MaxPool2D(data_format='channels_first')(high) 50 | low = keras.layers.MaxPool2D(data_format='channels_first')(low) 51 | conv = OctaveConv2D(5, kernel_size=3, ratio_out=0.0, data_format='channels_first')([high, low]) 52 | flatten = keras.layers.Flatten()(conv) 53 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 54 | model = keras.models.Model(inputs=inputs, outputs=outputs) 55 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 56 | model.summary(line_length=200) 57 | self._test_fit(model, data_format='channels_first') 58 | 59 | def test_fit_octave(self): 60 | inputs = keras.layers.Input(shape=(32, 32, 3)) 61 | high, low = OctaveConv2D(13, kernel_size=3, octave=4)(inputs) 62 | high, low = keras.layers.MaxPool2D()(high), keras.layers.MaxPool2D()(low) 63 | conv = OctaveConv2D(5, kernel_size=3, octave=4, ratio_out=0.0)([high, low]) 64 | flatten = keras.layers.Flatten()(conv) 65 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 66 | model = keras.models.Model(inputs=inputs, outputs=outputs) 67 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 68 | model.summary(line_length=200) 69 | self._test_fit(model) 70 | 71 | def test_fit_lower_output(self): 72 | inputs = keras.layers.Input(shape=(32, 32, 3)) 73 | high, low = OctaveConv2D(13, kernel_size=3)(inputs) 74 | high, low = keras.layers.MaxPool2D()(high), keras.layers.MaxPool2D()(low) 75 | high, low = OctaveConv2D(7, kernel_size=3)([high, low]) 76 | high, low = keras.layers.MaxPool2D()(high), keras.layers.MaxPool2D()(low) 77 | conv = OctaveConv2D(5, kernel_size=3, ratio_out=1.0)([high, low]) 78 | flatten = keras.layers.Flatten()(conv) 79 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 80 | model = keras.models.Model(inputs=inputs, outputs=outputs) 81 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 82 | model.summary(line_length=200) 83 | self._test_fit(model) 84 | 85 | def test_raise_dimension_specified(self): 86 | with self.assertRaises(ValueError): 87 | inputs = keras.layers.Input(shape=(32, 32, None)) 88 | outputs = OctaveConv2D(13, kernel_size=3, ratio_out=0.0)(inputs) 89 | model = keras.models.Model(inputs=inputs, outputs=outputs) 90 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 91 | with self.assertRaises(ValueError): 92 | inputs_high = keras.layers.Input(shape=(32, 32, 3)) 93 | inputs_low = keras.layers.Input(shape=(32, 32, None)) 94 | outputs = OctaveConv2D(13, kernel_size=3, ratio_out=0.0)([inputs_high, inputs_low]) 95 | model = keras.models.Model(inputs=[inputs_high, inputs_low], outputs=outputs) 96 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 97 | 98 | def test_raise_octave_divisible(self): 99 | with self.assertRaises(ValueError): 100 | inputs = keras.layers.Input(shape=(32, 32, 3)) 101 | outputs = OctaveConv2D(13, kernel_size=3, octave=5, ratio_out=0.0)(inputs) 102 | model = keras.models.Model(inputs=inputs, outputs=outputs) 103 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 104 | 105 | def test_make_dual_lambda(self): 106 | inputs = keras.layers.Input(shape=(32, 32, 3)) 107 | conv = OctaveConv2D(13, kernel_size=3)(inputs) 108 | pool = octave_dual(conv, lambda: keras.layers.MaxPool2D()) 109 | conv = OctaveConv2D(7, kernel_size=3)(pool) 110 | pool = octave_dual(conv, lambda: keras.layers.MaxPool2D()) 111 | conv = OctaveConv2D(5, kernel_size=3, ratio_out=0.0)(pool) 112 | flatten = keras.layers.Flatten()(conv) 113 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 114 | model = keras.models.Model(inputs=inputs, outputs=outputs) 115 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 116 | model.summary(line_length=200) 117 | self._test_fit(model) 118 | 119 | def test_fit_octave_conv_high(self): 120 | inputs = keras.layers.Input(shape=(32, 32, 3)) 121 | conv = octave_conv_2d(inputs, filters=13, kernel_size=3) 122 | pool = octave_dual(conv, keras.layers.MaxPool2D()) 123 | conv = octave_conv_2d(pool, filters=7, kernel_size=3, name='Octave-Mid') 124 | pool = octave_dual(conv, keras.layers.MaxPool2D()) 125 | conv = octave_conv_2d(pool, filters=5, kernel_size=3, ratio_out=0.0) 126 | flatten = keras.layers.Flatten()(conv) 127 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 128 | model = keras.models.Model(inputs=inputs, outputs=outputs) 129 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 130 | model.summary(line_length=200) 131 | self._test_fit(model) 132 | 133 | def test_fit_octave_conv_low(self): 134 | inputs = keras.layers.Input(shape=(32, 32, 3)) 135 | conv = octave_conv_2d(inputs, filters=13, kernel_size=3) 136 | pool = octave_dual(conv, keras.layers.MaxPool2D()) 137 | conv = octave_conv_2d(pool, filters=7, kernel_size=3, name='Octave-Mid') 138 | pool = octave_dual(conv, keras.layers.MaxPool2D()) 139 | conv = octave_conv_2d(pool, filters=5, kernel_size=3, ratio_out=1.0) 140 | flatten = keras.layers.Flatten()(conv) 141 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 142 | model = keras.models.Model(inputs=inputs, outputs=outputs) 143 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 144 | model.summary(line_length=200) 145 | self._test_fit(model) 146 | 147 | def test_fit_stride(self): 148 | inputs = keras.layers.Input(shape=(32, 32, 3)) 149 | high, low = OctaveConv2D(13, kernel_size=3, strides=(1, 2))(inputs) 150 | high, low = keras.layers.MaxPool2D()(high), keras.layers.MaxPool2D()(low) 151 | conv = OctaveConv2D(5, kernel_size=3, ratio_out=0.0)([high, low]) 152 | flatten = keras.layers.Flatten()(conv) 153 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 154 | model = keras.models.Model(inputs=inputs, outputs=outputs) 155 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 156 | model.summary(line_length=200) 157 | self._test_fit(model) 158 | 159 | def test_fit_octave_conv_stride(self): 160 | inputs = keras.layers.Input(shape=(32, 32, 3)) 161 | conv = octave_conv_2d(inputs, filters=13, kernel_size=3, strides=(1, 2)) 162 | pool = octave_dual(conv, keras.layers.MaxPool2D()) 163 | conv = octave_conv_2d(pool, filters=5, kernel_size=3, ratio_out=1.0) 164 | flatten = keras.layers.Flatten()(conv) 165 | outputs = keras.layers.Dense(units=2, activation='softmax')(flatten) 166 | model = keras.models.Model(inputs=inputs, outputs=outputs) 167 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 168 | model.summary(line_length=200) 169 | self._test_fit(model) 170 | -------------------------------------------------------------------------------- /keras_octave_conv/conv_1d.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import layers, activations, initializers, regularizers, constraints 2 | from tensorflow.keras import backend as K 3 | 4 | __all__ = ['OctaveConv1D', 'octave_conv_1d'] 5 | 6 | 7 | class OctaveConv1D(layers.Layer): 8 | """Octave convolutions. 9 | 10 | # Arguments 11 | octave: The division of the spatial dimensions by a power of 2. 12 | ratio_out: The ratio of filters for lower spatial resolution. 13 | 14 | # References 15 | - [Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution] 16 | (https://arxiv.org/pdf/1904.05049.pdf) 17 | """ 18 | 19 | def __init__(self, 20 | filters, 21 | kernel_size, 22 | octave=2, 23 | ratio_out=0.5, 24 | strides=1, 25 | dilation_rate=1, 26 | activation=None, 27 | use_bias=True, 28 | kernel_initializer='glorot_uniform', 29 | bias_initializer='zeros', 30 | kernel_regularizer=None, 31 | bias_regularizer=None, 32 | activity_regularizer=None, 33 | kernel_constraint=None, 34 | bias_constraint=None, 35 | **kwargs): 36 | super(OctaveConv1D, self).__init__(**kwargs) 37 | self.filters = filters 38 | self.kernel_size = kernel_size 39 | self.octave = octave 40 | self.ratio_out = ratio_out 41 | self.strides = strides 42 | self.dilation_rate = dilation_rate 43 | self.activation = activations.get(activation) 44 | self.use_bias = use_bias 45 | self.kernel_initializer = initializers.get(kernel_initializer) 46 | self.bias_initializer = initializers.get(bias_initializer) 47 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 48 | self.bias_regularizer = regularizers.get(bias_regularizer) 49 | self.activity_regularizer = regularizers.get(activity_regularizer) 50 | self.kernel_constraint = constraints.get(kernel_constraint) 51 | self.bias_constraint = constraints.get(bias_constraint) 52 | 53 | self.filters_low = int(filters * self.ratio_out) 54 | self.filters_high = filters - self.filters_low 55 | 56 | self.conv_high_to_high, self.conv_low_to_high = None, None 57 | if self.filters_high > 0: 58 | self.conv_high_to_high = self._init_conv(self.filters_high, name='{}-Conv1D-HH'.format(self.name)) 59 | self.conv_low_to_high = self._init_conv(self.filters_high, name='{}-Conv1D-LH'.format(self.name)) 60 | self.conv_low_to_low, self.conv_high_to_low = None, None 61 | if self.filters_low > 0: 62 | self.conv_low_to_low = self._init_conv(self.filters_low, name='{}-Conv1D-HL'.format(self.name)) 63 | self.conv_high_to_low = self._init_conv(self.filters_low, name='{}-Conv1D-LL'.format(self.name)) 64 | self.pooling = layers.AveragePooling1D( 65 | pool_size=self.octave, 66 | padding='valid', 67 | name='{}-AveragePooling1D'.format(self.name), 68 | ) 69 | self.up_sampling = layers.UpSampling1D( 70 | size=self.octave, 71 | name='{}-UpSampling1D'.format(self.name), 72 | ) 73 | 74 | def _init_conv(self, filters, name): 75 | return layers.Conv1D( 76 | filters=filters, 77 | kernel_size=self.kernel_size, 78 | strides=self.strides, 79 | padding='same', 80 | dilation_rate=self.dilation_rate, 81 | activation=self.activation, 82 | use_bias=self.use_bias, 83 | kernel_initializer=self.kernel_initializer, 84 | bias_initializer=self.bias_initializer, 85 | kernel_regularizer=self.kernel_regularizer, 86 | bias_regularizer=self.bias_regularizer, 87 | activity_regularizer=self.activity_regularizer, 88 | kernel_constraint=self.kernel_constraint, 89 | bias_constraint=self.bias_constraint, 90 | name=name, 91 | ) 92 | 93 | def build(self, input_shape): 94 | if isinstance(input_shape, list): 95 | input_shape_high, input_shape_low = input_shape 96 | else: 97 | input_shape_high, input_shape_low = input_shape, None 98 | if input_shape_high[-1] is None: 99 | raise ValueError('The channel dimension of the higher spatial inputs ' 100 | 'should be defined. Found `None`.') 101 | if input_shape_low is not None and input_shape_low[-1] is None: 102 | raise ValueError('The channel dimension of the lower spatial inputs ' 103 | 'should be defined. Found `None`.') 104 | if input_shape_high[-2] is not None and input_shape_high[-2] % self.octave != 0: 105 | raise ValueError('The length of the higher spatial inputs should be divisible by the octave. ' 106 | 'Found {} and {}.'.format(input_shape_high, self.octave)) 107 | if input_shape_low is None: 108 | self.conv_low_to_high, self.conv_low_to_low = None, None 109 | 110 | if self.conv_high_to_high is not None: 111 | with K.name_scope(self.conv_high_to_high.name): 112 | self.conv_high_to_high.build(input_shape_high) 113 | if self.conv_low_to_high is not None: 114 | with K.name_scope(self.conv_low_to_high.name): 115 | self.conv_low_to_high.build(input_shape_low) 116 | if self.conv_high_to_low is not None: 117 | with K.name_scope(self.conv_high_to_low.name): 118 | self.conv_high_to_low.build(input_shape_high) 119 | if self.conv_low_to_low is not None: 120 | with K.name_scope(self.conv_low_to_low.name): 121 | self.conv_low_to_low.build(input_shape_low) 122 | super(OctaveConv1D, self).build(input_shape) 123 | 124 | @property 125 | def trainable_weights(self): 126 | weights = [] 127 | if self.conv_high_to_high is not None: 128 | weights += self.conv_high_to_high.trainable_weights 129 | if self.conv_low_to_high is not None: 130 | weights += self.conv_low_to_high.trainable_weights 131 | if self.conv_high_to_low is not None: 132 | weights += self.conv_high_to_low.trainable_weights 133 | if self.conv_low_to_low is not None: 134 | weights += self.conv_low_to_low.trainable_weights 135 | return weights 136 | 137 | @property 138 | def non_trainable_weights(self): 139 | weights = [] 140 | if self.conv_high_to_high is not None: 141 | weights += self.conv_high_to_high.non_trainable_weights 142 | if self.conv_low_to_high is not None: 143 | weights += self.conv_low_to_high.non_trainable_weights 144 | if self.conv_high_to_low is not None: 145 | weights += self.conv_high_to_low.non_trainable_weights 146 | if self.conv_low_to_low is not None: 147 | weights += self.conv_low_to_low.non_trainable_weights 148 | return weights 149 | 150 | def compute_output_shape(self, input_shape): 151 | if isinstance(input_shape, list): 152 | input_shape_high, input_shape_low = input_shape 153 | else: 154 | input_shape_high, input_shape_low = input_shape, None 155 | 156 | output_shape_high = None 157 | if self.filters_high > 0: 158 | output_shape_high = self.conv_high_to_high.compute_output_shape(input_shape_high) 159 | output_shape_low = None 160 | if self.filters_low > 0: 161 | output_shape_low = self.conv_high_to_low.compute_output_shape( 162 | self.pooling.compute_output_shape(input_shape_high), 163 | ) 164 | 165 | if self.filters_low == 0: 166 | return output_shape_high 167 | if self.filters_high == 0: 168 | return output_shape_low 169 | return [output_shape_high, output_shape_low] 170 | 171 | def call(self, inputs, **kwargs): 172 | if isinstance(inputs, list): 173 | inputs_high, inputs_low = inputs 174 | else: 175 | inputs_high, inputs_low = inputs, None 176 | 177 | outputs_high_to_high, outputs_low_to_high = 0.0, 0.0 178 | if self.conv_high_to_high is not None: 179 | outputs_high_to_high = self.conv_high_to_high(inputs_high) 180 | if self.conv_low_to_high is not None: 181 | outputs_low_to_high = self.up_sampling(self.conv_low_to_high(inputs_low)) 182 | outputs_high = outputs_high_to_high + outputs_low_to_high 183 | 184 | outputs_low_to_low, outputs_high_to_low = 0.0, 0.0 185 | if self.conv_low_to_low is not None: 186 | outputs_low_to_low = self.conv_low_to_low(inputs_low) 187 | if self.conv_high_to_low is not None: 188 | outputs_high_to_low = self.conv_high_to_low(self.pooling(inputs_high)) 189 | outputs_low = outputs_low_to_low + outputs_high_to_low 190 | 191 | if self.filters_low == 0: 192 | return outputs_high 193 | if self.filters_high == 0: 194 | return outputs_low 195 | return [outputs_high, outputs_low] 196 | 197 | def get_config(self): 198 | config = { 199 | 'filters': self.filters, 200 | 'kernel_size': self.kernel_size, 201 | 'octave': self.octave, 202 | 'ratio_out': self.ratio_out, 203 | 'strides': self.strides, 204 | 'dilation_rate': self.dilation_rate, 205 | 'activation': activations.serialize(self.activation), 206 | 'use_bias': self.use_bias, 207 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 208 | 'bias_initializer': initializers.serialize(self.bias_initializer), 209 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 210 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 211 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 212 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 213 | 'bias_constraint': constraints.serialize(self.bias_constraint) 214 | } 215 | base_config = super(OctaveConv1D, self).get_config() 216 | return dict(list(base_config.items()) + list(config.items())) 217 | 218 | 219 | def octave_conv_1d(inputs, 220 | filters, 221 | kernel_size, 222 | octave=2, 223 | ratio_out=0.5, 224 | strides=1, 225 | dilation_rate=1, 226 | activation=None, 227 | use_bias=True, 228 | kernel_initializer='glorot_uniform', 229 | bias_initializer='zeros', 230 | kernel_regularizer=None, 231 | bias_regularizer=None, 232 | activity_regularizer=None, 233 | kernel_constraint=None, 234 | bias_constraint=None, 235 | name=None, 236 | **kwargs): 237 | if isinstance(inputs, (list, tuple)): 238 | inputs_high, inputs_low = inputs 239 | else: 240 | inputs_high, inputs_low = inputs, None 241 | 242 | filters_low = int(filters * ratio_out) 243 | filters_high = filters - filters_low 244 | 245 | def _init_conv(conv_filters, conv_name_suffix): 246 | if name is None: 247 | conv_name = None 248 | else: 249 | conv_name = name + '-' + conv_name_suffix 250 | return layers.Conv1D( 251 | filters=conv_filters, 252 | kernel_size=kernel_size, 253 | strides=strides, 254 | padding='same', 255 | dilation_rate=dilation_rate, 256 | activation=activation, 257 | use_bias=use_bias, 258 | kernel_initializer=kernel_initializer, 259 | bias_initializer=bias_initializer, 260 | kernel_regularizer=kernel_regularizer, 261 | bias_regularizer=bias_regularizer, 262 | activity_regularizer=activity_regularizer, 263 | kernel_constraint=kernel_constraint, 264 | bias_constraint=bias_constraint, 265 | name=conv_name, 266 | **kwargs 267 | ) 268 | 269 | outputs_high = None 270 | if filters_high > 0: 271 | outputs_high = _init_conv(filters_high, 'HH')(inputs_high) 272 | if inputs_low is not None: 273 | if name is None: 274 | up_sampling_name, add_name = None, None 275 | else: 276 | up_sampling_name, add_name = name + '-UpSample', name + '-Add-H' 277 | outputs_high = layers.Add(name=add_name)([outputs_high, layers.UpSampling1D( 278 | size=octave, 279 | name=up_sampling_name, 280 | )(_init_conv(filters_high, 'LH')(inputs_low))]) 281 | 282 | outputs_low = None 283 | if filters_low > 0: 284 | if name is None: 285 | pooling_name, add_name = None, None 286 | else: 287 | pooling_name, add_name = name + '-Pool', name + '-Add-L' 288 | outputs_low = _init_conv(filters_low, 'HL')(layers.AveragePooling1D( 289 | pool_size=octave, 290 | padding='valid', 291 | name=pooling_name, 292 | )(inputs_high)) 293 | if inputs_low is not None: 294 | outputs_low = layers.Add(name=add_name)([_init_conv(filters_low, 'LL')(inputs_low), outputs_low]) 295 | 296 | if outputs_high is None: 297 | return outputs_low 298 | if outputs_low is None: 299 | return outputs_high 300 | return [outputs_high, outputs_low] 301 | -------------------------------------------------------------------------------- /keras_octave_conv/conv_2d.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import layers, activations, initializers, regularizers, constraints 2 | from tensorflow.keras import backend as K 3 | 4 | __all__ = ['OctaveConv2D', 'octave_conv_2d'] 5 | 6 | 7 | class OctaveConv2D(layers.Layer): 8 | """Octave convolutions. 9 | 10 | # Arguments 11 | octave: The division of the spatial dimensions by a power of 2. 12 | ratio_out: The ratio of filters for lower spatial resolution. 13 | 14 | # References 15 | - [Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution] 16 | (https://arxiv.org/pdf/1904.05049.pdf) 17 | """ 18 | 19 | def __init__(self, 20 | filters, 21 | kernel_size, 22 | octave=2, 23 | ratio_out=0.5, 24 | strides=(1, 1), 25 | data_format=None, 26 | dilation_rate=(1, 1), 27 | activation=None, 28 | use_bias=True, 29 | kernel_initializer='glorot_uniform', 30 | bias_initializer='zeros', 31 | kernel_regularizer=None, 32 | bias_regularizer=None, 33 | activity_regularizer=None, 34 | kernel_constraint=None, 35 | bias_constraint=None, 36 | **kwargs): 37 | super(OctaveConv2D, self).__init__(**kwargs) 38 | self.filters = filters 39 | self.kernel_size = kernel_size 40 | self.octave = octave 41 | self.ratio_out = ratio_out 42 | self.strides = strides 43 | self.data_format = data_format 44 | self.dilation_rate = dilation_rate 45 | self.activation = activations.get(activation) 46 | self.use_bias = use_bias 47 | self.kernel_initializer = initializers.get(kernel_initializer) 48 | self.bias_initializer = initializers.get(bias_initializer) 49 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 50 | self.bias_regularizer = regularizers.get(bias_regularizer) 51 | self.activity_regularizer = regularizers.get(activity_regularizer) 52 | self.kernel_constraint = constraints.get(kernel_constraint) 53 | self.bias_constraint = constraints.get(bias_constraint) 54 | 55 | self.filters_low = int(filters * self.ratio_out) 56 | self.filters_high = filters - self.filters_low 57 | 58 | self.conv_high_to_high, self.conv_low_to_high = None, None 59 | if self.filters_high > 0: 60 | self.conv_high_to_high = self._init_conv(self.filters_high, name='{}-Conv2D-HH'.format(self.name)) 61 | self.conv_low_to_high = self._init_conv(self.filters_high, name='{}-Conv2D-LH'.format(self.name)) 62 | self.conv_low_to_low, self.conv_high_to_low = None, None 63 | if self.filters_low > 0: 64 | self.conv_low_to_low = self._init_conv(self.filters_low, name='{}-Conv2D-HL'.format(self.name)) 65 | self.conv_high_to_low = self._init_conv(self.filters_low, name='{}-Conv2D-LL'.format(self.name)) 66 | self.pooling = layers.AveragePooling2D( 67 | pool_size=self.octave, 68 | padding='valid', 69 | data_format=data_format, 70 | name='{}-AveragePooling2D'.format(self.name), 71 | ) 72 | self.up_sampling = layers.UpSampling2D( 73 | size=self.octave, 74 | data_format=data_format, 75 | interpolation='nearest', 76 | name='{}-UpSampling2D'.format(self.name), 77 | ) 78 | 79 | def _init_conv(self, filters, name): 80 | return layers.Conv2D( 81 | filters=filters, 82 | kernel_size=self.kernel_size, 83 | strides=self.strides, 84 | padding='same', 85 | data_format=self.data_format, 86 | dilation_rate=self.dilation_rate, 87 | activation=self.activation, 88 | use_bias=self.use_bias, 89 | kernel_initializer=self.kernel_initializer, 90 | bias_initializer=self.bias_initializer, 91 | kernel_regularizer=self.kernel_regularizer, 92 | bias_regularizer=self.bias_regularizer, 93 | activity_regularizer=self.activity_regularizer, 94 | kernel_constraint=self.kernel_constraint, 95 | bias_constraint=self.bias_constraint, 96 | name=name, 97 | ) 98 | 99 | def build(self, input_shape): 100 | if isinstance(input_shape, list): 101 | input_shape_high, input_shape_low = input_shape 102 | else: 103 | input_shape_high, input_shape_low = input_shape, None 104 | if self.data_format == 'channels_first': 105 | channel_axis, rows_axis, cols_axis = 1, 2, 3 106 | else: 107 | rows_axis, cols_axis, channel_axis = 1, 2, 3 108 | if input_shape_high[channel_axis] is None: 109 | raise ValueError('The channel dimension of the higher spatial inputs ' 110 | 'should be defined. Found `None`.') 111 | if input_shape_low is not None and input_shape_low[channel_axis] is None: 112 | raise ValueError('The channel dimension of the lower spatial inputs ' 113 | 'should be defined. Found `None`.') 114 | if input_shape_high[rows_axis] is not None and input_shape_high[rows_axis] % self.octave != 0 or \ 115 | input_shape_high[cols_axis] is not None and input_shape_high[cols_axis] % self.octave != 0: 116 | raise ValueError('The rows and columns of the higher spatial inputs should be divisible by the octave. ' 117 | 'Found {} and {}.'.format(input_shape_high, self.octave)) 118 | if input_shape_low is None: 119 | self.conv_low_to_high, self.conv_low_to_low = None, None 120 | 121 | if self.conv_high_to_high is not None: 122 | with K.name_scope(self.conv_high_to_high.name): 123 | self.conv_high_to_high.build(input_shape_high) 124 | if self.conv_low_to_high is not None: 125 | with K.name_scope(self.conv_low_to_high.name): 126 | self.conv_low_to_high.build(input_shape_low) 127 | if self.conv_high_to_low is not None: 128 | with K.name_scope(self.conv_high_to_low.name): 129 | self.conv_high_to_low.build(input_shape_high) 130 | if self.conv_low_to_low is not None: 131 | with K.name_scope(self.conv_low_to_low.name): 132 | self.conv_low_to_low.build(input_shape_low) 133 | super(OctaveConv2D, self).build(input_shape) 134 | 135 | @property 136 | def trainable_weights(self): 137 | weights = [] 138 | if self.conv_high_to_high is not None: 139 | weights += self.conv_high_to_high.trainable_weights 140 | if self.conv_low_to_high is not None: 141 | weights += self.conv_low_to_high.trainable_weights 142 | if self.conv_high_to_low is not None: 143 | weights += self.conv_high_to_low.trainable_weights 144 | if self.conv_low_to_low is not None: 145 | weights += self.conv_low_to_low.trainable_weights 146 | return weights 147 | 148 | @property 149 | def non_trainable_weights(self): 150 | weights = [] 151 | if self.conv_high_to_high is not None: 152 | weights += self.conv_high_to_high.non_trainable_weights 153 | if self.conv_low_to_high is not None: 154 | weights += self.conv_low_to_high.non_trainable_weights 155 | if self.conv_high_to_low is not None: 156 | weights += self.conv_high_to_low.non_trainable_weights 157 | if self.conv_low_to_low is not None: 158 | weights += self.conv_low_to_low.non_trainable_weights 159 | return weights 160 | 161 | def compute_output_shape(self, input_shape): 162 | if isinstance(input_shape, list): 163 | input_shape_high, input_shape_low = input_shape 164 | else: 165 | input_shape_high, input_shape_low = input_shape, None 166 | 167 | output_shape_high = None 168 | if self.filters_high > 0: 169 | output_shape_high = self.conv_high_to_high.compute_output_shape(input_shape_high) 170 | output_shape_low = None 171 | if self.filters_low > 0: 172 | output_shape_low = self.conv_high_to_low.compute_output_shape( 173 | self.pooling.compute_output_shape(input_shape_high), 174 | ) 175 | 176 | if self.filters_low == 0: 177 | return output_shape_high 178 | if self.filters_high == 0: 179 | return output_shape_low 180 | return [output_shape_high, output_shape_low] 181 | 182 | def call(self, inputs, **kwargs): 183 | if isinstance(inputs, list): 184 | inputs_high, inputs_low = inputs 185 | else: 186 | inputs_high, inputs_low = inputs, None 187 | 188 | outputs_high_to_high, outputs_low_to_high = 0.0, 0.0 189 | if self.conv_high_to_high is not None: 190 | outputs_high_to_high = self.conv_high_to_high(inputs_high) 191 | if self.conv_low_to_high is not None: 192 | outputs_low_to_high = self.up_sampling(self.conv_low_to_high(inputs_low)) 193 | outputs_high = outputs_high_to_high + outputs_low_to_high 194 | 195 | outputs_low_to_low, outputs_high_to_low = 0.0, 0.0 196 | if self.conv_low_to_low is not None: 197 | outputs_low_to_low = self.conv_low_to_low(inputs_low) 198 | if self.conv_high_to_low is not None: 199 | outputs_high_to_low = self.conv_high_to_low(self.pooling(inputs_high)) 200 | outputs_low = outputs_low_to_low + outputs_high_to_low 201 | 202 | if self.filters_low == 0: 203 | return outputs_high 204 | if self.filters_high == 0: 205 | return outputs_low 206 | return [outputs_high, outputs_low] 207 | 208 | def get_config(self): 209 | config = { 210 | 'filters': self.filters, 211 | 'kernel_size': self.kernel_size, 212 | 'octave': self.octave, 213 | 'ratio_out': self.ratio_out, 214 | 'strides': self.strides, 215 | 'data_format': self.data_format, 216 | 'dilation_rate': self.dilation_rate, 217 | 'activation': activations.serialize(self.activation), 218 | 'use_bias': self.use_bias, 219 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 220 | 'bias_initializer': initializers.serialize(self.bias_initializer), 221 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 222 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 223 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 224 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 225 | 'bias_constraint': constraints.serialize(self.bias_constraint) 226 | } 227 | base_config = super(OctaveConv2D, self).get_config() 228 | return dict(list(base_config.items()) + list(config.items())) 229 | 230 | 231 | def octave_conv_2d(inputs, 232 | filters, 233 | kernel_size, 234 | octave=2, 235 | ratio_out=0.5, 236 | strides=(1, 1), 237 | data_format=None, 238 | dilation_rate=(1, 1), 239 | activation=None, 240 | use_bias=True, 241 | kernel_initializer='glorot_uniform', 242 | bias_initializer='zeros', 243 | kernel_regularizer=None, 244 | bias_regularizer=None, 245 | activity_regularizer=None, 246 | kernel_constraint=None, 247 | bias_constraint=None, 248 | name=None, 249 | **kwargs): 250 | if isinstance(inputs, (list, tuple)): 251 | inputs_high, inputs_low = inputs 252 | else: 253 | inputs_high, inputs_low = inputs, None 254 | 255 | filters_low = int(filters * ratio_out) 256 | filters_high = filters - filters_low 257 | 258 | def _init_conv(conv_filters, conv_name_suffix): 259 | if name is None: 260 | conv_name = None 261 | else: 262 | conv_name = name + '-' + conv_name_suffix 263 | return layers.Conv2D( 264 | filters=conv_filters, 265 | kernel_size=kernel_size, 266 | strides=strides, 267 | padding='same', 268 | data_format=data_format, 269 | dilation_rate=dilation_rate, 270 | activation=activation, 271 | use_bias=use_bias, 272 | kernel_initializer=kernel_initializer, 273 | bias_initializer=bias_initializer, 274 | kernel_regularizer=kernel_regularizer, 275 | bias_regularizer=bias_regularizer, 276 | activity_regularizer=activity_regularizer, 277 | kernel_constraint=kernel_constraint, 278 | bias_constraint=bias_constraint, 279 | name=conv_name, 280 | **kwargs 281 | ) 282 | 283 | outputs_high = None 284 | if filters_high > 0: 285 | outputs_high = _init_conv(filters_high, 'HH')(inputs_high) 286 | if inputs_low is not None: 287 | if name is None: 288 | up_sampling_name, add_name = None, None 289 | else: 290 | up_sampling_name, add_name = name + '-UpSample', name + '-Add-H' 291 | outputs_high = layers.Add(name=add_name)([outputs_high, layers.UpSampling2D( 292 | size=octave, 293 | data_format=data_format, 294 | interpolation='nearest', 295 | name=up_sampling_name, 296 | )(_init_conv(filters_high, 'LH')(inputs_low))]) 297 | 298 | outputs_low = None 299 | if filters_low > 0: 300 | if name is None: 301 | pooling_name, add_name = None, None 302 | else: 303 | pooling_name, add_name = name + '-Pool', name + '-Add-L' 304 | outputs_low = _init_conv(filters_low, 'HL')(layers.AveragePooling2D( 305 | pool_size=octave, 306 | padding='valid', 307 | data_format=data_format, 308 | name=pooling_name, 309 | )(inputs_high)) 310 | if inputs_low is not None: 311 | outputs_low = layers.Add(name=add_name)([_init_conv(filters_low, 'LL')(inputs_low), outputs_low]) 312 | 313 | if outputs_high is None: 314 | return outputs_low 315 | if outputs_low is None: 316 | return outputs_high 317 | return [outputs_high, outputs_low] 318 | --------------------------------------------------------------------------------