├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── keras_layer_normalization ├── __init__.py └── layer_normalization.py ├── publish.sh ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── test.sh └── tests ├── __init__.py └── test_layer_normalization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # System thumbnail 107 | .DS_Store 108 | 109 | # IDE 110 | .idea 111 | 112 | # Temporary README 113 | README.rst 114 | 115 | # Images 116 | *.png 117 | 118 | # Models 119 | *.h5 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Zhao HG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras Layer Normalization 2 | 3 | [![Version](https://img.shields.io/pypi/v/keras-layer-normalization.svg)](https://pypi.org/project/keras-layer-normalization/) 4 | ![License](https://img.shields.io/pypi/l/keras-layer-normalization.svg) 5 | 6 | Implementation of the paper: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 7 | 8 | ## Install 9 | 10 | ```bash 11 | pip install keras-layer-normalization 12 | ``` 13 | 14 | ## Usage 15 | 16 | ```python 17 | from tensorflow import keras 18 | from keras_layer_normalization import LayerNormalization 19 | 20 | 21 | input_layer = keras.layers.Input(shape=(2, 3)) 22 | norm_layer = LayerNormalization()(input_layer) 23 | model = keras.models.Model(inputs=input_layer, outputs=norm_layer) 24 | model.compile(optimizer='adam', loss='mse', metrics={},) 25 | model.summary() 26 | ``` 27 | -------------------------------------------------------------------------------- /keras_layer_normalization/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer_normalization import LayerNormalization 2 | 3 | __version__ = '0.16.0' 4 | -------------------------------------------------------------------------------- /keras_layer_normalization/layer_normalization.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | __all__ = ['LayerNormalization'] 5 | 6 | 7 | class LayerNormalization(keras.layers.Layer): 8 | 9 | def __init__(self, 10 | center=True, 11 | scale=True, 12 | epsilon=None, 13 | gamma_initializer='ones', 14 | beta_initializer='zeros', 15 | gamma_regularizer=None, 16 | beta_regularizer=None, 17 | gamma_constraint=None, 18 | beta_constraint=None, 19 | **kwargs): 20 | """Layer normalization layer 21 | 22 | See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 23 | 24 | :param center: Add an offset parameter if it is True. 25 | :param scale: Add a scale parameter if it is True. 26 | :param epsilon: Epsilon for calculating variance. 27 | :param gamma_initializer: Initializer for the gamma weight. 28 | :param beta_initializer: Initializer for the beta weight. 29 | :param gamma_regularizer: Optional regularizer for the gamma weight. 30 | :param beta_regularizer: Optional regularizer for the beta weight. 31 | :param gamma_constraint: Optional constraint for the gamma weight. 32 | :param beta_constraint: Optional constraint for the beta weight. 33 | :param kwargs: 34 | """ 35 | super(LayerNormalization, self).__init__(**kwargs) 36 | self.supports_masking = True 37 | self.center = center 38 | self.scale = scale 39 | if epsilon is None: 40 | epsilon = K.epsilon() * K.epsilon() 41 | self.epsilon = epsilon 42 | self.gamma_initializer = keras.initializers.get(gamma_initializer) 43 | self.beta_initializer = keras.initializers.get(beta_initializer) 44 | self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) 45 | self.beta_regularizer = keras.regularizers.get(beta_regularizer) 46 | self.gamma_constraint = keras.constraints.get(gamma_constraint) 47 | self.beta_constraint = keras.constraints.get(beta_constraint) 48 | self.gamma, self.beta = None, None 49 | 50 | def get_config(self): 51 | config = { 52 | 'center': self.center, 53 | 'scale': self.scale, 54 | 'epsilon': self.epsilon, 55 | 'gamma_initializer': keras.initializers.serialize(self.gamma_initializer), 56 | 'beta_initializer': keras.initializers.serialize(self.beta_initializer), 57 | 'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer), 58 | 'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer), 59 | 'gamma_constraint': keras.constraints.serialize(self.gamma_constraint), 60 | 'beta_constraint': keras.constraints.serialize(self.beta_constraint), 61 | } 62 | base_config = super(LayerNormalization, self).get_config() 63 | return dict(list(base_config.items()) + list(config.items())) 64 | 65 | def compute_output_shape(self, input_shape): 66 | return input_shape 67 | 68 | def compute_mask(self, inputs, input_mask=None): 69 | return input_mask 70 | 71 | def build(self, input_shape): 72 | shape = input_shape[-1:] 73 | if self.scale: 74 | self.gamma = self.add_weight( 75 | shape=shape, 76 | initializer=self.gamma_initializer, 77 | regularizer=self.gamma_regularizer, 78 | constraint=self.gamma_constraint, 79 | name='gamma', 80 | ) 81 | if self.center: 82 | self.beta = self.add_weight( 83 | shape=shape, 84 | initializer=self.beta_initializer, 85 | regularizer=self.beta_regularizer, 86 | constraint=self.beta_constraint, 87 | name='beta', 88 | ) 89 | super(LayerNormalization, self).build(input_shape) 90 | 91 | def call(self, inputs, training=None): 92 | mean = K.mean(inputs, axis=-1, keepdims=True) 93 | variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True) 94 | std = K.sqrt(variance + self.epsilon) 95 | outputs = (inputs - mean) / std 96 | if self.scale: 97 | outputs *= self.gamma 98 | if self.center: 99 | outputs += self.beta 100 | return outputs 101 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf dist/* && python3 setup.py sdist && twine upload dist/* 3 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | setuptools>=38.6.0 2 | twine>=1.11.0 3 | wheel>=0.31.0 4 | nose 5 | tensorflow 6 | pycodestyle 7 | coverage 8 | keras-multi-head==0.29.0 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | from setuptools import setup, find_packages 5 | 6 | current_path = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | 9 | def read_file(*parts): 10 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 11 | return reader.read() 12 | 13 | 14 | def get_requirements(*parts): 15 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 16 | return list(map(lambda x: x.strip(), reader.readlines())) 17 | 18 | 19 | def find_version(*file_paths): 20 | version_file = read_file(*file_paths) 21 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 22 | if version_match: 23 | return version_match.group(1) 24 | raise RuntimeError('Unable to find version string.') 25 | 26 | 27 | setup( 28 | name='keras-layer-normalization', 29 | version=find_version('keras_layer_normalization', '__init__.py'), 30 | packages=find_packages(), 31 | url='https://github.com/CyberZHG/keras-layer-normalization', 32 | license='MIT', 33 | author='CyberZHG', 34 | author_email='CyberZHG@users.noreply.github.com', 35 | description='Layer normalization implemented in Keras', 36 | long_description=read_file('README.md'), 37 | long_description_content_type='text/markdown', 38 | install_requires=get_requirements('requirements.txt'), 39 | classifiers=( 40 | "Programming Language :: Python :: 3", 41 | "License :: OSI Approved :: MIT License", 42 | "Operating System :: OS Independent", 43 | ), 44 | ) 45 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | pycodestyle --max-line-length=120 keras_layer_normalization tests && \ 3 | nosetests --with-coverage --cover-html --cover-html-dir=htmlcov --cover-package=keras_layer_normalization tests -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-layer-normalization/b57dd4d9b2161738f85b657c125b314554b7d614/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_layer_normalization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from keras_multi_head import MultiHeadAttention 9 | from keras_layer_normalization import LayerNormalization 10 | 11 | 12 | class TestLayerNormalization(unittest.TestCase): 13 | 14 | def test_sample(self): 15 | input_layer = keras.layers.Input( 16 | shape=(2, 3), 17 | name='Input', 18 | ) 19 | norm_layer = LayerNormalization( 20 | name='Layer-Normalization', 21 | )(input_layer) 22 | model = keras.models.Model( 23 | inputs=input_layer, 24 | outputs=norm_layer, 25 | ) 26 | model.compile( 27 | optimizer='adam', 28 | loss='mse', 29 | metrics={}, 30 | ) 31 | model.summary() 32 | inputs = np.array([[ 33 | [0.2, 0.1, 0.3], 34 | [0.5, 0.1, 0.1], 35 | ]]) 36 | predict = model.predict(inputs) 37 | expected = np.asarray([[ 38 | [0.0, -1.22474487, 1.22474487], 39 | [1.41421356, -0.707106781, -0.707106781], 40 | ]]) 41 | self.assertTrue(np.allclose(expected, predict), predict) 42 | 43 | input_layer = keras.layers.Input( 44 | shape=(10, 256), 45 | name='Input', 46 | ) 47 | norm_layer = LayerNormalization( 48 | name='Layer-Normalization', 49 | beta_initializer='ones', 50 | )(input_layer) 51 | model = keras.models.Model( 52 | inputs=input_layer, 53 | outputs=norm_layer, 54 | ) 55 | model.compile( 56 | optimizer='adam', 57 | loss='mse', 58 | metrics={}, 59 | ) 60 | model.summary() 61 | inputs = np.zeros((2, 10, 256)) 62 | predict = model.predict(inputs) 63 | expected = np.ones((2, 10, 256)) 64 | self.assertTrue(np.allclose(expected, predict)) 65 | 66 | def test_fit(self): 67 | def _leaky_relu(x): 68 | return keras.activations.relu(x, alpha=0.01) 69 | 70 | input_layer = keras.layers.Input( 71 | shape=(2, 3), 72 | name='Input', 73 | ) 74 | norm_layer = LayerNormalization( 75 | name='Layer-Normalization-1', 76 | trainable=False, 77 | )(input_layer) 78 | att_layer = MultiHeadAttention( 79 | head_num=3, 80 | activation=_leaky_relu, 81 | name='Multi-Head-Attentions' 82 | )(norm_layer) 83 | dense_layer = keras.layers.Dense(units=3, name='Dense-1')(att_layer) 84 | norm_layer = LayerNormalization( 85 | name='Layer-Normalization-2', 86 | trainable=False, 87 | )(dense_layer) 88 | dense_layer = keras.layers.Dense(units=3, name='Dense-2')(norm_layer) 89 | model = keras.models.Model( 90 | inputs=input_layer, 91 | outputs=dense_layer, 92 | ) 93 | model.compile( 94 | optimizer=keras.optimizers.Adam(lr=1e-3), 95 | loss='mse', 96 | metrics={}, 97 | ) 98 | model.summary() 99 | 100 | def _generator(batch_size=32): 101 | while True: 102 | batch_inputs = np.random.random((batch_size, 2, 3)) 103 | batch_outputs = np.asarray([[[0.0, -0.1, 0.2]] * 2] * batch_size) 104 | yield batch_inputs, batch_outputs 105 | 106 | model.fit_generator( 107 | generator=_generator(), 108 | steps_per_epoch=100, 109 | epochs=100, 110 | validation_data=_generator(), 111 | validation_steps=100, 112 | callbacks=[ 113 | keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, min_delta=1e-4) 114 | ], 115 | ) 116 | model_path = os.path.join(tempfile.gettempdir(), 'test_layer_normalization_%f.h5' % np.random.random()) 117 | model.save(model_path) 118 | model = keras.models.load_model(model_path, custom_objects={ 119 | '_leaky_relu': _leaky_relu, 120 | 'MultiHeadAttention': MultiHeadAttention, 121 | 'LayerNormalization': LayerNormalization, 122 | }) 123 | for inputs, _ in _generator(batch_size=3): 124 | predicts = model.predict(inputs) 125 | expect = np.round(np.asarray([[[0.0, -0.1, 0.2]] * 2] * 3), decimals=1) 126 | actual = np.round(predicts, decimals=1) 127 | self.assertTrue(np.allclose(expect, actual), (expect, actual)) 128 | break 129 | 130 | def test_fit_zeros(self): 131 | def _leaky_relu(x): 132 | return keras.activations.relu(x, alpha=0.01) 133 | 134 | input_layer = keras.layers.Input( 135 | shape=(2, 3), 136 | name='Input', 137 | ) 138 | norm_layer = LayerNormalization( 139 | name='Layer-Normalization-1', 140 | trainable=False, 141 | )(input_layer) 142 | att_layer = MultiHeadAttention( 143 | head_num=3, 144 | activation=_leaky_relu, 145 | name='Multi-Head-Attentions' 146 | )(norm_layer) 147 | dense_layer = keras.layers.Dense(units=3, name='Dense-1')(att_layer) 148 | norm_layer = LayerNormalization( 149 | name='Layer-Normalization-2', 150 | trainable=False, 151 | )(dense_layer) 152 | dense_layer = keras.layers.Dense(units=3, name='Dense-2')(norm_layer) 153 | model = keras.models.Model( 154 | inputs=input_layer, 155 | outputs=dense_layer, 156 | ) 157 | model.compile( 158 | optimizer=keras.optimizers.Adam(lr=1e-3), 159 | loss='mse', 160 | metrics={}, 161 | ) 162 | model.summary() 163 | 164 | def _generator_zeros(batch_size=32): 165 | while True: 166 | batch_inputs = np.zeros((batch_size, 2, 3)) 167 | batch_outputs = np.asarray([[[0.0, -0.1, 0.2]] * 2] * batch_size) 168 | yield batch_inputs, batch_outputs 169 | 170 | model.fit_generator( 171 | generator=_generator_zeros(), 172 | steps_per_epoch=100, 173 | epochs=100, 174 | validation_data=_generator_zeros(), 175 | validation_steps=100, 176 | callbacks=[ 177 | keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, min_delta=1e-4) 178 | ], 179 | ) 180 | for inputs, _ in _generator_zeros(batch_size=3): 181 | predicts = model.predict(inputs) 182 | expect = np.round(np.asarray([[[0.0, -0.1, 0.2]] * 2] * 3), decimals=1) 183 | actual = np.round(predicts, decimals=1) 184 | self.assertTrue(np.allclose(expect, actual), (expect, actual)) 185 | break 186 | 187 | def test_save_load_json(self): 188 | model = keras.models.Sequential() 189 | model.add(LayerNormalization(input_shape=(2, 3))) 190 | model.compile(optimizer='adam', loss='mse') 191 | encoded = model.to_json() 192 | model = keras.models.model_from_json(encoded, custom_objects={'LayerNormalization': LayerNormalization}) 193 | model.summary() 194 | --------------------------------------------------------------------------------