├── README.md ├── encode ├── readme.md ├── bert_tensorrt.py ├── bert_base_build.py ├── tokenization.py └── data_processing_new.py └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | # BERT4TensorRT 2 | TensorRT 3 | -------------------------------------------------------------------------------- /encode/readme.md: -------------------------------------------------------------------------------- 1 | # 使用TensorRT做encode 2 | 官方有个fast transformer 3 | 这个项目demo/Bert下改的 4 | 5 | ## 模型转换成.engine文件,在bert_build文件下做如下修改 6 | 1. 删除了squad函数 7 | 2. 删除了squad的调用 8 | 9 | ## 修改后再改data_processing文件 10 | 1. 原来是处理MC数据的,现在改成处理一般句子的 11 | 12 | 13 | ## 修改BERT_TRT 14 | 1. 原来是处理MC数据的,现在改成处理一般句子的 15 | 16 | 17 | 18 | OK,完成以上操作,TensorRT的bert就完成了 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /encode/bert_tensorrt.py: -------------------------------------------------------------------------------- 1 | import data_processing_new as dpn 2 | import tokenization 3 | 4 | import tensorrt as trt 5 | TRT_LOGGER = trt.Logger(trt.Logger.WARNING) 6 | 7 | import ctypes 8 | nvinfer = ctypes.CDLL("libnvinfer_plugin.so", mode = ctypes.RTLD_GLOBAL) 9 | cm = ctypes.CDLL("/workspace/TensorRT/demo/BERT/build/libcommon.so", mode = ctypes.RTLD_GLOBAL) 10 | pg = ctypes.CDLL("/workspace/TensorRT/demo/BERT/build/libbert_plugins.so", mode = ctypes.RTLD_GLOBAL) 11 | 12 | bert_engine = '/workspace/TensorRT/demo/BERT/python/bert_base.engine' 13 | vocab_file = '/workspace/models/fine-tuned/chinese_bert_base/vocab.txt' 14 | batch_size = 1 15 | 16 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) 17 | 18 | max_query_length = 64 19 | 20 | max_seq_length = 128 21 | 22 | eval_start_time = time.time() 23 | input_features = dpn.convert_examples_to_features(text, None, tokenizer, max_seq_length) 24 | time.time() - eval_start_time 25 | 26 | import pycuda.driver as cuda 27 | import pycuda.autoinit 28 | import numpy as np 29 | import time 30 | 31 | max_batch_size = 1 32 | 33 | text = '早安' 34 | eval_start_time = time.time() 35 | input_features = dpn.convert_examples_to_features(text, None, tokenizer, max_seq_length) 36 | 37 | with open("./bert_base.engine", "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, \ 38 | runtime.deserialize_cuda_engine(f.read()) as engine, engine.create_execution_context() as context: 39 | 40 | print("List engine binding:") 41 | for binding in engine: 42 | print(" - {}: {}, Shape {}, {}".format( 43 | "Input" if engine.binding_is_input(binding) else "Output", 44 | binding, 45 | engine.get_binding_shape(binding), 46 | engine.get_binding_dtype(binding))) 47 | 48 | def binding_nbytes(binding): 49 | return trt.volume(engine.get_binding_shape(binding)) * engine.get_binding_dtype(binding).itemsize 50 | 51 | d_inputs = [cuda.mem_alloc(binding_nbytes(binding)) for binding in engine if engine.binding_is_input(binding)] 52 | h_output = cuda.pagelocked_empty(tuple(engine.get_binding_shape(3)), dtype=np.float32) 53 | d_output = cuda.mem_alloc(h_output.nbytes) 54 | 55 | stream = cuda.Stream() 56 | print("\nRunning Inference...") 57 | 58 | cuda.memcpy_htod_async(d_inputs[0], input_features["input_ids"], stream) 59 | cuda.memcpy_htod_async(d_inputs[1], input_features["segment_ids"], stream) 60 | cuda.memcpy_htod_async(d_inputs[2], input_features["input_mask"], stream) 61 | 62 | context.execute_async(bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle) 63 | cuda.memcpy_dtoh_async(h_output, d_output, stream) 64 | stream.synchronize() 65 | 66 | eval_time_elapsed = time.time() - eval_start_time 67 | 68 | 69 | print(eval_time_elapsed * 1000) 70 | 71 | a = h_output.reshape(128, 768)[0, :] -------------------------------------------------------------------------------- /encode/bert_base_build.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import ctypes 3 | import argparse 4 | import numpy as np 5 | import json 6 | import sys 7 | import re 8 | import os 9 | 10 | 11 | try: 12 | from tensorflow.python import pywrap_tensorflow as pyTF 13 | except ImportError as err: 14 | sys.stderr.write("""Error: Failed to import tensorflow module ({})""".format(err)) 15 | sys.exit() 16 | 17 | nvinfer = ctypes.CDLL("libnvinfer_plugin.so", mode=ctypes.RTLD_GLOBAL) 18 | cm = ctypes.CDLL("/workspace/TensorRT/demo/BERT/build/libcommon.so", mode=ctypes.RTLD_GLOBAL) 19 | pg = ctypes.CDLL("/workspace/TensorRT/demo/BERT/build/libbert_plugins.so", mode=ctypes.RTLD_GLOBAL) 20 | 21 | 22 | """ 23 | TensorRT Initialization 24 | """ 25 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) 26 | trt.init_libnvinfer_plugins(TRT_LOGGER, "") 27 | plg_registry = trt.get_plugin_registry() 28 | qkv2_plg_creator = plg_registry.get_plugin_creator("CustomQKVToContextPlugin", "1", "") 29 | skln_plg_creator = plg_registry.get_plugin_creator("CustomSkipLayerNormPlugin", "1", "") 30 | gelu_plg_creator = plg_registry.get_plugin_creator("CustomGeluPlugin", "1", "") 31 | emln_plg_creator = plg_registry.get_plugin_creator("CustomEmbLayerNormPlugin", "1", "") 32 | 33 | 34 | """ 35 | Attentions Keys 36 | """ 37 | WQ = "query_kernel" 38 | BQ = "query_bias" 39 | WK = "key_kernel" 40 | BK = "key_bias" 41 | WV = "value_kernel" 42 | BV = "value_bias" 43 | WQKV = "qkv_kernel" 44 | BQKV = "qkv_bias" 45 | 46 | 47 | """ 48 | Transformer Keys 49 | """ 50 | W_AOUT = "attention_output_dense_kernel" 51 | B_AOUT = "attention_output_dense_bias" 52 | AOUT_LN_BETA = "attention_output_layernorm_beta" 53 | AOUT_LN_GAMMA = "attention_output_layernorm_gamma" 54 | W_MID = "intermediate_dense_kernel" 55 | B_MID = "intermediate_dense_bias" 56 | W_LOUT = "output_dense_kernel" 57 | B_LOUT = "output_dense_bias" 58 | LOUT_LN_BETA = "output_layernorm_beta" 59 | LOUT_LN_GAMMA = "output_layernorm_gamma" 60 | 61 | 62 | 63 | class BertConfig: 64 | def __init__(self, bert_config_path): 65 | with open(bert_config_path, 'r') as f: 66 | data = json.load(f) 67 | self.num_attention_heads = data['num_attention_heads'] 68 | self.hidden_size = data['hidden_size'] 69 | self.intermediate_size = data['intermediate_size'] 70 | self.num_hidden_layers = data['num_hidden_layers'] 71 | self.use_fp16 = True 72 | 73 | 74 | def set_tensor_name(tensor, prefix, name): 75 | tensor.name = prefix + name 76 | 77 | def set_layer_name(layer, prefix, name, out_idx = 0): 78 | set_tensor_name(layer.get_output(out_idx), prefix, name) 79 | 80 | def attention_layer_opt(prefix, config, init_dict, network, input_tensor, imask): 81 | """ 82 | Add the attention layer 83 | """ 84 | assert(len(input_tensor.shape) == 4) 85 | S, hidden_size, _, _ = input_tensor.shape 86 | num_heads = config.num_attention_heads 87 | head_size = int(hidden_size / num_heads) 88 | 89 | Wall = init_dict[prefix + WQKV] 90 | Ball = init_dict[prefix + BQKV] 91 | 92 | mult_all = network.add_fully_connected(input_tensor, 3 * hidden_size, Wall, Ball) 93 | set_layer_name(mult_all, prefix, "qkv_mult") 94 | 95 | has_mask = imask != None 96 | 97 | pf_hidden_size = trt.PluginField("hidden_size", np.array([hidden_size], np.int32), trt.PluginFieldType.INT32) 98 | pf_num_heads = trt.PluginField("num_heads", np.array([num_heads], np.int32), trt.PluginFieldType.INT32) 99 | pf_S = trt.PluginField("S", np.array([S], np.int32), trt.PluginFieldType.INT32) 100 | pf_has_mask = trt.PluginField("has_mask", np.array([has_mask], np.int32), trt.PluginFieldType.INT32) 101 | 102 | pfc = trt.PluginFieldCollection([pf_hidden_size, pf_num_heads, pf_S, pf_has_mask]) 103 | qkv2ctx_plug = qkv2_plg_creator.create_plugin("qkv2ctx", pfc) 104 | 105 | qkv_in = [mult_all.get_output(0), imask] 106 | qkv2ctx = network.add_plugin_v2(qkv_in, qkv2ctx_plug) 107 | set_layer_name(qkv2ctx, prefix, "context_layer") 108 | return qkv2ctx 109 | 110 | 111 | def skipln(prefix, init_dict, network, input_tensor, skip): 112 | """ 113 | Add the skip layer 114 | """ 115 | idims = input_tensor.shape 116 | assert len(idims) == 4 117 | hidden_size = idims[1] 118 | 119 | pf_ld = trt.PluginField("ld", np.array([hidden_size], np.int32), trt.PluginFieldType.INT32) 120 | wbeta = init_dict[prefix + "beta"] 121 | pf_beta = trt.PluginField("beta", wbeta.numpy(), trt.PluginFieldType.FLOAT32) 122 | wgamma = init_dict[prefix + "gamma"] 123 | pf_gamma = trt.PluginField("gamma", wgamma.numpy(), trt.PluginFieldType.FLOAT32) 124 | 125 | pfc = trt.PluginFieldCollection([pf_ld, pf_beta, pf_gamma]) 126 | skipln_plug = skln_plg_creator.create_plugin("skipln", pfc) 127 | 128 | skipln_inputs = [input_tensor, skip] 129 | layer = network.add_plugin_v2(skipln_inputs, skipln_plug) 130 | return layer 131 | 132 | 133 | def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imask): 134 | """ 135 | Add the transformer layer 136 | """ 137 | idims = input_tensor.shape 138 | assert len(idims) == 4 139 | hidden_size = idims[1] 140 | 141 | context_transposed = attention_layer_opt(prefix + "attention_self_", config, init_dict, network, input_tensor, imask) 142 | attention_heads = context_transposed.get_output(0) 143 | 144 | W_aout = init_dict[prefix + W_AOUT] 145 | B_aout = init_dict[prefix + B_AOUT] 146 | attention_out_fc = network.add_fully_connected(attention_heads, hidden_size, W_aout, B_aout) 147 | 148 | skiplayer = skipln(prefix + "attention_output_layernorm_", init_dict, network, attention_out_fc.get_output(0), input_tensor) 149 | attention_ln = skiplayer.get_output(0) 150 | 151 | W_mid = init_dict[prefix + W_MID] 152 | B_mid = init_dict[prefix + B_MID] 153 | mid_dense = network.add_fully_connected(attention_ln, config.intermediate_size, W_mid, B_mid) 154 | 155 | mid_dense_out = mid_dense.get_output(0) 156 | 157 | pfc = trt.PluginFieldCollection() 158 | plug = gelu_plg_creator.create_plugin("gelu", pfc) 159 | 160 | gelu_layer = network.add_plugin_v2([mid_dense_out], plug) 161 | 162 | intermediate_act = gelu_layer.get_output(0) 163 | set_tensor_name(intermediate_act, prefix, "gelu") 164 | 165 | # Dense to hidden size 166 | W_lout = init_dict[prefix + W_LOUT] 167 | B_lout = init_dict[prefix + B_LOUT] 168 | 169 | out_dense = network.add_fully_connected(intermediate_act, hidden_size, W_lout, B_lout) 170 | set_layer_name(out_dense, prefix + "output_", "dense") 171 | out_layer = skipln(prefix + "output_layernorm_", init_dict, network, out_dense.get_output(0), attention_ln) 172 | out_ln = out_layer.get_output(0) 173 | 174 | set_tensor_name(out_ln, prefix + "output_", "reshape") 175 | 176 | return out_ln 177 | 178 | 179 | def bert_model(config, init_dict, network, input_tensor, input_mask): 180 | """ 181 | Create the bert model 182 | """ 183 | prev_input = input_tensor 184 | for layer in range(0, config.num_hidden_layers): 185 | ss = "l{}_".format(layer) 186 | prev_input = transformer_layer_opt(ss, config, init_dict, network, prev_input, input_mask) 187 | return prev_input 188 | 189 | 190 | def load_weights(inputbase): 191 | """ 192 | Load the weights from the tensorflow checkpoint 193 | """ 194 | weights_dict = dict() 195 | 196 | try: 197 | reader = pyTF.NewCheckpointReader(inputbase) 198 | tensor_dict = reader.get_variable_to_shape_map() 199 | 200 | # There might be training-related variables in the checkpoint that can be discarded 201 | param_names = [key for key in sorted(tensor_dict) if 'adam' not in key and 'global_step' not in key and 'pooler' not in key] 202 | count = len(param_names) 203 | TRT_LOGGER.log(TRT_LOGGER.INFO, str(count)) 204 | 205 | for pn in param_names: 206 | toks = pn.lower().split('/') 207 | if 'encoder' in pn: 208 | assert ('layer' in pn) 209 | l = (re.findall('\d+', pn))[0] 210 | outname = 'l{}_'.format(l) + '_'.join(toks[3:]) 211 | else: 212 | outname = '_'.join(toks) 213 | 214 | tensor = reader.get_tensor(pn) 215 | shape = tensor.shape 216 | if pn.find('kernel') != -1: 217 | TRT_LOGGER.log(TRT_LOGGER.INFO, "Transposing {}\n".format(np)) 218 | tensor = np.transpose(tensor) 219 | 220 | shape = tensor.shape 221 | flat_tensor = tensor.flatten() 222 | shape_str = '{} '.format(len(shape)) + ' '.join([str(d) for d in shape]) 223 | weights_dict[outname] = trt.Weights(flat_tensor) 224 | 225 | TRT_LOGGER.log(TRT_LOGGER.INFO, "Orig.name: {:}, TRT name: {:}, shape: {:}".format(pn, outname, shape_str)) 226 | 227 | additional_dict = dict() 228 | for key, value in weights_dict.items(): 229 | pos = key.find(BQ) 230 | if pos != -1: 231 | hidden_size = value.size 232 | prefix = key[:pos] 233 | 234 | Bq_ = value 235 | Bk_ = weights_dict[prefix + BK] 236 | Bv_ = weights_dict[prefix + BV] 237 | Wq_ = weights_dict[prefix + WQ] 238 | Wk_ = weights_dict[prefix + WK] 239 | Wv_ = weights_dict[prefix + WV] 240 | 241 | mat_size = hidden_size * hidden_size 242 | wcount = 3 * mat_size 243 | Wall = np.zeros(wcount, np.float32) 244 | bcount = 3 * hidden_size 245 | Ball = np.zeros(bcount, np.float32) 246 | Wall[0:mat_size] = Wq_.numpy()[0:mat_size] 247 | Wall[mat_size:2*mat_size] = Wk_.numpy()[0:mat_size] 248 | Wall[2*mat_size:3*mat_size] = Wv_.numpy()[0:mat_size] 249 | Ball[0:hidden_size] = Bq_.numpy()[0:hidden_size] 250 | Ball[hidden_size:2*hidden_size] = Bk_.numpy()[0:hidden_size] 251 | Ball[2*hidden_size:3*hidden_size] = Bv_.numpy()[0:hidden_size] 252 | 253 | additional_dict[prefix + WQKV] = trt.Weights(Wall) 254 | additional_dict[prefix + BQKV] = trt.Weights(Ball) 255 | 256 | except Exception as error: 257 | TRT_LOGGER.log(TRT_LOGGER.ERROR, str(error)) 258 | 259 | weights_dict.update(additional_dict) 260 | return weights_dict 261 | 262 | 263 | def main(inputbase, B, S, bert_path, outputbase): 264 | bert_config_path = os.path.join(bert_path, 'bert_config.json') 265 | TRT_LOGGER.log(TRT_LOGGER.INFO, bert_config_path) 266 | config = BertConfig(bert_config_path) 267 | 268 | # Load weights from checkpoint file 269 | init_dict = load_weights(inputbase) 270 | 271 | with trt.Builder(TRT_LOGGER) as builder: 272 | builder.max_batch_size = B 273 | builder.max_workspace_size = 5000 * (1024 * 1024) 274 | builder.fp16_mode = True 275 | builder.strict_type_constraints = False 276 | 277 | ty = trt.PluginFieldType.FLOAT32 278 | 279 | w = init_dict["bert_embeddings_layernorm_beta"] 280 | wbeta = trt.PluginField("bert_embeddings_layernorm_beta", w.numpy(), ty) 281 | 282 | w = init_dict["bert_embeddings_layernorm_gamma"] 283 | wgamma = trt.PluginField("bert_embeddings_layernorm_gamma", w.numpy(), ty) 284 | 285 | w = init_dict["bert_embeddings_word_embeddings"] 286 | wwordemb = trt.PluginField("bert_embeddings_word_embeddings", w.numpy(), ty) 287 | 288 | w = init_dict["bert_embeddings_token_type_embeddings"] 289 | wtokemb = trt.PluginField("bert_embeddings_token_type_embeddings", w.numpy(), ty) 290 | 291 | w = init_dict["bert_embeddings_position_embeddings"] 292 | wposemb = trt.PluginField("bert_embeddings_position_embeddings", w.numpy(), ty) 293 | 294 | pfc = trt.PluginFieldCollection([wbeta, wgamma, wwordemb, wtokemb, wposemb]) 295 | fn = emln_plg_creator.create_plugin("embeddings", pfc) 296 | 297 | with builder.create_network() as network: 298 | input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(S, )) 299 | segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(S, )) 300 | input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(S, )) 301 | inputs = [input_ids, segment_ids, input_mask] 302 | emb_layer = network.add_plugin_v2(inputs, fn) 303 | 304 | embeddings = emb_layer.get_output(0) 305 | mask_idx = emb_layer.get_output(1) 306 | 307 | bert_out = bert_model(config, init_dict, network, embeddings, mask_idx) 308 | 309 | # bert_output = bert_out.get_output(0) 310 | network.mark_output(bert_out) 311 | 312 | engine = builder.build_cuda_engine(network) 313 | 314 | TRT_LOGGER.log(TRT_LOGGER.INFO, "Serializing the engine....") 315 | serialized_engine = engine.serialize() 316 | TRT_LOGGER.log(TRT_LOGGER.INFO, "Saving the engine....") 317 | with open(outputbase, 'wb') as fout: 318 | fout.write(serialized_engine) 319 | TRT_LOGGER.log(TRT_LOGGER.INFO, "Done.") 320 | 321 | 322 | 323 | if __name__ == "__main__": 324 | model = '/workspace/models/fine-tuned/chinese_bert_base/bert_model.ckpt' 325 | output = './bert_base_en.engine' 326 | sequence = 128 327 | batchsize = 1 328 | config = '/workspace/models/fine-tuned/uncased_L-12_H-768_A-12' 329 | 330 | inputbase = model 331 | outputbase = output 332 | B = int(batchsize) 333 | S = int(sequence) 334 | bert_path = config 335 | 336 | main(inputbase, B, S, bert_path, outputbase) 337 | # Required to work around a double free issue in TRT 5.1 338 | os._exit(0) 339 | 340 | 341 | -------------------------------------------------------------------------------- /encode/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /encode/data_processing_new.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Modifications copyright (C) 2019 NVIDIA Corp. 17 | 18 | import tokenization 19 | import collections 20 | import numpy as np 21 | import six 22 | import math 23 | 24 | 25 | def convert_doc_tokens(paragraph_text): 26 | 27 | """ Return the list of tokens from the doc text """ 28 | def is_whitespace(c): 29 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 30 | return True 31 | return False 32 | 33 | doc_tokens = [] 34 | prev_is_whitespace = True 35 | for c in paragraph_text: 36 | if is_whitespace(c): 37 | prev_is_whitespace = True 38 | else: 39 | if prev_is_whitespace: 40 | doc_tokens.append(c) 41 | else: 42 | doc_tokens[-1] += c 43 | prev_is_whitespace = False 44 | 45 | return doc_tokens 46 | 47 | 48 | def _check_is_max_context(doc_spans, cur_span_index, position): 49 | """Check if this is the 'max context' doc span for the token.""" 50 | 51 | # Because of the sliding window approach taken to scoring documents, a single 52 | # token can appear in multiple documents. E.g. 53 | # Doc: the man went to the store and bought a gallon of milk 54 | # Span A: the man went to the 55 | # Span B: to the store and bought 56 | # Span C: and bought a gallon of 57 | # ... 58 | # 59 | # Now the word 'bought' will have two scores from spans B and C. We only 60 | # want to consider the score with "maximum context", which we define as 61 | # the *minimum* of its left and right context (the *sum* of left and 62 | # right context will always be the same, of course). 63 | # 64 | # In the example the maximum context for 'bought' would be span C since 65 | # it has 1 left context and 3 right context, while span B has 4 left context 66 | # and 0 right context. 67 | best_score = None 68 | best_span_index = None 69 | for (span_index, doc_span) in enumerate(doc_spans): 70 | end = doc_span.start + doc_span.length - 1 71 | if position < doc_span.start: 72 | continue 73 | if position > end: 74 | continue 75 | num_left_context = position - doc_span.start 76 | num_right_context = end - position 77 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 78 | if best_score is None or score > best_score: 79 | best_score = score 80 | best_span_index = span_index 81 | 82 | return cur_span_index == best_span_index 83 | 84 | 85 | def convert_examples_to_features(text_a, text_b, tokenizer, max_seq_length): 86 | 87 | tokens_a = tokenizer.tokenize(text_a) 88 | tokens_b = None 89 | if text_b: 90 | tokens_b = tokenizer.tokenize(text_b) 91 | 92 | if tokens_b: 93 | # Modifies `tokens_a` and `tokens_b` in place so that the total 94 | # length is less than the specified length. 95 | # Account for [CLS], [SEP], [SEP] with "- 3" 96 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 97 | else: 98 | # Account for [CLS] and [SEP] with "- 2" 99 | if len(tokens_a) > max_seq_length - 2: 100 | tokens_a = tokens_a[0:(max_seq_length - 2)] 101 | 102 | tokens = [] 103 | input_type_ids = [] 104 | 105 | tokens.append("[CLS]") 106 | input_type_ids.append(0) 107 | for token in tokens_a: 108 | tokens.append(token) 109 | input_type_ids.append(0) 110 | tokens.append("[SEP]") 111 | input_type_ids.append(0) 112 | 113 | if tokens_b: 114 | for token in tokens_b: 115 | tokens.append(token) 116 | input_type_ids.append(1) 117 | tokens.append("[SEP]") 118 | input_type_ids.append(1) 119 | 120 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 121 | 122 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 123 | # tokens are attended to. 124 | input_mask = [1] * len(input_ids) 125 | 126 | # Zero-pad up to the sequence length. 127 | while len(input_ids) < max_seq_length: 128 | input_ids.append(0) 129 | input_mask.append(0) 130 | input_type_ids.append(0) 131 | 132 | assert len(input_ids) == max_seq_length 133 | assert len(input_mask) == max_seq_length 134 | assert len(input_type_ids) == max_seq_length 135 | 136 | def create_int_feature(values): 137 | feature = np.asarray(values, dtype=np.int32, order=None) 138 | return feature 139 | 140 | features = collections.OrderedDict() 141 | features["input_ids"] = create_int_feature(input_ids) 142 | features["input_mask"] = create_int_feature(input_mask) 143 | features["segment_ids"] = create_int_feature(input_type_ids) 144 | features["tokens"] = tokens 145 | features["token_to_orig_map"] = '' 146 | features["token_is_max_context"] = '' 147 | return features 148 | 149 | 150 | def _get_best_indexes(logits, n_best_size): 151 | """Get the n-best logits from a list.""" 152 | 153 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 154 | 155 | best_indexes = [] 156 | for i in range(len(index_and_score)): 157 | if i >= n_best_size: 158 | break 159 | best_indexes.append(index_and_score[i][0]) 160 | return best_indexes 161 | 162 | 163 | def get_final_text(pred_text, orig_text, do_lower_case): 164 | """Project the tokenized prediction back to the original text.""" 165 | 166 | # When we created the data, we kept track of the alignment between original 167 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 168 | # now `orig_text` contains the span of our original text corresponding to the 169 | # span that we predicted. 170 | # 171 | # However, `orig_text` may contain extra characters that we don't want in 172 | # our prediction. 173 | # 174 | # For example, let's say: 175 | # pred_text = steve smith 176 | # orig_text = Steve Smith's 177 | # 178 | # We don't want to return `orig_text` because it contains the extra "'s". 179 | # 180 | # We don't want to return `pred_text` because it's already been normalized 181 | # (the SQuAD eval script also does punctuation stripping/lower casing but 182 | # our tokenizer does additional normalization like stripping accent 183 | # characters). 184 | # 185 | # What we really want to return is "Steve Smith". 186 | # 187 | # Therefore, we have to apply a semi-complicated alignment heruistic between 188 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 189 | # can fail in certain cases in which case we just return `orig_text`. 190 | 191 | def _strip_spaces(text): 192 | ns_chars = [] 193 | ns_to_s_map = collections.OrderedDict() 194 | for (i, c) in enumerate(text): 195 | if c == " ": 196 | continue 197 | ns_to_s_map[len(ns_chars)] = i 198 | ns_chars.append(c) 199 | ns_text = "".join(ns_chars) 200 | return (ns_text, ns_to_s_map) 201 | 202 | # We first tokenize `orig_text`, strip whitespace from the result 203 | # and `pred_text`, and check if they are the same length. If they are 204 | # NOT the same length, the heuristic has failed. If they are the same 205 | # length, we assume the characters are one-to-one aligned. 206 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 207 | 208 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 209 | 210 | start_position = tok_text.find(pred_text) 211 | if start_position == -1: 212 | return orig_text 213 | end_position = start_position + len(pred_text) - 1 214 | 215 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 216 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 217 | 218 | if len(orig_ns_text) != len(tok_ns_text): 219 | return orig_text 220 | 221 | # We then project the characters in `pred_text` back to `orig_text` using 222 | # the character-to-character alignment. 223 | tok_s_to_ns_map = {} 224 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 225 | tok_s_to_ns_map[tok_index] = i 226 | 227 | orig_start_position = None 228 | if start_position in tok_s_to_ns_map: 229 | ns_start_position = tok_s_to_ns_map[start_position] 230 | if ns_start_position in orig_ns_to_s_map: 231 | orig_start_position = orig_ns_to_s_map[ns_start_position] 232 | 233 | if orig_start_position is None: 234 | return orig_text 235 | 236 | orig_end_position = None 237 | if end_position in tok_s_to_ns_map: 238 | ns_end_position = tok_s_to_ns_map[end_position] 239 | if ns_end_position in orig_ns_to_s_map: 240 | orig_end_position = orig_ns_to_s_map[ns_end_position] 241 | 242 | if orig_end_position is None: 243 | return orig_text 244 | 245 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 246 | return output_text 247 | 248 | 249 | def _compute_softmax(scores): 250 | """Compute softmax probability over raw logits.""" 251 | if not scores: 252 | return [] 253 | 254 | max_score = None 255 | for score in scores: 256 | if max_score is None or score > max_score: 257 | max_score = score 258 | 259 | exp_scores = [] 260 | total_sum = 0.0 261 | for score in scores: 262 | x = math.exp(score - max_score) 263 | exp_scores.append(x) 264 | total_sum += x 265 | 266 | probs = [] 267 | for score in exp_scores: 268 | probs.append(score / total_sum) 269 | return probs 270 | 271 | 272 | def get_predictions(doc_tokens, features, start_logits, end_logits, n_best_size, max_answer_length): 273 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 274 | "PrelimPrediction", 275 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 276 | 277 | prediction = "" 278 | scores_diff_json = 0.0 279 | 280 | prelim_predictions = [] 281 | # keep track of the minimum score of null start+end of position 0 282 | score_null = 1000000 # large and positive 283 | min_null_feature_index = 0 # the paragraph slice with min mull score 284 | null_start_logit = 0 # the start logit at the slice with min null score 285 | null_end_logit = 0 # the end logit at the slice with min null score 286 | 287 | start_indexes = _get_best_indexes(start_logits, n_best_size) 288 | end_indexes = _get_best_indexes(end_logits, n_best_size) 289 | 290 | # if we could have irrelevant answers, get the min score of irrelevant 291 | version_2_with_negative = True 292 | if version_2_with_negative: 293 | feature_null_score = start_logits[0] + end_logits[0] 294 | if feature_null_score < score_null: 295 | score_null = feature_null_score 296 | min_null_feature_index = 0 297 | null_start_logit = start_logits[0] 298 | null_end_logit = end_logits[0] 299 | 300 | for start_index in start_indexes: 301 | for end_index in end_indexes: 302 | # We could hypothetically create invalid predictions, e.g., predict 303 | # that the start of the span is in the question. We throw out all 304 | # invalid predictions. 305 | if start_index >= len(features['tokens']): 306 | continue 307 | if end_index >= len(features['tokens']): 308 | continue 309 | if start_index not in features['token_to_orig_map']: 310 | continue 311 | if end_index not in features['token_to_orig_map']: 312 | continue 313 | if not features['token_is_max_context'].get(start_index, False): 314 | continue 315 | if end_index < start_index: 316 | continue 317 | length = end_index - start_index + 1 318 | if length > max_answer_length: 319 | continue 320 | prelim_predictions.append( 321 | _PrelimPrediction( 322 | feature_index=0, 323 | start_index=start_index, 324 | end_index=end_index, 325 | start_logit=start_logits[start_index], 326 | end_logit=end_logits[end_index])) 327 | 328 | if version_2_with_negative: 329 | prelim_predictions.append( 330 | _PrelimPrediction( 331 | feature_index=min_null_feature_index, 332 | start_index=0, 333 | end_index=0, 334 | start_logit=null_start_logit, 335 | end_logit=null_end_logit)) 336 | 337 | prelim_predictions = sorted( 338 | prelim_predictions, 339 | key=lambda x: (x.start_logit + x.end_logit), 340 | reverse=True) 341 | 342 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 343 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 344 | 345 | seen_predictions = {} 346 | nbest = [] 347 | for pred in prelim_predictions: 348 | if len(nbest) >= n_best_size: 349 | break 350 | 351 | if pred.start_index > 0: # this is a non-null prediction 352 | tok_tokens = features['tokens'][pred.start_index:(pred.end_index + 1)] 353 | orig_doc_start = features['token_to_orig_map'][pred.start_index] 354 | orig_doc_end = features['token_to_orig_map'][pred.end_index] 355 | orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)] 356 | tok_text = " ".join(tok_tokens) 357 | 358 | # De-tokenize WordPieces that have been split off. 359 | tok_text = tok_text.replace(" ##", "") 360 | tok_text = tok_text.replace("##", "") 361 | 362 | # Clean whitespace 363 | tok_text = tok_text.strip() 364 | tok_text = " ".join(tok_text.split()) 365 | orig_text = " ".join(orig_tokens) 366 | 367 | final_text = get_final_text(tok_text, orig_text, True) 368 | if final_text in seen_predictions: 369 | continue 370 | 371 | seen_predictions[final_text] = True 372 | else: 373 | final_text = "" 374 | seen_predictions[final_text] = True 375 | 376 | nbest.append( 377 | _NbestPrediction( 378 | text=final_text, 379 | start_logit=pred.start_logit, 380 | end_logit=pred.end_logit)) 381 | 382 | # if we didn't inlude the empty option in the n-best, inlcude it 383 | if version_2_with_negative: 384 | if "" not in seen_predictions: 385 | nbest.append( 386 | _NbestPrediction( 387 | text="", start_logit=null_start_logit, 388 | end_logit=null_end_logit)) 389 | # In very rare edge cases we could have no valid predictions. So we 390 | # just create a nonce prediction in this case to avoid failure. 391 | if not nbest: 392 | nbest.append( 393 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 394 | 395 | assert len(nbest) >= 1 396 | 397 | total_scores = [] 398 | best_non_null_entry = None 399 | for entry in nbest: 400 | total_scores.append(entry.start_logit + entry.end_logit) 401 | if not best_non_null_entry: 402 | if entry.text: 403 | best_non_null_entry = entry 404 | 405 | probs = _compute_softmax(total_scores) 406 | 407 | nbest_json = [] 408 | for (i, entry) in enumerate(nbest): 409 | output = collections.OrderedDict() 410 | output["text"] = entry.text 411 | output["probability"] = probs[i] 412 | output["start_logit"] = entry.start_logit 413 | output["end_logit"] = entry.end_logit 414 | nbest_json.append(output) 415 | 416 | assert len(nbest_json) >= 1 417 | 418 | null_score_diff_threshold = 0.0 419 | if not version_2_with_negative: 420 | prediction = nbest_json[0]["text"] 421 | else: 422 | # predict "" iff the null score - the score of best non-null > threshold 423 | score_diff = score_null - best_non_null_entry.start_logit - ( 424 | best_non_null_entry.end_logit) 425 | scores_diff_json = score_diff 426 | if score_diff > null_score_diff_threshold: 427 | prediction = "" 428 | else: 429 | prediction = best_non_null_entry.text 430 | 431 | return prediction, nbest_json, scores_diff_json 432 | --------------------------------------------------------------------------------