├── .gitignore ├── util_squad_eval.py ├── requirements.txt ├── README.md ├── model.py ├── util_mrqa_official_eval.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # setup 2 | .vscode 3 | 4 | # slurm output 5 | .output 6 | 7 | # data 8 | data 9 | 10 | # wandb logging 11 | wandb 12 | 13 | # misc 14 | __pycache__ 15 | .DS_Store -------------------------------------------------------------------------------- /util_squad_eval.py: -------------------------------------------------------------------------------- 1 | import re 2 | import collections 3 | import string 4 | 5 | 6 | def normalize_answer(s): 7 | def remove_articles(text): 8 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 9 | return re.sub(regex, ' ', text) 10 | 11 | def white_space_fix(text): 12 | return ' '.join(text.split()) 13 | 14 | def remove_punc(text): 15 | exclude = set(string.punctuation) 16 | return ''.join(ch for ch in text if ch not in exclude) 17 | 18 | def lower(text): 19 | return text.lower() 20 | 21 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 22 | 23 | 24 | def get_tokens(s): 25 | if not s: 26 | return [] 27 | return normalize_answer(s).split() 28 | 29 | 30 | def compute_f1(a_gold, a_pred): 31 | gold_toks = get_tokens(a_gold) 32 | pred_toks = get_tokens(a_pred) 33 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 34 | num_same = sum(common.values()) 35 | if len(gold_toks) == 0 or len(pred_toks) == 0: 36 | return int(gold_toks == pred_toks) 37 | if num_same == 0: 38 | return 0 39 | precision = 1.0 * num_same / len(pred_toks) 40 | recall = 1.0 * num_same / len(gold_toks) 41 | f1 = (2 * precision * recall) / (precision + recall) 42 | return f1 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 2 | bert-score==0.3.9 3 | boto3==1.17.104 4 | botocore==1.20.104 5 | certifi==2021.10.8 6 | chardet==4.0.0 7 | click==8.0.1 8 | configparser==5.0.2 9 | cycler==0.10.0 10 | datasets==1.9.0 11 | decorator @ file:///tmp/build/80754af9/decorator_1621259047763/work 12 | dill==0.3.4 13 | docker-pycreds==0.4.0 14 | filelock==3.0.12 15 | flake8 @ file:///tmp/build/80754af9/flake8_1620776156532/work 16 | fsspec==2021.7.0 17 | gitdb==4.0.7 18 | GitPython==3.1.17 19 | huggingface-hub==0.0.8 20 | idna==2.10 21 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617877307939/work 22 | ipython @ file:///tmp/build/80754af9/ipython_1617120888991/work 23 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work 24 | jedi @ file:///tmp/build/80754af9/jedi_1606932531272/work 25 | jmespath==0.10.0 26 | joblib==1.0.1 27 | kiwisolver==1.3.1 28 | matplotlib==3.4.2 29 | mccabe==0.6.1 30 | mkl-fft==1.3.0 31 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853971975/work 32 | mkl-service==2.3.0 33 | multiprocess==0.70.12.2 34 | nltk==3.6.2 35 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831195661/work 36 | packaging==20.9 37 | pandas==1.3.0 38 | parso==0.7.0 39 | pathtools==0.1.2 40 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 41 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 42 | Pillow==8.2.0 43 | portalocker==2.0.0 44 | promise==2.3 45 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work 46 | protobuf==3.17.3 47 | psutil==5.8.0 48 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 49 | pyarrow==4.0.1 50 | pycodestyle @ file:///tmp/build/80754af9/pycodestyle_1615748559966/work 51 | pyflakes @ file:///tmp/build/80754af9/pyflakes_1617200973297/work 52 | Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work 53 | pyparsing==2.4.7 54 | python-dateutil==2.8.1 55 | pytorch-pretrained-bert==0.6.2 56 | pytz==2021.1 57 | PyYAML==5.4.1 58 | regex==2021.4.4 59 | requests==2.25.1 60 | s3transfer==0.4.2 61 | sacrebleu==1.5.1 62 | sacremoses==0.0.45 63 | scipy @ file:///tmp/build/80754af9/scipy_1630606796912/work 64 | sentry-sdk==1.1.0 65 | shortuuid==1.0.1 66 | six @ file:///tmp/build/80754af9/six_1605205306277/work 67 | smmap==4.0.0 68 | subprocess32==3.5.4 69 | tokenizers==0.10.3 70 | torch==1.8.1 71 | torch-dct==0.1.5 72 | torchaudio==0.8.1 73 | torchvision==0.9.1 74 | tqdm @ file:///tmp/build/80754af9/tqdm_1615925068909/work 75 | traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work 76 | transformers==4.6.1 77 | typing-extensions==3.10.0.0 78 | urllib3==1.26.5 79 | wandb==0.11.0 80 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work 81 | xxhash==2.0.2 82 | yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work 83 | zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bandit-qa 2 | Code for [_Simulating Bandit Learning from User Feedback for Extractive Question Answering_](https://arxiv.org/pdf/2203.10079.pdf). 3 | 4 | Please contact the first author if you have any questions. 5 | 6 | ## Table of Contents 7 | - [Basics](#basics) 8 | - [Data](#data) 9 | - [Installation](#installation) 10 | - [Instruction](#instruction) 11 | - [Citation](#citation) 12 | 13 | ## Basics 14 | Brief intro for each file: 15 | - train.py: training script 16 | - model.py: model implementaton 17 | - util*.py: codes for evaluation 18 | 19 | 20 | ## Data 21 | You can download MRQA datasets from [MRQA official repo](https://github.com/mrqa/MRQA-Shared-Task-2019#training-data): training data and in-domain development data. 22 | 23 | You can download small sets of supervised examples for initial training from [Splinter repo](https://github.com/oriram/splinter): follow the instruction under "Downloading Few-Shot MRQA Splits". 24 | 25 | We suggest you to create a _data_ folder and save all data files there. 26 | 27 | 28 | ## Installation 29 | 1. This project is developed in Python 3.9.5. Using Conda to set up a virtual environment is recommended. 30 | 31 | 2. Install the required dependencies. 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | 3. Install PyTorch from http://pytorch.org/. 36 | 37 | 38 | ## Instruction 39 | You can run the following command to start an online simulation experiment with wandb logging: 40 | 41 | ``` 42 | python train.py --notes 'your own notes for this experiment if needed' --wandb --do_train --do_eval --model SpanBERT/spanbert-base-cased --seed 46 --train_file ??? --dev_file ??? --max_seq_length 512 --doc_stride 128 --eval_metric f1 --num_train_epochs 1 --eval_per_epoch 8 --output_dir .simulation --initialize_model_from_checkpoint ??? --train_batch_size 80 --eval_batch_size 20 --gradient_accumulation_steps 4 --scheduler constant --algo 'R' --turn_off_dropout --argmax_simulation 43 | ``` 44 | 45 | 46 | ??? means the path to file needed by the argument. Please read the argparse code at the bottom of train.py to understand what arguments you could further configure. 47 | 48 | To obtain a model initially trained on some supervised data, you are welcome to 1) use training scripts in [SpanBERT repo](https://github.com/facebookresearch/SpanBERT) on SQuAD and MRQA datasets, or 2) configure train.py in this repo. 49 | 50 | ## Citation 51 | ``` 52 | @inproceedings{Gao2022:banditqa-simulation, 53 | title = {Simulating Bandit Learning from User Feedback for Extractive Question Answering}, 54 | author = {Gao, Ge and 55 | Choi, Eunsol and 56 | Artzi, Yoav}, 57 | booktitle = {Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics}, 58 | month = may, 59 | year = {2022}, 60 | address = {Dublin, Ireland}, 61 | publisher = {Association for Computational Linguistics}, 62 | url = {https://aclanthology.org/2022.acl-long.355}, 63 | pages = {5167--5179} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import CrossEntropyLoss 3 | from transformers import BertModel 4 | import torch 5 | 6 | 7 | class BertForQuestionAnswering(nn.Module): 8 | def __init__(self, model_type: str): 9 | super(BertForQuestionAnswering, self).__init__() 10 | if 'bert-' in model_type: 11 | self.bert = BertModel.from_pretrained(model_type) 12 | else: 13 | raise ValueError('Model type!') 14 | 15 | self.qa_outputs = nn.Linear(self.bert.config.hidden_size, 2) # [N, L, H] => [N, L, 2] 16 | 17 | def forward(self, batch, return_prob=False, **kwargs): 18 | ''' 19 | each batch is a list of 5 items (training) or 3 items (inference) 20 | - input_ids: token id of the input sequence 21 | - attention_mask: mask of the sequence (1 for present, 0 for blank) 22 | - token_type_ids: indicator of type of sequence. 23 | - e.g. in QA, whether it is question or document 24 | - (training) start_positions: list of start positions of the span 25 | - (training) end_positions: list of end positions of the span 26 | ''' 27 | 28 | input_ids, attention_masks, token_type_ids = batch[:3] 29 | # pooler_output, last_hidden_state 30 | if 'distil' in self.bert.config._name_or_path: 31 | output = self.bert( 32 | input_ids=input_ids, 33 | # NOTE token_types_ids is not an argument for distilbert 34 | # token_type_ids=token_type_ids, 35 | attention_mask=attention_masks) 36 | else: 37 | output = self.bert(input_ids=input_ids, 38 | token_type_ids=token_type_ids, 39 | attention_mask=attention_masks) 40 | sequence_output = output.last_hidden_state 41 | logits = self.qa_outputs(sequence_output) # (bs, max_input_len, 2) 42 | start_logits, end_logits = logits.split(1, dim=-1) 43 | start_logits = start_logits.squeeze(-1) # (bs, max_input_len) 44 | end_logits = end_logits.squeeze(-1) # (bs, max_input_len) 45 | 46 | if len(batch) == 5: 47 | start_positions, end_positions = batch[3:] 48 | ignored_index = start_logits.size(1) 49 | start_positions.clamp_(0, ignored_index) 50 | end_positions.clamp_(0, ignored_index) 51 | 52 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 53 | start_loss = loss_fct(start_logits, start_positions) 54 | end_loss = loss_fct(end_logits, end_positions) 55 | total_loss = (start_loss + end_loss) / 2 56 | return total_loss, None 57 | 58 | elif len(batch) == 3: 59 | if not return_prob: 60 | return start_logits, end_logits 61 | else: 62 | return torch.softmax(start_logits, dim=-1), torch.softmax(end_logits, dim=-1) 63 | 64 | else: 65 | raise NotImplementedError() 66 | -------------------------------------------------------------------------------- /util_mrqa_official_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, Facebook, Inc. and its affiliates. All Rights Reserved 2 | """Official evaluation script for the MRQA Workshop Shared Task. 3 | Adapted fromt the SQuAD v1.1 official evaluation script. 4 | Usage: 5 | python official_eval.py dataset_file.jsonl.gz prediction_file.json 6 | """ 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from pathlib import Path 11 | from urllib.parse import urlparse 12 | import argparse 13 | import string 14 | import re 15 | import json 16 | import gzip 17 | import os 18 | from collections import Counter 19 | 20 | 21 | def cached_path(url_or_filename, cache_dir=None): 22 | """ 23 | Given something that might be a URL (or might be a local path), 24 | determine which. If it's a URL, download the file and cache it, and 25 | return the path to the cached file. If it's already a local path, 26 | make sure the file exists and then return the path. 27 | """ 28 | if cache_dir is None: 29 | cache_dir = os.path.dirname(url_or_filename) 30 | if isinstance(url_or_filename, Path): 31 | url_or_filename = str(url_or_filename) 32 | 33 | url_or_filename = os.path.expanduser(url_or_filename) 34 | parsed = urlparse(url_or_filename) 35 | 36 | if parsed.scheme in ('http', 'https', 's3'): 37 | # URL, so get it from the cache (downloading if necessary) 38 | return get_from_cache(url_or_filename, cache_dir) 39 | elif os.path.exists(url_or_filename): 40 | # File, and it exists. 41 | return url_or_filename 42 | elif parsed.scheme == '': 43 | # File, but it doesn't exist. 44 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 45 | else: 46 | # Something unknown 47 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 48 | 49 | 50 | def normalize_answer(s): 51 | """Lower text and remove punctuation, articles and extra whitespace.""" 52 | def remove_articles(text): 53 | return re.sub(r'\b(a|an|the)\b', ' ', text) 54 | 55 | def white_space_fix(text): 56 | return ' '.join(text.split()) 57 | 58 | def remove_punc(text): 59 | exclude = set(string.punctuation) 60 | return ''.join(ch for ch in text if ch not in exclude) 61 | 62 | def lower(text): 63 | return text.lower() 64 | 65 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 66 | 67 | 68 | def f1_score(prediction, ground_truth): 69 | prediction_tokens = normalize_answer(prediction).split() 70 | ground_truth_tokens = normalize_answer(ground_truth).split() 71 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 72 | num_same = sum(common.values()) 73 | if num_same == 0: 74 | return 0 75 | precision = 1.0 * num_same / len(prediction_tokens) 76 | recall = 1.0 * num_same / len(ground_truth_tokens) 77 | f1 = (2 * precision * recall) / (precision + recall) 78 | return f1 79 | 80 | 81 | def exact_match_score(prediction, ground_truth): 82 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 83 | 84 | 85 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 86 | scores_for_ground_truths = [] 87 | for ground_truth in ground_truths: 88 | score = metric_fn(prediction, ground_truth) 89 | scores_for_ground_truths.append(score) 90 | return max(scores_for_ground_truths) 91 | 92 | 93 | def read_predictions(prediction_file): 94 | with open(prediction_file) as f: 95 | predictions = json.load(f) 96 | return predictions 97 | 98 | 99 | def read_answers(gold_file): 100 | answers = {} 101 | with gzip.open(gold_file, 'rb') as f: 102 | for i, line in enumerate(f): 103 | example = json.loads(line) 104 | if i == 0 and 'header' in example: 105 | continue 106 | for qa in example['qas']: 107 | answers[qa['qid']] = qa['answers'] 108 | return answers 109 | 110 | 111 | def evaluate(answers, predictions, skip_no_answer=False): 112 | f1 = exact_match = total = 0 113 | for qid, ground_truths in answers.items(): 114 | if qid not in predictions: 115 | if not skip_no_answer: 116 | message = 'Unanswered question %s will receive score 0.' % qid 117 | print(message) 118 | total += 1 119 | continue 120 | total += 1 121 | prediction = predictions[qid] 122 | exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) 123 | f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) 124 | 125 | exact_match = 100.0 * exact_match / total 126 | f1 = 100.0 * f1 / total 127 | 128 | return {'exact_match': exact_match, 'f1': f1} 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser(description='Evaluation for MRQA Workshop Shared Task') 133 | parser.add_argument('dataset_file', type=str, help='Dataset File') 134 | parser.add_argument('prediction_file', type=str, help='Prediction File') 135 | parser.add_argument('--skip-no-answer', action='store_true') 136 | args = parser.parse_args() 137 | 138 | answers = read_answers(cached_path(args.dataset_file)) 139 | predictions = read_predictions(cached_path(args.prediction_file)) 140 | metrics = evaluate(answers, predictions, args.skip_no_answer) 141 | 142 | print(json.dumps(metrics)) 143 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Run BERT on MRQA. 2 | 3 | https://note.nkmk.me/en/python-break-nested-loops/ 4 | 5 | Script adapted from the span bert repo (Copyright (c) 2019, Facebook, Inc. and its affiliates. All Rights Reserved) 6 | """ 7 | 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import wandb 11 | import argparse 12 | import collections 13 | import json 14 | import logging 15 | import math 16 | import os 17 | import random 18 | import time 19 | import gzip 20 | import datetime 21 | from io import open 22 | 23 | import numpy as np 24 | import torch 25 | from torch.utils.data import DataLoader, TensorDataset 26 | 27 | from transformers import BertTokenizer 28 | from transformers import AdamW 29 | from model import BertForQuestionAnswering 30 | from transformers import get_scheduler 31 | 32 | from pytorch_pretrained_bert.tokenization import BasicTokenizer 33 | from util_mrqa_official_eval import exact_match_score, f1_score, metric_max_over_ground_truths 34 | 35 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 36 | datefmt='%m/%d/%Y %H:%M:%S', 37 | level=logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | PRED_FILE = "predictions.json" 41 | EVAL_FILE = "eval_results.txt" 42 | TEST_FILE = "test_results.txt" 43 | 44 | 45 | class MRQAExample(object): 46 | """ 47 | A single training/test example for the MRQA dataset. 48 | For examples without an answer, the start and end position are -1. 49 | """ 50 | def __init__(self, 51 | qas_id, 52 | question_text, 53 | doc_tokens, 54 | orig_answer_text=None, 55 | start_position=None, 56 | end_position=None): 57 | self.qas_id = qas_id 58 | self.question_text = question_text 59 | self.doc_tokens = doc_tokens 60 | self.orig_answer_text = orig_answer_text 61 | self.start_position = start_position 62 | self.end_position = end_position 63 | 64 | def __str__(self): 65 | return self.__repr__() 66 | 67 | def __repr__(self): 68 | s = "" 69 | s += "qas_id: %s" % (self.qas_id) 70 | s += ", question_text: %s" % (self.question_text) 71 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 72 | if self.start_position: 73 | s += ", start_position: %d" % (self.start_position) 74 | if self.end_position: 75 | s += ", end_position: %d" % (self.end_position) 76 | return s 77 | 78 | 79 | class InputFeatures(object): 80 | """A single set of features of data.""" 81 | def __init__(self, 82 | unique_id, 83 | example_index, 84 | doc_span_index, 85 | tokens, 86 | token_to_orig_map, 87 | token_is_max_context, 88 | input_ids, 89 | input_mask, 90 | segment_ids, 91 | start_position=None, 92 | end_position=None): 93 | self.unique_id = unique_id 94 | self.example_index = example_index 95 | self.doc_span_index = doc_span_index 96 | self.tokens = tokens 97 | self.token_to_orig_map = token_to_orig_map 98 | self.token_is_max_context = token_is_max_context 99 | self.input_ids = input_ids 100 | self.input_mask = input_mask 101 | self.segment_ids = segment_ids 102 | self.start_position = start_position 103 | self.end_position = end_position 104 | 105 | 106 | # new function to deal with .gz and .jsonl file 107 | def get_data(input_file): 108 | if input_file.endswith('.gz'): 109 | with gzip.GzipFile(input_file, 'r') as reader: 110 | # skip header 111 | content = reader.read().decode('utf-8').strip().split('\n')[1:] 112 | input_data = [json.loads(line) for line in content] 113 | else: 114 | with open(input_file, 'r', encoding="utf-8") as reader: 115 | # lines = reader.readlines() 116 | # input_data = [json.loads(line) for line in lines] 117 | print(reader.readline()) 118 | input_data = [json.loads(line) for line in reader] 119 | return input_data 120 | 121 | 122 | def read_mrqa_examples(input_file, is_training, ignore=0, percentage=1): 123 | """Read a MRQA json file into a list of MRQAExample.""" 124 | input_data = get_data(input_file) 125 | 126 | def is_whitespace(c): 127 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 128 | return True 129 | return False 130 | 131 | examples = [] 132 | num_answers = 0 133 | num_to_ignore = int(ignore * len(input_data)) 134 | num_to_load = int(percentage * len(input_data)) 135 | if ignore != 0 and percentage != 1 and ignore + percentage == 1: 136 | num_to_load = max(num_to_load, len(input_data) - num_to_ignore) 137 | logger.info('Notes: # documents loaded = {}'.format(num_to_load - num_to_ignore)) 138 | for entry in input_data[num_to_ignore:(num_to_ignore + num_to_load)]: 139 | paragraph_text = entry["context"] 140 | doc_tokens = [] 141 | char_to_word_offset = [] 142 | prev_is_whitespace = True 143 | for c in paragraph_text: 144 | if is_whitespace(c): 145 | prev_is_whitespace = True 146 | else: 147 | if prev_is_whitespace: 148 | doc_tokens.append(c) 149 | else: 150 | doc_tokens[-1] += c 151 | prev_is_whitespace = False 152 | char_to_word_offset.append(len(doc_tokens) - 1) 153 | for qa in entry["qas"]: 154 | qas_id = qa["qid"] 155 | question_text = qa["question"] 156 | start_position = None 157 | end_position = None 158 | orig_answer_text = None 159 | if is_training: 160 | answers = qa["detected_answers"] 161 | spans = sorted([span for spans in answers for span in spans['char_spans']]) 162 | # take first span 163 | char_start, char_end = spans[0][0], spans[0][1] 164 | orig_answer_text = paragraph_text[char_start:char_end + 1] 165 | start_position, end_position = char_to_word_offset[char_start], char_to_word_offset[ 166 | char_end] 167 | num_answers += sum([len(spans['char_spans']) for spans in answers]) 168 | example = MRQAExample(qas_id=qas_id, 169 | question_text=question_text, 170 | doc_tokens=doc_tokens, 171 | orig_answer_text=orig_answer_text, 172 | start_position=start_position, 173 | end_position=end_position) 174 | examples.append(example) 175 | logger.info('Num avg answers: {}'.format(num_answers / len(examples))) 176 | return examples 177 | 178 | 179 | def convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride, max_query_length, 180 | is_training): 181 | """Loads a data file into a list of `InputBatch`s.""" 182 | 183 | unique_id = 1000000000 184 | 185 | features = [] 186 | for (example_index, example) in enumerate(examples): 187 | query_tokens = tokenizer.tokenize(example.question_text) 188 | 189 | if len(query_tokens) > max_query_length: 190 | query_tokens = query_tokens[0:max_query_length] 191 | 192 | tok_to_orig_index = [] 193 | orig_to_tok_index = [] 194 | all_doc_tokens = [] 195 | for (i, token) in enumerate(example.doc_tokens): 196 | orig_to_tok_index.append(len(all_doc_tokens)) 197 | sub_tokens = tokenizer.tokenize(token) 198 | for sub_token in sub_tokens: 199 | tok_to_orig_index.append(i) 200 | all_doc_tokens.append(sub_token) 201 | 202 | tok_start_position = None 203 | tok_end_position = None 204 | if is_training: 205 | tok_start_position = -1 206 | tok_end_position = -1 207 | if is_training: 208 | tok_start_position = orig_to_tok_index[example.start_position] 209 | if example.end_position < len(example.doc_tokens) - 1: 210 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 211 | else: 212 | tok_end_position = len(all_doc_tokens) - 1 213 | (tok_start_position, 214 | tok_end_position) = _improve_answer_span(all_doc_tokens, tok_start_position, 215 | tok_end_position, tokenizer, 216 | example.orig_answer_text) 217 | 218 | # The -3 accounts for [CLS], [SEP] and [SEP] 219 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 220 | # We can have documents that are longer than the maximum sequence length. 221 | # To deal with this we do a sliding window approach, where we take chunks 222 | # of the up to our max length with a stride of `doc_stride`. 223 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 224 | "DocSpan", ["start", "length"]) 225 | doc_spans = [] 226 | start_offset = 0 227 | while start_offset < len(all_doc_tokens): 228 | length = len(all_doc_tokens) - start_offset 229 | if length > max_tokens_for_doc: 230 | length = max_tokens_for_doc 231 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 232 | if start_offset + length == len(all_doc_tokens): 233 | break 234 | start_offset += min(length, doc_stride) 235 | 236 | for (doc_span_index, doc_span) in enumerate(doc_spans): 237 | tokens = [] 238 | token_to_orig_map = {} 239 | token_is_max_context = {} 240 | segment_ids = [] 241 | tokens.append("[CLS]") 242 | segment_ids.append(0) 243 | for token in query_tokens: 244 | tokens.append(token) 245 | segment_ids.append(0) 246 | tokens.append("[SEP]") 247 | segment_ids.append(0) 248 | 249 | for i in range(doc_span.length): 250 | split_token_index = doc_span.start + i 251 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 252 | 253 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index) 254 | token_is_max_context[len(tokens)] = is_max_context 255 | tokens.append(all_doc_tokens[split_token_index]) 256 | segment_ids.append(1) 257 | tokens.append("[SEP]") 258 | segment_ids.append(1) 259 | 260 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 261 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 262 | # tokens are attended to. 263 | input_mask = [1] * len(input_ids) 264 | # Zero-pad up to the sequence length. 265 | while len(input_ids) < max_seq_length: 266 | input_ids.append(0) 267 | input_mask.append(0) 268 | segment_ids.append(0) 269 | 270 | assert len(input_ids) == max_seq_length 271 | assert len(input_mask) == max_seq_length 272 | assert len(segment_ids) == max_seq_length 273 | 274 | start_position = None 275 | end_position = None 276 | if is_training: 277 | # For training, if our document chunk does not contain an annotation 278 | # we throw it out, since there is nothing to predict. 279 | doc_start = doc_span.start 280 | doc_end = doc_span.start + doc_span.length - 1 281 | out_of_span = False 282 | if not (tok_start_position >= doc_start and tok_end_position <= doc_end): 283 | out_of_span = True 284 | if out_of_span: 285 | start_position = 0 286 | end_position = 0 287 | else: 288 | doc_offset = len(query_tokens) + 2 289 | start_position = tok_start_position - doc_start + doc_offset 290 | end_position = tok_end_position - doc_start + doc_offset 291 | if example_index < 0: 292 | logger.info("*** Example ***") 293 | logger.info("unique_id: %s" % (unique_id)) 294 | logger.info("example_index: %s" % (example_index)) 295 | logger.info("doc_span_index: %s" % (doc_span_index)) 296 | logger.info("tokens: %s" % " ".join(tokens)) 297 | logger.info("token_to_orig_map: %s" % 298 | " ".join(["%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])) 299 | logger.info("token_is_max_context: %s" % 300 | " ".join(["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()])) 301 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 302 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 303 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 304 | if is_training: 305 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 306 | logger.info("start_position: %d" % (start_position)) 307 | logger.info("end_position: %d" % (end_position)) 308 | logger.info("answer: %s" % (answer_text)) 309 | 310 | features.append( 311 | InputFeatures(unique_id=unique_id, 312 | example_index=example_index, 313 | doc_span_index=doc_span_index, 314 | tokens=tokens, 315 | token_to_orig_map=token_to_orig_map, 316 | token_is_max_context=token_is_max_context, 317 | input_ids=input_ids, 318 | input_mask=input_mask, 319 | segment_ids=segment_ids, 320 | start_position=start_position, 321 | end_position=end_position)) 322 | unique_id += 1 323 | 324 | return features 325 | 326 | 327 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): 328 | """Returns tokenized answer spans that better match the annotated answer.""" 329 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 330 | 331 | for new_start in range(input_start, input_end + 1): 332 | for new_end in range(input_end, new_start - 1, -1): 333 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 334 | if text_span == tok_answer_text: 335 | return (new_start, new_end) 336 | 337 | return (input_start, input_end) 338 | 339 | 340 | def _check_is_max_context(doc_spans, cur_span_index, position): 341 | """Check if this is the 'max context' doc span for the token.""" 342 | best_score = None 343 | best_span_index = None 344 | for (span_index, doc_span) in enumerate(doc_spans): 345 | end = doc_span.start + doc_span.length - 1 346 | if position < doc_span.start: 347 | continue 348 | if position > end: 349 | continue 350 | num_left_context = position - doc_span.start 351 | num_right_context = end - position 352 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 353 | if best_score is None or score > best_score: 354 | best_score = score 355 | best_span_index = span_index 356 | 357 | return cur_span_index == best_span_index 358 | 359 | 360 | RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) 361 | 362 | 363 | def make_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, 364 | do_lower_case, verbose_logging): 365 | example_index_to_features = collections.defaultdict(list) 366 | for feature in all_features: 367 | example_index_to_features[feature.example_index].append(feature) 368 | unique_id_to_result = {} 369 | for result in all_results: 370 | unique_id_to_result[result.unique_id] = result 371 | _PrelimPrediction = collections.namedtuple( 372 | "PrelimPrediction", 373 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 374 | all_predictions = collections.OrderedDict() 375 | all_nbest_json = collections.OrderedDict() 376 | for (example_index, example) in enumerate(all_examples): 377 | features = example_index_to_features[example_index] 378 | prelim_predictions = [] 379 | for (feature_index, feature) in enumerate(features): 380 | result = unique_id_to_result[feature.unique_id] 381 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 382 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 383 | for start_index in start_indexes: 384 | for end_index in end_indexes: 385 | if start_index >= len(feature.tokens): 386 | continue 387 | if end_index >= len(feature.tokens): 388 | continue 389 | if start_index not in feature.token_to_orig_map: 390 | continue 391 | if end_index not in feature.token_to_orig_map: 392 | continue 393 | if not feature.token_is_max_context.get(start_index, False): 394 | continue 395 | if end_index < start_index: 396 | continue 397 | length = end_index - start_index + 1 398 | if length > max_answer_length: 399 | continue 400 | prelim_predictions.append( 401 | _PrelimPrediction(feature_index=feature_index, 402 | start_index=start_index, 403 | end_index=end_index, 404 | start_logit=result.start_logits[start_index], 405 | end_logit=result.end_logits[end_index])) 406 | prelim_predictions = sorted(prelim_predictions, 407 | key=lambda x: (x.start_logit + x.end_logit), 408 | reverse=True) 409 | _NbestPrediction = collections.namedtuple("NbestPrediction", 410 | ["text", "start_logit", "end_logit"]) 411 | seen_predictions = {} 412 | nbest = [] 413 | for pred in prelim_predictions: 414 | if len(nbest) >= n_best_size: 415 | break 416 | feature = features[pred.feature_index] 417 | if pred.start_index > 0: 418 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 419 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 420 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 421 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 422 | tok_text = " ".join(tok_tokens) 423 | tok_text = tok_text.replace(" ##", "") 424 | tok_text = tok_text.replace("##", "") 425 | tok_text = tok_text.strip() 426 | tok_text = " ".join(tok_text.split()) 427 | orig_text = " ".join(orig_tokens) 428 | final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) 429 | if final_text in seen_predictions: 430 | continue 431 | seen_predictions[final_text] = True 432 | else: 433 | final_text = "" 434 | seen_predictions[final_text] = True 435 | nbest.append( 436 | _NbestPrediction(text=final_text, 437 | start_logit=pred.start_logit, 438 | end_logit=pred.end_logit)) 439 | if not nbest: 440 | nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 441 | assert len(nbest) >= 1 442 | total_scores = [] 443 | best_non_null_entry = None 444 | for entry in nbest: 445 | total_scores.append(entry.start_logit + entry.end_logit) 446 | if not best_non_null_entry: 447 | if entry.text: 448 | best_non_null_entry = entry 449 | probs = _compute_softmax(total_scores) 450 | nbest_json = [] 451 | for (i, entry) in enumerate(nbest): 452 | output = collections.OrderedDict() 453 | output["text"] = entry.text 454 | output["probability"] = probs[i] 455 | output["start_logit"] = entry.start_logit 456 | output["end_logit"] = entry.end_logit 457 | nbest_json.append(output) 458 | assert len(nbest_json) >= 1 459 | all_predictions[example.qas_id] = nbest_json[0]["text"] 460 | all_nbest_json[example.qas_id] = nbest_json 461 | return all_predictions, all_nbest_json 462 | 463 | 464 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 465 | """Project the tokenized prediction back to the original text.""" 466 | def _strip_spaces(text): 467 | ns_chars = [] 468 | ns_to_s_map = collections.OrderedDict() 469 | for (i, c) in enumerate(text): 470 | if c == " ": 471 | continue 472 | ns_to_s_map[len(ns_chars)] = i 473 | ns_chars.append(c) 474 | ns_text = "".join(ns_chars) 475 | return (ns_text, ns_to_s_map) 476 | 477 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 478 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 479 | start_position = tok_text.find(pred_text) 480 | if start_position == -1: 481 | if verbose_logging: 482 | logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 483 | return orig_text 484 | end_position = start_position + len(pred_text) - 1 485 | 486 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 487 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 488 | 489 | if len(orig_ns_text) != len(tok_ns_text): 490 | if verbose_logging: 491 | logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, 492 | tok_ns_text) 493 | return orig_text 494 | 495 | tok_s_to_ns_map = {} 496 | for (i, tok_index) in tok_ns_to_s_map.items(): 497 | tok_s_to_ns_map[tok_index] = i 498 | 499 | orig_start_position = None 500 | if start_position in tok_s_to_ns_map: 501 | ns_start_position = tok_s_to_ns_map[start_position] 502 | if ns_start_position in orig_ns_to_s_map: 503 | orig_start_position = orig_ns_to_s_map[ns_start_position] 504 | 505 | if orig_start_position is None: 506 | if verbose_logging: 507 | logger.info("Couldn't map start position") 508 | return orig_text 509 | 510 | orig_end_position = None 511 | if end_position in tok_s_to_ns_map: 512 | ns_end_position = tok_s_to_ns_map[end_position] 513 | if ns_end_position in orig_ns_to_s_map: 514 | orig_end_position = orig_ns_to_s_map[ns_end_position] 515 | 516 | if orig_end_position is None: 517 | if verbose_logging: 518 | logger.info("Couldn't map end position") 519 | return orig_text 520 | 521 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 522 | return output_text 523 | 524 | 525 | def _get_best_indexes(logits, n_best_size): 526 | """Get the n-best logits from a list.""" 527 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 528 | 529 | best_indexes = [] 530 | for i in range(len(index_and_score)): 531 | if i >= n_best_size: 532 | break 533 | best_indexes.append(index_and_score[i][0]) 534 | return best_indexes 535 | 536 | 537 | def _compute_softmax(scores): 538 | """Compute softmax probability over raw logits.""" 539 | if not scores: 540 | return [] 541 | 542 | max_score = None 543 | for score in scores: 544 | if max_score is None or score > max_score: 545 | max_score = score 546 | 547 | exp_scores = [] 548 | total_sum = 0.0 549 | for score in scores: 550 | x = math.exp(score - max_score) 551 | exp_scores.append(x) 552 | total_sum += x 553 | 554 | probs = [] 555 | for score in exp_scores: 556 | probs.append(score / total_sum) 557 | return probs 558 | 559 | 560 | def get_raw_scores(dataset, predictions): 561 | answers = {} 562 | for example in dataset: 563 | for qa in example['qas']: 564 | answers[qa['qid']] = qa['answers'] 565 | exact_scores = {} 566 | f1_scores = {} 567 | for qid, ground_truths in answers.items(): 568 | if qid not in predictions: 569 | print('Missing prediction for %s' % qid) 570 | continue 571 | prediction = predictions[qid] 572 | exact_scores[qid] = metric_max_over_ground_truths(exact_match_score, prediction, 573 | ground_truths) 574 | f1_scores[qid] = metric_max_over_ground_truths(f1_score, prediction, ground_truths) 575 | return exact_scores, f1_scores 576 | 577 | 578 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 579 | if not qid_list: 580 | total = len(exact_scores) 581 | return collections.OrderedDict([ 582 | ('exact', 100.0 * sum(exact_scores.values()) / total), 583 | ('f1', 100.0 * sum(f1_scores.values()) / total), 584 | ('total', total), 585 | ]) 586 | else: 587 | total = len(qid_list) 588 | return collections.OrderedDict([ 589 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 590 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 591 | ('total', total), 592 | ]) 593 | 594 | 595 | def evaluate(args, 596 | model, 597 | device, 598 | eval_dataset, 599 | eval_dataloader, 600 | eval_examples, 601 | eval_features, 602 | verbose=True): 603 | all_results = [] 604 | model.eval() 605 | for input_ids, input_mask, segment_ids, example_indices in eval_dataloader: 606 | input_ids = input_ids.to(device) 607 | input_mask = input_mask.to(device) 608 | segment_ids = segment_ids.to(device) 609 | with torch.no_grad(): 610 | batch_start_logits, batch_end_logits = model([input_ids, input_mask, segment_ids]) 611 | for i, example_index in enumerate(example_indices): 612 | start_logits = batch_start_logits[i].detach().cpu().tolist() 613 | end_logits = batch_end_logits[i].detach().cpu().tolist() 614 | eval_feature = eval_features[example_index.item()] 615 | unique_id = int(eval_feature.unique_id) 616 | all_results.append( 617 | RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) 618 | preds, nbest_preds = \ 619 | make_predictions(eval_examples, eval_features, all_results, 620 | args.n_best_size, args.max_answer_length, 621 | args.do_lower_case, args.verbose_logging) 622 | exact_raw, f1_raw = get_raw_scores(eval_dataset, preds) 623 | result = make_eval_dict(exact_raw, f1_raw) 624 | if verbose: 625 | logger.info("***** Eval results *****") 626 | for key in sorted(result.keys()): 627 | logger.info(" %s = %s", key, str(result[key])) 628 | return result, preds, nbest_preds 629 | 630 | 631 | def load_initialization(model, args): 632 | if os.path.exists(args.initialize_model_from_checkpoint + '/pytorch_model.bin'): 633 | ckpt = torch.load(args.initialize_model_from_checkpoint + '/pytorch_model.bin') 634 | model.load_state_dict(ckpt) 635 | else: 636 | ckpt = torch.load(args.initialize_model_from_checkpoint + '/saved_checkpoint') 637 | assert args.model == ckpt['args']['model'], args.model + ' vs ' + ckpt['args']['model'] 638 | model.load_state_dict(ckpt['model_state_dict']) 639 | logger.info("***** Model Initialization *****") 640 | logger.info("Loaded the model state from a saved checkpoint {}".format( 641 | args.initialize_model_from_checkpoint)) 642 | 643 | 644 | def turn_off_dropout(m): 645 | for mod in m.modules(): 646 | if isinstance(mod, torch.nn.Dropout): 647 | mod.p = 0 648 | 649 | 650 | def tune_bias_only(m): 651 | for name, param in m.bert.named_parameters(): 652 | if 'bias' in name or 'LayerNorm' in name: 653 | param.requires_grad = True 654 | else: 655 | param.requires_grad = False 656 | 657 | 658 | def flip(scores, flip_prob, negative_reward): 659 | if flip_prob != 0: 660 | probs = torch.rand(scores.shape).to(scores.device) 661 | # true for values to be flipped 662 | mask = probs < flip_prob 663 | positive = scores == 1 664 | scores[mask & positive] = negative_reward 665 | scores[mask & ~positive] = 1 666 | return scores 667 | 668 | 669 | def get_batch_rewards(start_probs, end_probs, start_positions, end_positions, device, args, 670 | tokenizer, input_ids): 671 | bs = start_probs.shape[0] 672 | if args.argmax_simulation: 673 | start_samples = torch.argmax(start_probs, dim=1) 674 | end_samples = torch.argmax(end_probs, dim=1) 675 | else: 676 | start_samples = torch.multinomial(start_probs, 1).view(-1) 677 | end_samples = torch.multinomial(end_probs, 1).view(-1) 678 | log_prob = start_probs[torch.arange(bs), start_samples].log() + end_probs[torch.arange(bs), 679 | end_samples].log() 680 | 681 | # compute rewards 682 | def binary_reward(): 683 | reward_mask = (start_samples == start_positions) & (end_samples == end_positions) 684 | rewards = torch.tensor([args.negative_reward] * bs).to(device) 685 | rewards[reward_mask] = 1 686 | return rewards 687 | 688 | rewards = eval(args.reward_fn)() 689 | rewards = flip(rewards, args.flip_prob, args.negative_reward) 690 | 691 | return start_samples, end_samples, log_prob, rewards 692 | 693 | 694 | def collect_rewards_offline(model, train_batches, args, device, tokenizer, n_gpu): 695 | total_pos = 0 696 | total_neg = 0 697 | for i in range(len(train_batches)): 698 | batch = train_batches[i] 699 | if n_gpu == 1: 700 | batch = tuple(t.to(device) for t in batch) 701 | 702 | # sampling 703 | input_ids, input_mask, segment_ids, start_positions, end_positions = batch 704 | with torch.no_grad(): 705 | start_probs, end_probs = model(batch=batch[:3], return_prob=True) 706 | 707 | start_samples, end_samples, log_prob, rewards = get_batch_rewards( 708 | start_probs, end_probs, start_positions, end_positions, device, args, tokenizer, 709 | input_ids) 710 | train_batches[i] = [ 711 | input_ids, input_mask, segment_ids, start_samples, end_samples, log_prob, rewards 712 | ] 713 | 714 | count_pos = torch.sum(rewards > 0).item() 715 | total_pos += count_pos 716 | total_neg += input_ids.shape[0] - count_pos 717 | return train_batches, total_pos, total_neg 718 | 719 | 720 | def main(args): 721 | args.timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f') 722 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 723 | n_gpu = torch.cuda.device_count() 724 | logger.info("device: {}, n_gpu: {}".format(device, n_gpu)) 725 | args.n_gpu = n_gpu 726 | 727 | # set up random seeds 728 | random.seed(args.seed) 729 | np.random.seed(args.seed) 730 | torch.manual_seed(args.seed) 731 | if n_gpu > 0: 732 | torch.cuda.manual_seed_all(args.seed) 733 | 734 | # deal with gradient accumulation 735 | if args.gradient_accumulation_steps < 1: 736 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 737 | args.gradient_accumulation_steps)) 738 | # actual bs = bs // g i.e. 5 = 10 // 2 739 | args.train_batch_size = \ 740 | args.train_batch_size // args.gradient_accumulation_steps 741 | 742 | # parse dataset 743 | if args.dataset is not None: 744 | assert args.train_file is None 745 | assert args.dev_file is None 746 | if args.dataset == 'squad': 747 | args.train_file = 'data/SQuAD_train.jsonl' 748 | args.dev_file = 'data/SQuAD_dev.jsonl.gz' 749 | elif args.dataset == 'hotpot': 750 | args.train_file = 'data/HotpotQA-train.jsonl.gz' 751 | args.dev_file = 'data/HotpotQA-dev.jsonl.gz' 752 | elif args.dataset == 'nq': 753 | args.train_file = 'data/NaturalQuestionsShort-train.jsonl.gz' 754 | args.dev_file = 'data/NaturalQuestionsShort-dev.jsonl.gz' 755 | elif args.dataset == 'news': 756 | args.train_file = 'data/NewsQA-train.jsonl.gz' 757 | args.dev_file = 'data/NewsQA-dev.jsonl.gz' 758 | elif args.dataset == 'search': 759 | args.train_file = 'data/SearchQA-train.jsonl.gz' 760 | args.dev_file = 'data/SearchQA-dev.jsonl.gz' 761 | elif args.dataset == 'trivia': 762 | args.train_file = 'data/TriviaQA-train.jsonl.gz' 763 | args.dev_file = 'data/TriviaQA-dev.jsonl.gz' 764 | else: 765 | raise ValueError('Unknown dataset') 766 | 767 | # if args.dataset is not None and args.pretrainex is not None: 768 | # assert args.initialize_model_from_checkpoint is None 769 | # raise ValueError('What initialization to use?') 770 | 771 | # if args.pretrainon is not None: 772 | # assert args.initialize_model_from_checkpoint is None 773 | # raise ValueError('Which dataset pretrained on?') 774 | 775 | # argparse checkers 776 | if not args.do_train and not args.do_eval: 777 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 778 | if args.do_train: 779 | assert args.train_file is not None 780 | if args.eval_test: 781 | assert args.test_file is not None 782 | # only evaluate on the test set: need an initialization 783 | if args.eval_test and not args.do_train: 784 | assert args.initialize_model_from_checkpoint is not None 785 | 786 | if args.percentage_train_data + args.percentage_train_data_to_ignore > 1: 787 | raise ValueError( 788 | "Problematic combination of percentages on training: {} to train but {} to ignore". 789 | format(args.percentage_train_data, args.percentage_train_data_to_ignore)) 790 | 791 | # set up logging files 792 | if not os.path.exists(args.output_dir): 793 | os.makedirs(args.output_dir) 794 | # set up the logging for this experiment 795 | args.output_dir += '/' + args.timestamp 796 | os.makedirs(args.output_dir) 797 | if args.do_train: 798 | logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "train.log"), 'w')) 799 | else: 800 | logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "eval.log"), 'w')) 801 | # log args 802 | logger.info(args) 803 | 804 | tokenizer = BertTokenizer.from_pretrained(args.model, do_lower_case=args.do_lower_case) 805 | 806 | if args.do_train and args.do_eval: 807 | # load dev dataset 808 | eval_dataset = get_data(input_file=args.dev_file) 809 | eval_examples = read_mrqa_examples(input_file=args.dev_file, is_training=False) 810 | eval_features = convert_examples_to_features(examples=eval_examples, 811 | tokenizer=tokenizer, 812 | max_seq_length=args.max_seq_length, 813 | doc_stride=args.doc_stride, 814 | max_query_length=args.max_query_length, 815 | is_training=False) 816 | logger.info("***** Dev *****") 817 | logger.info(" Num orig examples = %d", len(eval_examples)) 818 | logger.info(" Num split examples = %d", len(eval_features)) 819 | logger.info(" Batch size = %d", args.eval_batch_size) 820 | args.dev_num_orig_ex = len(eval_examples) 821 | args.dev_num_split_ex = len(eval_features) 822 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 823 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 824 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 825 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 826 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 827 | eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size) 828 | 829 | if args.do_train: 830 | train_examples = read_mrqa_examples(input_file=args.train_file, 831 | is_training=True, 832 | ignore=args.percentage_train_data_to_ignore, 833 | percentage=args.percentage_train_data) 834 | train_features = convert_examples_to_features(examples=train_examples, 835 | tokenizer=tokenizer, 836 | max_seq_length=args.max_seq_length, 837 | doc_stride=args.doc_stride, 838 | max_query_length=args.max_query_length, 839 | is_training=True) 840 | 841 | if args.train_mode == 'sorted' or args.train_mode == 'random_sorted': 842 | train_features = sorted(train_features, key=lambda f: np.sum(f.input_mask)) 843 | else: 844 | random.shuffle(train_features) 845 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 846 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 847 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 848 | all_start_positions = torch.tensor([f.start_position for f in train_features], 849 | dtype=torch.long) 850 | all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) 851 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 852 | all_start_positions, all_end_positions) 853 | train_dataloader = DataLoader(train_data, batch_size=args.train_batch_size) 854 | train_batches = [batch for batch in train_dataloader] 855 | 856 | num_train_optimization_steps = \ 857 | len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 858 | 859 | logger.info("***** Train *****") 860 | logger.info(" Num orig examples = %d", len(train_examples)) 861 | logger.info(" Num split examples = %d", len(train_features)) 862 | logger.info(" Batch size = %d", args.train_batch_size) 863 | logger.info(" Num steps = %d", num_train_optimization_steps) 864 | args.train_num_orig_ex = len(train_examples) 865 | args.train_num_split_ex = len(train_features) 866 | 867 | eval_step = max(1, len(train_batches) // args.eval_per_epoch) 868 | best_result = None 869 | lrs = [args.learning_rate] if args.learning_rate else \ 870 | [1e-4, 9e-5, 8e-5, 7e-5, 6e-5, 5e-5, 3e-5, 2e-5, 1e-5] 871 | for lr in lrs: 872 | if args.initialize_model_from_checkpoint: 873 | model = BertForQuestionAnswering(model_type=args.model) 874 | load_initialization(model=model, args=args) 875 | else: 876 | model = BertForQuestionAnswering(model_type=args.model) 877 | 878 | if args.turn_off_dropout: 879 | turn_off_dropout(model) 880 | 881 | if args.tune_bias_only: 882 | tune_bias_only(model) 883 | 884 | model.to(device) 885 | 886 | if n_gpu > 1: 887 | model = torch.nn.DataParallel(model) 888 | param_optimizer = list(model.named_parameters()) 889 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] 890 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 891 | optimizer_grouped_parameters = [{ 892 | 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 893 | 'weight_decay': 894 | 0.01 895 | }, { 896 | 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 897 | 'weight_decay': 898 | 0.0 899 | }] 900 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 901 | lr_scheduler = get_scheduler(args.scheduler, 902 | optimizer=optimizer, 903 | num_warmup_steps=int(num_train_optimization_steps * 904 | args.warmup_proportion), 905 | num_training_steps=num_train_optimization_steps) 906 | 907 | if args.setup == 'offline': 908 | train_batches, total_pos, total_neg = collect_rewards_offline( 909 | model, train_batches, args, device, tokenizer, n_gpu) 910 | logger.info("Offline regret computation: {} positives {} negatives".format( 911 | total_pos, total_neg)) 912 | 913 | if args.wandb: 914 | wandb.init( 915 | project='bandit-qa', 916 | name= 917 | f'{args.percentage_train_data}-{args.train_num_orig_ex}{args.dataset}_{args.algo}_{args.model}_{args.scheduler}={lr}_{args.initialize_model_from_checkpoint}+{args.argmax_simulation}_{args.output_dir}', 918 | notes=args.notes, 919 | config=vars(args)) 920 | wandb.watch(model) 921 | 922 | tr_loss = 0 923 | nb_tr_examples = 0 924 | nb_tr_steps = 0 925 | global_step = 0 926 | start_time = time.time() 927 | simulation_log = None 928 | one_epoch_f1 = None 929 | dev_f1s = [] 930 | steps = [] 931 | total_pos, total_neg = 0, 0 932 | for epoch in range(int(args.num_train_epochs)): 933 | model.train() 934 | logger.info("Start epoch #{} (lr = {})...".format(epoch, lr)) 935 | if args.train_mode == 'random' or args.train_mode == 'random_sorted': 936 | random.shuffle(train_batches) 937 | for step, batch in enumerate(train_batches): 938 | if n_gpu == 1: 939 | batch = tuple(t.to(device) for t in batch) 940 | 941 | start_probs, end_probs = model(batch=batch[:3], return_prob=True) 942 | bs = start_probs.shape[0] 943 | if args.setup == 'online': 944 | input_ids, _, _, start_positions, end_positions = batch 945 | start_samples, end_samples, log_prob, rewards = get_batch_rewards( 946 | start_probs, end_probs, start_positions, end_positions, device, args, 947 | tokenizer, input_ids) 948 | count_pos = torch.sum(rewards > 0).item() 949 | total_pos += count_pos 950 | total_neg += bs - count_pos 951 | else: 952 | input_ids, _, _, start_samples, end_samples, old_log_prob, old_rewards = batch 953 | log_prob = start_probs[torch.arange(bs), 954 | start_samples].log() + end_probs[torch.arange(bs), 955 | end_samples].log() 956 | ratios = torch.exp(log_prob - old_log_prob) 957 | rewards = torch.clamp(ratios, 0, 1) * old_rewards 958 | rewards = rewards.detach() 959 | 960 | # compute values 961 | if args.algo == 'Rwb': 962 | values = torch.tensor([-0.05] * bs).to(device) 963 | detached_advantages = rewards - values 964 | elif args.algo == 'Rwmb': 965 | detached_advantages = rewards - rewards.mean() 966 | else: 967 | detached_advantages = rewards 968 | # compute probs 969 | loss = (-log_prob * detached_advantages).mean() / 2 970 | 971 | if n_gpu > 1: 972 | loss = loss.mean() 973 | if args.gradient_accumulation_steps > 1: 974 | loss = loss / args.gradient_accumulation_steps 975 | 976 | tr_loss += loss.item() 977 | nb_tr_examples += input_ids.size(0) 978 | nb_tr_steps += 1 979 | 980 | loss.backward() 981 | if (step + 1) % args.gradient_accumulation_steps == 0: 982 | optimizer.step() 983 | lr_scheduler.step() 984 | optimizer.zero_grad() 985 | global_step += 1 986 | 987 | if args.wandb and (global_step + 1) % 25 == 0: 988 | wandb.log( 989 | { 990 | '(Train) policy loss': loss.item(), 991 | '(Train) reward': rewards.mean().item(), 992 | '(Train) advantage': detached_advantages.mean().item(), 993 | }, 994 | step=global_step) 995 | if simulation_log is not None: 996 | wandb.log(simulation_log, step=global_step) 997 | 998 | if (step + 1) % eval_step == 0 or step + 1 == len(train_batches): 999 | logger.info( 1000 | 'Epoch: {}, Step: {} / {}, used_time = {:.2f}s, loss = {:.6f}'.format( 1001 | epoch, step + 1, len(train_batches), 1002 | time.time() - start_time, tr_loss / nb_tr_steps)) 1003 | 1004 | if args.wandb: 1005 | wandb.log( 1006 | { 1007 | '(Train) loss': loss.item(), 1008 | '(Train) total pos': total_pos, 1009 | '(Train) total neg': total_neg 1010 | }, 1011 | step=global_step) 1012 | if simulation_log is not None: 1013 | wandb.log(simulation_log, step=global_step) 1014 | 1015 | save_model = False 1016 | if args.do_eval: 1017 | result, _, _ = \ 1018 | evaluate(args, model, device, eval_dataset, 1019 | eval_dataloader, eval_examples, eval_features) 1020 | model.train() 1021 | if args.wandb: 1022 | wandb.log(result, step=global_step) 1023 | result['global_step'] = global_step 1024 | result['epoch'] = epoch 1025 | result['learning_rate'] = lr 1026 | result[ 1027 | 'batch_size'] = args.train_batch_size * args.gradient_accumulation_steps 1028 | result['eval_step'] = eval_step 1029 | dev_f1s.append(round(result[args.eval_metric], 1)) 1030 | steps.append(step) 1031 | result['dev_f1s'] = dev_f1s 1032 | result['steps'] = steps 1033 | result['total_pos'] = total_pos 1034 | result['total_neg'] = total_neg 1035 | if (best_result is None) or (result[args.eval_metric] > 1036 | best_result[args.eval_metric]): 1037 | best_result = result 1038 | # save model when getting new best result 1039 | save_model = True 1040 | logger.info( 1041 | "!!! Best dev %s (lr=%s, epoch=%d): %.2f" % 1042 | (args.eval_metric, str(lr), epoch, result[args.eval_metric])) 1043 | elif best_result is not None: 1044 | save_model = True 1045 | else: 1046 | # case: no evaluation so just save the latest model 1047 | save_model = True 1048 | if save_model: 1049 | # NOTE changed 1050 | # save the config 1051 | model.bert.config.to_json_file( 1052 | os.path.join(args.output_dir, 'config.json')) 1053 | # save the model 1054 | torch.save( 1055 | { 1056 | 'global_step': global_step, 1057 | 'args': vars(args), 1058 | 'model_state_dict': model.state_dict(), 1059 | 'optimizer_state_dict': optimizer.state_dict(), 1060 | }, os.path.join(args.output_dir, 'saved_checkpoint')) 1061 | if best_result: 1062 | # i.e. best_result is not None 1063 | filename = EVAL_FILE 1064 | if len(lrs) != 1: 1065 | filename = str(lr) + '_' + EVAL_FILE 1066 | with open(os.path.join(args.output_dir, filename), "w") as writer: 1067 | for key in sorted(best_result.keys()): 1068 | writer.write("%s = %s\n" % (key, str(best_result[key]))) 1069 | if epoch == 0: 1070 | one_epoch_f1 = best_result['f1'] 1071 | writer.write("%s = %s\n" % ('one_epoch_f1', one_epoch_f1)) 1072 | if args.save_checkpoint: 1073 | checkpoint = { 1074 | 'global_step': global_step, 1075 | 'args': vars(args), 1076 | 'model_state_dict': model.state_dict(), 1077 | 'optimizer_state_dict': optimizer.state_dict(), 1078 | } 1079 | folder = args.output_dir + '/ckpt' 1080 | # create a folder if not existed 1081 | if not os.path.exists(folder): 1082 | os.makedirs(folder) 1083 | filename = folder + f'/{args.timestamp}_gstep={global_step}' 1084 | torch.save(checkpoint, filename) 1085 | 1086 | if args.eval_test: 1087 | if args.wandb: 1088 | wandb.init( 1089 | project='pqa', 1090 | entity='lil', 1091 | name= 1092 | f'{args.model}_{args.test_file}_{args.initialize_model_from_checkpoint}+{args.argmax_simulation}_{args.output_dir}', 1093 | tags=['eval'], 1094 | notes=args.notes, 1095 | config=vars(args)) 1096 | 1097 | eval_dataset = get_data(args.test_file) 1098 | eval_examples = read_mrqa_examples(input_file=args.test_file, is_training=False) 1099 | eval_features = convert_examples_to_features(examples=eval_examples, 1100 | tokenizer=tokenizer, 1101 | max_seq_length=args.max_seq_length, 1102 | doc_stride=args.doc_stride, 1103 | max_query_length=args.max_query_length, 1104 | is_training=False) 1105 | logger.info("***** Test *****") 1106 | logger.info(" Num orig examples = %d", len(eval_examples)) 1107 | logger.info(" Num split examples = %d", len(eval_features)) 1108 | logger.info(" Batch size = %d", args.eval_batch_size) 1109 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 1110 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 1111 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 1112 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 1113 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 1114 | eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size) 1115 | 1116 | # NOTE change: only evaluate on the test set 1117 | if not args.do_train: 1118 | model = BertForQuestionAnswering(model_type=args.model) 1119 | assert args.initialize_model_from_checkpoint is not None 1120 | load_initialization(model=model, args=args) 1121 | model.to(device) 1122 | result, preds, nbest_preds = evaluate(args, model, device, eval_dataset, eval_dataloader, 1123 | eval_examples, eval_features) 1124 | with open(os.path.join(args.output_dir, PRED_FILE), "w") as writer: 1125 | writer.write(json.dumps(preds, indent=4) + "\n") 1126 | with open(os.path.join(args.output_dir, TEST_FILE), "w") as writer: 1127 | for key in sorted(result.keys()): 1128 | writer.write("%s = %s\n" % (key, str(result[key]))) 1129 | 1130 | if args.wandb: 1131 | wandb.log(result, step=0) 1132 | 1133 | 1134 | if __name__ == "__main__": 1135 | parser = argparse.ArgumentParser() 1136 | parser.add_argument("--model", default=None, type=str, required=True) 1137 | parser.add_argument( 1138 | "--output_dir", 1139 | default=None, 1140 | type=str, 1141 | required=True, 1142 | help="The output directory where the model checkpoints and predictions will be written.") 1143 | parser.add_argument("--train_file", default=None, type=str) 1144 | parser.add_argument("--dev_file", default=None, type=str) 1145 | parser.add_argument("--test_file", default=None, type=str) 1146 | parser.add_argument("--eval_per_epoch", 1147 | default=10, 1148 | type=int, 1149 | help="How many times it evaluates on dev set per epoch") 1150 | parser.add_argument( 1151 | "--max_seq_length", 1152 | default=384, 1153 | type=int, 1154 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 1155 | "longer than this will be truncated, and sequences shorter than this will be padded.") 1156 | parser.add_argument("--doc_stride", 1157 | default=128, 1158 | type=int, 1159 | help="When splitting up a long document into chunks, " 1160 | "how much stride to take between chunks.") 1161 | parser.add_argument( 1162 | "--max_query_length", 1163 | default=64, 1164 | type=int, 1165 | help="The maximum number of tokens for the question. Questions longer than this will " 1166 | "be truncated to this length.") 1167 | parser.add_argument("--do_train", action='store_true', help="Whether to run training.") 1168 | parser.add_argument("--do_eval", 1169 | action='store_true', 1170 | help="Whether to run eval on the dev set.") 1171 | parser.add_argument("--do_lower_case", 1172 | action='store_true', 1173 | help="Set this flag if you are using an uncased model.") 1174 | parser.add_argument("--eval_test", 1175 | action='store_true', 1176 | help='Wehther to run eval on the test set.') 1177 | parser.add_argument("--train_batch_size", 1178 | default=32, 1179 | type=int, 1180 | help="Total batch size for training.") 1181 | parser.add_argument("--eval_batch_size", 1182 | default=8, 1183 | type=int, 1184 | help="Total batch size for predictions.") 1185 | parser.add_argument("--learning_rate", 1186 | default=None, 1187 | type=float, 1188 | help="The initial learning rate for Adam.") 1189 | parser.add_argument("--num_train_epochs", 1190 | default=3.0, 1191 | type=float, 1192 | help="Total number of training epochs to perform.") 1193 | parser.add_argument("--eval_metric", default='f1', type=str) 1194 | parser.add_argument("--train_mode", 1195 | type=str, 1196 | default='random_sorted', 1197 | choices=['random', 'sorted', 'random_sorted']) 1198 | parser.add_argument( 1199 | "--warmup_proportion", 1200 | default=0.1, 1201 | type=float, 1202 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " 1203 | "of training.") 1204 | parser.add_argument( 1205 | "--n_best_size", 1206 | default=20, 1207 | type=int, 1208 | help="The total number of n-best predictions to generate in the nbest_predictions.json " 1209 | "output file.") 1210 | parser.add_argument("--max_answer_length", 1211 | default=30, 1212 | type=int, 1213 | help="The maximum length of an answer that can be generated. " 1214 | "This is needed because the start " 1215 | "and end predictions are not conditioned on one another.") 1216 | parser.add_argument( 1217 | "--verbose_logging", 1218 | action='store_true', 1219 | help="If true, all of the warnings related to data processing will be printed. " 1220 | "A number of warnings are expected for a normal MRQA evaluation.") 1221 | parser.add_argument("--no_cuda", 1222 | action='store_true', 1223 | help="Whether not to use CUDA when available") 1224 | parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") 1225 | parser.add_argument( 1226 | '--gradient_accumulation_steps', 1227 | type=int, 1228 | default=1, 1229 | help="Number of updates steps to accumulate before performing a backward/update pass.") 1230 | 1231 | # below are customized arguments 1232 | parser.add_argument('--wandb', action='store_true', help='Whether to use wandb for logging.') 1233 | parser.add_argument('--notes', default='', help='Notes for this experiment: wandb logging') 1234 | parser.add_argument( 1235 | '--save_checkpoint', 1236 | action='store_true', 1237 | help= 1238 | 'Whether to save different checkpoints during training: recommend not to use this argument for space saving' 1239 | ) 1240 | parser.add_argument('--percentage_train_data', 1241 | type=float, 1242 | default=1, 1243 | help='Percetage of training data to load: for debugging purpose') 1244 | parser.add_argument( 1245 | '--percentage_train_data_to_ignore', 1246 | type=float, 1247 | default=0, 1248 | help= 1249 | 'Percetage of training data to ignore first: for experiments where to exlucde the some initial data used for pre-training' 1250 | ) 1251 | parser.add_argument( 1252 | '--argmax_simulation', 1253 | action='store_true', 1254 | help='Whether to take argmax of the results for simulation: stick with argmax in this work') 1255 | parser.add_argument( 1256 | '--reward_fn', 1257 | default='binary_reward', 1258 | type=str, 1259 | choices=['binary_reward'], 1260 | help='the type of reward function used during training: stick with binary in this work') 1261 | parser.add_argument('--initialize_model_from_checkpoint', 1262 | default=None, 1263 | help='Relative filepath to a saved checkpoint as model initialization.') 1264 | parser.add_argument( 1265 | '--flip_prob', 1266 | default=0.0, 1267 | type=float, 1268 | help='Parameter for the perturbation function: x probability to flip the rewards') 1269 | parser.add_argument('--scheduler', default='linear', type=str, help='Learning rate scheduler.') 1270 | parser.add_argument( 1271 | '--transfer', 1272 | action='store_true', 1273 | help='Domain adaptation or not. Not used in the code, only for wandb logging purpose.') 1274 | parser.add_argument('--turn_off_dropout', 1275 | action='store_true', 1276 | help='Should turn off dropout for simulation experiments') 1277 | parser.add_argument( 1278 | '--tune_bias_only', 1279 | action='store_true', 1280 | help='Only tune the bias and layernorm in bert, as well as the classifier on top') 1281 | parser.add_argument('--algo', 1282 | default='R', 1283 | choices=['R'], 1284 | help='training algorithm: stick with R in this work.') 1285 | parser.add_argument('--negative_reward', 1286 | default=-0.1, 1287 | type=float, 1288 | help='value for negative update') 1289 | parser.add_argument('--setup', 1290 | default='online', 1291 | type=str, 1292 | choices=['online', 'offline'], 1293 | help='online or offline setup') 1294 | parser.add_argument("--dataset", 1295 | default=None, 1296 | type=str, 1297 | choices=['squad', 'hotpot', 'nq', 'trivia', 'search', 'news']) 1298 | parser.add_argument("--pretrainon", 1299 | default=None, 1300 | type=str, 1301 | choices=['squad', 'hotpot', 'nq', 'trivia', 'search', 'news']) 1302 | parser.add_argument("--pretrainex", default=None, type=int, choices=[64, 256, 1024]) 1303 | args = parser.parse_args() 1304 | main(args) 1305 | --------------------------------------------------------------------------------