├── LICENSE ├── README.md ├── build_data.sh ├── build_openwebtext_pretraining_dataset.py ├── build_pretraining_dataset.py ├── configure_finetuning.py ├── configure_pretraining.py ├── download_glue_data.py ├── finetune.sh ├── finetune ├── __init__.py ├── classification │ ├── classification_metrics.py │ └── classification_tasks.py ├── feature_spec.py ├── preprocessing.py ├── qa │ ├── mrqa_official_eval.py │ ├── qa_metrics.py │ ├── qa_tasks.py │ ├── squad_official_eval.py │ └── squad_official_eval_v1.py ├── scorer.py ├── tagging │ ├── tagging_metrics.py │ ├── tagging_tasks.py │ └── tagging_utils.py ├── task.py └── task_builder.py ├── model ├── __init__.py ├── modeling.py ├── optimization.py └── tokenization.py ├── pretrain.sh ├── pretrain ├── __init__.py ├── pretrain_data.py └── pretrain_helpers.py ├── run_finetuning.py ├── run_pretraining.py ├── util ├── __init__.py ├── training_utils.py └── utils.py └── vocab.txt /LICENSE: -------------------------------------------------------------------------------- 1 | The Clear BSD License 2 | 3 | Copyright (c) [2012]-[2020] Shanghai Yitu Technology Co., Ltd. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | * Neither the name of Shanghai Yitu Technology Co., Ltd. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY SHANGHAI YITU TECHNOLOGY CO., LTD. AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SHANGHAI YITU TECHNOLOGY CO., LTD. OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvBERT 2 | 3 | ## Introduction 4 | 5 | In this repo, we introduce a new architecture **ConvBERT** for pre-training based language model. The code is tested on a V100 GPU. For detailed description and experimental results, please refer to our NeurIPS 2020 paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496). 6 | 7 | ## Requirements 8 | * Python 3 9 | * tensorflow 1.15 10 | * numpy 11 | * scikit-learn 12 | 13 | ## Experiments 14 | 15 | 16 | ### Pre-training 17 | 18 | These instructions pre-train a medium-small sized ConvBERT model (17M parameters) using the [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) corpus. 19 | 20 | To build the tf-record and pre-train the model, download the [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) corpus (12G) and **setup your data directory** in `build_data.sh` and `pretrain.sh`. Then run 21 | 22 | ```bash 23 | bash build_data.sh 24 | ``` 25 | 26 | The processed data require roughly 30G of disk space. Then, to pre-train the model, run 27 | 28 | ```bash 29 | bash pretrain.sh 30 | ``` 31 | 32 | See `configure_pretraining.py` for the details of the supported hyperparameters. 33 | 34 | ### Fine-tining 35 | 36 | We gives the instruction to fine-tune a pre-trained medium-small sized ConvBERT model (17M parameters) on GLUE. You can refer to the Google Colab notebook for a [quick example](https://colab.research.google.com/drive/1WIu2Cc1C8E7ayZBzEmpfd5sXOhe7Ehhz?usp=sharing). See our paper for more details on model performance. Pre-trained model can be found [here](https://drive.google.com/drive/folders/1pSsPcQrGXyt1FB45clALUQf-WTNAbUQa?usp=sharing). (You can also download it from [baidu cloud](https://pan.baidu.com/s/1jPo0e94p2dB8UBz33QuMrQ) with extraction code m9d2.) 37 | 38 | To evaluate the performance on GLUE, you can download the GLUE data by running 39 | ```bash 40 | python3 download_glue_data.py 41 | ``` 42 | Set up the data by running `mv CoLA cola && mv MNLI mnli && mv MRPC mrpc && mv QNLI qnli && mv QQP qqp && mv RTE rte && mv SST-2 sst && mv STS-B sts && mv diagnostic/diagnostic.tsv mnli && mkdir -p $DATA_DIR/finetuning_data && mv * $DATA_DIR/finetuning_data`. After preparing the GLUE data, **setup your data directory** in `finetune.sh` and run 43 | ```bash 44 | bash finetune.sh 45 | ``` 46 | And you can test different tasks by changing configs in `finetune.sh`. 47 | 48 | If you find this repo helpful, please consider cite 49 | ```bibtex 50 | @inproceedings{NEURIPS2020_96da2f59, 51 | author = {Jiang, Zi-Hang and Yu, Weihao and Zhou, Daquan and Chen, Yunpeng and Feng, Jiashi and Yan, Shuicheng}, 52 | booktitle = {Advances in Neural Information Processing Systems}, 53 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin}, 54 | pages = {12837--12848}, 55 | publisher = {Curran Associates, Inc.}, 56 | title = {ConvBERT: Improving BERT with Span-based Dynamic Convolution}, 57 | url = {https://proceedings.neurips.cc/paper/2020/file/96da2f590cd7246bbde0051047b0d6f7-Paper.pdf}, 58 | volume = {33}, 59 | year = {2020} 60 | } 61 | ``` 62 | # References 63 | 64 | Here are some great resources we benefit: 65 | 66 | Codebase: Our codebase are based on [ELECTRA](https://github.com/google-research/electra). 67 | 68 | Dynamic convolution: [Implementation](https://github.com/pytorch/fairseq/blob/265791b727b664d4d7da3abd918a3f6fb70d7337/fairseq/modules/lightconv_layer/lightconv_layer.py#L75) from [Pay Less Attention with Lightweight and Dynamic Convolutions](https://openreview.net/pdf?id=SkVhlh09tX) 69 | 70 | Dataset: [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) from [Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) 71 | 72 | -------------------------------------------------------------------------------- /build_data.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=/path/to/data_dir 2 | # please set your data_dir, like ~/data/convbert 3 | 4 | # extract data 5 | tar xf openwebtext.tar.xz 6 | # move to data_dir 7 | mv openwebtext $DATA_DIR/openwebtext 8 | cp vocab.txt $DATA_DIR/vocab.txt 9 | # build pre-train tf-record 10 | python3 build_openwebtext_pretraining_dataset.py --data-dir $DATA_DIR --num-processes 5 11 | -------------------------------------------------------------------------------- /build_openwebtext_pretraining_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Preprocessess the Open WebText corpus for pre-training.""" 4 | 5 | import argparse 6 | import multiprocessing 7 | import os 8 | import random 9 | import tarfile 10 | import time 11 | import tensorflow.compat.v1 as tf 12 | 13 | import build_pretraining_dataset 14 | from util import utils 15 | 16 | 17 | def write_examples(job_id, args): 18 | """A single process creating and writing out pre-processed examples.""" 19 | job_tmp_dir = os.path.join(args.data_dir, "tmp", "job_" + str(job_id)) 20 | owt_dir = os.path.join(args.data_dir, "openwebtext") 21 | 22 | def log(*args): 23 | msg = " ".join(map(str, args)) 24 | print("Job {}:".format(job_id), msg) 25 | 26 | log("Creating example writer") 27 | example_writer = build_pretraining_dataset.ExampleWriter( 28 | job_id=job_id, 29 | vocab_file=os.path.join(args.data_dir, "vocab.txt"), 30 | output_dir=os.path.join(args.data_dir, "pretrain_tfrecords"), 31 | max_seq_length=args.max_seq_length, 32 | num_jobs=args.num_processes, 33 | blanks_separate_docs=False, 34 | strip_accents=args.strip_accents, 35 | ) 36 | log("Writing tf examples") 37 | fnames = sorted(tf.io.gfile.listdir(owt_dir)) 38 | fnames = [f for (i, f) in enumerate(fnames) 39 | if i % args.num_processes == job_id] 40 | random.shuffle(fnames) 41 | start_time = time.time() 42 | for file_no, fname in enumerate(fnames): 43 | if file_no > 0 and file_no % 10 == 0: 44 | elapsed = time.time() - start_time 45 | log("processed {:}/{:} files ({:.1f}%), ELAPSED: {:}s, ETA: {:}s, " 46 | "{:} examples written".format( 47 | file_no, len(fnames), 100.0 * file_no / len(fnames), int(elapsed), 48 | int((len(fnames) - file_no) / (file_no / elapsed)), 49 | example_writer.n_written)) 50 | utils.rmkdir(job_tmp_dir) 51 | with tarfile.open(os.path.join(owt_dir, fname)) as f: 52 | f.extractall(job_tmp_dir) 53 | extracted_files = tf.io.gfile.listdir(job_tmp_dir) 54 | random.shuffle(extracted_files) 55 | for txt_fname in extracted_files: 56 | example_writer.write_examples(os.path.join(job_tmp_dir, txt_fname)) 57 | example_writer.finish() 58 | log("Done!") 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser(description=__doc__) 63 | parser.add_argument("--data-dir", required=True, 64 | help="Location of data (vocab file, corpus, etc).") 65 | parser.add_argument("--max-seq-length", default=128, type=int, 66 | help="Number of tokens per example.") 67 | parser.add_argument("--num-processes", default=1, type=int, 68 | help="Parallelize across multiple processes.") 69 | 70 | # toggle strip-accents and set default to True which is the default behavior 71 | parser.add_argument("--do-strip-accents", dest='strip_accents', 72 | action='store_true', help="Strip accents (default).") 73 | parser.add_argument("--no-strip-accents", dest='strip_accents', 74 | action='store_false', help="Don't strip accents.") 75 | parser.set_defaults(strip_accents=True) 76 | 77 | args = parser.parse_args() 78 | 79 | utils.rmkdir(os.path.join(args.data_dir, "pretrain_tfrecords")) 80 | if args.num_processes == 1: 81 | write_examples(0, args) 82 | else: 83 | jobs = [] 84 | for i in range(args.num_processes): 85 | job = multiprocessing.Process(target=write_examples, args=(i, args)) 86 | jobs.append(job) 87 | job.start() 88 | for job in jobs: 89 | job.join() 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /build_pretraining_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import argparse 4 | import multiprocessing 5 | import os 6 | import random 7 | import time 8 | import tensorflow.compat.v1 as tf 9 | 10 | from model import tokenization 11 | from util import utils 12 | 13 | 14 | def create_int_feature(values): 15 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 16 | return feature 17 | 18 | 19 | class ExampleBuilder(object): 20 | """Given a stream of input text, creates pretraining examples.""" 21 | 22 | def __init__(self, tokenizer, max_length): 23 | self._tokenizer = tokenizer 24 | self._current_sentences = [] 25 | self._current_length = 0 26 | self._max_length = max_length 27 | self._target_length = max_length 28 | 29 | def add_line(self, line): 30 | """Adds a line of text to the current example being built.""" 31 | line = line.strip().replace("\n", " ") 32 | if (not line) and self._current_length != 0: # empty lines separate docs 33 | return self._create_example() 34 | bert_tokens = self._tokenizer.tokenize(line) 35 | bert_tokids = self._tokenizer.convert_tokens_to_ids(bert_tokens) 36 | self._current_sentences.append(bert_tokids) 37 | self._current_length += len(bert_tokids) 38 | if self._current_length >= self._target_length: 39 | return self._create_example() 40 | return None 41 | 42 | def _create_example(self): 43 | """Creates a pre-training example from the current list of sentences.""" 44 | # small chance to only have one segment as in classification tasks 45 | if random.random() < 0.1: 46 | first_segment_target_length = 100000 47 | else: 48 | # -3 due to not yet having [CLS]/[SEP] tokens in the input text 49 | first_segment_target_length = (self._target_length - 3) // 2 50 | 51 | first_segment = [] 52 | second_segment = [] 53 | for sentence in self._current_sentences: 54 | # the sentence goes to the first segment if (1) the first segment is 55 | # empty, (2) the sentence doesn't put the first segment over length or 56 | # (3) 50% of the time when it does put the first segment over length 57 | if (first_segment or 58 | len(first_segment) + len(sentence) < first_segment_target_length or 59 | (second_segment and 60 | len(first_segment) < first_segment_target_length and 61 | random.random() < 0.5)): 62 | first_segment += sentence 63 | else: 64 | second_segment += sentence 65 | 66 | # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens 67 | first_segment = first_segment[:self._max_length - 2] 68 | second_segment = second_segment[:max(0, self._max_length - 69 | len(first_segment) - 3)] 70 | 71 | # prepare to start building the next example 72 | self._current_sentences = [] 73 | self._current_length = 0 74 | # small chance for random-length instead of max_length-length example 75 | if random.random() < 0.05: 76 | self._target_length = random.randint(5, self._max_length) 77 | else: 78 | self._target_length = self._max_length 79 | 80 | return self._make_tf_example(first_segment, second_segment) 81 | 82 | def _make_tf_example(self, first_segment, second_segment): 83 | """Converts two "segments" of text into a tf.train.Example.""" 84 | vocab = self._tokenizer.vocab 85 | input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]] 86 | segment_ids = [0] * len(input_ids) 87 | if second_segment: 88 | input_ids += second_segment + [vocab["[SEP]"]] 89 | segment_ids += [1] * (len(second_segment) + 1) 90 | input_mask = [1] * len(input_ids) 91 | input_ids += [0] * (self._max_length - len(input_ids)) 92 | input_mask += [0] * (self._max_length - len(input_mask)) 93 | segment_ids += [0] * (self._max_length - len(segment_ids)) 94 | tf_example = tf.train.Example(features=tf.train.Features(feature={ 95 | "input_ids": create_int_feature(input_ids), 96 | "input_mask": create_int_feature(input_mask), 97 | "segment_ids": create_int_feature(segment_ids) 98 | })) 99 | return tf_example 100 | 101 | 102 | class ExampleWriter(object): 103 | """Writes pre-training examples to disk.""" 104 | 105 | def __init__(self, job_id, vocab_file, output_dir, max_seq_length, 106 | num_jobs, blanks_separate_docs, num_out_files=1000, strip_accents=True): 107 | self._blanks_separate_docs = blanks_separate_docs 108 | tokenizer = tokenization.FullTokenizer( 109 | vocab_file=vocab_file, 110 | do_lower_case=True, 111 | strip_accents=strip_accents) 112 | self._example_builder = ExampleBuilder(tokenizer, max_seq_length) 113 | self._writers = [] 114 | for i in range(num_out_files): 115 | if i % num_jobs == job_id: 116 | output_fname = os.path.join( 117 | output_dir, "pretrain_data.tfrecord-{:}-of-{:}".format( 118 | i, num_out_files)) 119 | self._writers.append(tf.io.TFRecordWriter(output_fname)) 120 | self.n_written = 0 121 | 122 | def write_examples(self, input_file): 123 | """Writes out examples from the provided input file.""" 124 | with tf.io.gfile.GFile(input_file) as f: 125 | for line in f: 126 | line = line.strip() 127 | if line or self._blanks_separate_docs: 128 | example = self._example_builder.add_line(line) 129 | if example: 130 | self._writers[self.n_written % len(self._writers)].write( 131 | example.SerializeToString()) 132 | self.n_written += 1 133 | example = self._example_builder.add_line("") 134 | if example: 135 | self._writers[self.n_written % len(self._writers)].write( 136 | example.SerializeToString()) 137 | self.n_written += 1 138 | 139 | def finish(self): 140 | for writer in self._writers: 141 | writer.close() 142 | 143 | 144 | def write_examples(job_id, args): 145 | """A single process creating and writing out pre-processed examples.""" 146 | 147 | def log(*args): 148 | msg = " ".join(map(str, args)) 149 | print("Job {}:".format(job_id), msg) 150 | 151 | log("Creating example writer") 152 | example_writer = ExampleWriter( 153 | job_id=job_id, 154 | vocab_file=args.vocab_file, 155 | output_dir=args.output_dir, 156 | max_seq_length=args.max_seq_length, 157 | num_jobs=args.num_processes, 158 | blanks_separate_docs=args.blanks_separate_docs, 159 | strip_accents=args.strip_accents, 160 | ) 161 | log("Writing tf examples") 162 | fnames = sorted(tf.io.gfile.listdir(args.corpus_dir)) 163 | fnames = [f for (i, f) in enumerate(fnames) 164 | if i % args.num_processes == job_id] 165 | random.shuffle(fnames) 166 | start_time = time.time() 167 | for file_no, fname in enumerate(fnames): 168 | if file_no > 0: 169 | elapsed = time.time() - start_time 170 | log("processed {:}/{:} files ({:.1f}%), ELAPSED: {:}s, ETA: {:}s, " 171 | "{:} examples written".format( 172 | file_no, len(fnames), 100.0 * file_no / len(fnames), int(elapsed), 173 | int((len(fnames) - file_no) / (file_no / elapsed)), 174 | example_writer.n_written)) 175 | example_writer.write_examples(os.path.join(args.corpus_dir, fname)) 176 | example_writer.finish() 177 | log("Done!") 178 | 179 | 180 | def main(): 181 | parser = argparse.ArgumentParser(description=__doc__) 182 | parser.add_argument("--corpus-dir", required=True, 183 | help="Location of pre-training text files.") 184 | parser.add_argument("--vocab-file", required=True, 185 | help="Location of vocabulary file.") 186 | parser.add_argument("--output-dir", required=True, 187 | help="Where to write out the tfrecords.") 188 | parser.add_argument("--max-seq-length", default=128, type=int, 189 | help="Number of tokens per example.") 190 | parser.add_argument("--num-processes", default=1, type=int, 191 | help="Parallelize across multiple processes.") 192 | parser.add_argument("--blanks-separate-docs", default=True, type=bool, 193 | help="Whether blank lines indicate document boundaries.") 194 | 195 | # toggle strip-accents and set default to True which is the default behavior 196 | parser.add_argument("--do-strip-accents", dest='strip_accents', 197 | action='store_true', help="Strip accents (default).") 198 | parser.add_argument("--no-strip-accents", dest='strip_accents', 199 | action='store_false', help="Don't strip accents.") 200 | parser.set_defaults(strip_accents=True) 201 | 202 | args = parser.parse_args() 203 | 204 | utils.rmkdir(args.output_dir) 205 | if args.num_processes == 1: 206 | write_examples(0, args) 207 | else: 208 | jobs = [] 209 | for i in range(args.num_processes): 210 | job = multiprocessing.Process(target=write_examples, args=(i, args)) 211 | jobs.append(job) 212 | job.start() 213 | for job in jobs: 214 | job.join() 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /configure_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Config controlling hyperparameters for fine-tuning.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | 11 | import tensorflow.compat.v1 as tf 12 | 13 | 14 | class FinetuningConfig(object): 15 | """Fine-tuning hyperparameters.""" 16 | 17 | def __init__(self, model_name, data_dir, **kwargs): 18 | # general 19 | self.model_name = model_name 20 | self.debug = False # debug mode for quickly running things 21 | self.log_examples = False # print out some train examples for debugging 22 | self.num_trials = 1 # how many train+eval runs to perform 23 | self.do_train = True # train a model 24 | self.do_eval = True # evaluate the model 25 | self.keep_all_models = False # if False, only keep the last trial's ckpt 26 | 27 | # model 28 | self.model_size = "medium-small" # one of "small", "medium-small" or "base" 29 | self.task_names = ["cola"] # which tasks to learn 30 | # ConvBERT additional config 31 | self.conv_kernel_size=9 32 | self.linear_groups=2 33 | self.head_ratio=2 34 | self.conv_type="sdconv" 35 | 36 | # override the default transformer hparams for the provided model size; see 37 | # modeling.BertConfig for the possible hparams and util.training_utils for 38 | # the defaults 39 | self.model_hparam_overrides = ( 40 | kwargs["model_hparam_overrides"] 41 | if "model_hparam_overrides" in kwargs else {}) 42 | self.embedding_size = None # bert hidden size by default 43 | self.vocab_size = 30522 # number of tokens in the vocabulary 44 | self.do_lower_case = True 45 | 46 | # training 47 | self.learning_rate = 3e-4 48 | self.weight_decay_rate = 0.01 49 | self.layerwise_lr_decay = 0.8 # if > 0, the learning rate for a layer is 50 | # lr * lr_decay^(depth - max_depth) i.e., 51 | # shallower layers have lower learning rates 52 | self.num_train_epochs = 3.0 # passes over the dataset during training 53 | self.warmup_proportion = 0.1 # how much of training to warm up the LR for 54 | self.save_checkpoints_steps = 1000000 55 | self.iterations_per_loop = 1000 56 | self.use_tfrecords_if_existing = True # don't make tfrecords and write them 57 | # to disc if existing ones are found 58 | 59 | # writing model outputs to disc 60 | self.write_test_outputs = True # whether to write test set outputs, 61 | # currently supported for GLUE + SQuAD 2.0 62 | self.n_writes_test = 5 # write test set predictions for the first n trials 63 | 64 | # sizing 65 | self.max_seq_length = 128 66 | self.train_batch_size = 32 67 | self.eval_batch_size = 32 68 | self.predict_batch_size = 32 69 | self.double_unordered = True # for tasks like paraphrase where sentence 70 | # order doesn't matter, train the model on 71 | # on both sentence orderings for each example 72 | # for qa tasks 73 | self.max_query_length = 64 # max tokens in q as opposed to context 74 | self.doc_stride = 128 # stride when splitting doc into multiple examples 75 | self.n_best_size = 20 # number of predictions per example to save 76 | self.max_answer_length = 30 # filter out answers longer than this length 77 | self.answerable_classifier = True # answerable classifier for SQuAD 2.0 78 | self.answerable_uses_start_logits = True # more advanced answerable 79 | # classifier using predicted start 80 | self.answerable_weight = 0.5 # weight for answerability loss 81 | self.joint_prediction = True # jointly predict the start and end positions 82 | # of the answer span 83 | self.beam_size = 20 # beam size when doing joint predictions 84 | self.qa_na_threshold = -2.75 # threshold for "no answer" when writing SQuAD 85 | # 2.0 test outputs 86 | 87 | # TPU settings 88 | self.use_tpu = False 89 | self.num_tpu_cores = 1 90 | self.tpu_job_name = None 91 | self.tpu_name = None # cloud TPU to use for training 92 | self.tpu_zone = None # GCE zone where the Cloud TPU is located in 93 | self.gcp_project = None # project name for the Cloud TPU-enabled project 94 | 95 | # default locations of data files 96 | self.data_dir = data_dir 97 | pretrained_model_dir = os.path.join(data_dir, "models", model_name) 98 | self.raw_data_dir = os.path.join(data_dir, "finetuning_data", "{:}").format 99 | self.vocab_file = os.path.join(pretrained_model_dir, "vocab.txt") 100 | if not tf.io.gfile.exists(self.vocab_file): 101 | self.vocab_file = os.path.join(self.data_dir, "vocab.txt") 102 | task_names_str = ",".join( 103 | kwargs["task_names"] if "task_names" in kwargs else self.task_names) 104 | self.init_checkpoint = None if self.debug else tf.train.latest_checkpoint(pretrained_model_dir) 105 | self.model_dir = os.path.join(pretrained_model_dir, "finetuning_models", 106 | task_names_str + "_model") 107 | results_dir = os.path.join(pretrained_model_dir, "results") 108 | self.results_txt = os.path.join(results_dir, 109 | task_names_str + "_results.txt") 110 | self.results_pkl = os.path.join(results_dir, 111 | task_names_str + "_results.pkl") 112 | qa_topdir = os.path.join(results_dir, task_names_str + "_qa") 113 | self.qa_eval_file = os.path.join(qa_topdir, "{:}_eval.json").format 114 | self.qa_preds_file = os.path.join(qa_topdir, "{:}_preds.json").format 115 | self.qa_na_file = os.path.join(qa_topdir, "{:}_null_odds.json").format 116 | self.preprocessed_data_dir = os.path.join( 117 | data_dir, "finetuning_tfrecords", 118 | task_names_str + "_tfrecords" + ("-debug" if self.debug else "")) 119 | self.test_predictions = os.path.join( 120 | pretrained_model_dir, "test_predictions", 121 | "{:}_{:}_{:}_predictions.pkl").format 122 | 123 | # update defaults with passed-in hyperparameters 124 | self.update(kwargs) 125 | 126 | # default hyperparameters for single-task models 127 | if len(self.task_names) == 1: 128 | task_name = self.task_names[0] 129 | if task_name == "rte" or task_name == "sts": 130 | self.num_train_epochs = 10.0 131 | elif "squad" in task_name or "qa" in task_name: 132 | self.max_seq_length = 512 133 | self.num_train_epochs = 2.0 134 | self.write_distill_outputs = False 135 | self.write_test_outputs = False 136 | elif task_name == "chunk": 137 | self.max_seq_length = 256 138 | else: 139 | self.num_train_epochs = 3.0 140 | 141 | # default hyperparameters for different model sizes 142 | 143 | 144 | if self.model_size in ["medium-small"]: 145 | self.embedding_size = 128 146 | self.conv_kernel_size=9 147 | self.linear_groups=2 148 | self.head_ratio=2 149 | elif self.model_size in ["small"]: 150 | self.embedding_size = 128 151 | self.conv_kernel_size=9 152 | self.linear_groups=1 153 | self.head_ratio=2 154 | self.learning_rate = 3e-4 155 | elif self.model_size in ["base"]: 156 | self.learning_rate = 1e-4 157 | self.conv_kernel_size=9 158 | self.linear_groups=1 159 | self.head_ratio=2 160 | # debug-mode settings 161 | if self.debug: 162 | self.save_checkpoints_steps = 1000000 163 | self.use_tfrecords_if_existing = False 164 | self.num_trials = 1 165 | self.iterations_per_loop = 1 166 | self.train_batch_size = 32 167 | self.num_train_epochs = 3.0 168 | self.log_examples = True 169 | 170 | # passed-in-arguments override (for example) debug-mode defaults 171 | self.update(kwargs) 172 | 173 | def update(self, kwargs): 174 | for k, v in kwargs.items(): 175 | if k not in self.__dict__: 176 | raise ValueError("Unknown hparam " + k) 177 | self.__dict__[k] = v 178 | -------------------------------------------------------------------------------- /configure_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Config controlling hyperparameters for pre-training.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | 11 | 12 | class PretrainingConfig(object): 13 | """Defines pre-training hyperparameters.""" 14 | 15 | def __init__(self, model_name, data_dir, **kwargs): 16 | self.model_name = model_name 17 | self.debug = False # debug mode 18 | self.do_train = True # pre-train 19 | self.do_eval = False # evaluate generator/discriminator on unlabeled data 20 | 21 | # loss functions 22 | self.electra_objective = True # if False, use the BERT objective instead 23 | self.gen_weight = 1.0 # masked language modeling / generator loss 24 | self.disc_weight = 50.0 # discriminator loss 25 | self.mask_prob = 0.15 # percent of input tokens to mask out / replace 26 | 27 | # optimization 28 | self.learning_rate = 5e-4 29 | self.lr_decay_power = 1.0 # linear weight decay by default 30 | self.weight_decay_rate = 0.01 31 | self.num_warmup_steps = 10000 32 | 33 | # training settings 34 | self.iterations_per_loop = 200 35 | self.save_checkpoints_steps = 10000 36 | self.num_train_steps = 1000000 37 | self.num_eval_steps = 100 38 | 39 | # model settings 40 | self.model_size = "medium-small" # one of "small", "medium-smal", or "base" 41 | # override the default transformer hparams for the provided model size; see 42 | # modeling.BertConfig for the possible hparams and util.training_utils for 43 | # the defaults 44 | self.model_hparam_overrides = ( 45 | kwargs["model_hparam_overrides"] 46 | if "model_hparam_overrides" in kwargs else {}) 47 | self.embedding_size = None # bert hidden size by default 48 | self.vocab_size = 30522 # number of tokens in the vocabulary 49 | self.do_lower_case = True # lowercase the input? 50 | 51 | # ConvBERT additional config 52 | self.conv_kernel_size=9 53 | self.linear_groups=2 54 | self.head_ratio=2 55 | self.conv_type="sdconv" 56 | # generator settings 57 | self.uniform_generator = False # generator is uniform at random 58 | self.untied_generator_embeddings = False # tie generator/discriminator 59 | # token embeddings? 60 | self.untied_generator = True # tie all generator/discriminator weights? 61 | self.generator_layers = 1.0 # frac of discriminator layers for generator 62 | self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen 63 | self.disallow_correct = False # force the generator to sample incorrect 64 | # tokens (so 15% of tokens are always 65 | # fake) 66 | self.temperature = 1.0 # temperature for sampling from generator 67 | 68 | # batch sizes 69 | self.max_seq_length = 128 70 | self.train_batch_size = 128 71 | self.eval_batch_size = 128 72 | 73 | # TPU settings 74 | self.use_tpu = False 75 | self.tpu_job_name = None 76 | self.num_tpu_cores = 1 77 | self.tpu_name = None # cloud TPU to use for training 78 | self.tpu_zone = None # GCE zone where the Cloud TPU is located in 79 | self.gcp_project = None # project name for the Cloud TPU-enabled project 80 | 81 | # default locations of data files 82 | self.pretrain_tfrecords = os.path.join( 83 | data_dir, "pretrain_tfrecords/pretrain_data.tfrecord*") 84 | self.vocab_file = os.path.join(data_dir, "vocab.txt") 85 | self.model_dir = os.path.join(data_dir, "models", model_name) 86 | results_dir = os.path.join(self.model_dir, "results") 87 | self.results_txt = os.path.join(results_dir, "unsup_results.txt") 88 | self.results_pkl = os.path.join(results_dir, "unsup_results.pkl") 89 | 90 | # update defaults with passed-in hyperparameters 91 | self.update(kwargs) 92 | 93 | self.max_predictions_per_seq = int((self.mask_prob + 0.005) * 94 | self.max_seq_length) 95 | 96 | # debug-mode settings 97 | if self.debug: 98 | self.train_batch_size = 8 99 | self.num_train_steps = 20 100 | self.eval_batch_size = 4 101 | self.iterations_per_loop = 1 102 | self.num_eval_steps = 2 103 | 104 | # defaults for different-sized model 105 | if self.model_size in ["medium-small"]: 106 | self.embedding_size = 128 107 | self.conv_kernel_size=9 108 | self.linear_groups=2 109 | self.head_ratio=2 110 | elif self.model_size in ["small"]: 111 | self.embedding_size = 128 112 | self.conv_kernel_size=9 113 | self.linear_groups=1 114 | self.head_ratio=2 115 | self.learning_rate = 3e-4 116 | elif self.model_size in ["base"]: 117 | self.generator_hidden_size = 1/3 118 | self.learning_rate = 2e-4 119 | self.train_batch_size = 256 120 | self.eval_batch_size = 256 121 | self.conv_kernel_size=9 122 | self.linear_groups=1 123 | self.head_ratio=2 124 | 125 | # passed-in-arguments override (for example) debug-mode defaults 126 | self.update(kwargs) 127 | 128 | def update(self, kwargs): 129 | for k, v in kwargs.items(): 130 | if k not in self.__dict__: 131 | raise ValueError("Unknown hparam " + k) 132 | self.__dict__[k] = v 133 | -------------------------------------------------------------------------------- /download_glue_data.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | ''' 3 | 4 | import os 5 | import sys 6 | import shutil 7 | import argparse 8 | import tempfile 9 | import urllib.request 10 | import zipfile 11 | 12 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 13 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 14 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 15 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 16 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 17 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 18 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 19 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 20 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 21 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 22 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 23 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 24 | 25 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 26 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 27 | 28 | def download_and_extract(task, data_dir): 29 | print("Downloading and extracting %s..." % task) 30 | data_file = "%s.zip" % task 31 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 32 | with zipfile.ZipFile(data_file) as zip_ref: 33 | zip_ref.extractall(data_dir) 34 | os.remove(data_file) 35 | print("\tCompleted!") 36 | 37 | def format_mrpc(data_dir, path_to_data): 38 | print("Processing MRPC...") 39 | mrpc_dir = os.path.join(data_dir, "MRPC") 40 | if not os.path.isdir(mrpc_dir): 41 | os.mkdir(mrpc_dir) 42 | if path_to_data: 43 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 44 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 45 | else: 46 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 47 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 48 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 49 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 50 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 51 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 52 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 53 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 54 | 55 | dev_ids = [] 56 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 57 | for row in ids_fh: 58 | dev_ids.append(row.strip().split('\t')) 59 | 60 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 61 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 62 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 63 | header = data_fh.readline() 64 | train_fh.write(header) 65 | dev_fh.write(header) 66 | for row in data_fh: 67 | label, id1, id2, s1, s2 = row.strip().split('\t') 68 | if [id1, id2] in dev_ids: 69 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 70 | else: 71 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 72 | 73 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 74 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 75 | header = data_fh.readline() 76 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 77 | for idx, row in enumerate(data_fh): 78 | label, id1, id2, s1, s2 = row.strip().split('\t') 79 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 80 | print("\tCompleted!") 81 | 82 | def download_diagnostic(data_dir): 83 | print("Downloading and extracting diagnostic...") 84 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 85 | os.mkdir(os.path.join(data_dir, "diagnostic")) 86 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 87 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 88 | print("\tCompleted!") 89 | return 90 | 91 | def get_tasks(task_names): 92 | task_names = task_names.split(',') 93 | if "all" in task_names: 94 | tasks = TASKS 95 | else: 96 | tasks = [] 97 | for task_name in task_names: 98 | assert task_name in TASKS, "Task %s not found!" % task_name 99 | tasks.append(task_name) 100 | return tasks 101 | 102 | def main(arguments): 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 105 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 106 | type=str, default='all') 107 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 108 | type=str, default='') 109 | args = parser.parse_args(arguments) 110 | 111 | if not os.path.isdir(args.data_dir): 112 | os.mkdir(args.data_dir) 113 | tasks = get_tasks(args.tasks) 114 | 115 | for task in tasks: 116 | if task == 'MRPC': 117 | format_mrpc(args.data_dir, args.path_to_mrpc) 118 | elif task == 'diagnostic': 119 | download_diagnostic(args.data_dir) 120 | else: 121 | download_and_extract(task, args.data_dir) 122 | 123 | 124 | if __name__ == '__main__': 125 | sys.exit(main(sys.argv[1:])) 126 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=/path/to/data_dir 2 | # please set your data_dir, like ~/data/convbert 3 | NAME=convbert_medium-small 4 | 5 | python3 run_finetuning.py --data-dir $DATA_DIR \ 6 | --model-name $NAME --hparams '{"model_size": "medium-small", "task_names": ["cola"]}' -------------------------------------------------------------------------------- /finetune/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 -------------------------------------------------------------------------------- /finetune/classification/classification_metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Evaluation metrics for classification tasks.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import abc 10 | import numpy as np 11 | import scipy 12 | import sklearn 13 | 14 | from finetune import scorer 15 | 16 | 17 | class SentenceLevelScorer(scorer.Scorer): 18 | """Abstract scorer for classification/regression tasks.""" 19 | 20 | __metaclass__ = abc.ABCMeta 21 | 22 | def __init__(self): 23 | super(SentenceLevelScorer, self).__init__() 24 | self._total_loss = 0 25 | self._true_labels = [] 26 | self._preds = [] 27 | 28 | def update(self, results): 29 | super(SentenceLevelScorer, self).update(results) 30 | self._total_loss += results['loss'] 31 | self._true_labels.append(results['label_ids'] if 'label_ids' in results 32 | else results['targets']) 33 | self._preds.append(results['predictions']) 34 | 35 | def get_loss(self): 36 | return self._total_loss / len(self._true_labels) 37 | 38 | 39 | class AccuracyScorer(SentenceLevelScorer): 40 | 41 | def _get_results(self): 42 | correct, count = 0, 0 43 | for y_true, pred in zip(self._true_labels, self._preds): 44 | count += 1 45 | correct += (1 if y_true == pred else 0) 46 | return [ 47 | ('accuracy', 100.0 * correct / count), 48 | ('loss', self.get_loss()), 49 | ] 50 | 51 | 52 | class F1Scorer(SentenceLevelScorer): 53 | """Computes F1 for classification tasks.""" 54 | 55 | def __init__(self): 56 | super(F1Scorer, self).__init__() 57 | self._positive_label = 1 58 | 59 | def _get_results(self): 60 | n_correct, n_predicted, n_gold = 0, 0, 0 61 | for y_true, pred in zip(self._true_labels, self._preds): 62 | if pred == self._positive_label: 63 | n_gold += 1 64 | if pred == self._positive_label: 65 | n_predicted += 1 66 | if pred == y_true: 67 | n_correct += 1 68 | if n_correct == 0: 69 | p, r, f1 = 0, 0, 0 70 | else: 71 | p = 100.0 * n_correct / n_predicted 72 | r = 100.0 * n_correct / n_gold 73 | f1 = 2 * p * r / (p + r) 74 | return [ 75 | ('precision', p), 76 | ('recall', r), 77 | ('f1', f1), 78 | ('loss', self.get_loss()), 79 | ] 80 | 81 | 82 | class MCCScorer(SentenceLevelScorer): 83 | 84 | def _get_results(self): 85 | return [ 86 | ('mcc', 100 * sklearn.metrics.matthews_corrcoef( 87 | self._true_labels, self._preds)), 88 | ('loss', self.get_loss()), 89 | ] 90 | 91 | 92 | class RegressionScorer(SentenceLevelScorer): 93 | 94 | def _get_results(self): 95 | preds = np.array(self._preds).flatten() 96 | return [ 97 | ('pearson', 100.0 * scipy.stats.pearsonr( 98 | self._true_labels, preds)[0]), 99 | ('spearman', 100.0 * scipy.stats.spearmanr( 100 | self._true_labels, preds)[0]), 101 | ('mse', np.mean(np.square(np.array(self._true_labels) - self._preds))), 102 | ('loss', self.get_loss()), 103 | ] 104 | -------------------------------------------------------------------------------- /finetune/classification/classification_tasks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Text classification and regression tasks.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import abc 23 | import csv 24 | import os 25 | import tensorflow.compat.v1 as tf 26 | 27 | import configure_finetuning 28 | from finetune import feature_spec 29 | from finetune import task 30 | from finetune.classification import classification_metrics 31 | from model import tokenization 32 | from util import utils 33 | 34 | 35 | class InputExample(task.Example): 36 | """A single training/test example for simple sequence classification.""" 37 | 38 | def __init__(self, eid, task_name, text_a, text_b=None, label=None): 39 | super(InputExample, self).__init__(task_name) 40 | self.eid = eid 41 | self.text_a = text_a 42 | self.text_b = text_b 43 | self.label = label 44 | 45 | 46 | class SingleOutputTask(task.Task): 47 | """Task with a single prediction per example (e.g., text classification).""" 48 | 49 | __metaclass__ = abc.ABCMeta 50 | 51 | def __init__(self, config: configure_finetuning.FinetuningConfig, name, 52 | tokenizer): 53 | super(SingleOutputTask, self).__init__(config, name) 54 | self._tokenizer = tokenizer 55 | 56 | def get_examples(self, split): 57 | return self._create_examples(read_tsv( 58 | os.path.join(self.config.raw_data_dir(self.name), split + ".tsv"), 59 | max_lines=100 if self.config.debug else None), split) 60 | 61 | @abc.abstractmethod 62 | def _create_examples(self, lines, split): 63 | pass 64 | 65 | def featurize(self, example: InputExample, is_training, log=False): 66 | """Turn an InputExample into a dict of features.""" 67 | tokens_a = self._tokenizer.tokenize(example.text_a) 68 | tokens_b = None 69 | if example.text_b: 70 | tokens_b = self._tokenizer.tokenize(example.text_b) 71 | 72 | if tokens_b: 73 | # Modifies `tokens_a` and `tokens_b` in place so that the total 74 | # length is less than the specified length. 75 | # Account for [CLS], [SEP], [SEP] with "- 3" 76 | _truncate_seq_pair(tokens_a, tokens_b, self.config.max_seq_length - 3) 77 | else: 78 | # Account for [CLS] and [SEP] with "- 2" 79 | if len(tokens_a) > self.config.max_seq_length - 2: 80 | tokens_a = tokens_a[0:(self.config.max_seq_length - 2)] 81 | 82 | # The convention in BERT is: 83 | # (a) For sequence pairs: 84 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 85 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 86 | # (b) For single sequences: 87 | # tokens: [CLS] the dog is hairy . [SEP] 88 | # type_ids: 0 0 0 0 0 0 0 89 | # 90 | # Where "type_ids" are used to indicate whether this is the first 91 | # sequence or the second sequence. The embedding vectors for `type=0` and 92 | # `type=1` were learned during pre-training and are added to the wordpiece 93 | # embedding vector (and position vector). This is not *strictly* necessary 94 | # since the [SEP] token unambiguously separates the sequences, but it 95 | # makes it easier for the model to learn the concept of sequences. 96 | # 97 | # For classification tasks, the first vector (corresponding to [CLS]) is 98 | # used as the "sentence vector". Note that this only makes sense because 99 | # the entire model is fine-tuned. 100 | tokens = [] 101 | segment_ids = [] 102 | tokens.append("[CLS]") 103 | segment_ids.append(0) 104 | for token in tokens_a: 105 | tokens.append(token) 106 | segment_ids.append(0) 107 | tokens.append("[SEP]") 108 | segment_ids.append(0) 109 | 110 | if tokens_b: 111 | for token in tokens_b: 112 | tokens.append(token) 113 | segment_ids.append(1) 114 | tokens.append("[SEP]") 115 | segment_ids.append(1) 116 | 117 | input_ids = self._tokenizer.convert_tokens_to_ids(tokens) 118 | 119 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 120 | # tokens are attended to. 121 | input_mask = [1] * len(input_ids) 122 | 123 | # Zero-pad up to the sequence length. 124 | while len(input_ids) < self.config.max_seq_length: 125 | input_ids.append(0) 126 | input_mask.append(0) 127 | segment_ids.append(0) 128 | 129 | assert len(input_ids) == self.config.max_seq_length 130 | assert len(input_mask) == self.config.max_seq_length 131 | assert len(segment_ids) == self.config.max_seq_length 132 | 133 | if log: 134 | utils.log(" Example {:}".format(example.eid)) 135 | utils.log(" tokens: {:}".format(" ".join( 136 | [tokenization.printable_text(x) for x in tokens]))) 137 | utils.log(" input_ids: {:}".format(" ".join(map(str, input_ids)))) 138 | utils.log(" input_mask: {:}".format(" ".join(map(str, input_mask)))) 139 | utils.log(" segment_ids: {:}".format(" ".join(map(str, segment_ids)))) 140 | 141 | eid = example.eid 142 | features = { 143 | "input_ids": input_ids, 144 | "input_mask": input_mask, 145 | "segment_ids": segment_ids, 146 | "task_id": self.config.task_names.index(self.name), 147 | self.name + "_eid": eid, 148 | } 149 | self._add_features(features, example, log) 150 | return features 151 | 152 | def _load_glue(self, lines, split, text_a_loc, text_b_loc, label_loc, 153 | skip_first_line=False, eid_offset=0, swap=False): 154 | examples = [] 155 | for (i, line) in enumerate(lines): 156 | try: 157 | if i == 0 and skip_first_line: 158 | continue 159 | eid = i - (1 if skip_first_line else 0) + eid_offset 160 | text_a = tokenization.convert_to_unicode(line[text_a_loc]) 161 | if text_b_loc is None: 162 | text_b = None 163 | else: 164 | text_b = tokenization.convert_to_unicode(line[text_b_loc]) 165 | if "test" in split or "diagnostic" in split: 166 | label = self._get_dummy_label() 167 | else: 168 | label = tokenization.convert_to_unicode(line[label_loc]) 169 | if swap: 170 | text_a, text_b = text_b, text_a 171 | examples.append(InputExample(eid=eid, task_name=self.name, 172 | text_a=text_a, text_b=text_b, label=label)) 173 | except Exception as ex: 174 | utils.log("Error constructing example from line", i, 175 | "for task", self.name + ":", ex) 176 | utils.log("Input causing the error:", line) 177 | return examples 178 | 179 | @abc.abstractmethod 180 | def _get_dummy_label(self): 181 | pass 182 | 183 | @abc.abstractmethod 184 | def _add_features(self, features, example, log): 185 | pass 186 | 187 | 188 | class RegressionTask(SingleOutputTask): 189 | """Task where the output is a real-valued score for the input text.""" 190 | 191 | __metaclass__ = abc.ABCMeta 192 | 193 | def __init__(self, config: configure_finetuning.FinetuningConfig, name, 194 | tokenizer, min_value, max_value): 195 | super(RegressionTask, self).__init__(config, name, tokenizer) 196 | self._tokenizer = tokenizer 197 | self._min_value = min_value 198 | self._max_value = max_value 199 | 200 | def _get_dummy_label(self): 201 | return 0.0 202 | 203 | def get_feature_specs(self): 204 | feature_specs = [feature_spec.FeatureSpec(self.name + "_eid", []), 205 | feature_spec.FeatureSpec(self.name + "_targets", [], 206 | is_int_feature=False)] 207 | return feature_specs 208 | 209 | def _add_features(self, features, example, log): 210 | label = float(example.label) 211 | assert self._min_value <= label <= self._max_value 212 | # simple normalization of the label 213 | label = (label - self._min_value) / self._max_value 214 | if log: 215 | utils.log(" label: {:}".format(label)) 216 | features[example.task_name + "_targets"] = label 217 | 218 | def get_prediction_module(self, bert_model, features, is_training, 219 | percent_done): 220 | reprs = bert_model.get_pooled_output() 221 | if is_training: 222 | reprs = tf.nn.dropout(reprs, keep_prob=0.9) 223 | 224 | predictions = tf.layers.dense(reprs, 1) 225 | predictions = tf.squeeze(predictions, -1) 226 | 227 | targets = features[self.name + "_targets"] 228 | losses = tf.square(predictions - targets) 229 | outputs = dict( 230 | loss=losses, 231 | predictions=predictions, 232 | targets=features[self.name + "_targets"], 233 | eid=features[self.name + "_eid"] 234 | ) 235 | return losses, outputs 236 | 237 | def get_scorer(self): 238 | return classification_metrics.RegressionScorer() 239 | 240 | 241 | class ClassificationTask(SingleOutputTask): 242 | """Task where the output is a single categorical label for the input text.""" 243 | __metaclass__ = abc.ABCMeta 244 | 245 | def __init__(self, config: configure_finetuning.FinetuningConfig, name, 246 | tokenizer, label_list): 247 | super(ClassificationTask, self).__init__(config, name, tokenizer) 248 | self._tokenizer = tokenizer 249 | self._label_list = label_list 250 | 251 | def _get_dummy_label(self): 252 | return self._label_list[0] 253 | 254 | def get_feature_specs(self): 255 | return [feature_spec.FeatureSpec(self.name + "_eid", []), 256 | feature_spec.FeatureSpec(self.name + "_label_ids", [])] 257 | 258 | def _add_features(self, features, example, log): 259 | label_map = {} 260 | for (i, label) in enumerate(self._label_list): 261 | label_map[label] = i 262 | label_id = label_map[example.label] 263 | if log: 264 | utils.log(" label: {:} (id = {:})".format(example.label, label_id)) 265 | features[example.task_name + "_label_ids"] = label_id 266 | 267 | def get_prediction_module(self, bert_model, features, is_training, 268 | percent_done): 269 | num_labels = len(self._label_list) 270 | reprs = bert_model.get_pooled_output() 271 | 272 | if is_training: 273 | reprs = tf.nn.dropout(reprs, keep_prob=0.9) 274 | 275 | logits = tf.layers.dense(reprs, num_labels) 276 | log_probs = tf.nn.log_softmax(logits, axis=-1) 277 | 278 | label_ids = features[self.name + "_label_ids"] 279 | labels = tf.one_hot(label_ids, depth=num_labels, dtype=tf.float32) 280 | 281 | losses = -tf.reduce_sum(labels * log_probs, axis=-1) 282 | 283 | outputs = dict( 284 | loss=losses, 285 | logits=logits, 286 | predictions=tf.argmax(logits, axis=-1), 287 | label_ids=label_ids, 288 | eid=features[self.name + "_eid"], 289 | ) 290 | return losses, outputs 291 | 292 | def get_scorer(self): 293 | return classification_metrics.AccuracyScorer() 294 | 295 | 296 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 297 | """Truncates a sequence pair in place to the maximum length.""" 298 | 299 | # This is a simple heuristic which will always truncate the longer sequence 300 | # one token at a time. This makes more sense than truncating an equal percent 301 | # of tokens from each, since if one sequence is very short then each token 302 | # that's truncated likely contains more information than a longer sequence. 303 | while True: 304 | total_length = len(tokens_a) + len(tokens_b) 305 | if total_length <= max_length: 306 | break 307 | if len(tokens_a) > len(tokens_b): 308 | tokens_a.pop() 309 | else: 310 | tokens_b.pop() 311 | 312 | 313 | def read_tsv(input_file, quotechar=None, max_lines=None): 314 | """Reads a tab separated value file.""" 315 | with tf.io.gfile.GFile(input_file, "r") as f: 316 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 317 | lines = [] 318 | for i, line in enumerate(reader): 319 | if max_lines and i >= max_lines: 320 | break 321 | lines.append(line) 322 | return lines 323 | 324 | 325 | class MNLI(ClassificationTask): 326 | """Multi-NLI.""" 327 | 328 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 329 | super(MNLI, self).__init__(config, "mnli", tokenizer, 330 | ["contradiction", "entailment", "neutral"]) 331 | 332 | def get_examples(self, split): 333 | if split == "dev": 334 | split += "_matched" 335 | return self._create_examples(read_tsv( 336 | os.path.join(self.config.raw_data_dir(self.name), split + ".tsv"), 337 | max_lines=100 if self.config.debug else None), split) 338 | 339 | def _create_examples(self, lines, split): 340 | if split == "diagnostic": 341 | return self._load_glue(lines, split, 1, 2, None, True) 342 | else: 343 | return self._load_glue(lines, split, 8, 9, -1, True) 344 | 345 | def get_test_splits(self): 346 | return ["test_matched", "test_mismatched", "diagnostic"] 347 | 348 | 349 | class MRPC(ClassificationTask): 350 | """Microsoft Research Paraphrase Corpus.""" 351 | 352 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 353 | super(MRPC, self).__init__(config, "mrpc", tokenizer, ["0", "1"]) 354 | 355 | def _create_examples(self, lines, split): 356 | examples = [] 357 | examples += self._load_glue(lines, split, 3, 4, 0, True) 358 | if self.config.double_unordered and split == "train": 359 | examples += self._load_glue( 360 | lines, split, 3, 4, 0, True, len(examples), True) 361 | return examples 362 | 363 | 364 | class CoLA(ClassificationTask): 365 | """Corpus of Linguistic Acceptability.""" 366 | 367 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 368 | super(CoLA, self).__init__(config, "cola", tokenizer, ["0", "1"]) 369 | 370 | def _create_examples(self, lines, split): 371 | return self._load_glue(lines, split, 1 if split == "test" else 3, 372 | None, 1, split == "test") 373 | 374 | def get_scorer(self): 375 | return classification_metrics.MCCScorer() 376 | 377 | 378 | class SST(ClassificationTask): 379 | """Stanford Sentiment Treebank.""" 380 | 381 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 382 | super(SST, self).__init__(config, "sst", tokenizer, ["0", "1"]) 383 | 384 | def _create_examples(self, lines, split): 385 | if "test" in split: 386 | return self._load_glue(lines, split, 1, None, None, True) 387 | else: 388 | return self._load_glue(lines, split, 0, None, 1, True) 389 | 390 | 391 | class QQP(ClassificationTask): 392 | """Quora Question Pair.""" 393 | 394 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 395 | super(QQP, self).__init__(config, "qqp", tokenizer, ["0", "1"]) 396 | 397 | def _create_examples(self, lines, split): 398 | return self._load_glue(lines, split, 1 if split == "test" else 3, 399 | 2 if split == "test" else 4, 5, True) 400 | 401 | 402 | class RTE(ClassificationTask): 403 | """Recognizing Textual Entailment.""" 404 | 405 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 406 | super(RTE, self).__init__(config, "rte", tokenizer, 407 | ["entailment", "not_entailment"]) 408 | 409 | def _create_examples(self, lines, split): 410 | return self._load_glue(lines, split, 1, 2, 3, True) 411 | 412 | 413 | class WNLI(ClassificationTask): 414 | """Recognizing Textual Entailment.""" 415 | 416 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 417 | super(WNLI, self).__init__(config, "wnli", tokenizer, 418 | ["0", "1"]) 419 | 420 | def _create_examples(self, lines, split): 421 | return self._load_glue(lines, split, 1, 2, 3, True) 422 | 423 | class QNLI(ClassificationTask): 424 | """Question NLI.""" 425 | 426 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 427 | super(QNLI, self).__init__(config, "qnli", tokenizer, 428 | ["entailment", "not_entailment"]) 429 | 430 | def _create_examples(self, lines, split): 431 | return self._load_glue(lines, split, 1, 2, 3, True) 432 | 433 | 434 | class STS(RegressionTask): 435 | """Semantic Textual Similarity.""" 436 | 437 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 438 | super(STS, self).__init__(config, "sts", tokenizer, 0.0, 5.0) 439 | 440 | def _create_examples(self, lines, split): 441 | examples = [] 442 | if split == "test": 443 | examples += self._load_glue(lines, split, -2, -1, None, True) 444 | else: 445 | examples += self._load_glue(lines, split, -3, -2, -1, True) 446 | if self.config.double_unordered and split == "train": 447 | examples += self._load_glue( 448 | lines, split, -3, -2, -1, True, len(examples), True) 449 | return examples 450 | -------------------------------------------------------------------------------- /finetune/feature_spec.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Defines the inputs used when fine-tuning a model.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | import tensorflow.compat.v1 as tf 11 | 12 | import configure_finetuning 13 | 14 | 15 | def get_shared_feature_specs(config: configure_finetuning.FinetuningConfig): 16 | """Non-task-specific model inputs.""" 17 | return [ 18 | FeatureSpec("input_ids", [config.max_seq_length]), 19 | FeatureSpec("input_mask", [config.max_seq_length]), 20 | FeatureSpec("segment_ids", [config.max_seq_length]), 21 | FeatureSpec("task_id", []), 22 | ] 23 | 24 | 25 | class FeatureSpec(object): 26 | """Defines a feature passed as input to the model.""" 27 | 28 | def __init__(self, name, shape, default_value_fn=None, is_int_feature=True): 29 | self.name = name 30 | self.shape = shape 31 | self.default_value_fn = default_value_fn 32 | self.is_int_feature = is_int_feature 33 | 34 | def get_parsing_spec(self): 35 | return tf.io.FixedLenFeature( 36 | self.shape, tf.int64 if self.is_int_feature else tf.float32) 37 | 38 | def get_default_values(self): 39 | if self.default_value_fn: 40 | return self.default_value_fn(self.shape) 41 | else: 42 | return np.zeros( 43 | self.shape, np.int64 if self.is_int_feature else np.float32) 44 | -------------------------------------------------------------------------------- /finetune/preprocessing.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Code for serializing raw fine-tuning data into tfrecords""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import collections 10 | import os 11 | import random 12 | import numpy as np 13 | import tensorflow.compat.v1 as tf 14 | 15 | import configure_finetuning 16 | from finetune import feature_spec 17 | from util import utils 18 | 19 | 20 | class Preprocessor(object): 21 | """Class for loading, preprocessing, and serializing fine-tuning datasets.""" 22 | 23 | def __init__(self, config: configure_finetuning.FinetuningConfig, tasks): 24 | self._config = config 25 | self._tasks = tasks 26 | self._name_to_task = {task.name: task for task in tasks} 27 | 28 | self._feature_specs = feature_spec.get_shared_feature_specs(config) 29 | for task in tasks: 30 | self._feature_specs += task.get_feature_specs() 31 | self._name_to_feature_config = { 32 | spec.name: spec.get_parsing_spec() 33 | for spec in self._feature_specs 34 | } 35 | assert len(self._name_to_feature_config) == len(self._feature_specs) 36 | 37 | def prepare_train(self): 38 | return self._serialize_dataset(self._tasks, True, "train") 39 | 40 | def prepare_predict(self, tasks, split): 41 | return self._serialize_dataset(tasks, False, split) 42 | 43 | def _serialize_dataset(self, tasks, is_training, split): 44 | """Write out the dataset as tfrecords.""" 45 | dataset_name = "_".join(sorted([task.name for task in tasks])) 46 | dataset_name += "_" + split 47 | dataset_prefix = os.path.join( 48 | self._config.preprocessed_data_dir, dataset_name) 49 | tfrecords_path = dataset_prefix + ".tfrecord" 50 | metadata_path = dataset_prefix + ".metadata" 51 | batch_size = (self._config.train_batch_size if is_training else 52 | self._config.eval_batch_size) 53 | 54 | utils.log("Loading dataset", dataset_name) 55 | n_examples = None 56 | if (self._config.use_tfrecords_if_existing and 57 | tf.io.gfile.exists(metadata_path)): 58 | n_examples = utils.load_json(metadata_path)["n_examples"] 59 | 60 | if n_examples is None: 61 | utils.log("Existing tfrecords not found so creating") 62 | examples = [] 63 | for task in tasks: 64 | task_examples = task.get_examples(split) 65 | examples += task_examples 66 | if is_training: 67 | random.shuffle(examples) 68 | utils.mkdir(tfrecords_path.rsplit("/", 1)[0]) 69 | n_examples = self.serialize_examples( 70 | examples, is_training, tfrecords_path, batch_size) 71 | utils.write_json({"n_examples": n_examples}, metadata_path) 72 | 73 | input_fn = self._input_fn_builder(tfrecords_path, is_training) 74 | if is_training: 75 | steps = int(n_examples // batch_size * self._config.num_train_epochs) 76 | else: 77 | steps = n_examples // batch_size 78 | 79 | return input_fn, steps 80 | 81 | def serialize_examples(self, examples, is_training, output_file, batch_size): 82 | """Convert a set of `InputExample`s to a TFRecord file.""" 83 | n_examples = 0 84 | with tf.io.TFRecordWriter(output_file) as writer: 85 | for (ex_index, example) in enumerate(examples): 86 | if ex_index % 2000 == 0: 87 | utils.log("Writing example {:} of {:}".format( 88 | ex_index, len(examples))) 89 | for tf_example in self._example_to_tf_example( 90 | example, is_training, 91 | log=self._config.log_examples and ex_index < 1): 92 | writer.write(tf_example.SerializeToString()) 93 | n_examples += 1 94 | # add padding so the dataset is a multiple of batch_size 95 | while n_examples % batch_size != 0: 96 | writer.write(self._make_tf_example(task_id=len(self._config.task_names)) 97 | .SerializeToString()) 98 | n_examples += 1 99 | return n_examples 100 | 101 | def _example_to_tf_example(self, example, is_training, log=False): 102 | examples = self._name_to_task[example.task_name].featurize( 103 | example, is_training, log) 104 | if not isinstance(examples, list): 105 | examples = [examples] 106 | for example in examples: 107 | yield self._make_tf_example(**example) 108 | 109 | def _make_tf_example(self, **kwargs): 110 | """Make a tf.train.Example from the provided features.""" 111 | for k in kwargs: 112 | if k not in self._name_to_feature_config: 113 | raise ValueError("Unknown feature", k) 114 | features = collections.OrderedDict() 115 | for spec in self._feature_specs: 116 | if spec.name in kwargs: 117 | values = kwargs[spec.name] 118 | else: 119 | values = spec.get_default_values() 120 | if (isinstance(values, int) or isinstance(values, bool) or 121 | isinstance(values, float) or isinstance(values, np.float32) or 122 | (isinstance(values, np.ndarray) and values.size == 1)): 123 | values = [values] 124 | if spec.is_int_feature: 125 | feature = tf.train.Feature(int64_list=tf.train.Int64List( 126 | value=list(values))) 127 | else: 128 | feature = tf.train.Feature(float_list=tf.train.FloatList( 129 | value=list(values))) 130 | features[spec.name] = feature 131 | return tf.train.Example(features=tf.train.Features(feature=features)) 132 | 133 | def _input_fn_builder(self, input_file, is_training): 134 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 135 | 136 | def input_fn(params): 137 | """The actual input function.""" 138 | d = tf.data.TFRecordDataset(input_file) 139 | if is_training: 140 | d = d.repeat() 141 | d = d.shuffle(buffer_size=100) 142 | return d.apply( 143 | tf.data.experimental.map_and_batch( 144 | self._decode_tfrecord, 145 | batch_size=params["batch_size"], 146 | drop_remainder=True)) 147 | 148 | return input_fn 149 | 150 | def _decode_tfrecord(self, record): 151 | """Decodes a record to a TensorFlow example.""" 152 | example = tf.io.parse_single_example(record, self._name_to_feature_config) 153 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 154 | # So cast all int64 to int32. 155 | for name, tensor in example.items(): 156 | if tensor.dtype == tf.int64: 157 | example[name] = tf.cast(tensor, tf.int32) 158 | else: 159 | example[name] = tensor 160 | return example 161 | -------------------------------------------------------------------------------- /finetune/qa/mrqa_official_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Official evaluation script for the MRQA Workshop Shared Task. 4 | Adapted fromt the SQuAD v1.1 official evaluation script. 5 | Modified slightly for the ELECTRA codebase. 6 | """ 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import string 13 | import re 14 | import json 15 | import tensorflow.compat.v1 as tf 16 | from collections import Counter 17 | 18 | import configure_finetuning 19 | 20 | 21 | def normalize_answer(s): 22 | """Lower text and remove punctuation, articles and extra whitespace.""" 23 | def remove_articles(text): 24 | return re.sub(r'\b(a|an|the)\b', ' ', text) 25 | 26 | def white_space_fix(text): 27 | return ' '.join(text.split()) 28 | 29 | def remove_punc(text): 30 | exclude = set(string.punctuation) 31 | return ''.join(ch for ch in text if ch not in exclude) 32 | 33 | def lower(text): 34 | return text.lower() 35 | 36 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 37 | 38 | 39 | def f1_score(prediction, ground_truth): 40 | prediction_tokens = normalize_answer(prediction).split() 41 | ground_truth_tokens = normalize_answer(ground_truth).split() 42 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 43 | num_same = sum(common.values()) 44 | if num_same == 0: 45 | return 0 46 | precision = 1.0 * num_same / len(prediction_tokens) 47 | recall = 1.0 * num_same / len(ground_truth_tokens) 48 | f1 = (2 * precision * recall) / (precision + recall) 49 | return f1 50 | 51 | 52 | def exact_match_score(prediction, ground_truth): 53 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 54 | 55 | 56 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 57 | scores_for_ground_truths = [] 58 | for ground_truth in ground_truths: 59 | score = metric_fn(prediction, ground_truth) 60 | scores_for_ground_truths.append(score) 61 | return max(scores_for_ground_truths) 62 | 63 | 64 | def read_predictions(prediction_file): 65 | with tf.io.gfile.GFile(prediction_file) as f: 66 | predictions = json.load(f) 67 | return predictions 68 | 69 | 70 | def read_answers(gold_file): 71 | answers = {} 72 | with tf.io.gfile.GFile(gold_file, 'r') as f: 73 | for i, line in enumerate(f): 74 | example = json.loads(line) 75 | if i == 0 and 'header' in example: 76 | continue 77 | for qa in example['qas']: 78 | answers[qa['qid']] = qa['answers'] 79 | return answers 80 | 81 | 82 | def evaluate(answers, predictions, skip_no_answer=False): 83 | f1 = exact_match = total = 0 84 | for qid, ground_truths in answers.items(): 85 | if qid not in predictions: 86 | if not skip_no_answer: 87 | message = 'Unanswered question %s will receive score 0.' % qid 88 | print(message) 89 | total += 1 90 | continue 91 | total += 1 92 | prediction = predictions[qid] 93 | exact_match += metric_max_over_ground_truths( 94 | exact_match_score, prediction, ground_truths) 95 | f1 += metric_max_over_ground_truths( 96 | f1_score, prediction, ground_truths) 97 | 98 | exact_match = 100.0 * exact_match / total 99 | f1 = 100.0 * f1 / total 100 | 101 | return {'exact_match': exact_match, 'f1': f1} 102 | 103 | 104 | def main(config: configure_finetuning.FinetuningConfig, split, task_name): 105 | answers = read_answers(os.path.join(config.raw_data_dir(task_name), split + ".jsonl")) 106 | predictions = read_predictions(config.qa_preds_file(task_name)) 107 | return evaluate(answers, predictions, True) 108 | -------------------------------------------------------------------------------- /finetune/qa/qa_metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Evaluation metrics for question-answering tasks.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import collections 10 | import numpy as np 11 | import six 12 | 13 | import configure_finetuning 14 | from finetune import scorer 15 | from finetune.qa import mrqa_official_eval 16 | from finetune.qa import squad_official_eval 17 | from finetune.qa import squad_official_eval_v1 18 | from model import tokenization 19 | from util import utils 20 | 21 | 22 | RawResult = collections.namedtuple("RawResult", [ 23 | "unique_id", "start_logits", "end_logits", "answerable_logit", 24 | "start_top_log_probs", "start_top_index", "end_top_log_probs", 25 | "end_top_index" 26 | ]) 27 | 28 | 29 | class SpanBasedQAScorer(scorer.Scorer): 30 | """Runs evaluation for SQuAD 1.1, SQuAD 2.0, and MRQA tasks.""" 31 | 32 | def __init__(self, config: configure_finetuning.FinetuningConfig, task, split, 33 | v2): 34 | super(SpanBasedQAScorer, self).__init__() 35 | self._config = config 36 | self._task = task 37 | self._name = task.name 38 | self._split = split 39 | self._v2 = v2 40 | self._all_results = [] 41 | self._total_loss = 0 42 | self._split = split 43 | self._eval_examples = task.get_examples(split) 44 | 45 | def update(self, results): 46 | super(SpanBasedQAScorer, self).update(results) 47 | self._all_results.append( 48 | RawResult( 49 | unique_id=results["eid"], 50 | start_logits=results["start_logits"], 51 | end_logits=results["end_logits"], 52 | answerable_logit=results["answerable_logit"], 53 | start_top_log_probs=results["start_top_log_probs"], 54 | start_top_index=results["start_top_index"], 55 | end_top_log_probs=results["end_top_log_probs"], 56 | end_top_index=results["end_top_index"], 57 | )) 58 | self._total_loss += results["loss"] 59 | 60 | def get_loss(self): 61 | return self._total_loss / len(self._all_results) 62 | 63 | def _get_results(self): 64 | self.write_predictions() 65 | if self._name == "squad": 66 | squad_official_eval.set_opts(self._config, self._split) 67 | squad_official_eval.main() 68 | return sorted(utils.load_json( 69 | self._config.qa_eval_file(self._name)).items()) 70 | elif self._name == "squadv1": 71 | return sorted(squad_official_eval_v1.main( 72 | self._config, self._split).items()) 73 | else: 74 | return sorted(mrqa_official_eval.main( 75 | self._config, self._split, self._name).items()) 76 | 77 | def write_predictions(self): 78 | """Write final predictions to the json file.""" 79 | unique_id_to_result = {} 80 | for result in self._all_results: 81 | unique_id_to_result[result.unique_id] = result 82 | 83 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 84 | "PrelimPrediction", 85 | ["feature_index", "start_index", "end_index", "start_logit", 86 | "end_logit"]) 87 | 88 | all_predictions = collections.OrderedDict() 89 | all_nbest_json = collections.OrderedDict() 90 | scores_diff_json = collections.OrderedDict() 91 | 92 | for example in self._eval_examples: 93 | example_id = example.qas_id if "squad" in self._name else example.qid 94 | features = self._task.featurize(example, False, for_eval=True) 95 | 96 | prelim_predictions = [] 97 | # keep track of the minimum score of null start+end of position 0 98 | score_null = 1000000 # large and positive 99 | for (feature_index, feature) in enumerate(features): 100 | result = unique_id_to_result[feature[self._name + "_eid"]] 101 | if self._config.joint_prediction: 102 | start_indexes = result.start_top_index 103 | end_indexes = result.end_top_index 104 | else: 105 | start_indexes = _get_best_indexes(result.start_logits, 106 | self._config.n_best_size) 107 | end_indexes = _get_best_indexes(result.end_logits, 108 | self._config.n_best_size) 109 | # if we could have irrelevant answers, get the min score of irrelevant 110 | if self._v2: 111 | if self._config.answerable_classifier: 112 | feature_null_score = result.answerable_logit 113 | else: 114 | feature_null_score = result.start_logits[0] + result.end_logits[0] 115 | if feature_null_score < score_null: 116 | score_null = feature_null_score 117 | for i, start_index in enumerate(start_indexes): 118 | for j, end_index in enumerate( 119 | end_indexes[i] if self._config.joint_prediction else end_indexes): 120 | # We could hypothetically create invalid predictions, e.g., predict 121 | # that the start of the span is in the question. We throw out all 122 | # invalid predictions. 123 | if start_index >= len(feature[self._name + "_tokens"]): 124 | continue 125 | if end_index >= len(feature[self._name + "_tokens"]): 126 | continue 127 | if start_index == 0: 128 | continue 129 | if start_index not in feature[self._name + "_token_to_orig_map"]: 130 | continue 131 | if end_index not in feature[self._name + "_token_to_orig_map"]: 132 | continue 133 | if not feature[self._name + "_token_is_max_context"].get( 134 | start_index, False): 135 | continue 136 | if end_index < start_index: 137 | continue 138 | length = end_index - start_index + 1 139 | if length > self._config.max_answer_length: 140 | continue 141 | start_logit = (result.start_top_log_probs[i] if 142 | self._config.joint_prediction else 143 | result.start_logits[start_index]) 144 | end_logit = (result.end_top_log_probs[i, j] if 145 | self._config.joint_prediction else 146 | result.end_logits[end_index]) 147 | prelim_predictions.append( 148 | _PrelimPrediction( 149 | feature_index=feature_index, 150 | start_index=start_index, 151 | end_index=end_index, 152 | start_logit=start_logit, 153 | end_logit=end_logit)) 154 | 155 | if self._v2: 156 | if len(prelim_predictions) == 0 and self._config.debug: 157 | tokid = sorted(feature[self._name + "_token_to_orig_map"].keys())[0] 158 | prelim_predictions.append(_PrelimPrediction( 159 | feature_index=0, 160 | start_index=tokid, 161 | end_index=tokid + 1, 162 | start_logit=1.0, 163 | end_logit=1.0)) 164 | prelim_predictions = sorted( 165 | prelim_predictions, 166 | key=lambda x: (x.start_logit + x.end_logit), 167 | reverse=True) 168 | 169 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 170 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 171 | 172 | seen_predictions = {} 173 | nbest = [] 174 | for pred in prelim_predictions: 175 | if len(nbest) >= self._config.n_best_size: 176 | break 177 | feature = features[pred.feature_index] 178 | tok_tokens = feature[self._name + "_tokens"][ 179 | pred.start_index:(pred.end_index + 1)] 180 | orig_doc_start = feature[ 181 | self._name + "_token_to_orig_map"][pred.start_index] 182 | orig_doc_end = feature[ 183 | self._name + "_token_to_orig_map"][pred.end_index] 184 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 185 | tok_text = " ".join(tok_tokens) 186 | 187 | # De-tokenize WordPieces that have been split off. 188 | tok_text = tok_text.replace(" ##", "") 189 | tok_text = tok_text.replace("##", "") 190 | 191 | # Clean whitespace 192 | tok_text = tok_text.strip() 193 | tok_text = " ".join(tok_text.split()) 194 | orig_text = " ".join(orig_tokens) 195 | 196 | final_text = get_final_text(self._config, tok_text, orig_text) 197 | if final_text in seen_predictions: 198 | continue 199 | 200 | seen_predictions[final_text] = True 201 | 202 | nbest.append( 203 | _NbestPrediction( 204 | text=final_text, 205 | start_logit=pred.start_logit, 206 | end_logit=pred.end_logit)) 207 | 208 | # In very rare edge cases we could have no valid predictions. So we 209 | # just create a nonce prediction in this case to avoid failure. 210 | if not nbest: 211 | nbest.append( 212 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 213 | 214 | assert len(nbest) >= 1 215 | 216 | total_scores = [] 217 | best_non_null_entry = None 218 | for entry in nbest: 219 | total_scores.append(entry.start_logit + entry.end_logit) 220 | if not best_non_null_entry: 221 | if entry.text: 222 | best_non_null_entry = entry 223 | 224 | probs = _compute_softmax(total_scores) 225 | 226 | nbest_json = [] 227 | for (i, entry) in enumerate(nbest): 228 | output = collections.OrderedDict() 229 | output["text"] = entry.text 230 | output["probability"] = probs[i] 231 | output["start_logit"] = entry.start_logit 232 | output["end_logit"] = entry.end_logit 233 | nbest_json.append(dict(output)) 234 | 235 | assert len(nbest_json) >= 1 236 | 237 | if not self._v2: 238 | all_predictions[example_id] = nbest_json[0]["text"] 239 | else: 240 | # predict "" iff the null score - the score of best non-null > threshold 241 | if self._config.answerable_classifier: 242 | score_diff = score_null 243 | else: 244 | score_diff = score_null - best_non_null_entry.start_logit - ( 245 | best_non_null_entry.end_logit) 246 | scores_diff_json[example_id] = score_diff 247 | all_predictions[example_id] = best_non_null_entry.text 248 | 249 | all_nbest_json[example_id] = nbest_json 250 | 251 | utils.write_json(dict(all_predictions), 252 | self._config.qa_preds_file(self._name)) 253 | if self._v2: 254 | utils.write_json({ 255 | k: float(v) for k, v in six.iteritems(scores_diff_json)}, 256 | self._config.qa_na_file(self._name)) 257 | 258 | 259 | def _get_best_indexes(logits, n_best_size): 260 | """Get the n-best logits from a list.""" 261 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 262 | 263 | best_indexes = [] 264 | for i in range(len(index_and_score)): 265 | if i >= n_best_size: 266 | break 267 | best_indexes.append(index_and_score[i][0]) 268 | return best_indexes 269 | 270 | 271 | def _compute_softmax(scores): 272 | """Compute softmax probability over raw logits.""" 273 | if not scores: 274 | return [] 275 | 276 | max_score = None 277 | for score in scores: 278 | if max_score is None or score > max_score: 279 | max_score = score 280 | 281 | exp_scores = [] 282 | total_sum = 0.0 283 | for score in scores: 284 | x = np.exp(score - max_score) 285 | exp_scores.append(x) 286 | total_sum += x 287 | 288 | probs = [] 289 | for score in exp_scores: 290 | probs.append(score / total_sum) 291 | return probs 292 | 293 | 294 | def get_final_text(config: configure_finetuning.FinetuningConfig, pred_text, 295 | orig_text): 296 | """Project the tokenized prediction back to the original text.""" 297 | 298 | # When we created the data, we kept track of the alignment between original 299 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 300 | # now `orig_text` contains the span of our original text corresponding to the 301 | # span that we predicted. 302 | # 303 | # However, `orig_text` may contain extra characters that we don't want in 304 | # our prediction. 305 | # 306 | # For example, let's say: 307 | # pred_text = steve smith 308 | # orig_text = Steve Smith's 309 | # 310 | # We don't want to return `orig_text` because it contains the extra "'s". 311 | # 312 | # We don't want to return `pred_text` because it's already been normalized 313 | # (the SQuAD eval script also does punctuation stripping/lower casing but 314 | # our tokenizer does additional normalization like stripping accent 315 | # characters). 316 | # 317 | # What we really want to return is "Steve Smith". 318 | # 319 | # Therefore, we have to apply a semi-complicated alignment heruistic between 320 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 321 | # can fail in certain cases in which case we just return `orig_text`. 322 | 323 | def _strip_spaces(text): 324 | ns_chars = [] 325 | ns_to_s_map = collections.OrderedDict() 326 | for i, c in enumerate(text): 327 | if c == " ": 328 | continue 329 | ns_to_s_map[len(ns_chars)] = i 330 | ns_chars.append(c) 331 | ns_text = "".join(ns_chars) 332 | return ns_text, dict(ns_to_s_map) 333 | 334 | # We first tokenize `orig_text`, strip whitespace from the result 335 | # and `pred_text`, and check if they are the same length. If they are 336 | # NOT the same length, the heuristic has failed. If they are the same 337 | # length, we assume the characters are one-to-one aligned. 338 | tokenizer = tokenization.BasicTokenizer(do_lower_case=config.do_lower_case) 339 | 340 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 341 | 342 | start_position = tok_text.find(pred_text) 343 | if start_position == -1: 344 | if config.debug: 345 | utils.log( 346 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 347 | return orig_text 348 | end_position = start_position + len(pred_text) - 1 349 | 350 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 351 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 352 | 353 | if len(orig_ns_text) != len(tok_ns_text): 354 | if config.debug: 355 | utils.log("Length not equal after stripping spaces: '%s' vs '%s'", 356 | orig_ns_text, tok_ns_text) 357 | return orig_text 358 | 359 | # We then project the characters in `pred_text` back to `orig_text` using 360 | # the character-to-character alignment. 361 | tok_s_to_ns_map = {} 362 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 363 | tok_s_to_ns_map[tok_index] = i 364 | 365 | orig_start_position = None 366 | if start_position in tok_s_to_ns_map: 367 | ns_start_position = tok_s_to_ns_map[start_position] 368 | if ns_start_position in orig_ns_to_s_map: 369 | orig_start_position = orig_ns_to_s_map[ns_start_position] 370 | 371 | if orig_start_position is None: 372 | if config.debug: 373 | utils.log("Couldn't map start position") 374 | return orig_text 375 | 376 | orig_end_position = None 377 | if end_position in tok_s_to_ns_map: 378 | ns_end_position = tok_s_to_ns_map[end_position] 379 | if ns_end_position in orig_ns_to_s_map: 380 | orig_end_position = orig_ns_to_s_map[ns_end_position] 381 | 382 | if orig_end_position is None: 383 | if config.debug: 384 | utils.log("Couldn't map end position") 385 | return orig_text 386 | 387 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 388 | return output_text 389 | -------------------------------------------------------------------------------- /finetune/qa/qa_tasks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Question answering tasks. SQuAD 1.1/2.0 and 2019 MRQA tasks are supported.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import abc 10 | import collections 11 | import json 12 | import os 13 | import six 14 | import tensorflow.compat.v1 as tf 15 | 16 | import configure_finetuning 17 | from finetune import feature_spec 18 | from finetune import task 19 | from finetune.qa import qa_metrics 20 | from model import modeling 21 | from model import tokenization 22 | from util import utils 23 | 24 | 25 | class QAExample(task.Example): 26 | """Question-answering example.""" 27 | 28 | def __init__(self, 29 | task_name, 30 | eid, 31 | qas_id, 32 | qid, 33 | question_text, 34 | doc_tokens, 35 | orig_answer_text=None, 36 | start_position=None, 37 | end_position=None, 38 | is_impossible=False): 39 | super(QAExample, self).__init__(task_name) 40 | self.eid = eid 41 | self.qas_id = qas_id 42 | self.qid = qid 43 | self.question_text = question_text 44 | self.doc_tokens = doc_tokens 45 | self.orig_answer_text = orig_answer_text 46 | self.start_position = start_position 47 | self.end_position = end_position 48 | self.is_impossible = is_impossible 49 | 50 | def __str__(self): 51 | return self.__repr__() 52 | 53 | def __repr__(self): 54 | s = "" 55 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 56 | s += ", question_text: %s" % ( 57 | tokenization.printable_text(self.question_text)) 58 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 59 | if self.start_position: 60 | s += ", start_position: %d" % self.start_position 61 | if self.start_position: 62 | s += ", end_position: %d" % self.end_position 63 | if self.start_position: 64 | s += ", is_impossible: %r" % self.is_impossible 65 | return s 66 | 67 | 68 | def _check_is_max_context(doc_spans, cur_span_index, position): 69 | """Check if this is the 'max context' doc span for the token.""" 70 | 71 | # Because of the sliding window approach taken to scoring documents, a single 72 | # token can appear in multiple documents. E.g. 73 | # Doc: the man went to the store and bought a gallon of milk 74 | # Span A: the man went to the 75 | # Span B: to the store and bought 76 | # Span C: and bought a gallon of 77 | # ... 78 | # 79 | # Now the word 'bought' will have two scores from spans B and C. We only 80 | # want to consider the score with "maximum context", which we define as 81 | # the *minimum* of its left and right context (the *sum* of left and 82 | # right context will always be the same, of course). 83 | # 84 | # In the example the maximum context for 'bought' would be span C since 85 | # it has 1 left context and 3 right context, while span B has 4 left context 86 | # and 0 right context. 87 | best_score = None 88 | best_span_index = None 89 | for (span_index, doc_span) in enumerate(doc_spans): 90 | end = doc_span.start + doc_span.length - 1 91 | if position < doc_span.start: 92 | continue 93 | if position > end: 94 | continue 95 | num_left_context = position - doc_span.start 96 | num_right_context = end - position 97 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 98 | if best_score is None or score > best_score: 99 | best_score = score 100 | best_span_index = span_index 101 | 102 | return cur_span_index == best_span_index 103 | 104 | 105 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 106 | orig_answer_text): 107 | """Returns tokenized answer spans that better match the annotated answer.""" 108 | 109 | # The SQuAD annotations are character based. We first project them to 110 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 111 | # often find a "better match". For example: 112 | # 113 | # Question: What year was John Smith born? 114 | # Context: The leader was John Smith (1895-1943). 115 | # Answer: 1895 116 | # 117 | # The original whitespace-tokenized answer will be "(1895-1943).". However 118 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 119 | # the exact answer, 1895. 120 | # 121 | # However, this is not always possible. Consider the following: 122 | # 123 | # Question: What country is the top exporter of electornics? 124 | # Context: The Japanese electronics industry is the lagest in the world. 125 | # Answer: Japan 126 | # 127 | # In this case, the annotator chose "Japan" as a character sub-span of 128 | # the word "Japanese". Since our WordPiece tokenizer does not split 129 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 130 | # in SQuAD, but does happen. 131 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 132 | 133 | for new_start in range(input_start, input_end + 1): 134 | for new_end in range(input_end, new_start - 1, -1): 135 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 136 | if text_span == tok_answer_text: 137 | return new_start, new_end 138 | 139 | return input_start, input_end 140 | 141 | 142 | def is_whitespace(c): 143 | return c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F 144 | 145 | 146 | class QATask(task.Task): 147 | """A span-based question answering tasks (e.g., SQuAD).""" 148 | 149 | __metaclass__ = abc.ABCMeta 150 | 151 | def __init__(self, config: configure_finetuning.FinetuningConfig, name, 152 | tokenizer, v2=False): 153 | super(QATask, self).__init__(config, name) 154 | self._tokenizer = tokenizer 155 | self._examples = {} 156 | self.v2 = v2 157 | 158 | def _add_examples(self, examples, example_failures, paragraph, split): 159 | paragraph_text = paragraph["context"] 160 | doc_tokens = [] 161 | char_to_word_offset = [] 162 | prev_is_whitespace = True 163 | for c in paragraph_text: 164 | if is_whitespace(c): 165 | prev_is_whitespace = True 166 | else: 167 | if prev_is_whitespace: 168 | doc_tokens.append(c) 169 | else: 170 | doc_tokens[-1] += c 171 | prev_is_whitespace = False 172 | char_to_word_offset.append(len(doc_tokens) - 1) 173 | 174 | for qa in paragraph["qas"]: 175 | qas_id = qa["id"] if "id" in qa else None 176 | qid = qa["qid"] if "qid" in qa else None 177 | question_text = qa["question"] 178 | start_position = None 179 | end_position = None 180 | orig_answer_text = None 181 | is_impossible = False 182 | if split == "train": 183 | if self.v2: 184 | is_impossible = qa["is_impossible"] 185 | if not is_impossible: 186 | if "detected_answers" in qa: # MRQA format 187 | answer = qa["detected_answers"][0] 188 | answer_offset = answer["char_spans"][0][0] 189 | else: # SQuAD format 190 | answer = qa["answers"][0] 191 | answer_offset = answer["answer_start"] 192 | orig_answer_text = answer["text"] 193 | answer_length = len(orig_answer_text) 194 | start_position = char_to_word_offset[answer_offset] 195 | if answer_offset + answer_length - 1 >= len(char_to_word_offset): 196 | utils.log("End position is out of document!") 197 | example_failures[0] += 1 198 | continue 199 | end_position = char_to_word_offset[answer_offset + answer_length - 1] 200 | 201 | # Only add answers where the text can be exactly recovered from the 202 | # document. If this CAN'T happen it's likely due to weird Unicode 203 | # stuff so we will just skip the example. 204 | # 205 | # Note that this means for training mode, every example is NOT 206 | # guaranteed to be preserved. 207 | actual_text = " ".join( 208 | doc_tokens[start_position:(end_position + 1)]) 209 | cleaned_answer_text = " ".join( 210 | tokenization.whitespace_tokenize(orig_answer_text)) 211 | actual_text = actual_text.lower() 212 | cleaned_answer_text = cleaned_answer_text.lower() 213 | if actual_text.find(cleaned_answer_text) == -1: 214 | utils.log("Could not find answer: '{:}' in doc vs. " 215 | "'{:}' in provided answer".format( 216 | tokenization.printable_text(actual_text), 217 | tokenization.printable_text(cleaned_answer_text))) 218 | example_failures[0] += 1 219 | continue 220 | else: 221 | start_position = -1 222 | end_position = -1 223 | orig_answer_text = "" 224 | 225 | example = QAExample( 226 | task_name=self.name, 227 | eid=len(examples), 228 | qas_id=qas_id, 229 | qid=qid, 230 | question_text=question_text, 231 | doc_tokens=doc_tokens, 232 | orig_answer_text=orig_answer_text, 233 | start_position=start_position, 234 | end_position=end_position, 235 | is_impossible=is_impossible) 236 | examples.append(example) 237 | 238 | def get_feature_specs(self): 239 | return [ 240 | feature_spec.FeatureSpec(self.name + "_eid", []), 241 | feature_spec.FeatureSpec(self.name + "_start_positions", []), 242 | feature_spec.FeatureSpec(self.name + "_end_positions", []), 243 | feature_spec.FeatureSpec(self.name + "_is_impossible", []), 244 | ] 245 | 246 | def featurize(self, example: QAExample, is_training, log=False, 247 | for_eval=False): 248 | all_features = [] 249 | query_tokens = self._tokenizer.tokenize(example.question_text) 250 | 251 | if len(query_tokens) > self.config.max_query_length: 252 | query_tokens = query_tokens[0:self.config.max_query_length] 253 | 254 | tok_to_orig_index = [] 255 | orig_to_tok_index = [] 256 | all_doc_tokens = [] 257 | for (i, token) in enumerate(example.doc_tokens): 258 | orig_to_tok_index.append(len(all_doc_tokens)) 259 | sub_tokens = self._tokenizer.tokenize(token) 260 | for sub_token in sub_tokens: 261 | tok_to_orig_index.append(i) 262 | all_doc_tokens.append(sub_token) 263 | 264 | tok_start_position = None 265 | tok_end_position = None 266 | if is_training and example.is_impossible: 267 | tok_start_position = -1 268 | tok_end_position = -1 269 | if is_training and not example.is_impossible: 270 | tok_start_position = orig_to_tok_index[example.start_position] 271 | if example.end_position < len(example.doc_tokens) - 1: 272 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 273 | else: 274 | tok_end_position = len(all_doc_tokens) - 1 275 | (tok_start_position, tok_end_position) = _improve_answer_span( 276 | all_doc_tokens, tok_start_position, tok_end_position, self._tokenizer, 277 | example.orig_answer_text) 278 | 279 | # The -3 accounts for [CLS], [SEP] and [SEP] 280 | max_tokens_for_doc = self.config.max_seq_length - len(query_tokens) - 3 281 | 282 | # We can have documents that are longer than the maximum sequence length. 283 | # To deal with this we do a sliding window approach, where we take chunks 284 | # of the up to our max length with a stride of `doc_stride`. 285 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 286 | "DocSpan", ["start", "length"]) 287 | doc_spans = [] 288 | start_offset = 0 289 | while start_offset < len(all_doc_tokens): 290 | length = len(all_doc_tokens) - start_offset 291 | if length > max_tokens_for_doc: 292 | length = max_tokens_for_doc 293 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 294 | if start_offset + length == len(all_doc_tokens): 295 | break 296 | start_offset += min(length, self.config.doc_stride) 297 | 298 | for (doc_span_index, doc_span) in enumerate(doc_spans): 299 | tokens = [] 300 | token_to_orig_map = {} 301 | token_is_max_context = {} 302 | segment_ids = [] 303 | tokens.append("[CLS]") 304 | segment_ids.append(0) 305 | for token in query_tokens: 306 | tokens.append(token) 307 | segment_ids.append(0) 308 | tokens.append("[SEP]") 309 | segment_ids.append(0) 310 | 311 | for i in range(doc_span.length): 312 | split_token_index = doc_span.start + i 313 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 314 | 315 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 316 | split_token_index) 317 | token_is_max_context[len(tokens)] = is_max_context 318 | tokens.append(all_doc_tokens[split_token_index]) 319 | segment_ids.append(1) 320 | tokens.append("[SEP]") 321 | segment_ids.append(1) 322 | 323 | input_ids = self._tokenizer.convert_tokens_to_ids(tokens) 324 | 325 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 326 | # tokens are attended to. 327 | input_mask = [1] * len(input_ids) 328 | 329 | # Zero-pad up to the sequence length. 330 | while len(input_ids) < self.config.max_seq_length: 331 | input_ids.append(0) 332 | input_mask.append(0) 333 | segment_ids.append(0) 334 | 335 | assert len(input_ids) == self.config.max_seq_length 336 | assert len(input_mask) == self.config.max_seq_length 337 | assert len(segment_ids) == self.config.max_seq_length 338 | 339 | start_position = None 340 | end_position = None 341 | if is_training and not example.is_impossible: 342 | # For training, if our document chunk does not contain an annotation 343 | # we throw it out, since there is nothing to predict. 344 | doc_start = doc_span.start 345 | doc_end = doc_span.start + doc_span.length - 1 346 | out_of_span = False 347 | if not (tok_start_position >= doc_start and 348 | tok_end_position <= doc_end): 349 | out_of_span = True 350 | if out_of_span: 351 | start_position = 0 352 | end_position = 0 353 | else: 354 | doc_offset = len(query_tokens) + 2 355 | start_position = tok_start_position - doc_start + doc_offset 356 | end_position = tok_end_position - doc_start + doc_offset 357 | 358 | if is_training and example.is_impossible: 359 | start_position = 0 360 | end_position = 0 361 | 362 | if log: 363 | utils.log("*** Example ***") 364 | utils.log("doc_span_index: %s" % doc_span_index) 365 | utils.log("tokens: %s" % " ".join( 366 | [tokenization.printable_text(x) for x in tokens])) 367 | utils.log("token_to_orig_map: %s" % " ".join( 368 | ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) 369 | utils.log("token_is_max_context: %s" % " ".join([ 370 | "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) 371 | ])) 372 | utils.log("input_ids: %s" % " ".join([str(x) for x in input_ids])) 373 | utils.log("input_mask: %s" % " ".join([str(x) for x in input_mask])) 374 | utils.log("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 375 | if is_training and example.is_impossible: 376 | utils.log("impossible example") 377 | if is_training and not example.is_impossible: 378 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 379 | utils.log("start_position: %d" % start_position) 380 | utils.log("end_position: %d" % end_position) 381 | utils.log("answer: %s" % (tokenization.printable_text(answer_text))) 382 | 383 | features = { 384 | "task_id": self.config.task_names.index(self.name), 385 | self.name + "_eid": (1000 * example.eid) + doc_span_index, 386 | "input_ids": input_ids, 387 | "input_mask": input_mask, 388 | "segment_ids": segment_ids, 389 | } 390 | if for_eval: 391 | features.update({ 392 | self.name + "_doc_span_index": doc_span_index, 393 | self.name + "_tokens": tokens, 394 | self.name + "_token_to_orig_map": token_to_orig_map, 395 | self.name + "_token_is_max_context": token_is_max_context, 396 | }) 397 | if is_training: 398 | features.update({ 399 | self.name + "_start_positions": start_position, 400 | self.name + "_end_positions": end_position, 401 | self.name + "_is_impossible": example.is_impossible 402 | }) 403 | all_features.append(features) 404 | return all_features 405 | 406 | def get_prediction_module(self, bert_model, features, is_training, 407 | percent_done): 408 | final_hidden = bert_model.get_sequence_output() 409 | 410 | final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) 411 | batch_size = final_hidden_shape[0] 412 | seq_length = final_hidden_shape[1] 413 | 414 | answer_mask = tf.cast(features["input_mask"], tf.float32) 415 | answer_mask *= tf.cast(features["segment_ids"], tf.float32) 416 | answer_mask += tf.one_hot(0, seq_length) 417 | 418 | start_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1) 419 | 420 | start_top_log_probs = tf.zeros([batch_size, self.config.beam_size]) 421 | start_top_index = tf.zeros([batch_size, self.config.beam_size], tf.int32) 422 | end_top_log_probs = tf.zeros([batch_size, self.config.beam_size, 423 | self.config.beam_size]) 424 | end_top_index = tf.zeros([batch_size, self.config.beam_size, 425 | self.config.beam_size], tf.int32) 426 | if self.config.joint_prediction: 427 | start_logits += 1000.0 * (answer_mask - 1) 428 | start_log_probs = tf.nn.log_softmax(start_logits) 429 | start_top_log_probs, start_top_index = tf.nn.top_k( 430 | start_log_probs, k=self.config.beam_size) 431 | 432 | if not is_training: 433 | # batch, beam, length, hidden 434 | end_features = tf.tile(tf.expand_dims(final_hidden, 1), 435 | [1, self.config.beam_size, 1, 1]) 436 | # batch, beam, length 437 | start_index = tf.one_hot(start_top_index, 438 | depth=seq_length, axis=-1, dtype=tf.float32) 439 | # batch, beam, hidden 440 | start_features = tf.reduce_sum( 441 | tf.expand_dims(final_hidden, 1) * 442 | tf.expand_dims(start_index, -1), axis=-2) 443 | # batch, beam, length, hidden 444 | start_features = tf.tile(tf.expand_dims(start_features, 2), 445 | [1, 1, seq_length, 1]) 446 | else: 447 | start_index = tf.one_hot( 448 | features[self.name + "_start_positions"], depth=seq_length, 449 | axis=-1, dtype=tf.float32) 450 | start_features = tf.reduce_sum(tf.expand_dims(start_index, -1) * 451 | final_hidden, axis=1) 452 | start_features = tf.tile(tf.expand_dims(start_features, 1), 453 | [1, seq_length, 1]) 454 | end_features = final_hidden 455 | 456 | final_repr = tf.concat([start_features, end_features], -1) 457 | final_repr = tf.layers.dense(final_repr, 512, activation=modeling.gelu, 458 | name="qa_hidden") 459 | # batch, beam, length (batch, length when training) 460 | end_logits = tf.squeeze(tf.layers.dense(final_repr, 1), -1, 461 | name="qa_logits") 462 | if is_training: 463 | end_logits += 1000.0 * (answer_mask - 1) 464 | else: 465 | end_logits += tf.expand_dims(1000.0 * (answer_mask - 1), 1) 466 | 467 | if not is_training: 468 | end_log_probs = tf.nn.log_softmax(end_logits) 469 | end_top_log_probs, end_top_index = tf.nn.top_k( 470 | end_log_probs, k=self.config.beam_size) 471 | end_logits = tf.zeros([batch_size, seq_length]) 472 | else: 473 | end_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1) 474 | start_logits += 1000.0 * (answer_mask - 1) 475 | end_logits += 1000.0 * (answer_mask - 1) 476 | 477 | def compute_loss(logits, positions): 478 | one_hot_positions = tf.one_hot( 479 | positions, depth=seq_length, dtype=tf.float32) 480 | log_probs = tf.nn.log_softmax(logits, axis=-1) 481 | loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1) 482 | return loss 483 | 484 | start_positions = features[self.name + "_start_positions"] 485 | end_positions = features[self.name + "_end_positions"] 486 | 487 | start_loss = compute_loss(start_logits, start_positions) 488 | end_loss = compute_loss(end_logits, end_positions) 489 | 490 | losses = (start_loss + end_loss) / 2.0 491 | 492 | answerable_logit = tf.zeros([batch_size]) 493 | if self.config.answerable_classifier: 494 | final_repr = final_hidden[:, 0] 495 | if self.config.answerable_uses_start_logits: 496 | start_p = tf.nn.softmax(start_logits) 497 | start_feature = tf.reduce_sum(tf.expand_dims(start_p, -1) * 498 | final_hidden, axis=1) 499 | final_repr = tf.concat([final_repr, start_feature], -1) 500 | final_repr = tf.layers.dense(final_repr, 512, 501 | activation=modeling.gelu) 502 | answerable_logit = tf.squeeze(tf.layers.dense(final_repr, 1), -1) 503 | answerable_loss = tf.nn.sigmoid_cross_entropy_with_logits( 504 | labels=tf.cast(features[self.name + "_is_impossible"], tf.float32), 505 | logits=answerable_logit) 506 | losses += answerable_loss * self.config.answerable_weight 507 | 508 | return losses, dict( 509 | loss=losses, 510 | start_logits=start_logits, 511 | end_logits=end_logits, 512 | answerable_logit=answerable_logit, 513 | start_positions=features[self.name + "_start_positions"], 514 | end_positions=features[self.name + "_end_positions"], 515 | start_top_log_probs=start_top_log_probs, 516 | start_top_index=start_top_index, 517 | end_top_log_probs=end_top_log_probs, 518 | end_top_index=end_top_index, 519 | eid=features[self.name + "_eid"], 520 | ) 521 | 522 | def get_scorer(self, split="dev"): 523 | return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2) 524 | 525 | 526 | class MRQATask(QATask): 527 | """Class for finetuning tasks from the 2019 MRQA shared task.""" 528 | 529 | def __init__(self, config: configure_finetuning.FinetuningConfig, name, 530 | tokenizer): 531 | super(MRQATask, self).__init__(config, name, tokenizer) 532 | 533 | def get_examples(self, split): 534 | if split in self._examples: 535 | utils.log("N EXAMPLES", split, len(self._examples[split])) 536 | return self._examples[split] 537 | 538 | examples = [] 539 | example_failures = [0] 540 | with tf.io.gfile.GFile(os.path.join( 541 | self.config.raw_data_dir(self.name), split + ".jsonl"), "r") as f: 542 | for i, line in enumerate(f): 543 | if self.config.debug and i > 10: 544 | break 545 | paragraph = json.loads(line.strip()) 546 | if "header" in paragraph: 547 | continue 548 | self._add_examples(examples, example_failures, paragraph, split) 549 | self._examples[split] = examples 550 | utils.log("{:} examples created, {:} failures".format( 551 | len(examples), example_failures[0])) 552 | return examples 553 | 554 | def get_scorer(self, split="dev"): 555 | return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2) 556 | 557 | 558 | class SQuADTask(QATask): 559 | """Class for finetuning on SQuAD 2.0 or 1.1.""" 560 | 561 | def __init__(self, config: configure_finetuning.FinetuningConfig, name, 562 | tokenizer, v2=False): 563 | super(SQuADTask, self).__init__(config, name, tokenizer, v2=v2) 564 | 565 | def get_examples(self, split): 566 | if split in self._examples: 567 | return self._examples[split] 568 | 569 | with tf.io.gfile.GFile(os.path.join( 570 | self.config.raw_data_dir(self.name), 571 | split + ("-debug" if self.config.debug else "") + ".json"), "r") as f: 572 | input_data = json.load(f)["data"] 573 | 574 | examples = [] 575 | example_failures = [0] 576 | for entry in input_data: 577 | for paragraph in entry["paragraphs"]: 578 | self._add_examples(examples, example_failures, paragraph, split) 579 | self._examples[split] = examples 580 | utils.log("{:} examples created, {:} failures".format( 581 | len(examples), example_failures[0])) 582 | return examples 583 | 584 | def get_scorer(self, split="dev"): 585 | return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2) 586 | 587 | 588 | class SQuAD(SQuADTask): 589 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 590 | super(SQuAD, self).__init__(config, "squad", tokenizer, v2=True) 591 | 592 | 593 | class SQuADv1(SQuADTask): 594 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 595 | super(SQuADv1, self).__init__(config, "squadv1", tokenizer) 596 | 597 | 598 | class NewsQA(MRQATask): 599 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 600 | super(NewsQA, self).__init__(config, "newsqa", tokenizer) 601 | 602 | 603 | class NaturalQuestions(MRQATask): 604 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 605 | super(NaturalQuestions, self).__init__(config, "naturalqs", tokenizer) 606 | 607 | 608 | class SearchQA(MRQATask): 609 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 610 | super(SearchQA, self).__init__(config, "searchqa", tokenizer) 611 | 612 | 613 | class TriviaQA(MRQATask): 614 | def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer): 615 | super(TriviaQA, self).__init__(config, "triviaqa", tokenizer) 616 | -------------------------------------------------------------------------------- /finetune/qa/squad_official_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Official evaluation script for SQuAD version 2.0. 4 | 5 | In addition to basic functionality, we also compute additional statistics and 6 | plot precision-recall curves if an additional na_prob.json file is provided. 7 | This file is expected to map question ID's to the model's predicted probability 8 | that a question is unanswerable. 9 | 10 | Modified slightly for the ELECTRA codebase. 11 | """ 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import argparse 17 | import collections 18 | import json 19 | import numpy as np 20 | import os 21 | import re 22 | import string 23 | import sys 24 | import tensorflow.compat.v1 as tf 25 | 26 | import configure_finetuning 27 | 28 | OPTS = None 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 32 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 33 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 34 | parser.add_argument('--out-file', '-o', metavar='eval.json', 35 | help='Write accuracy metrics to file (default is stdout).') 36 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 37 | help='Model estimates of probability of no answer.') 38 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 39 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 40 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 41 | help='Save precision-recall curves to directory.') 42 | parser.add_argument('--verbose', '-v', action='store_true') 43 | if len(sys.argv) == 1: 44 | parser.print_help() 45 | sys.exit(1) 46 | return parser.parse_args() 47 | 48 | def set_opts(config: configure_finetuning.FinetuningConfig, split): 49 | global OPTS 50 | Options = collections.namedtuple("Options", [ 51 | "data_file", "pred_file", "out_file", "na_prob_file", "na_prob_thresh", 52 | "out_image_dir", "verbose"]) 53 | OPTS = Options( 54 | data_file=os.path.join( 55 | config.raw_data_dir("squad"), 56 | split + ("-debug" if config.debug else "") + ".json"), 57 | pred_file=config.qa_preds_file("squad"), 58 | out_file=config.qa_eval_file("squad"), 59 | na_prob_file=config.qa_na_file("squad"), 60 | na_prob_thresh=config.qa_na_threshold, 61 | out_image_dir=None, 62 | verbose=False 63 | ) 64 | 65 | def make_qid_to_has_ans(dataset): 66 | qid_to_has_ans = {} 67 | for article in dataset: 68 | for p in article['paragraphs']: 69 | for qa in p['qas']: 70 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 71 | return qid_to_has_ans 72 | 73 | def normalize_answer(s): 74 | """Lower text and remove punctuation, articles and extra whitespace.""" 75 | def remove_articles(text): 76 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 77 | return re.sub(regex, ' ', text) 78 | def white_space_fix(text): 79 | return ' '.join(text.split()) 80 | def remove_punc(text): 81 | exclude = set(string.punctuation) 82 | return ''.join(ch for ch in text if ch not in exclude) 83 | def lower(text): 84 | return text.lower() 85 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 86 | 87 | def get_tokens(s): 88 | if not s: return [] 89 | return normalize_answer(s).split() 90 | 91 | def compute_exact(a_gold, a_pred): 92 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 93 | 94 | def compute_f1(a_gold, a_pred): 95 | gold_toks = get_tokens(a_gold) 96 | pred_toks = get_tokens(a_pred) 97 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 98 | num_same = sum(common.values()) 99 | if len(gold_toks) == 0 or len(pred_toks) == 0: 100 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 101 | return int(gold_toks == pred_toks) 102 | if num_same == 0: 103 | return 0 104 | precision = 1.0 * num_same / len(pred_toks) 105 | recall = 1.0 * num_same / len(gold_toks) 106 | f1 = (2 * precision * recall) / (precision + recall) 107 | return f1 108 | 109 | def get_raw_scores(dataset, preds): 110 | exact_scores = {} 111 | f1_scores = {} 112 | for article in dataset: 113 | for p in article['paragraphs']: 114 | for qa in p['qas']: 115 | qid = qa['id'] 116 | gold_answers = [a['text'] for a in qa['answers'] 117 | if normalize_answer(a['text'])] 118 | if not gold_answers: 119 | # For unanswerable questions, only correct answer is empty string 120 | gold_answers = [''] 121 | if qid not in preds: 122 | print('Missing prediction for %s' % qid) 123 | continue 124 | a_pred = preds[qid] 125 | # Take max over all gold answers 126 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 127 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 128 | return exact_scores, f1_scores 129 | 130 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 131 | new_scores = {} 132 | for qid, s in scores.items(): 133 | pred_na = na_probs[qid] > na_prob_thresh 134 | if pred_na: 135 | new_scores[qid] = float(not qid_to_has_ans[qid]) 136 | else: 137 | new_scores[qid] = s 138 | return new_scores 139 | 140 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 141 | if not qid_list: 142 | total = len(exact_scores) 143 | return collections.OrderedDict([ 144 | ('exact', 100.0 * sum(exact_scores.values()) / total), 145 | ('f1', 100.0 * sum(f1_scores.values()) / total), 146 | ('total', total), 147 | ]) 148 | else: 149 | total = len(qid_list) 150 | return collections.OrderedDict([ 151 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 152 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 153 | ('total', total), 154 | ]) 155 | 156 | def merge_eval(main_eval, new_eval, prefix): 157 | for k in new_eval: 158 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 159 | 160 | def plot_pr_curve(precisions, recalls, out_image, title): 161 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 162 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 163 | plt.xlabel('Recall') 164 | plt.ylabel('Precision') 165 | plt.xlim([0.0, 1.05]) 166 | plt.ylim([0.0, 1.05]) 167 | plt.title(title) 168 | plt.savefig(out_image) 169 | plt.clf() 170 | 171 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 172 | out_image=None, title=None): 173 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 174 | true_pos = 0.0 175 | cur_p = 1.0 176 | cur_r = 0.0 177 | precisions = [1.0] 178 | recalls = [0.0] 179 | avg_prec = 0.0 180 | for i, qid in enumerate(qid_list): 181 | if qid_to_has_ans[qid]: 182 | true_pos += scores[qid] 183 | cur_p = true_pos / float(i+1) 184 | cur_r = true_pos / float(num_true_pos) 185 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 186 | # i.e., if we can put a threshold after this point 187 | avg_prec += cur_p * (cur_r - recalls[-1]) 188 | precisions.append(cur_p) 189 | recalls.append(cur_r) 190 | if out_image: 191 | plot_pr_curve(precisions, recalls, out_image, title) 192 | return {'ap': 100.0 * avg_prec} 193 | 194 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 195 | qid_to_has_ans, out_image_dir): 196 | if out_image_dir and not os.path.exists(out_image_dir): 197 | os.makedirs(out_image_dir) 198 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 199 | if num_true_pos == 0: 200 | return 201 | pr_exact = make_precision_recall_eval( 202 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 203 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 204 | title='Precision-Recall curve for Exact Match score') 205 | pr_f1 = make_precision_recall_eval( 206 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 207 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 208 | title='Precision-Recall curve for F1 score') 209 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 210 | pr_oracle = make_precision_recall_eval( 211 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 212 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 213 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 214 | merge_eval(main_eval, pr_exact, 'pr_exact') 215 | merge_eval(main_eval, pr_f1, 'pr_f1') 216 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 217 | 218 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 219 | if not qid_list: 220 | return 221 | x = [na_probs[k] for k in qid_list] 222 | weights = np.ones_like(x) / float(len(x)) 223 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 224 | plt.xlabel('Model probability of no-answer') 225 | plt.ylabel('Proportion of dataset') 226 | plt.title('Histogram of no-answer probability: %s' % name) 227 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 228 | plt.clf() 229 | 230 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 231 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 232 | cur_score = num_no_ans 233 | best_score = cur_score 234 | best_thresh = 0.0 235 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 236 | for i, qid in enumerate(qid_list): 237 | if qid not in scores: continue 238 | if qid_to_has_ans[qid]: 239 | diff = scores[qid] 240 | else: 241 | if preds[qid]: 242 | diff = -1 243 | else: 244 | diff = 0 245 | cur_score += diff 246 | if cur_score > best_score: 247 | best_score = cur_score 248 | best_thresh = na_probs[qid] 249 | return 100.0 * best_score / len(scores), best_thresh 250 | 251 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 252 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 253 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 254 | main_eval['best_exact'] = best_exact 255 | main_eval['best_exact_thresh'] = exact_thresh 256 | main_eval['best_f1'] = best_f1 257 | main_eval['best_f1_thresh'] = f1_thresh 258 | 259 | def main(): 260 | with tf.io.gfile.GFile(OPTS.data_file) as f: 261 | dataset_json = json.load(f) 262 | dataset = dataset_json['data'] 263 | with tf.io.gfile.GFile(OPTS.pred_file) as f: 264 | preds = json.load(f) 265 | if OPTS.na_prob_file: 266 | with tf.io.gfile.GFile(OPTS.na_prob_file) as f: 267 | na_probs = json.load(f) 268 | else: 269 | na_probs = {k: 0.0 for k in preds} 270 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 271 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 272 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 273 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 274 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 275 | OPTS.na_prob_thresh) 276 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 277 | OPTS.na_prob_thresh) 278 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 279 | if has_ans_qids: 280 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 281 | merge_eval(out_eval, has_ans_eval, 'HasAns') 282 | if no_ans_qids: 283 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 284 | merge_eval(out_eval, no_ans_eval, 'NoAns') 285 | if OPTS.na_prob_file: 286 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 287 | if OPTS.na_prob_file and OPTS.out_image_dir: 288 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 289 | qid_to_has_ans, OPTS.out_image_dir) 290 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 291 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 292 | if OPTS.out_file: 293 | with tf.io.gfile.GFile(OPTS.out_file, 'w') as f: 294 | json.dump(out_eval, f) 295 | else: 296 | print(json.dumps(out_eval, indent=2)) 297 | 298 | if __name__ == '__main__': 299 | OPTS = parse_args() 300 | if OPTS.out_image_dir: 301 | import matplotlib 302 | matplotlib.use('Agg') 303 | import matplotlib.pyplot as plt 304 | main() 305 | -------------------------------------------------------------------------------- /finetune/qa/squad_official_eval_v1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Official evaluation script for v1.1 of the SQuAD dataset. 5 | Modified slightly for the ELECTRA codebase. 6 | """ 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from collections import Counter 11 | import string 12 | import re 13 | import json 14 | import sys 15 | import os 16 | import collections 17 | import tensorflow.compat.v1 as tf 18 | 19 | import configure_finetuning 20 | 21 | 22 | def normalize_answer(s): 23 | """Lower text and remove punctuation, articles and extra whitespace.""" 24 | def remove_articles(text): 25 | return re.sub(r'\b(a|an|the)\b', ' ', text) 26 | 27 | def white_space_fix(text): 28 | return ' '.join(text.split()) 29 | 30 | def remove_punc(text): 31 | exclude = set(string.punctuation) 32 | return ''.join(ch for ch in text if ch not in exclude) 33 | 34 | def lower(text): 35 | return text.lower() 36 | 37 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 38 | 39 | 40 | def f1_score(prediction, ground_truth): 41 | prediction_tokens = normalize_answer(prediction).split() 42 | ground_truth_tokens = normalize_answer(ground_truth).split() 43 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 44 | num_same = sum(common.values()) 45 | if num_same == 0: 46 | return 0 47 | precision = 1.0 * num_same / len(prediction_tokens) 48 | recall = 1.0 * num_same / len(ground_truth_tokens) 49 | f1 = (2 * precision * recall) / (precision + recall) 50 | return f1 51 | 52 | 53 | def exact_match_score(prediction, ground_truth): 54 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 55 | 56 | 57 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 58 | scores_for_ground_truths = [] 59 | for ground_truth in ground_truths: 60 | score = metric_fn(prediction, ground_truth) 61 | scores_for_ground_truths.append(score) 62 | return max(scores_for_ground_truths) 63 | 64 | 65 | def evaluate(dataset, predictions): 66 | f1 = exact_match = total = 0 67 | for article in dataset: 68 | for paragraph in article['paragraphs']: 69 | for qa in paragraph['qas']: 70 | total += 1 71 | if qa['id'] not in predictions: 72 | message = 'Unanswered question ' + qa['id'] + \ 73 | ' will receive score 0.' 74 | print(message, file=sys.stderr) 75 | continue 76 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 77 | prediction = predictions[qa['id']] 78 | exact_match += metric_max_over_ground_truths( 79 | exact_match_score, prediction, ground_truths) 80 | f1 += metric_max_over_ground_truths( 81 | f1_score, prediction, ground_truths) 82 | 83 | exact_match = 100.0 * exact_match / total 84 | f1 = 100.0 * f1 / total 85 | 86 | return {'exact_match': exact_match, 'f1': f1} 87 | 88 | 89 | def main(config: configure_finetuning.FinetuningConfig, split): 90 | expected_version = '1.1' 91 | # parser = argparse.ArgumentParser( 92 | # description='Evaluation for SQuAD ' + expected_version) 93 | # parser.add_argument('dataset_file', help='Dataset file') 94 | # parser.add_argument('prediction_file', help='Prediction File') 95 | # args = parser.parse_args() 96 | Args = collections.namedtuple("Args", [ 97 | "dataset_file", "prediction_file" 98 | ]) 99 | args = Args(dataset_file=os.path.join( 100 | config.raw_data_dir("squadv1"), 101 | split + ("-debug" if config.debug else "") + ".json"), 102 | prediction_file=config.qa_preds_file("squadv1")) 103 | with tf.io.gfile.GFile(args.dataset_file) as dataset_file: 104 | dataset_json = json.load(dataset_file) 105 | if dataset_json['version'] != expected_version: 106 | print('Evaluation expects v-' + expected_version + 107 | ', but got dataset with v-' + dataset_json['version'], 108 | file=sys.stderr) 109 | dataset = dataset_json['data'] 110 | with tf.io.gfile.GFile(args.prediction_file) as prediction_file: 111 | predictions = json.load(prediction_file) 112 | return evaluate(dataset, predictions) 113 | 114 | -------------------------------------------------------------------------------- /finetune/scorer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Base class for evaluation metrics.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import abc 10 | 11 | 12 | class Scorer(object): 13 | """Abstract base class for computing evaluation metrics.""" 14 | 15 | __metaclass__ = abc.ABCMeta 16 | 17 | def __init__(self): 18 | self._updated = False 19 | self._cached_results = {} 20 | 21 | @abc.abstractmethod 22 | def update(self, results): 23 | self._updated = True 24 | 25 | @abc.abstractmethod 26 | def get_loss(self): 27 | pass 28 | 29 | @abc.abstractmethod 30 | def _get_results(self): 31 | return [] 32 | 33 | def get_results(self, prefix=""): 34 | results = self._get_results() if self._updated else self._cached_results 35 | self._cached_results = results 36 | self._updated = False 37 | return [(prefix + k, v) for k, v in results] 38 | 39 | def results_str(self): 40 | return " - ".join(["{:}: {:.2f}".format(k, v) 41 | for k, v in self.get_results()]) 42 | -------------------------------------------------------------------------------- /finetune/tagging/tagging_metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Metrics for sequence tagging tasks.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import abc 10 | import six 11 | 12 | import numpy as np 13 | 14 | from finetune import scorer 15 | from finetune.tagging import tagging_utils 16 | 17 | 18 | class WordLevelScorer(scorer.Scorer): 19 | """Base class for tagging scorers.""" 20 | __metaclass__ = abc.ABCMeta 21 | 22 | def __init__(self): 23 | super(WordLevelScorer, self).__init__() 24 | self._total_loss = 0 25 | self._total_words = 0 26 | self._labels = [] 27 | self._preds = [] 28 | 29 | def update(self, results): 30 | super(WordLevelScorer, self).update(results) 31 | self._total_loss += results['loss'] 32 | n_words = int(round(np.sum(results['labels_mask']))) 33 | self._labels.append(results['labels'][:n_words]) 34 | self._preds.append(results['predictions'][:n_words]) 35 | self._total_loss += np.sum(results['loss']) 36 | self._total_words += n_words 37 | 38 | def get_loss(self): 39 | return self._total_loss / max(1, self._total_words) 40 | 41 | 42 | class AccuracyScorer(WordLevelScorer): 43 | """Computes accuracy scores.""" 44 | 45 | def __init__(self, auto_fail_label=None): 46 | super(AccuracyScorer, self).__init__() 47 | self._auto_fail_label = auto_fail_label 48 | 49 | def _get_results(self): 50 | correct, count = 0, 0 51 | for labels, preds in zip(self._labels, self._preds): 52 | for y_true, y_pred in zip(labels, preds): 53 | count += 1 54 | correct += (1 if y_pred == y_true and y_true != self._auto_fail_label 55 | else 0) 56 | return [ 57 | ('accuracy', 100.0 * correct / count), 58 | ('loss', self.get_loss()) 59 | ] 60 | 61 | 62 | class F1Scorer(WordLevelScorer): 63 | """Computes F1 scores.""" 64 | 65 | __metaclass__ = abc.ABCMeta 66 | 67 | def __init__(self): 68 | super(F1Scorer, self).__init__() 69 | self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0 70 | 71 | def _get_results(self): 72 | if self._n_correct == 0: 73 | p, r, f1 = 0, 0, 0 74 | else: 75 | p = 100.0 * self._n_correct / self._n_predicted 76 | r = 100.0 * self._n_correct / self._n_gold 77 | f1 = 2 * p * r / (p + r) 78 | return [ 79 | ('precision', p), 80 | ('recall', r), 81 | ('f1', f1), 82 | ('loss', self.get_loss()), 83 | ] 84 | 85 | 86 | class EntityLevelF1Scorer(F1Scorer): 87 | """Computes F1 score for entity-level tasks such as NER.""" 88 | 89 | def __init__(self, label_mapping): 90 | super(EntityLevelF1Scorer, self).__init__() 91 | self._inv_label_mapping = {v: k for k, v in six.iteritems(label_mapping)} 92 | 93 | def _get_results(self): 94 | self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0 95 | for labels, preds in zip(self._labels, self._preds): 96 | sent_spans = set(tagging_utils.get_span_labels( 97 | labels, self._inv_label_mapping)) 98 | span_preds = set(tagging_utils.get_span_labels( 99 | preds, self._inv_label_mapping)) 100 | self._n_correct += len(sent_spans & span_preds) 101 | self._n_gold += len(sent_spans) 102 | self._n_predicted += len(span_preds) 103 | return super(EntityLevelF1Scorer, self)._get_results() 104 | -------------------------------------------------------------------------------- /finetune/tagging/tagging_tasks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Sequence tagging tasks.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import abc 10 | import collections 11 | import os 12 | import tensorflow.compat.v1 as tf 13 | 14 | import configure_finetuning 15 | from finetune import feature_spec 16 | from finetune import task 17 | from finetune.tagging import tagging_metrics 18 | from finetune.tagging import tagging_utils 19 | from model import tokenization 20 | from pretrain import pretrain_helpers 21 | from util import utils 22 | 23 | 24 | LABEL_ENCODING = "BIOES" 25 | 26 | 27 | class TaggingExample(task.Example): 28 | """A single tagged input sequence.""" 29 | 30 | def __init__(self, eid, task_name, words, tags, is_token_level, 31 | label_mapping): 32 | super(TaggingExample, self).__init__(task_name) 33 | self.eid = eid 34 | self.words = words 35 | if is_token_level: 36 | labels = tags 37 | else: 38 | span_labels = tagging_utils.get_span_labels(tags) 39 | labels = tagging_utils.get_tags( 40 | span_labels, len(words), LABEL_ENCODING) 41 | self.labels = [label_mapping[l] for l in labels] 42 | 43 | 44 | class TaggingTask(task.Task): 45 | """Defines a sequence tagging task (e.g., part-of-speech tagging).""" 46 | 47 | __metaclass__ = abc.ABCMeta 48 | 49 | def __init__(self, config: configure_finetuning.FinetuningConfig, name, 50 | tokenizer, is_token_level): 51 | super(TaggingTask, self).__init__(config, name) 52 | self._tokenizer = tokenizer 53 | self._label_mapping_path = os.path.join( 54 | self.config.preprocessed_data_dir, 55 | ("debug_" if self.config.debug else "") + self.name + 56 | "_label_mapping.pkl") 57 | self._is_token_level = is_token_level 58 | self._label_mapping = None 59 | 60 | def get_examples(self, split): 61 | sentences = self._get_labeled_sentences(split) 62 | examples = [] 63 | label_mapping = self._get_label_mapping(split, sentences) 64 | for i, (words, tags) in enumerate(sentences): 65 | examples.append(TaggingExample( 66 | i, self.name, words, tags, self._is_token_level, label_mapping 67 | )) 68 | return examples 69 | 70 | def _get_label_mapping(self, provided_split=None, provided_sentences=None): 71 | if self._label_mapping is not None: 72 | return self._label_mapping 73 | if tf.io.gfile.exists(self._label_mapping_path): 74 | self._label_mapping = utils.load_pickle(self._label_mapping_path) 75 | return self._label_mapping 76 | utils.log("Writing label mapping for task", self.name) 77 | tag_counts = collections.Counter() 78 | train_tags = set() 79 | for split in ["train", "dev", "test"]: 80 | if not tf.io.gfile.exists(os.path.join( 81 | self.config.raw_data_dir(self.name), split + ".txt")): 82 | continue 83 | if split == provided_split: 84 | split_sentences = provided_sentences 85 | else: 86 | split_sentences = self._get_labeled_sentences(split) 87 | for _, tags in split_sentences: 88 | if not self._is_token_level: 89 | span_labels = tagging_utils.get_span_labels(tags) 90 | tags = tagging_utils.get_tags(span_labels, len(tags), LABEL_ENCODING) 91 | for tag in tags: 92 | tag_counts[tag] += 1 93 | if provided_split == "train": 94 | train_tags.add(tag) 95 | if self.name == "ccg": 96 | infrequent_tags = [] 97 | for tag in tag_counts: 98 | if tag not in train_tags: 99 | infrequent_tags.append(tag) 100 | label_mapping = { 101 | label: i for i, label in enumerate(sorted(filter( 102 | lambda t: t not in infrequent_tags, tag_counts.keys()))) 103 | } 104 | n = len(label_mapping) 105 | for tag in infrequent_tags: 106 | label_mapping[tag] = n 107 | else: 108 | labels = sorted(tag_counts.keys()) 109 | label_mapping = {label: i for i, label in enumerate(labels)} 110 | utils.write_pickle(label_mapping, self._label_mapping_path) 111 | self._label_mapping = label_mapping 112 | return label_mapping 113 | 114 | def featurize(self, example: TaggingExample, is_training, log=False): 115 | words_to_tokens = tokenize_and_align(self._tokenizer, example.words) 116 | input_ids = [] 117 | tagged_positions = [] 118 | for word_tokens in words_to_tokens: 119 | if len(words_to_tokens) + len(input_ids) + 1 > self.config.max_seq_length: 120 | input_ids.append(self._tokenizer.vocab["[SEP]"]) 121 | break 122 | if "[CLS]" not in word_tokens and "[SEP]" not in word_tokens: 123 | tagged_positions.append(len(input_ids)) 124 | for token in word_tokens: 125 | input_ids.append(self._tokenizer.vocab[token]) 126 | 127 | pad = lambda x: x + [0] * (self.config.max_seq_length - len(x)) 128 | labels = pad(example.labels[:self.config.max_seq_length]) 129 | labeled_positions = pad(tagged_positions) 130 | labels_mask = pad([1.0] * len(tagged_positions)) 131 | segment_ids = pad([1] * len(input_ids)) 132 | input_mask = pad([1] * len(input_ids)) 133 | input_ids = pad(input_ids) 134 | assert len(input_ids) == self.config.max_seq_length 135 | assert len(input_mask) == self.config.max_seq_length 136 | assert len(segment_ids) == self.config.max_seq_length 137 | assert len(labels) == self.config.max_seq_length 138 | assert len(labels_mask) == self.config.max_seq_length 139 | 140 | return { 141 | "input_ids": input_ids, 142 | "input_mask": input_mask, 143 | "segment_ids": segment_ids, 144 | "task_id": self.config.task_names.index(self.name), 145 | self.name + "_eid": example.eid, 146 | self.name + "_labels": labels, 147 | self.name + "_labels_mask": labels_mask, 148 | self.name + "_labeled_positions": labeled_positions 149 | } 150 | 151 | def _get_labeled_sentences(self, split): 152 | sentences = [] 153 | with tf.io.gfile.GFile(os.path.join(self.config.raw_data_dir(self.name), 154 | split + ".txt"), "r") as f: 155 | sentence = [] 156 | for line in f: 157 | line = line.strip().split() 158 | if not line: 159 | if sentence: 160 | words, tags = zip(*sentence) 161 | sentences.append((words, tags)) 162 | sentence = [] 163 | if self.config.debug and len(sentences) > 100: 164 | return sentences 165 | continue 166 | if line[0] == "-DOCSTART-": 167 | continue 168 | word, tag = line[0], line[-1] 169 | sentence.append((word, tag)) 170 | return sentences 171 | 172 | def get_scorer(self): 173 | return tagging_metrics.AccuracyScorer() if self._is_token_level else \ 174 | tagging_metrics.EntityLevelF1Scorer(self._get_label_mapping()) 175 | 176 | def get_feature_specs(self): 177 | return [ 178 | feature_spec.FeatureSpec(self.name + "_eid", []), 179 | feature_spec.FeatureSpec(self.name + "_labels", 180 | [self.config.max_seq_length]), 181 | feature_spec.FeatureSpec(self.name + "_labels_mask", 182 | [self.config.max_seq_length], 183 | is_int_feature=False), 184 | feature_spec.FeatureSpec(self.name + "_labeled_positions", 185 | [self.config.max_seq_length]), 186 | ] 187 | 188 | def get_prediction_module( 189 | self, bert_model, features, is_training, percent_done): 190 | n_classes = len(self._get_label_mapping()) 191 | reprs = bert_model.get_sequence_output() 192 | reprs = pretrain_helpers.gather_positions( 193 | reprs, features[self.name + "_labeled_positions"]) 194 | logits = tf.layers.dense(reprs, n_classes) 195 | losses = tf.nn.softmax_cross_entropy_with_logits( 196 | labels=tf.one_hot(features[self.name + "_labels"], n_classes), 197 | logits=logits) 198 | losses *= features[self.name + "_labels_mask"] 199 | losses = tf.reduce_sum(losses, axis=-1) 200 | return losses, dict( 201 | loss=losses, 202 | logits=logits, 203 | predictions=tf.argmax(logits, axis=-1), 204 | labels=features[self.name + "_labels"], 205 | labels_mask=features[self.name + "_labels_mask"], 206 | eid=features[self.name + "_eid"], 207 | ) 208 | 209 | def _create_examples(self, lines, split): 210 | pass 211 | 212 | 213 | def tokenize_and_align(tokenizer, words, cased=False): 214 | """Splits up words into subword-level tokens.""" 215 | words = ["[CLS]"] + list(words) + ["[SEP]"] 216 | basic_tokenizer = tokenizer.basic_tokenizer 217 | tokenized_words = [] 218 | for word in words: 219 | word = tokenization.convert_to_unicode(word) 220 | word = basic_tokenizer._clean_text(word) 221 | if word == "[CLS]" or word == "[SEP]": 222 | word_toks = [word] 223 | else: 224 | if not cased: 225 | word = word.lower() 226 | word = basic_tokenizer._run_strip_accents(word) 227 | word_toks = basic_tokenizer._run_split_on_punc(word) 228 | tokenized_word = [] 229 | for word_tok in word_toks: 230 | tokenized_word += tokenizer.wordpiece_tokenizer.tokenize(word_tok) 231 | tokenized_words.append(tokenized_word) 232 | assert len(tokenized_words) == len(words) 233 | return tokenized_words 234 | 235 | 236 | class Chunking(TaggingTask): 237 | """Text chunking.""" 238 | 239 | def __init__(self, config, tokenizer): 240 | super(Chunking, self).__init__(config, "chunk", tokenizer, False) 241 | -------------------------------------------------------------------------------- /finetune/tagging/tagging_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Utilities for sequence tagging tasks.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | 10 | def get_span_labels(sentence_tags, inv_label_mapping=None): 11 | """Go from token-level labels to list of entities (start, end, class).""" 12 | if inv_label_mapping: 13 | sentence_tags = [inv_label_mapping[i] for i in sentence_tags] 14 | span_labels = [] 15 | last = 'O' 16 | start = -1 17 | for i, tag in enumerate(sentence_tags): 18 | pos, _ = (None, 'O') if tag == 'O' else tag.split('-') 19 | if (pos == 'S' or pos == 'B' or tag == 'O') and last != 'O': 20 | span_labels.append((start, i - 1, last.split('-')[-1])) 21 | if pos == 'B' or pos == 'S' or last == 'O': 22 | start = i 23 | last = tag 24 | if sentence_tags[-1] != 'O': 25 | span_labels.append((start, len(sentence_tags) - 1, 26 | sentence_tags[-1].split('-')[-1])) 27 | return span_labels 28 | 29 | 30 | def get_tags(span_labels, length, encoding): 31 | """Converts a list of entities to token-label labels based on the provided 32 | encoding (e.g., BIOES). 33 | """ 34 | 35 | tags = ['O' for _ in range(length)] 36 | for s, e, t in span_labels: 37 | for i in range(s, e + 1): 38 | tags[i] = 'I-' + t 39 | if 'E' in encoding: 40 | tags[e] = 'E-' + t 41 | if 'B' in encoding: 42 | tags[s] = 'B-' + t 43 | if 'S' in encoding and s - e == 0: 44 | tags[s] = 'S-' + t 45 | return tags 46 | -------------------------------------------------------------------------------- /finetune/task.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Defines a supervised NLP task.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import abc 10 | from typing import List, Tuple 11 | 12 | import configure_finetuning 13 | from finetune import feature_spec 14 | from finetune import scorer 15 | from model import modeling 16 | 17 | 18 | class Example(object): 19 | __metaclass__ = abc.ABCMeta 20 | 21 | def __init__(self, task_name): 22 | self.task_name = task_name 23 | 24 | 25 | class Task(object): 26 | """Override this class to add a new fine-tuning task.""" 27 | 28 | __metaclass__ = abc.ABCMeta 29 | 30 | def __init__(self, config: configure_finetuning.FinetuningConfig, name): 31 | self.config = config 32 | self.name = name 33 | 34 | def get_test_splits(self): 35 | return ["test"] 36 | 37 | @abc.abstractmethod 38 | def get_examples(self, split): 39 | pass 40 | 41 | @abc.abstractmethod 42 | def get_scorer(self) -> scorer.Scorer: 43 | pass 44 | 45 | @abc.abstractmethod 46 | def get_feature_specs(self) -> List[feature_spec.FeatureSpec]: 47 | pass 48 | 49 | @abc.abstractmethod 50 | def featurize(self, example: Example, is_training: bool, 51 | log: bool=False): 52 | pass 53 | 54 | @abc.abstractmethod 55 | def get_prediction_module( 56 | self, bert_model: modeling.BertModel, features: dict, is_training: bool, 57 | percent_done: float) -> Tuple: 58 | pass 59 | 60 | def __repr__(self): 61 | return "Task(" + self.name + ")" 62 | -------------------------------------------------------------------------------- /finetune/task_builder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Returns task instances given the task name.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import configure_finetuning 10 | from finetune.classification import classification_tasks 11 | from finetune.qa import qa_tasks 12 | from finetune.tagging import tagging_tasks 13 | from model import tokenization 14 | 15 | 16 | def get_tasks(config: configure_finetuning.FinetuningConfig): 17 | tokenizer = tokenization.FullTokenizer(vocab_file=config.vocab_file, 18 | do_lower_case=config.do_lower_case) 19 | return [get_task(config, task_name, tokenizer) 20 | for task_name in config.task_names] 21 | 22 | 23 | def get_task(config: configure_finetuning.FinetuningConfig, task_name, 24 | tokenizer): 25 | """Get an instance of a task based on its name.""" 26 | if task_name == "cola": 27 | return classification_tasks.CoLA(config, tokenizer) 28 | elif task_name == "mrpc": 29 | return classification_tasks.MRPC(config, tokenizer) 30 | elif task_name == "mnli": 31 | return classification_tasks.MNLI(config, tokenizer) 32 | elif task_name == "sst": 33 | return classification_tasks.SST(config, tokenizer) 34 | elif task_name == "rte": 35 | return classification_tasks.RTE(config, tokenizer) 36 | elif task_name == "qnli": 37 | return classification_tasks.QNLI(config, tokenizer) 38 | elif task_name == "qqp": 39 | return classification_tasks.QQP(config, tokenizer) 40 | elif task_name == "sts": 41 | return classification_tasks.STS(config, tokenizer) 42 | elif task_name == "wnli": 43 | return classification_tasks.WNLI(config, tokenizer) 44 | elif task_name == "squad": 45 | return qa_tasks.SQuAD(config, tokenizer) 46 | elif task_name == "squadv1": 47 | return qa_tasks.SQuADv1(config, tokenizer) 48 | elif task_name == "newsqa": 49 | return qa_tasks.NewsQA(config, tokenizer) 50 | elif task_name == "naturalqs": 51 | return qa_tasks.NaturalQuestions(config, tokenizer) 52 | elif task_name == "triviaqa": 53 | return qa_tasks.TriviaQA(config, tokenizer) 54 | elif task_name == "searchqa": 55 | return qa_tasks.SearchQA(config, tokenizer) 56 | elif task_name == "chunk": 57 | return tagging_tasks.Chunking(config, tokenizer) 58 | else: 59 | raise ValueError("Unknown task " + task_name) 60 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | -------------------------------------------------------------------------------- /model/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Functions and classes related to optimization (weight updates). 4 | Modified from the original BERT code to allow for having separate learning 5 | rates for different layers of the network. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import collections 13 | import re 14 | import tensorflow.compat.v1 as tf 15 | 16 | 17 | def create_optimizer( 18 | loss, learning_rate, num_train_steps, weight_decay_rate=0.0, use_tpu=False, 19 | warmup_steps=0, warmup_proportion=0, lr_decay_power=1.0, 20 | layerwise_lr_decay_power=-1, n_transformer_layers=None): 21 | """Creates an optimizer and training op.""" 22 | global_step = tf.train.get_or_create_global_step() 23 | learning_rate = tf.train.polynomial_decay( 24 | learning_rate, 25 | global_step, 26 | num_train_steps, 27 | end_learning_rate=0.0, 28 | power=lr_decay_power, 29 | cycle=False) 30 | warmup_steps = max(num_train_steps * warmup_proportion, warmup_steps) 31 | learning_rate *= tf.minimum( 32 | 1.0, tf.cast(global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)) 33 | 34 | if layerwise_lr_decay_power > 0: 35 | learning_rate = _get_layer_lrs(learning_rate, layerwise_lr_decay_power, 36 | n_transformer_layers) 37 | optimizer = AdamWeightDecayOptimizer( 38 | learning_rate=learning_rate, 39 | weight_decay_rate=weight_decay_rate, 40 | beta_1=0.9, 41 | beta_2=0.999, 42 | epsilon=1e-6, 43 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 44 | if use_tpu: 45 | optimizer = tf.tpu.CrossShardOptimizer(optimizer) 46 | 47 | tvars = tf.trainable_variables() 48 | grads = tf.gradients(loss, tvars) 49 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 50 | train_op = optimizer.apply_gradients( 51 | zip(grads, tvars), global_step=global_step) 52 | new_global_step = global_step + 1 53 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 54 | return train_op 55 | 56 | 57 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 58 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 59 | 60 | def __init__(self, 61 | learning_rate, 62 | weight_decay_rate=0.0, 63 | beta_1=0.9, 64 | beta_2=0.999, 65 | epsilon=1e-6, 66 | exclude_from_weight_decay=None, 67 | name="AdamWeightDecayOptimizer"): 68 | """Constructs a AdamWeightDecayOptimizer.""" 69 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 70 | 71 | self.learning_rate = learning_rate 72 | self.weight_decay_rate = weight_decay_rate 73 | self.beta_1 = beta_1 74 | self.beta_2 = beta_2 75 | self.epsilon = epsilon 76 | self.exclude_from_weight_decay = exclude_from_weight_decay 77 | 78 | def _apply_gradients(self, grads_and_vars, learning_rate): 79 | """See base class.""" 80 | assignments = [] 81 | for (grad, param) in grads_and_vars: 82 | if grad is None or param is None: 83 | continue 84 | 85 | param_name = self._get_variable_name(param.name) 86 | 87 | m = tf.get_variable( 88 | name=param_name + "/adam_m", 89 | shape=param.shape.as_list(), 90 | dtype=tf.float32, 91 | trainable=False, 92 | initializer=tf.zeros_initializer()) 93 | v = tf.get_variable( 94 | name=param_name + "/adam_v", 95 | shape=param.shape.as_list(), 96 | dtype=tf.float32, 97 | trainable=False, 98 | initializer=tf.zeros_initializer()) 99 | 100 | # Standard Adam update. 101 | next_m = ( 102 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 103 | next_v = ( 104 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 105 | tf.square(grad))) 106 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 107 | 108 | # Just adding the square of the weights to the loss function is *not* 109 | # the correct way of using L2 regularization/weight decay with Adam, 110 | # since that will interact with the m and v parameters in strange ways. 111 | # 112 | # Instead we want ot decay the weights in a manner that doesn't interact 113 | # with the m/v parameters. This is equivalent to adding the square 114 | # of the weights to the loss with plain (non-momentum) SGD. 115 | if self.weight_decay_rate > 0: 116 | if self._do_use_weight_decay(param_name): 117 | update += self.weight_decay_rate * param 118 | 119 | update_with_lr = learning_rate * update 120 | next_param = param - update_with_lr 121 | 122 | assignments.extend( 123 | [param.assign(next_param), 124 | m.assign(next_m), 125 | v.assign(next_v)]) 126 | 127 | return assignments 128 | 129 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 130 | if isinstance(self.learning_rate, dict): 131 | key_to_grads_and_vars = {} 132 | for grad, var in grads_and_vars: 133 | update_for_var = False 134 | for key in self.learning_rate: 135 | if key in var.name: 136 | update_for_var = True 137 | if key not in key_to_grads_and_vars: 138 | key_to_grads_and_vars[key] = [] 139 | key_to_grads_and_vars[key].append((grad, var)) 140 | if not update_for_var: 141 | raise ValueError("No learning rate specified for variable", var) 142 | assignments = [] 143 | for key, key_grads_and_vars in key_to_grads_and_vars.items(): 144 | assignments += self._apply_gradients(key_grads_and_vars, 145 | self.learning_rate[key]) 146 | else: 147 | assignments = self._apply_gradients(grads_and_vars, self.learning_rate) 148 | return tf.group(*assignments, name=name) 149 | 150 | def _do_use_weight_decay(self, param_name): 151 | """Whether to use L2 weight decay for `param_name`.""" 152 | if not self.weight_decay_rate: 153 | return False 154 | if self.exclude_from_weight_decay: 155 | for r in self.exclude_from_weight_decay: 156 | if re.search(r, param_name) is not None: 157 | return False 158 | return True 159 | 160 | def _get_variable_name(self, param_name): 161 | """Get the variable name from the tensor name.""" 162 | m = re.match("^(.*):\\d+$", param_name) 163 | if m is not None: 164 | param_name = m.group(1) 165 | return param_name 166 | 167 | 168 | def _get_layer_lrs(learning_rate, layer_decay, n_layers): 169 | """Have lower learning rates for layers closer to the input.""" 170 | key_to_depths = collections.OrderedDict({ 171 | "/embeddings/": 0, 172 | "/embeddings_project/": 0, 173 | "task_specific/": n_layers + 2, 174 | }) 175 | for layer in range(n_layers): 176 | key_to_depths["encoder/layer_" + str(layer) + "/"] = layer + 1 177 | return { 178 | key: learning_rate * (layer_decay ** (n_layers + 2 - depth)) 179 | for key, depth in key_to_depths.items() 180 | } 181 | -------------------------------------------------------------------------------- /model/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Tokenization classes, the same as used for BERT.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import collections 10 | import unicodedata 11 | import six 12 | import tensorflow.compat.v1 as tf 13 | 14 | 15 | 16 | def convert_to_unicode(text): 17 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 18 | if six.PY3: 19 | if isinstance(text, str): 20 | return text 21 | elif isinstance(text, bytes): 22 | return text.decode("utf-8", "ignore") 23 | else: 24 | raise ValueError("Unsupported string type: %s" % (type(text))) 25 | elif six.PY2: 26 | if isinstance(text, str): 27 | return text.decode("utf-8", "ignore") 28 | elif isinstance(text, unicode): 29 | return text 30 | else: 31 | raise ValueError("Unsupported string type: %s" % (type(text))) 32 | else: 33 | raise ValueError("Not running on Python2 or Python 3?") 34 | 35 | 36 | def printable_text(text): 37 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 38 | 39 | # These functions want `str` for both Python2 and Python3, but in one case 40 | # it's a Unicode string and in the other it's a byte string. 41 | if six.PY3: 42 | if isinstance(text, str): 43 | return text 44 | elif isinstance(text, bytes): 45 | return text.decode("utf-8", "ignore") 46 | else: 47 | raise ValueError("Unsupported string type: %s" % (type(text))) 48 | elif six.PY2: 49 | if isinstance(text, str): 50 | return text 51 | elif isinstance(text, unicode): 52 | return text.encode("utf-8") 53 | else: 54 | raise ValueError("Unsupported string type: %s" % (type(text))) 55 | else: 56 | raise ValueError("Not running on Python2 or Python 3?") 57 | 58 | 59 | def load_vocab(vocab_file): 60 | """Loads a vocabulary file into a dictionary.""" 61 | vocab = collections.OrderedDict() 62 | index = 0 63 | with tf.io.gfile.GFile(vocab_file, "r") as reader: 64 | while True: 65 | token = convert_to_unicode(reader.readline()) 66 | if not token: 67 | break 68 | token = token.strip() 69 | vocab[token] = index 70 | index += 1 71 | return vocab 72 | 73 | 74 | def convert_by_vocab(vocab, items): 75 | """Converts a sequence of [tokens|ids] using the vocab.""" 76 | output = [] 77 | for item in items: 78 | output.append(vocab[item]) 79 | return output 80 | 81 | 82 | def convert_tokens_to_ids(vocab, tokens): 83 | return convert_by_vocab(vocab, tokens) 84 | 85 | 86 | def convert_ids_to_tokens(inv_vocab, ids): 87 | return convert_by_vocab(inv_vocab, ids) 88 | 89 | 90 | def whitespace_tokenize(text): 91 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 92 | text = text.strip() 93 | if not text: 94 | return [] 95 | tokens = text.split() 96 | return tokens 97 | 98 | 99 | class FullTokenizer(object): 100 | """Runs end-to-end tokenziation.""" 101 | 102 | def __init__(self, vocab_file, do_lower_case=True, strip_accents=True): 103 | """Constructs a FullTokenizer. 104 | Args: 105 | vocab_file: The vocabulary file. 106 | do_lower_case: Whether to lower case the input. 107 | strip_accents: Whether to strip the accents. 108 | """ 109 | self.vocab = load_vocab(vocab_file) 110 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 111 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, strip_accents=strip_accents) 112 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 113 | 114 | def tokenize(self, text): 115 | split_tokens = [] 116 | for token in self.basic_tokenizer.tokenize(text): 117 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 118 | split_tokens.append(sub_token) 119 | 120 | return split_tokens 121 | 122 | def convert_tokens_to_ids(self, tokens): 123 | return convert_by_vocab(self.vocab, tokens) 124 | 125 | def convert_ids_to_tokens(self, ids): 126 | return convert_by_vocab(self.inv_vocab, ids) 127 | 128 | 129 | class BasicTokenizer(object): 130 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 131 | 132 | def __init__(self, do_lower_case=True, strip_accents=True): 133 | """Constructs a BasicTokenizer. 134 | 135 | Args: 136 | do_lower_case: Whether to lower case the input. 137 | strip_accents: Whether to strip the accents. 138 | """ 139 | self.do_lower_case = do_lower_case 140 | self.strip_accents = strip_accents 141 | 142 | def tokenize(self, text): 143 | """Tokenizes a piece of text.""" 144 | text = convert_to_unicode(text) 145 | text = self._clean_text(text) 146 | 147 | # This was added on November 1st, 2018 for the multilingual and Chinese 148 | # models. This is also applied to the English models now, but it doesn't 149 | # matter since the English models were not trained on any Chinese data 150 | # and generally don't have any Chinese data in them (there are Chinese 151 | # characters in the vocabulary because Wikipedia does have some Chinese 152 | # words in the English Wikipedia.). 153 | text = self._tokenize_chinese_chars(text) 154 | 155 | orig_tokens = whitespace_tokenize(text) 156 | split_tokens = [] 157 | for token in orig_tokens: 158 | if self.do_lower_case: 159 | token = token.lower() 160 | if self.strip_accents: 161 | token = self._run_strip_accents(token) 162 | split_tokens.extend(self._run_split_on_punc(token)) 163 | 164 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 165 | return output_tokens 166 | 167 | def _run_strip_accents(self, text): 168 | """Strips accents from a piece of text.""" 169 | text = unicodedata.normalize("NFD", text) 170 | output = [] 171 | for char in text: 172 | cat = unicodedata.category(char) 173 | if cat == "Mn": 174 | continue 175 | output.append(char) 176 | return "".join(output) 177 | 178 | def _run_split_on_punc(self, text): 179 | """Splits punctuation on a piece of text.""" 180 | chars = list(text) 181 | i = 0 182 | start_new_word = True 183 | output = [] 184 | while i < len(chars): 185 | char = chars[i] 186 | if _is_punctuation(char): 187 | output.append([char]) 188 | start_new_word = True 189 | else: 190 | if start_new_word: 191 | output.append([]) 192 | start_new_word = False 193 | output[-1].append(char) 194 | i += 1 195 | 196 | return ["".join(x) for x in output] 197 | 198 | def _tokenize_chinese_chars(self, text): 199 | """Adds whitespace around any CJK character.""" 200 | output = [] 201 | for char in text: 202 | cp = ord(char) 203 | if self._is_chinese_char(cp): 204 | output.append(" ") 205 | output.append(char) 206 | output.append(" ") 207 | else: 208 | output.append(char) 209 | return "".join(output) 210 | 211 | def _is_chinese_char(self, cp): 212 | """Checks whether CP is the codepoint of a CJK character.""" 213 | # This defines a "chinese character" as anything in the CJK Unicode block: 214 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 215 | # 216 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 217 | # despite its name. The modern Korean Hangul alphabet is a different block, 218 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 219 | # space-separated words, so they are not treated specially and handled 220 | # like the all of the other languages. 221 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 222 | (cp >= 0x3400 and cp <= 0x4DBF) or # 223 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 224 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 225 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 226 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 227 | (cp >= 0xF900 and cp <= 0xFAFF) or # 228 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 229 | return True 230 | 231 | return False 232 | 233 | def _clean_text(self, text): 234 | """Performs invalid character removal and whitespace cleanup on text.""" 235 | output = [] 236 | for char in text: 237 | cp = ord(char) 238 | if cp == 0 or cp == 0xfffd or _is_control(char): 239 | continue 240 | if _is_whitespace(char): 241 | output.append(" ") 242 | else: 243 | output.append(char) 244 | return "".join(output) 245 | 246 | 247 | class WordpieceTokenizer(object): 248 | """Runs WordPiece tokenziation.""" 249 | 250 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 251 | self.vocab = vocab 252 | self.unk_token = unk_token 253 | self.max_input_chars_per_word = max_input_chars_per_word 254 | 255 | def tokenize(self, text): 256 | """Tokenizes a piece of text into its word pieces. 257 | 258 | This uses a greedy longest-match-first algorithm to perform tokenization 259 | using the given vocabulary. 260 | 261 | For example: 262 | input = "unaffable" 263 | output = ["un", "##aff", "##able"] 264 | 265 | Args: 266 | text: A single token or whitespace separated tokens. This should have 267 | already been passed through `BasicTokenizer. 268 | 269 | Returns: 270 | A list of wordpiece tokens. 271 | """ 272 | 273 | text = convert_to_unicode(text) 274 | 275 | output_tokens = [] 276 | for token in whitespace_tokenize(text): 277 | chars = list(token) 278 | if len(chars) > self.max_input_chars_per_word: 279 | output_tokens.append(self.unk_token) 280 | continue 281 | 282 | is_bad = False 283 | start = 0 284 | sub_tokens = [] 285 | while start < len(chars): 286 | end = len(chars) 287 | cur_substr = None 288 | while start < end: 289 | substr = "".join(chars[start:end]) 290 | if start > 0: 291 | substr = "##" + substr 292 | if substr in self.vocab: 293 | cur_substr = substr 294 | break 295 | end -= 1 296 | if cur_substr is None: 297 | is_bad = True 298 | break 299 | sub_tokens.append(cur_substr) 300 | start = end 301 | 302 | if is_bad: 303 | output_tokens.append(self.unk_token) 304 | else: 305 | output_tokens.extend(sub_tokens) 306 | return output_tokens 307 | 308 | 309 | def _is_whitespace(char): 310 | """Checks whether `chars` is a whitespace character.""" 311 | # \t, \n, and \r are technically contorl characters but we treat them 312 | # as whitespace since they are generally considered as such. 313 | if char == " " or char == "\t" or char == "\n" or char == "\r": 314 | return True 315 | cat = unicodedata.category(char) 316 | if cat == "Zs": 317 | return True 318 | return False 319 | 320 | 321 | def _is_control(char): 322 | """Checks whether `chars` is a control character.""" 323 | # These are technically control characters but we count them as whitespace 324 | # characters. 325 | if char == "\t" or char == "\n" or char == "\r": 326 | return False 327 | cat = unicodedata.category(char) 328 | if cat.startswith("C"): 329 | return True 330 | return False 331 | 332 | 333 | def _is_punctuation(char): 334 | """Checks whether `chars` is a punctuation character.""" 335 | cp = ord(char) 336 | # We treat all non-letter/number ASCII as punctuation. 337 | # Characters such as "^", "$", and "`" are not in the Unicode 338 | # Punctuation class but we treat them as punctuation anyways, for 339 | # consistency. 340 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 341 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 342 | return True 343 | cat = unicodedata.category(char) 344 | if cat.startswith("P"): 345 | return True 346 | return False 347 | -------------------------------------------------------------------------------- /pretrain.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=/path/to/data_dir 2 | # please set your data_dir, like ~/data/convbert 3 | NAME=convbert_medium-small 4 | python3 run_pretraining.py --data-dir $DATA_DIR --model-name $NAME --hparams '{"model_size": "medium-small"}' 5 | # The small-sized and medium-small-sized ConvBERT model can run on a V100 GPU, while the based-sized model needs more computation resources. -------------------------------------------------------------------------------- /pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | -------------------------------------------------------------------------------- /pretrain/pretrain_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | """Helpers for preparing pre-training data and supplying them to the model.""" 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import collections 11 | 12 | import numpy as np 13 | import tensorflow.compat.v1 as tf 14 | 15 | import configure_pretraining 16 | from model import tokenization 17 | from util import utils 18 | 19 | 20 | def get_input_fn(config: configure_pretraining.PretrainingConfig, is_training, 21 | num_cpu_threads=4): 22 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 23 | 24 | input_files = [] 25 | for input_pattern in config.pretrain_tfrecords.split(","): 26 | input_files.extend(tf.io.gfile.glob(input_pattern)) 27 | 28 | def input_fn(params): 29 | """The actual input function.""" 30 | batch_size = params["batch_size"] 31 | 32 | name_to_features = { 33 | "input_ids": tf.io.FixedLenFeature([config.max_seq_length], tf.int64), 34 | "input_mask": tf.io.FixedLenFeature([config.max_seq_length], tf.int64), 35 | "segment_ids": tf.io.FixedLenFeature([config.max_seq_length], tf.int64), 36 | } 37 | 38 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 39 | d = d.repeat() 40 | d = d.shuffle(buffer_size=len(input_files)) 41 | 42 | # `cycle_length` is the number of parallel files that get read. 43 | cycle_length = min(num_cpu_threads, len(input_files)) 44 | 45 | # `sloppy` mode means that the interleaving is not exact. This adds 46 | # even more randomness to the training pipeline. 47 | d = d.apply( 48 | tf.data.experimental.parallel_interleave( 49 | tf.data.TFRecordDataset, 50 | sloppy=is_training, 51 | cycle_length=cycle_length)) 52 | d = d.shuffle(buffer_size=100) 53 | 54 | # We must `drop_remainder` on training because the TPU requires fixed 55 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 56 | # and we *don"t* want to drop the remainder, otherwise we wont cover 57 | # every sample. 58 | d = d.apply( 59 | tf.data.experimental.map_and_batch( 60 | lambda record: _decode_record(record, name_to_features), 61 | batch_size=batch_size, 62 | num_parallel_batches=num_cpu_threads, 63 | drop_remainder=True)) 64 | return d 65 | 66 | return input_fn 67 | 68 | 69 | def _decode_record(record, name_to_features): 70 | """Decodes a record to a TensorFlow example.""" 71 | example = tf.io.parse_single_example(record, name_to_features) 72 | 73 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 74 | # So cast all int64 to int32. 75 | for name in list(example.keys()): 76 | t = example[name] 77 | if t.dtype == tf.int64: 78 | t = tf.cast(t, tf.int32) 79 | example[name] = t 80 | 81 | return example 82 | 83 | 84 | # model inputs - it's a bit nicer to use a namedtuple rather than keep the 85 | # features as a dict 86 | Inputs = collections.namedtuple( 87 | "Inputs", ["input_ids", "input_mask", "segment_ids", "masked_lm_positions", 88 | "masked_lm_ids", "masked_lm_weights"]) 89 | 90 | 91 | def features_to_inputs(features): 92 | return Inputs( 93 | input_ids=features["input_ids"], 94 | input_mask=features["input_mask"], 95 | segment_ids=features["segment_ids"], 96 | masked_lm_positions=(features["masked_lm_positions"] 97 | if "masked_lm_positions" in features else None), 98 | masked_lm_ids=(features["masked_lm_ids"] 99 | if "masked_lm_ids" in features else None), 100 | masked_lm_weights=(features["masked_lm_weights"] 101 | if "masked_lm_weights" in features else None), 102 | ) 103 | 104 | 105 | def get_updated_inputs(inputs, **kwargs): 106 | features = inputs._asdict() 107 | for k, v in kwargs.items(): 108 | features[k] = v 109 | return features_to_inputs(features) 110 | 111 | 112 | ENDC = "\033[0m" 113 | COLORS = ["\033[" + str(n) + "m" for n in list(range(91, 97)) + [90]] 114 | RED = COLORS[0] 115 | BLUE = COLORS[3] 116 | CYAN = COLORS[5] 117 | GREEN = COLORS[1] 118 | 119 | 120 | def print_tokens(inputs: Inputs, inv_vocab, updates_mask=None): 121 | """Pretty-print model inputs.""" 122 | pos_to_tokid = {} 123 | for tokid, pos, weight in zip( 124 | inputs.masked_lm_ids[0], inputs.masked_lm_positions[0], 125 | inputs.masked_lm_weights[0]): 126 | if weight == 0: 127 | pass 128 | else: 129 | pos_to_tokid[pos] = tokid 130 | 131 | text = "" 132 | provided_update_mask = (updates_mask is not None) 133 | if not provided_update_mask: 134 | updates_mask = np.zeros_like(inputs.input_ids) 135 | for pos, (tokid, um) in enumerate( 136 | zip(inputs.input_ids[0], updates_mask[0])): 137 | token = inv_vocab[tokid] 138 | if token == "[PAD]": 139 | break 140 | if pos in pos_to_tokid: 141 | token = RED + token + " (" + inv_vocab[pos_to_tokid[pos]] + ")" + ENDC 142 | if provided_update_mask: 143 | assert um == 1 144 | else: 145 | if provided_update_mask: 146 | assert um == 0 147 | text += token + " " 148 | utils.log(tokenization.printable_text(text)) 149 | -------------------------------------------------------------------------------- /pretrain/pretrain_helpers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Helper functions for pre-training. These mainly deal with the gathering and 4 | scattering needed so the generator only makes predictions for the small number 5 | of masked tokens. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import tensorflow.compat.v1 as tf 13 | 14 | import configure_pretraining 15 | from model import modeling 16 | from model import tokenization 17 | from pretrain import pretrain_data 18 | 19 | 20 | def gather_positions(sequence, positions): 21 | """Gathers the vectors at the specific positions over a minibatch. 22 | 23 | Args: 24 | sequence: A [batch_size, seq_length] or 25 | [batch_size, seq_length, depth] tensor of values 26 | positions: A [batch_size, n_positions] tensor of indices 27 | 28 | Returns: A [batch_size, n_positions] or 29 | [batch_size, n_positions, depth] tensor of the values at the indices 30 | """ 31 | shape = modeling.get_shape_list(sequence, expected_rank=[2, 3]) 32 | depth_dimension = (len(shape) == 3) 33 | if depth_dimension: 34 | B, L, D = shape 35 | else: 36 | B, L = shape 37 | D = 1 38 | sequence = tf.expand_dims(sequence, -1) 39 | position_shift = tf.expand_dims(L * tf.range(B), -1) 40 | flat_positions = tf.reshape(positions + position_shift, [-1]) 41 | flat_sequence = tf.reshape(sequence, [B * L, D]) 42 | gathered = tf.gather(flat_sequence, flat_positions) 43 | if depth_dimension: 44 | return tf.reshape(gathered, [B, -1, D]) 45 | else: 46 | return tf.reshape(gathered, [B, -1]) 47 | 48 | 49 | def scatter_update(sequence, updates, positions): 50 | """Scatter-update a sequence. 51 | 52 | Args: 53 | sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor 54 | updates: A tensor of size batch_size*seq_len(*depth) 55 | positions: A [batch_size, n_positions] tensor 56 | 57 | Returns: A tuple of two tensors. First is a [batch_size, seq_len] or 58 | [batch_size, seq_len, depth] tensor of "sequence" with elements at 59 | "positions" replaced by the values at "updates." Updates to index 0 are 60 | ignored. If there are duplicated positions the update is only applied once. 61 | Second is a [batch_size, seq_len] mask tensor of which inputs were updated. 62 | """ 63 | shape = modeling.get_shape_list(sequence, expected_rank=[2, 3]) 64 | depth_dimension = (len(shape) == 3) 65 | if depth_dimension: 66 | B, L, D = shape 67 | else: 68 | B, L = shape 69 | D = 1 70 | sequence = tf.expand_dims(sequence, -1) 71 | N = modeling.get_shape_list(positions)[1] 72 | 73 | shift = tf.expand_dims(L * tf.range(B), -1) 74 | flat_positions = tf.reshape(positions + shift, [-1, 1]) 75 | flat_updates = tf.reshape(updates, [-1, D]) 76 | updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D]) 77 | updates = tf.reshape(updates, [B, L, D]) 78 | 79 | flat_updates_mask = tf.ones([B * N], tf.int32) 80 | updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L]) 81 | updates_mask = tf.reshape(updates_mask, [B, L]) 82 | not_first_token = tf.concat([tf.zeros((B, 1), tf.int32), 83 | tf.ones((B, L - 1), tf.int32)], -1) 84 | updates_mask *= not_first_token 85 | updates_mask_3d = tf.expand_dims(updates_mask, -1) 86 | 87 | # account for duplicate positions 88 | if sequence.dtype == tf.float32: 89 | updates_mask_3d = tf.cast(updates_mask_3d, tf.float32) 90 | updates /= tf.maximum(1.0, updates_mask_3d) 91 | else: 92 | assert sequence.dtype == tf.int32 93 | updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d)) 94 | updates_mask = tf.minimum(updates_mask, 1) 95 | updates_mask_3d = tf.minimum(updates_mask_3d, 1) 96 | 97 | updated_sequence = (((1 - updates_mask_3d) * sequence) + 98 | (updates_mask_3d * updates)) 99 | if not depth_dimension: 100 | updated_sequence = tf.squeeze(updated_sequence, -1) 101 | 102 | return updated_sequence, updates_mask 103 | 104 | 105 | def _get_candidates_mask(inputs: pretrain_data.Inputs, vocab, 106 | disallow_from_mask=None): 107 | """Returns a mask tensor of positions in the input that can be masked out.""" 108 | ignore_ids = [vocab["[SEP]"], vocab["[CLS]"], vocab["[MASK]"]] 109 | candidates_mask = tf.ones_like(inputs.input_ids, tf.bool) 110 | for ignore_id in ignore_ids: 111 | candidates_mask &= tf.not_equal(inputs.input_ids, ignore_id) 112 | candidates_mask &= tf.cast(inputs.input_mask, tf.bool) 113 | if disallow_from_mask is not None: 114 | candidates_mask &= ~disallow_from_mask 115 | return candidates_mask 116 | 117 | 118 | def mask(config: configure_pretraining.PretrainingConfig, 119 | inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0, 120 | disallow_from_mask=None, already_masked=None): 121 | """Implementation of dynamic masking. The optional arguments aren't needed for 122 | BERT/ELECTRA and are from early experiments in "strategically" masking out 123 | tokens instead of uniformly at random. 124 | 125 | Args: 126 | config: configure_pretraining.PretrainingConfig 127 | inputs: pretrain_data.Inputs containing input input_ids/input_mask 128 | mask_prob: percent of tokens to mask 129 | proposal_distribution: for non-uniform masking can be a [B, L] tensor 130 | of scores for masking each position. 131 | disallow_from_mask: a boolean tensor of [B, L] of positions that should 132 | not be masked out 133 | already_masked: a boolean tensor of [B, N] of already masked-out tokens 134 | for multiple rounds of masking 135 | Returns: a pretrain_data.Inputs with masking added 136 | """ 137 | # Get the batch size, sequence length, and max masked-out tokens 138 | N = config.max_predictions_per_seq 139 | B, L = modeling.get_shape_list(inputs.input_ids) 140 | 141 | # Find indices where masking out a token is allowed 142 | vocab = tokenization.FullTokenizer( 143 | config.vocab_file, do_lower_case=config.do_lower_case).vocab 144 | candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask) 145 | 146 | # Set the number of tokens to mask out per example 147 | num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32) 148 | num_to_predict = tf.maximum(1, tf.minimum( 149 | N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32))) 150 | masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32) 151 | if already_masked is not None: 152 | masked_lm_weights *= (1 - already_masked) 153 | 154 | # Get a probability of masking each position in the sequence 155 | candidate_mask_float = tf.cast(candidates_mask, tf.float32) 156 | sample_prob = (proposal_distribution * candidate_mask_float) 157 | sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True) 158 | 159 | # Sample the positions to mask out 160 | sample_prob = tf.stop_gradient(sample_prob) 161 | sample_logits = tf.log(sample_prob) 162 | masked_lm_positions = tf.random.categorical( 163 | sample_logits, N, dtype=tf.int32) 164 | masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32) 165 | 166 | # Get the ids of the masked-out tokens 167 | shift = tf.expand_dims(L * tf.range(B), -1) 168 | flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) 169 | masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]), 170 | flat_positions) 171 | masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1]) 172 | masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32) 173 | 174 | # Update the input ids 175 | replace_with_mask_positions = masked_lm_positions * tf.cast( 176 | tf.less(tf.random.uniform([B, N]), 0.85), tf.int32) 177 | inputs_ids, _ = scatter_update( 178 | inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]), 179 | replace_with_mask_positions) 180 | 181 | return pretrain_data.get_updated_inputs( 182 | inputs, 183 | input_ids=tf.stop_gradient(inputs_ids), 184 | masked_lm_positions=masked_lm_positions, 185 | masked_lm_ids=masked_lm_ids, 186 | masked_lm_weights=masked_lm_weights 187 | ) 188 | 189 | 190 | def unmask(inputs: pretrain_data.Inputs): 191 | unmasked_input_ids, _ = scatter_update( 192 | inputs.input_ids, inputs.masked_lm_ids, inputs.masked_lm_positions) 193 | return pretrain_data.get_updated_inputs(inputs, input_ids=unmasked_input_ids) 194 | 195 | 196 | def sample_from_softmax(logits, disallow=None): 197 | if disallow is not None: 198 | logits -= 1000.0 * disallow 199 | uniform_noise = tf.random.uniform( 200 | modeling.get_shape_list(logits), minval=0, maxval=1) 201 | gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9) 202 | return tf.one_hot(tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1, 203 | output_type=tf.int32), logits.shape[-1]) 204 | -------------------------------------------------------------------------------- /run_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Fine-tunes an ELECTRA/ConvBERT model on a downstream task.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import collections 11 | import json 12 | 13 | import tensorflow.compat.v1 as tf 14 | import numpy as np 15 | 16 | import configure_finetuning 17 | from finetune import preprocessing 18 | from finetune import task_builder 19 | from model import modeling 20 | from model import optimization 21 | from util import training_utils 22 | from util import utils 23 | 24 | 25 | class FinetuningModel(object): 26 | """Finetuning model with support for multi-task training.""" 27 | 28 | def __init__(self, config: configure_finetuning.FinetuningConfig, tasks, 29 | is_training, features, num_train_steps): 30 | # Create a shared transformer encoder 31 | bert_config = training_utils.get_bert_config(config) 32 | self.bert_config = bert_config 33 | if config.debug: 34 | bert_config.num_hidden_layers = 3 35 | bert_config.hidden_size = 144 36 | bert_config.intermediate_size = 144 * 4 37 | bert_config.num_attention_heads = 4 38 | assert config.max_seq_length <= bert_config.max_position_embeddings 39 | bert_model = modeling.BertModel( 40 | bert_config=bert_config, 41 | is_training=is_training, 42 | input_ids=features["input_ids"], 43 | input_mask=features["input_mask"], 44 | token_type_ids=features["segment_ids"], 45 | use_one_hot_embeddings=config.use_tpu, 46 | embedding_size=config.embedding_size) 47 | percent_done = (tf.cast(tf.train.get_or_create_global_step(), tf.float32) / 48 | tf.cast(num_train_steps, tf.float32)) 49 | 50 | # Add specific tasks 51 | self.outputs = {"task_id": features["task_id"]} 52 | losses = [] 53 | for task in tasks: 54 | with tf.variable_scope("task_specific/" + task.name): 55 | task_losses, task_outputs = task.get_prediction_module( 56 | bert_model, features, is_training, percent_done) 57 | losses.append(task_losses) 58 | self.outputs[task.name] = task_outputs 59 | self.loss = tf.reduce_sum( 60 | tf.stack(losses, -1) * 61 | tf.one_hot(features["task_id"], len(config.task_names))) 62 | 63 | 64 | def model_fn_builder(config: configure_finetuning.FinetuningConfig, tasks, 65 | num_train_steps, pretraining_config=None): 66 | """Returns `model_fn` closure for TPUEstimator.""" 67 | 68 | def model_fn(features, labels, mode, params): 69 | """The `model_fn` for TPUEstimator.""" 70 | utils.log("Building model...") 71 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 72 | model = FinetuningModel( 73 | config, tasks, is_training, features, num_train_steps) 74 | 75 | # Load pre-trained weights from checkpoint 76 | init_checkpoint = config.init_checkpoint 77 | if pretraining_config is not None: 78 | init_checkpoint = tf.train.latest_checkpoint(pretraining_config.model_dir) 79 | utils.log("Using checkpoint", init_checkpoint) 80 | tvars = tf.trainable_variables() 81 | scaffold_fn = None 82 | if init_checkpoint: 83 | assignment_map, _ = modeling.get_assignment_map_from_checkpoint( 84 | tvars, init_checkpoint) 85 | # import pdb;pdb.set_trace() 86 | if config.use_tpu: 87 | def tpu_scaffold(): 88 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 89 | return tf.train.Scaffold() 90 | scaffold_fn = tpu_scaffold 91 | else: 92 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 93 | 94 | # Build model for training or prediction 95 | if mode == tf.estimator.ModeKeys.TRAIN: 96 | train_op = optimization.create_optimizer( 97 | model.loss, config.learning_rate, num_train_steps, 98 | weight_decay_rate=config.weight_decay_rate, 99 | use_tpu=config.use_tpu, 100 | warmup_proportion=config.warmup_proportion, 101 | layerwise_lr_decay_power=config.layerwise_lr_decay, 102 | n_transformer_layers=model.bert_config.num_hidden_layers 103 | ) 104 | output_spec = tf.estimator.tpu.TPUEstimatorSpec( 105 | mode=mode, 106 | loss=model.loss, 107 | train_op=train_op, 108 | scaffold_fn=scaffold_fn, 109 | training_hooks=[training_utils.ETAHook( 110 | {} if config.use_tpu else dict(loss=model.loss), 111 | num_train_steps, config.iterations_per_loop, config.use_tpu, 10)]) 112 | else: 113 | assert mode == tf.estimator.ModeKeys.PREDICT 114 | output_spec = tf.estimator.tpu.TPUEstimatorSpec( 115 | mode=mode, 116 | predictions=utils.flatten_dict(model.outputs), 117 | scaffold_fn=scaffold_fn) 118 | 119 | utils.log("Building complete") 120 | def count_params(): 121 | n = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 122 | utils.log("Model size: %dK" % (n/1000)) 123 | count_params() 124 | return output_spec 125 | 126 | return model_fn 127 | 128 | 129 | class ModelRunner(object): 130 | """Fine-tunes a model on a supervised task.""" 131 | 132 | def __init__(self, config: configure_finetuning.FinetuningConfig, tasks, 133 | pretraining_config=None): 134 | self._config = config 135 | self._tasks = tasks 136 | self._preprocessor = preprocessing.Preprocessor(config, self._tasks) 137 | 138 | is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2 139 | tpu_cluster_resolver = None 140 | if config.use_tpu and config.tpu_name: 141 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 142 | config.tpu_name, zone=config.tpu_zone, project=config.gcp_project) 143 | tpu_config = tf.estimator.tpu.TPUConfig( 144 | iterations_per_loop=config.iterations_per_loop, 145 | num_shards=config.num_tpu_cores, 146 | per_host_input_for_training=is_per_host, 147 | tpu_job_name=config.tpu_job_name) 148 | run_config = tf.estimator.tpu.RunConfig( 149 | cluster=tpu_cluster_resolver, 150 | model_dir=config.model_dir, 151 | save_checkpoints_steps=config.save_checkpoints_steps, 152 | save_checkpoints_secs=None, 153 | tpu_config=tpu_config) 154 | 155 | if self._config.do_train: 156 | (self._train_input_fn, 157 | self.train_steps) = self._preprocessor.prepare_train() 158 | else: 159 | self._train_input_fn, self.train_steps = None, 0 160 | model_fn = model_fn_builder( 161 | config=config, 162 | tasks=self._tasks, 163 | num_train_steps=self.train_steps, 164 | pretraining_config=pretraining_config) 165 | self._estimator = tf.estimator.tpu.TPUEstimator( 166 | use_tpu=config.use_tpu, 167 | model_fn=model_fn, 168 | config=run_config, 169 | train_batch_size=config.train_batch_size, 170 | eval_batch_size=config.eval_batch_size, 171 | predict_batch_size=config.predict_batch_size) 172 | 173 | def train(self): 174 | utils.log("Training for {:} steps".format(self.train_steps)) 175 | self._estimator.train( 176 | input_fn=self._train_input_fn, max_steps=self.train_steps) 177 | 178 | def evaluate(self): 179 | return {task.name: self.evaluate_task(task) for task in self._tasks} 180 | 181 | def evaluate_task(self, task, split="dev", return_results=True): 182 | """Evaluate the current model.""" 183 | utils.log("Evaluating", task.name) 184 | eval_input_fn, _ = self._preprocessor.prepare_predict([task], split) 185 | results = self._estimator.predict(input_fn=eval_input_fn, 186 | yield_single_examples=True) 187 | scorer = task.get_scorer() 188 | for r in results: 189 | if r["task_id"] != len(self._tasks): # ignore padding examples 190 | r = utils.nest_dict(r, self._config.task_names) 191 | scorer.update(r[task.name]) 192 | if return_results: 193 | utils.log(task.name + ": " + scorer.results_str()) 194 | utils.log() 195 | return dict(scorer.get_results()) 196 | else: 197 | return scorer 198 | 199 | def write_classification_outputs(self, tasks, trial, split): 200 | """Write classification predictions to disk.""" 201 | utils.log("Writing out predictions for", tasks, split) 202 | predict_input_fn, _ = self._preprocessor.prepare_predict(tasks, split) 203 | results = self._estimator.predict(input_fn=predict_input_fn, 204 | yield_single_examples=True) 205 | # task name -> eid -> model-logits 206 | logits = collections.defaultdict(dict) 207 | for r in results: 208 | if r["task_id"] != len(self._tasks): 209 | r = utils.nest_dict(r, self._config.task_names) 210 | task_name = self._config.task_names[r["task_id"]] 211 | logits[task_name][r[task_name]["eid"]] = ( 212 | r[task_name]["logits"] if "logits" in r[task_name] 213 | else r[task_name]["predictions"]) 214 | for task_name in logits: 215 | utils.log("Pickling predictions for {:} {:} examples ({:})".format( 216 | len(logits[task_name]), task_name, split)) 217 | if trial <= self._config.n_writes_test: 218 | utils.write_pickle(logits[task_name], self._config.test_predictions( 219 | task_name, split, trial)) 220 | 221 | 222 | def write_results(config: configure_finetuning.FinetuningConfig, results): 223 | """Write evaluation metrics to disk.""" 224 | utils.log("Writing results to", config.results_txt) 225 | utils.mkdir(config.results_txt.rsplit("/", 1)[0]) 226 | utils.write_pickle(results, config.results_pkl) 227 | with tf.io.gfile.GFile(config.results_txt, "w") as f: 228 | results_str = "" 229 | for trial_results in results: 230 | for task_name, task_results in trial_results.items(): 231 | if task_name == "time" or task_name == "global_step": 232 | continue 233 | results_str += task_name + ": " + " - ".join( 234 | ["{:}: {:.2f}".format(k, v) 235 | for k, v in task_results.items()]) + "\n" 236 | f.write(results_str) 237 | utils.write_pickle(results, config.results_pkl) 238 | 239 | 240 | def run_finetuning(config: configure_finetuning.FinetuningConfig): 241 | """Run finetuning.""" 242 | 243 | # Setup for training 244 | results = [] 245 | trial = 1 246 | heading_info = "model={:}, trial {:}/{:}".format( 247 | config.model_name, trial, config.num_trials) 248 | heading = lambda msg: utils.heading(msg + ": " + heading_info) 249 | heading("Config") 250 | utils.log_config(config) 251 | generic_model_dir = config.model_dir 252 | tasks = task_builder.get_tasks(config) 253 | 254 | # Train and evaluate num_trials models with different random seeds 255 | while config.num_trials < 0 or trial <= config.num_trials: 256 | config.model_dir = generic_model_dir + "_" + str(trial) 257 | if config.do_train: 258 | utils.rmkdir(config.model_dir) 259 | 260 | model_runner = ModelRunner(config, tasks) 261 | if config.do_train: 262 | heading("Start training") 263 | model_runner.train() 264 | utils.log() 265 | 266 | if config.do_eval: 267 | heading("Run dev set evaluation") 268 | results.append(model_runner.evaluate()) 269 | write_results(config, results) 270 | if config.write_test_outputs and trial <= config.n_writes_test: 271 | heading("Running on the test set and writing the predictions") 272 | for task in tasks: 273 | # Currently only writing preds for GLUE and SQuAD 2.0 is supported 274 | if task.name in ["cola", "mrpc", "mnli", "sst", "rte", "qnli", "qqp", 275 | "sts","wnli"]: 276 | for split in task.get_test_splits(): 277 | model_runner.write_classification_outputs([task], trial, split) 278 | elif task.name == "squad": 279 | scorer = model_runner.evaluate_task(task, "test", False) 280 | scorer.write_predictions() 281 | preds = utils.load_json(config.qa_preds_file("squad")) 282 | null_odds = utils.load_json(config.qa_na_file("squad")) 283 | for q, _ in preds.items(): 284 | if null_odds[q] > config.qa_na_threshold: 285 | preds[q] = "" 286 | utils.write_json(preds, config.test_predictions( 287 | task.name, "test", trial)) 288 | else: 289 | utils.log("Skipping task", task.name, 290 | "- writing predictions is not supported for this task") 291 | 292 | if trial != config.num_trials and (not config.keep_all_models): 293 | utils.rmrf(config.model_dir) 294 | trial += 1 295 | 296 | 297 | def main(): 298 | parser = argparse.ArgumentParser(description=__doc__) 299 | parser.add_argument("--data-dir", required=True, 300 | help="Location of data files (model weights, etc).") 301 | parser.add_argument("--model-name", required=True, 302 | help="The name of the model being fine-tuned.") 303 | parser.add_argument("--hparams", default="{}", 304 | help="JSON dict of model hyperparameters.") 305 | args = parser.parse_args() 306 | if args.hparams.endswith(".json"): 307 | hparams = utils.load_json(args.hparams) 308 | else: 309 | hparams = json.loads(args.hparams) 310 | tf.logging.set_verbosity(tf.logging.ERROR) 311 | run_finetuning(configure_finetuning.FinetuningConfig( 312 | args.model_name, args.data_dir, **hparams)) 313 | 314 | 315 | if __name__ == "__main__": 316 | main() 317 | -------------------------------------------------------------------------------- /run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Pre-trains an ELECTRA/ConvBERT model.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import collections 11 | import json 12 | 13 | import tensorflow.compat.v1 as tf 14 | import numpy as np 15 | import configure_pretraining 16 | from model import modeling 17 | from model import optimization 18 | from pretrain import pretrain_data 19 | from pretrain import pretrain_helpers 20 | from util import training_utils 21 | from util import utils 22 | 23 | 24 | class PretrainingModel(object): 25 | """Transformer pre-training using the replaced-token-detection task.""" 26 | 27 | def __init__(self, config: configure_pretraining.PretrainingConfig, 28 | features, is_training): 29 | # Set up model config 30 | self._config = config 31 | self._bert_config = training_utils.get_bert_config(config) 32 | if config.debug: 33 | self._bert_config.num_hidden_layers = 3 34 | self._bert_config.hidden_size = 144 35 | self._bert_config.intermediate_size = 144 * 4 36 | self._bert_config.num_attention_heads = 4 37 | 38 | # Mask the input 39 | masked_inputs = pretrain_helpers.mask( 40 | config, pretrain_data.features_to_inputs(features), config.mask_prob) 41 | 42 | # Generator 43 | embedding_size = ( 44 | self._bert_config.hidden_size if config.embedding_size is None else 45 | config.embedding_size) 46 | if config.uniform_generator: 47 | mlm_output = self._get_masked_lm_output(masked_inputs, None) 48 | elif config.electra_objective and config.untied_generator: 49 | generator = self._build_transformer( 50 | masked_inputs, is_training, 51 | bert_config=get_generator_config(config, self._bert_config), 52 | embedding_size=(None if config.untied_generator_embeddings 53 | else embedding_size), 54 | untied_embeddings=config.untied_generator_embeddings, 55 | name="generator") 56 | mlm_output = self._get_masked_lm_output(masked_inputs, generator) 57 | else: 58 | generator = self._build_transformer( 59 | masked_inputs, is_training, embedding_size=embedding_size) 60 | mlm_output = self._get_masked_lm_output(masked_inputs, generator) 61 | fake_data = self._get_fake_data(masked_inputs, mlm_output.logits) 62 | self.mlm_output = mlm_output 63 | self.total_loss = config.gen_weight * mlm_output.loss 64 | 65 | # Discriminator 66 | disc_output = None 67 | if config.electra_objective: 68 | discriminator = self._build_transformer( 69 | fake_data.inputs, is_training, reuse=not config.untied_generator, 70 | embedding_size=embedding_size) 71 | disc_output = self._get_discriminator_output( 72 | fake_data.inputs, discriminator, fake_data.is_fake_tokens) 73 | self.total_loss += config.disc_weight * disc_output.loss 74 | 75 | # Evaluation 76 | eval_fn_inputs = { 77 | "input_ids": masked_inputs.input_ids, 78 | "masked_lm_preds": mlm_output.preds, 79 | "mlm_loss": mlm_output.per_example_loss, 80 | "masked_lm_ids": masked_inputs.masked_lm_ids, 81 | "masked_lm_weights": masked_inputs.masked_lm_weights, 82 | "input_mask": masked_inputs.input_mask 83 | } 84 | if config.electra_objective: 85 | eval_fn_inputs.update({ 86 | "disc_loss": disc_output.per_example_loss, 87 | "disc_labels": disc_output.labels, 88 | "disc_probs": disc_output.probs, 89 | "disc_preds": disc_output.preds, 90 | "sampled_tokids": tf.argmax(fake_data.sampled_tokens, -1, 91 | output_type=tf.int32) 92 | }) 93 | eval_fn_keys = eval_fn_inputs.keys() 94 | eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys] 95 | 96 | def metric_fn(*args): 97 | """Computes the loss and accuracy of the model.""" 98 | d = {k: arg for k, arg in zip(eval_fn_keys, args)} 99 | metrics = dict() 100 | metrics["masked_lm_accuracy"] = tf.metrics.accuracy( 101 | labels=tf.reshape(d["masked_lm_ids"], [-1]), 102 | predictions=tf.reshape(d["masked_lm_preds"], [-1]), 103 | weights=tf.reshape(d["masked_lm_weights"], [-1])) 104 | metrics["masked_lm_loss"] = tf.metrics.mean( 105 | values=tf.reshape(d["mlm_loss"], [-1]), 106 | weights=tf.reshape(d["masked_lm_weights"], [-1])) 107 | if config.electra_objective: 108 | metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy( 109 | labels=tf.reshape(d["masked_lm_ids"], [-1]), 110 | predictions=tf.reshape(d["sampled_tokids"], [-1]), 111 | weights=tf.reshape(d["masked_lm_weights"], [-1])) 112 | if config.disc_weight > 0: 113 | metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"]) 114 | metrics["disc_auc"] = tf.metrics.auc( 115 | d["disc_labels"] * d["input_mask"], 116 | d["disc_probs"] * tf.cast(d["input_mask"], tf.float32)) 117 | metrics["disc_accuracy"] = tf.metrics.accuracy( 118 | labels=d["disc_labels"], predictions=d["disc_preds"], 119 | weights=d["input_mask"]) 120 | metrics["disc_precision"] = tf.metrics.accuracy( 121 | labels=d["disc_labels"], predictions=d["disc_preds"], 122 | weights=d["disc_preds"] * d["input_mask"]) 123 | metrics["disc_recall"] = tf.metrics.accuracy( 124 | labels=d["disc_labels"], predictions=d["disc_preds"], 125 | weights=d["disc_labels"] * d["input_mask"]) 126 | return metrics 127 | self.eval_metrics = (metric_fn, eval_fn_values) 128 | 129 | def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): 130 | """Masked language modeling softmax layer.""" 131 | masked_lm_weights = inputs.masked_lm_weights 132 | with tf.variable_scope("generator_predictions"): 133 | if self._config.uniform_generator: 134 | logits = tf.zeros(self._bert_config.vocab_size) 135 | logits_tiled = tf.zeros( 136 | modeling.get_shape_list(inputs.masked_lm_ids) + 137 | [self._bert_config.vocab_size]) 138 | logits_tiled += tf.reshape(logits, [1, 1, self._bert_config.vocab_size]) 139 | logits = logits_tiled 140 | else: 141 | relevant_hidden = pretrain_helpers.gather_positions( 142 | model.get_sequence_output(), inputs.masked_lm_positions) 143 | hidden = tf.layers.dense( 144 | relevant_hidden, 145 | units=modeling.get_shape_list(model.get_embedding_table())[-1], 146 | activation=modeling.get_activation(self._bert_config.hidden_act), 147 | kernel_initializer=modeling.create_initializer( 148 | self._bert_config.initializer_range)) 149 | hidden = modeling.layer_norm(hidden) 150 | output_bias = tf.get_variable( 151 | "output_bias", 152 | shape=[self._bert_config.vocab_size], 153 | initializer=tf.zeros_initializer()) 154 | logits = tf.matmul(hidden, model.get_embedding_table(), 155 | transpose_b=True) 156 | logits = tf.nn.bias_add(logits, output_bias) 157 | 158 | oh_labels = tf.one_hot( 159 | inputs.masked_lm_ids, depth=self._bert_config.vocab_size, 160 | dtype=tf.float32) 161 | 162 | probs = tf.nn.softmax(logits) 163 | log_probs = tf.nn.log_softmax(logits) 164 | label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) 165 | 166 | numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs) 167 | denominator = tf.reduce_sum(masked_lm_weights) + 1e-6 168 | loss = numerator / denominator 169 | preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32) 170 | 171 | MLMOutput = collections.namedtuple( 172 | "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"]) 173 | return MLMOutput( 174 | logits=logits, probs=probs, per_example_loss=label_log_probs, 175 | loss=loss, preds=preds) 176 | 177 | def _get_discriminator_output(self, inputs, discriminator, labels): 178 | """Discriminator binary classifier.""" 179 | with tf.variable_scope("discriminator_predictions"): 180 | hidden = tf.layers.dense( 181 | discriminator.get_sequence_output(), 182 | units=self._bert_config.hidden_size, 183 | activation=modeling.get_activation(self._bert_config.hidden_act), 184 | kernel_initializer=modeling.create_initializer( 185 | self._bert_config.initializer_range)) 186 | logits = tf.squeeze(tf.layers.dense(hidden, units=1), -1) 187 | weights = tf.cast(inputs.input_mask, tf.float32) 188 | labelsf = tf.cast(labels, tf.float32) 189 | losses = tf.nn.sigmoid_cross_entropy_with_logits( 190 | logits=logits, labels=labelsf) * weights 191 | per_example_loss = (tf.reduce_sum(losses, axis=-1) / 192 | (1e-6 + tf.reduce_sum(weights, axis=-1))) 193 | loss = tf.reduce_sum(losses) / (1e-6 + tf.reduce_sum(weights)) 194 | probs = tf.nn.sigmoid(logits) 195 | preds = tf.cast(tf.round((tf.sign(logits) + 1) / 2), tf.int32) 196 | DiscOutput = collections.namedtuple( 197 | "DiscOutput", ["loss", "per_example_loss", "probs", "preds", 198 | "labels"]) 199 | return DiscOutput( 200 | loss=loss, per_example_loss=per_example_loss, probs=probs, 201 | preds=preds, labels=labels, 202 | ) 203 | 204 | def _get_fake_data(self, inputs, mlm_logits): 205 | """Sample from the generator to create corrupted input.""" 206 | inputs = pretrain_helpers.unmask(inputs) 207 | disallow = tf.one_hot( 208 | inputs.masked_lm_ids, depth=self._bert_config.vocab_size, 209 | dtype=tf.float32) if self._config.disallow_correct else None 210 | sampled_tokens = tf.stop_gradient(pretrain_helpers.sample_from_softmax( 211 | mlm_logits / self._config.temperature, disallow=disallow)) 212 | sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) 213 | updated_input_ids, masked = pretrain_helpers.scatter_update( 214 | inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) 215 | labels = masked * (1 - tf.cast( 216 | tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) 217 | updated_inputs = pretrain_data.get_updated_inputs( 218 | inputs, input_ids=updated_input_ids) 219 | FakedData = collections.namedtuple("FakedData", [ 220 | "inputs", "is_fake_tokens", "sampled_tokens"]) 221 | return FakedData(inputs=updated_inputs, is_fake_tokens=labels, 222 | sampled_tokens=sampled_tokens) 223 | 224 | def _build_transformer(self, inputs: pretrain_data.Inputs, is_training, 225 | bert_config=None, name="electra", reuse=False, **kwargs): 226 | """Build a transformer encoder network.""" 227 | if bert_config is None: 228 | bert_config = self._bert_config 229 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 230 | return modeling.BertModel( 231 | bert_config=bert_config, 232 | is_training=is_training, 233 | input_ids=inputs.input_ids, 234 | input_mask=inputs.input_mask, 235 | token_type_ids=inputs.segment_ids, 236 | use_one_hot_embeddings=self._config.use_tpu, 237 | scope=name, 238 | **kwargs) 239 | 240 | 241 | def get_generator_config(config: configure_pretraining.PretrainingConfig, 242 | bert_config: modeling.BertConfig): 243 | """Get model config for the generator network.""" 244 | gen_config = modeling.BertConfig.from_dict(bert_config.to_dict()) 245 | gen_config.hidden_size = int(round( 246 | bert_config.hidden_size * config.generator_hidden_size)) 247 | gen_config.num_hidden_layers = int(round( 248 | bert_config.num_hidden_layers * config.generator_layers)) 249 | gen_config.intermediate_size = 4 * gen_config.hidden_size 250 | # gen_config.num_attention_heads = max(1, gen_config.hidden_size // 64) 251 | gen_config.num_attention_heads = max(1, gen_config.num_attention_heads * config.generator_hidden_size) 252 | return gen_config 253 | 254 | 255 | def model_fn_builder(config: configure_pretraining.PretrainingConfig): 256 | """Build the model for training.""" 257 | 258 | def model_fn(features, labels, mode, params): 259 | """Build the model for training.""" 260 | model = PretrainingModel(config, features, 261 | mode == tf.estimator.ModeKeys.TRAIN) 262 | utils.log("Model is built!") 263 | def count_params(): 264 | n = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 265 | utils.log("Model size: %dK" % (n/1000)) 266 | count_params() 267 | if mode == tf.estimator.ModeKeys.TRAIN: 268 | train_op = optimization.create_optimizer( 269 | model.total_loss, config.learning_rate, config.num_train_steps, 270 | weight_decay_rate=config.weight_decay_rate, 271 | use_tpu=config.use_tpu, 272 | warmup_steps=config.num_warmup_steps, 273 | lr_decay_power=config.lr_decay_power 274 | ) 275 | output_spec = tf.estimator.tpu.TPUEstimatorSpec( 276 | mode=mode, 277 | loss=model.total_loss, 278 | train_op=train_op, 279 | training_hooks=[training_utils.ETAHook( 280 | {} if config.use_tpu else dict(loss=model.total_loss), 281 | config.num_train_steps, config.iterations_per_loop, 282 | config.use_tpu,100)] 283 | ) 284 | elif mode == tf.estimator.ModeKeys.EVAL: 285 | output_spec = tf.estimator.tpu.TPUEstimatorSpec( 286 | mode=mode, 287 | loss=model.total_loss, 288 | eval_metrics=model.eval_metrics, 289 | evaluation_hooks=[training_utils.ETAHook( 290 | {} if config.use_tpu else dict(loss=model.total_loss), 291 | config.num_eval_steps, config.iterations_per_loop, 292 | config.use_tpu, is_training=False)]) 293 | else: 294 | raise ValueError("Only TRAIN and EVAL modes are supported") 295 | return output_spec 296 | 297 | return model_fn 298 | 299 | 300 | def train_or_eval(config: configure_pretraining.PretrainingConfig): 301 | """Run pre-training or evaluate the pre-trained model.""" 302 | if config.do_train == config.do_eval: 303 | raise ValueError("Exactly one of `do_train` or `do_eval` must be True.") 304 | if config.debug: 305 | utils.rmkdir(config.model_dir) 306 | utils.heading("Config:") 307 | utils.log_config(config) 308 | 309 | is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2 310 | tpu_cluster_resolver = None 311 | if config.use_tpu and config.tpu_name: 312 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 313 | config.tpu_name, zone=config.tpu_zone, project=config.gcp_project) 314 | tpu_config = tf.estimator.tpu.TPUConfig( 315 | iterations_per_loop=config.iterations_per_loop, 316 | num_shards=(config.num_tpu_cores if config.do_train else 317 | config.num_tpu_cores), 318 | # tpu_job_name=("train_tpu_worker" if config.do_train else 319 | # "lm_eval_tpu_worker"), 320 | per_host_input_for_training=is_per_host) 321 | run_config = tf.estimator.tpu.RunConfig( 322 | cluster=tpu_cluster_resolver, 323 | model_dir=config.model_dir, 324 | save_checkpoints_steps=config.save_checkpoints_steps, 325 | tpu_config=tpu_config) 326 | model_fn = model_fn_builder(config=config) 327 | estimator = tf.estimator.tpu.TPUEstimator( 328 | use_tpu=config.use_tpu, 329 | model_fn=model_fn, 330 | config=run_config, 331 | train_batch_size=config.train_batch_size, 332 | eval_batch_size=config.eval_batch_size) 333 | 334 | if config.do_train: 335 | utils.heading("Running training") 336 | estimator.train(input_fn=pretrain_data.get_input_fn(config, True), 337 | max_steps=config.num_train_steps) 338 | if config.do_eval: 339 | utils.heading("Running evaluation") 340 | result = estimator.evaluate( 341 | input_fn=pretrain_data.get_input_fn(config, False), 342 | steps=config.num_eval_steps) 343 | for key in sorted(result.keys()): 344 | utils.log(" {:} = {:}".format(key, str(result[key]))) 345 | return result 346 | 347 | 348 | def train_one_step(config: configure_pretraining.PretrainingConfig): 349 | """Builds an ELECTRA/ConvBERT model an trains it for one step; useful for debugging.""" 350 | train_input_fn = pretrain_data.get_input_fn(config, True) 351 | features = tf.data.make_one_shot_iterator(train_input_fn(dict( 352 | batch_size=config.train_batch_size))).get_next() 353 | model = PretrainingModel(config, features, True) 354 | with tf.Session() as sess: 355 | sess.run(tf.global_variables_initializer()) 356 | utils.log(sess.run(model.total_loss)) 357 | 358 | 359 | def main(): 360 | parser = argparse.ArgumentParser(description=__doc__) 361 | parser.add_argument("--data-dir", required=True, 362 | help="Location of data files (model weights, etc).") 363 | parser.add_argument("--model-name", required=True, 364 | help="The name of the model being fine-tuned.") 365 | parser.add_argument("--hparams", default="{}", 366 | help="JSON dict of model hyperparameters.") 367 | args = parser.parse_args() 368 | if args.hparams.endswith(".json"): 369 | hparams = utils.load_json(args.hparams) 370 | else: 371 | hparams = json.loads(args.hparams) 372 | tf.logging.set_verbosity(tf.logging.ERROR) 373 | train_or_eval(configure_pretraining.PretrainingConfig( 374 | args.model_name, args.data_dir, **hparams)) 375 | 376 | 377 | if __name__ == "__main__": 378 | main() 379 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 -------------------------------------------------------------------------------- /util/training_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Utilities for training the models.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import datetime 10 | import re 11 | import time 12 | import tensorflow.compat.v1 as tf 13 | 14 | from model import modeling 15 | from util import utils 16 | 17 | 18 | class ETAHook(tf.estimator.SessionRunHook): 19 | """Print out the time remaining during training/evaluation.""" 20 | 21 | def __init__(self, to_log, n_steps, iterations_per_loop, on_tpu, 22 | log_every=1, is_training=True): 23 | self._to_log = to_log 24 | self._n_steps = n_steps 25 | self._iterations_per_loop = iterations_per_loop 26 | self._on_tpu = on_tpu 27 | self._log_every = log_every 28 | self._is_training = is_training 29 | self._steps_run_so_far = 0 30 | self._global_step = None 31 | self._global_step_tensor = None 32 | self._start_step = None 33 | self._start_time = None 34 | 35 | def begin(self): 36 | self._global_step_tensor = tf.train.get_or_create_global_step() 37 | 38 | def before_run(self, run_context): 39 | if self._start_time is None: 40 | self._start_time = time.time() 41 | return tf.estimator.SessionRunArgs(self._to_log) 42 | 43 | def after_run(self, run_context, run_values): 44 | self._global_step = run_context.session.run(self._global_step_tensor) 45 | self._steps_run_so_far += self._iterations_per_loop if self._on_tpu else 1 46 | if self._start_step is None: 47 | self._start_step = self._global_step - (self._iterations_per_loop 48 | if self._on_tpu else 1) 49 | self.log(run_values) 50 | 51 | def end(self, session): 52 | self._global_step = session.run(self._global_step_tensor) 53 | self.log() 54 | 55 | def log(self, run_values=None): 56 | step = self._global_step if self._is_training else self._steps_run_so_far 57 | if step % self._log_every != 0: 58 | return 59 | msg = "{:}/{:} = {:.1f}%".format(step, self._n_steps, 60 | 100.0 * step / self._n_steps) 61 | time_elapsed = time.time() - self._start_time 62 | time_per_step = time_elapsed / ( 63 | (step - self._start_step) if self._is_training else step) 64 | msg += ", SPS: {:.1f}".format(1 / time_per_step) 65 | msg += ", ELAP: " + secs_to_str(time_elapsed) 66 | msg += ", ETA: " + secs_to_str( 67 | (self._n_steps - step) * time_per_step) 68 | if run_values is not None: 69 | for tag, value in run_values.results.items(): 70 | msg += " - " + str(tag) + (": {:.4f}".format(value)) 71 | utils.log(msg) 72 | 73 | 74 | def secs_to_str(secs): 75 | s = str(datetime.timedelta(seconds=int(round(secs)))) 76 | s = re.sub("^0:", "", s) 77 | s = re.sub("^0", "", s) 78 | s = re.sub("^0:", "", s) 79 | s = re.sub("^0", "", s) 80 | return s 81 | 82 | 83 | def get_bert_config(config): 84 | """Get model hyperparameters based on a pretraining/finetuning config""" 85 | if config.model_size == "base": 86 | args = {"hidden_size": 768, "num_hidden_layers": 12} 87 | elif config.model_size == "medium-small": 88 | args = {"hidden_size": 384, "num_hidden_layers": 12} 89 | elif config.model_size == "small": 90 | args = {"hidden_size": 256, "num_hidden_layers": 12} 91 | else: 92 | raise ValueError("Unknown model size", config.model_size) 93 | args["head_ratio"] = config.head_ratio 94 | args["conv_kernel_size"] = config.conv_kernel_size 95 | args["linear_groups"] = config.linear_groups 96 | args["conv_type"] = config.conv_type 97 | args["vocab_size"] = config.vocab_size 98 | args.update(**config.model_hparam_overrides) 99 | # by default the ff size and num attn heads are determined by the hidden size 100 | args["num_attention_heads"] = max(1, args["hidden_size"] // 64) 101 | args["intermediate_size"] = 4 * args["hidden_size"] 102 | if config.model_size in ["medium-small"]: 103 | args["num_attention_heads"] = max(1, args["hidden_size"] // 48) 104 | args.update(**config.model_hparam_overrides) 105 | return modeling.BertConfig.from_dict(args) 106 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """A collection of general utility functions.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import json 10 | import pickle 11 | import sys 12 | 13 | import tensorflow.compat.v1 as tf 14 | 15 | 16 | def load_json(path): 17 | with tf.io.gfile.GFile(path, "r") as f: 18 | return json.load(f) 19 | 20 | 21 | def write_json(o, path): 22 | if "/" in path: 23 | tf.io.gfile.makedirs(path.rsplit("/", 1)[0]) 24 | with tf.io.gfile.GFile(path, "w") as f: 25 | json.dump(o, f) 26 | 27 | 28 | def load_pickle(path): 29 | with tf.io.gfile.GFile(path, "rb") as f: 30 | return pickle.load(f) 31 | 32 | 33 | def write_pickle(o, path): 34 | if "/" in path: 35 | tf.io.gfile.makedirs(path.rsplit("/", 1)[0]) 36 | with tf.io.gfile.GFile(path, "wb") as f: 37 | pickle.dump(o, f, -1) 38 | 39 | 40 | def mkdir(path): 41 | if not tf.io.gfile.exists(path): 42 | tf.io.gfile.makedirs(path) 43 | 44 | 45 | def rmrf(path): 46 | if tf.io.gfile.exists(path): 47 | tf.io.gfile.rmtree(path) 48 | 49 | 50 | def rmkdir(path): 51 | rmrf(path) 52 | mkdir(path) 53 | 54 | 55 | def log(*args): 56 | msg = " ".join(map(str, args)) 57 | sys.stdout.write(msg + "\n") 58 | sys.stdout.flush() 59 | 60 | 61 | def log_config(config): 62 | for key, value in sorted(config.__dict__.items()): 63 | log(key, value) 64 | log() 65 | 66 | 67 | def heading(*args): 68 | log(80 * "=") 69 | log(*args) 70 | log(80 * "=") 71 | 72 | 73 | def nest_dict(d, prefixes, delim="_"): 74 | """Go from {prefix_key: value} to {prefix: {key: value}}.""" 75 | nested = {} 76 | for k, v in d.items(): 77 | for prefix in prefixes: 78 | if k.startswith(prefix + delim): 79 | if prefix not in nested: 80 | nested[prefix] = {} 81 | nested[prefix][k.split(delim, 1)[1]] = v 82 | else: 83 | nested[k] = v 84 | return nested 85 | 86 | 87 | def flatten_dict(d, delim="_"): 88 | """Go from {prefix: {key: value}} to {prefix_key: value}.""" 89 | flattened = {} 90 | for k, v in d.items(): 91 | if isinstance(v, dict): 92 | for k2, v2 in v.items(): 93 | flattened[k + delim + k2] = v2 94 | else: 95 | flattened[k] = v 96 | return flattened 97 | --------------------------------------------------------------------------------