├── .gitignore ├── download_data.sh ├── bart.py ├── README.md ├── cli.py ├── run.py └── data.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | out 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir -p data 4 | wget https://nlp.cs.washington.edu/ambigqa/data/nqopen.zip -O data/nqopen.zip 5 | unzip -d data data/nqopen.zip 6 | rm data/nqopen.zip 7 | 8 | -------------------------------------------------------------------------------- /bart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor, nn 4 | from transformers import T5ForConditionalGeneration, BartForConditionalGeneration 5 | from transformers.modeling_bart import shift_tokens_right 6 | 7 | class MyBart(BartForConditionalGeneration): 8 | def forward(self, input_ids, attention_mask=None, encoder_outputs=None, 9 | decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None, 10 | use_cache=False, is_training=False): 11 | 12 | if is_training: 13 | _decoder_input_ids = shift_tokens_right(decoder_input_ids, self.config.pad_token_id) 14 | else: 15 | _decoder_input_ids = decoder_input_ids 16 | 17 | outputs = self.model( 18 | input_ids, 19 | attention_mask=attention_mask, 20 | encoder_outputs=encoder_outputs, 21 | decoder_input_ids=_decoder_input_ids, 22 | decoder_attention_mask=decoder_attention_mask, 23 | decoder_cached_states=decoder_cached_states, 24 | use_cache=use_cache, 25 | ) 26 | lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) 27 | if is_training: 28 | loss_fct = nn.CrossEntropyLoss(reduction="sum", ignore_index=self.config.pad_token_id) 29 | loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), 30 | decoder_input_ids.view(-1)) 31 | return loss 32 | return (lm_logits, ) + outputs[1:] 33 | 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BART version of closed-book QA 2 | 3 | This is a BART version of sequence-to-sequence model for open-domain QA in a closed-book setup, based on [PyTorch](https://pytorch.org/) and [Huggingface's Transformers](https://github.com/huggingface/transformers). 4 | 5 | The model is a sequence-to-sequence model that takes a question as an input and outputs the answer, without reading any external resource (e.g. passages). 6 | Please refer to [Roberts et al., 2020, How Much Knowledge Can You Pack Into the Parameters of a Language Model?](https://arxiv.org/abs/2002.08910) to learn more about closed-book QA setup and the original model based on T5. Their code and model checkpoints are available [here](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa). 7 | 8 | The model is based on BART-large. Please refer to [Lewis et al., ACL 2020, BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) to learn more about BART. 9 | 10 | We experiment with Natural Questions open-domain data (NQ-open), but the code should work on any QA data with question-answer pairs. 11 | 12 | 13 | ## Requirement 14 | 15 | This code is tested on Python 3.6.9. 16 | 17 | Install PyTorch and Transformers: 18 | ``` 19 | pip install torch==1.1.0 20 | pip install git+https://github.com/huggingface/transformers.git@7b75aa9fa55bee577e2c7403301ed31103125a35 21 | ``` 22 | 23 | Download NQ-open data: 24 | ``` 25 | chmod +x download_data.sh; ./download_data.sh 26 | ``` 27 | 28 | ## Training 29 | 30 | ``` 31 | python cli.py --do_train --output_dir out/nq-bart-closed-qa \ 32 | --train_file data/nqopen-train.json \ 33 | --predict_file data/nqopen-dev.json \ 34 | --train_batch_size ${train_bs} \ 35 | --predict_batch_size ${test_bs} \ 36 | --append_another_bos 37 | ``` 38 | 39 | The script will save the log and the best checkpoint inside `out/nq-bart-closed-qa`. 40 | 41 | 42 | Other useful commands (please refer to `cli.py` for the full list): 43 | - `eval_period`: interval to evaluate on the dev data 44 | - `verbose`: print a progress bar 45 | - `debug`: train and evaluate on a subset of the dev data for debugging purposes 46 | 47 | You can use `train_batch_size` and `predict_batch_size` depending on the gpu availability. With one 16GB gpu, you can use `train_batch_size=64, predict_batch_size=64`. 48 | Our model that we reports the result below was trained with `train_batch_size=1024, predict_batch_size 256` using eight 32GB gpus. Training took roughly 34 hours. 49 | 50 | Note: 51 | - This script saves the pre-tokenized data in `data/` once question-answer pairs are tokenized for the first time. 52 | - The model gives the best result when prepending extra BOS token (`--append_another_bos`). 53 | - Inference on multi-gpus is not working for now; we will update the code once it is fixed. 54 | 55 | ## Inference 56 | 57 | ``` 58 | python cli.py --do_predict --output_dir out/nq-bart-closed-qa \ 59 | --predict_file data/nqopen-dev.json \ 60 | --predict_batch_size ${test_bs} \ 61 | --append_another_bos --prefix dev_ 62 | python cli.py --do_predict --output_dir out/nq-bart-closed-qa \ 63 | --predict_file data/nqopen-test.json \ 64 | --predict_batch_size ${test_bs} \ 65 | --append_another_bos --prefix test_ 66 | ``` 67 | 68 | It will save the prediction file as `out/nq-bart-closed-qa/{dev|test}_predictions.json`. 69 | 70 | ## Result 71 | 72 | The final Exact Match score we get is 25.05 on the dev data and 24.10 on the test data. 73 | 74 | We made the best model checkpoint and the predictions on the dev/test data available. 75 | 76 | - [Best checkpoint + Dev/Test prediction (1.8G)][1] 77 | - [Dev/test predictions only (228K)][2] 78 | 79 | Note that T5-based model gets 27.0, 29.8, 32.1 and 34.5 on the test set with Base, Large, 3B and 11B, respectively, based on [the original paper](https://arxiv.org/pdf/2002.08910.pdf). Several factors could lead to the performance gaps: (i) T5 has a larger number of parameters and trained on a larger set of data and (ii) the original paper includes the dev data for training, whereas our codebase only trains the model on the train data and uses the dev data for choosing the best checkpoint. 80 | We also did not perform any hyperparamter tuning, as our goal is to provide the basic codebase rather than to achieve the best possible performance; we leave it for the future work. 81 | 82 | Note: that the original paper includes ablations that exclude supervised data for T5 pretraining, and reports comparable (or better) numbers: see Appendix C of [the original paper](https://arxiv.org/pdf/2002.08910.pdf) for the details! 83 | 84 | ## Contact 85 | 86 | Please email [Sewon Min](https://shmsw25.github.io) or write a Github issue for any question. 87 | 88 | 89 | [1]: http://nlp.cs.washington.edu/ambigqa/models/nq-bart-closed-qa/nq-bart-closed-qa.zip 90 | [2]: http://nlp.cs.washington.edu/ambigqa/models/nq-bart-closed-qa/predictions.zip 91 | 92 | 93 | -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import argparse 23 | import logging 24 | 25 | import random 26 | import numpy as np 27 | import torch 28 | 29 | from run import run 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | 34 | ## Basic parameters 35 | parser.add_argument("--train_file", default="data/nqopen-train.json") 36 | parser.add_argument("--predict_file", default="data/nqopen-dev.json") 37 | parser.add_argument("--output_dir", default=None, type=str, required=True) 38 | parser.add_argument("--do_train", action='store_true') 39 | parser.add_argument("--do_predict", action='store_true') 40 | 41 | ## Model parameters 42 | parser.add_argument("--checkpoint", type=str) 43 | parser.add_argument("--do_lowercase", action='store_true', default=True) 44 | 45 | # Preprocessing/decoding-related parameters 46 | parser.add_argument('--max_input_length', type=int, default=32) 47 | parser.add_argument('--max_output_length', type=int, default=20) 48 | parser.add_argument('--num_beams', type=int, default=4) 49 | parser.add_argument("--append_another_bos", action='store_true', default=False) 50 | 51 | # Training-related parameters 52 | parser.add_argument("--train_batch_size", default=40, type=int, 53 | help="Batch size per GPU/CPU for training.") 54 | parser.add_argument("--predict_batch_size", default=400, type=int, 55 | help="Batch size per GPU/CPU for evaluation.") 56 | parser.add_argument("--learning_rate", default=1e-5, type=float, 57 | help="The initial learning rate for Adam.") 58 | parser.add_argument("--warmup_proportion", default=0.01, type=float, 59 | help="Weight decay if we apply some.") 60 | parser.add_argument("--weight_decay", default=0.0, type=float, 61 | help="Weight deay if we apply some.") 62 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 63 | help="Epsilon for Adam optimizer.") 64 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 65 | help="Max gradient norm.") 66 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, 67 | help="Max gradient norm.") 68 | parser.add_argument("--num_train_epochs", default=10000.0, type=float, 69 | help="Total number of training epochs to perform.") 70 | parser.add_argument("--warmup_steps", default=0, type=int, 71 | help="Linear warmup over warmup_steps.") 72 | parser.add_argument('--wait_step', type=int, default=10) 73 | 74 | # Other parameters 75 | parser.add_argument("--verbose", action='store_true', 76 | help="If true, all of the warnings related to data processing will be printed. " 77 | "A number of warnings are expected for a normal SQuAD evaluation.") 78 | parser.add_argument('--eval_period', type=int, default=1000, 79 | help="Evaluate & save model") 80 | parser.add_argument('--prefix', type=str, default='', 81 | help="Prefix for saving predictions") 82 | parser.add_argument('--debug', action='store_true', 83 | help="Use a subset of data for debugging") 84 | parser.add_argument('--seed', type=int, default=42, 85 | help="random seed for initialization") 86 | args = parser.parse_args() 87 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 88 | print("Output directory () already exists and is not empty.") 89 | if not os.path.exists(args.output_dir): 90 | os.makedirs(args.output_dir, exist_ok=True) 91 | 92 | ##### Start writing logs 93 | 94 | log_filename = "{}log.txt".format("" if args.do_train else "eval_") 95 | 96 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 97 | datefmt='%m/%d/%Y %H:%M:%S', 98 | level=logging.INFO, 99 | handlers=[logging.FileHandler(os.path.join(args.output_dir, log_filename)), 100 | logging.StreamHandler()]) 101 | logger = logging.getLogger(__name__) 102 | logger.info(args) 103 | logger.info(args.output_dir) 104 | 105 | random.seed(args.seed) 106 | np.random.seed(args.seed) 107 | torch.manual_seed(args.seed) 108 | args.n_gpu = torch.cuda.device_count() 109 | 110 | if args.n_gpu > 0: 111 | torch.cuda.manual_seed_all(args.seed) 112 | 113 | if not args.do_train and not args.do_predict: 114 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 115 | 116 | if args.do_train: 117 | if not args.train_file: 118 | raise ValueError("If `do_train` is True, then `train_file` must be specified.") 119 | if not args.predict_file: 120 | raise ValueError("If `do_train` is True, then `predict_file` must be specified.") 121 | 122 | if args.do_predict: 123 | if not args.predict_file: 124 | raise ValueError("If `do_predict` is True, then `predict_file` must be specified.") 125 | 126 | logger.info("Using {} gpus".format(args.n_gpu)) 127 | run(args, logger) 128 | 129 | if __name__=='__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from transformers import BartTokenizer, BartConfig 6 | from transformers import AdamW, get_linear_schedule_with_warmup 7 | 8 | from data import QAData 9 | from bart import MyBart 10 | 11 | def run(args, logger): 12 | tokenizer = BartTokenizer.from_pretrained("bart-large") 13 | 14 | train_data = QAData(logger, args, args.train_file, True) 15 | dev_data = QAData(logger, args, args.predict_file, False) 16 | 17 | train_data.load_dataset(tokenizer) 18 | train_data.load_dataloader() 19 | 20 | dev_data.load_dataset(tokenizer) 21 | dev_data.load_dataloader() 22 | 23 | if args.do_train: 24 | if args.checkpoint is not None: 25 | def convert_to_single_gpu(state_dict): 26 | def _convert(key): 27 | if key.startswith('module.'): 28 | return key[7:] 29 | return key 30 | return {_convert(key):value for key, value in state_dict.items()} 31 | model = MyBart.from_pretrained("bart-large", 32 | state_dict=convert_to_single_gpu(torch.load(args.checkpoint))) 33 | else: 34 | model = MyBart.from_pretrained("bart-large") 35 | if args.n_gpu>1: 36 | model = torch.nn.DataParallel(model) 37 | 38 | if torch.cuda.is_available(): 39 | model.to(torch.device("cuda")) 40 | 41 | no_decay = ['bias', 'LayerNorm.weight'] 42 | optimizer_grouped_parameters = [ 43 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 44 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 45 | ] 46 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 47 | scheduler = get_linear_schedule_with_warmup(optimizer, 48 | num_warmup_steps=args.warmup_steps, 49 | num_training_steps=100000) 50 | train(args, logger, model, train_data, dev_data, optimizer, scheduler) 51 | 52 | if args.do_predict: 53 | checkpoint = os.path.join(args.output_dir, 'best-model.pt') 54 | def convert_to_single_gpu(state_dict): 55 | def _convert(key): 56 | if key.startswith('module.'): 57 | return key[7:] 58 | return key 59 | return {_convert(key):value for key, value in state_dict.items()} 60 | model = MyBart.from_pretrained("bart-large", 61 | state_dict=convert_to_single_gpu(torch.load(checkpoint))) 62 | logger.info("Loading checkpoint from {}".format(checkpoint)) 63 | if torch.cuda.is_available(): 64 | model.to(torch.device("cuda")) 65 | model.eval() 66 | ems = inference(model, dev_data, save_predictions=True) 67 | logger.info("%s on %s data: %.2f" % (dev_data.metric, dev_data.data_type, np.mean(ems)*100)) 68 | 69 | def train(args, logger, model, train_data, dev_data, optimizer, scheduler): 70 | model.train() 71 | global_step = 0 72 | train_losses = [] 73 | best_accuracy = -1 74 | stop_training=False 75 | 76 | logger.info("Starting training!") 77 | for epoch in range(int(args.num_train_epochs)): 78 | for batch in train_data.dataloader: 79 | global_step += 1 80 | if torch.cuda.is_available(): 81 | batch = [b.to(torch.device("cuda")) for b in batch] 82 | loss = model(input_ids=batch[0], attention_mask=batch[1], 83 | decoder_input_ids=batch[2], decoder_attention_mask=batch[3], 84 | is_training=True) 85 | if args.n_gpu > 1: 86 | loss = loss.mean() # mean() to average on multi-gpu. 87 | if torch.isnan(loss).data: 88 | logger.info("Stop training because loss=%s" % (loss.data)) 89 | stop_training=True 90 | break 91 | train_losses.append(loss.detach().cpu()) 92 | loss.backward() 93 | 94 | if global_step % args.gradient_accumulation_steps == 0: 95 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 96 | optimizer.step() # We have accumulated enought gradients 97 | scheduler.step() 98 | model.zero_grad() 99 | 100 | if global_step % args.eval_period == 0: 101 | model.eval() 102 | curr_em = inference(model if args.n_gpu==1 else model.module, dev_data) 103 | logger.info("Step %d Train loss %.2f %s %.2f%% on epoch=%d" % ( 104 | global_step, 105 | np.mean(train_losses), 106 | dev_data.metric, 107 | curr_em*100, 108 | epoch)) 109 | train_losses = [] 110 | if best_accuracy < curr_em: 111 | model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()} 112 | torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt")) 113 | logger.info("Saving model with best %s: %.2f%% -> %.2f%% on epoch=%d, global_step=%d" % \ 114 | (dev_data.metric, best_accuracy*100.0, curr_em*100.0, epoch, global_step)) 115 | best_accuracy = curr_em 116 | wait_step = 0 117 | stop_training = False 118 | else: 119 | wait_step += 1 120 | if wait_step >= args.wait_step: 121 | stop_training = True 122 | break 123 | model.train() 124 | if stop_training: 125 | break 126 | 127 | def inference(model, dev_data, save_predictions=False): 128 | predictions = [] 129 | bos_token_id = dev_data.tokenizer.bos_token_id 130 | for i, batch in enumerate(dev_data.dataloader): 131 | if torch.cuda.is_available(): 132 | batch = [b.to(torch.device("cuda")) for b in batch] 133 | outputs = model.generate(input_ids=batch[0], 134 | attention_mask=batch[1], 135 | num_beams=dev_data.args.num_beams, 136 | max_length=dev_data.args.max_output_length, 137 | early_stopping=True,) 138 | for input_, output in zip(batch[0], outputs): 139 | pred = dev_data.decode(output) 140 | predictions.append(pred) 141 | if save_predictions: 142 | dev_data.save_predictions(predictions) 143 | return np.mean(dev_data.evaluate(predictions)) 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | import string 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler 10 | 11 | class QAData(object): 12 | 13 | def __init__(self, logger, args, data_path, is_training): 14 | self.data_path = data_path 15 | if args.debug: 16 | self.data_path = data_path.replace("train", "dev") 17 | with open(self.data_path, "r") as f: 18 | self.data = json.load(f) 19 | if type(self.data)==dict: 20 | self.data = self.data["data"] 21 | if args.debug: 22 | self.data = self.data[:40] 23 | assert type(self.data)==list 24 | assert all(["id" in d for d in self.data]), self.data[0].keys() 25 | if type(self.data[0]["id"])==int: 26 | for i in range(len(self.data)): 27 | self.data[i]["id"] = str(self.data[i]["id"]) 28 | 29 | self.index2id = {i:d["id"] for i, d in enumerate(self.data)} 30 | self.id2index = {d["id"]:i for i, d in enumerate(self.data)} 31 | self.is_training = is_training 32 | self.load = not args.debug 33 | self.logger = logger 34 | self.args = args 35 | if "test" in self.data_path: 36 | self.data_type = "test" 37 | elif "dev" in self.data_path: 38 | self.data_type = "dev" 39 | elif "train" in self.data_path: 40 | self.data_type = "train" 41 | else: 42 | raise NotImplementedError() 43 | self.metric = "EM" 44 | self.max_input_length = self.args.max_input_length 45 | self.tokenizer = None 46 | self.dataset = None 47 | self.dataloader = None 48 | self.cache = None 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def decode(self, tokens): 54 | return self.tokenizer.decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True).lower() 55 | 56 | def decode_batch(self, tokens): 57 | return [self.decode(_tokens) for _tokens in tokens] 58 | 59 | def flatten(self, answers): 60 | new_answers, metadata = [], [] 61 | for answer in answers: 62 | metadata.append((len(new_answers), len(new_answers)+len(answer))) 63 | new_answers += answer 64 | return new_answers, metadata 65 | 66 | def load_dataset(self, tokenizer, do_return=False): 67 | self.tokenizer = tokenizer 68 | postfix = tokenizer.__class__.__name__.replace("zer", "zed") 69 | preprocessed_path = os.path.join( 70 | "/".join(self.data_path.split("/")[:-1]), 71 | self.data_path.split("/")[-1].replace(".json", "-{}.json".format(postfix))) 72 | if self.load and os.path.exists(preprocessed_path): 73 | self.logger.info("Loading pre-tokenized data from {}".format(preprocessed_path)) 74 | with open(preprocessed_path, "r") as f: 75 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, \ 76 | metadata = json.load(f) 77 | else: 78 | print ("Start tokenizing...") 79 | questions = [d["question"] if d["question"].endswith("?") else d["question"]+"?" 80 | for d in self.data] 81 | answers = [d["answer"] for d in self.data] 82 | answers, metadata = self.flatten(answers) 83 | if self.args.do_lowercase: 84 | questions = [question.lower() for question in questions] 85 | answers = [answer.lower() for answer in answers] 86 | if self.args.append_another_bos: 87 | questions = [" "+question for question in questions] 88 | answers = [" " +answer for answer in answers] 89 | question_input = tokenizer.batch_encode_plus(questions, 90 | pad_to_max_length=True, 91 | max_length=self.args.max_input_length) 92 | answer_input = tokenizer.batch_encode_plus(answers, 93 | pad_to_max_length=True) 94 | input_ids, attention_mask = question_input["input_ids"], question_input["attention_mask"] 95 | decoder_input_ids, decoder_attention_mask = answer_input["input_ids"], answer_input["attention_mask"] 96 | if self.load: 97 | preprocessed_data = [input_ids, attention_mask, 98 | decoder_input_ids, decoder_attention_mask, 99 | metadata] 100 | with open(preprocessed_path, "w") as f: 101 | json.dump([input_ids, attention_mask, 102 | decoder_input_ids, decoder_attention_mask, 103 | metadata], f) 104 | self.dataset = MyQADataset(input_ids, attention_mask, 105 | decoder_input_ids, decoder_attention_mask, 106 | in_metadata=None, out_metadata=metadata, 107 | is_training=self.is_training) 108 | self.logger.info("Loaded {} examples from {} data".format(len(self.dataset), self.data_type)) 109 | 110 | if do_return: 111 | return self.dataset 112 | 113 | def load_dataloader(self, do_return=False): 114 | self.dataloader = MyDataLoader(self.args, self.dataset, self.is_training) 115 | if do_return: 116 | return self.dataloader 117 | 118 | def evaluate(self, predictions): 119 | assert len(predictions)==len(self), (len(predictions), len(self)) 120 | ems = [] 121 | for (prediction, dp) in zip(predictions, self.data): 122 | ems.append(get_exact_match(prediction, dp["answer"])) 123 | return ems 124 | 125 | def save_predictions(self, predictions): 126 | assert len(predictions)==len(self), (len(predictions), len(self)) 127 | prediction_dict = {dp["id"]:prediction for dp, prediction in zip(self.data, predictions)} 128 | save_path = os.path.join(self.args.output_dir, "{}predictions.json".format(self.args.prefix)) 129 | with open(save_path, "w") as f: 130 | json.dump(prediction_dict, f) 131 | self.logger.info("Saved prediction in {}".format(save_path)) 132 | 133 | def get_exact_match(prediction, groundtruth): 134 | if type(groundtruth)==list: 135 | if len(groundtruth)==0: 136 | return 0 137 | return np.max([get_exact_match(prediction, gt) for gt in groundtruth]) 138 | return (normalize_answer(prediction) == normalize_answer(groundtruth)) 139 | 140 | def normalize_answer(s): 141 | def remove_articles(text): 142 | return re.sub(r'\b(a|an|the)\b', ' ', text) 143 | def white_space_fix(text): 144 | return ' '.join(text.split()) 145 | def remove_punc(text): 146 | exclude = set(string.punctuation) 147 | return ''.join(ch for ch in text if ch not in exclude) 148 | def lower(text): 149 | return text.lower() 150 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 151 | 152 | 153 | class MyQADataset(Dataset): 154 | def __init__(self, 155 | input_ids, attention_mask, 156 | decoder_input_ids, decoder_attention_mask, 157 | in_metadata=None, out_metadata=None, 158 | is_training=False): 159 | self.input_ids = torch.LongTensor(input_ids) 160 | self.attention_mask = torch.LongTensor(attention_mask) 161 | self.decoder_input_ids = torch.LongTensor(decoder_input_ids) 162 | self.decoder_attention_mask = torch.LongTensor(decoder_attention_mask) 163 | self.in_metadata = list(zip(range(len(input_ids)), range(1, 1+len(input_ids)))) \ 164 | if in_metadata is None else in_metadata 165 | self.out_metadata = list(zip(range(len(decoder_input_ids)), range(1, 1+len(decoder_input_ids)))) \ 166 | if out_metadata is None else out_metadata 167 | self.is_training = is_training 168 | 169 | assert len(self.input_ids)==len(self.attention_mask)==self.in_metadata[-1][-1] 170 | assert len(self.decoder_input_ids)==len(self.decoder_attention_mask)==self.out_metadata[-1][-1] 171 | 172 | def __len__(self): 173 | return len(self.in_metadata) 174 | 175 | def __getitem__(self, idx): 176 | if not self.is_training: 177 | idx = self.in_metadata[idx][0] 178 | return self.input_ids[idx], self.attention_mask[idx] 179 | 180 | in_idx = np.random.choice(range(*self.in_metadata[idx])) 181 | out_idx = np.random.choice(range(*self.out_metadata[idx])) 182 | return self.input_ids[in_idx], self.attention_mask[in_idx], \ 183 | self.decoder_input_ids[out_idx], self.decoder_attention_mask[out_idx] 184 | 185 | class MyDataLoader(DataLoader): 186 | 187 | def __init__(self, args, dataset, is_training): 188 | if is_training: 189 | sampler=RandomSampler(dataset) 190 | batch_size = args.train_batch_size 191 | else: 192 | sampler=SequentialSampler(dataset) 193 | batch_size = args.predict_batch_size 194 | super(MyDataLoader, self).__init__(dataset, sampler=sampler, batch_size=batch_size) 195 | 196 | 197 | --------------------------------------------------------------------------------