├── .gitignore ├── README.md ├── qa ├── basic_tokenizer.py ├── bert_retrieve_qa.py ├── config.py ├── datasets.py ├── eval_utils.py ├── msmarco_process.py ├── official_eval.py ├── online_sampler.py ├── prepro_dense.py ├── prepro_utils.py ├── tokenizer.py ├── train.py ├── train_dense_qa.sh ├── train_retrieve_qa.py └── utils.py ├── requirements.txt └── retrieval ├── basic_tokenizer.py ├── config.py ├── datasets.py ├── eval_retrieval.py ├── gen_index_id_map.py ├── get_embed.py ├── get_para_embed.sh ├── group_paras.py ├── retriever.py ├── tokenizer.py ├── train_retriever.py ├── train_retriever_cluster.sh ├── train_retriever_single.sh ├── trec_process.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /data 3 | /pretrained_models 4 | *.zip 5 | retrieval/logs/ 6 | __MACOSX/ 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProQA 2 | 3 | Resource-efficient method for pretraining a dense corpus index for open-domain QA and IR. Given a question, you could use this code to retrieval relevant paragraphs from Wikipedia and extract answers. 4 | 5 | ## 1. Set up the environments 6 | ``` 7 | conda create -n proqa -y python=3.6.9 && conda activate proqa 8 | pip install -r requirements.txt 9 | ``` 10 | If you want to used mixed precision training, you need to follow [Nvidia Apex repo](https://github.com/NVIDIA/apex) to install Apex if your GPUs support fp16. 11 | 12 | ## 2. Download data (including the corpus, paragraphs paired with the generated questions, etc.) 13 | ``` 14 | gdown https://drive.google.com/uc?id=17IMQ5zzfkCNsTZNJqZI5KveoIsaG2ZDt && unzip data.zip 15 | cd data && gdown https://drive.google.com/uc?id=1T1SntmAZxJ6QfNBN39KbAHcMw0JR5MwL 16 | ``` 17 | The data folder includes the QA datasets and also the paragraph database ``nq_paras.db`` which can be used with sqlite3. **If the command line fails to download the file, please use your brower instead.** 18 | 19 | ## 2. Use pretrained index and models 20 | Download the pretrained models and data from google drive: 21 | ``` 22 | gdown https://drive.google.com/uc?id=1fDRHsLk5emLqHSMkkoockoHjRSOEBaZw && unzip pretrained_models.zip 23 | ``` 24 | 25 | ### Test the Retrieval Performance Before QA finetuning 26 | * First, encode all the questions as embeddings (use WebQuestions text for this example): 27 | ``` 28 | cd retrieval 29 | CUDA_VISIBLE_DEVICES=0 python get_embed.py \ 30 | --do_predict \ 31 | --predict_batch_size 512 \ 32 | --bert_model_name bert-base-uncased \ 33 | --fp16 \ 34 | --predict_file ../data/WebQuestions-test.txt \ 35 | --init_checkpoint ../pretrained_models/retriever.pt \ 36 | --is_query_embed \ 37 | --embed_save_path ../data/wq_test_query_embed.npy 38 | ``` 39 | 40 | * Retrieval topk (k=80) paragraphs from the corpus and evaluate recall with simple string matching 41 | ``` 42 | python eval_retrieval.py ../data/WebQuestions-test.txt ../pretrained_models/para_embed.npy ../data/wq_test_query_embed.npy ../data/nq_paras.db 43 | ``` 44 | The arguments are the dataset file, dense corpus index, question embeddings and the paragraph database. The results should be like: 45 | ``` 46 | Top 80 Recall for 2032 QA pairs: 0.7839566929133859 ... 47 | Top 5 Recall for 2032 QA pairs: 0.5196850393700787 ... 48 | Top 10 Recall for 2032 QA pairs: 0.610236220472441 ... 49 | Top 20 Recall for 2032 QA pairs: 0.687007874015748 ... 50 | Top 50 Recall for 2032 QA pairs: 0.7554133858267716 ... 51 | ``` 52 | 53 | ## 3. Retriever pretraining 54 | ### Use a single pretraining file: 55 | * Under the `retrieval` directory: 56 | ``` 57 | cd retrieval 58 | ./train_retriever_single.sh 59 | ``` 60 | This script will use the unclustered the data for pretraining. After certain updates, we will pause the training and use the following steps to cluster the data and continue training. This will save a checkpoint under `retrieval/logs/`. 61 | 62 | ### Use clutered data for pretraining: 63 | #### Generate paragraph clusters 64 | * Generate the paragraph embeddings using the checkpoint from last step: 65 | ``` 66 | mkdir encodings 67 | CUDA_VISIBLE_DEVICES=0 python get_embed.py --do_predict --prefix eval-para \ 68 | --predict_batch_size 300 \ 69 | --bert_model_name bert-base-uncased \ 70 | --fp16 \ 71 | --predict_file ../data/retrieve_train.txt \ 72 | --init_checkpoint ../pretrained_models/retriever.pt \ 73 | --embed_save_path encodings/train_para_embed.npy \ 74 | --eval-workers 32 \ 75 | --fp16 76 | ``` 77 | * Generate clusters using the paragraph embeddings: 78 | ``` 79 | python group_paras.py 80 | ``` 81 | Clustering hyperparameter settings such as num of clusters can be found in `group_paras.py`. 82 | 83 | #### Pretraining using clusters 84 | * Then run the retrieval script: 85 | ``` 86 | ./train_retriever_cluster.sh 87 | ``` 88 | 89 | ## 4. QA finetuning 90 | * Generate the paragraph dense index under "retrieval" directory: ``./get_para_embed.sh`` 91 | * Finetune the pretraining model on the QA dataset under "qa" directory: ``./train_dense_qa.sh`` 92 | -------------------------------------------------------------------------------- /qa/basic_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | class Tokens(object): 13 | """A class to represent a list of tokenized text.""" 14 | TEXT = 0 15 | TEXT_WS = 1 16 | SPAN = 2 17 | POS = 3 18 | LEMMA = 4 19 | NER = 5 20 | 21 | def __init__(self, data, annotators, opts=None): 22 | self.data = data 23 | self.annotators = annotators 24 | self.opts = opts or {} 25 | 26 | def __len__(self): 27 | """The number of tokens.""" 28 | return len(self.data) 29 | 30 | def slice(self, i=None, j=None): 31 | """Return a view of the list of tokens from [i, j).""" 32 | new_tokens = copy.copy(self) 33 | new_tokens.data = self.data[i: j] 34 | return new_tokens 35 | 36 | def untokenize(self): 37 | """Returns the original text (with whitespace reinserted).""" 38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 39 | 40 | def words(self, uncased=False): 41 | """Returns a list of the text of each token 42 | 43 | Args: 44 | uncased: lower cases text 45 | """ 46 | if uncased: 47 | return [t[self.TEXT].lower() for t in self.data] 48 | else: 49 | return [t[self.TEXT] for t in self.data] 50 | 51 | def offsets(self): 52 | """Returns a list of [start, end) character offsets of each token.""" 53 | return [t[self.SPAN] for t in self.data] 54 | 55 | def pos(self): 56 | """Returns a list of part-of-speech tags of each token. 57 | Returns None if this annotation was not included. 58 | """ 59 | if 'pos' not in self.annotators: 60 | return None 61 | return [t[self.POS] for t in self.data] 62 | 63 | def lemmas(self): 64 | """Returns a list of the lemmatized text of each token. 65 | Returns None if this annotation was not included. 66 | """ 67 | if 'lemma' not in self.annotators: 68 | return None 69 | return [t[self.LEMMA] for t in self.data] 70 | 71 | def entities(self): 72 | """Returns a list of named-entity-recognition tags of each token. 73 | Returns None if this annotation was not included. 74 | """ 75 | if 'ner' not in self.annotators: 76 | return None 77 | return [t[self.NER] for t in self.data] 78 | 79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 80 | """Returns a list of all ngrams from length 1 to n. 81 | 82 | Args: 83 | n: upper limit of ngram length 84 | uncased: lower cases text 85 | filter_fn: user function that takes in an ngram list and returns 86 | True or False to keep or not keep the ngram 87 | as_string: return the ngram as a string vs list 88 | """ 89 | def _skip(gram): 90 | if not filter_fn: 91 | return False 92 | return filter_fn(gram) 93 | 94 | words = self.words(uncased) 95 | ngrams = [(s, e + 1) 96 | for s in range(len(words)) 97 | for e in range(s, min(s + n, len(words))) 98 | if not _skip(words[s:e + 1])] 99 | 100 | # Concatenate into strings 101 | if as_strings: 102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 103 | 104 | return ngrams 105 | 106 | def entity_groups(self): 107 | """Group consecutive entity tokens with the same NER tag.""" 108 | entities = self.entities() 109 | if not entities: 110 | return None 111 | non_ent = self.opts.get('non_ent', 'O') 112 | groups = [] 113 | idx = 0 114 | while idx < len(entities): 115 | ner_tag = entities[idx] 116 | # Check for entity tag 117 | if ner_tag != non_ent: 118 | # Chomp the sequence 119 | start = idx 120 | while (idx < len(entities) and entities[idx] == ner_tag): 121 | idx += 1 122 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 123 | else: 124 | idx += 1 125 | return groups 126 | 127 | 128 | class Tokenizer(object): 129 | """Base tokenizer class. 130 | Tokenizers implement tokenize, which should return a Tokens class. 131 | """ 132 | 133 | def tokenize(self, text): 134 | raise NotImplementedError 135 | 136 | def shutdown(self): 137 | pass 138 | 139 | def __del__(self): 140 | self.shutdown() 141 | 142 | 143 | import regex 144 | import logging 145 | 146 | logger = logging.getLogger(__name__) 147 | 148 | 149 | class RegexpTokenizer(Tokenizer): 150 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 151 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 152 | r'\.(?=\p{Z})') 153 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 154 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 155 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 156 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 157 | CONTRACTION1 = r"can(?=not\b)" 158 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 159 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 160 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 161 | END_DQUOTE = r'(?%s)|(?P%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' 176 | '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' 177 | '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' 178 | '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % 179 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, 180 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, 181 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, 182 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, 183 | self.NON_WS), 184 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 185 | ) 186 | if len(kwargs.get('annotators', {})) > 0: 187 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 188 | (type(self).__name__, kwargs.get('annotators'))) 189 | self.annotators = set() 190 | self.substitutions = kwargs.get('substitutions', True) 191 | 192 | def tokenize(self, text): 193 | data = [] 194 | matches = [m for m in self._regexp.finditer(text)] 195 | for i in range(len(matches)): 196 | # Get text 197 | token = matches[i].group() 198 | 199 | # Make normalizations for special token types 200 | if self.substitutions: 201 | groups = matches[i].groupdict() 202 | if groups['sdquote']: 203 | token = "``" 204 | elif groups['edquote']: 205 | token = "''" 206 | elif groups['ssquote']: 207 | token = "`" 208 | elif groups['esquote']: 209 | token = "'" 210 | elif groups['dash']: 211 | token = '--' 212 | elif groups['ellipses']: 213 | token = '...' 214 | 215 | # Get whitespace 216 | span = matches[i].span() 217 | start_ws = span[0] 218 | if i + 1 < len(matches): 219 | end_ws = matches[i + 1].span()[0] 220 | else: 221 | end_ws = span[1] 222 | 223 | # Format data 224 | data.append(( 225 | token, 226 | text[start_ws: end_ws], 227 | span, 228 | )) 229 | return Tokens(data, self.annotators) 230 | 231 | 232 | class SimpleTokenizer(Tokenizer): 233 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 234 | NON_WS = r'[^\p{Z}\p{C}]' 235 | 236 | def __init__(self, **kwargs): 237 | """ 238 | Args: 239 | annotators: None or empty set (only tokenizes). 240 | """ 241 | self._regexp = regex.compile( 242 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 243 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 244 | ) 245 | if len(kwargs.get('annotators', {})) > 0: 246 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 247 | (type(self).__name__, kwargs.get('annotators'))) 248 | self.annotators = set() 249 | 250 | def tokenize(self, text): 251 | data = [] 252 | matches = [m for m in self._regexp.finditer(text)] 253 | for i in range(len(matches)): 254 | # Get text 255 | token = matches[i].group() 256 | 257 | # Get whitespace 258 | span = matches[i].span() 259 | start_ws = span[0] 260 | if i + 1 < len(matches): 261 | end_ws = matches[i + 1].span()[0] 262 | else: 263 | end_ws = span[1] 264 | 265 | # Format data 266 | data.append(( 267 | token, 268 | text[start_ws: end_ws], 269 | span, 270 | )) 271 | return Tokens(data, self.annotators) 272 | 273 | 274 | 275 | -------------------------------------------------------------------------------- /qa/bert_retrieve_qa.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, BertConfig, BertPreTrainedModel 2 | import torch.nn as nn 3 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import sys 8 | sys.path.append('../retrieval') 9 | from retriever import BertForRetriever 10 | 11 | 12 | class BertRetrieveQA(nn.Module): 13 | 14 | def __init__(self, 15 | config, 16 | args 17 | ): 18 | super(BertRetrieveQA, self).__init__() 19 | self.shared_norm = args.shared_norm 20 | self.separate = args.separate 21 | self.add_select = args.add_select 22 | self.drop_early = args.drop_early 23 | 24 | if args.use_spanbert: 25 | self.bert = BertModel.from_pretrained(args.spanbert_path) 26 | else: 27 | self.bert = BertModel.from_pretrained(args.bert_model_name) 28 | 29 | # parameters from pretrained index 30 | self.retriever = BertForRetriever(config, args) 31 | if args.retriever_path != "": 32 | self.load_pretrained_retriever(args.retriever_path) 33 | 34 | self.qa_outputs = nn.Linear( 35 | config.hidden_size, 2) 36 | self.qa_drop = nn.Dropout(args.qa_drop) 37 | self.shared_norm = args.shared_norm 38 | 39 | if self.add_select: 40 | self.select_outputs = nn.Linear(config.hidden_size, 1) 41 | 42 | def load_pretrained_retriever(self, path): 43 | state_dict = torch.load(path) 44 | def filter(x): return x[7:] if x.startswith('module.') else x 45 | state_dict = {filter(k): v for (k, v) in state_dict.items()} 46 | self.retriever.load_state_dict(state_dict) 47 | 48 | def freeze_c_encoder(self): 49 | for p in self.retriever.bert_c.parameters(): 50 | p.requires_grad = False 51 | for p in self.retriever.proj_c.parameters(): 52 | p.requires_grad = False 53 | 54 | def freeze_retriever(self): 55 | for p in self.retriever.parameters(): 56 | p.requires_grad = False 57 | 58 | def forward(self, batch): 59 | input_ids, attention_mask, token_type_ids = batch[ 60 | "input_ids"], batch["input_mask"], batch["segment_ids"] 61 | outputs = self.bert(input_ids, attention_mask, token_type_ids) 62 | sequence_output = outputs[0] 63 | 64 | logits = self.qa_outputs(self.qa_drop(sequence_output)) 65 | outs = [o.squeeze(-1) for o in logits.split(1, dim=-1)] 66 | outs = [o.float().masked_fill(batch["paragraph_mask"].ne(1), -1e10).type_as(o) 67 | for o in outs] 68 | 69 | start_logits = outs[0] 70 | end_logits = outs[1] 71 | 72 | input_ids_q, attention_mask_q = batch["input_ids_q"], batch["input_mask_q"] 73 | q_cls = self.retriever.bert_q(input_ids_q, attention_mask_q)[1] 74 | q = self.retriever.proj_q(q_cls) 75 | 76 | rank_logits = q[0].unsqueeze(0).mm(batch["para_embed"].t()) 77 | rank_probs = F.softmax(rank_logits, dim=-1) 78 | 79 | if self.add_select: 80 | pooled_output = outputs[1] 81 | select_logits = self.select_outputs(pooled_output) 82 | 83 | if self.training: 84 | start_positions, end_positions, rank_targets = batch[ 85 | "start_positions"], batch["end_positions"], batch["para_targets"] 86 | loss_fct = CrossEntropyLoss(ignore_index=-1, reduction="none") 87 | 88 | if not self.drop_early: 89 | # early loss 90 | para_targets = batch["top5000_labels"].nonzero() 91 | early_losses = [loss_fct(rank_logits, p) 92 | for p in torch.unbind(para_targets)] 93 | if len(early_losses) == 0: 94 | early_loss = loss_fct(start_logits, start_logits.new_zeros( 95 | start_logits.size(0)).long()-1).sum() 96 | else: 97 | early_loss = - \ 98 | torch.log(torch.sum(torch.exp(-torch.cat(early_losses)))) 99 | 100 | if self.add_select: 101 | select_logits_flat = select_logits.view(1, -1) 102 | select_probs = F.softmax(select_logits_flat, dim=-1) 103 | 104 | if self.separate: 105 | select_targets_flat = rank_targets.view(1, -1) 106 | select_targets_flat = select_targets_flat.nonzero()[ 107 | :, 1].unsqueeze(1) 108 | select_losses = [loss_fct(select_logits_flat, r) 109 | for r in torch.unbind(select_targets_flat)] 110 | if len(select_losses) == 0: 111 | select_loss = loss_fct( 112 | select_logits_flat, select_logits_flat.new_zeros(1).long()-1).sum() 113 | else: 114 | select_loss = - torch.log(torch.sum(torch.exp(-torch.cat(select_losses)))) 115 | 116 | 117 | # two ways to calculate the span probabilities 118 | if self.shared_norm: 119 | offset = (torch.arange(start_positions.size( 120 | 0)) * start_logits.size(1)).unsqueeze(1).to(start_positions.device) 121 | start_positions_ = start_positions + \ 122 | (start_positions != -1) * offset 123 | end_positions_ = end_positions + (end_positions != -1) * offset 124 | start_positions_ = start_positions_.view(-1, 1) 125 | end_positions_ = end_positions_.view(-1, 1) 126 | start_logits_flat = start_logits.view(1, -1) 127 | end_logits_flat = end_logits.view(1, -1) 128 | start_losses = [loss_fct(start_logits_flat, s) 129 | for s in torch.unbind(start_positions_)] 130 | end_losses = [loss_fct(end_logits_flat, e) 131 | for e in torch.unbind(end_positions_)] 132 | loss_tensor = - (torch.cat(start_losses) + 133 | torch.cat(end_losses)) 134 | loss_tensor = loss_tensor.view(start_positions.size()) 135 | log_prob = loss_tensor.float().masked_fill( 136 | loss_tensor == 0, float('-inf')).type_as(loss_tensor) 137 | else: 138 | start_losses = [loss_fct(start_logits, starts) for starts in torch.unbind(start_positions, dim=1)] 139 | end_losses = [loss_fct(end_logits, ends) for ends in torch.unbind(end_positions, dim=1)] 140 | loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat([t.unsqueeze(1) for t in end_losses], dim=1) 141 | log_prob = - loss_tensor 142 | log_prob = log_prob.float().masked_fill(log_prob == 0, float('-inf')).type_as(log_prob) 143 | 144 | # marginal probabily for each paragraph 145 | probs = torch.exp(log_prob) 146 | marginal_probs = torch.sum(probs, dim=1) 147 | 148 | # joint or separate loss functions 149 | if self.separate: 150 | m_prob = [marginal_probs[idx] for idx in marginal_probs.nonzero()] 151 | if len(m_prob) == 0: 152 | span_loss = loss_fct(start_logits, start_logits.new_zeros( 153 | start_logits.size(0)).long()-1).sum() 154 | else: 155 | span_loss = - torch.log(torch.sum(torch.cat(m_prob))) 156 | total_loss = span_loss + select_loss + early_loss if self.add_select else span_loss + early_loss 157 | 158 | else: 159 | if self.add_select: 160 | rank_probs = select_probs 161 | 162 | joint_prob = marginal_probs * rank_probs.view(-1)[:marginal_probs.size(0)] 163 | joint_prob = [joint_prob[idx] for idx in marginal_probs.nonzero()] 164 | if len(joint_prob) == 0: 165 | joint_loss = loss_fct(start_logits, start_logits.new_zeros( 166 | start_logits.size(0)).long()-1).sum() 167 | else: 168 | joint_loss = - torch.log(torch.sum(torch.cat(joint_prob))) 169 | total_loss = joint_loss + early_loss 170 | 171 | return {"loss": total_loss} 172 | 173 | if self.add_select: 174 | return {"start_logits": start_logits, "end_logits": end_logits, "rank_logits": rank_logits, "select_logits": select_logits.view(1, -1)} 175 | else: 176 | return {"start_logits": start_logits, "end_logits": end_logits, "rank_logits": rank_logits} 177 | -------------------------------------------------------------------------------- /qa/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # Required parameters 8 | parser.add_argument("--bert_model_name", 9 | default="bert-base-uncased", type=str) 10 | parser.add_argument("--output_dir", default="logs", type=str, 11 | help="The output directory where the model checkpoints will be written.") 12 | parser.add_argument("--weight_decay", default=0.0, type=float, 13 | help="Weight decay if we apply some.") 14 | 15 | # Other parameters 16 | parser.add_argument("--load", default=False, action='store_true') 17 | parser.add_argument("--num_workers", default=5, type=int) 18 | parser.add_argument("--train_file", type=str, 19 | default="../../data/mrqa-train/HotpotQA-tokenized.jsonl") 20 | parser.add_argument("--predict_file", type=str, 21 | default="../../data/mrqa-dev/HotpotQA-tokenized.jsonl") 22 | parser.add_argument("--init_checkpoint", type=str, 23 | help="Initial checkpoint (usually from a pre-trained BERT model).", 24 | default="") 25 | parser.add_argument("--do_lower_case", default=True, action='store_true', 26 | help="Whether to lower case the input text. Should be True for uncased" 27 | "models and False for cased models.") 28 | parser.add_argument("--max_seq_length", default=512, type=int, 29 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 30 | "longer than this will be truncated, and sequences shorter than this will be padded.") 31 | parser.add_argument("--max_query_length", default=50, type=int, 32 | help="The maximum number of tokens for the question. Questions longer than this will " 33 | "be truncated to this length.") 34 | parser.add_argument("--do_train", default=False, 35 | action='store_true', help="Whether to run training.") 36 | parser.add_argument("--do_predict", default=False, 37 | action='store_true', help="Whether to run eval on the dev set.") 38 | parser.add_argument("--train_batch_size", default=8, 39 | type=int, help="Total batch size for training.") 40 | parser.add_argument("--predict_batch_size", default=100, 41 | type=int, help="Total batch size for predictions.") 42 | parser.add_argument("--learning_rate", default=5e-5, 43 | type=float, help="The initial learning rate for Adam.") 44 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 45 | help="Epsilon for Adam optimizer.") 46 | parser.add_argument("--num_train_epochs", default=200, type=float, 47 | help="Total number of training epochs to perform.") 48 | parser.add_argument('--wait_step', type=int, default=100) 49 | parser.add_argument("--save_checkpoints_steps", default=1000, type=int, 50 | help="How often to save the model checkpoint.") 51 | parser.add_argument("--iterations_per_loop", default=1000, type=int, 52 | help="How many steps to make in each estimator call.") 53 | parser.add_argument("--no_cuda", default=False, action='store_true', 54 | help="Whether not to use CUDA when available") 55 | parser.add_argument("--local_rank", type=int, default=-1, 56 | help="local_rank for distributed training on gpus") 57 | parser.add_argument("--accumulate_gradients", type=int, default=1, 58 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 59 | parser.add_argument('--seed', type=int, default=3, 60 | help="random seed for initialization") 61 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 62 | help="Number of updates steps to accumualte before performing a backward/update pass.") 63 | parser.add_argument('--eval_period', type=int, default=1000, help="setting to -1: eval only after each epoch") 64 | parser.add_argument('--verbose', action="store_true", default=False) 65 | parser.add_argument('--efficient_eval', action="store_true", help="whether to use fp16 for evaluation") 66 | parser.add_argument('--max_answer_len', default=20, type=int) 67 | parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.") 68 | 69 | parser.add_argument('--fp16', action='store_true') 70 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 71 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 72 | "See details at https://nvidia.github.io/apex/amp.html") 73 | 74 | # BERT QA 75 | parser.add_argument("--qa-drop", default=0, type=float) 76 | parser.add_argument("--rank-drop", default=0, type=float) 77 | 78 | parser.add_argument("--MI", action="store_true", help="Use MI regularization to improve weak supervision") 79 | parser.add_argument("--mi-k", default=10, type=int, help="negative sample number") 80 | parser.add_argument("--max-pool", action="store_true", help="CLS or maxpooling") 81 | parser.add_argument("--eval-workers", default=16, help="parallel data loader", type=int) 82 | parser.add_argument("--save-pred", action="store_true", help="uncertainty analysis") 83 | parser.add_argument("--retriever-path", type=str, default="", help="pretrained retriever checkpoint") 84 | 85 | parser.add_argument("--raw-train-data", type=str, 86 | default="../data/nq-train.txt") 87 | parser.add_argument("--raw-eval-data", type=str, 88 | default="../data/nq-dev.txt") 89 | parser.add_argument("--fix-para-encoder", action="store_true") 90 | parser.add_argument("--db-path", type=str, 91 | default='../data/nq_paras.db') 92 | parser.add_argument("--index-path", type=str, 93 | default="retrieval/index_data/para_embed_100k.npy") 94 | parser.add_argument("--matched-para-path", type=str, 95 | default="../data/wq_ft_train_matched.txt") 96 | 97 | parser.add_argument("--use-spanbert", action="store_true", help="use spanbert for question answering") 98 | parser.add_argument("--spanbert-path", 99 | default="../data/span_bert", type=str) 100 | parser.add_argument("--eval-k", default=5, type=int) 101 | parser.add_argument("--regex", action="store_true", help="for CuratedTrec") 102 | 103 | # investigate different kinds of loss functions 104 | parser.add_argument("--separate", action="store_true", help="separate the rank and reader loss") 105 | parser.add_argument("--add-select", action="store_true", help="replace the rank probability with the selection probility from the reader model ([CLS])") 106 | parser.add_argument("--drop-early", action="store_true", help="drop the early loss on topk5000") 107 | parser.add_argument("--shared-norm", action="store_true", 108 | help="normalize span logits across different paragraphs") 109 | 110 | # parser.add_argument("--fix-retriever", action="store_true") 111 | # parser.add_argument("--joint-train", action="store_true") 112 | # parser.add_argument("--mixed", action="store_true",help="shared norm and also use the rank probabilities in loss") 113 | # parser.add_argument("--use-adam", action="store_true") 114 | # parser.add_argument("--para-embed-path", type=str, default="") 115 | # parser.add_argument("--retrieved-path", type=str, default="") 116 | 117 | # For evaluation 118 | parser.add_argument('--prefix', type=str, default="eval") 119 | parser.add_argument('--debug', action="store_true") 120 | parser.add_argument('--use-top-passage', action="store_true") 121 | parser.add_argument('--topk', default=30, type=int) 122 | parser.add_argument('--save-all', action="store_true", help="save the predictions") 123 | parser.add_argument('--candidates', default="", type=str, help="restrict the predicted spans to be entities") 124 | 125 | args = parser.parse_args() 126 | 127 | return args 128 | -------------------------------------------------------------------------------- /qa/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset, Sampler 2 | import torch 3 | import json 4 | import numpy as np 5 | import random 6 | from tqdm import tqdm 7 | 8 | from joblib import Parallel, delayed 9 | 10 | from prepro_utils import hash_question 11 | 12 | def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False): 13 | """Convert a list of 1d tensors into a padded 2d tensor.""" 14 | size = max(v.size(0) for v in values) 15 | res = values[0].new(len(values), size).fill_(pad_idx) 16 | 17 | def copy_tensor(src, dst): 18 | assert dst.numel() == src.numel() 19 | if move_eos_to_beginning: 20 | assert src[-1] == eos_idx 21 | dst[0] = eos_idx 22 | dst[1:] = src[:-1] 23 | else: 24 | dst.copy_(src) 25 | 26 | for i, v in enumerate(values): 27 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 28 | return res 29 | 30 | 31 | class OpenQASampler(Sampler): 32 | """ 33 | Shuffle QA pairs not context, make sure data within the batch are from the same QA pair 34 | """ 35 | 36 | def __init__(self, data_source, batch_size): 37 | self.batch_size = batch_size 38 | # for each QA pair, sample negative paragraphs 39 | sample_indice = [] 40 | for qa_idx in range(len(data_source.qids)): 41 | batch_data = [] 42 | batch_data.append(random.choice(data_source.grouped_idx_has_answer[qa_idx])) 43 | assert len(batch_data) >= 1 44 | if len(data_source.grouped_idx_no_answer[qa_idx]) < self.batch_size - len(batch_data): 45 | # print("Too few negative samples...") 46 | # continue 47 | if len(data_source.grouped_idx_no_answer[qa_idx]) == 0: 48 | continue 49 | negative_sample = random.choices(data_source.grouped_idx_no_answer[qa_idx], k=self.batch_size - len(batch_data)) 50 | else: 51 | negative_sample = random.sample(data_source.grouped_idx_no_answer[qa_idx], self.batch_size - len(batch_data)) 52 | batch_data.extend(negative_sample) 53 | assert len(batch_data) == batch_size 54 | sample_indice.append(batch_data) 55 | 56 | print(f"{len(sample_indice)} QA pairs used for training...") 57 | 58 | sample_indice = np.array(sample_indice) 59 | np.random.shuffle(sample_indice) 60 | self.sample_indice = list(sample_indice.flatten()) 61 | 62 | def __len__(self): 63 | return len(self.sample_indice) 64 | 65 | def __iter__(self): 66 | return iter(self.sample_indice) 67 | 68 | 69 | class BatchSampler(Sampler): 70 | """ 71 | use all paragraphs, shuffle the QA pairs 72 | """ 73 | 74 | def __init__(self, data_source, batch_size): 75 | self.batch_size = batch_size 76 | sample_indice = [] 77 | for qa_idx in range(len(data_source.qids)): 78 | batch_data = [] 79 | batch_data.extend(data_source.grouped_idx_has_answer[qa_idx]) 80 | batch_data.extend(data_source.grouped_idx_no_answer[qa_idx]) 81 | assert len(batch_data) == batch_size 82 | sample_indice.append(batch_data) 83 | 84 | print(f"{len(sample_indice)} QA pairs used for training...") 85 | sample_indice = np.array(sample_indice) 86 | np.random.shuffle(sample_indice) 87 | self.sample_indice = list(sample_indice.flatten()) 88 | 89 | def __len__(self): 90 | return len(self.sample_indice) 91 | 92 | def __iter__(self): 93 | return iter(self.sample_indice) 94 | 95 | 96 | class OpenQADataset(Dataset): 97 | 98 | def __init__(self, 99 | tokenizer, 100 | data_path, 101 | max_query_length, 102 | max_length 103 | ): 104 | super().__init__() 105 | self.tokenizer = tokenizer 106 | print(f"Loading tokenized data from {data_path}...") 107 | 108 | 109 | self.qids = [] 110 | self.all_data = [json.loads(line) 111 | for line in tqdm(open(data_path).readlines())] 112 | self.grouped_idx_has_answer = [] 113 | self.grouped_idx_no_answer = [] 114 | for idx, item in enumerate(self.all_data): 115 | if len(self.qids) == 0 or item["qid"] != self.qids[-1]: 116 | self.qids.append(item["qid"]) 117 | self.grouped_idx_no_answer.append([]) 118 | self.grouped_idx_has_answer.append([]) 119 | if item["no_answer"] == 0: 120 | self.grouped_idx_has_answer[-1].append(idx) 121 | else: 122 | self.grouped_idx_no_answer[-1].append(idx) 123 | 124 | print(f"{len(self.qids)} QA pairs loaded....") 125 | self.max_query_length = max_query_length 126 | self.max_length = max_length 127 | 128 | def __getitem__(self, index): 129 | sample = self.all_data[index] 130 | qid = sample['qid'] 131 | q_subtoks = sample['q_subtoks'] 132 | if len(q_subtoks) > self.max_query_length: 133 | q_subtoks = q_subtoks[:self.max_query_length] 134 | question = torch.LongTensor(self.binarize_list(q_subtoks)) 135 | para_offset = question.size(0) + 2 136 | 137 | para_subtoks = sample['doc_subtoks'] 138 | max_tokens_for_doc = self.max_length - para_offset - 1 139 | if len(para_subtoks) > max_tokens_for_doc: 140 | para_subtoks = para_subtoks[:max_tokens_for_doc] 141 | 142 | paragraph = torch.LongTensor(self.binarize_list(para_subtoks)) 143 | text, seg = self._join_sents(question, paragraph) 144 | paragraph_mask = torch.zeros(text.shape).bool() 145 | question_mask = torch.zeros(text.shape).bool() 146 | paragraph_mask[para_offset:-1] = 1 147 | question_mask[1:para_offset] = 1 148 | 149 | starts, ends, no_answer = sample["starts"], sample["ends"], sample["no_answer"] 150 | 151 | start_positions, end_positions = [], [] 152 | if not no_answer: 153 | no_answer = 1 154 | for s, e in zip(starts, ends): 155 | assert s <= e 156 | if s >= paragraph.size(0): 157 | continue 158 | else: 159 | start_position = min(s, paragraph.size(0) - 1) + para_offset 160 | end_position = min(e, paragraph.size(0) - 1) + para_offset 161 | no_answer = 0 162 | start_positions.append(start_position) 163 | end_positions.append(end_position) 164 | 165 | if len(start_positions) == 0: 166 | assert no_answer 167 | start_positions.append(-1) 168 | end_positions.append(-1) 169 | 170 | start_tensor, end_tensor, no_answer = torch.LongTensor( 171 | start_positions), torch.LongTensor(end_positions), torch.LongTensor([no_answer]) 172 | 173 | item_tensor = { 174 | 'q': sample["q"], 175 | 'qid': qid, 176 | 'input_ids': text, 177 | 'segment_ids': seg, 178 | "input_ids_q": self._add_special_token(question), 179 | "input_ids_c": self._add_special_token(paragraph), 180 | 'para_offset': para_offset, 181 | 'paragraph_mask': paragraph_mask, 182 | 'question_mask': question_mask, 183 | 'doc_tokens': sample['doc_toks'], 184 | 'q_subtoks': q_subtoks, 185 | 'wp_tokens': para_subtoks, 186 | 'tok_to_orig_index': sample['tok_to_orig_index'], 187 | 'true_answers': sample["true_answers"], 188 | "start": start_tensor, 189 | "end": end_tensor, 190 | "no_answer": no_answer, 191 | } 192 | 193 | return item_tensor 194 | 195 | def _join_sents(self, sent1, sent2): 196 | cls = sent1.new_full((1,), self.tokenizer.vocab["[CLS]"]) 197 | sep = sent1.new_full((1,), self.tokenizer.vocab["[SEP]"]) 198 | sent1 = torch.cat([cls, sent1, sep]) 199 | sent2 = torch.cat([sent2, sep]) 200 | text = torch.cat([sent1, sent2]) 201 | segment1 = torch.zeros(sent1.size(0)).long() 202 | segment2 = torch.ones(sent2.size(0)).long() 203 | segment = torch.cat([segment1, segment2]) 204 | return text, segment 205 | 206 | def _add_special_token(self, sent): 207 | cls = sent.new_full((1,), self.tokenizer.vocab["[CLS]"]) 208 | sep = sent.new_full((1,), self.tokenizer.vocab["[SEP]"]) 209 | sent = torch.cat([cls, sent, sep]) 210 | return sent 211 | 212 | 213 | def binarize_list(self, words): 214 | return self.tokenizer.convert_tokens_to_ids(words) 215 | 216 | def tokenize(self, s): 217 | try: 218 | return self.tokenizer.tokenize(s) 219 | except: 220 | print('failed on', s) 221 | raise 222 | 223 | def __len__(self): 224 | return len(self.all_data) 225 | 226 | def openqa_collate(samples): 227 | if len(samples) == 0: 228 | return {} 229 | 230 | input_ids = collate_tokens([s['input_ids'] for s in samples], 0) 231 | start_masks = torch.zeros(input_ids.size()) 232 | for b_idx, s in enumerate(samples): 233 | for _ in s["start"]: 234 | if _ != -1: 235 | start_masks[b_idx, _] = 1 236 | 237 | net_input = { 238 | 'input_ids': input_ids, 239 | 'segment_ids': collate_tokens( 240 | [s['segment_ids'] for s in samples], 0), 241 | 'paragraph_mask': collate_tokens( 242 | [s['paragraph_mask'] for s in samples], 0,), 243 | 'question_mask': collate_tokens([s["question_mask"] for s in samples], 0), 244 | 'start_positions': collate_tokens( 245 | [s['start'] for s in samples], -1), 246 | 'end_positions': collate_tokens( 247 | [s['end'] for s in samples], -1), 248 | 'no_ans_targets': collate_tokens( 249 | [s['no_answer'] for s in samples], 0), 250 | 'input_mask': collate_tokens([torch.ones_like(s["input_ids"]) for s in samples], 0), 251 | 'start_masks': start_masks, 252 | 'input_ids_q': collate_tokens([s['input_ids_q'] for s in samples], 0), 253 | 'input_mask_q': collate_tokens([torch.ones_like(s["input_ids_q"]) for s in samples], 0), 254 | 'input_ids_c': collate_tokens([s['input_ids_c'] for s in samples], 0), 255 | 'input_mask_c': collate_tokens([torch.ones_like(s["input_ids_c"]) for s in samples], 0), 256 | } 257 | 258 | return { 259 | 'id': [s['qid'] for s in samples], 260 | "q": [s['q'] for s in samples], 261 | 'doc_tokens': [s['doc_tokens'] for s in samples], 262 | 'q_subtoks': [s['q_subtoks'] for s in samples], 263 | 'wp_tokens': [s['wp_tokens'] for s in samples], 264 | 'tok_to_orig_index': [s['tok_to_orig_index'] for s in samples], 265 | 'para_offset': [s['para_offset'] for s in samples], 266 | "true_answers": [s['true_answers'] for s in samples], 267 | 'net_input': net_input, 268 | } 269 | 270 | 271 | class top5k_generator(object): 272 | 273 | def __init__(self, 274 | retrieved_path, 275 | embed_path 276 | ): 277 | super().__init__() 278 | retrieved = [json.loads(l) for l in open(retrieved_path).readlines()] 279 | self.para_embed = np.load(embed_path) 280 | 281 | self.qid2para = {} 282 | for item in retrieved: 283 | self.qid2para[hash_question(item["question"])] = {"para_embed_idx": item["para_embed_idx"], "para_labels": item["para_labels"]} 284 | 285 | def generate(self, qid): 286 | para_labels = self.qid2para[qid]["para_labels"] 287 | para_embed_idx = self.qid2para[qid]["para_embed_idx"] 288 | if np.sum(para_labels) > 0: 289 | para_embed = torch.from_numpy(self.para_embed[para_embed_idx]) 290 | para_labels = torch.tensor(para_labels).nonzero().view(-1) 291 | result = {} 292 | result["para_embed"] = para_embed 293 | result["para_labels"] = para_labels 294 | return result 295 | else: 296 | return None 297 | 298 | 299 | if __name__ == "__main__": 300 | data_path = "../data/mrqa-train/SQuAD-tokenized.jsonl" 301 | tokenized_data = [json.loads(_.strip()) 302 | for _ in open(data_path).readlines()] 303 | q_lens = np.array([len(item['q_subtoks']) for item in tokenized_data]) 304 | c_lens = np.array([len(item['doc_subtoks']) for item in tokenized_data]) 305 | import pdb; pdb.set_trace() 306 | -------------------------------------------------------------------------------- /qa/eval_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | recover the answer string from BERT predictions 3 | """ 4 | 5 | import collections 6 | from tokenizer import BasicTokenizer 7 | import six 8 | 9 | def is_whitespace(c): 10 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 11 | return True 12 | return False 13 | 14 | 15 | def get_final_text(pred_text, orig_text, do_lower_case=False,verbose_logging=True): 16 | """Project the tokenized prediction back to the original text.""" 17 | def _strip_spaces(text): 18 | ns_chars = [] 19 | ns_to_s_map = collections.OrderedDict() 20 | for (i, c) in enumerate(text): 21 | if c == " ": 22 | continue 23 | ns_to_s_map[len(ns_chars)] = i 24 | ns_chars.append(c) 25 | ns_text = "".join(ns_chars) 26 | return (ns_text, ns_to_s_map) 27 | 28 | # We first tokenize `orig_text`, strip whitespace from the result 29 | # and `pred_text`, and check if they are the same length. If they are 30 | # NOT the same length, the heuristic has failed. If they are the same 31 | # length, we assume the characters are one-to-one aligned. 32 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 33 | 34 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 35 | 36 | start_position = tok_text.find(pred_text) 37 | if start_position == -1: 38 | if verbose_logging: 39 | print( 40 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 41 | return orig_text 42 | end_position = start_position + len(pred_text) - 1 43 | 44 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 45 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 46 | 47 | if len(orig_ns_text) != len(tok_ns_text): 48 | if verbose_logging: 49 | print("Length not equal after stripping spaces: '%s' vs '%s'", 50 | orig_ns_text, tok_ns_text) 51 | return orig_text 52 | 53 | # We then project the characters in `pred_text` back to `orig_text` using 54 | # the character-to-character alignment. 55 | tok_s_to_ns_map = {} 56 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 57 | tok_s_to_ns_map[tok_index] = i 58 | 59 | orig_start_position = None 60 | if start_position in tok_s_to_ns_map: 61 | ns_start_position = tok_s_to_ns_map[start_position] 62 | if ns_start_position in orig_ns_to_s_map: 63 | orig_start_position = orig_ns_to_s_map[ns_start_position] 64 | 65 | if orig_start_position is None: 66 | if verbose_logging: 67 | print("Couldn't map start position") 68 | return orig_text 69 | 70 | orig_end_position = None 71 | if end_position in tok_s_to_ns_map: 72 | ns_end_position = tok_s_to_ns_map[end_position] 73 | if ns_end_position in orig_ns_to_s_map: 74 | orig_end_position = orig_ns_to_s_map[ns_end_position] 75 | 76 | if orig_end_position is None: 77 | if verbose_logging: 78 | print("Couldn't map end position") 79 | return orig_text 80 | 81 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 82 | return output_text -------------------------------------------------------------------------------- /qa/msmarco_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def extract_qa_p(path="../data/msmarco-qa/train_v2.1.json", output="../data/msmarco-qa/train.txt"): 5 | data = json.load(open(path)) 6 | data_to_save = [] 7 | for id_, answers in data["answers"].items(): 8 | if answers[0] != 'No Answer Present.': 9 | passages = data["passages"][id_] 10 | query = data["query"][id_] 11 | relevant_p = [] 12 | for p in passages: 13 | if p["is_selected"]: 14 | relevant_p.append(p["passage_text"]) 15 | if len(relevant_p) != 0: 16 | data_to_save.append({"q": query, "answer": answers, "para": " ".join(relevant_p)}) 17 | 18 | with open(output, "w") as g: 19 | for l in data_to_save: 20 | g.write(json.dumps(l) + "\n") 21 | 22 | from tqdm import tqdm 23 | 24 | if __name__ == "__main__": 25 | # extract_qa_p() 26 | 27 | # data = [json.loads(l) 28 | # for l in open("../data/msmarco-qa/dev.txt").readlines()] 29 | 30 | # source_file = open("../data/msmarco-qa/val.source", "w") 31 | # target_file = open("../data/msmarco-qa/val.target", "w") 32 | # for _ in data: 33 | # source_file.write(_["para"] + "\n") 34 | # target_file.write(_["q"] + "\n") 35 | 36 | all_paras = [json.loads(l) for l in open( 37 | "../data/trec-2019/msmarco_paras.txt").readlines()] 38 | source_file = open("../data/msmarco-qa/test.source", "w") 39 | for _ in tqdm(all_paras): 40 | source_file.write(" ".join(_["text"].split()) + "\n") 41 | -------------------------------------------------------------------------------- /qa/official_eval.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for the MRQA Workshop Shared Task. 2 | Adapted fromt the SQuAD v1.1 official evaluation script. 3 | Usage: 4 | python official_eval.py dataset_file.jsonl.gz prediction_file.json 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import argparse 11 | import string 12 | import re 13 | import json 14 | import gzip 15 | import sys 16 | from collections import Counter 17 | # from allennlp.common.file_utils import cached_path 18 | 19 | 20 | def normalize_answer(s): 21 | """Lower text and remove punctuation, articles and extra whitespace.""" 22 | def remove_articles(text): 23 | return re.sub(r'\b(a|an|the)\b', ' ', text) 24 | 25 | def white_space_fix(text): 26 | return ' '.join(text.split()) 27 | 28 | def remove_punc(text): 29 | exclude = set(string.punctuation) 30 | return ''.join(ch for ch in text if ch not in exclude) 31 | 32 | def lower(text): 33 | return text.lower() 34 | 35 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 36 | 37 | 38 | def regex_match_score(prediction, pattern): 39 | """Check if the prediction matches the given regular expression.""" 40 | try: 41 | compiled = re.compile( 42 | pattern, 43 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE 44 | ) 45 | except BaseException: 46 | print('Regular expression failed to compile: %s' % pattern) 47 | return False 48 | return compiled.match(prediction) is not None 49 | 50 | def f1_score(prediction, ground_truth): 51 | prediction_tokens = normalize_answer(prediction).split() 52 | ground_truth_tokens = normalize_answer(ground_truth).split() 53 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 54 | num_same = sum(common.values()) 55 | if num_same == 0: 56 | return 0 57 | precision = 1.0 * num_same / len(prediction_tokens) 58 | recall = 1.0 * num_same / len(ground_truth_tokens) 59 | f1 = (2 * precision * recall) / (precision + recall) 60 | return f1 61 | 62 | 63 | def exact_match_score(prediction, ground_truth): 64 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 65 | 66 | 67 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 68 | scores_for_ground_truths = [] 69 | for ground_truth in ground_truths: 70 | score = metric_fn(prediction, ground_truth) 71 | scores_for_ground_truths.append(score) 72 | return max(scores_for_ground_truths) 73 | 74 | 75 | def read_predictions(prediction_file): 76 | with open(prediction_file) as f: 77 | predictions = json.load(f) 78 | return predictions 79 | 80 | 81 | def read_answers(gold_file): 82 | answers = {} 83 | with gzip.open(gold_file, 'rb') as f: 84 | for i, line in enumerate(f): 85 | example = json.loads(line) 86 | if i == 0 and 'header' in example: 87 | continue 88 | for qa in example['qas']: 89 | answers[qa['qid']] = qa['answers'] 90 | return answers 91 | 92 | 93 | def evaluate(answers, predictions, skip_no_answer=False): 94 | f1 = exact_match = total = 0 95 | for qid, ground_truths in answers.items(): 96 | if qid not in predictions: 97 | if not skip_no_answer: 98 | message = 'Unanswered question %s will receive score 0.' % qid 99 | print(message) 100 | total += 1 101 | continue 102 | total += 1 103 | prediction = predictions[qid] 104 | exact_match += metric_max_over_ground_truths( 105 | exact_match_score, prediction, ground_truths) 106 | f1 += metric_max_over_ground_truths( 107 | f1_score, prediction, ground_truths) 108 | 109 | exact_match = 100.0 * exact_match / total 110 | f1 = 100.0 * f1 / total 111 | 112 | return {'exact_match': exact_match, 'f1': f1} 113 | 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser( 117 | description='Evaluation for MRQA Workshop Shared Task') 118 | parser.add_argument('dataset_file', type=str, help='Dataset File') 119 | parser.add_argument('prediction_file', type=str, help='Prediction File') 120 | parser.add_argument('--skip-no-answer', action='store_true') 121 | args = parser.parse_args() 122 | 123 | # answers = read_answers(cached_path(args.dataset_file)) 124 | # predictions = read_predictions(cached_path(args.prediction_file)) 125 | # metrics = evaluate(answers, predictions, args.skip_no_answer) 126 | 127 | # print(json.dumps(metrics)) 128 | -------------------------------------------------------------------------------- /qa/online_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import numpy as np 4 | import random 5 | from prepro_utils import hash_question, normalize, find_ans_span_with_char_offsets, prepare 6 | from utils import DocDB 7 | import faiss 8 | from official_eval import normalize_answer 9 | from basic_tokenizer import SimpleTokenizer 10 | from prepro_dense import para_has_answer, match_answer_span 11 | from tqdm import tqdm 12 | 13 | from transformers import BertTokenizer 14 | 15 | """ 16 | retrieve paragraphs and find span for top5 on the fly 17 | """ 18 | 19 | 20 | def normalize_para(s): 21 | 22 | def white_space_fix(text): 23 | return ' '.join(text.split()) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(lower(s)) 29 | 30 | def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False): 31 | """Convert a list of 1d tensors into a padded 2d tensor.""" 32 | size = max(v.size(0) for v in values) 33 | res = values[0].new(len(values), size).fill_(pad_idx) 34 | 35 | def copy_tensor(src, dst): 36 | assert dst.numel() == src.numel() 37 | if move_eos_to_beginning: 38 | assert src[-1] == eos_idx 39 | dst[0] = eos_idx 40 | dst[1:] = src[:-1] 41 | else: 42 | dst.copy_(src) 43 | 44 | for i, v in enumerate(values): 45 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 46 | return res 47 | 48 | 49 | class OnlineSampler(object): 50 | 51 | def __init__(self, 52 | raw_data, 53 | tokenizer, 54 | max_query_length, 55 | max_length, 56 | db, 57 | para_embed, 58 | index2paraid='retrieval/index_data/idx_id.json', 59 | matched_para_path="", 60 | exact_search=False, 61 | cased=False, 62 | regex=False 63 | ): 64 | 65 | self.max_length = max_length 66 | self.max_query_length = max_query_length 67 | self.para_embed = para_embed 68 | self.cased = cased # spanbert used cased tokenization 69 | self.regex = regex 70 | 71 | if self.cased: 72 | self.cased_tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 73 | 74 | # if not exact_search: 75 | quantizer = faiss.IndexFlatIP(128) 76 | self.index = faiss.IndexIVFFlat(quantizer, 128, 100) 77 | self.index.train(self.para_embed) 78 | self.index.add(self.para_embed) 79 | self.index.nprobe = 20 80 | # else: 81 | # self.index = faiss.IndexFlatIP(128) 82 | # self.index.add(self.para_embed) 83 | 84 | self.tokenizer = tokenizer 85 | self.qa_data = [json.loads(l) for l in open(raw_data).readlines()] 86 | self.index2paraid = json.load(open(index2paraid)) 87 | self.para_db = db 88 | self.matched_para_path = matched_para_path 89 | if self.matched_para_path != "": 90 | print(f"Load matched gold paras from {self.matched_para_path}") 91 | annotated = [json.loads(l) for l in tqdm(open( 92 | self.matched_para_path).readlines())] 93 | self.qid2goldparas = {hash_question( 94 | item["question"]): item["matched_paras"] for item in annotated} 95 | 96 | self.basic_tokenizer = SimpleTokenizer() 97 | 98 | def shuffle(self): 99 | random.shuffle(self.qa_data) 100 | 101 | def __len__(self): 102 | return len(self.qa_data) 103 | 104 | def load(self, retriever, k=5): 105 | for qa in self.qa_data: 106 | with torch.no_grad(): 107 | q_ids = torch.LongTensor(self.tokenizer.encode( 108 | qa["question"], max_length=self.max_query_length)).view(1,-1).cuda() 109 | q_masks = torch.ones(q_ids.shape).bool().view(1,-1).cuda() 110 | q_cls = retriever.bert_q(q_ids, q_masks)[1] 111 | q_embed = retriever.proj_q(q_cls).data.cpu().numpy().astype('float32') 112 | 113 | _, I = self.index.search(q_embed, 5000) # retrieve 114 | para_embed_idx = I.reshape(-1) 115 | 116 | if self.cased: 117 | q_ids_cased = torch.LongTensor(self.cased_tokenizer.encode( 118 | qa["question"], max_length=self.max_query_length)).view(1, -1) 119 | 120 | para_idx = [self.index2paraid[str(_)] for _ in para_embed_idx] 121 | para_embeds = self.para_embed[para_embed_idx] 122 | 123 | qid = hash_question(qa["question"]) 124 | gold_paras = self.qid2goldparas[qid] 125 | 126 | # match answer strings 127 | p_labels = [] 128 | batched_examples = [] 129 | topk5000_labels = [int(_ in gold_paras) for _ in para_idx] 130 | 131 | # match answer spans in top5 paras 132 | for p_idx in para_idx[:k]: 133 | p = normalize(self.para_db.get_doc_text(p_idx)) 134 | # p_covered, matched_string = para_has_answer(p, qa["answer"], self.basic_tokenizer) 135 | matched_spans = match_answer_span( 136 | p, qa["answer"], self.basic_tokenizer, match="regex" if self.regex else "string") 137 | p_covered = int(len(matched_spans) > 0) 138 | ans_starts, ans_ends, ans_texts = [], [], [] 139 | 140 | if self.cased: 141 | doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare(p, self.cased_tokenizer) 142 | else: 143 | doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare( 144 | p, self.tokenizer) 145 | 146 | if p_covered: 147 | for matched_string in matched_spans: 148 | char_starts = [i for i in range( 149 | len(p)) if p.startswith(matched_string, i)] 150 | if len(char_starts) > 0: 151 | char_ends = [start + len(matched_string) - 1 for start in char_starts] 152 | answer = {"text": matched_string, "char_spans": list( 153 | zip(char_starts, char_ends))} 154 | 155 | if self.cased: 156 | ans_spans = find_ans_span_with_char_offsets(answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, self.cased_tokenizer) 157 | else: 158 | ans_spans = find_ans_span_with_char_offsets( 159 | answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, self.tokenizer) 160 | 161 | for s, e in ans_spans: 162 | ans_starts.append(s) 163 | ans_ends.append(e) 164 | ans_texts.append(matched_string) 165 | batched_examples.append({ 166 | "qid": hash_question(qa["question"]), 167 | "q": qa["question"], 168 | "true_answers": qa["answer"], 169 | "doc_subtoks": all_doc_tokens, 170 | "starts": ans_starts, 171 | "ends": ans_ends, 172 | "covered": p_covered 173 | }) 174 | 175 | # # look up saved 176 | # if p_idx in gold_paras: 177 | # p_covered = 1 178 | # all_doc_tokens = gold_paras[p_idx]["doc_subtoks"] 179 | # ans_starts = gold_paras[p_idx]["starts"] 180 | # ans_ends = gold_paras[p_idx]["ends"] 181 | # ans_texts = gold_paras[p_idx]["span_texts"] 182 | # else: 183 | # p_covered = 0 184 | # p = normalize(self.para_db.get_doc_text(p_idx)) 185 | # _, _, _, _, all_doc_tokens = prepare(p, self.tokenizer) 186 | # ans_starts, ans_ends, ans_texts = [], [], [] 187 | 188 | # batched_examples.append({ 189 | # "qid": hash_question(qa["question"]), 190 | # "q": qa["question"], 191 | # "true_answers": qa["answer"], 192 | # "doc_subtoks": all_doc_tokens, 193 | # "starts": ans_starts, 194 | # "ends": ans_ends, 195 | # "covered": p_covered 196 | # }) 197 | p_labels.append(int(p_covered)) 198 | 199 | # calculate loss only when the top5000 covered the answer passage 200 | if np.sum(topk5000_labels) > 0 or np.sum(p_labels) > 0: 201 | # training tensors 202 | for item in batched_examples: 203 | item["input_ids_q"] = q_ids.view(-1).cpu() 204 | 205 | if self.cased: 206 | item["input_ids_q_cased"] = q_ids_cased.view(-1) 207 | para_offset = item["input_ids_q_cased"].size(0) 208 | else: 209 | para_offset = item["input_ids_q"].size(0) 210 | 211 | max_toks_for_doc = self.max_length - para_offset - 1 212 | para_subtoks = item["doc_subtoks"] 213 | if len(para_subtoks) > max_toks_for_doc: 214 | para_subtoks = para_subtoks[:max_toks_for_doc] 215 | 216 | if self.cased: 217 | p_ids = self.cased_tokenizer.convert_tokens_to_ids(para_subtoks) 218 | else: 219 | p_ids = self.tokenizer.convert_tokens_to_ids( 220 | para_subtoks) 221 | item["input_ids_c"] = self._add_special_token(torch.LongTensor(p_ids)) 222 | paragraph = item["input_ids_c"][1:-1] 223 | if self.cased: 224 | item["input_ids"], item["segment_ids"] = self._join_sents( 225 | item["input_ids_q_cased"][1:-1], item["input_ids_c"][1:-1]) 226 | else: 227 | item["input_ids"], item["segment_ids"] = self._join_sents(item["input_ids_q"][1:-1], item["input_ids_c"][1:-1]) 228 | item["para_offset"] = para_offset 229 | item["paragraph_mask"] = torch.zeros(item["input_ids"].shape).bool() 230 | item["paragraph_mask"][para_offset:-1] = 1 231 | 232 | starts, ends, covered = item["starts"], item["ends"], item["covered"] 233 | start_positions, end_positions = [], [] 234 | 235 | covered = item["covered"] 236 | if covered: 237 | covered = 0 238 | for s, e in zip(starts, ends): 239 | assert s <= e 240 | if s >= paragraph.size(0): 241 | continue 242 | else: 243 | start_position = min( 244 | s, paragraph.size(0) - 1) + para_offset 245 | end_position = min(e, paragraph.size(0) - 1) + para_offset 246 | covered = 1 247 | start_positions.append(start_position) 248 | end_positions.append(end_position) 249 | if len(start_positions) == 0: 250 | assert not covered 251 | start_positions.append(-1) 252 | end_positions.append(-1) 253 | 254 | start_tensor, end_tensor, covered = torch.LongTensor( 255 | start_positions), torch.LongTensor(end_positions), torch.LongTensor([covered]) 256 | 257 | item["start"] = start_tensor 258 | item["end"] = end_tensor 259 | item["covered"] = covered 260 | 261 | 262 | yield self.collate(batched_examples, para_embeds, topk5000_labels) 263 | else: 264 | yield {} 265 | 266 | def eval_load(self, retriever, k=5): 267 | for qa in self.qa_data: 268 | with torch.no_grad(): 269 | q_ids = torch.LongTensor(self.tokenizer.encode(qa["question"], max_length=self.max_query_length)).view(1, -1).cuda() 270 | q_masks = torch.ones(q_ids.shape).bool().view(1, -1).cuda() 271 | q_cls = retriever.bert_q(q_ids, q_masks)[1] 272 | q_embed = retriever.proj_q( 273 | q_cls).data.cpu().numpy().astype('float32') 274 | _, I = self.index.search(q_embed, k) 275 | para_embed_idx = I.reshape(-1) 276 | para_idx = [self.index2paraid[str(_)] for _ in para_embed_idx] 277 | paras = [normalize(self.para_db.get_doc_text(idx)) 278 | for idx in para_idx] 279 | para_embeds = self.para_embed[para_embed_idx] 280 | 281 | if self.cased: 282 | q_ids_cased = torch.LongTensor(self.cased_tokenizer.encode( 283 | qa["question"], max_length=self.max_query_length)).view(1, -1) 284 | 285 | batched_examples = [] 286 | # match answer spans in top5 paras 287 | for p in paras: 288 | p = normalize(p) 289 | 290 | tokenizer = self.cased_tokenizer if self.cased else self.tokenizer 291 | doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare( 292 | p, tokenizer) 293 | 294 | batched_examples.append({ 295 | "qid": hash_question(qa["question"]), 296 | "q": qa["question"], 297 | "true_answers": qa["answer"], 298 | "doc_toks": doc_tokens, 299 | "doc_subtoks": all_doc_tokens, 300 | "tok_to_orig_index": tok_to_orig_index, 301 | }) 302 | 303 | for item in batched_examples: 304 | item["input_ids_q"] = q_ids.view(-1).cpu() 305 | 306 | if self.cased: 307 | item["input_ids_q_cased"] = q_ids_cased.view(-1) 308 | para_offset = item["input_ids_q_cased"].size(0) 309 | else: 310 | para_offset = item["input_ids_q"].size(0) 311 | max_toks_for_doc = self.max_length - para_offset - 1 312 | para_subtoks = item["doc_subtoks"] 313 | if len(para_subtoks) > max_toks_for_doc: 314 | para_subtoks = para_subtoks[:max_toks_for_doc] 315 | if self.cased: 316 | p_ids = self.cased_tokenizer.convert_tokens_to_ids( 317 | para_subtoks) 318 | else: 319 | p_ids = self.tokenizer.convert_tokens_to_ids( 320 | para_subtoks) 321 | item["input_ids_c"] = self._add_special_token( 322 | torch.LongTensor(p_ids)) 323 | paragraph = item["input_ids_c"][1:-1] 324 | if self.cased: 325 | item["input_ids"], item["segment_ids"] = self._join_sents( 326 | item["input_ids_q_cased"][1:-1], item["input_ids_c"][1:-1]) 327 | else: 328 | item["input_ids"], item["segment_ids"] = self._join_sents( 329 | item["input_ids_q"][1:-1], item["input_ids_c"][1:-1]) 330 | item["para_offset"] = para_offset 331 | item["paragraph_mask"] = torch.zeros( 332 | item["input_ids"].shape).bool() 333 | item["paragraph_mask"][para_offset:-1] = 1 334 | 335 | yield self.collate(batched_examples, para_embeds) 336 | 337 | 338 | def _add_special_token(self, sent): 339 | cls = sent.new_full((1,), self.tokenizer.vocab["[CLS]"]) 340 | sep = sent.new_full((1,), self.tokenizer.vocab["[SEP]"]) 341 | sent = torch.cat([cls, sent, sep]) 342 | return sent 343 | 344 | def _join_sents(self, sent1, sent2): 345 | cls = sent1.new_full((1,), self.tokenizer.vocab["[CLS]"]) 346 | sep = sent1.new_full((1,), self.tokenizer.vocab["[SEP]"]) 347 | sent1 = torch.cat([cls, sent1, sep]) 348 | sent2 = torch.cat([sent2, sep]) 349 | text = torch.cat([sent1, sent2]) 350 | segment1 = torch.zeros(sent1.size(0)).long() 351 | segment2 = torch.ones(sent2.size(0)).long() 352 | segment = torch.cat([segment1, segment2]) 353 | return text, segment 354 | 355 | def collate(self, samples, para_embeds, topk5000_labels=None): 356 | if len(samples) == 0: 357 | return {} 358 | 359 | input_ids = collate_tokens([s['input_ids'] for s in samples], 0) 360 | 361 | if "start" in samples[0]: 362 | assert topk5000_labels is not None 363 | net_input = { 364 | 'input_ids': input_ids, 365 | 'segment_ids': collate_tokens( 366 | [s['segment_ids'] for s in samples], 0), 367 | 'paragraph_mask': collate_tokens( 368 | [s['paragraph_mask'] for s in samples], 0,), 369 | 'start_positions': collate_tokens( 370 | [s['start'] for s in samples], -1), 371 | 'end_positions': collate_tokens( 372 | [s['end'] for s in samples], -1), 373 | 'para_targets': collate_tokens( 374 | [s['covered'] for s in samples], 0), 375 | 'input_mask': collate_tokens([torch.ones_like(s["input_ids"]) for s in samples], 0), 376 | 'input_ids_q': collate_tokens([s['input_ids_q'] for s in samples], 0), 377 | 'input_mask_q': collate_tokens([torch.ones_like(s["input_ids_q"]) for s in samples], 0), 378 | 'para_embed': torch.from_numpy(para_embeds), 379 | "top5000_labels": torch.LongTensor(topk5000_labels) 380 | } 381 | return { 382 | 'id': [s['qid'] for s in samples], 383 | "q": [s['q'] for s in samples], 384 | 'wp_tokens': [s['doc_subtoks'] for s in samples], 385 | 'para_offset': [s['para_offset'] for s in samples], 386 | "true_answers": [s['true_answers'] for s in samples], 387 | 'net_input': net_input, 388 | } 389 | 390 | else: 391 | net_input = { 392 | 'input_ids': input_ids, 393 | 'segment_ids': collate_tokens( 394 | [s['segment_ids'] for s in samples], 0), 395 | 'paragraph_mask': collate_tokens( 396 | [s['paragraph_mask'] for s in samples], 0,), 397 | 'input_mask': collate_tokens([torch.ones_like(s["input_ids"]) for s in samples], 0), 398 | 'input_ids_q': collate_tokens([s['input_ids_q'] for s in samples], 0), 399 | 'input_mask_q': collate_tokens([torch.ones_like(s["input_ids_q"]) for s in samples], 0), 400 | 'para_embed': torch.from_numpy(para_embeds) 401 | } 402 | 403 | return { 404 | 'id': [s['qid'] for s in samples], 405 | "q": [s['q'] for s in samples], 406 | 'doc_tokens': [s['doc_toks'] for s in samples], 407 | 'wp_tokens': [s['doc_subtoks'] for s in samples], 408 | 'tok_to_orig_index': [s['tok_to_orig_index'] for s in samples], 409 | 'para_offset': [s['para_offset'] for s in samples], 410 | "true_answers": [s['true_answers'] for s in samples], 411 | 'net_input': net_input, 412 | } 413 | 414 | 415 | 416 | if __name__ == "__main__": 417 | index_path = "retrieval/index_data/para_embed_3_28_c10000.npy" 418 | raw_data = "../data/nq-train.txt" 419 | 420 | 421 | from transformers import BertConfig, BertTokenizer 422 | from retrieval.retriever import BertForRetriever 423 | from config import get_args 424 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 425 | bert_config = BertConfig.from_pretrained('bert-base-uncased') 426 | args = get_args() 427 | retriever = BertForRetriever(bert_config, args) 428 | 429 | from utils import load_saved 430 | retriever_path = "retrieval/logs/splits_3_28_c10000-seed42-bsz640-fp16True-retrieve-from94_c1000_continue_from_failed-lr1e-05-bert-base-uncased-filterTrue/checkpoint_best.pt" 431 | retriever = load_saved(retriever, retriever_path) 432 | retriever.cuda() 433 | 434 | sampler = OnlineSampler(index_path, raw_data, tokenizer, args.max_query_length, args.max_seq_length) 435 | 436 | sampler.shuffle() 437 | retriever.eval() 438 | for batch in sampler.load(retriever): 439 | if batch is not {}: 440 | print(batch.keys()) 441 | print(batch["net_input"]["para_targets"]) 442 | import pdb; pdb.set_trace() 443 | 444 | -------------------------------------------------------------------------------- /qa/prepro_dense.py: -------------------------------------------------------------------------------- 1 | from prepro_utils import hash_question, normalize, find_ans_span_with_char_offsets, prepare 2 | import json 3 | from utils import DocDB 4 | from official_eval import normalize_answer 5 | import numpy as np 6 | from tqdm import tqdm 7 | from basic_tokenizer import RegexpTokenizer, SimpleTokenizer 8 | 9 | from multiprocessing import Pool as ProcessPool 10 | from multiprocessing.util import Finalize 11 | from functools import partial 12 | import re 13 | 14 | import sys 15 | from transformers import BertTokenizer 16 | 17 | PROCESS_TOK = None 18 | PROCESS_DB = None 19 | BERT_TOK = None 20 | 21 | def init(): 22 | global PROCESS_TOK, PROCESS_DB, BERT_TOK 23 | PROCESS_TOK = SimpleTokenizer() 24 | BERT_TOK = BertTokenizer.from_pretrained("bert-base-uncased") 25 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 26 | PROCESS_DB = DocDB('../data/nq_paras.db') 27 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 28 | 29 | 30 | def regex_match(text, pattern): 31 | """return all spans that match the pattern""" 32 | try: 33 | pattern = re.compile( 34 | pattern, 35 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, 36 | ) 37 | except BaseException: 38 | print('Regular expression failed to compile: %s' % pattern) 39 | return [] 40 | 41 | matched = [x.group() for x in re.finditer(pattern, text)] 42 | return list(set(matched)) 43 | 44 | def para_has_answer(p, answer, tokenizer): 45 | tokens = tokenizer.tokenize(p) 46 | text = tokens.words(uncased=True) 47 | matched = [] 48 | for single_answer in answer: 49 | single_answer = normalize(single_answer) 50 | single_answer = tokenizer.tokenize(single_answer) 51 | single_answer = single_answer.words(uncased=True) 52 | for i in range(0, len(text) - len(single_answer) + 1): 53 | if single_answer == text[i: i + len(single_answer)]: 54 | return True, tokens.slice(i, i + len(single_answer)).untokenize() 55 | return False, "" 56 | 57 | def match_answer_span(p, answer, tokenizer, match="string"): 58 | # p has been normalized 59 | if match == 'string': 60 | tokens = tokenizer.tokenize(p) 61 | text = tokens.words(uncased=True) 62 | matched = set() 63 | for single_answer in answer: 64 | single_answer = normalize(single_answer) 65 | single_answer = tokenizer.tokenize(single_answer) 66 | single_answer = single_answer.words(uncased=True) 67 | for i in range(0, len(text) - len(single_answer) + 1): 68 | if single_answer == text[i: i + len(single_answer)]: 69 | matched.add(tokens.slice(i, i + len(single_answer)).untokenize()) 70 | return list(matched) 71 | elif match == 'regex': 72 | # Answer is a regex 73 | single_answer = normalize(answer[0]) 74 | return regex_match(p, single_answer) 75 | 76 | def process_qa_para(qa_with_result, k=10000, match="string"): 77 | global PROCESS_DB, PROCESS_TOK 78 | qa, result = qa_with_result 79 | matched_paras = {} 80 | for para_id in result["para_id"][:k]: 81 | p = PROCESS_DB.get_doc_text(para_id) 82 | p = normalize(p) 83 | if match == "string": 84 | covered, matched = para_has_answer(p, qa["answer"], PROCESS_TOK) 85 | elif match == "regex": 86 | single_answer = normalize(qa["answer"][0]) 87 | matched = regex_match(p, single_answer) 88 | covered = len(matched) > 0 89 | if covered: 90 | matched_paras[para_id] = matched 91 | qa["matched_paras"] = matched_paras 92 | return qa 93 | 94 | def find_span(example): 95 | global PROCESS_DB, BERT_TOK 96 | annotated = {} 97 | for para_id, matched in example["matched_paras"].items(): 98 | p = normalize(PROCESS_DB.get_doc_text(para_id)) 99 | ans_starts, ans_ends, ans_texts = [], [], [] 100 | doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare( 101 | p, BERT_TOK) 102 | char_starts = [i for i in range( 103 | len(p)) if p.startswith(matched, i)] 104 | assert len(char_starts) > 0 105 | char_ends = [start + len(matched) - 1 for start in char_starts] 106 | answer = {"text": matched, "char_spans": list( 107 | zip(char_starts, char_ends))} 108 | ans_spans = find_ans_span_with_char_offsets( 109 | answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, BERT_TOK) 110 | for s, e in ans_spans: 111 | ans_starts.append(s) 112 | ans_ends.append(e) 113 | ans_texts.append(matched) 114 | annotated[para_id] = { 115 | # "doc_toks": doc_tokens, 116 | "doc_subtoks": all_doc_tokens, 117 | "starts": ans_starts, 118 | "ends": ans_ends, 119 | "span_texts": [matched], 120 | # "tok_to_orig_index": tok_to_orig_index 121 | } 122 | example["matched_paras"] = annotated 123 | return example 124 | 125 | 126 | def process_ground_paras(retrieved="../data/wq_finetuneq_train_10000.txt", save_path="../data/wq_ft_train_matched.txt", raw_data="../data/wq-train.txt", num_workers=40, debug=False, k=10000, match="string"): 127 | retrieved = [json.loads(l) for l in open(retrieved).readlines()] 128 | raw_data = [json.loads(l) for l in open(raw_data).readlines()] 129 | 130 | tokenizer = SimpleTokenizer() 131 | recall = [] 132 | processes = ProcessPool( 133 | processes=num_workers, 134 | initializer=init, 135 | ) 136 | process_qa_para_partial = partial(process_qa_para, k=k, match=match) 137 | num_tasks = len(raw_data) 138 | results = [] 139 | for _ in tqdm(processes.imap_unordered(process_qa_para_partial, zip(raw_data, retrieved)), total=len(raw_data)): 140 | results.append(_) 141 | 142 | topk_covered = [len(r["matched_paras"])>0 for r in results] 143 | print(np.mean(topk_covered)) 144 | 145 | if debug: 146 | return 147 | 148 | # # annotate those match paras, accelerate training 149 | # processed = [] 150 | # for _ in tqdm(processes.imap_unordered(find_span, results), total=len(results)): 151 | # processed.append(_) 152 | 153 | processes.close() 154 | processes.join() 155 | 156 | with open(save_path, "w") as g: 157 | for _ in results: 158 | g.write(json.dumps(_) + "\n") 159 | 160 | 161 | def debug(retrieved="../data/wq_finetuneq_dev_5000.txt", raw_data="../data/wq-dev.txt", precomputed="../data/wq_ft_dev_matched.txt", k=10): 162 | # check wether it reasonable to precompute a paragraph set 163 | retrieved = [json.loads(l) for l in open(retrieved).readlines()] 164 | raw_data = [json.loads(l) for l in open(raw_data).readlines()] 165 | 166 | annotated = [json.loads(l) for l in open(precomputed).readlines()] 167 | qid2goldparas = {hash_question(item["question"]): item["matched_paras"] for item in annotated} 168 | 169 | topk_covered = [] 170 | for qa, result in tqdm(zip(raw_data, retrieved), total=len(raw_data)): 171 | qid = hash_question(qa["question"]) 172 | covered = 0 173 | for para_id in result["para_id"][:k]: 174 | if para_id in qid2goldparas[qid]: 175 | covered = 1 176 | break 177 | topk_covered.append(covered) 178 | print(np.mean(topk_covered)) 179 | 180 | 181 | if __name__ == "__main__": 182 | 183 | # trec 184 | process_ground_paras(retrieved="../data/trec/trec_finetuneq_train-20000.txt", save_path="../data/trec_train_matched_20000.txt", raw_data="../data/trec-train.txt", num_workers=30, k=20000, match="regex") 185 | 186 | # # wq 187 | # process_ground_paras(retrieved="../data/wq_finetuneq_train-combined_15000.txt", 188 | # save_path="../data/wq_ft_train-combined_matched_15000.txt", raw_data="../data/wq-train-combined.txt", num_workers=30, k=15000) 189 | 190 | # nq 191 | #process_ground_paras(retrieved="../data/nq_finetuneq_train_10000.txt", 192 | # save_path="../data/nq_ft_train_matched.txt", raw_data="../data/nq-train.txt", num_workers=40) 193 | 194 | 195 | # # debug 196 | # process_ground_paras( 197 | # retrieved="../data/wq_finetuneq_dev_5000_fi.txt", raw_data="../data/wq-dev.txt", debug=True, k=5) 198 | # process_ground_paras( 199 | # retrieved="../data/wq_finetuneq_dev.txt", raw_data="../data/wq-dev.txt", debug=True, k=5) 200 | # process_ground_paras( 201 | # retrieved="../data/nq_finetuneq_dev_5000_fi.txt", raw_data="../data/nq-dev.txt", debug=True, k=5) 202 | # process_ground_paras( 203 | # retrieved="../data/nq_finetuneq_dev.txt", raw_data="../data/nq-dev.txt", debug=True, k=5) 204 | # #debug(k=30) 205 | # process_ground_paras(retrieved="../data/nq_finetuneq_train_10000.txt", 206 | # save_path="../data/nq_ft_train_matched.txt", raw_data="../data/nq-train.txt", num_workers=40) 207 | 208 | 209 | # # debug 210 | # process_ground_paras( 211 | # retrieved="../data/wq_finetuneq_dev_5000.txt", raw_data="../data/wq-dev.txt", debug=True, k=30) 212 | # debug(k=30) 213 | 214 | -------------------------------------------------------------------------------- /qa/prepro_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tokenizer import _is_whitespace, _is_punctuation, process, whitespace_tokenize 3 | from transformers import BertTokenizer 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | import hashlib 7 | import unicodedata 8 | import re 9 | import sys 10 | import numpy as np 11 | 12 | def hash_question(q): 13 | hash_object = hashlib.md5(q.encode()) 14 | return hash_object.hexdigest() 15 | 16 | def normalize(text): 17 | """Resolve different type of unicode encodings.""" 18 | return unicodedata.normalize('NFD', text) 19 | 20 | def load_mrqa_dataset(path): 21 | raw_data = [json.loads(line.strip()) for line in open(path).readlines()[1:]] 22 | 23 | qa_data = [] 24 | for item in raw_data: 25 | id_ = item["id"] 26 | context = item["context"] 27 | for qa in item["qas"]: 28 | qid = qa["qid"] 29 | question = qa["question"] 30 | answers = qa.get("answers", []) 31 | matched_answers = qa.get("detected_answers", []) 32 | qa_data.append( 33 | { 34 | "qid": qid, 35 | "question": question, 36 | "context": context, 37 | "matched_answers": matched_answers, 38 | "true_answers": answers 39 | } 40 | ) 41 | return qa_data 42 | 43 | 44 | def load_openqa_dataset(path, filter_no_answer=False): 45 | 46 | def _check_no_ans(sample): 47 | no_ans = True 48 | for para in sample["retrieved"]: 49 | if para["matched_answer"] != "": 50 | no_ans = False 51 | return no_ans 52 | return no_ans 53 | 54 | raw_data = [json.loads(line.strip()) for line in open(path).readlines()] 55 | 56 | if filter_no_answer: 57 | raw_data = [item for item in raw_data if not _check_no_ans(item)] 58 | 59 | print(f"Loading {len(raw_data)} QA pairs") 60 | return raw_data 61 | 62 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 63 | orig_answer_text): 64 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 65 | 66 | for new_start in range(input_start, input_end + 1): 67 | for new_end in range(input_end, new_start - 1, -1): 68 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 69 | if text_span == tok_answer_text: 70 | return (new_start, new_end) 71 | 72 | return (input_start, input_end) 73 | 74 | def find_ans_span_with_char_offsets(detected_ans, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, tokenizer): 75 | # could return mutiple spans for an answer string 76 | ans_text = detected_ans["text"] 77 | char_spans = detected_ans["char_spans"] 78 | ans_subtok_spans = [] 79 | for char_start, char_end in char_spans: 80 | tok_start = char_to_word_offset[char_start] 81 | tok_end = char_to_word_offset[char_end] # char_end points to the last char of the answer, not one after 82 | sub_tok_start = orig_to_tok_index[tok_start] 83 | 84 | if tok_end < len(doc_tokens) - 1: 85 | sub_tok_end = orig_to_tok_index[tok_end + 1] - 1 86 | else: 87 | sub_tok_end = len(all_doc_tokens) - 1 88 | 89 | actual_text = " ".join(doc_tokens[tok_start:(tok_end + 1)]) 90 | cleaned_answer_text = " ".join(whitespace_tokenize(ans_text)) 91 | if actual_text.find(cleaned_answer_text) == -1: 92 | print("Could not find answer: '{}' vs. '{}'".format( 93 | actual_text, cleaned_answer_text)) 94 | 95 | (sub_tok_start, sub_tok_end) = _improve_answer_span( 96 | all_doc_tokens, sub_tok_start, sub_tok_end, tokenizer, ans_text) 97 | ans_subtok_spans.append((sub_tok_start, sub_tok_end)) 98 | 99 | return ans_subtok_spans 100 | 101 | def tokenize_item(sample, tokenizer): 102 | doc_tokens = [] 103 | char_to_word_offset = [] 104 | prev_is_whitespace = True 105 | for c in sample["context"]: 106 | if _is_whitespace(c): 107 | prev_is_whitespace = True 108 | else: 109 | if prev_is_whitespace: 110 | doc_tokens.append(c) 111 | else: 112 | doc_tokens[-1] += c 113 | prev_is_whitespace = False 114 | char_to_word_offset.append(len(doc_tokens) - 1) 115 | 116 | orig_to_tok_index = [] 117 | tok_to_orig_index = [] 118 | all_doc_tokens = [] 119 | for (i, token) in enumerate(doc_tokens): 120 | orig_to_tok_index.append(len(all_doc_tokens)) 121 | sub_tokens = process(token, tokenizer) 122 | for sub_token in sub_tokens: 123 | tok_to_orig_index.append(i) 124 | all_doc_tokens.append(sub_token) 125 | q_sub_toks = process(sample["question"], tokenizer) 126 | 127 | # finding answer spans 128 | ans_starts, ans_ends, ans_texts = [], [], [] 129 | for answer in sample["matched_answers"]: 130 | ans_spans = find_ans_span_with_char_offsets( 131 | answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, tokenizer) 132 | 133 | for (s, e) in ans_spans: 134 | ans_starts.append(s) 135 | ans_ends.append(e) 136 | ans_texts.append(answer["text"]) 137 | 138 | return { 139 | "q_subtoks": q_sub_toks, 140 | "qid": sample["qid"], 141 | "doc_toks": doc_tokens, 142 | "doc_subtoks": all_doc_tokens, 143 | "tok_to_orig_index": tok_to_orig_index, 144 | "starts": ans_starts, 145 | "ends": ans_ends, 146 | "span_texts": ans_texts, 147 | "true_answers": sample["true_answers"] 148 | } 149 | 150 | def prepare(context, tokenizer): 151 | doc_tokens = [] 152 | char_to_word_offset = [] 153 | prev_is_whitespace = True 154 | 155 | for c in context: 156 | if _is_whitespace(c): 157 | prev_is_whitespace = True 158 | else: 159 | if prev_is_whitespace: 160 | doc_tokens.append(c) 161 | else: 162 | doc_tokens[-1] += c 163 | prev_is_whitespace = False 164 | char_to_word_offset.append(len(doc_tokens) - 1) 165 | 166 | orig_to_tok_index = [] 167 | tok_to_orig_index = [] 168 | all_doc_tokens = [] 169 | for (i, token) in enumerate(doc_tokens): 170 | orig_to_tok_index.append(len(all_doc_tokens)) 171 | sub_tokens = tokenizer.tokenize(token) 172 | for sub_token in sub_tokens: 173 | tok_to_orig_index.append(i) 174 | all_doc_tokens.append(sub_token) 175 | return doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens 176 | 177 | def tokenize_item_openqa(sample, tokenizer): 178 | """ 179 | process all the retrieved paragraphs of a QA pair 180 | """ 181 | q_sub_toks = process(sample["question"], tokenizer) 182 | 183 | examples = [] 184 | for para_idx, para in enumerate(sample["retrieved"]): 185 | doc_tokens = [] 186 | char_to_word_offset = [] 187 | prev_is_whitespace = True 188 | context = normalize(para["para"]) 189 | 190 | for c in context: 191 | if _is_whitespace(c): 192 | prev_is_whitespace = True 193 | else: 194 | if prev_is_whitespace: 195 | doc_tokens.append(c) 196 | else: 197 | doc_tokens[-1] += c 198 | prev_is_whitespace = False 199 | char_to_word_offset.append(len(doc_tokens) - 1) 200 | 201 | orig_to_tok_index = [] 202 | tok_to_orig_index = [] 203 | all_doc_tokens = [] 204 | for (i, token) in enumerate(doc_tokens): 205 | orig_to_tok_index.append(len(all_doc_tokens)) 206 | sub_tokens = process(token, tokenizer) 207 | for sub_token in sub_tokens: 208 | tok_to_orig_index.append(i) 209 | all_doc_tokens.append(sub_token) 210 | 211 | # finding answer spans 212 | ans_starts, ans_ends, ans_texts = [], [], [] 213 | no_answer = 0 214 | if para["matched_answer"] == "": 215 | ans_starts.append(-1) 216 | ans_ends.append(-1) 217 | ans_texts.append("") 218 | no_answer = 1 219 | else: 220 | ans_texts.append(para["matched_answer"]) 221 | char_starts = [i for i in range( 222 | len(context)) if context.startswith(para["matched_answer"], i)] 223 | 224 | if len(char_starts) == 0: 225 | import pdb; pdb.set_trace() 226 | char_ends = [start + len(para["matched_answer"]) - 1 for start in char_starts] 227 | answer = {"text": para["matched_answer"], "char_spans": list(zip(char_starts, char_ends))} 228 | ans_spans = find_ans_span_with_char_offsets( 229 | answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, tokenizer) 230 | for (s, e) in ans_spans: 231 | ans_starts.append(s) 232 | ans_ends.append(e) 233 | ans_texts.append(answer["text"]) 234 | qid = hash_question(sample["question"]) 235 | 236 | examples.append({ 237 | "q": sample["question"], 238 | "q_subtoks": q_sub_toks, 239 | "qid": qid, 240 | "para_id": para_idx, 241 | "doc_toks": doc_tokens, 242 | "doc_subtoks": all_doc_tokens, 243 | "tok_to_orig_index": tok_to_orig_index, 244 | "starts": ans_starts, 245 | "ends": ans_ends, 246 | "span_texts": ans_texts, 247 | "true_answers": sample["gold_answer"], 248 | "no_answer": no_answer, 249 | "bm25": para["bm25"], 250 | }) 251 | 252 | return examples 253 | 254 | def tokenize_items(items, tokenizer, verbose=False, openqa=False): 255 | if verbose: 256 | items = tqdm(items) 257 | if openqa: 258 | results = [] 259 | for _ in items: 260 | results.extend(tokenize_item_openqa(_, tokenizer)) 261 | return results 262 | else: 263 | return [tokenize_item(_, tokenizer) for _ in items] 264 | 265 | def tokenize_data(dataset, bert_model_name="bert-large-cased-whole-word-masking", num_workers=10, save_path=None, openqa=False): 266 | 267 | tokenizer = BertTokenizer.from_pretrained(bert_model_name) 268 | 269 | chunk_size = len(dataset) // num_workers 270 | offsets = [ 271 | _ * chunk_size for _ in range(0, num_workers)] + [len(dataset)] 272 | pool = Pool(processes=num_workers) 273 | print(f'Start multi-processing with {num_workers} workers....') 274 | results = [pool.apply_async(tokenize_items, args=( 275 | dataset[offsets[work_id]: offsets[work_id + 1]], tokenizer, True, openqa)) for work_id in range(num_workers)] 276 | outputs = [p.get() for p in results] 277 | samples = [] 278 | for o in outputs: 279 | samples.extend(o) 280 | 281 | # check the average number of matched spans 282 | answer_nums = [len(item["starts"]) 283 | for item in samples if item["no_answer"] == 0] 284 | print(f"Average number of matched answers: {np.mean(answer_nums)}...") 285 | print(f"Processed {len(samples)} examples...") 286 | if save_path: 287 | with open(save_path, 'w') as f: 288 | for s in samples: 289 | f.write(json.dumps(s) + "\n") 290 | else: 291 | return samples 292 | 293 | if __name__ == "__main__": 294 | import argparse 295 | parser = argparse.ArgumentParser() 296 | parser.add_argument( 297 | "--model-name", default="bert-base-uncased", type=str) 298 | parser.add_argument("--data", default="wq", type=str) 299 | parser.add_argument("--split", default="train", type=str) 300 | parser.add_argument("--topk", default=20, type=int) 301 | parser.add_argument("--filter", action="store_true", help="whether to filter no-answer QA pair") 302 | parser.add_argument("--dense-index", action="store_true") 303 | args = parser.parse_args() 304 | 305 | filter_ = True if "train" in args.split else False 306 | 307 | if args.dense_index: 308 | train_raw = load_openqa_dataset( 309 | f"../data/{args.data}/{args.data}-{args.split}-dense-final.txt", filter_no_answer=filter_) 310 | save_path = f"../data/{args.data}/{args.data}-{args.split}-dense-filtered-tokenized.txt" if filter_ else \ 311 | f"../data/{args.data}/{args.data}-{args.split}-dense-tokenized.txt" 312 | else: 313 | train_raw = load_openqa_dataset( 314 | f"../data/{args.data}/{args.data}-{args.split}-openqa-p{args.topk}.txt", filter_no_answer=filter_) 315 | save_path = f"../data/{args.data}/{args.data}-{args.split}-openqa-filtered-tokenized-p{args.topk}-all-matched.txt" if filter_ else \ 316 | f"../data/{args.data}/{args.data}-{args.split}-openqa-tokenized-p{args.topk}-all-matched.txt" 317 | 318 | train_tokenized = tokenize_data(train_raw, bert_model_name=args.model_name, save_path=save_path, openqa=True, num_workers=10) 319 | -------------------------------------------------------------------------------- /qa/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_tokens_to_ids(vocab, tokens): 28 | """Converts a sequence of tokens into ids using the vocab.""" 29 | ids = [] 30 | for token in tokens: 31 | ids.append(vocab[token]) 32 | return ids 33 | 34 | def whitespace_tokenize(text): 35 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 36 | text = text.strip() 37 | if not text: 38 | return [] 39 | tokens = text.split() 40 | return tokens 41 | 42 | 43 | def convert_to_unicode(text): 44 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 45 | if six.PY3: 46 | if isinstance(text, str): 47 | return text 48 | elif isinstance(text, bytes): 49 | return text.decode("utf-8", "ignore") 50 | else: 51 | raise ValueError("Unsupported string type: %s" % (type(text))) 52 | elif six.PY2: 53 | if isinstance(text, str): 54 | return text.decode("utf-8", "ignore") 55 | elif isinstance(text, unicode): 56 | return text 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | else: 60 | raise ValueError("Not running on Python2 or Python 3?") 61 | 62 | 63 | def _is_whitespace(char): 64 | """Checks whether `chars` is a whitespace character.""" 65 | # \t, \n, and \r are technically contorl characters but we treat them 66 | # as whitespace since they are generally considered as such. 67 | if char == " " or char == "\t" or char == "\n" or char == "\r": 68 | return True 69 | cat = unicodedata.category(char) 70 | if cat == "Zs": 71 | return True 72 | return False 73 | 74 | 75 | def _is_control(char): 76 | """Checks whether `chars` is a control character.""" 77 | # These are technically control characters but we count them as whitespace 78 | # characters. 79 | if char == "\t" or char == "\n" or char == "\r": 80 | return False 81 | cat = unicodedata.category(char) 82 | if cat.startswith("C"): 83 | return True 84 | return False 85 | 86 | class BasicTokenizer(object): 87 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 88 | 89 | def __init__(self, do_lower_case=True): 90 | """Constructs a BasicTokenizer. 91 | Args: 92 | do_lower_case: Whether to lower case the input. 93 | """ 94 | self.do_lower_case = do_lower_case 95 | 96 | def tokenize(self, text): 97 | """Tokenizes a piece of text.""" 98 | text = convert_to_unicode(text) 99 | text = self._clean_text(text) 100 | orig_tokens = whitespace_tokenize(text) 101 | split_tokens = [] 102 | for token in orig_tokens: 103 | if self.do_lower_case: 104 | token = token.lower() 105 | token = self._run_strip_accents(token) 106 | split_tokens.extend(self._run_split_on_punc(token)) 107 | 108 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 109 | return output_tokens 110 | 111 | def _run_strip_accents(self, text): 112 | """Strips accents from a piece of text.""" 113 | text = unicodedata.normalize("NFD", text) 114 | output = [] 115 | for char in text: 116 | cat = unicodedata.category(char) 117 | if cat == "Mn": 118 | continue 119 | output.append(char) 120 | return "".join(output) 121 | 122 | def _run_split_on_punc(self, text): 123 | """Splits punctuation on a piece of text.""" 124 | chars = list(text) 125 | i = 0 126 | start_new_word = True 127 | output = [] 128 | while i < len(chars): 129 | char = chars[i] 130 | if _is_punctuation(char): 131 | output.append([char]) 132 | start_new_word = True 133 | else: 134 | if start_new_word: 135 | output.append([]) 136 | start_new_word = False 137 | output[-1].append(char) 138 | i += 1 139 | 140 | return ["".join(x) for x in output] 141 | 142 | def _clean_text(self, text): 143 | """Performs invalid character removal and whitespace cleanup on text.""" 144 | output = [] 145 | for char in text: 146 | cp = ord(char) 147 | if cp == 0 or cp == 0xfffd or _is_control(char): 148 | continue 149 | if _is_whitespace(char): 150 | output.append(" ") 151 | else: 152 | output.append(char) 153 | return "".join(output) 154 | 155 | 156 | def _is_punctuation(char): 157 | """Checks whether `chars` is a punctuation character.""" 158 | cp = ord(char) 159 | # We treat all non-letter/number ASCII as punctuation. 160 | # Characters such as "^", "$", and "`" are not in the Unicode 161 | # Punctuation class but we treat them as punctuation anyways, for 162 | # consistency. 163 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 164 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 165 | return True 166 | cat = unicodedata.category(char) 167 | if cat.startswith("P"): 168 | return True 169 | return False 170 | 171 | 172 | def process(s, tokenizer): 173 | try: 174 | return tokenizer.tokenize(s) 175 | except: 176 | print('failed on', s) 177 | raise 178 | 179 | if __name__ == "__main__": 180 | _is_whitespace("a") 181 | -------------------------------------------------------------------------------- /qa/train.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import json 4 | import os 5 | import random 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.distributed import DistributedSampler 11 | from datasets import QADataset, collate 12 | from bert_qa import BertForQuestionAnswering 13 | from transformers import AdamW, BertConfig, BertTokenizer 14 | from torch.utils.tensorboard import SummaryWriter 15 | from eval_utils import get_final_text 16 | from official_eval import metric_max_over_ground_truths, f1_score, exact_match_score 17 | 18 | from utils import move_to_cuda, convert_to_half, AverageMeter 19 | from config import get_args 20 | 21 | def main(): 22 | args = get_args() 23 | 24 | if args.fp16: 25 | try: 26 | import apex 27 | apex.amp.register_half_function(torch, 'einsum') 28 | except ImportError: 29 | raise ImportError( 30 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 31 | 32 | # tb logger 33 | data_name = args.train_file.split("/")[-1].split('-')[0] 34 | model_name = f"{data_name}-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-{args.prefix}-lr{args.learning_rate}-{args.bert_model_name}" 35 | tb_logger = SummaryWriter(os.path.join(args.output_dir, "tflogs", model_name)) 36 | args.output_dir = os.path.join(args.output_dir, model_name) 37 | 38 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 39 | print(f"output directory {args.output_dir} already exists and is not empty.") 40 | if not os.path.exists(args.output_dir): 41 | os.makedirs(args.output_dir, exist_ok=True) 42 | 43 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 44 | datefmt='%m/%d/%Y %H:%M:%S', 45 | level=logging.INFO, 46 | handlers=[logging.FileHandler(os.path.join(args.output_dir, "log.txt")), 47 | logging.StreamHandler()]) 48 | logger = logging.getLogger(__name__) 49 | logger.info(args) 50 | 51 | if args.local_rank == -1 or args.no_cuda: 52 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 53 | n_gpu = torch.cuda.device_count() 54 | else: 55 | device = torch.device("cuda", args.local_rank) 56 | n_gpu = 1 57 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 58 | torch.distributed.init_process_group(backend='nccl') 59 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 60 | 61 | if args.accumulate_gradients < 1: 62 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 63 | args.accumulate_gradients)) 64 | 65 | args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) 66 | random.seed(args.seed) 67 | np.random.seed(args.seed) 68 | torch.manual_seed(args.seed) 69 | if n_gpu > 0: 70 | torch.cuda.manual_seed_all(args.seed) 71 | 72 | if not args.do_train and not args.do_predict: 73 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 74 | 75 | if args.do_train: 76 | if not args.train_file: 77 | raise ValueError( 78 | "If `do_train` is True, then `train_file` must be specified.") 79 | if not args.predict_file: 80 | raise ValueError( 81 | "If `do_train` is True, then `predict_file` must be specified.") 82 | 83 | if args.do_predict: 84 | if not args.predict_file: 85 | raise ValueError( 86 | "If `do_predict` is True, then `predict_file` must be specified.") 87 | 88 | bert_config = BertConfig.from_pretrained(args.bert_model_name) 89 | model = BertForQuestionAnswering(bert_config) 90 | tokenizer = BertTokenizer.from_pretrained(args.bert_model_name) 91 | 92 | if args.do_train and args.max_seq_length > bert_config.max_position_embeddings: 93 | raise ValueError( 94 | "Cannot use sequence length %d because the BERT model " 95 | "was only trained up to sequence length %d" % 96 | (args.max_seq_length, bert_config.max_position_embeddings)) 97 | 98 | eval_dataset = QADataset( 99 | tokenizer, args.predict_file, args.max_query_length, args.max_seq_length) 100 | eval_dataloader = DataLoader(eval_dataset, batch_size=args.predict_batch_size, collate_fn=collate, pin_memory=True) 101 | logger.info(f"Num of dev batches: {len(eval_dataloader)}") 102 | 103 | if args.init_checkpoint is not None: 104 | logger.info("Loading from {}".format(args.init_checkpoint)) 105 | if args.do_train and args.init_checkpoint == "": 106 | model = BertForQuestionAnswering.from_pretrained( 107 | args.bert_model_name) 108 | else: 109 | state_dict = torch.load(args.init_checkpoint) 110 | filter = lambda x: x[7:] if x.startswith('module.') else x 111 | state_dict = {filter(k):v for (k,v) in state_dict.items()} 112 | model.load_state_dict(state_dict) 113 | model.to(device) 114 | 115 | print(f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 116 | 117 | if args.do_train: 118 | no_decay = ['bias', 'LayerNorm.weight'] 119 | optimizer_parameters = [ 120 | {'params': [p for n, p in model.named_parameters() if not any( 121 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 122 | {'params': [p for n, p in model.named_parameters() if any( 123 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 124 | ] 125 | optimizer = AdamW(optimizer_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 126 | 127 | if args.fp16: 128 | try: 129 | from apex import amp 130 | except ImportError: 131 | raise ImportError( 132 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 133 | model, optimizer = amp.initialize( 134 | model, optimizer, opt_level=args.fp16_opt_level) 135 | else: 136 | if args.fp16: 137 | try: 138 | from apex import amp 139 | except ImportError: 140 | raise ImportError( 141 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 142 | model = amp.initialize(model, opt_level=args.fp16_opt_level) 143 | 144 | if args.local_rank != -1: 145 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 146 | output_device=args.local_rank) 147 | elif n_gpu > 1: 148 | model = torch.nn.DataParallel(model) 149 | 150 | if args.do_train: 151 | global_step = 0 152 | best_f1 = (-1, -1) 153 | wait_step = 0 154 | stop_training = False 155 | train_loss_meter = AverageMeter() 156 | logger.info('Start training....') 157 | model.train() 158 | train_dataset = QADataset(tokenizer, args.train_file, args.max_query_length, args.max_seq_length) 159 | train_dataloader = DataLoader( 160 | train_dataset, batch_size=args.train_batch_size, collate_fn=collate, shuffle=True, pin_memory=True) 161 | 162 | for epoch in range(int(args.num_train_epochs)): 163 | 164 | for step, batch in enumerate(tqdm(train_dataloader)): 165 | batch = move_to_cuda(batch) 166 | outputs = model(batch["net_input"]) 167 | loss = outputs["span_loss"] 168 | 169 | if n_gpu > 1: 170 | loss = loss.mean() # mean() to average on multi-gpu. 171 | 172 | if args.gradient_accumulation_steps > 1: 173 | loss = loss / args.gradient_accumulation_steps 174 | 175 | if args.fp16: 176 | with amp.scale_loss(loss, optimizer) as scaled_loss: 177 | scaled_loss.backward() 178 | else: 179 | loss.backward() 180 | 181 | train_loss_meter.update(loss.item()) 182 | tb_logger.add_scalar('batch_train_loss', loss.item(), global_step) 183 | 184 | if (step + 1) % args.gradient_accumulation_steps == 0: 185 | if args.fp16: 186 | torch.nn.utils.clip_grad_norm_( 187 | amp.master_params(optimizer), args.max_grad_norm) 188 | else: 189 | torch.nn.utils.clip_grad_norm_( 190 | model.parameters(), args.max_grad_norm) 191 | optimizer.step() # We have accumulated enought gradients 192 | model.zero_grad() 193 | global_step += 1 194 | 195 | if global_step % args.eval_period == 0: 196 | f1 = predict(logger, args, model, eval_dataloader, device, fp16=args.efficient_eval) 197 | logger.info("Step %d Train loss %.2f EM %.2f F1 %.2f on epoch=%d" % ( 198 | global_step, train_loss_meter.avg, f1[0]*100, f1[1]*100, epoch)) 199 | 200 | tb_logger.add_scalar('dev_f1', f1[0]*100, global_step) 201 | tb_logger.add_scalar('dev_em', f1[1]*100, global_step) 202 | 203 | if best_f1 < f1: 204 | logger.info("Saving model with best EM: %.2f (F1 %.2f) -> %.2f (F1 %.2f) on epoch=%d" % \ 205 | (best_f1[1]*100, best_f1[0]*100, f1[1]*100, f1[0]*100, epoch)) 206 | model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()} 207 | torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt")) 208 | model = model.to(device) 209 | best_f1 = f1 210 | wait_step = 0 211 | stop_training = False 212 | else: 213 | wait_step += 1 214 | if wait_step == args.wait_step: 215 | stop_training = True 216 | 217 | f1 = predict(logger, args, model, eval_dataloader, 218 | device, fp16=args.efficient_eval) 219 | logger.info("Step %d Train loss %.2f EM %.2f F1 %.2f on epoch=%d" % ( 220 | global_step, train_loss_meter.avg, f1[0]*100, f1[1]*100, epoch)) 221 | tb_logger.add_scalar('dev_f1', f1[0]*100, global_step) 222 | tb_logger.add_scalar('dev_em', f1[1]*100, global_step) 223 | logger.info(f"average training loss {train_loss_meter.avg}") 224 | 225 | 226 | if stop_training: 227 | break 228 | 229 | logger.info("Training finished!") 230 | 231 | # elif args.do_predict: 232 | # if type(model)==list: 233 | # model = [m.eval() for m in model] 234 | # else: 235 | # model.eval() 236 | # f1 = predict(logger, args, model, eval_dataloader, eval_examples, eval_features, 237 | # device, fp16=args.efficient_eval, write_prediction=False) 238 | # logger.info(f"test performance {f1}") 239 | # print(f1) 240 | 241 | 242 | def predict(logger, args, model, eval_dataloader, device, fp16=False): 243 | model.eval() 244 | all_results = [] 245 | 246 | if fp16: 247 | model.half() 248 | 249 | qid2results = {} 250 | for batch in tqdm(eval_dataloader): 251 | batch_to_feed = move_to_cuda(batch["net_input"]) 252 | if fp16: 253 | batch_to_feed = convert_to_half(batch_to_feed) 254 | with torch.no_grad(): 255 | results = model(batch_to_feed) 256 | batch_start_logits = results["start_logits"] 257 | batch_end_logits = results["end_logits"] 258 | question_mask = batch_to_feed["paragraph_mask"].ne(1) 259 | outs = [o.float().masked_fill(question_mask, -1e10).type_as(o) 260 | for o in [batch_start_logits, batch_end_logits]] 261 | 262 | span_scores = outs[0][:,:,None] + outs[1][:,None] 263 | max_answer_lens = 20 264 | max_seq_len = span_scores.size(1) 265 | span_mask = np.tril(np.triu(np.ones((max_seq_len, max_seq_len)), 0), max_answer_lens) 266 | span_mask = span_scores.data.new(max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask)) 267 | span_scores_masked = span_scores.float().masked_fill((1 - 268 | span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores) 269 | 270 | start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1] 271 | end_position = span_scores_masked.max(dim=2)[1].gather(1, start_position.unsqueeze(1)).squeeze(1) 272 | 273 | para_offset = batch['para_offset'] 274 | start_position_ = list(np.array(start_position.tolist()) - np.array(para_offset)) 275 | end_position_ = list(np.array(end_position.tolist()) - np.array(para_offset)) 276 | 277 | for idx, qid in enumerate(batch['id']): 278 | start = start_position_[idx] 279 | end = end_position_[idx] 280 | tok_to_orig_index = batch['tok_to_orig_index'][idx] 281 | doc_tokens = batch['doc_tokens'][idx] 282 | wp_tokens = batch['wp_tokens'][idx] 283 | orig_doc_start = tok_to_orig_index[start] 284 | orig_doc_end = tok_to_orig_index[end] 285 | orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)] 286 | tok_tokens = wp_tokens[start:end+1] 287 | tok_text = " ".join(tok_tokens) 288 | tok_text = tok_text.replace(" ##", "") 289 | tok_text = tok_text.replace("##", "") 290 | tok_text = tok_text.strip() 291 | tok_text = " ".join(tok_text.split()) 292 | orig_text = " ".join(orig_tokens) 293 | final_text = get_final_text(tok_text, orig_text, logger, do_lower_case=args.do_lower_case, verbose_logging=False) 294 | qid2results[qid] = [final_text, batch['true_answers'][idx]] 295 | 296 | f1s = [metric_max_over_ground_truths(f1_score, item[0], item[1]) for item in qid2results.values()] 297 | ems = [metric_max_over_ground_truths(exact_match_score, item[0], item[1]) for item in qid2results.values()] 298 | 299 | print(f"evaluated {len(f1s)} examples...") 300 | if fp16: 301 | model.float() 302 | model.train() 303 | 304 | return (np.mean(f1s), np.mean(ems)) 305 | 306 | 307 | if __name__ == "__main__": 308 | main() 309 | -------------------------------------------------------------------------------- /qa/train_dense_qa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=3 python train_retrieve_qa.py \ 5 | --do_train \ 6 | --prefix dense-index-trec-nocluser-70k \ 7 | --eval_period -1 \ 8 | --bert_model_name bert-base-uncased \ 9 | --train_batch_size 5 \ 10 | --gradient_accumulation_steps 1 \ 11 | --accumulate_gradients 1 \ 12 | --efficient_eval \ 13 | --learning_rate 1e-5 \ 14 | --fp16 \ 15 | --raw-train-data ../data/trec-train.txt \ 16 | --raw-eval-data ../data/trec-dev.txt \ 17 | --seed 3 \ 18 | --retriever-path ../retrieval/logs/retrieve_train.txt-seed31-bsz640-fp16True-baseline_no_cluster_from_failed_continue-lr1e-05-bert-base-uncased-filterTrue/checkpoint_40000.pt \ 19 | --index-path ../retrieval/encodings/para_embed.npy \ 20 | --fix-para-encoder \ 21 | --num_train_epochs 10 \ 22 | --matched-para-path ../data/trec_train_matched_20000.txt \ 23 | --regex \ 24 | --shared-norm \ 25 | # --separate \ 26 | -------------------------------------------------------------------------------- /qa/train_retrieve_qa.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import json 4 | import os 5 | import random 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | from copy import deepcopy 10 | 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | from bert_retrieve_qa import BertRetrieveQA 14 | from transformers import AdamW, BertConfig, BertTokenizer 15 | from torch.utils.tensorboard import SummaryWriter 16 | from eval_utils import get_final_text 17 | from official_eval import metric_max_over_ground_truths, exact_match_score, regex_match_score 18 | from online_sampler import OnlineSampler 19 | 20 | 21 | from utils import move_to_cuda, convert_to_half, AverageMeter, DocDB 22 | from config import get_args 23 | 24 | from collections import defaultdict, namedtuple 25 | import torch.nn.functional as F 26 | 27 | def load_saved(model, path): 28 | state_dict = torch.load(path) 29 | def filter(x): return x[7:] if x.startswith('module.') else x 30 | state_dict = {filter(k): v for (k, v) in state_dict.items()} 31 | model.load_state_dict(state_dict) 32 | return model 33 | 34 | 35 | def main(): 36 | args = get_args() 37 | 38 | if args.fp16: 39 | try: 40 | import apex 41 | apex.amp.register_half_function(torch, 'einsum') 42 | except ImportError: 43 | raise ImportError( 44 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 45 | 46 | # tb logger 47 | data_name = args.train_file.split("/")[-1].split('-')[0] 48 | model_name = f"dense-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-{args.prefix}-lr{args.learning_rate}-{args.bert_model_name}-qdrop{args.qa_drop}-sn{args.shared_norm}-sep{args.separate}-as{args.add_select}-noearly{args.drop_early}" 49 | if args.do_train: 50 | tb_logger = SummaryWriter(os.path.join( 51 | args.output_dir, "tflogs", "dense", model_name)) 52 | args.output_dir = os.path.join(args.output_dir, model_name) 53 | 54 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 55 | print( 56 | f"output directory {args.output_dir} already exists and is not empty.") 57 | if not os.path.exists(args.output_dir): 58 | os.makedirs(args.output_dir, exist_ok=True) 59 | 60 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 61 | datefmt='%m/%d/%Y %H:%M:%S', 62 | level=logging.INFO, 63 | handlers=[logging.FileHandler(os.path.join(args.output_dir, "log.txt")), 64 | logging.StreamHandler()]) 65 | logger = logging.getLogger(__name__) 66 | logger.info(args) 67 | 68 | if args.local_rank == -1 or args.no_cuda: 69 | device = torch.device( 70 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 71 | n_gpu = torch.cuda.device_count() 72 | else: 73 | device = torch.device("cuda", args.local_rank) 74 | n_gpu = 1 75 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 76 | torch.distributed.init_process_group(backend='nccl') 77 | logger.info("device %s n_gpu %d distributed training %r", 78 | device, n_gpu, bool(args.local_rank != -1)) 79 | 80 | if args.accumulate_gradients < 1: 81 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 82 | args.accumulate_gradients)) 83 | 84 | args.train_batch_size = int( 85 | args.train_batch_size / args.accumulate_gradients) 86 | random.seed(args.seed) 87 | np.random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | if n_gpu > 0: 90 | torch.cuda.manual_seed_all(args.seed) 91 | 92 | if not args.do_train and not args.do_predict: 93 | raise ValueError( 94 | "At least one of `do_train` or `do_predict` must be True.") 95 | 96 | if args.do_train: 97 | if not args.train_file: 98 | raise ValueError( 99 | "If `do_train` is True, then `train_file` must be specified.") 100 | if not args.predict_file: 101 | raise ValueError( 102 | "If `do_train` is True, then `predict_file` must be specified.") 103 | 104 | if args.do_predict: 105 | if not args.predict_file: 106 | raise ValueError( 107 | "If `do_predict` is True, then `predict_file` must be specified.") 108 | 109 | bert_config = BertConfig.from_pretrained(args.bert_model_name) 110 | model = BertRetrieveQA(bert_config, args) 111 | tokenizer = BertTokenizer.from_pretrained(args.bert_model_name) 112 | 113 | logger.info("Loading para db and pretrained index ...") 114 | para_db = DocDB(args.db_path) 115 | para_embed = np.load(args.index_path).astype('float32') 116 | 117 | if args.do_train and args.max_seq_length > bert_config.max_position_embeddings: 118 | raise ValueError( 119 | "Cannot use sequence length %d because the BERT model " 120 | "was only trained up to sequence length %d" % 121 | (args.max_seq_length, bert_config.max_position_embeddings)) 122 | 123 | exact_search = True if args.do_predict else False 124 | eval_dataloader = OnlineSampler(args.raw_eval_data, tokenizer, args.max_query_length, 125 | args.max_seq_length, para_db, para_embed, exact_search=exact_search, cased=args.use_spanbert, regex=args.regex) 126 | 127 | if args.init_checkpoint != "": 128 | model = load_saved(model, args.init_checkpoint) 129 | 130 | model.to(device) 131 | logger.info( 132 | f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 133 | 134 | if args.fix_para_encoder: 135 | model.freeze_c_encoder() 136 | 137 | if args.do_train: 138 | no_decay = ['bias', 'LayerNorm.weight'] 139 | optimizer_parameters = [ 140 | {'params': [p for n, p in model.named_parameters() if not any( 141 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 142 | {'params': [p for n, p in model.named_parameters() if any( 143 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 144 | ] 145 | optimizer = AdamW(optimizer_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 146 | 147 | if args.fp16: 148 | try: 149 | from apex import amp 150 | except ImportError: 151 | raise ImportError( 152 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 153 | model, optimizer = amp.initialize( 154 | model, optimizer, opt_level=args.fp16_opt_level) 155 | else: 156 | if args.fp16: 157 | try: 158 | from apex import amp 159 | except ImportError: 160 | raise ImportError( 161 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 162 | model = amp.initialize(model, opt_level=args.fp16_opt_level) 163 | 164 | if args.local_rank != -1: 165 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 166 | output_device=args.local_rank) 167 | elif n_gpu > 1: 168 | model = torch.nn.DataParallel(model) 169 | 170 | if args.do_train: 171 | global_step = 0 # gradient update step 172 | batch_step = 0 # forward batch count 173 | best_em = 0 174 | wait_step = 0 175 | stop_training = False 176 | train_loss_meter = AverageMeter() 177 | logger.info('Start training....') 178 | model.train() 179 | train_dataloader = OnlineSampler( 180 | args.raw_train_data, tokenizer, args.max_query_length, args.max_seq_length, para_db, para_embed, matched_para_path=args.matched_para_path, cased=args.use_spanbert, regex=args.regex) 181 | for epoch in range(int(args.num_train_epochs)): 182 | train_dataloader.shuffle() 183 | failed_retrival = 0 184 | for batch in tqdm(train_dataloader.load(model.retriever, k=args.train_batch_size), total=len(train_dataloader)): 185 | batch_step += 1 186 | if batch == {}: 187 | failed_retrival += 1 188 | continue 189 | batch = move_to_cuda(batch) 190 | outputs = model(batch["net_input"]) 191 | loss = outputs["loss"] 192 | if n_gpu > 1: 193 | loss = loss.mean() # mean() to average on multi-gpu. 194 | if args.gradient_accumulation_steps > 1: 195 | loss = loss / args.gradient_accumulation_steps 196 | if args.fp16: 197 | with amp.scale_loss(loss, optimizer) as scaled_loss: 198 | scaled_loss.backward() 199 | else: 200 | loss.backward() 201 | 202 | train_loss_meter.update(loss.item()) 203 | tb_logger.add_scalar('batch_train_loss', 204 | loss.item(), global_step) 205 | tb_logger.add_scalar('smoothed_train_loss', 206 | train_loss_meter.avg, global_step) 207 | 208 | if (batch_step + 1) % args.gradient_accumulation_steps == 0: 209 | if args.fp16: 210 | torch.nn.utils.clip_grad_norm_( 211 | amp.master_params(optimizer), args.max_grad_norm) 212 | else: 213 | torch.nn.utils.clip_grad_norm_( 214 | model.parameters(), args.max_grad_norm) 215 | optimizer.step() # We have accumulated enought gradients 216 | model.zero_grad() 217 | global_step += 1 218 | 219 | if args.eval_period != -1 and global_step % args.eval_period == 0: 220 | em = predict(args, model, eval_dataloader, 221 | device, fp16=args.efficient_eval) 222 | logger.info("Step %d Train loss %.2f EM %.2f on epoch=%d" % ( 223 | global_step, train_loss_meter.avg, em*100, epoch)) 224 | 225 | tb_logger.add_scalar('dev_em', em*100, global_step) 226 | 227 | if best_em < em: 228 | logger.info("Saving model with best EM: %.2f -> EM %.2f on epoch=%d" % 229 | (best_em*100, em*100, epoch)) 230 | model_state_dict = {k: v.cpu() for ( 231 | k, v) in model.state_dict().items()} 232 | torch.save(model_state_dict, os.path.join( 233 | args.output_dir, "best-model.pt")) 234 | model = model.to(device) 235 | best_em = em 236 | wait_step = 0 237 | stop_training = False 238 | else: 239 | wait_step += 1 240 | if wait_step == args.wait_step: 241 | stop_training = True 242 | 243 | logger.info(f"Failed retrieval: {failed_retrival}/{len(train_dataloader)} ...") 244 | em = predict(args, model, eval_dataloader, 245 | device, fp16=args.efficient_eval) 246 | tb_logger.add_scalar('dev_em', em*100, global_step) 247 | logger.info(f"average training loss {train_loss_meter.avg}") 248 | if best_em < em: 249 | logger.info("Saving model with best EM: %.2f -> %.2f on epoch=%d" % 250 | (best_em*100, em*100, epoch)) 251 | torch.save(model.state_dict(), os.path.join( 252 | args.output_dir, "best-model.pt")) 253 | model = model.to(device) 254 | best_em = em 255 | wait_step = 0 256 | 257 | if epoch > 15: 258 | logger.info(f"Saving model after epoch {epoch + 1}") 259 | torch.save(model.state_dict(), os.path.join( 260 | args.output_dir, f"model-{epoch+1}-{em}.pt")) 261 | 262 | if stop_training: 263 | break 264 | 265 | logger.info("Training finished!") 266 | 267 | elif args.do_predict: 268 | f1 = predict(args, model, eval_dataloader, 269 | device, fp16=args.efficient_eval) 270 | logger.info(f"test performance {f1}") 271 | print(f1) 272 | 273 | 274 | def predict(args, model, eval_dataloader, device, fp16=False): 275 | model.eval() 276 | if fp16: 277 | model.half() 278 | 279 | all_results = [] 280 | PredictionMeta = collections.namedtuple( 281 | "Prediction", ["text", "rank_score", "passage", "span_score", "question"]) 282 | qid2results = defaultdict(list) 283 | qid2ground = {} 284 | 285 | for batch in tqdm(eval_dataloader.eval_load(model.retriever, args.eval_k), total=len(eval_dataloader)): 286 | 287 | batch_to_feed = move_to_cuda(batch["net_input"]) 288 | if fp16: 289 | batch_to_feed = convert_to_half(batch_to_feed) 290 | with torch.no_grad(): 291 | results = model(batch_to_feed) 292 | batch_start_logits = results["start_logits"] 293 | batch_end_logits = results["end_logits"] 294 | batch_rank_logits = results["rank_logits"] 295 | if args.add_select: 296 | batch_select_logits = results["select_logits"] 297 | 298 | outs = [batch_start_logits, batch_end_logits] 299 | 300 | span_scores = outs[0][:, :, None] + outs[1][:, None] 301 | max_answer_lens = 10 302 | max_seq_len = span_scores.size(1) 303 | span_mask = np.tril( 304 | np.triu(np.ones((max_seq_len, max_seq_len)), 0), max_answer_lens) 305 | span_mask = span_scores.data.new( 306 | max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask)) 307 | span_scores_masked = span_scores.float().masked_fill((1 - 308 | span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores) 309 | 310 | start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1] 311 | end_position = span_scores_masked.max(dim=2)[1].gather( 312 | 1, start_position.unsqueeze(1)).squeeze(1) 313 | 314 | answer_scores = span_scores_masked.max(dim=2)[0].max(dim=1)[0].tolist() 315 | 316 | if args.add_select: 317 | rank_logits = batch_select_logits.view(-1).tolist() 318 | else: 319 | rank_logits = batch_rank_logits.view(-1).tolist() 320 | 321 | para_offset = batch['para_offset'] 322 | start_position_ = list( 323 | np.array(start_position.tolist()) - np.array(para_offset)) 324 | end_position_ = list( 325 | np.array(end_position.tolist()) - np.array(para_offset)) 326 | 327 | for idx, qid in enumerate(batch['id']): 328 | start = start_position_[idx] 329 | end = end_position_[idx] 330 | rank_score = rank_logits[idx] 331 | span_score = answer_scores[idx] 332 | tok_to_orig_index = batch['tok_to_orig_index'][idx] 333 | doc_tokens = batch['doc_tokens'][idx] 334 | wp_tokens = batch['wp_tokens'][idx] 335 | orig_doc_start = tok_to_orig_index[start] 336 | orig_doc_end = tok_to_orig_index[end] 337 | orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)] 338 | tok_tokens = wp_tokens[start:end+1] 339 | tok_text = " ".join(tok_tokens) 340 | tok_text = tok_text.replace(" ##", "") 341 | tok_text = tok_text.replace("##", "") 342 | tok_text = tok_text.strip() 343 | tok_text = " ".join(tok_text.split()) 344 | orig_text = " ".join(orig_tokens) 345 | final_text = get_final_text( 346 | tok_text, orig_text, do_lower_case=args.do_lower_case, verbose_logging=False) 347 | question = batch["q"][idx] 348 | qid2results[qid].append( 349 | PredictionMeta( 350 | text=final_text, 351 | rank_score=rank_score, 352 | span_score=span_score, 353 | passage=" ".join(doc_tokens), 354 | question=question, 355 | ) 356 | ) 357 | qid2ground[qid] = batch["true_answers"][idx] 358 | 359 | if args.save_all: 360 | print("Saving all prediction results ...") 361 | with open(f"{args.prefix}_all.json", "w") as g: 362 | json.dump(qid2results, g) 363 | with open(f"{args.prefix}_ground.json", "w") as g: 364 | json.dump(qid2ground, g) 365 | 366 | ## linear combination tuning on dev data 367 | best_em = 0 368 | for alpha in [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.5, 0.55, 0.6, 0.7, 0.8, 0.9, 1]: 369 | results_to_save = [] 370 | ems = [] 371 | for qid in qid2results.keys(): 372 | qid2results[qid] = sorted( 373 | qid2results[qid], key=lambda x: alpha*x.span_score + (1 - alpha)*x.rank_score, reverse=True) 374 | match_fn = regex_match_score if args.regex else exact_match_score 375 | ems.append(metric_max_over_ground_truths( 376 | match_fn, qid2results[qid][0].text, qid2ground[qid])) 377 | results_to_save.append({ 378 | "question": qid2results[qid][0].question, 379 | "para": qid2results[qid][0].passage, 380 | "answer": qid2results[qid][0].text, 381 | "rank_score": qid2results[qid][0].rank_score, 382 | "gold": qid2ground[qid], 383 | "em": ems[-1] 384 | }) 385 | em = np.mean(ems) 386 | if em > best_em: 387 | best_em = em 388 | print(f"evaluated {len(ems)} examples...") 389 | print(f"alpha: {alpha}; avg. EM: {em}") 390 | 391 | if args.save_pred: 392 | with open(f"{args.prefix}_{alpha}.json", "w") as g: 393 | for line in results_to_save: 394 | g.write(json.dumps(line) + "\n") 395 | 396 | if type(model) != list: 397 | if fp16: 398 | model.float() 399 | model.train() 400 | 401 | return best_em 402 | 403 | 404 | if __name__ == "__main__": 405 | main() 406 | 407 | -------------------------------------------------------------------------------- /qa/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import sqlite3 3 | import torch 4 | import unicodedata 5 | 6 | def move_to_cuda(sample): 7 | if len(sample) == 0: 8 | return {} 9 | 10 | def _move_to_cuda(maybe_tensor): 11 | if torch.is_tensor(maybe_tensor): 12 | return maybe_tensor.cuda() 13 | elif isinstance(maybe_tensor, dict): 14 | return { 15 | key: _move_to_cuda(value) 16 | for key, value in maybe_tensor.items() 17 | } 18 | elif isinstance(maybe_tensor, list): 19 | return [_move_to_cuda(x) for x in maybe_tensor] 20 | else: 21 | return maybe_tensor 22 | 23 | return _move_to_cuda(sample) 24 | 25 | def convert_to_half(sample): 26 | if len(sample) == 0: 27 | return {} 28 | 29 | def _convert_to_half(maybe_floatTensor): 30 | if torch.is_tensor(maybe_floatTensor) and maybe_floatTensor.type() == "torch.FloatTensor": 31 | return maybe_floatTensor.half() 32 | elif isinstance(maybe_floatTensor, dict): 33 | return { 34 | key: _convert_to_half(value) 35 | for key, value in maybe_floatTensor.items() 36 | } 37 | elif isinstance(maybe_floatTensor, list): 38 | return [_convert_to_half(x) for x in maybe_floatTensor] 39 | else: 40 | return maybe_floatTensor 41 | 42 | return _convert_to_half(sample) 43 | 44 | 45 | class AverageMeter(object): 46 | """Computes and stores the average and current value""" 47 | 48 | def __init__(self): 49 | self.reset() 50 | 51 | def reset(self): 52 | self.val = 0 53 | self.avg = 0 54 | self.sum = 0 55 | self.count = 0 56 | 57 | def update(self, val, n=1): 58 | self.val = val 59 | self.sum += val * n 60 | self.count += n 61 | self.avg = self.sum / self.count 62 | 63 | 64 | def normalize(text): 65 | """Resolve different type of unicode encodings.""" 66 | return unicodedata.normalize('NFD', text) 67 | 68 | 69 | def load_saved(model, path): 70 | state_dict = torch.load(path) 71 | def filter(x): return x[7:] if x.startswith('module.') else x 72 | state_dict = {filter(k): v for (k, v) in state_dict.items()} 73 | model.load_state_dict(state_dict) 74 | return model 75 | 76 | class DocDB(object): 77 | """Sqlite backed document storage. 78 | 79 | Implements get_doc_text(doc_id). 80 | """ 81 | 82 | def __init__(self, db_path=None): 83 | self.path = db_path 84 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 85 | 86 | def __enter__(self): 87 | return self 88 | 89 | def __exit__(self, *args): 90 | self.close() 91 | 92 | 93 | def close(self): 94 | """Close the connection to the database.""" 95 | self.connection.close() 96 | 97 | def get_doc_ids(self): 98 | """Fetch all ids of docs stored in the db.""" 99 | cursor = self.connection.cursor() 100 | cursor.execute("SELECT id FROM documents") 101 | results = [r[0] for r in cursor.fetchall()] 102 | cursor.close() 103 | return results 104 | 105 | def get_doc_text(self, doc_id): 106 | """Fetch the raw text of the doc for 'doc_id'.""" 107 | cursor = self.connection.cursor() 108 | cursor.execute( 109 | "SELECT text FROM documents WHERE id = ?", 110 | (normalize(doc_id),) 111 | ) 112 | result = cursor.fetchone() 113 | cursor.close() 114 | return result if result is None else result[0] 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fairseq==0.9.0 2 | faiss-cpu==1.6.3 3 | gdown==3.10.3 4 | joblib==0.13.2 5 | tensorboard==2.0.2 6 | tensorboardX==2.0 7 | tensorflow-estimator==2.0.1 8 | tensorflow-gpu==2.0.1 9 | torch==1.4.0 10 | torchvision==0.5.0 11 | tqdm==4.36.1 12 | transformers==2.5.1 13 | -------------------------------------------------------------------------------- /retrieval/basic_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | 13 | class Tokens(object): 14 | """A class to represent a list of tokenized text.""" 15 | TEXT = 0 16 | TEXT_WS = 1 17 | SPAN = 2 18 | POS = 3 19 | LEMMA = 4 20 | NER = 5 21 | 22 | def __init__(self, data, annotators, opts=None): 23 | self.data = data 24 | self.annotators = annotators 25 | self.opts = opts or {} 26 | 27 | def __len__(self): 28 | """The number of tokens.""" 29 | return len(self.data) 30 | 31 | def slice(self, i=None, j=None): 32 | """Return a view of the list of tokens from [i, j).""" 33 | new_tokens = copy.copy(self) 34 | new_tokens.data = self.data[i: j] 35 | return new_tokens 36 | 37 | def untokenize(self): 38 | """Returns the original text (with whitespace reinserted).""" 39 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 40 | 41 | def words(self, uncased=False): 42 | """Returns a list of the text of each token 43 | 44 | Args: 45 | uncased: lower cases text 46 | """ 47 | if uncased: 48 | return [t[self.TEXT].lower() for t in self.data] 49 | else: 50 | return [t[self.TEXT] for t in self.data] 51 | 52 | def offsets(self): 53 | """Returns a list of [start, end) character offsets of each token.""" 54 | return [t[self.SPAN] for t in self.data] 55 | 56 | def pos(self): 57 | """Returns a list of part-of-speech tags of each token. 58 | Returns None if this annotation was not included. 59 | """ 60 | if 'pos' not in self.annotators: 61 | return None 62 | return [t[self.POS] for t in self.data] 63 | 64 | def lemmas(self): 65 | """Returns a list of the lemmatized text of each token. 66 | Returns None if this annotation was not included. 67 | """ 68 | if 'lemma' not in self.annotators: 69 | return None 70 | return [t[self.LEMMA] for t in self.data] 71 | 72 | def entities(self): 73 | """Returns a list of named-entity-recognition tags of each token. 74 | Returns None if this annotation was not included. 75 | """ 76 | if 'ner' not in self.annotators: 77 | return None 78 | return [t[self.NER] for t in self.data] 79 | 80 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 81 | """Returns a list of all ngrams from length 1 to n. 82 | 83 | Args: 84 | n: upper limit of ngram length 85 | uncased: lower cases text 86 | filter_fn: user function that takes in an ngram list and returns 87 | True or False to keep or not keep the ngram 88 | as_string: return the ngram as a string vs list 89 | """ 90 | def _skip(gram): 91 | if not filter_fn: 92 | return False 93 | return filter_fn(gram) 94 | 95 | words = self.words(uncased) 96 | ngrams = [(s, e + 1) 97 | for s in range(len(words)) 98 | for e in range(s, min(s + n, len(words))) 99 | if not _skip(words[s:e + 1])] 100 | 101 | # Concatenate into strings 102 | if as_strings: 103 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 104 | 105 | return ngrams 106 | 107 | def entity_groups(self): 108 | """Group consecutive entity tokens with the same NER tag.""" 109 | entities = self.entities() 110 | if not entities: 111 | return None 112 | non_ent = self.opts.get('non_ent', 'O') 113 | groups = [] 114 | idx = 0 115 | while idx < len(entities): 116 | ner_tag = entities[idx] 117 | # Check for entity tag 118 | if ner_tag != non_ent: 119 | # Chomp the sequence 120 | start = idx 121 | while (idx < len(entities) and entities[idx] == ner_tag): 122 | idx += 1 123 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 124 | else: 125 | idx += 1 126 | return groups 127 | 128 | 129 | class Tokenizer(object): 130 | """Base tokenizer class. 131 | Tokenizers implement tokenize, which should return a Tokens class. 132 | """ 133 | 134 | def tokenize(self, text): 135 | raise NotImplementedError 136 | 137 | def shutdown(self): 138 | pass 139 | 140 | def __del__(self): 141 | self.shutdown() 142 | 143 | 144 | import regex 145 | import logging 146 | 147 | logger = logging.getLogger(__name__) 148 | 149 | 150 | class RegexpTokenizer(Tokenizer): 151 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 152 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 153 | r'\.(?=\p{Z})') 154 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 155 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 156 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 157 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 158 | CONTRACTION1 = r"can(?=not\b)" 159 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 160 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 161 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 162 | END_DQUOTE = r'(?<!\p{Z})(\'\'|["\u0094\u201D\u00BB])' 163 | END_SQUOTE = r'(?<!\p{Z})[\'\u0092\u2019\u203A]' 164 | DASH = r'--|[\u0096\u0097\u2013\u2014\u2015]' 165 | ELLIPSES = r'\.\.\.|\u2026' 166 | PUNCT = r'\p{P}' 167 | NON_WS = r'[^\p{Z}\p{C}]' 168 | 169 | def __init__(self, **kwargs): 170 | """ 171 | Args: 172 | annotators: None or empty set (only tokenizes). 173 | substitutions: if true, normalizes some token types (e.g. quotes). 174 | """ 175 | self._regexp = regex.compile( 176 | '(?P<digit>%s)|(?P<title>%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' 177 | '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' 178 | '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' 179 | '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % 180 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, 181 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, 182 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, 183 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, 184 | self.NON_WS), 185 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 186 | ) 187 | if len(kwargs.get('annotators', {})) > 0: 188 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 189 | (type(self).__name__, kwargs.get('annotators'))) 190 | self.annotators = set() 191 | self.substitutions = kwargs.get('substitutions', True) 192 | 193 | def tokenize(self, text): 194 | data = [] 195 | matches = [m for m in self._regexp.finditer(text)] 196 | for i in range(len(matches)): 197 | # Get text 198 | token = matches[i].group() 199 | 200 | # Make normalizations for special token types 201 | if self.substitutions: 202 | groups = matches[i].groupdict() 203 | if groups['sdquote']: 204 | token = "``" 205 | elif groups['edquote']: 206 | token = "''" 207 | elif groups['ssquote']: 208 | token = "`" 209 | elif groups['esquote']: 210 | token = "'" 211 | elif groups['dash']: 212 | token = '--' 213 | elif groups['ellipses']: 214 | token = '...' 215 | 216 | # Get whitespace 217 | span = matches[i].span() 218 | start_ws = span[0] 219 | if i + 1 < len(matches): 220 | end_ws = matches[i + 1].span()[0] 221 | else: 222 | end_ws = span[1] 223 | 224 | # Format data 225 | data.append(( 226 | token, 227 | text[start_ws: end_ws], 228 | span, 229 | )) 230 | return Tokens(data, self.annotators) 231 | 232 | 233 | class SimpleTokenizer(Tokenizer): 234 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 235 | NON_WS = r'[^\p{Z}\p{C}]' 236 | 237 | def __init__(self, **kwargs): 238 | """ 239 | Args: 240 | annotators: None or empty set (only tokenizes). 241 | """ 242 | self._regexp = regex.compile( 243 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 244 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 245 | ) 246 | if len(kwargs.get('annotators', {})) > 0: 247 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 248 | (type(self).__name__, kwargs.get('annotators'))) 249 | self.annotators = set() 250 | 251 | def tokenize(self, text): 252 | data = [] 253 | matches = [m for m in self._regexp.finditer(text)] 254 | for i in range(len(matches)): 255 | # Get text 256 | token = matches[i].group() 257 | 258 | # Get whitespace 259 | span = matches[i].span() 260 | start_ws = span[0] 261 | if i + 1 < len(matches): 262 | end_ws = matches[i + 1].span()[0] 263 | else: 264 | end_ws = span[1] 265 | 266 | # Format data 267 | data.append(( 268 | token, 269 | text[start_ws: end_ws], 270 | span, 271 | )) 272 | return Tokens(data, self.annotators) 273 | 274 | 275 | 276 | -------------------------------------------------------------------------------- /retrieval/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # Required parameters 8 | parser.add_argument("--bert_model_name", 9 | default="bert-large-cased-whole-word-masking", type=str) 10 | parser.add_argument("--output_dir", default="logs", type=str, 11 | help="The output directory where the model checkpoints will be written.") 12 | parser.add_argument("--weight_decay", default=0.0, type=float, 13 | help="Weight decay if we apply some.") 14 | 15 | # Other parameters 16 | parser.add_argument("--load", default=False, action='store_true') 17 | parser.add_argument("--num_workers", default=5, type=int) 18 | parser.add_argument("--train_file", type=str, 19 | default="") 20 | parser.add_argument("--predict_file", type=str, 21 | default="") 22 | parser.add_argument("--init_checkpoint", type=str, 23 | help="Initial checkpoint (usually from a pre-trained BERT model).", 24 | default="") 25 | parser.add_argument("--max_seq_length", default=512, type=int, 26 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 27 | "longer than this will be truncated, and sequences shorter than this will be padded.") 28 | parser.add_argument("--max_query_length", default=30, type=int, 29 | help="The maximum number of tokens for the question. Questions longer than this will " 30 | "be truncated to this length.") 31 | parser.add_argument("--do_train", default=False, 32 | action='store_true', help="Whether to run training.") 33 | parser.add_argument("--do_predict", default=False, 34 | action='store_true', help="Whether to run eval on the dev set.") 35 | parser.add_argument("--train_batch_size", default=8, 36 | type=int, help="Total batch size for training.") 37 | parser.add_argument("--predict_batch_size", default=100, 38 | type=int, help="Total batch size for predictions.") 39 | parser.add_argument("--learning_rate", default=5e-5, 40 | type=float, help="The initial learning rate for Adam.") 41 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 42 | help="Epsilon for Adam optimizer.") 43 | parser.add_argument("--num_train_epochs", default=5000, type=float, 44 | help="Total number of training epochs to perform.") 45 | parser.add_argument('--wait_step', type=int, default=100) 46 | parser.add_argument("--save_checkpoints_steps", default=20000, type=int, 47 | help="How often to save the model checkpoint.") 48 | parser.add_argument("--iterations_per_loop", default=1000, type=int, 49 | help="How many steps to make in each estimator call.") 50 | parser.add_argument("--no_cuda", default=False, action='store_true', 51 | help="Whether not to use CUDA when available") 52 | parser.add_argument("--local_rank", type=int, default=-1, 53 | help="local_rank for distributed training on gpus") 54 | parser.add_argument("--accumulate_gradients", type=int, default=1, 55 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 56 | parser.add_argument('--seed', type=int, default=3, 57 | help="random seed for initialization") 58 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 59 | help="Number of updates steps to accumualte before performing a backward/update pass.") 60 | parser.add_argument('--eval-period', type=int, default=2500) 61 | parser.add_argument('--verbose', action="store_true", default=False) 62 | parser.add_argument('--efficient_eval', action="store_true", help="whether to use fp16 for evaluation") 63 | parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.") 64 | 65 | parser.add_argument('--fp16', action='store_true') 66 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 67 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 68 | "See details at https://nvidia.github.io/apex/amp.html") 69 | 70 | 71 | parser.add_argument('--filter', action='store_true', help="1. paragraph too short; 2. answer in questions.") 72 | 73 | # For evaluation 74 | parser.add_argument('--prefix', type=str, default="eval") 75 | parser.add_argument('--debug', action="store_true", default=False) 76 | parser.add_argument("--eval-workers", default=32, 77 | help="parallel data loader", type=int) 78 | 79 | 80 | # For encode questions 81 | parser.add_argument("--use-whole-model", action="store_true", help="re encode the questions after QA finetuning") 82 | parser.add_argument("--joint-train", action="store_true") 83 | parser.add_argument("--max-pool", action="store_true", help="CLS or maxpooling") 84 | parser.add_argument("--shared-norm", action="store_true", help="normalize span logits across different paragraphs") 85 | parser.add_argument("--retriever-path", type=str, default="", help="pretrained retriever checkpoint") 86 | parser.add_argument("--qa-drop", default=0, type=float) 87 | 88 | parser.add_argument('--embed_save_path', type=str, default="") 89 | parser.add_argument('--is_query_embed', action="store_true", default=False) 90 | 91 | args = parser.parse_args() 92 | 93 | return args 94 | -------------------------------------------------------------------------------- /retrieval/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset, Sampler 2 | import torch 3 | import json 4 | import numpy as np 5 | import random 6 | from tqdm import tqdm 7 | import re 8 | import string 9 | import os 10 | from multiprocessing import Pool as ProcessPool 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | def remove_articles(text): 15 | return re.sub(r'\b(a|an|the)\b', ' ', text) 16 | 17 | def white_space_fix(text): 18 | return ' '.join(text.split()) 19 | 20 | def remove_punc(text): 21 | exclude = set(string.punctuation) 22 | return ''.join(ch for ch in text if ch not in exclude) 23 | 24 | def lower(text): 25 | return text.lower() 26 | 27 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 28 | 29 | def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False): 30 | """Convert a list of 1d tensors into a padded 2d tensor.""" 31 | size = max(v.size(0) for v in values) 32 | res = values[0].new(len(values), size).fill_(pad_idx) 33 | 34 | def copy_tensor(src, dst): 35 | assert dst.numel() == src.numel() 36 | if move_eos_to_beginning: 37 | assert src[-1] == eos_idx 38 | dst[0] = eos_idx 39 | dst[1:] = src[:-1] 40 | else: 41 | dst.copy_(src) 42 | 43 | for i, v in enumerate(values): 44 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 45 | return res 46 | 47 | 48 | class ClusterDataset(Dataset): 49 | 50 | def __init__(self, 51 | tokenizer, 52 | data_folder, 53 | max_query_length, 54 | max_length, 55 | filter=False 56 | ): 57 | super().__init__() 58 | self.tokenizer = tokenizer 59 | self.filter = filter 60 | self.max_query_length = max_query_length 61 | self.max_length = max_length 62 | 63 | print(f"Loading data splits from {data_folder}") 64 | file_lists = [os.path.join(data_folder, f) for f in os.listdir(data_folder)] 65 | 66 | self.data, self.index_clusters = [], [] 67 | processes = ProcessPool(processes=30) 68 | file_datas = processes.map(self.load_file, file_lists) 69 | processes.close() 70 | processes.join() 71 | for file_data in file_datas: 72 | indice = len(self.data) + np.arange(len(file_data)) 73 | self.index_clusters.append(list(indice)) 74 | self.data.extend(file_data) 75 | 76 | print(f"Total {len(self.data)} loaded") 77 | 78 | def filter_sample(self, item): 79 | if len(item["Paragraph"].split()) < 20: 80 | return False 81 | if normalize_answer(item["Answer"]) in normalize_answer(item["Question"]): 82 | return False 83 | return True 84 | 85 | def load_file(self, file): 86 | data = [json.loads(line) for line in open(file).readlines()] 87 | if self.filter: 88 | data = [item for item in data if self.filter_sample(item)] 89 | return data 90 | 91 | def __getitem__(self, index): 92 | sample = self.data[index] 93 | question = sample['Question'] 94 | paragraph = sample['Paragraph'] 95 | 96 | question_ids = torch.LongTensor(self.tokenizer.encode( 97 | question, max_length=self.max_query_length)) 98 | question_masks = torch.ones(question_ids.shape).bool() 99 | 100 | paragraph_ids = torch.LongTensor(self.tokenizer.encode( 101 | paragraph, max_length=self.max_length - self.max_query_length)) 102 | paragraph_masks = torch.ones(paragraph_ids.shape).bool() 103 | 104 | return { 105 | 'input_ids_q': question_ids, 106 | 'input_mask_q': question_masks, 107 | 'input_ids_c': paragraph_ids, 108 | 'input_mask_c': paragraph_masks, 109 | } 110 | 111 | def __len__(self): 112 | return len(self.data) 113 | 114 | 115 | class ClusterSampler(Sampler): 116 | 117 | def __init__(self, data_source, batch_size): 118 | """ 119 | batch size: within batch, all samples come from the same cluster 120 | """ 121 | print(f"Sample with batch size {batch_size}") 122 | 123 | index_clusters = data_source.index_clusters 124 | sample_indice = [] 125 | 126 | # shuffle inside each cluster 127 | num_group = 3 128 | for cluster in index_clusters: 129 | groups = [] # 3 adjacent examples share the same para 130 | for i in range(num_group): 131 | groups.append(cluster[i::num_group]) 132 | random.shuffle(groups) 133 | for g in groups: 134 | random.shuffle(g) 135 | sample_indice += g 136 | 137 | # sample batches, avoid adjacent batches always come from the same cluster 138 | self.sample_indice = [] 139 | batch_starts = np.arange(0, len(data_source), batch_size) 140 | np.random.shuffle(batch_starts) 141 | for batch_start in batch_starts: 142 | self.sample_indice += sample_indice[batch_start:batch_start+batch_size] 143 | 144 | assert len(self.sample_indice) == len(data_source) 145 | 146 | def __len__(self): 147 | return len(self.sample_indice) 148 | 149 | def __iter__(self): 150 | return iter(self.sample_indice) 151 | 152 | 153 | class ReDataset(Dataset): 154 | 155 | def __init__(self, 156 | tokenizer, 157 | data_path, 158 | max_query_length, 159 | max_length, 160 | filter=False 161 | ): 162 | super().__init__() 163 | self.tokenizer = tokenizer 164 | self.filter = filter 165 | print(f"Loading data from {data_path}") 166 | 167 | self.data = [json.loads(line) for line in open(data_path).readlines()] 168 | 169 | # filter 170 | original_count = len(self.data) 171 | if self.filter: 172 | self.data = [item for item in self.data if self.filter_sample(item)] 173 | print(f"Using {len(self.data)} out of {original_count}") 174 | 175 | self.max_query_length = max_query_length 176 | self.max_length = max_length 177 | self.group_indexs = [] 178 | num_group = 3 179 | indexs = list(range(len(self.data))) 180 | for i in range(num_group): 181 | self.group_indexs.append(indexs[i::num_group]) 182 | 183 | def filter_sample(self, item): 184 | if len(item["Paragraph"].split()) < 20: 185 | return False 186 | if normalize_answer(item["Answer"]) in normalize_answer(item["Question"]): 187 | return False 188 | return True 189 | 190 | def __getitem__(self, index): 191 | sample = self.data[index] 192 | question = sample['Question'] 193 | paragraph = sample['Paragraph'] 194 | 195 | question_ids = torch.LongTensor(self.tokenizer.encode(question, max_length=self.max_query_length)) 196 | question_masks = torch.ones(question_ids.shape).bool() 197 | 198 | paragraph_ids = torch.LongTensor(self.tokenizer.encode(paragraph, max_length=self.max_length - self.max_query_length)) 199 | paragraph_masks = torch.ones(paragraph_ids.shape).bool() 200 | 201 | return { 202 | 'input_ids_q': question_ids, 203 | 'input_mask_q': question_masks, 204 | 'input_ids_c': paragraph_ids, 205 | 'input_mask_c': paragraph_masks, 206 | } 207 | 208 | def __len__(self): 209 | return len(self.data) 210 | 211 | 212 | class ReSampler(Sampler): 213 | """ 214 | Shuffle QA pairs not context, make sure data within the batch are from the same QA pair 215 | """ 216 | 217 | def __init__(self, data_source): 218 | # for each QA pair, sample negative paragraphs 219 | sample_indice = [] 220 | for _ in data_source.group_indexs: 221 | random.shuffle(_) 222 | sample_indice += _ 223 | self.sample_indice = sample_indice 224 | 225 | def __len__(self): 226 | return len(self.sample_indice) 227 | 228 | def __iter__(self): 229 | return iter(self.sample_indice) 230 | 231 | def re_collate(samples): 232 | if len(samples) == 0: 233 | return {} 234 | 235 | return { 236 | 'input_ids_q': collate_tokens([s['input_ids_q'] for s in samples], 0), 237 | 'input_mask_q': collate_tokens([s['input_mask_q'] for s in samples], 0), 238 | 'input_ids_c': collate_tokens([s['input_ids_c'] for s in samples], 0), 239 | 'input_mask_c': collate_tokens([s['input_mask_c'] for s in samples], 0), 240 | } 241 | 242 | class FTDataset(Dataset): 243 | """ 244 | finetune the Question encoder with 245 | """ 246 | 247 | def __init__(self, 248 | tokenizer, 249 | data_path, 250 | max_query_length, 251 | max_length, 252 | filter=False 253 | ): 254 | super().__init__() 255 | 256 | 257 | class EmDataset(Dataset): 258 | 259 | def __init__(self, 260 | tokenizer, 261 | data_path, 262 | max_query_length, 263 | max_length, 264 | is_query_embed, 265 | ): 266 | super().__init__() 267 | self.is_query_embed = is_query_embed 268 | self.tokenizer = tokenizer 269 | 270 | print(f"Loading data from {data_path}") 271 | self.data = [json.loads(_.strip()) 272 | for _ in tqdm(open(data_path).readlines())] 273 | 274 | self.max_length = max_query_length if is_query_embed else max_length 275 | print(f"Max sequence length: {self.max_length}") 276 | 277 | 278 | def __getitem__(self, index): 279 | sample = self.data[index] 280 | if self.is_query_embed: 281 | sent = sample['question'] 282 | else: 283 | sent = sample['text'] 284 | 285 | sent_ids = torch.LongTensor( 286 | self.tokenizer.encode(sent, max_length=self.max_length)) 287 | sent_masks = torch.ones(sent_ids.shape).bool() 288 | 289 | return { 290 | 'input_ids': sent_ids, 291 | 'input_mask': sent_masks, 292 | } 293 | 294 | def __len__(self): 295 | return len(self.data) 296 | 297 | 298 | def em_collate(samples): 299 | if len(samples) == 0: 300 | return {} 301 | 302 | return { 303 | 'input_ids': collate_tokens([s['input_ids'] for s in samples], 0), 304 | 'input_mask': collate_tokens([s['input_mask'] for s in samples], 0), 305 | } 306 | -------------------------------------------------------------------------------- /retrieval/eval_retrieval.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import json 4 | import faiss 5 | import argparse 6 | 7 | from multiprocessing import Pool as ProcessPool 8 | from multiprocessing.util import Finalize 9 | from functools import partial 10 | from collections import defaultdict 11 | 12 | from basic_tokenizer import SimpleTokenizer 13 | from utils import DocDB, normalize 14 | 15 | 16 | PROCESS_TOK = None 17 | PROCESS_DB = None 18 | 19 | def init(db_path): 20 | global PROCESS_TOK, PROCESS_DB 21 | PROCESS_TOK = SimpleTokenizer() 22 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 23 | PROCESS_DB = DocDB(db_path) 24 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 25 | 26 | 27 | def para_has_answer(answer, para, return_matched=False): 28 | global PROCESS_DB, PROCESS_TOK 29 | text = normalize(para) 30 | tokens = PROCESS_TOK.tokenize(text) 31 | text = tokens.words(uncased=True) 32 | assert len(text) == len(tokens) 33 | for single_answer in answer: 34 | single_answer = normalize(single_answer) 35 | single_answer = PROCESS_TOK.tokenize(single_answer) 36 | single_answer = single_answer.words(uncased=True) 37 | for i in range(0, len(text) - len(single_answer) + 1): 38 | if single_answer == text[i: i + len(single_answer)]: 39 | if return_matched: 40 | return True, tokens.slice(i, i + len(single_answer)).untokenize() 41 | else: 42 | return True 43 | if return_matched: 44 | return False, "" 45 | return False 46 | 47 | def get_score(answer_doc, topk=80): 48 | """Search through all the top docs to see if they have the answer.""" 49 | question, answer, doc_ids = answer_doc 50 | top5doc_covered = 0 51 | global PROCESS_DB 52 | all_paras = [PROCESS_DB.get_doc_text(doc_id) for doc_id in doc_ids] 53 | 54 | topk_paras = all_paras[:topk] 55 | topkpara_covered = [] 56 | for p in topk_paras: 57 | topkpara_covered.append(int(para_has_answer(answer, p))) 58 | 59 | return { 60 | str(topk): int(np.sum(topkpara_covered) > 0), 61 | "5": int(np.sum(topkpara_covered[:5]) > 0), 62 | "10": int(np.sum(topkpara_covered[:10]) > 0), 63 | "20": int(np.sum(topkpara_covered[:20]) > 0), 64 | "50": int(np.sum(topkpara_covered[:50]) > 0), 65 | } 66 | 67 | 68 | def convert_idx2id(idxs): 69 | idx_id_mapping = json.load(open('../pretrained_models/idx_id.json')) 70 | retrieval_results = [] 71 | for cand_idx in idxs: 72 | out_ids = [] 73 | for _ in cand_idx: 74 | out_ids.append(idx_id_mapping[str(_)]) 75 | retrieval_results.append(out_ids) 76 | return retrieval_results 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('raw_data', type=str, default=None) 81 | parser.add_argument('indexpath', type=str, default=None) 82 | parser.add_argument('query_embed', type=str, default=None) 83 | parser.add_argument('db', type=str, default=None) 84 | parser.add_argument('--topk', type=int, default=80) 85 | parser.add_argument('--num-workers', type=int, default=10) 86 | args = parser.parse_args() 87 | 88 | qas = [json.loads(line) for line in open(args.raw_data).readlines()] 89 | questions = [item["question"] for item in qas] 90 | answers = [item["answer"] for item in qas] 91 | 92 | processes = ProcessPool( 93 | processes=args.num_workers, 94 | initializer=init, 95 | initargs=[args.db] 96 | ) 97 | 98 | d = 128 99 | xq = np.load(args.query_embed).astype('float32') 100 | xb = np.load(args.indexpath).astype('float32') 101 | 102 | index = faiss.IndexFlatIP(d) # build the index 103 | index.add(xb) # add vectors to the index 104 | D, I = index.search(xq, args.topk) # actual search 105 | 106 | retrieval_results = convert_idx2id(I) 107 | 108 | assert len(retrieval_results) == len(questions) == len(answers) 109 | answers_docs = zip(questions, answers, retrieval_results) 110 | 111 | get_score_partial = partial( 112 | get_score, topk=args.topk) 113 | results = processes.map(get_score_partial, answers_docs) 114 | 115 | aggregate = defaultdict(list) 116 | for r in results: 117 | for k, v in r.items(): 118 | aggregate[k].append(v) 119 | 120 | for k in aggregate: 121 | results = aggregate[k] 122 | print('Top {} Recall for {} QA pairs: {} ...'.format( 123 | k, len(results), np.mean(results))) 124 | -------------------------------------------------------------------------------- /retrieval/gen_index_id_map.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | mapping = {} 4 | with open('../data/para_doc.db') as f_in: 5 | for idx, line in enumerate(f_in): 6 | sample = json.loads(line.strip()) 7 | mapping[idx] = sample['id'] 8 | with open('index_data/idx_id.json', 'w') as f_out: 9 | json.dump(mapping, f_out) 10 | 11 | -------------------------------------------------------------------------------- /retrieval/get_embed.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import json 4 | import os 5 | import random 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | from copy import deepcopy 10 | 11 | from torch.utils.data import DataLoader 12 | from datasets import EmDataset, em_collate 13 | from retriever import BertForRetriever 14 | from transformers import AdamW, BertConfig, BertTokenizer 15 | from utils import move_to_cuda, convert_to_half, AverageMeter 16 | from config import get_args 17 | 18 | from collections import defaultdict, namedtuple 19 | import torch.nn.functional as F 20 | 21 | 22 | def load_saved(model, path): 23 | state_dict = torch.load(path) 24 | def filter(x): return x[7:] if x.startswith('module.') else x 25 | state_dict = {filter(k): v for (k, v) in state_dict.items()} 26 | model.load_state_dict(state_dict) 27 | return model 28 | 29 | def main(): 30 | args = get_args() 31 | 32 | is_query_embed = args.is_query_embed 33 | embed_save_path = args.embed_save_path 34 | 35 | if args.fp16: 36 | try: 37 | import apex 38 | apex.amp.register_half_function(torch, 'einsum') 39 | except ImportError: 40 | raise ImportError( 41 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 42 | 43 | 44 | if args.local_rank == -1 or args.no_cuda: 45 | device = torch.device( 46 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 47 | n_gpu = torch.cuda.device_count() 48 | else: 49 | device = torch.device("cuda", args.local_rank) 50 | n_gpu = 1 51 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 52 | torch.distributed.init_process_group(backend='nccl') 53 | 54 | if args.accumulate_gradients < 1: 55 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 56 | args.accumulate_gradients)) 57 | 58 | args.train_batch_size = int( 59 | args.train_batch_size / args.accumulate_gradients) 60 | random.seed(args.seed) 61 | np.random.seed(args.seed) 62 | torch.manual_seed(args.seed) 63 | if n_gpu > 0: 64 | torch.cuda.manual_seed_all(args.seed) 65 | 66 | if not args.do_train and not args.do_predict: 67 | raise ValueError( 68 | "At least one of `do_train` or `do_predict` must be True.") 69 | 70 | if args.do_train: 71 | if not args.train_file: 72 | raise ValueError( 73 | "If `do_train` is True, then `train_file` must be specified.") 74 | if not args.predict_file: 75 | raise ValueError( 76 | "If `do_train` is True, then `predict_file` must be specified.") 77 | 78 | if args.do_predict: 79 | if not args.predict_file: 80 | raise ValueError( 81 | "If `do_predict` is True, then `predict_file` must be specified.") 82 | 83 | bert_config = BertConfig.from_pretrained(args.bert_model_name) 84 | model = BertForRetriever(bert_config, args) 85 | tokenizer = BertTokenizer.from_pretrained(args.bert_model_name) 86 | 87 | if args.do_train and args.max_seq_length > bert_config.max_position_embeddings: 88 | raise ValueError( 89 | "Cannot use sequence length %d because the BERT model " 90 | "was only trained up to sequence length %d" % 91 | (args.max_seq_length, bert_config.max_position_embeddings)) 92 | 93 | eval_dataset = EmDataset( 94 | tokenizer, args.predict_file, args.max_query_length, args.max_seq_length, is_query_embed) 95 | eval_dataloader = DataLoader( 96 | eval_dataset, batch_size=args.predict_batch_size, collate_fn=em_collate, pin_memory=True, num_workers=args.eval_workers) 97 | 98 | assert args.init_checkpoint != "" 99 | model = load_saved(model, args.init_checkpoint) 100 | 101 | model.to(device) 102 | 103 | if args.do_train: 104 | no_decay = ['bias', 'LayerNorm.weight'] 105 | optimizer_parameters = [ 106 | {'params': [p for n, p in model.named_parameters() if not any( 107 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 108 | {'params': [p for n, p in model.named_parameters() if any( 109 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 110 | ] 111 | optimizer = AdamW(optimizer_parameters, 112 | lr=args.learning_rate, eps=args.adam_epsilon) 113 | 114 | if args.fp16: 115 | try: 116 | from apex import amp 117 | except ImportError: 118 | raise ImportError( 119 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 120 | model, optimizer = amp.initialize( 121 | model, optimizer, opt_level=args.fp16_opt_level) 122 | else: 123 | if args.fp16: 124 | try: 125 | from apex import amp 126 | except ImportError: 127 | raise ImportError( 128 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 129 | model = amp.initialize(model, opt_level=args.fp16_opt_level) 130 | 131 | if args.local_rank != -1: 132 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 133 | output_device=args.local_rank) 134 | elif n_gpu > 1: 135 | model = torch.nn.DataParallel(model) 136 | 137 | 138 | embeds = predict(args, model, eval_dataloader, device, fp16=args.efficient_eval, is_query_embed=is_query_embed) 139 | np.save(embed_save_path, embeds.cpu().numpy()) 140 | 141 | 142 | def predict(args, model, eval_dataloader, device, fp16=False, is_query_embed=True): 143 | if type(model) == list: 144 | model = [m.eval() for m in model] 145 | else: 146 | model.eval() 147 | if fp16: 148 | if type(model) == list: 149 | model = [m.half() for m in model] 150 | else: 151 | model.half() 152 | 153 | num_correct = 0.0 154 | num_total = 0.0 155 | embed_array = [] 156 | for batch in tqdm(eval_dataloader): 157 | batch_to_feed = move_to_cuda(batch) 158 | with torch.no_grad(): 159 | results = model.get_embed(batch_to_feed, is_query_embed) 160 | embed = results['embed'] 161 | embed_array.append(embed) 162 | #print(prediction, target, sum(prediction==target), len(prediction)) 163 | #print(num_total, num_correct) 164 | 165 | ## linear combination tuning on dev data 166 | embed_array = torch.cat(embed_array) 167 | 168 | if fp16: 169 | model.float() 170 | 171 | model.train() 172 | return embed_array 173 | 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /retrieval/get_para_embed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=3 python3 get_embed.py \ 3 | --do_predict \ 4 | --prefix eval-para \ 5 | --predict_batch_size 300 \ 6 | --bert_model_name bert-base-uncased \ 7 | --fp16 \ 8 | --predict_file ../data/wiki_splits.txt \ 9 | --init_checkpoint logs/retrieve_train.txt-seed87-bsz640-fp16True-retriever_pretraining_single-lr1e-05-bert-base-uncased-filterTrue/checkpoint_best.pt \ 10 | --embed_save_path encodings/para_embed.npy \ 11 | --eval-workers 32 \ 12 | 13 | -------------------------------------------------------------------------------- /retrieval/group_paras.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import faiss 3 | import os 4 | import argparse 5 | 6 | def write_file(file_name, samples): 7 | with open(file_name, 'w') as f_out: 8 | for _ in samples: 9 | f_out.write(_) 10 | 11 | 12 | def group_paras(I, ncentroids, split_path): 13 | samples = [[] for _ in range(ncentroids)] 14 | with open('../data/retrieve_train.txt') as f_in: 15 | for i, line in enumerate(f_in): 16 | samples[I[i][0]].append(line) 17 | for i, group in enumerate(samples): 18 | write_file(split_path + 'split_'+str(i)+'.txt', group) 19 | 20 | def clusering(data, niter=1000, verbose=True, ncentroids=1024, max_points_per_centroid=10000000, gpu_id=0, spherical=False): 21 | # use one gpu 22 | ''' 23 | res = faiss.StandardGpuResources() 24 | cfg = faiss.GpuIndexFlatConfig() 25 | cfg.useFloat16 = False 26 | cfg.device = gpu_id 27 | 28 | d = data.shape[1] 29 | if spherical: 30 | index = faiss.GpuIndexFlatIP(res, d, cfg) 31 | else: 32 | index = faiss.GpuIndexFlatL2(res, d, cfg) 33 | ''' 34 | d = data.shape[1] 35 | if spherical: 36 | index = faiss.IndexFlatIP(d) 37 | else: 38 | index = faiss.IndexFlatL2(d) 39 | 40 | clus = faiss.Clustering(d, ncentroids) 41 | clus.verbose = True 42 | clus.niter = niter 43 | clus.max_points_per_centroid = max_points_per_centroid 44 | 45 | clus.train(x, index) 46 | centroids = faiss.vector_float_to_array(clus.centroids) 47 | centroids = centroids.reshape(ncentroids, d) 48 | 49 | index.reset() 50 | index.add(centroids) 51 | D, I = index.search(data, 1) 52 | 53 | return D, I 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--ncentroids', type=int, default=10000) 58 | parser.add_argument('--niter', type=int, default=250) 59 | parser.add_argument('--max_points_per_centroid', type=int, default=1000) 60 | parser.add_argument('--indexpath', type=str, default=None) 61 | parser.add_argument('--spherical', action='store_true') 62 | args = parser.parse_args() 63 | 64 | 65 | train_para_embed_path = "encodings/train_para_embed.npy" 66 | split_save_path = "../data/data_splits/" 67 | if os.path.exists(split_save_path) and os.listdir(split_save_path): 68 | print(f"output directory {split_save_path} already exists and is not empty.") 69 | if not os.path.exists(split_save_path): 70 | os.makedirs(split_save_path, exist_ok=True) 71 | 72 | x = np.load(train_para_embed_path) 73 | x = np.float32(x) 74 | 75 | D, I = clusering(x, niter=args.niter, ncentroids=args.ncentroids, max_points_per_centroid=args.max_points_per_centroid, spherical=args.spherical) 76 | 77 | group_paras(I, args.ncentroids, split_path=split_save_path) 78 | -------------------------------------------------------------------------------- /retrieval/retriever.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers import BertModel, BertConfig, BertPreTrainedModel 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss 5 | import torch 6 | 7 | 8 | class BertForRetriever(nn.Module): 9 | 10 | def __init__(self, 11 | config, 12 | args 13 | ): 14 | super(BertForRetriever, self).__init__() 15 | 16 | self.bert_q = BertModel.from_pretrained(args.bert_model_name) 17 | self.bert_c = BertModel.from_pretrained(args.bert_model_name) 18 | 19 | self.proj_q = nn.Linear(config.hidden_size, 128) 20 | self.proj_c = nn.Linear(config.hidden_size, 128) 21 | 22 | def forward(self, batch): 23 | input_ids_q, attention_mask_q = batch["input_ids_q"], batch["input_mask_q"] 24 | q_cls = self.bert_q(input_ids_q, attention_mask_q)[1] 25 | q = self.proj_q(q_cls) 26 | 27 | input_ids_c, attention_mask_c = batch["input_ids_c"], batch["input_mask_c"] 28 | c_cls = self.bert_c(input_ids_c, attention_mask_c)[1] 29 | c = self.proj_c(c_cls) 30 | 31 | return {"q": q, "c": c} 32 | 33 | def get_embed(self, batch, is_query_embed): 34 | 35 | input_ids, attention_mask = batch["input_ids"], batch["input_mask"] 36 | if is_query_embed: 37 | q_cls = self.bert_q(input_ids, attention_mask)[1] 38 | q = self.proj_q(q_cls) 39 | return {'embed': q} 40 | else: 41 | c_cls = self.bert_c(input_ids, attention_mask)[1] 42 | c = self.proj_c(c_cls) 43 | return {'embed': c} 44 | -------------------------------------------------------------------------------- /retrieval/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_tokens_to_ids(vocab, tokens): 28 | """Converts a sequence of tokens into ids using the vocab.""" 29 | ids = [] 30 | for token in tokens: 31 | ids.append(vocab[token]) 32 | return ids 33 | 34 | def whitespace_tokenize(text): 35 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 36 | text = text.strip() 37 | if not text: 38 | return [] 39 | tokens = text.split() 40 | return tokens 41 | 42 | 43 | def convert_to_unicode(text): 44 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 45 | if six.PY3: 46 | if isinstance(text, str): 47 | return text 48 | elif isinstance(text, bytes): 49 | return text.decode("utf-8", "ignore") 50 | else: 51 | raise ValueError("Unsupported string type: %s" % (type(text))) 52 | elif six.PY2: 53 | if isinstance(text, str): 54 | return text.decode("utf-8", "ignore") 55 | elif isinstance(text, unicode): 56 | return text 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | else: 60 | raise ValueError("Not running on Python2 or Python 3?") 61 | 62 | 63 | def _is_whitespace(char): 64 | """Checks whether `chars` is a whitespace character.""" 65 | # \t, \n, and \r are technically contorl characters but we treat them 66 | # as whitespace since they are generally considered as such. 67 | if char == " " or char == "\t" or char == "\n" or char == "\r": 68 | return True 69 | cat = unicodedata.category(char) 70 | if cat == "Zs": 71 | return True 72 | return False 73 | 74 | 75 | def _is_control(char): 76 | """Checks whether `chars` is a control character.""" 77 | # These are technically control characters but we count them as whitespace 78 | # characters. 79 | if char == "\t" or char == "\n" or char == "\r": 80 | return False 81 | cat = unicodedata.category(char) 82 | if cat.startswith("C"): 83 | return True 84 | return False 85 | 86 | class BasicTokenizer(object): 87 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 88 | 89 | def __init__(self, do_lower_case=True): 90 | """Constructs a BasicTokenizer. 91 | Args: 92 | do_lower_case: Whether to lower case the input. 93 | """ 94 | self.do_lower_case = do_lower_case 95 | 96 | def tokenize(self, text): 97 | """Tokenizes a piece of text.""" 98 | text = convert_to_unicode(text) 99 | text = self._clean_text(text) 100 | orig_tokens = whitespace_tokenize(text) 101 | split_tokens = [] 102 | for token in orig_tokens: 103 | if self.do_lower_case: 104 | token = token.lower() 105 | token = self._run_strip_accents(token) 106 | split_tokens.extend(self._run_split_on_punc(token)) 107 | 108 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 109 | return output_tokens 110 | 111 | def _run_strip_accents(self, text): 112 | """Strips accents from a piece of text.""" 113 | text = unicodedata.normalize("NFD", text) 114 | output = [] 115 | for char in text: 116 | cat = unicodedata.category(char) 117 | if cat == "Mn": 118 | continue 119 | output.append(char) 120 | return "".join(output) 121 | 122 | def _run_split_on_punc(self, text): 123 | """Splits punctuation on a piece of text.""" 124 | chars = list(text) 125 | i = 0 126 | start_new_word = True 127 | output = [] 128 | while i < len(chars): 129 | char = chars[i] 130 | if _is_punctuation(char): 131 | output.append([char]) 132 | start_new_word = True 133 | else: 134 | if start_new_word: 135 | output.append([]) 136 | start_new_word = False 137 | output[-1].append(char) 138 | i += 1 139 | 140 | return ["".join(x) for x in output] 141 | 142 | def _clean_text(self, text): 143 | """Performs invalid character removal and whitespace cleanup on text.""" 144 | output = [] 145 | for char in text: 146 | cp = ord(char) 147 | if cp == 0 or cp == 0xfffd or _is_control(char): 148 | continue 149 | if _is_whitespace(char): 150 | output.append(" ") 151 | else: 152 | output.append(char) 153 | return "".join(output) 154 | 155 | 156 | def _is_punctuation(char): 157 | """Checks whether `chars` is a punctuation character.""" 158 | cp = ord(char) 159 | # We treat all non-letter/number ASCII as punctuation. 160 | # Characters such as "^", "$", and "`" are not in the Unicode 161 | # Punctuation class but we treat them as punctuation anyways, for 162 | # consistency. 163 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 164 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 165 | return True 166 | cat = unicodedata.category(char) 167 | if cat.startswith("P"): 168 | return True 169 | return False 170 | 171 | 172 | def process(s, tokenizer): 173 | try: 174 | return tokenizer.tokenize(s) 175 | except: 176 | print('failed on', s) 177 | raise 178 | 179 | if __name__ == "__main__": 180 | _is_whitespace("a") 181 | -------------------------------------------------------------------------------- /retrieval/train_retriever.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import json 4 | import os 5 | import random 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | from copy import deepcopy 10 | 11 | from torch.utils.data import DataLoader 12 | from datasets import ReDataset, ReSampler, re_collate, ClusterSampler, ClusterDataset 13 | from retriever import BertForRetriever 14 | from transformers import AdamW, BertConfig, BertTokenizer 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | from utils import move_to_cuda, convert_to_half, AverageMeter 18 | from config import get_args 19 | 20 | from collections import defaultdict, namedtuple 21 | import torch.nn.functional as F 22 | from torch.nn import CrossEntropyLoss 23 | 24 | 25 | def load_saved(model, path): 26 | state_dict = torch.load(path) 27 | def filter(x): return x[7:] if x.startswith('module.') else x 28 | state_dict = {filter(k): v for (k, v) in state_dict.items()} 29 | model.load_state_dict(state_dict) 30 | return model 31 | 32 | def main(): 33 | args = get_args() 34 | 35 | if args.fp16: 36 | try: 37 | import apex 38 | apex.amp.register_half_function(torch, 'einsum') 39 | except ImportError: 40 | raise ImportError( 41 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 42 | 43 | # tb logger 44 | data_name = args.train_file.split("/")[-1].split('-')[0] 45 | model_name = f"{data_name}-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-{args.prefix}-lr{args.learning_rate}-{args.bert_model_name}-filter{args.filter}" 46 | tb_logger = SummaryWriter(os.path.join( 47 | args.output_dir, "tflogs", model_name)) 48 | args.output_dir = os.path.join(args.output_dir, model_name) 49 | 50 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 51 | print( 52 | f"output directory {args.output_dir} already exists and is not empty.") 53 | if not os.path.exists(args.output_dir): 54 | os.makedirs(args.output_dir, exist_ok=True) 55 | 56 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 57 | datefmt='%m/%d/%Y %H:%M:%S', 58 | level=logging.INFO, 59 | handlers=[logging.FileHandler(os.path.join(args.output_dir, "log.txt")), 60 | logging.StreamHandler()]) 61 | logger = logging.getLogger(__name__) 62 | logger.info(args) 63 | 64 | if args.local_rank == -1 or args.no_cuda: 65 | device = torch.device( 66 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 67 | n_gpu = torch.cuda.device_count() 68 | else: 69 | device = torch.device("cuda", args.local_rank) 70 | n_gpu = 1 71 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 72 | torch.distributed.init_process_group(backend='nccl') 73 | logger.info("device %s n_gpu %d distributed training %r", 74 | device, n_gpu, bool(args.local_rank != -1)) 75 | 76 | if args.accumulate_gradients < 1: 77 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 78 | args.accumulate_gradients)) 79 | 80 | args.train_batch_size = int( 81 | args.train_batch_size / args.accumulate_gradients) 82 | random.seed(args.seed) 83 | np.random.seed(args.seed) 84 | torch.manual_seed(args.seed) 85 | if n_gpu > 0: 86 | torch.cuda.manual_seed_all(args.seed) 87 | 88 | if not args.do_train and not args.do_predict: 89 | raise ValueError( 90 | "At least one of `do_train` or `do_predict` must be True.") 91 | 92 | if args.do_train: 93 | if not args.train_file: 94 | raise ValueError( 95 | "If `do_train` is True, then `train_file` must be specified.") 96 | if not args.predict_file: 97 | raise ValueError( 98 | "If `do_train` is True, then `predict_file` must be specified.") 99 | 100 | if args.do_predict: 101 | if not args.predict_file: 102 | raise ValueError( 103 | "If `do_predict` is True, then `predict_file` must be specified.") 104 | 105 | bert_config = BertConfig.from_pretrained(args.bert_model_name) 106 | model = BertForRetriever(bert_config, args) 107 | tokenizer = BertTokenizer.from_pretrained(args.bert_model_name) 108 | 109 | if args.do_train and args.max_seq_length > bert_config.max_position_embeddings: 110 | raise ValueError( 111 | "Cannot use sequence length %d because the BERT model " 112 | "was only trained up to sequence length %d" % 113 | (args.max_seq_length, bert_config.max_position_embeddings)) 114 | 115 | eval_dataset = ReDataset( 116 | tokenizer, args.predict_file, args.max_query_length, args.max_seq_length) 117 | #sampler = ReSampler(eval_dataset) 118 | eval_dataloader = DataLoader( 119 | eval_dataset, batch_size=args.predict_batch_size, collate_fn=re_collate, pin_memory=True, num_workers=args.eval_workers) 120 | logger.info(f"Num of dev batches: {len(eval_dataloader)}") 121 | 122 | if args.init_checkpoint != "": 123 | if ";" in args.init_checkpoint: 124 | models = [] 125 | for path in args.init_checkpoint.split(";"): 126 | instance = deepcopy(load_saved(model, path)) 127 | models.append(instance) 128 | model = models 129 | else: 130 | model = load_saved(model, args.init_checkpoint) 131 | 132 | if type(model) == list: 133 | model = [m.to(device) for m in model] 134 | else: 135 | model.to(device) 136 | print( 137 | f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 138 | 139 | if args.do_train: 140 | no_decay = ['bias', 'LayerNorm.weight'] 141 | optimizer_parameters = [ 142 | {'params': [p for n, p in model.named_parameters() if not any( 143 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 144 | {'params': [p for n, p in model.named_parameters() if any( 145 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 146 | ] 147 | optimizer = AdamW(optimizer_parameters, 148 | lr=args.learning_rate, eps=args.adam_epsilon) 149 | 150 | if args.fp16: 151 | try: 152 | from apex import amp 153 | except ImportError: 154 | raise ImportError( 155 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 156 | model, optimizer = amp.initialize( 157 | model, optimizer, opt_level=args.fp16_opt_level) 158 | else: 159 | if args.fp16: 160 | try: 161 | from apex import amp 162 | except ImportError: 163 | raise ImportError( 164 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 165 | model = amp.initialize(model, opt_level=args.fp16_opt_level) 166 | 167 | if args.local_rank != -1: 168 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 169 | output_device=args.local_rank) 170 | elif n_gpu > 1: 171 | model = torch.nn.DataParallel(model) 172 | 173 | if args.do_train: 174 | global_step = 0 # gradient update step 175 | batch_step = 0 # forward batch count 176 | best_acc = 0 177 | wait_step = 0 178 | stop_training = False 179 | train_loss_meter = AverageMeter() 180 | model.train() 181 | 182 | if not os.path.isdir(args.train_file): 183 | train_dataset = ReDataset( 184 | tokenizer, args.train_file, args.max_query_length, args.max_seq_length, args.filter) 185 | sampler = ReSampler(train_dataset) 186 | train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, sampler=sampler, pin_memory=True, collate_fn=re_collate, num_workers=8) 187 | else: 188 | train_dataset = ClusterDataset( 189 | tokenizer, args.train_file, args.max_query_length, args.max_seq_length, args.filter) 190 | sampler = ClusterSampler( 191 | train_dataset, args.train_batch_size) 192 | train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, sampler=sampler, pin_memory=True, collate_fn=re_collate, num_workers=8) 193 | 194 | logger.info('Start training....') 195 | loss_fct = CrossEntropyLoss() 196 | for epoch in range(int(args.num_train_epochs)): 197 | 198 | for batch in tqdm(train_dataloader): 199 | batch_step += 1 200 | batch = move_to_cuda(batch) 201 | outputs = model(batch) 202 | 203 | product = torch.mm(outputs["q"], outputs["c"].t()) 204 | target = torch.arange(product.size(0)).to(product.device) 205 | loss = loss_fct(product, target) 206 | 207 | if args.gradient_accumulation_steps > 1: 208 | loss = loss / args.gradient_accumulation_steps 209 | 210 | if args.fp16: 211 | with amp.scale_loss(loss, optimizer) as scaled_loss: 212 | scaled_loss.backward() 213 | else: 214 | loss.backward() 215 | 216 | train_loss_meter.update(loss.item()) 217 | tb_logger.add_scalar('batch_train_loss', 218 | loss.item(), global_step) 219 | tb_logger.add_scalar('smoothed_train_loss', 220 | train_loss_meter.avg, global_step) 221 | 222 | if (batch_step + 1) % args.gradient_accumulation_steps == 0: 223 | if args.fp16: 224 | torch.nn.utils.clip_grad_norm_( 225 | amp.master_params(optimizer), args.max_grad_norm) 226 | else: 227 | torch.nn.utils.clip_grad_norm_( 228 | model.parameters(), args.max_grad_norm) 229 | optimizer.step() # We have accumulated enought gradients 230 | model.zero_grad() 231 | global_step += 1 232 | 233 | if global_step % args.save_checkpoints_steps == 0: 234 | torch.save(model.state_dict(), os.path.join( 235 | args.output_dir, f"checkpoint_{global_step}.pt")) 236 | 237 | if global_step % args.eval_period == 0: 238 | acc = predict(args, model, eval_dataloader, 239 | device, fp16=args.efficient_eval) 240 | logger.info("Step %d Train loss %.2f Acc %.2f on epoch=%d" % ( 241 | global_step, train_loss_meter.avg, acc*100, epoch)) 242 | 243 | tb_logger.add_scalar('dev_acc', acc*100, global_step) 244 | 245 | # save most recent model 246 | torch.save(model.state_dict(), os.path.join( 247 | args.output_dir, f"checkpoint_last.pt")) 248 | 249 | if best_acc < acc: 250 | logger.info("Saving model with best Acc %.2f -> Acc %.2f on epoch=%d" % 251 | (best_acc*100, acc*100, epoch)) 252 | # model_state_dict = {k: v.cpu() for ( 253 | # k, v) in model.state_dict().items()} 254 | torch.save(model.state_dict(), os.path.join( 255 | args.output_dir, f"checkpoint_best.pt")) 256 | model = model.to(device) 257 | best_acc = acc 258 | wait_step = 0 259 | stop_training = False 260 | else: 261 | wait_step += 1 262 | if wait_step == args.wait_step: 263 | stop_training = True 264 | 265 | 266 | 267 | # acc = predict(args, model, eval_dataloader, 268 | # device, fp16=args.efficient_eval) 269 | # tb_logger.add_scalar('dev_acc', acc*100, global_step) 270 | # logger.info(f"average training loss {train_loss_meter.avg}") 271 | # if best_acc < acc: 272 | # logger.info("Saving model with best Acc %.2f -> Acc %.2f on epoch=%d" % 273 | # (best_acc*100, acc*100, epoch)) 274 | # model_state_dict = {k: v.cpu() for ( 275 | # k, v) in model.state_dict().items()} 276 | # torch.save(model_state_dict, os.path.join( 277 | # args.output_dir, "best-model.pt")) 278 | # model = model.to(device) 279 | # best_acc = acc 280 | # wait_step = 0 281 | 282 | if stop_training: 283 | break 284 | 285 | logger.info("Training finished!") 286 | 287 | elif args.do_predict: 288 | acc = predict(args, model, eval_dataloader, device, fp16=args.efficient_eval) 289 | logger.info(f"test performance {acc}") 290 | print(acc) 291 | 292 | 293 | def predict(args, model, eval_dataloader, device, fp16=False): 294 | if type(model) == list: 295 | model = [m.eval() for m in model] 296 | else: 297 | model.eval() 298 | 299 | if fp16: 300 | if type(model) == list: 301 | model = [m.half() for m in model] 302 | else: 303 | model.half() 304 | 305 | num_correct = 0.0 306 | num_total = 0.0 307 | for batch in tqdm(eval_dataloader): 308 | batch_to_feed = move_to_cuda(batch) 309 | if fp16: 310 | batch_to_feed = convert_to_half(batch_to_feed) 311 | with torch.no_grad(): 312 | results = model(batch_to_feed) 313 | product = torch.mm(results["q"], results["c"].t()) 314 | target = torch.arange(product.size(0)).to(product.device) 315 | prediction = product.argmax(-1) 316 | pred_res = prediction == target 317 | num_total += len(pred_res) 318 | num_correct += sum(pred_res) 319 | 320 | ## linear combination tuning on dev data 321 | acc = num_correct/num_total 322 | best_acc = 0 323 | if acc > best_acc: 324 | best_acc = acc 325 | print(f"evaluated {num_total} examples...") 326 | print(f"avg. Acc: {acc}") 327 | 328 | 329 | if fp16: 330 | model.float() 331 | model.train() 332 | 333 | return best_acc 334 | 335 | 336 | if __name__ == "__main__": 337 | main() 338 | -------------------------------------------------------------------------------- /retrieval/train_retriever_cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_retriever.py \ 4 | --do_train \ 5 | --prefix retriever_pretraining_cluster \ 6 | --predict_batch_size 512 \ 7 | --bert_model_name bert-base-uncased \ 8 | --train_batch_size 640 \ 9 | --gradient_accumulation_steps 8 \ 10 | --accumulate_gradients 8 \ 11 | --efficient_eval \ 12 | --learning_rate 1e-5 \ 13 | --train_file ../data/data_splits/\ 14 | --predict_file ../data/retrieve_dev_shuffled.txt \ 15 | --seed 87 \ 16 | --init_checkpoint logs/retrieve_train.txt-seed87-bsz640-fp16True-retriever_pretraining_single-lr1e-05-bert-base-uncased-filterTrue/checkpoint_last.pt \ 17 | --eval-period 800 \ 18 | --filter 19 | -------------------------------------------------------------------------------- /retrieval/train_retriever_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_retriever.py \ 4 | --do_train \ 5 | --prefix retriever_pretraining_single \ 6 | --predict_batch_size 512 \ 7 | --bert_model_name bert-base-uncased \ 8 | --train_batch_size 640 \ 9 | --gradient_accumulation_steps 8 \ 10 | --accumulate_gradients 8 \ 11 | --efficient_eval \ 12 | --learning_rate 1e-5 \ 13 | --fp16 \ 14 | --train_file ../data/retrieve_train.txt \ 15 | --predict_file ../data/retrieve_dev_shuffled.txt \ 16 | --seed 87 \ 17 | --eval-period 800 \ 18 | --filter 19 | -------------------------------------------------------------------------------- /retrieval/trec_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import faiss 7 | 8 | def prepare_corpus(path="../data/trec-2019/collection.tsv", save_path="../data/trec-2019/msmarco_paras.txt"): 9 | corpus = [] 10 | for line in tqdm(open(path).readlines()): 11 | line = line.strip() 12 | pid, text = line.split("\t") 13 | corpus.append({"text": text, "id": int(pid)}) 14 | with open(save_path, "w") as g: 15 | for _ in corpus: 16 | g.write(json.dumps(_) + "\n") 17 | 18 | def extract_labels( 19 | input="../data/trec-2019/qrels.train.tsv", 20 | output="../data/trec-2019/msmacro-train.txt", 21 | queries="../data/trec-2019/queries.train.tsv" 22 | ): 23 | # id2queries 24 | qid2query = {} 25 | for line in open(queries).readlines(): 26 | line = line.strip() 27 | qid, q = line.split("\t")[0], line.split("\t")[1] 28 | if q.endswith("?"): 29 | q = q[:-1] 30 | qid2query[int(qid)] = q 31 | print(len(qid2query)) 32 | 33 | # queries with groundtruths 34 | qid2ground = defaultdict(list) 35 | for line in open(input).readlines(): 36 | line = line.strip() 37 | qid, pid = line.split("\t")[0], line.split("\t")[2] 38 | qid2ground[int(qid)].append(int(pid)) 39 | print(len(qid2ground)) 40 | 41 | # generate data for train/dev 42 | with open(output, "w") as g: 43 | for qid, labels in qid2ground.items(): 44 | question = qid2query[qid] 45 | sample = {"question":question, "labels": labels, "qid": qid} 46 | g.write(json.dumps(sample) + "\n") 47 | 48 | 49 | def debug(): 50 | top1000_dev = open("../data/trec-2019/top1000.dev").readlines() 51 | qid2top10000 = defaultdict(list) 52 | for l in top1000_dev: 53 | qid2top10000[int(l.split("\t")[0])].append(int(l.split("\t")[1])) 54 | print(len(qid2top10000)) 55 | 56 | processed_dev = [json.loads(l) for l in tqdm(open( 57 | "../data/trec-2019/processed/dev.txt").readlines())] 58 | qid2ground = {_["qid"]: _["labels"] for _ in processed_dev} 59 | 60 | covered = [] 61 | for qid in qid2top10000.keys(): 62 | top1000_labels = [int(_ in qid2ground[qid]) for _ in qid2top10000[qid]] 63 | covered.append(int(np.sum(top1000_labels) > 0)) 64 | 65 | print(len(covered)) 66 | print(np.mean(covered)) 67 | 68 | 69 | def retrieve_topk(index_path="../data/trec-2019/embeds/msmarco_paras_embed.npy", query_embeds="../data/trec-2019/embeds/msmarco-train-query.npy", query_input="../data/trec-2019/msmacro-train.txt", output="../data/trec-2019/processed/train.txt"): 70 | d = 128 71 | xq = np.load(query_embeds).astype('float32') 72 | xb = np.load(index_path).astype('float32') 73 | 74 | index = faiss.IndexFlatIP(d) # build the index 75 | index.add(xb) # add vectors to the index 76 | D, I = index.search(xq, 10000) # actual search 77 | 78 | raw_data = [json.loads(l) for l in open(query_input).readlines()] 79 | 80 | processed = [] 81 | covered = [] 82 | for idx, para_indice in enumerate(I): 83 | orig_sample = raw_data[idx] 84 | para_embed_idx = [int(_) for _ in para_indice] 85 | para_labels = [int(_ in orig_sample["labels"]) for _ in para_embed_idx] 86 | orig_sample["para_embed_idx"] = para_embed_idx 87 | orig_sample["para_labels"] = para_labels 88 | processed.append(orig_sample) 89 | covered.append(int(np.sum(para_labels) > 0)) 90 | 91 | print(f"Avg recall: {np.mean(covered)}") 92 | with open(output, "w") as g: 93 | for _ in processed: 94 | g.write(json.dumps(_) + "\n") 95 | 96 | 97 | if __name__ == "__main__": 98 | # prepare_corpus() 99 | # extract_labels(input="../data/trec-2019/qrels.dev.small.tsv", 100 | # output="../data/trec-2019/msmacro-dev-small.txt", 101 | # queries="../data/trec-2019/queries.dev.tsv") 102 | 103 | # debug() 104 | 105 | retrieve_topk() 106 | -------------------------------------------------------------------------------- /retrieval/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sqlite3 3 | import unicodedata 4 | 5 | def move_to_cuda(sample): 6 | if len(sample) == 0: 7 | return {} 8 | 9 | def _move_to_cuda(maybe_tensor): 10 | if torch.is_tensor(maybe_tensor): 11 | return maybe_tensor.cuda() 12 | elif isinstance(maybe_tensor, dict): 13 | return { 14 | key: _move_to_cuda(value) 15 | for key, value in maybe_tensor.items() 16 | } 17 | elif isinstance(maybe_tensor, list): 18 | return [_move_to_cuda(x) for x in maybe_tensor] 19 | else: 20 | return maybe_tensor 21 | 22 | return _move_to_cuda(sample) 23 | 24 | def convert_to_half(sample): 25 | if len(sample) == 0: 26 | return {} 27 | 28 | def _convert_to_half(maybe_floatTensor): 29 | if torch.is_tensor(maybe_floatTensor) and maybe_floatTensor.type() == "torch.FloatTensor": 30 | return maybe_floatTensor.half() 31 | elif isinstance(maybe_floatTensor, dict): 32 | return { 33 | key: _convert_to_half(value) 34 | for key, value in maybe_floatTensor.items() 35 | } 36 | elif isinstance(maybe_floatTensor, list): 37 | return [_convert_to_half(x) for x in maybe_floatTensor] 38 | else: 39 | return maybe_floatTensor 40 | 41 | return _convert_to_half(sample) 42 | 43 | 44 | class AverageMeter(object): 45 | """Computes and stores the average and current value""" 46 | 47 | def __init__(self): 48 | self.reset() 49 | 50 | def reset(self): 51 | self.val = 0 52 | self.avg = 0 53 | self.sum = 0 54 | self.count = 0 55 | 56 | def update(self, val, n=1): 57 | self.val = val 58 | self.sum += val * n 59 | self.count += n 60 | self.avg = self.sum / self.count 61 | 62 | 63 | def normalize(text): 64 | """Resolve different type of unicode encodings.""" 65 | return unicodedata.normalize('NFD', text) 66 | 67 | 68 | class DocDB(object): 69 | """Sqlite backed document storage. 70 | 71 | Implements get_doc_text(doc_id). 72 | """ 73 | 74 | def __init__(self, db_path=None): 75 | self.path = db_path 76 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 77 | 78 | def __enter__(self): 79 | return self 80 | 81 | def __exit__(self, *args): 82 | self.close() 83 | 84 | def close(self): 85 | """Close the connection to the database.""" 86 | self.connection.close() 87 | 88 | def get_doc_ids(self): 89 | """Fetch all ids of docs stored in the db.""" 90 | cursor = self.connection.cursor() 91 | cursor.execute("SELECT id FROM documents") 92 | results = [r[0] for r in cursor.fetchall()] 93 | cursor.close() 94 | return results 95 | 96 | def get_doc_text(self, doc_id): 97 | """Fetch the raw text of the doc for 'doc_id'.""" 98 | cursor = self.connection.cursor() 99 | cursor.execute( 100 | "SELECT text FROM documents WHERE id = ?", 101 | (normalize(doc_id),) 102 | ) 103 | result = cursor.fetchone() 104 | cursor.close() 105 | return result if result is None else result[0] 106 | --------------------------------------------------------------------------------