├── README.md ├── msmarco_doc_preprocess ├── preprocessor_multipreprocess_msmarco_doc.py └── tokenization_msmarco_doc.py └── msmarco_doc_train ├── modeling ├── create_model_bison.py ├── modeling_bison.py └── optimization_nvidia.py ├── train_msmarco_doc.py └── utils ├── gpu_environment.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # BISON : BM25-weighted Self-Attention Network for Multi-Fields Document Search 2 | This is the impletement of paper BISON : BM25-weighted Self-Attention Network for Multi-Fields Document Search https://arxiv.org/abs/2007.05186. Taking MS Marco Document Ranking task as an example. 3 | 4 | ## Getting Started 5 | This version is built with a distributed training with horovod approach 6 | 7 | ## Prerequisites 8 | ``` 9 | tensorflow>=1.14.0 10 | horovod 11 | ``` 12 | 13 | ## Running 14 | The folder "msmarco_doc_preprocess" is responsible for preprocess data of MS Marco. 15 | The folder "msmarco_doc_train" is used to train the data with BISON. 16 | Entrance file is train_msmarco_doc.py 17 | 18 | ### Data Preprocess Example command 19 | ``` 20 | mpirun -np 8 python preprocessor_multipreprocess_msmarco_doc.py \ 21 | --task_name=BISON \ 22 | --data_dir=your_data_folder_path \ 23 | --data_file=your_data_file_name \ 24 | --data_line_count=your_data_line_count \ 25 | --output_dir=your_output_folder_path \ 26 | --vocab_file=your_vocab_file \ 27 | --full_word_idf_file=your_idf_file \ 28 | --default_idf=your_default_idf \ 29 | --do_lower_case=True \ 30 | --max_seq_length_query=20 --max_seq_length_url=30 --max_seq_length_title=30 \ 31 | --label_col=0 --src_col=2 --url_col=4 --title_col=5 \ 32 | --BM25_K1=0.25 --BM25_B=0.4 --BM25_AVGDL_Q=6 --BM25_AVGDL_D=25 33 | ``` 34 | 35 | ### Training Example command 36 | ``` 37 | TF_XLA_FLAGS=--tf_xla_auto_jit=2 mpirun --verbose --allow-run-as-root -np 8 --display-map -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x HOROVOD_GPU_ALLREDUCE=NCCL -x PATH -mca pml ob1 -mca btl self,tcp,openib -mca btl_tcp_if_exclude lo,docker0 \ 38 | python train_msmarco_doc.py \ 39 | --task_name=BISON \ 40 | --do_train=True \ 41 | --preprocess_train_dir=your_preprocessed_training_data_folder_path \ 42 | --train_line_count=your_data_line_count \ 43 | --train_partition_count=your_data_partition_count \ 44 | --preprocess_train_file_name=your_preprocess_train_file_name \ 45 | --preprocess_eval_dir=your_preprocessed_eval_data_folder_path \ 46 | --output_dir=your_output_folder_path \ 47 | --query_bert_config_file=your_query_config \ 48 | --meta_bert_config_file=your_meta_config \ 49 | --max_seq_length_query=20 --max_seq_length_url=30 --max_seq_length_title=30 \ 50 | --train_batch_size=512 --learning_rate=8e-5 --num_train_epochs=5 --save_checkpoints_steps=800 --eval_batch_size=500 \ 51 | --nce_temperature=1000 --nce_weight=0.5 \ 52 | --use_fp16=True --use_xla=True --horovod=True --use_one_hot_embeddings=False --verbose_logging=True 53 | ``` 54 | -------------------------------------------------------------------------------- /msmarco_doc_preprocess/preprocessor_multipreprocess_msmarco_doc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import sys 24 | import os 25 | import tokenization_msmarco_doc as tokenization 26 | import tensorflow as tf 27 | from tensorflow.python.client import device_lib 28 | import horovod.tensorflow as hvd 29 | import codecs 30 | csv.field_size_limit(sys.maxsize) 31 | flags = tf.flags 32 | FLAGS = flags.FLAGS 33 | 34 | ## Required parameters 35 | flags.DEFINE_string("task_name", None, "The name of the task.") 36 | flags.DEFINE_string("data_dir", None, "The input data dir. Should contain the .tsv files (or other data files) for the task.") 37 | flags.DEFINE_string("data_file", None, "Data file name.") 38 | flags.DEFINE_integer("data_line_count", None, "Data file line count.") 39 | flags.DEFINE_string("output_dir", None, "The output directory where the preprocessed data will be written.") 40 | 41 | flags.DEFINE_string("vocab_file", None, "The vocabulary file.") 42 | flags.DEFINE_string("full_word_idf_file", None, "The idf file that used on full word.") 43 | flags.DEFINE_float("default_idf", None, "The default value for words not in idf dict.") 44 | 45 | flags.DEFINE_bool( 46 | "do_lower_case", True, 47 | "Whether to lower case the input text. Should be True for uncased " 48 | "models and False for cased models.") 49 | 50 | flags.DEFINE_integer( 51 | "max_seq_length_query", 20, 52 | "The maximum total input sequence length after WordPiece tokenization. " 53 | "Sequences longer than this will be truncated, and sequences shorter " 54 | "than this will be padded.") 55 | 56 | flags.DEFINE_integer( 57 | "max_seq_length_url", 30, 58 | "The maximum total input sequence length after WordPiece tokenization. " 59 | "Sequences longer than this will be truncated, and sequences shorter " 60 | "than this will be padded.") 61 | 62 | flags.DEFINE_integer( 63 | "max_seq_length_title", 30, 64 | "The maximum total input sequence length after WordPiece tokenization. " 65 | "Sequences longer than this will be truncated, and sequences shorter " 66 | "than this will be padded.") 67 | 68 | flags.DEFINE_integer( 69 | "max_seq_length_body", 128, 70 | "The maximum total input sequence length after WordPiece tokenization. " 71 | "Sequences longer than this will be truncated, and sequences shorter " 72 | "than this will be padded.") 73 | 74 | flags.DEFINE_integer("src_col", 2, "src_col") 75 | flags.DEFINE_integer("url_col", 4, "url_col") 76 | flags.DEFINE_integer("title_col", 5, "title_col") 77 | flags.DEFINE_integer("body_col", -1, "body_col") 78 | flags.DEFINE_integer("label_col", -1, "label_col") 79 | 80 | flags.DEFINE_float("BM25_K1", 0.25, "K1 in BM25.") 81 | flags.DEFINE_float("BM25_B", 0.4, "B in BM25.") 82 | flags.DEFINE_float("BM25_AVGDL_Q", 6, "AVG Query Len in BM25.") 83 | flags.DEFINE_float("BM25_AVGDL_D", 25, "AVG Doc Len in BM25.") 84 | 85 | class InputExample(object): 86 | """A single training/test example for simple sequence classification.""" 87 | 88 | def __init__(self, guid, query, url, title, label=None, body=None): 89 | """Constructs a InputExample. 90 | 91 | Args: 92 | guid: Unique id for the example. 93 | text_a: string. The untokenized text of the first sequence. For single 94 | sequence tasks, only this sequence must be specified. 95 | text_b: (Optional) string. The untokenized text of the second sequence. 96 | Only must be specified for sequence pair tasks. 97 | label: (Optional) string. The label of the example. This should be 98 | specified for train and dev examples, but not for test examples. 99 | """ 100 | self.guid = guid 101 | self.query = query 102 | self.url = url 103 | self.title = title 104 | self.label = label 105 | self.body = body 106 | 107 | 108 | class InputFeatures(object): 109 | """A single set of features of data.""" 110 | 111 | def __init__(self, query_input_ids, query_input_mask,query_input_idfs, query_segment_ids, meta_input_ids, meta_input_mask, meta_input_idfs, metaStream_segment_ids, label_id): 112 | self.query_input_ids = query_input_ids 113 | self.query_input_mask = query_input_mask 114 | self.query_input_idfs = query_input_idfs 115 | self.query_segment_ids = query_segment_ids 116 | self.meta_input_ids = meta_input_ids 117 | self.meta_input_idfs = meta_input_idfs 118 | self.meta_input_mask = meta_input_mask 119 | self.metaStream_segment_ids = metaStream_segment_ids 120 | self.label_id = label_id 121 | 122 | 123 | class DataProcessor(object): 124 | """Base class for data converters for sequence classification data sets.""" 125 | 126 | def get_train_examples(self, data_dir): 127 | """Gets a collection of `InputExample`s for the train set.""" 128 | raise NotImplementedError() 129 | 130 | def get_dev_examples(self, data_dir): 131 | """Gets a collection of `InputExample`s for the dev set.""" 132 | raise NotImplementedError() 133 | 134 | def get_test_examples(self, data_dir): 135 | """Gets a collection of `InputExample`s for prediction.""" 136 | raise NotImplementedError() 137 | 138 | def get_labels(self): 139 | """Gets the list of labels for this data set.""" 140 | raise NotImplementedError() 141 | 142 | @classmethod 143 | def _read_tsv(cls, input_file, quotechar=None): 144 | """Reads a tab separated value file.""" 145 | with tf.gfile.Open(input_file, "r") as f: 146 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 147 | lines = [] 148 | for line in reader: 149 | lines.append(line) 150 | return lines 151 | 152 | 153 | class BISONProcessor(DataProcessor): 154 | """Processor for the MRPC data set (GLUE version).""" 155 | 156 | def get_labels(self): 157 | """See base class.""" 158 | return ["0", "1"] 159 | 160 | 161 | def convert_single_example(ex_index, example, label_list, 162 | max_seq_length_query, max_seq_length_url, max_seq_length_title, max_seq_length_body, tokenizer): 163 | """Converts a single `InputExample` into a single `InputFeatures`.""" 164 | label_map = {} 165 | for (i, label) in enumerate(label_list): 166 | label_map[label] = i 167 | 168 | query_tokens_a = tokenizer.tokenize(example.query) 169 | 170 | metaStream_url = tokenizer.tokenize(example.url) 171 | metaStream_title = tokenizer.tokenize(example.title) 172 | if example.body is not None: 173 | metaStream_body = tokenizer.tokenize(example.body) 174 | 175 | # query format: [CLS]query tokens[SEP] 176 | if len(query_tokens_a) > max_seq_length_query - 2: 177 | query_tokens_a = query_tokens_a[0:(max_seq_length_query - 2)] 178 | # document format: [CLS]url tokens[SEP]title tokens[SEP]body tokens(optional)[SEP] 179 | if len(metaStream_url) > max_seq_length_url - 2: 180 | metaStream_url = metaStream_url[0:(max_seq_length_url - 2)] 181 | if len(metaStream_title) > max_seq_length_title - 1: 182 | metaStream_title = metaStream_title[0:(max_seq_length_title - 1)] 183 | if example.body is not None: 184 | if len(metaStream_body) > max_seq_length_body - 1: 185 | metaStream_body = metaStream_body[0:(max_seq_length_body - 1)] 186 | 187 | query_tokens = [] 188 | metaStream_tokens = [] 189 | query_segment_ids = [] 190 | metaStream_segment_ids = [] 191 | 192 | query_tokens.append("[CLS]") 193 | metaStream_tokens.append("[CLS]") 194 | query_segment_ids.append(0) 195 | metaStream_segment_ids.append(0) 196 | 197 | for q_token in query_tokens_a: 198 | query_tokens.append(q_token) 199 | query_segment_ids.append(0) 200 | query_tokens.append("[SEP]") 201 | query_segment_ids.append(0) 202 | 203 | for m_token in metaStream_url: 204 | metaStream_tokens.append(m_token) 205 | metaStream_segment_ids.append(0) 206 | metaStream_tokens.append("[SEP]") 207 | metaStream_segment_ids.append(0) 208 | 209 | for m_token in metaStream_title: 210 | metaStream_tokens.append(m_token) 211 | metaStream_segment_ids.append(1) 212 | metaStream_tokens.append("[SEP]") 213 | metaStream_segment_ids.append(1) 214 | 215 | if example.body is not None: 216 | for m_token in metaStream_body: 217 | metaStream_tokens.append(m_token) 218 | metaStream_segment_ids.append(2) 219 | metaStream_tokens.append("[SEP]") 220 | metaStream_segment_ids.append(2) 221 | 222 | query_input_ids = tokenizer.convert_tokens_to_ids(query_tokens) 223 | meta_input_ids = tokenizer.convert_tokens_to_ids(metaStream_tokens) 224 | 225 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 226 | # tokens are attended to. 227 | query_input_mask = [1] * len(query_input_ids) 228 | meta_input_mask = [1] * len(meta_input_ids) 229 | 230 | basic_tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 231 | query_words = ["[CLS]"] 232 | query_words.extend(basic_tokenizer.tokenize(example.query)) 233 | query_words.append("[SEP]") 234 | 235 | metaStream_words = ["[CLS]"] 236 | metaStream_words.extend(basic_tokenizer.tokenize(example.url)) 237 | metaStream_words.append("[SEP]") 238 | metaStream_words.extend(basic_tokenizer.tokenize(example.title)) 239 | metaStream_words.append("[SEP]") 240 | if example.body is not None: 241 | metaStream_words.extend(basic_tokenizer.tokenize(example.body)) 242 | metaStream_words.append("[SEP]") 243 | 244 | query_input_idfs = tokenizer.convert_tokens_to_bm25s_by_full_word_for_msmarco_doc(query_tokens, query_words, FLAGS.default_idf, False) 245 | meta_input_idfs = tokenizer.convert_tokens_to_bm25s_by_full_word_for_msmarco_doc(metaStream_tokens, metaStream_words, FLAGS.default_idf, True) 246 | 247 | # Zero-pad up to the sequence length. 248 | while len(query_input_ids) < max_seq_length_query: 249 | query_input_ids.append(0) 250 | query_input_mask.append(0) 251 | query_segment_ids.append(0) 252 | query_input_idfs.append(0) 253 | 254 | max_seq_length_doc = max_seq_length_url + max_seq_length_title 255 | if example.body is not None: 256 | max_seq_length_doc += max_seq_length_body 257 | while len(meta_input_ids) < max_seq_length_doc: 258 | meta_input_ids.append(0) 259 | meta_input_mask.append(0) 260 | metaStream_segment_ids.append(0) 261 | meta_input_idfs.append(0) 262 | 263 | assert len(query_input_ids) == max_seq_length_query 264 | assert len(query_input_mask) == max_seq_length_query 265 | assert len(query_segment_ids) == max_seq_length_query 266 | assert len(query_input_idfs) == max_seq_length_query 267 | assert len(meta_input_ids) == max_seq_length_doc 268 | assert len(meta_input_mask) == max_seq_length_doc 269 | assert len(metaStream_segment_ids) == max_seq_length_doc 270 | assert len(meta_input_idfs) == max_seq_length_doc 271 | 272 | label_id = label_map[example.label] 273 | if ex_index < 5: 274 | tf.logging.info("*** Example ***") 275 | tf.logging.info("guid: %s" % (example.guid)) 276 | tf.logging.info("query_tokens: %s" % " ".join( 277 | [tokenization.printable_text(x) for x in query_tokens])) 278 | tf.logging.info("query_input_ids: %s" % " ".join([str(x) for x in query_input_ids])) 279 | tf.logging.info("query_input_mask: %s" % " ".join([str(x) for x in query_input_mask])) 280 | tf.logging.info("query_segment_ids: %s" % " ".join([str(x) for x in query_segment_ids])) 281 | tf.logging.info("query_input_idfs: %s" % " ".join([str(x) for x in query_input_idfs])) 282 | tf.logging.info("metaStream_tokens: %s" % " ".join( 283 | [tokenization.printable_text(x) for x in metaStream_tokens])) 284 | tf.logging.info("meta_input_ids: %s" % " ".join([str(x) for x in meta_input_ids])) 285 | tf.logging.info("meta_input_mask: %s" % " ".join([str(x) for x in meta_input_mask])) 286 | tf.logging.info("metaStream_segment_ids: %s" % " ".join([str(x) for x in metaStream_segment_ids])) 287 | tf.logging.info("meta_input_idfs: %s" % " ".join([str(x) for x in meta_input_idfs])) 288 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 289 | 290 | feature = InputFeatures( 291 | query_input_ids=query_input_ids, 292 | query_input_mask=query_input_mask, 293 | query_input_idfs=query_input_idfs, 294 | query_segment_ids=query_segment_ids, 295 | meta_input_ids=meta_input_ids, 296 | meta_input_mask=meta_input_mask, 297 | meta_input_idfs=meta_input_idfs, 298 | metaStream_segment_ids=metaStream_segment_ids, 299 | label_id=label_id) 300 | return feature 301 | 302 | def file_based_convert_examples_to_features_v2(raw_file_path, 303 | csv_line_count, 304 | rank_size, 305 | rank, 306 | label_list, 307 | max_seq_length_query, 308 | max_seq_length_url, 309 | max_seq_length_title, 310 | max_seq_length_body, 311 | tokenizer, 312 | set_type, 313 | src_col, 314 | url_col, 315 | title_col, 316 | body_col, 317 | label_col, 318 | output_file): 319 | process_count = int(csv_line_count // rank_size) 320 | process_offset = int(process_count * rank) 321 | 322 | if rank == rank_size - 1: 323 | process_count += csv_line_count % rank_size 324 | 325 | idx = 0 326 | 327 | # read the csv file line by line 328 | with tf.gfile.Open(raw_file_path, 'r') as fp: 329 | output_file_parent_path = os.path.dirname(output_file) 330 | 331 | if not os.path.exists(output_file_parent_path): 332 | os.makedirs(output_file_parent_path) 333 | 334 | if os.path.exists(output_file): 335 | os.remove(output_file) 336 | 337 | tf.logging.info("rank:%d, begin write file to:%s" % (rank, output_file)) 338 | 339 | reader = csv.reader((x.replace('\0', '') for x in fp), delimiter="\t", quotechar=None) 340 | writer = tf.python_io.TFRecordWriter(output_file) 341 | 342 | for line in reader: 343 | if idx < process_offset: 344 | if 0 == idx % 1000000: 345 | tf.logging.info("rank:%d, skim example %d of %d" % (rank, idx, process_offset)) 346 | 347 | idx += 1 348 | continue 349 | 350 | if idx >= (process_offset + process_count): 351 | break 352 | 353 | cur_idx = idx - process_offset 354 | 355 | if cur_idx % 10000 == 0: 356 | tf.logging.info("rank:%d, Writing example %d of %d" % (rank, cur_idx, process_count)) 357 | 358 | # create a single example 359 | guid = "%s-%s" % (set_type, idx) 360 | query = tokenization.convert_to_unicode(line[src_col]) 361 | url = tokenization.convert_to_unicode(line[url_col]) 362 | title = tokenization.convert_to_unicode(line[title_col]) 363 | label = tokenization.convert_to_unicode('1') 364 | if not label_col == -1: 365 | label = tokenization.convert_to_unicode(line[label_col]) 366 | body = None 367 | if not body_col == -1: 368 | body = tokenization.convert_to_unicode(line[body_col]) 369 | 370 | single_example = InputExample( 371 | guid=guid, query=query, url=url, title=title, label=label, body=body) 372 | 373 | feature = convert_single_example(cur_idx, 374 | single_example, 375 | label_list, 376 | max_seq_length_query, 377 | max_seq_length_url, 378 | max_seq_length_title, 379 | max_seq_length_body, 380 | tokenizer) 381 | 382 | def create_int_feature(values): 383 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 384 | return f 385 | def create_float_feature(values): 386 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 387 | return f 388 | 389 | features = collections.OrderedDict() 390 | features["query_input_ids"] = create_int_feature(feature.query_input_ids) 391 | features["query_input_mask"] = create_int_feature(feature.query_input_mask) 392 | features["query_input_idfs"] = create_float_feature(feature.query_input_idfs) 393 | features["query_segment_ids"] = create_int_feature(feature.query_segment_ids) 394 | features["meta_input_ids"] = create_int_feature(feature.meta_input_ids) 395 | features["meta_input_mask"] = create_int_feature(feature.meta_input_mask) 396 | features["meta_input_idfs"] = create_float_feature(feature.meta_input_idfs) 397 | features["metaStream_segment_ids"] = create_int_feature(feature.metaStream_segment_ids) 398 | features["label_ids"] = create_int_feature([feature.label_id]) 399 | 400 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 401 | writer.write(tf_example.SerializeToString()) 402 | 403 | idx += 1 404 | 405 | writer.close() 406 | def main(_): 407 | tf.logging.set_verbosity(tf.logging.INFO) 408 | 409 | tf.logging.info("***** Start to preprocess data *****") 410 | hvd.init() 411 | tf.gfile.MakeDirs(FLAGS.output_dir) 412 | data_file = os.path.join(FLAGS.output_dir, str(hvd.rank()), "train.tf_record") 413 | raw_data_file_path = os.path.join(FLAGS.data_dir, FLAGS.data_file) 414 | processors = { 415 | "bison": BISONProcessor, 416 | } 417 | task_name = FLAGS.task_name.lower() 418 | if task_name not in processors: 419 | raise ValueError("Task not found: %s" % (task_name)) 420 | processor = processors[task_name]() 421 | label_list = processor.get_labels() 422 | tokenizer = tokenization.FullTokenizerFullWordIDF( 423 | vocab_file=FLAGS.vocab_file, full_word_idf_file=FLAGS.full_word_idf_file, do_lower_case=FLAGS.do_lower_case) 424 | file_based_convert_examples_to_features_v2(raw_data_file_path, 425 | FLAGS.data_line_count, 426 | hvd.size(), 427 | hvd.rank(), 428 | label_list, 429 | FLAGS.max_seq_length_query, 430 | FLAGS.max_seq_length_url, 431 | FLAGS.max_seq_length_title, 432 | FLAGS.max_seq_length_body, 433 | tokenizer, 434 | "train", 435 | FLAGS.src_col, 436 | FLAGS.url_col, 437 | FLAGS.title_col, 438 | FLAGS.body_col, 439 | FLAGS.label_col, 440 | data_file) 441 | tf.logging.info("***** Data preprocess finished *****") 442 | 443 | if __name__ == "__main__": 444 | flags.mark_flag_as_required("task_name") 445 | flags.mark_flag_as_required("data_dir") 446 | flags.mark_flag_as_required("data_file") 447 | flags.mark_flag_as_required("data_line_count") 448 | flags.mark_flag_as_required("output_dir") 449 | flags.mark_flag_as_required("vocab_file") 450 | flags.mark_flag_as_required("full_word_idf_file") 451 | flags.mark_flag_as_required("default_idf") 452 | tf.app.run() 453 | -------------------------------------------------------------------------------- /msmarco_doc_preprocess/tokenization_msmarco_doc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python # -*- coding: utf-8 -*- 2 | # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Tokenization classes.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import collections 24 | import unicodedata 25 | import six 26 | import tensorflow as tf 27 | import re 28 | import os 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | 40 | flags = tf.flags 41 | FLAGS = flags.FLAGS 42 | 43 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 44 | """Checks whether the casing config is consistent with the checkpoint name.""" 45 | 46 | # The casing has to be passed in by the user and there is no explicit check 47 | # as to whether it matches the checkpoint. The casing information probably 48 | # should have been stored in the bert_config.json file, but it's not, so 49 | # we have to heuristically detect it to validate. 50 | 51 | if not init_checkpoint: 52 | return 53 | 54 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 55 | if m is None: 56 | return 57 | 58 | model_name = m.group(1) 59 | 60 | lower_models = [ 61 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 62 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 63 | ] 64 | 65 | cased_models = [ 66 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 67 | "multi_cased_L-12_H-768_A-12" 68 | ] 69 | 70 | is_bad_config = False 71 | if model_name in lower_models and not do_lower_case: 72 | is_bad_config = True 73 | actual_flag = "False" 74 | case_name = "lowercased" 75 | opposite_flag = "True" 76 | 77 | if model_name in cased_models and do_lower_case: 78 | is_bad_config = True 79 | actual_flag = "True" 80 | case_name = "cased" 81 | opposite_flag = "False" 82 | 83 | if is_bad_config: 84 | raise ValueError( 85 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 86 | "However, `%s` seems to be a %s model, so you " 87 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 88 | "how the model was pre-training. If this error is wrong, please " 89 | "just comment out this check." % (actual_flag, init_checkpoint, 90 | model_name, case_name, opposite_flag)) 91 | 92 | 93 | def convert_to_unicode(text): 94 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 95 | if isinstance(text, str): 96 | return text 97 | elif isinstance(text, bytes): 98 | return text.decode("utf-8", "ignore") 99 | else: 100 | raise ValueError("Unsupported string type: %s" % (type(text))) 101 | 102 | 103 | def printable_text(text): 104 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 105 | 106 | # These functions want `str` for both Python2 and Python3, but in one case 107 | # it's a Unicode string and in the other it's a byte string. 108 | if isinstance(text, str): 109 | return text 110 | elif isinstance(text, bytes): 111 | return text.decode("utf-8", "ignore") 112 | else: 113 | raise ValueError("Unsupported string type: %s" % (type(text))) 114 | 115 | 116 | def load_vocab(vocab_file): 117 | """Loads a vocabulary file into a dictionary.""" 118 | vocab = collections.OrderedDict() 119 | index = 0 120 | with tf.gfile.GFile(vocab_file, "r") as reader: 121 | while True: 122 | token = convert_to_unicode(reader.readline()) 123 | if not token: 124 | break 125 | token = token.strip() 126 | vocab[token] = index 127 | index += 1 128 | return vocab 129 | 130 | def load_idf(idf_file): 131 | """"Loads a idf file that containing all words' idf value""" 132 | idf = collections.OrderedDict() 133 | with tf.gfile.GFile(idf_file, "r") as reader: 134 | while True: 135 | line = convert_to_unicode(reader.readline()) 136 | if not line: 137 | break 138 | line_arr = line.strip().split("\t") 139 | if len(line_arr) < 2: 140 | continue 141 | token = line_arr[0] 142 | idf_instance = float(line_arr[1]) 143 | idf[token] = idf_instance 144 | return idf 145 | 146 | def convert_tokens_to_bm25s_by_full_word_for_msmarco_doc(full_word_idfs, tokens, words, default_idf, is_doc): 147 | import collections 148 | tfs = collections.Counter(words) 149 | bm25s_list = [] 150 | n_words = len(words) 151 | token_idx = 0 152 | word_idx = 0 153 | temp_token_agg = "" 154 | accumulation = 1 155 | while token_idx < len(tokens) and word_idx < len(words): 156 | temp_token_agg += tokens[token_idx] if "##" not in tokens[token_idx] else tokens[token_idx][2:] 157 | while tokens[token_idx] == "[SEP]" and words[word_idx] != "[SEP]": 158 | word_idx += 1 159 | word = words[word_idx] 160 | # Attention! Token is uncased! 161 | if temp_token_agg == word or temp_token_agg.lower() == word or (token_idx < len(tokens) - 1 and tokens[token_idx + 1] == "[SEP]"): 162 | idf = full_word_idfs[word] if word in full_word_idfs else default_idf 163 | for i in range(accumulation): 164 | tf = tfs[word] / n_words 165 | if is_doc: 166 | atf = tf * (FLAGS.BM25_K1 + 1) / (tf + FLAGS.BM25_K1 * (1 - FLAGS.BM25_B + FLAGS.BM25_B * n_words / FLAGS.BM25_AVGDL_D)) 167 | else: 168 | atf = tf * (FLAGS.BM25_K1 + 1) / (tf + FLAGS.BM25_K1 * (1 - FLAGS.BM25_B + FLAGS.BM25_B * n_words / FLAGS.BM25_AVGDL_Q)) 169 | bm25s_list.append(idf*atf) 170 | temp_token_agg = "" 171 | accumulation = 1 172 | word_idx += 1 173 | token_idx += 1 174 | else: 175 | accumulation += 1 176 | token_idx += 1 177 | return bm25s_list 178 | 179 | 180 | def convert_by_vocab(vocab, items): 181 | """Converts a sequence of [tokens|ids] using the vocab.""" 182 | output = [] 183 | for item in items: 184 | output.append(vocab[item]) 185 | return output 186 | 187 | 188 | def whitespace_tokenize(text): 189 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 190 | text = text.strip() 191 | if not text: 192 | return [] 193 | tokens = text.split() 194 | return tokens 195 | 196 | class FullTokenizerFullWordIDF(object): 197 | # Run a tokenization init with IDF information 198 | def __init__(self, vocab_file,full_word_idf_file, do_lower_case=True): 199 | self.vocab = load_vocab(vocab_file) 200 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 201 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 202 | self.full_word_idfs = load_idf(full_word_idf_file) 203 | 204 | def tokenize(self, text): 205 | split_tokens = [] 206 | for token in self.basic_tokenizer.tokenize(text): 207 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 208 | split_tokens.append(sub_token) 209 | 210 | return split_tokens 211 | 212 | def convert_tokens_to_ids(self, tokens): 213 | return convert_by_vocab(self.vocab, tokens) 214 | 215 | def convert_tokens_to_bm25s_by_full_word_for_msmarco_doc(self, tokens, words, default_idf, is_doc): 216 | return convert_tokens_to_bm25s_by_full_word_for_msmarco_doc(self.full_word_idfs, tokens, words, default_idf, is_doc) 217 | 218 | class FullTokenizer(object): 219 | """Runs end-to-end tokenziation.""" 220 | 221 | def __init__(self, vocab_file, do_lower_case=True): 222 | self.vocab = load_vocab(vocab_file) 223 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 224 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 225 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 226 | 227 | def tokenize(self, text): 228 | split_tokens = [] 229 | for token in self.basic_tokenizer.tokenize(text): 230 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 231 | split_tokens.append(sub_token) 232 | 233 | return split_tokens 234 | 235 | def convert_tokens_to_ids(self, tokens): 236 | return convert_by_vocab(self.vocab, tokens) 237 | 238 | def convert_ids_to_tokens(self, ids): 239 | return convert_by_vocab(self.inv_vocab, ids) 240 | 241 | 242 | class BertTokenizer(object): 243 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 244 | 245 | def __init__(self, vocab_file, do_lower_case=True): 246 | if not os.path.isfile(vocab_file): 247 | raise ValueError( 248 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 249 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 250 | self.vocab = load_vocab(vocab_file) 251 | self.ids_to_tokens = collections.OrderedDict( 252 | [(ids, tok) for tok, ids in self.vocab.items()]) 253 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 254 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 255 | 256 | def tokenize(self, text): 257 | split_tokens = [] 258 | for token in self.basic_tokenizer.tokenize(text): 259 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 260 | split_tokens.append(sub_token) 261 | return split_tokens 262 | 263 | def convert_tokens_to_ids(self, tokens): 264 | """Converts a sequence of tokens into ids using the vocab.""" 265 | ids = [] 266 | for token in tokens: 267 | ids.append(self.vocab[token]) 268 | return ids 269 | 270 | def convert_ids_to_tokens(self, ids): 271 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 272 | tokens = [] 273 | for i in ids: 274 | tokens.append(self.ids_to_tokens[i]) 275 | return tokens 276 | 277 | @classmethod 278 | def from_pretrained(cls, pretrained_model_name, do_lower_case=True): 279 | """ 280 | Instantiate a PreTrainedBertModel from a pre-trained model file. 281 | Download and cache the pre-trained model file if needed. 282 | """ 283 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 284 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 285 | else: 286 | vocab_file = pretrained_model_name 287 | # redirect to the cache, if necessary 288 | try: 289 | resolved_vocab_file = cached_path(vocab_file) 290 | if resolved_vocab_file == vocab_file: 291 | 292 | logger.info("loading vocabulary file {}".format(vocab_file)) 293 | else: 294 | logger.info("loading vocabulary file {} from cache at {}".format( 295 | vocab_file, resolved_vocab_file)) 296 | # Instantiate tokenizer. 297 | tokenizer = cls(resolved_vocab_file, do_lower_case) 298 | except FileNotFoundError: 299 | logger.error( 300 | "Model name '{}' was not found in model name list ({}). " 301 | "We assumed '{}' was a path or url but couldn't find any file " 302 | "associated to this path or url.".format( 303 | pretrained_model_name, 304 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 305 | pretrained_model_name)) 306 | tokenizer = None 307 | return tokenizer 308 | 309 | 310 | class BasicTokenizer(object): 311 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 312 | 313 | def __init__(self, do_lower_case=True): 314 | """Constructs a BasicTokenizer. 315 | 316 | Args: 317 | do_lower_case: Whether to lower case the input. 318 | """ 319 | self.do_lower_case = do_lower_case 320 | 321 | def tokenize(self, text): 322 | """Tokenizes a piece of text.""" 323 | text = convert_to_unicode(text) 324 | text = self._clean_text(text) 325 | # This was added on November 1st, 2018 for the multilingual and Chinese 326 | # models. This is also applied to the English models now, but it doesn't 327 | # matter since the English models were not trained on any Chinese data 328 | # and generally don't have any Chinese data in them (there are Chinese 329 | # characters in the vocabulary because Wikipedia does have some Chinese 330 | # words in the English Wikipedia.). 331 | text = self._tokenize_chinese_chars(text) 332 | orig_tokens = whitespace_tokenize(text) 333 | split_tokens = [] 334 | for token in orig_tokens: 335 | if self.do_lower_case: 336 | token = token.lower() 337 | token = self._run_strip_accents(token) 338 | split_tokens.extend(self._run_split_on_punc(token)) 339 | 340 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 341 | return output_tokens 342 | 343 | def _run_strip_accents(self, text): 344 | """Strips accents from a piece of text.""" 345 | text = unicodedata.normalize("NFD", text) 346 | output = [] 347 | for char in text: 348 | cat = unicodedata.category(char) 349 | if cat == "Mn": 350 | continue 351 | output.append(char) 352 | return "".join(output) 353 | 354 | def _run_split_on_punc(self, text): 355 | """Splits punctuation on a piece of text.""" 356 | chars = list(text) 357 | i = 0 358 | start_new_word = True 359 | output = [] 360 | while i < len(chars): 361 | char = chars[i] 362 | if _is_punctuation(char): 363 | output.append([char]) 364 | start_new_word = True 365 | else: 366 | if start_new_word: 367 | output.append([]) 368 | start_new_word = False 369 | output[-1].append(char) 370 | i += 1 371 | 372 | return ["".join(x) for x in output] 373 | 374 | def _tokenize_chinese_chars(self, text): 375 | """Adds whitespace around any CJK character.""" 376 | output = [] 377 | for char in text: 378 | cp = ord(char) 379 | if self._is_chinese_char(cp): 380 | output.append(" ") 381 | output.append(char) 382 | output.append(" ") 383 | else: 384 | output.append(char) 385 | return "".join(output) 386 | 387 | def _is_chinese_char(self, cp): 388 | """Checks whether CP is the codepoint of a CJK character.""" 389 | # This defines a "chinese character" as anything in the CJK Unicode block: 390 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 391 | # 392 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 393 | # despite its name. The modern Korean Hangul alphabet is a different block, 394 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 395 | # space-separated words, so they are not treated specially and handled 396 | # like the all of the other languages. 397 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 398 | (cp >= 0x3400 and cp <= 0x4DBF) or # 399 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 400 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 401 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 402 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 403 | (cp >= 0xF900 and cp <= 0xFAFF) or # 404 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 405 | return True 406 | 407 | return False 408 | 409 | def _clean_text(self, text): 410 | """Performs invalid character removal and whitespace cleanup on text.""" 411 | output = [] 412 | for char in text: 413 | cp = ord(char) 414 | if cp == 0 or cp == 0xfffd or _is_control(char): 415 | continue 416 | if _is_whitespace(char): 417 | output.append(" ") 418 | else: 419 | output.append(char) 420 | return "".join(output) 421 | 422 | 423 | class WordpieceTokenizer(object): 424 | """Runs WordPiece tokenization.""" 425 | 426 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 427 | self.vocab = vocab 428 | self.unk_token = unk_token 429 | self.max_input_chars_per_word = max_input_chars_per_word 430 | 431 | def tokenize(self, text): 432 | """Tokenizes a piece of text into its word pieces. 433 | 434 | This uses a greedy longest-match-first algorithm to perform tokenization 435 | using the given vocabulary. 436 | 437 | For example: 438 | input = "unaffable" 439 | output = ["un", "##aff", "##able"] 440 | 441 | Args: 442 | text: A single token or whitespace separated tokens. This should have 443 | already been passed through `BasicTokenizer. 444 | 445 | Returns: 446 | A list of wordpiece tokens. 447 | """ 448 | 449 | text = convert_to_unicode(text) 450 | 451 | output_tokens = [] 452 | for token in whitespace_tokenize(text): 453 | chars = list(token) 454 | if len(chars) > self.max_input_chars_per_word: 455 | output_tokens.append(self.unk_token) 456 | continue 457 | 458 | is_bad = False 459 | start = 0 460 | sub_tokens = [] 461 | while start < len(chars): 462 | end = len(chars) 463 | cur_substr = None 464 | while start < end: 465 | substr = "".join(chars[start:end]) 466 | if start > 0: 467 | substr = "##" + substr 468 | if substr in self.vocab: 469 | cur_substr = substr 470 | break 471 | end -= 1 472 | if cur_substr is None: 473 | is_bad = True 474 | break 475 | sub_tokens.append(cur_substr) 476 | start = end 477 | 478 | if is_bad: 479 | output_tokens.append(self.unk_token) 480 | else: 481 | output_tokens.extend(sub_tokens) 482 | return output_tokens 483 | 484 | 485 | def _is_whitespace(char): 486 | """Checks whether `chars` is a whitespace character.""" 487 | # \t, \n, and \r are technically contorl characters but we treat them 488 | # as whitespace since they are generally considered as such. 489 | if char == " " or char == "\t" or char == "\n" or char == "\r": 490 | return True 491 | cat = unicodedata.category(char) 492 | if cat == "Zs": 493 | return True 494 | return False 495 | 496 | 497 | def _is_control(char): 498 | """Checks whether `chars` is a control character.""" 499 | # These are technically control characters but we count them as whitespace 500 | # characters. 501 | if char == "\t" or char == "\n" or char == "\r": 502 | return False 503 | cat = unicodedata.category(char) 504 | if cat.startswith("C"): 505 | return True 506 | return False 507 | 508 | 509 | def _is_punctuation(char): 510 | """Checks whether `chars` is a punctuation character.""" 511 | cp = ord(char) 512 | # We treat all non-letter/number ASCII as punctuation. 513 | # Characters such as "^", "$", and "`" are not in the Unicode 514 | # Punctuation class but we treat them as punctuation anyways, for 515 | # consistency. 516 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 517 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 518 | return True 519 | cat = unicodedata.category(char) 520 | if cat.startswith("P"): 521 | return True 522 | return False 523 | -------------------------------------------------------------------------------- /msmarco_doc_train/modeling/create_model_bison.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from modeling import modeling_bison, optimization_nvidia 3 | 4 | flags = tf.flags 5 | FLAGS = flags.FLAGS 6 | 7 | def file_based_input_fn_builder(input_file, batch_size, query_seq_length, meta_seq_length, is_training, 8 | drop_remainder, is_fidelity_eval=False, hvd=None): 9 | """Creates an `input_fn` closure to be passed to Estimator.""" 10 | 11 | name_to_features = { 12 | "query_input_ids": tf.FixedLenFeature([query_seq_length], tf.int64), 13 | "query_input_mask": tf.FixedLenFeature([query_seq_length], tf.int64), 14 | "query_input_idfs": tf.FixedLenFeature([query_seq_length], tf.float32), 15 | "query_segment_ids": tf.FixedLenFeature([query_seq_length], tf.int64), 16 | "meta_input_ids": tf.FixedLenFeature([meta_seq_length], tf.int64), 17 | "meta_input_mask": tf.FixedLenFeature([meta_seq_length], tf.int64), 18 | "meta_input_idfs": tf.FixedLenFeature([meta_seq_length], tf.float32), 19 | "metaStream_segment_ids": tf.FixedLenFeature([meta_seq_length], tf.int64), 20 | "label_ids": tf.FixedLenFeature([], tf.int64), 21 | } 22 | name_to_features_eval = { 23 | "query_input_ids": tf.FixedLenFeature([query_seq_length], tf.int64), 24 | "query_input_mask": tf.FixedLenFeature([query_seq_length], tf.int64), 25 | "query_input_idfs": tf.FixedLenFeature([query_seq_length], tf.float32), 26 | "query_segment_ids": tf.FixedLenFeature([query_seq_length], tf.int64), 27 | "meta_input_ids": tf.FixedLenFeature([meta_seq_length], tf.int64), 28 | "meta_input_mask": tf.FixedLenFeature([meta_seq_length], tf.int64), 29 | "meta_input_idfs": tf.FixedLenFeature([meta_seq_length], tf.float32), 30 | "metaStream_segment_ids": tf.FixedLenFeature([meta_seq_length], tf.int64), 31 | "label_ids": tf.FixedLenFeature([], tf.int64), 32 | "query_id": tf.FixedLenFeature([], tf.int64), 33 | "IFM": tf.FixedLenFeature([], tf.int64), 34 | "InstanceId": tf.FixedLenFeature([], tf.int64), 35 | "MapLabel": tf.FixedLenFeature([], tf.int64), 36 | "docId": tf.FixedLenFeature([], tf.int64), 37 | } 38 | 39 | def _decode_record(record, name_to_features): 40 | """Decodes a record to a TensorFlow example.""" 41 | example = tf.parse_single_example(record, name_to_features) 42 | 43 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 44 | # So cast all int64 to int32. 45 | for name in list(example.keys()): 46 | t = example[name] 47 | if t.dtype == tf.int64: 48 | t = tf.to_int32(t) 49 | example[name] = t 50 | 51 | return example 52 | 53 | def input_fn(): 54 | """The actual input function.""" 55 | 56 | # For training, we want a lot of parallel reading and shuffling. 57 | # For eval, we want no shuffling and parallel reading doesn't matter. 58 | if is_training: 59 | files = tf.data.Dataset.list_files(input_file) 60 | d = files.interleave(tf.data.TFRecordDataset, cycle_length=32, 61 | num_parallel_calls=24) 62 | d = d.prefetch(buffer_size=batch_size * 4 * 3) 63 | d = d.shuffle(buffer_size=batch_size * 3) 64 | d = d.repeat() 65 | d = d.map(map_func=lambda record: _decode_record(record, name_to_features), 66 | num_parallel_calls=24) 67 | d = d.batch(batch_size) 68 | else: 69 | d = tf.data.TFRecordDataset(input_file) 70 | d = d.apply( 71 | tf.contrib.data.map_and_batch( 72 | lambda record: _decode_record(record, 73 | name_to_features=name_to_features_eval if is_fidelity_eval else name_to_features), 74 | batch_size=batch_size, 75 | drop_remainder=drop_remainder)) 76 | return d 77 | 78 | return input_fn 79 | 80 | 81 | def eval_file_based_input_fn_builder(input_file, query_seq_length, meta_seq_length, drop_remainder=False, is_fidelity_eval=False): 82 | name_to_features = { 83 | "query_input_ids": tf.FixedLenFeature([query_seq_length], tf.int64), 84 | "query_input_mask": tf.FixedLenFeature([query_seq_length], tf.int64), 85 | "query_input_idfs": tf.FixedLenFeature([query_seq_length], tf.float32), 86 | "query_segment_ids": tf.FixedLenFeature([query_seq_length], tf.int64), 87 | "meta_input_ids": tf.FixedLenFeature([meta_seq_length], tf.int64), 88 | "meta_input_mask": tf.FixedLenFeature([meta_seq_length], tf.int64), 89 | "meta_input_idfs": tf.FixedLenFeature([meta_seq_length], tf.float32), 90 | "metaStream_segment_ids": tf.FixedLenFeature([meta_seq_length], tf.int64), 91 | "label_ids": tf.FixedLenFeature([], tf.int64), 92 | } 93 | 94 | name_to_features_eval = { 95 | "query_input_ids": tf.FixedLenFeature([query_seq_length], tf.int64), 96 | "query_input_mask": tf.FixedLenFeature([query_seq_length], tf.int64), 97 | "query_input_idfs": tf.FixedLenFeature([query_seq_length], tf.float32), 98 | "query_segment_ids": tf.FixedLenFeature([query_seq_length], tf.int64), 99 | "meta_input_ids": tf.FixedLenFeature([meta_seq_length], tf.int64), 100 | "meta_input_mask": tf.FixedLenFeature([meta_seq_length], tf.int64), 101 | "meta_input_idfs": tf.FixedLenFeature([meta_seq_length], tf.float32), 102 | "metaStream_segment_ids": tf.FixedLenFeature([meta_seq_length], tf.int64), 103 | "label_ids": tf.FixedLenFeature([], tf.int64), 104 | "query_id": tf.FixedLenFeature([], tf.int64), 105 | "IFM": tf.FixedLenFeature([], tf.int64), 106 | "InstanceId": tf.FixedLenFeature([], tf.int64), 107 | "MapLabel": tf.FixedLenFeature([], tf.int64), 108 | "docId": tf.FixedLenFeature([], tf.int64), 109 | } 110 | 111 | def _decode_record(record, name_to_features): 112 | """Decodes a record to a TensorFlow example.""" 113 | example = tf.parse_single_example(record, name_to_features) 114 | 115 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 116 | # So cast all int64 to int32. 117 | for name in list(example.keys()): 118 | t = example[name] 119 | if t.dtype == tf.int64: 120 | t = tf.to_int32(t) 121 | example[name] = t 122 | 123 | return example 124 | 125 | def input_fn(params): 126 | """The actual input function.""" 127 | batch_size = FLAGS.eval_batch_size 128 | 129 | d = tf.data.TFRecordDataset(input_file) 130 | d = d.apply( 131 | tf.contrib.data.map_and_batch( 132 | lambda record: _decode_record(record, 133 | name_to_features=name_to_features_eval if is_fidelity_eval else name_to_features), 134 | batch_size=batch_size, 135 | drop_remainder=drop_remainder)) 136 | 137 | return d 138 | 139 | return input_fn 140 | 141 | 142 | def create_model(query_bert_config, meta_bert_config, is_training, query_input_ids, query_input_mask, query_input_idfs, query_segment_ids, meta_input_ids, meta_input_mask, meta_input_idfs, metaStream_segment_ids, 143 | labels, use_one_hot_embeddings, nce_temperature, nce_weight): 144 | """Creates a classification model.""" 145 | tf.logging.info("*** Query Weighted attention is enabled in Transformer ***") 146 | query_model = modeling_bison.BertModel( 147 | config=query_bert_config, 148 | is_training=is_training, 149 | input_ids=query_input_ids, 150 | input_mask=query_input_mask, 151 | input_idfs=query_input_idfs, 152 | token_type_ids=query_segment_ids, 153 | use_one_hot_embeddings=use_one_hot_embeddings, 154 | scope="query", 155 | compute_type=tf.float16 if FLAGS.use_fp16 else tf.float32) 156 | 157 | tf.logging.info("*** Meta Weighted attention is enabled in Transformer ***") 158 | meta_model = modeling_bison.BertModel( 159 | config=meta_bert_config, 160 | is_training=is_training, 161 | input_ids=meta_input_ids, 162 | input_mask=meta_input_mask, 163 | input_idfs=meta_input_idfs, 164 | token_type_ids=metaStream_segment_ids, 165 | use_one_hot_embeddings=use_one_hot_embeddings, 166 | scope="meta", 167 | compute_type=tf.float16 if FLAGS.use_fp16 else tf.float32) 168 | 169 | query_output_layer = query_model.get_pooled_output() 170 | tf.logging.info( 171 | " !!!!!!!!!!!!!!!!!!!!!!!!!query_output_layer = %s", query_output_layer) 172 | 173 | meta_output_layer = meta_model.get_pooled_output() 174 | tf.logging.info( 175 | " !!!!!!!!!!!!!!!!!!!!!!!!!meta_output_layer = %s", meta_output_layer) 176 | query_hidden_size = query_output_layer.shape[-1].value 177 | meta_hidden_size = meta_output_layer.shape[-1].value 178 | 179 | tf.summary.histogram("query_output_layer", query_output_layer) 180 | tf.summary.histogram("meta_output_layer", meta_output_layer) 181 | 182 | with tf.variable_scope("loss"): 183 | query_vectors = query_output_layer 184 | meta_vectors = meta_output_layer 185 | tf.logging.info( 186 | " !!!!!!!!!!!!!!!!!!!!!!!!!query_vectors = %s", query_vectors) 187 | tf.logging.info( 188 | " !!!!!!!!!!!!!!!!!!!!!!!!!meta_vectors = %s", meta_vectors) 189 | 190 | # Filter NCE cases by query similarity 191 | query_output_layer_l2 = tf.nn.l2_normalize(query_vectors,-1) 192 | t_cross_query_sim = tf.matmul(query_output_layer_l2,query_output_layer_l2, transpose_b=True) 193 | t_cross_query_mask = tf.where(tf.greater(t_cross_query_sim,0.90),-1e12 *tf.ones_like(t_cross_query_sim),tf.zeros_like(t_cross_query_sim)) 194 | 195 | # Generate NCE cases 196 | meta_encodesrc = meta_output_layer 197 | batch_size = tf.shape(meta_encodesrc)[0] 198 | t_encoded_src_norm = tf.nn.l2_normalize(query_vectors, -1) 199 | t_encoded_trg_norm = tf.nn.l2_normalize(meta_vectors, -1) 200 | tf.logging.info( 201 | " !!!!!!!!!!!!!!!!!!!!!!!!!t_encoded_src_norm = %s", t_encoded_src_norm) 202 | tf.logging.info( 203 | " !!!!!!!!!!!!!!!!!!!!!!!!!t_encoded_trg_norm = %s", t_encoded_trg_norm) 204 | 205 | t_cross_sim = tf.matmul( 206 | t_encoded_src_norm, t_encoded_trg_norm, transpose_b=True) 207 | t_cross_sim_masked = -1e12 * \ 208 | tf.eye(tf.shape(t_cross_sim)[0]) + t_cross_sim 209 | t_cross_sim_masked = t_cross_query_mask + t_cross_sim_masked 210 | tf.logging.info( 211 | " !!!!!!!!!!!!!!!!!!!!!!!!!t_cross_sim_masked = %s", t_cross_sim_masked) 212 | t_max_neg_idx = tf.reshape(tf.multinomial( 213 | t_cross_sim_masked * nce_temperature, 1), [-1]) 214 | t_max_neg_idx = tf.stop_gradient(t_max_neg_idx) 215 | tf.logging.info( 216 | " !!!!!!!!!!!!!!!!!!!!!!!!!t_max_neg_idx = %s", t_max_neg_idx) 217 | t_neg_encoded_trg_norm = tf.gather(t_encoded_trg_norm, t_max_neg_idx) 218 | tf.logging.info( 219 | " !!!!!!!!!!!!!!!!!!!!!!!!!t_neg_encoded_trg_norm = %s", t_neg_encoded_trg_norm) 220 | 221 | t_encoded_src_norm_concat = tf.concat( 222 | [t_encoded_src_norm, t_encoded_src_norm], 0) 223 | t_encoded_trg_norm_concat = tf.concat( 224 | [t_encoded_trg_norm, t_neg_encoded_trg_norm], 0) 225 | tf.logging.info( 226 | " !!!!!!!!!!!!!!!!!!!!!!!!!Add_t_encoded_src_norm = %s", t_encoded_src_norm_concat) 227 | tf.logging.info( 228 | " !!!!!!!!!!!!!!!!!!!!!!!!!Add_t_encoded_trg_norm = %s", t_encoded_trg_norm_concat) 229 | 230 | tf.logging.info(" !!!!!!!!!!!!!!!!!!!!!!!!!raw_labels = %s", labels) 231 | t_label = tf.to_int32(labels) 232 | t_label = tf.pad(t_label, [[0, tf.shape(labels)[0]]]) 233 | tf.logging.info(" !!!!!!!!!!!!!!!!!!!!!!!!!pad_labels = %s", t_label) 234 | 235 | t_sim = tf.reduce_sum( 236 | t_encoded_src_norm_concat * t_encoded_trg_norm_concat, 237 | -1) 238 | 239 | pos_logits=tf.boolean_mask(t_sim[:batch_size],tf.equal(t_label[:batch_size],1)) 240 | neg_logits=tf.boolean_mask(t_sim[:batch_size],tf.equal(t_label[:batch_size],0)) 241 | tf.summary.scalar('pos_logits', tf.reduce_mean(pos_logits)) 242 | tf.summary.scalar('neg_logits', tf.reduce_mean(neg_logits)) 243 | tf.summary.scalar('nce_logits', tf.reduce_mean(t_sim[batch_size:])) 244 | if FLAGS.Fix_Sim_Weight: 245 | t_logits = FLAGS.sim_weight * t_sim + FLAGS.sim_bias 246 | else: 247 | v_weights = tf.get_variable( 248 | 'SimWeights', 249 | initializer=tf.ones([1], dtype=tf.float32)) 250 | v_biases = tf.get_variable( 251 | 'SimBiases', 252 | initializer=tf.zeros([1], dtype=tf.float32)) 253 | # monitor sim_weight and sim_bias change 254 | tf.summary.scalar('sim_weight', v_weights[0]) 255 | tf.summary.scalar('sim_bias', v_biases[0]) 256 | t_logits = v_weights * t_sim + v_biases 257 | 258 | loss_label = tf.where(tf.equal(t_label,1),t_label,tf.zeros_like(t_label)) 259 | per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.to_float(loss_label), logits=t_logits) 260 | tf.logging.info( 261 | " !!!!!!!!!!!!!!!!!!!!!!!!!per_example_loss = %s", per_example_loss) 262 | eval_loss = tf.reduce_mean(per_example_loss[:batch_size], 0) 263 | tf.logging.info(" !!!!!!!!!!!!!!!!!!!!!!!!!eval_loss = %s", eval_loss) 264 | nce_loss = tf.reduce_mean(per_example_loss[batch_size:], 0) 265 | tf.logging.info(" !!!!!!!!!!!!!!!!!!!!!!!!!nce_loss = %s", nce_loss) 266 | loss = nce_weight * nce_loss + (1.0 - nce_weight) * eval_loss 267 | 268 | tf.summary.scalar('eval_loss', (1.0 - nce_weight) * eval_loss) 269 | tf.summary.scalar('nce_loss', nce_weight * nce_loss) 270 | 271 | loss = tf.identity(loss, name='loss') 272 | per_example_loss = tf.identity( 273 | per_example_loss, name='per_example_loss') 274 | query_vectors = tf.identity(query_vectors, name='query_vectors') 275 | meta_vectors = tf.identity(meta_vectors, name='meta_vectors') 276 | score = tf.identity(t_sim[:batch_size], name='score') 277 | return (loss, per_example_loss, query_vectors, meta_vectors, score) 278 | 279 | 280 | def model_fn_builder(query_bert_config, meta_bert_config, init_checkpoint, 281 | learning_rate, num_train_steps, num_warmup_steps, use_one_hot_embeddings, nce_temperature, nce_weight, hvd=None): 282 | """Returns `model_fn` closure for Estimator.""" 283 | 284 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 285 | """The `model_fn` for Estimator.""" 286 | def metric_fn(per_example_loss, label_ids, cos_distance): 287 | predictions = tf.where(tf.greater(cos_distance, 0.5), x=tf.ones_like( 288 | cos_distance), y=tf.zeros_like(cos_distance)) 289 | accuracy = tf.metrics.accuracy(label_ids, tf.sigmoid(cos_distance)) 290 | loss = tf.metrics.mean(per_example_loss) 291 | return { 292 | "eval_accuracy": accuracy, 293 | "eval_loss": loss, 294 | } 295 | 296 | tf.logging.info("*** Features ***") 297 | for name in sorted(features.keys()): 298 | tf.logging.info(" name = %s, shape = %s" % 299 | (name, features[name].shape)) 300 | 301 | query_input_ids = features["query_input_ids"] 302 | query_input_mask = features["query_input_mask"] 303 | query_input_idfs = features["query_input_idfs"] 304 | query_segment_ids = features["query_segment_ids"] 305 | meta_input_ids = features["meta_input_ids"] 306 | meta_input_idfs = features["meta_input_idfs"] 307 | meta_input_mask = features["meta_input_mask"] 308 | metaStream_segment_ids = features["metaStream_segment_ids"] 309 | label_ids = features["label_ids"] 310 | 311 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 312 | (total_loss, per_example_loss, query_vectors, meta_vectors, cos_distance) = create_model( 313 | query_bert_config, meta_bert_config, is_training, 314 | query_input_ids, query_input_mask,query_input_idfs, query_segment_ids, meta_input_ids, meta_input_mask, meta_input_idfs, metaStream_segment_ids, label_ids, 315 | use_one_hot_embeddings, nce_temperature, nce_weight) 316 | 317 | tvars = tf.trainable_variables() 318 | initialized_variable_names = {} 319 | if init_checkpoint and (hvd is None or hvd.rank() == 0): 320 | (assignment_map, initialized_variable_names 321 | ) = modeling_bison.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 322 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 323 | 324 | if FLAGS.verbose_logging: 325 | tf.logging.info("**** Trainable Variables ****") 326 | for var in tvars: 327 | init_string = "" 328 | if var.name in initialized_variable_names: 329 | init_string = ", *INIT_FROM_CKPT*" 330 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 331 | init_string) 332 | 333 | output_spec = None 334 | if mode == tf.estimator.ModeKeys.TRAIN: 335 | 336 | train_op, LR, l2_norm = optimization_nvidia.create_optimizer( 337 | total_loss, learning_rate, num_train_steps, num_warmup_steps, 338 | hvd, False, FLAGS.use_fp16, FLAGS.num_accumulation_steps) 339 | tf.logging.info(" loss = %s", total_loss) 340 | tf.summary.scalar('learning_rate', LR) 341 | 342 | logging_hook = tf.estimator.LoggingTensorHook({"loss": total_loss, "learning_rate": LR, "global_norm": l2_norm}, every_n_iter=100) 343 | 344 | output_spec = tf.estimator.EstimatorSpec( 345 | mode=mode, 346 | loss=total_loss, 347 | train_op=train_op, 348 | training_hooks=[logging_hook]) 349 | elif mode == tf.estimator.ModeKeys.EVAL: 350 | eval_metrics = ( 351 | metric_fn, [per_example_loss, label_ids, cos_distance]) 352 | 353 | predictions = tf.where(tf.greater(cos_distance, 0.5), x=tf.ones_like(cos_distance), y=tf.zeros_like(cos_distance)) 354 | 355 | accuracy = tf.metrics.accuracy(label_ids, predictions) 356 | auc = tf.metrics.auc(label_ids, predictions) 357 | eval_metric_ops_dict = {'accuracy': accuracy, 358 | 'auc': auc} 359 | output_spec = tf.estimator.EstimatorSpec( 360 | mode=mode, 361 | loss=total_loss, 362 | eval_metric_ops=eval_metric_ops_dict) 363 | else: 364 | cos_distance = tf.expand_dims(cos_distance, -1) 365 | tf.logging.info(" cos_distance = %s, shape = %s", 366 | cos_distance.name, cos_distance.shape) 367 | rt = cos_distance 368 | tf.logging.info(" rt = %s, shape = %s", rt.name, rt.shape) 369 | output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=rt) 370 | return output_spec 371 | 372 | return model_fn -------------------------------------------------------------------------------- /msmarco_doc_train/modeling/modeling_bison.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import numpy as np 27 | import six 28 | import tensorflow as tf 29 | 30 | from utils.gpu_environment import get_custom_getter 31 | 32 | 33 | class BertConfig(object): 34 | """Configuration for `BertModel`.""" 35 | 36 | def __init__(self, 37 | vocab_size, 38 | hidden_size=768, 39 | num_hidden_layers=12, 40 | num_attention_heads=12, 41 | intermediate_size=3072, 42 | hidden_act="gelu", 43 | hidden_dropout_prob=0.1, 44 | attention_probs_dropout_prob=0.1, 45 | max_position_embeddings=512, 46 | type_vocab_size=16, 47 | initializer_range=0.02): 48 | """Constructs BertConfig. 49 | 50 | Args: 51 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 52 | hidden_size: Size of the encoder layers and the pooler layer. 53 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 54 | num_attention_heads: Number of attention heads for each attention layer in 55 | the Transformer encoder. 56 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 57 | layer in the Transformer encoder. 58 | hidden_act: The non-linear activation function (function or string) in the 59 | encoder and pooler. 60 | hidden_dropout_prob: The dropout probability for all fully connected 61 | layers in the embeddings, encoder, and pooler. 62 | attention_probs_dropout_prob: The dropout ratio for the attention 63 | probabilities. 64 | max_position_embeddings: The maximum sequence length that this model might 65 | ever be used with. Typically set this to something large just in case 66 | (e.g., 512 or 1024 or 2048). 67 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 68 | `BertModel`. 69 | initializer_range: The stdev of the truncated_normal_initializer for 70 | initializing all weight matrices. 71 | """ 72 | self.vocab_size = vocab_size 73 | self.hidden_size = hidden_size 74 | self.num_hidden_layers = num_hidden_layers 75 | self.num_attention_heads = num_attention_heads 76 | self.hidden_act = hidden_act 77 | self.intermediate_size = intermediate_size 78 | self.hidden_dropout_prob = hidden_dropout_prob 79 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 80 | self.max_position_embeddings = max_position_embeddings 81 | self.type_vocab_size = type_vocab_size 82 | self.initializer_range = initializer_range 83 | 84 | @classmethod 85 | def from_dict(cls, json_object): 86 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 87 | config = BertConfig(vocab_size=None) 88 | for (key, value) in six.iteritems(json_object): 89 | config.__dict__[key] = value 90 | return config 91 | 92 | @classmethod 93 | def from_json_file(cls, json_file): 94 | """Constructs a `BertConfig` from a json file of parameters.""" 95 | with tf.gfile.GFile(json_file, "r") as reader: 96 | text = reader.read() 97 | return cls.from_dict(json.loads(text)) 98 | 99 | def to_dict(self): 100 | """Serializes this instance to a Python dictionary.""" 101 | output = copy.deepcopy(self.__dict__) 102 | return output 103 | 104 | def to_json_string(self): 105 | """Serializes this instance to a JSON string.""" 106 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 107 | 108 | 109 | class BertModel(object): 110 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 111 | 112 | Example usage: 113 | 114 | ```python 115 | # Already been converted into WordPiece token ids 116 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 117 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 118 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 119 | 120 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 121 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 122 | 123 | model = modeling.BertModel(config=config, is_training=True, 124 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 125 | 126 | label_embeddings = tf.get_variable(...) 127 | pooled_output = model.get_pooled_output() 128 | logits = tf.matmul(pooled_output, label_embeddings) 129 | ... 130 | ``` 131 | """ 132 | 133 | def __init__(self, 134 | config, 135 | is_training, 136 | input_ids, 137 | input_mask=None, 138 | input_idfs=None, 139 | token_type_ids=None, 140 | use_one_hot_embeddings=False, 141 | scope=None, 142 | compute_type=tf.float32): 143 | """Constructor for BertModel. 144 | 145 | Args: 146 | config: `BertConfig` instance. 147 | is_training: bool. true for training model, false for eval model. Controls 148 | whether dropout will be applied. 149 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 150 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 151 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 152 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 153 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 154 | it is much faster if this is True, on the CPU or GPU, it is faster if 155 | this is False. 156 | scope: (optional) variable scope. Defaults to "bert". 157 | compute_type: (optional) either float32 or float16. Only applies to GPUs. 158 | 159 | Raises: 160 | ValueError: The config is invalid or one of the input tensor shapes 161 | is invalid. 162 | """ 163 | config = copy.deepcopy(config) 164 | if not is_training: 165 | config.hidden_dropout_prob = 0.0 166 | config.attention_probs_dropout_prob = 0.0 167 | 168 | input_shape = get_shape_list(input_ids, expected_rank=2) 169 | batch_size = input_shape[0] 170 | seq_length = input_shape[1] 171 | 172 | if input_mask is None: 173 | input_mask = tf.ones( 174 | shape=[batch_size, seq_length], dtype=tf.int32) 175 | 176 | if token_type_ids is None: 177 | token_type_ids = tf.zeros( 178 | shape=[batch_size, seq_length], dtype=tf.int32) 179 | 180 | with tf.variable_scope("bert", reuse=tf.AUTO_REUSE, custom_getter=get_custom_getter(compute_type)): 181 | with tf.variable_scope("embeddings", reuse=tf.AUTO_REUSE): 182 | # For good convergence with mixed precision training, 183 | # it is important that the embedding codes remain fp32. 184 | # Perform embedding lookup on the word ids. 185 | (self.embedding_output, self.embedding_table) = embedding_lookup( 186 | input_ids=input_ids, 187 | vocab_size=config.vocab_size, 188 | embedding_size=config.hidden_size, 189 | initializer_range=config.initializer_range, 190 | word_embedding_name="word_embeddings", 191 | use_one_hot_embeddings=use_one_hot_embeddings) 192 | 193 | with tf.variable_scope(scope, custom_getter=get_custom_getter(compute_type)): 194 | with tf.variable_scope("embeddings"): 195 | # Add positional embeddings and token type embeddings, then layer 196 | # normalize and perform dropout. 197 | self.embedding_output = embedding_postprocessor( 198 | input_tensor=self.embedding_output, 199 | use_token_type=True, 200 | token_type_ids=token_type_ids, 201 | token_type_vocab_size=config.type_vocab_size, 202 | token_type_embedding_name="token_type_embeddings", 203 | use_position_embeddings=True, 204 | position_embedding_name="position_embeddings", 205 | initializer_range=config.initializer_range, 206 | max_position_embeddings=config.max_position_embeddings, 207 | dropout_prob=config.hidden_dropout_prob, 208 | use_one_hot_embeddings=use_one_hot_embeddings) 209 | 210 | with tf.variable_scope("encoder"): 211 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 212 | # mask of shape [batch_size, seq_length, seq_length] which is used 213 | # for the attention scores. 214 | #input_ids = tf.Print(input_ids,[input_ids],message="Input ids value:",summarize=100) 215 | #input_mask = tf.Print(input_mask,[input_mask], message="Input mask value:",summarize=100) 216 | #input_idfs = tf.Print(input_idfs,[input_idfs], message="Input idfs value:",summarize=100) 217 | attention_mask = create_attention_mask_from_input_mask( 218 | input_ids, input_mask) 219 | #attention_mask = tf.Print(attention_mask,[attention_mask],"First time attention mask:") 220 | 221 | # Use a L1 normalizer to generate the sequence weight from sequence idfs 222 | #tf.logging.info("Flag1") 223 | #input_idfs_sum = tf.reduce_sum(input_idfs,1) 224 | #input_idfs_sum = tf.expand_dims(input_idfs_sum, -1) 225 | #input_idfs_norm = tf.divide(input_idfs, input_idfs_sum) 226 | #tf.logging.info(input_idfs_norm.dtype) 227 | #input_idfs_norm = tf.Print(input_idfs_norm, [input_idfs_norm], message="input_idfs_norm's value:", first_n=2) 228 | seq_weight = tf.where(tf.equal(input_idfs,0), tf.ones_like(input_idfs), input_idfs) 229 | seq_weight = tf.saturate_cast(seq_weight, compute_type) 230 | #seq_weight = tf.Print(seq_weight, [seq_weight], message="seq_weight's value:", first_n=2) 231 | # Run the stacked transformer. 232 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 233 | self.all_encoder_layers = transformer_model( 234 | input_tensor=tf.saturate_cast( 235 | self.embedding_output, compute_type), 236 | seq_weight=seq_weight, 237 | attention_mask=attention_mask, 238 | hidden_size=config.hidden_size, 239 | num_hidden_layers=config.num_hidden_layers, 240 | num_attention_heads=config.num_attention_heads, 241 | intermediate_size=config.intermediate_size, 242 | intermediate_act_fn=get_activation(config.hidden_act), 243 | hidden_dropout_prob=config.hidden_dropout_prob, 244 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 245 | initializer_range=config.initializer_range, 246 | do_return_all_layers=True) 247 | 248 | self.sequence_output = tf.cast( 249 | self.all_encoder_layers[-1], tf.float32) 250 | # The "pooler" converts the encoded sequence tensor of shape 251 | # [batch_size, seq_length, hidden_size] to a tensor of shape 252 | # [batch_size, hidden_size]. This is necessary for segment-level 253 | # (or segment-pair-level) classification tasks where we need a fixed 254 | # dimensional representation of the segment. 255 | with tf.variable_scope("pooler"): 256 | # We "pool" the model by simply taking the hidden state corresponding 257 | # to the first token. We assume that this has been pre-trained 258 | first_token_tensor = tf.squeeze( 259 | self.sequence_output[:, 0:1, :], axis=1) 260 | self.pooled_output = tf.layers.dense( 261 | first_token_tensor, 262 | config.hidden_size, 263 | activation=tf.tanh, 264 | kernel_initializer=create_initializer(config.initializer_range)) 265 | 266 | def get_pooled_output(self): 267 | return self.pooled_output 268 | 269 | def get_sequence_output(self): 270 | """Gets final hidden layer of encoder. 271 | 272 | Returns: 273 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 274 | to the final hidden of the transformer encoder. 275 | """ 276 | return self.sequence_output 277 | 278 | def get_all_encoder_layers(self): 279 | return self.all_encoder_layers 280 | 281 | def get_embedding_output(self): 282 | """Gets output of the embedding lookup (i.e., input to the transformer). 283 | 284 | Returns: 285 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 286 | to the output of the embedding layer, after summing the word 287 | embeddings with the positional embeddings and the token type embeddings, 288 | then performing layer normalization. This is the input to the transformer. 289 | """ 290 | return self.embedding_output 291 | 292 | def get_embedding_table(self): 293 | return self.embedding_table 294 | 295 | 296 | def gelu(x): 297 | """Gaussian Error Linear Unit. 298 | 299 | This is a smoother version of the RELU. 300 | Original paper: https://arxiv.org/abs/1606.08415 301 | Args: 302 | x: float Tensor to perform activation. 303 | 304 | Returns: 305 | `x` with the GELU activation applied. 306 | """ 307 | cdf = 0.5 * (1.0 + tf.tanh( 308 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 309 | return x * cdf 310 | 311 | 312 | def get_activation(activation_string): 313 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 314 | 315 | Args: 316 | activation_string: String name of the activation function. 317 | 318 | Returns: 319 | A Python function corresponding to the activation function. If 320 | `activation_string` is None, empty, or "linear", this will return None. 321 | If `activation_string` is not a string, it will return `activation_string`. 322 | 323 | Raises: 324 | ValueError: The `activation_string` does not correspond to a known 325 | activation. 326 | """ 327 | 328 | # We assume that anything that"s not a string is already an activation 329 | # function, so we just return it. 330 | if not isinstance(activation_string, six.string_types): 331 | return activation_string 332 | 333 | if not activation_string: 334 | return None 335 | 336 | act = activation_string.lower() 337 | if act == "linear": 338 | return None 339 | elif act == "relu": 340 | return tf.nn.relu 341 | elif act == "gelu": 342 | return gelu 343 | elif act == "tanh": 344 | return tf.tanh 345 | else: 346 | raise ValueError("Unsupported activation: %s" % act) 347 | 348 | 349 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 350 | """Compute the union of the current variables and checkpoint variables.""" 351 | assignment_map = {} 352 | initialized_variable_names = {} 353 | 354 | name_to_variable = collections.OrderedDict() 355 | for var in tvars: 356 | name = var.name 357 | m = re.match("^(.*):\\d+$", name) 358 | if m is not None: 359 | name = m.group(1) 360 | name_to_variable[name] = var 361 | 362 | init_vars = tf.train.list_variables(init_checkpoint) 363 | 364 | assignment_map = collections.OrderedDict() 365 | for x in init_vars: 366 | (name, var) = (x[0], x[1]) 367 | if name not in name_to_variable: 368 | continue 369 | assignment_map[name] = name 370 | initialized_variable_names[name] = 1 371 | initialized_variable_names[name + ":0"] = 1 372 | 373 | return (assignment_map, initialized_variable_names) 374 | 375 | 376 | def dropout(input_tensor, dropout_prob): 377 | """Perform dropout. 378 | 379 | Args: 380 | input_tensor: float Tensor. 381 | dropout_prob: Python float. The probability of dropping out a value (NOT of 382 | *keeping* a dimension as in `tf.nn.dropout`). 383 | 384 | Returns: 385 | A version of `input_tensor` with dropout applied. 386 | """ 387 | if dropout_prob is None or dropout_prob == 0.0: 388 | return input_tensor 389 | 390 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 391 | return output 392 | 393 | 394 | def layer_norm(input_tensor, name=None): 395 | """Run layer normalization on the last dimension of the tensor.""" 396 | if input_tensor.dtype == tf.float16: 397 | try: 398 | from modeling.fused_layer_norm import fused_layer_norm 399 | return fused_layer_norm( 400 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name, 401 | use_fused_batch_norm=True) 402 | except ImportError: 403 | return tf.contrib.layers.layer_norm( 404 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 405 | else: 406 | return tf.contrib.layers.layer_norm( 407 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 408 | 409 | 410 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 411 | """Runs layer normalization followed by dropout.""" 412 | output_tensor = layer_norm(input_tensor, name) 413 | output_tensor = dropout(output_tensor, dropout_prob) 414 | return output_tensor 415 | 416 | 417 | def create_initializer(initializer_range=0.02): 418 | """Creates a `truncated_normal_initializer` with the given range.""" 419 | return tf.truncated_normal_initializer(stddev=initializer_range) 420 | 421 | 422 | def embedding_lookup(input_ids, 423 | vocab_size, 424 | embedding_size=128, 425 | initializer_range=0.02, 426 | word_embedding_name="word_embeddings", 427 | use_one_hot_embeddings=False): 428 | """Looks up words embeddings for id tensor. 429 | 430 | Args: 431 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 432 | ids. 433 | vocab_size: int. Size of the embedding vocabulary. 434 | embedding_size: int. Width of the word embeddings. 435 | initializer_range: float. Embedding initialization range. 436 | word_embedding_name: string. Name of the embedding table. 437 | use_one_hot_embeddings: bool. If True, use one-hot method for word 438 | embeddings. If False, use `tf.gather()`. 439 | 440 | Returns: 441 | float Tensor of shape [batch_size, seq_length, embedding_size]. 442 | """ 443 | # This function assumes that the input is of shape [batch_size, seq_length, 444 | # num_inputs]. 445 | # 446 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 447 | # reshape to [batch_size, seq_length, 1]. 448 | if input_ids.shape.ndims == 2: 449 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 450 | 451 | embedding_table = tf.get_variable( 452 | name=word_embedding_name, 453 | shape=[vocab_size, embedding_size], 454 | initializer=create_initializer(initializer_range)) 455 | 456 | flat_input_ids = tf.reshape(input_ids, [-1]) 457 | if use_one_hot_embeddings: 458 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 459 | output = tf.matmul(one_hot_input_ids, embedding_table) 460 | else: 461 | output = tf.gather(embedding_table, flat_input_ids) 462 | 463 | input_shape = get_shape_list(input_ids) 464 | 465 | output = tf.reshape(output, 466 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 467 | return (output, embedding_table) 468 | 469 | 470 | def embedding_postprocessor(input_tensor, 471 | use_token_type=False, 472 | token_type_ids=None, 473 | token_type_vocab_size=16, 474 | token_type_embedding_name="token_type_embeddings", 475 | use_position_embeddings=True, 476 | position_embedding_name="position_embeddings", 477 | initializer_range=0.02, 478 | max_position_embeddings=512, 479 | dropout_prob=0.1, 480 | use_one_hot_embeddings=False): 481 | """Performs various post-processing on a word embedding tensor. 482 | 483 | Args: 484 | input_tensor: float Tensor of shape [batch_size, seq_length, 485 | embedding_size]. 486 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 487 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 488 | Must be specified if `use_token_type` is True. 489 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 490 | token_type_embedding_name: string. The name of the embedding table variable 491 | for token type ids. 492 | use_position_embeddings: bool. Whether to add position embeddings for the 493 | position of each token in the sequence. 494 | position_embedding_name: string. The name of the embedding table variable 495 | for positional embeddings. 496 | initializer_range: float. Range of the weight initialization. 497 | max_position_embeddings: int. Maximum sequence length that might ever be 498 | used with this model. This can be longer than the sequence length of 499 | input_tensor, but cannot be shorter. 500 | dropout_prob: float. Dropout probability applied to the final output tensor. 501 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 502 | embeddings or tf.embedding_lookup() for the word embeddings. 503 | 504 | Returns: 505 | float tensor with same shape as `input_tensor`. 506 | 507 | Raises: 508 | ValueError: One of the tensor shapes or input values is invalid. 509 | """ 510 | input_shape = get_shape_list(input_tensor, expected_rank=3) 511 | batch_size = input_shape[0] 512 | seq_length = input_shape[1] 513 | width = input_shape[2] 514 | 515 | output = input_tensor 516 | 517 | if use_token_type: 518 | if token_type_ids is None: 519 | raise ValueError("`token_type_ids` must be specified if" 520 | "`use_token_type` is True.") 521 | token_type_table = tf.get_variable( 522 | name=token_type_embedding_name, 523 | shape=[token_type_vocab_size, width], 524 | initializer=create_initializer(initializer_range)) 525 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 526 | if use_one_hot_embeddings: 527 | # This vocab will be small so we always do one-hot here, since it is 528 | # always faster for a small vocabulary. 529 | one_hot_ids = tf.one_hot( 530 | flat_token_type_ids, depth=token_type_vocab_size) 531 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 532 | else: 533 | token_type_embeddings = tf.gather( 534 | token_type_table, flat_token_type_ids) 535 | token_type_embeddings = tf.reshape(token_type_embeddings, 536 | [batch_size, seq_length, width]) 537 | output += token_type_embeddings 538 | 539 | if use_position_embeddings: 540 | full_position_embeddings = tf.get_variable( 541 | name=position_embedding_name, 542 | shape=[max_position_embeddings, width], 543 | initializer=create_initializer(initializer_range)) 544 | # Since the position embedding table is a learned variable, we create it 545 | # using a (long) sequence length `max_position_embeddings`. The actual 546 | # sequence length might be shorter than this, for faster training of 547 | # tasks that do not have long sequences. 548 | # 549 | # So `full_position_embeddings` is effectively an embedding table 550 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 551 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 552 | # perform a slice. 553 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 554 | [seq_length, width]) 555 | num_dims = len(output.shape.as_list()) 556 | 557 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 558 | # we broadcast among the first dimensions, which is typically just 559 | # the batch size. 560 | position_broadcast_shape = [] 561 | for _ in range(num_dims - 2): 562 | position_broadcast_shape.append(1) 563 | position_broadcast_shape.extend([seq_length, width]) 564 | position_embeddings = tf.reshape(position_embeddings, 565 | position_broadcast_shape) 566 | output += position_embeddings 567 | 568 | output = layer_norm_and_dropout(output, dropout_prob) 569 | return output 570 | 571 | 572 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 573 | """Create 3D attention mask from a 2D tensor mask. 574 | 575 | Args: 576 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 577 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 578 | 579 | Returns: 580 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 581 | """ 582 | to_mask = tf.cast(to_mask, dtype=tf.float32) 583 | 584 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 585 | batch_size = from_shape[0] 586 | 587 | to_shape = get_shape_list(to_mask, expected_rank=2) 588 | to_seq_length = to_shape[1] 589 | 590 | to_mask = tf.reshape(to_mask, [batch_size, 1, to_seq_length]) 591 | # The mask will be automatically broadcasted to 592 | # [batch_size, from_seq_length, to_seq_length] when it is used in the 593 | # attention layer. 594 | return to_mask 595 | 596 | 597 | def weight_attention_layer(from_tensor, 598 | to_tensor, 599 | seq_weight, 600 | attention_mask=None, 601 | num_attention_heads=1, 602 | size_per_head=512, 603 | query_act=None, 604 | key_act=None, 605 | value_act=None, 606 | attention_probs_dropout_prob=0.0, 607 | initializer_range=0.02, 608 | do_return_2d_tensor=False, 609 | batch_size=None, 610 | from_seq_length=None, 611 | to_seq_length=None, 612 | layer_idx=None): 613 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 614 | 615 | This is an implementation of multi-headed attention based on "Attention 616 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 617 | this is self-attention. Each timestep in `from_tensor` attends to the 618 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 619 | 620 | This function first projects `from_tensor` into a "query" tensor and 621 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 622 | of tensors of length `num_attention_heads`, where each tensor is of shape 623 | [batch_size, seq_length, size_per_head]. 624 | 625 | Then, the query and key tensors are dot-producted and scaled. These are 626 | softmaxed to obtain attention probabilities. The value tensors are then 627 | interpolated by these probabilities, then concatenated back to a single 628 | tensor and returned. 629 | 630 | In practice, the multi-headed attention are done with transposes and 631 | reshapes rather than actual separate tensors. 632 | 633 | Args: 634 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 635 | from_width]. 636 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 637 | attention_mask: (optional) int32 Tensor of shape [batch_size, 638 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 639 | attention scores will effectively be set to -infinity for any positions in 640 | the mask that are 0, and will be unchanged for positions that are 1. 641 | num_attention_heads: int. Number of attention heads. 642 | size_per_head: int. Size of each attention head. 643 | query_act: (optional) Activation function for the query transform. 644 | key_act: (optional) Activation function for the key transform. 645 | value_act: (optional) Activation function for the value transform. 646 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 647 | attention probabilities. 648 | initializer_range: float. Range of the weight initializer. 649 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 650 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 651 | output will be of shape [batch_size, from_seq_length, num_attention_heads 652 | * size_per_head]. 653 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 654 | of the 3D version of the `from_tensor` and `to_tensor`. 655 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 656 | of the 3D version of the `from_tensor`. 657 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 658 | of the 3D version of the `to_tensor`. 659 | 660 | Returns: 661 | float Tensor of shape [batch_size, from_seq_length, 662 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 663 | true, this will be of shape [batch_size * from_seq_length, 664 | num_attention_heads * size_per_head]). 665 | 666 | Raises: 667 | ValueError: Any of the arguments or tensor shapes are invalid. 668 | """ 669 | 670 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 671 | seq_length, width): 672 | output_tensor = tf.reshape( 673 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 674 | 675 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 676 | return output_tensor 677 | 678 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 679 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 680 | 681 | if len(from_shape) != len(to_shape): 682 | raise ValueError( 683 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 684 | 685 | if len(from_shape) == 3: 686 | batch_size = from_shape[0] 687 | from_seq_length = from_shape[1] 688 | to_seq_length = to_shape[1] 689 | elif len(from_shape) == 2: 690 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 691 | raise ValueError( 692 | "When passing in rank 2 tensors to attention_layer, the values " 693 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 694 | "must all be specified.") 695 | 696 | # Scalar dimensions referenced here: 697 | # B = batch size (number of sequences) 698 | # F = `from_tensor` sequence length 699 | # T = `to_tensor` sequence length 700 | # N = `num_attention_heads` 701 | # H = `size_per_head` 702 | 703 | from_tensor_2d = reshape_to_matrix(from_tensor) 704 | to_tensor_2d = reshape_to_matrix(to_tensor) 705 | 706 | # `query_layer` = [B*F, N*H] 707 | query_layer = tf.layers.dense( 708 | from_tensor_2d, 709 | num_attention_heads * size_per_head, 710 | activation=query_act, 711 | name="query", 712 | kernel_initializer=create_initializer(initializer_range)) 713 | 714 | # `key_layer` = [B*T, N*H] 715 | key_layer = tf.layers.dense( 716 | to_tensor_2d, 717 | num_attention_heads * size_per_head, 718 | activation=key_act, 719 | name="key", 720 | kernel_initializer=create_initializer(initializer_range)) 721 | 722 | # `value_layer` = [B*T, N*H] 723 | value_layer = tf.layers.dense( 724 | to_tensor_2d, 725 | num_attention_heads * size_per_head, 726 | activation=value_act, 727 | name="value", 728 | kernel_initializer=create_initializer(initializer_range)) 729 | 730 | # `query_layer` = [B, N, F, H] 731 | query_layer = transpose_for_scores(query_layer, batch_size, 732 | num_attention_heads, from_seq_length, 733 | size_per_head) 734 | 735 | # `key_layer` = [B, N, T, H] 736 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 737 | to_seq_length, size_per_head) 738 | 739 | # Take the dot product between "query" and "key" to get the raw 740 | # attention scores. 741 | # `attention_scores` = [B, N, F, T] 742 | 743 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 744 | attention_scores = tf.multiply(attention_scores, 745 | 1.0 / math.sqrt(float(size_per_head))) 746 | 747 | 748 | # Apply the weight across tokens to attention scores, assume rank of origin sequence_weight is 2, size is [B, F] 749 | 750 | #Step1 expand dims twice to [B, 1, F, 1] from [B, F] 751 | seq_weight = tf.expand_dims(seq_weight, 1) 752 | seq_weight = tf.expand_dims(seq_weight, -1) 753 | #if layer_idx == 0: 754 | #query_layer = tf.Print(query_layer, [query_layer], message="Init query_layer's value:") 755 | #key_layer = tf.Print(key_layer, [key_layer], message="Init key_layer's value:") 756 | #attention_scores = tf.Print(attention_scores, [attention_scores], message="Init Attention scores's value:") 757 | #seq_weight = tf.Print(seq_weight, [seq_weight], message="Seq weight's value:") 758 | #Step2 apply different weight across tokens. 759 | attention_scores = tf.multiply(seq_weight, attention_scores) 760 | 761 | if attention_mask is not None: 762 | # `attention_mask` = [B, 1, F, T] 763 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 764 | #if layer_idx == 0: 765 | #adder = tf.Print(adder,[adder],message="adder value:", summarize=100) 766 | #attention_mask = tf.Print(attention_mask,[attention_mask],message="attention_mask value:",summarize=30) 767 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 768 | # masked positions, this operation will create a tensor which is 0.0 for 769 | # positions we want to attend and -10000.0 for masked positions. 770 | adder = (1.0 - tf.cast(attention_mask, 771 | attention_scores.dtype)) * -10000.0 772 | 773 | 774 | # Since we are adding it to the raw scores before the softmax, this is 775 | # effectively the same as removing these entirely. 776 | attention_scores += adder 777 | #if layer_idx == 0: 778 | #attention_scores = tf.Print(attention_scores, [attention_scores], message="Third Attention scores's value:",summarize=100) 779 | # Normalize the attention scores to probabilities. 780 | # `attention_probs` = [B, N, F, T] 781 | attention_probs = tf.nn.softmax(attention_scores) 782 | #if layer_idx == 0: 783 | #attention_probs = tf.Print(attention_probs, [attention_probs], message="attention_probs's value:",summarize=100) 784 | # This is actually dropping out entire tokens to attend to, which might 785 | # seem a bit unusual, but is taken from the original Transformer paper. 786 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 787 | 788 | # `value_layer` = [B, T, N, H] 789 | value_layer = tf.reshape( 790 | value_layer, 791 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 792 | 793 | # `value_layer` = [B, N, T, H] 794 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 795 | 796 | # `context_layer` = [B, N, F, H] 797 | context_layer = tf.matmul(attention_probs, value_layer) 798 | 799 | # `context_layer` = [B, F, N, H] 800 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 801 | 802 | if do_return_2d_tensor: 803 | # `context_layer` = [B*F, N*H] 804 | context_layer = tf.reshape( 805 | context_layer, 806 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 807 | else: 808 | # `context_layer` = [B, F, N*H] 809 | context_layer = tf.reshape( 810 | context_layer, 811 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 812 | #if layer_idx == 0: 813 | #context_layer = tf.Print(context_layer, [context_layer], message="context_layer's value:") 814 | 815 | return context_layer 816 | 817 | 818 | def transformer_model(input_tensor, 819 | seq_weight, 820 | attention_mask=None, 821 | hidden_size=768, 822 | num_hidden_layers=12, 823 | num_attention_heads=12, 824 | intermediate_size=3072, 825 | intermediate_act_fn=gelu, 826 | hidden_dropout_prob=0.1, 827 | attention_probs_dropout_prob=0.1, 828 | initializer_range=0.02, 829 | do_return_all_layers=False): 830 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 831 | 832 | This is almost an exact implementation of the original Transformer encoder. 833 | 834 | See the original paper: 835 | https://arxiv.org/abs/1706.03762 836 | 837 | Also see: 838 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 839 | 840 | Args: 841 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 842 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 843 | seq_length], with 1 for positions that can be attended to and 0 in 844 | positions that should not be. 845 | hidden_size: int. Hidden size of the Transformer. 846 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 847 | num_attention_heads: int. Number of attention heads in the Transformer. 848 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 849 | forward) layer. 850 | intermediate_act_fn: function. The non-linear activation function to apply 851 | to the output of the intermediate/feed-forward layer. 852 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 853 | attention_probs_dropout_prob: float. Dropout probability of the attention 854 | probabilities. 855 | initializer_range: float. Range of the initializer (stddev of truncated 856 | normal). 857 | do_return_all_layers: Whether to also return all layers or just the final 858 | layer. 859 | 860 | Returns: 861 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 862 | hidden layer of the Transformer. 863 | 864 | Raises: 865 | ValueError: A Tensor shape or parameter is invalid. 866 | """ 867 | if hidden_size % num_attention_heads != 0: 868 | raise ValueError( 869 | "The hidden size (%d) is not a multiple of the number of attention " 870 | "heads (%d)" % (hidden_size, num_attention_heads)) 871 | 872 | attention_head_size = int(hidden_size / num_attention_heads) 873 | input_shape = get_shape_list(input_tensor, expected_rank=3) 874 | batch_size = input_shape[0] 875 | seq_length = input_shape[1] 876 | input_width = input_shape[2] 877 | 878 | # The Transformer performs sum residuals on all layers so the input needs 879 | # to be the same as the hidden size. 880 | if input_width != hidden_size: 881 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 882 | (input_width, hidden_size)) 883 | 884 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 885 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 886 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 887 | # help the optimizer. 888 | prev_output = reshape_to_matrix(input_tensor) 889 | 890 | all_layer_outputs = [] 891 | for layer_idx in range(num_hidden_layers): 892 | with tf.variable_scope("layer_%d" % layer_idx): 893 | layer_input = prev_output 894 | 895 | with tf.variable_scope("attention"): 896 | attention_heads = [] 897 | with tf.variable_scope("self"): 898 | attention_head = weight_attention_layer( 899 | from_tensor=layer_input, 900 | to_tensor=layer_input, 901 | seq_weight=seq_weight, 902 | attention_mask=attention_mask, 903 | num_attention_heads=num_attention_heads, 904 | size_per_head=attention_head_size, 905 | attention_probs_dropout_prob=attention_probs_dropout_prob, 906 | initializer_range=initializer_range, 907 | do_return_2d_tensor=True, 908 | batch_size=batch_size, 909 | from_seq_length=seq_length, 910 | to_seq_length=seq_length, 911 | layer_idx=layer_idx) 912 | attention_heads.append(attention_head) 913 | 914 | attention_output = None 915 | if len(attention_heads) == 1: 916 | attention_output = attention_heads[0] 917 | else: 918 | # In the case where we have other sequences, we just concatenate 919 | # them to the self-attention head before the projection. 920 | attention_output = tf.concat(attention_heads, axis=-1) 921 | #if layer_idx == 0: 922 | #attention_output = tf.Print(attention_output,[attention_output],"attention_output value is:") 923 | 924 | # Run a linear projection of `hidden_size` then add a residual 925 | # with `layer_input`. 926 | with tf.variable_scope("output"): 927 | attention_output = tf.layers.dense( 928 | attention_output, 929 | hidden_size, 930 | kernel_initializer=create_initializer(initializer_range)) 931 | attention_output = dropout( 932 | attention_output, hidden_dropout_prob) 933 | attention_output = layer_norm( 934 | attention_output + layer_input) 935 | 936 | # The activation is only applied to the "intermediate" hidden layer. 937 | with tf.variable_scope("intermediate"): 938 | intermediate_output = tf.layers.dense( 939 | attention_output, 940 | intermediate_size, 941 | activation=intermediate_act_fn, 942 | kernel_initializer=create_initializer(initializer_range)) 943 | 944 | # Down-project back to `hidden_size` then add the residual. 945 | with tf.variable_scope("output"): 946 | layer_output = tf.layers.dense( 947 | intermediate_output, 948 | hidden_size, 949 | kernel_initializer=create_initializer(initializer_range)) 950 | layer_output = dropout(layer_output, hidden_dropout_prob) 951 | layer_output = layer_norm(layer_output + attention_output) 952 | #if layer_idx == 0: 953 | #layer_output = tf.Print(layer_output,[layer_output],"First Layer Output:") 954 | prev_output = layer_output 955 | all_layer_outputs.append(layer_output) 956 | 957 | 958 | if do_return_all_layers: 959 | final_outputs = [] 960 | for layer_output in all_layer_outputs: 961 | final_output = reshape_from_matrix(layer_output, input_shape) 962 | final_outputs.append(final_output) 963 | return final_outputs 964 | else: 965 | final_output = reshape_from_matrix(prev_output, input_shape) 966 | return final_output 967 | 968 | 969 | def get_shape_list(tensor, expected_rank=None, name=None): 970 | """Returns a list of the shape of tensor, preferring static dimensions. 971 | 972 | Args: 973 | tensor: A tf.Tensor object to find the shape of. 974 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 975 | specified and the `tensor` has a different rank, and exception will be 976 | thrown. 977 | name: Optional name of the tensor for the error message. 978 | 979 | Returns: 980 | A list of dimensions of the shape of tensor. All static dimensions will 981 | be returned as python integers, and dynamic dimensions will be returned 982 | as tf.Tensor scalars. 983 | """ 984 | if name is None: 985 | name = tensor.name 986 | 987 | if expected_rank is not None: 988 | assert_rank(tensor, expected_rank, name) 989 | 990 | shape = tensor.shape.as_list() 991 | 992 | non_static_indexes = [] 993 | for (index, dim) in enumerate(shape): 994 | if dim is None: 995 | non_static_indexes.append(index) 996 | 997 | if not non_static_indexes: 998 | return shape 999 | 1000 | dyn_shape = tf.shape(tensor) 1001 | for index in non_static_indexes: 1002 | shape[index] = dyn_shape[index] 1003 | return shape 1004 | 1005 | 1006 | def reshape_to_matrix(input_tensor): 1007 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 1008 | ndims = input_tensor.shape.ndims 1009 | if ndims < 2: 1010 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 1011 | (input_tensor.shape)) 1012 | if ndims == 2: 1013 | return input_tensor 1014 | 1015 | width = input_tensor.shape[-1] 1016 | output_tensor = tf.reshape(input_tensor, [-1, width]) 1017 | return output_tensor 1018 | 1019 | 1020 | def reshape_from_matrix(output_tensor, orig_shape_list): 1021 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 1022 | if len(orig_shape_list) == 2: 1023 | return output_tensor 1024 | 1025 | output_shape = get_shape_list(output_tensor) 1026 | 1027 | orig_dims = orig_shape_list[0:-1] 1028 | width = output_shape[-1] 1029 | 1030 | return tf.reshape(output_tensor, orig_dims + [width]) 1031 | 1032 | 1033 | def assert_rank(tensor, expected_rank, name=None): 1034 | """Raises an exception if the tensor rank is not of the expected rank. 1035 | 1036 | Args: 1037 | tensor: A tf.Tensor to check the rank of. 1038 | expected_rank: Python integer or list of integers, expected rank. 1039 | name: Optional name of the tensor for the error message. 1040 | 1041 | Raises: 1042 | ValueError: If the expected shape doesn't match the actual shape. 1043 | """ 1044 | if name is None: 1045 | name = tensor.name 1046 | 1047 | expected_rank_dict = {} 1048 | if isinstance(expected_rank, six.integer_types): 1049 | expected_rank_dict[expected_rank] = True 1050 | else: 1051 | for x in expected_rank: 1052 | expected_rank_dict[x] = True 1053 | 1054 | actual_rank = tensor.shape.ndims 1055 | if actual_rank not in expected_rank_dict: 1056 | scope_name = tf.get_variable_scope().name 1057 | raise ValueError( 1058 | "For the tensor `%s` in scope `%s`, the actual rank " 1059 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 1060 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 1061 | -------------------------------------------------------------------------------- /msmarco_doc_train/modeling/optimization_nvidia.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2018 The Google AI Language Team Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Functions and classes related to optimization (weight updates).""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import re 24 | import tensorflow as tf 25 | from tensorflow.python.ops import array_ops 26 | from tensorflow.python.ops import linalg_ops 27 | from tensorflow.python.ops import math_ops 28 | from horovod.tensorflow.compression import Compression 29 | 30 | 31 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, hvd=None, manual_fp16=False, use_fp16=False, num_accumulation_steps=1, 32 | optimizer_type="adam", allreduce_post_accumulation=False, freeze_embedding=False): 33 | """Creates an optimizer training op.""" 34 | global_step = tf.train.get_or_create_global_step() 35 | 36 | # avoid step change in learning rate at end of warmup phase 37 | if optimizer_type == "adam": 38 | power = 1.0 39 | decayed_learning_rate_at_crossover_point = init_lr * ( 40 | (1.0 - float(num_warmup_steps) / float(num_train_steps)) ** power) 41 | else: 42 | power = 0.5 43 | decayed_learning_rate_at_crossover_point = init_lr 44 | 45 | adjusted_init_lr = init_lr * \ 46 | (init_lr / decayed_learning_rate_at_crossover_point) 47 | print('decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e' % ( 48 | decayed_learning_rate_at_crossover_point, adjusted_init_lr)) 49 | 50 | learning_rate = tf.constant( 51 | value=adjusted_init_lr, shape=[], dtype=tf.float32) 52 | 53 | # Implements linear decay of the learning rate. 54 | learning_rate = tf.train.polynomial_decay( 55 | learning_rate, 56 | global_step, 57 | num_train_steps, 58 | end_learning_rate=0.0, 59 | power=power, 60 | cycle=False) 61 | 62 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 63 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 64 | if num_warmup_steps: 65 | global_steps_int = tf.cast(global_step, tf.int32) 66 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 67 | 68 | global_steps_float = tf.cast(global_steps_int, tf.float32) 69 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 70 | 71 | warmup_percent_done = global_steps_float / warmup_steps_float 72 | warmup_learning_rate = init_lr * warmup_percent_done 73 | 74 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 75 | learning_rate = ( 76 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 77 | 78 | if optimizer_type == "lamb": 79 | print("Initializing LAMB Optimizer") 80 | optimizer = LAMBOptimizer( 81 | learning_rate=learning_rate, 82 | weight_decay_rate=0.01, 83 | beta_1=0.9, 84 | beta_2=0.999, 85 | epsilon=1e-6, 86 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 87 | else: 88 | print("Initializing ADAM Weight Decay Optimizer") 89 | # It is recommended that you use this optimizer for fine tuning, since this 90 | # is how the model was trained (note that the Adam m/v variables are NOT 91 | # loaded from init_checkpoint.) 92 | optimizer = AdamWeightDecayOptimizer( 93 | learning_rate=learning_rate, 94 | weight_decay_rate=0.01, 95 | beta_1=0.9, 96 | beta_2=0.999, 97 | epsilon=1e-6, 98 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 99 | 100 | if hvd is not None and (num_accumulation_steps == 1 or (not allreduce_post_accumulation)): 101 | optimizer = hvd.DistributedOptimizer( 102 | optimizer, sparse_as_dense=True, compression=Compression.fp16 if use_fp16 or manual_fp16 else Compression.none) 103 | if manual_fp16 or use_fp16: 104 | loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager( 105 | init_loss_scale=2**32, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, decr_ratio=0.5) 106 | optimizer = tf.contrib.mixed_precision.LossScaleOptimizer( 107 | optimizer, loss_scale_manager) 108 | 109 | tvars = tf.trainable_variables() 110 | if freeze_embedding: 111 | tvars = [v for v in tvars if 'word_embeddings' not in v.name] 112 | grads_and_vars = optimizer.compute_gradients( 113 | loss * 1.0 / num_accumulation_steps, tvars) 114 | 115 | if num_accumulation_steps > 1: 116 | local_step = tf.get_variable(name="local_step", shape=[], dtype=tf.int32, trainable=False, 117 | initializer=tf.zeros_initializer) 118 | batch_finite = tf.get_variable(name="batch_finite", shape=[], dtype=tf.bool, trainable=False, 119 | initializer=tf.ones_initializer) 120 | accum_vars = [tf.get_variable( 121 | name=tvar.name.split(":")[0] + "/accum", 122 | shape=tvar.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) for tvar in tvars] 126 | 127 | reset_step = tf.cast(tf.math.equal(local_step % 128 | num_accumulation_steps, 0), dtype=tf.bool) 129 | local_step = tf.cond(reset_step, lambda: local_step.assign( 130 | tf.ones_like(local_step)), lambda: local_step.assign_add(1)) 131 | 132 | grads_and_vars_and_accums = [(gv[0], gv[1], accum_vars[i]) for i, gv in enumerate( 133 | grads_and_vars) if gv[0] is not None] 134 | grads, tvars, accum_vars = list(zip(*grads_and_vars_and_accums)) 135 | 136 | all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite( 137 | g)) for g in grads]) if manual_fp16 or use_fp16 else tf.constant(True, dtype=tf.bool) 138 | batch_finite = tf.cond(reset_step, 139 | lambda: batch_finite.assign(tf.math.logical_and( 140 | tf.constant(True, dtype=tf.bool), all_are_finite)), 141 | lambda: batch_finite.assign(tf.math.logical_and(batch_finite, all_are_finite))) 142 | 143 | # This is how the model was pre-trained. 144 | # ensure global norm is a finite number 145 | # to prevent clip_by_global_norm from having a hizzy fit. 146 | (clipped_grads, l2_norm) = tf.clip_by_global_norm( 147 | grads, clip_norm=1.0, 148 | use_norm=tf.cond( 149 | all_are_finite, 150 | lambda: tf.global_norm(grads), 151 | lambda: tf.constant(1.0))) 152 | 153 | accum_vars = tf.cond(reset_step, 154 | lambda: [accum_vars[i].assign( 155 | grad) for i, grad in enumerate(clipped_grads)], 156 | lambda: [accum_vars[i].assign_add(grad) for i, grad in enumerate(clipped_grads)]) 157 | 158 | def update(accum_vars): 159 | if allreduce_post_accumulation and hvd is not None: 160 | accum_vars = [hvd.allreduce(tf.convert_to_tensor(accum_var), compression=Compression.fp16 if use_fp16 or manual_fp16 else Compression.none) if isinstance(accum_var, tf.IndexedSlices) 161 | else hvd.allreduce(accum_var, compression=Compression.fp16 if use_fp16 or manual_fp16 else Compression.none) for accum_var in accum_vars] 162 | return optimizer.apply_gradients(list(zip(accum_vars, tvars)), global_step=global_step) 163 | 164 | update_step = tf.identity(tf.cast(tf.math.equal( 165 | local_step % num_accumulation_steps, 0), dtype=tf.bool), name="update_step") 166 | update_op = tf.cond(update_step, 167 | lambda: update(accum_vars), lambda: tf.no_op()) 168 | 169 | new_global_step = tf.cond(tf.math.logical_and(update_step, tf.cast(hvd.allreduce( 170 | tf.cast(batch_finite, tf.int32)), tf.bool)), lambda: global_step+1, lambda: global_step) 171 | new_global_step = tf.identity(new_global_step, name='step_update') 172 | train_op = tf.group(update_op, [global_step.assign(new_global_step)]) 173 | else: 174 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 175 | grads, tvars = list(zip(*grads_and_vars)) 176 | all_are_finite = tf.reduce_all( 177 | [tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 or manual_fp16 else tf.constant(True, dtype=tf.bool) 178 | 179 | # This is how the model was pre-trained. 180 | # ensure global norm is a finite number 181 | # to prevent clip_by_global_norm from having a hizzy fit. 182 | (clipped_grads, l2_norm) = tf.clip_by_global_norm( 183 | grads, clip_norm=1.0, 184 | use_norm=tf.cond( 185 | all_are_finite, 186 | lambda: tf.global_norm(grads), 187 | lambda: tf.constant(1.0))) 188 | 189 | train_op = optimizer.apply_gradients( 190 | list(zip(clipped_grads, tvars)), global_step=global_step) 191 | 192 | new_global_step = tf.cond( 193 | all_are_finite, lambda: global_step + 1, lambda: global_step + 1)### In order to align with fidelity pipeline, prevent steps not same for evaluation and train 194 | new_global_step = tf.identity(new_global_step, name='step_update') 195 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 196 | return train_op, learning_rate, l2_norm 197 | 198 | 199 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 200 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 201 | 202 | def __init__(self, 203 | learning_rate, 204 | weight_decay_rate=0.0, 205 | beta_1=0.9, 206 | beta_2=0.999, 207 | epsilon=1e-6, 208 | exclude_from_weight_decay=None, 209 | name="AdamWeightDecayOptimizer"): 210 | """Constructs a AdamWeightDecayOptimizer.""" 211 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 212 | 213 | self.learning_rate = tf.identity(learning_rate, name='learning_rate') 214 | self.weight_decay_rate = weight_decay_rate 215 | self.beta_1 = beta_1 216 | self.beta_2 = beta_2 217 | self.epsilon = epsilon 218 | self.exclude_from_weight_decay = exclude_from_weight_decay 219 | 220 | def apply_gradients(self, grads_and_vars, global_step=None, name=None, 221 | manual_fp16=False): 222 | """See base class.""" 223 | assignments = [] 224 | for (grad, param) in grads_and_vars: 225 | if grad is None or param is None: 226 | continue 227 | 228 | param_name = self._get_variable_name(param.name) 229 | has_shadow = manual_fp16 and param.dtype.base_dtype != tf.float32 230 | if has_shadow: 231 | # create shadow fp32 weights for fp16 variable 232 | param_fp32 = tf.get_variable( 233 | name=param_name + "/shadow", 234 | dtype=tf.float32, 235 | trainable=False, 236 | initializer=tf.cast(param.initialized_value(), tf.float32)) 237 | else: 238 | param_fp32 = param 239 | 240 | m = tf.get_variable( 241 | name=param_name + "/adam_m", 242 | shape=param.shape.as_list(), 243 | dtype=tf.float32, 244 | trainable=False, 245 | initializer=tf.zeros_initializer()) 246 | v = tf.get_variable( 247 | name=param_name + "/adam_v", 248 | shape=param.shape.as_list(), 249 | dtype=tf.float32, 250 | trainable=False, 251 | initializer=tf.zeros_initializer()) 252 | 253 | # Standard Adam update. 254 | next_m = ( 255 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 256 | next_v = ( 257 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 258 | tf.square(grad))) 259 | 260 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 261 | 262 | # Just adding the square of the weights to the loss function is *not* 263 | # the correct way of using L2 regularization/weight decay with Adam, 264 | # since that will interact with the m and v parameters in strange ways. 265 | # 266 | # Instead we want ot decay the weights in a manner that doesn't interact 267 | # with the m/v parameters. This is equivalent to adding the square 268 | # of the weights to the loss with plain (non-momentum) SGD. 269 | if self._do_use_weight_decay(param_name): 270 | update += self.weight_decay_rate * param_fp32 271 | 272 | update_with_lr = self.learning_rate * update 273 | 274 | next_param = param_fp32 - update_with_lr 275 | 276 | if has_shadow: 277 | # cast shadow fp32 weights to fp16 and assign to trainable variable 278 | param.assign(tf.cast(next_param, param.dtype.base_dtype)) 279 | assignments.extend( 280 | [param_fp32.assign(next_param), 281 | m.assign(next_m), 282 | v.assign(next_v)]) 283 | return tf.group(*assignments, name=name) 284 | 285 | def _do_use_weight_decay(self, param_name): 286 | """Whether to use L2 weight decay for `param_name`.""" 287 | if not self.weight_decay_rate: 288 | return False 289 | if self.exclude_from_weight_decay: 290 | for r in self.exclude_from_weight_decay: 291 | if re.search(r, param_name) is not None: 292 | return False 293 | return True 294 | 295 | def _get_variable_name(self, param_name): 296 | """Get the variable name from the tensor name.""" 297 | m = re.match("^(.*):\\d+$", param_name) 298 | if m is not None: 299 | param_name = m.group(1) 300 | return param_name 301 | 302 | 303 | class LAMBOptimizer(tf.train.Optimizer): 304 | """A LAMB optimizer that includes "correct" L2 weight decay.""" 305 | 306 | def __init__(self, 307 | learning_rate, 308 | weight_decay_rate=0.0, 309 | beta_1=0.9, 310 | beta_2=0.999, 311 | epsilon=1e-6, 312 | exclude_from_weight_decay=None, 313 | name="LAMBOptimizer"): 314 | """Constructs a LAMBOptimizer.""" 315 | super(LAMBOptimizer, self).__init__(False, name) 316 | 317 | self.learning_rate = tf.identity(learning_rate, name='learning_rate') 318 | self.weight_decay_rate = weight_decay_rate 319 | self.beta_1 = beta_1 320 | self.beta_2 = beta_2 321 | self.epsilon = epsilon 322 | self.exclude_from_weight_decay = exclude_from_weight_decay 323 | self.steps = 0 324 | 325 | def apply_gradients(self, grads_and_vars, global_step=None, name=None, 326 | manual_fp16=False): 327 | """See base class.""" 328 | assignments = [] 329 | for (grad, param) in grads_and_vars: 330 | if grad is None or param is None: 331 | continue 332 | 333 | param_name = self._get_variable_name(param.name) 334 | has_shadow = manual_fp16 and param.dtype.base_dtype != tf.float32 335 | if has_shadow: 336 | # create shadow fp32 weights for fp16 variable 337 | param_fp32 = tf.get_variable( 338 | name=param_name + "/shadow", 339 | dtype=tf.float32, 340 | trainable=False, 341 | initializer=tf.cast(param.initialized_value(), tf.float32)) 342 | else: 343 | param_fp32 = param 344 | 345 | m = tf.get_variable( 346 | name=param_name + "/adam_m", 347 | shape=param.shape.as_list(), 348 | dtype=tf.float32, 349 | trainable=False, 350 | initializer=tf.zeros_initializer()) 351 | v = tf.get_variable( 352 | name=param_name + "/adam_v", 353 | shape=param.shape.as_list(), 354 | dtype=tf.float32, 355 | trainable=False, 356 | initializer=tf.zeros_initializer()) 357 | 358 | # LAMB update 359 | next_m = ( 360 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 361 | next_v = ( 362 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 363 | tf.square(grad))) 364 | 365 | self.steps += 1 366 | beta1_correction = (1 - self.beta_1 ** self.steps) 367 | beta2_correction = (1 - self.beta_2 ** self.steps) 368 | 369 | next_m_unbiased = next_m / beta1_correction 370 | next_v_unbiased = next_v / beta2_correction 371 | 372 | update = next_m_unbiased / \ 373 | (tf.sqrt(next_v_unbiased) + self.epsilon) 374 | 375 | # Just adding the square of the weights to the loss function is *not* 376 | # the correct way of using L2 regularization/weight decay with Adam, 377 | # since that will interact with the m and v parameters in strange ways. 378 | # 379 | # Instead we want ot decay the weights in a manner that doesn't interact 380 | # with the m/v parameters. This is equivalent to adding the square 381 | # of the weights to the loss with plain (non-momentum) SGD. 382 | if self._do_use_weight_decay(param_name): 383 | update += self.weight_decay_rate * param_fp32 384 | 385 | w_norm = linalg_ops.norm(param, ord=2) 386 | g_norm = linalg_ops.norm(update, ord=2) 387 | ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( 388 | math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) 389 | 390 | update_with_lr = ratio * self.learning_rate * update 391 | 392 | next_param = param_fp32 - update_with_lr 393 | 394 | if has_shadow: 395 | # cast shadow fp32 weights to fp16 and assign to trainable variable 396 | param.assign(tf.cast(next_param, param.dtype.base_dtype)) 397 | assignments.extend( 398 | [param_fp32.assign(next_param), 399 | m.assign(next_m), 400 | v.assign(next_v)]) 401 | return tf.group(*assignments, name=name) 402 | 403 | def _do_use_weight_decay(self, param_name): 404 | """Whether to use L2 weight decay for `param_name`.""" 405 | if not self.weight_decay_rate: 406 | return False 407 | if self.exclude_from_weight_decay: 408 | for r in self.exclude_from_weight_decay: 409 | if re.search(r, param_name) is not None: 410 | return False 411 | return True 412 | 413 | def _get_variable_name(self, param_name): 414 | """Get the variable name from the tensor name.""" 415 | m = re.match("^(.*):\\d+$", param_name) 416 | if m is not None: 417 | param_name = m.group(1) 418 | return param_name 419 | -------------------------------------------------------------------------------- /msmarco_doc_train/train_msmarco_doc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2018 The Google AI Language Team Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """BERT finetuning runner.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import time 24 | import re 25 | import shutil 26 | import math 27 | import collections 28 | import csv 29 | import os 30 | from modeling import modeling_bison 31 | import tensorflow as tf 32 | import horovod.tensorflow as hvd 33 | import time 34 | from utils.utils import LogEvalRunHook, LogTrainRunHook 35 | import numpy as np 36 | from tensorflow.python.client import device_lib 37 | from tensorflow.python.summary.writer import writer_cache 38 | import logging 39 | import sys 40 | 41 | flags = tf.flags 42 | FLAGS = flags.FLAGS 43 | tf.logging.set_verbosity(logging.INFO) 44 | 45 | # Required parameters 46 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 47 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 48 | flags.DEFINE_string("preprocess_train_dir", None, "The preprocess training data directory.") 49 | flags.DEFINE_integer("train_line_count", None, "Data file line count.") 50 | flags.DEFINE_integer("train_partition_count", None, "Total count of training files.") 51 | flags.DEFINE_string("preprocess_train_file_name", "train.tf_record", "The preprocess training file name.") 52 | flags.DEFINE_string("preprocess_eval_dir", None, "The preprocess eval data directory.") 53 | flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.") 54 | 55 | flags.DEFINE_string( 56 | "query_bert_config_file", None, 57 | "The config json file corresponding to the pre-trained BERT model. " 58 | "This specifies the model architecture.") 59 | 60 | flags.DEFINE_string( 61 | "meta_bert_config_file", None, 62 | "The config json file corresponding to the pre-trained BERT model. " 63 | "This specifies the model architecture.") 64 | 65 | # Other parameters 66 | flags.DEFINE_string( 67 | "init_checkpoint", None, 68 | "Initial checkpoint (usually from a pre-trained BERT model).") 69 | 70 | flags.DEFINE_integer( 71 | "max_seq_length_query", 20, 72 | "The maximum total input sequence length after WordPiece tokenization. " 73 | "Sequences longer than this will be truncated, and sequences shorter " 74 | "than this will be padded.") 75 | 76 | flags.DEFINE_integer( 77 | "max_seq_length_url", 30, 78 | "The maximum total input sequence length after WordPiece tokenization. " 79 | "Sequences longer than this will be truncated, and sequences shorter " 80 | "than this will be padded.") 81 | 82 | flags.DEFINE_integer( 83 | "max_seq_length_title", 30, 84 | "The maximum total input sequence length after WordPiece tokenization. " 85 | "Sequences longer than this will be truncated, and sequences shorter " 86 | "than this will be padded.") 87 | 88 | flags.DEFINE_integer( 89 | "max_seq_length_body", 128, 90 | "The maximum total input sequence length after WordPiece tokenization. " 91 | "Sequences longer than this will be truncated, and sequences shorter " 92 | "than this will be padded.") 93 | 94 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 95 | 96 | flags.DEFINE_integer("eval_batch_size", 32, "Total batch size for eval.") 97 | 98 | flags.DEFINE_float("learning_rate", 5e-5, 99 | "The initial learning rate for Adam.") 100 | 101 | flags.DEFINE_float("num_train_epochs", 3.0, 102 | "Total number of training epochs to perform.") 103 | 104 | flags.DEFINE_float( 105 | "warmup_proportion", 0.1, 106 | "Proportion of training to perform linear learning rate warmup for. " 107 | "E.g., 0.1 = 10% of training.") 108 | 109 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 110 | "How often to save the model checkpoint.") 111 | 112 | flags.DEFINE_integer("iterations_per_loop", 1000, 113 | "How many steps to make in each estimator call.") 114 | flags.DEFINE_integer("num_accumulation_steps", 1, 115 | "Number of accumulation steps before gradient update" 116 | "Global batch size = num_accumulation_steps * train_batch_size") 117 | flags.DEFINE_bool("use_fp16", False, 118 | "Whether to use fp32 or fp16 arithmetic on GPU.") 119 | flags.DEFINE_bool("use_xla", False, 120 | "Whether to enable XLA JIT compilation.") 121 | flags.DEFINE_bool("horovod", False, 122 | "Whether to use Horovod for multi-gpu runs") 123 | flags.DEFINE_bool("use_one_hot_embeddings", False, 124 | "Whether to use use_one_hot_embeddings") 125 | 126 | flags.DEFINE_float("nce_temperature", 10, "nce_temperature") 127 | flags.DEFINE_float("nce_weight", 0.5, "nce_weight") 128 | flags.DEFINE_string("activation", 'relu', "activation") 129 | 130 | flags.DEFINE_bool( 131 | "verbose_logging", False, 132 | "If true, all of the warnings related to data processing will be printed. " 133 | "A number of warnings are expected for a normal SQuAD evaluation.") 134 | 135 | flags.DEFINE_bool( 136 | "Fix_Sim_Weight", False, 137 | "Whether to fix sim weight and bias.") 138 | 139 | flags.DEFINE_integer("sim_weight", None, "sim_weight") 140 | 141 | flags.DEFINE_integer("sim_bias", None, "sim_bias") 142 | 143 | flags.DEFINE_bool("enable_body", False, "whether to enable body in training.") 144 | 145 | 146 | 147 | def get_func_by_task(task: str): 148 | """ 149 | Get model builder function, input builder function by task 150 | """ 151 | if task == "bison": 152 | from modeling.create_model_bison import model_fn_builder, file_based_input_fn_builder, eval_file_based_input_fn_builder 153 | return model_fn_builder, file_based_input_fn_builder, eval_file_based_input_fn_builder 154 | else: 155 | raise ValueError("Unsupported Task: " + task) 156 | 157 | 158 | # this function will check how many example in a tfrecord file (used for eval-file) 159 | # file_path must be str or list 160 | def check_line_count_in_tfrecords(file_path): 161 | line_count = 0 162 | 163 | if isinstance(file_path, str): 164 | for _ in tf.python_io.tf_record_iterator(file_path): 165 | line_count += 1 166 | elif isinstance(file_path, list): 167 | for f in file_path: 168 | for _ in tf.python_io.tf_record_iterator(f): 169 | line_count += 1 170 | else: 171 | raise ValueError('file_path must be str or str-list') 172 | 173 | return line_count 174 | 175 | 176 | # folder: folder path 177 | # files: list 178 | def find_all_file_in_folder(folder, file_paths): 179 | for file_name in os.listdir(folder): 180 | path = os.path.join(folder, file_name) 181 | 182 | if os.path.isfile(path): 183 | file_paths.append(path) 184 | elif os.path.isdir(path): 185 | find_all_file_in_folder(path, file_paths) 186 | 187 | def main(_): 188 | tf.logging.set_verbosity(tf.logging.INFO) 189 | 190 | if FLAGS.horovod: 191 | hvd.init() 192 | 193 | if FLAGS.use_fp16: 194 | os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1" 195 | 196 | model_fn_builder, file_based_input_fn_builder, eval_file_based_input_fn_builder = get_func_by_task(FLAGS.task_name.lower()) 197 | 198 | if not FLAGS.do_train: 199 | raise ValueError("`do_train` must be True.") 200 | 201 | query_bert_config = modeling_bison.BertConfig.from_json_file( 202 | FLAGS.query_bert_config_file) 203 | meta_bert_config = modeling_bison.BertConfig.from_json_file( 204 | FLAGS.meta_bert_config_file) 205 | 206 | # Sequence length check 207 | if FLAGS.max_seq_length_query > query_bert_config.max_position_embeddings: 208 | raise ValueError( 209 | "Cannot use query sequence length %d because the BERT model " 210 | "was only trained up to sequence length %d" % 211 | (FLAGS.max_seq_length_query, query_bert_config.max_position_embeddings)) 212 | 213 | meta_seq_length = FLAGS.max_seq_length_url + FLAGS.max_seq_length_title 214 | if FLAGS.enable_body: 215 | meta_seq_length += FLAGS.max_seq_length_body 216 | if meta_seq_length > meta_bert_config.max_position_embeddings: 217 | raise ValueError( 218 | "Cannot use meta sequence length %d because the BERT model " 219 | "was only trained up to sequence length %d" % 220 | (meta_seq_length, meta_bert_config.max_position_embeddings)) 221 | 222 | master_process = True 223 | training_hooks = [] 224 | global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps 225 | hvd_rank = 0 226 | config = tf.ConfigProto() 227 | if FLAGS.horovod: 228 | tf.logging.info("Multi-GPU training with TF Horovod") 229 | tf.logging.info("hvd.size() = %d hvd.rank() = %d", 230 | hvd.size(), hvd.rank()) 231 | global_batch_size = FLAGS.train_batch_size * \ 232 | FLAGS.num_accumulation_steps * hvd.size() 233 | master_process = (hvd.rank() == 0) 234 | hvd_rank = hvd.rank() 235 | config.gpu_options.allow_growth = True 236 | config.gpu_options.visible_device_list = str(hvd.local_rank()) 237 | if hvd.size() > 1: 238 | training_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) 239 | if FLAGS.use_xla: 240 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 241 | 242 | tf.gfile.MakeDirs(FLAGS.output_dir) 243 | run_config = tf.estimator.RunConfig( 244 | model_dir=FLAGS.output_dir if master_process else None, 245 | session_config=config, 246 | save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None, 247 | keep_checkpoint_max=10) 248 | 249 | if master_process: 250 | tf.logging.info("***** Configuaration *****") 251 | for key in FLAGS.__flags.keys(): 252 | tf.logging.info(' {}: {}'.format(key, getattr(FLAGS, key))) 253 | tf.logging.info("**************************") 254 | 255 | train_examples = None 256 | num_train_steps = None 257 | num_warmup_steps = None 258 | train_examples_count = FLAGS.train_line_count 259 | log_train_run_hook = LogTrainRunHook(global_batch_size, hvd_rank) 260 | training_hooks.append(log_train_run_hook) 261 | 262 | if FLAGS.do_train: 263 | num_train_steps = int( 264 | train_examples_count / global_batch_size * FLAGS.num_train_epochs) 265 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 266 | 267 | model_fn = model_fn_builder( 268 | query_bert_config=query_bert_config, 269 | meta_bert_config=meta_bert_config, 270 | init_checkpoint=FLAGS.init_checkpoint, 271 | learning_rate=FLAGS.learning_rate, 272 | num_train_steps=num_train_steps, 273 | num_warmup_steps=num_warmup_steps, 274 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings, 275 | nce_temperature=FLAGS.nce_temperature, 276 | nce_weight=FLAGS.nce_weight, 277 | hvd=None if not FLAGS.horovod else hvd) 278 | 279 | estimator = tf.estimator.Estimator( 280 | model_fn=model_fn, 281 | config=run_config) 282 | 283 | if FLAGS.do_train: 284 | start_index = 0 285 | end_index = FLAGS.train_partition_count 286 | 287 | if FLAGS.horovod: 288 | tfrecord_per_GPU = int(FLAGS.train_partition_count / hvd.size()) 289 | start_index = hvd.rank() * tfrecord_per_GPU 290 | end_index = start_index+tfrecord_per_GPU 291 | 292 | if hvd.rank() == hvd.size(): 293 | end_index = FLAGS.train_partition_count 294 | 295 | tf.logging.info("***** Running training *****") 296 | tf.logging.info(" Num examples = %d", train_examples_count) 297 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 298 | tf.logging.info(" Num steps = %d", num_train_steps) 299 | tf.logging.info(" hvd rank = %d", hvd.rank()) 300 | tf.logging.info(" Num start_index = %d", start_index) 301 | tf.logging.info(" Num end_index = %d", end_index) 302 | 303 | train_file_list = [] 304 | for i in range(start_index, end_index): 305 | train_file_list.append(os.path.join( 306 | FLAGS.preprocess_train_dir, str(i), FLAGS.preprocess_train_file_name)) 307 | tf.logging.info("merge "+str(end_index-start_index) + 308 | " preprocessed file from preprocess dir") 309 | tf.logging.info(train_file_list) 310 | 311 | train_input_fn = file_based_input_fn_builder( 312 | input_file=train_file_list, 313 | batch_size=FLAGS.train_batch_size, 314 | query_seq_length=FLAGS.max_seq_length_query, 315 | meta_seq_length=meta_seq_length, 316 | is_training=True, 317 | drop_remainder=True, 318 | is_fidelity_eval=False, 319 | hvd=None if not FLAGS.horovod else hvd) 320 | 321 | # initilize eval file 322 | # must set preprocess_eval_dir, all file in folder preprocess_eval_dir will thinked as tfrecord file 323 | if FLAGS.preprocess_eval_dir is None: 324 | raise ValueError('must set preprocess_eval_dir by hand.') 325 | 326 | all_eval_files = [] 327 | eval_file_list = [] 328 | find_all_file_in_folder(FLAGS.preprocess_eval_dir, all_eval_files) 329 | for i in range(len(all_eval_files)): 330 | if hvd.rank() == i % hvd.size(): 331 | eval_file_list.append(all_eval_files[i]) 332 | 333 | if 0 == len(eval_file_list): 334 | raise ValueError(' Rank: %d get eval file empty.' % (hvd.rank())) 335 | 336 | tf.logging.info("**********Check how many eval example in current rank*************") 337 | eval_examples_count = check_line_count_in_tfrecords(eval_file_list) 338 | 339 | eval_steps = int(math.ceil(eval_examples_count / FLAGS.eval_batch_size)) 340 | 341 | tf.logging.info("***** Running evaluation *****") 342 | tf.logging.info(" Rank: %d will eval files:%s" % 343 | (hvd.rank(), str(eval_file_list))) 344 | tf.logging.info(" Rank: %d eval example count:%d" % 345 | (hvd.rank(), eval_examples_count)) 346 | tf.logging.info(" Rank: %d eval batch size:%d" % 347 | (hvd.rank(), FLAGS.eval_batch_size)) 348 | tf.logging.info(" Rank: %d eval_steps:%d" % 349 | (hvd.rank(), eval_steps)) 350 | 351 | eval_input_fn = eval_file_based_input_fn_builder( 352 | input_file=eval_file_list, 353 | query_seq_length=FLAGS.max_seq_length_query, 354 | meta_seq_length=meta_seq_length, 355 | drop_remainder=False, 356 | is_fidelity_eval=False) 357 | 358 | # create InMemoryEvaluatorHook 359 | in_memory_evaluator = tf.estimator.experimental.InMemoryEvaluatorHook( 360 | estimator=estimator, 361 | steps=eval_steps, # steps must be set or will not print any log, do not know why 362 | input_fn=eval_input_fn, 363 | every_n_iter=FLAGS.save_checkpoints_steps, 364 | name="fidelity_eval") 365 | training_hooks.append(in_memory_evaluator) 366 | 367 | train_start_time = time.time() 368 | estimator.train(input_fn=train_input_fn, 369 | max_steps=num_train_steps, hooks=training_hooks) 370 | train_time_elapsed = time.time() - train_start_time 371 | train_time_wo_overhead = log_train_run_hook.total_time 372 | avg_sentences_per_second = num_train_steps * \ 373 | global_batch_size * 1.0 / train_time_elapsed 374 | ss_sentences_per_second = ( 375 | num_train_steps - log_train_run_hook.skipped) * global_batch_size * 1.0 / train_time_wo_overhead 376 | 377 | if master_process: 378 | tf.logging.info("-----------------------------") 379 | tf.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed, 380 | num_train_steps * global_batch_size) 381 | tf.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead, 382 | (num_train_steps - log_train_run_hook.skipped) * global_batch_size) 383 | tf.logging.info( 384 | "Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second) 385 | tf.logging.info( 386 | "Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second) 387 | tf.logging.info("-----------------------------") 388 | 389 | 390 | if __name__ == "__main__": 391 | flags.mark_flag_as_required("task_name") 392 | flags.mark_flag_as_required("do_train") 393 | flags.mark_flag_as_required("preprocess_train_dir") 394 | flags.mark_flag_as_required("train_line_count") 395 | flags.mark_flag_as_required("train_partition_count") 396 | flags.mark_flag_as_required("preprocess_train_file_name") 397 | flags.mark_flag_as_required("preprocess_eval_dir") 398 | flags.mark_flag_as_required("output_dir") 399 | flags.mark_flag_as_required("query_bert_config_file") 400 | flags.mark_flag_as_required("meta_bert_config_file") 401 | 402 | tf.app.run() 403 | -------------------------------------------------------------------------------- /msmarco_doc_train/utils/gpu_environment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensorflow as tf 17 | import numpy as np 18 | 19 | 20 | def float32_variable_storage_getter(getter, name, shape=None, dtype=None, 21 | initializer=None, regularizer=None, 22 | trainable=True, 23 | *args, **kwargs): 24 | """Custom variable getter that forces trainable variables to be stored in 25 | float32 precision and then casts them to the training precision. 26 | """ 27 | storage_dtype = tf.float32 if trainable else dtype 28 | variable = getter(name, shape, dtype=storage_dtype, 29 | initializer=initializer, regularizer=regularizer, 30 | trainable=trainable, 31 | *args, **kwargs) 32 | if trainable and dtype != tf.float32: 33 | variable = tf.cast(variable, dtype) 34 | return variable 35 | 36 | 37 | def get_custom_getter(compute_type): 38 | return float32_variable_storage_getter if compute_type == tf.float16 else None 39 | -------------------------------------------------------------------------------- /msmarco_doc_train/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | import tensorflow as tf 15 | import time 16 | 17 | # report latency and throughput during eval 18 | 19 | 20 | class LogEvalRunHook(tf.train.SessionRunHook): 21 | def __init__(self, global_batch_size, hvd_rank=-1): 22 | self.global_batch_size = global_batch_size 23 | self.hvd_rank = hvd_rank 24 | self.total_time = 0.0 25 | self.count = 0 26 | self.skipped = 0 27 | self.time_list = [] 28 | 29 | def before_run(self, run_context): 30 | self.t0 = time.time() 31 | 32 | def after_run(self, run_context, run_values): 33 | elapsed_secs = time.time() - self.t0 34 | self.count += 1 35 | 36 | # Removing first 2 (arbitrary) number of startup iterations from perf evaluations 37 | if self.count <= 2: 38 | print("Skipping time record for ", self.count, " due to overhead") 39 | self.skipped += 1 40 | else: 41 | self.time_list.append(elapsed_secs) 42 | self.total_time += elapsed_secs 43 | 44 | # report throughput during training 45 | 46 | 47 | class LogTrainRunHook(tf.train.SessionRunHook): 48 | def __init__(self, global_batch_size, hvd_rank=-1, save_checkpoints_steps=1000): 49 | self.global_batch_size = global_batch_size 50 | self.hvd_rank = hvd_rank 51 | self.save_checkpoints_steps = save_checkpoints_steps 52 | 53 | self.total_time = 0.0 54 | self.count = 0 # Holds number of iterations, including skipped iterations for fp16 loss scaling 55 | 56 | def after_create_session(self, session, coord): 57 | self.init_global_step = session.run(tf.train.get_global_step()) 58 | 59 | def before_run(self, run_context): 60 | self.t0 = time.time() 61 | return tf.train.SessionRunArgs( 62 | fetches=['step_update:0']) 63 | 64 | def after_run(self, run_context, run_values): 65 | elapsed_secs = time.time() - self.t0 66 | self.global_step = run_values.results[0] 67 | self.count += 1 68 | 69 | # Removing first step + first two steps after every checkpoint save 70 | if (self.global_step - self.init_global_step) % self.save_checkpoints_steps <= 1: 71 | print("Skipping time record for ", self.global_step, 72 | " due to checkpoint-saving/warmup overhead") 73 | else: 74 | self.total_time += elapsed_secs 75 | 76 | def end(self, session): 77 | num_global_steps = self.global_step - self.init_global_step 78 | 79 | self.skipped = (num_global_steps // self.save_checkpoints_steps) * 2 + \ 80 | min(2, num_global_steps % self.save_checkpoints_steps) - 1 81 | --------------------------------------------------------------------------------