├── .gitignore ├── LICENSE ├── README.md ├── The BQ Corpus.pdf ├── __init__.py ├── args.py ├── bert_vec.py ├── data ├── dev.csv ├── test.csv └── train.csv ├── extract_feature.py ├── graph.py ├── modeling.py ├── optimization.py ├── requirements.txt ├── similarity.py └── tokenization.py /.gitignore: -------------------------------------------------------------------------------- 1 | /chinese_L-12_H-768_A-12 2 | tmp/ 3 | __pycache__/ 4 | .idea/ 5 | data/data_merger.py 6 | data/data.py 7 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bert-utils 2 | 3 | 本文基于Google开源的[BERT](https://github.com/google-research/bert)代码进行了进一步的简化,方便生成句向量与做文本分类 4 | 5 | --- 6 | 7 | ***** New July 1st, 2019 ***** 8 | + 修改句向量`graph`文件的生成方式,提升句向量启动速度。不再每次以临时文件的方式生成,首次执行extract_feature.py时会创建`tmp/result/graph`, 9 | 再次执行时直接读取该文件,如果`args.py`文件内容有修改,需要删除`tmp/result/graph`文件 10 | + 修复同时启动两个进程生成句向量时代码报错的bug 11 | + 修改文本匹配数据集为QA_corpus,该份数据相比于蚂蚁金服的数据更有权威性 12 | 13 | --- 14 | 15 | 1、下载BERT中文模型 16 | 17 | 下载地址: https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 18 | 19 | 2、把下载好的模型添加到当前目录下 20 | 21 | 3、句向量生成 22 | 23 | 生成句向量不需要做fine tune,使用预先训练好的模型即可,可参考`extract_feature.py`的`main`方法,注意参数必须是一个list。 24 | 25 | 首次生成句向量时需要加载graph,并在output_dir路径下生成一个新的graph文件,因此速度比较慢,再次调用速度会很快 26 | ``` 27 | from bert.extrac_feature import BertVector 28 | bv = BertVector() 29 | bv.encode(['今天天气不错']) 30 | ``` 31 | 32 | 4、文本分类 33 | 34 | 文本分类需要做fine tune,首先把数据准备好存放在`data`目录下,训练集的名字必须为`train.csv`,验证集的名字必须为`dev.csv`,测试集的名字必须为`test.csv`, 35 | 必须先调用`set_mode`方法,可参考`similarity.py`的`main`方法, 36 | 37 | 训练: 38 | ``` 39 | from similarity import BertSim 40 | import tensorflow as tf 41 | 42 | bs = BertSim() 43 | bs.set_mode(tf.estimator.ModeKeys.TRAIN) 44 | bs.train() 45 | ``` 46 | 47 | 验证: 48 | ``` 49 | from similarity import BertSim 50 | import tensorflow as tf 51 | 52 | bs = BertSim() 53 | bs.set_mode(tf.estimator.ModeKeys.EVAL) 54 | bs.eval() 55 | ``` 56 | 57 | 测试: 58 | ``` 59 | from similarity import BertSim 60 | import tensorflow as tf 61 | 62 | bs = BertSim() 63 | bs.set_mode(tf.estimator.ModeKeys.PREDICT) 64 | bs.test() 65 | ``` 66 | 67 | 5、DEMO中自带了QA_corpus数据集,这里给出[地址](http://icrc.hitsz.edu.cn/info/1037/1162.htm), 68 | 该份数据的生成方式请参阅附件中的论文`The BQ Corpus.pdf` -------------------------------------------------------------------------------- /The BQ Corpus.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/terrifyzhao/bert-utils/1d5f3eb649b4ee8a059f7050da483d0cd6d7fff4/The BQ Corpus.pdf -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | tf.logging.set_verbosity(tf.logging.INFO) 5 | 6 | file_path = os.path.dirname(__file__) 7 | 8 | model_dir = os.path.join(file_path, 'chinese_L-12_H-768_A-12/') 9 | config_name = os.path.join(model_dir, 'bert_config.json') 10 | ckpt_name = os.path.join(model_dir, 'bert_model.ckpt') 11 | output_dir = os.path.join(model_dir, '../tmp/result/') 12 | vocab_file = os.path.join(model_dir, 'vocab.txt') 13 | data_dir = os.path.join(model_dir, '../data/') 14 | 15 | num_train_epochs = 10 16 | batch_size = 128 17 | learning_rate = 0.00005 18 | 19 | # gpu使用率 20 | gpu_memory_fraction = 0.8 21 | 22 | # 默认取倒数第二层的输出值作为句向量 23 | layer_indexes = [-2] 24 | 25 | # 序列的最大程度,单文本建议把该值调小 26 | max_seq_len = 5 27 | 28 | # graph名字 29 | graph_file = 'tmp/result/graph' -------------------------------------------------------------------------------- /bert_vec.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from graph import set_logger 3 | from termcolor import colored 4 | 5 | logger = set_logger(colored('BERT_VEC', 'yellow')) 6 | bert_file_name = 'bert_data.pkl' 7 | 8 | 9 | class BertData: 10 | def __init__(self): 11 | self.dic = {} 12 | self._read_dic() 13 | 14 | # 批量插入数据 15 | def add_batch_data(self, keys, values): 16 | for key, value in zip(keys, values): 17 | self.dic[key] = value 18 | 19 | # 插入单条数据 20 | def add_data(self, key, value): 21 | self.dic[key] = value 22 | 23 | # 根据key删除数据 24 | def delete_data(self, key): 25 | if self.dic and self.dic.get(key, ''): 26 | self.dic.pop(key) 27 | 28 | # 根据key获取数据 29 | def get_data(self, key): 30 | return self.dic.get(key, '') 31 | 32 | # 获取全部数据 33 | def get_all_data(self): 34 | return self.dic 35 | 36 | # 提交 37 | def commit(self): 38 | self._save_dic() 39 | 40 | def _save_dic(self): 41 | try: 42 | with open(bert_file_name, 'wb')as file: 43 | pickle.dump(self.dic, file) 44 | logger.info('bert data saved successfully') 45 | except: 46 | logger.info('save bert data failed') 47 | 48 | def _read_dic(self): 49 | try: 50 | with open(bert_file_name, 'rb')as file: 51 | self.dic = pickle.load(file) 52 | except FileNotFoundError: 53 | logger.info('local bert data is none') 54 | 55 | 56 | if __name__ == '__main__': 57 | bd = BertData() 58 | data = [] 59 | vec = [] 60 | import numpy as np 61 | 62 | for i in range(30000): 63 | data.append('阿迪和考虑就鞍山市会计法哈三联空间和福利卡就很烦' + str(i)) 64 | vec.append(np.random.rand(768)) 65 | bd.add_batch_data(data, vec) 66 | # 增删改需要调用commit方法才会修改本地缓存的内容,查询不需要调用该方法 67 | bd.commit() 68 | # bd.delete_data('上午好啊天气真的不错0') 69 | # res = bd.get_data('上午好啊天气真的不错1') 70 | # bd.add_data('上午好啊天气真的不错test', [1, 2, 3]) 71 | # res = bd.get_all_data() 72 | # print(res.keys()) 73 | # for i in res.items(): 74 | # print(i[0], ':', i[1]) 75 | # print(res.values()) 76 | -------------------------------------------------------------------------------- /extract_feature.py: -------------------------------------------------------------------------------- 1 | import modeling 2 | import tokenization 3 | from graph import optimize_graph 4 | import args 5 | from queue import Queue 6 | from threading import Thread 7 | import tensorflow as tf 8 | import os 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 11 | 12 | 13 | class InputExample(object): 14 | 15 | def __init__(self, unique_id, text_a, text_b): 16 | self.unique_id = unique_id 17 | self.text_a = text_a 18 | self.text_b = text_b 19 | 20 | 21 | class InputFeatures(object): 22 | """A single set of features of data.""" 23 | 24 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 25 | self.unique_id = unique_id 26 | self.tokens = tokens 27 | self.input_ids = input_ids 28 | self.input_mask = input_mask 29 | self.input_type_ids = input_type_ids 30 | 31 | 32 | class BertVector: 33 | 34 | def __init__(self, batch_size=32): 35 | """ 36 | init BertVector 37 | :param batch_size: Depending on your memory default is 32 38 | """ 39 | self.max_seq_length = args.max_seq_len 40 | self.layer_indexes = args.layer_indexes 41 | self.gpu_memory_fraction = 1 42 | if os.path.exists(args.graph_file): 43 | self.graph_path = args.graph_file 44 | else: 45 | self.graph_path = optimize_graph() 46 | 47 | self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 48 | self.batch_size = batch_size 49 | self.estimator = self.get_estimator() 50 | self.input_queue = Queue(maxsize=1) 51 | self.output_queue = Queue(maxsize=1) 52 | self.predict_thread = Thread(target=self.predict_from_queue, daemon=True) 53 | self.predict_thread.start() 54 | 55 | def get_estimator(self): 56 | from tensorflow.python.estimator.estimator import Estimator 57 | from tensorflow.python.estimator.run_config import RunConfig 58 | from tensorflow.python.estimator.model_fn import EstimatorSpec 59 | 60 | def model_fn(features, labels, mode, params): 61 | with tf.gfile.GFile(self.graph_path, 'rb') as f: 62 | graph_def = tf.GraphDef() 63 | graph_def.ParseFromString(f.read()) 64 | 65 | input_names = ['input_ids', 'input_mask', 'input_type_ids'] 66 | 67 | output = tf.import_graph_def(graph_def, 68 | input_map={k + ':0': features[k] for k in input_names}, 69 | return_elements=['final_encodes:0']) 70 | 71 | return EstimatorSpec(mode=mode, predictions={ 72 | 'encodes': output[0] 73 | }) 74 | 75 | config = tf.ConfigProto() 76 | config.gpu_options.allow_growth = True 77 | config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_fraction 78 | config.log_device_placement = False 79 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 80 | 81 | return Estimator(model_fn=model_fn, config=RunConfig(session_config=config), 82 | params={'batch_size': self.batch_size}, model_dir='../tmp') 83 | 84 | def predict_from_queue(self): 85 | prediction = self.estimator.predict(input_fn=self.queue_predict_input_fn, yield_single_examples=False) 86 | for i in prediction: 87 | self.output_queue.put(i) 88 | 89 | def encode(self, sentence): 90 | self.input_queue.put(sentence) 91 | prediction = self.output_queue.get()['encodes'] 92 | return prediction 93 | 94 | def queue_predict_input_fn(self): 95 | 96 | return (tf.data.Dataset.from_generator( 97 | self.generate_from_queue, 98 | output_types={'unique_ids': tf.int32, 99 | 'input_ids': tf.int32, 100 | 'input_mask': tf.int32, 101 | 'input_type_ids': tf.int32}, 102 | output_shapes={ 103 | 'unique_ids': (None,), 104 | 'input_ids': (None, self.max_seq_length), 105 | 'input_mask': (None, self.max_seq_length), 106 | 'input_type_ids': (None, self.max_seq_length)}).prefetch(10)) 107 | 108 | def generate_from_queue(self): 109 | while True: 110 | features = list(self.convert_examples_to_features(seq_length=self.max_seq_length, tokenizer=self.tokenizer)) 111 | yield { 112 | 'unique_ids': [f.unique_id for f in features], 113 | 'input_ids': [f.input_ids for f in features], 114 | 'input_mask': [f.input_mask for f in features], 115 | 'input_type_ids': [f.input_type_ids for f in features] 116 | } 117 | 118 | def input_fn_builder(self, features, seq_length): 119 | """Creates an `input_fn` closure to be passed to Estimator.""" 120 | 121 | all_unique_ids = [] 122 | all_input_ids = [] 123 | all_input_mask = [] 124 | all_input_type_ids = [] 125 | 126 | for feature in features: 127 | all_unique_ids.append(feature.unique_id) 128 | all_input_ids.append(feature.input_ids) 129 | all_input_mask.append(feature.input_mask) 130 | all_input_type_ids.append(feature.input_type_ids) 131 | 132 | def input_fn(params): 133 | """The actual input function.""" 134 | batch_size = params["batch_size"] 135 | 136 | num_examples = len(features) 137 | 138 | # This is for demo purposes and does NOT scale to large data sets. We do 139 | # not use Dataset.from_generator() because that uses tf.py_func which is 140 | # not TPU compatible. The right way to load data is with TFRecordReader. 141 | d = tf.data.Dataset.from_tensor_slices({ 142 | "unique_ids": 143 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 144 | "input_ids": 145 | tf.constant( 146 | all_input_ids, shape=[num_examples, seq_length], 147 | dtype=tf.int32), 148 | "input_mask": 149 | tf.constant( 150 | all_input_mask, 151 | shape=[num_examples, seq_length], 152 | dtype=tf.int32), 153 | "input_type_ids": 154 | tf.constant( 155 | all_input_type_ids, 156 | shape=[num_examples, seq_length], 157 | dtype=tf.int32), 158 | }) 159 | 160 | d = d.batch(batch_size=batch_size, drop_remainder=False) 161 | return d 162 | 163 | return input_fn 164 | 165 | def model_fn_builder(self, bert_config, init_checkpoint, layer_indexes): 166 | """Returns `model_fn` closure for TPUEstimator.""" 167 | 168 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 169 | """The `model_fn` for TPUEstimator.""" 170 | 171 | unique_ids = features["unique_ids"] 172 | input_ids = features["input_ids"] 173 | input_mask = features["input_mask"] 174 | input_type_ids = features["input_type_ids"] 175 | 176 | jit_scope = tf.contrib.compiler.jit.experimental_jit_scope 177 | 178 | with jit_scope(): 179 | model = modeling.BertModel( 180 | config=bert_config, 181 | is_training=False, 182 | input_ids=input_ids, 183 | input_mask=input_mask, 184 | token_type_ids=input_type_ids) 185 | 186 | if mode != tf.estimator.ModeKeys.PREDICT: 187 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 188 | 189 | tvars = tf.trainable_variables() 190 | 191 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, 192 | init_checkpoint) 193 | 194 | tf.logging.info("**** Trainable Variables ****") 195 | for var in tvars: 196 | init_string = "" 197 | if var.name in initialized_variable_names: 198 | init_string = ", *INIT_FROM_CKPT*" 199 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 200 | init_string) 201 | 202 | all_layers = model.get_all_encoder_layers() 203 | 204 | predictions = { 205 | "unique_id": unique_ids, 206 | } 207 | 208 | for (i, layer_index) in enumerate(layer_indexes): 209 | predictions["layer_output_%d" % i] = all_layers[layer_index] 210 | 211 | from tensorflow.python.estimator.model_fn import EstimatorSpec 212 | 213 | output_spec = EstimatorSpec(mode=mode, predictions=predictions) 214 | return output_spec 215 | 216 | return model_fn 217 | 218 | def convert_examples_to_features(self, seq_length, tokenizer): 219 | """Loads a data file into a list of `InputBatch`s.""" 220 | 221 | features = [] 222 | input_masks = [] 223 | examples = self._to_example(self.input_queue.get()) 224 | for (ex_index, example) in enumerate(examples): 225 | tokens_a = tokenizer.tokenize(example.text_a) 226 | 227 | # if the sentences's length is more than seq_length, only use sentence's left part 228 | if len(tokens_a) > seq_length - 2: 229 | tokens_a = tokens_a[0:(seq_length - 2)] 230 | 231 | # The convention in BERT is: 232 | # (a) For sequence pairs: 233 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 234 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 235 | # (b) For single sequences: 236 | # tokens: [CLS] the dog is hairy . [SEP] 237 | # type_ids: 0 0 0 0 0 0 0 238 | # 239 | # Where "type_ids" are used to indicate whether this is the first 240 | # sequence or the second sequence. The embedding vectors for `type=0` and 241 | # `type=1` were learned during pre-training and are added to the wordpiece 242 | # embedding vector (and position vector). This is not *strictly* necessary 243 | # since the [SEP] token unambiguously separates the sequences, but it makes 244 | # it easier for the model to learn the concept of sequences. 245 | # 246 | # For classification tasks, the first vector (corresponding to [CLS]) is 247 | # used as as the "sentence vector". Note that this only makes sense because 248 | # the entire model is fine-tuned. 249 | tokens = [] 250 | input_type_ids = [] 251 | tokens.append("[CLS]") 252 | input_type_ids.append(0) 253 | for token in tokens_a: 254 | tokens.append(token) 255 | input_type_ids.append(0) 256 | tokens.append("[SEP]") 257 | input_type_ids.append(0) 258 | 259 | # Where "input_ids" are tokens's index in vocabulary 260 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 261 | 262 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 263 | # tokens are attended to. 264 | input_mask = [1] * len(input_ids) 265 | input_masks.append(input_mask) 266 | # Zero-pad up to the sequence length. 267 | while len(input_ids) < seq_length: 268 | input_ids.append(0) 269 | input_mask.append(0) 270 | input_type_ids.append(0) 271 | 272 | assert len(input_ids) == seq_length 273 | assert len(input_mask) == seq_length 274 | assert len(input_type_ids) == seq_length 275 | 276 | if ex_index < 5: 277 | tf.logging.info("*** Example ***") 278 | tf.logging.info("unique_id: %s" % (example.unique_id)) 279 | tf.logging.info("tokens: %s" % " ".join( 280 | [tokenization.printable_text(x) for x in tokens])) 281 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 282 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 283 | tf.logging.info( 284 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 285 | 286 | yield InputFeatures( 287 | unique_id=example.unique_id, 288 | tokens=tokens, 289 | input_ids=input_ids, 290 | input_mask=input_mask, 291 | input_type_ids=input_type_ids) 292 | 293 | def _truncate_seq_pair(self, tokens_a, tokens_b, max_length): 294 | """Truncates a sequence pair in place to the maximum length.""" 295 | 296 | # This is a simple heuristic which will always truncate the longer sequence 297 | # one token at a time. This makes more sense than truncating an equal percent 298 | # of tokens from each, since if one sequence is very short then each token 299 | # that's truncated likely contains more information than a longer sequence. 300 | while True: 301 | total_length = len(tokens_a) + len(tokens_b) 302 | if total_length <= max_length: 303 | break 304 | if len(tokens_a) > len(tokens_b): 305 | tokens_a.pop() 306 | else: 307 | tokens_b.pop() 308 | 309 | @staticmethod 310 | def _to_example(sentences): 311 | import re 312 | """ 313 | sentences to InputExample 314 | :param sentences: list of strings 315 | :return: list of InputExample 316 | """ 317 | unique_id = 0 318 | for ss in sentences: 319 | line = tokenization.convert_to_unicode(ss) 320 | if not line: 321 | continue 322 | line = line.strip() 323 | text_a = None 324 | text_b = None 325 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 326 | if m is None: 327 | text_a = line 328 | else: 329 | text_a = m.group(1) 330 | text_b = m.group(2) 331 | yield InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b) 332 | unique_id += 1 333 | 334 | 335 | if __name__ == "__main__": 336 | bert = BertVector() 337 | 338 | while True: 339 | question = input('question: ') 340 | v = bert.encode([question]) 341 | print(str(v)) 342 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from termcolor import colored 4 | import modeling 5 | import args 6 | import tensorflow as tf 7 | import os 8 | 9 | 10 | def set_logger(context, verbose=False): 11 | logger = logging.getLogger(context) 12 | logger.setLevel(logging.DEBUG if verbose else logging.INFO) 13 | formatter = logging.Formatter( 14 | '%(levelname)-.1s:' + context + ':[%(filename).5s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt= 15 | '%m-%d %H:%M:%S') 16 | console_handler = logging.StreamHandler() 17 | console_handler.setLevel(logging.DEBUG if verbose else logging.INFO) 18 | console_handler.setFormatter(formatter) 19 | logger.handlers = [] 20 | logger.addHandler(console_handler) 21 | return logger 22 | 23 | 24 | def optimize_graph(logger=None, verbose=False): 25 | if not logger: 26 | logger = set_logger(colored('BERT_VEC', 'yellow'), verbose) 27 | try: 28 | # we don't need GPU for optimizing the graph 29 | from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference 30 | tf.gfile.MakeDirs(args.output_dir) 31 | 32 | config_fp = args.config_name 33 | logger.info('model config: %s' % config_fp) 34 | 35 | # 加载bert配置文件 36 | with tf.gfile.GFile(config_fp, 'r') as f: 37 | bert_config = modeling.BertConfig.from_dict(json.load(f)) 38 | 39 | logger.info('build graph...') 40 | # input placeholders, not sure if they are friendly to XLA 41 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids') 42 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask') 43 | input_type_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_type_ids') 44 | 45 | jit_scope = tf.contrib.compiler.jit.experimental_jit_scope 46 | 47 | with jit_scope(): 48 | input_tensors = [input_ids, input_mask, input_type_ids] 49 | 50 | model = modeling.BertModel( 51 | config=bert_config, 52 | is_training=False, 53 | input_ids=input_ids, 54 | input_mask=input_mask, 55 | token_type_ids=input_type_ids, 56 | use_one_hot_embeddings=False) 57 | 58 | # 获取所有要训练的变量 59 | tvars = tf.trainable_variables() 60 | 61 | init_checkpoint = args.ckpt_name 62 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, 63 | init_checkpoint) 64 | 65 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 66 | 67 | # 共享卷积核 68 | with tf.variable_scope("pooling"): 69 | # 如果只有一层,就只取对应那一层的weight 70 | if len(args.layer_indexes) == 1: 71 | encoder_layer = model.all_encoder_layers[args.layer_indexes[0]] 72 | else: 73 | # 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数 74 | all_layers = [model.all_encoder_layers[l] for l in args.layer_indexes] 75 | encoder_layer = tf.concat(all_layers, -1) 76 | 77 | mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1) 78 | masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / ( 79 | tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10) 80 | 81 | input_mask = tf.cast(input_mask, tf.float32) 82 | # 以下代码是句向量的生成方法,可以理解为做了一个卷积的操作,但是没有把结果相加, 卷积核是input_mask 83 | pooled = masked_reduce_mean(encoder_layer, input_mask) 84 | pooled = tf.identity(pooled, 'final_encodes') 85 | 86 | output_tensors = [pooled] 87 | tmp_g = tf.get_default_graph().as_graph_def() 88 | 89 | # allow_soft_placement:自动选择运行设备 90 | config = tf.ConfigProto(allow_soft_placement=True) 91 | with tf.Session(config=config) as sess: 92 | logger.info('load parameters from checkpoint...') 93 | sess.run(tf.global_variables_initializer()) 94 | logger.info('freeze...') 95 | tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors]) 96 | dtypes = [n.dtype for n in input_tensors] 97 | logger.info('optimize...') 98 | tmp_g = optimize_for_inference( 99 | tmp_g, 100 | [n.name[:-2] for n in input_tensors], 101 | [n.name[:-2] for n in output_tensors], 102 | [dtype.as_datatype_enum for dtype in dtypes], 103 | False) 104 | # tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name 105 | tmp_file = args.graph_file 106 | logger.info('write graph to a tmp file: %s' % tmp_file) 107 | with tf.gfile.GFile(tmp_file, 'wb') as f: 108 | f.write(tmp_g.SerializeToString()) 109 | return tmp_file 110 | except Exception as e: 111 | logger.error('fail to optimize the graph!') 112 | logger.error(e) 113 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import six 27 | import tensorflow as tf 28 | 29 | 30 | class BertConfig(object): 31 | """Configuration for `BertModel`.""" 32 | 33 | def __init__(self, 34 | vocab_size, 35 | hidden_size=768, 36 | num_hidden_layers=12, 37 | num_attention_heads=12, 38 | intermediate_size=3072, 39 | hidden_act="gelu", 40 | hidden_dropout_prob=0.1, 41 | attention_probs_dropout_prob=0.1, 42 | max_position_embeddings=512, 43 | type_vocab_size=16, 44 | initializer_range=0.02): 45 | """Constructs BertConfig. 46 | 47 | Args: 48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 49 | hidden_size: Size of the encoder layers and the pooler layer. 50 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 51 | num_attention_heads: Number of attention heads for each attention layer in 52 | the Transformer encoder. 53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 54 | layer in the Transformer encoder. 55 | hidden_act: The non-linear activation function (function or string) in the 56 | encoder and pooler. 57 | hidden_dropout_prob: The dropout probability for all fully connected 58 | layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this model might 62 | ever be used with. Typically set this to something large just in case 63 | (e.g., 512 or 1024 or 2048). 64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 65 | `BertModel`. 66 | initializer_range: The stdev of the truncated_normal_initializer for 67 | initializing all weight matrices. 68 | """ 69 | self.vocab_size = vocab_size 70 | self.hidden_size = hidden_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.hidden_act = hidden_act 74 | self.intermediate_size = intermediate_size 75 | self.hidden_dropout_prob = hidden_dropout_prob 76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 77 | self.max_position_embeddings = max_position_embeddings 78 | self.type_vocab_size = type_vocab_size 79 | self.initializer_range = initializer_range 80 | 81 | @classmethod 82 | def from_dict(cls, json_object): 83 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 84 | config = BertConfig(vocab_size=None) 85 | for (key, value) in six.iteritems(json_object): 86 | config.__dict__[key] = value 87 | return config 88 | 89 | @classmethod 90 | def from_json_file(cls, json_file): 91 | """Constructs a `BertConfig` from a json file of parameters.""" 92 | with tf.gfile.GFile(json_file, "r") as reader: 93 | text = reader.read() 94 | return cls.from_dict(json.loads(text)) 95 | 96 | def to_dict(self): 97 | """Serializes this instance to a Python dictionary.""" 98 | output = copy.deepcopy(self.__dict__) 99 | return output 100 | 101 | def to_json_string(self): 102 | """Serializes this instance to a JSON string.""" 103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 104 | 105 | 106 | class BertModel(object): 107 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 108 | 109 | Example usage: 110 | 111 | ```python 112 | # Already been converted into WordPiece token ids 113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 116 | 117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 119 | 120 | model = modeling.BertModel(config=config, is_training=True, 121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 122 | 123 | label_embeddings = tf.get_variable(...) 124 | pooled_output = model.get_pooled_output() 125 | logits = tf.matmul(pooled_output, label_embeddings) 126 | ... 127 | ``` 128 | """ 129 | 130 | def __init__(self, 131 | config, 132 | is_training, 133 | input_ids, 134 | input_mask=None, 135 | token_type_ids=None, 136 | use_one_hot_embeddings=True, 137 | scope=None): 138 | """Constructor for BertModel. 139 | 140 | Args: 141 | config: `BertConfig` instance. 142 | is_training: bool. rue for training model, false for eval model. Controls 143 | whether dropout will be applied. 144 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 149 | it is must faster if this is True, on the CPU or GPU, it is faster if 150 | this is False. 151 | scope: (optional) variable scope. Defaults to "bert". 152 | 153 | Raises: 154 | ValueError: The config is invalid or one of the input tensor shapes 155 | is invalid. 156 | """ 157 | config = copy.deepcopy(config) 158 | if not is_training: 159 | config.hidden_dropout_prob = 0.0 160 | config.attention_probs_dropout_prob = 0.0 161 | 162 | input_shape = get_shape_list(input_ids, expected_rank=2) 163 | batch_size = input_shape[0] 164 | seq_length = input_shape[1] 165 | 166 | if input_mask is None: 167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 168 | 169 | if token_type_ids is None: 170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 171 | 172 | with tf.variable_scope(scope, default_name="bert"): 173 | with tf.variable_scope("embeddings"): 174 | # Perform embedding lookup on the word ids. 175 | (self.embedding_output, self.embedding_table) = embedding_lookup( 176 | input_ids=input_ids, 177 | vocab_size=config.vocab_size, 178 | embedding_size=config.hidden_size, 179 | initializer_range=config.initializer_range, 180 | word_embedding_name="word_embeddings", 181 | use_one_hot_embeddings=use_one_hot_embeddings) 182 | 183 | # Add positional embeddings and token type embeddings, then layer 184 | # normalize and perform dropout. 185 | self.embedding_output = embedding_postprocessor( 186 | input_tensor=self.embedding_output, 187 | use_token_type=True, 188 | token_type_ids=token_type_ids, 189 | token_type_vocab_size=config.type_vocab_size, 190 | token_type_embedding_name="token_type_embeddings", 191 | use_position_embeddings=True, 192 | position_embedding_name="position_embeddings", 193 | initializer_range=config.initializer_range, 194 | max_position_embeddings=config.max_position_embeddings, 195 | dropout_prob=config.hidden_dropout_prob) 196 | 197 | with tf.variable_scope("encoder"): 198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 199 | # mask of shape [batch_size, seq_length, seq_length] which is used 200 | # for the attention scores. 201 | attention_mask = create_attention_mask_from_input_mask( 202 | input_ids, input_mask) 203 | 204 | # Run the stacked transformer. 205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 206 | self.all_encoder_layers = transformer_model( 207 | input_tensor=self.embedding_output, 208 | attention_mask=attention_mask, 209 | hidden_size=config.hidden_size, 210 | num_hidden_layers=config.num_hidden_layers, 211 | num_attention_heads=config.num_attention_heads, 212 | intermediate_size=config.intermediate_size, 213 | intermediate_act_fn=get_activation(config.hidden_act), 214 | hidden_dropout_prob=config.hidden_dropout_prob, 215 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 216 | initializer_range=config.initializer_range, 217 | do_return_all_layers=True) 218 | 219 | self.sequence_output = self.all_encoder_layers[-1] 220 | # The "pooler" converts the encoded sequence tensor of shape 221 | # [batch_size, seq_length, hidden_size] to a tensor of shape 222 | # [batch_size, hidden_size]. This is necessary for segment-level 223 | # (or segment-pair-level) classification tasks where we need a fixed 224 | # dimensional representation of the segment. 225 | with tf.variable_scope("pooler"): 226 | # We "pool" the model by simply taking the hidden state corresponding 227 | # to the first token. We assume that this has been pre-trained 228 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 229 | self.pooled_output = tf.layers.dense( 230 | first_token_tensor, 231 | config.hidden_size, 232 | activation=tf.tanh, 233 | kernel_initializer=create_initializer(config.initializer_range)) 234 | 235 | def get_pooled_output(self): 236 | return self.pooled_output 237 | 238 | def get_sequence_output(self): 239 | """Gets final hidden layer of encoder. 240 | 241 | Returns: 242 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 243 | to the final hidden of the transformer encoder. 244 | """ 245 | return self.sequence_output 246 | 247 | def get_all_encoder_layers(self): 248 | return self.all_encoder_layers 249 | 250 | def get_embedding_output(self): 251 | """Gets output of the embedding lookup (i.e., input to the transformer). 252 | 253 | Returns: 254 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 255 | to the output of the embedding layer, after summing the word 256 | embeddings with the positional embeddings and the token type embeddings, 257 | then performing layer normalization. This is the input to the transformer. 258 | """ 259 | return self.embedding_output 260 | 261 | def get_embedding_table(self): 262 | return self.embedding_table 263 | 264 | 265 | def gelu(input_tensor): 266 | """Gaussian Error Linear Unit. 267 | 268 | This is a smoother version of the RELU. 269 | Original paper: https://arxiv.org/abs/1606.08415 270 | 271 | Args: 272 | input_tensor: float Tensor to perform activation. 273 | 274 | Returns: 275 | `input_tensor` with the GELU activation applied. 276 | """ 277 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 278 | return input_tensor * cdf 279 | 280 | 281 | def get_activation(activation_string): 282 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 283 | 284 | Args: 285 | activation_string: String name of the activation function. 286 | 287 | Returns: 288 | A Python function corresponding to the activation function. If 289 | `activation_string` is None, empty, or "linear", this will return None. 290 | If `activation_string` is not a string, it will return `activation_string`. 291 | 292 | Raises: 293 | ValueError: The `activation_string` does not correspond to a known 294 | activation. 295 | """ 296 | 297 | # We assume that anything that"s not a string is already an activation 298 | # function, so we just return it. 299 | if not isinstance(activation_string, six.string_types): 300 | return activation_string 301 | 302 | if not activation_string: 303 | return None 304 | 305 | act = activation_string.lower() 306 | if act == "linear": 307 | return None 308 | elif act == "relu": 309 | return tf.nn.relu 310 | elif act == "gelu": 311 | return gelu 312 | elif act == "tanh": 313 | return tf.tanh 314 | else: 315 | raise ValueError("Unsupported activation: %s" % act) 316 | 317 | 318 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 319 | """Compute the union of the current variables and checkpoint variables.""" 320 | assignment_map = {} 321 | initialized_variable_names = {} 322 | 323 | name_to_variable = collections.OrderedDict() 324 | for var in tvars: 325 | name = var.name 326 | m = re.match("^(.*):\\d+$", name) 327 | if m is not None: 328 | name = m.group(1) 329 | name_to_variable[name] = var 330 | 331 | init_vars = tf.train.list_variables(init_checkpoint) 332 | 333 | assignment_map = collections.OrderedDict() 334 | for x in init_vars: 335 | (name, var) = (x[0], x[1]) 336 | if name not in name_to_variable: 337 | continue 338 | assignment_map[name] = name 339 | initialized_variable_names[name] = 1 340 | initialized_variable_names[name + ":0"] = 1 341 | 342 | return (assignment_map, initialized_variable_names) 343 | 344 | 345 | def dropout(input_tensor, dropout_prob): 346 | """Perform dropout. 347 | 348 | Args: 349 | input_tensor: float Tensor. 350 | dropout_prob: Python float. The probability of dropping out a value (NOT of 351 | *keeping* a dimension as in `tf.nn.dropout`). 352 | 353 | Returns: 354 | A version of `input_tensor` with dropout applied. 355 | """ 356 | if dropout_prob is None or dropout_prob == 0.0: 357 | return input_tensor 358 | 359 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 360 | return output 361 | 362 | 363 | def layer_norm(input_tensor, name=None): 364 | """Run layer normalization on the last dimension of the tensor.""" 365 | return tf.contrib.layers.layer_norm( 366 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 367 | 368 | 369 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 370 | """Runs layer normalization followed by dropout.""" 371 | output_tensor = layer_norm(input_tensor, name) 372 | output_tensor = dropout(output_tensor, dropout_prob) 373 | return output_tensor 374 | 375 | 376 | def create_initializer(initializer_range=0.02): 377 | """Creates a `truncated_normal_initializer` with the given range.""" 378 | return tf.truncated_normal_initializer(stddev=initializer_range) 379 | 380 | 381 | def embedding_lookup(input_ids, 382 | vocab_size, 383 | embedding_size=128, 384 | initializer_range=0.02, 385 | word_embedding_name="word_embeddings", 386 | use_one_hot_embeddings=False): 387 | """Looks up words embeddings for id tensor. 388 | 389 | Args: 390 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 391 | ids. 392 | vocab_size: int. Size of the embedding vocabulary. 393 | embedding_size: int. Width of the word embeddings. 394 | initializer_range: float. Embedding initialization range. 395 | word_embedding_name: string. Name of the embedding table. 396 | use_one_hot_embeddings: bool. If True, use one-hot method for word 397 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 398 | for TPUs. 399 | 400 | Returns: 401 | float Tensor of shape [batch_size, seq_length, embedding_size]. 402 | """ 403 | # This function assumes that the input is of shape [batch_size, seq_length, 404 | # num_inputs]. 405 | # 406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 407 | # reshape to [batch_size, seq_length, 1]. 408 | if input_ids.shape.ndims == 2: 409 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 410 | 411 | embedding_table = tf.get_variable( 412 | name=word_embedding_name, 413 | shape=[vocab_size, embedding_size], 414 | initializer=create_initializer(initializer_range)) 415 | 416 | if use_one_hot_embeddings: 417 | flat_input_ids = tf.reshape(input_ids, [-1]) 418 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 419 | output = tf.matmul(one_hot_input_ids, embedding_table) 420 | else: 421 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 422 | 423 | input_shape = get_shape_list(input_ids) 424 | 425 | output = tf.reshape(output, 426 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 427 | return (output, embedding_table) 428 | 429 | 430 | def embedding_postprocessor(input_tensor, 431 | use_token_type=False, 432 | token_type_ids=None, 433 | token_type_vocab_size=16, 434 | token_type_embedding_name="token_type_embeddings", 435 | use_position_embeddings=True, 436 | position_embedding_name="position_embeddings", 437 | initializer_range=0.02, 438 | max_position_embeddings=512, 439 | dropout_prob=0.1): 440 | """Performs various post-processing on a word embedding tensor. 441 | 442 | Args: 443 | input_tensor: float Tensor of shape [batch_size, seq_length, 444 | embedding_size]. 445 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 446 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 447 | Must be specified if `use_token_type` is True. 448 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 449 | token_type_embedding_name: string. The name of the embedding table variable 450 | for token type ids. 451 | use_position_embeddings: bool. Whether to add position embeddings for the 452 | position of each token in the sequence. 453 | position_embedding_name: string. The name of the embedding table variable 454 | for positional embeddings. 455 | initializer_range: float. Range of the weight initialization. 456 | max_position_embeddings: int. Maximum sequence length that might ever be 457 | used with this model. This can be longer than the sequence length of 458 | input_tensor, but cannot be shorter. 459 | dropout_prob: float. Dropout probability applied to the final output tensor. 460 | 461 | Returns: 462 | float tensor with same shape as `input_tensor`. 463 | 464 | Raises: 465 | ValueError: One of the tensor shapes or input values is invalid. 466 | """ 467 | input_shape = get_shape_list(input_tensor, expected_rank=3) 468 | batch_size = input_shape[0] 469 | seq_length = input_shape[1] 470 | width = input_shape[2] 471 | 472 | output = input_tensor 473 | 474 | if use_token_type: 475 | if token_type_ids is None: 476 | raise ValueError("`token_type_ids` must be specified if" 477 | "`use_token_type` is True.") 478 | token_type_table = tf.get_variable( 479 | name=token_type_embedding_name, 480 | shape=[token_type_vocab_size, width], 481 | initializer=create_initializer(initializer_range)) 482 | # This vocab will be small so we always do one-hot here, since it is always 483 | # faster for a small vocabulary. 484 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 485 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 486 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 487 | token_type_embeddings = tf.reshape(token_type_embeddings, 488 | [batch_size, seq_length, width]) 489 | output += token_type_embeddings 490 | 491 | if use_position_embeddings: 492 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 493 | with tf.control_dependencies([assert_op]): 494 | full_position_embeddings = tf.get_variable( 495 | name=position_embedding_name, 496 | shape=[max_position_embeddings, width], 497 | initializer=create_initializer(initializer_range)) 498 | # Since the position embedding table is a learned variable, we create it 499 | # using a (long) sequence length `max_position_embeddings`. The actual 500 | # sequence length might be shorter than this, for faster training of 501 | # tasks that do not have long sequences. 502 | # 503 | # So `full_position_embeddings` is effectively an embedding table 504 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 505 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 506 | # perform a slice. 507 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 508 | [seq_length, -1]) 509 | num_dims = len(output.shape.as_list()) 510 | 511 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 512 | # we broadcast among the first dimensions, which is typically just 513 | # the batch size. 514 | position_broadcast_shape = [] 515 | for _ in range(num_dims - 2): 516 | position_broadcast_shape.append(1) 517 | position_broadcast_shape.extend([seq_length, width]) 518 | position_embeddings = tf.reshape(position_embeddings, 519 | position_broadcast_shape) 520 | output += position_embeddings 521 | 522 | output = layer_norm_and_dropout(output, dropout_prob) 523 | return output 524 | 525 | 526 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 527 | """Create 3D attention mask from a 2D tensor mask. 528 | 529 | Args: 530 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 531 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 532 | 533 | Returns: 534 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 535 | """ 536 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 537 | batch_size = from_shape[0] 538 | from_seq_length = from_shape[1] 539 | 540 | to_shape = get_shape_list(to_mask, expected_rank=2) 541 | to_seq_length = to_shape[1] 542 | 543 | to_mask = tf.cast( 544 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 545 | 546 | # We don't assume that `from_tensor` is a mask (although it could be). We 547 | # don't actually care if we attend *from* padding tokens (only *to* padding) 548 | # tokens so we create a tensor of all ones. 549 | # 550 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 551 | broadcast_ones = tf.ones( 552 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 553 | 554 | # Here we broadcast along two dimensions to create the mask. 555 | mask = broadcast_ones * to_mask 556 | 557 | return mask 558 | 559 | 560 | def attention_layer(from_tensor, 561 | to_tensor, 562 | attention_mask=None, 563 | num_attention_heads=1, 564 | size_per_head=512, 565 | query_act=None, 566 | key_act=None, 567 | value_act=None, 568 | attention_probs_dropout_prob=0.0, 569 | initializer_range=0.02, 570 | do_return_2d_tensor=False, 571 | batch_size=None, 572 | from_seq_length=None, 573 | to_seq_length=None): 574 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 575 | 576 | This is an implementation of multi-headed attention based on "Attention 577 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 578 | this is self-attention. Each timestep in `from_tensor` attends to the 579 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 580 | 581 | This function first projects `from_tensor` into a "query" tensor and 582 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 583 | of tensors of length `num_attention_heads`, where each tensor is of shape 584 | [batch_size, seq_length, size_per_head]. 585 | 586 | Then, the query and key tensors are dot-producted and scaled. These are 587 | softmaxed to obtain attention probabilities. The value tensors are then 588 | interpolated by these probabilities, then concatenated back to a single 589 | tensor and returned. 590 | 591 | In practice, the multi-headed attention are done with transposes and 592 | reshapes rather than actual separate tensors. 593 | 594 | Args: 595 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 596 | from_width]. 597 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 598 | attention_mask: (optional) int32 Tensor of shape [batch_size, 599 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 600 | attention scores will effectively be set to -infinity for any positions in 601 | the mask that are 0, and will be unchanged for positions that are 1. 602 | num_attention_heads: int. Number of attention heads. 603 | size_per_head: int. Size of each attention head. 604 | query_act: (optional) Activation function for the query transform. 605 | key_act: (optional) Activation function for the key transform. 606 | value_act: (optional) Activation function for the value transform. 607 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 608 | attention probabilities. 609 | initializer_range: float. Range of the weight initializer. 610 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 611 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 612 | output will be of shape [batch_size, from_seq_length, num_attention_heads 613 | * size_per_head]. 614 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 615 | of the 3D version of the `from_tensor` and `to_tensor`. 616 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 617 | of the 3D version of the `from_tensor`. 618 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 619 | of the 3D version of the `to_tensor`. 620 | 621 | Returns: 622 | float Tensor of shape [batch_size, from_seq_length, 623 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 624 | true, this will be of shape [batch_size * from_seq_length, 625 | num_attention_heads * size_per_head]). 626 | 627 | Raises: 628 | ValueError: Any of the arguments or tensor shapes are invalid. 629 | """ 630 | 631 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 632 | seq_length, width): 633 | output_tensor = tf.reshape( 634 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 635 | 636 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 637 | return output_tensor 638 | 639 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 640 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 641 | 642 | if len(from_shape) != len(to_shape): 643 | raise ValueError( 644 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 645 | 646 | if len(from_shape) == 3: 647 | batch_size = from_shape[0] 648 | from_seq_length = from_shape[1] 649 | to_seq_length = to_shape[1] 650 | elif len(from_shape) == 2: 651 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 652 | raise ValueError( 653 | "When passing in rank 2 tensors to attention_layer, the values " 654 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 655 | "must all be specified.") 656 | 657 | # Scalar dimensions referenced here: 658 | # B = batch size (number of sequences) 659 | # F = `from_tensor` sequence length 660 | # T = `to_tensor` sequence length 661 | # N = `num_attention_heads` 662 | # H = `size_per_head` 663 | 664 | from_tensor_2d = reshape_to_matrix(from_tensor) 665 | to_tensor_2d = reshape_to_matrix(to_tensor) 666 | 667 | # `query_layer` = [B*F, N*H] 668 | query_layer = tf.layers.dense( 669 | from_tensor_2d, 670 | num_attention_heads * size_per_head, 671 | activation=query_act, 672 | name="query", 673 | kernel_initializer=create_initializer(initializer_range)) 674 | 675 | # `key_layer` = [B*T, N*H] 676 | key_layer = tf.layers.dense( 677 | to_tensor_2d, 678 | num_attention_heads * size_per_head, 679 | activation=key_act, 680 | name="key", 681 | kernel_initializer=create_initializer(initializer_range)) 682 | 683 | # `value_layer` = [B*T, N*H] 684 | value_layer = tf.layers.dense( 685 | to_tensor_2d, 686 | num_attention_heads * size_per_head, 687 | activation=value_act, 688 | name="value", 689 | kernel_initializer=create_initializer(initializer_range)) 690 | 691 | # `query_layer` = [B, N, F, H] 692 | query_layer = transpose_for_scores(query_layer, batch_size, 693 | num_attention_heads, from_seq_length, 694 | size_per_head) 695 | 696 | # `key_layer` = [B, N, T, H] 697 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 698 | to_seq_length, size_per_head) 699 | 700 | # Take the dot product between "query" and "key" to get the raw 701 | # attention scores. 702 | # `attention_scores` = [B, N, F, T] 703 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 704 | attention_scores = tf.multiply(attention_scores, 705 | 1.0 / math.sqrt(float(size_per_head))) 706 | 707 | if attention_mask is not None: 708 | # `attention_mask` = [B, 1, F, T] 709 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 710 | 711 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 712 | # masked positions, this operation will create a tensor which is 0.0 for 713 | # positions we want to attend and -10000.0 for masked positions. 714 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 715 | 716 | # Since we are adding it to the raw scores before the softmax, this is 717 | # effectively the same as removing these entirely. 718 | attention_scores += adder 719 | 720 | # Normalize the attention scores to probabilities. 721 | # `attention_probs` = [B, N, F, T] 722 | attention_probs = tf.nn.softmax(attention_scores) 723 | 724 | # This is actually dropping out entire tokens to attend to, which might 725 | # seem a bit unusual, but is taken from the original Transformer paper. 726 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 727 | 728 | # `value_layer` = [B, T, N, H] 729 | value_layer = tf.reshape( 730 | value_layer, 731 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 732 | 733 | # `value_layer` = [B, N, T, H] 734 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 735 | 736 | # `context_layer` = [B, N, F, H] 737 | context_layer = tf.matmul(attention_probs, value_layer) 738 | 739 | # `context_layer` = [B, F, N, H] 740 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 741 | 742 | if do_return_2d_tensor: 743 | # `context_layer` = [B*F, N*V] 744 | context_layer = tf.reshape( 745 | context_layer, 746 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 747 | else: 748 | # `context_layer` = [B, F, N*V] 749 | context_layer = tf.reshape( 750 | context_layer, 751 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 752 | 753 | return context_layer 754 | 755 | 756 | def transformer_model(input_tensor, 757 | attention_mask=None, 758 | hidden_size=768, 759 | num_hidden_layers=12, 760 | num_attention_heads=12, 761 | intermediate_size=3072, 762 | intermediate_act_fn=gelu, 763 | hidden_dropout_prob=0.1, 764 | attention_probs_dropout_prob=0.1, 765 | initializer_range=0.02, 766 | do_return_all_layers=False): 767 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 768 | 769 | This is almost an exact implementation of the original Transformer encoder. 770 | 771 | See the original paper: 772 | https://arxiv.org/abs/1706.03762 773 | 774 | Also see: 775 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 776 | 777 | Args: 778 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 779 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 780 | seq_length], with 1 for positions that can be attended to and 0 in 781 | positions that should not be. 782 | hidden_size: int. Hidden size of the Transformer. 783 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 784 | num_attention_heads: int. Number of attention heads in the Transformer. 785 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 786 | forward) layer. 787 | intermediate_act_fn: function. The non-linear activation function to apply 788 | to the output of the intermediate/feed-forward layer. 789 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 790 | attention_probs_dropout_prob: float. Dropout probability of the attention 791 | probabilities. 792 | initializer_range: float. Range of the initializer (stddev of truncated 793 | normal). 794 | do_return_all_layers: Whether to also return all layers or just the final 795 | layer. 796 | 797 | Returns: 798 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 799 | hidden layer of the Transformer. 800 | 801 | Raises: 802 | ValueError: A Tensor shape or parameter is invalid. 803 | """ 804 | if hidden_size % num_attention_heads != 0: 805 | raise ValueError( 806 | "The hidden size (%d) is not a multiple of the number of attention " 807 | "heads (%d)" % (hidden_size, num_attention_heads)) 808 | 809 | attention_head_size = int(hidden_size / num_attention_heads) 810 | input_shape = get_shape_list(input_tensor, expected_rank=3) 811 | batch_size = input_shape[0] 812 | seq_length = input_shape[1] 813 | input_width = input_shape[2] 814 | 815 | # The Transformer performs sum residuals on all layers so the input needs 816 | # to be the same as the hidden size. 817 | if input_width != hidden_size: 818 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 819 | (input_width, hidden_size)) 820 | 821 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 822 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 823 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 824 | # help the optimizer. 825 | prev_output = reshape_to_matrix(input_tensor) 826 | 827 | all_layer_outputs = [] 828 | for layer_idx in range(num_hidden_layers): 829 | with tf.variable_scope("layer_%d" % layer_idx): 830 | layer_input = prev_output 831 | 832 | with tf.variable_scope("attention"): 833 | attention_heads = [] 834 | with tf.variable_scope("self"): 835 | attention_head = attention_layer( 836 | from_tensor=layer_input, 837 | to_tensor=layer_input, 838 | attention_mask=attention_mask, 839 | num_attention_heads=num_attention_heads, 840 | size_per_head=attention_head_size, 841 | attention_probs_dropout_prob=attention_probs_dropout_prob, 842 | initializer_range=initializer_range, 843 | do_return_2d_tensor=True, 844 | batch_size=batch_size, 845 | from_seq_length=seq_length, 846 | to_seq_length=seq_length) 847 | attention_heads.append(attention_head) 848 | 849 | attention_output = None 850 | if len(attention_heads) == 1: 851 | attention_output = attention_heads[0] 852 | else: 853 | # In the case where we have other sequences, we just concatenate 854 | # them to the self-attention head before the projection. 855 | attention_output = tf.concat(attention_heads, axis=-1) 856 | 857 | # Run a linear projection of `hidden_size` then add a residual 858 | # with `layer_input`. 859 | with tf.variable_scope("output"): 860 | attention_output = tf.layers.dense( 861 | attention_output, 862 | hidden_size, 863 | kernel_initializer=create_initializer(initializer_range)) 864 | attention_output = dropout(attention_output, hidden_dropout_prob) 865 | attention_output = layer_norm(attention_output + layer_input) 866 | 867 | # The activation is only applied to the "intermediate" hidden layer. 868 | with tf.variable_scope("intermediate"): 869 | intermediate_output = tf.layers.dense( 870 | attention_output, 871 | intermediate_size, 872 | activation=intermediate_act_fn, 873 | kernel_initializer=create_initializer(initializer_range)) 874 | 875 | # Down-project back to `hidden_size` then add the residual. 876 | with tf.variable_scope("output"): 877 | layer_output = tf.layers.dense( 878 | intermediate_output, 879 | hidden_size, 880 | kernel_initializer=create_initializer(initializer_range)) 881 | layer_output = dropout(layer_output, hidden_dropout_prob) 882 | layer_output = layer_norm(layer_output + attention_output) 883 | prev_output = layer_output 884 | all_layer_outputs.append(layer_output) 885 | 886 | if do_return_all_layers: 887 | final_outputs = [] 888 | for layer_output in all_layer_outputs: 889 | final_output = reshape_from_matrix(layer_output, input_shape) 890 | final_outputs.append(final_output) 891 | return final_outputs 892 | else: 893 | final_output = reshape_from_matrix(prev_output, input_shape) 894 | return final_output 895 | 896 | 897 | def get_shape_list(tensor, expected_rank=None, name=None): 898 | """Returns a list of the shape of tensor, preferring static dimensions. 899 | 900 | Args: 901 | tensor: A tf.Tensor object to find the shape of. 902 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 903 | specified and the `tensor` has a different rank, and exception will be 904 | thrown. 905 | name: Optional name of the tensor for the error message. 906 | 907 | Returns: 908 | A list of dimensions of the shape of tensor. All static dimensions will 909 | be returned as python integers, and dynamic dimensions will be returned 910 | as tf.Tensor scalars. 911 | """ 912 | if name is None: 913 | name = tensor.name 914 | 915 | if expected_rank is not None: 916 | assert_rank(tensor, expected_rank, name) 917 | 918 | shape = tensor.shape.as_list() 919 | 920 | non_static_indexes = [] 921 | for (index, dim) in enumerate(shape): 922 | if dim is None: 923 | non_static_indexes.append(index) 924 | 925 | if not non_static_indexes: 926 | return shape 927 | 928 | dyn_shape = tf.shape(tensor) 929 | for index in non_static_indexes: 930 | shape[index] = dyn_shape[index] 931 | return shape 932 | 933 | 934 | def reshape_to_matrix(input_tensor): 935 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 936 | ndims = input_tensor.shape.ndims 937 | if ndims < 2: 938 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 939 | (input_tensor.shape)) 940 | if ndims == 2: 941 | return input_tensor 942 | 943 | width = input_tensor.shape[-1] 944 | output_tensor = tf.reshape(input_tensor, [-1, width]) 945 | return output_tensor 946 | 947 | 948 | def reshape_from_matrix(output_tensor, orig_shape_list): 949 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 950 | if len(orig_shape_list) == 2: 951 | return output_tensor 952 | 953 | output_shape = get_shape_list(output_tensor) 954 | 955 | orig_dims = orig_shape_list[0:-1] 956 | width = output_shape[-1] 957 | 958 | return tf.reshape(output_tensor, orig_dims + [width]) 959 | 960 | 961 | def assert_rank(tensor, expected_rank, name=None): 962 | """Raises an exception if the tensor rank is not of the expected rank. 963 | 964 | Args: 965 | tensor: A tf.Tensor to check the rank of. 966 | expected_rank: Python integer or list of integers, expected rank. 967 | name: Optional name of the tensor for the error message. 968 | 969 | Raises: 970 | ValueError: If the expected shape doesn't match the actual shape. 971 | """ 972 | if name is None: 973 | name = tensor.name 974 | 975 | expected_rank_dict = {} 976 | if isinstance(expected_rank, six.integer_types): 977 | expected_rank_dict[expected_rank] = True 978 | else: 979 | for x in expected_rank: 980 | expected_rank_dict[x] = True 981 | 982 | actual_rank = tensor.shape.ndims 983 | if actual_rank not in expected_rank_dict: 984 | scope_name = tf.get_variable_scope().name 985 | raise ValueError( 986 | "For the tensor `%s` in scope `%s`, the actual rank " 987 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 988 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 989 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /similarity.py: -------------------------------------------------------------------------------- 1 | import os 2 | from queue import Queue 3 | from threading import Thread 4 | 5 | import pandas as pd 6 | import tensorflow as tf 7 | import collections 8 | import args 9 | import tokenization 10 | import modeling 11 | import optimization 12 | 13 | 14 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 15 | 16 | 17 | class InputExample(object): 18 | """A single training/test example for simple sequence classification.""" 19 | 20 | def __init__(self, guid, text_a, text_b=None, label=None): 21 | """Constructs a InputExample. 22 | 23 | Args: 24 | guid: Unique id for the example. 25 | text_a: string. The untokenized text of the first sequence. For single 26 | sequence tasks, only this sequence must be specified. 27 | text_b: (Optional) string. The untokenized text of the second sequence. 28 | Only must be specified for sequence pair tasks. 29 | label: (Optional) string. The label of the example. This should be 30 | specified for train and dev examples, but not for test examples. 31 | """ 32 | self.guid = guid 33 | self.text_a = text_a 34 | self.text_b = text_b 35 | self.label = label 36 | 37 | 38 | class InputFeatures(object): 39 | """A single set of features of data.""" 40 | 41 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 42 | self.input_ids = input_ids 43 | self.input_mask = input_mask 44 | self.segment_ids = segment_ids 45 | self.label_id = label_id 46 | 47 | 48 | class DataProcessor(object): 49 | """Base class for data converters for sequence classification data sets.""" 50 | 51 | def get_train_examples(self, data_dir): 52 | """Gets a collection of `InputExample`s for the train set.""" 53 | raise NotImplementedError() 54 | 55 | def get_dev_examples(self, data_dir): 56 | """Gets a collection of `InputExample`s for the dev set.""" 57 | raise NotImplementedError() 58 | 59 | def get_test_examples(self, data_dir): 60 | """Gets a collection of `InputExample`s for prediction.""" 61 | raise NotImplementedError() 62 | 63 | def get_labels(self): 64 | """Gets the list of labels for this data set.""" 65 | raise NotImplementedError() 66 | 67 | 68 | class SimProcessor(DataProcessor): 69 | def get_train_examples(self, data_dir): 70 | file_path = os.path.join(data_dir, 'train.csv') 71 | train_df = pd.read_csv(file_path, encoding='utf-8') 72 | train_data = [] 73 | for index, train in enumerate(train_df.values): 74 | guid = 'train-%d' % index 75 | text_a = tokenization.convert_to_unicode(str(train[0])) 76 | text_b = tokenization.convert_to_unicode(str(train[1])) 77 | label = str(train[2]) 78 | train_data.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 79 | return train_data 80 | 81 | def get_dev_examples(self, data_dir): 82 | file_path = os.path.join(data_dir, 'dev.csv') 83 | dev_df = pd.read_csv(file_path, encoding='utf-8') 84 | dev_data = [] 85 | for index, dev in enumerate(dev_df.values): 86 | guid = 'test-%d' % index 87 | text_a = tokenization.convert_to_unicode(str(dev[0])) 88 | text_b = tokenization.convert_to_unicode(str(dev[1])) 89 | label = str(dev[2]) 90 | dev_data.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 91 | return dev_data 92 | 93 | def get_test_examples(self, data_dir): 94 | file_path = os.path.join(data_dir, 'test.csv') 95 | test_df = pd.read_csv(file_path, encoding='utf-8') 96 | test_data = [] 97 | for index, test in enumerate(test_df.values): 98 | guid = 'test-%d' % index 99 | text_a = tokenization.convert_to_unicode(str(test[0])) 100 | text_b = tokenization.convert_to_unicode(str(test[1])) 101 | label = str(test[2]) 102 | test_data.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 103 | return test_data 104 | 105 | def get_sentence_examples(self, questions): 106 | for index, data in enumerate(questions): 107 | guid = 'test-%d' % index 108 | text_a = tokenization.convert_to_unicode(str(data[0])) 109 | text_b = tokenization.convert_to_unicode(str(data[1])) 110 | label = str(0) 111 | yield InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) 112 | 113 | def get_labels(self): 114 | return ['0', '1'] 115 | 116 | 117 | class BertSim: 118 | 119 | def __init__(self, batch_size=args.batch_size): 120 | self.mode = None 121 | self.max_seq_length = args.max_seq_len 122 | self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 123 | self.batch_size = batch_size 124 | self.estimator = None 125 | self.processor = SimProcessor() 126 | tf.logging.set_verbosity(tf.logging.INFO) 127 | 128 | def set_mode(self, mode): 129 | self.mode = mode 130 | self.estimator = self.get_estimator() 131 | if mode == tf.estimator.ModeKeys.PREDICT: 132 | self.input_queue = Queue(maxsize=1) 133 | self.output_queue = Queue(maxsize=1) 134 | self.predict_thread = Thread(target=self.predict_from_queue, daemon=True) 135 | self.predict_thread.start() 136 | 137 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 138 | labels, num_labels, use_one_hot_embeddings): 139 | """Creates a classification model.""" 140 | model = modeling.BertModel( 141 | config=bert_config, 142 | is_training=is_training, 143 | input_ids=input_ids, 144 | input_mask=input_mask, 145 | token_type_ids=segment_ids, 146 | use_one_hot_embeddings=use_one_hot_embeddings) 147 | 148 | # In the demo, we are doing a simple classification task on the entire 149 | # segment. 150 | # 151 | # If you want to use the token-level output, use model.get_sequence_output() 152 | # instead. 153 | output_layer = model.get_pooled_output() 154 | 155 | hidden_size = output_layer.shape[-1].value 156 | 157 | output_weights = tf.get_variable( 158 | "output_weights", [num_labels, hidden_size], 159 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 160 | 161 | output_bias = tf.get_variable( 162 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 163 | 164 | with tf.variable_scope("loss"): 165 | if is_training: 166 | # I.e., 0.1 dropout 167 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 168 | 169 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 170 | logits = tf.nn.bias_add(logits, output_bias) 171 | probabilities = tf.nn.softmax(logits, axis=-1) 172 | log_probs = tf.nn.log_softmax(logits, axis=-1) 173 | 174 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 175 | 176 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 177 | loss = tf.reduce_mean(per_example_loss) 178 | 179 | return (loss, per_example_loss, logits, probabilities) 180 | 181 | def model_fn_builder(self, bert_config, num_labels, init_checkpoint, learning_rate, 182 | num_train_steps, num_warmup_steps, 183 | use_one_hot_embeddings): 184 | """Returns `model_fn` closurimport_tfe for TPUEstimator.""" 185 | 186 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 187 | from tensorflow.python.estimator.model_fn import EstimatorSpec 188 | 189 | tf.logging.info("*** Features ***") 190 | for name in sorted(features.keys()): 191 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 192 | 193 | input_ids = features["input_ids"] 194 | input_mask = features["input_mask"] 195 | segment_ids = features["segment_ids"] 196 | label_ids = features["label_ids"] 197 | 198 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 199 | 200 | (total_loss, per_example_loss, logits, probabilities) = BertSim.create_model( 201 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 202 | num_labels, use_one_hot_embeddings) 203 | 204 | tvars = tf.trainable_variables() 205 | initialized_variable_names = {} 206 | 207 | if init_checkpoint: 208 | (assignment_map, initialized_variable_names) \ 209 | = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 210 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 211 | 212 | tf.logging.info("**** Trainable Variables ****") 213 | for var in tvars: 214 | init_string = "" 215 | if var.name in initialized_variable_names: 216 | init_string = ", *INIT_FROM_CKPT*" 217 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 218 | init_string) 219 | 220 | if mode == tf.estimator.ModeKeys.TRAIN: 221 | 222 | train_op = optimization.create_optimizer( 223 | total_loss, learning_rate, num_train_steps, num_warmup_steps, False) 224 | 225 | output_spec = EstimatorSpec( 226 | mode=mode, 227 | loss=total_loss, 228 | train_op=train_op) 229 | elif mode == tf.estimator.ModeKeys.EVAL: 230 | 231 | def metric_fn(per_example_loss, label_ids, logits): 232 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 233 | accuracy = tf.metrics.accuracy(label_ids, predictions) 234 | auc = tf.metrics.auc(label_ids, predictions) 235 | loss = tf.metrics.mean(per_example_loss) 236 | return { 237 | "eval_accuracy": accuracy, 238 | "eval_auc": auc, 239 | "eval_loss": loss, 240 | } 241 | 242 | eval_metrics = metric_fn(per_example_loss, label_ids, logits) 243 | output_spec = EstimatorSpec( 244 | mode=mode, 245 | loss=total_loss, 246 | eval_metric_ops=eval_metrics) 247 | else: 248 | output_spec = EstimatorSpec(mode=mode, predictions=probabilities) 249 | 250 | return output_spec 251 | 252 | return model_fn 253 | 254 | def get_estimator(self): 255 | 256 | from tensorflow.python.estimator.estimator import Estimator 257 | from tensorflow.python.estimator.run_config import RunConfig 258 | 259 | bert_config = modeling.BertConfig.from_json_file(args.config_name) 260 | label_list = self.processor.get_labels() 261 | train_examples = self.processor.get_train_examples(args.data_dir) 262 | num_train_steps = int( 263 | len(train_examples) / self.batch_size * args.num_train_epochs) 264 | num_warmup_steps = int(num_train_steps * 0.1) 265 | 266 | if self.mode == tf.estimator.ModeKeys.TRAIN: 267 | init_checkpoint = args.ckpt_name 268 | else: 269 | init_checkpoint = args.output_dir 270 | 271 | model_fn = self.model_fn_builder( 272 | bert_config=bert_config, 273 | num_labels=len(label_list), 274 | init_checkpoint=init_checkpoint, 275 | learning_rate=args.learning_rate, 276 | num_train_steps=num_train_steps, 277 | num_warmup_steps=num_warmup_steps, 278 | use_one_hot_embeddings=False) 279 | 280 | config = tf.ConfigProto() 281 | config.gpu_options.allow_growth = True 282 | config.gpu_options.per_process_gpu_memory_fraction = args.gpu_memory_fraction 283 | config.log_device_placement = False 284 | 285 | return Estimator(model_fn=model_fn, config=RunConfig(session_config=config), model_dir=args.output_dir, 286 | params={'batch_size': self.batch_size}) 287 | 288 | def predict_from_queue(self): 289 | for i in self.estimator.predict(input_fn=self.queue_predict_input_fn, yield_single_examples=False): 290 | self.output_queue.put(i) 291 | 292 | def queue_predict_input_fn(self): 293 | return (tf.data.Dataset.from_generator( 294 | self.generate_from_queue, 295 | output_types={ 296 | 'input_ids': tf.int32, 297 | 'input_mask': tf.int32, 298 | 'segment_ids': tf.int32, 299 | 'label_ids': tf.int32}, 300 | output_shapes={ 301 | 'input_ids': (None, self.max_seq_length), 302 | 'input_mask': (None, self.max_seq_length), 303 | 'segment_ids': (None, self.max_seq_length), 304 | 'label_ids': (1,)}).prefetch(10)) 305 | 306 | def convert_examples_to_features(self, examples, label_list, max_seq_length, tokenizer): 307 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 308 | 309 | for (ex_index, example) in enumerate(examples): 310 | label_map = {} 311 | for (i, label) in enumerate(label_list): 312 | label_map[label] = i 313 | 314 | tokens_a = tokenizer.tokenize(example.text_a) 315 | tokens_b = None 316 | if example.text_b: 317 | tokens_b = tokenizer.tokenize(example.text_b) 318 | 319 | if tokens_b: 320 | # Modifies `tokens_a` and `tokens_b` in place so that the total 321 | # length is less than the specified length. 322 | # Account for [CLS], [SEP], [SEP] with "- 3" 323 | self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 324 | else: 325 | # Account for [CLS] and [SEP] with "- 2" 326 | if len(tokens_a) > max_seq_length - 2: 327 | tokens_a = tokens_a[0:(max_seq_length - 2)] 328 | 329 | # The convention in BERT is: 330 | # (a) For sequence pairs: 331 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 332 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 333 | # (b) For single sequences: 334 | # tokens: [CLS] the dog is hairy . [SEP] 335 | # type_ids: 0 0 0 0 0 0 0 336 | # 337 | # Where "type_ids" are used to indicate whether this is the first 338 | # sequence or the second sequence. The embedding vectors for `type=0` and 339 | # `type=1` were learned during pre-training and are added to the wordpiece 340 | # embedding vector (and position vector). This is not *strictly* necessary 341 | # since the [SEP] token unambiguously separates the sequences, but it makes 342 | # it easier for the model to learn the concept of sequences. 343 | # 344 | # For classification tasks, the first vector (corresponding to [CLS]) is 345 | # used as as the "sentence vector". Note that this only makes sense because 346 | # the entire model is fine-tuned. 347 | tokens = [] 348 | segment_ids = [] 349 | tokens.append("[CLS]") 350 | segment_ids.append(0) 351 | for token in tokens_a: 352 | tokens.append(token) 353 | segment_ids.append(0) 354 | tokens.append("[SEP]") 355 | segment_ids.append(0) 356 | 357 | if tokens_b: 358 | for token in tokens_b: 359 | tokens.append(token) 360 | segment_ids.append(1) 361 | tokens.append("[SEP]") 362 | segment_ids.append(1) 363 | 364 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 365 | 366 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 367 | # tokens are attended to. 368 | input_mask = [1] * len(input_ids) 369 | 370 | # Zero-pad up to the sequence length. 371 | while len(input_ids) < max_seq_length: 372 | input_ids.append(0) 373 | input_mask.append(0) 374 | segment_ids.append(0) 375 | 376 | assert len(input_ids) == max_seq_length 377 | assert len(input_mask) == max_seq_length 378 | assert len(segment_ids) == max_seq_length 379 | 380 | label_id = label_map[example.label] 381 | if ex_index < 5: 382 | tf.logging.info("*** Example ***") 383 | tf.logging.info("guid: %s" % (example.guid)) 384 | tf.logging.info("tokens: %s" % " ".join( 385 | [tokenization.printable_text(x) for x in tokens])) 386 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 387 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 388 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 389 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 390 | 391 | feature = InputFeatures( 392 | input_ids=input_ids, 393 | input_mask=input_mask, 394 | segment_ids=segment_ids, 395 | label_id=label_id) 396 | 397 | yield feature 398 | 399 | def generate_from_queue(self): 400 | while True: 401 | predict_examples = self.processor.get_sentence_examples(self.input_queue.get()) 402 | features = list(self.convert_examples_to_features(predict_examples, self.processor.get_labels(), 403 | args.max_seq_len, self.tokenizer)) 404 | yield { 405 | 'input_ids': [f.input_ids for f in features], 406 | 'input_mask': [f.input_mask for f in features], 407 | 'segment_ids': [f.segment_ids for f in features], 408 | 'label_ids': [f.label_id for f in features] 409 | } 410 | 411 | def _truncate_seq_pair(self, tokens_a, tokens_b, max_length): 412 | """Truncates a sequence pair in place to the maximum length.""" 413 | 414 | # This is a simple heuristic which will always truncate the longer sequence 415 | # one token at a time. This makes more sense than truncating an equal percent 416 | # of tokens from each, since if one sequence is very short then each token 417 | # that's truncated likely contains more information than a longer sequence. 418 | while True: 419 | total_length = len(tokens_a) + len(tokens_b) 420 | if total_length <= max_length: 421 | break 422 | if len(tokens_a) > len(tokens_b): 423 | tokens_a.pop() 424 | else: 425 | tokens_b.pop() 426 | 427 | def convert_single_example(self, ex_index, example, label_list, max_seq_length, tokenizer): 428 | """Converts a single `InputExample` into a single `InputFeatures`.""" 429 | label_map = {} 430 | for (i, label) in enumerate(label_list): 431 | label_map[label] = i 432 | 433 | tokens_a = tokenizer.tokenize(example.text_a) 434 | tokens_b = None 435 | if example.text_b: 436 | tokens_b = tokenizer.tokenize(example.text_b) 437 | 438 | if tokens_b: 439 | # Modifies `tokens_a` and `tokens_b` in place so that the total 440 | # length is less than the specified length. 441 | # Account for [CLS], [SEP], [SEP] with "- 3" 442 | self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 443 | else: 444 | # Account for [CLS] and [SEP] with "- 2" 445 | if len(tokens_a) > max_seq_length - 2: 446 | tokens_a = tokens_a[0:(max_seq_length - 2)] 447 | 448 | # The convention in BERT is: 449 | # (a) For sequence pairs: 450 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 451 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 452 | # (b) For single sequences: 453 | # tokens: [CLS] the dog is hairy . [SEP] 454 | # type_ids: 0 0 0 0 0 0 0 455 | # 456 | # Where "type_ids" are used to indicate whether this is the first 457 | # sequence or the second sequence. The embedding vectors for `type=0` and 458 | # `type=1` were learned during pre-training and are added to the wordpiece 459 | # embedding vector (and position vector). This is not *strictly* necessary 460 | # since the [SEP] token unambiguously separates the sequences, but it makes 461 | # it easier for the model to learn the concept of sequences. 462 | # 463 | # For classification tasks, the first vector (corresponding to [CLS]) is 464 | # used as as the "sentence vector". Note that this only makes sense because 465 | # the entire model is fine-tuned. 466 | tokens = [] 467 | segment_ids = [] 468 | tokens.append("[CLS]") 469 | segment_ids.append(0) 470 | for token in tokens_a: 471 | tokens.append(token) 472 | segment_ids.append(0) 473 | tokens.append("[SEP]") 474 | segment_ids.append(0) 475 | 476 | if tokens_b: 477 | for token in tokens_b: 478 | tokens.append(token) 479 | segment_ids.append(1) 480 | tokens.append("[SEP]") 481 | segment_ids.append(1) 482 | 483 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 484 | 485 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 486 | # tokens are attended to. 487 | input_mask = [1] * len(input_ids) 488 | 489 | # Zero-pad up to the sequence length. 490 | while len(input_ids) < max_seq_length: 491 | input_ids.append(0) 492 | input_mask.append(0) 493 | segment_ids.append(0) 494 | 495 | assert len(input_ids) == max_seq_length 496 | assert len(input_mask) == max_seq_length 497 | assert len(segment_ids) == max_seq_length 498 | 499 | label_id = label_map[example.label] 500 | if ex_index < 5: 501 | tf.logging.info("*** Example ***") 502 | tf.logging.info("guid: %s" % (example.guid)) 503 | tf.logging.info("tokens: %s" % " ".join( 504 | [tokenization.printable_text(x) for x in tokens])) 505 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 506 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 507 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 508 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 509 | 510 | feature = InputFeatures( 511 | input_ids=input_ids, 512 | input_mask=input_mask, 513 | segment_ids=segment_ids, 514 | label_id=label_id) 515 | return feature 516 | 517 | def file_based_convert_examples_to_features(self, examples, label_list, max_seq_length, tokenizer, output_file): 518 | """Convert a set of `InputExample`s to a TFRecord file.""" 519 | 520 | writer = tf.python_io.TFRecordWriter(output_file) 521 | 522 | for (ex_index, example) in enumerate(examples): 523 | if ex_index % 10000 == 0: 524 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 525 | 526 | feature = self.convert_single_example(ex_index, example, label_list, 527 | max_seq_length, tokenizer) 528 | 529 | def create_int_feature(values): 530 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 531 | return f 532 | 533 | features = collections.OrderedDict() 534 | features["input_ids"] = create_int_feature(feature.input_ids) 535 | features["input_mask"] = create_int_feature(feature.input_mask) 536 | features["segment_ids"] = create_int_feature(feature.segment_ids) 537 | features["label_ids"] = create_int_feature([feature.label_id]) 538 | 539 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 540 | writer.write(tf_example.SerializeToString()) 541 | 542 | def file_based_input_fn_builder(self, input_file, seq_length, is_training, drop_remainder): 543 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 544 | 545 | name_to_features = { 546 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 547 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 548 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 549 | "label_ids": tf.FixedLenFeature([], tf.int64), 550 | } 551 | 552 | def _decode_record(record, name_to_features): 553 | """Decodes a record to a TensorFlow example.""" 554 | example = tf.parse_single_example(record, name_to_features) 555 | 556 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 557 | # So cast all int64 to int32. 558 | for name in list(example.keys()): 559 | t = example[name] 560 | if t.dtype == tf.int64: 561 | t = tf.to_int32(t) 562 | example[name] = t 563 | 564 | return example 565 | 566 | def input_fn(params): 567 | """The actual input function.""" 568 | batch_size = params["batch_size"] 569 | 570 | # For training, we want a lot of parallel reading and shuffling. 571 | # For eval, we want no shuffling and parallel reading doesn't matter. 572 | d = tf.data.TFRecordDataset(input_file) 573 | if is_training: 574 | d = d.repeat() 575 | d = d.shuffle(buffer_size=100) 576 | 577 | d = d.apply( 578 | tf.contrib.data.map_and_batch( 579 | lambda record: _decode_record(record, name_to_features), 580 | batch_size=batch_size, 581 | drop_remainder=drop_remainder)) 582 | 583 | return d 584 | 585 | return input_fn 586 | 587 | def train(self): 588 | if self.mode is None: 589 | raise ValueError("Please set the 'mode' parameter") 590 | 591 | bert_config = modeling.BertConfig.from_json_file(args.config_name) 592 | 593 | if args.max_seq_len > bert_config.max_position_embeddings: 594 | raise ValueError( 595 | "Cannot use sequence length %d because the BERT model " 596 | "was only trained up to sequence length %d" % 597 | (args.max_seq_len, bert_config.max_position_embeddings)) 598 | 599 | tf.gfile.MakeDirs(args.output_dir) 600 | 601 | label_list = self.processor.get_labels() 602 | 603 | train_examples = self.processor.get_train_examples(args.data_dir) 604 | num_train_steps = int(len(train_examples) / args.batch_size * args.num_train_epochs) 605 | 606 | estimator = self.get_estimator() 607 | 608 | train_file = os.path.join(args.output_dir, "train.tf_record") 609 | self.file_based_convert_examples_to_features(train_examples, label_list, args.max_seq_len, self.tokenizer, 610 | train_file) 611 | tf.logging.info("***** Running training *****") 612 | tf.logging.info(" Num examples = %d", len(train_examples)) 613 | tf.logging.info(" Batch size = %d", args.batch_size) 614 | tf.logging.info(" Num steps = %d", num_train_steps) 615 | train_input_fn = self.file_based_input_fn_builder(input_file=train_file, seq_length=args.max_seq_len, 616 | is_training=True, 617 | drop_remainder=True) 618 | 619 | # early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook( 620 | # estimator, 621 | # metric_name='loss', 622 | # max_steps_without_decrease=10, 623 | # min_steps=num_train_steps) 624 | 625 | # estimator.train(input_fn=train_input_fn, hooks=[early_stopping]) 626 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 627 | 628 | def eval(self): 629 | if self.mode is None: 630 | raise ValueError("Please set the 'mode' parameter") 631 | eval_examples = self.processor.get_dev_examples(args.data_dir) 632 | eval_file = os.path.join(args.output_dir, "eval.tf_record") 633 | label_list = self.processor.get_labels() 634 | self.file_based_convert_examples_to_features( 635 | eval_examples, label_list, args.max_seq_len, self.tokenizer, eval_file) 636 | 637 | tf.logging.info("***** Running evaluation *****") 638 | tf.logging.info(" Num examples = %d", len(eval_examples)) 639 | tf.logging.info(" Batch size = %d", self.batch_size) 640 | 641 | eval_input_fn = self.file_based_input_fn_builder( 642 | input_file=eval_file, 643 | seq_length=args.max_seq_len, 644 | is_training=False, 645 | drop_remainder=False) 646 | 647 | estimator = self.get_estimator() 648 | result = estimator.evaluate(input_fn=eval_input_fn, steps=None) 649 | 650 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 651 | with tf.gfile.GFile(output_eval_file, "w") as writer: 652 | tf.logging.info("***** Eval results *****") 653 | for key in sorted(result.keys()): 654 | tf.logging.info(" %s = %s", key, str(result[key])) 655 | writer.write("%s = %s\n" % (key, str(result[key]))) 656 | 657 | def predict(self, sentence1, sentence2): 658 | if self.mode is None: 659 | raise ValueError("Please set the 'mode' parameter") 660 | self.input_queue.put([(sentence1, sentence2)]) 661 | prediction = self.output_queue.get() 662 | return prediction 663 | 664 | 665 | if __name__ == '__main__': 666 | sim = BertSim() 667 | sim.set_mode(tf.estimator.ModeKeys.TRAIN) 668 | sim.train() 669 | sim.set_mode(tf.estimator.ModeKeys.EVAL) 670 | sim.eval() 671 | # sim.set_mode(tf.estimator.ModeKeys.PREDICT) 672 | # while True: 673 | # sentence1 = input('sentence1: ') 674 | # sentence2 = input('sentence2: ') 675 | # predict = sim.predict(sentence1, sentence2) 676 | # print(f'similarity:{predict[0][1]}') 677 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_by_vocab(vocab, items): 86 | """Converts a sequence of [tokens|ids] using the vocab.""" 87 | output = [] 88 | for item in items: 89 | output.append(vocab[item]) 90 | return output 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | return convert_by_vocab(vocab, tokens) 95 | 96 | 97 | def convert_ids_to_tokens(inv_vocab, ids): 98 | return convert_by_vocab(inv_vocab, ids) 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab = load_vocab(vocab_file) 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | 119 | def tokenize(self, text): 120 | split_tokens = [] 121 | for token in self.basic_tokenizer.tokenize(text): 122 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 123 | split_tokens.append(sub_token) 124 | 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | return convert_by_vocab(self.vocab, tokens) 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | return convert_by_vocab(self.inv_vocab, ids) 132 | 133 | 134 | class BasicTokenizer(object): 135 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 136 | 137 | def __init__(self, do_lower_case=True): 138 | """Constructs a BasicTokenizer. 139 | 140 | Args: 141 | do_lower_case: Whether to lower case the input. 142 | """ 143 | self.do_lower_case = do_lower_case 144 | 145 | def tokenize(self, text): 146 | """Tokenizes a piece of text.""" 147 | text = convert_to_unicode(text) 148 | text = self._clean_text(text) 149 | 150 | # This was added on November 1st, 2018 for the multilingual and Chinese 151 | # models. This is also applied to the English models now, but it doesn't 152 | # matter since the English models were not trained on any Chinese data 153 | # and generally don't have any Chinese data in them (there are Chinese 154 | # characters in the vocabulary because Wikipedia does have some Chinese 155 | # words in the English Wikipedia.). 156 | text = self._tokenize_chinese_chars(text) 157 | 158 | orig_tokens = whitespace_tokenize(text) 159 | split_tokens = [] 160 | for token in orig_tokens: 161 | if self.do_lower_case: 162 | token = token.lower() 163 | token = self._run_strip_accents(token) 164 | split_tokens.extend(self._run_split_on_punc(token)) 165 | 166 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 167 | return output_tokens 168 | 169 | def _run_strip_accents(self, text): 170 | """Strips accents from a piece of text.""" 171 | text = unicodedata.normalize("NFD", text) 172 | output = [] 173 | for char in text: 174 | cat = unicodedata.category(char) 175 | if cat == "Mn": 176 | continue 177 | output.append(char) 178 | return "".join(output) 179 | 180 | def _run_split_on_punc(self, text): 181 | """Splits punctuation on a piece of text.""" 182 | chars = list(text) 183 | i = 0 184 | start_new_word = True 185 | output = [] 186 | while i < len(chars): 187 | char = chars[i] 188 | if _is_punctuation(char): 189 | output.append([char]) 190 | start_new_word = True 191 | else: 192 | if start_new_word: 193 | output.append([]) 194 | start_new_word = False 195 | output[-1].append(char) 196 | i += 1 197 | 198 | return ["".join(x) for x in output] 199 | 200 | def _tokenize_chinese_chars(self, text): 201 | """Adds whitespace around any CJK character.""" 202 | output = [] 203 | for char in text: 204 | cp = ord(char) 205 | if self._is_chinese_char(cp): 206 | output.append(" ") 207 | output.append(char) 208 | output.append(" ") 209 | else: 210 | output.append(char) 211 | return "".join(output) 212 | 213 | def _is_chinese_char(self, cp): 214 | """Checks whether CP is the codepoint of a CJK character.""" 215 | # This defines a "chinese character" as anything in the CJK Unicode block: 216 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 217 | # 218 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 219 | # despite its name. The modern Korean Hangul alphabet is a different block, 220 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 221 | # space-separated words, so they are not treated specially and handled 222 | # like the all of the other languages. 223 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 224 | (cp >= 0x3400 and cp <= 0x4DBF) or # 225 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 226 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 227 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 228 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 229 | (cp >= 0xF900 and cp <= 0xFAFF) or # 230 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 231 | return True 232 | 233 | return False 234 | 235 | def _clean_text(self, text): 236 | """Performs invalid character removal and whitespace cleanup on text.""" 237 | output = [] 238 | for char in text: 239 | cp = ord(char) 240 | if cp == 0 or cp == 0xfffd or _is_control(char): 241 | continue 242 | if _is_whitespace(char): 243 | output.append(" ") 244 | else: 245 | output.append(char) 246 | return "".join(output) 247 | 248 | 249 | class WordpieceTokenizer(object): 250 | """Runs WordPiece tokenziation.""" 251 | 252 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 253 | self.vocab = vocab 254 | self.unk_token = unk_token 255 | self.max_input_chars_per_word = max_input_chars_per_word 256 | 257 | def tokenize(self, text): 258 | """Tokenizes a piece of text into its word pieces. 259 | 260 | This uses a greedy longest-match-first algorithm to perform tokenization 261 | using the given vocabulary. 262 | 263 | For example: 264 | input = "unaffable" 265 | output = ["un", "##aff", "##able"] 266 | 267 | Args: 268 | text: A single token or whitespace separated tokens. This should have 269 | already been passed through `BasicTokenizer. 270 | 271 | Returns: 272 | A list of wordpiece tokens. 273 | """ 274 | 275 | text = convert_to_unicode(text) 276 | 277 | output_tokens = [] 278 | for token in whitespace_tokenize(text): 279 | chars = list(token) 280 | if len(chars) > self.max_input_chars_per_word: 281 | output_tokens.append(self.unk_token) 282 | continue 283 | 284 | is_bad = False 285 | start = 0 286 | sub_tokens = [] 287 | while start < len(chars): 288 | end = len(chars) 289 | cur_substr = None 290 | while start < end: 291 | substr = "".join(chars[start:end]) 292 | if start > 0: 293 | substr = "##" + substr 294 | if substr in self.vocab: 295 | cur_substr = substr 296 | break 297 | end -= 1 298 | if cur_substr is None: 299 | is_bad = True 300 | break 301 | sub_tokens.append(cur_substr) 302 | start = end 303 | 304 | if is_bad: 305 | output_tokens.append(self.unk_token) 306 | else: 307 | output_tokens.extend(sub_tokens) 308 | return output_tokens 309 | 310 | 311 | def _is_whitespace(char): 312 | """Checks whether `chars` is a whitespace character.""" 313 | # \t, \n, and \r are technically contorl characters but we treat them 314 | # as whitespace since they are generally considered as such. 315 | if char == " " or char == "\t" or char == "\n" or char == "\r": 316 | return True 317 | cat = unicodedata.category(char) 318 | if cat == "Zs": 319 | return True 320 | return False 321 | 322 | 323 | def _is_control(char): 324 | """Checks whether `chars` is a control character.""" 325 | # These are technically control characters but we count them as whitespace 326 | # characters. 327 | if char == "\t" or char == "\n" or char == "\r": 328 | return False 329 | cat = unicodedata.category(char) 330 | if cat.startswith("C"): 331 | return True 332 | return False 333 | 334 | 335 | def _is_punctuation(char): 336 | """Checks whether `chars` is a punctuation character.""" 337 | cp = ord(char) 338 | # We treat all non-letter/number ASCII as punctuation. 339 | # Characters such as "^", "$", and "`" are not in the Unicode 340 | # Punctuation class but we treat them as punctuation anyways, for 341 | # consistency. 342 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 343 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 344 | return True 345 | cat = unicodedata.category(char) 346 | if cat.startswith("P"): 347 | return True 348 | return False 349 | --------------------------------------------------------------------------------