├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md └── stale.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── README.zh-CN.md ├── keras_transformer ├── __init__.py ├── gelu.py └── transformer.py ├── publish.sh ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── test.sh └── tests ├── __init__.py ├── test_decode.py ├── test_gelu.py ├── test_get_decoders.py ├── test_get_encoders.py ├── test_get_model.py ├── test_suffix_repeat.py └── test_translate.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: CyberZHG 7 | 8 | --- 9 | 10 | **Describe the Bug** 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | **Version Info** 15 | 16 | * [ ] I'm using the latest version 17 | 18 | **Minimal Codes To Reproduce** 19 | 20 | ```python 21 | import keras_bert 22 | 23 | pass 24 | ``` 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: CyberZHG 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask questions about the repo 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | daysUntilStale: 5 2 | daysUntilClose: 2 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # System thumbnail 107 | .DS_Store 108 | 109 | # IDE 110 | .idea 111 | 112 | # Temporary README 113 | README.rst 114 | 115 | # Images 116 | *.png 117 | 118 | # Models 119 | *.h5 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Zhao HG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include README.zh-CN.md 3 | include requirements.txt 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras Transformer 2 | 3 | [![Version](https://img.shields.io/pypi/v/keras-transformer.svg)](https://pypi.org/project/keras-transformer/) 4 | ![License](https://img.shields.io/pypi/l/keras-transformer.svg) 5 | 6 | \[[中文](https://github.com/CyberZHG/keras-transformer/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-transformer/blob/master/README.md)\] 7 | 8 | Implementation of [transformer](https://arxiv.org/pdf/1706.03762.pdf) for seq2seq tasks. 9 | 10 | ## Install 11 | 12 | ```bash 13 | pip install keras-transformer 14 | ``` 15 | 16 | ## Usage 17 | 18 | ### Train 19 | 20 | ```python 21 | import numpy as np 22 | from keras_transformer import get_model 23 | 24 | # Build a small toy token dictionary 25 | tokens = 'all work and no play makes jack a dull boy'.split(' ') 26 | token_dict = { 27 | '': 0, 28 | '': 1, 29 | '': 2, 30 | } 31 | for token in tokens: 32 | if token not in token_dict: 33 | token_dict[token] = len(token_dict) 34 | 35 | # Generate toy data 36 | encoder_inputs_no_padding = [] 37 | encoder_inputs, decoder_inputs, decoder_outputs = [], [], [] 38 | for i in range(1, len(tokens) - 1): 39 | encode_tokens, decode_tokens = tokens[:i], tokens[i:] 40 | encode_tokens = [''] + encode_tokens + [''] + [''] * (len(tokens) - len(encode_tokens)) 41 | output_tokens = decode_tokens + ['', ''] + [''] * (len(tokens) - len(decode_tokens)) 42 | decode_tokens = [''] + decode_tokens + [''] + [''] * (len(tokens) - len(decode_tokens)) 43 | encode_tokens = list(map(lambda x: token_dict[x], encode_tokens)) 44 | decode_tokens = list(map(lambda x: token_dict[x], decode_tokens)) 45 | output_tokens = list(map(lambda x: [token_dict[x]], output_tokens)) 46 | encoder_inputs_no_padding.append(encode_tokens[:i + 2]) 47 | encoder_inputs.append(encode_tokens) 48 | decoder_inputs.append(decode_tokens) 49 | decoder_outputs.append(output_tokens) 50 | 51 | # Build the model 52 | model = get_model( 53 | token_num=len(token_dict), 54 | embed_dim=30, 55 | encoder_num=3, 56 | decoder_num=2, 57 | head_num=3, 58 | hidden_dim=120, 59 | attention_activation='relu', 60 | feed_forward_activation='relu', 61 | dropout_rate=0.05, 62 | embed_weights=np.random.random((13, 30)), 63 | ) 64 | model.compile( 65 | optimizer='adam', 66 | loss='sparse_categorical_crossentropy', 67 | ) 68 | model.summary() 69 | 70 | # Train the model 71 | model.fit( 72 | x=[np.asarray(encoder_inputs * 1000), np.asarray(decoder_inputs * 1000)], 73 | y=np.asarray(decoder_outputs * 1000), 74 | epochs=5, 75 | ) 76 | ``` 77 | 78 | ### Predict 79 | 80 | ```python 81 | from keras_transformer import decode 82 | 83 | decoded = decode( 84 | model, 85 | encoder_inputs_no_padding, 86 | start_token=token_dict[''], 87 | end_token=token_dict[''], 88 | pad_token=token_dict[''], 89 | max_len=100, 90 | ) 91 | token_dict_rev = {v: k for k, v in token_dict.items()} 92 | for i in range(len(decoded)): 93 | print(' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1]))) 94 | ``` 95 | 96 | ### Translation 97 | 98 | ```python 99 | import numpy as np 100 | from keras_transformer import get_model, decode 101 | 102 | source_tokens = [ 103 | 'i need more power'.split(' '), 104 | 'eat jujube and pill'.split(' '), 105 | ] 106 | target_tokens = [ 107 | list('我要更多的抛瓦'), 108 | list('吃枣💊'), 109 | ] 110 | 111 | # Generate dictionaries 112 | def build_token_dict(token_list): 113 | token_dict = { 114 | '': 0, 115 | '': 1, 116 | '': 2, 117 | } 118 | for tokens in token_list: 119 | for token in tokens: 120 | if token not in token_dict: 121 | token_dict[token] = len(token_dict) 122 | return token_dict 123 | 124 | source_token_dict = build_token_dict(source_tokens) 125 | target_token_dict = build_token_dict(target_tokens) 126 | target_token_dict_inv = {v: k for k, v in target_token_dict.items()} 127 | 128 | # Add special tokens 129 | encode_tokens = [[''] + tokens + [''] for tokens in source_tokens] 130 | decode_tokens = [[''] + tokens + [''] for tokens in target_tokens] 131 | output_tokens = [tokens + ['', ''] for tokens in target_tokens] 132 | 133 | # Padding 134 | source_max_len = max(map(len, encode_tokens)) 135 | target_max_len = max(map(len, decode_tokens)) 136 | 137 | encode_tokens = [tokens + [''] * (source_max_len - len(tokens)) for tokens in encode_tokens] 138 | decode_tokens = [tokens + [''] * (target_max_len - len(tokens)) for tokens in decode_tokens] 139 | output_tokens = [tokens + [''] * (target_max_len - len(tokens)) for tokens in output_tokens] 140 | 141 | encode_input = [list(map(lambda x: source_token_dict[x], tokens)) for tokens in encode_tokens] 142 | decode_input = [list(map(lambda x: target_token_dict[x], tokens)) for tokens in decode_tokens] 143 | decode_output = [list(map(lambda x: [target_token_dict[x]], tokens)) for tokens in output_tokens] 144 | 145 | # Build & fit model 146 | model = get_model( 147 | token_num=max(len(source_token_dict), len(target_token_dict)), 148 | embed_dim=32, 149 | encoder_num=2, 150 | decoder_num=2, 151 | head_num=4, 152 | hidden_dim=128, 153 | dropout_rate=0.05, 154 | use_same_embed=False, # Use different embeddings for different languages 155 | ) 156 | model.compile('adam', 'sparse_categorical_crossentropy') 157 | model.summary() 158 | 159 | model.fit( 160 | x=[np.array(encode_input * 1024), np.array(decode_input * 1024)], 161 | y=np.array(decode_output * 1024), 162 | epochs=10, 163 | batch_size=32, 164 | ) 165 | 166 | # Predict 167 | decoded = decode( 168 | model, 169 | encode_input, 170 | start_token=target_token_dict[''], 171 | end_token=target_token_dict[''], 172 | pad_token=target_token_dict[''], 173 | ) 174 | print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1]))) 175 | print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1]))) 176 | ``` 177 | 178 | ### Decode 179 | 180 | In `decode`, the word with top probability is selected as the predicted token by default. You can add randomness by setting `top_k` and `temperature`: 181 | 182 | ```python 183 | decoded = decode( 184 | model, 185 | encode_input, 186 | start_token=target_token_dict[''], 187 | end_token=target_token_dict[''], 188 | pad_token=target_token_dict[''], 189 | top_k=10, 190 | temperature=1.0, 191 | ) 192 | print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1]))) 193 | print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1]))) 194 | ``` 195 | -------------------------------------------------------------------------------- /README.zh-CN.md: -------------------------------------------------------------------------------- 1 | # Keras Transformer 2 | 3 | [![Version](https://img.shields.io/pypi/v/keras-transformer.svg)](https://pypi.org/project/keras-transformer/) 4 | ![License](https://img.shields.io/pypi/l/keras-transformer.svg) 5 | 6 | \[[中文](https://github.com/CyberZHG/keras-transformer/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-transformer/blob/master/README.md)\] 7 | 8 | [Transformer](https://arxiv.org/pdf/1706.03762.pdf)的实现。 9 | 10 | ## 安装 11 | 12 | ```bash 13 | pip install keras-transformer 14 | ``` 15 | 16 | ## 使用 17 | 18 | ### 训练 19 | 20 | ```python 21 | import numpy as np 22 | from keras_transformer import get_model 23 | 24 | # 构建一个toy词典 25 | tokens = 'all work and no play makes jack a dull boy'.split(' ') 26 | token_dict = { 27 | '': 0, 28 | '': 1, 29 | '': 2, 30 | } 31 | for token in tokens: 32 | if token not in token_dict: 33 | token_dict[token] = len(token_dict) 34 | 35 | # 生成toy数据 36 | encoder_inputs_no_padding = [] 37 | encoder_inputs, decoder_inputs, decoder_outputs = [], [], [] 38 | for i in range(1, len(tokens) - 1): 39 | encode_tokens, decode_tokens = tokens[:i], tokens[i:] 40 | encode_tokens = [''] + encode_tokens + [''] + [''] * (len(tokens) - len(encode_tokens)) 41 | output_tokens = decode_tokens + ['', ''] + [''] * (len(tokens) - len(decode_tokens)) 42 | decode_tokens = [''] + decode_tokens + [''] + [''] * (len(tokens) - len(decode_tokens)) 43 | encode_tokens = list(map(lambda x: token_dict[x], encode_tokens)) 44 | decode_tokens = list(map(lambda x: token_dict[x], decode_tokens)) 45 | output_tokens = list(map(lambda x: [token_dict[x]], output_tokens)) 46 | encoder_inputs_no_padding.append(encode_tokens[:i + 2]) 47 | encoder_inputs.append(encode_tokens) 48 | decoder_inputs.append(decode_tokens) 49 | decoder_outputs.append(output_tokens) 50 | 51 | # 构建模型 52 | model = get_model( 53 | token_num=len(token_dict), 54 | embed_dim=30, 55 | encoder_num=3, 56 | decoder_num=2, 57 | head_num=3, 58 | hidden_dim=120, 59 | attention_activation='relu', 60 | feed_forward_activation='relu', 61 | dropout_rate=0.05, 62 | embed_weights=np.random.random((13, 30)), 63 | ) 64 | model.compile( 65 | optimizer='adam', 66 | loss='sparse_categorical_crossentropy', 67 | ) 68 | model.summary() 69 | 70 | # Train the model 71 | model.fit( 72 | x=[np.asarray(encoder_inputs * 1000), np.asarray(decoder_inputs * 1000)], 73 | y=np.asarray(decoder_outputs * 1000), 74 | epochs=5, 75 | ) 76 | ``` 77 | 78 | ### 预测 79 | 80 | ```python 81 | from keras_transformer import decode 82 | 83 | decoded = decode( 84 | model, 85 | encoder_inputs_no_padding, 86 | start_token=token_dict[''], 87 | end_token=token_dict[''], 88 | pad_token=token_dict[''], 89 | max_len=100, 90 | ) 91 | token_dict_rev = {v: k for k, v in token_dict.items()} 92 | for i in range(len(decoded)): 93 | print(' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1]))) 94 | ``` 95 | 96 | ### 翻译 97 | 98 | ```python 99 | import numpy as np 100 | from keras_transformer import get_model, decode 101 | 102 | source_tokens = [ 103 | 'i need more power'.split(' '), 104 | 'eat jujube and pill'.split(' '), 105 | ] 106 | target_tokens = [ 107 | list('我要更多的抛瓦'), 108 | list('吃枣💊'), 109 | ] 110 | 111 | # 生成不同语言的词典 112 | def build_token_dict(token_list): 113 | token_dict = { 114 | '': 0, 115 | '': 1, 116 | '': 2, 117 | } 118 | for tokens in token_list: 119 | for token in tokens: 120 | if token not in token_dict: 121 | token_dict[token] = len(token_dict) 122 | return token_dict 123 | 124 | source_token_dict = build_token_dict(source_tokens) 125 | target_token_dict = build_token_dict(target_tokens) 126 | target_token_dict_inv = {v: k for k, v in target_token_dict.items()} 127 | 128 | # 添加特殊符号 129 | encode_tokens = [[''] + tokens + [''] for tokens in source_tokens] 130 | decode_tokens = [[''] + tokens + [''] for tokens in target_tokens] 131 | output_tokens = [tokens + ['', ''] for tokens in target_tokens] 132 | 133 | # 补齐长度 134 | source_max_len = max(map(len, encode_tokens)) 135 | target_max_len = max(map(len, decode_tokens)) 136 | 137 | encode_tokens = [tokens + [''] * (source_max_len - len(tokens)) for tokens in encode_tokens] 138 | decode_tokens = [tokens + [''] * (target_max_len - len(tokens)) for tokens in decode_tokens] 139 | output_tokens = [tokens + [''] * (target_max_len - len(tokens)) for tokens in output_tokens] 140 | 141 | encode_input = [list(map(lambda x: source_token_dict[x], tokens)) for tokens in encode_tokens] 142 | decode_input = [list(map(lambda x: target_token_dict[x], tokens)) for tokens in decode_tokens] 143 | decode_output = [list(map(lambda x: [target_token_dict[x]], tokens)) for tokens in output_tokens] 144 | 145 | # 构建和训练模型 146 | model = get_model( 147 | token_num=max(len(source_token_dict), len(target_token_dict)), 148 | embed_dim=32, 149 | encoder_num=2, 150 | decoder_num=2, 151 | head_num=4, 152 | hidden_dim=128, 153 | dropout_rate=0.05, 154 | use_same_embed=False, # 不同语言需要使用不同的词嵌入 155 | ) 156 | model.compile('adam', 'sparse_categorical_crossentropy') 157 | model.summary() 158 | 159 | model.fit( 160 | x=[np.array(encode_input * 1024), np.array(decode_input * 1024)], 161 | y=np.array(decode_output * 1024), 162 | epochs=10, 163 | batch_size=32, 164 | ) 165 | 166 | # 预测过程 167 | decoded = decode( 168 | model, 169 | encode_input, 170 | start_token=target_token_dict[''], 171 | end_token=target_token_dict[''], 172 | pad_token=target_token_dict[''], 173 | ) 174 | print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1]))) 175 | print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1]))) 176 | ``` 177 | 178 | ### 柱搜索 179 | 180 | 默认参数下,`decode`只使用概率最高的词作为结果。通过调整`top_k`和`temperature`可以启用柱搜索,较高的温度会使每个词被选中的概率更为平均,而极为接近零的温度相当于`top_k`为1的结果: 181 | 182 | ```python 183 | decoded = decode( 184 | model, 185 | encode_input, 186 | start_token=target_token_dict[''], 187 | end_token=target_token_dict[''], 188 | pad_token=target_token_dict[''], 189 | top_k=10, 190 | temperature=1.0, 191 | ) 192 | print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1]))) 193 | print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1]))) 194 | ``` 195 | -------------------------------------------------------------------------------- /keras_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .gelu import gelu 2 | from .transformer import * 3 | 4 | __version__ = '0.40.0' 5 | -------------------------------------------------------------------------------- /keras_transformer/gelu.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from tensorflow.keras import backend as K 4 | 5 | 6 | def gelu(x): 7 | """An approximation of gelu. 8 | 9 | See: https://arxiv.org/pdf/1606.08415.pdf 10 | """ 11 | return 0.5 * x * (1.0 + K.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x * x * x))) 12 | -------------------------------------------------------------------------------- /keras_transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow import keras 3 | 4 | from keras_layer_normalization import LayerNormalization 5 | from keras_multi_head import MultiHeadAttention 6 | from keras_position_wise_feed_forward import FeedForward 7 | from keras_pos_embd import TrigPosEmbedding 8 | from keras_embed_sim import EmbeddingRet, EmbeddingSim 9 | from .gelu import gelu 10 | 11 | 12 | __all__ = [ 13 | 'get_custom_objects', 'get_encoders', 'get_decoders', 'get_model', 'decode', 14 | 'attention_builder', 'feed_forward_builder', 'get_encoder_component', 'get_decoder_component', 15 | ] 16 | 17 | 18 | def get_custom_objects(): 19 | return { 20 | 'gelu': gelu, 21 | 'LayerNormalization': LayerNormalization, 22 | 'MultiHeadAttention': MultiHeadAttention, 23 | 'FeedForward': FeedForward, 24 | 'TrigPosEmbedding': TrigPosEmbedding, 25 | 'EmbeddingRet': EmbeddingRet, 26 | 'EmbeddingSim': EmbeddingSim, 27 | } 28 | 29 | 30 | def _wrap_layer(name, 31 | input_layer, 32 | build_func, 33 | dropout_rate=0.0, 34 | trainable=True): 35 | """Wrap layers with residual, normalization and dropout. 36 | 37 | :param name: Prefix of names for internal layers. 38 | :param input_layer: Input layer. 39 | :param build_func: A callable that takes the input tensor and generates the output tensor. 40 | :param dropout_rate: Dropout rate. 41 | :param trainable: Whether the layers are trainable. 42 | :return: Output layer. 43 | """ 44 | build_output = build_func(input_layer) 45 | if dropout_rate > 0.0: 46 | dropout_layer = keras.layers.Dropout( 47 | rate=dropout_rate, 48 | name='%s-Dropout' % name, 49 | )(build_output) 50 | else: 51 | dropout_layer = build_output 52 | if isinstance(input_layer, list): 53 | input_layer = input_layer[0] 54 | add_layer = keras.layers.Add(name='%s-Add' % name)([input_layer, dropout_layer]) 55 | normal_layer = LayerNormalization( 56 | trainable=trainable, 57 | name='%s-Norm' % name, 58 | )(add_layer) 59 | return normal_layer 60 | 61 | 62 | def attention_builder(name, 63 | head_num, 64 | activation, 65 | history_only, 66 | trainable=True): 67 | """Get multi-head self-attention builder. 68 | 69 | :param name: Prefix of names for internal layers. 70 | :param head_num: Number of heads in multi-head self-attention. 71 | :param activation: Activation for multi-head self-attention. 72 | :param history_only: Only use history data. 73 | :param trainable: Whether the layer is trainable. 74 | :return: 75 | """ 76 | def _attention_builder(x): 77 | return MultiHeadAttention( 78 | head_num=head_num, 79 | activation=activation, 80 | history_only=history_only, 81 | trainable=trainable, 82 | name=name, 83 | )(x) 84 | return _attention_builder 85 | 86 | 87 | def feed_forward_builder(name, 88 | hidden_dim, 89 | activation, 90 | trainable=True): 91 | """Get position-wise feed-forward layer builder. 92 | 93 | :param name: Prefix of names for internal layers. 94 | :param hidden_dim: Hidden dimension of feed forward layer. 95 | :param activation: Activation for feed-forward layer. 96 | :param trainable: Whether the layer is trainable. 97 | :return: 98 | """ 99 | def _feed_forward_builder(x): 100 | return FeedForward( 101 | units=hidden_dim, 102 | activation=activation, 103 | trainable=trainable, 104 | name=name, 105 | )(x) 106 | return _feed_forward_builder 107 | 108 | 109 | def get_encoder_component(name, 110 | input_layer, 111 | head_num, 112 | hidden_dim, 113 | attention_activation=None, 114 | feed_forward_activation=gelu, 115 | dropout_rate=0.0, 116 | trainable=True,): 117 | """Multi-head self-attention and feed-forward layer. 118 | 119 | :param name: Prefix of names for internal layers. 120 | :param input_layer: Input layer. 121 | :param head_num: Number of heads in multi-head self-attention. 122 | :param hidden_dim: Hidden dimension of feed forward layer. 123 | :param attention_activation: Activation for multi-head self-attention. 124 | :param feed_forward_activation: Activation for feed-forward layer. 125 | :param dropout_rate: Dropout rate. 126 | :param trainable: Whether the layers are trainable. 127 | :return: Output layer. 128 | """ 129 | attention_name = '%s-MultiHeadSelfAttention' % name 130 | feed_forward_name = '%s-FeedForward' % name 131 | attention_layer = _wrap_layer( 132 | name=attention_name, 133 | input_layer=input_layer, 134 | build_func=attention_builder( 135 | name=attention_name, 136 | head_num=head_num, 137 | activation=attention_activation, 138 | history_only=False, 139 | trainable=trainable, 140 | ), 141 | dropout_rate=dropout_rate, 142 | trainable=trainable, 143 | ) 144 | feed_forward_layer = _wrap_layer( 145 | name=feed_forward_name, 146 | input_layer=attention_layer, 147 | build_func=feed_forward_builder( 148 | name=feed_forward_name, 149 | hidden_dim=hidden_dim, 150 | activation=feed_forward_activation, 151 | trainable=trainable, 152 | ), 153 | dropout_rate=dropout_rate, 154 | trainable=trainable, 155 | ) 156 | return feed_forward_layer 157 | 158 | 159 | def get_decoder_component(name, 160 | input_layer, 161 | encoded_layer, 162 | head_num, 163 | hidden_dim, 164 | attention_activation=None, 165 | feed_forward_activation=gelu, 166 | dropout_rate=0.0, 167 | trainable=True): 168 | """Multi-head self-attention, multi-head query attention and feed-forward layer. 169 | 170 | :param name: Prefix of names for internal layers. 171 | :param input_layer: Input layer. 172 | :param encoded_layer: Encoded layer from encoder. 173 | :param head_num: Number of heads in multi-head self-attention. 174 | :param hidden_dim: Hidden dimension of feed forward layer. 175 | :param attention_activation: Activation for multi-head self-attention. 176 | :param feed_forward_activation: Activation for feed-forward layer. 177 | :param dropout_rate: Dropout rate. 178 | :param trainable: Whether the layers are trainable. 179 | :return: Output layer. 180 | """ 181 | self_attention_name = '%s-MultiHeadSelfAttention' % name 182 | query_attention_name = '%s-MultiHeadQueryAttention' % name 183 | feed_forward_name = '%s-FeedForward' % name 184 | self_attention_layer = _wrap_layer( 185 | name=self_attention_name, 186 | input_layer=input_layer, 187 | build_func=attention_builder( 188 | name=self_attention_name, 189 | head_num=head_num, 190 | activation=attention_activation, 191 | history_only=True, 192 | trainable=trainable, 193 | ), 194 | dropout_rate=dropout_rate, 195 | trainable=trainable, 196 | ) 197 | query_attention_layer = _wrap_layer( 198 | name=query_attention_name, 199 | input_layer=[self_attention_layer, encoded_layer, encoded_layer], 200 | build_func=attention_builder( 201 | name=query_attention_name, 202 | head_num=head_num, 203 | activation=attention_activation, 204 | history_only=False, 205 | trainable=trainable, 206 | ), 207 | dropout_rate=dropout_rate, 208 | trainable=trainable, 209 | ) 210 | feed_forward_layer = _wrap_layer( 211 | name=feed_forward_name, 212 | input_layer=query_attention_layer, 213 | build_func=feed_forward_builder( 214 | name=feed_forward_name, 215 | hidden_dim=hidden_dim, 216 | activation=feed_forward_activation, 217 | trainable=trainable, 218 | ), 219 | dropout_rate=dropout_rate, 220 | trainable=trainable, 221 | ) 222 | return feed_forward_layer 223 | 224 | 225 | def get_encoders(encoder_num, 226 | input_layer, 227 | head_num, 228 | hidden_dim, 229 | attention_activation=None, 230 | feed_forward_activation=gelu, 231 | dropout_rate=0.0, 232 | trainable=True): 233 | """Get encoders. 234 | 235 | :param encoder_num: Number of encoder components. 236 | :param input_layer: Input layer. 237 | :param head_num: Number of heads in multi-head self-attention. 238 | :param hidden_dim: Hidden dimension of feed forward layer. 239 | :param attention_activation: Activation for multi-head self-attention. 240 | :param feed_forward_activation: Activation for feed-forward layer. 241 | :param dropout_rate: Dropout rate. 242 | :param trainable: Whether the layers are trainable. 243 | :return: Output layer. 244 | """ 245 | last_layer = input_layer 246 | for i in range(encoder_num): 247 | last_layer = get_encoder_component( 248 | name='Encoder-%d' % (i + 1), 249 | input_layer=last_layer, 250 | head_num=head_num, 251 | hidden_dim=hidden_dim, 252 | attention_activation=attention_activation, 253 | feed_forward_activation=feed_forward_activation, 254 | dropout_rate=dropout_rate, 255 | trainable=trainable, 256 | ) 257 | return last_layer 258 | 259 | 260 | def get_decoders(decoder_num, 261 | input_layer, 262 | encoded_layer, 263 | head_num, 264 | hidden_dim, 265 | attention_activation=None, 266 | feed_forward_activation=gelu, 267 | dropout_rate=0.0, 268 | trainable=True): 269 | """Get decoders. 270 | 271 | :param decoder_num: Number of decoder components. 272 | :param input_layer: Input layer. 273 | :param encoded_layer: Encoded layer from encoder. 274 | :param head_num: Number of heads in multi-head self-attention. 275 | :param hidden_dim: Hidden dimension of feed forward layer. 276 | :param attention_activation: Activation for multi-head self-attention. 277 | :param feed_forward_activation: Activation for feed-forward layer. 278 | :param dropout_rate: Dropout rate. 279 | :param trainable: Whether the layers are trainable. 280 | :return: Output layer. 281 | """ 282 | last_layer = input_layer 283 | for i in range(decoder_num): 284 | last_layer = get_decoder_component( 285 | name='Decoder-%d' % (i + 1), 286 | input_layer=last_layer, 287 | encoded_layer=encoded_layer, 288 | head_num=head_num, 289 | hidden_dim=hidden_dim, 290 | attention_activation=attention_activation, 291 | feed_forward_activation=feed_forward_activation, 292 | dropout_rate=dropout_rate, 293 | trainable=trainable, 294 | ) 295 | return last_layer 296 | 297 | 298 | def get_model(token_num, 299 | embed_dim, 300 | encoder_num, 301 | decoder_num, 302 | head_num, 303 | hidden_dim, 304 | attention_activation=None, 305 | feed_forward_activation=gelu, 306 | dropout_rate=0.0, 307 | use_same_embed=True, 308 | embed_weights=None, 309 | embed_trainable=None, 310 | trainable=True): 311 | """Get full model without compilation. 312 | 313 | :param token_num: Number of distinct tokens. 314 | :param embed_dim: Dimension of token embedding. 315 | :param encoder_num: Number of encoder components. 316 | :param decoder_num: Number of decoder components. 317 | :param head_num: Number of heads in multi-head self-attention. 318 | :param hidden_dim: Hidden dimension of feed forward layer. 319 | :param attention_activation: Activation for multi-head self-attention. 320 | :param feed_forward_activation: Activation for feed-forward layer. 321 | :param dropout_rate: Dropout rate. 322 | :param use_same_embed: Whether to use the same token embedding layer. `token_num`, `embed_weights` and 323 | `embed_trainable` should be lists of two elements if it is False. 324 | :param embed_weights: Initial weights of token embedding. 325 | :param embed_trainable: Whether the token embedding is trainable. It will automatically set to False if the given 326 | value is None when embedding weights has been provided. 327 | :param trainable: Whether the layers are trainable. 328 | :return: Keras model. 329 | """ 330 | if not isinstance(token_num, list): 331 | token_num = [token_num, token_num] 332 | encoder_token_num, decoder_token_num = token_num 333 | 334 | if not isinstance(embed_weights, list): 335 | embed_weights = [embed_weights, embed_weights] 336 | encoder_embed_weights, decoder_embed_weights = embed_weights 337 | if encoder_embed_weights is not None: 338 | encoder_embed_weights = [encoder_embed_weights] 339 | if decoder_embed_weights is not None: 340 | decoder_embed_weights = [decoder_embed_weights] 341 | 342 | if not isinstance(embed_trainable, list): 343 | embed_trainable = [embed_trainable, embed_trainable] 344 | encoder_embed_trainable, decoder_embed_trainable = embed_trainable 345 | if encoder_embed_trainable is None: 346 | encoder_embed_trainable = encoder_embed_weights is None 347 | if decoder_embed_trainable is None: 348 | decoder_embed_trainable = decoder_embed_weights is None 349 | 350 | if use_same_embed: 351 | encoder_embed_layer = decoder_embed_layer = EmbeddingRet( 352 | input_dim=encoder_token_num, 353 | output_dim=embed_dim, 354 | mask_zero=True, 355 | weights=encoder_embed_weights, 356 | trainable=encoder_embed_trainable, 357 | name='Token-Embedding', 358 | ) 359 | else: 360 | encoder_embed_layer = EmbeddingRet( 361 | input_dim=encoder_token_num, 362 | output_dim=embed_dim, 363 | mask_zero=True, 364 | weights=encoder_embed_weights, 365 | trainable=encoder_embed_trainable, 366 | name='Encoder-Token-Embedding', 367 | ) 368 | decoder_embed_layer = EmbeddingRet( 369 | input_dim=decoder_token_num, 370 | output_dim=embed_dim, 371 | mask_zero=True, 372 | weights=decoder_embed_weights, 373 | trainable=decoder_embed_trainable, 374 | name='Decoder-Token-Embedding', 375 | ) 376 | encoder_input = keras.layers.Input(shape=(None,), name='Encoder-Input') 377 | encoder_embed = TrigPosEmbedding( 378 | mode=TrigPosEmbedding.MODE_ADD, 379 | name='Encoder-Embedding', 380 | )(encoder_embed_layer(encoder_input)[0]) 381 | encoded_layer = get_encoders( 382 | encoder_num=encoder_num, 383 | input_layer=encoder_embed, 384 | head_num=head_num, 385 | hidden_dim=hidden_dim, 386 | attention_activation=attention_activation, 387 | feed_forward_activation=feed_forward_activation, 388 | dropout_rate=dropout_rate, 389 | trainable=trainable, 390 | ) 391 | decoder_input = keras.layers.Input(shape=(None,), name='Decoder-Input') 392 | decoder_embed, decoder_embed_weights = decoder_embed_layer(decoder_input) 393 | decoder_embed = TrigPosEmbedding( 394 | mode=TrigPosEmbedding.MODE_ADD, 395 | name='Decoder-Embedding', 396 | )(decoder_embed) 397 | decoded_layer = get_decoders( 398 | decoder_num=decoder_num, 399 | input_layer=decoder_embed, 400 | encoded_layer=encoded_layer, 401 | head_num=head_num, 402 | hidden_dim=hidden_dim, 403 | attention_activation=attention_activation, 404 | feed_forward_activation=feed_forward_activation, 405 | dropout_rate=dropout_rate, 406 | trainable=trainable, 407 | ) 408 | output_layer = EmbeddingSim( 409 | trainable=trainable, 410 | name='Decoder-Output', 411 | )([decoded_layer, decoder_embed_weights]) 412 | return keras.models.Model(inputs=[encoder_input, decoder_input], outputs=output_layer) 413 | 414 | 415 | def _get_max_suffix_repeat_times(tokens, max_len): 416 | detect_len = min(max_len, len(tokens)) 417 | next = [-1] * detect_len 418 | k = -1 419 | for i in range(1, detect_len): 420 | while k >= 0 and tokens[len(tokens) - i - 1] != tokens[len(tokens) - k - 2]: 421 | k = next[k] 422 | if tokens[len(tokens) - i - 1] == tokens[len(tokens) - k - 2]: 423 | k += 1 424 | next[i] = k 425 | max_repeat = 1 426 | for i in range(2, detect_len): 427 | if next[i] >= 0 and (i + 1) % (i - next[i]) == 0: 428 | max_repeat = max(max_repeat, (i + 1) // (i - next[i])) 429 | return max_repeat 430 | 431 | 432 | def decode(model, 433 | tokens, 434 | start_token, 435 | end_token, 436 | pad_token, 437 | top_k=1, 438 | temperature=1.0, 439 | max_len=10000, 440 | max_repeat=10, 441 | max_repeat_block=10): 442 | """Decode with the given model and input tokens. 443 | 444 | :param model: The trained model. 445 | :param tokens: The input tokens of encoder. 446 | :param start_token: The token that represents the start of a sentence. 447 | :param end_token: The token that represents the end of a sentence. 448 | :param pad_token: The token that represents padding. 449 | :param top_k: Choose the last token from top K. 450 | :param temperature: Randomness in boltzmann distribution. 451 | :param max_len: Maximum length of decoded list. 452 | :param max_repeat: Maximum number of repeating blocks. 453 | :param max_repeat_block: Maximum length of the repeating block. 454 | :return: Decoded tokens. 455 | """ 456 | is_single = not isinstance(tokens[0], list) 457 | if is_single: 458 | tokens = [tokens] 459 | batch_size = len(tokens) 460 | decoder_inputs = [[start_token] for _ in range(batch_size)] 461 | outputs = [None for _ in range(batch_size)] 462 | output_len = 1 463 | while len(list(filter(lambda x: x is None, outputs))) > 0: 464 | output_len += 1 465 | batch_inputs, batch_outputs = [], [] 466 | max_input_len = 0 467 | index_map = {} 468 | for i in range(batch_size): 469 | if outputs[i] is None: 470 | index_map[len(batch_inputs)] = i 471 | batch_inputs.append(tokens[i][:]) 472 | batch_outputs.append(decoder_inputs[i]) 473 | max_input_len = max(max_input_len, len(tokens[i])) 474 | for i in range(len(batch_inputs)): 475 | batch_inputs[i] += [pad_token] * (max_input_len - len(batch_inputs[i])) 476 | predicts = model.predict([np.array(batch_inputs), np.array(batch_outputs)]) 477 | for i in range(len(predicts)): 478 | if top_k == 1: 479 | last_token = predicts[i][-1].argmax(axis=-1) 480 | else: 481 | probs = [(prob, j) for j, prob in enumerate(predicts[i][-1])] 482 | probs.sort(reverse=True) 483 | probs = probs[:top_k] 484 | indices, probs = list(map(lambda x: x[1], probs)), list(map(lambda x: x[0], probs)) 485 | probs = np.array(probs) / temperature 486 | probs = probs - np.max(probs) 487 | probs = np.exp(probs) 488 | probs = probs / np.sum(probs) 489 | last_token = np.random.choice(indices, p=probs) 490 | decoder_inputs[index_map[i]].append(last_token) 491 | if last_token == end_token or\ 492 | (max_len is not None and output_len >= max_len) or\ 493 | _get_max_suffix_repeat_times(decoder_inputs[index_map[i]], 494 | max_repeat * max_repeat_block) >= max_repeat: 495 | outputs[index_map[i]] = decoder_inputs[index_map[i]] 496 | if is_single: 497 | outputs = outputs[0] 498 | return outputs 499 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf dist/* && python3 setup.py sdist && twine upload dist/* 3 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | setuptools>=38.6.0 2 | twine>=1.11.0 3 | wheel>=0.31.0 4 | nose 5 | tensorflow 6 | pycodestyle 7 | coverage 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras-pos-embd==0.13.0 2 | keras-multi-head==0.29.0 3 | keras-layer-normalization==0.16.0 4 | keras-position-wise-feed-forward==0.8.0 5 | keras-embed-sim==0.10.0 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | from setuptools import setup, find_packages 5 | 6 | current_path = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | 9 | def read_file(*parts): 10 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 11 | return reader.read() 12 | 13 | 14 | def get_requirements(*parts): 15 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 16 | return list(map(lambda x: x.strip(), reader.readlines())) 17 | 18 | 19 | def find_version(*file_paths): 20 | version_file = read_file(*file_paths) 21 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 22 | if version_match: 23 | return version_match.group(1) 24 | raise RuntimeError('Unable to find version string.') 25 | 26 | 27 | setup( 28 | name='keras-transformer', 29 | version=find_version('keras_transformer', '__init__.py'), 30 | packages=find_packages(), 31 | url='https://github.com/CyberZHG/keras-transformer', 32 | license='MIT', 33 | author='CyberZHG', 34 | author_email='CyberZHG@users.noreply.github.com', 35 | description='Transformer implemented in Keras', 36 | long_description=read_file('README.md'), 37 | long_description_content_type='text/markdown', 38 | install_requires=get_requirements('requirements.txt'), 39 | classifiers=( 40 | "Programming Language :: Python :: 3", 41 | "License :: OSI Approved :: MIT License", 42 | "Operating System :: OS Independent", 43 | ), 44 | ) 45 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | pycodestyle --max-line-length=120 keras_transformer tests && \ 3 | nosetests --with-coverage --cover-erase --cover-html --cover-html-dir=htmlcov --cover-package=keras_transformer tests -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-transformer/3edee5a22027298dd7e4c5b3b52c31d07ca490fb/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_decode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | from tensorflow import keras 6 | 7 | from keras_transformer import get_custom_objects, get_model, decode 8 | 9 | 10 | class TestDecode(unittest.TestCase): 11 | 12 | def test_decode(self): 13 | tokens = 'all work and no play makes jack a dull boy'.split(' ') 14 | token_dict = { 15 | '': 0, 16 | '': 1, 17 | '': 2, 18 | } 19 | for token in tokens: 20 | if token not in token_dict: 21 | token_dict[token] = len(token_dict) 22 | model = get_model( 23 | token_num=len(token_dict), 24 | embed_dim=32, 25 | encoder_num=3, 26 | decoder_num=2, 27 | head_num=4, 28 | hidden_dim=128, 29 | dropout_rate=0.05, 30 | ) 31 | model.compile( 32 | optimizer='adam', 33 | loss='sparse_categorical_crossentropy', 34 | ) 35 | model.summary() 36 | encoder_inputs_no_padding = [] 37 | encoder_inputs, decoder_inputs, decoder_outputs = [], [], [] 38 | for i in range(1, len(tokens)): 39 | encode_tokens, decode_tokens = tokens[:i], tokens[i:] 40 | encode_tokens = [''] + encode_tokens + [''] + [''] * (len(tokens) - len(encode_tokens)) 41 | output_tokens = decode_tokens + ['', ''] + [''] * (len(tokens) - len(decode_tokens)) 42 | decode_tokens = [''] + decode_tokens + [''] + [''] * (len(tokens) - len(decode_tokens)) 43 | encode_tokens = list(map(lambda x: token_dict[x], encode_tokens)) 44 | decode_tokens = list(map(lambda x: token_dict[x], decode_tokens)) 45 | output_tokens = list(map(lambda x: [token_dict[x]], output_tokens)) 46 | encoder_inputs_no_padding.append(encode_tokens[:i + 2]) 47 | encoder_inputs.append(encode_tokens) 48 | decoder_inputs.append(decode_tokens) 49 | decoder_outputs.append(output_tokens) 50 | current_path = os.path.dirname(os.path.abspath(__file__)) 51 | model_path = os.path.join(current_path, 'test_transformer.h5') 52 | if os.path.exists(model_path): 53 | model.load_weights(model_path, by_name=True) 54 | else: 55 | model.fit( 56 | x=[np.asarray(encoder_inputs * 2048), np.asarray(decoder_inputs * 2048)], 57 | y=np.asarray(decoder_outputs * 2048), 58 | epochs=10, 59 | batch_size=128, 60 | ) 61 | model.save(model_path) 62 | model = keras.models.load_model(model_path, custom_objects=get_custom_objects()) 63 | decoded = decode( 64 | model, 65 | encoder_inputs_no_padding * 2, 66 | start_token=token_dict[''], 67 | end_token=token_dict[''], 68 | pad_token=token_dict[''], 69 | ) 70 | token_dict_rev = {v: k for k, v in token_dict.items()} 71 | for i in range(len(decoded)): 72 | print(' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1]))) 73 | for i in range(len(decoded)): 74 | for j in range(len(decoded[i])): 75 | self.assertEqual(decoder_inputs[i % len(decoder_inputs)][j], decoded[i][j]) 76 | 77 | decoded = decode( 78 | model, 79 | encoder_inputs_no_padding[2] + [0] * 5, 80 | start_token=token_dict[''], 81 | end_token=token_dict[''], 82 | pad_token=token_dict[''], 83 | ) 84 | for j in range(len(decoded)): 85 | self.assertEqual(decoder_inputs[2][j], decoded[j], decoded) 86 | 87 | decoded = decode( 88 | model, 89 | encoder_inputs_no_padding, 90 | start_token=token_dict[''], 91 | end_token=token_dict[''], 92 | pad_token=token_dict[''], 93 | max_len=4, 94 | ) 95 | token_dict_rev = {v: k for k, v in token_dict.items()} 96 | for i in range(len(decoded)): 97 | print(' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1]))) 98 | for i in range(len(decoded)): 99 | self.assertTrue(len(decoded[i]) <= 4, decoded[i]) 100 | for j in range(len(decoded[i])): 101 | self.assertEqual(decoder_inputs[i][j], decoded[i][j], decoded) 102 | 103 | decoded_top_5 = decode( 104 | model, 105 | encoder_inputs_no_padding, 106 | start_token=token_dict[''], 107 | end_token=token_dict[''], 108 | pad_token=token_dict[''], 109 | max_len=4, 110 | top_k=5, 111 | temperature=1e-10, 112 | ) 113 | has_diff = False 114 | for i in range(len(decoded)): 115 | s1 = ' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1])) 116 | s5 = ' '.join(map(lambda x: token_dict_rev[x], decoded_top_5[i][1:-1])) 117 | if s1 != s5: 118 | has_diff = True 119 | self.assertFalse(has_diff) 120 | 121 | decoded_top_5 = decode( 122 | model, 123 | encoder_inputs_no_padding, 124 | start_token=token_dict[''], 125 | end_token=token_dict[''], 126 | pad_token=token_dict[''], 127 | max_len=4, 128 | top_k=5, 129 | ) 130 | has_diff = False 131 | for i in range(len(decoded)): 132 | s1 = ' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1])) 133 | s5 = ' '.join(map(lambda x: token_dict_rev[x], decoded_top_5[i][1:-1])) 134 | if s1 != s5: 135 | has_diff = True 136 | self.assertTrue(has_diff) 137 | -------------------------------------------------------------------------------- /tests/test_gelu.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tensorflow.keras import backend as K 4 | 5 | from keras_transformer import gelu 6 | 7 | 8 | class TestGelu(unittest.TestCase): 9 | 10 | def test_sample(self): 11 | try: 12 | results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).eval(session=K.get_session()) 13 | except Exception as e: 14 | try: 15 | results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).eval() 16 | except Exception as e: 17 | results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).numpy() 18 | self.assertEqual(0.0, results[0]) 19 | self.assertGreater(0.0, results[1]) 20 | self.assertLess(-1.0, results[1]) 21 | self.assertEqual(0.0, results[2]) 22 | self.assertGreater(1.0, results[3]) 23 | self.assertLess(0.0, results[3]) 24 | self.assertEqual(30.0, results[4]) 25 | -------------------------------------------------------------------------------- /tests/test_get_decoders.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tensorflow import keras 4 | 5 | from keras_transformer import get_encoders, get_decoders 6 | 7 | 8 | class TestGetDecoderComponent(unittest.TestCase): 9 | 10 | def test_sample(self): 11 | encoder_input_layer = keras.layers.Input(shape=(512, 768), name='Encoder-Input') 12 | decoder_input_layer = keras.layers.Input(shape=(512, 768), name='Decoder-Input') 13 | encoded_layer = get_encoders( 14 | encoder_num=2, 15 | input_layer=encoder_input_layer, 16 | head_num=12, 17 | hidden_dim=3072, 18 | dropout_rate=0.0, 19 | ) 20 | output_layer = get_decoders( 21 | decoder_num=2, 22 | input_layer=decoder_input_layer, 23 | encoded_layer=encoded_layer, 24 | head_num=12, 25 | hidden_dim=3072, 26 | dropout_rate=0.0, 27 | ) 28 | model = keras.models.Model(inputs=[encoder_input_layer, decoder_input_layer], outputs=output_layer) 29 | model.compile(optimizer='adam', loss='mse', metrics={}) 30 | model.summary(line_length=160) 31 | 32 | output_layer = get_decoders( 33 | decoder_num=2, 34 | input_layer=decoder_input_layer, 35 | encoded_layer=encoded_layer, 36 | head_num=12, 37 | hidden_dim=3072, 38 | dropout_rate=0.1, 39 | ) 40 | model = keras.models.Model(inputs=[encoder_input_layer, decoder_input_layer], outputs=output_layer) 41 | model.compile(optimizer='adam', loss='mse', metrics={}) 42 | model.summary(line_length=160) 43 | self.assertIsNotNone(model) 44 | -------------------------------------------------------------------------------- /tests/test_get_encoders.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tensorflow import keras 4 | 5 | from keras_transformer import get_encoders 6 | 7 | 8 | class TestGetEncoderComponent(unittest.TestCase): 9 | 10 | def test_sample(self): 11 | input_layer = keras.layers.Input(shape=(512, 768), name='Input') 12 | output_layer = get_encoders( 13 | encoder_num=2, 14 | input_layer=input_layer, 15 | head_num=12, 16 | hidden_dim=3072, 17 | dropout_rate=0.0, 18 | ) 19 | model = keras.models.Model(inputs=input_layer, outputs=output_layer) 20 | model.compile(optimizer='adam', loss='mse', metrics={}) 21 | model.summary(line_length=160) 22 | 23 | output_layer = get_encoders( 24 | encoder_num=2, 25 | input_layer=input_layer, 26 | head_num=12, 27 | hidden_dim=3072, 28 | dropout_rate=0.1, 29 | ) 30 | model = keras.models.Model(inputs=input_layer, outputs=output_layer) 31 | model.compile(optimizer='adam', loss='mse', metrics={}) 32 | model.summary(line_length=160) 33 | self.assertIsNotNone(model) 34 | -------------------------------------------------------------------------------- /tests/test_get_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from keras_transformer import get_custom_objects, get_model 9 | 10 | 11 | class TestGetModel(unittest.TestCase): 12 | 13 | def test_get_same(self): 14 | model = get_model( 15 | token_num=13, 16 | embed_dim=30, 17 | encoder_num=3, 18 | decoder_num=2, 19 | head_num=3, 20 | hidden_dim=120, 21 | attention_activation=None, 22 | feed_forward_activation='relu', 23 | dropout_rate=0.05, 24 | use_same_embed=True, 25 | embed_weights=np.random.random((13, 30)), 26 | trainable=False, 27 | ) 28 | model.compile( 29 | optimizer=keras.optimizers.Adam(), 30 | loss=keras.losses.categorical_crossentropy, 31 | metrics={}, 32 | ) 33 | model_path = os.path.join(tempfile.gettempdir(), 'test_transformer_%f.h5' % np.random.random()) 34 | model.save(model_path) 35 | model = keras.models.load_model(model_path, custom_objects=get_custom_objects()) 36 | model.summary() 37 | try: 38 | keras.utils.plot_model(model, 'transformer_same.png') 39 | except Exception as e: 40 | print(e) 41 | self.assertIsNotNone(model) 42 | 43 | def test_get_diff(self): 44 | model = get_model( 45 | token_num=[13, 14], 46 | embed_dim=30, 47 | encoder_num=3, 48 | decoder_num=2, 49 | head_num=3, 50 | hidden_dim=120, 51 | attention_activation=None, 52 | feed_forward_activation='relu', 53 | dropout_rate=0.05, 54 | use_same_embed=False, 55 | ) 56 | model.compile( 57 | optimizer=keras.optimizers.Adam(), 58 | loss=keras.losses.categorical_crossentropy, 59 | metrics={}, 60 | ) 61 | model_path = os.path.join(tempfile.gettempdir(), 'test_transformer_%f.h5' % np.random.random()) 62 | model.save(model_path) 63 | model = keras.models.load_model(model_path, custom_objects=get_custom_objects()) 64 | model.summary() 65 | try: 66 | keras.utils.plot_model(model, 'transformer_diff.png') 67 | except Exception as e: 68 | print(e) 69 | self.assertIsNotNone(model) 70 | -------------------------------------------------------------------------------- /tests/test_suffix_repeat.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from keras_transformer.transformer import _get_max_suffix_repeat_times 3 | 4 | 5 | class TestSuffixRepeat(TestCase): 6 | 7 | def test_abcd(self): 8 | self.assertEqual(1, _get_max_suffix_repeat_times('abcdabcdabcd', max_len=3)) 9 | self.assertEqual(1, _get_max_suffix_repeat_times('abcdabcdabcd', max_len=6)) 10 | self.assertEqual(2, _get_max_suffix_repeat_times('abcdabcdabcd', max_len=11)) 11 | self.assertEqual(3, _get_max_suffix_repeat_times('abcdabcdabcd', max_len=12)) 12 | self.assertEqual(3, _get_max_suffix_repeat_times('abcdabcdabcd', max_len=16)) 13 | self.assertEqual(2, _get_max_suffix_repeat_times('bcdabcdabcd', max_len=16)) 14 | -------------------------------------------------------------------------------- /tests/test_translate.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from __future__ import unicode_literals 3 | 4 | import unittest 5 | import numpy as np 6 | from keras_transformer import get_model, decode 7 | 8 | 9 | class TestTranslate(unittest.TestCase): 10 | 11 | @staticmethod 12 | def _build_token_dict(token_list): 13 | token_dict = { 14 | '': 0, 15 | '': 1, 16 | '': 2, 17 | } 18 | for tokens in token_list: 19 | for token in tokens: 20 | if token not in token_dict: 21 | token_dict[token] = len(token_dict) 22 | return token_dict 23 | 24 | def test_translate(self): 25 | source_tokens = [ 26 | 'i need more power'.split(' '), 27 | 'eat jujube and pill'.split(' '), 28 | ] 29 | target_tokens = [ 30 | list('我要更多的抛瓦'), 31 | list('吃枣💊'), 32 | ] 33 | 34 | # Generate dictionaries 35 | source_token_dict = self._build_token_dict(source_tokens) 36 | target_token_dict = self._build_token_dict(target_tokens) 37 | target_token_dict_inv = {v: k for k, v in target_token_dict.items()} 38 | 39 | # Add special tokens 40 | encode_tokens = [[''] + tokens + [''] for tokens in source_tokens] 41 | decode_tokens = [[''] + tokens + [''] for tokens in target_tokens] 42 | output_tokens = [tokens + ['', ''] for tokens in target_tokens] 43 | 44 | # Padding 45 | source_max_len = max(map(len, encode_tokens)) 46 | target_max_len = max(map(len, decode_tokens)) 47 | 48 | encode_tokens = [tokens + [''] * (source_max_len - len(tokens)) for tokens in encode_tokens] 49 | decode_tokens = [tokens + [''] * (target_max_len - len(tokens)) for tokens in decode_tokens] 50 | output_tokens = [tokens + [''] * (target_max_len - len(tokens)) for tokens in output_tokens] 51 | 52 | encode_input = [list(map(lambda x: source_token_dict[x], tokens)) for tokens in encode_tokens] 53 | decode_input = [list(map(lambda x: target_token_dict[x], tokens)) for tokens in decode_tokens] 54 | decode_output = [list(map(lambda x: [target_token_dict[x]], tokens)) for tokens in output_tokens] 55 | 56 | # Build & fit model 57 | model = get_model( 58 | token_num=max(len(source_token_dict), len(target_token_dict)), 59 | embed_dim=32, 60 | encoder_num=2, 61 | decoder_num=2, 62 | head_num=4, 63 | hidden_dim=128, 64 | dropout_rate=0.05, 65 | use_same_embed=False, # Use different embeddings for different languages 66 | ) 67 | model.compile('adam', 'sparse_categorical_crossentropy') 68 | model.summary() 69 | model.fit( 70 | x=[np.array(encode_input * 1024), np.array(decode_input * 1024)], 71 | y=np.array(decode_output * 1024), 72 | epochs=10, 73 | batch_size=32, 74 | ) 75 | 76 | # Predict 77 | decoded = decode( 78 | model, 79 | encode_input, 80 | start_token=target_token_dict[''], 81 | end_token=target_token_dict[''], 82 | pad_token=target_token_dict[''], 83 | ) 84 | for i in range(len(encode_input)): 85 | predicted = ''.join(map(lambda x: target_token_dict_inv[x], decoded[i][1:-1])) 86 | self.assertEqual(''.join(target_tokens[i]), predicted) 87 | --------------------------------------------------------------------------------