├── .github └── stale.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── README.zh-CN.md ├── keras_self_attention ├── __init__.py ├── real_former.py ├── scaled_dot_attention.py ├── seq_self_attention.py └── seq_weighted_attention.py ├── publish.sh ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── test.sh └── tests ├── __init__.py ├── scaled_dot_attention ├── __init__.py ├── test_history.py ├── test_real_former.py ├── test_sample.py └── test_save_load.py ├── seq_self_attention ├── __init__.py ├── test_activation.py ├── test_bias.py ├── test_history.py ├── test_local.py ├── test_loss.py ├── test_mask.py ├── test_mul.py ├── test_save_load.py └── util.py └── seq_weighted_attention ├── __init__.py └── test_save_load.py /.github/stale.yml: -------------------------------------------------------------------------------- 1 | daysUntilStale: 5 2 | daysUntilClose: 2 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # System thumbnail 107 | .DS_Store 108 | 109 | # IDE 110 | .idea 111 | 112 | # Temporary README 113 | README.rst 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 PoW 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras Self-Attention 2 | 3 | [![Version](https://img.shields.io/pypi/v/keras-self-attention.svg)](https://pypi.org/project/keras-self-attention/) 4 | ![License](https://img.shields.io/pypi/l/keras-self-attention.svg) 5 | 6 | \[[中文](https://github.com/CyberZHG/keras-self-attention/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-self-attention/blob/master/README.md)\] 7 | 8 | Attention mechanism for processing sequential data that considers the context for each timestamp. 9 | 10 | * ![](https://user-images.githubusercontent.com/853842/44248592-1fbd0500-a21e-11e8-9fe0-52a1e4a48329.gif) 11 | * ![](https://user-images.githubusercontent.com/853842/44248591-1e8bd800-a21e-11e8-9ca8-9198c2725108.gif) 12 | * ![](https://user-images.githubusercontent.com/853842/44248590-1df34180-a21e-11e8-8ff1-268217f466ba.gif) 13 | * ![](https://user-images.githubusercontent.com/853842/44249018-8ba06d00-a220-11e8-80e3-802677b658ed.gif) 14 | 15 | ## Install 16 | 17 | ```bash 18 | pip install keras-self-attention 19 | ``` 20 | 21 | ## Usage 22 | 23 | ### Basic 24 | 25 | By default, the attention layer uses additive attention and considers the whole context while calculating the relevance. The following code creates an attention layer that follows the equations in the first section (`attention_activation` is the activation function of `e_{t, t'}`): 26 | 27 | ```python 28 | from tensorflow import keras 29 | from keras_self_attention import SeqSelfAttention 30 | 31 | 32 | model = keras.models.Sequential() 33 | model.add(keras.layers.Embedding(input_dim=10000, 34 | output_dim=300, 35 | mask_zero=True)) 36 | model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128, 37 | return_sequences=True))) 38 | model.add(SeqSelfAttention(attention_activation='sigmoid')) 39 | model.add(keras.layers.Dense(units=5)) 40 | model.compile( 41 | optimizer='adam', 42 | loss='categorical_crossentropy', 43 | metrics=['categorical_accuracy'], 44 | ) 45 | model.summary() 46 | ``` 47 | 48 | ### Local Attention 49 | 50 | The global context may be too broad for one piece of data. The parameter `attention_width` controls the width of the local context: 51 | 52 | ```python 53 | from keras_self_attention import SeqSelfAttention 54 | 55 | SeqSelfAttention( 56 | attention_width=15, 57 | attention_activation='sigmoid', 58 | name='Attention', 59 | ) 60 | ``` 61 | 62 | ### Multiplicative Attention 63 | 64 | You can use multiplicative attention by setting `attention_type`: 65 | 66 | ![](https://user-images.githubusercontent.com/853842/44253887-a03a3080-a233-11e8-9d49-3fd7e622a0f7.gif) 67 | 68 | ```python 69 | from keras_self_attention import SeqSelfAttention 70 | 71 | SeqSelfAttention( 72 | attention_width=15, 73 | attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL, 74 | attention_activation=None, 75 | kernel_regularizer=keras.regularizers.l2(1e-6), 76 | use_attention_bias=False, 77 | name='Attention', 78 | ) 79 | ``` 80 | 81 | ### Regularizer 82 | 83 | ![](https://user-images.githubusercontent.com/853842/44250188-f99b6300-a225-11e8-8fab-8dcf0d99616e.gif) 84 | 85 | To use the regularizer, set `attention_regularizer_weight` to a positive number: 86 | 87 | ```python 88 | from tensorflow import keras 89 | from keras_self_attention import SeqSelfAttention 90 | 91 | inputs = keras.layers.Input(shape=(None,)) 92 | embd = keras.layers.Embedding(input_dim=32, 93 | output_dim=16, 94 | mask_zero=True)(inputs) 95 | lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=16, 96 | return_sequences=True))(embd) 97 | att = SeqSelfAttention(attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL, 98 | kernel_regularizer=keras.regularizers.l2(1e-4), 99 | bias_regularizer=keras.regularizers.l1(1e-4), 100 | attention_regularizer_weight=1e-4, 101 | name='Attention')(lstm) 102 | dense = keras.layers.Dense(units=5, name='Dense')(att) 103 | model = keras.models.Model(inputs=inputs, outputs=[dense]) 104 | model.compile( 105 | optimizer='adam', 106 | loss={'Dense': 'sparse_categorical_crossentropy'}, 107 | metrics={'Dense': 'categorical_accuracy'}, 108 | ) 109 | model.summary(line_length=100) 110 | ``` 111 | 112 | ### Load the Model 113 | 114 | Make sure to add `SeqSelfAttention` to custom objects: 115 | 116 | ```python 117 | from tensorflow import keras 118 | 119 | keras.models.load_model(model_path, custom_objects=SeqSelfAttention.get_custom_objects()) 120 | ``` 121 | 122 | ### History Only 123 | 124 | Set `history_only` to `True` when only historical data could be used: 125 | 126 | ```python 127 | SeqSelfAttention( 128 | attention_width=3, 129 | history_only=True, 130 | name='Attention', 131 | ) 132 | ``` 133 | 134 | ### Multi-Head 135 | 136 | Please refer to [keras-multi-head](https://github.com/CyberZHG/keras-multi-head). 137 | -------------------------------------------------------------------------------- /README.zh-CN.md: -------------------------------------------------------------------------------- 1 | # Keras自注意力 2 | 3 | [![Version](https://img.shields.io/pypi/v/keras-self-attention.svg)](https://pypi.org/project/keras-self-attention/) 4 | ![License](https://img.shields.io/pypi/l/keras-self-attention.svg) 5 | 6 | \[[中文](https://github.com/CyberZHG/keras-self-attention/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-self-attention/blob/master/README.md)\] 7 | 8 | Attention mechanism for processing sequential data that considers the context for each timestamp. 9 | 10 | * ![](https://user-images.githubusercontent.com/853842/44248592-1fbd0500-a21e-11e8-9fe0-52a1e4a48329.gif) 11 | * ![](https://user-images.githubusercontent.com/853842/44248591-1e8bd800-a21e-11e8-9ca8-9198c2725108.gif) 12 | * ![](https://user-images.githubusercontent.com/853842/44248590-1df34180-a21e-11e8-8ff1-268217f466ba.gif) 13 | * ![](https://user-images.githubusercontent.com/853842/44249018-8ba06d00-a220-11e8-80e3-802677b658ed.gif) 14 | 15 | ## 安装 16 | 17 | ```bash 18 | pip install keras-self-attention 19 | ``` 20 | 21 | ## 使用 22 | 23 | ### 基本 24 | 25 | 默认情况下,注意力层使用加性注意力机制,并使用全部上下文进行计算。下面的代码根据页首的公式创建了一个注意力层(`attention_activation`是注意力权重`e_{t, t'}`): 26 | 27 | ```python 28 | from tensorflow import keras 29 | from keras_self_attention import SeqSelfAttention 30 | 31 | 32 | model = keras.models.Sequential() 33 | model.add(keras.layers.Embedding(input_dim=10000, 34 | output_dim=300, 35 | mask_zero=True)) 36 | model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128, 37 | return_sequences=True))) 38 | model.add(SeqSelfAttention(attention_activation='sigmoid')) 39 | model.add(keras.layers.Dense(units=5)) 40 | model.compile( 41 | optimizer='adam', 42 | loss='categorical_crossentropy', 43 | metrics=['categorical_accuracy'], 44 | ) 45 | model.summary() 46 | ``` 47 | 48 | ### 局部注意力 49 | 50 | 参数`attention_width`控制着局部注意力的宽度: 51 | 52 | ```python 53 | from keras_self_attention import SeqSelfAttention 54 | 55 | SeqSelfAttention( 56 | attention_width=15, 57 | attention_activation='sigmoid', 58 | name='Attention', 59 | ) 60 | ``` 61 | 62 | ### 乘性注意力 63 | 64 | 用`attention_type`来改变注意力机制的计算方法: 65 | 66 | ![](https://user-images.githubusercontent.com/853842/44253887-a03a3080-a233-11e8-9d49-3fd7e622a0f7.gif) 67 | 68 | ```python 69 | from keras_self_attention import SeqSelfAttention 70 | 71 | SeqSelfAttention( 72 | attention_width=15, 73 | attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL, 74 | attention_activation=None, 75 | kernel_regularizer=keras.regularizers.l2(1e-6), 76 | use_attention_bias=False, 77 | name='Attention', 78 | ) 79 | ``` 80 | 81 | ### 正则化 82 | 83 | ![](https://user-images.githubusercontent.com/853842/44250188-f99b6300-a225-11e8-8fab-8dcf0d99616e.gif) 84 | 85 | 通过将`attention_regularizer_weight`设置为一个正数来使用正则化: 86 | 87 | ```python 88 | from tensorflow import keras 89 | from keras_self_attention import SeqSelfAttention 90 | 91 | inputs = keras.layers.Input(shape=(None,)) 92 | embd = keras.layers.Embedding(input_dim=32, 93 | output_dim=16, 94 | mask_zero=True)(inputs) 95 | lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=16, 96 | return_sequences=True))(embd) 97 | att = SeqSelfAttention(attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL, 98 | kernel_regularizer=keras.regularizers.l2(1e-4), 99 | bias_regularizer=keras.regularizers.l1(1e-4), 100 | attention_regularizer_weight=1e-4, 101 | name='Attention')(lstm) 102 | dense = keras.layers.Dense(units=5, name='Dense')(att) 103 | model = keras.models.Model(inputs=inputs, outputs=[dense]) 104 | model.compile( 105 | optimizer='adam', 106 | loss={'Dense': 'sparse_categorical_crossentropy'}, 107 | metrics={'Dense': 'categorical_accuracy'}, 108 | ) 109 | model.summary(line_length=100) 110 | ``` 111 | 112 | ### 加载模型 113 | 114 | Make sure to add `SeqSelfAttention` to custom objects: 115 | 116 | ```python 117 | from tensorflow import keras 118 | 119 | keras.models.load_model(model_path, custom_objects=SeqSelfAttention.get_custom_objects()) 120 | ``` 121 | 122 | ### 只使用历史进行计算 123 | 124 | 对于decoder等场景,为了保持输出固定只能使用上文的信息: 125 | 126 | ```python 127 | SeqSelfAttention( 128 | attention_width=3, 129 | history_only=True, 130 | name='Attention', 131 | ) 132 | ``` 133 | 134 | ### 多头注意力 135 | 136 | 参考[keras-multi-head](https://github.com/CyberZHG/keras-multi-head)。 137 | -------------------------------------------------------------------------------- /keras_self_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .seq_self_attention import SeqSelfAttention 2 | from .seq_weighted_attention import SeqWeightedAttention 3 | from .scaled_dot_attention import ScaledDotProductAttention 4 | from .real_former import ResidualScaledDotProductAttention 5 | 6 | __version__ = '0.51.0' 7 | -------------------------------------------------------------------------------- /keras_self_attention/real_former.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | class ResidualScaledDotProductAttention(keras.layers.Layer): 6 | r"""The attention layer that takes three inputs representing queries, keys and values. 7 | 8 | \text{Attention}(Q, K, V, Prev) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}} + Prev) V 9 | 10 | See: https://arxiv.org/pdf/1706.03762.pdf 11 | """ 12 | 13 | def __init__(self, 14 | return_attention=False, 15 | history_only=False, 16 | **kwargs): 17 | """Initialize the layer. 18 | 19 | :param return_attention: Whether to return attention weights. 20 | :param history_only: Whether to only use history data. 21 | :param kwargs: Arguments for parent class. 22 | """ 23 | super().__init__(**kwargs) 24 | self.supports_masking = True 25 | self.return_attention = return_attention 26 | self.history_only = history_only 27 | self.intensity = self.attention = None 28 | 29 | def get_config(self): 30 | config = { 31 | 'return_attention': self.return_attention, 32 | 'history_only': self.history_only, 33 | } 34 | base_config = super(ResidualScaledDotProductAttention, self).get_config() 35 | return dict(list(base_config.items()) + list(config.items())) 36 | 37 | def compute_mask(self, inputs, mask=None): 38 | mask = mask[0] 39 | if self.return_attention: 40 | mask = [mask, mask[-1], None] 41 | return [mask, mask[-1]] 42 | 43 | def call(self, inputs, mask=None, **kwargs): 44 | if len(inputs) == 4: 45 | query, key, value, prev = inputs 46 | mask = mask[1] 47 | else: 48 | query = key = value = inputs[0] 49 | prev = inputs[1] 50 | mask = mask[0] 51 | feature_dim = K.shape(query)[-1] 52 | e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx())) 53 | new_prev = e = e + prev 54 | if self.history_only: 55 | query_len, key_len = K.shape(query)[1], K.shape(key)[1] 56 | indices = K.expand_dims(K.arange(0, key_len), axis=0) 57 | upper = K.expand_dims(K.arange(0, query_len), axis=-1) 58 | e -= 10000.0 * K.expand_dims(K.cast(indices > upper, K.floatx()), axis=0) 59 | if mask is not None: 60 | e -= 10000.0 * (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx())) 61 | self.intensity = e 62 | e = K.exp(e - K.max(e, axis=-1, keepdims=True)) 63 | self.attention = e / K.sum(e, axis=-1, keepdims=True) 64 | v = K.batch_dot(self.attention, value) 65 | output = [v, new_prev] 66 | if self.return_attention: 67 | output.append(self.attention) 68 | return output 69 | -------------------------------------------------------------------------------- /keras_self_attention/scaled_dot_attention.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | class ScaledDotProductAttention(keras.layers.Layer): 6 | r"""The attention layer that takes three inputs representing queries, keys and values. 7 | 8 | \text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}}) V 9 | 10 | See: https://arxiv.org/pdf/1706.03762.pdf 11 | """ 12 | 13 | def __init__(self, 14 | return_attention=False, 15 | history_only=False, 16 | **kwargs): 17 | """Initialize the layer. 18 | 19 | :param return_attention: Whether to return attention weights. 20 | :param history_only: Whether to only use history data. 21 | :param kwargs: Arguments for parent class. 22 | """ 23 | super(ScaledDotProductAttention, self).__init__(**kwargs) 24 | self.supports_masking = True 25 | self.return_attention = return_attention 26 | self.history_only = history_only 27 | self.intensity = self.attention = None 28 | 29 | def get_config(self): 30 | config = { 31 | 'return_attention': self.return_attention, 32 | 'history_only': self.history_only, 33 | } 34 | base_config = super(ScaledDotProductAttention, self).get_config() 35 | return dict(list(base_config.items()) + list(config.items())) 36 | 37 | def compute_output_shape(self, input_shape): 38 | if isinstance(input_shape, list): 39 | query_shape, key_shape, value_shape = input_shape 40 | else: 41 | query_shape = key_shape = value_shape = input_shape 42 | output_shape = query_shape[:-1] + value_shape[-1:] 43 | if self.return_attention: 44 | attention_shape = query_shape[:2] + (key_shape[1],) 45 | return [output_shape, attention_shape] 46 | return output_shape 47 | 48 | def compute_mask(self, inputs, mask=None): 49 | if isinstance(mask, list): 50 | mask = mask[0] 51 | if self.return_attention: 52 | return [mask, None] 53 | return mask 54 | 55 | def call(self, inputs, mask=None, **kwargs): 56 | if isinstance(inputs, list): 57 | query, key, value = inputs 58 | else: 59 | query = key = value = inputs 60 | if isinstance(mask, list): 61 | mask = mask[1] 62 | feature_dim = K.shape(query)[-1] 63 | e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx())) 64 | if self.history_only: 65 | query_len, key_len = K.shape(query)[1], K.shape(key)[1] 66 | indices = K.expand_dims(K.arange(0, key_len), axis=0) 67 | upper = K.expand_dims(K.arange(0, query_len), axis=-1) 68 | e -= 10000.0 * K.expand_dims(K.cast(indices > upper, K.floatx()), axis=0) 69 | if mask is not None: 70 | e -= 10000.0 * (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx())) 71 | self.intensity = e 72 | e = K.exp(e - K.max(e, axis=-1, keepdims=True)) 73 | self.attention = e / K.sum(e, axis=-1, keepdims=True) 74 | v = K.batch_dot(self.attention, value) 75 | if self.return_attention: 76 | return [v, self.attention] 77 | return v 78 | -------------------------------------------------------------------------------- /keras_self_attention/seq_self_attention.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | class SeqSelfAttention(keras.layers.Layer): 6 | 7 | ATTENTION_TYPE_ADD = 'additive' 8 | ATTENTION_TYPE_MUL = 'multiplicative' 9 | 10 | def __init__(self, 11 | units=32, 12 | attention_width=None, 13 | attention_type=ATTENTION_TYPE_ADD, 14 | return_attention=False, 15 | history_only=False, 16 | kernel_initializer='glorot_normal', 17 | bias_initializer='zeros', 18 | kernel_regularizer=None, 19 | bias_regularizer=None, 20 | kernel_constraint=None, 21 | bias_constraint=None, 22 | use_additive_bias=True, 23 | use_attention_bias=True, 24 | attention_activation=None, 25 | attention_regularizer_weight=0.0, 26 | **kwargs): 27 | """Layer initialization. 28 | 29 | For additive attention, see: https://arxiv.org/pdf/1806.01264.pdf 30 | 31 | :param units: The dimension of the vectors that used to calculate the attention weights. 32 | :param attention_width: The width of local attention. 33 | :param attention_type: 'additive' or 'multiplicative'. 34 | :param return_attention: Whether to return the attention weights for visualization. 35 | :param history_only: Only use historical pieces of data. 36 | :param kernel_initializer: The initializer for weight matrices. 37 | :param bias_initializer: The initializer for biases. 38 | :param kernel_regularizer: The regularization for weight matrices. 39 | :param bias_regularizer: The regularization for biases. 40 | :param kernel_constraint: The constraint for weight matrices. 41 | :param bias_constraint: The constraint for biases. 42 | :param use_additive_bias: Whether to use bias while calculating the relevance of inputs features 43 | in additive mode. 44 | :param use_attention_bias: Whether to use bias while calculating the weights of attention. 45 | :param attention_activation: The activation used for calculating the weights of attention. 46 | :param attention_regularizer_weight: The weights of attention regularizer. 47 | :param kwargs: Parameters for parent class. 48 | """ 49 | super(SeqSelfAttention, self).__init__(**kwargs) 50 | self.supports_masking = True 51 | self.units = units 52 | self.attention_width = attention_width 53 | self.attention_type = attention_type 54 | self.return_attention = return_attention 55 | self.history_only = history_only 56 | if history_only and attention_width is None: 57 | self.attention_width = int(1e9) 58 | 59 | self.use_additive_bias = use_additive_bias 60 | self.use_attention_bias = use_attention_bias 61 | self.kernel_initializer = keras.initializers.get(kernel_initializer) 62 | self.bias_initializer = keras.initializers.get(bias_initializer) 63 | self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) 64 | self.bias_regularizer = keras.regularizers.get(bias_regularizer) 65 | self.kernel_constraint = keras.constraints.get(kernel_constraint) 66 | self.bias_constraint = keras.constraints.get(bias_constraint) 67 | self.attention_activation = keras.activations.get(attention_activation) 68 | self.attention_regularizer_weight = attention_regularizer_weight 69 | self._backend = keras.backend.backend() 70 | 71 | if attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD: 72 | self.Wx, self.Wt, self.bh = None, None, None 73 | self.Wa, self.ba = None, None 74 | elif attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL: 75 | self.Wa, self.ba = None, None 76 | else: 77 | raise NotImplementedError('No implementation for attention type : ' + attention_type) 78 | 79 | def get_config(self): 80 | config = { 81 | 'units': self.units, 82 | 'attention_width': self.attention_width, 83 | 'attention_type': self.attention_type, 84 | 'return_attention': self.return_attention, 85 | 'history_only': self.history_only, 86 | 'use_additive_bias': self.use_additive_bias, 87 | 'use_attention_bias': self.use_attention_bias, 88 | 'kernel_initializer': keras.initializers.serialize(self.kernel_initializer), 89 | 'bias_initializer': keras.initializers.serialize(self.bias_initializer), 90 | 'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer), 91 | 'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer), 92 | 'kernel_constraint': keras.constraints.serialize(self.kernel_constraint), 93 | 'bias_constraint': keras.constraints.serialize(self.bias_constraint), 94 | 'attention_activation': keras.activations.serialize(self.attention_activation), 95 | 'attention_regularizer_weight': self.attention_regularizer_weight, 96 | } 97 | base_config = super(SeqSelfAttention, self).get_config() 98 | return dict(list(base_config.items()) + list(config.items())) 99 | 100 | def build(self, input_shape): 101 | if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD: 102 | self._build_additive_attention(input_shape) 103 | elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL: 104 | self._build_multiplicative_attention(input_shape) 105 | super(SeqSelfAttention, self).build(input_shape) 106 | 107 | def _build_additive_attention(self, input_shape): 108 | feature_dim = int(input_shape[2]) 109 | 110 | self.Wt = self.add_weight(shape=(feature_dim, self.units), 111 | name='{}_Add_Wt'.format(self.name), 112 | initializer=self.kernel_initializer, 113 | regularizer=self.kernel_regularizer, 114 | constraint=self.kernel_constraint) 115 | self.Wx = self.add_weight(shape=(feature_dim, self.units), 116 | name='{}_Add_Wx'.format(self.name), 117 | initializer=self.kernel_initializer, 118 | regularizer=self.kernel_regularizer, 119 | constraint=self.kernel_constraint) 120 | if self.use_additive_bias: 121 | self.bh = self.add_weight(shape=(self.units,), 122 | name='{}_Add_bh'.format(self.name), 123 | initializer=self.bias_initializer, 124 | regularizer=self.bias_regularizer, 125 | constraint=self.bias_constraint) 126 | 127 | self.Wa = self.add_weight(shape=(self.units, 1), 128 | name='{}_Add_Wa'.format(self.name), 129 | initializer=self.kernel_initializer, 130 | regularizer=self.kernel_regularizer, 131 | constraint=self.kernel_constraint) 132 | if self.use_attention_bias: 133 | self.ba = self.add_weight(shape=(1,), 134 | name='{}_Add_ba'.format(self.name), 135 | initializer=self.bias_initializer, 136 | regularizer=self.bias_regularizer, 137 | constraint=self.bias_constraint) 138 | 139 | def _build_multiplicative_attention(self, input_shape): 140 | feature_dim = int(input_shape[2]) 141 | 142 | self.Wa = self.add_weight(shape=(feature_dim, feature_dim), 143 | name='{}_Mul_Wa'.format(self.name), 144 | initializer=self.kernel_initializer, 145 | regularizer=self.kernel_regularizer, 146 | constraint=self.kernel_constraint) 147 | if self.use_attention_bias: 148 | self.ba = self.add_weight(shape=(1,), 149 | name='{}_Mul_ba'.format(self.name), 150 | initializer=self.bias_initializer, 151 | regularizer=self.bias_regularizer, 152 | constraint=self.bias_constraint) 153 | 154 | def call(self, inputs, mask=None, **kwargs): 155 | input_len = K.shape(inputs)[1] 156 | 157 | if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD: 158 | e = self._call_additive_emission(inputs) 159 | elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL: 160 | e = self._call_multiplicative_emission(inputs) 161 | 162 | if self.attention_activation is not None: 163 | e = self.attention_activation(e) 164 | if self.attention_width is not None: 165 | if self.history_only: 166 | lower = K.arange(0, input_len) - (self.attention_width - 1) 167 | else: 168 | lower = K.arange(0, input_len) - self.attention_width // 2 169 | lower = K.expand_dims(lower, axis=-1) 170 | upper = lower + self.attention_width 171 | indices = K.expand_dims(K.arange(0, input_len), axis=0) 172 | e -= 10000.0 * (1.0 - K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx())) 173 | if mask is not None: 174 | mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1) 175 | e -= 10000.0 * ((1.0 - mask) * (1.0 - K.permute_dimensions(mask, (0, 2, 1)))) 176 | 177 | # a_{t} = \text{softmax}(e_t) 178 | e = K.exp(e - K.max(e, axis=-1, keepdims=True)) 179 | a = e / K.sum(e, axis=-1, keepdims=True) 180 | 181 | # l_t = \sum_{t'} a_{t, t'} x_{t'} 182 | v = K.batch_dot(a, inputs) 183 | if self.attention_regularizer_weight > 0.0: 184 | self.add_loss(self._attention_regularizer(a)) 185 | 186 | if self.return_attention: 187 | return [v, a] 188 | return v 189 | 190 | def _call_additive_emission(self, inputs): 191 | input_shape = K.shape(inputs) 192 | batch_size, input_len = input_shape[0], input_shape[1] 193 | 194 | # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h) 195 | q = K.expand_dims(K.dot(inputs, self.Wt), 2) 196 | k = K.expand_dims(K.dot(inputs, self.Wx), 1) 197 | if self.use_additive_bias: 198 | h = K.tanh(q + k + self.bh) 199 | else: 200 | h = K.tanh(q + k) 201 | 202 | # e_{t, t'} = W_a h_{t, t'} + b_a 203 | if self.use_attention_bias: 204 | e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len)) 205 | else: 206 | e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len)) 207 | return e 208 | 209 | def _call_multiplicative_emission(self, inputs): 210 | # e_{t, t'} = x_t^T W_a x_{t'} + b_a 211 | e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1))) 212 | if self.use_attention_bias: 213 | e += self.ba[0] 214 | return e 215 | 216 | def compute_output_shape(self, input_shape): 217 | output_shape = input_shape 218 | if self.return_attention: 219 | attention_shape = (input_shape[0], output_shape[1], input_shape[1]) 220 | return [output_shape, attention_shape] 221 | return output_shape 222 | 223 | def compute_mask(self, inputs, mask=None): 224 | if self.return_attention: 225 | return [mask, None] 226 | return mask 227 | 228 | def _attention_regularizer(self, attention): 229 | batch_size = K.cast(K.shape(attention)[0], K.floatx()) 230 | input_len = K.shape(attention)[-1] 231 | indices = K.expand_dims(K.arange(0, input_len), axis=0) 232 | diagonal = K.expand_dims(K.arange(0, input_len), axis=-1) 233 | eye = K.cast(K.equal(indices, diagonal), K.floatx()) 234 | return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot( 235 | attention, 236 | K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size 237 | 238 | @staticmethod 239 | def get_custom_objects(): 240 | return {'SeqSelfAttention': SeqSelfAttention} 241 | -------------------------------------------------------------------------------- /keras_self_attention/seq_weighted_attention.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | class SeqWeightedAttention(keras.layers.Layer): 6 | r"""Y = \text{softmax}(XW + b) X 7 | 8 | See: https://arxiv.org/pdf/1708.00524.pdf 9 | """ 10 | 11 | def __init__(self, use_bias=True, return_attention=False, **kwargs): 12 | super(SeqWeightedAttention, self).__init__(**kwargs) 13 | self.supports_masking = True 14 | self.use_bias = use_bias 15 | self.return_attention = return_attention 16 | self.W, self.b = None, None 17 | 18 | def get_config(self): 19 | config = { 20 | 'use_bias': self.use_bias, 21 | 'return_attention': self.return_attention, 22 | } 23 | base_config = super(SeqWeightedAttention, self).get_config() 24 | return dict(list(base_config.items()) + list(config.items())) 25 | 26 | def build(self, input_shape): 27 | self.W = self.add_weight(shape=(int(input_shape[2]), 1), 28 | name='{}_W'.format(self.name), 29 | initializer=keras.initializers.get('uniform')) 30 | if self.use_bias: 31 | self.b = self.add_weight(shape=(1,), 32 | name='{}_b'.format(self.name), 33 | initializer=keras.initializers.get('zeros')) 34 | super(SeqWeightedAttention, self).build(input_shape) 35 | 36 | def call(self, x, mask=None): 37 | logits = K.dot(x, self.W) 38 | if self.use_bias: 39 | logits += self.b 40 | x_shape = K.shape(x) 41 | logits = K.reshape(logits, (x_shape[0], x_shape[1])) 42 | if mask is not None: 43 | mask = K.cast(mask, K.floatx()) 44 | logits -= 10000.0 * (1.0 - mask) 45 | ai = K.exp(logits - K.max(logits, axis=-1, keepdims=True)) 46 | att_weights = ai / (K.sum(ai, axis=1, keepdims=True) + K.epsilon()) 47 | weighted_input = x * K.expand_dims(att_weights) 48 | result = K.sum(weighted_input, axis=1) 49 | if self.return_attention: 50 | return [result, att_weights] 51 | return result 52 | 53 | def compute_output_shape(self, input_shape): 54 | output_len = input_shape[2] 55 | if self.return_attention: 56 | return [(input_shape[0], output_len), (input_shape[0], input_shape[1])] 57 | return input_shape[0], output_len 58 | 59 | def compute_mask(self, _, input_mask=None): 60 | if self.return_attention: 61 | return [None, None] 62 | return None 63 | 64 | @staticmethod 65 | def get_custom_objects(): 66 | return {'SeqWeightedAttention': SeqWeightedAttention} 67 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -f dist/* && python3 setup.py sdist && twine upload dist/* 3 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | setuptools>=38.6.0 2 | twine>=1.11.0 3 | wheel>=0.31.0 4 | nose 5 | tensorflow 6 | pycodestyle 7 | coverage 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | from setuptools import setup, find_packages 5 | 6 | current_path = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | 9 | def read_file(*parts): 10 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 11 | return reader.read() 12 | 13 | 14 | def get_requirements(*parts): 15 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 16 | return list(map(lambda x: x.strip(), reader.readlines())) 17 | 18 | 19 | def find_version(*file_paths): 20 | version_file = read_file(*file_paths) 21 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 22 | if version_match: 23 | return version_match.group(1) 24 | raise RuntimeError('Unable to find version string.') 25 | 26 | 27 | setup( 28 | name='keras-self-attention', 29 | version=find_version('keras_self_attention', '__init__.py'), 30 | packages=find_packages(), 31 | url='https://github.com/CyberZHG/keras-self-attention', 32 | license='MIT', 33 | author='CyberZHG', 34 | author_email='CyberZHG@users.noreply.github.com', 35 | description='Attention mechanism for processing sequential data that considers the context for each timestamp', 36 | long_description=read_file('README.md'), 37 | long_description_content_type='text/markdown', 38 | install_requires=get_requirements('requirements.txt'), 39 | classifiers=( 40 | "Programming Language :: Python :: 3", 41 | "License :: OSI Approved :: MIT License", 42 | "Operating System :: OS Independent", 43 | ), 44 | ) 45 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | pycodestyle --max-line-length=120 keras_self_attention tests && 3 | nosetests --nocapture --with-coverage --cover-erase --cover-html --cover-html-dir=htmlcov --cover-package=keras_self_attention tests 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-self-attention/f3bf21dbb1f3251b5417a8bb254dd91807b1aec5/tests/__init__.py -------------------------------------------------------------------------------- /tests/scaled_dot_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-self-attention/f3bf21dbb1f3251b5417a8bb254dd91807b1aec5/tests/scaled_dot_attention/__init__.py -------------------------------------------------------------------------------- /tests/scaled_dot_attention/test_history.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorflow import keras 5 | 6 | from keras_self_attention import ScaledDotProductAttention 7 | 8 | 9 | class TestHistory(unittest.TestCase): 10 | 11 | def test_history(self): 12 | input_layer = keras.layers.Input( 13 | shape=(5,), 14 | name='Input', 15 | ) 16 | embed_layer = keras.layers.Embedding( 17 | input_dim=4, 18 | output_dim=5, 19 | mask_zero=True, 20 | weights=[ 21 | np.asarray([ 22 | [0.1, 0.2, 0.3, 0.4, 0.5], 23 | [0.2, 0.3, 0.4, 0.6, 0.5], 24 | [0.4, 0.7, 0.2, 0.6, 0.9], 25 | [0.3, 0.5, 0.8, 0.9, 0.1], 26 | ]), 27 | ], 28 | name='Embedding', 29 | )(input_layer) 30 | att_layer, att_weights = ScaledDotProductAttention( 31 | history_only=True, 32 | return_attention=True, 33 | name='Attention', 34 | )([embed_layer, embed_layer, embed_layer]) 35 | model = keras.models.Model(inputs=input_layer, outputs=[att_layer, att_weights]) 36 | model.compile(optimizer='adam', loss='mse') 37 | model.summary() 38 | inputs = np.array([[1, 2, 3, 1, 0]]) 39 | predicts = model.predict(inputs) 40 | results, weights = predicts[0][0], predicts[1][0] 41 | self.assertFalse(np.allclose(results[0], results[3])) 42 | self.assertTrue(np.allclose( 43 | np.asarray([0.2, 0.3, 0.4, 0.6, 0.5]), 44 | results[0], 45 | ), results[0]) 46 | for i in range(4): 47 | for j in range(5): 48 | if j > i: 49 | self.assertEqual(0.0, weights[i][j]) 50 | else: 51 | self.assertLess(0.0, weights[i][j]) 52 | -------------------------------------------------------------------------------- /tests/scaled_dot_attention/test_real_former.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from keras_self_attention import ResidualScaledDotProductAttention 9 | 10 | 11 | class TestResidualScaledDotProductAttention(unittest.TestCase): 12 | 13 | def test_history(self): 14 | input_layer = keras.layers.Input( 15 | shape=(5,), 16 | name='Input', 17 | ) 18 | prev_layer = keras.layers.Input( 19 | shape=(5, 5), 20 | name='Prev', 21 | ) 22 | embed_layer = keras.layers.Embedding( 23 | input_dim=4, 24 | output_dim=5, 25 | mask_zero=True, 26 | weights=[ 27 | np.asarray([ 28 | [0.1, 0.2, 0.3, 0.4, 0.5], 29 | [0.2, 0.3, 0.4, 0.6, 0.5], 30 | [0.4, 0.7, 0.2, 0.6, 0.9], 31 | [0.3, 0.5, 0.8, 0.9, 0.1], 32 | ]), 33 | ], 34 | name='Embedding', 35 | )(input_layer) 36 | att_layer, _, att_weights = ResidualScaledDotProductAttention( 37 | history_only=True, 38 | return_attention=True, 39 | name='Attention', 40 | )([embed_layer, embed_layer, embed_layer, prev_layer]) 41 | model = keras.models.Model(inputs=[input_layer, prev_layer], outputs=[att_layer, att_weights]) 42 | model.compile(optimizer='adam', loss='mse') 43 | model_path = os.path.join(tempfile.gettempdir(), 'keras_self_att_test_sl_%f.h5' % np.random.random()) 44 | model.save(model_path) 45 | model = keras.models.load_model( 46 | model_path, 47 | custom_objects={ 48 | 'ResidualScaledDotProductAttention': ResidualScaledDotProductAttention, 49 | }, 50 | ) 51 | inputs = np.array([[1, 2, 3, 1, 0]]) 52 | prev = np.zeros((1, 5, 5)) 53 | predicts = model.predict([inputs, prev]) 54 | results, weights = predicts[0][0], predicts[1][0] 55 | self.assertFalse(np.allclose(results[0], results[3])) 56 | self.assertTrue(np.allclose( 57 | np.asarray([0.2, 0.3, 0.4, 0.6, 0.5]), 58 | results[0], 59 | ), results[0]) 60 | for i in range(4): 61 | for j in range(5): 62 | if j > i: 63 | self.assertEqual(0.0, weights[i][j]) 64 | else: 65 | self.assertLess(0.0, weights[i][j]) 66 | 67 | def test_sample(self): 68 | input_layer = keras.layers.Input( 69 | shape=(5,), 70 | name='Input', 71 | ) 72 | prev_layer = keras.layers.Input( 73 | shape=(5, 5), 74 | name='Prev', 75 | ) 76 | embed_layer = keras.layers.Embedding( 77 | input_dim=4, 78 | output_dim=5, 79 | mask_zero=True, 80 | weights=[ 81 | np.array([ 82 | [0.1, 0.2, 0.3, 0.4, 0.5], 83 | [0.2, 0.3, 0.4, 0.6, 0.5], 84 | [0.4, 0.7, 0.2, 0.6, 0.9], 85 | [0.3, 0.5, 0.8, 0.9, 0.1], 86 | ]), 87 | ], 88 | name='Embedding', 89 | )(input_layer) 90 | att_layer, _ = ResidualScaledDotProductAttention(name='Attention')([embed_layer, prev_layer]) 91 | model = keras.models.Model(inputs=[input_layer, prev_layer], outputs=att_layer) 92 | model.compile(optimizer='adam', loss='mse') 93 | inputs = np.array([[1, 2, 3, 1, 0]]) 94 | prev = np.zeros((1, 5, 5)) 95 | predict = model.predict([inputs, prev])[0] 96 | self.assertTrue(np.allclose(predict[0], predict[3])) 97 | self.assertTrue(np.allclose( 98 | np.asarray([0.27883747, 0.45767492, 0.47448885, 0.69199574, 0.47368336]), 99 | predict[2], 100 | ), predict[2]) 101 | -------------------------------------------------------------------------------- /tests/scaled_dot_attention/test_sample.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorflow import keras 5 | 6 | from keras_self_attention import ScaledDotProductAttention 7 | 8 | 9 | class TestAttention(unittest.TestCase): 10 | 11 | def test_sample(self): 12 | input_layer = keras.layers.Input( 13 | shape=(5,), 14 | name='Input', 15 | ) 16 | embed_layer = keras.layers.Embedding( 17 | input_dim=4, 18 | output_dim=5, 19 | mask_zero=True, 20 | weights=[ 21 | np.array([ 22 | [0.1, 0.2, 0.3, 0.4, 0.5], 23 | [0.2, 0.3, 0.4, 0.6, 0.5], 24 | [0.4, 0.7, 0.2, 0.6, 0.9], 25 | [0.3, 0.5, 0.8, 0.9, 0.1], 26 | ]), 27 | ], 28 | name='Embedding', 29 | )(input_layer) 30 | att_layer = ScaledDotProductAttention(name='Attention')(embed_layer) 31 | model = keras.models.Model(inputs=input_layer, outputs=att_layer) 32 | model.compile(optimizer='adam', loss='mse') 33 | model.summary() 34 | inputs = np.array([[1, 2, 3, 1, 0]]) 35 | predict = model.predict(inputs)[0] 36 | self.assertTrue(np.allclose(predict[0], predict[3])) 37 | self.assertTrue(np.allclose( 38 | np.asarray([0.27883747, 0.45767492, 0.47448885, 0.69199574, 0.47368336]), 39 | predict[2], 40 | ), predict[2]) 41 | -------------------------------------------------------------------------------- /tests/scaled_dot_attention/test_save_load.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import tempfile 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from keras_self_attention import ScaledDotProductAttention 9 | 10 | 11 | class TestSaveLoad(unittest.TestCase): 12 | 13 | def test_save_load(self): 14 | input_q = keras.layers.Input(shape=(5, 3), name='Input-Q') 15 | input_k = keras.layers.Input(shape=(4, 3), name='Input-K') 16 | input_v = keras.layers.Input(shape=(4, 6), name='Input-V') 17 | attention, weights = ScaledDotProductAttention( 18 | return_attention=True, 19 | history_only=True, 20 | name='Attention', 21 | )([input_q, input_k, input_v]) 22 | model = keras.models.Model(inputs=[input_q, input_k, input_v], outputs=[attention, weights]) 23 | model.compile(optimizer='adam', loss='mse') 24 | model_path = os.path.join(tempfile.gettempdir(), 'keras_self_att_test_sl_%f.h5' % np.random.random()) 25 | model.save(model_path) 26 | model = keras.models.load_model( 27 | model_path, 28 | custom_objects={ 29 | 'ScaledDotProductAttention': ScaledDotProductAttention, 30 | }, 31 | ) 32 | model.summary(line_length=120) 33 | self.assertTrue(model is not None) 34 | -------------------------------------------------------------------------------- /tests/seq_self_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-self-attention/f3bf21dbb1f3251b5417a8bb254dd91807b1aec5/tests/seq_self_attention/__init__.py -------------------------------------------------------------------------------- /tests/seq_self_attention/test_activation.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | from keras_self_attention import SeqSelfAttention 4 | from .util import TestMaskShape 5 | 6 | 7 | class TestActivation(TestMaskShape): 8 | 9 | def test_attention_activation(self): 10 | attention = SeqSelfAttention(return_attention=True, 11 | attention_width=3, 12 | kernel_regularizer=keras.regularizers.l2(1e-4), 13 | bias_regularizer=keras.regularizers.l1(1e-4), 14 | attention_activation='sigmoid', 15 | name='Attention') 16 | self.check_mask_shape(attention) 17 | -------------------------------------------------------------------------------- /tests/seq_self_attention/test_bias.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | from keras_self_attention import SeqSelfAttention 4 | from .util import TestMaskShape 5 | 6 | 7 | class TestBias(TestMaskShape): 8 | 9 | def test_no_bias(self): 10 | attention = SeqSelfAttention(return_attention=True, 11 | attention_width=3, 12 | kernel_regularizer=keras.regularizers.l2(1e-4), 13 | bias_regularizer=keras.regularizers.l1(1e-4), 14 | use_additive_bias=False, 15 | use_attention_bias=False, 16 | attention_activation='relu', 17 | name='Attention') 18 | self.check_mask_shape(attention) 19 | -------------------------------------------------------------------------------- /tests/seq_self_attention/test_history.py: -------------------------------------------------------------------------------- 1 | from keras_self_attention import SeqSelfAttention 2 | from .util import TestMaskShape 3 | 4 | 5 | class TestHistory(TestMaskShape): 6 | 7 | def test_history(self): 8 | attention = SeqSelfAttention(return_attention=True, 9 | attention_width=3, 10 | history_only=True, 11 | name='Attention') 12 | self.check_mask_shape(attention) 13 | 14 | def test_infinite_history(self): 15 | attention = SeqSelfAttention(return_attention=True, 16 | history_only=True, 17 | name='Attention') 18 | self.check_mask_shape(attention) 19 | -------------------------------------------------------------------------------- /tests/seq_self_attention/test_local.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | from keras_self_attention import SeqSelfAttention 4 | from .util import TestMaskShape 5 | 6 | 7 | class TestLocal(TestMaskShape): 8 | 9 | def check_local_range(self, attention_type): 10 | attention = SeqSelfAttention(return_attention=True, 11 | attention_width=5, 12 | attention_type=attention_type, 13 | kernel_regularizer=keras.regularizers.l2(1e-4), 14 | bias_regularizer=keras.regularizers.l1(1e-4), 15 | name='Attention') 16 | self.check_mask_shape(attention) 17 | 18 | def test_add(self): 19 | self.check_local_range(attention_type=SeqSelfAttention.ATTENTION_TYPE_ADD) 20 | 21 | def test_mul(self): 22 | self.check_local_range(attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL) 23 | -------------------------------------------------------------------------------- /tests/seq_self_attention/test_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow import keras 3 | 4 | from keras_self_attention import SeqSelfAttention 5 | from .util import TestMaskShape 6 | 7 | 8 | class TestLoss(TestMaskShape): 9 | 10 | def test_loss(self): 11 | attention = SeqSelfAttention(return_attention=False, 12 | attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL, 13 | kernel_regularizer=keras.regularizers.l2(1e-6), 14 | bias_regularizer=keras.regularizers.l1(1e-6), 15 | attention_regularizer_weight=1e-4, 16 | name='Attention') 17 | sentences, input_data, token_dict = self.get_input_data() 18 | model = self.get_model(attention, token_dict) 19 | sentence_len = input_data.shape[1] 20 | model.fit( 21 | x=input_data, 22 | y=np.zeros((len(sentences), sentence_len, 1)), 23 | epochs=10, 24 | ) 25 | self.assertTrue(model is not None) 26 | -------------------------------------------------------------------------------- /tests/seq_self_attention/test_mask.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | from keras_self_attention import SeqSelfAttention 4 | from .util import TestMaskShape 5 | 6 | 7 | class TestMask(TestMaskShape): 8 | 9 | def test_return_attention(self): 10 | attention = SeqSelfAttention(return_attention=True, 11 | kernel_regularizer=keras.regularizers.l2(1e-4), 12 | bias_regularizer=keras.regularizers.l1(1e-4), 13 | name='Attention') 14 | self.check_mask_shape(attention) 15 | -------------------------------------------------------------------------------- /tests/seq_self_attention/test_mul.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | from keras_self_attention import SeqSelfAttention 4 | from .util import TestMaskShape 5 | 6 | 7 | class TestMul(TestMaskShape): 8 | 9 | def test_multiplicative(self): 10 | attention = SeqSelfAttention(return_attention=True, 11 | attention_width=15, 12 | attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL, 13 | kernel_regularizer=keras.regularizers.l2(1e-4), 14 | bias_regularizer=keras.regularizers.l1(1e-4), 15 | name='Attention') 16 | self.check_mask_shape(attention) 17 | 18 | def test_not_implemented(self): 19 | with self.assertRaises(NotImplementedError): 20 | SeqSelfAttention(return_attention=True, 21 | attention_type='random') 22 | -------------------------------------------------------------------------------- /tests/seq_self_attention/test_save_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import numpy as np 5 | from tensorflow import keras 6 | 7 | from keras_self_attention import SeqSelfAttention 8 | from .util import TestMaskShape 9 | 10 | 11 | class TestSaveLoad(TestMaskShape): 12 | 13 | def test_save_load(self): 14 | _, _, token_dict = self.get_input_data() 15 | model = self.get_model(SeqSelfAttention(name='Attention'), token_dict) 16 | model_path = os.path.join(tempfile.gettempdir(), 'keras_self_att_test_save_load_%f.h5' % np.random.random()) 17 | model.save(model_path) 18 | model = keras.models.load_model(model_path, custom_objects={'SeqSelfAttention': SeqSelfAttention}) 19 | model.summary() 20 | self.assertTrue(model is not None) 21 | 22 | def test_save_load_with_loss(self): 23 | attention = SeqSelfAttention(return_attention=True, 24 | attention_width=7, 25 | attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL, 26 | kernel_regularizer=keras.regularizers.l2(1e-4), 27 | bias_regularizer=keras.regularizers.l1(1e-4), 28 | attention_regularizer_weight=1e-3, 29 | name='Attention') 30 | _, _, token_dict = self.get_input_data() 31 | model = self.get_model(attention, token_dict) 32 | model_path = os.path.join(tempfile.gettempdir(), 'keras_self_att_test_sl_with_loss_%f.h5' % np.random.random()) 33 | model.save(model_path) 34 | model = keras.models.load_model(model_path, custom_objects=SeqSelfAttention.get_custom_objects()) 35 | model.summary() 36 | self.assertTrue(model is not None) 37 | -------------------------------------------------------------------------------- /tests/seq_self_attention/util.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorflow import keras 5 | 6 | 7 | class TestMaskShape(unittest.TestCase): 8 | 9 | @staticmethod 10 | def get_input_data(): 11 | sentences = [ 12 | ['All', 'work', 'and', 'no', 'play'], 13 | ['makes', 'Jack', 'a', 'dull', 'boy', '.'], 14 | ['From', 'that', 'day', 'forth', 'my', 'arm', 'changed'], 15 | ] 16 | token_dict = { 17 | '': 0, 18 | '': 1, 19 | } 20 | sentence_len = max(map(len, sentences)) 21 | input_data = [[0] * sentence_len for _ in range(len(sentences))] 22 | for i, sentence in enumerate(sentences): 23 | for j, token in enumerate(sentence): 24 | if token in token_dict: 25 | input_data[i][j] = token_dict[token] 26 | elif np.random.randint(0, 5) == 0: 27 | input_data[i][j] = token_dict[''] 28 | else: 29 | input_data[i][j] = len(token_dict) 30 | token_dict[token] = len(token_dict) 31 | return sentences, np.asarray(input_data), token_dict 32 | 33 | @staticmethod 34 | def get_model(attention, token_dict): 35 | inputs = keras.layers.Input(shape=(None,), name='Input') 36 | embd = keras.layers.Embedding(input_dim=len(token_dict), 37 | output_dim=16, 38 | mask_zero=True, 39 | name='Embedding')(inputs) 40 | lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=16, 41 | return_sequences=True), 42 | name='Bi-LSTM')(embd) 43 | if attention.return_attention: 44 | att, weights = attention(lstm) 45 | else: 46 | att = attention(lstm) 47 | dense = keras.layers.Dense(units=5, name='Dense')(att) 48 | loss = {'Dense': 'sparse_categorical_crossentropy'} 49 | if attention.return_attention: 50 | model = keras.models.Model(inputs=inputs, outputs=[dense, weights]) 51 | loss[attention.name] = 'mse' 52 | else: 53 | model = keras.models.Model(inputs=inputs, outputs=dense) 54 | model.compile(optimizer='adam', loss=loss) 55 | model.summary(line_length=100) 56 | return model 57 | 58 | def check_mask_shape(self, attention): 59 | sentences, input_data, token_dict = self.get_input_data() 60 | model = self.get_model(attention, token_dict) 61 | outputs = model.predict(input_data) 62 | if attention.attention_width is None: 63 | attention_width = 1e9 64 | else: 65 | attention_width = attention.attention_width 66 | history_only = attention.history_only 67 | attention_output = outputs[1] 68 | for i, sentence in enumerate(sentences): 69 | for j in range(len(sentence)): 70 | for k in range(len(sentence)): 71 | if history_only and 0 <= j - k < attention_width: 72 | self.assertGreater(attention_output[i][j][k], 0.0) 73 | elif not history_only and abs(j - k) <= attention_width // 2: 74 | self.assertGreater(attention_output[i][j][k], 0.0) 75 | else: 76 | self.assertEqual(attention_output[i][j][k], 0.0) 77 | self.assertTrue(abs(np.sum(attention_output[i][j]) - 1.0) < 1e-6) 78 | -------------------------------------------------------------------------------- /tests/seq_weighted_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-self-attention/f3bf21dbb1f3251b5417a8bb254dd91807b1aec5/tests/seq_weighted_attention/__init__.py -------------------------------------------------------------------------------- /tests/seq_weighted_attention/test_save_load.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import tempfile 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from keras_self_attention import SeqWeightedAttention as Attention 9 | 10 | 11 | class TestSaveLoad(unittest.TestCase): 12 | 13 | def _test_save_load(self, attention): 14 | inputs = keras.layers.Input(shape=(None,), name='Input') 15 | embd = keras.layers.Embedding(input_dim=3, 16 | output_dim=5, 17 | mask_zero=True, 18 | name='Embedding')(inputs) 19 | lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=7, 20 | return_sequences=True), 21 | name='Bi-LSTM')(embd) 22 | if attention.return_attention: 23 | layer, weights = attention(lstm) 24 | else: 25 | layer = attention(lstm) 26 | dense = keras.layers.Dense(units=2, activation='softmax', name='Softmax')(layer) 27 | loss = {'Softmax': 'sparse_categorical_crossentropy'} 28 | if attention.return_attention: 29 | outputs = [dense, weights] 30 | loss[attention.name] = 'mse' 31 | else: 32 | outputs = dense 33 | model = keras.models.Model(inputs=inputs, outputs=outputs) 34 | model.compile(optimizer='adam', loss=loss) 35 | model_path = os.path.join(tempfile.gettempdir(), 'keras_weighted_att_test_sl_%f.h5' % np.random.random()) 36 | model.save(model_path) 37 | model = keras.models.load_model(model_path, custom_objects=Attention.get_custom_objects()) 38 | model.summary(line_length=100) 39 | if attention.return_attention: 40 | self.assertEqual(2, len(model.outputs)) 41 | else: 42 | self.assertEqual(1, len(model.outputs)) 43 | 44 | def test_default(self): 45 | self._test_save_load(Attention(name='Attention')) 46 | 47 | def test_return_attention(self): 48 | self._test_save_load(Attention(return_attention=True, use_bias=False, name='Attention')) 49 | --------------------------------------------------------------------------------