├── pycocoevalcap ├── __init__.py ├── bleu │ ├── __init__.py │ ├── LICENSE │ ├── bleu.py │ └── bleu_scorer.py ├── tokenizer │ ├── __init__.py │ ├── stanford-corenlp-3.4.1.jar │ └── ptbtokenizer.py └── eval.py ├── .gitignore ├── codes ├── models │ ├── biencoder.py │ ├── span_predictor.py │ └── seq2seq.py ├── util.py ├── PassageData.py ├── DataLoader.py ├── cli.py ├── download_data.py ├── ambigqa_evaluate_script.py ├── run.py ├── README.md ├── QGData.py └── QAData.py ├── evidence.md ├── README.md └── ambigqa_evaluate_script.py /pycocoevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /pycocoevalcap/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /pycocoevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tmp* 2 | *.json 3 | __pycache__ 4 | */__pycache__ 5 | */*/__pycache__ 6 | */*.sh 7 | */*.err 8 | */*.out 9 | */out 10 | 11 | -------------------------------------------------------------------------------- /pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shmsw25/AmbigQA/HEAD/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /pycocoevalcap/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /pycocoevalcap/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /codes/models/biencoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss 5 | 6 | from transformers import BertPreTrainedModel, BertModel 7 | 8 | class MyBiEncoder(BertPreTrainedModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | self.ctx_model = BertModel(config) 12 | self.question_model = BertModel(config) 13 | self.init_weights() 14 | 15 | def forward_qbert(self, input_ids, attention_mask): 16 | return self.question_model(input_ids=input_ids, attention_mask=attention_mask)[0][:,0,:] 17 | 18 | def forward_pbert(self, input_ids, attention_mask): 19 | return self.ctx_model(input_ids=input_ids, attention_mask=attention_mask)[0][:,0,:] 20 | 21 | def forward(self, 22 | q_input_ids, q_attention_mask, 23 | p_input_ids, p_attention_mask, 24 | is_training=False): 25 | ''' 26 | :q_input_ids, q_attention_mask, q_token_type_ids: [N, L] 27 | :p_input_ids, p_attention_mask, p_token_type_ids: [N, M, L] 28 | ''' 29 | N, M, L = p_input_ids.size() 30 | question_output = self.forward_qbert(q_input_ids, q_attention_mask) 31 | passage_output = self.forward_pbert(input_ids=p_input_ids.view(-1, L), 32 | attention_mask=p_attention_mask.view(-1, L)) 33 | if is_training: 34 | inner_prods = torch.matmul(question_output, passage_output.transpose(0, 1)) 35 | loss_fct = CrossEntropyLoss() 36 | labels = M * torch.arange(N, dtype=torch.long).cuda() 37 | total_loss = loss_fct(inner_prods, labels) # [N, N*M] 38 | return total_loss 39 | else: 40 | return question_output, passage_output 41 | 42 | 43 | -------------------------------------------------------------------------------- /pycocoevalcap/eval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | from tokenizer.ptbtokenizer import PTBTokenizer 3 | from bleu.bleu import Bleu 4 | from meteor.meteor import Meteor 5 | from rouge.rouge import Rouge 6 | from cider.cider import Cider 7 | from spice.spice import Spice 8 | 9 | class COCOEvalCap: 10 | def __init__(self, coco, cocoRes): 11 | self.evalImgs = [] 12 | self.eval = {} 13 | self.imgToEval = {} 14 | self.coco = coco 15 | self.cocoRes = cocoRes 16 | self.params = {'image_id': coco.getImgIds()} 17 | 18 | def evaluate(self): 19 | imgIds = self.params['image_id'] 20 | # imgIds = self.coco.getImgIds() 21 | gts = {} 22 | res = {} 23 | for imgId in imgIds: 24 | gts[imgId] = self.coco.imgToAnns[imgId] 25 | res[imgId] = self.cocoRes.imgToAnns[imgId] 26 | 27 | # ================================================= 28 | # Set up scorers 29 | # ================================================= 30 | print 'tokenization...' 31 | tokenizer = PTBTokenizer() 32 | gts = tokenizer.tokenize(gts) 33 | res = tokenizer.tokenize(res) 34 | 35 | # ================================================= 36 | # Set up scorers 37 | # ================================================= 38 | print 'setting up scorers...' 39 | scorers = [ 40 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 41 | (Meteor(),"METEOR"), 42 | (Rouge(), "ROUGE_L"), 43 | (Cider(), "CIDEr"), 44 | (Spice(), "SPICE") 45 | ] 46 | 47 | # ================================================= 48 | # Compute scores 49 | # ================================================= 50 | for scorer, method in scorers: 51 | print 'computing %s score...'%(scorer.method()) 52 | score, scores = scorer.compute_score(gts, res) 53 | if type(method) == list: 54 | for sc, scs, m in zip(score, scores, method): 55 | self.setEval(sc, m) 56 | self.setImgToEvalImgs(scs, gts.keys(), m) 57 | print "%s: %0.3f"%(m, sc) 58 | else: 59 | self.setEval(score, method) 60 | self.setImgToEvalImgs(scores, gts.keys(), method) 61 | print "%s: %0.3f"%(method, score) 62 | self.setEvalImgs() 63 | 64 | def setEval(self, score, method): 65 | self.eval[method] = score 66 | 67 | def setImgToEvalImgs(self, scores, imgIds, method): 68 | for imgId, score in zip(imgIds, scores): 69 | if not imgId in self.imgToEval: 70 | self.imgToEval[imgId] = {} 71 | self.imgToEval[imgId]["image_id"] = imgId 72 | self.imgToEval[imgId][method] = score 73 | 74 | def setEvalImgs(self): 75 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] -------------------------------------------------------------------------------- /pycocoevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import tempfile 15 | import itertools 16 | 17 | # path to the stanford corenlp jar 18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 19 | 20 | # punctuations to be removed from the sentences 21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 23 | 24 | class PTBTokenizer: 25 | """Python wrapper of Stanford PTBTokenizer""" 26 | 27 | def tokenize(self, captions_for_image): 28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 29 | 'edu.stanford.nlp.process.PTBTokenizer', \ 30 | '-preserveLines', '-lowerCase'] 31 | 32 | # ====================================================== 33 | # prepare data for PTB Tokenizer 34 | # ====================================================== 35 | final_tokenized_captions_for_image = {} 36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 37 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 38 | 39 | # ====================================================== 40 | # save sentences to temporary file 41 | # ====================================================== 42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 43 | tmp_file = tempfile.NamedTemporaryFile(mode="w", 44 | delete=False, 45 | dir=path_to_jar_dirname) 46 | tmp_file.write(sentences) 47 | tmp_file.close() 48 | 49 | # ====================================================== 50 | # tokenize sentence 51 | # ====================================================== 52 | cmd.append(os.path.basename(tmp_file.name)) 53 | #cmd[0] = os.path.join(path_to_jar_dirname, os.path.basename(tmp_file.name)) + "/" + cmd[0] 54 | cmd = [" ".join(cmd)] 55 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 56 | stdout=subprocess.PIPE, shell=True) 57 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 58 | lines = token_lines.decode().split('\n') 59 | # remove temp file 60 | os.remove(tmp_file.name) 61 | 62 | # ====================================================== 63 | # create dictionary for tokenized captions 64 | # ====================================================== 65 | for k, line in zip(image_id, lines): 66 | if not k in final_tokenized_captions_for_image: 67 | final_tokenized_captions_for_image[k] = [] 68 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 69 | if w not in PUNCTUATIONS]) 70 | final_tokenized_captions_for_image[k].append(tokenized_caption) 71 | 72 | return final_tokenized_captions_for_image 73 | -------------------------------------------------------------------------------- /codes/models/span_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss 5 | 6 | from transformers import BertForQuestionAnswering, AlbertForQuestionAnswering 7 | 8 | class SpanPredictor(BertForQuestionAnswering): 9 | def __init__(self, config): 10 | config.num_labels = 2 11 | super().__init__(config) 12 | self.qa_classifier = nn.Linear(config.hidden_size, 1) 13 | 14 | def forward(self, 15 | input_ids=None, attention_mask=None, 16 | token_type_ids=None, inputs_embeds=None, 17 | start_positions=None, end_positions=None, answer_mask=None, 18 | is_training=False): 19 | 20 | N, M, L = input_ids.size() 21 | output = self.bert(input_ids.view(N*M, L), 22 | attention_mask=attention_mask.view(N*M, L), 23 | token_type_ids=token_type_ids.view(N*M, L), 24 | inputs_embeds=None if inputs_embeds is None else inputs_embeds.view(N*M, L, -1))[0] 25 | logits = self.qa_outputs(output) 26 | start_logits, end_logits = logits.split(1, dim=-1) 27 | start_logits = start_logits.squeeze(-1) 28 | end_logits = end_logits.squeeze(-1) 29 | sel_logits = self.qa_classifier(output[:,0,:]) 30 | 31 | if is_training: 32 | start_positions, end_positions, answer_mask = \ 33 | start_positions.view(N*M, -1), end_positions.view(N*M, -1), answer_mask.view(N*M, -1) 34 | return get_loss(start_positions, end_positions, answer_mask, 35 | start_logits, end_logits, sel_logits, N, M) 36 | else: 37 | return start_logits.view(N, M, L), end_logits.view(N, M, L), sel_logits.view(N, M) 38 | 39 | 40 | class AlbertSpanPredictor(AlbertForQuestionAnswering): 41 | def __init__(self, config): 42 | config.num_labels = 2 43 | super().__init__(config) 44 | self.qa_classifier = nn.Linear(config.hidden_size, 1) 45 | 46 | def forward(self, 47 | input_ids=None, attention_mask=None, 48 | token_type_ids=None, inputs_embeds=None, 49 | start_positions=None, end_positions=None, answer_mask=None, 50 | is_training=False): 51 | 52 | N, M, L = input_ids.size() 53 | output = self.albert(input_ids.view(N*M, L), 54 | attention_mask=attention_mask.view(N*M, L), 55 | token_type_ids=token_type_ids.view(N*M, L), 56 | inputs_embeds=None if inputs_embeds is None else inputs_embeds.view(N*M, L, -1))[0] 57 | logits = self.qa_outputs(output) 58 | start_logits, end_logits = logits.split(1, dim=-1) 59 | start_logits = start_logits.squeeze(-1) 60 | end_logits = end_logits.squeeze(-1) 61 | sel_logits = self.qa_classifier(output[:,0,:]) 62 | 63 | if is_training: 64 | start_positions, end_positions, answer_mask = \ 65 | start_positions.view(N*M, -1), end_positions.view(N*M, -1), answer_mask.view(N*M, -1) 66 | return get_loss(start_positions, end_positions, answer_mask, 67 | start_logits, end_logits, sel_logits, N, M) 68 | else: 69 | return start_logits.view(N, M, L), end_logits.view(N, M, L), sel_logits.view(N, M) 70 | 71 | def get_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, sel_logits, N, M): 72 | answer_mask = answer_mask.type(torch.FloatTensor).cuda() 73 | ignored_index = start_logits.size(1) 74 | start_positions.clamp_(0, ignored_index) 75 | end_positions.clamp_(0, ignored_index) 76 | loss_fct = CrossEntropyLoss(reduce=False, ignore_index=ignored_index) 77 | 78 | sel_logits = sel_logits.view(N, M) 79 | sel_labels = torch.zeros(N, dtype=torch.long).cuda() 80 | sel_loss = torch.sum(loss_fct(sel_logits, sel_labels)) 81 | start_losses = [(loss_fct(start_logits, _start_positions) * _span_mask) \ 82 | for (_start_positions, _span_mask) \ 83 | in zip(torch.unbind(start_positions, dim=1), torch.unbind(answer_mask, dim=1))] 84 | end_losses = [(loss_fct(end_logits, _end_positions) * _span_mask) \ 85 | for (_end_positions, _span_mask) \ 86 | in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_mask, dim=1))] 87 | loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + \ 88 | torch.cat([t.unsqueeze(1) for t in end_losses], dim=1) 89 | loss_tensor=loss_tensor.view(N, M, -1).max(dim=1)[0] 90 | span_loss = _take_mml(loss_tensor) 91 | return span_loss + sel_loss 92 | 93 | def _take_mml(loss_tensor): 94 | marginal_likelihood = torch.sum(torch.exp( 95 | - loss_tensor - 1e10 * (loss_tensor==0).float()), 1) 96 | return -torch.sum(torch.log(marginal_likelihood + \ 97 | torch.ones(loss_tensor.size(0)).cuda()*(marginal_likelihood==0).float())) 98 | 99 | -------------------------------------------------------------------------------- /codes/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor, nn 4 | from transformers import T5ForConditionalGeneration, BartForConditionalGeneration 5 | 6 | class MyBart(BartForConditionalGeneration): 7 | def forward(self, input_ids, attention_mask=None, encoder_outputs=None, 8 | decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None, 9 | use_cache=False, is_training=False): 10 | 11 | if is_training: 12 | decoder_start_token_id = self.config.decoder_start_token_id 13 | _decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) 14 | _decoder_input_ids[..., 1:] = decoder_input_ids[..., :-1].clone() 15 | _decoder_input_ids[..., 0] = decoder_start_token_id 16 | #_decoder_input_ids = decoder_input_ids.clone() 17 | #_decoder_input_ids[..., 0] = decoder_start_token_id 18 | #new_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) 19 | #new_decoder_input_ids[..., :-1] = decoder_input_ids[..., 1:].clone() 20 | #decoder_input_ids = new_decoder_input_ids 21 | #print (input_ids[0,:10].detach().cpu().tolist()) 22 | #print (_decoder_input_ids[0,:10].detach().cpu().tolist()) 23 | #print (decoder_input_ids[0, :10].detach().cpu().tolist()) 24 | else: 25 | _decoder_input_ids = decoder_input_ids.clone() 26 | #print (input_ids[0,:10].detach().cpu().tolist()) 27 | #print (_decoder_input_ids[0].detach().cpu().tolist()) 28 | 29 | outputs = self.model( 30 | input_ids, 31 | attention_mask=attention_mask, 32 | encoder_outputs=encoder_outputs, 33 | decoder_input_ids=_decoder_input_ids, 34 | decoder_attention_mask=decoder_attention_mask, 35 | decoder_cached_states=decoder_cached_states, 36 | use_cache=use_cache, 37 | ) 38 | lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) 39 | if is_training: 40 | loss_fct = nn.CrossEntropyLoss(reduce=False) 41 | losses = loss_fct(lm_logits.view(-1, self.config.vocab_size), 42 | decoder_input_ids.view(-1)) 43 | loss = torch.sum(losses * decoder_attention_mask.float().view(-1)) 44 | return loss 45 | return (lm_logits, ) + outputs[1:] 46 | 47 | class MyT5(T5ForConditionalGeneration): 48 | def forward(self, input_ids=None, attention_mask=None, encoder_outputs=None, 49 | decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None, 50 | decoder_past_key_value_states=None, 51 | use_cache=False, is_training=False): 52 | 53 | if encoder_outputs is None: 54 | encoder_outputs = self.encoder( 55 | input_ids=input_ids, attention_mask=attention_mask, 56 | inputs_embeds=None, head_mask=None 57 | ) 58 | hidden_states = encoder_outputs[0] 59 | 60 | _decoder_input_ids = decoder_input_ids 61 | _decoder_attention_mask = decoder_attention_mask 62 | if is_training: 63 | decoder_start_token_id = self.config.decoder_start_token_id 64 | _decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) 65 | _decoder_input_ids[..., 1:] = decoder_input_ids[..., :-1].clone() 66 | _decoder_input_ids = _decoder_input_ids + self.config.eos_token_id * (_decoder_input_ids==0).long() 67 | _decoder_input_ids[..., 0] = decoder_start_token_id 68 | _decoder_attention_mask = decoder_attention_mask.new_zeros(decoder_attention_mask.shape) 69 | _decoder_attention_mask[..., 1:] = decoder_attention_mask[..., :-1].clone() 70 | _decoder_attention_mask[..., 0] = 1 71 | else: 72 | print (_decoder_input_ids) 73 | print (_decoder_attention_mask) 74 | decoder_outputs = self.decoder( 75 | input_ids=_decoder_input_ids, 76 | attention_mask=_decoder_attention_mask, 77 | inputs_embeds=None, 78 | past_key_value_states=None, 79 | encoder_hidden_states=hidden_states, 80 | encoder_attention_mask=attention_mask, 81 | head_mask=None, 82 | use_cache=use_cache, 83 | ) 84 | 85 | sequence_output = decoder_outputs[0] 86 | sequence_output = sequence_output * (self.model_dim ** -0.5) 87 | lm_logits = self.lm_head(sequence_output) 88 | 89 | decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here 90 | if is_training: 91 | loss_fct = nn.CrossEntropyLoss(reduce=False) 92 | losses = loss_fct(lm_logits.view(-1, self.config.vocab_size), 93 | decoder_input_ids.view(-1)) 94 | loss = torch.sum(losses * decoder_attention_mask.float().view(-1)) 95 | return loss 96 | 97 | return decoder_outputs + encoder_outputs 98 | 99 | 100 | -------------------------------------------------------------------------------- /evidence.md: -------------------------------------------------------------------------------- 1 | # AmbigQA/AmbigNQ README 2 | 3 | We released semi-oracle evidence passages for researchers interested in multi-answer extraction and disambiguation rather than retrieval. This document describes how they are obtained, statistics and the upperbound when using these evidence passages. 4 | 5 | 6 | ## Content 7 | 1. [General information about the data](#general-information-about-the-data) 8 | * [Data format](#data-format) 9 | * [When to use this data](#when-to-use-this-data) 10 | * [Statistics and performance upperbound](#statistics-and-performance-upperbound) 11 | 2. [Data Creation](#data-creation) 12 | 13 | 14 | ## General information about the data 15 | 16 | 17 | The number of Wikipedia articles per question is 3.0 on average. 18 | 19 | ### Data Format 20 | 21 | The json file is a list, which i-th item is a dictionary containing `id`, `question`, `annotations` (as in the original AmbigQA data) as well as `articles_plain_text` and `articles_html_text`. `articles_plain_text` is a list of articles in plain text (Markdown), such as: 22 | ```python 23 | [ 24 | "# Dexter (season 1)\n\nThe first season of Dexter is an adaptation of Jeff Lindsay's first novel in a series of the same name, Darkly Dreaming Dexter. ...", 25 | "# Chrisstian Camargo\n\nChristian Camargo is an American actor, producer, writer and director. ... ## Early years\n\nCamargo was born ...", 26 | "# List of Dexter characters\n\nThis is a list of characters ... * Michael C. Hall\n* Maxwell Huckabee (age 3) * Nicholas Vigneau (young Dexter, season 7) ..." 27 | ] 28 | ``` 29 | `article_html_text` is a list of articles in an html format, such as: 30 | ```python 31 | [ 32 | "

Dexter (season 1)\n\nThe first season of Dexter is an adaptation of Jeff Lindsay's first novel in a series of the same name, Darkly Dreaming Dexter. ...", 33 | "

Chrisstian Camargo

\n\nChristian Camargo is an American actor, producer, writer and director. ...

Early years

\n\nCamargo was born ...", 34 | "

List of Dexter characters

\n\nThis is a list of characters ...
  • Michael C. Hall
  • Maxwell Huckabee (age 3)
  • Nicholas Vigneau (young Dexter, season 7)
  • ..." 35 | ] 36 | ``` 37 | 38 | ### When to use this data 39 | 40 | We recommend using this data if you want to focus on multi-answer extraction and disambiguation given evidence text. 41 | The end-to-end QA model is supposed to retrieve evidence text, but evidence retrieval itself is a very difficult problem and current retrieval models are not good at retrieving high-coverage evidence text (reference: [this paper](https://arxiv.org/abs/2104.08445)). While we encourage making progress in the retrieval part, we are releasing this semi-oracle evidence data so that the progress in the subsequent part is not blocked by the progress in retrieval. 42 | 43 | 44 | While the size of the evidence text can be a variable in the end-to-end QA model, we set the size of the semi-oracle evidence to be approximately 10,000 words, following much of recent work in QA that uses 100 passages * 100 words per passage. 45 | 46 | 47 | ### Statistics and performance upperbound 48 | 49 | #### Distributions of the number of Wikipedia articles per question 50 | | | 1 | 2 | 3 | 4+ | 51 | |---|---|---|---|---| 52 | | Train | 0.1 | 0.1 | 99.4 | 0.3 | 53 | | Dev | 0.0 | 0.0 | 99.5 | 0.4 | 54 | | Test | 0.0 | 0.0 | 99.5 | 0.5 | 55 | 56 | #### Distributions of the number of tokens per question 57 | (based on the plain text, white space tokenization) 58 | | | 0--5000 | 5000--10000| 10000--15000 | 15000--20000| 20000-- | 59 | |---|---|---|---|---|---| 60 | | Train | 30.9 | 33.2 | 19.1 | 9.2 | 7.7 | 61 | | Dev | 29.8 | 33.9 | 19.4 | 8.8 | 8.0 | 62 | | Test | 29.0 | 34.8 | 18.0 | 9.8 | 8.3 | 63 | 64 | #### Answer coverage and performance upperbound 65 | (Performance upperbound is the same for both answer F1 and QG F1) 66 | 67 | | | Macro-Avg coverage | Perf upperbound (all) | Perf upperbound (multi-only) | 68 | |---|---|---|---| 69 | | Train | 78.2 | 80.1 | 77.1 | 70 | | Dev | 84.4 | 86.6 | 82.2 | 71 | | Test | 83.0 | 85.6 | 81.3 | 72 | 73 | 74 | #### Distributions of the number of covered answers (%) 75 | 76 | | | 0 | 1 | 2 | 3 | 4+ | 77 | |---|---|---|---|---|---| 78 | | Train | 10.1 | 62.8 | 33.6 | 23.8 | 10.1 | 79 | | Dev | 15.7 | 58.5 | 42.4 | 30.4 | 15.7 | 80 | | Test | 18.8 | 56.1 | 45.6 | 36.0 | 18.8 | 81 | 82 | 83 | ## Data Creation 84 | 85 | We use the Wikipedia dump of 02/01/2020, which is the same one as used in the [AmbigQA paper](https://arxiv.org/abs/2004.10645). We preprocess the dump so that each article includes headers, plain text and lists (tables and infoboxes are excluded). We excluded disambiguation pages, following prior work (DrQA, DPR and more). 86 | 87 | We look up the annotator interactive logs, and find positive articles and negative articles as follows. 88 | * Positive articles: we examine articles that anotator clicked (if they clicked a disambiguation page, articles that are linked to the disambiguation page), and include articles that contain any valid answer as positive articles. 89 | * Negative articles: we include all articles that annotators have seen (including just titles). This includes articles that are result of the search engine and all articles linked to the disambiguation page. Among those, articles that do not contain the valid answers are considered as negative articles. 90 | 91 | Once we obtain positive articles and negative articles, we create a set of articles by (1) first including all positive articles, and (2) if the number of positive articles is less than 3, sampling negative articles as follows. 92 | 1. Create a BM25 index using all positive and negative articles. 93 | 2. Compute BM25 scores of each article using the question as a query. 94 | 3. Compute a weight probability using a softmax of BM25 scores. 95 | 4. Sample articles based on the weight probability, until the number of unique articles is 3. 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /codes/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from joblib import Parallel, delayed 6 | 7 | def decode_span_batch(features, scores, tokenizer, max_answer_length, 8 | n_paragraphs=None, topk_answer=1, verbose=False, n_jobs=1, 9 | save_psg_sel_only=False): 10 | assert len(features)==len(scores) 11 | iter=zip(features, scores) 12 | if n_jobs>1: 13 | def f(t): 14 | return decode_span(t[0], tokenizer, t[1][0], t[1][1], t[1][2], max_answer_length, 15 | n_paragraphs=n_paragraphs, topk_answer=topk_answer, 16 | save_psg_sel_only=save_psg_sel_only) 17 | return Parallel(n_jobs=n_jobs)(delayed(f)(t) for t in iter) 18 | if verbose: 19 | iter = tqdm(iter) 20 | predictions = [decode_span(feature, tokenizer, start_logits, end_logits, sel_logits, 21 | max_answer_length, n_paragraphs, topk_answer, save_psg_sel_only) \ 22 | for (feature, (start_logits, end_logits, sel_logits)) in iter] 23 | return predictions 24 | 25 | def decode_span(feature, tokenizer, start_logits_list, end_logits_list, sel_logits_list, 26 | max_answer_length, n_paragraphs=None, topk_answer=1, save_psg_sel_only=False): 27 | all_positive_token_ids, all_positive_input_mask = feature 28 | assert len(start_logits_list)==len(end_logits_list)==len(sel_logits_list) 29 | assert type(sel_logits_list[0])==float 30 | log_softmax_switch_logits_list = _compute_log_softmax(sel_logits_list[:len(all_positive_token_ids)]) 31 | 32 | if save_psg_sel_only: 33 | return np.argsort(-np.array(log_softmax_switch_logits_list)).tolist() 34 | 35 | sorted_logits = sorted(enumerate(zip(start_logits_list, end_logits_list, sel_logits_list)), 36 | key=lambda x: -x[1][2]) 37 | nbest = [] 38 | for passage_index, (start_logits, end_logits, switch_logits) in sorted_logits: 39 | scores = [] 40 | if len(all_positive_token_ids)<=passage_index: 41 | continue 42 | 43 | positive_token_ids = all_positive_token_ids[passage_index] 44 | positive_input_mask = all_positive_input_mask[passage_index] 45 | start_offset = 1 + positive_token_ids.index(tokenizer.sep_token_id) 46 | end_offset = positive_input_mask.index(0) if 0 in positive_input_mask else len(positive_input_mask) 47 | 48 | positive_token_ids = positive_token_ids[start_offset:end_offset] 49 | start_logits = start_logits[start_offset:end_offset] 50 | end_logits = end_logits[start_offset:end_offset] 51 | log_softmax_start_logits = _compute_log_softmax(start_logits) 52 | log_softmax_end_logits = _compute_log_softmax(end_logits) 53 | 54 | for (i, s) in enumerate(start_logits): 55 | for (j, e) in enumerate(end_logits[i:i+max_answer_length]): 56 | scores.append(((i, i+j), s+e)) 57 | 58 | scores = sorted(scores, key=lambda x: x[1], reverse=True) 59 | chosen_span_intervals = [] 60 | 61 | for (start_index, end_index), score in scores: 62 | if end_index < start_index: 63 | continue 64 | length = end_index - start_index + 1 65 | if length > max_answer_length: 66 | continue 67 | if any([start_index<=prev_start_index<=prev_end_index<=end_index or 68 | prev_start_index<=start_index<=end_index<=prev_end_index 69 | for (prev_start_index, prev_end_index) in chosen_span_intervals]): 70 | continue 71 | 72 | answer_text = tokenizer.decode(positive_token_ids[start_index:end_index+1], 73 | skip_special_tokens=True, 74 | clean_up_tokenization_spaces=True).strip() 75 | passage_text = tokenizer.decode(positive_token_ids[:start_index], 76 | skip_special_tokens=True, 77 | clean_up_tokenization_spaces=True).strip() + \ 78 | " " + answer_text + " " + \ 79 | tokenizer.decode(positive_token_ids[end_index+1:], 80 | skip_special_tokens=True, 81 | clean_up_tokenization_spaces=True).strip() 82 | 83 | nbest.append({ 84 | 'text': answer_text, 85 | 'passage_index': passage_index, 86 | 'passage': passage_text, 87 | 'log_softmax': log_softmax_switch_logits_list[passage_index] + \ 88 | log_softmax_start_logits[start_index] + \ 89 | log_softmax_end_logits[end_index]}) 90 | 91 | chosen_span_intervals.append((start_index, end_index)) 92 | if topk_answer>-1 and topk_answer==len(chosen_span_intervals): 93 | break 94 | 95 | if len(nbest)==0: 96 | nbest = [{'text': 'empty', 'log_softmax': -99999, 'passage_index': 0, 'passage': ''}] 97 | 98 | sorted_nbest = sorted(nbest, key=lambda x: -x["log_softmax"]) 99 | 100 | if n_paragraphs is None: 101 | return sorted_nbest[:topk_answer] if topk_answer>-1 else sorted_nbest 102 | else: 103 | return [[pred for pred in sorted_nbest if pred['passage_index'] max_score: 114 | max_score = score 115 | exp_scores = [] 116 | total_sum = 0.0 117 | for score in scores: 118 | x = math.exp(score - max_score) 119 | exp_scores.append(x) 120 | total_sum += x 121 | probs = [] 122 | for score in exp_scores: 123 | probs.append(score / total_sum) 124 | return np.log(probs).tolist() 125 | -------------------------------------------------------------------------------- /codes/PassageData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import gzip 4 | import numpy as np 5 | from collections import defaultdict 6 | 7 | import torch 8 | from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler 9 | 10 | from ambigqa_evaluate_script import normalize_answer 11 | from DataLoader import MyDataLoader 12 | 13 | class PassageData(object): 14 | def __init__(self, logger, args, tokenizer): 15 | self.logger = logger 16 | self.args = args 17 | self.data_path = os.path.join(args.dpr_data_dir, 18 | "data/wikipedia_split/psgs_w100{}.tsv.gz".format("_20200201" if args.wiki_2020 else "")) 19 | 20 | self.passages = None 21 | self.titles = None 22 | self.tokenizer = tokenizer 23 | self.tokenized_data = None 24 | 25 | def load_db(self, subset=None): 26 | if not self.args.skip_db_load: 27 | self.passages = {} 28 | self.titles = {} 29 | with gzip.open(self.data_path, "rb") as f: 30 | _ = f.readline() 31 | offset = 0 32 | for line in f: 33 | if subset is None or offset in subset: 34 | _id, passage, title = line.decode().strip().split("\t") 35 | assert int(_id)-1==offset 36 | self.passages[offset] = passage.lower() 37 | self.titles[offset] = title.lower() 38 | offset += 1 39 | assert subset is None or len(subset)==len(self.titles)==len(self.passages) 40 | self.logger.info("Loaded {} passages".format(len(self.passages))) 41 | 42 | def load_tokenized_data(self, model_name, all=False, do_return=False, subset=None, index=None): 43 | def _get_cache_path(index): 44 | if model_name=="bert": 45 | cache_path = self.data_path.replace(".tsv.gz", "_{}_BertTokenized.pkl".format(index)) 46 | elif model_name=="albert": 47 | cache_path = self.data_path.replace(".tsv.gz", "_{}_AlbertTokenized.pkl".format(index)) 48 | elif model_name=="bart": 49 | cache_path = self.data_path.replace(".tsv.gz", "_{}_BartTokenized.pkl".format(index)) 50 | else: 51 | raise NotImplementedError(model_name) 52 | return cache_path 53 | 54 | if subset is not None and not os.path.exists(_get_cache_path(0)): 55 | assert not self.args.skip_db_load 56 | if self.titles is None or self.passages is None: 57 | self.load_db(subset) 58 | final_tokenized_data = {"input_ids": {}, "attention_mask": {}} 59 | psg_ids = list(subset) 60 | input_data = [self.titles[_id] + " " + self.tokenizer.sep_token + " " + self.passages[_id] 61 | for _id in psg_ids] 62 | tokenized_data = self.tokenizer.batch_encode_plus(input_data, 63 | max_length=128, 64 | pad_to_max_length=model_name in ["albert", "bert"]) 65 | input_ids = {_id: _input_ids 66 | for _id, _input_ids in zip(psg_ids, tokenized_data["input_ids"])} 67 | attention_mask = {_id: _attention_mask 68 | for _id, _attention_mask in zip(psg_ids, tokenized_data["attention_mask"])} 69 | final_tokenized_data = {"input_ids": input_ids, "attention_mask": attention_mask} 70 | elif all: 71 | for index in range(10): 72 | curr_tokenized_data = self.load_tokenized_data(model_name, all=False, do_return=True, subset=subset, index=index) 73 | if index==0: 74 | tokenized_data = curr_tokenized_data 75 | elif subset is None: 76 | tokenized_data["input_ids"] += curr_tokenized_data["input_ids"] 77 | tokenized_data["attention_mask"] += curr_tokenized_data["attention_mask"] 78 | else: 79 | tokenized_data["input_ids"].update(curr_tokenized_data["input_ids"]) 80 | tokenized_data["attention_mask"].update(curr_tokenized_data["attention_mask"]) 81 | final_tokenized_data = tokenized_data 82 | else: 83 | index=self.args.db_index if index is None else index 84 | assert 0<=index<10 85 | cache_path = _get_cache_path(index) 86 | if os.path.exists(cache_path): 87 | with open(cache_path, "rb") as f: 88 | tokenized_data = pkl.load(f) 89 | else: 90 | assert not self.args.skip_db_load 91 | if self.titles is None or self.passages is None: 92 | self.load_db() 93 | # tokenize 2.2M for each thread 94 | min_idx = index*2200000 95 | max_idx = min(len(self.titles), (index+1)*2200000) 96 | self.logger.info("Start tokenizing from {} to {}".format(min_idx, max_idx)) 97 | input_data = [self.titles[_id] + " " + self.tokenizer.sep_token + " " + self.passages[_id] 98 | for _id in range(min_idx, max_idx)] 99 | tokenized_data = self.tokenizer.batch_encode_plus(input_data, 100 | max_length=128, 101 | pad_to_max_length=model_name in ["albert", "bert"]) 102 | with open(cache_path, "wb") as f: 103 | pkl.dump({"input_ids": tokenized_data["input_ids"], 104 | "attention_mask": tokenized_data["attention_mask"]}, f) 105 | 106 | if subset is None: 107 | final_tokenized_data = tokenized_data 108 | else: 109 | # only keep 2200000*i 110 | start, end = 2200000*index, 2200000*(index+1) 111 | final_tokenized_data = {"input_ids": {}, "attention_mask": {}} 112 | for passage_idx in subset: 113 | if start<=passage_idx0 for positive_input_ids in self.positive_input_ids]) 92 | 93 | self.is_training = is_training 94 | self.train_M = train_M 95 | self.test_M = test_M 96 | 97 | def __len__(self): 98 | return len(self.positive_input_ids) 99 | 100 | def __getitem__(self, idx): 101 | if not self.is_training: 102 | input_ids = self.positive_input_ids[idx][:self.test_M] 103 | input_mask = self.positive_input_mask[idx][:self.test_M] 104 | token_type_ids = self.positive_token_type_ids[idx][:self.test_M] 105 | return [self._pad(t, self.test_M) for t in [input_ids, input_mask, token_type_ids]] 106 | 107 | # sample positive 108 | positive_idx = np.random.choice(len(self.positive_input_ids[idx])) 109 | #positive_idx = 0 110 | positive_input_ids = self.positive_input_ids[idx][positive_idx] 111 | positive_input_mask = self.positive_input_mask[idx][positive_idx] 112 | positive_token_type_ids = self.positive_token_type_ids[idx][positive_idx] 113 | positive_start_positions = self.positive_start_positions[idx][positive_idx] 114 | positive_end_positions = self.positive_end_positions[idx][positive_idx] 115 | positive_answer_mask = self.positive_answer_mask[idx][positive_idx] 116 | 117 | # sample negatives 118 | negative_idxs = np.random.permutation(range(len(self.negative_input_ids[idx])))[:self.train_M-1] 119 | negative_input_ids = [self.negative_input_ids[idx][i] for i in negative_idxs] 120 | negative_input_mask = [self.negative_input_mask[idx][i] for i in negative_idxs] 121 | negative_token_type_ids = [self.negative_token_type_ids[idx][i] for i in negative_idxs] 122 | negative_input_ids, negative_input_mask, negative_token_type_ids = \ 123 | [self._pad(t, self.train_M-1) for t in [negative_input_ids, negative_input_mask, negative_token_type_ids]] 124 | 125 | # aggregate 126 | input_ids = torch.cat([positive_input_ids.unsqueeze(0), negative_input_ids], dim=0) 127 | input_mask = torch.cat([positive_input_mask.unsqueeze(0), negative_input_mask], dim=0) 128 | token_type_ids = torch.cat([positive_token_type_ids.unsqueeze(0), negative_token_type_ids], dim=0) 129 | start_positions, end_positions, answer_mask = \ 130 | [self._pad([t], self.train_M) for t in [positive_start_positions, 131 | positive_end_positions, 132 | positive_answer_mask]] 133 | return input_ids, input_mask, token_type_ids, start_positions, end_positions, answer_mask 134 | 135 | def tensorize(self, key): 136 | return [torch.LongTensor(t) for t in self.data[key]] if key in self.data.keys() else None 137 | 138 | def _pad(self, input_ids, M): 139 | if len(input_ids)==0: 140 | return torch.zeros((M, self.negative_input_ids[0].size(1)), dtype=torch.long) 141 | if type(input_ids)==list: 142 | input_ids = torch.stack(input_ids) 143 | if len(input_ids)==M: 144 | return input_ids 145 | return torch.cat([input_ids, 146 | torch.zeros((M-input_ids.size(0), input_ids.size(1)), dtype=torch.long)], 147 | dim=0) 148 | 149 | class MyDataLoader(DataLoader): 150 | 151 | def __init__(self, args, dataset, is_training, batch_size=None): 152 | if is_training: 153 | sampler=RandomSampler(dataset) 154 | batch_size = args.train_batch_size if batch_size is None else batch_size 155 | else: 156 | sampler=SequentialSampler(dataset) 157 | batch_size = args.predict_batch_size if batch_size is None else batch_size 158 | 159 | super(MyDataLoader, self).__init__(dataset, sampler=sampler, batch_size=batch_size) 160 | 161 | 162 | -------------------------------------------------------------------------------- /pycocoevalcap/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def xrange(*args): 24 | return range(*args) 25 | 26 | def precook(s, n=4, out=False): 27 | """Takes a string as input and returns an object that can be given to 28 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 29 | can take string arguments as well.""" 30 | words = s.split() 31 | counts = defaultdict(int) 32 | for k in xrange(1,n+1): 33 | for i in xrange(len(words)-k+1): 34 | ngram = tuple(words[i:i+k]) 35 | counts[ngram] += 1 36 | return (len(words), counts) 37 | 38 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 39 | '''Takes a list of reference sentences for a single segment 40 | and returns an object that encapsulates everything that BLEU 41 | needs to know about them.''' 42 | 43 | reflen = [] 44 | maxcounts = {} 45 | for ref in refs: 46 | rl, counts = precook(ref, n) 47 | reflen.append(rl) 48 | for (ngram,count) in counts.items(): 49 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 50 | 51 | # Calculate effective reference sentence length. 52 | if eff == "shortest": 53 | reflen = min(reflen) 54 | elif eff == "average": 55 | reflen = float(sum(reflen))/len(reflen) 56 | 57 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 58 | 59 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 60 | 61 | return (reflen, maxcounts) 62 | 63 | def cook_test(test, ref, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | 67 | reflen, refmaxcounts = ref 68 | testlen, counts = precook(test, n, True) 69 | 70 | result = {} 71 | 72 | # Calculate effective reference sentence length. 73 | 74 | if eff == "closest": 75 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 76 | else: ## i.e., "average" or "shortest" or None 77 | result["reflen"] = reflen 78 | 79 | result["testlen"] = testlen 80 | 81 | result["guess"] = [max(0,testlen-k+1) for k in xrange(1,n+1)] 82 | 83 | result['correct'] = [0]*n 84 | for (ngram, count) in counts.items(): 85 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 86 | 87 | return result 88 | 89 | class BleuScorer(object): 90 | """Bleu scorer. 91 | """ 92 | 93 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 94 | # special_reflen is used in oracle (proportional effective ref len for a node). 95 | 96 | def copy(self): 97 | ''' copy the refs.''' 98 | new = BleuScorer(n=self.n) 99 | new.ctest = copy.copy(self.ctest) 100 | new.crefs = copy.copy(self.crefs) 101 | new._score = None 102 | return new 103 | 104 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 105 | ''' singular instance ''' 106 | 107 | self.n = n 108 | self.crefs = [] 109 | self.ctest = [] 110 | self.cook_append(test, refs) 111 | self.special_reflen = special_reflen 112 | 113 | def cook_append(self, test, refs): 114 | '''called by constructor and __iadd__ to avoid creating new instances.''' 115 | 116 | if refs is not None: 117 | self.crefs.append(cook_refs(refs)) 118 | if test is not None: 119 | cooked_test = cook_test(test, self.crefs[-1]) 120 | self.ctest.append(cooked_test) ## N.B.: -1 121 | else: 122 | self.ctest.append(None) # lens of crefs and ctest have to match 123 | 124 | self._score = None ## need to recompute 125 | 126 | def ratio(self, option=None): 127 | self.compute_score(option=option) 128 | return self._ratio 129 | 130 | def score_ratio(self, option=None): 131 | '''return (bleu, len_ratio) pair''' 132 | return (self.fscore(option=option), self.ratio(option=option)) 133 | 134 | def score_ratio_str(self, option=None): 135 | return "%.4f (%.2f)" % self.score_ratio(option) 136 | 137 | def reflen(self, option=None): 138 | self.compute_score(option=option) 139 | return self._reflen 140 | 141 | def testlen(self, option=None): 142 | self.compute_score(option=option) 143 | return self._testlen 144 | 145 | def retest(self, new_test): 146 | if type(new_test) is str: 147 | new_test = [new_test] 148 | assert len(new_test) == len(self.crefs), new_test 149 | self.ctest = [] 150 | for t, rs in zip(new_test, self.crefs): 151 | self.ctest.append(cook_test(t, rs)) 152 | self._score = None 153 | 154 | return self 155 | 156 | def rescore(self, new_test): 157 | ''' replace test(s) with new test(s), and returns the new score.''' 158 | 159 | return self.retest(new_test).compute_score() 160 | 161 | def size(self): 162 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 163 | return len(self.crefs) 164 | 165 | def __iadd__(self, other): 166 | '''add an instance (e.g., from another sentence).''' 167 | 168 | if type(other) is tuple: 169 | ## avoid creating new BleuScorer instances 170 | self.cook_append(other[0], other[1]) 171 | else: 172 | assert self.compatible(other), "incompatible BLEUs." 173 | self.ctest.extend(other.ctest) 174 | self.crefs.extend(other.crefs) 175 | self._score = None ## need to recompute 176 | 177 | return self 178 | 179 | def compatible(self, other): 180 | return isinstance(other, BleuScorer) and self.n == other.n 181 | 182 | def single_reflen(self, option="average"): 183 | return self._single_reflen(self.crefs[0][0], option) 184 | 185 | def _single_reflen(self, reflens, option=None, testlen=None): 186 | 187 | if option == "shortest": 188 | reflen = min(reflens) 189 | elif option == "average": 190 | reflen = float(sum(reflens))/len(reflens) 191 | elif option == "closest": 192 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 193 | else: 194 | assert False, "unsupported reflen option %s" % option 195 | 196 | return reflen 197 | 198 | def recompute_score(self, option=None, verbose=0): 199 | self._score = None 200 | return self.compute_score(option, verbose) 201 | 202 | def compute_score(self, option=None, verbose=0): 203 | n = self.n 204 | small = 1e-9 205 | tiny = 1e-15 ## so that if guess is 0 still return 0 206 | bleu_list = [[] for _ in range(n)] 207 | 208 | if self._score is not None: 209 | return self._score 210 | 211 | if option is None: 212 | option = "average" if len(self.crefs) == 1 else "closest" 213 | 214 | self._testlen = 0 215 | self._reflen = 0 216 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 217 | 218 | # for each sentence 219 | for comps in self.ctest: 220 | testlen = comps['testlen'] 221 | self._testlen += testlen 222 | 223 | if self.special_reflen is None: ## need computation 224 | reflen = self._single_reflen(comps['reflen'], option, testlen) 225 | else: 226 | reflen = self.special_reflen 227 | 228 | self._reflen += reflen 229 | 230 | for key in ['guess','correct']: 231 | for k in xrange(n): 232 | totalcomps[key][k] += comps[key][k] 233 | 234 | # append per image bleu score 235 | bleu = 1. 236 | for k in xrange(n): 237 | bleu *= (float(comps['correct'][k]) + tiny) \ 238 | /(float(comps['guess'][k]) + small) 239 | bleu_list[k].append(bleu ** (1./(k+1))) 240 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 241 | if ratio < 1: 242 | for k in xrange(n): 243 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 244 | 245 | if verbose > 1: 246 | print (comps, reflen) 247 | 248 | totalcomps['reflen'] = self._reflen 249 | totalcomps['testlen'] = self._testlen 250 | 251 | bleus = [] 252 | bleu = 1. 253 | for k in xrange(n): 254 | bleu *= float(totalcomps['correct'][k] + tiny) \ 255 | / (totalcomps['guess'][k] + small) 256 | bleus.append(bleu ** (1./(k+1))) 257 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 258 | if ratio < 1: 259 | for k in xrange(n): 260 | bleus[k] *= math.exp(1 - 1/ratio) 261 | 262 | if verbose > 0: 263 | print (totalcomps) 264 | print ("ratio:", ratio) 265 | 266 | self._score = bleus 267 | return self._score, bleu_list 268 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AmbigQA/AmbigNQ README 2 | 3 | This is the repository documenting the paper 4 | [AmbigQA: Answering Ambiguous Open-domain Questions](https://arxiv.org/abs/2004.10645) (EMNLP 2020) 5 | by Sewon Min, Julian Michael, Hannaneh Hajishirzi, and Luke Zettlemoyer. 6 | 7 | * [Website](https://nlp.cs.washington.edu/ambigqa) 8 | * Read the [paper](https://arxiv.org/abs/2004.10645) 9 | * Download the dataset: [AmbigNQ light ver.](https://nlp.cs.washington.edu/ambigqa/data/ambignq_light.zip) / [AmbigNQ full ver.](https://nlp.cs.washington.edu/ambigqa/data/ambignq.zip) / [AmbigNQ evidence ver. *(new!)*](https://nlp.cs.washington.edu/ambigqa/data/ambignq_with_evidence_articles.zip) / [NQ-open](https://nlp.cs.washington.edu/ambigqa/data/nqopen.zip) 10 | * **Update (07/2020)**: Try running [baseline codes][codes] 11 | * **Update (11/2021)**: We released semi-oracle evidence passages for researchers interested in multi-answer extraction and disambiguation rather than retrieval. Please read [evidence.md](evidence.md) for details. 12 | 13 | ## Content 14 | 1. [Citation](#citation) 15 | 2. [Dataset Contents](#dataset-contents) 16 | * [AmbigNQ](#ambignq) 17 | * [AmbigNQ with evidence articles](#ambignq-with-evidence-articles) 18 | * [NQ-open](#nq-open) 19 | * [Additional resources](#additional-resources) 20 | 3. [Evaluation script](#evaluation-script) 21 | 4. [Baseline codes](#baseline-codes) 22 | 5. [Leaderboard submission guide](#leaderboard-submission-guide) 23 | 24 | ## Citation 25 | 26 | If you find the AmbigQA task or AmbigNQ dataset useful, please cite our paper: 27 | ``` 28 | @inproceedings{ min2020ambigqa, 29 | title={ {A}mbig{QA}: Answering Ambiguous Open-domain Questions }, 30 | author={ Min, Sewon and Michael, Julian and Hajishirzi, Hannaneh and Zettlemoyer, Luke }, 31 | booktitle={ EMNLP }, 32 | year={2020} 33 | } 34 | ``` 35 | 36 | Please also make sure to credit and cite the creators of Natural Questions, 37 | the dataset which we built ours off of: 38 | ``` 39 | @article{ kwiatkowski2019natural, 40 | title={ Natural questions: a benchmark for question answering research}, 41 | author={ Kwiatkowski, Tom and Palomaki, Jennimaria and Redfield, Olivia and Collins, Michael and Parikh, Ankur and Alberti, Chris and Epstein, Danielle and Polosukhin, Illia and Devlin, Jacob and Lee, Kenton and others }, 42 | journal={ Transactions of the Association for Computational Linguistics }, 43 | year={ 2019 } 44 | } 45 | ``` 46 | 47 | 48 | ## Dataset Contents 49 | 50 | ### AmbigNQ 51 | 52 | [Click here to download the light version of the data (1.1M)](https://nlp.cs.washington.edu/ambigqa/data/ambignq_light.zip). 53 | 54 | [Click here to download the full version of the data (18M)](https://nlp.cs.washington.edu/ambigqa/data/ambignq.zip). 55 | 56 | We provide two distributions of our new dataset AmbigNQ: a `full` version with all annotation metadata 57 | and a `light` version with only inputs and outputs. 58 | 59 | The full version contains 60 | - train.json (47M) 61 | - dev.json (17M) 62 | 63 | The light version contains 64 | - train_light.json (3.3M) 65 | - dev_light.json (977K) 66 | 67 | `train.json` and `dev.json` files contain a list of dictionary that represents a single datapoint, with the following keys 68 | 69 | - `id` (string): an identifier for the question, consistent with the original NQ dataset. 70 | - `question` (string): a question. This is identical to the question in the original NQ except we postprocess the string to start uppercase and end with a question mark. 71 | - `annotations` (a list of dictionaries): a list of all acceptable outputs, where each output is a dictionary that represents either a single answer or multiple question-answer pairs. 72 | - `type`: `singleAnswer` or `multipleQAs` 73 | - (If `type` is `singleAnswer`) `answer`: a list of strings that are all acceptable answer texts 74 | - (If `type` is `multipleQAs`) `qaPairs`: a list of dictionaries with `question` and `answer`. `question` is a string, and `answer` is a list of strings that are all acceptable answer texts 75 | - `viewed_doc_titles` (a list of strings): a list of titles of Wikipedia pages viewed by crowdworkers during annotations. This is an underestimate, since Wikipedia pages viewed through hyperlinks are not included. Note that this should not be the input to a system. It is fine to use it as extra supervision, but please keep in mind that it is an underestimate. 76 | - `used_queries` (a list of dictionaries): a list of dictionaries containing the search queries and results that were used by crowdworkers during annotations. Each dictionary contains `query` (a string) and `results` (a list of dictionaries containing `title` and `snippet`). Search results are obtained through the Google Search API restricted to Wikipedia (details in the paper). Note that this should not be the input to a system. It is fine to use it as extra supervision. 77 | - `nq_answer` (a list of strings): the list of annotated answers in the original NQ. 78 | - `nq_doc_title` (string): an associated Wikipedia page title in the original NQ. 79 | 80 | `{train|dev}_light.json` are formatted the same way, but only contain `id`, `question` and `annotations`. 81 | 82 | 83 | ### AmbigNQ with evidence articles 84 | 85 | [Click here to download the data (575M)](https://nlp.cs.washington.edu/ambigqa/data/ambignq_with_evidence_articles.zip). 86 | 87 | Please read [evidence.md](evidence.md) for details. 88 | 89 | The evidence version contains 90 | - train_with_evidence_articles.json (1.2G) 91 | - dev_with_evidence_articles.json (241M) 92 | - test_with_evidence_articles_without_answers.json (245M) 93 | 94 | They contain a list of dictionary that represents a single datapoint, just as the above. In addition to `id`, `question` and `annotations` (omitted in the test data), each dictionary contains 95 | 96 | - `articles_plain_text`: a list of articles in the plain text. 97 | - `articles_html_text`: a list of articles in the HTML text. 98 | 99 | *In order to evaluate your model on the test data*: Follow [Leaderboard submission guide](#leaderboard-submission-guide) to submit your model predictions on the test questions. 100 | 101 | 102 | ### NQ-open 103 | 104 | [Click here to download the data (3.9M)](https://nlp.cs.washington.edu/ambigqa/data/nqopen.zip). 105 | 106 | 107 | We release our split of NQ-open, for comparison and use as weak supervision: 108 | 109 | - nqopen-train.json (9.7M) 110 | - nqopen-dev.json (1.1M) 111 | - nqopen-test.json (489K) 112 | 113 | Each file contains a list of dictionaries representing a single datapoint, with the following keys 114 | 115 | - `id` (string): an identifier that is consistent with the original NQ. 116 | - `question` (string): a question. 117 | - `answer` (a list of strings): a list of acceptable answer texts. 118 | 119 | ### Additional resources 120 | 121 | - `docs.db`: sqlite db that is consistent with [DrQA](https://github.com/facebookresearch/DrQA); containing plain text only, no disambiguation pages 122 | - `docs-html.db`: sqlite db that is consistent with [DrQA](https://github.com/facebookresearch/DrQA), containing html, no disambiguation pages 123 | - Top 100 Wikipedia passages retrieved from Dense Passage Retrieval 124 | 125 | ## Evaluation script 126 | 127 | The evaluation script is [here](https://github.com/shmsw25/AmbigQA/blob/master/ambigqa_evaluate_script.py). 128 | It has been tested on Python 3.5 and 3.6. 129 | 130 | Step 1. Follow the instruction in [coco-caption](https://github.com/tylin/coco-caption) for setup. If you want to compute F1 answer only, you can skip this. 131 | 132 | Step 2. Run the evaluation script via 133 | ``` 134 | python ambigqa_evaluation_script.py --reference_path {reference data file} --prediction_path {prediction file} 135 | ``` 136 | 137 | The prediction should be a json file with a dictionary that has `id` as a key and a prediction object as a value. A prediction object should be in the following format. 138 | 139 | - a list of strings (answers), if you only want to compute answer F1. 140 | - a list of dictionaries with "question" and "answer" as keys, if you want to compute full metrics. 141 | 142 | Example: 143 | 144 | To only compute answer F1: 145 | ``` 146 | { 147 | "-6631842452804060768": ["1624", "1664"], 148 | ... 149 | } 150 | ``` 151 | 152 | To compute full metrics: 153 | ``` 154 | { 155 | "-6631842452804060768": [ 156 | {"question": "When was city of new york city founded with dutch protection?", "answer": "1624"}, 157 | {"question": "When was city of new york city founded and renamed with english name?", "answer": "1664"} 158 | ], 159 | ... 160 | } 161 | ``` 162 | 163 | ## Baseline codes 164 | 165 | Try running [baseline codes][codes] (instructions in its README), which includes DPR retrieval, DPR reader and SpanSeqGen. This includes codes and scripts for both NQ-open and AmbigNQ. 166 | 167 | 168 | ## Leaderboard submission guide 169 | 170 | Create a prediction file using the questions on NQ-open test data, and email it to [Sewon Min](mailto:sewon@cs.washington.edu). 171 | 172 | Please make sure you include the following in the email: 173 | 174 | - test prediction file. Make sure that the format is in line with the official evaluation script. As you are not supposed to know which subset of NQ-open test set is AmbigNQ, your file should contain predictions for all NQ-open test examples. 175 | - whether the prediction is in the standard setting or zero-shot setting, i.e. whether the model was trained on AmbigNQ train data or not. 176 | - the name of the model 177 | - [optional] dev prediction file and expected dev results. This is to double-check there is no unexpected problem. 178 | - [optional] the institution, and link to the paper/code/demo. They can be updated later. 179 | 180 | If you use semi-oracle articles described [here](evidence.md), please make sure to mention it. 181 | 182 | 183 | Notes 184 | - Models will be sorted by `F1 answer (all) + F1 edit-f1` (standard) or `F1 answer (all)` (zero-shot). 185 | - Please allow for up to one week ahead of time before getting the test numbers and/or your numbers appear on the leaderboard. 186 | - We limit the number of submissions to be 20 per year and 5 per month. 187 | 188 | 189 | [codes]: https://github.com/shmsw25/AmbigQA/tree/master/codes 190 | 191 | 192 | 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /codes/cli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import argparse 23 | import logging 24 | 25 | import random 26 | import numpy as np 27 | import torch 28 | 29 | from run import run 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | 34 | ## Basic parameters 35 | parser.add_argument("--task", default="qa", choices=["dpr", "qa", "qg"], 36 | type=str) 37 | parser.add_argument("--train_file", default="data/nqopen/train.json", 38 | type=str) 39 | parser.add_argument("--predict_file", default="data/nqopen/dev.json", 40 | type=str) 41 | parser.add_argument("--output_dir", default=None, type=str, required=True) 42 | parser.add_argument("--dpr_data_dir", default="/checkpoint/sewonmin/dpr", type=str, 43 | help="path where you downloaded DPR related files" 44 | "(Wikipedia DB, checkpoints, etc)") 45 | parser.add_argument("--do_train", action='store_true') 46 | parser.add_argument("--do_predict", action='store_true') 47 | parser.add_argument("--do_prepro_only", action='store_true') 48 | parser.add_argument("--ambigqa", action='store_true', 49 | help="[For AmbigQA] specify if you are experimenting with AmbigQA") 50 | parser.add_argument("--skip_inference", action='store_true', 51 | help="Instead of periodically evaluating on the dev set and" 52 | "only storing the best checkpoint, store all checkpoints" 53 | "without evaluation on the dev set;" 54 | "this saves time while requires more disk memory") 55 | parser.add_argument("--skip_db_load", action='store_true') 56 | parser.add_argument("--db_index", default=-1, type=int) 57 | parser.add_argument("--wiki_2020", action='store_true', 58 | help="[For AmbigQA] Use Wikipedia dump from 02/01/2020" 59 | "instead of 12/20/2018") 60 | 61 | ## Model parameters 62 | parser.add_argument('--bert_name', type=str, default='bert-base-uncased') 63 | parser.add_argument("--cache_dir", default="", type=str, 64 | help="Where do you want to store the pre-trained models downloaded from s3") 65 | parser.add_argument("--checkpoint", type=str, 66 | help="Initial checkpoint; when not specified, it will use pretrained BERT/BART models", \ 67 | default=None) 68 | parser.add_argument("--resume_global_step", type=int, default=0) 69 | parser.add_argument("--do_lowercase", action='store_true', default=True) 70 | 71 | # Preprocessing-related parameters 72 | parser.add_argument('--max_passage_length', type=int, default=200) 73 | parser.add_argument('--max_question_length', type=int, default=32) 74 | parser.add_argument('--train_M', type=int, default=24, 75 | help="# of passages / question in DPR reader") 76 | parser.add_argument('--test_M', type=int, default=50, 77 | help="# of passages / question in DPR reader") 78 | parser.add_argument("--max_n_answers", default=10, type=int) 79 | parser.add_argument('--n_jobs', type=int, default=12) 80 | parser.add_argument("--append_another_bos", action='store_true', 81 | help="For SpanSeqGen, append extra BOS token in the" 82 | "beginning of the sequence (by default, automatically" 83 | "set to `True` when using BART)") 84 | parser.add_argument("--psg_sel_dir", type=str, default=None, 85 | help="For SpanSeqGen, DPR reader path which contains" 86 | "passage selection predictions") 87 | parser.add_argument("--discard_not_found_answers", action='store_true', 88 | help="For SpanSeqGen, do not learn to generate answers" 89 | "if they are not found in DPR passages") 90 | parser.add_argument("--consider_order_for_multiple_answers", action='store_true', 91 | help="[For AmbigQA] Generate answers in the same order" 92 | "as they appear in DPR passages") 93 | parser.add_argument("--nq_answer_as_prefix", action='store_true', 94 | help="[For AmbigQA] For co-training, use known answer as prefix" 95 | "to generate extra answers") 96 | 97 | # Training-related parameters 98 | parser.add_argument("--train_batch_size", default=40, type=int, 99 | help="Batch size per GPU/CPU for training.") 100 | parser.add_argument("--predict_batch_size", default=400, type=int, 101 | help="Batch size per GPU/CPU for evaluation.") 102 | 103 | parser.add_argument("--learning_rate", default=1e-5, type=float, 104 | help="The initial learning rate for Adam.") 105 | parser.add_argument("--warmup_proportion", default=0.01, type=float, 106 | help="Weight decay if we apply some.") 107 | parser.add_argument("--weight_decay", default=0.0, type=float, 108 | help="Weight deay if we apply some.") 109 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 110 | help="Epsilon for Adam optimizer.") 111 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 112 | help="Max gradient norm.") 113 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, 114 | help="Max gradient norm.") 115 | parser.add_argument("--num_train_epochs", default=10000.0, type=float, 116 | help="Total number of training epochs to perform.") 117 | parser.add_argument("--warmup_steps", default=0, type=int, 118 | help="Linear warmup over warmup_steps.") 119 | parser.add_argument('--wait_step', type=int, default=10) 120 | 121 | ## Evaluation-related parameters 122 | parser.add_argument("--n_best_size", default=1, type=int, 123 | help="The total number of n-best predictions to generate in the nbest_predictions.json output file.") 124 | parser.add_argument("--max_answer_length", default=10, type=int, 125 | help="The maximum length of an answer that can be generated. This is needed because the start " 126 | "and end predictions are not conditioned on one another.") 127 | parser.add_argument("--verbose", action='store_true', 128 | help="If true, all of the warnings related to data processing will be printed. " 129 | "A number of warnings are expected for a normal SQuAD evaluation.") 130 | parser.add_argument('--eval_period', type=int, default=400, 131 | help="Evaluate & save model") 132 | parser.add_argument('--prefix', type=str, default=None, 133 | help="Prefix for saving predictions; split name (e.g. `dev` or `test`) if not specified") 134 | parser.add_argument('--n_paragraphs', type=str, default=None, 135 | help="A list of numbers separated by comma, for ablations on number of passages per question (e.g. `20,50,100`)") 136 | parser.add_argument("--save_psg_sel_only", action='store_true', 137 | help="For DPR reader, only save the passage selection predictions without span predictions (mainly for preprocessing for SpanSeqGen)") 138 | parser.add_argument('--topk_answer', type=int, default=1, 139 | help="# of top answers per question to save") 140 | 141 | ## Other parameters 142 | parser.add_argument('--debug', action='store_true', 143 | help="Use a subset of data for debugging") 144 | parser.add_argument('--overwrite_output_dir', action='store_true', 145 | help="Overwrite the content of the output directory") 146 | parser.add_argument('--overwrite_cache', action='store_true', 147 | help="Overwrite the cached training and evaluation sets") 148 | parser.add_argument('--seed', type=int, default=42, 149 | help="random seed for initialization") 150 | parser.add_argument("--local_rank", type=int, default=-1, 151 | help="local_rank for distributed training on gpus") 152 | parser.add_argument('--fp16', action='store_true', 153 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 154 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 155 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 156 | "See details at https://nvidia.github.io/apex/amp.html") 157 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 158 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 159 | args = parser.parse_args() 160 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 161 | print("Output directory () already exists and is not empty.") 162 | if not os.path.exists(args.output_dir): 163 | os.makedirs(args.output_dir, exist_ok=True) 164 | 165 | ##### Start writing logs 166 | 167 | log_filename = "{}log.txt".format("" if args.do_train else "eval_") 168 | 169 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 170 | datefmt='%m/%d/%Y %H:%M:%S', 171 | level=logging.INFO, 172 | handlers=[logging.FileHandler(os.path.join(args.output_dir, log_filename)), 173 | logging.StreamHandler()]) 174 | logger = logging.getLogger(__name__) 175 | logger.info(args) 176 | logger.info(args.output_dir) 177 | 178 | random.seed(args.seed) 179 | np.random.seed(args.seed) 180 | torch.manual_seed(args.seed) 181 | args.n_gpu = torch.cuda.device_count() 182 | 183 | if args.n_gpu > 0: 184 | torch.cuda.manual_seed_all(args.seed) 185 | 186 | if not args.do_train and not args.do_predict: 187 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 188 | 189 | if args.do_train: 190 | if not args.train_file: 191 | raise ValueError("If `do_train` is True, then `train_file` must be specified.") 192 | if not args.predict_file: 193 | raise ValueError("If `do_train` is True, then `predict_file` must be specified.") 194 | 195 | if args.do_predict: 196 | if not args.predict_file: 197 | raise ValueError("If `do_predict` is True, then `predict_file` must be specified.") 198 | 199 | logger.info("Using {} gpus".format(args.n_gpu)) 200 | 201 | if args.bert_name.startswith("bart") or args.bert_name.startswith("t5"): 202 | args.is_seq2seq = True 203 | elif args.bert_name.startswith("bert") or args.bert_name.startswith("roberta") or args.bert_name.startswith("albert"): 204 | args.is_seq2seq = False 205 | else: 206 | raise NotImplementedError("Pretrained model not recognized: {}".format(args.bert_name)) 207 | run(args, logger) 208 | 209 | if __name__=='__main__': 210 | main() 211 | -------------------------------------------------------------------------------- /codes/download_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line tool to download various preprocessed data sources & checkpoints for DPR 10 | """ 11 | 12 | import gzip 13 | import os 14 | import pathlib 15 | 16 | import argparse 17 | import wget 18 | 19 | NQ_LICENSE_FILES = [ 20 | 'https://dl.fbaipublicfiles.com/dpr/nq_license/LICENSE', 21 | 'https://dl.fbaipublicfiles.com/dpr/nq_license/README', 22 | ] 23 | 24 | RESOURCES_MAP = { 25 | # Wikipedia DB 2018/12/20 (provided by DPR) 26 | 'data.wikipedia_split.psgs_w100': { 27 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz', 28 | 'original_ext': '.tsv.gz', 29 | 'compressed': False, 30 | 'desc': 'Entire wikipedia passages set obtain by splitting all pages into 100-word segments (no overlap)' 31 | }, 32 | 33 | # Wikipedia DB 2020/02/01 (provided by DPR) 34 | 'data.wikipedia_split.psgs_w100_20200201': { 35 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/psgs_w100_20200201.tsv.gz', 36 | 'original_ext': '.tsv.gz', 37 | 'compressed': False, 38 | 'desc': 'Entire wikipedia passages set obtain by splitting all pages into 100-word segments (no overlap)' 39 | }, 40 | 41 | # QA data / checkpoints provided by DPR 42 | 'data.retriever.nq-dev': { 43 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz', 44 | 'original_ext': '.json', 45 | 'compressed': True, 46 | 'desc': 'NQ dev subset with passages pools for the Retriever train time validation', 47 | 'license_files': NQ_LICENSE_FILES, 48 | }, 49 | 50 | 'data.retriever.nq-train': { 51 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz', 52 | 'original_ext': '.json', 53 | 'compressed': True, 54 | 'desc': 'NQ train subset with passages pools for the Retriever training', 55 | 'license_files': NQ_LICENSE_FILES, 56 | }, 57 | 58 | 'data.retriever.qas.nq-dev': { 59 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-dev.qa.csv', 60 | 'original_ext': '.csv', 61 | 'compressed': False, 62 | 'desc': 'NQ dev subset for Retriever validation and IR results generation', 63 | 'license_files': NQ_LICENSE_FILES, 64 | }, 65 | 66 | 'data.retriever.qas.nq-test': { 67 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-test.qa.csv', 68 | 'original_ext': '.csv', 69 | 'compressed': False, 70 | 'desc': 'NQ test subset for Retriever validation and IR results generation', 71 | 'license_files': NQ_LICENSE_FILES, 72 | }, 73 | 74 | 'data.retriever.qas.nq-train': { 75 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-train.qa.csv', 76 | 'original_ext': '.csv', 77 | 'compressed': False, 78 | 'desc': 'NQ train subset for Retriever validation and IR results generation', 79 | 'license_files': NQ_LICENSE_FILES, 80 | }, 81 | 82 | 'data.gold_passages_info.nq_train': { 83 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-train_gold_info.json.gz', 84 | 'original_ext': '.json', 85 | 'compressed': True, 86 | 'desc': 'Original NQ (our train subset) gold positive passages and alternative question tokenization', 87 | 'license_files': NQ_LICENSE_FILES, 88 | }, 89 | 90 | 'data.gold_passages_info.nq_dev': { 91 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-dev_gold_info.json.gz', 92 | 'original_ext': '.json', 93 | 'compressed': True, 94 | 'desc': 'Original NQ (our dev subset) gold positive passages and alternative question tokenization', 95 | 'license_files': NQ_LICENSE_FILES, 96 | }, 97 | 98 | 'data.gold_passages_info.nq_test': { 99 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-test_gold_info.json.gz', 100 | 'original_ext': '.json', 101 | 'compressed': True, 102 | 'desc': 'Original NQ (our test, original dev subset) gold positive passages and alternative question ' 103 | 'tokenization', 104 | 'license_files': NQ_LICENSE_FILES, 105 | }, 106 | 107 | 'data.retriever_results.nq.single.wikipedia_passages': { 108 | 's3_url': ['https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single/nq/wiki_passages_{}'.format(i) for i in 109 | range(50)], 110 | 'original_ext': '.pkl', 111 | 'compressed': False, 112 | 'desc': 'Encoded wikipedia files using a biencoder checkpoint(' 113 | 'checkpoint.retriever.single.nq.bert-base-encoder) trained on NQ dataset ' 114 | }, 115 | 116 | 'data.retriever_results.nq.single.test': { 117 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-test.json.gz', 118 | 'original_ext': '.json', 119 | 'compressed': True, 120 | 'desc': 'Retrieval results of NQ test dataset for the encoder trained on NQ', 121 | 'license_files': NQ_LICENSE_FILES, 122 | }, 123 | 'data.retriever_results.nq.single.dev': { 124 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-dev.json.gz', 125 | 'original_ext': '.json', 126 | 'compressed': True, 127 | 'desc': 'Retrieval results of NQ dev dataset for the encoder trained on NQ', 128 | 'license_files': NQ_LICENSE_FILES, 129 | }, 130 | 'data.retriever_results.nq.single.train': { 131 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-train.json.gz', 132 | 'original_ext': '.json', 133 | 'compressed': True, 134 | 'desc': 'Retrieval results of NQ train dataset for the encoder trained on NQ', 135 | 'license_files': NQ_LICENSE_FILES, 136 | }, 137 | 138 | 'checkpoint.retriever.single.nq.bert-base-encoder': { 139 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/checkpoint/retriever/single/nq/hf_bert_base.cp', 140 | 'original_ext': '.cp', 141 | 'compressed': False, 142 | 'desc': 'Biencoder weights trained on NQ data and HF bert-base-uncased model' 143 | }, 144 | 145 | 'checkpoint.retriever.multiset.bert-base-encoder': { 146 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/checkpoint/retriver/multiset/hf_bert_base.cp', 147 | 'original_ext': '.cp', 148 | 'compressed': False, 149 | 'desc': 'Biencoder weights trained on multi set data and HF bert-base-uncased model' 150 | }, 151 | 152 | 'data.reader.nq.single.train': { 153 | 's3_url': ['https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/train.{}.pkl'.format(i) for i in range(8)], 154 | 'original_ext': '.pkl', 155 | 'compressed': False, 156 | 'desc': 'Reader model NQ train dataset input data preprocessed from retriever results (also trained on NQ)', 157 | 'license_files': NQ_LICENSE_FILES, 158 | }, 159 | 160 | 'data.reader.nq.single.dev': { 161 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/dev.0.pkl', 162 | 'original_ext': '.pkl', 163 | 'compressed': False, 164 | 'desc': 'Reader model NQ dev dataset input data preprocessed from retriever results (also trained on NQ)', 165 | 'license_files': NQ_LICENSE_FILES, 166 | }, 167 | 168 | 'data.reader.nq.single.test': { 169 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/test.0.pkl', 170 | 'original_ext': '.pkl', 171 | 'compressed': False, 172 | 'desc': 'Reader model NQ test dataset input data preprocessed from retriever results (also trained on NQ)', 173 | 'license_files': NQ_LICENSE_FILES, 174 | }, 175 | 176 | 'checkpoint.reader.nq-single.hf-bert-base': { 177 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-single/hf_bert_base.cp', 178 | 'original_ext': '.cp', 179 | 'compressed': False, 180 | 'desc': 'Reader weights trained on NQ-single retriever results and HF bert-base-uncased model' 181 | }, 182 | 183 | 'checkpoint.reader.nq-trivia-hybrid.hf-bert-base': { 184 | 's3_url': 'https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-trivia-hybrid/hf_bert_base.cp', 185 | 'original_ext': '.cp', 186 | 'compressed': False, 187 | 'desc': 'Reader weights trained on Trivia multi hybrid retriever results and HF bert-base-uncased model' 188 | }, 189 | 190 | # resources provided by AmbigQA 191 | 'data.ambigqa.train_light': { 192 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/train_light.json', 193 | 'original_ext': '.json', 194 | 'compressed': False, 195 | 'desc': 'Train file for AmbigQA' 196 | }, 197 | 198 | 'data.ambigqa.dev_light': { 199 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/dev_light.json', 200 | 'original_ext': '.json', 201 | 'compressed': False, 202 | 'desc': 'Dev file for AmbigQA' 203 | }, 204 | 205 | 'data.nqopen.train': { 206 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/nqopen-train.json', 207 | 'original_ext': '.json', 208 | 'compressed': False, 209 | 'desc': 'Train file for NQ-open' 210 | }, 211 | 212 | 'data.nqopen.dev': { 213 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/nqopen-dev.json', 214 | 'original_ext': '.json', 215 | 'compressed': False, 216 | 'desc': 'Dev file for NQ-open' 217 | }, 218 | 219 | 'data.nqopen.test': { 220 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/nqopen-test.json', 221 | 'original_ext': '.json', 222 | 'compressed': False, 223 | 'desc': 'Test file for NQ-open' 224 | }, 225 | 226 | 'data.nqopen.train_id2answers': { 227 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/train_id2answers.json', 228 | 'original_ext': '.json', 229 | 'compressed': False, 230 | 'desc': 'Dev id to official answers provided by Google' 231 | }, 232 | 233 | 'data.nqopen.dev_id2answers': { 234 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/dev_id2answers.json', 235 | 'original_ext': '.json', 236 | 'compressed': False, 237 | 'desc': 'Dev id to official answers provided by Google' 238 | }, 239 | 240 | 'data.nqopen.test_id2answers': { 241 | 's3_url': 'https://nlp.cs.washington.edu/ambigqa/data/test_id2answers.json', 242 | 'original_ext': '.json', 243 | 'compressed': False, 244 | 'desc': 'Test id to official answers provided by Google' 245 | }, 246 | 247 | 248 | } 249 | 250 | 251 | def unpack(gzip_file: str, out_file: str): 252 | print('Uncompressing ', gzip_file) 253 | input = gzip.GzipFile(gzip_file, 'rb') 254 | s = input.read() 255 | input.close() 256 | output = open(out_file, 'wb') 257 | output.write(s) 258 | output.close() 259 | print('Saved to ', out_file) 260 | 261 | 262 | def download_resource(s3_url: str, original_ext: str, compressed: bool, resource_key: str, out_dir: str) -> str: 263 | print('Loading from ', s3_url) 264 | 265 | # create local dir 266 | path_names = resource_key.split('.') 267 | 268 | root_dir = out_dir if out_dir else './' 269 | save_root = os.path.join(root_dir, *path_names[:-1]) # last segment is for file name 270 | 271 | pathlib.Path(save_root).mkdir(parents=True, exist_ok=True) 272 | 273 | local_file = os.path.join(save_root, path_names[-1] + ('.tmp' if compressed else original_ext)) 274 | 275 | if os.path.exists(local_file): 276 | print('File already exist ', local_file) 277 | return save_root 278 | 279 | wget.download(s3_url, out=local_file) 280 | 281 | print('Saved to ', local_file) 282 | 283 | if compressed: 284 | uncompressed_file = os.path.join(save_root, path_names[-1] + original_ext) 285 | unpack(local_file, uncompressed_file) 286 | os.remove(local_file) 287 | return save_root 288 | 289 | 290 | def download_file(s3_url: str, out_dir: str, file_name: str): 291 | print('Loading from ', s3_url) 292 | local_file = os.path.join(out_dir, file_name) 293 | 294 | if os.path.exists(local_file): 295 | print('File already exist ', local_file) 296 | return 297 | 298 | wget.download(s3_url, out=local_file) 299 | print('Saved to ', local_file) 300 | 301 | 302 | def download(resource_key: str, out_dir: str = None): 303 | if resource_key not in RESOURCES_MAP: 304 | # match by prefix 305 | resources = [k for k in RESOURCES_MAP.keys() if k.startswith(resource_key)] 306 | if resources: 307 | for key in resources: 308 | download(key, out_dir) 309 | else: 310 | print('no resources found for specified key') 311 | return 312 | download_info = RESOURCES_MAP[resource_key] 313 | 314 | s3_url = download_info['s3_url'] 315 | 316 | save_root_dir = None 317 | if isinstance(s3_url, list): 318 | for i, url in enumerate(s3_url): 319 | save_root_dir = download_resource(url, 320 | download_info['original_ext'], 321 | download_info['compressed'], 322 | '{}_{}'.format(resource_key, i), 323 | out_dir) 324 | else: 325 | save_root_dir = download_resource(s3_url, 326 | download_info['original_ext'], 327 | download_info['compressed'], 328 | resource_key, 329 | out_dir) 330 | 331 | license_files = download_info.get('license_files', None) 332 | if not license_files: 333 | return 334 | 335 | download_file(license_files[0], save_root_dir, 'LICENSE') 336 | download_file(license_files[1], save_root_dir, 'README') 337 | 338 | 339 | def main(): 340 | parser = argparse.ArgumentParser() 341 | 342 | parser.add_argument("--output_dir", default="./", type=str, 343 | help="The output directory to download file") 344 | parser.add_argument("--resource", type=str, 345 | help="Resource name. See RESOURCES_MAP for all possible values") 346 | args = parser.parse_args() 347 | if args.resource: 348 | download(args.resource, args.output_dir) 349 | else: 350 | print('Please specify resource value. Possible options are:') 351 | for k, v in RESOURCES_MAP.items(): 352 | print('Resource key={} description: {}'.format(k, v['desc'])) 353 | 354 | 355 | if __name__ == '__main__': 356 | main() 357 | -------------------------------------------------------------------------------- /ambigqa_evaluate_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import string 5 | import argparse 6 | import numpy as np 7 | #from collections import Counter, defaultdict 8 | 9 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 10 | from pycocoevalcap.bleu.bleu import Bleu 11 | 12 | tokenizer = PTBTokenizer() 13 | 14 | class QAPairEvaluation(object): 15 | 16 | def __init__(self, reference, prediction, metrics="all"): 17 | ''' 18 | :param: samples: a list of annotated data 19 | :param: predictions: a dictionary with id as key and prediction as value 20 | prediction can be either 21 | - a list of strings 22 | - a list of dictionaries with quetion and answer as keys 23 | ''' 24 | self.reference = reference 25 | self.prediction = [prediction[sample["id"]] for sample in reference] 26 | self.metrics = metrics 27 | METRICS_ANSWER = ["F1 answer"] 28 | METRICS_QG = ["F1 bleu1", "F1 bleu2", "F1 bleu3", "F1 bleu4", "F1 edit-f1"] 29 | 30 | if metrics=="all" and type(self.prediction[0][0])==str: 31 | self.metrics = METRICS_ANSWER 32 | elif metrics=="all": 33 | self.metrics = METRICS_ANSWER+METRICS_QG 34 | 35 | assert len(set(self.metrics)-set(METRICS_ANSWER)-set(METRICS_QG))==0 36 | self.QG_METRICS_TO_COMPUTE = [m for m in ["bleu1", "bleu2", "bleu3", "bleu4", "rouge-l", "edit-f1"] if any([metric.endswith(m) for metric in self.metrics])] 37 | 38 | if len(self.QG_METRICS_TO_COMPUTE)>0: 39 | # if evaluating QG, tokenize prompt question, 40 | # reference question and predicted question 41 | data_to_tokenize = {} 42 | for i, ref in enumerate(self.reference): 43 | data_to_tokenize["prompt.{}".format(i)] = [{"caption": ref["question"]}] 44 | for j, annotation in enumerate(ref["annotations"]): 45 | if annotation['type']=='multipleQAs': 46 | for k, pair in enumerate(annotation['qaPairs']): 47 | data_to_tokenize["ref.{}.{}.{}".format(i, j, k)] = \ 48 | [{'caption': sent.strip()} for sent in pair["question"].split('|') if len(sent.strip())>0] 49 | for i, pred in enumerate(self.prediction): 50 | for j, pair in enumerate(pred): 51 | data_to_tokenize["gen.{}.{}".format(i, j)] = [{"caption": pair["question"]}] 52 | 53 | all_tokens = tokenizer.tokenize(data_to_tokenize) 54 | for key, values in all_tokens.items(): 55 | values = {'sent': [normalize_answer(value) for value in values]} 56 | if key.startswith("prompt."): 57 | i = key.split(".")[1] 58 | self.reference[int(i)]["question"] = values 59 | elif key.startswith("ref."): 60 | i, j, k = key.split('.')[1:] 61 | self.reference[int(i)]["annotations"][int(j)]["qaPairs"][int(k)]["question"] = values 62 | elif key.startswith("gen."): 63 | i, j = key.split(".")[1:] 64 | self.prediction[int(i)][int(j)]["question"] = values 65 | else: 66 | raise NotImplementedError() 67 | 68 | self.is_multi = [not any([ann["type"]=="singleAnswer" for ann in ref["annotations"]]) \ 69 | for ref in self.reference] 70 | self.results = [self.get_all_metrics(idx) for idx in range(len(self.reference))] 71 | 72 | def print_all_metrics(self): 73 | for metric in self.metrics: 74 | result = [e[metric] for e in self.results] 75 | result_multi_only = [e[metric] for e, is_multi in zip(self.results, self.is_multi) \ 76 | if is_multi] 77 | if metric=="F1 answer": 78 | print ("%s\t%.3f (all)\t%.3f (multi only)" % (metric, np.mean(result), np.mean(result_multi_only))) 79 | else: 80 | print ("%s\t%.3f" % (metric, np.mean(result_multi_only))) 81 | 82 | def get_metric(self, metric): 83 | return np.mean([e[metric] for e in self.results]) 84 | 85 | def get_all_metrics(self, idx): 86 | evaluation = {} 87 | promptQuestion = self.reference[idx]["question"] 88 | annotations = self.reference[idx]["annotations"] 89 | if type(self.prediction[idx][0])==dict: 90 | # prediction contains a set of question-answer pairs 91 | predictions = [pair["answer"] for pair in self.prediction[idx]] 92 | questions = [pair["question"] for pair in self.prediction[idx]] 93 | else: 94 | # prediction contains a set of answers 95 | predictions = self.prediction[idx] 96 | questions = None 97 | 98 | for annotation in annotations: 99 | # iterate each annotation and take the maximum metrics 100 | if annotation['type']=='singleAnswer': 101 | f1 = get_f1([annotation['answer']], predictions) 102 | for metric in self.metrics: 103 | if metric.startswith('F1'): 104 | evaluation[metric] = max(evaluation.get(metric, 0), f1) 105 | elif annotation['type']=='multipleQAs': 106 | matching_pairs = [] 107 | evaluation['F1 answer'] = max(evaluation.get("F1 answer", 0), 108 | get_f1([answer['answer'] for answer in annotation['qaPairs']], predictions)) 109 | if questions is None: 110 | # skip the below if not evaluating QG 111 | continue 112 | for i, answer in enumerate(annotation["qaPairs"]): 113 | for j, prediction in enumerate(predictions): 114 | # get every reference-prediction pair with the correct answer prediction 115 | em = get_exact_match(answer['answer'], prediction) 116 | if em: 117 | qg_evals = get_qg_metrics(questions[j], 118 | answer['question'], 119 | promptQuestion, 120 | self.QG_METRICS_TO_COMPUTE) 121 | matching_pairs.append((i, j, qg_evals)) 122 | 123 | def _get_qg_f1(metric_func): 124 | curr_matching_pairs = sorted(matching_pairs, key=lambda x: metric_func(x[2]), reverse=True) 125 | occupied_answers = [False for _ in annotation["qaPairs"]] 126 | occupied_predictions = [False for _ in predictions] 127 | tot = 0 128 | # find non-overapping reference-prediction pairs 129 | # that match the answer prediction 130 | # to get the evaluation score 131 | for (i, j, e) in curr_matching_pairs: 132 | if occupied_answers[i] or occupied_predictions[j]: 133 | continue 134 | occupied_answers[i] = True 135 | occupied_predictions[j] = True 136 | tot += metric_func(e) 137 | assert np.sum(occupied_answers)==np.sum(occupied_predictions) 138 | return 2 * tot / (len(occupied_answers)+len(occupied_predictions)) 139 | 140 | for metric in self.QG_METRICS_TO_COMPUTE: 141 | metric_name = "F1 {}".format(metric) 142 | if metric_name in self.metrics: 143 | e = _get_qg_f1(lambda x: x[metric]) 144 | evaluation[metric_name] = max(evaluation.get(metric_name, 0), e) 145 | else: 146 | raise NotImplementedError() 147 | 148 | assert len(self.metrics)==len(evaluation), (self.metrics, evaluation.keys()) 149 | return evaluation 150 | 151 | def get_qg_metrics(generated, question, promptQuestion, metrics): 152 | 153 | evaluation = {} 154 | 155 | # computing bleu scores 156 | for name, score in zip(['bleu{}'.format(i) for i in range(1, 5)], 157 | Bleu(4).compute_score(question, generated)[0]): 158 | if name in metrics: 159 | evaluation[name] = score 160 | 161 | # computing edit-f1 score 162 | if 'edit-f1' in metrics: 163 | def _get_edits(tokens1, tokens2): 164 | allCommon = [] 165 | while True: 166 | commons = list(set(tokens1) & set(tokens2)) 167 | if len(commons)==0: 168 | break 169 | allCommon += commons 170 | for c in commons: 171 | ind1, ind2 = tokens1.index(c), tokens2.index(c) 172 | tokens1 = tokens1[:ind1]+tokens1[ind1+1:] 173 | tokens2 = tokens2[:ind2]+tokens2[ind2+1:] 174 | deleted = ["[DELETED]"+token for token in tokens1] 175 | added = ["[ADDED]"+token for token in tokens2] 176 | common = ["[FIXED]"+token for token in allCommon] 177 | return deleted+added #+common 178 | 179 | assert len(generated)==len(promptQuestion)==1 180 | generated = generated["sent"][0].split(" ") 181 | promptQuestion = promptQuestion["sent"][0].split(" ") 182 | prediction = _get_edits(promptQuestion, generated) 183 | edit_f1 = 0 184 | for _question in question["sent"]: 185 | _question = _question.split(" ") 186 | reference = _get_edits(promptQuestion, _question) 187 | # now compare the reference edits and predicted edits 188 | if len(reference)==len(prediction)==0: 189 | # rarely, reference has no edits after normalization 190 | # then, if the prediction also has no edits, it gets full score 191 | edit_f1 = 1 192 | elif len(reference)==0 or len(prediction)==0: 193 | # if only one of them has no edits, zero score 194 | edit_f1 = max(edit_f1, 0) 195 | else: 196 | # otherwise, compute F1 score between prediction and reference 197 | edit_f1 = max(edit_f1, get_f1(prediction, reference, is_equal=lambda x, y: x==y)) 198 | evaluation["edit-f1"] = edit_f1 199 | 200 | assert len(metrics)==len(evaluation) 201 | return evaluation 202 | 203 | def get_exact_match(answers1, answers2): 204 | if type(answers1)==list: 205 | if len(answers1)==0: 206 | return 0 207 | return np.max([get_exact_match(a, answers2) for a in answers1]) 208 | if type(answers2)==list: 209 | if len(answers2)==0: 210 | return 0 211 | return np.max([get_exact_match(answers1, a) for a in answers2]) 212 | return (normalize_answer(answers1) == normalize_answer(answers2)) 213 | 214 | def normalize_answer(s): 215 | 216 | def remove_articles(text): 217 | return re.sub(r'\b(a|an|the)\b', ' ', text) 218 | 219 | def white_space_fix(text): 220 | return ' '.join(text.split()) 221 | 222 | def remove_punc(text): 223 | exclude = set(string.punctuation) 224 | return ''.join(ch for ch in text if ch not in exclude) 225 | 226 | def lower(text): 227 | return text.lower() 228 | 229 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 230 | 231 | def get_f1(answers, predictions, is_equal=get_exact_match): 232 | ''' 233 | :answers: a list of list of strings 234 | :predictions: a list of strings 235 | ''' 236 | assert len(answers)>0 and len(predictions)>0, (answers, predictions) 237 | occupied_answers = [False for _ in answers] 238 | occupied_predictions = [False for _ in predictions] 239 | for i, answer in enumerate(answers): 240 | for j, prediction in enumerate(predictions): 241 | if occupied_answers[i] or occupied_predictions[j]: 242 | continue 243 | em = is_equal(answer, prediction) 244 | if em: 245 | occupied_answers[i] = True 246 | occupied_predictions[j] = True 247 | assert np.sum(occupied_answers)==np.sum(occupied_predictions) 248 | a, b = np.mean(occupied_answers), np.mean(occupied_predictions) 249 | if a+b==0: 250 | return 0 251 | return 2*a*b/(a+b) 252 | 253 | def load_reference(reference_path): 254 | if os.path.exists(reference_path): 255 | with open(reference_path, "r") as f: 256 | reference = json.load(f) 257 | if not (type(reference)==list and \ 258 | all([type(ref)==dict and "id" in ref and "question" in ref and "annotations" in ref and \ 259 | type(ref["question"])==str and type(ref["annotations"])==list and \ 260 | all([type(ann)==dict and ann["type"] in ["singleAnswer", "multipleQAs"] for ann in ref["annotations"]]) \ 261 | for ref in reference])): 262 | raise Exception("Reference file {} is wrong".format(reference_path)) 263 | else: 264 | raise Exception("Reference file {} not found".format(reference_path)) 265 | return reference 266 | 267 | def load_prediction(prediction_path, ids): 268 | if os.path.exists(prediction_path): 269 | with open(prediction_path, "r") as f: 270 | prediction = json.load(f) 271 | if str(list(prediction.keys())[0])==int: 272 | prediction = {str(key):value for key, value in prediction.items()} 273 | if type(list(prediction.values())[0])==str: 274 | prediction = {key:[value] for key, value in prediction.items()} 275 | if not (type(prediction)==dict and \ 276 | len(ids-set(prediction.keys()))==0): 277 | raise Exception("Prediction file {} is wrong".format(prediction_path)) 278 | if not (all([type(pred)==list for pred in prediction.values()]) and \ 279 | (all([type(p)==str for pred in prediction.values() for p in pred]) or \ 280 | all([type(p)==dict and "question" in p and "answer" in p \ 281 | and type(p["question"])==type(p["answer"])==str for pred in prediction.values() for p in pred]))): 282 | raise Exception("Prediction file {} has a wrong format".format(prediction_path)) 283 | else: 284 | raise Exception("Prediction file {} not found".format(prediction_path)) 285 | return prediction 286 | 287 | if __name__=="__main__": 288 | parser = argparse.ArgumentParser() 289 | parser.add_argument('--reference_path', type=str, required=True) 290 | parser.add_argument('--prediction_path', type=str, required=True) 291 | args = parser.parse_args() 292 | 293 | reference = load_reference(args.reference_path) 294 | ids = set([d["id"] for d in reference]) 295 | prediction = load_prediction(args.prediction_path, ids) 296 | evaluation = QAPairEvaluation(reference, prediction) 297 | evaluation.print_all_metrics() 298 | 299 | 300 | -------------------------------------------------------------------------------- /codes/ambigqa_evaluate_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import string 5 | import argparse 6 | import numpy as np 7 | from collections import Counter, defaultdict 8 | 9 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 10 | from pycocoevalcap.bleu.bleu import Bleu 11 | 12 | tokenizer = PTBTokenizer() 13 | 14 | class QAPairEvaluation(object): 15 | 16 | def __init__(self, reference, prediction, metrics="all"): 17 | ''' 18 | :param: samples: a list of annotated data 19 | :param: predictions: a dictionary with id as key and prediction as value 20 | prediction can be either 21 | - a list of strings 22 | - a list of dictionaries with quetion and answer as keys 23 | ''' 24 | self.reference = reference 25 | self.prediction = [prediction[sample["id"]] for sample in reference] 26 | self.metrics = metrics 27 | METRICS_ANSWER = ["F1 answer"] 28 | METRICS_QG = ["F1 bleu1", "F1 bleu2", "F1 bleu3", "F1 bleu4", "F1 edit-f1"] 29 | 30 | if metrics=="all" and type(self.prediction[0][0])==str: 31 | self.metrics = METRICS_ANSWER 32 | elif metrics=="all": 33 | self.metrics = METRICS_ANSWER+METRICS_QG 34 | 35 | assert len(set(self.metrics)-set(METRICS_ANSWER)-set(METRICS_QG))==0 36 | self.QG_METRICS_TO_COMPUTE = [m for m in ["bleu1", "bleu2", "bleu3", "bleu4", "rouge-l", "edit-f1"] if any([metric.endswith(m) for metric in self.metrics])] 37 | 38 | if len(self.QG_METRICS_TO_COMPUTE)>0: 39 | # if evaluating QG, tokenize prompt question, 40 | # reference question and predicted question 41 | data_to_tokenize = {} 42 | for i, ref in enumerate(self.reference): 43 | data_to_tokenize["prompt.{}".format(i)] = [{"caption": ref["question"]}] 44 | for j, annotation in enumerate(ref["annotations"]): 45 | if annotation['type']=='multipleQAs': 46 | for k, pair in enumerate(annotation['qaPairs']): 47 | data_to_tokenize["ref.{}.{}.{}".format(i, j, k)] = \ 48 | [{'caption': sent.strip()} for sent in pair["question"].split('|') if len(sent.strip())>0] 49 | for i, pred in enumerate(self.prediction): 50 | for j, pair in enumerate(pred): 51 | data_to_tokenize["gen.{}.{}".format(i, j)] = [{"caption": pair["question"]}] 52 | 53 | all_tokens = tokenizer.tokenize(data_to_tokenize) 54 | for key, values in all_tokens.items(): 55 | values = {'sent': [normalize_answer(value) for value in values]} 56 | if key.startswith("prompt."): 57 | i = key.split(".")[1] 58 | self.reference[int(i)]["question"] = values 59 | elif key.startswith("ref."): 60 | i, j, k = key.split('.')[1:] 61 | self.reference[int(i)]["annotations"][int(j)]["qaPairs"][int(k)]["question"] = values 62 | elif key.startswith("gen."): 63 | i, j = key.split(".")[1:] 64 | self.prediction[int(i)][int(j)]["question"] = values 65 | else: 66 | raise NotImplementedError() 67 | 68 | self.is_multi = [not any([ann["type"]=="singleAnswer" for ann in ref["annotations"]]) \ 69 | for ref in self.reference] 70 | self.results = [self.get_all_metrics(idx) for idx in range(len(self.reference))] 71 | 72 | def print_all_metrics(self): 73 | for metric in self.metrics: 74 | result = [e[metric] for e in self.results] 75 | result_multi_only = [e[metric] for e, is_multi in zip(self.results, self.is_multi) \ 76 | if is_multi] 77 | if metric=="F1 answer": 78 | print ("%s\t%.3f (all)\t%.3f (multi only)" % (metric, np.mean(result), np.mean(result_multi_only))) 79 | else: 80 | print ("%s\t%.3f" % (metric, np.mean(result_multi_only))) 81 | 82 | def get_metric(self, metric): 83 | return np.mean([e[metric] for e in self.results]) 84 | 85 | def get_all_metrics(self, idx): 86 | evaluation = {} 87 | promptQuestion = self.reference[idx]["question"] 88 | annotations = self.reference[idx]["annotations"] 89 | if type(self.prediction[idx][0])==dict: 90 | # prediction contains a set of question-answer pairs 91 | predictions = [pair["answer"] for pair in self.prediction[idx]] 92 | questions = [pair["question"] for pair in self.prediction[idx]] 93 | else: 94 | # prediction contains a set of answers 95 | predictions = self.prediction[idx] 96 | questions = None 97 | 98 | for annotation in annotations: 99 | # iterate each annotation and take the maximum metrics 100 | if annotation['type']=='singleAnswer': 101 | f1 = get_f1([annotation['answer']], predictions) 102 | for metric in self.metrics: 103 | if metric.startswith('F1'): 104 | evaluation[metric] = max(evaluation.get(metric, 0), f1) 105 | elif annotation['type']=='multipleQAs': 106 | matching_pairs = [] 107 | evaluation['F1 answer'] = max(evaluation.get("F1 answer", 0), 108 | get_f1([answer['answer'] for answer in annotation['qaPairs']], predictions)) 109 | if questions is None: 110 | # skip the below if not evaluating QG 111 | continue 112 | for i, answer in enumerate(annotation["qaPairs"]): 113 | for j, prediction in enumerate(predictions): 114 | # get every reference-prediction pair with the correct answer prediction 115 | em = get_exact_match(answer['answer'], prediction) 116 | if em: 117 | qg_evals = get_qg_metrics(questions[j], 118 | answer['question'], 119 | promptQuestion, 120 | self.QG_METRICS_TO_COMPUTE) 121 | matching_pairs.append((i, j, qg_evals)) 122 | 123 | def _get_qg_f1(metric_func): 124 | curr_matching_pairs = sorted(matching_pairs, key=lambda x: metric_func(x[2]), reverse=True) 125 | occupied_answers = [False for _ in annotation["qaPairs"]] 126 | occupied_predictions = [False for _ in predictions] 127 | tot = 0 128 | # find non-overapping reference-prediction pairs 129 | # that match the answer prediction 130 | # to get the evaluation score 131 | for (i, j, e) in curr_matching_pairs: 132 | if occupied_answers[i] or occupied_predictions[j]: 133 | continue 134 | occupied_answers[i] = True 135 | occupied_predictions[j] = True 136 | tot += metric_func(e) 137 | assert np.sum(occupied_answers)==np.sum(occupied_predictions) 138 | return 2 * tot / (len(occupied_answers)+len(occupied_predictions)) 139 | 140 | for metric in self.QG_METRICS_TO_COMPUTE: 141 | metric_name = "F1 {}".format(metric) 142 | if metric_name in self.metrics: 143 | e = _get_qg_f1(lambda x: x[metric]) 144 | evaluation[metric_name] = max(evaluation.get(metric_name, 0), e) 145 | else: 146 | raise NotImplementedError() 147 | 148 | assert len(self.metrics)==len(evaluation), (self.metrics, evaluation.keys()) 149 | return evaluation 150 | 151 | def get_qg_metrics(generated, question, promptQuestion, metrics): 152 | 153 | evaluation = {} 154 | 155 | # computing bleu scores 156 | for name, score in zip(['bleu{}'.format(i) for i in range(1, 5)], 157 | Bleu(4).compute_score(question, generated)[0]): 158 | if name in metrics: 159 | evaluation[name] = score 160 | 161 | # computing edit-f1 score 162 | if 'edit-f1' in metrics: 163 | def _get_edits(tokens1, tokens2): 164 | allCommon = [] 165 | while True: 166 | commons = list(set(tokens1) & set(tokens2)) 167 | if len(commons)==0: 168 | break 169 | allCommon += commons 170 | for c in commons: 171 | ind1, ind2 = tokens1.index(c), tokens2.index(c) 172 | tokens1 = tokens1[:ind1]+tokens1[ind1+1:] 173 | tokens2 = tokens2[:ind2]+tokens2[ind2+1:] 174 | deleted = ["[DELETED]"+token for token in tokens1] 175 | added = ["[ADDED]"+token for token in tokens2] 176 | common = ["[FIXED]"+token for token in allCommon] 177 | return deleted+added #+common 178 | 179 | assert len(generated)==len(promptQuestion)==1 180 | generated = generated["sent"][0].split(" ") 181 | promptQuestion = promptQuestion["sent"][0].split(" ") 182 | prediction = _get_edits(promptQuestion, generated) 183 | edit_f1 = 0 184 | for _question in question["sent"]: 185 | _question = _question.split(" ") 186 | reference = _get_edits(promptQuestion, _question) 187 | # now compare the reference edits and predicted edits 188 | if len(reference)==len(prediction)==0: 189 | # rarely, reference has no edits after normalization 190 | # then, if the prediction also has no edits, it gets full score 191 | edit_f1 = 1 192 | elif len(reference)==0 or len(prediction)==0: 193 | # if only one of them has no edits, zero score 194 | edit_f1 = max(edit_f1, 0) 195 | else: 196 | # otherwise, compute F1 score between prediction and reference 197 | edit_f1 = max(edit_f1, get_f1(prediction, reference, is_equal=lambda x, y: x==y)) 198 | evaluation["edit-f1"] = edit_f1 199 | 200 | assert len(metrics)==len(evaluation) 201 | return evaluation 202 | 203 | def get_exact_match(answers1, answers2): 204 | if type(answers1)==list: 205 | if len(answers1)==0: 206 | return 0 207 | return np.max([get_exact_match(a, answers2) for a in answers1]) 208 | if type(answers2)==list: 209 | if len(answers2)==0: 210 | return 0 211 | return np.max([get_exact_match(answers1, a) for a in answers2]) 212 | return (normalize_answer(answers1) == normalize_answer(answers2)) 213 | 214 | def normalize_answer(s): 215 | 216 | def remove_articles(text): 217 | return re.sub(r'\b(a|an|the)\b', ' ', text) 218 | 219 | def white_space_fix(text): 220 | return ' '.join(text.split()) 221 | 222 | def remove_punc(text): 223 | exclude = set(string.punctuation) 224 | return ''.join(ch for ch in text if ch not in exclude) 225 | 226 | def lower(text): 227 | return text.lower() 228 | 229 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 230 | 231 | def get_f1(answers, predictions, is_equal=get_exact_match, return_p_and_r=False): 232 | ''' 233 | :answers: a list of list of strings 234 | :predictions: a list of strings 235 | ''' 236 | assert len(answers)>0 and len(predictions)>0, (answers, predictions) 237 | occupied_answers = [False for _ in answers] 238 | occupied_predictions = [False for _ in predictions] 239 | for i, answer in enumerate(answers): 240 | for j, prediction in enumerate(predictions): 241 | if occupied_answers[i] or occupied_predictions[j]: 242 | continue 243 | em = is_equal(answer, prediction) 244 | if em: 245 | occupied_answers[i] = True 246 | occupied_predictions[j] = True 247 | assert np.sum(occupied_answers)==np.sum(occupied_predictions) 248 | a, b = np.mean(occupied_answers), np.mean(occupied_predictions) 249 | if return_p_and_r: 250 | if a+b==0: 251 | return 0., 0., 0. 252 | return 2*a*b/(a+b), float(a), float(b) 253 | if a+b==0: 254 | return 0. 255 | return 2*a*b/(a+b) 256 | 257 | def load_reference(reference_path): 258 | if os.path.exists(reference_path): 259 | with open(reference_path, "r") as f: 260 | reference = json.load(f) 261 | if not (type(reference)==list and \ 262 | all([type(ref)==dict and "id" in ref and "question" in ref and "annotations" in ref and \ 263 | type(ref["question"])==str and type(ref["annotations"])==list and \ 264 | all([type(ann)==dict and ann["type"] in ["singleAnswer", "multipleQAs"] for ann in ref["annotations"]]) \ 265 | for ref in reference])): 266 | raise Exception("Reference file {} is wrong".format(reference_path)) 267 | else: 268 | raise Exception("Reference file {} not found".format(reference_path)) 269 | return reference 270 | 271 | def load_prediction(prediction_path, ids): 272 | if os.path.exists(prediction_path): 273 | with open(prediction_path, "r") as f: 274 | prediction = json.load(f) 275 | if str(list(prediction.keys())[0])==int: 276 | prediction = {str(key):value for key, value in prediction.items()} 277 | if type(list(prediction.values())[0])==str: 278 | prediction = {key:[value] for key, value in prediction.items()} 279 | if not (type(prediction)==dict and \ 280 | len(ids-set(prediction.keys()))==0): 281 | raise Exception("Prediction file {} is wrong".format(prediction_path)) 282 | if not (all([type(pred)==list for pred in prediction.values()]) and \ 283 | (all([type(p)==str for pred in prediction.values() for p in pred]) or \ 284 | all([type(p)==dict and "question" in p and "answer" in p \ 285 | and type(p["question"])==type(p["answer"])==str for pred in prediction.values() for p in pred]))): 286 | raise Exception("Prediction file {} has a wrong format".format(prediction_path)) 287 | else: 288 | raise Exception("Prediction file {} not found".format(prediction_path)) 289 | return prediction 290 | 291 | if __name__=="__main__": 292 | parser = argparse.ArgumentParser() 293 | parser.add_argument('--reference_path', type=str, required=True) 294 | parser.add_argument('--prediction_path', type=str, required=True) 295 | args = parser.parse_args() 296 | 297 | reference = load_reference(args.reference_path) 298 | ids = set([d["id"] for d in reference]) 299 | prediction = load_prediction(args.prediction_path, ids) 300 | evaluation = QAPairEvaluation(reference, prediction) 301 | evaluation.print_all_metrics() 302 | 303 | 304 | -------------------------------------------------------------------------------- /codes/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from tqdm import tqdm 6 | from transformers import BartTokenizer, AlbertTokenizer, BertTokenizer 7 | from transformers import BartConfig, AlbertConfig, BertConfig 8 | from transformers import AdamW, get_linear_schedule_with_warmup 9 | 10 | from QAData import QAData, AmbigQAData 11 | from QGData import QGData, AmbigQGData 12 | from PassageData import PassageData 13 | 14 | from models.span_predictor import SpanPredictor, AlbertSpanPredictor 15 | from models.seq2seq import MyBart 16 | from models.seq2seq_with_prefix import MyBartWithPrefix 17 | from models.biencoder import MyBiEncoder 18 | 19 | def run(args, logger): 20 | 21 | args.dpr = args.task=="dpr" 22 | args.is_seq2seq = 'bart' in args.bert_name 23 | if 'bart' in args.bert_name: 24 | tokenizer = BartTokenizer.from_pretrained(args.bert_name) 25 | tokenizer.add_tokens([""]) 26 | Model = MyBartWithPrefix if args.do_predict and args.nq_answer_as_prefix else MyBart 27 | Config = BartConfig 28 | args.append_another_bos = True 29 | elif 'albert' in args.bert_name: 30 | tokenizer = AlbertTokenizer.from_pretrained(args.bert_name) 31 | Model = AlbertSpanPredictor 32 | Config = AlbertConfig 33 | elif 'bert' in args.bert_name: 34 | tokenizer = BertTokenizer.from_pretrained(args.bert_name) 35 | Model = MyBiEncoder if args.dpr else SpanPredictor 36 | Config = BertConfig 37 | else: 38 | raise NotImplementedError() 39 | 40 | if args.dpr: 41 | Model = MyBiEncoder 42 | args.checkpoint = os.path.join(args.dpr_data_dir, "checkpoint/retriever/multiset/bert-base-encoder.cp") 43 | assert not args.do_train, "Training DPR is not supported yet" 44 | 45 | passages = PassageData(logger, args, tokenizer) 46 | 47 | def _getQAData(): 48 | if args.task=="qg": 49 | return AmbigQGData if args.ambigqa else QGData 50 | return AmbigQAData if args.ambigqa else QAData 51 | 52 | def _load_from_checkpoint(checkpoint): 53 | def convert_to_single_gpu(state_dict): 54 | if "model_dict" in state_dict: 55 | state_dict = state_dict["model_dict"] 56 | def _convert(key): 57 | if key.startswith('module.'): 58 | return key[7:] 59 | return key 60 | return {_convert(key):value for key, value in state_dict.items()} 61 | state_dict = convert_to_single_gpu(torch.load(checkpoint)) 62 | model = Model(Config.from_pretrained(args.bert_name)) 63 | if "bart" in args.bert_name: 64 | model.resize_token_embeddings(len(tokenizer)) 65 | logger.info("Loading from {}".format(checkpoint)) 66 | return model.from_pretrained(None, config=model.config, state_dict=state_dict) 67 | 68 | if args.do_train and args.skip_inference: 69 | dev_data = None 70 | else: 71 | dev_data = _getQAData()(logger, args, args.predict_file, False, passages) 72 | dev_data.load_dataset(tokenizer) 73 | if args.do_prepro_only: 74 | dev_data.load_dpr_data() 75 | exit() 76 | dev_data.load_dataloader() 77 | 78 | if args.do_train: 79 | train_data = _getQAData()(logger, args, args.train_file, True, passages) 80 | train_data.load_dataset(tokenizer) 81 | train_data.load_dataloader() 82 | 83 | if args.checkpoint is not None: 84 | model = _load_from_checkpoint(args.checkpoint) 85 | else: 86 | model = Model.from_pretrained(args.bert_name) 87 | if "bart" in args.bert_name: 88 | model.resize_token_embeddings(len(tokenizer)) 89 | if args.n_gpu>1: 90 | model = torch.nn.DataParallel(model) 91 | model.to(torch.device("cuda")) 92 | 93 | no_decay = ['bias', 'LayerNorm.weight'] 94 | optimizer_grouped_parameters = [ 95 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 96 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 97 | ] 98 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 99 | scheduler = get_linear_schedule_with_warmup(optimizer, 100 | num_warmup_steps=args.warmup_steps, 101 | num_training_steps=100000) 102 | train(args, logger, model, train_data, dev_data, optimizer, scheduler) 103 | 104 | if args.do_predict: 105 | checkpoint = os.path.join(args.output_dir, 'best-model.pt') if args.checkpoint is None else args.checkpoint 106 | model = _load_from_checkpoint(checkpoint) 107 | logger.info("Loading checkpoint from {}".format(checkpoint)) 108 | if args.n_gpu>1 and 'bert' in args.bert_name: 109 | model = torch.nn.DataParallel(model) 110 | model.to(torch.device("cuda")) 111 | model.eval() 112 | ems = inference(model, dev_data, save_predictions=True) 113 | logger.info("%s on test data = %.2f" % (dev_data.metric, np.mean(ems)*100)) 114 | 115 | def train(args, logger, model, train_data, dev_data, optimizer, scheduler): 116 | model.train() 117 | global_step = 0 118 | train_losses = [] 119 | best_accuracy = -1 120 | stop_training=False 121 | 122 | for _ in range(args.resume_global_step): 123 | optimizer.step() 124 | scheduler.step() 125 | 126 | logger.info("Start training!") 127 | for epoch in range(int(args.num_train_epochs)): 128 | for batch in train_data.dataloader: 129 | global_step += 1 130 | batch = [b.to(torch.device("cuda")) for b in batch] 131 | if args.is_seq2seq: 132 | loss = model(input_ids=batch[0], attention_mask=batch[1], 133 | decoder_input_ids=batch[2], decoder_attention_mask=batch[3], 134 | is_training=True) 135 | else: 136 | loss = model(input_ids=batch[0], attention_mask=batch[1], token_type_ids=batch[2], 137 | start_positions=batch[3], end_positions=batch[4], answer_mask=batch[5], 138 | is_training=True) 139 | if args.n_gpu > 1: 140 | loss = loss.mean() # mean() to average on multi-gpu. 141 | if torch.isnan(loss).data: 142 | logger.info("Stop training because loss=%s" % (loss.data)) 143 | stop_training=True 144 | break 145 | train_losses.append(loss.detach().cpu()) 146 | loss.backward() 147 | 148 | if global_step % args.gradient_accumulation_steps == 0: 149 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 150 | optimizer.step() # We have accumulated enought gradients 151 | scheduler.step() 152 | model.zero_grad() 153 | 154 | if global_step % args.eval_period == 0: 155 | if args.skip_inference: 156 | logger.info("Step %d (epoch %d) Train loss %.2f" % ( 157 | global_step, 158 | epoch, 159 | np.mean(train_losses))) 160 | train_losses = [] 161 | model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()} 162 | torch.save(model_state_dict, os.path.join(args.output_dir, 163 | "best-model-{}.pt".format(str(global_step).zfill(6)))) 164 | else: 165 | model.eval() 166 | curr_em = inference(model, dev_data) 167 | logger.info("Step %d Train loss %.2f %s %.2f%% on epoch=%d" % ( 168 | global_step, 169 | np.mean(train_losses), 170 | dev_data.metric, 171 | curr_em*100, 172 | epoch)) 173 | train_losses = [] 174 | if best_accuracy < curr_em: 175 | model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()} 176 | torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt")) 177 | logger.info("Saving model with best %s: %.2f%% -> %.2f%% on epoch=%d, global_step=%d" % \ 178 | (dev_data.metric, best_accuracy*100.0, curr_em*100.0, epoch, global_step)) 179 | best_accuracy = curr_em 180 | wait_step = 0 181 | stop_training = False 182 | else: 183 | wait_step += 1 184 | if wait_step >= args.wait_step: 185 | stop_training = True 186 | break 187 | model.train() 188 | if stop_training: 189 | break 190 | 191 | def inference(model, dev_data, save_predictions=False): 192 | if dev_data.args.dpr: 193 | return inference_dpr(model, dev_data, save_predictions) 194 | if "bart" in dev_data.args.bert_name: 195 | return inference_seq2seq(model if dev_data.args.n_gpu==1 or dev_data.args.do_predict else model.module, dev_data, save_predictions) 196 | return inference_span_predictor(model, dev_data, save_predictions) 197 | 198 | def inference_dpr(model, dev_data, save_predictions): 199 | 200 | def _inference(dataloader, is_passages): 201 | if dev_data.args.n_gpu>1: 202 | curr_model = model.module.ctx_model if is_passages else model.module.question_model 203 | curr_model = torch.nn.DataParallel(curr_model) 204 | else: 205 | curr_model = model.ctx_model if is_passages else model.question_model 206 | vectors = [] 207 | for i, batch in tqdm(enumerate(dataloader)): 208 | with torch.no_grad(): 209 | batch = [b.to(torch.device("cuda")) for b in batch] 210 | outputs = curr_model(input_ids=batch[0], attention_mask=batch[1])[0][:,0,:] 211 | vectors.append(outputs.detach().cpu().numpy()) 212 | return np.concatenate(vectors, axis=0) 213 | 214 | checkpoint = dev_data.args.checkpoint 215 | assert checkpoint is not None 216 | import faiss 217 | postfix = "_20200201" if dev_data.args.wiki_2020 else "" 218 | index_path = checkpoint[:checkpoint.index(".")] + "{}.IndexFlatIP".format(postfix) 219 | if os.path.exists(index_path): 220 | index = faiss.read_index(index_path) 221 | else: 222 | checkpoint = dev_data.args.checkpoint 223 | # load passage vectors 224 | index = dev_data.args.db_index 225 | if index==-1: 226 | for index in range(10): 227 | pvec_path = checkpoint[:checkpoint.index(".")] + ".psgs_w100{}_{}.npy".format(postfix, index) 228 | assert os.path.exists(pvec_path) 229 | if index==0: 230 | pvec = np.load(pvec_path) 231 | else: 232 | pvec = np.concatenate([pvec, np.load(pvec_path)], axis=0) 233 | else: 234 | pvec_path = checkpoint[:checkpoint.index(".")] + ".psgs_w100{}_{}.npy".format(postfix, index) 235 | print (pvec_path) 236 | if os.path.exists(pvec_path): 237 | pvec = np.load(pvec_path) 238 | else: 239 | dev_data.passages.load_tokenized_data("bert") 240 | dev_data.passages.load_dataset("bert") 241 | dataloader = dev_data.passages.load_dataloader( 242 | dev_data.args.predict_batch_size, 243 | is_training=False, 244 | do_return=True) 245 | if dev_data.args.verbose: 246 | dataloader = tqdm(dataloader) 247 | pvec = _inference(dataloader, is_passages=True) 248 | np.save(pvec_path, pvec) 249 | exit() 250 | print (pvec.shape) 251 | index = faiss.IndexFlatIP(pvec.shape[1]) 252 | index.add(pvec) 253 | faiss.write_index(index, index_path) 254 | qvec = _inference(dev_data.dataloader, is_passages=False) #model.inference(dev_data.dataloader, is_passages=False) 255 | print (qvec.shape) 256 | D, I = index.search(qvec, 100) 257 | assert D.shape == I.shape == (qvec.shape[0], 100) 258 | predictions = I.tolist() 259 | accuracy = dev_data.passages.evaluate(predictions, dev_data.get_answers()) 260 | if save_predictions: 261 | dev_data.save_predictions(predictions) 262 | return np.mean(accuracy) 263 | 264 | def inference_seq2seq(model, dev_data, save_predictions=False): 265 | predictions = [] 266 | bos_token_id = dev_data.tokenizer.bos_token_id 267 | if dev_data.args.task=="qa": 268 | max_answer_length = dev_data.args.max_answer_length 269 | assert max_answer_length>=25 or not dev_data.args.ambigqa 270 | else: 271 | max_answer_length = dev_data.args.max_question_length 272 | if dev_data.args.verbose: 273 | dev_data.dataloader = tqdm(dev_data.dataloader) 274 | for i, batch in enumerate(dev_data.dataloader): 275 | with torch.no_grad(): 276 | decoder_start_token_id = None if not dev_data.args.nq_answer_as_prefix else \ 277 | [[model.config.decoder_start_token_id] + tokens[:min(max_answer_length-2, tokens.index(dev_data.tokenizer.eos_token_id))] for tokens in batch[2].tolist()] 278 | batch = [b.to(torch.device("cuda")) for b in batch[:2]] 279 | outputs = model.generate(input_ids=batch[0], 280 | attention_mask=batch[1], 281 | num_beams=4, 282 | max_length=max_answer_length, 283 | early_stopping=True, 284 | decoder_start_token_id=decoder_start_token_id, 285 | num_return_sequences=4 if decoder_start_token_id is not None else 1 286 | ) 287 | for input_, output in zip(batch[0], outputs): 288 | predictions.append(dev_data.decode(output)) 289 | if save_predictions: 290 | dev_data.save_predictions(predictions) 291 | return np.mean(dev_data.evaluate(predictions)) 292 | 293 | def inference_span_predictor(model, dev_data, save_predictions=False): 294 | outputs = [] 295 | if dev_data.args.verbose: 296 | dev_data.dataloader = tqdm(dev_data.dataloader) 297 | for i, batch in enumerate(dev_data.dataloader): 298 | with torch.no_grad(): 299 | batch = [b.to(torch.device("cuda")) for b in batch] 300 | batch_start_logits, batch_end_logits, batch_sel_logits = model( 301 | input_ids=batch[0], attention_mask=batch[1], token_type_ids=batch[2]) 302 | batch_start_logits = batch_start_logits.detach().cpu().tolist() 303 | batch_end_logits = batch_end_logits.detach().cpu().tolist() 304 | batch_sel_logits = batch_sel_logits.detach().cpu().tolist() 305 | assert len(batch_start_logits)==len(batch_end_logits)==len(batch_sel_logits) 306 | for start_logit, end_logit, sel_logit in zip(batch_start_logits, batch_end_logits, batch_sel_logits): 307 | outputs.append((start_logit, end_logit, sel_logit)) 308 | 309 | if save_predictions and dev_data.args.n_paragraphs is None: 310 | n_paragraphs = [dev_data.args.test_M] 311 | elif save_predictions: 312 | n_paragraphs = [int(n) for n in dev_data.args.n_paragraphs.split(",")] 313 | else: 314 | n_paragraphs = None 315 | predictions = dev_data.decode_span(outputs, 316 | n_paragraphs=n_paragraphs) 317 | if save_predictions: 318 | dev_data.save_predictions(predictions) 319 | return np.mean(dev_data.evaluate(predictions, n_paragraphs=n_paragraphs)) 320 | 321 | 322 | -------------------------------------------------------------------------------- /codes/README.md: -------------------------------------------------------------------------------- 1 | # AmbigQA Baseline Models 2 | 3 | **Update as of 07/2020**: Codes for running DPR retrieval, DPR reader and BART reader (SpanSeqGen) on NQ-open and AmbigQA are ready. Stay tuned for Question Generation models! 4 | 5 | This repo contains multiple models for open-domain question answering. This code is based on [PyTorch][pytorch] and [HuggingFace Transformers][hf]. 6 | 7 | This is an original implementation of "Sewon Min, Julian Michael, Hannaneh Hajishirzi, Luke Zettlemoyer. [AmbigQA: Answering Ambiguous Open-domain Questions][ambigqa-paper]. 2020". 8 | ``` 9 | @article{ min2020ambigqa, 10 | title={ {A}mbig{QA}: Answering Ambiguous Open-domain Questions }, 11 | author={ Min, Sewon and Michael, Julian and Hajishirzi, Hannaneh and Zettlemoyer, Luke }, 12 | journal={ arXiv preprint arXiv:2004.10645 }, 13 | year={2020} 14 | } 15 | ``` 16 | 17 | This also contains a re-implementation of "Vladimir Karpukhin*, Barlas Oguz*, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, Wen-tau Yih. [Dense Passage Retrieval for Open-domain Question Answering. 2020][dpr-paper]", as part of AmbigQA models. The original implementation can be found [here][dpr-code]. This codebase achieves higher accuracy. 18 | ``` 19 | @article{ karpukhin2020dense, 20 | title={ Dense Passage Retrieval for Open-domain Question Answering }, 21 | author={ Karpukhin, Vladimir and Oguz, Barlas and Min, Sewon and Lewis, Patrick and Wu, Ledell and Edunov, Sergey and Chen, Danqi and Yih, Wen-tau }, 22 | journal={ arXiv preprint arXiv:2004.04906 }, 23 | year={2020} 24 | } 25 | ``` 26 | 27 | ## Content 28 | 1. [Installation](#installation) 29 | 2. [Download data](#download-data) 30 | 3. Instructions for Training & Testing 31 | * [DPR Retrieval](#dpr-retrieval) 32 | * [DPR Reader (Span Selection Model)](#dpr-reader-span-selection-model) 33 | * [SpanSeqGen (BART Reader)](#spanseqgen-bart-reader) 34 | * [Finetuning on AmbigQA](#finetuning-on-ambigqa) 35 | 4. [Results](#results) 36 | * [Results with less resources](#results-with-less-resources) 37 | 5. [Interactive Demo for Question Answering](#interactive) 38 | 6. [Pretrained model checkpoint](#need-preprocessed-data--pretrained-models--predictions) 39 | 40 | ## Installation 41 | Tested with python 3.6.12 42 | ``` 43 | pip install torch==1.1.0 44 | pip install git+https://github.com/huggingface/transformers.git@7b75aa9fa55bee577e2c7403301ed31103125a35 45 | pip install wget 46 | ``` 47 | 48 | Also, move `pycocoevalcap` to current directory 49 | ``` 50 | mv ../pycocoevalcap pycocoevalcap 51 | ``` 52 | 53 | ## Download data 54 | Let `data_dir` be a directory to save data. 55 | ``` 56 | python3 download_data.py --resource data.wikipedia_split.psgs_w100 --output_dir ${data_dir} # provided by original DPR 57 | python3 download_data.py --resource data.wikipedia_split.psgs_w100_20200201 --output_dir ${data_dir} # only for AmbigQA 58 | python3 download_data.py --resource data.nqopen --output_dir ${data_dir} 59 | python3 download_data.py --resource data.gold_passages_info.nq_train --output_dir ${data_dir} 60 | python3 download_data.py --resource data.ambigqa --output_dir ${data_dir} 61 | ``` 62 | 63 | ## DPR Retrieval 64 | For training DPR retrieval, please refer to the [original implementation][dpr-code]. This code is for taking checkpoint from the original implementation, and running inference. 65 | 66 | Step 1: Download DPR retrieval checkpoint provided by DPR original implementation. 67 | ``` 68 | python3 download_data.py --resource checkpoint.retriever.multiset.bert-base-encoder --output_dir ${dpr_data_dir} 69 | ``` 70 | 71 | Step 2: Run inference to obtain passage vectors. 72 | ``` 73 | for i in 0 1 2 3 4 5 6 7 8 9 ; do \ # for parallelization 74 | python3 cli.py --do_predict --bert_name bert-base-uncased --output_dir out/dpr --dpr_data_dir ${data_dir} --do_predict --task dpr --predict_batch_size 3200 --db_index $i ; \ 75 | done 76 | ``` 77 | - `--predict_batch_size` of 3200 is good for one 32gb GPU. 78 | - `--verbose` to print a progress bar 79 | - This script will tokenize passages in Wikipedia which will takes time. If you want to pre-tokenize first and then launch the job on gpus afterward, please do the following: (1) run the above command with `--do_prepro_only`, and (2) re-run the above command without `--do_prepro_only`. 80 | 81 | Each run will take around 1.5 hours with one 32 gpu. 82 | 83 | Step 3: Run inference to obtain question vectors and save the retrieval predictions. 84 | ``` 85 | python3 cli.py --bert_name ber-base-uncased --output_dir out/dpr --dpr_data_dir ${data_dir} --do_predict --task dpr --predict_batch_size 3200 --predict_file data/nqopen/{train|dev|test}.json 86 | ``` 87 | 88 | This script will print out recall rate and save the retrieval results as `out/dpr/{train|dev|test}_predictions.json`. 89 | 90 | Tip1: Running this for the first time regardless of the data split will create DPR index and save it, so that the next runs can reuse them. If you do not want to create DPR index multiple times, you can run on one data split first, and run the others afterward. If you have resource to run them in parallel, it may save time to just run all of them in parallel. 91 | 92 | Tip2: If you are fine with not printing the recall rate, you can specify `--skip_db_load` to save time. It will then print the recall to be 0, but the prediction file will be saved with no problem. 93 | 94 | ## DPR Reader (Span Selection Model) 95 | 96 | For training on NQ-open, run 97 | ``` 98 | python3 cli.py --do_train --task qa --output_dir out/nq-span-selection \ 99 | --dpr_data_dir ${data_dir} \ 100 | --train_file data/nqopen/train.json \ 101 | --predict_file data/nqopen/dev.json \ 102 | --bert_name {bert-base-uncased|bert-large-uncased} \ 103 | --train_batch_size 32 --train_M 32 --predict_batch_size 128 \ 104 | --eval_period 2000 --wait_step 10 105 | ``` 106 | 107 | - This script will save preprocessed input data so that it can re-load them once it is created. You might want to preprocess data before launching a job on GPUs. 108 | - `train_batch_size` is # of questions / batch, and `train_M` is # of passages / question. Thus, # of (question, passage) / batch is `train_batch_size * train_M`, which matters for GPU usage. With one 32gb GPU and bert-base-uncased, you can use `train_batch_size * train_M` of 128, as hyperparamters specified in the command above. 109 | - `eval_period` is an interval to test on the dev data. The script will only save the best checkpoint based on the dev data. If you prefer, you can specify `skip_inference` to skip inference on the dev data and save all checkpoints. You can then run the inference script (described next) on the dev data using every checkpoint, and choose the best checkpoint. 110 | - `wait_step` is the number of steps to wait since the best checkpoint, until the training is finished. 111 | 112 | When training is done, run the following command for prediction. 113 | ``` 114 | python3 cli.py --do_predict --task qa --output_dir out/nq-span-selection \ 115 | --dpr_data_dir ${data_dir} \ 116 | --predict_file data/nqopen/{dev|test}.json \ 117 | --bert_name {bert-base-uncased|bert-large-uncased} \ 118 | --predict_batch_size 32 119 | ``` 120 | This command runs predictions using `out/nq-span-selection/best-model.pt` by default. If you want to run predictions using another checkpoint, please specify its path by `--checkpoint`. 121 | 122 | 123 | ## SpanSeqGen (BART Reader) 124 | 125 | Note: this model is different from BART closed-book QA model (implemented [here][bart-closed-book-qa]), because this model reads DPR retrieved passages as input. 126 | 127 | First, tokenize passage vectors. 128 | ``` 129 | for i in 0 1 2 3 4 5 6 7 8 9 ; do \ # for parallelization 130 | python3 cli.py --bert_name bart-large --output_dir out/dpr --dpr_data_dir ${data_dir} --do_predict --do_prepro_only --task dpr --predict_batch_size 3200 --db_index $i \ 131 | done 132 | ``` 133 | 134 | Then, save passage selection from the trained DPR reader: 135 | ``` 136 | python3 cli.py --do_predict --task qa --output_dir out/nq-span-selection \ 137 | --dpr_data_dir ${data_dir} \ 138 | --predict_file data/nqopen/{train|dev|test}.json \ 139 | --bert_name {bert-base-uncased|bert-large-uncased} \ 140 | --predict_batch_size 32 --save_psg_sel_only 141 | ``` 142 | 143 | Now, train a model on NQ-open by: 144 | ``` 145 | python3 cli.py --do_train --task qa --output_dir out/nq-span-seq-gen \ 146 | --dpr_data_dir ${data_dir} \ 147 | --train_file data/nqopen/train.json \ 148 | --predict_file data/nqopen/dev.json \ 149 | --psg_sel_dir out/nq-span-selection \ 150 | --bert_name bart-large \ 151 | --discard_not_found_answers \ 152 | --train_batch_size 20 --predict_batch_size 40 \ 153 | --eval_period 2000 --wait_step 10 154 | ``` 155 | 156 | ## Finetuning on AmbigQA 157 | 158 | In order to experiment on AmbigQA, you can simply repeat the process with NQ-open, with only two differences - (i) specifying `--ambigqa` and `--wiki_2020` at several places and (ii) initialize weights from models trained on NQ-open. Step-by-step instructions are as follows. 159 | 160 | First, make DPR retrieval predictions using Wikipedia 2020. You can do so by simply repeating Step 2 and Step 3 of [DPR Retrieval](#dpr-retrieval) with `--wiki_2020` specified. 161 | ``` 162 | for i in 0 1 2 3 4 5 6 7 8 9 ; do \ # for parallelization 163 | python3 cli.py --do_predict --bert_name bert-base-uncased --output_dir out/dpr --dpr_data_dir ${data_dir} --do_predict --task dpr --predict_batch_size 3200 --db_index $i --wiki_2020 \ 164 | done 165 | python3 cli.py --bert_name ber-base-uncased --output_dir out/dpr --dpr_data_dir ${data_dir} --do_predict --task dpr --predict_batch_size 3200 --predict_file data/nqopen/{train|dev|test}.json --wiki_2020 166 | ``` 167 | 168 | In order to fine-tune DPR span selection model on AmbigQA, run the training command similar to NQ training command, but with `--ambigqa` and `--wiki2020` specified. We also used smaller `eval_period` as the dataset size is smaller. 169 | ``` 170 | python3 cli.py --do_train --task qa --output_dir out/ambignq-span-selection \ 171 | --dpr_data_dir ${data_dir} \ 172 | --train_file data/ambigqa/train_light.json \ 173 | --predict_file data/ambigqa/dev_light.json \ 174 | --bert_name {bert-base-uncased|bert-large-uncased} \ 175 | --train_batch_size 32 --train_M 32 --predict_batch_size 32 \ 176 | --eval_period 500 --wait_step 10 --topk_answer 3 --ambigqa --wiki_2020 177 | ``` 178 | 179 | In order to fine-tune SpanSeqGen on AmbigQA, first run the inference script over DPR to get highly ranked passages, just like we did on NQ. 180 | ``` 181 | python3 cli.py --do_predict --task qa --output_dir out/nq-span-selection \ 182 | --dpr_data_dir ${data_dir} \ 183 | --predict_file data/nqopen/{train|dev|test}.json \ 184 | --bert_name {bert-base-uncased|bert-large-uncased} \ 185 | --predict_batch_size 32 --save_psg_sel_only --wiki_2020 186 | ``` 187 | 188 | Next, train SpanSeqGen on AmbigNQ via the following command, which specifies `--ambigqa`, `--wiki_2020` and `--max_answer_length 25`. 189 | ``` 190 | python3 cli.py --do_train --task qa --output_dir out/ambignq-span-seq-gen \ 191 | --dpr_data_dir ${data_dir} \ 192 | --train_file data/ambigqa/train_light.json \ 193 | --predict_file data/ambigqa/dev_light.json \ 194 | --psg_sel_dir out/nq-span-selection \ 195 | --bert_name bart-large \ 196 | --discard_not_found_answers \ 197 | --train_batch_size 20 --predict_batch_size 40 \ 198 | --eval_period 500 --wait_step 10 --ambigqa --wiki_2020 --max_answer_length 25 199 | ``` 200 | 201 | ## Hyperparameter details 202 | 203 | **On NQ-open:** For BERT-base, we use `train_batch_size=32, train_M=32` (w/ eight 32GB gpus). For BERT-large, we use `train_batch_size=8, train_M=16` (w/ four 32GB gpus). For BART, we use `train_batch_size=24` (w/ four 32GB gpus). For others, we use default hyperparameters. 204 | 205 | **On AmbigQA:** We use `train_batch_size=8` for BERT-base and `train_batch_size=24` for BART. We use `learning_rate=5e-6` for both. 206 | 207 | ## Results 208 | 209 | | | NQ-open (dev) | NQ-open (test) | AmbigQA zero-shot (dev) | AmbigQA zero-shot (test) | AmbigQA (dev) | AmbigQA (test) | 210 | |---|---|---|---|---|---|---| 211 | |DPR (original implementation)| 39.8 | 41.5 | 35.2/26.5 | 30.1/23.2 | 37.1/28.4 | 32.3/24.8 | 212 | |DPR (this code)| 40.6 | 41.6 | 35.2/23.9 | 29.9/21.4 | 36.8/25.8 | 33.3/23.4 | 213 | |DPR (this code) w/ BERT-large| 43.2 | 44.3 | - | - | - | - | 214 | |SpanSeqGen (reported)| 42.0 | 42.2 | 36.4/24.8 | 30.8/20.7 | 39.7/29.3 | 33.5/24.5 | 215 | |SpanSeqGen (this code)| 43.1 | 45.0 | 37.4/26.1 | 33.2/22.6 | 40.3/29.2 | 35.5/25.8 | 216 | 217 | Two numbers on AmbigQA indicate F1 score on all questions and F1 score on questions with multiple QA pairs only. 218 | 219 | By default, the models are based on BERT-base and BART-large. 220 | 221 | *Note (as of 07/2020)*: Note that numbers are slightly different from those reported in the paper, because numbers in the paper are based on experiments with fairseq. We re-implemented the models with Huggingface Transformers, and were able to obtain similar/better numbers. We will update numbers in the paper of the next version. 222 | 223 | *Note*: There happen to be two versions of NQ answers which marginally differ in tokenization methods (e.g. `July 15 , 2020` vs. `July 15, 2020` or `2019 - 2020` vs. `2019--2020`). 224 | Research papers outside Google ([#1][dpr-paper], [#2][ambigqa-paper], [#3][hard-em], [#4][path-retriever], [#5][rag], [#6][colbert], [#7][fusion-decoder], [#8][graph-retriever]) have been using [this version](https://nlp.cs.washington.edu/ambigqa/data/nqopen.zip), and in June 2020 the original NQ/NQ-open authors release the [original version](https://github.com/efficientqa/nq-open) that have been used in research papers from Google ([#1][orqa], [#2][realm], [#3][t5qa]). 225 | We verified that the performance differences are marginal when applying simple postprocessing (e.g. `text.replace(" - ", "-").replace(" : ", ":")`). 226 | The numbers reported here as well as codes follow Google's original version. Compared to the previous version, performance difference is 40.6 (original) vs. 40.3 (previous) vs. 40.7 (union of two) on the dev set and 41.6 (original) vs. 41.7 (previous) vs. 41.8 (union of two) on the test set. 227 | Nonetheless, we advice to use the original version provided by Google in the future. 228 | 229 | ### Results with less resources 230 | 231 | The readers are not very sensitive to hyperparamters (`train_batch_size` and `train_M`). In case you want to experiment with less resources and want to check the reproducibility, here are our results depending on the number of 32gb GPUs. 232 | 233 | DPR with BERT-base: 234 | | Num. of 32gb GPU(s) | (`train_batch_size`, `train_M`) | NQ-open (dev) | NQ-open (test) | 235 | |---|---|---|---| 236 | | 1 | (8, 16) | 40.5 | 41.4 | 237 | | 2 | (16, 16) | 40.9 | 41.1 | 238 | | 4 | (16, 32) | 41.2 | 41.1 | 239 | | 8 | (32, 32) | 40.6 | 41.6 | 240 | 241 | DPR with BERT-large: 242 | | Num. of 32gb GPU(s) | (`train_batch_size`, `train_M`) | NQ-open (dev) | NQ-open (test) | 243 | |---|---|---|---| 244 | | 2 | (8, 8) | 42.0 | 43.4 | 245 | | 4 | (8, 16) | 43.2 | 44.3 | 246 | | 8 | (16, 16) | 42.2 | 43.2 | 247 | 248 | 249 | ## Interactive 250 | 251 | You can run DPR interactively as follows. 252 | 253 | ```python 254 | from InteractiveDPR import InteractiveDPR 255 | interactive_dpr = InteractiveDPR(dpr_data_dir=path_to_dpr_data_dir reader_checkpoint=path_do_reader_checkpoint) 256 | question = "When did harry potter and the sorcerer's stone movie come out?" 257 | print (interactive_dpr.predict(question, topk_answer=5, only_text=True)) 258 | ``` 259 | 260 | For details, please refer to `InteractiveDPR.py` 261 | 262 | 263 | ## Need preprocessed data / pretrained models / predictions? 264 | 265 | **DPR** 266 | - [DPR predictions on NQ](https://nlp.cs.washington.edu/ambigqa/models/nq-dpr.zip): contains passage idxs from wikipedia 20181220 for NQ train/dev/test 267 | - [DPR predictions on AmbigNQ](https://nlp.cs.washington.edu/ambigqa/models/ambigqa-dpr.zip): contains passage idxs from wikipedia 20200201 for AmbigNQ train/dev and NQ test (AmbigNQ test set is hidden and you need to submit NQ test predictions to submit to the leaderboard) 268 | 269 | **Question Answering** 270 | Click in order to download checkpoints: 271 | - [DPR Reader trained on NQ (387M)][checkpoint-nq-dpr] 272 | - [DPR Reader (w/ BERT-large) trained on NQ (1.2G)][checkpoint-nq-dpr-large] 273 | - [DPR Reader trained on AmbigNQ (387M)][checkpoint-ambignq-dpr] 274 | - [SpanSeqGen trained on NQ (1.8G)][checkpoint-nq-bart] 275 | - [SpanSeqGen trained on AmbigNQ (1.8G)][checkpoint-ambignq-bart] 276 | 277 | **Passage Reranking from DPR Reader** 278 | - [Reranking result (37M)](https://nlp.cs.washington.edu/ambigqa/models/reranking_results.zip): contain reranking for NQ train/dev/test (aligned with [Wikipedia 2018](https://github.com/shmsw25/AmbigQA/blob/master/codes/download_data.py#L26) and AmbigQA train/dev (aligned with [Wikipedia 2020](https://github.com/shmsw25/AmbigQA/blob/master/codes/download_data.py#L34)). 279 | 280 | For a sanity check, the recall accuracy should be as follows. (For AmbigQA, macro-average of recall.) 281 | 282 | | k | NQ train | NQ dev | NQ test | AmbigQA train | AmbigQA dev | 283 | |---|---|---|---|---|---| 284 | | 1 |80.4|59.8|59.4|58.3|51.8| 285 | | 5 |86.8|75.9|76.3|72.7|70.0| 286 | | 10 |87.8|79.9|80.8|76.2|74.8| 287 | | 100 |89.2|86.2|87.4|81.2|83.1| 288 | 289 | **Question Disambiguation** 290 | Coming soon! 291 | 292 | [ambigqa-paper]: https://arxiv.org/abs/2004.10645 293 | [dpr-paper]: https://arxiv.org/abs/2004.04906 294 | [dpr-code]: https://github.com/facebookresearch/DPR 295 | [bart-closed-book-qa]: https://github.com/shmsw25/bart-closed-book-qa 296 | [hf]: https://huggingface.co/transformers/ 297 | [pytorch]: https://pytorch.org/ 298 | 299 | [hard-em]: https://arxiv.org/abs/1909.04849 300 | [path-retriever]: https://arxiv.org/abs/1911.10470 301 | [rag]: https://arxiv.org/abs/2005.11401 302 | [fusion-decoder]: https://arxiv.org/abs/2007.01282 303 | [colbert]: https://arxiv.org/abs/2007.00814 304 | [graph-retriever]: https://arxiv.org/abs/1911.03868 305 | 306 | [orqa]: https://arxiv.org/abs/1906.00300 307 | [realm]: https://arxiv.org/abs/2002.08909 308 | [t5qa]: https://arxiv.org/abs/2002.08910 309 | 310 | [checkpoint-nq-dpr]: https://nlp.cs.washington.edu/ambigqa/models/nq-bert-base-uncased-32-32-0.zip 311 | [checkpoint-nq-dpr-large]: https://nlp.cs.washington.edu/ambigqa/models/nq-bert-large-uncased-16-16-0.zip 312 | [checkpoint-ambignq-dpr]: https://nlp.cs.washington.edu/ambigqa/models/ambignq-bert-base-uncased-8-32-0.zip 313 | [checkpoint-nq-bart]: https://nlp.cs.washington.edu/ambigqa/models/nq-bart-large-24-0.zip 314 | [checkpoint-ambignq-bart]: https://nlp.cs.washington.edu/ambigqa/models/ambignq-bart-large-12-0.zip 315 | 316 | 317 | 318 | 319 | 320 | 321 | -------------------------------------------------------------------------------- /codes/QGData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import gzip 4 | import re 5 | import pickle as pkl 6 | import string 7 | import numpy as np 8 | from tqdm import tqdm 9 | from collections import Counter, defaultdict 10 | 11 | import torch 12 | from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler 13 | 14 | from QAData import QAData, AmbigQAData 15 | from DataLoader import MySimpleQADataset, MySimpleQADatasetForPair, MyDataLoader 16 | from util import decode_span_batch 17 | 18 | # for evaluation 19 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 20 | from ambigqa_evaluate_script import normalize_answer, get_exact_match, get_f1, get_qg_metrics 21 | from pycocoevalcap.bleu.bleu import Bleu 22 | 23 | class QGData(QAData): 24 | def __init__(self, logger, args, data_path, is_training, passages=None): 25 | super(QGData, self).__init__(logger, args, data_path, is_training, passages) 26 | self.qg_tokenizer = PTBTokenizer() 27 | self.metric = "Bleu" 28 | if not self.is_training: 29 | self.qg_tokenizer = PTBTokenizer() 30 | 31 | def load_dpr_data(self): 32 | dpr_retrieval_path = "out/dpr/{}_predictions.json".format( 33 | self.data_type+"_20200201" if self.args.wiki_2020 else self.data_type) 34 | postfix = self.tokenizer.__class__.__name__.replace("zer", "zed") 35 | dpr_tokenized_path = dpr_retrieval_path.replace(".json", "_{}_qg.json".format(postfix)) 36 | assert "Bart" in postfix 37 | return self.load_dpr_data_bart(dpr_retrieval_path, dpr_tokenized_path) 38 | 39 | def load_dpr_data_bart(self, dpr_retrieval_path, dpr_tokenized_path): 40 | self.logger.info("{}\n{}".format(dpr_retrieval_path, dpr_tokenized_path)) 41 | if os.path.exists(dpr_tokenized_path): 42 | self.logger.info("Loading DPR data from {}".format(dpr_tokenized_path)) 43 | with open(dpr_tokenized_path, "r") as f: 44 | self.tokenized_data = json.load(f) 45 | else: 46 | self.logger.info("Start processing DPR data") 47 | if self.passages.tokenized_data is None: 48 | self.passages.load_tokenized_data("bart", all=True) 49 | if "train_for_inference" not in dpr_retrieval_path: 50 | dpr_retrieval_path = dpr_retrieval_path.replace("train", "train_for_inference") 51 | with open(dpr_retrieval_path, "r") as f: 52 | dpr_passages = json.load(f) 53 | assert len(dpr_passages)==len(self) 54 | assert self.args.psg_sel_dir is not None 55 | psg_sel_fn = os.path.join(self.args.psg_sel_dir, 56 | "{}{}_psg_sel.json".format( 57 | self.data_type.replace("train", "train_for_inference") if "for_inference" not in self.data_type else self.data_type, 58 | "_20200201" if self.args.wiki_2020 else "")) 59 | self.logger.info("Loading passage selection from DPR reader: {}".format(psg_sel_fn)) 60 | with open(psg_sel_fn, "r") as f: 61 | fg_passages = json.load(f) 62 | assert len(fg_passages)==len(dpr_passages) 63 | dpr_passages = [[psgs[i] for i in fg_psgs] for psgs, fg_psgs in zip(dpr_passages, fg_passages)] 64 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, metadata = self.tokenized_data 65 | assert len(dpr_passages)==len(input_ids)==len(attention_mask) 66 | bos_token_id = self.tokenizer.bos_token_id 67 | 68 | def _get_tokenized_answer(idx): 69 | tokens = self.tokenized_data[2][idx] 70 | if 0 in self.tokenized_data[3][idx]: 71 | tokens = tokens[:self.tokenized_data[3][idx].index(0)] 72 | assert tokens[0]==tokens[1]==self.tokenizer.bos_token_id and tokens[-1]==self.tokenizer.eos_token_id 73 | return tokens[2:-1] 74 | 75 | def _included(tokens, curr_input_ids, end_of_answer): 76 | for i in range(end_of_answer, 1024-len(tokens)+1): 77 | if curr_input_ids[i:i+len(tokens)]==tokens: 78 | return True 79 | return False 80 | 81 | has_valid = [] 82 | new_input_ids, new_attention_mask, new_decoder_input_ids, new_decoder_attention_mask, new_metadata = [], [], [], [], [] 83 | for idx, (curr_input_ids, curr_attention_mask, curr_metadata, dpr_ids) in enumerate(zip( 84 | input_ids, attention_mask, metadata, dpr_passages)): 85 | dpr_input_ids = [self.passages.tokenized_data["input_ids"][_id] for _id in dpr_ids] 86 | dpr_attention_mask = [self.passages.tokenized_data["attention_mask"][_id] for _id in dpr_ids] 87 | 88 | # create multiple inputs 89 | answer_input_ids_list, answer_attention_mask_list, is_valid_list = [], [], [] 90 | for answer_idx in range(*curr_metadata): 91 | end_of_answer = decoder_input_ids[answer_idx].index(self.tokenizer.eos_token_id)+1 92 | answer_input_ids = decoder_input_ids[answer_idx][:end_of_answer] 93 | answer_attention_mask = decoder_attention_mask[answer_idx][:end_of_answer] 94 | offset = 0 95 | while len(answer_input_ids)<1024: 96 | assert dpr_input_ids[offset][0] == bos_token_id 97 | assert len(dpr_input_ids[offset])==len(dpr_attention_mask[offset]) 98 | assert np.sum(dpr_attention_mask[offset])==len(dpr_attention_mask[offset]) 99 | answer_input_ids += dpr_input_ids[offset][1:] 100 | answer_attention_mask += dpr_attention_mask[offset][1:] 101 | offset += 1 102 | assert len(answer_input_ids)==len(answer_attention_mask) 103 | answer_input_ids_list.append(answer_input_ids[:1024]) 104 | answer_attention_mask_list.append(answer_attention_mask[:1024]) 105 | is_valid_list.append(_included( 106 | decoder_input_ids[answer_idx][2:end_of_answer-1], 107 | answer_input_ids, end_of_answer)) 108 | 109 | has_valid.append(any(is_valid_list)) 110 | if self.is_training: 111 | if not any(is_valid_list): 112 | is_valid_list = [True for _ in is_valid_list] 113 | new_metadata.append((len(new_input_ids), len(new_input_ids)+sum(is_valid_list))) 114 | new_input_ids += [answer_input_ids for answer_input_ids, is_valid in 115 | zip(answer_input_ids_list, is_valid_list) if is_valid] 116 | new_attention_mask += [answer_attention_mask for answer_attention_mask, is_valid in 117 | zip(answer_attention_mask_list, is_valid_list) if is_valid] 118 | else: 119 | index = is_valid_list.index(True) if any(is_valid_list) else 0 120 | new_metadata.append((len(new_input_ids), len(new_input_ids)+1)) 121 | new_input_ids.append(answer_input_ids_list[index]) 122 | new_attention_mask.append(answer_attention_mask_list[index]) 123 | new_decoder_input_ids.append(curr_input_ids) 124 | new_decoder_attention_mask.append(curr_attention_mask) 125 | 126 | assert len(new_input_ids)==len(new_attention_mask)==new_metadata[-1][-1] 127 | self.tokenized_data = [new_input_ids, new_attention_mask, new_decoder_input_ids, new_decoder_attention_mask, new_metadata] 128 | with open(dpr_tokenized_path, "w") as f: 129 | json.dump(self.tokenized_data, f) 130 | self.logger.info("Finish saving tokenized DPR data at {}".format(dpr_tokenized_path)) 131 | self.logger.info("%.1f%% questions have at least one answer mentioned in passages" % (100*np.mean(has_valid))) 132 | 133 | 134 | def load_dataset(self, tokenizer, do_return=False): 135 | if self.tokenized_data is None: 136 | self.load_tokenized_data(tokenizer) 137 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, metadata = self.tokenized_data 138 | self.dataset = MySimpleQADataset(input_ids, 139 | attention_mask, 140 | decoder_input_ids if self.is_training or self.args.nq_answer_as_prefix else None, 141 | decoder_attention_mask if self.is_training or self.args.nq_answer_as_prefix else None, 142 | in_metadata=metadata, 143 | out_metadata=None, 144 | is_training=self.is_training, 145 | answer_as_prefix=self.args.nq_answer_as_prefix) 146 | self.logger.info("Loaded {} examples from {} data".format(len(self.dataset), self.data_type)) 147 | 148 | if do_return: 149 | return self.dataset 150 | 151 | def load_dataloader(self, do_return=False): 152 | self.dataloader = MyDataLoader(self.args, self.dataset, is_training=self.is_training) 153 | if do_return: 154 | return self.dataloader 155 | 156 | def evaluate(self, predictions, n_paragraphs=None): 157 | assert len(predictions)==len(self), (len(predictions), len(self)) 158 | bleu = [] 159 | 160 | # first, tokenize 161 | data_to_tokenize = {} 162 | for i, (d, pred) in enumerate(zip(self.data, predictions)): 163 | data_to_tokenize["ref.{}".format(i)] = [{"caption": d["question"]}] 164 | data_to_tokenize["gen.{}".format(i)] = [{"caption": pred if type(pred)==str else pred[0]}] 165 | all_tokens = self.qg_tokenizer.tokenize(data_to_tokenize) 166 | for i in range(len(self.data)): 167 | reference = {"sent": [normalize_answer(text) for text in all_tokens["ref.{}".format(i)]]} 168 | generated = {"sent": [normalize_answer(text) for text in all_tokens["gen.{}".format(i)]]} 169 | bleu.append(Bleu(4).compute_score(reference, generated)[0][-1]) 170 | return np.mean(bleu) 171 | 172 | def save_predictions(self, predictions): 173 | assert len(predictions)==len(self), (len(predictions), len(self)) 174 | save_path = os.path.join(self.args.output_dir, "{}{}_predictions.json".format( 175 | self.data_type if self.args.prefix is None else self.args.prefix, 176 | "_20200201" if self.args.wiki_2020 and not self.args.ambigqa else "")) 177 | with open(save_path, "w") as f: 178 | json.dump(predictions, f) 179 | self.logger.info("Saved prediction in {}".format(save_path)) 180 | 181 | class AmbigQGData(AmbigQAData, QAData): 182 | def __init__(self, logger, args, data_path, is_training, passages=None): 183 | super(AmbigQGData, self).__init__(logger, args, data_path, is_training, passages) 184 | 185 | with open("/".join(data_path.split("/")[:-2]) + "/nqopen/{}.json".format(self.data_type), "r") as f: 186 | orig_data = json.load(f) 187 | id_to_orig_idx = {d["id"]:i for i, d in enumerate(orig_data)} 188 | 189 | self.ref_questions = [] 190 | self.ref_answers = [] 191 | # we will only consider questions with multiple answers 192 | for i, d in enumerate(self.data): 193 | if not all([ann["type"]=="multipleQAs" for ann in d["annotations"]]): 194 | self.ref_questions.append(None) 195 | self.ref_answers.append(None) 196 | continue 197 | questions, answers = [], [] 198 | for annotation in d["annotations"]: 199 | questions.append([[q.strip() for q in pair["question"].split("|")] for pair in annotation["qaPairs"]]) 200 | answers.append([list(set(pair["answer"])) for pair in annotation["qaPairs"]]) 201 | assert type(answers)==list and \ 202 | all([type(answer)==list for answer in answers]) and \ 203 | all([type(_a)==str for answer in answers+questions for _answer in answer for _a in _answer]) 204 | self.ref_questions.append(questions) 205 | self.ref_answers.append(answers) 206 | self.data[i]["orig_idx"] = id_to_orig_idx[d["id"]] 207 | 208 | 209 | self.SEP = "" 210 | self.qg_tokenizer = PTBTokenizer() 211 | self.metric = "EDIT-F1" 212 | if not self.is_training: 213 | self.qg_tokenizer = PTBTokenizer() 214 | 215 | # override 216 | def load_dpr_data(self): 217 | dpr_retrieval_path = "out/dpr/{}_predictions.json".format( 218 | self.data_type+"_20200201" if self.args.wiki_2020 else self.data_type) 219 | postfix = self.tokenizer.__class__.__name__.replace("zer", "zed") 220 | assert "Bart" in postfix 221 | dpr_tokenized_path = dpr_retrieval_path.replace("predictions.json", "ambigqa_predictions_{}_qg.json".format(postfix)) 222 | self.load_dpr_data_bart(dpr_retrieval_path, dpr_tokenized_path) 223 | 224 | # in attention_mask, 1 means answer + passages, 2 means prompt, 3 means other answers 225 | do_include_prompt=True 226 | do_include_others=True 227 | new_input_ids, new_attention_mask = [], [] 228 | for input_ids, attention_mask in zip(self.tokenized_data[0], self.tokenized_data[1]): 229 | _input_ids = [_id for _id, mask in zip(input_ids, attention_mask) 230 | if mask==1 or (do_include_prompt and mask==2) or (do_include_others and mask==3)] 231 | _attention_mask = [1 for mask in attention_mask 232 | if mask==1 or (do_include_prompt and mask==2) or (do_include_others and mask==3)] 233 | assert len(_input_ids)==len(_attention_mask) 234 | while len(_input_ids)<1024: 235 | _input_ids.append(self.tokenizer.pad_token_id) 236 | _attention_mask.append(0) 237 | new_input_ids.append(_input_ids[:1024]) 238 | new_attention_mask.append(_attention_mask[:1024]) 239 | self.tokenized_data[0] = new_input_ids 240 | self.tokenized_data[1] = new_attention_mask 241 | 242 | 243 | # override 244 | def load_dpr_data_bart(self, dpr_retrieval_path, dpr_tokenized_path): 245 | 246 | self.logger.info(dpr_tokenized_path) 247 | 248 | if self.is_training and self.args.consider_order_for_multiple_answers: 249 | dpr_tokenized_path = dpr_tokenized_path.replace(".json", "_ordered.json") 250 | 251 | if os.path.exists(dpr_tokenized_path): 252 | self.logger.info("Loading DPR data from {}".format(dpr_tokenized_path)) 253 | with open(dpr_tokenized_path, "r") as f: 254 | self.tokenized_data = json.load(f) 255 | return 256 | 257 | import itertools 258 | self.logger.info("Start processing DPR data from {}".format(dpr_retrieval_path)) 259 | if self.passages.tokenized_data is None: 260 | self.passages.load_tokenized_data("bart", all=True) 261 | 262 | with open(dpr_retrieval_path.format(self.data_type).replace("train", "train_for_inference"), "r") as f: 263 | dpr_passages = json.load(f) 264 | assert self.args.psg_sel_dir is not None 265 | 266 | psg_sel_fn = os.path.join(self.args.psg_sel_dir, 267 | "{}{}_psg_sel.json".format( 268 | self.data_type.replace("train", "train_for_inference"), 269 | "_20200201" if self.args.wiki_2020 else "")) 270 | self.logger.info("Loading passage selection from DPR reader: {}".format(psg_sel_fn)) 271 | with open(psg_sel_fn, "r") as f: 272 | fg_passages = json.load(f) 273 | assert len(fg_passages)==len(dpr_passages) 274 | dpr_passages = [[psgs[i] for i in fg_psgs] for psgs, fg_psgs in zip(dpr_passages, fg_passages)] 275 | 276 | # added to convert original DPR data to AmbigQA DPR data 277 | dpr_passages = [dpr_passages[d["orig_idx"]] for d in self.data] 278 | 279 | assert len(dpr_passages)==len(self) 280 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, metadata = self.tokenized_data 281 | assert len(dpr_passages)==len(input_ids)==len(attention_mask)==len(metadata) 282 | bos_token_id = self.tokenizer.bos_token_id 283 | eos_token_id = self.tokenizer.eos_token_id 284 | pad_token_id = self.tokenizer.pad_token_id 285 | sep_token_id = self.tokenizer.convert_tokens_to_ids(self.SEP) 286 | assert type(bos_token_id)==type(eos_token_id)==type(sep_token_id)==int 287 | 288 | 289 | def _get_tokenized_answer(idx): 290 | tokens = decoder_input_ids[idx] 291 | if 0 in decoder_attention_mask[idx]: 292 | tokens = tokens[:decoder_attention_mask[idx].index(0)] 293 | assert tokens[0]==tokens[1]==bos_token_id and tokens[-1]==eos_token_id 294 | return tokens[2:-1] 295 | 296 | def _included(tokens, curr_input_ids): 297 | for i in range(len(curr_input_ids)+1): 298 | if curr_input_ids[i:i+len(tokens)]==tokens: 299 | return True 300 | return False 301 | 302 | new_input_ids, new_attention_mask = [], [] 303 | new_output, new_metadata = [], [] 304 | chosen_list = [] 305 | for idx, (curr_input_ids, curr_attention_mask, dpr_ids) in tqdm(enumerate( 306 | zip(input_ids, attention_mask, dpr_passages))): 307 | if self.ref_questions[idx] is None: 308 | continue 309 | 310 | end_of_question = curr_input_ids.index(self.tokenizer.eos_token_id)+1 311 | q_input_ids = curr_input_ids[:end_of_question] 312 | 313 | p_input_ids, p_attention_mask = [], [] 314 | dpr_input_ids = [self.passages.tokenized_data["input_ids"][_id] for _id in dpr_ids] 315 | dpr_attention_mask = [self.passages.tokenized_data["attention_mask"][_id] for _id in dpr_ids] 316 | offset = 0 317 | while len(p_input_ids)<1024: 318 | assert dpr_input_ids[offset][0] == bos_token_id 319 | assert len(dpr_input_ids[offset])==len(dpr_attention_mask[offset]) 320 | assert np.sum(dpr_attention_mask[offset])==len(dpr_attention_mask[offset]) 321 | p_input_ids += dpr_input_ids[offset][1:] 322 | p_attention_mask += dpr_attention_mask[offset][1:] 323 | 324 | tokenized_ref_answers_list, is_valid_list, n_missing_list = [], [], [] 325 | for ref_questions, ref_answers, ref_metadata in zip(self.ref_questions[idx], 326 | self.ref_answers[idx], 327 | metadata[idx]): 328 | # ref_metadata: [(0, 1), (1, 4)] 329 | assert type(ref_metadata[0][0])==int 330 | assert [len(ref_answer)==end-start for ref_answer, (start, end) 331 | in zip(ref_answers, ref_metadata)] 332 | tokenized_ref_answers = [[_get_tokenized_answer(i) for i in range(*m)] for m in ref_metadata] 333 | is_valid = [[_included(tokens, p_input_ids) for tokens in _tokens] for _tokens in tokenized_ref_answers] 334 | n_missing = np.sum([not any(v) for v in is_valid]) 335 | tokenized_ref_answers_list.append(tokenized_ref_answers) 336 | is_valid_list.append(is_valid) 337 | n_missing_list.append(n_missing) 338 | 339 | min_n_missing = np.min(n_missing_list) 340 | annotation_indices = [ann_idx for ann_idx in range(len(n_missing_list)) 341 | if n_missing_list[ann_idx]==min_n_missing] 342 | 343 | def _form_data(annotation_idx): 344 | ref_questions = self.ref_questions[idx][annotation_idx] 345 | ref_answers = self.ref_answers[idx][annotation_idx] 346 | tokenized_ref_answers = tokenized_ref_answers_list[annotation_idx] 347 | assert len(ref_questions)==len(ref_answers)==len(tokenized_ref_answers)==len(is_valid_list[annotation_idx]) 348 | final_ref_questions, final_ref_answers = [], [] 349 | chosen_indices = [] 350 | for (ref_question, ref_answer, tok_ref_answer, is_valid) in \ 351 | zip(ref_questions, ref_answers, tokenized_ref_answers, is_valid_list[annotation_idx]): 352 | assert len(ref_answer)==len(tok_ref_answer)==len(is_valid) 353 | chosen_idx = is_valid.index(True) if True in is_valid else 0 354 | chosen_indices.append(chosen_idx) 355 | final_ref_questions.append(ref_question[0]) 356 | final_ref_answers.append(tok_ref_answer[chosen_idx]) 357 | for i, final_ref_question in enumerate(final_ref_questions): 358 | input_ids = [bos_token_id, bos_token_id] + final_ref_answers[i] 359 | attention_mask = [1 for _ in input_ids] 360 | input_ids += [sep_token_id] + q_input_ids 361 | attention_mask += [2 for _ in range(len(q_input_ids)+1)] 362 | for j, answer in enumerate(final_ref_answers): 363 | if j==i: continue 364 | input_ids += [sep_token_id] + answer 365 | attention_mask += [3 for _ in range(len(answer)+1)] 366 | input_ids += p_input_ids 367 | attention_mask += p_attention_mask 368 | assert len(input_ids)==len(attention_mask) 369 | new_input_ids.append(input_ids) 370 | new_attention_mask.append(attention_mask) 371 | new_output.append(final_ref_question) 372 | return chosen_indices 373 | 374 | start = len(new_output) 375 | if self.is_training: 376 | start = len(new_output) 377 | for annotation_idx in annotation_indices: 378 | _form_data(annotation_idx) 379 | else: 380 | annotation_idx = annotation_indices[0] 381 | chosen_indices = _form_data(annotation_idx) 382 | chosen_list.append({"annotation_idx": annotation_idx, 383 | "answer_idx": chosen_indices}) 384 | assert len(new_output)-start > 0 385 | new_metadata.append((start, len(new_output))) 386 | 387 | if self.is_training: 388 | new_output = self.tokenizer.batch_encode_plus(new_output, max_length=32, pad_to_max_length=True) 389 | new_decoder_input_ids = new_output["input_ids"] 390 | new_decoder_attention_mask = new_output["attention_mask"] 391 | else: 392 | new_decoder_input_ids, new_decoder_attention_mask = None, None 393 | 394 | self.tokenized_data = [new_input_ids, new_attention_mask, 395 | new_decoder_input_ids, new_decoder_attention_mask, new_metadata] 396 | if not self.is_training: 397 | self.tokenized_data.append(chosen_list) 398 | with open(dpr_tokenized_path, "w") as f: 399 | json.dump(self.tokenized_data, f) 400 | self.logger.info("Finish saving tokenized DPR data") 401 | 402 | 403 | def load_dataset(self, tokenizer, do_return=False): 404 | if self.tokenized_data is None: 405 | self.load_tokenized_data(tokenizer) 406 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, metadata = self.tokenized_data[:5] 407 | self.dataset = MySimpleQADatasetForPair(input_ids, 408 | attention_mask, 409 | decoder_input_ids if self.is_training or self.args.nq_answer_as_prefix else None, 410 | decoder_attention_mask if self.is_training or self.args.nq_answer_as_prefix else None, 411 | metadata=metadata, 412 | is_training=self.is_training) 413 | self.logger.info("Loaded {} examples from {} data".format(len(self.dataset), self.data_type)) 414 | 415 | if do_return: 416 | return self.dataset 417 | 418 | 419 | # override 420 | def evaluate(self, predictions, n_paragraphs=None): 421 | metadata, chosen_list = self.tokenized_data[-2:] 422 | assert np.sum([ref_questions is not None for ref_questions in self.ref_questions])==len(metadata) 423 | assert len(predictions)==metadata[-1][-1] and len(chosen_list)==len(metadata) 424 | # first, tokenize 425 | data_to_tokenize = {} 426 | indices = [] 427 | offset = 0 428 | for i, (d, ref_questions, ref_answers) in enumerate(zip(self.data, self.ref_questions, self.ref_answers)): 429 | if ref_questions is None: continue 430 | data_to_tokenize["prompt.{}".format(i)] = [{"caption": d["question"]}] 431 | ann_idx = chosen_list[offset]["annotation_idx"] 432 | answer_idx = chosen_list[offset]["answer_idx"] 433 | start, end = metadata[offset] 434 | assert len(ref_questions[ann_idx])==len(ref_answers[ann_idx])==len(answer_idx)==end-start 435 | indices.append((i, len(answer_idx))) 436 | for j, (ref_question, pred, a_idx) in enumerate( 437 | zip(ref_questions[ann_idx], predictions[start:end], answer_idx)): 438 | assert type(ref_question)==list 439 | data_to_tokenize["gen.{}.{}".format(i, j)] = [{"caption": pred if type(pred)==str else pred[0]}] 440 | data_to_tokenize["ref.{}.{}".format(i, j)] = [{"caption": ref} for ref in ref_question] 441 | offset += 1 442 | 443 | assert offset==len(metadata) 444 | all_tokens = self.qg_tokenizer.tokenize(data_to_tokenize) 445 | 446 | def _get(key): 447 | return {'sent': [normalize_answer(value) for value in all_tokens[key]]} 448 | 449 | bleu, f1s = [], [] 450 | def _get_qg_metrics(gens, refs, prompt, metrics): 451 | return np.mean([get_qg_metrics(gen, ref, prompt, metrics) for gen, ref in zip(gens, refs)]) 452 | 453 | for (i, n) in indices: 454 | curr_bleu, curr_f1s = [], [] 455 | for j in range(n): 456 | e = get_qg_metrics(_get("gen.{}.{}".format(i, j)), 457 | _get("ref.{}.{}".format(i, j)), 458 | _get("prompt.{}".format(i)), 459 | metrics=["bleu4", "edit-f1"]) 460 | curr_bleu.append(e["bleu4"]) 461 | curr_f1s.append(e["edit-f1"]) 462 | bleu.append(np.mean(curr_bleu)) 463 | f1s.append(np.mean(curr_f1s)) 464 | self.logger.info("BLEU=%.1f\tEDIT-F1=%.1f" % (100*np.mean(bleu), 100*np.mean(f1s))) 465 | return np.mean(f1s) 466 | 467 | 468 | -------------------------------------------------------------------------------- /codes/QAData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import gzip 4 | import re 5 | import pickle as pkl 6 | import string 7 | import numpy as np 8 | from tqdm import tqdm 9 | from collections import Counter, defaultdict 10 | 11 | import torch 12 | from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler 13 | 14 | from DataLoader import MySimpleQADataset, MyQADataset, MyDataLoader 15 | from util import decode_span_batch 16 | 17 | # for evaluation 18 | from ambigqa_evaluate_script import normalize_answer, get_exact_match, get_f1, get_qg_metrics 19 | from pycocoevalcap.bleu.bleu import Bleu 20 | 21 | class QAData(object): 22 | 23 | def __init__(self, logger, args, data_path, is_training, passages=None): 24 | self.data_path = data_path 25 | self.passages = passages 26 | if args.debug: 27 | self.data_path = data_path.replace("train", "dev") 28 | if "test" in self.data_path: 29 | self.data_type = "test" 30 | elif "dev" in self.data_path: 31 | self.data_type = "dev" 32 | elif "train" in self.data_path: 33 | self.data_type = "train" if is_training else "train_for_inference" 34 | else: 35 | raise NotImplementedError() 36 | 37 | with open(self.data_path, "r") as f: 38 | self.data = json.load(f) 39 | 40 | if "data" in self.data: 41 | self.data = self.data["data"] 42 | 43 | if "answers" in self.data[0]: 44 | self.data = [{"id": d["id"], "question": d["question"], "answer": d["answers"]} for d in self.data] 45 | 46 | if args.debug: 47 | self.data = self.data[:40] 48 | assert type(self.data)==list 49 | 50 | if not args.ambigqa: 51 | id2answer_path = os.path.join("/".join(self.data_path.split("/")[:-1]), 52 | "{}_id2answers.json".format(self.data_type.replace("train_for_inference", "train"))) 53 | with open(id2answer_path, "r") as f: 54 | id2answers = json.load(f) 55 | for i, d in enumerate(self.data): 56 | if is_training: 57 | self.data[i]["answer"] += id2answers[d["id"]] 58 | else: 59 | self.data[i]["answer"] = id2answers[d["id"]] 60 | 61 | self.is_training = is_training 62 | self.load = not args.debug 63 | self.logger = logger 64 | self.args = args 65 | self.metric = "EM" 66 | self.tokenizer = None 67 | self.tokenized_data = None 68 | self.dpr_tokenized_data = None 69 | self.dataset = None 70 | self.dataloader = None 71 | 72 | def __len__(self): 73 | return len(self.data) 74 | 75 | def get_answers(self): 76 | return [d["answer"] for d in self.data] 77 | 78 | def decode(self, tokens): 79 | if type(tokens[0])==list: 80 | return [self.decode(_tokens) for _tokens in tokens] 81 | return self.tokenizer.decode(tokens, 82 | skip_special_tokens=True, 83 | clean_up_tokenization_spaces=True).strip().replace(" - ", "-").replace(" : ", ":") 84 | 85 | def decode_span(self, outputs, n_paragraphs): 86 | assert len(self.data)==len(self.tokenized_data["positive_input_ids"])==\ 87 | len(self.tokenized_data["positive_input_mask"])==len(outputs), \ 88 | (len(self.data), len(self.tokenized_data["positive_input_ids"]), 89 | len(self.tokenized_data["positive_input_mask"]), len(outputs)) 90 | return decode_span_batch(list(zip(self.tokenized_data["positive_input_ids"], 91 | self.tokenized_data["positive_input_mask"])), 92 | outputs, 93 | tokenizer=self.tokenizer, 94 | max_answer_length=self.args.max_answer_length, 95 | n_paragraphs=n_paragraphs, 96 | topk_answer=self.args.topk_answer, 97 | verbose=self.args.verbose, 98 | n_jobs=self.args.n_jobs, 99 | save_psg_sel_only=self.args.save_psg_sel_only) 100 | 101 | def flatten(self, answers): 102 | new_answers, metadata = [], [] 103 | for answer in answers: 104 | assert type(answer)==list 105 | metadata.append((len(new_answers), len(new_answers)+len(answer))) 106 | new_answers += answer 107 | return new_answers, metadata 108 | 109 | def load_tokenized_data(self, tokenizer): 110 | self.tokenizer = tokenizer 111 | postfix = tokenizer.__class__.__name__.replace("zer", "zed") 112 | assert "Bart" in postfix or "Bert" in postfix or "Albert" in postfix 113 | assert ("Bart" in postfix and self.args.append_another_bos) \ 114 | or ("Bart" not in postfix and not self.args.append_another_bos) 115 | preprocessed_path = os.path.join( 116 | "/".join(self.data_path.split("/")[:-1]), 117 | self.data_path.split("/")[-1].replace( 118 | ".tsv" if self.data_path.endswith(".tsv") else ".json", 119 | "{}{}-{}.json".format( 120 | "-uncased" if self.args.do_lowercase else "", 121 | "-xbos" if self.args.append_another_bos else "", 122 | postfix))) 123 | if self.load and os.path.exists(preprocessed_path): 124 | self.logger.info("Loading pre-tokenized data from {}".format(preprocessed_path)) 125 | with open(preprocessed_path, "r") as f: 126 | tokenized_data = json.load(f) 127 | else: 128 | print ("Start tokenizing...") 129 | questions = [d["question"] if d["question"].endswith("?") else d["question"]+"?" 130 | for d in self.data] 131 | answers = [d["answer"] for d in self.data] 132 | answers, metadata = self.flatten(answers) 133 | if self.args.do_lowercase: 134 | questions = [question.lower() for question in questions] 135 | answers = [answer.lower() for answer in answers] 136 | if self.args.append_another_bos: 137 | questions = [" "+question for question in questions] 138 | answers = [" " +answer for answer in answers] 139 | question_input = tokenizer.batch_encode_plus(questions, 140 | pad_to_max_length=True, 141 | max_length=32) 142 | answer_input = tokenizer.batch_encode_plus(answers, 143 | pad_to_max_length="Bart" in postfix, 144 | max_length=20) 145 | input_ids, attention_mask = question_input["input_ids"], question_input["attention_mask"] 146 | decoder_input_ids, decoder_attention_mask = answer_input["input_ids"], answer_input["attention_mask"] 147 | tokenized_data = [input_ids, attention_mask, 148 | decoder_input_ids, decoder_attention_mask, metadata] 149 | if self.load: 150 | with open(preprocessed_path, "w") as f: 151 | json.dump(tokenized_data, f) 152 | self.tokenized_data = tokenized_data 153 | if not self.args.dpr: 154 | self.load_dpr_data() 155 | 156 | def load_dpr_data(self): 157 | data_type = self.data_type.replace("train_for_inference", "train") 158 | dpr_retrieval_path = "out/dpr/{}_predictions.json".format( 159 | data_type+"_20200201" if self.args.wiki_2020 else data_type) 160 | postfix = self.tokenizer.__class__.__name__.replace("zer", "zed") 161 | dpr_tokenized_path = "out/dpr/{}_predictions_{}.json".format( 162 | self.data_type+"_20200201" if self.args.wiki_2020 else self.data_type, postfix) 163 | if "Bart" in postfix: 164 | return self.load_dpr_data_bart(dpr_retrieval_path, dpr_tokenized_path) 165 | elif "Bert" in postfix or "Albert" in postfix: 166 | return self.load_dpr_data_bert(dpr_retrieval_path, dpr_tokenized_path) 167 | else: 168 | raise NotImplementedError() 169 | 170 | def load_dpr_data_bart(self, dpr_retrieval_path, dpr_tokenized_path): 171 | self.logger.info("{}\n{}".format(dpr_retrieval_path, dpr_tokenized_path)) 172 | if os.path.exists(dpr_tokenized_path): 173 | self.logger.info("Loading DPR data from {}".format(dpr_tokenized_path)) 174 | with open(dpr_tokenized_path, "r") as f: 175 | input_ids, attention_mask = json.load(f) 176 | else: 177 | with open(dpr_retrieval_path, "r") as f: 178 | dpr_passages = json.load(f) 179 | assert len(dpr_passages)==len(self) 180 | assert self.args.psg_sel_dir is not None 181 | data_type = self.data_type.replace("train", "train_for_inference") \ 182 | if "for_inference" not in self.data_type else self.data_type 183 | psg_sel_fn = os.path.join(self.args.psg_sel_dir, 184 | "{}{}_psg_sel.json".format( 185 | data_type, 186 | "_20200201" if self.args.wiki_2020 else "")) 187 | self.logger.info("Loading passage selection from DPR reader: {}".format(psg_sel_fn)) 188 | with open(psg_sel_fn, "r") as f: 189 | fg_passages = json.load(f) 190 | assert len(fg_passages)==len(dpr_passages) 191 | dpr_passages = [[psgs[i] for i in fg_psgs] for psgs, fg_psgs in zip(dpr_passages, fg_passages)] 192 | 193 | self.logger.info("Start processing DPR data") 194 | if self.passages.tokenized_data is None: 195 | subset = set([p_idx for retrieved in dpr_passages for p_idx in retrieved]) 196 | self.passages.load_tokenized_data("bart", subset=subset, all=True) 197 | 198 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, metadata = self.tokenized_data 199 | assert len(dpr_passages)==len(input_ids)==len(attention_mask) 200 | bos_token_id = self.tokenizer.bos_token_id 201 | 202 | def _get_tokenized_answer(idx): 203 | tokens = decoder_input_ids[idx] 204 | if 0 in decoder_attention_mask[idx]: 205 | tokens = tokens[:decoder_attention_mask[idx].index(0)] 206 | assert tokens[0]==tokens[1]==bos_token_id and tokens[-1]==self.tokenizer.eos_token_id 207 | return tokens[2:-1] 208 | 209 | for idx, (curr_input_ids, curr_attention_mask, curr_metadata, dpr_ids) in enumerate(zip( 210 | input_ids, attention_mask, metadata, dpr_passages)): 211 | dpr_input_ids = [self.passages.tokenized_data["input_ids"][_id] for _id in dpr_ids] 212 | dpr_attention_mask = [self.passages.tokenized_data["attention_mask"][_id] for _id in dpr_ids] 213 | offset = 0 214 | end_of_question = curr_input_ids.index(self.tokenizer.eos_token_id)+1 215 | input_ids[idx] = curr_input_ids[:end_of_question] 216 | attention_mask[idx] = curr_attention_mask[:end_of_question] 217 | 218 | while len(input_ids[idx])<1024: 219 | assert dpr_input_ids[offset][0] == bos_token_id 220 | assert len(dpr_input_ids[offset])==len(dpr_attention_mask[offset]) 221 | assert np.sum(dpr_attention_mask[offset])==len(dpr_attention_mask[offset]) 222 | input_ids[idx] += dpr_input_ids[offset][1:] 223 | attention_mask[idx] += dpr_attention_mask[offset][1:] 224 | offset += 1 225 | assert len(input_ids)==len(attention_mask) 226 | input_ids[idx] = input_ids[idx][:1024] 227 | attention_mask[idx] = attention_mask[idx][:1024] 228 | 229 | with open(dpr_tokenized_path, "w") as f: 230 | json.dump([input_ids, attention_mask], f) 231 | self.logger.info("Finish saving tokenized DPR data at {}".format(dpr_tokenized_path)) 232 | 233 | self.tokenized_data[0] = input_ids 234 | self.tokenized_data[1] = attention_mask 235 | 236 | if self.is_training and self.args.discard_not_found_answers: 237 | self.discard_not_found_answers() 238 | 239 | def discard_not_found_answers(self): 240 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, metadata = self.tokenized_data 241 | new_input_ids, new_attention_mask, new_decoder_input_ids, new_decoder_attention_mask, new_metadata = [], [], [], [], [] 242 | 243 | skipped_idxs = [] 244 | 245 | self.logger.info("Discarding training examples where retrieval fails...") 246 | 247 | def _get_tokenized_answer(idx): 248 | tokens = self.tokenized_data[2][idx] 249 | if 0 in self.tokenized_data[3][idx]: 250 | tokens = tokens[:self.tokenized_data[3][idx].index(0)] 251 | assert tokens[0]==tokens[1]==self.tokenizer.bos_token_id and tokens[-1]==self.tokenizer.eos_token_id 252 | return tokens[2:-1] 253 | 254 | for idx, (curr_input_ids, curr_attention_mask, curr_metadata) in enumerate(zip( 255 | input_ids, attention_mask, metadata)): 256 | end_of_question = curr_input_ids.index(self.tokenizer.eos_token_id)+1 257 | def _included(tokens): 258 | for i in range(end_of_question, 1024-len(tokens)+1): 259 | if curr_input_ids[i:i+len(tokens)]==tokens: 260 | return True 261 | return False 262 | 263 | valid_answer_idxs = [answer_idx for answer_idx in range(curr_metadata[0], curr_metadata[1]) 264 | if _included(_get_tokenized_answer(answer_idx))] 265 | if len(valid_answer_idxs)==0: 266 | skipped_idxs.append(idx) 267 | continue 268 | new_input_ids.append(curr_input_ids) 269 | new_attention_mask.append(curr_attention_mask) 270 | new_decoder_input_ids += [decoder_input_ids[i] for i in valid_answer_idxs] 271 | new_decoder_attention_mask += [decoder_attention_mask[i] for i in valid_answer_idxs] 272 | new_metadata.append([len(new_decoder_input_ids)-len(valid_answer_idxs), len(new_decoder_input_ids)]) 273 | 274 | self.tokenized_data = [new_input_ids, new_attention_mask, new_decoder_input_ids, new_decoder_attention_mask, new_metadata] 275 | 276 | print (len(input_ids), len(new_input_ids), len(skipped_idxs)) 277 | 278 | def load_dpr_data_bert(self, dpr_retrieval_path, dpr_tokenized_path): 279 | if os.path.exists(dpr_tokenized_path): 280 | self.logger.info("Loading DPR data from {}".format(dpr_tokenized_path)) 281 | with open(dpr_tokenized_path, "r") as f: 282 | self.tokenized_data = json.load(f) 283 | return 284 | self.logger.info("Start processing DPR data") 285 | with open(dpr_retrieval_path, "r") as f: 286 | dpr_passages = json.load(f) 287 | 288 | if self.args.ambigqa: 289 | # added to convert original DPR data to AmbigQA DPR data 290 | dpr_passages = [dpr_passages[d["orig_idx"]] for d in self.data] 291 | elif self.is_training: 292 | with open(os.path.join(self.args.dpr_data_dir, "data/gold_passages_info/nq_train.json"), "r") as f: 293 | gold_titles = [d["title"] for d in json.load(f)["data"]] 294 | assert len(gold_titles)==len(self) 295 | 296 | input_ids, attention_mask, answer_input_ids, _, metadata = self.tokenized_data 297 | eos_token_id = self.tokenizer.eos_token_id if "Albert" in dpr_tokenized_path else self.tokenizer.sep_token_id 298 | assert eos_token_id is not None 299 | assert len(dpr_passages)==len(input_ids)==len(attention_mask)==len(metadata) 300 | if self.passages.tokenized_data is None: 301 | subset = set([p_idx for retrieved in dpr_passages for p_idx in retrieved]) 302 | self.passages.load_tokenized_data("albert" if "Albert" in dpr_tokenized_path else "bert", 303 | subset=subset, all=True) 304 | features = defaultdict(list) 305 | max_n_answers = self.args.max_n_answers 306 | oracle_exact_matches = [] 307 | flatten_exact_matches = [] 308 | positive_contains_gold_title = [] 309 | for i, (q_input_ids, q_attention_mask, retrieved) in \ 310 | tqdm(enumerate(zip(input_ids, attention_mask, dpr_passages))): 311 | assert len(q_input_ids)==len(q_attention_mask)==32 312 | q_input_ids = [in_ for in_, mask in zip(q_input_ids, q_attention_mask) if mask] 313 | assert 3<=len(q_input_ids)<=32 314 | p_input_ids = [self.passages.tokenized_data["input_ids"][p_idx] for p_idx in retrieved] 315 | p_attention_mask = [self.passages.tokenized_data["attention_mask"][p_idx] for p_idx in retrieved] 316 | a_input_ids = [answer_input_ids[idx][1:-1] for idx in range(metadata[i][0], metadata[i][1])] 317 | detected_spans = [] 318 | for _p_input_ids in p_input_ids: 319 | detected_spans.append([]) 320 | for _a_input_ids in a_input_ids: 321 | decoded_a_input_ids = self.decode(_a_input_ids) 322 | for j in range(len(_p_input_ids)-len(_a_input_ids)+1): 323 | if _p_input_ids[j:j+len(_a_input_ids)]==_a_input_ids: 324 | detected_spans[-1].append((j+len(q_input_ids), j+len(q_input_ids)+len(_a_input_ids)-1)) 325 | elif "Albert" in dpr_tokenized_path and \ 326 | _p_input_ids[j]==_a_input_ids[0] and \ 327 | 13 in _p_input_ids[j:j+len(_a_input_ids)]: 328 | k = j + len(_a_input_ids)+1 329 | while k0][:20] 335 | negatives = [j for j, spans in enumerate(detected_spans) if len(spans)==0][:50] 336 | if len(positives)==0: 337 | continue 338 | elif self.is_training: 339 | gold_title = normalize_answer(gold_titles[i]) 340 | _positives = [j for j, spans in enumerate(detected_spans) if len(spans)>0] 341 | if len(_positives)==0: 342 | continue 343 | positives = [j for j in _positives if normalize_answer(self.decode(p_input_ids[j][:p_input_ids[j].index(eos_token_id)]))==gold_title] 344 | positive_contains_gold_title.append(len(positives)>0) 345 | if len(positives)==0: 346 | positives = _positives[:20] 347 | negatives = [j for j, spans in enumerate(detected_spans) if len(spans)==0][:50] 348 | else: 349 | positives = [j for j in range(len(detected_spans))] 350 | negatives = [] 351 | for key in ["positive_input_ids", "positive_input_mask", "positive_token_type_ids", 352 | "positive_start_positions", "positive_end_positions", "positive_answer_mask", 353 | "negative_input_ids", "negative_input_mask", "negative_token_type_ids"]: 354 | features[key].append([]) 355 | 356 | def _form_input(p_input_ids, p_attention_mask): 357 | assert len(p_input_ids)==len(p_attention_mask) 358 | assert len(p_input_ids)==128 or (len(p_input_ids)<=128 and np.sum(p_attention_mask)==len(p_attention_mask)) 359 | if len(p_input_ids)<128: 360 | p_input_ids += [self.tokenizer.pad_token_id for _ in range(128-len(p_input_ids))] 361 | p_attention_mask += [0 for _ in range(128-len(p_attention_mask))] 362 | input_ids = q_input_ids + p_input_ids + [self.tokenizer.pad_token_id for _ in range(32-len(q_input_ids))] 363 | attention_mask = [1 for _ in range(len(q_input_ids))] + p_attention_mask + [0 for _ in range(32-len(q_input_ids))] 364 | token_type_ids = [0 for _ in range(len(q_input_ids))] + p_attention_mask + [0 for _ in range(32-len(q_input_ids))] 365 | return input_ids, attention_mask, token_type_ids 366 | 367 | for idx in positives: 368 | input_ids, attention_mask, token_type_ids = _form_input(p_input_ids[idx], p_attention_mask[idx]) 369 | features["positive_input_ids"][-1].append(input_ids) 370 | features["positive_input_mask"][-1].append(attention_mask) 371 | features["positive_token_type_ids"][-1].append(token_type_ids) 372 | detected_span = detected_spans[idx] 373 | features["positive_start_positions"][-1].append( 374 | [s[0] for s in detected_span[:max_n_answers]] + [0 for _ in range(max_n_answers-len(detected_span))]) 375 | features["positive_end_positions"][-1].append( 376 | [s[1] for s in detected_span[:max_n_answers]] + [0 for _ in range(max_n_answers-len(detected_span))]) 377 | features["positive_answer_mask"][-1].append( 378 | [1 for _ in detected_span[:max_n_answers]] + [0 for _ in range(max_n_answers-len(detected_span))]) 379 | for idx in negatives: 380 | input_ids, attention_mask, token_type_ids = _form_input(p_input_ids[idx], p_attention_mask[idx]) 381 | features["negative_input_ids"][-1].append(input_ids) 382 | features["negative_input_mask"][-1].append(attention_mask) 383 | features["negative_token_type_ids"][-1].append(token_type_ids) 384 | # for debugging 385 | for p_input_ids, starts, ends, masks in zip(features["positive_input_ids"][-1], 386 | features["positive_start_positions"][-1], 387 | features["positive_end_positions"][-1], 388 | features["positive_answer_mask"][-1]): 389 | if np.sum(masks)==0: continue 390 | assert len(starts)==len(ends)==len(masks)==max_n_answers 391 | decoded_answers = [self.tokenizer.decode(p_input_ids[start:end+1]) for start, end, mask in zip(starts, ends, masks) if mask] 392 | ems = [get_exact_match(decoded_answer, self.data[i]["answer"]) for decoded_answer in decoded_answers] 393 | oracle_exact_matches.append(np.max(ems)) 394 | flatten_exact_matches += ems 395 | print ("oracle exact matches", np.mean(oracle_exact_matches)) 396 | print ("flatten exact matches", np.mean(flatten_exact_matches)) 397 | print ("positive contains gold title", np.mean(positive_contains_gold_title)) 398 | print (len(positive_contains_gold_title)) 399 | self.tokenized_data = features 400 | 401 | with open(dpr_tokenized_path, "w") as f: 402 | json.dump(self.tokenized_data, f) 403 | 404 | def load_dataset(self, tokenizer, do_return=False): 405 | if self.tokenized_data is None: 406 | self.load_tokenized_data(tokenizer) 407 | if isinstance(self.tokenized_data, dict): 408 | self.dataset = MyQADataset(self.tokenized_data, 409 | is_training=self.is_training, 410 | train_M=self.args.train_M, 411 | test_M=self.args.test_M) 412 | else: 413 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, metadata = self.tokenized_data 414 | self.dataset = MySimpleQADataset(input_ids, 415 | attention_mask, 416 | decoder_input_ids if self.is_training or self.args.nq_answer_as_prefix else None, 417 | decoder_attention_mask if self.is_training or self.args.nq_answer_as_prefix else None, 418 | in_metadata=None, 419 | out_metadata=metadata, 420 | is_training=self.is_training, 421 | answer_as_prefix=self.args.nq_answer_as_prefix) 422 | self.logger.info("Loaded {} examples from {} data".format(len(self.dataset), self.data_type)) 423 | 424 | if do_return: 425 | return self.dataset 426 | 427 | def load_dataloader(self, do_return=False): 428 | self.dataloader = MyDataLoader(self.args, self.dataset, is_training=self.is_training) 429 | if do_return: 430 | return self.dataloader 431 | 432 | def evaluate(self, predictions, n_paragraphs=None): 433 | assert len(predictions)==len(self), (len(predictions), len(self)) 434 | if self.args.save_psg_sel_only: 435 | return [-1] 436 | if n_paragraphs is None: 437 | ems = [] 438 | for (prediction, dp) in zip(predictions, self.data): 439 | if type(prediction)==list: 440 | prediction = prediction[0] 441 | if type(prediction)==dict: 442 | prediction = prediction["text"] 443 | ems.append(get_exact_match(prediction, dp["answer"])) 444 | return ems 445 | ems = defaultdict(list) 446 | for (prediction, dp) in zip(predictions, self.data): 447 | assert len(n_paragraphs)==len(prediction) 448 | for pred, n in zip(prediction, n_paragraphs): 449 | if type(pred)==list: 450 | pred = pred[0] 451 | if type(pred)==dict: 452 | pred = pred["text"] 453 | ems[n].append(get_exact_match(pred, dp["answer"])) 454 | for n in n_paragraphs: 455 | self.logger.info("n_paragraphs=%d\t#M=%.2f" % (n, np.mean(ems[n])*100)) 456 | return ems[n_paragraphs[-1]] 457 | 458 | def save_predictions(self, predictions): 459 | assert len(predictions)==len(self), (len(predictions), len(self)) 460 | save_path = os.path.join(self.args.output_dir, "{}{}_predictions.json".format( 461 | self.data_type if self.args.prefix is None else self.args.prefix, 462 | "_20200201" if self.args.wiki_2020 and not self.args.ambigqa else "")) 463 | if self.args.save_psg_sel_only: 464 | save_path = save_path.replace("predictions.json", "psg_sel.json") 465 | with open(save_path, "w") as f: 466 | json.dump(predictions, f) 467 | self.logger.info("Saved prediction in {}".format(save_path)) 468 | 469 | class AmbigQAData(QAData): 470 | def __init__(self, logger, args, data_path, is_training, passages=None): 471 | super(AmbigQAData, self).__init__(logger, args, data_path, is_training, passages) 472 | 473 | with open("/".join(data_path.split("/")[:-2]) + "/nqopen/{}.json".format(self.data_type), "r") as f: 474 | orig_data = json.load(f) 475 | id_to_orig_idx = {d["id"]:i for i, d in enumerate(orig_data)} 476 | 477 | for i, d in enumerate(self.data): 478 | answers = [] 479 | for annotation in d["annotations"]: 480 | assert annotation["type"] in ["singleAnswer", "multipleQAs"] 481 | if annotation["type"]=="singleAnswer": 482 | answers.append([list(set(annotation["answer"]))]) 483 | else: 484 | answers.append([list(set(pair["answer"])) for pair in annotation["qaPairs"]]) 485 | assert type(answers)==list and \ 486 | all([type(answer)==list for answer in answers]) and \ 487 | all([type(_a)==str for answer in answers for _answer in answer for _a in _answer]) 488 | self.data[i]["answer"] = answers 489 | self.data[i]["orig_idx"] = id_to_orig_idx[d["id"]] 490 | 491 | self.metric = "F1" 492 | self.SEP = "" 493 | 494 | # override 495 | def flatten(self, answers): 496 | new_answers, metadata = [], [] 497 | for _answers in answers: 498 | assert type(_answers)==list 499 | metadata.append([]) 500 | for answer in _answers: 501 | metadata[-1].append([]) 502 | for _answer in answer: 503 | assert len(_answer)>0, _answers 504 | assert type(_answer)==list and type(_answer[0])==str, _answers 505 | metadata[-1][-1].append((len(new_answers), len(new_answers)+len(_answer))) 506 | new_answers += _answer 507 | return new_answers, metadata 508 | 509 | # override 510 | def load_dpr_data(self): 511 | dpr_retrieval_path = "out/dpr/{}_predictions.json".format( 512 | self.data_type+"_20200201" if self.args.wiki_2020 else self.data_type) 513 | postfix = self.tokenizer.__class__.__name__.replace("zer", "zed") 514 | dpr_tokenized_path = self.data_path.replace(".json", "_dpr{}.json".format("_20200201" if self.args.wiki_2020 else "")) 515 | if "Bart" in postfix: 516 | return self.load_dpr_data_bart(dpr_retrieval_path, dpr_tokenized_path) 517 | metadata, new_metadata = self.tokenized_data[-1], [] 518 | for curr_metadata in metadata: 519 | new_metadata.append((curr_metadata[0][0][0], curr_metadata[-1][-1][-1])) 520 | self.tokenized_data[-1] = new_metadata 521 | return self.load_dpr_data_bert(dpr_retrieval_path, dpr_tokenized_path) 522 | 523 | # override 524 | def load_dpr_data_bart(self, dpr_retrieval_path, dpr_tokenized_path): 525 | 526 | if self.is_training and self.args.consider_order_for_multiple_answers: 527 | dpr_tokenized_path = dpr_tokenized_path.replace(".json", "_ordered.json") 528 | 529 | self.logger.info(dpr_retrieval_path) 530 | self.logger.info(dpr_tokenized_path) 531 | 532 | if os.path.exists(dpr_tokenized_path): 533 | self.logger.info("Loading DPR data from {}".format(dpr_tokenized_path)) 534 | with open(dpr_tokenized_path, "r") as f: 535 | self.tokenized_data = json.load(f) 536 | return 537 | 538 | import itertools 539 | self.logger.info("Start processing DPR data from {}".format(dpr_retrieval_path)) 540 | if self.passages.tokenized_data is None: 541 | self.passages.load_tokenized_data("bart", all=True) 542 | 543 | with open(dpr_retrieval_path.format(self.data_type).replace("train", "train_for_inference"), "r") as f: 544 | dpr_passages = json.load(f) 545 | assert self.args.psg_sel_dir is not None 546 | 547 | psg_sel_fn = os.path.join(self.args.psg_sel_dir, 548 | "{}{}_psg_sel.json".format( 549 | self.data_type.replace("train", "train_for_inference"), 550 | "_20200201" if self.args.wiki_2020 else "")) 551 | self.logger.info("Loading passage selection from DPR reader: {}".format(psg_sel_fn)) 552 | with open(psg_sel_fn, "r") as f: 553 | fg_passages = json.load(f) 554 | assert len(fg_passages)==len(dpr_passages) 555 | dpr_passages = [[psgs[i] for i in fg_psgs] for psgs, fg_psgs in zip(dpr_passages, fg_passages)] 556 | 557 | # added to convert original DPR data to AmbigQA DPR data 558 | dpr_passages = [dpr_passages[d["orig_idx"]] for d in self.data] 559 | 560 | assert len(dpr_passages)==len(self) 561 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, metadata = self.tokenized_data 562 | assert len(dpr_passages)==len(input_ids)==len(attention_mask)==len(metadata) 563 | bos_token_id = self.tokenizer.bos_token_id 564 | eos_token_id = self.tokenizer.eos_token_id 565 | pad_token_id = self.tokenizer.pad_token_id 566 | sep_token_id = self.tokenizer.convert_tokens_to_ids(self.SEP) 567 | assert type(bos_token_id)==type(eos_token_id)==type(sep_token_id)==int 568 | 569 | def _get_tokenized_answer(idx): 570 | tokens = decoder_input_ids[idx] 571 | if 0 in decoder_attention_mask[idx]: 572 | tokens = tokens[:decoder_attention_mask[idx].index(0)] 573 | assert tokens[0]==tokens[1]==bos_token_id and tokens[-1]==eos_token_id 574 | return tokens[2:-1] 575 | 576 | new_input_ids, new_attention_mask = [], [] 577 | if self.is_training: 578 | new_decoder_input_ids, new_decoder_attention_mask, new_metadata = [], [], [] 579 | else: 580 | new_decoder_input_ids, new_decoder_attention_mask, new_metadata = None, None, None 581 | for idx, (curr_input_ids, curr_attention_mask, curr_metadata, dpr_ids) in enumerate(zip( 582 | input_ids, attention_mask, metadata, dpr_passages)): 583 | 584 | dpr_input_ids = [self.passages.tokenized_data["input_ids"][_id] for _id in dpr_ids] 585 | dpr_attention_mask = [self.passages.tokenized_data["attention_mask"][_id] for _id in dpr_ids] 586 | 587 | # creating input_ids is done in the same way as NQ-open. 588 | offset = 0 589 | end_of_question = curr_input_ids.index(eos_token_id)+1 590 | input_ids[idx] = curr_input_ids[:end_of_question] 591 | attention_mask[idx] = curr_attention_mask[:end_of_question] 592 | while len(input_ids[idx])<1024: 593 | assert dpr_input_ids[offset][0] == bos_token_id 594 | assert len(dpr_input_ids[offset])==len(dpr_attention_mask[offset]) 595 | assert np.sum(dpr_attention_mask[offset])==len(dpr_attention_mask[offset]) 596 | input_ids[idx] += dpr_input_ids[offset][1:] 597 | attention_mask[idx] += dpr_attention_mask[offset][1:] 598 | offset += 1 599 | 600 | if self.is_training: 601 | # now, re-creating decoder_input_ids and metadata 602 | def _included(tokens): 603 | for i in range(end_of_question, 1024-len(tokens)+1): 604 | if input_ids[idx][i:i+len(tokens)]==tokens: 605 | return True 606 | return False 607 | def _valid(tokens_list): 608 | offset = 0 609 | for i in range(end_of_question, 1024): 610 | if input_ids[idx][i:i+len(tokens_list[offset])]==tokens_list[offset]: 611 | offset += 1 612 | if offset==len(tokens_list): 613 | return True 614 | return False 615 | 616 | for _curr_metadata in curr_metadata: 617 | found_answers = [] 618 | for start, end in _curr_metadata: 619 | _answers = [] 620 | for j in range(start, end): 621 | answer = _get_tokenized_answer(j) 622 | if not _included(answer): continue 623 | if answer in _answers: continue 624 | _answers.append(answer) 625 | if len(_answers)>0: 626 | found_answers.append(_answers) 627 | 628 | if len(found_answers)==0: 629 | continue 630 | 631 | decoder_offset = len(new_decoder_input_ids) 632 | cnt = 0 633 | for _answers in itertools.product(*found_answers): 634 | _answers = list(_answers) 635 | if self.args.consider_order_for_multiple_answers and not _valid(_answers): 636 | continue 637 | answers = [bos_token_id, bos_token_id] 638 | for j, answer in enumerate(_answers): 639 | if j>0: answers.append(sep_token_id) 640 | answers += answer 641 | answers.append(eos_token_id) 642 | answers = answers[:30] 643 | new_decoder_input_ids.append( 644 | answers + [pad_token_id for _ in range(30-len(answers))]) 645 | new_decoder_attention_mask.append( 646 | [1 for _ in answers] + [0 for _ in range(30-len(answers))]) 647 | cnt += 1 648 | if cnt==100: 649 | break 650 | assert decoder_offset+cnt==len(new_decoder_input_ids) 651 | if cnt==0: 652 | continue 653 | new_metadata.append([decoder_offset, decoder_offset+cnt]) 654 | 655 | new_input_ids.append(input_ids[idx][:1024]) 656 | new_attention_mask.append(attention_mask[idx][:1024]) 657 | 658 | self.tokenized_data = [new_input_ids, new_attention_mask, new_decoder_input_ids, 659 | new_decoder_attention_mask, new_metadata] 660 | with open(dpr_tokenized_path, "w") as f: 661 | json.dump(self.tokenized_data, f) 662 | self.logger.info("Finish saving tokenized DPR data") 663 | 664 | # override 665 | def evaluate(self, predictions, n_paragraphs=None): 666 | assert len(predictions)==len(self), (len(predictions), len(self)) 667 | f1s, f1s_wo_dupli = [], [] 668 | if self.args.is_seq2seq: 669 | for (prediction, dp) in zip(predictions, self.data): 670 | prediction1 = [text.strip() for text in prediction.split(self.SEP)] 671 | prediction2 = list(set(prediction1)) 672 | f1s.append(np.max([get_f1(answer, prediction1) for answer in dp["answer"]])) 673 | f1s_wo_dupli.append(np.max([get_f1(answer, prediction2) for answer in dp["answer"]])) 674 | self.logger.info("F1=%.2f, F1 w/o dupli=%.2f" % (np.mean(f1s)*100, np.mean(f1s_wo_dupli)*100)) 675 | else: 676 | for (prediction, dp) in zip(predictions, self.data): 677 | preds = [] 678 | if type(prediction[0])==list: 679 | prediction = prediction[-1] 680 | for p in prediction: 681 | if normalize_answer(p["text"]) not in preds: 682 | if p["log_softmax"]>np.log(0.05) or len(preds)==0: 683 | preds.append(normalize_answer(p["text"])) 684 | if p["log_softmax"]<=np.log(0.05) or len(preds)==3: 685 | break 686 | f1s.append(np.max([get_f1(answer, preds) for answer in dp["answer"]])) 687 | return f1s 688 | 689 | 690 | 691 | 692 | 693 | --------------------------------------------------------------------------------