├── test_tftext.py ├── __init__.py ├── test_mask.py ├── test_tfds_wmt.py ├── LICENSE ├── tokenization_en.py ├── .gitignore ├── test_bert_tf2.py ├── transformer_test.py ├── README.md ├── tokenization_test.py ├── translate.py ├── tokenization.py ├── transformer.py └── extract_features.py /test_tftext.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import tokenization 4 | 5 | tokenizer = tokenization.FullTokenizer( 6 | vocab_file='/Users/livingmagic/Documents/deeplearning/models/bert/chinese_L-12_H-768_A-12/vocab.txt', do_lower_case=True) 7 | 8 | tokens = tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]']) 9 | print(tokens) 10 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /test_mask.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def create_padding_mask(seq): 5 | seq = tf.cast(tf.math.equal(seq, 0), tf.float32) 6 | 7 | # add extra dimensions so that we can add the padding 8 | # to the attention logits. 9 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len) 10 | 11 | 12 | def create_look_ahead_mask(size): 13 | mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) 14 | return mask # (seq_len, seq_len) 15 | 16 | 17 | def test(): 18 | seq = tf.constant([[0, 1, 2], [1, 2, 3]]) 19 | dec_target_padding_mask = create_padding_mask(seq) 20 | look_ahead_mask = create_look_ahead_mask(seq.shape[1]) 21 | 22 | combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) 23 | 24 | print(combined_mask) 25 | 26 | 27 | test() 28 | -------------------------------------------------------------------------------- /test_tfds_wmt.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_datasets as tfds 3 | import tensorflow_text as text 4 | 5 | config = tfds.translate.wmt.WmtConfig( 6 | description="WMT 2019 translation task dataset.", 7 | version="0.0.3", 8 | language_pair=("zh", "en"), 9 | subsets={ 10 | tfds.Split.TRAIN: ["newscommentary_v13"], 11 | tfds.Split.VALIDATION: ["newsdev2017"], 12 | } 13 | ) 14 | 15 | builder = tfds.builder("wmt_translate", config=config) 16 | print(builder.info) 17 | builder.download_and_prepare() 18 | datasets = builder.as_dataset(as_supervised=True) 19 | train_dataset = datasets['train'] 20 | val_dataset = datasets['validation'] 21 | 22 | for zh, en in train_dataset.take(5): 23 | print('zh: {}'.format(zh.numpy())) 24 | print('en: {}'.format(en.numpy())) 25 | 26 | # If you need NumPy arrays 27 | # np_datasets = tfds.as_numpy(datasets) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 livingmagic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tokenization_en.py: -------------------------------------------------------------------------------- 1 | import tensorflow_datasets as tfds 2 | 3 | 4 | def load_dataset(): 5 | config = tfds.translate.wmt.WmtConfig( 6 | description="WMT 2019 translation task dataset.", 7 | version="0.0.3", 8 | language_pair=("zh", "en"), 9 | subsets={ 10 | tfds.Split.TRAIN: ["newscommentary_v13"], 11 | tfds.Split.VALIDATION: ["newsdev2017"], 12 | } 13 | ) 14 | 15 | builder = tfds.builder("wmt_translate", config=config) 16 | print(builder.info.splits) 17 | builder.download_and_prepare() 18 | datasets = builder.as_dataset(as_supervised=True) 19 | 20 | return datasets['train'], datasets['validation'] 21 | 22 | 23 | def make_subword_vocab(examples): 24 | tokenizer_en = tfds.features.text.SubwordTextEncoder.build_from_corpus( 25 | (en.numpy() for zh, en in examples), target_vocab_size=2 ** 13) 26 | sample_string = 'Transformer is awesome.' 27 | tokenized_string = tokenizer_en.encode(sample_string) 28 | print('Tokenized string is {}'.format(tokenized_string)) 29 | original_string = tokenizer_en.decode(tokenized_string) 30 | print('The original string: {}'.format(original_string)) 31 | 32 | tokenizer_en.save_to_file('vocab_en') 33 | 34 | 35 | def load_subword_vocab(vocab_file): 36 | return tfds.features.text.SubwordTextEncoder.load_from_file(vocab_file) 37 | 38 | 39 | if __name__ == "__main__": 40 | train_examples, _ = load_dataset() 41 | tokenizer_en = tfds.features.text.SubwordTextEncoder.build_from_corpus( 42 | (en.numpy() for zh, en in train_examples), target_vocab_size=2 ** 13) 43 | tokenizer_en.save_to_file('vocab_en.txt') 44 | print(tokenizer_en.vocab_size + 2) 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .DS_Store 107 | .idea 108 | .idea/* 109 | tmp/ 110 | test.py -------------------------------------------------------------------------------- /test_bert_tf2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | from bert.embeddings import BertEmbeddingsLayer 5 | from tensorflow import keras 6 | from bert import BertModelLayer 7 | from bert.loader import StockBertConfig, load_stock_weights 8 | import logging 9 | 10 | 11 | def test1(): 12 | l_bert = BertModelLayer( 13 | vocab_size=16000, # embedding params 14 | use_token_type=True, 15 | use_position_embeddings=True, 16 | token_type_vocab_size=2, 17 | 18 | num_layers=12, # transformer encoder params 19 | hidden_size=768, 20 | hidden_dropout=0.1, 21 | intermediate_size=4 * 768, 22 | intermediate_activation="gelu", 23 | 24 | name="bert" # any other Keras layer params 25 | ) 26 | 27 | print(l_bert.params) 28 | 29 | 30 | def test2(): 31 | model_dir = "/Users/livingmagic/Documents/deeplearning/models/bert/chinese_L-12_H-768_A-12" 32 | 33 | bert_config_file = os.path.join(model_dir, "bert_config.json") 34 | bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt") 35 | 36 | with tf.io.gfile.GFile(bert_config_file, "r") as reader: 37 | stock_params = StockBertConfig.from_json_string(reader.read()) 38 | bert_params = stock_params.to_bert_model_layer_params() 39 | 40 | l_bert = BertModelLayer.from_params(bert_params, name="bert", trainable=False) 41 | 42 | # # Input and output endpoints 43 | max_seq_len = 128 44 | l_input_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32') 45 | output = l_bert(l_input_ids, training=False) # [batch_size, max_seq_len, hidden_size] 46 | print('Output shape: {}'.format(output.get_shape())) 47 | 48 | bert_model = keras.Model(inputs=l_input_ids, outputs=output) 49 | print(bert_model.trainable_weights) 50 | 51 | # # loading the original pre-trained weights into the BERT layer: 52 | # load_stock_weights(l_bert, bert_ckpt_file) 53 | # print(bert_model.predict(np.arange(0, 128)[np.newaxis, :])) 54 | 55 | 56 | def test_embeddings_layer(): 57 | tf.get_logger().setLevel(logging.INFO) 58 | layer = BertEmbeddingsLayer() 59 | mask = layer.compute_mask(inputs=np.array([1, 2, 0, 0, 1])) 60 | print('mask: {}'.format(tf.cast(mask, tf.float32) * 0.1)) 61 | 62 | 63 | test_embeddings_layer() -------------------------------------------------------------------------------- /transformer_test.py: -------------------------------------------------------------------------------- 1 | from transformer import * 2 | import numpy as np 3 | import os 4 | 5 | 6 | class TransformerTest(tf.test.TestCase): 7 | def test_transform(self): 8 | MODEL_DIR = "/Users/livingmagic/Documents/deeplearning/models/bert/chinese_L-12_H-768_A-12" 9 | bert_config_file = os.path.join(MODEL_DIR, "bert_config.json") 10 | bert_ckpt_file = os.path.join(MODEL_DIR, "bert_model.ckpt") 11 | 12 | config = Config(num_layers=4, d_model=128, dff=512, num_heads=8) 13 | 14 | transformer = Transformer(config=config, 15 | target_vocab_size=8173, 16 | bert_config_file=bert_config_file) 17 | 18 | inp = tf.random.uniform((32, 128)) 19 | tar_inp = tf.random.uniform((32, 128)) 20 | fn_out, _ = transformer(inp, tar_inp, 21 | True, 22 | look_ahead_mask=None, 23 | dec_padding_mask=None) 24 | print(tar_inp.shape) 25 | print(fn_out.shape) # (batch_size, tar_seq_len) (batch_size, tar_seq_len, target_vocab_size) 26 | 27 | w11 = tf.reduce_sum(transformer.encoder.weights[0]).numpy() 28 | w12 = tf.reduce_sum(transformer.encoder.weights[1]).numpy() 29 | # init bert pre-trained weights 30 | transformer.restore_encoder(bert_ckpt_file) 31 | w21 = tf.reduce_sum(transformer.encoder.weights[0]).numpy() 32 | w22 = tf.reduce_sum(transformer.encoder.weights[1]).numpy() 33 | self.assertNotEqual(w11, w21) 34 | self.assertNotEqual(w12, w22) 35 | 36 | def test_encoder(self): 37 | MODEL_DIR = "/Users/livingmagic/Documents/deeplearning/models/bert/chinese_L-12_H-768_A-12" 38 | bert_config_file = os.path.join(MODEL_DIR, "bert_config.json") 39 | 40 | bert_encoder = build_encoder(config_file=bert_config_file) 41 | bert_encoder.trainable = False 42 | inp = tf.random.uniform((32, 128)) 43 | bert_encoder(inp, training=False) 44 | 45 | weight_names = [] 46 | for weight in bert_encoder.weights: 47 | weight_names.append(weight.name) 48 | with open('2.txt', 'w') as f: 49 | f.write(str(weight_names)) 50 | 51 | 52 | if __name__ == "__main__": 53 | tf.test.main() 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nmt-with-bert-tf2 2 | A Transformer model to translate Chinese to English using pre-trained model BERT as encoder in Tensorflow 2.0. 3 | 4 | ## Usage 5 | 6 | You can train model free in the Google Colab notebook: [nmt_with_transformer.ipynb](https://colab.research.google.com/github/livingmagic/nmt-with-bert-tf2/blob/master/nmt_with_transformer.ipynb) 7 | 8 | After 4 epochs, the final training result is: **Epoch 4 Loss 0.5936 Accuracy 0.1355** 9 | 10 | ```reStructuredText 11 | Epoch 4 Batch 0 Loss 0.6691 Accuracy 0.1323 12 | Epoch 4 Batch 500 Loss 0.6059 Accuracy 0.1335 13 | Epoch 4 Batch 1000 Loss 0.6038 Accuracy 0.1335 14 | Epoch 4 Batch 1500 Loss 0.6016 Accuracy 0.1335 15 | Epoch 4 Batch 2000 Loss 0.6010 Accuracy 0.1342 16 | Epoch 4 Batch 2500 Loss 0.5993 Accuracy 0.1346 17 | Epoch 4 Batch 3000 Loss 0.5968 Accuracy 0.1350 18 | Epoch 4 Batch 3500 Loss 0.5955 Accuracy 0.1353 19 | Saving checkpoint for epoch 4 at ./checkpoints/train/ckpt-4 20 | Epoch 4 Loss 0.5936 Accuracy 0.1355 21 | Time taken for 1 epoch: 3553.9979977607727 secs 22 | ``` 23 | 24 | You can evaluate some texts, such as: 25 | 26 | ```reStructuredText 27 | Input: 我爱你是一件幸福的事情。 28 | Predicted translation: I love you are a blessing. 29 | ``` 30 | 31 | ```text 32 | Input: 虽然继承了祖荫,但朴槿惠已经证明了自己是个机敏而老练的政治家——她历经20年才爬上韩国大国家党最高领导层并成为全国知名人物。 33 | Predicted translation: While inherited her father, Park has proven that she is a brave and al-keal – a politician who has been able to topple the country’s largest party and become a national leader. 34 | Real translation: While Park derives some of her power from her family pedigree, she has proven to be an astute and seasoned politician – one who climbed the Grand National Party’s leadership ladder over the last two decades to emerge as a national figure. 35 | ``` 36 | 37 | **Notice**: Must add"。"in the end of the input sentence, otherwise you may get a unexpected result. 38 | 39 | ## Using BERT to extract fixed feature vectors (like ELMo) 40 | 41 | The chinese BERT pre-trained model used here is: [BERT-Base, Chinese](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip). You can use the example code below to extract features using tensorflow2. 42 | 43 | ``` 44 | # Sentence A and Sentence B are separated by the ||| delimiter for sentence 45 | # pair tasks like question answering and entailment. 46 | # For single sentence inputs, put one sentence per line and DON'T use the 47 | # delimiter. 48 | echo '我不是故意说的。 ||| 真心话大冒险!\n富士康科技集团发声明否认“撤离大陆”?' > tmp/input.txt 49 | 50 | python extract_features.py \ 51 | --input_file=tmp/input.txt \ 52 | --output_file=tmp/output.jsonl \ 53 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 54 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 55 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 56 | --max_seq_length=128 \ 57 | --batch_size=8 58 | ``` 59 | 60 | ## Resources 61 | 62 | - [BERT](https://arxiv.org/abs/1810.04805) - BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 63 | - [google-research/bert](https://github.com/google-research/bert) - the original BERT implementation 64 | - [kpe/bert-for-tf2](https://github.com/kpe/bert-for-tf2) - A Keras TensorFlow 2.0 implementation of BERT. -------------------------------------------------------------------------------- /tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | import tokenization 22 | import six 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | vocab_writer.write("".join( 38 | [x + "\n" for x in vocab_tokens]).encode("utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(tokenization._is_punctuation(u"-")) 127 | self.assertTrue(tokenization._is_punctuation(u"$")) 128 | self.assertTrue(tokenization._is_punctuation(u"`")) 129 | self.assertTrue(tokenization._is_punctuation(u".")) 130 | 131 | self.assertFalse(tokenization._is_punctuation(u"A")) 132 | self.assertFalse(tokenization._is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from absl import flags 3 | from absl import app 4 | from absl import logging 5 | 6 | from tokenization import FullTokenizer 7 | from tokenization_en import load_subword_vocab 8 | from transformer import Transformer, FileConfig 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | MODEL_DIR = "/Users/livingmagic/Documents/deeplearning/models/bert-nmt/zh-en_bert-tf2_L6-D256/" 13 | 14 | flags.DEFINE_string("bert_config_file", MODEL_DIR + "bert_config.json", "The bert config file") 15 | flags.DEFINE_string("bert_vocab_file", MODEL_DIR + "vocab.txt", 16 | "The vocabulary file that the BERT model was trained on.") 17 | 18 | flags.DEFINE_string("init_checkpoint", MODEL_DIR + "bert_nmt_ckpt", "") 19 | flags.DEFINE_string("config_file", MODEL_DIR + "config.json", "The transformer config file except bert") 20 | flags.DEFINE_string("vocab_file", MODEL_DIR + "vocab_en", "The english vocabulary file") 21 | flags.DEFINE_integer("max_seq_length", 128, "Max length to sequence length") 22 | flags.DEFINE_string("inp_sentence", None, "") 23 | 24 | 25 | def create_padding_mask(seq): 26 | seq = tf.cast(tf.math.equal(seq, 0), tf.float32) 27 | 28 | # add extra dimensions so that we can add the padding 29 | # to the attention logits. 30 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len) 31 | 32 | 33 | def create_look_ahead_mask(size): 34 | """ 35 | The look-ahead mask is used to mask the future tokens in a sequence. 36 | In other words, the mask indicates which entries should not be used. 37 | """ 38 | mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) 39 | return mask # (seq_len, seq_len) 40 | 41 | 42 | def create_masks(inp, tar): 43 | # Used in the 2nd attention block in the decoder. 44 | # This padding mask is used to mask the encoder outputs. 45 | dec_padding_mask = create_padding_mask(inp) 46 | 47 | # Used in the 1st attention block in the decoder. 48 | # It is used to pad and mask future tokens in the input received by 49 | # the decoder. 50 | look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1]) 51 | dec_target_padding_mask = create_padding_mask(tar) 52 | combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) 53 | 54 | return combined_mask, dec_padding_mask 55 | 56 | 57 | def encode_zh(tokenizer_zh, zh): 58 | tokens_zh = tokenizer_zh.tokenize(zh) 59 | lang1 = tokenizer_zh.convert_tokens_to_ids(['[CLS]'] + tokens_zh + ['[SEP]']) 60 | 61 | return lang1 62 | 63 | 64 | def evaluate(transformer, 65 | tokenizer_zh, 66 | tokenizer_en, 67 | inp_sentence, 68 | max_seq_length): 69 | # normalize input sentence 70 | inp_sentence = encode_zh(tokenizer_zh, inp_sentence) 71 | encoder_input = tf.expand_dims(inp_sentence, 0) 72 | 73 | # as the target is english, the first word to the transformer should be the 74 | # english start token. 75 | decoder_input = [tokenizer_en.vocab_size] 76 | output = tf.expand_dims(decoder_input, 0) 77 | 78 | for i in range(max_seq_length): 79 | combined_mask, dec_padding_mask = create_masks( 80 | encoder_input, output) 81 | 82 | # predictions.shape == (batch_size, seq_len, vocab_size) 83 | predictions, attention_weights = transformer(encoder_input, 84 | output, 85 | False, 86 | combined_mask, 87 | dec_padding_mask) 88 | 89 | # select the last word from the seq_len dimension 90 | predictions = predictions[:, -1:, :] # (batch_size, 1, vocab_size) 91 | 92 | predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32) 93 | 94 | # return the result if the predicted_id is equal to the end token 95 | if tf.equal(predicted_id, tokenizer_en.vocab_size + 1): 96 | return tf.squeeze(output, axis=0), attention_weights 97 | 98 | # concatentate the predicted_id to the output which is given to the decoder 99 | # as its input. 100 | output = tf.concat([output, predicted_id], axis=-1) 101 | 102 | return tf.squeeze(output, axis=0), attention_weights 103 | 104 | 105 | def main(_): 106 | tokenizer_zh = FullTokenizer( 107 | vocab_file=FLAGS.bert_vocab_file, do_lower_case=True) 108 | 109 | tokenizer_en = load_subword_vocab(FLAGS.vocab_file) 110 | target_vocab_size = tokenizer_en.vocab_size + 2 111 | 112 | config = FileConfig(FLAGS.config_file) 113 | transformer = Transformer(config=config, 114 | target_vocab_size=target_vocab_size, 115 | bert_config_file=FLAGS.bert_config_file) 116 | 117 | inp = tf.random.uniform((1, FLAGS.max_seq_length)) 118 | tar_inp = tf.random.uniform((1, FLAGS.max_seq_length)) 119 | fn_out, _ = transformer(inp, tar_inp, 120 | True, 121 | look_ahead_mask=None, 122 | dec_padding_mask=None) 123 | 124 | transformer.load_weights(FLAGS.init_checkpoint) 125 | 126 | print(transformer.encoder.weights[0]) 127 | 128 | result, _ = evaluate(transformer, 129 | tokenizer_zh, 130 | tokenizer_en, 131 | FLAGS.inp_sentence, 132 | FLAGS.max_seq_length) 133 | 134 | predicted_sentence = tokenizer_en.decode([i for i in result 135 | if i < tokenizer_en.vocab_size]) 136 | 137 | print('Input: {}'.format(FLAGS.inp_sentence)) 138 | print('Predicted translation: {}'.format(predicted_sentence)) 139 | 140 | 141 | if __name__ == "__main__": 142 | flags.mark_flag_as_required("inp_sentence") 143 | app.run(main) 144 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.io.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_by_vocab(vocab, items): 86 | """Converts a sequence of [tokens|ids] using the vocab.""" 87 | output = [] 88 | for item in items: 89 | output.append(vocab[item]) 90 | return output 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | return convert_by_vocab(vocab, tokens) 95 | 96 | 97 | def convert_ids_to_tokens(inv_vocab, ids): 98 | return convert_by_vocab(inv_vocab, ids) 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab = load_vocab(vocab_file) 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | 119 | def tokenize(self, text): 120 | split_tokens = [] 121 | for token in self.basic_tokenizer.tokenize(text): 122 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 123 | split_tokens.append(sub_token) 124 | 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | return convert_by_vocab(self.vocab, tokens) 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | return convert_by_vocab(self.inv_vocab, ids) 132 | 133 | 134 | class BasicTokenizer(object): 135 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 136 | 137 | def __init__(self, do_lower_case=True): 138 | """Constructs a BasicTokenizer. 139 | 140 | Args: 141 | do_lower_case: Whether to lower case the input. 142 | """ 143 | self.do_lower_case = do_lower_case 144 | 145 | def tokenize(self, text): 146 | """Tokenizes a piece of text.""" 147 | text = convert_to_unicode(text) 148 | text = self._clean_text(text) 149 | 150 | # This was added on November 1st, 2018 for the multilingual and Chinese 151 | # models. This is also applied to the English models now, but it doesn't 152 | # matter since the English models were not trained on any Chinese data 153 | # and generally don't have any Chinese data in them (there are Chinese 154 | # characters in the vocabulary because Wikipedia does have some Chinese 155 | # words in the English Wikipedia.). 156 | text = self._tokenize_chinese_chars(text) 157 | 158 | orig_tokens = whitespace_tokenize(text) 159 | split_tokens = [] 160 | for token in orig_tokens: 161 | if self.do_lower_case: 162 | token = token.lower() 163 | token = self._run_strip_accents(token) 164 | split_tokens.extend(self._run_split_on_punc(token)) 165 | 166 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 167 | return output_tokens 168 | 169 | def _run_strip_accents(self, text): 170 | """Strips accents from a piece of text.""" 171 | text = unicodedata.normalize("NFD", text) 172 | output = [] 173 | for char in text: 174 | cat = unicodedata.category(char) 175 | if cat == "Mn": 176 | continue 177 | output.append(char) 178 | return "".join(output) 179 | 180 | def _run_split_on_punc(self, text): 181 | """Splits punctuation on a piece of text.""" 182 | chars = list(text) 183 | i = 0 184 | start_new_word = True 185 | output = [] 186 | while i < len(chars): 187 | char = chars[i] 188 | if _is_punctuation(char): 189 | output.append([char]) 190 | start_new_word = True 191 | else: 192 | if start_new_word: 193 | output.append([]) 194 | start_new_word = False 195 | output[-1].append(char) 196 | i += 1 197 | 198 | return ["".join(x) for x in output] 199 | 200 | def _tokenize_chinese_chars(self, text): 201 | """Adds whitespace around any CJK character.""" 202 | output = [] 203 | for char in text: 204 | cp = ord(char) 205 | if self._is_chinese_char(cp): 206 | output.append(" ") 207 | output.append(char) 208 | output.append(" ") 209 | else: 210 | output.append(char) 211 | return "".join(output) 212 | 213 | def _is_chinese_char(self, cp): 214 | """Checks whether CP is the codepoint of a CJK character.""" 215 | # This defines a "chinese character" as anything in the CJK Unicode block: 216 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 217 | # 218 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 219 | # despite its name. The modern Korean Hangul alphabet is a different block, 220 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 221 | # space-separated words, so they are not treated specially and handled 222 | # like the all of the other languages. 223 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 224 | (cp >= 0x3400 and cp <= 0x4DBF) or # 225 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 226 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 227 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 228 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 229 | (cp >= 0xF900 and cp <= 0xFAFF) or # 230 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 231 | return True 232 | 233 | return False 234 | 235 | def _clean_text(self, text): 236 | """Performs invalid character removal and whitespace cleanup on text.""" 237 | output = [] 238 | for char in text: 239 | cp = ord(char) 240 | if cp == 0 or cp == 0xfffd or _is_control(char): 241 | continue 242 | if _is_whitespace(char): 243 | output.append(" ") 244 | else: 245 | output.append(char) 246 | return "".join(output) 247 | 248 | 249 | class WordpieceTokenizer(object): 250 | """Runs WordPiece tokenziation.""" 251 | 252 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 253 | self.vocab = vocab 254 | self.unk_token = unk_token 255 | self.max_input_chars_per_word = max_input_chars_per_word 256 | 257 | def tokenize(self, text): 258 | """Tokenizes a piece of text into its word pieces. 259 | 260 | This uses a greedy longest-match-first algorithm to perform tokenization 261 | using the given vocabulary. 262 | 263 | For example: 264 | input = "unaffable" 265 | output = ["un", "##aff", "##able"] 266 | 267 | Args: 268 | text: A single token or whitespace separated tokens. This should have 269 | already been passed through `BasicTokenizer. 270 | 271 | Returns: 272 | A list of wordpiece tokens. 273 | """ 274 | 275 | text = convert_to_unicode(text) 276 | 277 | output_tokens = [] 278 | for token in whitespace_tokenize(text): 279 | chars = list(token) 280 | if len(chars) > self.max_input_chars_per_word: 281 | output_tokens.append(self.unk_token) 282 | continue 283 | 284 | is_bad = False 285 | start = 0 286 | sub_tokens = [] 287 | while start < len(chars): 288 | end = len(chars) 289 | cur_substr = None 290 | while start < end: 291 | substr = "".join(chars[start:end]) 292 | if start > 0: 293 | substr = "##" + substr 294 | if substr in self.vocab: 295 | cur_substr = substr 296 | break 297 | end -= 1 298 | if cur_substr is None: 299 | is_bad = True 300 | break 301 | sub_tokens.append(cur_substr) 302 | start = end 303 | 304 | if is_bad: 305 | output_tokens.append(self.unk_token) 306 | else: 307 | output_tokens.extend(sub_tokens) 308 | return output_tokens 309 | 310 | 311 | def _is_whitespace(char): 312 | """Checks whether `chars` is a whitespace character.""" 313 | # \t, \n, and \r are technically contorl characters but we treat them 314 | # as whitespace since they are generally considered as such. 315 | if char == " " or char == "\t" or char == "\n" or char == "\r": 316 | return True 317 | cat = unicodedata.category(char) 318 | if cat == "Zs": 319 | return True 320 | return False 321 | 322 | 323 | def _is_control(char): 324 | """Checks whether `chars` is a control character.""" 325 | # These are technically control characters but we count them as whitespace 326 | # characters. 327 | if char == "\t" or char == "\n" or char == "\r": 328 | return False 329 | cat = unicodedata.category(char) 330 | if cat.startswith("C"): 331 | return True 332 | return False 333 | 334 | 335 | def _is_punctuation(char): 336 | """Checks whether `chars` is a punctuation character.""" 337 | cp = ord(char) 338 | # We treat all non-letter/number ASCII as punctuation. 339 | # Characters such as "^", "$", and "`" are not in the Unicode 340 | # Punctuation class but we treat them as punctuation anyways, for 341 | # consistency. 342 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 343 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 344 | return True 345 | cat = unicodedata.category(char) 346 | if cat.startswith("P"): 347 | return True 348 | return False 349 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import tensorflow as tf 4 | 5 | import numpy as np 6 | from bert import BertModelLayer 7 | 8 | from bert.loader import StockBertConfig, map_to_stock_variable_name 9 | 10 | 11 | def get_angles(pos, i, d_model): 12 | angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)) 13 | return pos * angle_rates 14 | 15 | 16 | def positional_encoding(seq_length, d_model): 17 | """ 18 | The formula cossin for positional encoding 19 | :param seq_length: 20 | :param d_model: 21 | :return: 22 | """ 23 | angle_rads = get_angles(np.arange(seq_length)[:, np.newaxis], 24 | np.arange(d_model)[np.newaxis, :], 25 | d_model) 26 | 27 | # apply sin to even indices in the array; 2i 28 | sines = np.sin(angle_rads[:, 0::2]) 29 | 30 | # apply cos to odd indices in the array; 2i+1 31 | cosines = np.cos(angle_rads[:, 1::2]) 32 | 33 | pos_encoding = np.concatenate([sines, cosines], axis=-1) 34 | 35 | pos_encoding = pos_encoding[np.newaxis, ...] 36 | 37 | return tf.cast(pos_encoding, dtype=tf.float32) 38 | 39 | 40 | def positional_embedding( 41 | seq_length, 42 | d_model, 43 | position_embedding_name="position_embeddings", 44 | max_position_embeddings=512, 45 | initial_stddev=0.02): 46 | full_position_embeddings = tf.Variable( 47 | initial_value=tf.random.truncated_normal(shape=[max_position_embeddings, d_model], stddev=initial_stddev), 48 | name=position_embedding_name) 49 | # Since the position embedding table is a learned variable, we create it 50 | # using a (long) sequence length `max_position_embeddings`. The actual 51 | # sequence length might be shorter than this, for faster training of 52 | # tasks that do not have long sequences. 53 | # 54 | # So `full_position_embeddings` is effectively an embedding table 55 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 56 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 57 | # perform a slice. 58 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 59 | [seq_length, -1]) 60 | 61 | pos_encoding = position_embeddings[np.newaxis, ...] 62 | 63 | return tf.cast(pos_encoding, dtype=tf.float32) 64 | 65 | 66 | def scaled_dot_product_attention(q, k, v, mask): 67 | """Calculate the attention weights. 68 | q, k, v must have matching leading dimensions. 69 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. 70 | The mask has different shapes depending on its type(padding or look ahead) 71 | but it must be broadcastable for addition. 72 | 73 | Args: 74 | q: query shape == (..., seq_len_q, depth) 75 | k: key shape == (..., seq_len_k, depth) 76 | v: value shape == (..., seq_len_v, depth_v) 77 | mask: Float tensor with shape broadcastable 78 | to (..., seq_len_q, seq_len_k). Defaults to None. 79 | 80 | Returns: 81 | output, attention_weights 82 | """ 83 | 84 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) 85 | 86 | # scale matmul_qk 87 | dk = tf.cast(tf.shape(k)[-1], tf.float32) 88 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) 89 | 90 | # add the mask to the scaled tensor. 91 | if mask is not None: 92 | scaled_attention_logits += (mask * -1e9) 93 | 94 | # softmax is normalized on the last axis (seq_len_k) so that the scores 95 | # add up to 1. 96 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) 97 | 98 | output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) 99 | 100 | return output, attention_weights 101 | 102 | 103 | class MultiHeadAttention(tf.keras.layers.Layer): 104 | def __init__(self, d_model, num_heads): 105 | super(MultiHeadAttention, self).__init__() 106 | self.num_heads = num_heads 107 | self.d_model = d_model 108 | 109 | assert d_model % self.num_heads == 0 110 | 111 | self.depth = d_model // self.num_heads 112 | 113 | self.wq = tf.keras.layers.Dense(d_model) 114 | self.wk = tf.keras.layers.Dense(d_model) 115 | self.wv = tf.keras.layers.Dense(d_model) 116 | 117 | self.dense = tf.keras.layers.Dense(d_model) 118 | 119 | def split_heads(self, x, batch_size): 120 | """Split the last dimension into (num_heads, depth). 121 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) 122 | """ 123 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) 124 | return tf.transpose(x, perm=[0, 2, 1, 3]) 125 | 126 | def call(self, v, k, q, mask): 127 | batch_size = tf.shape(q)[0] 128 | 129 | q = self.wq(q) # (batch_size, seq_len, d_model) 130 | k = self.wk(k) # (batch_size, seq_len, d_model) 131 | v = self.wv(v) # (batch_size, seq_len, d_model) 132 | 133 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) 134 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) 135 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) 136 | 137 | # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth) 138 | # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) 139 | scaled_attention, attention_weights = scaled_dot_product_attention( 140 | q, k, v, mask) 141 | 142 | scaled_attention = tf.transpose(scaled_attention, 143 | perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth) 144 | 145 | concat_attention = tf.reshape(scaled_attention, 146 | (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model) 147 | 148 | output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model) 149 | 150 | return output, attention_weights 151 | 152 | 153 | def point_wise_feed_forward_network(d_model, dff): 154 | return tf.keras.Sequential([ 155 | tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff) 156 | tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model) 157 | ]) 158 | 159 | 160 | def build_encoder(config_file): 161 | with tf.io.gfile.GFile(config_file, "r") as reader: 162 | stock_params = StockBertConfig.from_json_string(reader.read()) 163 | bert_params = stock_params.to_bert_model_layer_params() 164 | 165 | return BertModelLayer.from_params(bert_params, name="bert") 166 | 167 | 168 | class DecoderLayer(tf.keras.layers.Layer): 169 | def __init__(self, d_model, num_heads, dff, rate=0.1): 170 | super(DecoderLayer, self).__init__() 171 | 172 | self.mha1 = MultiHeadAttention(d_model, num_heads) 173 | self.mha2 = MultiHeadAttention(d_model, num_heads) 174 | 175 | self.ffn = point_wise_feed_forward_network(d_model, dff) 176 | 177 | self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) 178 | self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) 179 | self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6) 180 | 181 | self.dropout1 = tf.keras.layers.Dropout(rate) 182 | self.dropout2 = tf.keras.layers.Dropout(rate) 183 | self.dropout3 = tf.keras.layers.Dropout(rate) 184 | 185 | def call(self, x, enc_output, training, 186 | look_ahead_mask, padding_mask): 187 | # enc_output.shape == (batch_size, input_seq_len, d_model) 188 | 189 | attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model) 190 | attn1 = self.dropout1(attn1, training=training) 191 | out1 = self.layernorm1(attn1 + x) 192 | 193 | attn2, attn_weights_block2 = self.mha2( 194 | enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model) 195 | attn2 = self.dropout2(attn2, training=training) 196 | out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model) 197 | 198 | ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model) 199 | ffn_output = self.dropout3(ffn_output, training=training) 200 | out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model) 201 | 202 | return out3, attn_weights_block1, attn_weights_block2 203 | 204 | 205 | class Decoder(tf.keras.layers.Layer): 206 | def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, 207 | rate=0.1): 208 | super(Decoder, self).__init__() 209 | 210 | self.d_model = d_model 211 | self.num_layers = num_layers 212 | 213 | self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model) 214 | self.pos_encoding = positional_encoding(target_vocab_size, self.d_model) 215 | 216 | self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 217 | for _ in range(num_layers)] 218 | self.dropout = tf.keras.layers.Dropout(rate) 219 | 220 | def call(self, x, enc_output, training, 221 | look_ahead_mask, padding_mask): 222 | seq_len = tf.shape(x)[1] 223 | attention_weights = {} 224 | 225 | x = self.embedding(x) # (batch_size, target_seq_len, d_model) 226 | x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) 227 | x += self.pos_encoding[:, :seq_len, :] 228 | 229 | x = self.dropout(x, training=training) 230 | 231 | for i in range(self.num_layers): 232 | x, block1, block2 = self.dec_layers[i](x, enc_output, training, 233 | look_ahead_mask, padding_mask) 234 | 235 | attention_weights['decoder_layer{}_block1'.format(i + 1)] = block1 236 | attention_weights['decoder_layer{}_block2'.format(i + 1)] = block2 237 | 238 | # x.shape == (batch_size, target_seq_len, d_model) 239 | return x, attention_weights 240 | 241 | 242 | class FileConfig(object): 243 | def __init__(self, config_file): 244 | with open(config_file, 'r') as f: 245 | text = f.read() 246 | 247 | params = json.loads(text) 248 | self.num_layers = params['num_layers'] 249 | self.d_model = params['d_model'] 250 | self.dff = params['dff'] 251 | self.num_heads = params['num_heads'] 252 | 253 | 254 | class Config(object): 255 | def __init__(self, num_layers, d_model, dff, num_heads): 256 | self.num_layers = num_layers 257 | self.d_model = d_model 258 | self.dff = dff 259 | self.num_heads = num_heads 260 | 261 | 262 | class Transformer(tf.keras.Model): 263 | def __init__(self, config, 264 | target_vocab_size, 265 | bert_config_file, 266 | bert_training=False, 267 | dropout_rate=0.1, 268 | name='transformer'): 269 | super(Transformer, self).__init__(name=name) 270 | 271 | self.encoder = build_encoder(config_file=bert_config_file) 272 | self.encoder.trainable = bert_training 273 | 274 | self.decoder = Decoder(config.num_layers, config.d_model, 275 | config.num_heads, config.dff, 276 | target_vocab_size, dropout_rate) 277 | 278 | self.final_layer = tf.keras.layers.Dense(target_vocab_size) 279 | 280 | def load_stock_weights(self, bert: BertModelLayer, ckpt_file): 281 | assert isinstance(bert, BertModelLayer), "Expecting a BertModelLayer instance as first argument" 282 | assert tf.compat.v1.train.checkpoint_exists(ckpt_file), "Checkpoint does not exist: {}".format(ckpt_file) 283 | ckpt_reader = tf.train.load_checkpoint(ckpt_file) 284 | 285 | bert_prefix = 'transformer/bert' 286 | 287 | weights = [] 288 | for weight in bert.weights: 289 | stock_name = map_to_stock_variable_name(weight.name, bert_prefix) 290 | 291 | if ckpt_reader.has_tensor(stock_name): 292 | value = ckpt_reader.get_tensor(stock_name) 293 | weights.append(value) 294 | else: 295 | raise ValueError("No value for:[{}], i.e.:[{}] in:[{}]".format( 296 | weight.name, stock_name, ckpt_file)) 297 | bert.set_weights(weights) 298 | print("Done loading {} BERT weights from: {} into {} (prefix:{})".format( 299 | len(weights), ckpt_file, bert, bert_prefix)) 300 | 301 | def restore_encoder(self, bert_ckpt_file): 302 | # loading the original pre-trained weights into the BERT layer: 303 | self.load_stock_weights(self.encoder, bert_ckpt_file) 304 | 305 | def call(self, inp, tar, training, look_ahead_mask, dec_padding_mask): 306 | enc_output = self.encoder(inp, training=self.encoder.trainable) # (batch_size, inp_seq_len, d_model) 307 | 308 | # dec_output.shape == (batch_size, tar_seq_len, d_model) 309 | dec_output, attention_weights = self.decoder( 310 | tar, enc_output, training, look_ahead_mask, dec_padding_mask) 311 | 312 | final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size) 313 | 314 | return final_output, attention_weights 315 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import collections 23 | import json 24 | import re 25 | 26 | import tokenization 27 | import numpy as np 28 | import tensorflow as tf 29 | from tensorflow import keras 30 | from tensorflow.python import keras 31 | 32 | from absl import flags 33 | from absl import app 34 | from absl import logging 35 | 36 | from bert import BertModelLayer 37 | from bert.loader import StockBertConfig, load_stock_weights 38 | 39 | FLAGS = flags.FLAGS 40 | 41 | flags.DEFINE_string("input_file", None, "") 42 | 43 | flags.DEFINE_string("output_file", None, "") 44 | 45 | flags.DEFINE_string( 46 | "bert_config_file", None, 47 | "The config json file corresponding to the pre-trained BERT model. " 48 | "This specifies the model architecture.") 49 | 50 | flags.DEFINE_integer( 51 | "max_seq_length", 128, 52 | "The maximum total input sequence length after WordPiece tokenization. " 53 | "Sequences longer than this will be truncated, and sequences shorter " 54 | "than this will be padded.") 55 | 56 | flags.DEFINE_string( 57 | "init_checkpoint", None, 58 | "Initial checkpoint (usually from a pre-trained BERT model).") 59 | 60 | flags.DEFINE_string("vocab_file", None, 61 | "The vocabulary file that the BERT model was trained on.") 62 | 63 | flags.DEFINE_bool( 64 | "do_lower_case", True, 65 | "Whether to lower case the input text. Should be True for uncased " 66 | "models and False for cased models.") 67 | 68 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 69 | 70 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 71 | 72 | flags.DEFINE_string("master", None, 73 | "If using a TPU, the address of the master.") 74 | 75 | flags.DEFINE_integer( 76 | "num_tpu_cores", 8, 77 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 78 | 79 | 80 | class InputExample(object): 81 | 82 | def __init__(self, unique_id, text_a, text_b): 83 | self.unique_id = unique_id 84 | self.text_a = text_a 85 | self.text_b = text_b 86 | 87 | 88 | class InputFeatures(object): 89 | """A single set of features of data.""" 90 | 91 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 92 | self.unique_id = unique_id 93 | self.tokens = tokens 94 | self.input_ids = input_ids 95 | self.input_mask = input_mask 96 | self.input_type_ids = input_type_ids 97 | 98 | 99 | def build_dataset(features, seq_length, batch_size): 100 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 101 | 102 | all_unique_ids = [] 103 | all_input_ids = [] 104 | all_input_mask = [] 105 | all_input_type_ids = [] 106 | 107 | for feature in features: 108 | all_unique_ids.append(feature.unique_id) 109 | all_input_ids.append(feature.input_ids) 110 | all_input_mask.append(feature.input_mask) 111 | all_input_type_ids.append(feature.input_type_ids) 112 | 113 | num_examples = len(features) 114 | 115 | # This is for demo purposes and does NOT scale to large data sets. We do 116 | # not use Dataset.from_generator() because that uses tf.py_func which is 117 | # not TPU compatible. The right way to load data is with TFRecordReader. 118 | d = tf.data.Dataset.from_tensor_slices({ 119 | "unique_ids": 120 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 121 | "input_ids": 122 | tf.constant( 123 | all_input_ids, shape=[num_examples, seq_length], 124 | dtype=tf.int32), 125 | "input_mask": 126 | tf.constant( 127 | all_input_mask, 128 | shape=[num_examples, seq_length], 129 | dtype=tf.int32), 130 | "input_type_ids": 131 | tf.constant( 132 | all_input_type_ids, 133 | shape=[num_examples, seq_length], 134 | dtype=tf.int32), 135 | }) 136 | 137 | return d.batch(batch_size=batch_size, drop_remainder=False) 138 | 139 | 140 | def build_model(bert_config, init_checkpoint, max_seq_len): 141 | bert_params = from_json_file(bert_config) 142 | l_bert = BertModelLayer.from_params(bert_params, name="bert") 143 | 144 | # Input and output endpoints 145 | l_input_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32') 146 | l_token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32') 147 | l_input_mask = keras.layers.Input(shape=(max_seq_len,), dtype='int32') 148 | output = l_bert([l_input_ids, l_token_type_ids], mask=l_input_mask, 149 | training=False) # [batch_size, max_seq_len, hidden_size] 150 | print('Output shape: {}'.format(output.get_shape())) 151 | 152 | # Build model 153 | model = keras.Model(inputs=[l_input_ids, l_token_type_ids, l_input_mask], outputs=output) 154 | # loading the original pre-trained weights into the BERT layer: 155 | load_stock_weights(l_bert, init_checkpoint) 156 | 157 | return model 158 | 159 | 160 | def convert_examples_to_features(examples, seq_length, tokenizer): 161 | """Loads a data file into a list of `InputBatch`s.""" 162 | 163 | features = [] 164 | for (ex_index, example) in enumerate(examples): 165 | tokens_a = tokenizer.tokenize(example.text_a) 166 | 167 | tokens_b = None 168 | if example.text_b: 169 | tokens_b = tokenizer.tokenize(example.text_b) 170 | 171 | if tokens_b: 172 | # Modifies `tokens_a` and `tokens_b` in place so that the total 173 | # length is less than the specified length. 174 | # Account for [CLS], [SEP], [SEP] with "- 3" 175 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 176 | else: 177 | # Account for [CLS] and [SEP] with "- 2" 178 | if len(tokens_a) > seq_length - 2: 179 | tokens_a = tokens_a[0:(seq_length - 2)] 180 | 181 | # The convention in BERT is: 182 | # (a) For sequence pairs: 183 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 184 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 185 | # (b) For single sequences: 186 | # tokens: [CLS] the dog is hairy . [SEP] 187 | # type_ids: 0 0 0 0 0 0 0 188 | # 189 | # Where "type_ids" are used to indicate whether this is the first 190 | # sequence or the second sequence. The embedding vectors for `type=0` and 191 | # `type=1` were learned during pre-training and are added to the wordpiece 192 | # embedding vector (and position vector). This is not *strictly* necessary 193 | # since the [SEP] token unambiguously separates the sequences, but it makes 194 | # it easier for the model to learn the concept of sequences. 195 | # 196 | # For classification tasks, the first vector (corresponding to [CLS]) is 197 | # used as as the "sentence vector". Note that this only makes sense because 198 | # the entire model is fine-tuned. 199 | tokens = [] 200 | input_type_ids = [] 201 | tokens.append("[CLS]") 202 | input_type_ids.append(0) 203 | for token in tokens_a: 204 | tokens.append(token) 205 | input_type_ids.append(0) 206 | tokens.append("[SEP]") 207 | input_type_ids.append(0) 208 | 209 | if tokens_b: 210 | for token in tokens_b: 211 | tokens.append(token) 212 | input_type_ids.append(1) 213 | tokens.append("[SEP]") 214 | input_type_ids.append(1) 215 | 216 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 217 | 218 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 219 | # tokens are attended to. 220 | input_mask = [1] * len(input_ids) 221 | 222 | # Zero-pad up to the sequence length. 223 | while len(input_ids) < seq_length: 224 | input_ids.append(0) 225 | input_mask.append(0) 226 | input_type_ids.append(0) 227 | 228 | assert len(input_ids) == seq_length 229 | assert len(input_mask) == seq_length 230 | assert len(input_type_ids) == seq_length 231 | 232 | if ex_index < 5: 233 | logging.info("*** Example ***") 234 | logging.info("unique_id: %s" % (example.unique_id)) 235 | logging.info("tokens: %s" % " ".join( 236 | [tokenization.printable_text(x) for x in tokens])) 237 | logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 238 | logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 239 | logging.info( 240 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 241 | 242 | features.append( 243 | InputFeatures( 244 | unique_id=example.unique_id, 245 | tokens=tokens, 246 | input_ids=input_ids, 247 | input_mask=input_mask, 248 | input_type_ids=input_type_ids)) 249 | return features 250 | 251 | 252 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 253 | """Truncates a sequence pair in place to the maximum length.""" 254 | 255 | # This is a simple heuristic which will always truncate the longer sequence 256 | # one token at a time. This makes more sense than truncating an equal percent 257 | # of tokens from each, since if one sequence is very short then each token 258 | # that's truncated likely contains more information than a longer sequence. 259 | while True: 260 | total_length = len(tokens_a) + len(tokens_b) 261 | if total_length <= max_length: 262 | break 263 | if len(tokens_a) > len(tokens_b): 264 | tokens_a.pop() 265 | else: 266 | tokens_b.pop() 267 | 268 | 269 | def read_examples(input_file): 270 | """Read a list of `InputExample`s from an input file.""" 271 | examples = [] 272 | unique_id = 0 273 | with tf.io.gfile.GFile(input_file, "r") as reader: 274 | while True: 275 | line = reader.readline() 276 | if not line: 277 | break 278 | line = line.strip() 279 | text_a = None 280 | text_b = None 281 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 282 | if m is None: 283 | text_a = line 284 | else: 285 | text_a = m.group(1) 286 | text_b = m.group(2) 287 | examples.append( 288 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 289 | unique_id += 1 290 | return examples 291 | 292 | 293 | def from_json_file(bert_config_file): 294 | with tf.io.gfile.GFile(bert_config_file, "r") as reader: 295 | stock_params = StockBertConfig.from_json_string(reader.read()) 296 | bert_params = stock_params.to_bert_model_layer_params() 297 | return bert_params 298 | 299 | 300 | def main(_): 301 | logging.set_verbosity(logging.INFO) 302 | 303 | tokenizer = tokenization.FullTokenizer( 304 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 305 | 306 | examples = read_examples(FLAGS.input_file) 307 | 308 | features = convert_examples_to_features( 309 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 310 | 311 | unique_id_to_feature = {} 312 | for feature in features: 313 | unique_id_to_feature[feature.unique_id] = feature 314 | 315 | # Build model 316 | model = build_model(bert_config=FLAGS.bert_config_file, 317 | init_checkpoint=FLAGS.init_checkpoint, 318 | max_seq_len=FLAGS.max_seq_length) 319 | 320 | dataset = build_dataset( 321 | features=features, seq_length=FLAGS.max_seq_length, batch_size=FLAGS.batch_size) 322 | 323 | with tf.io.gfile.GFile(FLAGS.output_file, "w") as writer: 324 | for item in dataset: 325 | unique_ids = list(item["unique_ids"]) 326 | result = model.predict([item["input_ids"], item["input_type_ids"], item["input_mask"]]) 327 | 328 | for (i, unique_id) in enumerate(unique_ids): 329 | line_result = np.squeeze(result[i]) 330 | feature = unique_id_to_feature[unique_id.numpy()] 331 | output_json = collections.OrderedDict() 332 | output_json["line"] = int(unique_id.numpy()) 333 | all_features = [] 334 | for (j, token) in enumerate(feature.tokens): 335 | features = collections.OrderedDict() 336 | features["token"] = token 337 | features["feature"] = [ 338 | round(float(x), 6) for x in line_result[j].flat 339 | ] 340 | all_features.append(features) 341 | output_json["features"] = all_features 342 | writer.write(json.dumps(output_json, ensure_ascii=False) + "\n") 343 | 344 | 345 | if __name__ == "__main__": 346 | flags.mark_flag_as_required("input_file") 347 | flags.mark_flag_as_required("vocab_file") 348 | flags.mark_flag_as_required("bert_config_file") 349 | flags.mark_flag_as_required("init_checkpoint") 350 | flags.mark_flag_as_required("output_file") 351 | app.run(main) 352 | --------------------------------------------------------------------------------