├── tests ├── __init__.py ├── datasets │ ├── __init__.py │ └── test_get_pretrained.py ├── layers │ ├── __init__.py │ ├── test_inputs.py │ ├── test_task_embed.py │ ├── test_extract.py │ ├── test_embedding.py │ ├── test_conv.py │ ├── test_pooling.py │ └── test_masked.py ├── optimizers │ ├── __init__.py │ └── test_warmup.py ├── test_bert_fit.h5 ├── test_checkpoint │ ├── bert_model.ckpt.index │ ├── bert_model.ckpt.meta │ ├── bert_model.ckpt.data-00000-of-00001 │ ├── vocab.txt │ └── bert_config.json ├── test_loader.py ├── test_tokenizer.py ├── test_util.py └── test_bert.py ├── demo ├── tune │ ├── __init__.py │ └── keras_bert_classification_tpu.ipynb ├── load_model │ ├── __init__.py │ ├── load_and_pool.py │ ├── load_and_get_attention_map.py │ ├── load_and_predict.py │ ├── load_and_extract.py │ ├── keras_bert_load_and_extract_tpu.ipynb │ └── keras_bert_load_and_extract.ipynb └── visualization │ ├── __init__.py │ └── vis.py ├── keras_bert ├── optimizers │ ├── __init__.py │ ├── util.py │ └── warmup_v2.py ├── datasets │ ├── __init__.py │ └── pretrained.py ├── __init__.py ├── layers │ ├── __init__.py │ ├── inputs.py │ ├── pooling.py │ ├── conv.py │ ├── extract.py │ ├── masked.py │ ├── task_embed.py │ └── embedding.py ├── util.py ├── loader.py ├── tokenizer.py └── bert.py ├── requirements.txt ├── .github ├── stale.yml └── ISSUE_TEMPLATE │ ├── question.md │ ├── bug_report.md │ └── feature_request.md ├── publish.sh ├── MANIFEST.in ├── requirements-dev.txt ├── test.sh ├── LICENSE ├── CHANGELOG.md ├── setup.py ├── .gitignore ├── README.md └── README.zh-CN.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/tune/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/load_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_bert/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | keras-transformer==0.40.0 3 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | daysUntilStale: 5 2 | daysUntilClose: 2 3 | -------------------------------------------------------------------------------- /keras_bert/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .pretrained import * 2 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf dist/* && python3 setup.py sdist && twine upload dist/* 3 | -------------------------------------------------------------------------------- /tests/test_bert_fit.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-bert/HEAD/tests/test_bert_fit.h5 -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include README-zh-CN.md 3 | include CHANGELOG.md 4 | include requirements.txt 5 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | setuptools>=38.6.0 2 | twine>=1.11.0 3 | wheel>=0.31.0 4 | tensorflow 5 | nose 6 | pycodestyle 7 | coverage -------------------------------------------------------------------------------- /tests/test_checkpoint/bert_model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-bert/HEAD/tests/test_checkpoint/bert_model.ckpt.index -------------------------------------------------------------------------------- /tests/test_checkpoint/bert_model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-bert/HEAD/tests/test_checkpoint/bert_model.ckpt.meta -------------------------------------------------------------------------------- /tests/test_checkpoint/bert_model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberZHG/keras-bert/HEAD/tests/test_checkpoint/bert_model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /tests/test_checkpoint/vocab.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | [CLS] 4 | [SEP] 5 | [MASK] 6 | all 7 | work 8 | and 9 | no 10 | play 11 | makes 12 | jack 13 | a 14 | dull 15 | boy -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask questions about the repo 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /keras_bert/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import * 2 | from .loader import * 3 | from .tokenizer import Tokenizer 4 | from .optimizers import * 5 | from .util import * 6 | from .datasets import * 7 | 8 | __version__ = '0.89.0' 9 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | pycodestyle --max-line-length=120 keras_bert tests && \ 3 | nosetests --nocapture --with-coverage --cover-erase --cover-html --cover-html-dir=htmlcov --cover-package=keras_bert --with-doctest -------------------------------------------------------------------------------- /keras_bert/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .inputs import get_inputs 2 | from .embedding import get_embedding, TokenEmbedding, EmbeddingSimilarity 3 | from .masked import Masked 4 | from .extract import Extract 5 | from .pooling import MaskedGlobalMaxPool1D 6 | from .conv import MaskedConv1D 7 | from .task_embed import TaskEmbedding 8 | -------------------------------------------------------------------------------- /tests/layers/test_inputs.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from keras_bert.layers import get_inputs 3 | 4 | 5 | class TestInputs(unittest.TestCase): 6 | 7 | def test_name(self): 8 | inputs = get_inputs(seq_len=512) 9 | self.assertEqual(3, len(inputs)) 10 | self.assertTrue('Segment' in inputs[1].name) 11 | -------------------------------------------------------------------------------- /tests/datasets/test_get_pretrained.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import TestCase 3 | from keras_bert.datasets import get_pretrained, PretrainedList 4 | 5 | 6 | class TestGetPretrained(TestCase): 7 | 8 | def test_get_pretrained(self): 9 | path = get_pretrained(PretrainedList.__test__) 10 | self.assertTrue(os.path.exists(os.path.join(path, 'README.md'))) 11 | -------------------------------------------------------------------------------- /keras_bert/layers/inputs.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | 4 | def get_inputs(seq_len): 5 | """Get input layers. 6 | 7 | See: https://arxiv.org/pdf/1810.04805.pdf 8 | 9 | :param seq_len: Length of the sequence or None. 10 | """ 11 | names = ['Token', 'Segment', 'Masked'] 12 | return [keras.layers.Input( 13 | shape=(seq_len,), 14 | name='Input-%s' % name, 15 | ) for name in names] 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: CyberZHG 7 | 8 | --- 9 | 10 | **Describe the Bug** 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | **Version Info** 15 | 16 | * [ ] I'm using the latest version 17 | 18 | **Minimal Codes To Reproduce** 19 | 20 | ```python 21 | import keras_bert 22 | 23 | pass 24 | ``` 25 | -------------------------------------------------------------------------------- /tests/test_checkpoint/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 4, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 16, 9 | "max_position_embeddings": 16, 10 | "num_attention_heads": 4, 11 | "num_hidden_layers": 2, 12 | "pooler_fc_size": 4, 13 | "pooler_num_attention_heads": 4, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 16, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 15 19 | } 20 | -------------------------------------------------------------------------------- /keras_bert/layers/pooling.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | class MaskedGlobalMaxPool1D(keras.layers.Layer): 6 | 7 | def __init__(self, **kwargs): 8 | super(MaskedGlobalMaxPool1D, self).__init__(**kwargs) 9 | self.supports_masking = True 10 | 11 | def compute_mask(self, inputs, mask=None): 12 | return None 13 | 14 | def call(self, inputs, mask=None): 15 | if mask is not None: 16 | mask = K.cast(mask, K.floatx()) 17 | inputs -= K.expand_dims((1.0 - mask) * 1e6, axis=-1) 18 | return K.max(inputs, axis=-2) 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: CyberZHG 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /keras_bert/layers/conv.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | class MaskedConv1D(keras.layers.Conv1D): 6 | 7 | def __init__(self, **kwargs): 8 | super(MaskedConv1D, self).__init__(**kwargs) 9 | self.supports_masking = True 10 | 11 | def compute_mask(self, inputs, mask=None): 12 | if mask is not None and self.padding == 'valid': 13 | mask = mask[:, self.kernel_size[0] // 2 * self.dilation_rate[0] * 2:] 14 | return mask 15 | 16 | def call(self, inputs, mask=None): 17 | if mask is not None: 18 | mask = K.cast(mask, K.floatx()) 19 | inputs *= K.expand_dims(mask, axis=-1) 20 | return super(MaskedConv1D, self).call(inputs) 21 | -------------------------------------------------------------------------------- /keras_bert/layers/extract.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | 4 | class Extract(keras.layers.Layer): 5 | """Extract from index. 6 | 7 | See: https://arxiv.org/pdf/1810.04805.pdf 8 | """ 9 | 10 | def __init__(self, index, **kwargs): 11 | super(Extract, self).__init__(**kwargs) 12 | self.index = index 13 | self.supports_masking = True 14 | 15 | def get_config(self): 16 | config = { 17 | 'index': self.index, 18 | } 19 | base_config = super(Extract, self).get_config() 20 | return dict(list(base_config.items()) + list(config.items())) 21 | 22 | def compute_mask(self, inputs, mask=None): 23 | return None 24 | 25 | def call(self, x, mask=None): 26 | return x[:, self.index] 27 | -------------------------------------------------------------------------------- /keras_bert/optimizers/util.py: -------------------------------------------------------------------------------- 1 | from .warmup_v2 import AdamWarmup 2 | 3 | __all__ = ['AdamWarmup', 'calc_train_steps'] 4 | 5 | 6 | def calc_train_steps(num_example, batch_size, epochs, warmup_proportion=0.1): 7 | """Calculate the number of total and warmup steps. 8 | 9 | >>> calc_train_steps(num_example=1024, batch_size=32, epochs=10, warmup_proportion=0.1) 10 | (320, 32) 11 | 12 | :param num_example: Number of examples in one epoch. 13 | :param batch_size: Batch size. 14 | :param epochs: Number of epochs. 15 | :param warmup_proportion: The proportion of warmup steps. 16 | :return: Total steps and warmup steps. 17 | """ 18 | steps = (num_example + batch_size - 1) // batch_size 19 | total = steps * epochs 20 | warmup = int(total * warmup_proportion) 21 | return total, warmup 22 | -------------------------------------------------------------------------------- /tests/layers/test_task_embed.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorflow import keras 5 | from tensorflow.keras import backend as K 6 | 7 | from keras_bert.layers import TaskEmbedding 8 | 9 | 10 | class TestTaskEmbedding(unittest.TestCase): 11 | 12 | def test_mask_zero(self): 13 | embed_input = keras.layers.Input(shape=(5, 4)) 14 | task_input = keras.layers.Input(shape=(1,)) 15 | task_embed = TaskEmbedding(input_dim=2, output_dim=4, mask_zero=True)([embed_input, task_input]) 16 | func = K.function([embed_input, task_input], [task_embed]) 17 | embed, task = np.random.random((2, 5, 4)), np.array([[0], [1]]) 18 | output = func([embed, task])[0] 19 | self.assertTrue(np.allclose(embed[0], output[0])) 20 | self.assertFalse(np.allclose(embed[1], output[1])) 21 | -------------------------------------------------------------------------------- /tests/layers/test_extract.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorflow import keras 5 | 6 | from keras_bert.layers import Extract 7 | 8 | 9 | class TestExtract(unittest.TestCase): 10 | 11 | def test_sample(self): 12 | input_layer = keras.layers.Input( 13 | shape=(3, 4), 14 | name='Input', 15 | ) 16 | extract_layer = Extract( 17 | index=0, 18 | name='Extract' 19 | )(input_layer) 20 | model = keras.models.Model( 21 | inputs=input_layer, 22 | outputs=extract_layer, 23 | ) 24 | model.compile( 25 | optimizer='adam', 26 | loss='mse', 27 | metrics={}, 28 | ) 29 | model.summary() 30 | inputs = np.asarray([[ 31 | [0.1, 0.2, 0.3, 0.4], 32 | [-0.1, 0.2, -0.3, 0.4], 33 | [0.1, -0.2, 0.3, -0.4], 34 | ]]) 35 | predict = model.predict(inputs) 36 | expected = np.asarray([[0.1, 0.2, 0.3, 0.4]]) 37 | self.assertTrue(np.allclose(expected, predict), predict) 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Zhao HG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/layers/test_embedding.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tensorflow import keras 4 | 5 | from keras_bert.layers import get_inputs, get_embedding 6 | 7 | 8 | class TestEmbedding(unittest.TestCase): 9 | 10 | def test_sample(self): 11 | inputs = get_inputs(seq_len=512) 12 | embed_layer = get_embedding(inputs, token_num=12, embed_dim=768, pos_num=512) 13 | model = keras.models.Model(inputs=inputs, outputs=embed_layer) 14 | model.compile( 15 | optimizer='adam', 16 | loss='mse', 17 | metrics={}, 18 | ) 19 | model.summary() 20 | self.assertEqual((None, 512, 768), model.layers[-1].output_shape) 21 | 22 | def test_no_dropout(self): 23 | inputs = get_inputs(seq_len=512) 24 | embed_layer = get_embedding(inputs, token_num=12, embed_dim=768, pos_num=512, dropout_rate=0.0) 25 | model = keras.models.Model(inputs=inputs, outputs=embed_layer) 26 | model.compile( 27 | optimizer='adam', 28 | loss='mse', 29 | metrics={}, 30 | ) 31 | model.summary() 32 | self.assertEqual((None, 512, 768), model.layers[-1].output_shape) 33 | -------------------------------------------------------------------------------- /tests/layers/test_conv.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | from tensorflow import keras 5 | 6 | from keras_bert.layers import MaskedConv1D, MaskedGlobalMaxPool1D 7 | 8 | 9 | class TestConv(TestCase): 10 | 11 | def test_masked_conv_1d_fit(self): 12 | input_layer = keras.layers.Input(shape=(7,)) 13 | embed_layer = keras.layers.Embedding( 14 | input_dim=11, 15 | output_dim=13, 16 | mask_zero=True, 17 | )(input_layer) 18 | conv_layer = MaskedConv1D(filters=7, kernel_size=3)(embed_layer) 19 | pool_layer = MaskedGlobalMaxPool1D()(conv_layer) 20 | dense_layer = keras.layers.Dense(units=2, activation='softmax')(pool_layer) 21 | model = keras.models.Model(inputs=input_layer, outputs=dense_layer) 22 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 23 | model.summary() 24 | x = np.array(np.random.randint(0, 11, (32, 7)).tolist() * 100) 25 | y = np.array(np.random.randint(0, 2, (32,)).tolist() * 100) 26 | model.fit(x, y, epochs=10) 27 | y_hat = model.predict(x).argmax(axis=-1) 28 | self.assertEqual(y.tolist(), y_hat.tolist()) 29 | -------------------------------------------------------------------------------- /demo/load_model/load_and_pool.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import keras 4 | from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths 5 | from keras_bert.layers import MaskedGlobalMaxPool1D 6 | 7 | print('This demo demonstrates how to load the pre-trained model and extract the sentence embedding with pooling.') 8 | 9 | if len(sys.argv) == 2: 10 | model_path = sys.argv[1] 11 | else: 12 | from keras_bert.datasets import get_pretrained, PretrainedList 13 | model_path = get_pretrained(PretrainedList.chinese_base) 14 | 15 | paths = get_checkpoint_paths(model_path) 16 | 17 | model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=10) 18 | pool_layer = MaskedGlobalMaxPool1D(name='Pooling')(model.output) 19 | model = keras.models.Model(inputs=model.inputs, outputs=pool_layer) 20 | model.summary(line_length=120) 21 | 22 | token_dict = load_vocabulary(paths.vocab) 23 | 24 | tokenizer = Tokenizer(token_dict) 25 | text = '语言模型' 26 | tokens = tokenizer.tokenize(text) 27 | print('Tokens:', tokens) 28 | indices, segments = tokenizer.encode(first=text, max_len=10) 29 | 30 | predicts = model.predict([np.array([indices]), np.array([segments])])[0] 31 | print('Pooled:', predicts.tolist()[:5]) 32 | -------------------------------------------------------------------------------- /demo/load_model/load_and_get_attention_map.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths 4 | from keras_bert.backend import backend as K 5 | 6 | print('This demo demonstrates how to load the pre-trained model and extract the attention map') 7 | 8 | if len(sys.argv) == 2: 9 | model_path = sys.argv[1] 10 | else: 11 | from keras_bert.datasets import get_pretrained, PretrainedList 12 | model_path = get_pretrained(PretrainedList.chinese_base) 13 | 14 | paths = get_checkpoint_paths(model_path) 15 | 16 | model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=10) 17 | attention_layer = model.get_layer('Encoder-1-MultiHeadSelfAttention') 18 | model = K.function(model.inputs, attention_layer.attention) 19 | 20 | token_dict = load_vocabulary(paths.vocab) 21 | 22 | tokenizer = Tokenizer(token_dict) 23 | text = '语言模型' 24 | tokens = tokenizer.tokenize(text) 25 | print('Tokens:', tokens) 26 | indices, segments = tokenizer.encode(first=text, max_len=10) 27 | 28 | predicts = model([np.array([indices]), np.array([segments])])[0] 29 | for i, token in enumerate(tokens): 30 | print(token) 31 | for head_index in range(12): 32 | print(predicts[i][head_index, :len(text) + 2].tolist()) 33 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [Unreleased] 4 | 5 | ## [0.85.0] - 2020-07-09 6 | 7 | ### Fixed 8 | 9 | - Compatible with Keras 2.4.3 10 | 11 | ## [0.82.0] - 2020-06-02 12 | 13 | ### Removed 14 | 15 | - Adapter 16 | 17 | ## [0.78.0] - 2019-09-17 18 | 19 | ### Fixed 20 | 21 | - Compatible with Keras 2.3.0 22 | 23 | ## [0.70.0] - 2019-07-16 24 | 25 | ### Added 26 | 27 | - Try to find the indices of tokens in the original text. 28 | 29 | ## [0.69.0] - 2019-07-16 30 | 31 | ### Added 32 | 33 | - [Adapter](https://arxiv.org/pdf/1902.00751.pdf) 34 | 35 | ## [0.60.0] - 2019-06-10 36 | 37 | ### Added 38 | 39 | - `trainable` can be a list of prefixes of layer names 40 | 41 | ## [0.58.0] - 2019-06-10 42 | 43 | ### Fixed 44 | 45 | - Use `math_ops` for tensorflow backend 46 | - Assign names to variables in warmup optimizer 47 | 48 | ## [0.56.0] - 2019-06-04 49 | 50 | ### Changed 51 | 52 | - Docs about `training` and `trainable` 53 | 54 | ### Fixed 55 | 56 | - Missing `trainable=False` when `training=True` 57 | 58 | ## [0.54.0] - 2019-05-29 59 | 60 | ### Added 61 | 62 | - Support eager mode with tensorflow backend 63 | 64 | ## [0.43.0] - 2019-05-12 65 | 66 | ### Added 67 | 68 | - Support `tf.keras` 69 | 70 | ## [0.40.0] - 2019-04-29 71 | 72 | ### Added 73 | 74 | - Warmup optimizer 75 | 76 | ## [Older Versions] 77 | 78 | ### Added 79 | 80 | - BERT implementation 81 | - Load official model 82 | - Tokenizer 83 | - Demos 84 | -------------------------------------------------------------------------------- /demo/visualization/vis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import keras 3 | from keras_bert import get_model 4 | 5 | 6 | model = get_model( 7 | token_num=30000, 8 | pos_num=512, 9 | transformer_num=12, 10 | head_num=12, 11 | embed_dim=768, 12 | feed_forward_dim=768 * 4, 13 | ) 14 | model.summary(line_length=120) 15 | current_path = os.path.dirname(os.path.abspath(__file__)) 16 | output_path = os.path.join(current_path, 'bert_small.png') 17 | keras.utils.plot_model(model, show_shapes=True, to_file=output_path) 18 | 19 | model = get_model( 20 | token_num=30000, 21 | pos_num=512, 22 | transformer_num=24, 23 | head_num=16, 24 | embed_dim=1024, 25 | feed_forward_dim=1024 * 4, 26 | ) 27 | model.summary(line_length=120) 28 | output_path = os.path.join(current_path, 'bert_big.png') 29 | keras.utils.plot_model(model, show_shapes=True, to_file=output_path) 30 | 31 | inputs, outputs = get_model( 32 | token_num=30000, 33 | pos_num=512, 34 | transformer_num=12, 35 | head_num=12, 36 | embed_dim=768, 37 | feed_forward_dim=768 * 4, 38 | training=False, 39 | ) 40 | model = keras.models.Model(inputs=inputs, outputs=outputs) 41 | model.compile(optimizer='adam', loss='mse', metrics={}) 42 | model.summary(line_length=120) 43 | current_path = os.path.dirname(os.path.abspath(__file__)) 44 | output_path = os.path.join(current_path, 'bert_trained.png') 45 | keras.utils.plot_model(model, show_shapes=True, to_file=output_path) 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | from setuptools import setup, find_packages 5 | 6 | current_path = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | 9 | def read_file(*parts): 10 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 11 | return reader.read() 12 | 13 | 14 | def get_requirements(*parts): 15 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 16 | return list(map(lambda x: x.strip(), reader.readlines())) 17 | 18 | 19 | def find_version(*file_paths): 20 | version_file = read_file(*file_paths) 21 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 22 | if version_match: 23 | return version_match.group(1) 24 | raise RuntimeError('Unable to find version string.') 25 | 26 | 27 | setup( 28 | name='keras-bert', 29 | version=find_version('keras_bert', '__init__.py'), 30 | packages=find_packages(), 31 | url='https://github.com/CyberZHG/keras-bert', 32 | license='MIT', 33 | author='CyberZHG', 34 | author_email='CyberZHG@users.noreply.github.com', 35 | description='BERT implemented in Keras', 36 | long_description=read_file('README.md'), 37 | long_description_content_type='text/markdown', 38 | install_requires=get_requirements('requirements.txt'), 39 | classifiers=( 40 | "Programming Language :: Python :: 3", 41 | "License :: OSI Approved :: MIT License", 42 | "Operating System :: OS Independent", 43 | ), 44 | ) 45 | -------------------------------------------------------------------------------- /keras_bert/layers/masked.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | class Masked(keras.layers.Layer): 6 | """Generate output mask based on the given mask. 7 | 8 | The inputs for the layer is the original input layer and the masked locations. 9 | 10 | See: https://arxiv.org/pdf/1810.04805.pdf 11 | """ 12 | 13 | def __init__(self, 14 | return_masked=False, 15 | **kwargs): 16 | """Initialize the layer. 17 | 18 | :param return_masked: Whether to return the merged mask. 19 | :param kwargs: Arguments for parent class. 20 | """ 21 | super(Masked, self).__init__(**kwargs) 22 | self.supports_masking = True 23 | self.return_masked = return_masked 24 | 25 | def get_config(self): 26 | config = { 27 | 'return_masked': self.return_masked, 28 | } 29 | base_config = super(Masked, self).get_config() 30 | return dict(list(base_config.items()) + list(config.items())) 31 | 32 | def compute_mask(self, inputs, mask=None): 33 | token_mask = K.not_equal(inputs[1], 0) 34 | masked = K.all(K.stack([token_mask, mask[0]], axis=0), axis=0) 35 | if self.return_masked: 36 | return [masked, None] 37 | return masked 38 | 39 | def call(self, inputs, mask=None, **kwargs): 40 | output = inputs[0] + 0 41 | if self.return_masked: 42 | return [output, K.cast(self.compute_mask(inputs, mask)[0], K.floatx())] 43 | return output 44 | -------------------------------------------------------------------------------- /tests/layers/test_pooling.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | from tensorflow import keras 5 | 6 | from keras_bert.layers import MaskedGlobalMaxPool1D 7 | 8 | 9 | class TestPooling(TestCase): 10 | 11 | def test_masked_global_max_pool_1d_predict(self): 12 | input_layer = keras.layers.Input(shape=(None,)) 13 | embed_layer = keras.layers.Embedding( 14 | input_dim=5, 15 | output_dim=6, 16 | mask_zero=True, 17 | name='Embed' 18 | )(input_layer) 19 | pool_layer = MaskedGlobalMaxPool1D()(embed_layer) 20 | model = keras.models.Model(inputs=input_layer, outputs=pool_layer) 21 | model.compile(optimizer='adam', loss='mse') 22 | x = np.array([[1, 2, 0, 0], [2, 3, 4, 0]]) 23 | y = model.predict(x) 24 | embed = model.get_layer('Embed').get_weights()[0] 25 | expected = np.max(embed[1:3], axis=0) 26 | self.assertTrue(np.allclose(expected, y[0]), (expected, y[0])) 27 | expected = np.max(embed[2:5], axis=0) 28 | self.assertTrue(np.allclose(expected, y[1]), (expected, y[1])) 29 | 30 | def test_masked_global_max_pool_1d_fit(self): 31 | input_layer = keras.layers.Input(shape=(None,)) 32 | embed_layer = keras.layers.Embedding( 33 | input_dim=11, 34 | output_dim=13, 35 | mask_zero=False, 36 | )(input_layer) 37 | pool_layer = MaskedGlobalMaxPool1D()(embed_layer) 38 | dense_layer = keras.layers.Dense(units=2, activation='softmax')(pool_layer) 39 | model = keras.models.Model(inputs=input_layer, outputs=dense_layer) 40 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') 41 | model.summary() 42 | x = np.random.randint(0, 11, (32, 7)) 43 | y = np.random.randint(0, 2, (32,)) 44 | model.fit(x, y) 45 | -------------------------------------------------------------------------------- /keras_bert/datasets/pretrained.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from collections import namedtuple 4 | from tensorflow.keras.utils import get_file 5 | 6 | __all__ = ['PretrainedInfo', 'PretrainedList', 'get_pretrained'] 7 | 8 | 9 | PretrainedInfo = namedtuple('PretrainedInfo', ['url', 'extract_name', 'target_name']) 10 | 11 | 12 | class PretrainedList(object): 13 | 14 | __test__ = PretrainedInfo( 15 | 'https://github.com/CyberZHG/keras-bert/archive/master.zip', 16 | 'keras-bert-master', 17 | 'keras-bert', 18 | ) 19 | 20 | multi_cased_base = 'https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip' 21 | chinese_base = 'https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip' 22 | wwm_uncased_large = 'https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip' 23 | wwm_cased_large = 'https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip' 24 | chinese_wwm_base = PretrainedInfo( 25 | 'https://storage.googleapis.com/hfl-rc/chinese-bert/chinese_wwm_L-12_H-768_A-12.zip', 26 | 'publish', 27 | 'chinese_wwm_L-12_H-768_A-12', 28 | ) 29 | 30 | 31 | def get_pretrained(info): 32 | path = info 33 | if isinstance(info, PretrainedInfo): 34 | path = info.url 35 | path = get_file(fname=os.path.split(path)[-1], origin=path, extract=True) 36 | base_part, file_part = os.path.split(path) 37 | file_part = file_part.split('.')[0] 38 | if isinstance(info, PretrainedInfo): 39 | extract_path = os.path.join(base_part, info.extract_name) 40 | target_path = os.path.join(base_part, info.target_name) 41 | if not os.path.exists(target_path): 42 | shutil.move(extract_path, target_path) 43 | file_part = info.target_name 44 | return os.path.join(base_part, file_part) 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # System thumbnail 107 | .DS_Store 108 | 109 | # IDE 110 | .idea 111 | 112 | # Images 113 | *.png 114 | 115 | # Models 116 | *.h5 117 | -------------------------------------------------------------------------------- /demo/load_model/load_and_predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths 4 | 5 | print('This demo demonstrates how to load the pre-trained model and check whether the two sentences are continuous') 6 | 7 | if len(sys.argv) == 2: 8 | model_path = sys.argv[1] 9 | else: 10 | from keras_bert.datasets import get_pretrained, PretrainedList 11 | model_path = get_pretrained(PretrainedList.chinese_base) 12 | 13 | paths = get_checkpoint_paths(model_path) 14 | 15 | model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, training=True, seq_len=None) 16 | model.summary(line_length=120) 17 | 18 | token_dict = load_vocabulary(paths.vocab) 19 | token_dict_inv = {v: k for k, v in token_dict.items()} 20 | 21 | tokenizer = Tokenizer(token_dict) 22 | text = '数学是利用符号语言研究数量、结构、变化以及空间等概念的一门学科' 23 | tokens = tokenizer.tokenize(text) 24 | tokens[1] = tokens[2] = '[MASK]' 25 | print('Tokens:', tokens) 26 | 27 | indices = np.array([[token_dict[token] for token in tokens]]) 28 | segments = np.array([[0] * len(tokens)]) 29 | masks = np.array([[0, 1, 1] + [0] * (len(tokens) - 3)]) 30 | 31 | predicts = model.predict([indices, segments, masks])[0].argmax(axis=-1).tolist() 32 | print('Fill with: ', list(map(lambda x: token_dict_inv[x], predicts[0][1:3]))) 33 | 34 | 35 | sentence_1 = '数学是利用符号语言研究數量、结构、变化以及空间等概念的一門学科。' 36 | sentence_2 = '从某种角度看屬於形式科學的一種。' 37 | print('Tokens:', tokenizer.tokenize(first=sentence_1, second=sentence_2)) 38 | indices, segments = tokenizer.encode(first=sentence_1, second=sentence_2) 39 | masks = np.array([[0] * len(indices)]) 40 | 41 | predicts = model.predict([np.array([indices]), np.array([segments]), masks])[1] 42 | print('%s is random next: ' % sentence_2, bool(np.argmax(predicts, axis=-1)[0])) 43 | 44 | sentence_2 = '任何一个希尔伯特空间都有一族标准正交基。' 45 | print('Tokens:', tokenizer.tokenize(first=sentence_1, second=sentence_2)) 46 | indices, segments = tokenizer.encode(first=sentence_1, second=sentence_2) 47 | masks = np.array([[0] * len(indices)]) 48 | 49 | predicts = model.predict([np.array([indices]), np.array([segments]), masks])[1] 50 | print('%s is random next: ' % sentence_2, bool(np.argmax(predicts, axis=-1)[0])) 51 | -------------------------------------------------------------------------------- /demo/load_model/load_and_extract.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths 4 | 5 | print('This demo demonstrates how to load the pre-trained model and extract word embeddings') 6 | 7 | if len(sys.argv) == 2: 8 | model_path = sys.argv[1] 9 | else: 10 | from keras_bert.datasets import get_pretrained, PretrainedList 11 | model_path = get_pretrained(PretrainedList.chinese_base) 12 | 13 | paths = get_checkpoint_paths(model_path) 14 | 15 | model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=10) 16 | model.summary(line_length=120) 17 | 18 | token_dict = load_vocabulary(paths.vocab) 19 | 20 | tokenizer = Tokenizer(token_dict) 21 | text = '语言模型' 22 | tokens = tokenizer.tokenize(text) 23 | print('Tokens:', tokens) 24 | indices, segments = tokenizer.encode(first=text, max_len=10) 25 | 26 | predicts = model.predict([np.array([indices]), np.array([segments])])[0] 27 | for i, token in enumerate(tokens): 28 | print(token, predicts[i].tolist()[:5]) 29 | 30 | """Official outputs: 31 | { 32 | "linex_index": 0, 33 | "features": [ 34 | { 35 | "token": "[CLS]", 36 | "layers": [ 37 | { 38 | "index": -1, 39 | "values": [-0.63251, 0.203023, 0.079366, -0.032843, 0.566809, ...] 40 | } 41 | ] 42 | }, 43 | { 44 | "token": "语", 45 | "layers": [ 46 | { 47 | "index": -1, 48 | "values": [-0.758835, 0.096518, 1.071875, 0.005038, 0.688799, ...] 49 | } 50 | ] 51 | }, 52 | { 53 | "token": "言", 54 | "layers": [ 55 | { 56 | "index": -1, 57 | "values": [0.547702, -0.792117, 0.444354, -0.711265, 1.20489, ...] 58 | } 59 | ] 60 | }, 61 | { 62 | "token": "模", 63 | "layers": [ 64 | { 65 | "index": -1, 66 | "values": [-0.292423, 0.605271, 0.499686, -0.42458, 0.428554, ...] 67 | } 68 | ] 69 | }, 70 | { 71 | "token": "型", 72 | "layers": [ 73 | { 74 | "index": -1, 75 | "values": [ -0.747346, 0.494315, 0.718516, -0.872353, 0.83496, ...] 76 | } 77 | ] 78 | }, 79 | { 80 | "token": "[SEP]", 81 | "layers": [ 82 | { 83 | "index": -1, 84 | "values": [-0.874138, -0.216504, 1.338839, -0.105871, 0.39609, ...] 85 | } 86 | ] 87 | } 88 | ] 89 | } 90 | """ 91 | -------------------------------------------------------------------------------- /tests/test_loader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from keras_bert import load_trained_model_from_checkpoint, load_vocabulary 4 | 5 | 6 | class TestLoader(unittest.TestCase): 7 | 8 | def test_load_trained(self): 9 | current_path = os.path.dirname(os.path.abspath(__file__)) 10 | config_path = os.path.join(current_path, 'test_checkpoint', 'bert_config.json') 11 | model_path = os.path.join(current_path, 'test_checkpoint', 'bert_model.ckpt') 12 | model = load_trained_model_from_checkpoint(config_path, model_path, training=False) 13 | model.summary() 14 | 15 | def test_load_trained_shorter(self): 16 | current_path = os.path.dirname(os.path.abspath(__file__)) 17 | config_path = os.path.join(current_path, 'test_checkpoint', 'bert_config.json') 18 | model_path = os.path.join(current_path, 'test_checkpoint', 'bert_model.ckpt') 19 | model = load_trained_model_from_checkpoint(config_path, model_path, training=False, seq_len=8) 20 | model.summary() 21 | 22 | def test_load_training(self): 23 | current_path = os.path.dirname(os.path.abspath(__file__)) 24 | config_path = os.path.join(current_path, 'test_checkpoint', 'bert_config.json') 25 | model_path = os.path.join(current_path, 'test_checkpoint', 'bert_model.ckpt') 26 | model = load_trained_model_from_checkpoint(config_path, model_path, training=True) 27 | model.summary() 28 | 29 | def test_load_output_layer_num(self): 30 | current_path = os.path.dirname(os.path.abspath(__file__)) 31 | config_path = os.path.join(current_path, 'test_checkpoint', 'bert_config.json') 32 | model_path = os.path.join(current_path, 'test_checkpoint', 'bert_model.ckpt') 33 | model = load_trained_model_from_checkpoint(config_path, model_path, training=False, output_layer_num=4) 34 | model.summary() 35 | model = load_trained_model_from_checkpoint(config_path, model_path, training=False, output_layer_num=[0]) 36 | model.summary() 37 | model = load_trained_model_from_checkpoint(config_path, model_path, training=False, output_layer_num=[1]) 38 | model.summary() 39 | model = load_trained_model_from_checkpoint(config_path, model_path, training=False, output_layer_num=[-1]) 40 | model.summary() 41 | model = load_trained_model_from_checkpoint(config_path, model_path, training=False, output_layer_num=[-2]) 42 | model.summary() 43 | model = load_trained_model_from_checkpoint(config_path, model_path, training=False, output_layer_num=[0, -1]) 44 | model.summary() 45 | 46 | def test_load_with_trainable_prefixes(self): 47 | current_path = os.path.dirname(os.path.abspath(__file__)) 48 | config_path = os.path.join(current_path, 'test_checkpoint', 'bert_config.json') 49 | model_path = os.path.join(current_path, 'test_checkpoint', 'bert_model.ckpt') 50 | model = load_trained_model_from_checkpoint( 51 | config_path, 52 | model_path, 53 | training=False, 54 | trainable=['Encoder'], 55 | ) 56 | model.summary() 57 | 58 | def test_load_vocabulary(self): 59 | current_path = os.path.dirname(os.path.abspath(__file__)) 60 | vocab_path = os.path.join(current_path, 'test_checkpoint', 'vocab.txt') 61 | token_dict = load_vocabulary(vocab_path) 62 | self.assertEqual(15, len(token_dict)) 63 | -------------------------------------------------------------------------------- /tests/optimizers/test_warmup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from unittest import TestCase 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | 8 | from keras_bert import AdamWarmup 9 | 10 | 11 | class TestWarmup(TestCase): 12 | 13 | def _test_fit(self, optmizer): 14 | x = np.random.standard_normal((1000, 5)) 15 | y = np.dot(x, np.random.standard_normal((5, 2))).argmax(axis=-1) 16 | model = keras.models.Sequential() 17 | model.add(keras.layers.Dense( 18 | units=2, 19 | input_shape=(5,), 20 | kernel_constraint=keras.constraints.MaxNorm(1000.0), 21 | activation='softmax', 22 | )) 23 | model.compile( 24 | optimizer=optmizer, 25 | loss='sparse_categorical_crossentropy', 26 | ) 27 | model.fit( 28 | x, y, 29 | batch_size=10, 30 | epochs=110, 31 | callbacks=[keras.callbacks.EarlyStopping(monitor='loss', min_delta=1e-4, patience=3)], 32 | ) 33 | 34 | model_path = os.path.join(tempfile.gettempdir(), 'keras_warmup_%f.h5' % np.random.random()) 35 | model.save(model_path) 36 | 37 | from tensorflow.python.keras.utils.generic_utils import CustomObjectScope 38 | with CustomObjectScope({'AdamWarmup': AdamWarmup}): # Workaround for incorrect global variable used in keras 39 | model = keras.models.load_model(model_path, custom_objects={'AdamWarmup': AdamWarmup}) 40 | 41 | results = model.predict(x).argmax(axis=-1) 42 | diff = np.sum(np.abs(y - results)) 43 | self.assertLess(diff, 100) 44 | 45 | def test_fit(self): 46 | self._test_fit(AdamWarmup( 47 | decay_steps=10000, 48 | warmup_steps=5000, 49 | learning_rate=1e-3, 50 | min_lr=1e-4, 51 | amsgrad=False, 52 | weight_decay=1e-3, 53 | )) 54 | 55 | def test_fit_amsgrad(self): 56 | self._test_fit(AdamWarmup( 57 | decay_steps=10000, 58 | warmup_steps=5000, 59 | learning_rate=1e-3, 60 | min_lr=1e-4, 61 | amsgrad=True, 62 | weight_decay=1e-3, 63 | )) 64 | 65 | def test_fit_embed(self): 66 | model = keras.models.Sequential() 67 | model.add(keras.layers.Embedding( 68 | input_shape=(None,), 69 | input_dim=5, 70 | output_dim=16, 71 | mask_zero=True, 72 | )) 73 | model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=8))) 74 | model.add(keras.layers.Dense(units=2, activation='softmax')) 75 | model.compile(AdamWarmup( 76 | decay_steps=10000, 77 | warmup_steps=5000, 78 | learning_rate=1e-3, 79 | min_lr=1e-4, 80 | amsgrad=True, 81 | weight_decay=1e-3, 82 | ), loss='sparse_categorical_crossentropy') 83 | 84 | x = np.random.randint(0, 5, (1024, 15)) 85 | y = (x[:, 1] > 2).astype('int32') 86 | model.fit(x, y, epochs=10, verbose=1) 87 | 88 | model_path = os.path.join(tempfile.gettempdir(), 'test_warmup_%f.h5' % np.random.random()) 89 | model.save(model_path) 90 | from tensorflow.python.keras.utils.generic_utils import CustomObjectScope 91 | with CustomObjectScope({'AdamWarmup': AdamWarmup}): # Workaround for incorrect global variable used in keras 92 | keras.models.load_model(model_path, custom_objects={'AdamWarmup': AdamWarmup}) 93 | -------------------------------------------------------------------------------- /keras_bert/layers/task_embed.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | 4 | __all__ = ['TaskEmbedding'] 5 | 6 | 7 | class TaskEmbedding(keras.layers.Layer): 8 | """Embedding for tasks. 9 | 10 | # Arguments 11 | input_dim: int > 0. Number of the tasks. Plus 1 if `mask_zero` is enabled. 12 | output_dim: int >= 0. Dimension of the dense embedding. 13 | embeddings_initializer: Initializer for the `embeddings` matrix. 14 | embeddings_regularizer: Regularizer function applied to the `embeddings` matrix. 15 | embeddings_constraint: Constraint function applied to the `embeddings` matrix. 16 | mask_zero: Generate zeros for 0 index if it is `True`. 17 | 18 | # Input shape 19 | Previous embeddings, 3D tensor with shape: `(batch_size, sequence_length, output_dim)`. 20 | Task IDs, 2D tensor with shape: `(batch_size, 1)`. 21 | 22 | # Output shape 23 | 3D tensor with shape: `(batch_size, sequence_length, output_dim)`. 24 | """ 25 | 26 | def __init__(self, 27 | input_dim, 28 | output_dim, 29 | embeddings_initializer='uniform', 30 | embeddings_regularizer=None, 31 | embeddings_constraint=None, 32 | mask_zero=False, 33 | **kwargs): 34 | super(TaskEmbedding, self).__init__(**kwargs) 35 | self.supports_masking = True 36 | self.input_dim = input_dim 37 | self.output_dim = output_dim 38 | self.embeddings_initializer = keras.initializers.get(embeddings_initializer) 39 | self.embeddings_regularizer = keras.regularizers.get(embeddings_regularizer) 40 | self.embeddings_constraint = keras.constraints.get(embeddings_constraint) 41 | self.mask_zero = mask_zero 42 | 43 | self.embeddings = None 44 | 45 | def build(self, input_shape): 46 | self.embeddings = self.add_weight( 47 | shape=(self.input_dim, self.output_dim), 48 | initializer=self.embeddings_initializer, 49 | regularizer=self.embeddings_regularizer, 50 | constraint=self.embeddings_constraint, 51 | name='embeddings', 52 | ) 53 | super(TaskEmbedding, self).build(input_shape) 54 | 55 | def compute_mask(self, inputs, mask=None): 56 | output_mask = None 57 | if mask is not None: 58 | output_mask = mask[0] 59 | return output_mask 60 | 61 | def call(self, inputs, **kwargs): 62 | inputs, tasks = inputs 63 | if K.dtype(tasks) != 'int32': 64 | tasks = K.cast(tasks, 'int32') 65 | task_embed = K.gather(self.embeddings, tasks) 66 | if self.mask_zero: 67 | task_embed = task_embed * K.expand_dims(K.cast(K.not_equal(tasks, 0), K.floatx()), axis=-1) 68 | return inputs + task_embed 69 | 70 | def get_config(self): 71 | config = { 72 | 'input_dim': self.input_dim, 73 | 'output_dim': self.output_dim, 74 | 'embeddings_initializer': keras.initializers.serialize(self.embeddings_initializer), 75 | 'embeddings_regularizer': keras.regularizers.serialize(self.embeddings_regularizer), 76 | 'embeddings_constraint': keras.constraints.serialize(self.embeddings_constraint), 77 | 'mask_zero': self.mask_zero, 78 | } 79 | base_config = super(TaskEmbedding, self).get_config() 80 | return dict(list(base_config.items()) + list(config.items())) 81 | -------------------------------------------------------------------------------- /tests/layers/test_masked.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorflow import keras 5 | from tensorflow.keras import backend as K 6 | 7 | from keras_transformer import gelu, get_encoders 8 | from keras_bert.layers import get_inputs, get_embedding, Masked 9 | 10 | 11 | class TestMasked(unittest.TestCase): 12 | 13 | def test_sample(self): 14 | inputs = get_inputs(seq_len=512) 15 | embed_layer, _ = get_embedding(inputs, token_num=12, embed_dim=768, pos_num=512) 16 | masked_layer = Masked(name='Masked')([embed_layer, inputs[-1]]) 17 | model = keras.models.Model(inputs=inputs, outputs=masked_layer) 18 | model.compile( 19 | optimizer='adam', 20 | loss='mse', 21 | metrics={}, 22 | ) 23 | model.summary() 24 | model.predict([ 25 | np.asarray([[1] + [0] * 511]), 26 | np.asarray([[0] * 512]), 27 | np.asarray([[1] + [0] * 511]), 28 | ]) 29 | self.assertEqual((None, 512, 768), model.layers[-1].output_shape) 30 | 31 | def test_mask_result(self): 32 | input_layer = keras.layers.Input( 33 | shape=(None,), 34 | name='Input', 35 | ) 36 | embed_layer = keras.layers.Embedding( 37 | input_dim=12, 38 | output_dim=9, 39 | mask_zero=True, 40 | name='Embedding', 41 | )(input_layer) 42 | transformer_layer = get_encoders( 43 | encoder_num=1, 44 | input_layer=embed_layer, 45 | head_num=1, 46 | hidden_dim=12, 47 | attention_activation=None, 48 | feed_forward_activation=gelu, 49 | dropout_rate=0.1, 50 | ) 51 | dense_layer = keras.layers.Dense( 52 | units=12, 53 | activation='softmax', 54 | name='Dense', 55 | )(transformer_layer) 56 | mask_layer = keras.layers.Input( 57 | shape=(None,), 58 | name='Mask', 59 | ) 60 | masked_layer, mask_result = Masked( 61 | return_masked=True, 62 | name='Masked', 63 | )([dense_layer, mask_layer]) 64 | model = keras.models.Model( 65 | inputs=[input_layer, mask_layer], 66 | outputs=[masked_layer, mask_result], 67 | ) 68 | model.compile( 69 | optimizer='adam', 70 | loss='mse', 71 | ) 72 | model.summary() 73 | predicts = model.predict([ 74 | np.asarray([ 75 | [1, 2, 3, 4, 5, 6, 7, 8, 0, 0], 76 | [1, 2, 3, 4, 0, 0, 0, 0, 0, 0], 77 | ]), 78 | np.asarray([ 79 | [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], 80 | [0, 1, 0, 1, 0, 0, 0, 0, 0, 0], 81 | ]), 82 | ]) 83 | expect = np.asarray([ 84 | [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], 85 | [0, 1, 0, 1, 0, 0, 0, 0, 0, 0], 86 | ]) 87 | self.assertTrue(np.allclose(expect, predicts[1])) 88 | 89 | def test_mask_loss(self): 90 | def _loss(y_true, _): 91 | return K.sum(y_true, axis=-1) 92 | 93 | inputs = [keras.layers.Input((5,)), keras.layers.Input((5,))] 94 | embed = keras.layers.Embedding(input_dim=2, output_dim=3, mask_zero=True)(inputs[0]) 95 | masked = Masked()([embed, inputs[1]]) 96 | 97 | model = keras.models.Model(inputs, masked) 98 | model.compile( 99 | optimizer='sgd', 100 | loss=_loss, 101 | ) 102 | 103 | token_input = np.array([ 104 | [1, 1, 1, 0, 0], 105 | [1, 1, 1, 1, 0], 106 | ]) 107 | mask_input = np.array([ 108 | [0, 1, 0, 0, 0], 109 | [1, 0, 0, 0, 0], 110 | ]) 111 | outputs = np.arange(30, dtype=K.floatx()).reshape((2, 5, 3)) 112 | actual = model.evaluate([token_input, mask_input], outputs) 113 | self.assertTrue(np.abs(actual - 6.0) < 1e-6 or np.abs(actual - 30.0) < 1e-6, actual) 114 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from keras_bert import Tokenizer 3 | 4 | 5 | class TestTokenizer(TestCase): 6 | 7 | def test_uncased(self): 8 | tokens = [ 9 | '[PAD]', '[UNK]', '[CLS]', '[SEP]', 'want', '##want', 10 | '##ed', 'wa', 'un', 'runn', '##ing', ',', 11 | '\u535A', '\u63A8', 12 | ] 13 | token_dict = {token: i for i, token in enumerate(tokens)} 14 | tokenizer = Tokenizer(token_dict) 15 | text = u"UNwant\u00E9d, running \nah\u535A\u63A8zzz\u00AD" 16 | tokens = tokenizer.tokenize(text) 17 | expected = [ 18 | '[CLS]', 'un', '##want', '##ed', ',', 'runn', '##ing', 19 | 'a', '##h', '\u535A', '\u63A8', 'z', '##z', '##z', 20 | '[SEP]', 21 | ] 22 | self.assertEqual(expected, tokens) 23 | indices, segments = tokenizer.encode(text) 24 | expected = [2, 8, 5, 6, 11, 9, 10, 1, 1, 12, 13, 1, 1, 1, 3] 25 | self.assertEqual(expected, indices) 26 | expected = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 27 | self.assertEqual(expected, segments) 28 | 29 | decoded = tokenizer.decode(indices) 30 | expected = [ 31 | 'un', '##want', '##ed', ',', 'runn', '##ing', 32 | '[UNK]', '[UNK]', '\u535A', '\u63A8', '[UNK]', '[UNK]', '[UNK]', 33 | ] 34 | self.assertEqual(expected, decoded) 35 | 36 | def test_padding(self): 37 | tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] 38 | token_dict = {token: i for i, token in enumerate(tokens)} 39 | tokenizer = Tokenizer(token_dict) 40 | text = '\u535A\u63A8' 41 | 42 | # single 43 | indices, segments = tokenizer.encode(first=text, max_len=100) 44 | expected = [2, 1, 1, 3] + [0] * 96 45 | self.assertEqual(expected, indices) 46 | expected = [0] * 100 47 | self.assertEqual(expected, segments) 48 | decoded = tokenizer.decode(indices) 49 | self.assertEqual(['[UNK]', '[UNK]'], decoded) 50 | indices, segments = tokenizer.encode(first=text, max_len=3) 51 | self.assertEqual([2, 1, 3], indices) 52 | self.assertEqual([0, 0, 0], segments) 53 | 54 | # paired 55 | indices, segments = tokenizer.encode(first=text, second=text, max_len=100) 56 | expected = [2, 1, 1, 3, 1, 1, 3] + [0] * 93 57 | self.assertEqual(expected, indices) 58 | expected = [0, 0, 0, 0, 1, 1, 1] + [0] * 93 59 | self.assertEqual(expected, segments) 60 | decoded = tokenizer.decode(indices) 61 | self.assertEqual((['[UNK]', '[UNK]'], ['[UNK]', '[UNK]']), decoded) 62 | indices, segments = tokenizer.encode(first=text, second=text, max_len=4) 63 | self.assertEqual([2, 1, 3, 3], indices) 64 | self.assertEqual([0, 0, 0, 1], segments) 65 | 66 | def test_empty(self): 67 | tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] 68 | token_dict = {token: i for i, token in enumerate(tokens)} 69 | tokenizer = Tokenizer(token_dict) 70 | text = u'' 71 | self.assertEqual(['[CLS]', '[SEP]'], tokenizer.tokenize(text)) 72 | indices, segments = tokenizer.encode(text) 73 | self.assertEqual([2, 3], indices) 74 | self.assertEqual([0, 0], segments) 75 | decoded = tokenizer.decode(indices) 76 | self.assertEqual([], decoded) 77 | 78 | def test_cased(self): 79 | tokens = [ 80 | '[UNK]', u'[CLS]', '[SEP]', 'want', '##want', 81 | '##\u00E9d', 'wa', 'UN', 'runn', '##ing', ',', 82 | ] 83 | token_dict = {token: i for i, token in enumerate(tokens)} 84 | tokenizer = Tokenizer(token_dict, cased=True) 85 | text = "UNwant\u00E9d, running" 86 | tokens = tokenizer.tokenize(text) 87 | expected = ['[CLS]', 'UN', '##want', '##\u00E9d', ',', 'runn', '##ing', '[SEP]'] 88 | self.assertEqual(expected, tokens) 89 | indices, segments = tokenizer.encode(text) 90 | expected = [1, 7, 4, 5, 10, 8, 9, 2] 91 | self.assertEqual(expected, indices) 92 | expected = [0, 0, 0, 0, 0, 0, 0, 0] 93 | self.assertEqual(expected, segments) 94 | -------------------------------------------------------------------------------- /keras_bert/layers/embedding.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | from keras_pos_embd import PositionEmbedding 4 | 5 | 6 | class TokenEmbedding(keras.layers.Embedding): 7 | """Embedding layer with weights returned.""" 8 | 9 | def compute_mask(self, inputs, mask=None): 10 | return [super(TokenEmbedding, self).compute_mask(inputs, mask), None] 11 | 12 | def call(self, inputs): 13 | return [super(TokenEmbedding, self).call(inputs), self.embeddings + 0] 14 | 15 | 16 | def get_embedding(inputs, token_num, pos_num, embed_dim, dropout_rate=0.1, trainable=True): 17 | """Get embedding layer. 18 | 19 | See: https://arxiv.org/pdf/1810.04805.pdf 20 | 21 | :param inputs: Input layers. 22 | :param token_num: Number of tokens. 23 | :param pos_num: Maximum position. 24 | :param embed_dim: The dimension of all embedding layers. 25 | :param dropout_rate: Dropout rate. 26 | :param trainable: Whether the layers are trainable. 27 | :return: The merged embedding layer and weights of token embedding. 28 | """ 29 | embeddings = [ 30 | TokenEmbedding( 31 | input_dim=token_num, 32 | output_dim=embed_dim, 33 | mask_zero=True, 34 | trainable=trainable, 35 | name='Embedding-Token', 36 | )(inputs[0]), 37 | keras.layers.Embedding( 38 | input_dim=2, 39 | output_dim=embed_dim, 40 | trainable=trainable, 41 | name='Embedding-Segment', 42 | )(inputs[1]), 43 | ] 44 | embeddings[0], embed_weights = embeddings[0] 45 | embed_layer = keras.layers.Add(name='Embedding-Token-Segment')(embeddings) 46 | embed_layer = PositionEmbedding( 47 | input_dim=pos_num, 48 | output_dim=embed_dim, 49 | mode=PositionEmbedding.MODE_ADD, 50 | trainable=trainable, 51 | name='Embedding-Position', 52 | )(embed_layer) 53 | return embed_layer, embed_weights 54 | 55 | 56 | class EmbeddingSimilarity(keras.layers.Layer): 57 | """Calculate similarity between features and token embeddings with bias term.""" 58 | 59 | def __init__(self, 60 | initializer='zeros', 61 | regularizer=None, 62 | constraint=None, 63 | **kwargs): 64 | """Initialize the layer. 65 | 66 | :param output_dim: Same as embedding output dimension. 67 | :param initializer: Initializer for bias. 68 | :param regularizer: Regularizer for bias. 69 | :param constraint: Constraint for bias. 70 | :param kwargs: Arguments for parent class. 71 | """ 72 | super(EmbeddingSimilarity, self).__init__(**kwargs) 73 | self.supports_masking = True 74 | self.initializer = keras.initializers.get(initializer) 75 | self.regularizer = keras.regularizers.get(regularizer) 76 | self.constraint = keras.constraints.get(constraint) 77 | self.bias = None 78 | 79 | def get_config(self): 80 | config = { 81 | 'initializer': keras.initializers.serialize(self.initializer), 82 | 'regularizer': keras.regularizers.serialize(self.regularizer), 83 | 'constraint': keras.constraints.serialize(self.constraint), 84 | } 85 | base_config = super(EmbeddingSimilarity, self).get_config() 86 | return dict(list(base_config.items()) + list(config.items())) 87 | 88 | def build(self, input_shape): 89 | self.bias = self.add_weight( 90 | shape=(int(input_shape[1][0]),), 91 | initializer=self.initializer, 92 | regularizer=self.regularizer, 93 | constraint=self.constraint, 94 | name='bias', 95 | ) 96 | super(EmbeddingSimilarity, self).build(input_shape) 97 | 98 | def compute_mask(self, inputs, mask=None): 99 | return mask[0] 100 | 101 | def call(self, inputs, mask=None, **kwargs): 102 | inputs, embeddings = inputs 103 | outputs = K.bias_add(K.dot(inputs, K.transpose(embeddings)), self.bias) 104 | return keras.activations.softmax(outputs) 105 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import codecs 4 | 5 | from tensorflow import keras 6 | 7 | from keras_bert import get_model, POOL_NSP, POOL_MAX, POOL_AVE, extract_embeddings 8 | 9 | 10 | class TestUtil(unittest.TestCase): 11 | 12 | def setUp(self): 13 | current_path = os.path.dirname(os.path.abspath(__file__)) 14 | self.model_path = os.path.join(current_path, 'test_checkpoint') 15 | 16 | def test_extract_embeddings_default(self): 17 | embeddings = extract_embeddings(self.model_path, ['all work and no play', 'makes jack a dull boy~']) 18 | self.assertEqual(2, len(embeddings)) 19 | self.assertEqual((7, 4), embeddings[0].shape) 20 | self.assertEqual((8, 4), embeddings[1].shape) 21 | 22 | def test_extract_embeddings_batch_size_1(self): 23 | embeddings = extract_embeddings( 24 | self.model_path, 25 | ['all work and no play', 'makes jack a dull boy~'], 26 | batch_size=1, 27 | ) 28 | self.assertEqual(2, len(embeddings)) 29 | self.assertEqual((7, 4), embeddings[0].shape) 30 | self.assertEqual((8, 4), embeddings[1].shape) 31 | 32 | def test_extract_embeddings_pair(self): 33 | embeddings = extract_embeddings( 34 | self.model_path, 35 | [ 36 | ('all work and no play', 'makes jack a dull boy'), 37 | ('makes jack a dull boy', 'all work and no play'), 38 | ], 39 | ) 40 | self.assertEqual(2, len(embeddings)) 41 | self.assertEqual((13, 4), embeddings[0].shape) 42 | 43 | def test_extract_embeddings_single_pooling(self): 44 | embeddings = extract_embeddings( 45 | self.model_path, 46 | [ 47 | ('all work and no play', 'makes jack a dull boy'), 48 | ('makes jack a dull boy', 'all work and no play'), 49 | ], 50 | poolings=POOL_NSP, 51 | ) 52 | self.assertEqual(2, len(embeddings)) 53 | self.assertEqual((4,), embeddings[0].shape) 54 | 55 | def test_extract_embeddings_multi_pooling(self): 56 | embeddings = extract_embeddings( 57 | self.model_path, 58 | [ 59 | ('all work and no play', 'makes jack a dull boy'), 60 | ('makes jack a dull boy', 'all work and no play'), 61 | ], 62 | poolings=[POOL_NSP, POOL_MAX, POOL_AVE], 63 | output_layer_num=2, 64 | ) 65 | self.assertEqual(2, len(embeddings)) 66 | self.assertEqual((24,), embeddings[0].shape) 67 | 68 | def test_extract_embeddings_invalid_pooling(self): 69 | with self.assertRaises(ValueError): 70 | extract_embeddings( 71 | self.model_path, 72 | [ 73 | ('all work and no play', 'makes jack a dull boy'), 74 | ('makes jack a dull boy', 'all work and no play'), 75 | ], 76 | poolings=['invalid'], 77 | ) 78 | 79 | def test_extract_embeddings_variable_lengths(self): 80 | tokens = [ 81 | '[PAD]', '[UNK]', '[CLS]', '[SEP]', 82 | 'all', 'work', 'and', 'no', 'play', 83 | 'makes', 'jack', 'a', 'dull', 'boy', '~', 84 | ] 85 | token_dict = {token: i for i, token in enumerate(tokens)} 86 | inputs, outputs = get_model( 87 | token_num=len(tokens), 88 | pos_num=20, 89 | seq_len=None, 90 | embed_dim=13, 91 | transformer_num=1, 92 | feed_forward_dim=17, 93 | head_num=1, 94 | training=False, 95 | ) 96 | model = keras.models.Model(inputs, outputs) 97 | embeddings = extract_embeddings( 98 | model, 99 | [ 100 | ('all work and no play', 'makes jack'), 101 | ('a dull boy', 'all work and no play and no play'), 102 | ], 103 | vocabs=token_dict, 104 | batch_size=2, 105 | ) 106 | self.assertEqual(2, len(embeddings)) 107 | self.assertEqual((10, 13), embeddings[0].shape) 108 | self.assertEqual((14, 13), embeddings[1].shape) 109 | 110 | def test_extract_embeddings_from_file(self): 111 | with codecs.open(os.path.join(self.model_path, 'vocab.txt'), 'r', 'utf8') as reader: 112 | texts = map(lambda x: x.strip(), reader) 113 | embeddings = extract_embeddings(self.model_path, texts) 114 | self.assertEqual(15, len(embeddings)) 115 | -------------------------------------------------------------------------------- /tests/test_bert.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import tempfile 4 | 5 | import numpy as np 6 | from tensorflow import keras 7 | from tensorflow.keras import backend as K 8 | 9 | from keras_bert import (get_model, compile_model, get_base_dict, gen_batch_inputs, get_token_embedding, 10 | get_custom_objects) 11 | 12 | 13 | class TestBERT(unittest.TestCase): 14 | 15 | def test_sample(self): 16 | model = get_model( 17 | token_num=200, 18 | head_num=3, 19 | transformer_num=2, 20 | ) 21 | model_path = os.path.join(tempfile.gettempdir(), 'keras_bert_%f.h5' % np.random.random()) 22 | model.save(model_path) 23 | from tensorflow.python.keras.utils.generic_utils import CustomObjectScope 24 | with CustomObjectScope(get_custom_objects()): # Workaround for incorrect global variable used in keras 25 | model = keras.models.load_model( 26 | model_path, 27 | custom_objects=get_custom_objects(), 28 | ) 29 | model.summary(line_length=200) 30 | 31 | def test_task_embed(self): 32 | inputs, outputs = get_model( 33 | token_num=20, 34 | embed_dim=12, 35 | head_num=3, 36 | transformer_num=2, 37 | use_task_embed=True, 38 | task_num=10, 39 | training=False, 40 | dropout_rate=0.0, 41 | ) 42 | model = keras.models.Model(inputs, outputs) 43 | model_path = os.path.join(tempfile.gettempdir(), 'keras_bert_%f.h5' % np.random.random()) 44 | model.save(model_path) 45 | from tensorflow.python.keras.utils.generic_utils import CustomObjectScope 46 | with CustomObjectScope(get_custom_objects()): # Workaround for incorrect global variable used in keras 47 | model = keras.models.load_model( 48 | model_path, 49 | custom_objects=get_custom_objects(), 50 | ) 51 | model.summary(line_length=200) 52 | 53 | def test_save_load_json(self): 54 | model = get_model( 55 | token_num=200, 56 | head_num=3, 57 | transformer_num=2, 58 | attention_activation='gelu', 59 | ) 60 | compile_model(model) 61 | data = model.to_json() 62 | model = keras.models.model_from_json(data, custom_objects=get_custom_objects()) 63 | model.summary() 64 | 65 | def test_get_token_embedding(self): 66 | model = get_model( 67 | token_num=200, 68 | head_num=3, 69 | transformer_num=2, 70 | attention_activation='gelu', 71 | ) 72 | embed = get_token_embedding(model) 73 | self.assertEqual((200, 768), K.int_shape(embed)) 74 | 75 | def test_fit(self): 76 | current_path = os.path.dirname(os.path.abspath(__file__)) 77 | model_path = os.path.join(current_path, 'test_bert_fit.h5') 78 | sentence_pairs = [ 79 | [['all', 'work', 'and', 'no', 'play'], ['makes', 'jack', 'a', 'dull', 'boy']], 80 | [['from', 'the', 'day', 'forth'], ['my', 'arm', 'changed']], 81 | [['and', 'a', 'voice', 'echoed'], ['power', 'give', 'me', 'more', 'power']], 82 | ] 83 | token_dict = get_base_dict() 84 | for pairs in sentence_pairs: 85 | for token in pairs[0] + pairs[1]: 86 | if token not in token_dict: 87 | token_dict[token] = len(token_dict) 88 | token_list = list(token_dict.keys()) 89 | if os.path.exists(model_path): 90 | steps_per_epoch = 10 91 | from tensorflow.python.keras.utils.generic_utils import CustomObjectScope 92 | with CustomObjectScope(get_custom_objects()): # Workaround for incorrect global variable used in keras 93 | model = keras.models.load_model( 94 | model_path, 95 | custom_objects=get_custom_objects(), 96 | ) 97 | else: 98 | steps_per_epoch = 1000 99 | model = get_model( 100 | token_num=len(token_dict), 101 | head_num=5, 102 | transformer_num=12, 103 | embed_dim=25, 104 | feed_forward_dim=100, 105 | seq_len=20, 106 | pos_num=20, 107 | dropout_rate=0.05, 108 | attention_activation='gelu', 109 | ) 110 | compile_model( 111 | model, 112 | learning_rate=1e-3, 113 | decay_steps=30000, 114 | warmup_steps=10000, 115 | weight_decay=1e-3, 116 | ) 117 | model.summary() 118 | 119 | def _generator(): 120 | while True: 121 | yield gen_batch_inputs( 122 | sentence_pairs, 123 | token_dict, 124 | token_list, 125 | seq_len=20, 126 | mask_rate=0.3, 127 | swap_sentence_rate=1.0, 128 | ) 129 | 130 | model.fit_generator( 131 | generator=_generator(), 132 | steps_per_epoch=steps_per_epoch, 133 | epochs=1, 134 | validation_data=_generator(), 135 | validation_steps=steps_per_epoch // 10, 136 | ) 137 | # model.save(model_path) 138 | for inputs, outputs in _generator(): 139 | predicts = model.predict(inputs) 140 | outputs = list(map(lambda x: np.squeeze(x, axis=-1), outputs)) 141 | predicts = list(map(lambda x: np.argmax(x, axis=-1), predicts)) 142 | batch_size, seq_len = inputs[-1].shape 143 | for i in range(batch_size): 144 | match, total = 0, 0 145 | for j in range(seq_len): 146 | if inputs[-1][i][j]: 147 | total += 1 148 | if outputs[0][i][j] == predicts[0][i][j]: 149 | match += 1 150 | self.assertGreater(match, total * 0.9) 151 | self.assertTrue(np.allclose(outputs[1], predicts[1])) 152 | break 153 | -------------------------------------------------------------------------------- /keras_bert/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | 4 | import numpy as np 5 | from tensorflow import keras 6 | from tensorflow.keras import backend as K 7 | 8 | from .layers import Extract, MaskedGlobalMaxPool1D 9 | from .loader import load_trained_model_from_checkpoint, load_vocabulary 10 | from .tokenizer import Tokenizer 11 | 12 | __all__ = [ 13 | 'POOL_NSP', 'POOL_MAX', 'POOL_AVE', 14 | 'get_checkpoint_paths', 'extract_embeddings_generator', 'extract_embeddings', 15 | ] 16 | 17 | 18 | POOL_NSP = 'POOL_NSP' 19 | POOL_MAX = 'POOL_MAX' 20 | POOL_AVE = 'POOL_AVE' 21 | 22 | 23 | def get_checkpoint_paths(model_path): 24 | CheckpointPaths = namedtuple('CheckpointPaths', ['config', 'checkpoint', 'vocab']) 25 | config_path = os.path.join(model_path, 'bert_config.json') 26 | checkpoint_path = os.path.join(model_path, 'bert_model.ckpt') 27 | vocab_path = os.path.join(model_path, 'vocab.txt') 28 | return CheckpointPaths(config_path, checkpoint_path, vocab_path) 29 | 30 | 31 | def extract_embeddings_generator(model, 32 | texts, 33 | poolings=None, 34 | vocabs=None, 35 | cased=False, 36 | batch_size=4, 37 | cut_embed=True, 38 | output_layer_num=1): 39 | """Extract embeddings from texts. 40 | 41 | :param model: Path to the checkpoint or built model without MLM and NSP. 42 | :param texts: Iterable texts. 43 | :param poolings: Pooling methods. Word embeddings will be returned if it is None. 44 | Otherwise concatenated pooled embeddings will be returned. 45 | :param vocabs: A dict should be provided if model is built. 46 | :param cased: Whether it is cased for tokenizer. 47 | :param batch_size: Batch size. 48 | :param cut_embed: The computed embeddings will be cut based on their input lengths. 49 | :param output_layer_num: The number of layers whose outputs will be concatenated as a single output. 50 | Only available when `model` is a path to checkpoint. 51 | :return: A list of numpy arrays representing the embeddings. 52 | """ 53 | if isinstance(model, (str, type(u''))): 54 | paths = get_checkpoint_paths(model) 55 | model = load_trained_model_from_checkpoint( 56 | config_file=paths.config, 57 | checkpoint_file=paths.checkpoint, 58 | output_layer_num=output_layer_num, 59 | ) 60 | vocabs = load_vocabulary(paths.vocab) 61 | 62 | seq_len = K.int_shape(model.outputs[0])[1] 63 | tokenizer = Tokenizer(vocabs, cased=cased) 64 | 65 | def _batch_generator(): 66 | tokens, segments = [], [] 67 | 68 | def _pad_inputs(): 69 | if seq_len is None: 70 | max_len = max(map(len, tokens)) 71 | for i in range(len(tokens)): 72 | tokens[i].extend([0] * (max_len - len(tokens[i]))) 73 | segments[i].extend([0] * (max_len - len(segments[i]))) 74 | return [np.array(tokens), np.array(segments)] 75 | 76 | for text in texts: 77 | if isinstance(text, (str, type(u''))): 78 | token, segment = tokenizer.encode(text, max_len=seq_len) 79 | else: 80 | token, segment = tokenizer.encode(text[0], text[1], max_len=seq_len) 81 | tokens.append(token) 82 | segments.append(segment) 83 | if len(tokens) == batch_size: 84 | yield _pad_inputs() 85 | tokens, segments = [], [] 86 | if len(tokens) > 0: 87 | yield _pad_inputs() 88 | 89 | if poolings is not None: 90 | if isinstance(poolings, (str, type(u''))): 91 | poolings = [poolings] 92 | outputs = [] 93 | for pooling in poolings: 94 | if pooling == POOL_NSP: 95 | outputs.append(Extract(index=0, name='Pool-NSP')(model.outputs[0])) 96 | elif pooling == POOL_MAX: 97 | outputs.append(MaskedGlobalMaxPool1D(name='Pool-Max')(model.outputs[0])) 98 | elif pooling == POOL_AVE: 99 | outputs.append(keras.layers.GlobalAvgPool1D(name='Pool-Ave')(model.outputs[0])) 100 | else: 101 | raise ValueError('Unknown pooling method: {}'.format(pooling)) 102 | if len(outputs) == 1: 103 | outputs = outputs[0] 104 | else: 105 | outputs = keras.layers.Concatenate(name='Concatenate')(outputs) 106 | model = keras.models.Model(inputs=model.inputs, outputs=outputs) 107 | 108 | for batch_inputs in _batch_generator(): 109 | outputs = model.predict(batch_inputs) 110 | for inputs, output in zip(batch_inputs[0], outputs): 111 | if poolings is None and cut_embed: 112 | length = 0 113 | for i in range(len(inputs) - 1, -1, -1): 114 | if inputs[i] != 0: 115 | length = i + 1 116 | break 117 | output = output[:length] 118 | yield output 119 | 120 | 121 | def extract_embeddings(model, 122 | texts, 123 | poolings=None, 124 | vocabs=None, 125 | cased=False, 126 | batch_size=4, 127 | cut_embed=True, 128 | output_layer_num=1): 129 | """Extract embeddings from texts. 130 | 131 | :param model: Path to the checkpoint or built model without MLM and NSP. 132 | :param texts: Iterable texts. 133 | :param poolings: Pooling methods. Word embeddings will be returned if it is None. 134 | Otherwise concatenated pooled embeddings will be returned. 135 | :param vocabs: A dict should be provided if model is built. 136 | :param cased: Whether it is cased for tokenizer. 137 | :param batch_size: Batch size. 138 | :param cut_embed: The computed embeddings will be cut based on their input lengths. 139 | :param output_layer_num: The number of layers whose outputs will be concatenated as a single output. 140 | Only available when `model` is a path to checkpoint. 141 | :return: A list of numpy arrays representing the embeddings. 142 | """ 143 | return [embedding for embedding in extract_embeddings_generator( 144 | model, texts, poolings, vocabs, cased, batch_size, cut_embed, output_layer_num 145 | )] 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras BERT 2 | 3 | [![Version](https://img.shields.io/pypi/v/keras-bert.svg)](https://pypi.org/project/keras-bert/) 4 | ![License](https://img.shields.io/pypi/l/keras-bert.svg) 5 | 6 | \[[中文](https://github.com/CyberZHG/keras-bert/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-bert/blob/master/README.md)\] 7 | 8 | Implementation of the [BERT](https://arxiv.org/pdf/1810.04805.pdf). Official pre-trained models could be loaded for feature extraction and prediction. 9 | 10 | ## Install 11 | 12 | ```bash 13 | pip install keras-bert 14 | ``` 15 | 16 | ## Usage 17 | 18 | * [Load Official Pre-trained Models](#Load-Official-Pre-trained-Models) 19 | * [Tokenizer](#Tokenizer) 20 | * [Train & Use](#Train-&-Use) 21 | * [Use Warmup](#Use-Warmup) 22 | * [Download Pretrained Checkpoints](#Download-Pretrained-Checkpoints) 23 | * [Extract Features](#Extract-Features) 24 | 25 | ### External Links 26 | 27 | * [Kashgari is a Production-ready NLP Transfer learning framework for text-labeling and text-classification](https://github.com/BrikerMan/Kashgari) 28 | * [Keras ALBERT](https://github.com/TinkerMob/keras_albert_model) 29 | 30 | ### Load Official Pre-trained Models 31 | 32 | In [feature extraction demo](./demo/load_model/load_and_extract.py), you should be able to get the same extraction results as the official model `chinese_L-12_H-768_A-12`. And in [prediction demo](./demo/load_model/load_and_predict.py), the missing word in the sentence could be predicted. 33 | 34 | 35 | ### Run on TPU 36 | 37 | The [extraction demo](https://colab.research.google.com/github/CyberZHG/keras-bert/blob/master/demo/load_model/keras_bert_load_and_extract_tpu.ipynb) shows how to convert to a model that runs on TPU. 38 | 39 | The [classification demo](https://colab.research.google.com/github/CyberZHG/keras-bert/blob/master/demo/tune/keras_bert_classification_tpu.ipynb) shows how to apply the model to simple classification tasks. 40 | 41 | ### Tokenizer 42 | 43 | The `Tokenizer` class is used for splitting texts and generating indices: 44 | 45 | ```python 46 | from keras_bert import Tokenizer 47 | 48 | token_dict = { 49 | '[CLS]': 0, 50 | '[SEP]': 1, 51 | 'un': 2, 52 | '##aff': 3, 53 | '##able': 4, 54 | '[UNK]': 5, 55 | } 56 | tokenizer = Tokenizer(token_dict) 57 | print(tokenizer.tokenize('unaffable')) # The result should be `['[CLS]', 'un', '##aff', '##able', '[SEP]']` 58 | indices, segments = tokenizer.encode('unaffable') 59 | print(indices) # Should be `[0, 2, 3, 4, 1]` 60 | print(segments) # Should be `[0, 0, 0, 0, 0]` 61 | 62 | print(tokenizer.tokenize(first='unaffable', second='钢')) 63 | # The result should be `['[CLS]', 'un', '##aff', '##able', '[SEP]', '钢', '[SEP]']` 64 | indices, segments = tokenizer.encode(first='unaffable', second='钢', max_len=10) 65 | print(indices) # Should be `[0, 2, 3, 4, 1, 5, 1, 0, 0, 0]` 66 | print(segments) # Should be `[0, 0, 0, 0, 0, 1, 1, 0, 0, 0]` 67 | ``` 68 | 69 | ### Train & Use 70 | 71 | ```python 72 | from tensorflow import keras 73 | from keras_bert import get_base_dict, get_model, compile_model, gen_batch_inputs 74 | 75 | 76 | # A toy input example 77 | sentence_pairs = [ 78 | [['all', 'work', 'and', 'no', 'play'], ['makes', 'jack', 'a', 'dull', 'boy']], 79 | [['from', 'the', 'day', 'forth'], ['my', 'arm', 'changed']], 80 | [['and', 'a', 'voice', 'echoed'], ['power', 'give', 'me', 'more', 'power']], 81 | ] 82 | 83 | 84 | # Build token dictionary 85 | token_dict = get_base_dict() # A dict that contains some special tokens 86 | for pairs in sentence_pairs: 87 | for token in pairs[0] + pairs[1]: 88 | if token not in token_dict: 89 | token_dict[token] = len(token_dict) 90 | token_list = list(token_dict.keys()) # Used for selecting a random word 91 | 92 | 93 | # Build & train the model 94 | model = get_model( 95 | token_num=len(token_dict), 96 | head_num=5, 97 | transformer_num=12, 98 | embed_dim=25, 99 | feed_forward_dim=100, 100 | seq_len=20, 101 | pos_num=20, 102 | dropout_rate=0.05, 103 | ) 104 | compile_model(model) 105 | model.summary() 106 | 107 | def _generator(): 108 | while True: 109 | yield gen_batch_inputs( 110 | sentence_pairs, 111 | token_dict, 112 | token_list, 113 | seq_len=20, 114 | mask_rate=0.3, 115 | swap_sentence_rate=1.0, 116 | ) 117 | 118 | model.fit_generator( 119 | generator=_generator(), 120 | steps_per_epoch=1000, 121 | epochs=100, 122 | validation_data=_generator(), 123 | validation_steps=100, 124 | callbacks=[ 125 | keras.callbacks.EarlyStopping(monitor='val_loss', patience=5) 126 | ], 127 | ) 128 | 129 | 130 | # Use the trained model 131 | inputs, output_layer = get_model( 132 | token_num=len(token_dict), 133 | head_num=5, 134 | transformer_num=12, 135 | embed_dim=25, 136 | feed_forward_dim=100, 137 | seq_len=20, 138 | pos_num=20, 139 | dropout_rate=0.05, 140 | training=False, # The input layers and output layer will be returned if `training` is `False` 141 | trainable=False, # Whether the model is trainable. The default value is the same with `training` 142 | output_layer_num=4, # The number of layers whose outputs will be concatenated as a single output. 143 | # Only available when `training` is `False`. 144 | ) 145 | ``` 146 | 147 | ### Use Warmup 148 | 149 | `AdamWarmup` optimizer is provided for warmup and decay. The learning rate will reach `lr` in `warmpup_steps` steps, and decay to `min_lr` in `decay_steps` steps. There is a helper function `calc_train_steps` for calculating the two steps: 150 | 151 | ```python 152 | import numpy as np 153 | from keras_bert import AdamWarmup, calc_train_steps 154 | 155 | train_x = np.random.standard_normal((1024, 100)) 156 | 157 | total_steps, warmup_steps = calc_train_steps( 158 | num_example=train_x.shape[0], 159 | batch_size=32, 160 | epochs=10, 161 | warmup_proportion=0.1, 162 | ) 163 | 164 | optimizer = AdamWarmup(total_steps, warmup_steps, lr=1e-3, min_lr=1e-5) 165 | ``` 166 | 167 | ### Download Pretrained Checkpoints 168 | 169 | Several download urls has been added. You can get the downloaded and uncompressed path of a checkpoint by: 170 | 171 | ```python 172 | from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths 173 | 174 | model_path = get_pretrained(PretrainedList.multi_cased_base) 175 | paths = get_checkpoint_paths(model_path) 176 | print(paths.config, paths.checkpoint, paths.vocab) 177 | ``` 178 | 179 | ### Extract Features 180 | 181 | You can use helper function `extract_embeddings` if the features of tokens or sentences (without further tuning) are what you need. To extract the features of all tokens: 182 | 183 | ```python 184 | from keras_bert import extract_embeddings 185 | 186 | model_path = 'xxx/yyy/uncased_L-12_H-768_A-12' 187 | texts = ['all work and no play', 'makes jack a dull boy~'] 188 | 189 | embeddings = extract_embeddings(model_path, texts) 190 | ``` 191 | 192 | The returned result is a list with the same length as texts. Each item in the list is a numpy array truncated by the length of the input. The shapes of outputs in this example are `(7, 768)` and `(8, 768)`. 193 | 194 | When the inputs are paired-sentences, and you need the outputs of `NSP` and max-pooling of the last 4 layers: 195 | 196 | ```python 197 | from keras_bert import extract_embeddings, POOL_NSP, POOL_MAX 198 | 199 | model_path = 'xxx/yyy/uncased_L-12_H-768_A-12' 200 | texts = [ 201 | ('all work and no play', 'makes jack a dull boy'), 202 | ('makes jack a dull boy', 'all work and no play'), 203 | ] 204 | 205 | embeddings = extract_embeddings(model_path, texts, output_layer_num=4, poolings=[POOL_NSP, POOL_MAX]) 206 | ``` 207 | 208 | There are no token features in the results. The outputs of `NSP` and max-pooling will be concatenated with the final shape `(768 x 4 x 2,)`. 209 | 210 | The second argument in the helper function is a generator. To extract features from file: 211 | 212 | ```python 213 | import codecs 214 | from keras_bert import extract_embeddings 215 | 216 | model_path = 'xxx/yyy/uncased_L-12_H-768_A-12' 217 | 218 | with codecs.open('xxx.txt', 'r', 'utf8') as reader: 219 | texts = map(lambda x: x.strip(), reader) 220 | embeddings = extract_embeddings(model_path, texts) 221 | ``` 222 | -------------------------------------------------------------------------------- /keras_bert/optimizers/warmup_v2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.optimizers import Optimizer 3 | from tensorflow.python.ops import state_ops, control_flow_ops 4 | from tensorflow.python.keras import backend_config 5 | 6 | 7 | __all__ = ['AdamWarmup'] 8 | 9 | 10 | class AdamWarmup(Optimizer): 11 | """Adam optimizer with warmup.""" 12 | 13 | def __init__(self, 14 | decay_steps, 15 | warmup_steps, 16 | min_lr=0.0, 17 | learning_rate=0.001, 18 | beta_1=0.9, 19 | beta_2=0.999, 20 | epsilon=1e-7, 21 | weight_decay=0., 22 | weight_decay_pattern=None, 23 | amsgrad=False, 24 | name='AdamWarmup', 25 | **kwargs): 26 | r"""Construct a new Adam optimizer. 27 | 28 | Args: 29 | decay_steps: Learning rate will decay linearly to zero in decay steps. 30 | warmup_steps: Learning rate will increase linearly to lr in first warmup steps. 31 | lr: float >= 0. Learning rate. 32 | beta_1: float, 0 < beta < 1. Generally close to 1. 33 | beta_2: float, 0 < beta < 1. Generally close to 1. 34 | epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 35 | weight_decay: float >= 0. Weight decay. 36 | weight_decay_pattern: A list of strings. The substring of weight names to be decayed. 37 | All weights will be decayed if it is None. 38 | amsgrad: boolean. Whether to apply the AMSGrad variant of this 39 | algorithm from the paper "On the Convergence of Adam and 40 | Beyond". 41 | """ 42 | 43 | super(AdamWarmup, self).__init__(name, **kwargs) 44 | self._set_hyper('decay_steps', float(decay_steps)) 45 | self._set_hyper('warmup_steps', float(warmup_steps)) 46 | self._set_hyper('min_lr', min_lr) 47 | self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 48 | self._set_hyper('decay', self._initial_decay) 49 | self._set_hyper('beta_1', beta_1) 50 | self._set_hyper('beta_2', beta_2) 51 | self._set_hyper('weight_decay', weight_decay) 52 | self.epsilon = epsilon or backend_config.epsilon() 53 | self.amsgrad = amsgrad 54 | self._initial_weight_decay = weight_decay 55 | self._weight_decay_pattern = weight_decay_pattern 56 | 57 | def _create_slots(self, var_list): 58 | for var in var_list: 59 | self.add_slot(var, 'm') 60 | for var in var_list: 61 | self.add_slot(var, 'v') 62 | if self.amsgrad: 63 | for var in var_list: 64 | self.add_slot(var, 'vhat') 65 | 66 | def set_weights(self, weights): 67 | params = self.weights 68 | num_vars = int((len(params) - 1) / 2) 69 | if len(weights) == 3 * num_vars + 1: 70 | weights = weights[:len(params)] 71 | super(AdamWarmup, self).set_weights(weights) 72 | 73 | def _resource_apply_dense(self, grad, var): 74 | var_dtype = var.dtype.base_dtype 75 | lr_t = self._decayed_lr(var_dtype) 76 | m = self.get_slot(var, 'm') 77 | v = self.get_slot(var, 'v') 78 | beta_1_t = self._get_hyper('beta_1', var_dtype) 79 | beta_2_t = self._get_hyper('beta_2', var_dtype) 80 | epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) 81 | local_step = tf.cast(self.iterations + 1, var_dtype) 82 | beta_1_power = tf.pow(beta_1_t, local_step) 83 | beta_2_power = tf.pow(beta_2_t, local_step) 84 | 85 | decay_steps = self._get_hyper('decay_steps', var_dtype) 86 | warmup_steps = self._get_hyper('warmup_steps', var_dtype) 87 | min_lr = self._get_hyper('min_lr', var_dtype) 88 | lr_t = tf.where( 89 | local_step <= warmup_steps, 90 | lr_t * (local_step / warmup_steps), 91 | min_lr + (lr_t - min_lr) * (1.0 - tf.minimum(local_step, decay_steps) / decay_steps), 92 | ) 93 | lr_t = (lr_t * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)) 94 | 95 | m_t = state_ops.assign(m, 96 | beta_1_t * m + (1.0 - beta_1_t) * grad, 97 | use_locking=self._use_locking) 98 | 99 | v_t = state_ops.assign(v, 100 | beta_2_t * v + (1.0 - beta_2_t) * tf.square(grad), 101 | use_locking=self._use_locking) 102 | 103 | if self.amsgrad: 104 | v_hat = self.get_slot(var, 'vhat') 105 | v_hat_t = tf.maximum(v_hat, v_t) 106 | var_update = m_t / (tf.sqrt(v_hat_t) + epsilon_t) 107 | else: 108 | var_update = m_t / (tf.sqrt(v_t) + epsilon_t) 109 | 110 | if self._initial_weight_decay > 0.0: 111 | weight_decay = self._get_hyper('weight_decay', var_dtype) 112 | var_update += weight_decay * var 113 | var_update = state_ops.assign_sub(var, lr_t * var_update, use_locking=self._use_locking) 114 | 115 | updates = [var_update, m_t, v_t] 116 | if self.amsgrad: 117 | updates.append(v_hat_t) 118 | return control_flow_ops.group(*updates) 119 | 120 | def _resource_apply_sparse(self, grad, var, indices): 121 | var_dtype = var.dtype.base_dtype 122 | lr_t = self._decayed_lr(var_dtype) 123 | beta_1_t = self._get_hyper('beta_1', var_dtype) 124 | beta_2_t = self._get_hyper('beta_2', var_dtype) 125 | epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) 126 | local_step = tf.cast(self.iterations + 1, var_dtype) 127 | beta_1_power = tf.pow(beta_1_t, local_step) 128 | beta_2_power = tf.pow(beta_2_t, local_step) 129 | 130 | decay_steps = self._get_hyper('decay_steps', var_dtype) 131 | warmup_steps = self._get_hyper('warmup_steps', var_dtype) 132 | min_lr = self._get_hyper('min_lr', var_dtype) 133 | lr_t = tf.where( 134 | local_step <= warmup_steps, 135 | lr_t * (local_step / warmup_steps), 136 | min_lr + (lr_t - min_lr) * (1.0 - tf.minimum(local_step, decay_steps) / decay_steps), 137 | ) 138 | lr_t = (lr_t * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)) 139 | 140 | m = self.get_slot(var, 'm') 141 | m_scaled_g_values = grad * (1 - beta_1_t) 142 | m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) 143 | with tf.control_dependencies([m_t]): 144 | m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 145 | 146 | v = self.get_slot(var, 'v') 147 | v_scaled_g_values = (grad * grad) * (1 - beta_2_t) 148 | v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) 149 | with tf.control_dependencies([v_t]): 150 | v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 151 | 152 | if self.amsgrad: 153 | v_hat = self.get_slot(var, 'vhat') 154 | v_hat_t = tf.maximum(v_hat, v_t) 155 | var_update = m_t / (tf.sqrt(v_hat_t) + epsilon_t) 156 | else: 157 | var_update = m_t / (tf.sqrt(v_t) + epsilon_t) 158 | 159 | if self._initial_weight_decay > 0.0: 160 | weight_decay = self._get_hyper('weight_decay', var_dtype) 161 | var_update += weight_decay * var 162 | var_update = state_ops.assign_sub(var, lr_t * var_update, use_locking=self._use_locking) 163 | 164 | updates = [var_update, m_t, v_t] 165 | if self.amsgrad: 166 | updates.append(v_hat_t) 167 | return control_flow_ops.group(*updates) 168 | 169 | def get_config(self): 170 | config = super(AdamWarmup, self).get_config() 171 | config.update({ 172 | 'decay_steps': self._serialize_hyperparameter('decay_steps'), 173 | 'warmup_steps': self._serialize_hyperparameter('warmup_steps'), 174 | 'min_lr': self._serialize_hyperparameter('min_lr'), 175 | 'learning_rate': self._serialize_hyperparameter('learning_rate'), 176 | 'decay': self._serialize_hyperparameter('decay'), 177 | 'beta_1': self._serialize_hyperparameter('beta_1'), 178 | 'beta_2': self._serialize_hyperparameter('beta_2'), 179 | 'weight_decay': self._serialize_hyperparameter('weight_decay'), 180 | 'epsilon': self.epsilon, 181 | 'amsgrad': self.amsgrad, 182 | }) 183 | return config 184 | -------------------------------------------------------------------------------- /keras_bert/loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import codecs 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow import keras 7 | 8 | from .bert import get_model 9 | 10 | __all__ = [ 11 | 'build_model_from_config', 12 | 'load_model_weights_from_checkpoint', 13 | 'load_trained_model_from_checkpoint', 14 | 'load_vocabulary', 15 | ] 16 | 17 | 18 | def checkpoint_loader(checkpoint_file): 19 | def _loader(name): 20 | return tf.train.load_variable(checkpoint_file, name) 21 | return _loader 22 | 23 | 24 | def build_model_from_config(config_file, 25 | training=False, 26 | trainable=None, 27 | output_layer_num=1, 28 | seq_len=int(1e9), 29 | **kwargs): 30 | """Build the model from config file. 31 | 32 | :param config_file: The path to the JSON configuration file. 33 | :param training: If training, the whole model will be returned. 34 | Otherwise, the MLM and NSP parts will be ignored. 35 | :param trainable: Whether the model is trainable. 36 | :param output_layer_num: The number of layers whose outputs will be concatenated as a single output. 37 | Only available when `training` is `False`. 38 | :param seq_len: If it is not None and it is shorter than the value in the config file, the weights in 39 | position embeddings will be sliced to fit the new length. 40 | :return: model and config 41 | """ 42 | with open(config_file, 'r') as reader: 43 | config = json.loads(reader.read()) 44 | if seq_len is not None: 45 | config['max_position_embeddings'] = seq_len = min(seq_len, config['max_position_embeddings']) 46 | if trainable is None: 47 | trainable = training 48 | model = get_model( 49 | token_num=config['vocab_size'], 50 | pos_num=config['max_position_embeddings'], 51 | seq_len=seq_len, 52 | embed_dim=config['hidden_size'], 53 | transformer_num=config['num_hidden_layers'], 54 | head_num=config['num_attention_heads'], 55 | feed_forward_dim=config['intermediate_size'], 56 | feed_forward_activation=config['hidden_act'], 57 | training=training, 58 | trainable=trainable, 59 | output_layer_num=output_layer_num, 60 | **kwargs) 61 | if not training: 62 | inputs, outputs = model 63 | model = keras.models.Model(inputs=inputs, outputs=outputs) 64 | return model, config 65 | 66 | 67 | def load_model_weights_from_checkpoint(model, 68 | config, 69 | checkpoint_file, 70 | training=False): 71 | """Load trained official model from checkpoint. 72 | 73 | :param model: Built keras model. 74 | :param config: Loaded configuration file. 75 | :param checkpoint_file: The path to the checkpoint files, should end with '.ckpt'. 76 | :param training: If training, the whole model will be returned. 77 | Otherwise, the MLM and NSP parts will be ignored. 78 | """ 79 | loader = checkpoint_loader(checkpoint_file) 80 | 81 | model.get_layer(name='Embedding-Token').set_weights([ 82 | loader('bert/embeddings/word_embeddings'), 83 | ]) 84 | model.get_layer(name='Embedding-Position').set_weights([ 85 | loader('bert/embeddings/position_embeddings')[:config['max_position_embeddings'], :], 86 | ]) 87 | model.get_layer(name='Embedding-Segment').set_weights([ 88 | loader('bert/embeddings/token_type_embeddings'), 89 | ]) 90 | model.get_layer(name='Embedding-Norm').set_weights([ 91 | loader('bert/embeddings/LayerNorm/gamma'), 92 | loader('bert/embeddings/LayerNorm/beta'), 93 | ]) 94 | for i in range(config['num_hidden_layers']): 95 | try: 96 | model.get_layer(name='Encoder-%d-MultiHeadSelfAttention' % (i + 1)) 97 | except ValueError as e: 98 | continue 99 | model.get_layer(name='Encoder-%d-MultiHeadSelfAttention' % (i + 1)).set_weights([ 100 | loader('bert/encoder/layer_%d/attention/self/query/kernel' % i), 101 | loader('bert/encoder/layer_%d/attention/self/query/bias' % i), 102 | loader('bert/encoder/layer_%d/attention/self/key/kernel' % i), 103 | loader('bert/encoder/layer_%d/attention/self/key/bias' % i), 104 | loader('bert/encoder/layer_%d/attention/self/value/kernel' % i), 105 | loader('bert/encoder/layer_%d/attention/self/value/bias' % i), 106 | loader('bert/encoder/layer_%d/attention/output/dense/kernel' % i), 107 | loader('bert/encoder/layer_%d/attention/output/dense/bias' % i), 108 | ]) 109 | model.get_layer(name='Encoder-%d-MultiHeadSelfAttention-Norm' % (i + 1)).set_weights([ 110 | loader('bert/encoder/layer_%d/attention/output/LayerNorm/gamma' % i), 111 | loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' % i), 112 | ]) 113 | model.get_layer(name='Encoder-%d-FeedForward' % (i + 1)).set_weights([ 114 | loader('bert/encoder/layer_%d/intermediate/dense/kernel' % i), 115 | loader('bert/encoder/layer_%d/intermediate/dense/bias' % i), 116 | loader('bert/encoder/layer_%d/output/dense/kernel' % i), 117 | loader('bert/encoder/layer_%d/output/dense/bias' % i), 118 | ]) 119 | model.get_layer(name='Encoder-%d-FeedForward-Norm' % (i + 1)).set_weights([ 120 | loader('bert/encoder/layer_%d/output/LayerNorm/gamma' % i), 121 | loader('bert/encoder/layer_%d/output/LayerNorm/beta' % i), 122 | ]) 123 | if training: 124 | model.get_layer(name='MLM-Dense').set_weights([ 125 | loader('cls/predictions/transform/dense/kernel'), 126 | loader('cls/predictions/transform/dense/bias'), 127 | ]) 128 | model.get_layer(name='MLM-Norm').set_weights([ 129 | loader('cls/predictions/transform/LayerNorm/gamma'), 130 | loader('cls/predictions/transform/LayerNorm/beta'), 131 | ]) 132 | model.get_layer(name='MLM-Sim').set_weights([ 133 | loader('cls/predictions/output_bias'), 134 | ]) 135 | model.get_layer(name='NSP-Dense').set_weights([ 136 | loader('bert/pooler/dense/kernel'), 137 | loader('bert/pooler/dense/bias'), 138 | ]) 139 | model.get_layer(name='NSP').set_weights([ 140 | np.transpose(loader('cls/seq_relationship/output_weights')), 141 | loader('cls/seq_relationship/output_bias'), 142 | ]) 143 | 144 | 145 | def load_trained_model_from_checkpoint(config_file, 146 | checkpoint_file, 147 | training=False, 148 | trainable=None, 149 | output_layer_num=1, 150 | seq_len=int(1e9), 151 | **kwargs): 152 | """Load trained official model from checkpoint. 153 | 154 | :param config_file: The path to the JSON configuration file. 155 | :param checkpoint_file: The path to the checkpoint files, should end with '.ckpt'. 156 | :param training: If training, the whole model will be returned. 157 | Otherwise, the MLM and NSP parts will be ignored. 158 | :param trainable: Whether the model is trainable. The default value is the same with `training`. 159 | :param output_layer_num: The number of layers whose outputs will be concatenated as a single output. 160 | Only available when `training` is `False`. 161 | :param seq_len: If it is not None and it is shorter than the value in the config file, the weights in 162 | position embeddings will be sliced to fit the new length. 163 | :return: model 164 | """ 165 | model, config = build_model_from_config( 166 | config_file, 167 | training=training, 168 | trainable=trainable, 169 | output_layer_num=output_layer_num, 170 | seq_len=seq_len, 171 | **kwargs) 172 | load_model_weights_from_checkpoint(model, config, checkpoint_file, training=training) 173 | return model 174 | 175 | 176 | def load_vocabulary(vocab_path): 177 | token_dict = {} 178 | with codecs.open(vocab_path, 'r', 'utf8') as reader: 179 | for line in reader: 180 | token = line.strip() 181 | token_dict[token] = len(token_dict) 182 | return token_dict 183 | -------------------------------------------------------------------------------- /README.zh-CN.md: -------------------------------------------------------------------------------- 1 | # Keras BERT 2 | 3 | [![Version](https://img.shields.io/pypi/v/keras-bert.svg)](https://pypi.org/project/keras-bert/) 4 | ![License](https://img.shields.io/pypi/l/keras-bert.svg) 5 | 6 | \[[中文](https://github.com/CyberZHG/keras-bert/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-bert/blob/master/README.md)\] 7 | 8 | [BERT](https://arxiv.org/pdf/1810.04805.pdf)的非官方实现,可以加载官方的预训练模型进行特征提取和预测。 9 | 10 | ## 安装 11 | 12 | ```bash 13 | pip install keras-bert 14 | ``` 15 | 16 | ## 使用 17 | 18 | * [使用官方模型](#使用官方模型) 19 | * [分词](#分词) 20 | * [训练和使用](#训练和使用) 21 | * [关于`training`和`trainable`](#关于training和trainable) 22 | * [使用Warmup](#使用Warmup) 23 | * [关于输入](#关于输入) 24 | * [下载预训练模型](#下载预训练模型) 25 | * [提取特征](#提取特征) 26 | * [模型存储与加载](#模型存储与加载) 27 | * [使用任务嵌入](#使用任务嵌入) 28 | * [使用`tf.keras`](#使用tensorflowpythonkeras) 29 | 30 | ### External Links 31 | 32 | * [Kashgari是一个极简且强大的 NLP 框架,可用于文本分类和标注的学习,研究及部署上线](https://github.com/BrikerMan/Kashgari) 33 | * [当Bert遇上Keras:这可能是Bert最简单的打开姿势](https://spaces.ac.cn/archives/6736) 34 | * [Keras ALBERT](https://github.com/TinkerMob/keras_albert_model) 35 | 36 | ### 使用官方模型 37 | 38 | [特征提取展示](./demo/load_model/load_and_extract.py)中使用官方预训练好的`chinese_L-12_H-768_A-12`可以得到和官方工具一样的结果。 39 | 40 | [预测展示](./demo/load_model/load_and_predict.py)中可以填补出缺失词并预测是否是上下文。 41 | 42 | ### 使用TPU 43 | 44 | [特征提取示例](https://colab.research.google.com/github/CyberZHG/keras-bert/blob/master/demo/load_model/keras_bert_load_and_extract_tpu.ipynb)中展示了如何在TPU上进行特征提取。 45 | 46 | [分类示例](https://colab.research.google.com/github/CyberZHG/keras-bert/blob/master/demo/tune/keras_bert_classification_tpu.ipynb)中在IMDB数据集上对模型进行了微调以适应新的分类任务。 47 | 48 | ### 分词 49 | 50 | `Tokenizer`类可以用来进行分词工作,包括归一化和英文部分的最大贪心匹配等,在CJK字符集内的中文会以单字分隔。 51 | 52 | ```python 53 | from keras_bert import Tokenizer 54 | 55 | token_dict = { 56 | '[CLS]': 0, 57 | '[SEP]': 1, 58 | 'un': 2, 59 | '##aff': 3, 60 | '##able': 4, 61 | '[UNK]': 5, 62 | } 63 | tokenizer = Tokenizer(token_dict) 64 | print(tokenizer.tokenize('unaffable')) # 分词结果是:`['[CLS]', 'un', '##aff', '##able', '[SEP]']` 65 | 66 | indices, segments = tokenizer.encode('unaffable') 67 | print(indices) # 词对应的下标:`[0, 2, 3, 4, 1]` 68 | print(segments) # 段落对应下标:`[0, 0, 0, 0, 0]` 69 | 70 | print(tokenizer.tokenize(first='unaffable', second='钢')) 71 | # 分词结果是:`['[CLS]', 'un', '##aff', '##able', '[SEP]', '钢', '[SEP]']` 72 | indices, segments = tokenizer.encode(first='unaffable', second='钢', max_len=10) 73 | print(indices) # 词对应的下标:`[0, 2, 3, 4, 1, 5, 1, 0, 0, 0]` 74 | print(segments) # 段落对应下标:`[0, 0, 0, 0, 0, 1, 1, 0, 0, 0]` 75 | ``` 76 | 77 | `Tokenizer`也提供了尝试去寻找分词后的结果在原始文本中的起始和终止下标的功能,输入可以是decode后的结果,包含少量的错词: 78 | 79 | ```python 80 | from keras_bert import Tokenizer 81 | 82 | intervals = Tokenizer.rematch("All rights reserved.", ["[UNK]", "righs", "[UNK]", "ser", "[UNK]", "[UNK]"]) 83 | # 结果是:[(0, 3), (4, 10), (11, 13), (13, 16), (16, 19), (19, 20)] 84 | ``` 85 | 86 | ### 训练和使用 87 | 88 | 训练过程推荐使用官方的代码。这个代码库内包含一个的训练过程,`training`为`True`的情况下使用的是带warmup的Adam优化器: 89 | 90 | ```python 91 | from tensorflow import keras 92 | from keras_bert import get_base_dict, get_model, compile_model, gen_batch_inputs 93 | 94 | 95 | # 随便的输入样例: 96 | sentence_pairs = [ 97 | [['all', 'work', 'and', 'no', 'play'], ['makes', 'jack', 'a', 'dull', 'boy']], 98 | [['from', 'the', 'day', 'forth'], ['my', 'arm', 'changed']], 99 | [['and', 'a', 'voice', 'echoed'], ['power', 'give', 'me', 'more', 'power']], 100 | ] 101 | 102 | 103 | # 构建自定义词典 104 | token_dict = get_base_dict() # 初始化特殊符号,如`[CLS]` 105 | for pairs in sentence_pairs: 106 | for token in pairs[0] + pairs[1]: 107 | if token not in token_dict: 108 | token_dict[token] = len(token_dict) 109 | token_list = list(token_dict.keys()) # Used for selecting a random word 110 | 111 | 112 | # 构建和训练模型 113 | model = get_model( 114 | token_num=len(token_dict), 115 | head_num=5, 116 | transformer_num=12, 117 | embed_dim=25, 118 | feed_forward_dim=100, 119 | seq_len=20, 120 | pos_num=20, 121 | dropout_rate=0.05, 122 | ) 123 | compile_model(model) 124 | model.summary() 125 | 126 | def _generator(): 127 | while True: 128 | yield gen_batch_inputs( 129 | sentence_pairs, 130 | token_dict, 131 | token_list, 132 | seq_len=20, 133 | mask_rate=0.3, 134 | swap_sentence_rate=1.0, 135 | ) 136 | 137 | model.fit_generator( 138 | generator=_generator(), 139 | steps_per_epoch=1000, 140 | epochs=100, 141 | validation_data=_generator(), 142 | validation_steps=100, 143 | callbacks=[ 144 | keras.callbacks.EarlyStopping(monitor='val_loss', patience=5) 145 | ], 146 | ) 147 | 148 | 149 | # 使用训练好的模型 150 | inputs, output_layer = get_model( 151 | token_num=len(token_dict), 152 | head_num=5, 153 | transformer_num=12, 154 | embed_dim=25, 155 | feed_forward_dim=100, 156 | seq_len=20, 157 | pos_num=20, 158 | dropout_rate=0.05, 159 | training=False, # 当`training`是`False`,返回值是输入和输出 160 | trainable=False, # 模型是否可训练,默认值和`training`相同 161 | output_layer_num=4, # 最后几层的输出将合并在一起作为最终的输出,只有当`training`是`False`有效 162 | ) 163 | ``` 164 | 165 | #### 关于`training`和`trainable` 166 | 167 | 虽然看起来相似,但这两个参数是不相关的。`training`表示是否在训练BERT语言模型,当为`True`时完整的BERT模型会被返回,当为`False`时没有MLM和NSP相关计算的结构,返回输入层和根据`output_layer_num`合并最后几层的输出。加载的层是否可训练只跟`trainable`有关。 168 | 169 | 此外,`trainable`可以是一个包含字符串的列表,如果某一层的前缀出现在列表中,则当前层是可训练的。在使用预训练模型时,如果不想再训练嵌入层,可以传入`trainable=['Encoder']`来只对编码层进行调整。 170 | 171 | ### 使用Warmup 172 | 173 | `AdamWarmup`优化器可用于学习率的「热身」与「衰减」。学习率将在`warmpup_steps`步线性增长到`lr`,并在总共`decay_steps`步后线性减少到`min_lr`。辅助函数`calc_train_steps`可用于计算这两个步数: 174 | 175 | ```python 176 | import numpy as np 177 | from keras_bert import AdamWarmup, calc_train_steps 178 | 179 | train_x = np.random.standard_normal((1024, 100)) 180 | 181 | total_steps, warmup_steps = calc_train_steps( 182 | num_example=train_x.shape[0], 183 | batch_size=32, 184 | epochs=10, 185 | warmup_proportion=0.1, 186 | ) 187 | 188 | optimizer = AdamWarmup(total_steps, warmup_steps, lr=1e-3, min_lr=1e-5) 189 | ``` 190 | 191 | ### 关于输入 192 | 193 | 在`training`为`True`的情况下,输入包含三项:token下标、segment下标、被masked的词的模版。当`training`为`False`时输入只包含前两项。位置下标由于是固定的,会在模型内部生成,不需要手动再输入一遍。被masked的词的模版在输入被masked的词是值为1,否则为0。 194 | 195 | ### 下载预训练模型 196 | 197 | 库中记录了一些预训练模型的下载地址,可以通过如下方式获得解压后的checkpoint的路径: 198 | 199 | ```python 200 | from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths 201 | 202 | model_path = get_pretrained(PretrainedList.multi_cased_base) 203 | paths = get_checkpoint_paths(model_path) 204 | print(paths.config, paths.checkpoint, paths.vocab) 205 | ``` 206 | 207 | ### 提取特征 208 | 209 | 如果不需要微调,只想提取词/句子的特征,则可以使用`extract_embeddings`来简化流程。如提取每个句子对应的全部词的特征: 210 | 211 | ```python 212 | from keras_bert import extract_embeddings 213 | 214 | model_path = 'xxx/yyy/uncased_L-12_H-768_A-12' 215 | texts = ['all work and no play', 'makes jack a dull boy~'] 216 | 217 | embeddings = extract_embeddings(model_path, texts) 218 | ``` 219 | 220 | 返回的结果是一个list,长度和输入文本的个数相同,每个元素都是numpy的数组,默认会根据输出的长度进行裁剪,所以在这个例子中输出的大小分别为`(7, 768)`和`(8, 768)`。 221 | 222 | 如果输入是成对的句子,想使用最后4层特征,且提取`NSP`位输出和max-pooling的结果,则可以用: 223 | 224 | ```python 225 | from keras_bert import extract_embeddings, POOL_NSP, POOL_MAX 226 | 227 | model_path = 'xxx/yyy/uncased_L-12_H-768_A-12' 228 | texts = [ 229 | ('all work and no play', 'makes jack a dull boy'), 230 | ('makes jack a dull boy', 'all work and no play'), 231 | ] 232 | 233 | embeddings = extract_embeddings(model_path, texts, output_layer_num=4, poolings=[POOL_NSP, POOL_MAX]) 234 | ``` 235 | 236 | 输出结果中不再包含词的特征,`NSP`和max-pooling的输出会拼接在一起,每个numpy数组的大小为`(768 x 4 x 2,)`。 237 | 238 | 第二个参数接受的是一个generator,如果想读取文件并生成特征,可以用下面的方法: 239 | 240 | ```python 241 | import codecs 242 | from keras_bert import extract_embeddings 243 | 244 | model_path = 'xxx/yyy/uncased_L-12_H-768_A-12' 245 | 246 | with codecs.open('xxx.txt', 'r', 'utf8') as reader: 247 | texts = map(lambda x: x.strip(), reader) 248 | embeddings = extract_embeddings(model_path, texts) 249 | ``` 250 | 251 | ### 模型存储与加载 252 | 253 | ```python 254 | from keras_bert import load_trained_model_from_checkpoint, get_custom_objects 255 | 256 | model = load_trained_model_from_checkpoint('xxx', 'yyy') 257 | model.save('save_path.h5') 258 | model.load('save_path.h5', custom_objects=get_custom_objects()) 259 | ``` 260 | 261 | ### 使用任务嵌入 262 | 263 | 如果有多任务训练的需求,可以启用任务嵌入层,针对不同任务将嵌入的结果加上不同的编码,注意要让`Embedding-Task`层可训练: 264 | 265 | ```python 266 | from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths, load_trained_model_from_checkpoint 267 | 268 | model_path = get_pretrained(PretrainedList.multi_cased_base) 269 | paths = get_checkpoint_paths(model_path) 270 | model = load_trained_model_from_checkpoint( 271 | config_file=paths.config, 272 | checkpoint_file=paths.checkpoint, 273 | training=False, 274 | trainable=True, 275 | use_task_embed=True, 276 | task_num=10, 277 | ) 278 | ``` 279 | -------------------------------------------------------------------------------- /keras_bert/tokenizer.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | from keras_bert.bert import TOKEN_CLS, TOKEN_SEP, TOKEN_UNK 3 | 4 | 5 | class Tokenizer(object): 6 | 7 | def __init__(self, 8 | token_dict, 9 | token_cls=TOKEN_CLS, 10 | token_sep=TOKEN_SEP, 11 | token_unk=TOKEN_UNK, 12 | pad_index=0, 13 | cased=False): 14 | """Initialize tokenizer. 15 | 16 | :param token_dict: A dict maps tokens to indices. 17 | :param token_cls: The token represents classification. 18 | :param token_sep: The token represents separator. 19 | :param token_unk: The token represents unknown token. 20 | :param pad_index: The index to pad. 21 | :param cased: Whether to keep the case. 22 | """ 23 | self._token_dict = token_dict 24 | self._token_dict_inv = {v: k for k, v in token_dict.items()} 25 | self._token_cls = token_cls 26 | self._token_sep = token_sep 27 | self._token_unk = token_unk 28 | self._pad_index = pad_index 29 | self._cased = cased 30 | 31 | @staticmethod 32 | def _truncate(first_tokens, second_tokens=None, max_len=None): 33 | if max_len is None: 34 | return 35 | 36 | if second_tokens is not None: 37 | while True: 38 | total_len = len(first_tokens) + len(second_tokens) 39 | if total_len <= max_len - 3: # 3 for [CLS] .. tokens_a .. [SEP] .. tokens_b [SEP] 40 | break 41 | if len(first_tokens) > len(second_tokens): 42 | first_tokens.pop() 43 | else: 44 | second_tokens.pop() 45 | else: 46 | del first_tokens[max_len - 2:] # 2 for [CLS] .. tokens .. [SEP] 47 | 48 | def _pack(self, first_tokens, second_tokens=None): 49 | first_packed_tokens = [self._token_cls] + first_tokens + [self._token_sep] 50 | if second_tokens is not None: 51 | second_packed_tokens = second_tokens + [self._token_sep] 52 | return first_packed_tokens + second_packed_tokens, len(first_packed_tokens), len(second_packed_tokens) 53 | else: 54 | return first_packed_tokens, len(first_packed_tokens), 0 55 | 56 | def _convert_tokens_to_ids(self, tokens): 57 | unk_id = self._token_dict.get(self._token_unk) 58 | return [self._token_dict.get(token, unk_id) for token in tokens] 59 | 60 | def tokenize(self, first, second=None): 61 | """Split text to tokens. 62 | 63 | :param first: First text. 64 | :param second: Second text. 65 | :return: A list of strings. 66 | """ 67 | first_tokens = self._tokenize(first) 68 | second_tokens = self._tokenize(second) if second is not None else None 69 | tokens, _, _ = self._pack(first_tokens, second_tokens) 70 | return tokens 71 | 72 | def encode(self, first, second=None, max_len=None): 73 | first_tokens = self._tokenize(first) 74 | second_tokens = self._tokenize(second) if second is not None else None 75 | self._truncate(first_tokens, second_tokens, max_len) 76 | tokens, first_len, second_len = self._pack(first_tokens, second_tokens) 77 | 78 | token_ids = self._convert_tokens_to_ids(tokens) 79 | segment_ids = [0] * first_len + [1] * second_len 80 | 81 | if max_len is not None: 82 | pad_len = max_len - first_len - second_len 83 | token_ids += [self._pad_index] * pad_len 84 | segment_ids += [0] * pad_len 85 | 86 | return token_ids, segment_ids 87 | 88 | def decode(self, ids): 89 | sep = ids.index(self._token_dict[self._token_sep]) 90 | try: 91 | stop = ids.index(self._pad_index) 92 | except ValueError as e: 93 | stop = len(ids) 94 | tokens = [self._token_dict_inv[i] for i in ids] 95 | first = tokens[1:sep] 96 | if sep < stop - 1: 97 | second = tokens[sep + 1:stop - 1] 98 | return first, second 99 | return first 100 | 101 | def _tokenize(self, text): 102 | if not self._cased: 103 | text = unicodedata.normalize('NFD', text) 104 | text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn']) 105 | text = text.lower() 106 | spaced = '' 107 | for ch in text: 108 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 109 | spaced += ' ' + ch + ' ' 110 | elif self._is_space(ch): 111 | spaced += ' ' 112 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 113 | continue 114 | else: 115 | spaced += ch 116 | tokens = [] 117 | for word in spaced.strip().split(): 118 | tokens += self._word_piece_tokenize(word) 119 | return tokens 120 | 121 | def _word_piece_tokenize(self, word): 122 | if word in self._token_dict: 123 | return [word] 124 | tokens = [] 125 | start, stop = 0, 0 126 | while start < len(word): 127 | stop = len(word) 128 | while stop > start: 129 | sub = word[start:stop] 130 | if start > 0: 131 | sub = '##' + sub 132 | if sub in self._token_dict: 133 | break 134 | stop -= 1 135 | if start == stop: 136 | stop += 1 137 | tokens.append(sub) 138 | start = stop 139 | return tokens 140 | 141 | @staticmethod 142 | def _is_punctuation(ch): 143 | code = ord(ch) 144 | return 33 <= code <= 47 or \ 145 | 58 <= code <= 64 or \ 146 | 91 <= code <= 96 or \ 147 | 123 <= code <= 126 or \ 148 | unicodedata.category(ch).startswith('P') 149 | 150 | @staticmethod 151 | def _is_cjk_character(ch): 152 | code = ord(ch) 153 | return 0x4E00 <= code <= 0x9FFF or \ 154 | 0x3400 <= code <= 0x4DBF or \ 155 | 0x20000 <= code <= 0x2A6DF or \ 156 | 0x2A700 <= code <= 0x2B73F or \ 157 | 0x2B740 <= code <= 0x2B81F or \ 158 | 0x2B820 <= code <= 0x2CEAF or \ 159 | 0xF900 <= code <= 0xFAFF or \ 160 | 0x2F800 <= code <= 0x2FA1F 161 | 162 | @staticmethod 163 | def _is_space(ch): 164 | return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ 165 | unicodedata.category(ch) == 'Zs' 166 | 167 | @staticmethod 168 | def _is_control(ch): 169 | return unicodedata.category(ch) in ('Cc', 'Cf') 170 | 171 | @staticmethod 172 | def rematch(text, tokens, cased=False, unknown_token=TOKEN_UNK): 173 | """Try to find the indices of tokens in the original text. 174 | 175 | >>> Tokenizer.rematch("All rights reserved.", ["all", "rights", "re", "##ser", "##ved", "."]) 176 | [(0, 3), (4, 10), (11, 13), (13, 16), (16, 19), (19, 20)] 177 | >>> Tokenizer.rematch("All rights reserved.", ["all", "rights", "re", "##ser", "[UNK]", "."]) 178 | [(0, 3), (4, 10), (11, 13), (13, 16), (16, 19), (19, 20)] 179 | >>> Tokenizer.rematch("All rights reserved.", ["[UNK]", "rights", "[UNK]", "##ser", "[UNK]", "[UNK]"]) 180 | [(0, 3), (4, 10), (11, 13), (13, 16), (16, 19), (19, 20)] 181 | >>> Tokenizer.rematch("All rights reserved.", ["[UNK]", "righs", "[UNK]", "ser", "[UNK]", "[UNK]"]) 182 | [(0, 3), (4, 10), (11, 13), (13, 16), (16, 19), (19, 20)] 183 | >>> Tokenizer.rematch("All rights reserved.", 184 | ... ["[UNK]", "rights", "[UNK]", "[UNK]", "[UNK]", "[UNK]"]) # doctest:+ELLIPSIS 185 | [(0, 3), (4, 10), (11, ... 19), (19, 20)] 186 | >>> Tokenizer.rematch("All rights reserved.", ["all rights", "reserved", "."]) 187 | [(0, 10), (11, 19), (19, 20)] 188 | >>> Tokenizer.rematch("All rights reserved.", ["all rights", "reserved", "."], cased=True) 189 | [(0, 10), (11, 19), (19, 20)] 190 | >>> Tokenizer.rematch("#hash tag ##", ["#", "hash", "tag", "##"]) 191 | [(0, 1), (1, 5), (6, 9), (10, 12)] 192 | >>> Tokenizer.rematch("嘛呢,吃了吗?", ["[UNK]", "呢", ",", "[UNK]", "了", "吗", "?"]) 193 | [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)] 194 | >>> Tokenizer.rematch(" 吃了吗? ", ["吃", "了", "吗", "?"]) 195 | [(2, 3), (3, 4), (4, 5), (5, 6)] 196 | 197 | :param text: Original text. 198 | :param tokens: Decoded list of tokens. 199 | :param cased: Whether it is cased. 200 | :param unknown_token: The representation of unknown token. 201 | :return: A list of tuples represents the start and stop locations in the original text. 202 | """ 203 | decoded, token_offsets = '', [] 204 | for token in tokens: 205 | token_offsets.append([len(decoded), 0]) 206 | if token == unknown_token: 207 | token = '#' 208 | if not cased: 209 | token = token.lower() 210 | if len(token) > 2 and token.startswith('##'): 211 | token = token[2:] 212 | elif len(decoded) > 0: 213 | token = ' ' + token 214 | token_offsets[-1][0] += 1 215 | decoded += token 216 | token_offsets[-1][1] = len(decoded) 217 | 218 | heading = 0 219 | text = text.rstrip() 220 | for i in range(len(text)): 221 | if not Tokenizer._is_space(text[i]): 222 | break 223 | heading += 1 224 | text = text[heading:] 225 | len_text, len_decode = len(text), len(decoded) 226 | costs = [[0] * (len_decode + 1) for _ in range(2)] 227 | paths = [[(-1, -1)] * (len_decode + 1) for _ in range(len_text + 1)] 228 | curr, prev = 0, 1 229 | 230 | for j in range(len_decode + 1): 231 | costs[curr][j] = j 232 | for i in range(1, len_text + 1): 233 | curr, prev = prev, curr 234 | costs[curr][0] = i 235 | ch = text[i - 1] 236 | if not cased: 237 | ch = ch.lower() 238 | for j in range(1, len_decode + 1): 239 | costs[curr][j] = costs[prev][j - 1] 240 | paths[i][j] = (i - 1, j - 1) 241 | if ch != decoded[j - 1]: 242 | costs[curr][j] = costs[prev][j - 1] 243 | paths[i][j] = (i - 1, j - 1) 244 | if costs[prev][j] < costs[curr][j]: 245 | costs[curr][j] = costs[prev][j] 246 | paths[i][j] = (i - 1, j) 247 | if costs[curr][j - 1] < costs[curr][j]: 248 | costs[curr][j] = costs[curr][j - 1] 249 | paths[i][j] = (i, j - 1) 250 | costs[curr][j] += 1 251 | 252 | matches = [0] * (len_decode + 1) 253 | position = (len_text, len_decode) 254 | while position != (-1, -1): 255 | i, j = position 256 | matches[j] = i 257 | position = paths[i][j] 258 | 259 | intervals = [[matches[offset[0]], matches[offset[1]]] for offset in token_offsets] 260 | for i, interval in enumerate(intervals): 261 | token_a, token_b = text[interval[0]:interval[1]], tokens[i] 262 | if len(token_b) > 2 and token_b.startswith('##'): 263 | token_b = token_b[2:] 264 | if not cased: 265 | token_a, token_b = token_a.lower(), token_b.lower() 266 | if token_a == token_b: 267 | continue 268 | if i == 0: 269 | border = 0 270 | else: 271 | border = intervals[i - 1][1] 272 | for j in range(interval[0] - 1, border - 1, -1): 273 | if Tokenizer._is_space(text[j]): 274 | break 275 | interval[0] -= 1 276 | if i + 1 == len(intervals): 277 | border = len_text 278 | else: 279 | border = intervals[i + 1][0] 280 | for j in range(interval[1], border): 281 | if Tokenizer._is_space(text[j]): 282 | break 283 | interval[1] += 1 284 | intervals = [(interval[0] + heading, interval[1] + heading) for interval in intervals] 285 | return intervals 286 | -------------------------------------------------------------------------------- /demo/load_model/keras_bert_load_and_extract_tpu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 0, 6 | "metadata": { 7 | "id": "doNFRjPqiBhM", 8 | "colab_type": "code", 9 | "outputId": "8e86b76a-fae7-4a0b-db24-447986c06f96", 10 | "colab": { 11 | "base_uri": "https://localhost:8080/", 12 | "height": 119.0 13 | } 14 | }, 15 | "outputs": [ 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "Archive: uncased_L-12_H-768_A-12.zip\n", 21 | " inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.meta \n", 22 | " inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001 \n", 23 | " inflating: uncased_L-12_H-768_A-12/vocab.txt \n", 24 | " inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.index \n", 25 | " inflating: uncased_L-12_H-768_A-12/bert_config.json \n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "# @title Preparation\n", 31 | "!pip install -q keras-bert\n", 32 | "!wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n", 33 | "!unzip -o uncased_L-12_H-768_A-12.zip" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 0, 39 | "metadata": { 40 | "id": "KUQ8UtquieFj", 41 | "colab_type": "code", 42 | "colab": {} 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "# @title Environment\n", 47 | "import os\n", 48 | "\n", 49 | "pretrained_path = 'uncased_L-12_H-768_A-12'\n", 50 | "config_path = os.path.join(pretrained_path, 'bert_config.json')\n", 51 | "checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt')\n", 52 | "vocab_path = os.path.join(pretrained_path, 'vocab.txt')\n", 53 | "\n", 54 | "# TF_KERAS must be added to environment variables in order to use TPU\n", 55 | "os.environ['TF_KERAS'] = '1'" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 0, 61 | "metadata": { 62 | "id": "exDvuSwPevQP", 63 | "colab_type": "code", 64 | "colab": {} 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "# @title Initialize TPU Strategy\n", 69 | "\n", 70 | "import tensorflow as tf\n", 71 | "from keras_bert import get_custom_objects\n", 72 | "\n", 73 | "TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']\n", 74 | "resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)\n", 75 | "tf.contrib.distribute.initialize_tpu_system(resolver)\n", 76 | "strategy = tf.contrib.distribute.TPUStrategy(resolver)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 0, 82 | "metadata": { 83 | "id": "sVTPNxOyj4HJ", 84 | "colab_type": "code", 85 | "colab": {} 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "# @title Load Basic Model\n", 90 | "import codecs\n", 91 | "from keras_bert import load_trained_model_from_checkpoint\n", 92 | "\n", 93 | "token_dict = {}\n", 94 | "with codecs.open(vocab_path, 'r', 'utf8') as reader:\n", 95 | " for line in reader:\n", 96 | " token = line.strip()\n", 97 | " token_dict[token] = len(token_dict)\n", 98 | "\n", 99 | "with strategy.scope():\n", 100 | " model = load_trained_model_from_checkpoint(config_path, checkpoint_path)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 0, 106 | "metadata": { 107 | "id": "xioN-O_vtztC", 108 | "colab_type": "code", 109 | "outputId": "e147fde2-d1cb-439b-d63f-a649be4a0def", 110 | "colab": { 111 | "base_uri": "https://localhost:8080/", 112 | "height": 377.0 113 | } 114 | }, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "['[CLS]', 'from', 'that', 'day', 'forth', '.', '.', '.', 'my', 'arm', 'changed', '.', '.', '.', 'and', 'a', 'voice', 'echoed', '[SEP]']\n", 121 | "[CLS] [0.24250675737857819, 0.04605229198932648, -0.24484458565711975, -0.5553151369094849, -0.16091349720954895, -0.046603765338659286, 0.2648216784000397, 0.4632352590560913, -0.15888850390911102, -0.5054463148117065, 0.1872796267271042, -0.0844820961356163, -0.10551590472459793, 1.1445354223251343, 0.35309019684791565, 0.4587448537349701, 0.22255975008010864, 0.23128646612167358, 0.17905741930007935]\n", 122 | "from [0.2858668565750122, 0.12927496433258057, 0.08937370777130127, -0.06506256759166718, -0.18307062983512878, 0.4357893466949463, -0.4666714668273926, 1.1149680614471436, 0.26170825958251953, -1.0477269887924194, -0.7197380661964417, -0.30874621868133545, 0.3589649498462677, 0.43190720677375793, 0.510287880897522, 0.4445205330848694, 0.6695327162742615, 0.11726009100675583, 0.34817394614219666]\n", 123 | "that [-0.7514970302581787, 0.14548861980438232, 0.32245880365371704, -0.043174318969249725, 0.5105547904968262, 0.6441463828086853, 0.3476734757423401, 1.8521389961242676, -0.3416602611541748, -0.32603877782821655, 0.3870062232017517, -0.781717836856842, 0.7790629267692566, 0.38004353642463684, -0.026217259466648102, 1.1450964212417603, 0.8676536083221436, -0.006007872521877289, 0.4079648554325104]\n", 124 | "day [-0.5767358541488647, 0.16873040795326233, 0.7379938364028931, -0.43202054500579834, -0.2541211247444153, -0.12445046007633209, 0.0982133075594902, 0.6879273056983948, -0.5429892539978027, -1.1291855573654175, 0.10011367499828339, -0.7181406021118164, 0.11406952142715454, 0.7394105792045593, -0.14538182318210602, 0.8026691675186157, -0.0491754524409771, -0.5183604955673218, 0.5328009128570557]\n", 125 | "forth [0.6604690551757812, 0.44872960448265076, 0.22819921374320984, -1.302493691444397, -0.5676481127738953, 0.7631499171257019, 0.31730082631111145, 0.9745714664459229, 0.7384849786758423, -0.12949073314666748, 0.6890379786491394, -0.2063390463590622, 0.7543689608573914, 1.060841679573059, 0.2897346615791321, 1.460550308227539, 0.1994684338569641, -0.3415696620941162, 0.3051512837409973]\n", 126 | ". [0.7019387483596802, -1.0037682056427002, 0.5318847894668579, -0.2360236793756485, 0.7502506375312805, 0.3656868040561676, 0.47489359974861145, 1.2967716455459595, 0.2107715904712677, -1.135636568069458, -0.2472764253616333, -0.8922532796859741, -0.5542468428611755, 1.2267390489578247, 0.49479609727859497, 0.5181231498718262, 0.246188685297966, -0.028614960610866547, 0.36292847990989685]\n", 127 | ". [0.6940287351608276, -0.9777721762657166, 0.8698938488960266, 0.10904823988676071, 0.6616824269294739, -0.5494440793991089, 0.6290906667709351, 0.9006298780441284, -0.3196481764316559, -1.679004192352295, -0.45238935947418213, -1.0310773849487305, -0.4837859869003296, 0.8608845472335815, 0.3741556406021118, 0.7373818755149841, 0.06710120290517807, -0.005348920822143555, 0.5634738206863403]\n", 128 | ". [0.4909120500087738, -0.5309440493583679, 0.4432699978351593, 0.08617030084133148, 0.07127449661493301, 0.002989581786096096, 0.5167111158370972, 0.9511327743530273, 0.20563672482967377, -1.380079984664917, -0.13269132375717163, -0.5576552152633667, -0.6243664026260376, 0.33799853920936584, 0.7936240434646606, 0.006669636815786362, 0.23250898718833923, -0.2893766462802887, 0.013550074771046638]\n", 129 | "my [0.1981644332408905, 0.18493112921714783, 0.6035155653953552, -0.4148944020271301, -0.26000604033470154, 0.7049614787101746, 0.33354446291923523, 1.4654905796051025, -0.03241407871246338, -0.844508171081543, 0.22912199795246124, -0.4619494676589966, 0.1021852195262909, 0.6551948189735413, 0.4241233468055725, 0.28182658553123474, 0.8504406809806824, -0.7569733262062073, -0.03992771729826927]\n", 130 | "arm [0.1387634128332138, -0.18610626459121704, 0.043312449008226395, -0.30445870757102966, 0.25981974601745605, 0.43626904487609863, -0.1595546454191208, 0.6088374853134155, -0.2485939860343933, -0.947198748588562, -0.2202623188495636, -0.43322864174842834, 0.35308578610420227, 0.5505412220954895, -0.27582991123199463, 0.8192852735519409, 0.7734602093696594, 0.19990915060043335, -0.1323583722114563]\n", 131 | "changed [0.6847427487373352, -0.06571120023727417, 0.740653932094574, -0.3307543098926544, 0.5178371667861938, 0.4792756140232086, -0.1748911589384079, 0.7557358145713806, -0.8578546643257141, 0.029164336621761322, 0.8551721572875977, -0.9205213189125061, 0.11773020774126053, 0.9219750165939331, 0.4660824239253998, 0.33519038558006287, 0.5005586743354797, -0.4561012089252472, -0.8589122891426086]\n", 132 | ". [0.691735029220581, -0.9538367390632629, 0.6850607991218567, -0.33083242177963257, 0.6519606113433838, 0.18196794390678406, 0.14063923060894012, 1.0543153285980225, 0.09589402377605438, -1.1536606550216675, -0.3813367486000061, -0.7072638869285583, -0.5489786863327026, 1.2570656538009644, 0.4382341802120209, 0.4493916630744934, 0.46416541934013367, -0.1259787380695343, 0.13374735414981842]\n", 133 | ". [0.49735528230667114, -0.824055016040802, 1.0104886293411255, 0.1664600521326065, 0.6780160069465637, -0.6264775991439819, 0.3941092789173126, 0.7728183269500732, -0.33566147089004517, -1.6417758464813232, -0.3855264186859131, -1.174580454826355, -0.48833924531936646, 0.8608969449996948, 0.4212232530117035, 0.6263501048088074, 0.11568772792816162, -0.18229623138904572, 0.4028193950653076]\n", 134 | ". [0.5248441100120544, -0.8752838373184204, 0.8074136972427368, -0.30946090817451477, 0.5306240916252136, -0.11764449626207352, 0.6163278222084045, 0.8180994987487793, 0.0724523738026619, -0.7624974250793457, -0.3387123942375183, -0.6998348832130432, -0.7396746873855591, 0.9056875705718994, 0.3457651436328888, 0.07841179519891739, -0.011172758415341377, -0.5140762329101562, 0.0032963640987873077]\n", 135 | "and [-0.1338833123445511, -0.19814497232437134, -0.16618409752845764, 0.07875220477581024, -0.1724027693271637, 0.7233635783195496, -0.1553294062614441, 1.6001187562942505, 0.045651875436306, -1.3001012802124023, 0.1902690976858139, -0.34323716163635254, 0.19466280937194824, 0.6093133687973022, 0.4001396894454956, 0.5367133617401123, 0.609708845615387, -0.5849275588989258, -0.5499515533447266]\n", 136 | "a [-0.07800914347171783, 0.08851209282875061, -0.12622755765914917, -0.36778488755226135, 0.5426658391952515, 0.3223933279514313, 0.3393816351890564, 0.8235282897949219, -0.43569034337997437, -0.829987645149231, 0.05009227991104126, -0.7572609186172485, -0.5095729827880859, 1.2285490036010742, -0.5328006148338318, -0.08091999590396881, 0.8362939357757568, -0.3866221606731415, -0.07156673073768616]\n", 137 | "voice [0.43880149722099304, -0.12388401478528976, -0.09448350965976715, -0.4879209101200104, 0.6075019836425781, 0.32898378372192383, 0.6265825033187866, 0.728482723236084, -0.46667003631591797, -0.5817348957061768, -0.4383503198623657, -0.31208136677742004, -0.4085504412651062, 0.6028963327407837, -0.6348038911819458, -0.43311747908592224, -0.12351862341165543, -0.07814896106719971, -0.16226796805858612]\n", 138 | "echoed [0.6161936521530151, -0.22712212800979614, -0.0740918517112732, -0.7913801074028015, 0.09446655958890915, 0.6521013975143433, -0.08286319673061371, 0.4983268082141876, -0.629666268825531, -0.40694940090179443, 0.12792044878005981, -0.11819512397050858, -0.4466920495033264, 0.8771257996559143, -0.8426066637039185, -0.34973663091659546, 0.11151856929063797, -0.30014142394065857, -0.5564588308334351]\n", 139 | "[SEP] [0.7558457851409912, 0.36590367555618286, -0.17892876267433167, 0.39153221249580383, -0.5322665572166443, -0.613459587097168, 0.4913185238838196, -0.35385745763778687, 0.4216068983078003, -0.1137060597538948, 0.17552773654460907, -0.10072151571512222, 0.1765543669462204, 0.14826729893684387, -0.25414443016052246, -0.06873864680528641, 0.23133708536624908, -0.04865124076604843, 0.37981322407722473]\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "# @title Extraction\n", 145 | "import numpy as np\n", 146 | "from keras_bert import Tokenizer\n", 147 | "\n", 148 | "tokenizer = Tokenizer(token_dict)\n", 149 | "text = 'From that day forth... my arm changed... and a voice echoed'\n", 150 | "tokens = tokenizer.tokenize(text)\n", 151 | "indices, segments = tokenizer.encode(first=text, max_len=512)\n", 152 | "print(tokens)\n", 153 | "\n", 154 | "predicts = model.predict([np.array([indices] * 8), np.array([segments] * 8)])[0]\n", 155 | " \n", 156 | "for i, token in enumerate(tokens):\n", 157 | " print(token, predicts[i].tolist()[:19])" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "colab": { 163 | "name": "keras_bert_load_and_extract_tpu.ipynb", 164 | "version": "0.3.2", 165 | "provenance": [], 166 | "collapsed_sections": [] 167 | }, 168 | "kernelspec": { 169 | "name": "python3", 170 | "display_name": "Python 3" 171 | }, 172 | "accelerator": "TPU" 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 0 176 | } 177 | -------------------------------------------------------------------------------- /keras_bert/bert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow import keras 3 | 4 | from keras_pos_embd import PositionEmbedding 5 | from keras_layer_normalization import LayerNormalization 6 | from keras_transformer import get_encoders, gelu 7 | from keras_transformer import get_custom_objects as get_encoder_custom_objects 8 | from .layers import get_inputs, get_embedding, TokenEmbedding, EmbeddingSimilarity, Masked, Extract, TaskEmbedding 9 | from .optimizers import AdamWarmup 10 | 11 | 12 | __all__ = [ 13 | 'TOKEN_PAD', 'TOKEN_UNK', 'TOKEN_CLS', 'TOKEN_SEP', 'TOKEN_MASK', 14 | 'gelu', 'get_model', 'compile_model', 'get_base_dict', 'gen_batch_inputs', 'get_token_embedding', 15 | 'get_custom_objects', 16 | ] 17 | 18 | 19 | TOKEN_PAD = '' # Token for padding 20 | TOKEN_UNK = '[UNK]' # Token for unknown words 21 | TOKEN_CLS = '[CLS]' # Token for classification 22 | TOKEN_SEP = '[SEP]' # Token for separation 23 | TOKEN_MASK = '[MASK]' # Token for masking 24 | 25 | 26 | def get_model(token_num, 27 | pos_num=512, 28 | seq_len=512, 29 | embed_dim=768, 30 | transformer_num=12, 31 | head_num=12, 32 | feed_forward_dim=3072, 33 | dropout_rate=0.1, 34 | attention_activation=None, 35 | feed_forward_activation='gelu', 36 | training=True, 37 | trainable=None, 38 | output_layer_num=1, 39 | use_task_embed=False, 40 | task_num=10): 41 | """Get BERT model. 42 | 43 | See: https://arxiv.org/pdf/1810.04805.pdf 44 | 45 | :param token_num: Number of tokens. 46 | :param pos_num: Maximum position. 47 | :param seq_len: Maximum length of the input sequence or None. 48 | :param embed_dim: Dimensions of embeddings. 49 | :param transformer_num: Number of transformers. 50 | :param head_num: Number of heads in multi-head attention in each transformer. 51 | :param feed_forward_dim: Dimension of the feed forward layer in each transformer. 52 | :param dropout_rate: Dropout rate. 53 | :param attention_activation: Activation for attention layers. 54 | :param feed_forward_activation: Activation for feed-forward layers. 55 | :param training: A built model with MLM and NSP outputs will be returned if it is `True`, 56 | otherwise the input layers and the last feature extraction layer will be returned. 57 | :param trainable: Whether the model is trainable. 58 | :param output_layer_num: The number of layers whose outputs will be concatenated as a single output. 59 | Only available when `training` is `False`. 60 | :param use_task_embed: Whether to add task embeddings to existed embeddings. 61 | :param task_num: The number of tasks. 62 | :return: The built model. 63 | """ 64 | if attention_activation == 'gelu': 65 | attention_activation = gelu 66 | if feed_forward_activation == 'gelu': 67 | feed_forward_activation = gelu 68 | if trainable is None: 69 | trainable = training 70 | 71 | def _trainable(_layer): 72 | if isinstance(trainable, (list, tuple, set)): 73 | for prefix in trainable: 74 | if _layer.name.startswith(prefix): 75 | return True 76 | return False 77 | return trainable 78 | 79 | inputs = get_inputs(seq_len=seq_len) 80 | embed_layer, embed_weights = get_embedding( 81 | inputs, 82 | token_num=token_num, 83 | embed_dim=embed_dim, 84 | pos_num=pos_num, 85 | dropout_rate=dropout_rate, 86 | ) 87 | if use_task_embed: 88 | task_input = keras.layers.Input( 89 | shape=(1,), 90 | name='Input-Task', 91 | ) 92 | embed_layer = TaskEmbedding( 93 | input_dim=task_num, 94 | output_dim=embed_dim, 95 | mask_zero=False, 96 | name='Embedding-Task', 97 | )([embed_layer, task_input]) 98 | inputs = inputs[:2] + [task_input, inputs[-1]] 99 | if dropout_rate > 0.0: 100 | dropout_layer = keras.layers.Dropout( 101 | rate=dropout_rate, 102 | name='Embedding-Dropout', 103 | )(embed_layer) 104 | else: 105 | dropout_layer = embed_layer 106 | embed_layer = LayerNormalization( 107 | trainable=trainable, 108 | name='Embedding-Norm', 109 | )(dropout_layer) 110 | transformed = get_encoders( 111 | encoder_num=transformer_num, 112 | input_layer=embed_layer, 113 | head_num=head_num, 114 | hidden_dim=feed_forward_dim, 115 | attention_activation=attention_activation, 116 | feed_forward_activation=feed_forward_activation, 117 | dropout_rate=dropout_rate, 118 | ) 119 | if training: 120 | mlm_dense_layer = keras.layers.Dense( 121 | units=embed_dim, 122 | activation=feed_forward_activation, 123 | name='MLM-Dense', 124 | )(transformed) 125 | mlm_norm_layer = LayerNormalization(name='MLM-Norm')(mlm_dense_layer) 126 | mlm_pred_layer = EmbeddingSimilarity(name='MLM-Sim')([mlm_norm_layer, embed_weights]) 127 | masked_layer = Masked(name='MLM')([mlm_pred_layer, inputs[-1]]) 128 | extract_layer = Extract(index=0, name='Extract')(transformed) 129 | nsp_dense_layer = keras.layers.Dense( 130 | units=embed_dim, 131 | activation='tanh', 132 | name='NSP-Dense', 133 | )(extract_layer) 134 | nsp_pred_layer = keras.layers.Dense( 135 | units=2, 136 | activation='softmax', 137 | name='NSP', 138 | )(nsp_dense_layer) 139 | model = keras.models.Model(inputs=inputs, outputs=[masked_layer, nsp_pred_layer]) 140 | for layer in model.layers: 141 | layer.trainable = _trainable(layer) 142 | return model 143 | else: 144 | if use_task_embed: 145 | inputs = inputs[:3] 146 | else: 147 | inputs = inputs[:2] 148 | model = keras.models.Model(inputs=inputs, outputs=transformed) 149 | for layer in model.layers: 150 | layer.trainable = _trainable(layer) 151 | if isinstance(output_layer_num, int): 152 | output_layer_num = min(output_layer_num, transformer_num) 153 | output_layer_num = [-i for i in range(1, output_layer_num + 1)] 154 | outputs = [] 155 | for layer_index in output_layer_num: 156 | if layer_index < 0: 157 | layer_index = transformer_num + layer_index 158 | layer_index += 1 159 | layer = model.get_layer(name='Encoder-{}-FeedForward-Norm'.format(layer_index)) 160 | outputs.append(layer.output) 161 | if len(outputs) > 1: 162 | transformed = keras.layers.Concatenate(name='Encoder-Output')(list(reversed(outputs))) 163 | else: 164 | transformed = outputs[0] 165 | return inputs, transformed 166 | 167 | 168 | def compile_model(model, 169 | weight_decay=0.01, 170 | decay_steps=100000, 171 | warmup_steps=10000, 172 | learning_rate=1e-4): 173 | """Compile the model with warmup optimizer and sparse cross-entropy loss. 174 | 175 | :param model: The built model. 176 | :param weight_decay: Weight decay rate. 177 | :param decay_steps: Learning rate will decay linearly to zero in decay steps. 178 | :param warmup_steps: Learning rate will increase linearly to learning_rate in first warmup steps. 179 | :param learning_rate: Learning rate. 180 | :return: The compiled model. 181 | """ 182 | model.compile( 183 | optimizer=AdamWarmup( 184 | decay_steps=decay_steps, 185 | warmup_steps=warmup_steps, 186 | learning_rate=learning_rate, 187 | weight_decay=weight_decay, 188 | weight_decay_pattern=['embeddings', 'kernel', 'W1', 'W2', 'Wk', 'Wq', 'Wv', 'Wo'], 189 | ), 190 | loss=keras.losses.sparse_categorical_crossentropy, 191 | ) 192 | 193 | 194 | def get_custom_objects(): 195 | """Get all custom objects for loading saved models.""" 196 | custom_objects = get_encoder_custom_objects() 197 | custom_objects['PositionEmbedding'] = PositionEmbedding 198 | custom_objects['TokenEmbedding'] = TokenEmbedding 199 | custom_objects['EmbeddingSimilarity'] = EmbeddingSimilarity 200 | custom_objects['TaskEmbedding'] = TaskEmbedding 201 | custom_objects['Masked'] = Masked 202 | custom_objects['Extract'] = Extract 203 | custom_objects['AdamWarmup'] = AdamWarmup 204 | return custom_objects 205 | 206 | 207 | def get_base_dict(): 208 | """Get basic dictionary containing special tokens.""" 209 | return { 210 | TOKEN_PAD: 0, 211 | TOKEN_UNK: 1, 212 | TOKEN_CLS: 2, 213 | TOKEN_SEP: 3, 214 | TOKEN_MASK: 4, 215 | } 216 | 217 | 218 | def get_token_embedding(model): 219 | """Get token embedding from model. 220 | 221 | :param model: The built model. 222 | :return: The output weights of embeddings. 223 | """ 224 | return model.get_layer('Embedding-Token').output[1] 225 | 226 | 227 | def gen_batch_inputs(sentence_pairs, 228 | token_dict, 229 | token_list, 230 | seq_len=512, 231 | mask_rate=0.15, 232 | mask_mask_rate=0.8, 233 | mask_random_rate=0.1, 234 | swap_sentence_rate=0.5, 235 | force_mask=True): 236 | """Generate a batch of inputs and outputs for training. 237 | 238 | :param sentence_pairs: A list of pairs containing lists of tokens. 239 | :param token_dict: The dictionary containing special tokens. 240 | :param token_list: A list containing all tokens. 241 | :param seq_len: Length of the sequence. 242 | :param mask_rate: The rate of choosing a token for prediction. 243 | :param mask_mask_rate: The rate of replacing the token to `TOKEN_MASK`. 244 | :param mask_random_rate: The rate of replacing the token to a random word. 245 | :param swap_sentence_rate: The rate of swapping the second sentences. 246 | :param force_mask: At least one position will be masked. 247 | :return: All the inputs and outputs. 248 | """ 249 | batch_size = len(sentence_pairs) 250 | base_dict = get_base_dict() 251 | unknown_index = token_dict[TOKEN_UNK] 252 | # Generate sentence swapping mapping 253 | nsp_outputs = np.zeros((batch_size,)) 254 | mapping = {} 255 | if swap_sentence_rate > 0.0: 256 | indices = [index for index in range(batch_size) if np.random.random() < swap_sentence_rate] 257 | mapped = indices[:] 258 | np.random.shuffle(mapped) 259 | for i in range(len(mapped)): 260 | if indices[i] != mapped[i]: 261 | nsp_outputs[indices[i]] = 1.0 262 | mapping = {indices[i]: mapped[i] for i in range(len(indices))} 263 | # Generate MLM 264 | token_inputs, segment_inputs, masked_inputs = [], [], [] 265 | mlm_outputs = [] 266 | for i in range(batch_size): 267 | first, second = sentence_pairs[i][0], sentence_pairs[mapping.get(i, i)][1] 268 | segment_inputs.append(([0] * (len(first) + 2) + [1] * (seq_len - (len(first) + 2)))[:seq_len]) 269 | tokens = [TOKEN_CLS] + first + [TOKEN_SEP] + second + [TOKEN_SEP] 270 | tokens = tokens[:seq_len] 271 | tokens += [TOKEN_PAD] * (seq_len - len(tokens)) 272 | token_input, masked_input, mlm_output = [], [], [] 273 | has_mask = False 274 | for token in tokens: 275 | mlm_output.append(token_dict.get(token, unknown_index)) 276 | if token not in base_dict and np.random.random() < mask_rate: 277 | has_mask = True 278 | masked_input.append(1) 279 | r = np.random.random() 280 | if r < mask_mask_rate: 281 | token_input.append(token_dict[TOKEN_MASK]) 282 | elif r < mask_mask_rate + mask_random_rate: 283 | while True: 284 | token = token_list[np.random.randint(0, len(token_list))] 285 | if token not in base_dict: 286 | token_input.append(token_dict[token]) 287 | break 288 | else: 289 | token_input.append(token_dict.get(token, unknown_index)) 290 | else: 291 | masked_input.append(0) 292 | token_input.append(token_dict.get(token, unknown_index)) 293 | if force_mask and not has_mask: 294 | masked_input[1] = 1 295 | token_inputs.append(token_input) 296 | masked_inputs.append(masked_input) 297 | mlm_outputs.append(mlm_output) 298 | inputs = [np.asarray(x) for x in [token_inputs, segment_inputs, masked_inputs]] 299 | outputs = [np.asarray(np.expand_dims(x, axis=-1)) for x in [mlm_outputs, nsp_outputs]] 300 | return inputs, outputs 301 | -------------------------------------------------------------------------------- /demo/tune/keras_bert_classification_tpu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "keras_bert_classification_tpu.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "TPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "doNFRjPqiBhM", 22 | "colab_type": "code", 23 | "outputId": "725008ed-61f8-433c-82ce-9293bc07629c", 24 | "colab": { 25 | "base_uri": "https://localhost:8080/", 26 | "height": 289 27 | } 28 | }, 29 | "source": [ 30 | "# @title Preparation\n", 31 | "!pip install -q keras-bert keras-rectified-adam\n", 32 | "!wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n", 33 | "!unzip -o uncased_L-12_H-768_A-12.zip" 34 | ], 35 | "execution_count": 1, 36 | "outputs": [ 37 | { 38 | "output_type": "stream", 39 | "text": [ 40 | " Building wheel for keras-bert (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 41 | " Building wheel for keras-rectified-adam (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 42 | " Building wheel for keras-transformer (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 43 | " Building wheel for keras-pos-embd (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 44 | " Building wheel for keras-multi-head (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 45 | " Building wheel for keras-layer-normalization (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 46 | " Building wheel for keras-position-wise-feed-forward (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 47 | " Building wheel for keras-embed-sim (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 48 | " Building wheel for keras-self-attention (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 49 | "Archive: uncased_L-12_H-768_A-12.zip\n", 50 | " creating: uncased_L-12_H-768_A-12/\n", 51 | " inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.meta \n", 52 | " inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001 \n", 53 | " inflating: uncased_L-12_H-768_A-12/vocab.txt \n", 54 | " inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.index \n", 55 | " inflating: uncased_L-12_H-768_A-12/bert_config.json \n" 56 | ], 57 | "name": "stdout" 58 | } 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "id": "bzoFRUGmh6a3", 65 | "colab_type": "code", 66 | "colab": {} 67 | }, 68 | "source": [ 69 | "# @title Constants\n", 70 | "\n", 71 | "SEQ_LEN = 128\n", 72 | "BATCH_SIZE = 128\n", 73 | "EPOCHS = 5\n", 74 | "LR = 1e-4" 75 | ], 76 | "execution_count": 2, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "id": "KUQ8UtquieFj", 83 | "colab_type": "code", 84 | "colab": {} 85 | }, 86 | "source": [ 87 | "# @title Environment\n", 88 | "import os\n", 89 | "\n", 90 | "pretrained_path = 'uncased_L-12_H-768_A-12'\n", 91 | "config_path = os.path.join(pretrained_path, 'bert_config.json')\n", 92 | "checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt')\n", 93 | "vocab_path = os.path.join(pretrained_path, 'vocab.txt')\n", 94 | "\n", 95 | "# TF_KERAS must be added to environment variables in order to use TPU\n", 96 | "os.environ['TF_KERAS'] = '1'" 97 | ], 98 | "execution_count": 3, 99 | "outputs": [] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "metadata": { 104 | "id": "gGzsxkLTpRrs", 105 | "colab_type": "code", 106 | "colab": { 107 | "base_uri": "https://localhost:8080/", 108 | "height": 170 109 | }, 110 | "outputId": "6bc483bd-fee3-4cb1-c317-3ff6983389f2" 111 | }, 112 | "source": [ 113 | "# @title Initialize TPU Strategy\n", 114 | "\n", 115 | "import tensorflow as tf\n", 116 | "from keras_bert import get_custom_objects\n", 117 | "\n", 118 | "TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']\n", 119 | "resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)\n", 120 | "tf.contrib.distribute.initialize_tpu_system(resolver)\n", 121 | "strategy = tf.contrib.distribute.TPUStrategy(resolver)" 122 | ], 123 | "execution_count": 4, 124 | "outputs": [ 125 | { 126 | "output_type": "stream", 127 | "text": [ 128 | "WARNING: Logging before flag parsing goes to stderr.\n", 129 | "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", 130 | "For more information, please see:\n", 131 | " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", 132 | " * https://github.com/tensorflow/addons\n", 133 | " * https://github.com/tensorflow/io (for I/O related ops)\n", 134 | "If you depend on functionality not listed there, please file an issue.\n", 135 | "\n" 136 | ], 137 | "name": "stderr" 138 | } 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "metadata": { 144 | "id": "sVTPNxOyj4HJ", 145 | "colab_type": "code", 146 | "outputId": "a6b0bc85-93cb-4642-ea81-77d78af6a186", 147 | "colab": { 148 | "base_uri": "https://localhost:8080/", 149 | "height": 139 150 | } 151 | }, 152 | "source": [ 153 | "# @title Load Basic Model\n", 154 | "import codecs\n", 155 | "from keras_bert import load_trained_model_from_checkpoint\n", 156 | "\n", 157 | "token_dict = {}\n", 158 | "with codecs.open(vocab_path, 'r', 'utf8') as reader:\n", 159 | " for line in reader:\n", 160 | " token = line.strip()\n", 161 | " token_dict[token] = len(token_dict)\n", 162 | "\n", 163 | "with strategy.scope():\n", 164 | " model = load_trained_model_from_checkpoint(\n", 165 | " config_path,\n", 166 | " checkpoint_path,\n", 167 | " training=True,\n", 168 | " trainable=True,\n", 169 | " seq_len=SEQ_LEN,\n", 170 | " )" 171 | ], 172 | "execution_count": 5, 173 | "outputs": [ 174 | { 175 | "output_type": "stream", 176 | "text": [], 177 | "name": "stderr" 178 | } 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "xioN-O_vtztC", 185 | "colab_type": "code", 186 | "outputId": "05526ce5-d0fe-4ede-fe4a-199997244f27", 187 | "colab": { 188 | "base_uri": "https://localhost:8080/", 189 | "height": 51 190 | } 191 | }, 192 | "source": [ 193 | "# @title Download IMDB Data\n", 194 | "import tensorflow as tf\n", 195 | "\n", 196 | "dataset = tf.keras.utils.get_file(\n", 197 | " fname=\"aclImdb.tar.gz\", \n", 198 | " origin=\"http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\", \n", 199 | " extract=True,\n", 200 | ")" 201 | ], 202 | "execution_count": 6, 203 | "outputs": [ 204 | { 205 | "output_type": "stream", 206 | "text": [ 207 | "Downloading data from http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n", 208 | "84131840/84125825 [==============================] - 3s 0us/step\n" 209 | ], 210 | "name": "stdout" 211 | } 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "metadata": { 217 | "id": "xfC3Nh8pnckd", 218 | "colab_type": "code", 219 | "outputId": "00f6e5d4-f6ca-4626-9b3b-9525efbf4596", 220 | "colab": { 221 | "base_uri": "https://localhost:8080/", 222 | "height": 85 223 | } 224 | }, 225 | "source": [ 226 | "# @title Convert Data to Array\n", 227 | "import os\n", 228 | "import numpy as np\n", 229 | "from tqdm import tqdm\n", 230 | "from keras_bert import Tokenizer\n", 231 | "\n", 232 | "tokenizer = Tokenizer(token_dict)\n", 233 | "\n", 234 | "\n", 235 | "def load_data(path):\n", 236 | " global tokenizer\n", 237 | " indices, sentiments = [], []\n", 238 | " for folder, sentiment in (('neg', 0), ('pos', 1)):\n", 239 | " folder = os.path.join(path, folder)\n", 240 | " for name in tqdm(os.listdir(folder)):\n", 241 | " with open(os.path.join(folder, name), 'r') as reader:\n", 242 | " text = reader.read()\n", 243 | " ids, segments = tokenizer.encode(text, max_len=SEQ_LEN)\n", 244 | " indices.append(ids)\n", 245 | " sentiments.append(sentiment)\n", 246 | " items = list(zip(indices, sentiments))\n", 247 | " np.random.shuffle(items)\n", 248 | " indices, sentiments = zip(*items)\n", 249 | " indices = np.array(indices)\n", 250 | " mod = indices.shape[0] % BATCH_SIZE\n", 251 | " if mod > 0:\n", 252 | " indices, sentiments = indices[:-mod], sentiments[:-mod]\n", 253 | " return [indices, np.zeros_like(indices)], np.array(sentiments)\n", 254 | " \n", 255 | " \n", 256 | "train_path = os.path.join(os.path.dirname(dataset), 'aclImdb', 'train')\n", 257 | "test_path = os.path.join(os.path.dirname(dataset), 'aclImdb', 'test')\n", 258 | "\n", 259 | "\n", 260 | "train_x, train_y = load_data(train_path)\n", 261 | "test_x, test_y = load_data(test_path)" 262 | ], 263 | "execution_count": 7, 264 | "outputs": [ 265 | { 266 | "output_type": "stream", 267 | "text": [ 268 | "100%|██████████| 12500/12500 [00:44<00:00, 277.81it/s]\n", 269 | "100%|██████████| 12500/12500 [00:46<00:00, 268.06it/s]\n", 270 | "100%|██████████| 12500/12500 [00:44<00:00, 282.12it/s]\n", 271 | "100%|██████████| 12500/12500 [00:45<00:00, 277.77it/s]\n" 272 | ], 273 | "name": "stderr" 274 | } 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "metadata": { 280 | "id": "OhMA1j7wnqSm", 281 | "colab_type": "code", 282 | "colab": {} 283 | }, 284 | "source": [ 285 | "# @title Build Custom Model\n", 286 | "from tensorflow.python import keras\n", 287 | "from keras_radam import RAdam\n", 288 | "\n", 289 | "with strategy.scope():\n", 290 | " inputs = model.inputs[:2]\n", 291 | " dense = model.get_layer('NSP-Dense').output\n", 292 | " outputs = keras.layers.Dense(units=2, activation='softmax')(dense)\n", 293 | " \n", 294 | " model = keras.models.Model(inputs, outputs)\n", 295 | " model.compile(\n", 296 | " RAdam(lr=LR),\n", 297 | " loss='sparse_categorical_crossentropy',\n", 298 | " metrics=['sparse_categorical_accuracy'],\n", 299 | " )" 300 | ], 301 | "execution_count": 8, 302 | "outputs": [] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "metadata": { 307 | "id": "jmOLb7lWvDvl", 308 | "colab_type": "code", 309 | "colab": {} 310 | }, 311 | "source": [ 312 | "# @title Initialize Variables\n", 313 | "import tensorflow as tf\n", 314 | "import tensorflow.keras.backend as K\n", 315 | "\n", 316 | "sess = K.get_session()\n", 317 | "uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.report_uninitialized_variables())])\n", 318 | "init_op = tf.variables_initializer(\n", 319 | " [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_variables]\n", 320 | ")\n", 321 | "sess.run(init_op)" 322 | ], 323 | "execution_count": 9, 324 | "outputs": [] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "metadata": { 329 | "id": "QgP7bCQxrZpQ", 330 | "colab_type": "code", 331 | "outputId": "d2167bd6-2a03-48bd-e6c6-d31c63c57e65", 332 | "colab": { 333 | "base_uri": "https://localhost:8080/", 334 | "height": 204 335 | } 336 | }, 337 | "source": [ 338 | "# @title Fit\n", 339 | "\n", 340 | "model.fit(\n", 341 | " train_x,\n", 342 | " train_y,\n", 343 | " epochs=EPOCHS,\n", 344 | " batch_size=BATCH_SIZE,\n", 345 | ")" 346 | ], 347 | "execution_count": 10, 348 | "outputs": [ 349 | { 350 | "output_type": "stream", 351 | "text": [ 352 | "Epoch 1/5\n", 353 | "195/195 [==============================] - 84s 432ms/step - loss: 0.7142 - sparse_categorical_accuracy: 0.5161\n", 354 | "Epoch 2/5\n", 355 | "195/195 [==============================] - 38s 196ms/step - loss: 0.6928 - sparse_categorical_accuracy: 0.5467\n", 356 | "Epoch 3/5\n", 357 | "195/195 [==============================] - 38s 195ms/step - loss: 0.6061 - sparse_categorical_accuracy: 0.6643\n", 358 | "Epoch 4/5\n", 359 | "195/195 [==============================] - 38s 195ms/step - loss: 0.4952 - sparse_categorical_accuracy: 0.7586\n", 360 | "Epoch 5/5\n", 361 | "195/195 [==============================] - 38s 195ms/step - loss: 0.4192 - sparse_categorical_accuracy: 0.8067\n" 362 | ], 363 | "name": "stdout" 364 | }, 365 | { 366 | "output_type": "execute_result", 367 | "data": { 368 | "text/plain": [ 369 | "" 370 | ] 371 | }, 372 | "metadata": { 373 | "tags": [] 374 | }, 375 | "execution_count": 11 376 | } 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "metadata": { 382 | "id": "ZBSba3vprlRD", 383 | "colab_type": "code", 384 | "outputId": "b1afd7a1-dd12-44c1-cb6e-0631d5896311", 385 | "colab": { 386 | "base_uri": "https://localhost:8080/", 387 | "height": 34 388 | } 389 | }, 390 | "source": [ 391 | "# @title Predict\n", 392 | "\n", 393 | "predicts = model.predict(test_x, verbose=True).argmax(axis=-1)" 394 | ], 395 | "execution_count": 12, 396 | "outputs": [ 397 | { 398 | "output_type": "stream", 399 | "text": [], 400 | "name": "stdout" 401 | } 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "metadata": { 407 | "id": "Wo1aps8prrCq", 408 | "colab_type": "code", 409 | "colab": {} 410 | }, 411 | "source": [ 412 | "# @title Accuracy\n", 413 | "\n", 414 | "print(np.sum(test_y == predicts) / test_y.shape[0])" 415 | ], 416 | "execution_count": 13, 417 | "outputs": [] 418 | } 419 | ] 420 | } -------------------------------------------------------------------------------- /demo/load_model/keras_bert_load_and_extract.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "jUMerb-ChqzU", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "# Load & Extract\n", 11 | "\n", 12 | "## Download Pretrained Weights" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": { 19 | "id": "j6ggPFIzkjR7", 20 | "colab_type": "code", 21 | "colab": {} 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "!pip install -q keras-bert" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": { 32 | "colab_type": "code", 33 | "id": "VN1LrB1P9mhH", 34 | "colab": {} 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "!wget -q https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": { 45 | "id": "-xYtjuUvhYY4", 46 | "colab_type": "code", 47 | "outputId": "cc6e84b9-6026-4443-b629-a57720efbc2d", 48 | "colab": { 49 | "base_uri": "https://localhost:8080/", 50 | "height": 119.0 51 | } 52 | }, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "Archive: chinese_L-12_H-768_A-12.zip\n", 59 | " inflating: chinese_L-12_H-768_A-12/bert_model.ckpt.meta \n", 60 | " inflating: chinese_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001 \n", 61 | " inflating: chinese_L-12_H-768_A-12/vocab.txt \n", 62 | " inflating: chinese_L-12_H-768_A-12/bert_model.ckpt.index \n", 63 | " inflating: chinese_L-12_H-768_A-12/bert_config.json \n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "!unzip -o chinese_L-12_H-768_A-12.zip" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "id": "laBEGbeeiD0X", 75 | "colab_type": "text" 76 | }, 77 | "source": [ 78 | "## Build Model & Dictionary" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "id": "AKC4cOstjAKK", 85 | "colab_type": "text" 86 | }, 87 | "source": [ 88 | "Set paths:" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": { 95 | "id": "lJYa8uvOiA6t", 96 | "colab_type": "code", 97 | "colab": {} 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "import os\n", 102 | "\n", 103 | "pretrained_path = 'chinese_L-12_H-768_A-12'\n", 104 | "config_path = os.path.join(pretrained_path, 'bert_config.json')\n", 105 | "checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt')\n", 106 | "vocab_path = os.path.join(pretrained_path, 'vocab.txt')" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": { 112 | "id": "XmboeInri6bD", 113 | "colab_type": "text" 114 | }, 115 | "source": [ 116 | "Enable `tf.keras` by adding `TF_KERAS` to environment variables:" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "metadata": { 123 | "id": "LOQHPPyii5Tk", 124 | "colab_type": "code", 125 | "colab": {} 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "os.environ['TF_KERAS'] = '1'" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "RUlKT2XIjJO3", 136 | "colab_type": "text" 137 | }, 138 | "source": [ 139 | "Build the dictionary:" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": { 146 | "id": "NHEAc16hjIEA", 147 | "colab_type": "code", 148 | "colab": {} 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "import codecs\n", 153 | "\n", 154 | "token_dict = {}\n", 155 | "with codecs.open(vocab_path, 'r', 'utf8') as reader:\n", 156 | " for line in reader:\n", 157 | " token = line.strip()\n", 158 | " token_dict[token] = len(token_dict)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": { 164 | "id": "QXnghzK-jRC1", 165 | "colab_type": "text" 166 | }, 167 | "source": [ 168 | "Build the model:" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 7, 174 | "metadata": { 175 | "id": "ETadpJ5_jTgY", 176 | "colab_type": "code", 177 | "outputId": "11847961-846a-4711-cc0a-4f3c0cb3c8b1", 178 | "colab": { 179 | "base_uri": "https://localhost:8080/", 180 | "height": 4219.0 181 | } 182 | }, 183 | "outputs": [ 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 189 | "Instructions for updating:\n", 190 | "Colocations handled automatically by placer.\n", 191 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", 192 | "Instructions for updating:\n", 193 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n", 194 | "________________________________________________________________________________________________________________________\n", 195 | "Layer (type) Output Shape Param # Connected to \n", 196 | "========================================================================================================================\n", 197 | "Input-Token (InputLayer) (None, 512) 0 \n", 198 | "________________________________________________________________________________________________________________________\n", 199 | "Input-Segment (InputLayer) (None, 512) 0 \n", 200 | "________________________________________________________________________________________________________________________\n", 201 | "Embedding-Token (TokenEmbedding) [(None, 512, 768), (21128, 16226304 Input-Token[0][0] \n", 202 | "________________________________________________________________________________________________________________________\n", 203 | "Embedding-Segment (Embedding) (None, 512, 768) 1536 Input-Segment[0][0] \n", 204 | "________________________________________________________________________________________________________________________\n", 205 | "Embedding-Token-Segment (Add) (None, 512, 768) 0 Embedding-Token[0][0] \n", 206 | " Embedding-Segment[0][0] \n", 207 | "________________________________________________________________________________________________________________________\n", 208 | "Embedding-Position (PositionEmbedding) (None, 512, 768) 393216 Embedding-Token-Segment[0][0] \n", 209 | "________________________________________________________________________________________________________________________\n", 210 | "Embedding-Dropout (Dropout) (None, 512, 768) 0 Embedding-Position[0][0] \n", 211 | "________________________________________________________________________________________________________________________\n", 212 | "Embedding-Norm (LayerNormalization) (None, 512, 768) 1536 Embedding-Dropout[0][0] \n", 213 | "________________________________________________________________________________________________________________________\n", 214 | "Encoder-1-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Embedding-Norm[0][0] \n", 215 | "________________________________________________________________________________________________________________________\n", 216 | "Encoder-1-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-1-MultiHeadSelfAttention[0][0] \n", 217 | "________________________________________________________________________________________________________________________\n", 218 | "Encoder-1-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Embedding-Norm[0][0] \n", 219 | " Encoder-1-MultiHeadSelfAttention-Dropout\n", 220 | "________________________________________________________________________________________________________________________\n", 221 | "Encoder-1-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-1-MultiHeadSelfAttention-Add[0][\n", 222 | "________________________________________________________________________________________________________________________\n", 223 | "Encoder-1-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-1-MultiHeadSelfAttention-Norm[0]\n", 224 | "________________________________________________________________________________________________________________________\n", 225 | "Encoder-1-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-1-FeedForward[0][0] \n", 226 | "________________________________________________________________________________________________________________________\n", 227 | "Encoder-1-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-1-MultiHeadSelfAttention-Norm[0]\n", 228 | " Encoder-1-FeedForward-Dropout[0][0] \n", 229 | "________________________________________________________________________________________________________________________\n", 230 | "Encoder-1-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-1-FeedForward-Add[0][0] \n", 231 | "________________________________________________________________________________________________________________________\n", 232 | "Encoder-2-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Encoder-1-FeedForward-Norm[0][0] \n", 233 | "________________________________________________________________________________________________________________________\n", 234 | "Encoder-2-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-2-MultiHeadSelfAttention[0][0] \n", 235 | "________________________________________________________________________________________________________________________\n", 236 | "Encoder-2-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Encoder-1-FeedForward-Norm[0][0] \n", 237 | " Encoder-2-MultiHeadSelfAttention-Dropout\n", 238 | "________________________________________________________________________________________________________________________\n", 239 | "Encoder-2-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-2-MultiHeadSelfAttention-Add[0][\n", 240 | "________________________________________________________________________________________________________________________\n", 241 | "Encoder-2-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-2-MultiHeadSelfAttention-Norm[0]\n", 242 | "________________________________________________________________________________________________________________________\n", 243 | "Encoder-2-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-2-FeedForward[0][0] \n", 244 | "________________________________________________________________________________________________________________________\n", 245 | "Encoder-2-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-2-MultiHeadSelfAttention-Norm[0]\n", 246 | " Encoder-2-FeedForward-Dropout[0][0] \n", 247 | "________________________________________________________________________________________________________________________\n", 248 | "Encoder-2-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-2-FeedForward-Add[0][0] \n", 249 | "________________________________________________________________________________________________________________________\n", 250 | "Encoder-3-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Encoder-2-FeedForward-Norm[0][0] \n", 251 | "________________________________________________________________________________________________________________________\n", 252 | "Encoder-3-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-3-MultiHeadSelfAttention[0][0] \n", 253 | "________________________________________________________________________________________________________________________\n", 254 | "Encoder-3-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Encoder-2-FeedForward-Norm[0][0] \n", 255 | " Encoder-3-MultiHeadSelfAttention-Dropout\n", 256 | "________________________________________________________________________________________________________________________\n", 257 | "Encoder-3-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-3-MultiHeadSelfAttention-Add[0][\n", 258 | "________________________________________________________________________________________________________________________\n", 259 | "Encoder-3-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-3-MultiHeadSelfAttention-Norm[0]\n", 260 | "________________________________________________________________________________________________________________________\n", 261 | "Encoder-3-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-3-FeedForward[0][0] \n", 262 | "________________________________________________________________________________________________________________________\n", 263 | "Encoder-3-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-3-MultiHeadSelfAttention-Norm[0]\n", 264 | " Encoder-3-FeedForward-Dropout[0][0] \n", 265 | "________________________________________________________________________________________________________________________\n", 266 | "Encoder-3-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-3-FeedForward-Add[0][0] \n", 267 | "________________________________________________________________________________________________________________________\n", 268 | "Encoder-4-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Encoder-3-FeedForward-Norm[0][0] \n", 269 | "________________________________________________________________________________________________________________________\n", 270 | "Encoder-4-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-4-MultiHeadSelfAttention[0][0] \n", 271 | "________________________________________________________________________________________________________________________\n", 272 | "Encoder-4-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Encoder-3-FeedForward-Norm[0][0] \n", 273 | " Encoder-4-MultiHeadSelfAttention-Dropout\n", 274 | "________________________________________________________________________________________________________________________\n", 275 | "Encoder-4-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-4-MultiHeadSelfAttention-Add[0][\n", 276 | "________________________________________________________________________________________________________________________\n", 277 | "Encoder-4-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-4-MultiHeadSelfAttention-Norm[0]\n", 278 | "________________________________________________________________________________________________________________________\n", 279 | "Encoder-4-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-4-FeedForward[0][0] \n", 280 | "________________________________________________________________________________________________________________________\n", 281 | "Encoder-4-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-4-MultiHeadSelfAttention-Norm[0]\n", 282 | " Encoder-4-FeedForward-Dropout[0][0] \n", 283 | "________________________________________________________________________________________________________________________\n", 284 | "Encoder-4-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-4-FeedForward-Add[0][0] \n", 285 | "________________________________________________________________________________________________________________________\n", 286 | "Encoder-5-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Encoder-4-FeedForward-Norm[0][0] \n", 287 | "________________________________________________________________________________________________________________________\n", 288 | "Encoder-5-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-5-MultiHeadSelfAttention[0][0] \n", 289 | "________________________________________________________________________________________________________________________\n", 290 | "Encoder-5-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Encoder-4-FeedForward-Norm[0][0] \n", 291 | " Encoder-5-MultiHeadSelfAttention-Dropout\n", 292 | "________________________________________________________________________________________________________________________\n", 293 | "Encoder-5-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-5-MultiHeadSelfAttention-Add[0][\n", 294 | "________________________________________________________________________________________________________________________\n", 295 | "Encoder-5-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-5-MultiHeadSelfAttention-Norm[0]\n", 296 | "________________________________________________________________________________________________________________________\n", 297 | "Encoder-5-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-5-FeedForward[0][0] \n", 298 | "________________________________________________________________________________________________________________________\n", 299 | "Encoder-5-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-5-MultiHeadSelfAttention-Norm[0]\n", 300 | " Encoder-5-FeedForward-Dropout[0][0] \n", 301 | "________________________________________________________________________________________________________________________\n", 302 | "Encoder-5-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-5-FeedForward-Add[0][0] \n", 303 | "________________________________________________________________________________________________________________________\n", 304 | "Encoder-6-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Encoder-5-FeedForward-Norm[0][0] \n", 305 | "________________________________________________________________________________________________________________________\n", 306 | "Encoder-6-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-6-MultiHeadSelfAttention[0][0] \n", 307 | "________________________________________________________________________________________________________________________\n", 308 | "Encoder-6-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Encoder-5-FeedForward-Norm[0][0] \n", 309 | " Encoder-6-MultiHeadSelfAttention-Dropout\n", 310 | "________________________________________________________________________________________________________________________\n", 311 | "Encoder-6-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-6-MultiHeadSelfAttention-Add[0][\n", 312 | "________________________________________________________________________________________________________________________\n", 313 | "Encoder-6-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-6-MultiHeadSelfAttention-Norm[0]\n", 314 | "________________________________________________________________________________________________________________________\n", 315 | "Encoder-6-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-6-FeedForward[0][0] \n", 316 | "________________________________________________________________________________________________________________________\n", 317 | "Encoder-6-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-6-MultiHeadSelfAttention-Norm[0]\n", 318 | " Encoder-6-FeedForward-Dropout[0][0] \n", 319 | "________________________________________________________________________________________________________________________\n", 320 | "Encoder-6-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-6-FeedForward-Add[0][0] \n", 321 | "________________________________________________________________________________________________________________________\n", 322 | "Encoder-7-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Encoder-6-FeedForward-Norm[0][0] \n", 323 | "________________________________________________________________________________________________________________________\n", 324 | "Encoder-7-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-7-MultiHeadSelfAttention[0][0] \n", 325 | "________________________________________________________________________________________________________________________\n", 326 | "Encoder-7-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Encoder-6-FeedForward-Norm[0][0] \n", 327 | " Encoder-7-MultiHeadSelfAttention-Dropout\n", 328 | "________________________________________________________________________________________________________________________\n", 329 | "Encoder-7-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-7-MultiHeadSelfAttention-Add[0][\n", 330 | "________________________________________________________________________________________________________________________\n", 331 | "Encoder-7-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-7-MultiHeadSelfAttention-Norm[0]\n", 332 | "________________________________________________________________________________________________________________________\n", 333 | "Encoder-7-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-7-FeedForward[0][0] \n", 334 | "________________________________________________________________________________________________________________________\n", 335 | "Encoder-7-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-7-MultiHeadSelfAttention-Norm[0]\n", 336 | " Encoder-7-FeedForward-Dropout[0][0] \n", 337 | "________________________________________________________________________________________________________________________\n", 338 | "Encoder-7-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-7-FeedForward-Add[0][0] \n", 339 | "________________________________________________________________________________________________________________________\n", 340 | "Encoder-8-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Encoder-7-FeedForward-Norm[0][0] \n", 341 | "________________________________________________________________________________________________________________________\n", 342 | "Encoder-8-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-8-MultiHeadSelfAttention[0][0] \n", 343 | "________________________________________________________________________________________________________________________\n", 344 | "Encoder-8-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Encoder-7-FeedForward-Norm[0][0] \n", 345 | " Encoder-8-MultiHeadSelfAttention-Dropout\n", 346 | "________________________________________________________________________________________________________________________\n", 347 | "Encoder-8-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-8-MultiHeadSelfAttention-Add[0][\n", 348 | "________________________________________________________________________________________________________________________\n", 349 | "Encoder-8-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-8-MultiHeadSelfAttention-Norm[0]\n", 350 | "________________________________________________________________________________________________________________________\n", 351 | "Encoder-8-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-8-FeedForward[0][0] \n", 352 | "________________________________________________________________________________________________________________________\n", 353 | "Encoder-8-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-8-MultiHeadSelfAttention-Norm[0]\n", 354 | " Encoder-8-FeedForward-Dropout[0][0] \n", 355 | "________________________________________________________________________________________________________________________\n", 356 | "Encoder-8-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-8-FeedForward-Add[0][0] \n", 357 | "________________________________________________________________________________________________________________________\n", 358 | "Encoder-9-MultiHeadSelfAttention (Mult (None, None, 768) 2362368 Encoder-8-FeedForward-Norm[0][0] \n", 359 | "________________________________________________________________________________________________________________________\n", 360 | "Encoder-9-MultiHeadSelfAttention-Dropo (None, None, 768) 0 Encoder-9-MultiHeadSelfAttention[0][0] \n", 361 | "________________________________________________________________________________________________________________________\n", 362 | "Encoder-9-MultiHeadSelfAttention-Add ( (None, 512, 768) 0 Encoder-8-FeedForward-Norm[0][0] \n", 363 | " Encoder-9-MultiHeadSelfAttention-Dropout\n", 364 | "________________________________________________________________________________________________________________________\n", 365 | "Encoder-9-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-9-MultiHeadSelfAttention-Add[0][\n", 366 | "________________________________________________________________________________________________________________________\n", 367 | "Encoder-9-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-9-MultiHeadSelfAttention-Norm[0]\n", 368 | "________________________________________________________________________________________________________________________\n", 369 | "Encoder-9-FeedForward-Dropout (Dropout (None, 512, 768) 0 Encoder-9-FeedForward[0][0] \n", 370 | "________________________________________________________________________________________________________________________\n", 371 | "Encoder-9-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-9-MultiHeadSelfAttention-Norm[0]\n", 372 | " Encoder-9-FeedForward-Dropout[0][0] \n", 373 | "________________________________________________________________________________________________________________________\n", 374 | "Encoder-9-FeedForward-Norm (LayerNorma (None, 512, 768) 1536 Encoder-9-FeedForward-Add[0][0] \n", 375 | "________________________________________________________________________________________________________________________\n", 376 | "Encoder-10-MultiHeadSelfAttention (Mul (None, None, 768) 2362368 Encoder-9-FeedForward-Norm[0][0] \n", 377 | "________________________________________________________________________________________________________________________\n", 378 | "Encoder-10-MultiHeadSelfAttention-Drop (None, None, 768) 0 Encoder-10-MultiHeadSelfAttention[0][0] \n", 379 | "________________________________________________________________________________________________________________________\n", 380 | "Encoder-10-MultiHeadSelfAttention-Add (None, 512, 768) 0 Encoder-9-FeedForward-Norm[0][0] \n", 381 | " Encoder-10-MultiHeadSelfAttention-Dropou\n", 382 | "________________________________________________________________________________________________________________________\n", 383 | "Encoder-10-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-10-MultiHeadSelfAttention-Add[0]\n", 384 | "________________________________________________________________________________________________________________________\n", 385 | "Encoder-10-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-10-MultiHeadSelfAttention-Norm[0\n", 386 | "________________________________________________________________________________________________________________________\n", 387 | "Encoder-10-FeedForward-Dropout (Dropou (None, 512, 768) 0 Encoder-10-FeedForward[0][0] \n", 388 | "________________________________________________________________________________________________________________________\n", 389 | "Encoder-10-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-10-MultiHeadSelfAttention-Norm[0\n", 390 | " Encoder-10-FeedForward-Dropout[0][0] \n", 391 | "________________________________________________________________________________________________________________________\n", 392 | "Encoder-10-FeedForward-Norm (LayerNorm (None, 512, 768) 1536 Encoder-10-FeedForward-Add[0][0] \n", 393 | "________________________________________________________________________________________________________________________\n", 394 | "Encoder-11-MultiHeadSelfAttention (Mul (None, None, 768) 2362368 Encoder-10-FeedForward-Norm[0][0] \n", 395 | "________________________________________________________________________________________________________________________\n", 396 | "Encoder-11-MultiHeadSelfAttention-Drop (None, None, 768) 0 Encoder-11-MultiHeadSelfAttention[0][0] \n", 397 | "________________________________________________________________________________________________________________________\n", 398 | "Encoder-11-MultiHeadSelfAttention-Add (None, 512, 768) 0 Encoder-10-FeedForward-Norm[0][0] \n", 399 | " Encoder-11-MultiHeadSelfAttention-Dropou\n", 400 | "________________________________________________________________________________________________________________________\n", 401 | "Encoder-11-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-11-MultiHeadSelfAttention-Add[0]\n", 402 | "________________________________________________________________________________________________________________________\n", 403 | "Encoder-11-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-11-MultiHeadSelfAttention-Norm[0\n", 404 | "________________________________________________________________________________________________________________________\n", 405 | "Encoder-11-FeedForward-Dropout (Dropou (None, 512, 768) 0 Encoder-11-FeedForward[0][0] \n", 406 | "________________________________________________________________________________________________________________________\n", 407 | "Encoder-11-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-11-MultiHeadSelfAttention-Norm[0\n", 408 | " Encoder-11-FeedForward-Dropout[0][0] \n", 409 | "________________________________________________________________________________________________________________________\n", 410 | "Encoder-11-FeedForward-Norm (LayerNorm (None, 512, 768) 1536 Encoder-11-FeedForward-Add[0][0] \n", 411 | "________________________________________________________________________________________________________________________\n", 412 | "Encoder-12-MultiHeadSelfAttention (Mul (None, None, 768) 2362368 Encoder-11-FeedForward-Norm[0][0] \n", 413 | "________________________________________________________________________________________________________________________\n", 414 | "Encoder-12-MultiHeadSelfAttention-Drop (None, None, 768) 0 Encoder-12-MultiHeadSelfAttention[0][0] \n", 415 | "________________________________________________________________________________________________________________________\n", 416 | "Encoder-12-MultiHeadSelfAttention-Add (None, 512, 768) 0 Encoder-11-FeedForward-Norm[0][0] \n", 417 | " Encoder-12-MultiHeadSelfAttention-Dropou\n", 418 | "________________________________________________________________________________________________________________________\n", 419 | "Encoder-12-MultiHeadSelfAttention-Norm (None, 512, 768) 1536 Encoder-12-MultiHeadSelfAttention-Add[0]\n", 420 | "________________________________________________________________________________________________________________________\n", 421 | "Encoder-12-FeedForward (FeedForward) (None, 512, 768) 4722432 Encoder-12-MultiHeadSelfAttention-Norm[0\n", 422 | "________________________________________________________________________________________________________________________\n", 423 | "Encoder-12-FeedForward-Dropout (Dropou (None, 512, 768) 0 Encoder-12-FeedForward[0][0] \n", 424 | "________________________________________________________________________________________________________________________\n", 425 | "Encoder-12-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-12-MultiHeadSelfAttention-Norm[0\n", 426 | " Encoder-12-FeedForward-Dropout[0][0] \n", 427 | "________________________________________________________________________________________________________________________\n", 428 | "Encoder-12-FeedForward-Norm (LayerNorm (None, 512, 768) 1536 Encoder-12-FeedForward-Add[0][0] \n", 429 | "========================================================================================================================\n", 430 | "Total params: 101,677,056\n", 431 | "Trainable params: 0\n", 432 | "Non-trainable params: 101,677,056\n", 433 | "________________________________________________________________________________________________________________________\n" 434 | ] 435 | } 436 | ], 437 | "source": [ 438 | "from keras_bert import load_trained_model_from_checkpoint\n", 439 | "\n", 440 | "model = load_trained_model_from_checkpoint(config_path, checkpoint_path)\n", 441 | "model.summary(line_length=120)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": { 447 | "id": "w9kAo-ZDjZ21", 448 | "colab_type": "text" 449 | }, 450 | "source": [ 451 | "## Tokenization" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 8, 457 | "metadata": { 458 | "id": "DDsAcJI6jdjA", 459 | "colab_type": "code", 460 | "outputId": "e21c0038-c9d0-48bd-dc18-6051bf3d7929", 461 | "colab": { 462 | "base_uri": "https://localhost:8080/", 463 | "height": 51.0 464 | } 465 | }, 466 | "outputs": [ 467 | { 468 | "name": "stdout", 469 | "output_type": "stream", 470 | "text": [ 471 | "[101, 6427, 6241, 3563, 1798, 102, 0, 0, 0, 0]\n", 472 | "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n" 473 | ] 474 | } 475 | ], 476 | "source": [ 477 | "from keras_bert import Tokenizer\n", 478 | "\n", 479 | "tokenizer = Tokenizer(token_dict)\n", 480 | "text = '语言模型'\n", 481 | "tokens = tokenizer.tokenize(text)\n", 482 | "indices, segments = tokenizer.encode(first=text, max_len=512)\n", 483 | "print(indices[:10])\n", 484 | "print(segments[:10])" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": { 490 | "id": "iofhUkhokBp6", 491 | "colab_type": "text" 492 | }, 493 | "source": [ 494 | "## Extract Feature" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 9, 500 | "metadata": { 501 | "id": "vQvu-SfGjiDI", 502 | "colab_type": "code", 503 | "outputId": "3fe3da96-f3f3-43c4-c9be-a752bb33740d", 504 | "colab": { 505 | "base_uri": "https://localhost:8080/", 506 | "height": 119.0 507 | } 508 | }, 509 | "outputs": [ 510 | { 511 | "name": "stdout", 512 | "output_type": "stream", 513 | "text": [ 514 | "[CLS] [-0.6325103044509888, 0.20302410423755646, 0.07936538010835648, -0.03284265100955963, 0.5668085813522339]\n", 515 | "语 [-0.7588362097740173, 0.0965188592672348, 1.0718743801116943, 0.005039289593696594, 0.6887993812561035]\n", 516 | "言 [0.5477026104927063, -0.7921162843704224, 0.44435110688209534, -0.7112641930580139, 1.2048895359039307]\n", 517 | "模 [-0.29242411255836487, 0.6052717566490173, 0.49968627095222473, -0.42457854747772217, 0.42855408787727356]\n", 518 | "型 [-0.7473456263542175, 0.49431660771369934, 0.7185154557228088, -0.8723534941673279, 0.8349594473838806]\n", 519 | "[SEP] [-0.8741379976272583, -0.2165030986070633, 1.33883798122406, -0.10587061941623688, 0.3960897624492645]\n" 520 | ] 521 | } 522 | ], 523 | "source": [ 524 | "import numpy as np\n", 525 | "\n", 526 | "predicts = model.predict([np.array([indices]), np.array([segments])])[0]\n", 527 | "for i, token in enumerate(tokens):\n", 528 | " print(token, predicts[i].tolist()[:5])" 529 | ] 530 | } 531 | ], 532 | "metadata": { 533 | "colab": { 534 | "name": "keras_bert_load_and_extract.ipynb", 535 | "version": "0.3.2", 536 | "provenance": [], 537 | "collapsed_sections": [] 538 | }, 539 | "kernelspec": { 540 | "name": "python3", 541 | "display_name": "Python 3" 542 | } 543 | }, 544 | "nbformat": 4, 545 | "nbformat_minor": 0 546 | } 547 | --------------------------------------------------------------------------------