├── README.md ├── __init__.py ├── create_pretraining_data.py ├── extract_features.py ├── modeling.py ├── optimization.py ├── run_classifier_elmo.py ├── run_pretraining.py ├── run_squad.py ├── run_squad_elmo.py └── tokenization.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep dynamic Contextualized word representation (DDCWR) 2 | TensorFlow code and pre-trained models for DDCWR 3 | 4 | # Important explanation 5 | 1. The method of the model is simple, only using the feed forward neural network with attention mechanism. 6 | 2. Model training is fast, and only a few cycles can be used to train the model. The value of the initialization parameter comes from the BERT model of Google. 7 | 3. The effect of the model is very good. In most cases, it is consistent with the current (2018-11-13) optimal model. Sometimes the effect is better. The optimal effect can be seen in [gluebenchmark](https://gluebenchmark.com/leaderboard). 8 | 9 | # Thought of article 10 | 11 | This model Deep_dynamic_word_representation(DDWR) combines the BERT model and ELMo's deep context word representation. 12 | 13 | The BERT comes from [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) 14 | The ELMo comes from [Deep contextualized word representations](https://arxiv.org/abs/1802.05365v2) 15 | 16 | # Basic usage method 17 | 18 | ## Download Pre-trained models 19 | 20 | [BERT-Base, Uncased](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip) 21 | 22 | ## Doenload [GLUE data](https://gluebenchmark.com/tasks)DATA 23 | 24 | using this [script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) 25 | 26 | ## Sentence (and sentence-pair) classification tasks 27 | 28 | difference 29 | ``` 30 | export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 31 | export GLUE_DIR=/path/to/glue 32 | 33 | python run_classifier_elmo.py \ 34 | --task_name=MRPC \ 35 | --do_train=true \ 36 | --do_eval=true \ 37 | --data_dir=$GLUE_DIR/MRPC \ 38 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 39 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 40 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 41 | --max_seq_length=128 \ 42 | --train_batch_size=32 \ 43 | --learning_rate=2e-5 \ 44 | --num_train_epochs=3.0 \ 45 | --output_dir=/tmp/mrpc_output/ 46 | ``` 47 | 48 | ### Prediction from classifier 49 | > the same as https://github.com/google-research/bert 50 | 51 | ``` 52 | export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 53 | export GLUE_DIR=/path/to/glue 54 | export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifier 55 | 56 | python run_classifier_elmo.py \ 57 | --task_name=MRPC \ 58 | --do_predict=true \ 59 | --data_dir=$GLUE_DIR/MRPC \ 60 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 61 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 62 | --init_checkpoint=$TRAINED_CLASSIFIER \ 63 | --max_seq_length=128 \ 64 | --output_dir=/tmp/mrpc_output/ 65 | ``` 66 | more methods to [google-research/bert](https://github.com/google-research/bert) 67 | 68 | 69 | ## Solve [SQUAD1.1](https://rajpurkar.github.io/SQuAD-explorer/) problem 70 | 71 | > the same as https://github.com/google-research/bert 72 | 73 | difference 74 | ``` 75 | python run_squad_elmo.py --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt --do_train=True --train_file=$SQUAD_DIR/train-v1.1.json --do_predict=True --predict_file=$SQUAD_DIR/dev-v1.1.json --train_batch_size=12 --learning_rate=3e-5 --num_train_epochs=2.0 --max_seq_length=384 --doc_stride=128 --output_dir=./tmp/elmo_squad_base/ 76 | ``` 77 | 78 | ## Experimental Result 79 | 80 | ``` 81 | python run_squad_elmo.py 82 | {“exact_match”: 81.20151371807, “f1”: 88.56178500169332} 83 | ``` 84 | 85 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | 24 | import tokenization 25 | import tensorflow as tf 26 | 27 | flags = tf.flags 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("input_file", None, 32 | "Input raw text file (or comma-separated list of files).") 33 | 34 | flags.DEFINE_string( 35 | "output_file", None, 36 | "Output TF example file (or comma-separated list of files).") 37 | 38 | flags.DEFINE_string("vocab_file", None, 39 | "The vocabulary file that the BERT model was trained on.") 40 | 41 | flags.DEFINE_bool( 42 | "do_lower_case", True, 43 | "Whether to lower case the input text. Should be True for uncased " 44 | "models and False for cased models.") 45 | 46 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 47 | 48 | flags.DEFINE_integer("max_predictions_per_seq", 20, 49 | "Maximum number of masked LM predictions per sequence.") 50 | 51 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 52 | 53 | flags.DEFINE_integer( 54 | "dupe_factor", 10, 55 | "Number of times to duplicate the input data (with different masks).") 56 | 57 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 58 | 59 | flags.DEFINE_float( 60 | "short_seq_prob", 0.1, 61 | "Probability of creating sequences which are shorter than the " 62 | "maximum length.") 63 | 64 | 65 | class TrainingInstance(object): 66 | """A single training instance (sentence pair).""" 67 | 68 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 69 | is_random_next): 70 | self.tokens = tokens 71 | self.segment_ids = segment_ids 72 | self.is_random_next = is_random_next 73 | self.masked_lm_positions = masked_lm_positions 74 | self.masked_lm_labels = masked_lm_labels 75 | 76 | def __str__(self): 77 | s = "" 78 | s += "tokens: %s\n" % (" ".join( 79 | [tokenization.printable_text(x) for x in self.tokens])) 80 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 81 | s += "is_random_next: %s\n" % self.is_random_next 82 | s += "masked_lm_positions: %s\n" % (" ".join( 83 | [str(x) for x in self.masked_lm_positions])) 84 | s += "masked_lm_labels: %s\n" % (" ".join( 85 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 86 | s += "\n" 87 | return s 88 | 89 | def __repr__(self): 90 | return self.__str__() 91 | 92 | 93 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 94 | max_predictions_per_seq, output_files): 95 | """Create TF example files from `TrainingInstance`s.""" 96 | writers = [] 97 | for output_file in output_files: 98 | writers.append(tf.python_io.TFRecordWriter(output_file)) 99 | 100 | writer_index = 0 101 | 102 | total_written = 0 103 | for (inst_index, instance) in enumerate(instances): 104 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 105 | input_mask = [1] * len(input_ids) 106 | segment_ids = list(instance.segment_ids) 107 | assert len(input_ids) <= max_seq_length 108 | 109 | while len(input_ids) < max_seq_length: 110 | input_ids.append(0) 111 | input_mask.append(0) 112 | segment_ids.append(0) 113 | 114 | assert len(input_ids) == max_seq_length 115 | assert len(input_mask) == max_seq_length 116 | assert len(segment_ids) == max_seq_length 117 | 118 | masked_lm_positions = list(instance.masked_lm_positions) 119 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 120 | masked_lm_weights = [1.0] * len(masked_lm_ids) 121 | 122 | while len(masked_lm_positions) < max_predictions_per_seq: 123 | masked_lm_positions.append(0) 124 | masked_lm_ids.append(0) 125 | masked_lm_weights.append(0.0) 126 | 127 | next_sentence_label = 1 if instance.is_random_next else 0 128 | 129 | features = collections.OrderedDict() 130 | features["input_ids"] = create_int_feature(input_ids) 131 | features["input_mask"] = create_int_feature(input_mask) 132 | features["segment_ids"] = create_int_feature(segment_ids) 133 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 134 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 135 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 136 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 137 | 138 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 139 | 140 | writers[writer_index].write(tf_example.SerializeToString()) 141 | writer_index = (writer_index + 1) % len(writers) 142 | 143 | total_written += 1 144 | 145 | if inst_index < 20: 146 | tf.logging.info("*** Example ***") 147 | tf.logging.info("tokens: %s" % " ".join( 148 | [tokenization.printable_text(x) for x in instance.tokens])) 149 | 150 | for feature_name in features.keys(): 151 | feature = features[feature_name] 152 | values = [] 153 | if feature.int64_list.value: 154 | values = feature.int64_list.value 155 | elif feature.float_list.value: 156 | values = feature.float_list.value 157 | tf.logging.info( 158 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 159 | 160 | for writer in writers: 161 | writer.close() 162 | 163 | tf.logging.info("Wrote %d total instances", total_written) 164 | 165 | 166 | def create_int_feature(values): 167 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 168 | return feature 169 | 170 | 171 | def create_float_feature(values): 172 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 173 | return feature 174 | 175 | 176 | def create_training_instances(input_files, tokenizer, max_seq_length, 177 | dupe_factor, short_seq_prob, masked_lm_prob, 178 | max_predictions_per_seq, rng): 179 | """Create `TrainingInstance`s from raw text.""" 180 | all_documents = [[]] 181 | 182 | # Input file format: 183 | # (1) One sentence per line. These should ideally be actual sentences, not 184 | # entire paragraphs or arbitrary spans of text. (Because we use the 185 | # sentence boundaries for the "next sentence prediction" task). 186 | # (2) Blank lines between documents. Document boundaries are needed so 187 | # that the "next sentence prediction" task doesn't span between documents. 188 | for input_file in input_files: 189 | with tf.gfile.GFile(input_file, "r") as reader: 190 | while True: 191 | line = tokenization.convert_to_unicode(reader.readline()) 192 | if not line: 193 | break 194 | line = line.strip() 195 | 196 | # Empty lines are used as document delimiters 197 | if not line: 198 | all_documents.append([]) 199 | tokens = tokenizer.tokenize(line) 200 | if tokens: 201 | all_documents[-1].append(tokens) 202 | 203 | # Remove empty documents 204 | all_documents = [x for x in all_documents if x] 205 | rng.shuffle(all_documents) 206 | 207 | vocab_words = list(tokenizer.vocab.keys()) 208 | instances = [] 209 | for _ in range(dupe_factor): 210 | for document_index in range(len(all_documents)): 211 | instances.extend( 212 | create_instances_from_document( 213 | all_documents, document_index, max_seq_length, short_seq_prob, 214 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 215 | 216 | rng.shuffle(instances) 217 | return instances 218 | 219 | 220 | def create_instances_from_document( 221 | all_documents, document_index, max_seq_length, short_seq_prob, 222 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 223 | """Creates `TrainingInstance`s for a single document.""" 224 | document = all_documents[document_index] 225 | 226 | # Account for [CLS], [SEP], [SEP] 227 | max_num_tokens = max_seq_length - 3 228 | 229 | # We *usually* want to fill up the entire sequence since we are padding 230 | # to `max_seq_length` anyways, so short sequences are generally wasted 231 | # computation. However, we *sometimes* 232 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 233 | # sequences to minimize the mismatch between pre-training and fine-tuning. 234 | # The `target_seq_length` is just a rough target however, whereas 235 | # `max_seq_length` is a hard limit. 236 | target_seq_length = max_num_tokens 237 | if rng.random() < short_seq_prob: 238 | target_seq_length = rng.randint(2, max_num_tokens) 239 | 240 | # We DON'T just concatenate all of the tokens from a document into a long 241 | # sequence and choose an arbitrary split point because this would make the 242 | # next sentence prediction task too easy. Instead, we split the input into 243 | # segments "A" and "B" based on the actual "sentences" provided by the user 244 | # input. 245 | instances = [] 246 | current_chunk = [] 247 | current_length = 0 248 | i = 0 249 | while i < len(document): 250 | segment = document[i] 251 | current_chunk.append(segment) 252 | current_length += len(segment) 253 | if i == len(document) - 1 or current_length >= target_seq_length: 254 | if current_chunk: 255 | # `a_end` is how many segments from `current_chunk` go into the `A` 256 | # (first) sentence. 257 | a_end = 1 258 | if len(current_chunk) >= 2: 259 | a_end = rng.randint(1, len(current_chunk) - 1) 260 | 261 | tokens_a = [] 262 | for j in range(a_end): 263 | tokens_a.extend(current_chunk[j]) 264 | 265 | tokens_b = [] 266 | # Random next 267 | is_random_next = False 268 | if len(current_chunk) == 1 or rng.random() < 0.5: 269 | is_random_next = True 270 | target_b_length = target_seq_length - len(tokens_a) 271 | 272 | # This should rarely go for more than one iteration for large 273 | # corpora. However, just to be careful, we try to make sure that 274 | # the random document is not the same as the document 275 | # we're processing. 276 | for _ in range(10): 277 | random_document_index = rng.randint(0, len(all_documents) - 1) 278 | if random_document_index != document_index: 279 | break 280 | 281 | random_document = all_documents[random_document_index] 282 | random_start = rng.randint(0, len(random_document) - 1) 283 | for j in range(random_start, len(random_document)): 284 | tokens_b.extend(random_document[j]) 285 | if len(tokens_b) >= target_b_length: 286 | break 287 | # We didn't actually use these segments so we "put them back" so 288 | # they don't go to waste. 289 | num_unused_segments = len(current_chunk) - a_end 290 | i -= num_unused_segments 291 | # Actual next 292 | else: 293 | is_random_next = False 294 | for j in range(a_end, len(current_chunk)): 295 | tokens_b.extend(current_chunk[j]) 296 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 297 | 298 | assert len(tokens_a) >= 1 299 | assert len(tokens_b) >= 1 300 | 301 | tokens = [] 302 | segment_ids = [] 303 | tokens.append("[CLS]") 304 | segment_ids.append(0) 305 | for token in tokens_a: 306 | tokens.append(token) 307 | segment_ids.append(0) 308 | 309 | tokens.append("[SEP]") 310 | segment_ids.append(0) 311 | 312 | for token in tokens_b: 313 | tokens.append(token) 314 | segment_ids.append(1) 315 | tokens.append("[SEP]") 316 | segment_ids.append(1) 317 | 318 | (tokens, masked_lm_positions, 319 | masked_lm_labels) = create_masked_lm_predictions( 320 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 321 | instance = TrainingInstance( 322 | tokens=tokens, 323 | segment_ids=segment_ids, 324 | is_random_next=is_random_next, 325 | masked_lm_positions=masked_lm_positions, 326 | masked_lm_labels=masked_lm_labels) 327 | instances.append(instance) 328 | current_chunk = [] 329 | current_length = 0 330 | i += 1 331 | 332 | return instances 333 | 334 | 335 | def create_masked_lm_predictions(tokens, masked_lm_prob, 336 | max_predictions_per_seq, vocab_words, rng): 337 | """Creates the predictions for the masked LM objective.""" 338 | 339 | cand_indexes = [] 340 | for (i, token) in enumerate(tokens): 341 | if token == "[CLS]" or token == "[SEP]": 342 | continue 343 | cand_indexes.append(i) 344 | 345 | rng.shuffle(cand_indexes) 346 | 347 | output_tokens = list(tokens) 348 | 349 | masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name 350 | 351 | num_to_predict = min(max_predictions_per_seq, 352 | max(1, int(round(len(tokens) * masked_lm_prob)))) 353 | 354 | masked_lms = [] 355 | covered_indexes = set() 356 | for index in cand_indexes: 357 | if len(masked_lms) >= num_to_predict: 358 | break 359 | if index in covered_indexes: 360 | continue 361 | covered_indexes.add(index) 362 | 363 | masked_token = None 364 | # 80% of the time, replace with [MASK] 365 | if rng.random() < 0.8: 366 | masked_token = "[MASK]" 367 | else: 368 | # 10% of the time, keep original 369 | if rng.random() < 0.5: 370 | masked_token = tokens[index] 371 | # 10% of the time, replace with random word 372 | else: 373 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 374 | 375 | output_tokens[index] = masked_token 376 | 377 | masked_lms.append(masked_lm(index=index, label=tokens[index])) 378 | 379 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 380 | 381 | masked_lm_positions = [] 382 | masked_lm_labels = [] 383 | for p in masked_lms: 384 | masked_lm_positions.append(p.index) 385 | masked_lm_labels.append(p.label) 386 | 387 | return (output_tokens, masked_lm_positions, masked_lm_labels) 388 | 389 | 390 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 391 | """Truncates a pair of sequences to a maximum sequence length.""" 392 | while True: 393 | total_length = len(tokens_a) + len(tokens_b) 394 | if total_length <= max_num_tokens: 395 | break 396 | 397 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 398 | assert len(trunc_tokens) >= 1 399 | 400 | # We want to sometimes truncate from the front and sometimes from the 401 | # back to add more randomness and avoid biases. 402 | if rng.random() < 0.5: 403 | del trunc_tokens[0] 404 | else: 405 | trunc_tokens.pop() 406 | 407 | 408 | def main(_): 409 | tf.logging.set_verbosity(tf.logging.INFO) 410 | 411 | tokenizer = tokenization.FullTokenizer( 412 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 413 | 414 | input_files = [] 415 | for input_pattern in FLAGS.input_file.split(","): 416 | input_files.extend(tf.gfile.Glob(input_pattern)) 417 | 418 | tf.logging.info("*** Reading from input files ***") 419 | for input_file in input_files: 420 | tf.logging.info(" %s", input_file) 421 | 422 | rng = random.Random(FLAGS.random_seed) 423 | instances = create_training_instances( 424 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 425 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 426 | rng) 427 | 428 | output_files = FLAGS.output_file.split(",") 429 | tf.logging.info("*** Writing to output files ***") 430 | for output_file in output_files: 431 | tf.logging.info(" %s", output_file) 432 | 433 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 434 | FLAGS.max_predictions_per_seq, output_files) 435 | 436 | 437 | if __name__ == "__main__": 438 | flags.mark_flag_as_required("input_file") 439 | flags.mark_flag_as_required("output_file") 440 | flags.mark_flag_as_required("vocab_file") 441 | tf.app.run() 442 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import collections 23 | import json 24 | import re 25 | 26 | import modeling 27 | import tokenization 28 | import tensorflow as tf 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("input_file", None, "") 35 | 36 | flags.DEFINE_string("output_file", None, "") 37 | 38 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_integer( 46 | "max_seq_length", 128, 47 | "The maximum total input sequence length after WordPiece tokenization. " 48 | "Sequences longer than this will be truncated, and sequences shorter " 49 | "than this will be padded.") 50 | 51 | flags.DEFINE_string( 52 | "init_checkpoint", None, 53 | "Initial checkpoint (usually from a pre-trained BERT model).") 54 | 55 | flags.DEFINE_string("vocab_file", None, 56 | "The vocabulary file that the BERT model was trained on.") 57 | 58 | flags.DEFINE_bool( 59 | "do_lower_case", True, 60 | "Whether to lower case the input text. Should be True for uncased " 61 | "models and False for cased models.") 62 | 63 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 64 | 65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 66 | 67 | flags.DEFINE_string("master", None, 68 | "If using a TPU, the address of the master.") 69 | 70 | flags.DEFINE_integer( 71 | "num_tpu_cores", 8, 72 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 73 | 74 | flags.DEFINE_bool( 75 | "use_one_hot_embeddings", False, 76 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 77 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 78 | "since it is much faster.") 79 | 80 | 81 | class InputExample(object): 82 | 83 | def __init__(self, unique_id, text_a, text_b): 84 | self.unique_id = unique_id 85 | self.text_a = text_a 86 | self.text_b = text_b 87 | 88 | 89 | class InputFeatures(object): 90 | """A single set of features of data.""" 91 | 92 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 93 | self.unique_id = unique_id 94 | self.tokens = tokens 95 | self.input_ids = input_ids 96 | self.input_mask = input_mask 97 | self.input_type_ids = input_type_ids 98 | 99 | 100 | def input_fn_builder(features, seq_length): 101 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 102 | 103 | all_unique_ids = [] 104 | all_input_ids = [] 105 | all_input_mask = [] 106 | all_input_type_ids = [] 107 | 108 | for feature in features: 109 | all_unique_ids.append(feature.unique_id) 110 | all_input_ids.append(feature.input_ids) 111 | all_input_mask.append(feature.input_mask) 112 | all_input_type_ids.append(feature.input_type_ids) 113 | 114 | def input_fn(params): 115 | """The actual input function.""" 116 | batch_size = params["batch_size"] 117 | 118 | num_examples = len(features) 119 | 120 | # This is for demo purposes and does NOT scale to large data sets. We do 121 | # not use Dataset.from_generator() because that uses tf.py_func which is 122 | # not TPU compatible. The right way to load data is with TFRecordReader. 123 | d = tf.data.Dataset.from_tensor_slices({ 124 | "unique_ids": 125 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 126 | "input_ids": 127 | tf.constant( 128 | all_input_ids, shape=[num_examples, seq_length], 129 | dtype=tf.int32), 130 | "input_mask": 131 | tf.constant( 132 | all_input_mask, 133 | shape=[num_examples, seq_length], 134 | dtype=tf.int32), 135 | "input_type_ids": 136 | tf.constant( 137 | all_input_type_ids, 138 | shape=[num_examples, seq_length], 139 | dtype=tf.int32), 140 | }) 141 | 142 | d = d.batch(batch_size=batch_size, drop_remainder=False) 143 | return d 144 | 145 | return input_fn 146 | 147 | 148 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 149 | use_one_hot_embeddings): 150 | """Returns `model_fn` closure for TPUEstimator.""" 151 | 152 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 153 | """The `model_fn` for TPUEstimator.""" 154 | 155 | unique_ids = features["unique_ids"] 156 | input_ids = features["input_ids"] 157 | input_mask = features["input_mask"] 158 | input_type_ids = features["input_type_ids"] 159 | 160 | model = modeling.BertModel( 161 | config=bert_config, 162 | is_training=False, 163 | input_ids=input_ids, 164 | input_mask=input_mask, 165 | token_type_ids=input_type_ids, 166 | use_one_hot_embeddings=use_one_hot_embeddings) 167 | 168 | if mode != tf.estimator.ModeKeys.PREDICT: 169 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 170 | 171 | tvars = tf.trainable_variables() 172 | scaffold_fn = None 173 | (assignment_map, 174 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( 175 | tvars, init_checkpoint) 176 | if use_tpu: 177 | 178 | def tpu_scaffold(): 179 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 180 | return tf.train.Scaffold() 181 | 182 | scaffold_fn = tpu_scaffold 183 | else: 184 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 185 | 186 | tf.logging.info("**** Trainable Variables ****") 187 | for var in tvars: 188 | init_string = "" 189 | if var.name in initialized_variable_names: 190 | init_string = ", *INIT_FROM_CKPT*" 191 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 192 | init_string) 193 | 194 | all_layers = model.get_all_encoder_layers() 195 | 196 | predictions = { 197 | "unique_id": unique_ids, 198 | } 199 | 200 | for (i, layer_index) in enumerate(layer_indexes): 201 | predictions["layer_output_%d" % i] = all_layers[layer_index] 202 | 203 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 204 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 205 | return output_spec 206 | 207 | return model_fn 208 | 209 | 210 | def convert_examples_to_features(examples, seq_length, tokenizer): 211 | """Loads a data file into a list of `InputBatch`s.""" 212 | 213 | features = [] 214 | for (ex_index, example) in enumerate(examples): 215 | tokens_a = tokenizer.tokenize(example.text_a) 216 | 217 | tokens_b = None 218 | if example.text_b: 219 | tokens_b = tokenizer.tokenize(example.text_b) 220 | 221 | if tokens_b: 222 | # Modifies `tokens_a` and `tokens_b` in place so that the total 223 | # length is less than the specified length. 224 | # Account for [CLS], [SEP], [SEP] with "- 3" 225 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 226 | else: 227 | # Account for [CLS] and [SEP] with "- 2" 228 | if len(tokens_a) > seq_length - 2: 229 | tokens_a = tokens_a[0:(seq_length - 2)] 230 | 231 | # The convention in BERT is: 232 | # (a) For sequence pairs: 233 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 234 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 235 | # (b) For single sequences: 236 | # tokens: [CLS] the dog is hairy . [SEP] 237 | # type_ids: 0 0 0 0 0 0 0 238 | # 239 | # Where "type_ids" are used to indicate whether this is the first 240 | # sequence or the second sequence. The embedding vectors for `type=0` and 241 | # `type=1` were learned during pre-training and are added to the wordpiece 242 | # embedding vector (and position vector). This is not *strictly* necessary 243 | # since the [SEP] token unambiguously separates the sequences, but it makes 244 | # it easier for the model to learn the concept of sequences. 245 | # 246 | # For classification tasks, the first vector (corresponding to [CLS]) is 247 | # used as as the "sentence vector". Note that this only makes sense because 248 | # the entire model is fine-tuned. 249 | tokens = [] 250 | input_type_ids = [] 251 | tokens.append("[CLS]") 252 | input_type_ids.append(0) 253 | for token in tokens_a: 254 | tokens.append(token) 255 | input_type_ids.append(0) 256 | tokens.append("[SEP]") 257 | input_type_ids.append(0) 258 | 259 | if tokens_b: 260 | for token in tokens_b: 261 | tokens.append(token) 262 | input_type_ids.append(1) 263 | tokens.append("[SEP]") 264 | input_type_ids.append(1) 265 | 266 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 267 | 268 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 269 | # tokens are attended to. 270 | input_mask = [1] * len(input_ids) 271 | 272 | # Zero-pad up to the sequence length. 273 | while len(input_ids) < seq_length: 274 | input_ids.append(0) 275 | input_mask.append(0) 276 | input_type_ids.append(0) 277 | 278 | assert len(input_ids) == seq_length 279 | assert len(input_mask) == seq_length 280 | assert len(input_type_ids) == seq_length 281 | 282 | if ex_index < 5: 283 | tf.logging.info("*** Example ***") 284 | tf.logging.info("unique_id: %s" % (example.unique_id)) 285 | tf.logging.info("tokens: %s" % " ".join( 286 | [tokenization.printable_text(x) for x in tokens])) 287 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 288 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 289 | tf.logging.info( 290 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 291 | 292 | features.append( 293 | InputFeatures( 294 | unique_id=example.unique_id, 295 | tokens=tokens, 296 | input_ids=input_ids, 297 | input_mask=input_mask, 298 | input_type_ids=input_type_ids)) 299 | return features 300 | 301 | 302 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 303 | """Truncates a sequence pair in place to the maximum length.""" 304 | 305 | # This is a simple heuristic which will always truncate the longer sequence 306 | # one token at a time. This makes more sense than truncating an equal percent 307 | # of tokens from each, since if one sequence is very short then each token 308 | # that's truncated likely contains more information than a longer sequence. 309 | while True: 310 | total_length = len(tokens_a) + len(tokens_b) 311 | if total_length <= max_length: 312 | break 313 | if len(tokens_a) > len(tokens_b): 314 | tokens_a.pop() 315 | else: 316 | tokens_b.pop() 317 | 318 | 319 | def read_examples(input_file): 320 | """Read a list of `InputExample`s from an input file.""" 321 | examples = [] 322 | unique_id = 0 323 | with tf.gfile.GFile(input_file, "r") as reader: 324 | while True: 325 | line = tokenization.convert_to_unicode(reader.readline()) 326 | if not line: 327 | break 328 | line = line.strip() 329 | text_a = None 330 | text_b = None 331 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 332 | if m is None: 333 | text_a = line 334 | else: 335 | text_a = m.group(1) 336 | text_b = m.group(2) 337 | examples.append( 338 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 339 | unique_id += 1 340 | return examples 341 | 342 | 343 | def main(_): 344 | tf.logging.set_verbosity(tf.logging.INFO) 345 | 346 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 347 | 348 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 349 | 350 | tokenizer = tokenization.FullTokenizer( 351 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 352 | 353 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 354 | run_config = tf.contrib.tpu.RunConfig( 355 | master=FLAGS.master, 356 | tpu_config=tf.contrib.tpu.TPUConfig( 357 | num_shards=FLAGS.num_tpu_cores, 358 | per_host_input_for_training=is_per_host)) 359 | 360 | examples = read_examples(FLAGS.input_file) 361 | 362 | features = convert_examples_to_features( 363 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 364 | 365 | unique_id_to_feature = {} 366 | for feature in features: 367 | unique_id_to_feature[feature.unique_id] = feature 368 | 369 | model_fn = model_fn_builder( 370 | bert_config=bert_config, 371 | init_checkpoint=FLAGS.init_checkpoint, 372 | layer_indexes=layer_indexes, 373 | use_tpu=FLAGS.use_tpu, 374 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 375 | 376 | # If TPU is not available, this will fall back to normal Estimator on CPU 377 | # or GPU. 378 | estimator = tf.contrib.tpu.TPUEstimator( 379 | use_tpu=FLAGS.use_tpu, 380 | model_fn=model_fn, 381 | config=run_config, 382 | predict_batch_size=FLAGS.batch_size) 383 | 384 | input_fn = input_fn_builder( 385 | features=features, seq_length=FLAGS.max_seq_length) 386 | 387 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, 388 | "w")) as writer: 389 | for result in estimator.predict(input_fn, yield_single_examples=True): 390 | unique_id = int(result["unique_id"]) 391 | feature = unique_id_to_feature[unique_id] 392 | output_json = collections.OrderedDict() 393 | output_json["linex_index"] = unique_id 394 | all_features = [] 395 | for (i, token) in enumerate(feature.tokens): 396 | all_layers = [] 397 | for (j, layer_index) in enumerate(layer_indexes): 398 | layer_output = result["layer_output_%d" % j] 399 | layers = collections.OrderedDict() 400 | layers["index"] = layer_index 401 | layers["values"] = [ 402 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat 403 | ] 404 | all_layers.append(layers) 405 | features = collections.OrderedDict() 406 | features["token"] = token 407 | features["layers"] = all_layers 408 | all_features.append(features) 409 | output_json["features"] = all_features 410 | writer.write(json.dumps(output_json) + "\n") 411 | 412 | 413 | if __name__ == "__main__": 414 | flags.mark_flag_as_required("input_file") 415 | flags.mark_flag_as_required("vocab_file") 416 | flags.mark_flag_as_required("bert_config_file") 417 | flags.mark_flag_as_required("init_checkpoint") 418 | flags.mark_flag_as_required("output_file") 419 | tf.app.run() 420 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import six 27 | import tensorflow as tf 28 | 29 | 30 | class BertConfig(object): 31 | """Configuration for `BertModel`.""" 32 | 33 | def __init__(self, 34 | vocab_size, 35 | hidden_size=768, 36 | num_hidden_layers=12, 37 | num_attention_heads=12, 38 | intermediate_size=3072, 39 | hidden_act="gelu", 40 | hidden_dropout_prob=0.1, 41 | attention_probs_dropout_prob=0.1, 42 | max_position_embeddings=512, 43 | type_vocab_size=16, 44 | initializer_range=0.02): 45 | """Constructs BertConfig. 46 | 47 | Args: 48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 49 | hidden_size: Size of the encoder layers and the pooler layer. 50 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 51 | num_attention_heads: Number of attention heads for each attention layer in 52 | the Transformer encoder. 53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 54 | layer in the Transformer encoder. 55 | hidden_act: The non-linear activation function (function or string) in the 56 | encoder and pooler. 57 | hidden_dropout_prob: The dropout probability for all fully connected 58 | layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this model might 62 | ever be used with. Typically set this to something large just in case 63 | (e.g., 512 or 1024 or 2048). 64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 65 | `BertModel`. 66 | initializer_range: The stdev of the truncated_normal_initializer for 67 | initializing all weight matrices. 68 | """ 69 | self.vocab_size = vocab_size 70 | self.hidden_size = hidden_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.hidden_act = hidden_act 74 | self.intermediate_size = intermediate_size 75 | self.hidden_dropout_prob = hidden_dropout_prob 76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 77 | self.max_position_embeddings = max_position_embeddings 78 | self.type_vocab_size = type_vocab_size 79 | self.initializer_range = initializer_range 80 | 81 | @classmethod 82 | def from_dict(cls, json_object): 83 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 84 | config = BertConfig(vocab_size=None) 85 | for (key, value) in six.iteritems(json_object): 86 | config.__dict__[key] = value 87 | return config 88 | 89 | @classmethod 90 | def from_json_file(cls, json_file): 91 | """Constructs a `BertConfig` from a json file of parameters.""" 92 | with tf.gfile.GFile(json_file, "r") as reader: 93 | text = reader.read() 94 | return cls.from_dict(json.loads(text)) 95 | 96 | def to_dict(self): 97 | """Serializes this instance to a Python dictionary.""" 98 | output = copy.deepcopy(self.__dict__) 99 | return output 100 | 101 | def to_json_string(self): 102 | """Serializes this instance to a JSON string.""" 103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 104 | 105 | 106 | class BertModel(object): 107 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 108 | 109 | Example usage: 110 | 111 | ```python 112 | # Already been converted into WordPiece token ids 113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 116 | 117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 119 | 120 | model = modeling.BertModel(config=config, is_training=True, 121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 122 | 123 | label_embeddings = tf.get_variable(...) 124 | pooled_output = model.get_pooled_output() 125 | logits = tf.matmul(pooled_output, label_embeddings) 126 | ... 127 | ``` 128 | """ 129 | 130 | def __init__(self, 131 | config, 132 | is_training, 133 | input_ids, 134 | input_mask=None, 135 | token_type_ids=None, 136 | use_one_hot_embeddings=True, 137 | scope=None): 138 | """Constructor for BertModel. 139 | 140 | Args: 141 | config: `BertConfig` instance. 142 | is_training: bool. rue for training model, false for eval model. Controls 143 | whether dropout will be applied. 144 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 149 | it is must faster if this is True, on the CPU or GPU, it is faster if 150 | this is False. 151 | scope: (optional) variable scope. Defaults to "bert". 152 | 153 | Raises: 154 | ValueError: The config is invalid or one of the input tensor shapes 155 | is invalid. 156 | """ 157 | config = copy.deepcopy(config) 158 | if not is_training: 159 | config.hidden_dropout_prob = 0.0 160 | config.attention_probs_dropout_prob = 0.0 161 | 162 | input_shape = get_shape_list(input_ids, expected_rank=2) 163 | batch_size = input_shape[0] 164 | seq_length = input_shape[1] 165 | 166 | if input_mask is None: 167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 168 | 169 | if token_type_ids is None: 170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 171 | 172 | with tf.variable_scope(scope, default_name="bert"): 173 | with tf.variable_scope("embeddings"): 174 | # Perform embedding lookup on the word ids. 175 | (self.embedding_output, self.embedding_table) = embedding_lookup( 176 | input_ids=input_ids, 177 | vocab_size=config.vocab_size, 178 | embedding_size=config.hidden_size, 179 | initializer_range=config.initializer_range, 180 | word_embedding_name="word_embeddings", 181 | use_one_hot_embeddings=use_one_hot_embeddings) 182 | 183 | # Add positional embeddings and token type embeddings, then layer 184 | # normalize and perform dropout. 185 | self.embedding_output = embedding_postprocessor( 186 | input_tensor=self.embedding_output, 187 | use_token_type=True, 188 | token_type_ids=token_type_ids, 189 | token_type_vocab_size=config.type_vocab_size, 190 | token_type_embedding_name="token_type_embeddings", 191 | use_position_embeddings=True, 192 | position_embedding_name="position_embeddings", 193 | initializer_range=config.initializer_range, 194 | max_position_embeddings=config.max_position_embeddings, 195 | dropout_prob=config.hidden_dropout_prob) 196 | 197 | with tf.variable_scope("encoder"): 198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 199 | # mask of shape [batch_size, seq_length, seq_length] which is used 200 | # for the attention scores. 201 | attention_mask = create_attention_mask_from_input_mask( 202 | input_ids, input_mask) 203 | 204 | # Run the stacked transformer. 205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 206 | self.all_encoder_layers = transformer_model( 207 | input_tensor=self.embedding_output, 208 | attention_mask=attention_mask, 209 | hidden_size=config.hidden_size, 210 | num_hidden_layers=config.num_hidden_layers, 211 | num_attention_heads=config.num_attention_heads, 212 | intermediate_size=config.intermediate_size, 213 | intermediate_act_fn=get_activation(config.hidden_act), 214 | hidden_dropout_prob=config.hidden_dropout_prob, 215 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 216 | initializer_range=config.initializer_range, 217 | do_return_all_layers=True) 218 | 219 | self.sequence_output = self.all_encoder_layers[-1] 220 | # The "pooler" converts the encoded sequence tensor of shape 221 | # [batch_size, seq_length, hidden_size] to a tensor of shape 222 | # [batch_size, hidden_size]. This is necessary for segment-level 223 | # (or segment-pair-level) classification tasks where we need a fixed 224 | # dimensional representation of the segment. 225 | with tf.variable_scope("pooler"): 226 | # We "pool" the model by simply taking the hidden state corresponding 227 | # to the first token. We assume that this has been pre-trained 228 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 229 | self.pooled_output = tf.layers.dense( 230 | first_token_tensor, 231 | config.hidden_size, 232 | activation=tf.tanh, 233 | kernel_initializer=create_initializer(config.initializer_range)) 234 | 235 | def get_pooled_output(self): 236 | return self.pooled_output 237 | 238 | def get_sequence_output(self): 239 | """Gets final hidden layer of encoder. 240 | 241 | Returns: 242 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 243 | to the final hidden of the transformer encoder. 244 | """ 245 | return self.sequence_output 246 | 247 | def get_all_encoder_layers(self): 248 | return self.all_encoder_layers 249 | 250 | def get_embedding_output(self): 251 | """Gets output of the embedding lookup (i.e., input to the transformer). 252 | 253 | Returns: 254 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 255 | to the output of the embedding layer, after summing the word 256 | embeddings with the positional embeddings and the token type embeddings, 257 | then performing layer normalization. This is the input to the transformer. 258 | """ 259 | return self.embedding_output 260 | 261 | def get_embedding_table(self): 262 | return self.embedding_table 263 | 264 | 265 | def gelu(input_tensor): 266 | """Gaussian Error Linear Unit. 267 | 268 | This is a smoother version of the RELU. 269 | Original paper: https://arxiv.org/abs/1606.08415 270 | 271 | Args: 272 | input_tensor: float Tensor to perform activation. 273 | 274 | Returns: 275 | `input_tensor` with the GELU activation applied. 276 | """ 277 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 278 | return input_tensor * cdf 279 | 280 | 281 | def get_activation(activation_string): 282 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 283 | 284 | Args: 285 | activation_string: String name of the activation function. 286 | 287 | Returns: 288 | A Python function corresponding to the activation function. If 289 | `activation_string` is None, empty, or "linear", this will return None. 290 | If `activation_string` is not a string, it will return `activation_string`. 291 | 292 | Raises: 293 | ValueError: The `activation_string` does not correspond to a known 294 | activation. 295 | """ 296 | 297 | # We assume that anything that"s not a string is already an activation 298 | # function, so we just return it. 299 | if not isinstance(activation_string, six.string_types): 300 | return activation_string 301 | 302 | if not activation_string: 303 | return None 304 | 305 | act = activation_string.lower() 306 | if act == "linear": 307 | return None 308 | elif act == "relu": 309 | return tf.nn.relu 310 | elif act == "gelu": 311 | return gelu 312 | elif act == "tanh": 313 | return tf.tanh 314 | else: 315 | raise ValueError("Unsupported activation: %s" % act) 316 | 317 | 318 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 319 | """Compute the union of the current variables and checkpoint variables.""" 320 | assignment_map = {} 321 | initialized_variable_names = {} 322 | 323 | name_to_variable = collections.OrderedDict() 324 | for var in tvars: 325 | name = var.name 326 | m = re.match("^(.*):\\d+$", name) 327 | if m is not None: 328 | name = m.group(1) 329 | name_to_variable[name] = var 330 | 331 | init_vars = tf.train.list_variables(init_checkpoint) 332 | 333 | assignment_map = collections.OrderedDict() 334 | for x in init_vars: 335 | (name, var) = (x[0], x[1]) 336 | if name not in name_to_variable: 337 | continue 338 | assignment_map[name] = name 339 | initialized_variable_names[name] = 1 340 | initialized_variable_names[name + ":0"] = 1 341 | 342 | return (assignment_map, initialized_variable_names) 343 | 344 | 345 | def dropout(input_tensor, dropout_prob): 346 | """Perform dropout. 347 | 348 | Args: 349 | input_tensor: float Tensor. 350 | dropout_prob: Python float. The probability of dropping out a value (NOT of 351 | *keeping* a dimension as in `tf.nn.dropout`). 352 | 353 | Returns: 354 | A version of `input_tensor` with dropout applied. 355 | """ 356 | if dropout_prob is None or dropout_prob == 0.0: 357 | return input_tensor 358 | 359 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 360 | return output 361 | 362 | 363 | def layer_norm(input_tensor, name=None): 364 | """Run layer normalization on the last dimension of the tensor.""" 365 | return tf.contrib.layers.layer_norm( 366 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 367 | 368 | 369 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 370 | """Runs layer normalization followed by dropout.""" 371 | output_tensor = layer_norm(input_tensor, name) 372 | output_tensor = dropout(output_tensor, dropout_prob) 373 | return output_tensor 374 | 375 | 376 | def create_initializer(initializer_range=0.02): 377 | """Creates a `truncated_normal_initializer` with the given range.""" 378 | return tf.truncated_normal_initializer(stddev=initializer_range) 379 | 380 | 381 | def embedding_lookup(input_ids, 382 | vocab_size, 383 | embedding_size=128, 384 | initializer_range=0.02, 385 | word_embedding_name="word_embeddings", 386 | use_one_hot_embeddings=False): 387 | """Looks up words embeddings for id tensor. 388 | 389 | Args: 390 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 391 | ids. 392 | vocab_size: int. Size of the embedding vocabulary. 393 | embedding_size: int. Width of the word embeddings. 394 | initializer_range: float. Embedding initialization range. 395 | word_embedding_name: string. Name of the embedding table. 396 | use_one_hot_embeddings: bool. If True, use one-hot method for word 397 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 398 | for TPUs. 399 | 400 | Returns: 401 | float Tensor of shape [batch_size, seq_length, embedding_size]. 402 | """ 403 | # This function assumes that the input is of shape [batch_size, seq_length, 404 | # num_inputs]. 405 | # 406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 407 | # reshape to [batch_size, seq_length, 1]. 408 | if input_ids.shape.ndims == 2: 409 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 410 | 411 | embedding_table = tf.get_variable( 412 | name=word_embedding_name, 413 | shape=[vocab_size, embedding_size], 414 | initializer=create_initializer(initializer_range)) 415 | 416 | if use_one_hot_embeddings: 417 | flat_input_ids = tf.reshape(input_ids, [-1]) 418 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 419 | output = tf.matmul(one_hot_input_ids, embedding_table) 420 | else: 421 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 422 | 423 | input_shape = get_shape_list(input_ids) 424 | 425 | output = tf.reshape(output, 426 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 427 | return (output, embedding_table) 428 | 429 | 430 | def embedding_postprocessor(input_tensor, 431 | use_token_type=False, 432 | token_type_ids=None, 433 | token_type_vocab_size=16, 434 | token_type_embedding_name="token_type_embeddings", 435 | use_position_embeddings=True, 436 | position_embedding_name="position_embeddings", 437 | initializer_range=0.02, 438 | max_position_embeddings=512, 439 | dropout_prob=0.1): 440 | """Performs various post-processing on a word embedding tensor. 441 | 442 | Args: 443 | input_tensor: float Tensor of shape [batch_size, seq_length, 444 | embedding_size]. 445 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 446 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 447 | Must be specified if `use_token_type` is True. 448 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 449 | token_type_embedding_name: string. The name of the embedding table variable 450 | for token type ids. 451 | use_position_embeddings: bool. Whether to add position embeddings for the 452 | position of each token in the sequence. 453 | position_embedding_name: string. The name of the embedding table variable 454 | for positional embeddings. 455 | initializer_range: float. Range of the weight initialization. 456 | max_position_embeddings: int. Maximum sequence length that might ever be 457 | used with this model. This can be longer than the sequence length of 458 | input_tensor, but cannot be shorter. 459 | dropout_prob: float. Dropout probability applied to the final output tensor. 460 | 461 | Returns: 462 | float tensor with same shape as `input_tensor`. 463 | 464 | Raises: 465 | ValueError: One of the tensor shapes or input values is invalid. 466 | """ 467 | input_shape = get_shape_list(input_tensor, expected_rank=3) 468 | batch_size = input_shape[0] 469 | seq_length = input_shape[1] 470 | width = input_shape[2] 471 | 472 | if seq_length > max_position_embeddings: 473 | raise ValueError("The seq length (%d) cannot be greater than " 474 | "`max_position_embeddings` (%d)" % 475 | (seq_length, max_position_embeddings)) 476 | 477 | output = input_tensor 478 | 479 | if use_token_type: 480 | if token_type_ids is None: 481 | raise ValueError("`token_type_ids` must be specified if" 482 | "`use_token_type` is True.") 483 | token_type_table = tf.get_variable( 484 | name=token_type_embedding_name, 485 | shape=[token_type_vocab_size, width], 486 | initializer=create_initializer(initializer_range)) 487 | # This vocab will be small so we always do one-hot here, since it is always 488 | # faster for a small vocabulary. 489 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 490 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 491 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 492 | token_type_embeddings = tf.reshape(token_type_embeddings, 493 | [batch_size, seq_length, width]) 494 | output += token_type_embeddings 495 | 496 | if use_position_embeddings: 497 | full_position_embeddings = tf.get_variable( 498 | name=position_embedding_name, 499 | shape=[max_position_embeddings, width], 500 | initializer=create_initializer(initializer_range)) 501 | # Since the position embedding table is a learned variable, we create it 502 | # using a (long) sequence length `max_position_embeddings`. The actual 503 | # sequence length might be shorter than this, for faster training of 504 | # tasks that do not have long sequences. 505 | # 506 | # So `full_position_embeddings` is effectively an embedding table 507 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 508 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 509 | # perform a slice. 510 | if seq_length < max_position_embeddings: 511 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 512 | [seq_length, -1]) 513 | else: 514 | position_embeddings = full_position_embeddings 515 | 516 | num_dims = len(output.shape.as_list()) 517 | 518 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 519 | # we broadcast among the first dimensions, which is typically just 520 | # the batch size. 521 | position_broadcast_shape = [] 522 | for _ in range(num_dims - 2): 523 | position_broadcast_shape.append(1) 524 | position_broadcast_shape.extend([seq_length, width]) 525 | position_embeddings = tf.reshape(position_embeddings, 526 | position_broadcast_shape) 527 | output += position_embeddings 528 | 529 | output = layer_norm_and_dropout(output, dropout_prob) 530 | return output 531 | 532 | 533 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 534 | """Create 3D attention mask from a 2D tensor mask. 535 | 536 | Args: 537 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 538 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 539 | 540 | Returns: 541 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 542 | """ 543 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 544 | batch_size = from_shape[0] 545 | from_seq_length = from_shape[1] 546 | 547 | to_shape = get_shape_list(to_mask, expected_rank=2) 548 | to_seq_length = to_shape[1] 549 | 550 | to_mask = tf.cast( 551 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 552 | 553 | # We don't assume that `from_tensor` is a mask (although it could be). We 554 | # don't actually care if we attend *from* padding tokens (only *to* padding) 555 | # tokens so we create a tensor of all ones. 556 | # 557 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 558 | broadcast_ones = tf.ones( 559 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 560 | 561 | # Here we broadcast along two dimensions to create the mask. 562 | mask = broadcast_ones * to_mask 563 | 564 | return mask 565 | 566 | 567 | def attention_layer(from_tensor, 568 | to_tensor, 569 | attention_mask=None, 570 | num_attention_heads=1, 571 | size_per_head=512, 572 | query_act=None, 573 | key_act=None, 574 | value_act=None, 575 | attention_probs_dropout_prob=0.0, 576 | initializer_range=0.02, 577 | do_return_2d_tensor=False, 578 | batch_size=None, 579 | from_seq_length=None, 580 | to_seq_length=None): 581 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 582 | 583 | This is an implementation of multi-headed attention based on "Attention 584 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 585 | this is self-attention. Each timestep in `from_tensor` attends to the 586 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 587 | 588 | This function first projects `from_tensor` into a "query" tensor and 589 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 590 | of tensors of length `num_attention_heads`, where each tensor is of shape 591 | [batch_size, seq_length, size_per_head]. 592 | 593 | Then, the query and key tensors are dot-producted and scaled. These are 594 | softmaxed to obtain attention probabilities. The value tensors are then 595 | interpolated by these probabilities, then concatenated back to a single 596 | tensor and returned. 597 | 598 | In practice, the multi-headed attention are done with transposes and 599 | reshapes rather than actual separate tensors. 600 | 601 | Args: 602 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 603 | from_width]. 604 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 605 | attention_mask: (optional) int32 Tensor of shape [batch_size, 606 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 607 | attention scores will effectively be set to -infinity for any positions in 608 | the mask that are 0, and will be unchanged for positions that are 1. 609 | num_attention_heads: int. Number of attention heads. 610 | size_per_head: int. Size of each attention head. 611 | query_act: (optional) Activation function for the query transform. 612 | key_act: (optional) Activation function for the key transform. 613 | value_act: (optional) Activation function for the value transform. 614 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 615 | attention probabilities. 616 | initializer_range: float. Range of the weight initializer. 617 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 618 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 619 | output will be of shape [batch_size, from_seq_length, num_attention_heads 620 | * size_per_head]. 621 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 622 | of the 3D version of the `from_tensor` and `to_tensor`. 623 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 624 | of the 3D version of the `from_tensor`. 625 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 626 | of the 3D version of the `to_tensor`. 627 | 628 | Returns: 629 | float Tensor of shape [batch_size, from_seq_length, 630 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 631 | true, this will be of shape [batch_size * from_seq_length, 632 | num_attention_heads * size_per_head]). 633 | 634 | Raises: 635 | ValueError: Any of the arguments or tensor shapes are invalid. 636 | """ 637 | 638 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 639 | seq_length, width): 640 | output_tensor = tf.reshape( 641 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 642 | 643 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 644 | return output_tensor 645 | 646 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 647 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 648 | 649 | if len(from_shape) != len(to_shape): 650 | raise ValueError( 651 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 652 | 653 | if len(from_shape) == 3: 654 | batch_size = from_shape[0] 655 | from_seq_length = from_shape[1] 656 | to_seq_length = to_shape[1] 657 | elif len(from_shape) == 2: 658 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 659 | raise ValueError( 660 | "When passing in rank 2 tensors to attention_layer, the values " 661 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 662 | "must all be specified.") 663 | 664 | # Scalar dimensions referenced here: 665 | # B = batch size (number of sequences) 666 | # F = `from_tensor` sequence length 667 | # T = `to_tensor` sequence length 668 | # N = `num_attention_heads` 669 | # H = `size_per_head` 670 | 671 | from_tensor_2d = reshape_to_matrix(from_tensor) 672 | to_tensor_2d = reshape_to_matrix(to_tensor) 673 | 674 | # `query_layer` = [B*F, N*H] 675 | query_layer = tf.layers.dense( 676 | from_tensor_2d, 677 | num_attention_heads * size_per_head, 678 | activation=query_act, 679 | name="query", 680 | kernel_initializer=create_initializer(initializer_range)) 681 | 682 | # `key_layer` = [B*T, N*H] 683 | key_layer = tf.layers.dense( 684 | to_tensor_2d, 685 | num_attention_heads * size_per_head, 686 | activation=key_act, 687 | name="key", 688 | kernel_initializer=create_initializer(initializer_range)) 689 | 690 | # `value_layer` = [B*T, N*H] 691 | value_layer = tf.layers.dense( 692 | to_tensor_2d, 693 | num_attention_heads * size_per_head, 694 | activation=value_act, 695 | name="value", 696 | kernel_initializer=create_initializer(initializer_range)) 697 | 698 | # `query_layer` = [B, N, F, H] 699 | query_layer = transpose_for_scores(query_layer, batch_size, 700 | num_attention_heads, from_seq_length, 701 | size_per_head) 702 | 703 | # `key_layer` = [B, N, T, H] 704 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 705 | to_seq_length, size_per_head) 706 | 707 | # Take the dot product between "query" and "key" to get the raw 708 | # attention scores. 709 | # `attention_scores` = [B, N, F, T] 710 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 711 | attention_scores = tf.multiply(attention_scores, 712 | 1.0 / math.sqrt(float(size_per_head))) 713 | 714 | if attention_mask is not None: 715 | # `attention_mask` = [B, 1, F, T] 716 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 717 | 718 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 719 | # masked positions, this operation will create a tensor which is 0.0 for 720 | # positions we want to attend and -10000.0 for masked positions. 721 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 722 | 723 | # Since we are adding it to the raw scores before the softmax, this is 724 | # effectively the same as removing these entirely. 725 | attention_scores += adder 726 | 727 | # Normalize the attention scores to probabilities. 728 | # `attention_probs` = [B, N, F, T] 729 | attention_probs = tf.nn.softmax(attention_scores) 730 | 731 | # This is actually dropping out entire tokens to attend to, which might 732 | # seem a bit unusual, but is taken from the original Transformer paper. 733 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 734 | 735 | # `value_layer` = [B, T, N, H] 736 | value_layer = tf.reshape( 737 | value_layer, 738 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 739 | 740 | # `value_layer` = [B, N, T, H] 741 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 742 | 743 | # `context_layer` = [B, N, F, H] 744 | context_layer = tf.matmul(attention_probs, value_layer) 745 | 746 | # `context_layer` = [B, F, N, H] 747 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 748 | 749 | if do_return_2d_tensor: 750 | # `context_layer` = [B*F, N*V] 751 | context_layer = tf.reshape( 752 | context_layer, 753 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 754 | else: 755 | # `context_layer` = [B, F, N*V] 756 | context_layer = tf.reshape( 757 | context_layer, 758 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 759 | 760 | return context_layer 761 | 762 | 763 | def transformer_model(input_tensor, 764 | attention_mask=None, 765 | hidden_size=768, 766 | num_hidden_layers=12, 767 | num_attention_heads=12, 768 | intermediate_size=3072, 769 | intermediate_act_fn=gelu, 770 | hidden_dropout_prob=0.1, 771 | attention_probs_dropout_prob=0.1, 772 | initializer_range=0.02, 773 | do_return_all_layers=False): 774 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 775 | 776 | This is almost an exact implementation of the original Transformer encoder. 777 | 778 | See the original paper: 779 | https://arxiv.org/abs/1706.03762 780 | 781 | Also see: 782 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 783 | 784 | Args: 785 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 786 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 787 | seq_length], with 1 for positions that can be attended to and 0 in 788 | positions that should not be. 789 | hidden_size: int. Hidden size of the Transformer. 790 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 791 | num_attention_heads: int. Number of attention heads in the Transformer. 792 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 793 | forward) layer. 794 | intermediate_act_fn: function. The non-linear activation function to apply 795 | to the output of the intermediate/feed-forward layer. 796 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 797 | attention_probs_dropout_prob: float. Dropout probability of the attention 798 | probabilities. 799 | initializer_range: float. Range of the initializer (stddev of truncated 800 | normal). 801 | do_return_all_layers: Whether to also return all layers or just the final 802 | layer. 803 | 804 | Returns: 805 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 806 | hidden layer of the Transformer. 807 | 808 | Raises: 809 | ValueError: A Tensor shape or parameter is invalid. 810 | """ 811 | if hidden_size % num_attention_heads != 0: 812 | raise ValueError( 813 | "The hidden size (%d) is not a multiple of the number of attention " 814 | "heads (%d)" % (hidden_size, num_attention_heads)) 815 | 816 | attention_head_size = int(hidden_size / num_attention_heads) 817 | input_shape = get_shape_list(input_tensor, expected_rank=3) 818 | batch_size = input_shape[0] 819 | seq_length = input_shape[1] 820 | input_width = input_shape[2] 821 | 822 | # The Transformer performs sum residuals on all layers so the input needs 823 | # to be the same as the hidden size. 824 | if input_width != hidden_size: 825 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 826 | (input_width, hidden_size)) 827 | 828 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 829 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 830 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 831 | # help the optimizer. 832 | prev_output = reshape_to_matrix(input_tensor) 833 | 834 | all_layer_outputs = [] 835 | for layer_idx in range(num_hidden_layers): 836 | with tf.variable_scope("layer_%d" % layer_idx): 837 | layer_input = prev_output 838 | 839 | with tf.variable_scope("attention"): 840 | attention_heads = [] 841 | with tf.variable_scope("self"): 842 | attention_head = attention_layer( 843 | from_tensor=layer_input, 844 | to_tensor=layer_input, 845 | attention_mask=attention_mask, 846 | num_attention_heads=num_attention_heads, 847 | size_per_head=attention_head_size, 848 | attention_probs_dropout_prob=attention_probs_dropout_prob, 849 | initializer_range=initializer_range, 850 | do_return_2d_tensor=True, 851 | batch_size=batch_size, 852 | from_seq_length=seq_length, 853 | to_seq_length=seq_length) 854 | attention_heads.append(attention_head) 855 | 856 | attention_output = None 857 | if len(attention_heads) == 1: 858 | attention_output = attention_heads[0] 859 | else: 860 | # In the case where we have other sequences, we just concatenate 861 | # them to the self-attention head before the projection. 862 | attention_output = tf.concat(attention_heads, axis=-1) 863 | 864 | # Run a linear projection of `hidden_size` then add a residual 865 | # with `layer_input`. 866 | with tf.variable_scope("output"): 867 | attention_output = tf.layers.dense( 868 | attention_output, 869 | hidden_size, 870 | kernel_initializer=create_initializer(initializer_range)) 871 | attention_output = dropout(attention_output, hidden_dropout_prob) 872 | attention_output = layer_norm(attention_output + layer_input) 873 | 874 | # The activation is only applied to the "intermediate" hidden layer. 875 | with tf.variable_scope("intermediate"): 876 | intermediate_output = tf.layers.dense( 877 | attention_output, 878 | intermediate_size, 879 | activation=intermediate_act_fn, 880 | kernel_initializer=create_initializer(initializer_range)) 881 | 882 | # Down-project back to `hidden_size` then add the residual. 883 | with tf.variable_scope("output"): 884 | layer_output = tf.layers.dense( 885 | intermediate_output, 886 | hidden_size, 887 | kernel_initializer=create_initializer(initializer_range)) 888 | layer_output = dropout(layer_output, hidden_dropout_prob) 889 | layer_output = layer_norm(layer_output + attention_output) 890 | prev_output = layer_output 891 | all_layer_outputs.append(layer_output) 892 | 893 | if do_return_all_layers: 894 | final_outputs = [] 895 | for layer_output in all_layer_outputs: 896 | final_output = reshape_from_matrix(layer_output, input_shape) 897 | final_outputs.append(final_output) 898 | return final_outputs 899 | else: 900 | final_output = reshape_from_matrix(prev_output, input_shape) 901 | return final_output 902 | 903 | 904 | def get_shape_list(tensor, expected_rank=None, name=None): 905 | """Returns a list of the shape of tensor, preferring static dimensions. 906 | 907 | Args: 908 | tensor: A tf.Tensor object to find the shape of. 909 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 910 | specified and the `tensor` has a different rank, and exception will be 911 | thrown. 912 | name: Optional name of the tensor for the error message. 913 | 914 | Returns: 915 | A list of dimensions of the shape of tensor. All static dimensions will 916 | be returned as python integers, and dynamic dimensions will be returned 917 | as tf.Tensor scalars. 918 | """ 919 | if name is None: 920 | name = tensor.name 921 | 922 | if expected_rank is not None: 923 | assert_rank(tensor, expected_rank, name) 924 | 925 | shape = tensor.shape.as_list() 926 | 927 | non_static_indexes = [] 928 | for (index, dim) in enumerate(shape): 929 | if dim is None: 930 | non_static_indexes.append(index) 931 | 932 | if not non_static_indexes: 933 | return shape 934 | 935 | dyn_shape = tf.shape(tensor) 936 | for index in non_static_indexes: 937 | shape[index] = dyn_shape[index] 938 | return shape 939 | 940 | 941 | def reshape_to_matrix(input_tensor): 942 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 943 | ndims = input_tensor.shape.ndims 944 | if ndims < 2: 945 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 946 | (input_tensor.shape)) 947 | if ndims == 2: 948 | return input_tensor 949 | 950 | width = input_tensor.shape[-1] 951 | output_tensor = tf.reshape(input_tensor, [-1, width]) 952 | return output_tensor 953 | 954 | 955 | def reshape_from_matrix(output_tensor, orig_shape_list): 956 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 957 | if len(orig_shape_list) == 2: 958 | return output_tensor 959 | 960 | output_shape = get_shape_list(output_tensor) 961 | 962 | orig_dims = orig_shape_list[0:-1] 963 | width = output_shape[-1] 964 | 965 | return tf.reshape(output_tensor, orig_dims + [width]) 966 | 967 | 968 | def assert_rank(tensor, expected_rank, name=None): 969 | """Raises an exception if the tensor rank is not of the expected rank. 970 | 971 | Args: 972 | tensor: A tf.Tensor to check the rank of. 973 | expected_rank: Python integer or list of integers, expected rank. 974 | name: Optional name of the tensor for the error message. 975 | 976 | Raises: 977 | ValueError: If the expected shape doesn't match the actual shape. 978 | """ 979 | if name is None: 980 | name = tensor.name 981 | 982 | expected_rank_dict = {} 983 | if isinstance(expected_rank, six.integer_types): 984 | expected_rank_dict[expected_rank] = True 985 | else: 986 | for x in expected_rank: 987 | expected_rank_dict[x] = True 988 | 989 | actual_rank = tensor.shape.ndims 990 | if actual_rank not in expected_rank_dict: 991 | scope_name = tf.get_variable_scope().name 992 | raise ValueError( 993 | "For the tensor `%s` in scope `%s`, the actual rank " 994 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 995 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 996 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /run_classifier_elmo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import modeling 25 | import optimization 26 | import tokenization 27 | import tensorflow as tf 28 | 29 | flags = tf.flags 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | ## Required parameters 34 | flags.DEFINE_string( 35 | "data_dir", None, 36 | "The input data dir. Should contain the .tsv files (or other data files) " 37 | "for the task.") 38 | 39 | flags.DEFINE_string( 40 | "bert_config_file", None, 41 | "The config json file corresponding to the pre-trained BERT model. " 42 | "This specifies the model architecture.") 43 | 44 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 45 | 46 | flags.DEFINE_string("vocab_file", None, 47 | "The vocabulary file that the BERT model was trained on.") 48 | 49 | flags.DEFINE_string( 50 | "output_dir", None, 51 | "The output directory where the model checkpoints will be written.") 52 | 53 | ## Other parameters 54 | 55 | flags.DEFINE_string( 56 | "init_checkpoint", None, 57 | "Initial checkpoint (usually from a pre-trained BERT model).") 58 | 59 | flags.DEFINE_bool( 60 | "do_lower_case", True, 61 | "Whether to lower case the input text. Should be True for uncased " 62 | "models and False for cased models.") 63 | 64 | flags.DEFINE_integer( 65 | "max_seq_length", 128, 66 | "The maximum total input sequence length after WordPiece tokenization. " 67 | "Sequences longer than this will be truncated, and sequences shorter " 68 | "than this will be padded.") 69 | 70 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 71 | 72 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 73 | 74 | flags.DEFINE_bool( 75 | "do_predict", False, 76 | "Whether to run the model in inference mode on the test set.") 77 | 78 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 79 | 80 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 81 | 82 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 83 | 84 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 85 | 86 | flags.DEFINE_float("num_train_epochs", 3.0, 87 | "Total number of training epochs to perform.") 88 | 89 | flags.DEFINE_float( 90 | "warmup_proportion", 0.1, 91 | "Proportion of training to perform linear learning rate warmup for. " 92 | "E.g., 0.1 = 10% of training.") 93 | 94 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 95 | "How often to save the model checkpoint.") 96 | 97 | flags.DEFINE_integer("iterations_per_loop", 1000, 98 | "How many steps to make in each estimator call.") 99 | 100 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 101 | 102 | tf.flags.DEFINE_string( 103 | "tpu_name", None, 104 | "The Cloud TPU to use for training. This should be either the name " 105 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 106 | "url.") 107 | 108 | tf.flags.DEFINE_string( 109 | "tpu_zone", None, 110 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 111 | "specified, we will attempt to automatically detect the GCE project from " 112 | "metadata.") 113 | 114 | tf.flags.DEFINE_string( 115 | "gcp_project", None, 116 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 117 | "specified, we will attempt to automatically detect the GCE project from " 118 | "metadata.") 119 | 120 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 121 | 122 | flags.DEFINE_integer( 123 | "num_tpu_cores", 8, 124 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 125 | 126 | 127 | class InputExample(object): 128 | """A single training/test example for simple sequence classification.""" 129 | 130 | def __init__(self, guid, text_a, text_b=None, label=None): 131 | """Constructs a InputExample. 132 | 133 | Args: 134 | guid: Unique id for the example. 135 | text_a: string. The untokenized text of the first sequence. For single 136 | sequence tasks, only this sequence must be specified. 137 | text_b: (Optional) string. The untokenized text of the second sequence. 138 | Only must be specified for sequence pair tasks. 139 | label: (Optional) string. The label of the example. This should be 140 | specified for train and dev examples, but not for test examples. 141 | """ 142 | self.guid = guid 143 | self.text_a = text_a 144 | self.text_b = text_b 145 | self.label = label 146 | 147 | 148 | class InputFeatures(object): 149 | """A single set of features of data.""" 150 | 151 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 152 | self.input_ids = input_ids 153 | self.input_mask = input_mask 154 | self.segment_ids = segment_ids 155 | self.label_id = label_id 156 | 157 | 158 | class DataProcessor(object): 159 | """Base class for data converters for sequence classification data sets.""" 160 | 161 | def get_train_examples(self, data_dir): 162 | """Gets a collection of `InputExample`s for the train set.""" 163 | raise NotImplementedError() 164 | 165 | def get_dev_examples(self, data_dir): 166 | """Gets a collection of `InputExample`s for the dev set.""" 167 | raise NotImplementedError() 168 | 169 | def get_test_examples(self, data_dir): 170 | """Gets a collection of `InputExample`s for prediction.""" 171 | raise NotImplementedError() 172 | 173 | def get_labels(self): 174 | """Gets the list of labels for this data set.""" 175 | raise NotImplementedError() 176 | 177 | @classmethod 178 | def _read_tsv(cls, input_file, quotechar=None): 179 | """Reads a tab separated value file.""" 180 | with tf.gfile.Open(input_file, "r") as f: 181 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 182 | lines = [] 183 | for line in reader: 184 | lines.append(line) 185 | return lines 186 | 187 | 188 | class XnliProcessor(DataProcessor): 189 | """Processor for the XNLI data set.""" 190 | 191 | def __init__(self): 192 | self.language = "zh" 193 | 194 | def get_train_examples(self, data_dir): 195 | """See base class.""" 196 | lines = self._read_tsv( 197 | os.path.join(data_dir, "multinli", 198 | "multinli.train.%s.tsv" % self.language)) 199 | examples = [] 200 | for (i, line) in enumerate(lines): 201 | if i == 0: 202 | continue 203 | guid = "train-%d" % (i) 204 | text_a = tokenization.convert_to_unicode(line[0]) 205 | text_b = tokenization.convert_to_unicode(line[1]) 206 | label = tokenization.convert_to_unicode(line[2]) 207 | if label == tokenization.convert_to_unicode("contradictory"): 208 | label = tokenization.convert_to_unicode("contradiction") 209 | examples.append( 210 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 211 | return examples 212 | 213 | def get_dev_examples(self, data_dir): 214 | """See base class.""" 215 | lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) 216 | examples = [] 217 | for (i, line) in enumerate(lines): 218 | if i == 0: 219 | continue 220 | guid = "dev-%d" % (i) 221 | language = tokenization.convert_to_unicode(line[0]) 222 | if language != tokenization.convert_to_unicode(self.language): 223 | continue 224 | text_a = tokenization.convert_to_unicode(line[6]) 225 | text_b = tokenization.convert_to_unicode(line[7]) 226 | label = tokenization.convert_to_unicode(line[1]) 227 | examples.append( 228 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 229 | return examples 230 | 231 | def get_labels(self): 232 | """See base class.""" 233 | return ["contradiction", "entailment", "neutral"] 234 | 235 | 236 | class MnliProcessor(DataProcessor): 237 | """Processor for the MultiNLI data set (GLUE version).""" 238 | 239 | def get_train_examples(self, data_dir): 240 | """See base class.""" 241 | return self._create_examples( 242 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 243 | 244 | def get_dev_examples(self, data_dir): 245 | """See base class.""" 246 | return self._create_examples( 247 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 248 | "dev_matched") 249 | 250 | def get_test_examples(self, data_dir): 251 | """See base class.""" 252 | return self._create_examples( 253 | self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") 254 | 255 | def get_labels(self): 256 | """See base class.""" 257 | return ["contradiction", "entailment", "neutral"] 258 | 259 | def _create_examples(self, lines, set_type): 260 | """Creates examples for the training and dev sets.""" 261 | examples = [] 262 | for (i, line) in enumerate(lines): 263 | if i == 0: 264 | continue 265 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 266 | text_a = tokenization.convert_to_unicode(line[8]) 267 | text_b = tokenization.convert_to_unicode(line[9]) 268 | if set_type == "test": 269 | label = "contradiction" 270 | else: 271 | label = tokenization.convert_to_unicode(line[-1]) 272 | examples.append( 273 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 274 | return examples 275 | 276 | class QqpProcessor(DataProcessor): 277 | """Processor for the QQP data set.""" 278 | 279 | def get_train_examples(self, data_dir): 280 | """See base class.""" 281 | return self._create_examples( 282 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 283 | 284 | def get_dev_examples(self, data_dir): 285 | """See base class.""" 286 | return self._create_examples( 287 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 288 | 289 | def get_test_examples(self, data_dir): 290 | """See base class.""" 291 | return self._create_examples( 292 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 293 | 294 | def get_labels(self): 295 | """See base class.""" 296 | return ["0", "1"] 297 | 298 | def _create_examples(self, lines, set_type): 299 | """Creates examples for the training and dev sets.""" 300 | examples = [] 301 | for (i, line) in enumerate(lines): 302 | if i == 0 or len(line)!=6: 303 | continue 304 | guid = "%s-%s" % (set_type, i) 305 | text_a = tokenization.convert_to_unicode(line[3]) 306 | text_b = tokenization.convert_to_unicode(line[4]) 307 | if set_type == "test": 308 | label = "1" 309 | else: 310 | label = tokenization.convert_to_unicode(line[5]) 311 | examples.append( 312 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 313 | return examples 314 | 315 | 316 | class MrpcProcessor(DataProcessor): 317 | """Processor for the MRPC data set (GLUE version).""" 318 | 319 | def get_train_examples(self, data_dir): 320 | """See base class.""" 321 | return self._create_examples( 322 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 323 | 324 | def get_dev_examples(self, data_dir): 325 | """See base class.""" 326 | return self._create_examples( 327 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 328 | 329 | def get_test_examples(self, data_dir): 330 | """See base class.""" 331 | return self._create_examples( 332 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 333 | 334 | def get_labels(self): 335 | """See base class.""" 336 | return ["0", "1"] 337 | 338 | def _create_examples(self, lines, set_type): 339 | """Creates examples for the training and dev sets.""" 340 | examples = [] 341 | for (i, line) in enumerate(lines): 342 | if i == 0: 343 | continue 344 | guid = "%s-%s" % (set_type, i) 345 | text_a = tokenization.convert_to_unicode(line[3]) 346 | text_b = tokenization.convert_to_unicode(line[4]) 347 | if set_type == "test": 348 | label = "0" 349 | else: 350 | label = tokenization.convert_to_unicode(line[0]) 351 | examples.append( 352 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 353 | return examples 354 | 355 | 356 | class ColaProcessor(DataProcessor): 357 | """Processor for the CoLA data set (GLUE version).""" 358 | 359 | def get_train_examples(self, data_dir): 360 | """See base class.""" 361 | return self._create_examples( 362 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 363 | 364 | def get_dev_examples(self, data_dir): 365 | """See base class.""" 366 | return self._create_examples( 367 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 368 | 369 | def get_test_examples(self, data_dir): 370 | """See base class.""" 371 | return self._create_examples( 372 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 373 | 374 | def get_labels(self): 375 | """See base class.""" 376 | return ["0", "1"] 377 | 378 | def _create_examples(self, lines, set_type): 379 | """Creates examples for the training and dev sets.""" 380 | examples = [] 381 | for (i, line) in enumerate(lines): 382 | # Only the test set has a header 383 | if set_type == "test" and i == 0: 384 | continue 385 | guid = "%s-%s" % (set_type, i) 386 | if set_type == "test": 387 | text_a = tokenization.convert_to_unicode(line[1]) 388 | label = "0" 389 | else: 390 | text_a = tokenization.convert_to_unicode(line[3]) 391 | label = tokenization.convert_to_unicode(line[1]) 392 | examples.append( 393 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 394 | return examples 395 | 396 | 397 | def convert_single_example(ex_index, example, label_list, max_seq_length, 398 | tokenizer): 399 | """Converts a single `InputExample` into a single `InputFeatures`.""" 400 | label_map = {} 401 | for (i, label) in enumerate(label_list): 402 | label_map[label] = i 403 | 404 | tokens_a = tokenizer.tokenize(example.text_a) 405 | tokens_b = None 406 | if example.text_b: 407 | tokens_b = tokenizer.tokenize(example.text_b) 408 | 409 | if tokens_b: 410 | # Modifies `tokens_a` and `tokens_b` in place so that the total 411 | # length is less than the specified length. 412 | # Account for [CLS], [SEP], [SEP] with "- 3" 413 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 414 | else: 415 | # Account for [CLS] and [SEP] with "- 2" 416 | if len(tokens_a) > max_seq_length - 2: 417 | tokens_a = tokens_a[0:(max_seq_length - 2)] 418 | 419 | # The convention in BERT is: 420 | # (a) For sequence pairs: 421 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 422 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 423 | # (b) For single sequences: 424 | # tokens: [CLS] the dog is hairy . [SEP] 425 | # type_ids: 0 0 0 0 0 0 0 426 | # 427 | # Where "type_ids" are used to indicate whether this is the first 428 | # sequence or the second sequence. The embedding vectors for `type=0` and 429 | # `type=1` were learned during pre-training and are added to the wordpiece 430 | # embedding vector (and position vector). This is not *strictly* necessary 431 | # since the [SEP] token unambiguously separates the sequences, but it makes 432 | # it easier for the model to learn the concept of sequences. 433 | # 434 | # For classification tasks, the first vector (corresponding to [CLS]) is 435 | # used as as the "sentence vector". Note that this only makes sense because 436 | # the entire model is fine-tuned. 437 | tokens = [] 438 | segment_ids = [] 439 | tokens.append("[CLS]") 440 | segment_ids.append(0) 441 | for token in tokens_a: 442 | tokens.append(token) 443 | segment_ids.append(0) 444 | tokens.append("[SEP]") 445 | segment_ids.append(0) 446 | 447 | if tokens_b: 448 | for token in tokens_b: 449 | tokens.append(token) 450 | segment_ids.append(1) 451 | tokens.append("[SEP]") 452 | segment_ids.append(1) 453 | 454 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 455 | 456 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 457 | # tokens are attended to. 458 | input_mask = [1] * len(input_ids) 459 | 460 | # Zero-pad up to the sequence length. 461 | while len(input_ids) < max_seq_length: 462 | input_ids.append(0) 463 | input_mask.append(0) 464 | segment_ids.append(0) 465 | 466 | assert len(input_ids) == max_seq_length 467 | assert len(input_mask) == max_seq_length 468 | assert len(segment_ids) == max_seq_length 469 | 470 | label_id = label_map[example.label] 471 | if ex_index < 5: 472 | tf.logging.info("*** Example ***") 473 | tf.logging.info("guid: %s" % (example.guid)) 474 | tf.logging.info("tokens: %s" % " ".join( 475 | [tokenization.printable_text(x) for x in tokens])) 476 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 477 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 478 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 479 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 480 | 481 | feature = InputFeatures( 482 | input_ids=input_ids, 483 | input_mask=input_mask, 484 | segment_ids=segment_ids, 485 | label_id=label_id) 486 | return feature 487 | 488 | 489 | def file_based_convert_examples_to_features( 490 | examples, label_list, max_seq_length, tokenizer, output_file): 491 | """Convert a set of `InputExample`s to a TFRecord file.""" 492 | 493 | writer = tf.python_io.TFRecordWriter(output_file) 494 | 495 | for (ex_index, example) in enumerate(examples): 496 | if ex_index % 10000 == 0: 497 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 498 | 499 | feature = convert_single_example(ex_index, example, label_list, 500 | max_seq_length, tokenizer) 501 | 502 | def create_int_feature(values): 503 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 504 | return f 505 | 506 | features = collections.OrderedDict() 507 | features["input_ids"] = create_int_feature(feature.input_ids) 508 | features["input_mask"] = create_int_feature(feature.input_mask) 509 | features["segment_ids"] = create_int_feature(feature.segment_ids) 510 | features["label_ids"] = create_int_feature([feature.label_id]) 511 | 512 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 513 | writer.write(tf_example.SerializeToString()) 514 | 515 | 516 | def file_based_input_fn_builder(input_file, seq_length, is_training, 517 | drop_remainder): 518 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 519 | 520 | name_to_features = { 521 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 522 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 523 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 524 | "label_ids": tf.FixedLenFeature([], tf.int64), 525 | } 526 | 527 | def _decode_record(record, name_to_features): 528 | """Decodes a record to a TensorFlow example.""" 529 | example = tf.parse_single_example(record, name_to_features) 530 | 531 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 532 | # So cast all int64 to int32. 533 | for name in list(example.keys()): 534 | t = example[name] 535 | if t.dtype == tf.int64: 536 | t = tf.to_int32(t) 537 | example[name] = t 538 | 539 | return example 540 | 541 | def input_fn(params): 542 | """The actual input function.""" 543 | batch_size = params["batch_size"] 544 | 545 | # For training, we want a lot of parallel reading and shuffling. 546 | # For eval, we want no shuffling and parallel reading doesn't matter. 547 | d = tf.data.TFRecordDataset(input_file) 548 | if is_training: 549 | d = d.repeat() 550 | d = d.shuffle(buffer_size=100) 551 | 552 | d = d.apply( 553 | tf.contrib.data.map_and_batch( 554 | lambda record: _decode_record(record, name_to_features), 555 | batch_size=batch_size, 556 | drop_remainder=drop_remainder)) 557 | 558 | return d 559 | 560 | return input_fn 561 | 562 | 563 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 564 | """Truncates a sequence pair in place to the maximum length.""" 565 | 566 | # This is a simple heuristic which will always truncate the longer sequence 567 | # one token at a time. This makes more sense than truncating an equal percent 568 | # of tokens from each, since if one sequence is very short then each token 569 | # that's truncated likely contains more information than a longer sequence. 570 | while True: 571 | total_length = len(tokens_a) + len(tokens_b) 572 | if total_length <= max_length: 573 | break 574 | if len(tokens_a) > len(tokens_b): 575 | tokens_a.pop() 576 | else: 577 | tokens_b.pop() 578 | 579 | 580 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 581 | labels, num_labels, use_one_hot_embeddings): 582 | """Creates a classification model.""" 583 | model = modeling.BertModel( 584 | config=bert_config, 585 | is_training=is_training, 586 | input_ids=input_ids, 587 | input_mask=input_mask, 588 | token_type_ids=segment_ids, 589 | use_one_hot_embeddings=use_one_hot_embeddings) 590 | 591 | # In the demo, we are doing a simple classification task on the entire 592 | # segment. 593 | # 594 | # If you want to use the token-level output, use model.get_sequence_output() 595 | # instead. 596 | #output_layer = model.get_pooled_output() 597 | 598 | # ELMo method, more info see https://arxiv.org/abs/1802.05365v2 599 | all_encoder_layers = model.get_all_encoder_layers() 600 | 601 | hidden_size = all_encoder_layers[0].shape[-1].value 602 | 603 | ELMo_layer_numbers = 12 604 | if ELMo_layer_numbers > len(all_encoder_layers): 605 | ELMo_layer_numbers = len(all_encoder_layers) 606 | 607 | ELMO_sequence_output_list = [] 608 | with tf.variable_scope("ELMo"): 609 | for layer_idx in range(ELMo_layer_numbers): 610 | sequence_output = all_encoder_layers[layer_idx] 611 | ELMO_sequence_output_list.append(sequence_output) 612 | 613 | s_task = tf.Variable(tf.random_normal([ELMo_layer_numbers,]), name="s_task_layer_weight") 614 | s_task_weight = tf.nn.softmax(s_task) 615 | gama_task = tf.Variable(tf.random_normal([]), name="gama_task") 616 | 617 | ELMO_weighted_output_list = [] 618 | for layer_idx in range(ELMo_layer_numbers): 619 | sequence_hidden = ELMO_sequence_output_list[layer_idx] 620 | s_task_layer = s_task_weight[layer_idx] 621 | ELMO_weighted_output_list.append(sequence_hidden * s_task_layer) 622 | 623 | ELMO_weighted_output = tf.reduce_sum(ELMO_weighted_output_list, axis=0) * gama_task 624 | 625 | with tf.variable_scope("pooler"): 626 | # We "pool" the model by simply taking the hidden state corresponding 627 | # to the first token. We assume that this has been pre-trained 628 | first_token_tensor = tf.squeeze(ELMO_weighted_output[:, 0:1, :], axis=1) 629 | pooled_output = tf.layers.dense( 630 | first_token_tensor, 631 | hidden_size, 632 | activation=tf.tanh) 633 | 634 | output_layer = pooled_output 635 | 636 | output_weights = tf.get_variable( 637 | "output_weights", [num_labels, hidden_size], 638 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 639 | 640 | output_bias = tf.get_variable( 641 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 642 | 643 | with tf.variable_scope("loss"): 644 | if is_training: 645 | # I.e., 0.1 dropout 646 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 647 | 648 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 649 | logits = tf.nn.bias_add(logits, output_bias) 650 | probabilities = tf.nn.softmax(logits, axis=-1) 651 | log_probs = tf.nn.log_softmax(logits, axis=-1) 652 | 653 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 654 | 655 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 656 | loss = tf.reduce_mean(per_example_loss) 657 | 658 | return (loss, per_example_loss, logits, probabilities) 659 | 660 | 661 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 662 | num_train_steps, num_warmup_steps, use_tpu, 663 | use_one_hot_embeddings): 664 | """Returns `model_fn` closure for TPUEstimator.""" 665 | 666 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 667 | """The `model_fn` for TPUEstimator.""" 668 | 669 | tf.logging.info("*** Features ***") 670 | for name in sorted(features.keys()): 671 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 672 | 673 | input_ids = features["input_ids"] 674 | input_mask = features["input_mask"] 675 | segment_ids = features["segment_ids"] 676 | label_ids = features["label_ids"] 677 | 678 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 679 | 680 | (total_loss, per_example_loss, logits, probabilities) = create_model( 681 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 682 | num_labels, use_one_hot_embeddings) 683 | 684 | tvars = tf.trainable_variables() 685 | 686 | scaffold_fn = None 687 | if init_checkpoint: 688 | (assignment_map, initialized_variable_names 689 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 690 | if use_tpu: 691 | 692 | def tpu_scaffold(): 693 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 694 | return tf.train.Scaffold() 695 | 696 | scaffold_fn = tpu_scaffold 697 | else: 698 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 699 | 700 | tf.logging.info("**** Trainable Variables ****") 701 | for var in tvars: 702 | init_string = "" 703 | if var.name in initialized_variable_names: 704 | init_string = ", *INIT_FROM_CKPT*" 705 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 706 | init_string) 707 | 708 | output_spec = None 709 | if mode == tf.estimator.ModeKeys.TRAIN: 710 | 711 | train_op = optimization.create_optimizer( 712 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 713 | 714 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 715 | mode=mode, 716 | loss=total_loss, 717 | train_op=train_op, 718 | scaffold_fn=scaffold_fn) 719 | elif mode == tf.estimator.ModeKeys.EVAL: 720 | 721 | def metric_fn(per_example_loss, label_ids, logits): 722 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 723 | accuracy = tf.metrics.accuracy(label_ids, predictions) 724 | loss = tf.metrics.mean(per_example_loss) 725 | return { 726 | "eval_accuracy": accuracy, 727 | "eval_loss": loss, 728 | } 729 | 730 | eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) 731 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 732 | mode=mode, 733 | loss=total_loss, 734 | eval_metrics=eval_metrics, 735 | scaffold_fn=scaffold_fn) 736 | else: 737 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 738 | mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn) 739 | return output_spec 740 | 741 | return model_fn 742 | 743 | 744 | # This function is not used by this file but is still used by the Colab and 745 | # people who depend on it. 746 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 747 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 748 | 749 | all_input_ids = [] 750 | all_input_mask = [] 751 | all_segment_ids = [] 752 | all_label_ids = [] 753 | 754 | for feature in features: 755 | all_input_ids.append(feature.input_ids) 756 | all_input_mask.append(feature.input_mask) 757 | all_segment_ids.append(feature.segment_ids) 758 | all_label_ids.append(feature.label_id) 759 | 760 | def input_fn(params): 761 | """The actual input function.""" 762 | batch_size = params["batch_size"] 763 | 764 | num_examples = len(features) 765 | 766 | # This is for demo purposes and does NOT scale to large data sets. We do 767 | # not use Dataset.from_generator() because that uses tf.py_func which is 768 | # not TPU compatible. The right way to load data is with TFRecordReader. 769 | d = tf.data.Dataset.from_tensor_slices({ 770 | "input_ids": 771 | tf.constant( 772 | all_input_ids, shape=[num_examples, seq_length], 773 | dtype=tf.int32), 774 | "input_mask": 775 | tf.constant( 776 | all_input_mask, 777 | shape=[num_examples, seq_length], 778 | dtype=tf.int32), 779 | "segment_ids": 780 | tf.constant( 781 | all_segment_ids, 782 | shape=[num_examples, seq_length], 783 | dtype=tf.int32), 784 | "label_ids": 785 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 786 | }) 787 | 788 | if is_training: 789 | d = d.repeat() 790 | d = d.shuffle(buffer_size=100) 791 | 792 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 793 | return d 794 | 795 | return input_fn 796 | 797 | 798 | # This function is not used by this file but is still used by the Colab and 799 | # people who depend on it. 800 | def convert_examples_to_features(examples, label_list, max_seq_length, 801 | tokenizer): 802 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 803 | 804 | features = [] 805 | for (ex_index, example) in enumerate(examples): 806 | if ex_index % 10000 == 0: 807 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 808 | 809 | feature = convert_single_example(ex_index, example, label_list, 810 | max_seq_length, tokenizer) 811 | 812 | features.append(feature) 813 | return features 814 | 815 | 816 | def main(_): 817 | tf.logging.set_verbosity(tf.logging.INFO) 818 | 819 | processors = { 820 | "cola": ColaProcessor, 821 | "mnli": MnliProcessor, 822 | "mrpc": MrpcProcessor, 823 | "xnli": XnliProcessor, 824 | "qqp": QqpProcessor, 825 | } 826 | 827 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 828 | raise ValueError( 829 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 830 | 831 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 832 | 833 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 834 | raise ValueError( 835 | "Cannot use sequence length %d because the BERT model " 836 | "was only trained up to sequence length %d" % 837 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 838 | 839 | tf.gfile.MakeDirs(FLAGS.output_dir) 840 | 841 | task_name = FLAGS.task_name.lower() 842 | 843 | if task_name not in processors: 844 | raise ValueError("Task not found: %s" % (task_name)) 845 | 846 | processor = processors[task_name]() 847 | 848 | label_list = processor.get_labels() 849 | 850 | tokenizer = tokenization.FullTokenizer( 851 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 852 | 853 | tpu_cluster_resolver = None 854 | if FLAGS.use_tpu and FLAGS.tpu_name: 855 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 856 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 857 | 858 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 859 | run_config = tf.contrib.tpu.RunConfig( 860 | cluster=tpu_cluster_resolver, 861 | master=FLAGS.master, 862 | model_dir=FLAGS.output_dir, 863 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 864 | tpu_config=tf.contrib.tpu.TPUConfig( 865 | iterations_per_loop=FLAGS.iterations_per_loop, 866 | num_shards=FLAGS.num_tpu_cores, 867 | per_host_input_for_training=is_per_host)) 868 | 869 | train_examples = None 870 | num_train_steps = None 871 | num_warmup_steps = None 872 | if FLAGS.do_train: 873 | train_examples = processor.get_train_examples(FLAGS.data_dir) 874 | num_train_steps = int( 875 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 876 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 877 | 878 | model_fn = model_fn_builder( 879 | bert_config=bert_config, 880 | num_labels=len(label_list), 881 | init_checkpoint=FLAGS.init_checkpoint, 882 | learning_rate=FLAGS.learning_rate, 883 | num_train_steps=num_train_steps, 884 | num_warmup_steps=num_warmup_steps, 885 | use_tpu=FLAGS.use_tpu, 886 | use_one_hot_embeddings=FLAGS.use_tpu) 887 | 888 | # If TPU is not available, this will fall back to normal Estimator on CPU 889 | # or GPU. 890 | estimator = tf.contrib.tpu.TPUEstimator( 891 | use_tpu=FLAGS.use_tpu, 892 | model_fn=model_fn, 893 | config=run_config, 894 | train_batch_size=FLAGS.train_batch_size, 895 | eval_batch_size=FLAGS.eval_batch_size, 896 | predict_batch_size=FLAGS.predict_batch_size) 897 | 898 | if FLAGS.do_train: 899 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 900 | file_based_convert_examples_to_features( 901 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 902 | tf.logging.info("***** Running training *****") 903 | tf.logging.info(" Num examples = %d", len(train_examples)) 904 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 905 | tf.logging.info(" Num steps = %d", num_train_steps) 906 | train_input_fn = file_based_input_fn_builder( 907 | input_file=train_file, 908 | seq_length=FLAGS.max_seq_length, 909 | is_training=True, 910 | drop_remainder=True) 911 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 912 | 913 | if FLAGS.do_eval: 914 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 915 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 916 | file_based_convert_examples_to_features( 917 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 918 | 919 | tf.logging.info("***** Running evaluation *****") 920 | tf.logging.info(" Num examples = %d", len(eval_examples)) 921 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 922 | 923 | # This tells the estimator to run through the entire set. 924 | eval_steps = None 925 | # However, if running eval on the TPU, you will need to specify the 926 | # number of steps. 927 | if FLAGS.use_tpu: 928 | # Eval will be slightly WRONG on the TPU because it will truncate 929 | # the last batch. 930 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 931 | 932 | eval_drop_remainder = True if FLAGS.use_tpu else False 933 | eval_input_fn = file_based_input_fn_builder( 934 | input_file=eval_file, 935 | seq_length=FLAGS.max_seq_length, 936 | is_training=False, 937 | drop_remainder=eval_drop_remainder) 938 | 939 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 940 | 941 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 942 | with tf.gfile.GFile(output_eval_file, "w") as writer: 943 | tf.logging.info("***** Eval results *****") 944 | for key in sorted(result.keys()): 945 | tf.logging.info(" %s = %s", key, str(result[key])) 946 | writer.write("%s = %s\n" % (key, str(result[key]))) 947 | 948 | if FLAGS.do_predict: 949 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 950 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 951 | file_based_convert_examples_to_features(predict_examples, label_list, 952 | FLAGS.max_seq_length, tokenizer, 953 | predict_file) 954 | 955 | tf.logging.info("***** Running prediction*****") 956 | tf.logging.info(" Num examples = %d", len(predict_examples)) 957 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 958 | 959 | if FLAGS.use_tpu: 960 | # Warning: According to tpu_estimator.py Prediction on TPU is an 961 | # experimental feature and hence not supported here 962 | raise ValueError("Prediction in TPU not supported") 963 | 964 | predict_drop_remainder = True if FLAGS.use_tpu else False 965 | predict_input_fn = file_based_input_fn_builder( 966 | input_file=predict_file, 967 | seq_length=FLAGS.max_seq_length, 968 | is_training=False, 969 | drop_remainder=predict_drop_remainder) 970 | 971 | result = estimator.predict(input_fn=predict_input_fn) 972 | 973 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 974 | with tf.gfile.GFile(output_predict_file, "w") as writer: 975 | tf.logging.info("***** Predict results *****") 976 | for prediction in result: 977 | output_line = "\t".join( 978 | str(class_probability) for class_probability in prediction) + "\n" 979 | writer.write(output_line) 980 | 981 | 982 | if __name__ == "__main__": 983 | flags.mark_flag_as_required("data_dir") 984 | flags.mark_flag_as_required("task_name") 985 | flags.mark_flag_as_required("vocab_file") 986 | flags.mark_flag_as_required("bert_config_file") 987 | flags.mark_flag_as_required("output_dir") 988 | tf.app.run() 989 | -------------------------------------------------------------------------------- /run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling 23 | import optimization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | ## Required parameters 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string( 37 | "input_file", None, 38 | "Input TF example files (can be a glob or comma separated).") 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | ## Other parameters 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded. Must match data generation.") 54 | 55 | flags.DEFINE_integer( 56 | "max_predictions_per_seq", 20, 57 | "Maximum number of masked LM predictions per sequence. " 58 | "Must match data generation.") 59 | 60 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 61 | 62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 63 | 64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 65 | 66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 67 | 68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 69 | 70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 71 | 72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 73 | 74 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 75 | "How often to save the model checkpoint.") 76 | 77 | flags.DEFINE_integer("iterations_per_loop", 1000, 78 | "How many steps to make in each estimator call.") 79 | 80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 81 | 82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 83 | 84 | tf.flags.DEFINE_string( 85 | "tpu_name", None, 86 | "The Cloud TPU to use for training. This should be either the name " 87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 88 | "url.") 89 | 90 | tf.flags.DEFINE_string( 91 | "tpu_zone", None, 92 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 93 | "specified, we will attempt to automatically detect the GCE project from " 94 | "metadata.") 95 | 96 | tf.flags.DEFINE_string( 97 | "gcp_project", None, 98 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 99 | "specified, we will attempt to automatically detect the GCE project from " 100 | "metadata.") 101 | 102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 103 | 104 | flags.DEFINE_integer( 105 | "num_tpu_cores", 8, 106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 107 | 108 | 109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 110 | num_train_steps, num_warmup_steps, use_tpu, 111 | use_one_hot_embeddings): 112 | """Returns `model_fn` closure for TPUEstimator.""" 113 | 114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 115 | """The `model_fn` for TPUEstimator.""" 116 | 117 | tf.logging.info("*** Features ***") 118 | for name in sorted(features.keys()): 119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 120 | 121 | input_ids = features["input_ids"] 122 | input_mask = features["input_mask"] 123 | segment_ids = features["segment_ids"] 124 | masked_lm_positions = features["masked_lm_positions"] 125 | masked_lm_ids = features["masked_lm_ids"] 126 | masked_lm_weights = features["masked_lm_weights"] 127 | next_sentence_labels = features["next_sentence_labels"] 128 | 129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 130 | 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | (masked_lm_loss, 140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 141 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 142 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 143 | 144 | (next_sentence_loss, next_sentence_example_loss, 145 | next_sentence_log_probs) = get_next_sentence_output( 146 | bert_config, model.get_pooled_output(), next_sentence_labels) 147 | 148 | total_loss = masked_lm_loss + next_sentence_loss 149 | 150 | tvars = tf.trainable_variables() 151 | 152 | initialized_variable_names = {} 153 | scaffold_fn = None 154 | if init_checkpoint: 155 | (assignment_map, initialized_variable_names 156 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 157 | if use_tpu: 158 | 159 | def tpu_scaffold(): 160 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 161 | return tf.train.Scaffold() 162 | 163 | scaffold_fn = tpu_scaffold 164 | else: 165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 166 | 167 | tf.logging.info("**** Trainable Variables ****") 168 | for var in tvars: 169 | init_string = "" 170 | if var.name in initialized_variable_names: 171 | init_string = ", *INIT_FROM_CKPT*" 172 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 173 | init_string) 174 | 175 | output_spec = None 176 | if mode == tf.estimator.ModeKeys.TRAIN: 177 | train_op = optimization.create_optimizer( 178 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 179 | 180 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | train_op=train_op, 184 | scaffold_fn=scaffold_fn) 185 | elif mode == tf.estimator.ModeKeys.EVAL: 186 | 187 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 188 | masked_lm_weights, next_sentence_example_loss, 189 | next_sentence_log_probs, next_sentence_labels): 190 | """Computes the loss and accuracy of the model.""" 191 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 192 | [-1, masked_lm_log_probs.shape[-1]]) 193 | masked_lm_predictions = tf.argmax( 194 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 195 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 196 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 197 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 198 | masked_lm_accuracy = tf.metrics.accuracy( 199 | labels=masked_lm_ids, 200 | predictions=masked_lm_predictions, 201 | weights=masked_lm_weights) 202 | masked_lm_mean_loss = tf.metrics.mean( 203 | values=masked_lm_example_loss, weights=masked_lm_weights) 204 | 205 | next_sentence_log_probs = tf.reshape( 206 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 207 | next_sentence_predictions = tf.argmax( 208 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 209 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 210 | next_sentence_accuracy = tf.metrics.accuracy( 211 | labels=next_sentence_labels, predictions=next_sentence_predictions) 212 | next_sentence_mean_loss = tf.metrics.mean( 213 | values=next_sentence_example_loss) 214 | 215 | return { 216 | "masked_lm_accuracy": masked_lm_accuracy, 217 | "masked_lm_loss": masked_lm_mean_loss, 218 | "next_sentence_accuracy": next_sentence_accuracy, 219 | "next_sentence_loss": next_sentence_mean_loss, 220 | } 221 | 222 | eval_metrics = (metric_fn, [ 223 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 224 | masked_lm_weights, next_sentence_example_loss, 225 | next_sentence_log_probs, next_sentence_labels 226 | ]) 227 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 228 | mode=mode, 229 | loss=total_loss, 230 | eval_metrics=eval_metrics, 231 | scaffold_fn=scaffold_fn) 232 | else: 233 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 234 | 235 | return output_spec 236 | 237 | return model_fn 238 | 239 | 240 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 241 | label_ids, label_weights): 242 | """Get loss and log probs for the masked LM.""" 243 | input_tensor = gather_indexes(input_tensor, positions) 244 | 245 | with tf.variable_scope("cls/predictions"): 246 | # We apply one more non-linear transformation before the output layer. 247 | # This matrix is not used after pre-training. 248 | with tf.variable_scope("transform"): 249 | input_tensor = tf.layers.dense( 250 | input_tensor, 251 | units=bert_config.hidden_size, 252 | activation=modeling.get_activation(bert_config.hidden_act), 253 | kernel_initializer=modeling.create_initializer( 254 | bert_config.initializer_range)) 255 | input_tensor = modeling.layer_norm(input_tensor) 256 | 257 | # The output weights are the same as the input embeddings, but there is 258 | # an output-only bias for each token. 259 | output_bias = tf.get_variable( 260 | "output_bias", 261 | shape=[bert_config.vocab_size], 262 | initializer=tf.zeros_initializer()) 263 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 264 | logits = tf.nn.bias_add(logits, output_bias) 265 | log_probs = tf.nn.log_softmax(logits, axis=-1) 266 | 267 | label_ids = tf.reshape(label_ids, [-1]) 268 | label_weights = tf.reshape(label_weights, [-1]) 269 | 270 | one_hot_labels = tf.one_hot( 271 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 272 | 273 | # The `positions` tensor might be zero-padded (if the sequence is too 274 | # short to have the maximum number of predictions). The `label_weights` 275 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 276 | # padding predictions. 277 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 278 | numerator = tf.reduce_sum(label_weights * per_example_loss) 279 | denominator = tf.reduce_sum(label_weights) + 1e-5 280 | loss = numerator / denominator 281 | 282 | return (loss, per_example_loss, log_probs) 283 | 284 | 285 | def get_next_sentence_output(bert_config, input_tensor, labels): 286 | """Get loss and log probs for the next sentence prediction.""" 287 | 288 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 289 | # "random sentence". This weight matrix is not used after pre-training. 290 | with tf.variable_scope("cls/seq_relationship"): 291 | output_weights = tf.get_variable( 292 | "output_weights", 293 | shape=[2, bert_config.hidden_size], 294 | initializer=modeling.create_initializer(bert_config.initializer_range)) 295 | output_bias = tf.get_variable( 296 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 297 | 298 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 299 | logits = tf.nn.bias_add(logits, output_bias) 300 | log_probs = tf.nn.log_softmax(logits, axis=-1) 301 | labels = tf.reshape(labels, [-1]) 302 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 303 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 304 | loss = tf.reduce_mean(per_example_loss) 305 | return (loss, per_example_loss, log_probs) 306 | 307 | 308 | def gather_indexes(sequence_tensor, positions): 309 | """Gathers the vectors at the specific positions over a minibatch.""" 310 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 311 | batch_size = sequence_shape[0] 312 | seq_length = sequence_shape[1] 313 | width = sequence_shape[2] 314 | 315 | flat_offsets = tf.reshape( 316 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 317 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 318 | flat_sequence_tensor = tf.reshape(sequence_tensor, 319 | [batch_size * seq_length, width]) 320 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 321 | return output_tensor 322 | 323 | 324 | def input_fn_builder(input_files, 325 | max_seq_length, 326 | max_predictions_per_seq, 327 | is_training, 328 | num_cpu_threads=4): 329 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 330 | 331 | def input_fn(params): 332 | """The actual input function.""" 333 | batch_size = params["batch_size"] 334 | 335 | name_to_features = { 336 | "input_ids": 337 | tf.FixedLenFeature([max_seq_length], tf.int64), 338 | "input_mask": 339 | tf.FixedLenFeature([max_seq_length], tf.int64), 340 | "segment_ids": 341 | tf.FixedLenFeature([max_seq_length], tf.int64), 342 | "masked_lm_positions": 343 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 344 | "masked_lm_ids": 345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 346 | "masked_lm_weights": 347 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 348 | "next_sentence_labels": 349 | tf.FixedLenFeature([1], tf.int64), 350 | } 351 | 352 | # For training, we want a lot of parallel reading and shuffling. 353 | # For eval, we want no shuffling and parallel reading doesn't matter. 354 | if is_training: 355 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 356 | d = d.repeat() 357 | d = d.shuffle(buffer_size=len(input_files)) 358 | 359 | # `cycle_length` is the number of parallel files that get read. 360 | cycle_length = min(num_cpu_threads, len(input_files)) 361 | 362 | # `sloppy` mode means that the interleaving is not exact. This adds 363 | # even more randomness to the training pipeline. 364 | d = d.apply( 365 | tf.contrib.data.parallel_interleave( 366 | tf.data.TFRecordDataset, 367 | sloppy=is_training, 368 | cycle_length=cycle_length)) 369 | d = d.shuffle(buffer_size=100) 370 | else: 371 | d = tf.data.TFRecordDataset(input_files) 372 | # Since we evaluate for a fixed number of steps we don't want to encounter 373 | # out-of-range exceptions. 374 | d = d.repeat() 375 | 376 | # We must `drop_remainder` on training because the TPU requires fixed 377 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 378 | # and we *don't* want to drop the remainder, otherwise we wont cover 379 | # every sample. 380 | d = d.apply( 381 | tf.contrib.data.map_and_batch( 382 | lambda record: _decode_record(record, name_to_features), 383 | batch_size=batch_size, 384 | num_parallel_batches=num_cpu_threads, 385 | drop_remainder=True)) 386 | return d 387 | 388 | return input_fn 389 | 390 | 391 | def _decode_record(record, name_to_features): 392 | """Decodes a record to a TensorFlow example.""" 393 | example = tf.parse_single_example(record, name_to_features) 394 | 395 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 396 | # So cast all int64 to int32. 397 | for name in list(example.keys()): 398 | t = example[name] 399 | if t.dtype == tf.int64: 400 | t = tf.to_int32(t) 401 | example[name] = t 402 | 403 | return example 404 | 405 | 406 | def main(_): 407 | tf.logging.set_verbosity(tf.logging.INFO) 408 | 409 | if not FLAGS.do_train and not FLAGS.do_eval: 410 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 411 | 412 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 413 | 414 | tf.gfile.MakeDirs(FLAGS.output_dir) 415 | 416 | input_files = [] 417 | for input_pattern in FLAGS.input_file.split(","): 418 | input_files.extend(tf.gfile.Glob(input_pattern)) 419 | 420 | tf.logging.info("*** Input Files ***") 421 | for input_file in input_files: 422 | tf.logging.info(" %s" % input_file) 423 | 424 | tpu_cluster_resolver = None 425 | if FLAGS.use_tpu and FLAGS.tpu_name: 426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 428 | 429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 430 | run_config = tf.contrib.tpu.RunConfig( 431 | cluster=tpu_cluster_resolver, 432 | master=FLAGS.master, 433 | model_dir=FLAGS.output_dir, 434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 435 | tpu_config=tf.contrib.tpu.TPUConfig( 436 | iterations_per_loop=FLAGS.iterations_per_loop, 437 | num_shards=FLAGS.num_tpu_cores, 438 | per_host_input_for_training=is_per_host)) 439 | 440 | model_fn = model_fn_builder( 441 | bert_config=bert_config, 442 | init_checkpoint=FLAGS.init_checkpoint, 443 | learning_rate=FLAGS.learning_rate, 444 | num_train_steps=FLAGS.num_train_steps, 445 | num_warmup_steps=FLAGS.num_warmup_steps, 446 | use_tpu=FLAGS.use_tpu, 447 | use_one_hot_embeddings=FLAGS.use_tpu) 448 | 449 | # If TPU is not available, this will fall back to normal Estimator on CPU 450 | # or GPU. 451 | estimator = tf.contrib.tpu.TPUEstimator( 452 | use_tpu=FLAGS.use_tpu, 453 | model_fn=model_fn, 454 | config=run_config, 455 | train_batch_size=FLAGS.train_batch_size, 456 | eval_batch_size=FLAGS.eval_batch_size) 457 | 458 | if FLAGS.do_train: 459 | tf.logging.info("***** Running training *****") 460 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 461 | train_input_fn = input_fn_builder( 462 | input_files=input_files, 463 | max_seq_length=FLAGS.max_seq_length, 464 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 465 | is_training=True) 466 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 467 | 468 | if FLAGS.do_eval: 469 | tf.logging.info("***** Running evaluation *****") 470 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 471 | 472 | eval_input_fn = input_fn_builder( 473 | input_files=input_files, 474 | max_seq_length=FLAGS.max_seq_length, 475 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 476 | is_training=False) 477 | 478 | result = estimator.evaluate( 479 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 480 | 481 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 482 | with tf.gfile.GFile(output_eval_file, "w") as writer: 483 | tf.logging.info("***** Eval results *****") 484 | for key in sorted(result.keys()): 485 | tf.logging.info(" %s = %s", key, str(result[key])) 486 | writer.write("%s = %s\n" % (key, str(result[key]))) 487 | 488 | 489 | if __name__ == "__main__": 490 | flags.mark_flag_as_required("input_file") 491 | flags.mark_flag_as_required("bert_config_file") 492 | flags.mark_flag_as_required("output_dir") 493 | tf.app.run() 494 | -------------------------------------------------------------------------------- /run_squad.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run BERT on SQuAD.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import json 23 | import math 24 | import os 25 | import random 26 | import modeling 27 | import optimization 28 | import tokenization 29 | import six 30 | import tensorflow as tf 31 | 32 | flags = tf.flags 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | ## Required parameters 37 | flags.DEFINE_string( 38 | "bert_config_file", None, 39 | "The config json file corresponding to the pre-trained BERT model. " 40 | "This specifies the model architecture.") 41 | 42 | flags.DEFINE_string("vocab_file", None, 43 | "The vocabulary file that the BERT model was trained on.") 44 | 45 | flags.DEFINE_string( 46 | "output_dir", None, 47 | "The output directory where the model checkpoints will be written.") 48 | 49 | ## Other parameters 50 | flags.DEFINE_string("train_file", None, 51 | "SQuAD json for training. E.g., train-v1.1.json") 52 | 53 | flags.DEFINE_string( 54 | "predict_file", None, 55 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 56 | 57 | flags.DEFINE_string( 58 | "init_checkpoint", None, 59 | "Initial checkpoint (usually from a pre-trained BERT model).") 60 | 61 | flags.DEFINE_bool( 62 | "do_lower_case", True, 63 | "Whether to lower case the input text. Should be True for uncased " 64 | "models and False for cased models.") 65 | 66 | flags.DEFINE_integer( 67 | "max_seq_length", 384, 68 | "The maximum total input sequence length after WordPiece tokenization. " 69 | "Sequences longer than this will be truncated, and sequences shorter " 70 | "than this will be padded.") 71 | 72 | flags.DEFINE_integer( 73 | "doc_stride", 128, 74 | "When splitting up a long document into chunks, how much stride to " 75 | "take between chunks.") 76 | 77 | flags.DEFINE_integer( 78 | "max_query_length", 64, 79 | "The maximum number of tokens for the question. Questions longer than " 80 | "this will be truncated to this length.") 81 | 82 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 83 | 84 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.") 85 | 86 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 87 | 88 | flags.DEFINE_integer("predict_batch_size", 8, 89 | "Total batch size for predictions.") 90 | 91 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 92 | 93 | flags.DEFINE_float("num_train_epochs", 3.0, 94 | "Total number of training epochs to perform.") 95 | 96 | flags.DEFINE_float( 97 | "warmup_proportion", 0.1, 98 | "Proportion of training to perform linear learning rate warmup for. " 99 | "E.g., 0.1 = 10% of training.") 100 | 101 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 102 | "How often to save the model checkpoint.") 103 | 104 | flags.DEFINE_integer("iterations_per_loop", 1000, 105 | "How many steps to make in each estimator call.") 106 | 107 | flags.DEFINE_integer( 108 | "n_best_size", 20, 109 | "The total number of n-best predictions to generate in the " 110 | "nbest_predictions.json output file.") 111 | 112 | flags.DEFINE_integer( 113 | "max_answer_length", 30, 114 | "The maximum length of an answer that can be generated. This is needed " 115 | "because the start and end predictions are not conditioned on one another.") 116 | 117 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 118 | 119 | tf.flags.DEFINE_string( 120 | "tpu_name", None, 121 | "The Cloud TPU to use for training. This should be either the name " 122 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 123 | "url.") 124 | 125 | tf.flags.DEFINE_string( 126 | "tpu_zone", None, 127 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 128 | "specified, we will attempt to automatically detect the GCE project from " 129 | "metadata.") 130 | 131 | tf.flags.DEFINE_string( 132 | "gcp_project", None, 133 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 134 | "specified, we will attempt to automatically detect the GCE project from " 135 | "metadata.") 136 | 137 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 138 | 139 | flags.DEFINE_integer( 140 | "num_tpu_cores", 8, 141 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 142 | 143 | flags.DEFINE_bool( 144 | "verbose_logging", False, 145 | "If true, all of the warnings related to data processing will be printed. " 146 | "A number of warnings are expected for a normal SQuAD evaluation.") 147 | 148 | 149 | class SquadExample(object): 150 | """A single training/test example for simple sequence classification.""" 151 | 152 | def __init__(self, 153 | qas_id, 154 | question_text, 155 | doc_tokens, 156 | orig_answer_text=None, 157 | start_position=None, 158 | end_position=None): 159 | self.qas_id = qas_id 160 | self.question_text = question_text 161 | self.doc_tokens = doc_tokens 162 | self.orig_answer_text = orig_answer_text 163 | self.start_position = start_position 164 | self.end_position = end_position 165 | 166 | def __str__(self): 167 | return self.__repr__() 168 | 169 | def __repr__(self): 170 | s = "" 171 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 172 | s += ", question_text: %s" % ( 173 | tokenization.printable_text(self.question_text)) 174 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 175 | if self.start_position: 176 | s += ", start_position: %d" % (self.start_position) 177 | if self.start_position: 178 | s += ", end_position: %d" % (self.end_position) 179 | return s 180 | 181 | 182 | class InputFeatures(object): 183 | """A single set of features of data.""" 184 | 185 | def __init__(self, 186 | unique_id, 187 | example_index, 188 | doc_span_index, 189 | tokens, 190 | token_to_orig_map, 191 | token_is_max_context, 192 | input_ids, 193 | input_mask, 194 | segment_ids, 195 | start_position=None, 196 | end_position=None): 197 | self.unique_id = unique_id 198 | self.example_index = example_index 199 | self.doc_span_index = doc_span_index 200 | self.tokens = tokens 201 | self.token_to_orig_map = token_to_orig_map 202 | self.token_is_max_context = token_is_max_context 203 | self.input_ids = input_ids 204 | self.input_mask = input_mask 205 | self.segment_ids = segment_ids 206 | self.start_position = start_position 207 | self.end_position = end_position 208 | 209 | 210 | def read_squad_examples(input_file, is_training): 211 | """Read a SQuAD json file into a list of SquadExample.""" 212 | with tf.gfile.Open(input_file, "r") as reader: 213 | input_data = json.load(reader)["data"] 214 | 215 | def is_whitespace(c): 216 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 217 | return True 218 | return False 219 | 220 | examples = [] 221 | for entry in input_data: 222 | for paragraph in entry["paragraphs"]: 223 | paragraph_text = paragraph["context"] 224 | doc_tokens = [] 225 | char_to_word_offset = [] 226 | prev_is_whitespace = True 227 | for c in paragraph_text: 228 | if is_whitespace(c): 229 | prev_is_whitespace = True 230 | else: 231 | if prev_is_whitespace: 232 | doc_tokens.append(c) 233 | else: 234 | doc_tokens[-1] += c 235 | prev_is_whitespace = False 236 | char_to_word_offset.append(len(doc_tokens) - 1) 237 | 238 | for qa in paragraph["qas"]: 239 | qas_id = qa["id"] 240 | question_text = qa["question"] 241 | start_position = None 242 | end_position = None 243 | orig_answer_text = None 244 | if is_training: 245 | if len(qa["answers"]) != 1: 246 | raise ValueError( 247 | "For training, each question should have exactly 1 answer.") 248 | answer = qa["answers"][0] 249 | orig_answer_text = answer["text"] 250 | answer_offset = answer["answer_start"] 251 | answer_length = len(orig_answer_text) 252 | start_position = char_to_word_offset[answer_offset] 253 | end_position = char_to_word_offset[answer_offset + answer_length - 1] 254 | # Only add answers where the text can be exactly recovered from the 255 | # document. If this CAN'T happen it's likely due to weird Unicode 256 | # stuff so we will just skip the example. 257 | # 258 | # Note that this means for training mode, every example is NOT 259 | # guaranteed to be preserved. 260 | actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) 261 | cleaned_answer_text = " ".join( 262 | tokenization.whitespace_tokenize(orig_answer_text)) 263 | if actual_text.find(cleaned_answer_text) == -1: 264 | tf.logging.warning("Could not find answer: '%s' vs. '%s'", 265 | actual_text, cleaned_answer_text) 266 | continue 267 | 268 | example = SquadExample( 269 | qas_id=qas_id, 270 | question_text=question_text, 271 | doc_tokens=doc_tokens, 272 | orig_answer_text=orig_answer_text, 273 | start_position=start_position, 274 | end_position=end_position) 275 | examples.append(example) 276 | return examples 277 | 278 | 279 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 280 | doc_stride, max_query_length, is_training, 281 | output_fn): 282 | """Loads a data file into a list of `InputBatch`s.""" 283 | 284 | unique_id = 1000000000 285 | 286 | for (example_index, example) in enumerate(examples): 287 | query_tokens = tokenizer.tokenize(example.question_text) 288 | 289 | if len(query_tokens) > max_query_length: 290 | query_tokens = query_tokens[0:max_query_length] 291 | 292 | tok_to_orig_index = [] 293 | orig_to_tok_index = [] 294 | all_doc_tokens = [] 295 | for (i, token) in enumerate(example.doc_tokens): 296 | orig_to_tok_index.append(len(all_doc_tokens)) 297 | sub_tokens = tokenizer.tokenize(token) 298 | for sub_token in sub_tokens: 299 | tok_to_orig_index.append(i) 300 | all_doc_tokens.append(sub_token) 301 | 302 | tok_start_position = None 303 | tok_end_position = None 304 | if is_training: 305 | tok_start_position = orig_to_tok_index[example.start_position] 306 | if example.end_position < len(example.doc_tokens) - 1: 307 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 308 | else: 309 | tok_end_position = len(all_doc_tokens) - 1 310 | (tok_start_position, tok_end_position) = _improve_answer_span( 311 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 312 | example.orig_answer_text) 313 | 314 | # The -3 accounts for [CLS], [SEP] and [SEP] 315 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 316 | 317 | # We can have documents that are longer than the maximum sequence length. 318 | # To deal with this we do a sliding window approach, where we take chunks 319 | # of the up to our max length with a stride of `doc_stride`. 320 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 321 | "DocSpan", ["start", "length"]) 322 | doc_spans = [] 323 | start_offset = 0 324 | while start_offset < len(all_doc_tokens): 325 | length = len(all_doc_tokens) - start_offset 326 | if length > max_tokens_for_doc: 327 | length = max_tokens_for_doc 328 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 329 | if start_offset + length == len(all_doc_tokens): 330 | break 331 | start_offset += min(length, doc_stride) 332 | 333 | for (doc_span_index, doc_span) in enumerate(doc_spans): 334 | tokens = [] 335 | token_to_orig_map = {} 336 | token_is_max_context = {} 337 | segment_ids = [] 338 | tokens.append("[CLS]") 339 | segment_ids.append(0) 340 | for token in query_tokens: 341 | tokens.append(token) 342 | segment_ids.append(0) 343 | tokens.append("[SEP]") 344 | segment_ids.append(0) 345 | 346 | for i in range(doc_span.length): 347 | split_token_index = doc_span.start + i 348 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 349 | 350 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 351 | split_token_index) 352 | token_is_max_context[len(tokens)] = is_max_context 353 | tokens.append(all_doc_tokens[split_token_index]) 354 | segment_ids.append(1) 355 | tokens.append("[SEP]") 356 | segment_ids.append(1) 357 | 358 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 359 | 360 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 361 | # tokens are attended to. 362 | input_mask = [1] * len(input_ids) 363 | 364 | # Zero-pad up to the sequence length. 365 | while len(input_ids) < max_seq_length: 366 | input_ids.append(0) 367 | input_mask.append(0) 368 | segment_ids.append(0) 369 | 370 | assert len(input_ids) == max_seq_length 371 | assert len(input_mask) == max_seq_length 372 | assert len(segment_ids) == max_seq_length 373 | 374 | start_position = None 375 | end_position = None 376 | if is_training: 377 | # For training, if our document chunk does not contain an annotation 378 | # we throw it out, since there is nothing to predict. 379 | doc_start = doc_span.start 380 | doc_end = doc_span.start + doc_span.length - 1 381 | if (example.start_position < doc_start or 382 | example.end_position < doc_start or 383 | example.start_position > doc_end or example.end_position > doc_end): 384 | continue 385 | 386 | doc_offset = len(query_tokens) + 2 387 | start_position = tok_start_position - doc_start + doc_offset 388 | end_position = tok_end_position - doc_start + doc_offset 389 | 390 | if example_index < 20: 391 | tf.logging.info("*** Example ***") 392 | tf.logging.info("unique_id: %s" % (unique_id)) 393 | tf.logging.info("example_index: %s" % (example_index)) 394 | tf.logging.info("doc_span_index: %s" % (doc_span_index)) 395 | tf.logging.info("tokens: %s" % " ".join( 396 | [tokenization.printable_text(x) for x in tokens])) 397 | tf.logging.info("token_to_orig_map: %s" % " ".join( 398 | ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) 399 | tf.logging.info("token_is_max_context: %s" % " ".join([ 400 | "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) 401 | ])) 402 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 403 | tf.logging.info( 404 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 405 | tf.logging.info( 406 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 407 | if is_training: 408 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 409 | tf.logging.info("start_position: %d" % (start_position)) 410 | tf.logging.info("end_position: %d" % (end_position)) 411 | tf.logging.info( 412 | "answer: %s" % (tokenization.printable_text(answer_text))) 413 | 414 | feature = InputFeatures( 415 | unique_id=unique_id, 416 | example_index=example_index, 417 | doc_span_index=doc_span_index, 418 | tokens=tokens, 419 | token_to_orig_map=token_to_orig_map, 420 | token_is_max_context=token_is_max_context, 421 | input_ids=input_ids, 422 | input_mask=input_mask, 423 | segment_ids=segment_ids, 424 | start_position=start_position, 425 | end_position=end_position) 426 | 427 | # Run callback 428 | output_fn(feature) 429 | 430 | unique_id += 1 431 | 432 | 433 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 434 | orig_answer_text): 435 | """Returns tokenized answer spans that better match the annotated answer.""" 436 | 437 | # The SQuAD annotations are character based. We first project them to 438 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 439 | # often find a "better match". For example: 440 | # 441 | # Question: What year was John Smith born? 442 | # Context: The leader was John Smith (1895-1943). 443 | # Answer: 1895 444 | # 445 | # The original whitespace-tokenized answer will be "(1895-1943).". However 446 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 447 | # the exact answer, 1895. 448 | # 449 | # However, this is not always possible. Consider the following: 450 | # 451 | # Question: What country is the top exporter of electornics? 452 | # Context: The Japanese electronics industry is the lagest in the world. 453 | # Answer: Japan 454 | # 455 | # In this case, the annotator chose "Japan" as a character sub-span of 456 | # the word "Japanese". Since our WordPiece tokenizer does not split 457 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 458 | # in SQuAD, but does happen. 459 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 460 | 461 | for new_start in range(input_start, input_end + 1): 462 | for new_end in range(input_end, new_start - 1, -1): 463 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 464 | if text_span == tok_answer_text: 465 | return (new_start, new_end) 466 | 467 | return (input_start, input_end) 468 | 469 | 470 | def _check_is_max_context(doc_spans, cur_span_index, position): 471 | """Check if this is the 'max context' doc span for the token.""" 472 | 473 | # Because of the sliding window approach taken to scoring documents, a single 474 | # token can appear in multiple documents. E.g. 475 | # Doc: the man went to the store and bought a gallon of milk 476 | # Span A: the man went to the 477 | # Span B: to the store and bought 478 | # Span C: and bought a gallon of 479 | # ... 480 | # 481 | # Now the word 'bought' will have two scores from spans B and C. We only 482 | # want to consider the score with "maximum context", which we define as 483 | # the *minimum* of its left and right context (the *sum* of left and 484 | # right context will always be the same, of course). 485 | # 486 | # In the example the maximum context for 'bought' would be span C since 487 | # it has 1 left context and 3 right context, while span B has 4 left context 488 | # and 0 right context. 489 | best_score = None 490 | best_span_index = None 491 | for (span_index, doc_span) in enumerate(doc_spans): 492 | end = doc_span.start + doc_span.length - 1 493 | if position < doc_span.start: 494 | continue 495 | if position > end: 496 | continue 497 | num_left_context = position - doc_span.start 498 | num_right_context = end - position 499 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 500 | if best_score is None or score > best_score: 501 | best_score = score 502 | best_span_index = span_index 503 | 504 | return cur_span_index == best_span_index 505 | 506 | 507 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 508 | use_one_hot_embeddings): 509 | """Creates a classification model.""" 510 | model = modeling.BertModel( 511 | config=bert_config, 512 | is_training=is_training, 513 | input_ids=input_ids, 514 | input_mask=input_mask, 515 | token_type_ids=segment_ids, 516 | use_one_hot_embeddings=use_one_hot_embeddings) 517 | 518 | final_hidden = model.get_sequence_output() 519 | 520 | final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) 521 | batch_size = final_hidden_shape[0] 522 | seq_length = final_hidden_shape[1] 523 | hidden_size = final_hidden_shape[2] 524 | 525 | output_weights = tf.get_variable( 526 | "cls/squad/output_weights", [2, hidden_size], 527 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 528 | 529 | output_bias = tf.get_variable( 530 | "cls/squad/output_bias", [2], initializer=tf.zeros_initializer()) 531 | 532 | final_hidden_matrix = tf.reshape(final_hidden, 533 | [batch_size * seq_length, hidden_size]) 534 | logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) 535 | logits = tf.nn.bias_add(logits, output_bias) 536 | 537 | logits = tf.reshape(logits, [batch_size, seq_length, 2]) 538 | logits = tf.transpose(logits, [2, 0, 1]) 539 | 540 | unstacked_logits = tf.unstack(logits, axis=0) 541 | 542 | (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1]) 543 | 544 | return (start_logits, end_logits) 545 | 546 | 547 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 548 | num_train_steps, num_warmup_steps, use_tpu, 549 | use_one_hot_embeddings): 550 | """Returns `model_fn` closure for TPUEstimator.""" 551 | 552 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 553 | """The `model_fn` for TPUEstimator.""" 554 | 555 | tf.logging.info("*** Features ***") 556 | for name in sorted(features.keys()): 557 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 558 | 559 | unique_ids = features["unique_ids"] 560 | input_ids = features["input_ids"] 561 | input_mask = features["input_mask"] 562 | segment_ids = features["segment_ids"] 563 | 564 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 565 | 566 | (start_logits, end_logits) = create_model( 567 | bert_config=bert_config, 568 | is_training=is_training, 569 | input_ids=input_ids, 570 | input_mask=input_mask, 571 | segment_ids=segment_ids, 572 | use_one_hot_embeddings=use_one_hot_embeddings) 573 | 574 | tvars = tf.trainable_variables() 575 | 576 | initialized_variable_names = {} 577 | scaffold_fn = None 578 | if init_checkpoint: 579 | (assignment_map, initialized_variable_names 580 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 581 | if use_tpu: 582 | 583 | def tpu_scaffold(): 584 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 585 | return tf.train.Scaffold() 586 | 587 | scaffold_fn = tpu_scaffold 588 | else: 589 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 590 | 591 | tf.logging.info("**** Trainable Variables ****") 592 | for var in tvars: 593 | init_string = "" 594 | if var.name in initialized_variable_names: 595 | init_string = ", *INIT_FROM_CKPT*" 596 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 597 | init_string) 598 | 599 | output_spec = None 600 | if mode == tf.estimator.ModeKeys.TRAIN: 601 | seq_length = modeling.get_shape_list(input_ids)[1] 602 | 603 | def compute_loss(logits, positions): 604 | one_hot_positions = tf.one_hot( 605 | positions, depth=seq_length, dtype=tf.float32) 606 | log_probs = tf.nn.log_softmax(logits, axis=-1) 607 | loss = -tf.reduce_mean( 608 | tf.reduce_sum(one_hot_positions * log_probs, axis=-1)) 609 | return loss 610 | 611 | start_positions = features["start_positions"] 612 | end_positions = features["end_positions"] 613 | 614 | start_loss = compute_loss(start_logits, start_positions) 615 | end_loss = compute_loss(end_logits, end_positions) 616 | 617 | total_loss = (start_loss + end_loss) / 2.0 618 | 619 | train_op = optimization.create_optimizer( 620 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 621 | 622 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 623 | mode=mode, 624 | loss=total_loss, 625 | train_op=train_op, 626 | scaffold_fn=scaffold_fn) 627 | elif mode == tf.estimator.ModeKeys.PREDICT: 628 | predictions = { 629 | "unique_ids": unique_ids, 630 | "start_logits": start_logits, 631 | "end_logits": end_logits, 632 | } 633 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 634 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 635 | else: 636 | raise ValueError( 637 | "Only TRAIN and PREDICT modes are supported: %s" % (mode)) 638 | 639 | return output_spec 640 | 641 | return model_fn 642 | 643 | 644 | def input_fn_builder(input_file, seq_length, is_training, drop_remainder): 645 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 646 | 647 | name_to_features = { 648 | "unique_ids": tf.FixedLenFeature([], tf.int64), 649 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 650 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 651 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 652 | } 653 | 654 | if is_training: 655 | name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64) 656 | name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64) 657 | 658 | def _decode_record(record, name_to_features): 659 | """Decodes a record to a TensorFlow example.""" 660 | example = tf.parse_single_example(record, name_to_features) 661 | 662 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 663 | # So cast all int64 to int32. 664 | for name in list(example.keys()): 665 | t = example[name] 666 | if t.dtype == tf.int64: 667 | t = tf.to_int32(t) 668 | example[name] = t 669 | 670 | return example 671 | 672 | def input_fn(params): 673 | """The actual input function.""" 674 | batch_size = params["batch_size"] 675 | 676 | # For training, we want a lot of parallel reading and shuffling. 677 | # For eval, we want no shuffling and parallel reading doesn't matter. 678 | d = tf.data.TFRecordDataset(input_file) 679 | if is_training: 680 | d = d.repeat() 681 | d = d.shuffle(buffer_size=100) 682 | 683 | d = d.apply( 684 | tf.contrib.data.map_and_batch( 685 | lambda record: _decode_record(record, name_to_features), 686 | batch_size=batch_size, 687 | drop_remainder=drop_remainder)) 688 | 689 | return d 690 | 691 | return input_fn 692 | 693 | 694 | RawResult = collections.namedtuple("RawResult", 695 | ["unique_id", "start_logits", "end_logits"]) 696 | 697 | 698 | def write_predictions(all_examples, all_features, all_results, n_best_size, 699 | max_answer_length, do_lower_case, output_prediction_file, 700 | output_nbest_file): 701 | """Write final predictions to the json file.""" 702 | tf.logging.info("Writing predictions to: %s" % (output_prediction_file)) 703 | tf.logging.info("Writing nbest to: %s" % (output_nbest_file)) 704 | 705 | example_index_to_features = collections.defaultdict(list) 706 | for feature in all_features: 707 | example_index_to_features[feature.example_index].append(feature) 708 | 709 | unique_id_to_result = {} 710 | for result in all_results: 711 | unique_id_to_result[result.unique_id] = result 712 | 713 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 714 | "PrelimPrediction", 715 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 716 | 717 | all_predictions = collections.OrderedDict() 718 | all_nbest_json = collections.OrderedDict() 719 | for (example_index, example) in enumerate(all_examples): 720 | features = example_index_to_features[example_index] 721 | 722 | prelim_predictions = [] 723 | for (feature_index, feature) in enumerate(features): 724 | result = unique_id_to_result[feature.unique_id] 725 | 726 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 727 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 728 | for start_index in start_indexes: 729 | for end_index in end_indexes: 730 | # We could hypothetically create invalid predictions, e.g., predict 731 | # that the start of the span is in the question. We throw out all 732 | # invalid predictions. 733 | if start_index >= len(feature.tokens): 734 | continue 735 | if end_index >= len(feature.tokens): 736 | continue 737 | if start_index not in feature.token_to_orig_map: 738 | continue 739 | if end_index not in feature.token_to_orig_map: 740 | continue 741 | if not feature.token_is_max_context.get(start_index, False): 742 | continue 743 | if end_index < start_index: 744 | continue 745 | length = end_index - start_index + 1 746 | if length > max_answer_length: 747 | continue 748 | prelim_predictions.append( 749 | _PrelimPrediction( 750 | feature_index=feature_index, 751 | start_index=start_index, 752 | end_index=end_index, 753 | start_logit=result.start_logits[start_index], 754 | end_logit=result.end_logits[end_index])) 755 | 756 | prelim_predictions = sorted( 757 | prelim_predictions, 758 | key=lambda x: (x.start_logit + x.end_logit), 759 | reverse=True) 760 | 761 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 762 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 763 | 764 | seen_predictions = {} 765 | nbest = [] 766 | for pred in prelim_predictions: 767 | if len(nbest) >= n_best_size: 768 | break 769 | feature = features[pred.feature_index] 770 | 771 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 772 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 773 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 774 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 775 | tok_text = " ".join(tok_tokens) 776 | 777 | # De-tokenize WordPieces that have been split off. 778 | tok_text = tok_text.replace(" ##", "") 779 | tok_text = tok_text.replace("##", "") 780 | 781 | # Clean whitespace 782 | tok_text = tok_text.strip() 783 | tok_text = " ".join(tok_text.split()) 784 | orig_text = " ".join(orig_tokens) 785 | 786 | final_text = get_final_text(tok_text, orig_text, do_lower_case) 787 | if final_text in seen_predictions: 788 | continue 789 | 790 | seen_predictions[final_text] = True 791 | nbest.append( 792 | _NbestPrediction( 793 | text=final_text, 794 | start_logit=pred.start_logit, 795 | end_logit=pred.end_logit)) 796 | 797 | # In very rare edge cases we could have no valid predictions. So we 798 | # just create a nonce prediction in this case to avoid failure. 799 | if not nbest: 800 | nbest.append( 801 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 802 | 803 | assert len(nbest) >= 1 804 | 805 | total_scores = [] 806 | for entry in nbest: 807 | total_scores.append(entry.start_logit + entry.end_logit) 808 | 809 | probs = _compute_softmax(total_scores) 810 | 811 | nbest_json = [] 812 | for (i, entry) in enumerate(nbest): 813 | output = collections.OrderedDict() 814 | output["text"] = entry.text 815 | output["probability"] = probs[i] 816 | output["start_logit"] = entry.start_logit 817 | output["end_logit"] = entry.end_logit 818 | nbest_json.append(output) 819 | 820 | assert len(nbest_json) >= 1 821 | 822 | all_predictions[example.qas_id] = nbest_json[0]["text"] 823 | all_nbest_json[example.qas_id] = nbest_json 824 | 825 | with tf.gfile.GFile(output_prediction_file, "w") as writer: 826 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 827 | 828 | with tf.gfile.GFile(output_nbest_file, "w") as writer: 829 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 830 | 831 | 832 | def get_final_text(pred_text, orig_text, do_lower_case): 833 | """Project the tokenized prediction back to the original text.""" 834 | 835 | # When we created the data, we kept track of the alignment between original 836 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 837 | # now `orig_text` contains the span of our original text corresponding to the 838 | # span that we predicted. 839 | # 840 | # However, `orig_text` may contain extra characters that we don't want in 841 | # our prediction. 842 | # 843 | # For example, let's say: 844 | # pred_text = steve smith 845 | # orig_text = Steve Smith's 846 | # 847 | # We don't want to return `orig_text` because it contains the extra "'s". 848 | # 849 | # We don't want to return `pred_text` because it's already been normalized 850 | # (the SQuAD eval script also does punctuation stripping/lower casing but 851 | # our tokenizer does additional normalization like stripping accent 852 | # characters). 853 | # 854 | # What we really want to return is "Steve Smith". 855 | # 856 | # Therefore, we have to apply a semi-complicated alignment heruistic between 857 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 858 | # can fail in certain cases in which case we just return `orig_text`. 859 | 860 | def _strip_spaces(text): 861 | ns_chars = [] 862 | ns_to_s_map = collections.OrderedDict() 863 | for (i, c) in enumerate(text): 864 | if c == " ": 865 | continue 866 | ns_to_s_map[len(ns_chars)] = i 867 | ns_chars.append(c) 868 | ns_text = "".join(ns_chars) 869 | return (ns_text, ns_to_s_map) 870 | 871 | # We first tokenize `orig_text`, strip whitespace from the result 872 | # and `pred_text`, and check if they are the same length. If they are 873 | # NOT the same length, the heuristic has failed. If they are the same 874 | # length, we assume the characters are one-to-one aligned. 875 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 876 | 877 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 878 | 879 | start_position = tok_text.find(pred_text) 880 | if start_position == -1: 881 | if FLAGS.verbose_logging: 882 | tf.logging.info( 883 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 884 | return orig_text 885 | end_position = start_position + len(pred_text) - 1 886 | 887 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 888 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 889 | 890 | if len(orig_ns_text) != len(tok_ns_text): 891 | if FLAGS.verbose_logging: 892 | tf.logging.info("Length not equal after stripping spaces: '%s' vs '%s'", 893 | orig_ns_text, tok_ns_text) 894 | return orig_text 895 | 896 | # We then project the characters in `pred_text` back to `orig_text` using 897 | # the character-to-character alignment. 898 | tok_s_to_ns_map = {} 899 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 900 | tok_s_to_ns_map[tok_index] = i 901 | 902 | orig_start_position = None 903 | if start_position in tok_s_to_ns_map: 904 | ns_start_position = tok_s_to_ns_map[start_position] 905 | if ns_start_position in orig_ns_to_s_map: 906 | orig_start_position = orig_ns_to_s_map[ns_start_position] 907 | 908 | if orig_start_position is None: 909 | if FLAGS.verbose_logging: 910 | tf.logging.info("Couldn't map start position") 911 | return orig_text 912 | 913 | orig_end_position = None 914 | if end_position in tok_s_to_ns_map: 915 | ns_end_position = tok_s_to_ns_map[end_position] 916 | if ns_end_position in orig_ns_to_s_map: 917 | orig_end_position = orig_ns_to_s_map[ns_end_position] 918 | 919 | if orig_end_position is None: 920 | if FLAGS.verbose_logging: 921 | tf.logging.info("Couldn't map end position") 922 | return orig_text 923 | 924 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 925 | return output_text 926 | 927 | 928 | def _get_best_indexes(logits, n_best_size): 929 | """Get the n-best logits from a list.""" 930 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 931 | 932 | best_indexes = [] 933 | for i in range(len(index_and_score)): 934 | if i >= n_best_size: 935 | break 936 | best_indexes.append(index_and_score[i][0]) 937 | return best_indexes 938 | 939 | 940 | def _compute_softmax(scores): 941 | """Compute softmax probability over raw logits.""" 942 | if not scores: 943 | return [] 944 | 945 | max_score = None 946 | for score in scores: 947 | if max_score is None or score > max_score: 948 | max_score = score 949 | 950 | exp_scores = [] 951 | total_sum = 0.0 952 | for score in scores: 953 | x = math.exp(score - max_score) 954 | exp_scores.append(x) 955 | total_sum += x 956 | 957 | probs = [] 958 | for score in exp_scores: 959 | probs.append(score / total_sum) 960 | return probs 961 | 962 | 963 | class FeatureWriter(object): 964 | """Writes InputFeature to TF example file.""" 965 | 966 | def __init__(self, filename, is_training): 967 | self.filename = filename 968 | self.is_training = is_training 969 | self.num_features = 0 970 | self._writer = tf.python_io.TFRecordWriter(filename) 971 | 972 | def process_feature(self, feature): 973 | """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" 974 | self.num_features += 1 975 | 976 | def create_int_feature(values): 977 | feature = tf.train.Feature( 978 | int64_list=tf.train.Int64List(value=list(values))) 979 | return feature 980 | 981 | features = collections.OrderedDict() 982 | features["unique_ids"] = create_int_feature([feature.unique_id]) 983 | features["input_ids"] = create_int_feature(feature.input_ids) 984 | features["input_mask"] = create_int_feature(feature.input_mask) 985 | features["segment_ids"] = create_int_feature(feature.segment_ids) 986 | 987 | if self.is_training: 988 | features["start_positions"] = create_int_feature([feature.start_position]) 989 | features["end_positions"] = create_int_feature([feature.end_position]) 990 | 991 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 992 | self._writer.write(tf_example.SerializeToString()) 993 | 994 | def close(self): 995 | self._writer.close() 996 | 997 | 998 | def validate_flags_or_throw(bert_config): 999 | """Validate the input FLAGS or throw an exception.""" 1000 | if not FLAGS.do_train and not FLAGS.do_predict: 1001 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 1002 | 1003 | if FLAGS.do_train: 1004 | if not FLAGS.train_file: 1005 | raise ValueError( 1006 | "If `do_train` is True, then `train_file` must be specified.") 1007 | if FLAGS.do_predict: 1008 | if not FLAGS.predict_file: 1009 | raise ValueError( 1010 | "If `do_predict` is True, then `predict_file` must be specified.") 1011 | 1012 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 1013 | raise ValueError( 1014 | "Cannot use sequence length %d because the BERT model " 1015 | "was only trained up to sequence length %d" % 1016 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 1017 | 1018 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: 1019 | raise ValueError( 1020 | "The max_seq_length (%d) must be greater than max_query_length " 1021 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)) 1022 | 1023 | 1024 | def main(_): 1025 | tf.logging.set_verbosity(tf.logging.INFO) 1026 | 1027 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 1028 | 1029 | validate_flags_or_throw(bert_config) 1030 | 1031 | tf.gfile.MakeDirs(FLAGS.output_dir) 1032 | 1033 | tokenizer = tokenization.FullTokenizer( 1034 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 1035 | 1036 | tpu_cluster_resolver = None 1037 | if FLAGS.use_tpu and FLAGS.tpu_name: 1038 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 1039 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 1040 | 1041 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 1042 | run_config = tf.contrib.tpu.RunConfig( 1043 | cluster=tpu_cluster_resolver, 1044 | master=FLAGS.master, 1045 | model_dir=FLAGS.output_dir, 1046 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 1047 | tpu_config=tf.contrib.tpu.TPUConfig( 1048 | iterations_per_loop=FLAGS.iterations_per_loop, 1049 | num_shards=FLAGS.num_tpu_cores, 1050 | per_host_input_for_training=is_per_host)) 1051 | 1052 | train_examples = None 1053 | num_train_steps = None 1054 | num_warmup_steps = None 1055 | if FLAGS.do_train: 1056 | train_examples = read_squad_examples( 1057 | input_file=FLAGS.train_file, is_training=True) 1058 | num_train_steps = int( 1059 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 1060 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 1061 | 1062 | # Pre-shuffle the input to avoid having to make a very large shuffle 1063 | # buffer in in the `input_fn`. 1064 | rng = random.Random(12345) 1065 | rng.shuffle(train_examples) 1066 | 1067 | model_fn = model_fn_builder( 1068 | bert_config=bert_config, 1069 | init_checkpoint=FLAGS.init_checkpoint, 1070 | learning_rate=FLAGS.learning_rate, 1071 | num_train_steps=num_train_steps, 1072 | num_warmup_steps=num_warmup_steps, 1073 | use_tpu=FLAGS.use_tpu, 1074 | use_one_hot_embeddings=FLAGS.use_tpu) 1075 | 1076 | # If TPU is not available, this will fall back to normal Estimator on CPU 1077 | # or GPU. 1078 | estimator = tf.contrib.tpu.TPUEstimator( 1079 | use_tpu=FLAGS.use_tpu, 1080 | model_fn=model_fn, 1081 | config=run_config, 1082 | train_batch_size=FLAGS.train_batch_size, 1083 | predict_batch_size=FLAGS.predict_batch_size) 1084 | 1085 | if FLAGS.do_train: 1086 | # We write to a temporary file to avoid storing very large constant tensors 1087 | # in memory. 1088 | train_writer = FeatureWriter( 1089 | filename=os.path.join(FLAGS.output_dir, "train.tf_record"), 1090 | is_training=True) 1091 | convert_examples_to_features( 1092 | examples=train_examples, 1093 | tokenizer=tokenizer, 1094 | max_seq_length=FLAGS.max_seq_length, 1095 | doc_stride=FLAGS.doc_stride, 1096 | max_query_length=FLAGS.max_query_length, 1097 | is_training=True, 1098 | output_fn=train_writer.process_feature) 1099 | train_writer.close() 1100 | 1101 | tf.logging.info("***** Running training *****") 1102 | tf.logging.info(" Num orig examples = %d", len(train_examples)) 1103 | tf.logging.info(" Num split examples = %d", train_writer.num_features) 1104 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 1105 | tf.logging.info(" Num steps = %d", num_train_steps) 1106 | del train_examples 1107 | 1108 | train_input_fn = input_fn_builder( 1109 | input_file=train_writer.filename, 1110 | seq_length=FLAGS.max_seq_length, 1111 | is_training=True, 1112 | drop_remainder=True) 1113 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 1114 | 1115 | if FLAGS.do_predict: 1116 | eval_examples = read_squad_examples( 1117 | input_file=FLAGS.predict_file, is_training=False) 1118 | 1119 | eval_writer = FeatureWriter( 1120 | filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), 1121 | is_training=False) 1122 | eval_features = [] 1123 | 1124 | def append_feature(feature): 1125 | eval_features.append(feature) 1126 | eval_writer.process_feature(feature) 1127 | 1128 | convert_examples_to_features( 1129 | examples=eval_examples, 1130 | tokenizer=tokenizer, 1131 | max_seq_length=FLAGS.max_seq_length, 1132 | doc_stride=FLAGS.doc_stride, 1133 | max_query_length=FLAGS.max_query_length, 1134 | is_training=False, 1135 | output_fn=append_feature) 1136 | eval_writer.close() 1137 | 1138 | tf.logging.info("***** Running predictions *****") 1139 | tf.logging.info(" Num orig examples = %d", len(eval_examples)) 1140 | tf.logging.info(" Num split examples = %d", len(eval_features)) 1141 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 1142 | 1143 | all_results = [] 1144 | 1145 | predict_input_fn = input_fn_builder( 1146 | input_file=eval_writer.filename, 1147 | seq_length=FLAGS.max_seq_length, 1148 | is_training=False, 1149 | drop_remainder=False) 1150 | 1151 | # If running eval on the TPU, you will need to specify the number of 1152 | # steps. 1153 | all_results = [] 1154 | for result in estimator.predict( 1155 | predict_input_fn, yield_single_examples=True): 1156 | if len(all_results) % 1000 == 0: 1157 | tf.logging.info("Processing example: %d" % (len(all_results))) 1158 | unique_id = int(result["unique_ids"]) 1159 | start_logits = [float(x) for x in result["start_logits"].flat] 1160 | end_logits = [float(x) for x in result["end_logits"].flat] 1161 | all_results.append( 1162 | RawResult( 1163 | unique_id=unique_id, 1164 | start_logits=start_logits, 1165 | end_logits=end_logits)) 1166 | 1167 | output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json") 1168 | output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json") 1169 | write_predictions(eval_examples, eval_features, all_results, 1170 | FLAGS.n_best_size, FLAGS.max_answer_length, 1171 | FLAGS.do_lower_case, output_prediction_file, 1172 | output_nbest_file) 1173 | 1174 | 1175 | if __name__ == "__main__": 1176 | flags.mark_flag_as_required("vocab_file") 1177 | flags.mark_flag_as_required("bert_config_file") 1178 | flags.mark_flag_as_required("output_dir") 1179 | tf.app.run() 1180 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_by_vocab(vocab, items): 86 | """Converts a sequence of [tokens|ids] using the vocab.""" 87 | output = [] 88 | for item in items: 89 | output.append(vocab[item]) 90 | return output 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | return convert_by_vocab(vocab, tokens) 95 | 96 | 97 | def convert_ids_to_tokens(inv_vocab, ids): 98 | return convert_by_vocab(inv_vocab, ids) 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab = load_vocab(vocab_file) 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | 119 | def tokenize(self, text): 120 | split_tokens = [] 121 | for token in self.basic_tokenizer.tokenize(text): 122 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 123 | split_tokens.append(sub_token) 124 | 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | return convert_by_vocab(self.vocab, tokens) 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | return convert_by_vocab(self.inv_vocab, ids) 132 | 133 | 134 | class BasicTokenizer(object): 135 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 136 | 137 | def __init__(self, do_lower_case=True): 138 | """Constructs a BasicTokenizer. 139 | 140 | Args: 141 | do_lower_case: Whether to lower case the input. 142 | """ 143 | self.do_lower_case = do_lower_case 144 | 145 | def tokenize(self, text): 146 | """Tokenizes a piece of text.""" 147 | text = convert_to_unicode(text) 148 | text = self._clean_text(text) 149 | 150 | # This was added on November 1st, 2018 for the multilingual and Chinese 151 | # models. This is also applied to the English models now, but it doesn't 152 | # matter since the English models were not trained on any Chinese data 153 | # and generally don't have any Chinese data in them (there are Chinese 154 | # characters in the vocabulary because Wikipedia does have some Chinese 155 | # words in the English Wikipedia.). 156 | text = self._tokenize_chinese_chars(text) 157 | 158 | orig_tokens = whitespace_tokenize(text) 159 | split_tokens = [] 160 | for token in orig_tokens: 161 | if self.do_lower_case: 162 | token = token.lower() 163 | token = self._run_strip_accents(token) 164 | split_tokens.extend(self._run_split_on_punc(token)) 165 | 166 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 167 | return output_tokens 168 | 169 | def _run_strip_accents(self, text): 170 | """Strips accents from a piece of text.""" 171 | text = unicodedata.normalize("NFD", text) 172 | output = [] 173 | for char in text: 174 | cat = unicodedata.category(char) 175 | if cat == "Mn": 176 | continue 177 | output.append(char) 178 | return "".join(output) 179 | 180 | def _run_split_on_punc(self, text): 181 | """Splits punctuation on a piece of text.""" 182 | chars = list(text) 183 | i = 0 184 | start_new_word = True 185 | output = [] 186 | while i < len(chars): 187 | char = chars[i] 188 | if _is_punctuation(char): 189 | output.append([char]) 190 | start_new_word = True 191 | else: 192 | if start_new_word: 193 | output.append([]) 194 | start_new_word = False 195 | output[-1].append(char) 196 | i += 1 197 | 198 | return ["".join(x) for x in output] 199 | 200 | def _tokenize_chinese_chars(self, text): 201 | """Adds whitespace around any CJK character.""" 202 | output = [] 203 | for char in text: 204 | cp = ord(char) 205 | if self._is_chinese_char(cp): 206 | output.append(" ") 207 | output.append(char) 208 | output.append(" ") 209 | else: 210 | output.append(char) 211 | return "".join(output) 212 | 213 | def _is_chinese_char(self, cp): 214 | """Checks whether CP is the codepoint of a CJK character.""" 215 | # This defines a "chinese character" as anything in the CJK Unicode block: 216 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 217 | # 218 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 219 | # despite its name. The modern Korean Hangul alphabet is a different block, 220 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 221 | # space-separated words, so they are not treated specially and handled 222 | # like the all of the other languages. 223 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 224 | (cp >= 0x3400 and cp <= 0x4DBF) or # 225 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 226 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 227 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 228 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 229 | (cp >= 0xF900 and cp <= 0xFAFF) or # 230 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 231 | return True 232 | 233 | return False 234 | 235 | def _clean_text(self, text): 236 | """Performs invalid character removal and whitespace cleanup on text.""" 237 | output = [] 238 | for char in text: 239 | cp = ord(char) 240 | if cp == 0 or cp == 0xfffd or _is_control(char): 241 | continue 242 | if _is_whitespace(char): 243 | output.append(" ") 244 | else: 245 | output.append(char) 246 | return "".join(output) 247 | 248 | 249 | class WordpieceTokenizer(object): 250 | """Runs WordPiece tokenziation.""" 251 | 252 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 253 | self.vocab = vocab 254 | self.unk_token = unk_token 255 | self.max_input_chars_per_word = max_input_chars_per_word 256 | 257 | def tokenize(self, text): 258 | """Tokenizes a piece of text into its word pieces. 259 | 260 | This uses a greedy longest-match-first algorithm to perform tokenization 261 | using the given vocabulary. 262 | 263 | For example: 264 | input = "unaffable" 265 | output = ["un", "##aff", "##able"] 266 | 267 | Args: 268 | text: A single token or whitespace separated tokens. This should have 269 | already been passed through `BasicTokenizer. 270 | 271 | Returns: 272 | A list of wordpiece tokens. 273 | """ 274 | 275 | text = convert_to_unicode(text) 276 | 277 | output_tokens = [] 278 | for token in whitespace_tokenize(text): 279 | chars = list(token) 280 | if len(chars) > self.max_input_chars_per_word: 281 | output_tokens.append(self.unk_token) 282 | continue 283 | 284 | is_bad = False 285 | start = 0 286 | sub_tokens = [] 287 | while start < len(chars): 288 | end = len(chars) 289 | cur_substr = None 290 | while start < end: 291 | substr = "".join(chars[start:end]) 292 | if start > 0: 293 | substr = "##" + substr 294 | if substr in self.vocab: 295 | cur_substr = substr 296 | break 297 | end -= 1 298 | if cur_substr is None: 299 | is_bad = True 300 | break 301 | sub_tokens.append(cur_substr) 302 | start = end 303 | 304 | if is_bad: 305 | output_tokens.append(self.unk_token) 306 | else: 307 | output_tokens.extend(sub_tokens) 308 | return output_tokens 309 | 310 | 311 | def _is_whitespace(char): 312 | """Checks whether `chars` is a whitespace character.""" 313 | # \t, \n, and \r are technically contorl characters but we treat them 314 | # as whitespace since they are generally considered as such. 315 | if char == " " or char == "\t" or char == "\n" or char == "\r": 316 | return True 317 | cat = unicodedata.category(char) 318 | if cat == "Zs": 319 | return True 320 | return False 321 | 322 | 323 | def _is_control(char): 324 | """Checks whether `chars` is a control character.""" 325 | # These are technically control characters but we count them as whitespace 326 | # characters. 327 | if char == "\t" or char == "\n" or char == "\r": 328 | return False 329 | cat = unicodedata.category(char) 330 | if cat.startswith("C"): 331 | return True 332 | return False 333 | 334 | 335 | def _is_punctuation(char): 336 | """Checks whether `chars` is a punctuation character.""" 337 | cp = ord(char) 338 | # We treat all non-letter/number ASCII as punctuation. 339 | # Characters such as "^", "$", and "`" are not in the Unicode 340 | # Punctuation class but we treat them as punctuation anyways, for 341 | # consistency. 342 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 343 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 344 | return True 345 | cat = unicodedata.category(char) 346 | if cat.startswith("P"): 347 | return True 348 | return False 349 | --------------------------------------------------------------------------------