├── .gitignore
├── README.md
├── qa
├── basic_tokenizer.py
├── bert_retrieve_qa.py
├── config.py
├── datasets.py
├── eval_utils.py
├── msmarco_process.py
├── official_eval.py
├── online_sampler.py
├── prepro_dense.py
├── prepro_utils.py
├── tokenizer.py
├── train.py
├── train_dense_qa.sh
├── train_retrieve_qa.py
└── utils.py
├── requirements.txt
└── retrieval
├── basic_tokenizer.py
├── config.py
├── datasets.py
├── eval_retrieval.py
├── gen_index_id_map.py
├── get_embed.py
├── get_para_embed.sh
├── group_paras.py
├── retriever.py
├── tokenizer.py
├── train_retriever.py
├── train_retriever_cluster.sh
├── train_retriever_single.sh
├── trec_process.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | /data
3 | /pretrained_models
4 | *.zip
5 | retrieval/logs/
6 | __MACOSX/
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProQA
2 |
3 | Resource-efficient method for pretraining a dense corpus index for open-domain QA and IR. Given a question, you could use this code to retrieval relevant paragraphs from Wikipedia and extract answers.
4 |
5 | ## 1. Set up the environments
6 | ```
7 | conda create -n proqa -y python=3.6.9 && conda activate proqa
8 | pip install -r requirements.txt
9 | ```
10 | If you want to used mixed precision training, you need to follow [Nvidia Apex repo](https://github.com/NVIDIA/apex) to install Apex if your GPUs support fp16.
11 |
12 | ## 2. Download data (including the corpus, paragraphs paired with the generated questions, etc.)
13 | ```
14 | gdown https://drive.google.com/uc?id=17IMQ5zzfkCNsTZNJqZI5KveoIsaG2ZDt && unzip data.zip
15 | cd data && gdown https://drive.google.com/uc?id=1T1SntmAZxJ6QfNBN39KbAHcMw0JR5MwL
16 | ```
17 | The data folder includes the QA datasets and also the paragraph database ``nq_paras.db`` which can be used with sqlite3. **If the command line fails to download the file, please use your brower instead.**
18 |
19 | ## 2. Use pretrained index and models
20 | Download the pretrained models and data from google drive:
21 | ```
22 | gdown https://drive.google.com/uc?id=1fDRHsLk5emLqHSMkkoockoHjRSOEBaZw && unzip pretrained_models.zip
23 | ```
24 |
25 | ### Test the Retrieval Performance Before QA finetuning
26 | * First, encode all the questions as embeddings (use WebQuestions text for this example):
27 | ```
28 | cd retrieval
29 | CUDA_VISIBLE_DEVICES=0 python get_embed.py \
30 | --do_predict \
31 | --predict_batch_size 512 \
32 | --bert_model_name bert-base-uncased \
33 | --fp16 \
34 | --predict_file ../data/WebQuestions-test.txt \
35 | --init_checkpoint ../pretrained_models/retriever.pt \
36 | --is_query_embed \
37 | --embed_save_path ../data/wq_test_query_embed.npy
38 | ```
39 |
40 | * Retrieval topk (k=80) paragraphs from the corpus and evaluate recall with simple string matching
41 | ```
42 | python eval_retrieval.py ../data/WebQuestions-test.txt ../pretrained_models/para_embed.npy ../data/wq_test_query_embed.npy ../data/nq_paras.db
43 | ```
44 | The arguments are the dataset file, dense corpus index, question embeddings and the paragraph database. The results should be like:
45 | ```
46 | Top 80 Recall for 2032 QA pairs: 0.7839566929133859 ...
47 | Top 5 Recall for 2032 QA pairs: 0.5196850393700787 ...
48 | Top 10 Recall for 2032 QA pairs: 0.610236220472441 ...
49 | Top 20 Recall for 2032 QA pairs: 0.687007874015748 ...
50 | Top 50 Recall for 2032 QA pairs: 0.7554133858267716 ...
51 | ```
52 |
53 | ## 3. Retriever pretraining
54 | ### Use a single pretraining file:
55 | * Under the `retrieval` directory:
56 | ```
57 | cd retrieval
58 | ./train_retriever_single.sh
59 | ```
60 | This script will use the unclustered the data for pretraining. After certain updates, we will pause the training and use the following steps to cluster the data and continue training. This will save a checkpoint under `retrieval/logs/`.
61 |
62 | ### Use clutered data for pretraining:
63 | #### Generate paragraph clusters
64 | * Generate the paragraph embeddings using the checkpoint from last step:
65 | ```
66 | mkdir encodings
67 | CUDA_VISIBLE_DEVICES=0 python get_embed.py --do_predict --prefix eval-para \
68 | --predict_batch_size 300 \
69 | --bert_model_name bert-base-uncased \
70 | --fp16 \
71 | --predict_file ../data/retrieve_train.txt \
72 | --init_checkpoint ../pretrained_models/retriever.pt \
73 | --embed_save_path encodings/train_para_embed.npy \
74 | --eval-workers 32 \
75 | --fp16
76 | ```
77 | * Generate clusters using the paragraph embeddings:
78 | ```
79 | python group_paras.py
80 | ```
81 | Clustering hyperparameter settings such as num of clusters can be found in `group_paras.py`.
82 |
83 | #### Pretraining using clusters
84 | * Then run the retrieval script:
85 | ```
86 | ./train_retriever_cluster.sh
87 | ```
88 |
89 | ## 4. QA finetuning
90 | * Generate the paragraph dense index under "retrieval" directory: ``./get_para_embed.sh``
91 | * Finetune the pretraining model on the QA dataset under "qa" directory: ``./train_dense_qa.sh``
92 |
--------------------------------------------------------------------------------
/qa/basic_tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Base tokenizer/tokens classes and utilities."""
8 |
9 | import copy
10 |
11 |
12 | class Tokens(object):
13 | """A class to represent a list of tokenized text."""
14 | TEXT = 0
15 | TEXT_WS = 1
16 | SPAN = 2
17 | POS = 3
18 | LEMMA = 4
19 | NER = 5
20 |
21 | def __init__(self, data, annotators, opts=None):
22 | self.data = data
23 | self.annotators = annotators
24 | self.opts = opts or {}
25 |
26 | def __len__(self):
27 | """The number of tokens."""
28 | return len(self.data)
29 |
30 | def slice(self, i=None, j=None):
31 | """Return a view of the list of tokens from [i, j)."""
32 | new_tokens = copy.copy(self)
33 | new_tokens.data = self.data[i: j]
34 | return new_tokens
35 |
36 | def untokenize(self):
37 | """Returns the original text (with whitespace reinserted)."""
38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
39 |
40 | def words(self, uncased=False):
41 | """Returns a list of the text of each token
42 |
43 | Args:
44 | uncased: lower cases text
45 | """
46 | if uncased:
47 | return [t[self.TEXT].lower() for t in self.data]
48 | else:
49 | return [t[self.TEXT] for t in self.data]
50 |
51 | def offsets(self):
52 | """Returns a list of [start, end) character offsets of each token."""
53 | return [t[self.SPAN] for t in self.data]
54 |
55 | def pos(self):
56 | """Returns a list of part-of-speech tags of each token.
57 | Returns None if this annotation was not included.
58 | """
59 | if 'pos' not in self.annotators:
60 | return None
61 | return [t[self.POS] for t in self.data]
62 |
63 | def lemmas(self):
64 | """Returns a list of the lemmatized text of each token.
65 | Returns None if this annotation was not included.
66 | """
67 | if 'lemma' not in self.annotators:
68 | return None
69 | return [t[self.LEMMA] for t in self.data]
70 |
71 | def entities(self):
72 | """Returns a list of named-entity-recognition tags of each token.
73 | Returns None if this annotation was not included.
74 | """
75 | if 'ner' not in self.annotators:
76 | return None
77 | return [t[self.NER] for t in self.data]
78 |
79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
80 | """Returns a list of all ngrams from length 1 to n.
81 |
82 | Args:
83 | n: upper limit of ngram length
84 | uncased: lower cases text
85 | filter_fn: user function that takes in an ngram list and returns
86 | True or False to keep or not keep the ngram
87 | as_string: return the ngram as a string vs list
88 | """
89 | def _skip(gram):
90 | if not filter_fn:
91 | return False
92 | return filter_fn(gram)
93 |
94 | words = self.words(uncased)
95 | ngrams = [(s, e + 1)
96 | for s in range(len(words))
97 | for e in range(s, min(s + n, len(words)))
98 | if not _skip(words[s:e + 1])]
99 |
100 | # Concatenate into strings
101 | if as_strings:
102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
103 |
104 | return ngrams
105 |
106 | def entity_groups(self):
107 | """Group consecutive entity tokens with the same NER tag."""
108 | entities = self.entities()
109 | if not entities:
110 | return None
111 | non_ent = self.opts.get('non_ent', 'O')
112 | groups = []
113 | idx = 0
114 | while idx < len(entities):
115 | ner_tag = entities[idx]
116 | # Check for entity tag
117 | if ner_tag != non_ent:
118 | # Chomp the sequence
119 | start = idx
120 | while (idx < len(entities) and entities[idx] == ner_tag):
121 | idx += 1
122 | groups.append((self.slice(start, idx).untokenize(), ner_tag))
123 | else:
124 | idx += 1
125 | return groups
126 |
127 |
128 | class Tokenizer(object):
129 | """Base tokenizer class.
130 | Tokenizers implement tokenize, which should return a Tokens class.
131 | """
132 |
133 | def tokenize(self, text):
134 | raise NotImplementedError
135 |
136 | def shutdown(self):
137 | pass
138 |
139 | def __del__(self):
140 | self.shutdown()
141 |
142 |
143 | import regex
144 | import logging
145 |
146 | logger = logging.getLogger(__name__)
147 |
148 |
149 | class RegexpTokenizer(Tokenizer):
150 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*'
151 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)'
152 | r'\.(?=\p{Z})')
153 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)'
154 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++'
155 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM)
156 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't"
157 | CONTRACTION1 = r"can(?=not\b)"
158 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b"
159 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})'
160 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})'
161 | END_DQUOTE = r'(?%s)|(?P
%s)|(?P%s)|(?P%s)|(?P%s)|'
176 | '(?P%s)|(?P%s)|(?P%s)|(?P%s)|'
177 | '(?P%s)|(?P%s)|(?P%s)|(?P%s)|'
178 | '(?%s)|(?P%s)|(?P%s)' %
179 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN,
180 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2,
181 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE,
182 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT,
183 | self.NON_WS),
184 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
185 | )
186 | if len(kwargs.get('annotators', {})) > 0:
187 | logger.warning('%s only tokenizes! Skipping annotators: %s' %
188 | (type(self).__name__, kwargs.get('annotators')))
189 | self.annotators = set()
190 | self.substitutions = kwargs.get('substitutions', True)
191 |
192 | def tokenize(self, text):
193 | data = []
194 | matches = [m for m in self._regexp.finditer(text)]
195 | for i in range(len(matches)):
196 | # Get text
197 | token = matches[i].group()
198 |
199 | # Make normalizations for special token types
200 | if self.substitutions:
201 | groups = matches[i].groupdict()
202 | if groups['sdquote']:
203 | token = "``"
204 | elif groups['edquote']:
205 | token = "''"
206 | elif groups['ssquote']:
207 | token = "`"
208 | elif groups['esquote']:
209 | token = "'"
210 | elif groups['dash']:
211 | token = '--'
212 | elif groups['ellipses']:
213 | token = '...'
214 |
215 | # Get whitespace
216 | span = matches[i].span()
217 | start_ws = span[0]
218 | if i + 1 < len(matches):
219 | end_ws = matches[i + 1].span()[0]
220 | else:
221 | end_ws = span[1]
222 |
223 | # Format data
224 | data.append((
225 | token,
226 | text[start_ws: end_ws],
227 | span,
228 | ))
229 | return Tokens(data, self.annotators)
230 |
231 |
232 | class SimpleTokenizer(Tokenizer):
233 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
234 | NON_WS = r'[^\p{Z}\p{C}]'
235 |
236 | def __init__(self, **kwargs):
237 | """
238 | Args:
239 | annotators: None or empty set (only tokenizes).
240 | """
241 | self._regexp = regex.compile(
242 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
243 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
244 | )
245 | if len(kwargs.get('annotators', {})) > 0:
246 | logger.warning('%s only tokenizes! Skipping annotators: %s' %
247 | (type(self).__name__, kwargs.get('annotators')))
248 | self.annotators = set()
249 |
250 | def tokenize(self, text):
251 | data = []
252 | matches = [m for m in self._regexp.finditer(text)]
253 | for i in range(len(matches)):
254 | # Get text
255 | token = matches[i].group()
256 |
257 | # Get whitespace
258 | span = matches[i].span()
259 | start_ws = span[0]
260 | if i + 1 < len(matches):
261 | end_ws = matches[i + 1].span()[0]
262 | else:
263 | end_ws = span[1]
264 |
265 | # Format data
266 | data.append((
267 | token,
268 | text[start_ws: end_ws],
269 | span,
270 | ))
271 | return Tokens(data, self.annotators)
272 |
273 |
274 |
275 |
--------------------------------------------------------------------------------
/qa/bert_retrieve_qa.py:
--------------------------------------------------------------------------------
1 | from transformers import BertModel, BertConfig, BertPreTrainedModel
2 | import torch.nn as nn
3 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | import sys
8 | sys.path.append('../retrieval')
9 | from retriever import BertForRetriever
10 |
11 |
12 | class BertRetrieveQA(nn.Module):
13 |
14 | def __init__(self,
15 | config,
16 | args
17 | ):
18 | super(BertRetrieveQA, self).__init__()
19 | self.shared_norm = args.shared_norm
20 | self.separate = args.separate
21 | self.add_select = args.add_select
22 | self.drop_early = args.drop_early
23 |
24 | if args.use_spanbert:
25 | self.bert = BertModel.from_pretrained(args.spanbert_path)
26 | else:
27 | self.bert = BertModel.from_pretrained(args.bert_model_name)
28 |
29 | # parameters from pretrained index
30 | self.retriever = BertForRetriever(config, args)
31 | if args.retriever_path != "":
32 | self.load_pretrained_retriever(args.retriever_path)
33 |
34 | self.qa_outputs = nn.Linear(
35 | config.hidden_size, 2)
36 | self.qa_drop = nn.Dropout(args.qa_drop)
37 | self.shared_norm = args.shared_norm
38 |
39 | if self.add_select:
40 | self.select_outputs = nn.Linear(config.hidden_size, 1)
41 |
42 | def load_pretrained_retriever(self, path):
43 | state_dict = torch.load(path)
44 | def filter(x): return x[7:] if x.startswith('module.') else x
45 | state_dict = {filter(k): v for (k, v) in state_dict.items()}
46 | self.retriever.load_state_dict(state_dict)
47 |
48 | def freeze_c_encoder(self):
49 | for p in self.retriever.bert_c.parameters():
50 | p.requires_grad = False
51 | for p in self.retriever.proj_c.parameters():
52 | p.requires_grad = False
53 |
54 | def freeze_retriever(self):
55 | for p in self.retriever.parameters():
56 | p.requires_grad = False
57 |
58 | def forward(self, batch):
59 | input_ids, attention_mask, token_type_ids = batch[
60 | "input_ids"], batch["input_mask"], batch["segment_ids"]
61 | outputs = self.bert(input_ids, attention_mask, token_type_ids)
62 | sequence_output = outputs[0]
63 |
64 | logits = self.qa_outputs(self.qa_drop(sequence_output))
65 | outs = [o.squeeze(-1) for o in logits.split(1, dim=-1)]
66 | outs = [o.float().masked_fill(batch["paragraph_mask"].ne(1), -1e10).type_as(o)
67 | for o in outs]
68 |
69 | start_logits = outs[0]
70 | end_logits = outs[1]
71 |
72 | input_ids_q, attention_mask_q = batch["input_ids_q"], batch["input_mask_q"]
73 | q_cls = self.retriever.bert_q(input_ids_q, attention_mask_q)[1]
74 | q = self.retriever.proj_q(q_cls)
75 |
76 | rank_logits = q[0].unsqueeze(0).mm(batch["para_embed"].t())
77 | rank_probs = F.softmax(rank_logits, dim=-1)
78 |
79 | if self.add_select:
80 | pooled_output = outputs[1]
81 | select_logits = self.select_outputs(pooled_output)
82 |
83 | if self.training:
84 | start_positions, end_positions, rank_targets = batch[
85 | "start_positions"], batch["end_positions"], batch["para_targets"]
86 | loss_fct = CrossEntropyLoss(ignore_index=-1, reduction="none")
87 |
88 | if not self.drop_early:
89 | # early loss
90 | para_targets = batch["top5000_labels"].nonzero()
91 | early_losses = [loss_fct(rank_logits, p)
92 | for p in torch.unbind(para_targets)]
93 | if len(early_losses) == 0:
94 | early_loss = loss_fct(start_logits, start_logits.new_zeros(
95 | start_logits.size(0)).long()-1).sum()
96 | else:
97 | early_loss = - \
98 | torch.log(torch.sum(torch.exp(-torch.cat(early_losses))))
99 |
100 | if self.add_select:
101 | select_logits_flat = select_logits.view(1, -1)
102 | select_probs = F.softmax(select_logits_flat, dim=-1)
103 |
104 | if self.separate:
105 | select_targets_flat = rank_targets.view(1, -1)
106 | select_targets_flat = select_targets_flat.nonzero()[
107 | :, 1].unsqueeze(1)
108 | select_losses = [loss_fct(select_logits_flat, r)
109 | for r in torch.unbind(select_targets_flat)]
110 | if len(select_losses) == 0:
111 | select_loss = loss_fct(
112 | select_logits_flat, select_logits_flat.new_zeros(1).long()-1).sum()
113 | else:
114 | select_loss = - torch.log(torch.sum(torch.exp(-torch.cat(select_losses))))
115 |
116 |
117 | # two ways to calculate the span probabilities
118 | if self.shared_norm:
119 | offset = (torch.arange(start_positions.size(
120 | 0)) * start_logits.size(1)).unsqueeze(1).to(start_positions.device)
121 | start_positions_ = start_positions + \
122 | (start_positions != -1) * offset
123 | end_positions_ = end_positions + (end_positions != -1) * offset
124 | start_positions_ = start_positions_.view(-1, 1)
125 | end_positions_ = end_positions_.view(-1, 1)
126 | start_logits_flat = start_logits.view(1, -1)
127 | end_logits_flat = end_logits.view(1, -1)
128 | start_losses = [loss_fct(start_logits_flat, s)
129 | for s in torch.unbind(start_positions_)]
130 | end_losses = [loss_fct(end_logits_flat, e)
131 | for e in torch.unbind(end_positions_)]
132 | loss_tensor = - (torch.cat(start_losses) +
133 | torch.cat(end_losses))
134 | loss_tensor = loss_tensor.view(start_positions.size())
135 | log_prob = loss_tensor.float().masked_fill(
136 | loss_tensor == 0, float('-inf')).type_as(loss_tensor)
137 | else:
138 | start_losses = [loss_fct(start_logits, starts) for starts in torch.unbind(start_positions, dim=1)]
139 | end_losses = [loss_fct(end_logits, ends) for ends in torch.unbind(end_positions, dim=1)]
140 | loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat([t.unsqueeze(1) for t in end_losses], dim=1)
141 | log_prob = - loss_tensor
142 | log_prob = log_prob.float().masked_fill(log_prob == 0, float('-inf')).type_as(log_prob)
143 |
144 | # marginal probabily for each paragraph
145 | probs = torch.exp(log_prob)
146 | marginal_probs = torch.sum(probs, dim=1)
147 |
148 | # joint or separate loss functions
149 | if self.separate:
150 | m_prob = [marginal_probs[idx] for idx in marginal_probs.nonzero()]
151 | if len(m_prob) == 0:
152 | span_loss = loss_fct(start_logits, start_logits.new_zeros(
153 | start_logits.size(0)).long()-1).sum()
154 | else:
155 | span_loss = - torch.log(torch.sum(torch.cat(m_prob)))
156 | total_loss = span_loss + select_loss + early_loss if self.add_select else span_loss + early_loss
157 |
158 | else:
159 | if self.add_select:
160 | rank_probs = select_probs
161 |
162 | joint_prob = marginal_probs * rank_probs.view(-1)[:marginal_probs.size(0)]
163 | joint_prob = [joint_prob[idx] for idx in marginal_probs.nonzero()]
164 | if len(joint_prob) == 0:
165 | joint_loss = loss_fct(start_logits, start_logits.new_zeros(
166 | start_logits.size(0)).long()-1).sum()
167 | else:
168 | joint_loss = - torch.log(torch.sum(torch.cat(joint_prob)))
169 | total_loss = joint_loss + early_loss
170 |
171 | return {"loss": total_loss}
172 |
173 | if self.add_select:
174 | return {"start_logits": start_logits, "end_logits": end_logits, "rank_logits": rank_logits, "select_logits": select_logits.view(1, -1)}
175 | else:
176 | return {"start_logits": start_logits, "end_logits": end_logits, "rank_logits": rank_logits}
177 |
--------------------------------------------------------------------------------
/qa/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser()
6 |
7 | # Required parameters
8 | parser.add_argument("--bert_model_name",
9 | default="bert-base-uncased", type=str)
10 | parser.add_argument("--output_dir", default="logs", type=str,
11 | help="The output directory where the model checkpoints will be written.")
12 | parser.add_argument("--weight_decay", default=0.0, type=float,
13 | help="Weight decay if we apply some.")
14 |
15 | # Other parameters
16 | parser.add_argument("--load", default=False, action='store_true')
17 | parser.add_argument("--num_workers", default=5, type=int)
18 | parser.add_argument("--train_file", type=str,
19 | default="../../data/mrqa-train/HotpotQA-tokenized.jsonl")
20 | parser.add_argument("--predict_file", type=str,
21 | default="../../data/mrqa-dev/HotpotQA-tokenized.jsonl")
22 | parser.add_argument("--init_checkpoint", type=str,
23 | help="Initial checkpoint (usually from a pre-trained BERT model).",
24 | default="")
25 | parser.add_argument("--do_lower_case", default=True, action='store_true',
26 | help="Whether to lower case the input text. Should be True for uncased"
27 | "models and False for cased models.")
28 | parser.add_argument("--max_seq_length", default=512, type=int,
29 | help="The maximum total input sequence length after WordPiece tokenization. Sequences "
30 | "longer than this will be truncated, and sequences shorter than this will be padded.")
31 | parser.add_argument("--max_query_length", default=50, type=int,
32 | help="The maximum number of tokens for the question. Questions longer than this will "
33 | "be truncated to this length.")
34 | parser.add_argument("--do_train", default=False,
35 | action='store_true', help="Whether to run training.")
36 | parser.add_argument("--do_predict", default=False,
37 | action='store_true', help="Whether to run eval on the dev set.")
38 | parser.add_argument("--train_batch_size", default=8,
39 | type=int, help="Total batch size for training.")
40 | parser.add_argument("--predict_batch_size", default=100,
41 | type=int, help="Total batch size for predictions.")
42 | parser.add_argument("--learning_rate", default=5e-5,
43 | type=float, help="The initial learning rate for Adam.")
44 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
45 | help="Epsilon for Adam optimizer.")
46 | parser.add_argument("--num_train_epochs", default=200, type=float,
47 | help="Total number of training epochs to perform.")
48 | parser.add_argument('--wait_step', type=int, default=100)
49 | parser.add_argument("--save_checkpoints_steps", default=1000, type=int,
50 | help="How often to save the model checkpoint.")
51 | parser.add_argument("--iterations_per_loop", default=1000, type=int,
52 | help="How many steps to make in each estimator call.")
53 | parser.add_argument("--no_cuda", default=False, action='store_true',
54 | help="Whether not to use CUDA when available")
55 | parser.add_argument("--local_rank", type=int, default=-1,
56 | help="local_rank for distributed training on gpus")
57 | parser.add_argument("--accumulate_gradients", type=int, default=1,
58 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)")
59 | parser.add_argument('--seed', type=int, default=3,
60 | help="random seed for initialization")
61 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
62 | help="Number of updates steps to accumualte before performing a backward/update pass.")
63 | parser.add_argument('--eval_period', type=int, default=1000, help="setting to -1: eval only after each epoch")
64 | parser.add_argument('--verbose', action="store_true", default=False)
65 | parser.add_argument('--efficient_eval', action="store_true", help="whether to use fp16 for evaluation")
66 | parser.add_argument('--max_answer_len', default=20, type=int)
67 | parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
68 |
69 | parser.add_argument('--fp16', action='store_true')
70 | parser.add_argument('--fp16_opt_level', type=str, default='O1',
71 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
72 | "See details at https://nvidia.github.io/apex/amp.html")
73 |
74 | # BERT QA
75 | parser.add_argument("--qa-drop", default=0, type=float)
76 | parser.add_argument("--rank-drop", default=0, type=float)
77 |
78 | parser.add_argument("--MI", action="store_true", help="Use MI regularization to improve weak supervision")
79 | parser.add_argument("--mi-k", default=10, type=int, help="negative sample number")
80 | parser.add_argument("--max-pool", action="store_true", help="CLS or maxpooling")
81 | parser.add_argument("--eval-workers", default=16, help="parallel data loader", type=int)
82 | parser.add_argument("--save-pred", action="store_true", help="uncertainty analysis")
83 | parser.add_argument("--retriever-path", type=str, default="", help="pretrained retriever checkpoint")
84 |
85 | parser.add_argument("--raw-train-data", type=str,
86 | default="../data/nq-train.txt")
87 | parser.add_argument("--raw-eval-data", type=str,
88 | default="../data/nq-dev.txt")
89 | parser.add_argument("--fix-para-encoder", action="store_true")
90 | parser.add_argument("--db-path", type=str,
91 | default='../data/nq_paras.db')
92 | parser.add_argument("--index-path", type=str,
93 | default="retrieval/index_data/para_embed_100k.npy")
94 | parser.add_argument("--matched-para-path", type=str,
95 | default="../data/wq_ft_train_matched.txt")
96 |
97 | parser.add_argument("--use-spanbert", action="store_true", help="use spanbert for question answering")
98 | parser.add_argument("--spanbert-path",
99 | default="../data/span_bert", type=str)
100 | parser.add_argument("--eval-k", default=5, type=int)
101 | parser.add_argument("--regex", action="store_true", help="for CuratedTrec")
102 |
103 | # investigate different kinds of loss functions
104 | parser.add_argument("--separate", action="store_true", help="separate the rank and reader loss")
105 | parser.add_argument("--add-select", action="store_true", help="replace the rank probability with the selection probility from the reader model ([CLS])")
106 | parser.add_argument("--drop-early", action="store_true", help="drop the early loss on topk5000")
107 | parser.add_argument("--shared-norm", action="store_true",
108 | help="normalize span logits across different paragraphs")
109 |
110 | # parser.add_argument("--fix-retriever", action="store_true")
111 | # parser.add_argument("--joint-train", action="store_true")
112 | # parser.add_argument("--mixed", action="store_true",help="shared norm and also use the rank probabilities in loss")
113 | # parser.add_argument("--use-adam", action="store_true")
114 | # parser.add_argument("--para-embed-path", type=str, default="")
115 | # parser.add_argument("--retrieved-path", type=str, default="")
116 |
117 | # For evaluation
118 | parser.add_argument('--prefix', type=str, default="eval")
119 | parser.add_argument('--debug', action="store_true")
120 | parser.add_argument('--use-top-passage', action="store_true")
121 | parser.add_argument('--topk', default=30, type=int)
122 | parser.add_argument('--save-all', action="store_true", help="save the predictions")
123 | parser.add_argument('--candidates', default="", type=str, help="restrict the predicted spans to be entities")
124 |
125 | args = parser.parse_args()
126 |
127 | return args
128 |
--------------------------------------------------------------------------------
/qa/datasets.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader, Dataset, Sampler
2 | import torch
3 | import json
4 | import numpy as np
5 | import random
6 | from tqdm import tqdm
7 |
8 | from joblib import Parallel, delayed
9 |
10 | from prepro_utils import hash_question
11 |
12 | def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
13 | """Convert a list of 1d tensors into a padded 2d tensor."""
14 | size = max(v.size(0) for v in values)
15 | res = values[0].new(len(values), size).fill_(pad_idx)
16 |
17 | def copy_tensor(src, dst):
18 | assert dst.numel() == src.numel()
19 | if move_eos_to_beginning:
20 | assert src[-1] == eos_idx
21 | dst[0] = eos_idx
22 | dst[1:] = src[:-1]
23 | else:
24 | dst.copy_(src)
25 |
26 | for i, v in enumerate(values):
27 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
28 | return res
29 |
30 |
31 | class OpenQASampler(Sampler):
32 | """
33 | Shuffle QA pairs not context, make sure data within the batch are from the same QA pair
34 | """
35 |
36 | def __init__(self, data_source, batch_size):
37 | self.batch_size = batch_size
38 | # for each QA pair, sample negative paragraphs
39 | sample_indice = []
40 | for qa_idx in range(len(data_source.qids)):
41 | batch_data = []
42 | batch_data.append(random.choice(data_source.grouped_idx_has_answer[qa_idx]))
43 | assert len(batch_data) >= 1
44 | if len(data_source.grouped_idx_no_answer[qa_idx]) < self.batch_size - len(batch_data):
45 | # print("Too few negative samples...")
46 | # continue
47 | if len(data_source.grouped_idx_no_answer[qa_idx]) == 0:
48 | continue
49 | negative_sample = random.choices(data_source.grouped_idx_no_answer[qa_idx], k=self.batch_size - len(batch_data))
50 | else:
51 | negative_sample = random.sample(data_source.grouped_idx_no_answer[qa_idx], self.batch_size - len(batch_data))
52 | batch_data.extend(negative_sample)
53 | assert len(batch_data) == batch_size
54 | sample_indice.append(batch_data)
55 |
56 | print(f"{len(sample_indice)} QA pairs used for training...")
57 |
58 | sample_indice = np.array(sample_indice)
59 | np.random.shuffle(sample_indice)
60 | self.sample_indice = list(sample_indice.flatten())
61 |
62 | def __len__(self):
63 | return len(self.sample_indice)
64 |
65 | def __iter__(self):
66 | return iter(self.sample_indice)
67 |
68 |
69 | class BatchSampler(Sampler):
70 | """
71 | use all paragraphs, shuffle the QA pairs
72 | """
73 |
74 | def __init__(self, data_source, batch_size):
75 | self.batch_size = batch_size
76 | sample_indice = []
77 | for qa_idx in range(len(data_source.qids)):
78 | batch_data = []
79 | batch_data.extend(data_source.grouped_idx_has_answer[qa_idx])
80 | batch_data.extend(data_source.grouped_idx_no_answer[qa_idx])
81 | assert len(batch_data) == batch_size
82 | sample_indice.append(batch_data)
83 |
84 | print(f"{len(sample_indice)} QA pairs used for training...")
85 | sample_indice = np.array(sample_indice)
86 | np.random.shuffle(sample_indice)
87 | self.sample_indice = list(sample_indice.flatten())
88 |
89 | def __len__(self):
90 | return len(self.sample_indice)
91 |
92 | def __iter__(self):
93 | return iter(self.sample_indice)
94 |
95 |
96 | class OpenQADataset(Dataset):
97 |
98 | def __init__(self,
99 | tokenizer,
100 | data_path,
101 | max_query_length,
102 | max_length
103 | ):
104 | super().__init__()
105 | self.tokenizer = tokenizer
106 | print(f"Loading tokenized data from {data_path}...")
107 |
108 |
109 | self.qids = []
110 | self.all_data = [json.loads(line)
111 | for line in tqdm(open(data_path).readlines())]
112 | self.grouped_idx_has_answer = []
113 | self.grouped_idx_no_answer = []
114 | for idx, item in enumerate(self.all_data):
115 | if len(self.qids) == 0 or item["qid"] != self.qids[-1]:
116 | self.qids.append(item["qid"])
117 | self.grouped_idx_no_answer.append([])
118 | self.grouped_idx_has_answer.append([])
119 | if item["no_answer"] == 0:
120 | self.grouped_idx_has_answer[-1].append(idx)
121 | else:
122 | self.grouped_idx_no_answer[-1].append(idx)
123 |
124 | print(f"{len(self.qids)} QA pairs loaded....")
125 | self.max_query_length = max_query_length
126 | self.max_length = max_length
127 |
128 | def __getitem__(self, index):
129 | sample = self.all_data[index]
130 | qid = sample['qid']
131 | q_subtoks = sample['q_subtoks']
132 | if len(q_subtoks) > self.max_query_length:
133 | q_subtoks = q_subtoks[:self.max_query_length]
134 | question = torch.LongTensor(self.binarize_list(q_subtoks))
135 | para_offset = question.size(0) + 2
136 |
137 | para_subtoks = sample['doc_subtoks']
138 | max_tokens_for_doc = self.max_length - para_offset - 1
139 | if len(para_subtoks) > max_tokens_for_doc:
140 | para_subtoks = para_subtoks[:max_tokens_for_doc]
141 |
142 | paragraph = torch.LongTensor(self.binarize_list(para_subtoks))
143 | text, seg = self._join_sents(question, paragraph)
144 | paragraph_mask = torch.zeros(text.shape).bool()
145 | question_mask = torch.zeros(text.shape).bool()
146 | paragraph_mask[para_offset:-1] = 1
147 | question_mask[1:para_offset] = 1
148 |
149 | starts, ends, no_answer = sample["starts"], sample["ends"], sample["no_answer"]
150 |
151 | start_positions, end_positions = [], []
152 | if not no_answer:
153 | no_answer = 1
154 | for s, e in zip(starts, ends):
155 | assert s <= e
156 | if s >= paragraph.size(0):
157 | continue
158 | else:
159 | start_position = min(s, paragraph.size(0) - 1) + para_offset
160 | end_position = min(e, paragraph.size(0) - 1) + para_offset
161 | no_answer = 0
162 | start_positions.append(start_position)
163 | end_positions.append(end_position)
164 |
165 | if len(start_positions) == 0:
166 | assert no_answer
167 | start_positions.append(-1)
168 | end_positions.append(-1)
169 |
170 | start_tensor, end_tensor, no_answer = torch.LongTensor(
171 | start_positions), torch.LongTensor(end_positions), torch.LongTensor([no_answer])
172 |
173 | item_tensor = {
174 | 'q': sample["q"],
175 | 'qid': qid,
176 | 'input_ids': text,
177 | 'segment_ids': seg,
178 | "input_ids_q": self._add_special_token(question),
179 | "input_ids_c": self._add_special_token(paragraph),
180 | 'para_offset': para_offset,
181 | 'paragraph_mask': paragraph_mask,
182 | 'question_mask': question_mask,
183 | 'doc_tokens': sample['doc_toks'],
184 | 'q_subtoks': q_subtoks,
185 | 'wp_tokens': para_subtoks,
186 | 'tok_to_orig_index': sample['tok_to_orig_index'],
187 | 'true_answers': sample["true_answers"],
188 | "start": start_tensor,
189 | "end": end_tensor,
190 | "no_answer": no_answer,
191 | }
192 |
193 | return item_tensor
194 |
195 | def _join_sents(self, sent1, sent2):
196 | cls = sent1.new_full((1,), self.tokenizer.vocab["[CLS]"])
197 | sep = sent1.new_full((1,), self.tokenizer.vocab["[SEP]"])
198 | sent1 = torch.cat([cls, sent1, sep])
199 | sent2 = torch.cat([sent2, sep])
200 | text = torch.cat([sent1, sent2])
201 | segment1 = torch.zeros(sent1.size(0)).long()
202 | segment2 = torch.ones(sent2.size(0)).long()
203 | segment = torch.cat([segment1, segment2])
204 | return text, segment
205 |
206 | def _add_special_token(self, sent):
207 | cls = sent.new_full((1,), self.tokenizer.vocab["[CLS]"])
208 | sep = sent.new_full((1,), self.tokenizer.vocab["[SEP]"])
209 | sent = torch.cat([cls, sent, sep])
210 | return sent
211 |
212 |
213 | def binarize_list(self, words):
214 | return self.tokenizer.convert_tokens_to_ids(words)
215 |
216 | def tokenize(self, s):
217 | try:
218 | return self.tokenizer.tokenize(s)
219 | except:
220 | print('failed on', s)
221 | raise
222 |
223 | def __len__(self):
224 | return len(self.all_data)
225 |
226 | def openqa_collate(samples):
227 | if len(samples) == 0:
228 | return {}
229 |
230 | input_ids = collate_tokens([s['input_ids'] for s in samples], 0)
231 | start_masks = torch.zeros(input_ids.size())
232 | for b_idx, s in enumerate(samples):
233 | for _ in s["start"]:
234 | if _ != -1:
235 | start_masks[b_idx, _] = 1
236 |
237 | net_input = {
238 | 'input_ids': input_ids,
239 | 'segment_ids': collate_tokens(
240 | [s['segment_ids'] for s in samples], 0),
241 | 'paragraph_mask': collate_tokens(
242 | [s['paragraph_mask'] for s in samples], 0,),
243 | 'question_mask': collate_tokens([s["question_mask"] for s in samples], 0),
244 | 'start_positions': collate_tokens(
245 | [s['start'] for s in samples], -1),
246 | 'end_positions': collate_tokens(
247 | [s['end'] for s in samples], -1),
248 | 'no_ans_targets': collate_tokens(
249 | [s['no_answer'] for s in samples], 0),
250 | 'input_mask': collate_tokens([torch.ones_like(s["input_ids"]) for s in samples], 0),
251 | 'start_masks': start_masks,
252 | 'input_ids_q': collate_tokens([s['input_ids_q'] for s in samples], 0),
253 | 'input_mask_q': collate_tokens([torch.ones_like(s["input_ids_q"]) for s in samples], 0),
254 | 'input_ids_c': collate_tokens([s['input_ids_c'] for s in samples], 0),
255 | 'input_mask_c': collate_tokens([torch.ones_like(s["input_ids_c"]) for s in samples], 0),
256 | }
257 |
258 | return {
259 | 'id': [s['qid'] for s in samples],
260 | "q": [s['q'] for s in samples],
261 | 'doc_tokens': [s['doc_tokens'] for s in samples],
262 | 'q_subtoks': [s['q_subtoks'] for s in samples],
263 | 'wp_tokens': [s['wp_tokens'] for s in samples],
264 | 'tok_to_orig_index': [s['tok_to_orig_index'] for s in samples],
265 | 'para_offset': [s['para_offset'] for s in samples],
266 | "true_answers": [s['true_answers'] for s in samples],
267 | 'net_input': net_input,
268 | }
269 |
270 |
271 | class top5k_generator(object):
272 |
273 | def __init__(self,
274 | retrieved_path,
275 | embed_path
276 | ):
277 | super().__init__()
278 | retrieved = [json.loads(l) for l in open(retrieved_path).readlines()]
279 | self.para_embed = np.load(embed_path)
280 |
281 | self.qid2para = {}
282 | for item in retrieved:
283 | self.qid2para[hash_question(item["question"])] = {"para_embed_idx": item["para_embed_idx"], "para_labels": item["para_labels"]}
284 |
285 | def generate(self, qid):
286 | para_labels = self.qid2para[qid]["para_labels"]
287 | para_embed_idx = self.qid2para[qid]["para_embed_idx"]
288 | if np.sum(para_labels) > 0:
289 | para_embed = torch.from_numpy(self.para_embed[para_embed_idx])
290 | para_labels = torch.tensor(para_labels).nonzero().view(-1)
291 | result = {}
292 | result["para_embed"] = para_embed
293 | result["para_labels"] = para_labels
294 | return result
295 | else:
296 | return None
297 |
298 |
299 | if __name__ == "__main__":
300 | data_path = "../data/mrqa-train/SQuAD-tokenized.jsonl"
301 | tokenized_data = [json.loads(_.strip())
302 | for _ in open(data_path).readlines()]
303 | q_lens = np.array([len(item['q_subtoks']) for item in tokenized_data])
304 | c_lens = np.array([len(item['doc_subtoks']) for item in tokenized_data])
305 | import pdb; pdb.set_trace()
306 |
--------------------------------------------------------------------------------
/qa/eval_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | recover the answer string from BERT predictions
3 | """
4 |
5 | import collections
6 | from tokenizer import BasicTokenizer
7 | import six
8 |
9 | def is_whitespace(c):
10 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
11 | return True
12 | return False
13 |
14 |
15 | def get_final_text(pred_text, orig_text, do_lower_case=False,verbose_logging=True):
16 | """Project the tokenized prediction back to the original text."""
17 | def _strip_spaces(text):
18 | ns_chars = []
19 | ns_to_s_map = collections.OrderedDict()
20 | for (i, c) in enumerate(text):
21 | if c == " ":
22 | continue
23 | ns_to_s_map[len(ns_chars)] = i
24 | ns_chars.append(c)
25 | ns_text = "".join(ns_chars)
26 | return (ns_text, ns_to_s_map)
27 |
28 | # We first tokenize `orig_text`, strip whitespace from the result
29 | # and `pred_text`, and check if they are the same length. If they are
30 | # NOT the same length, the heuristic has failed. If they are the same
31 | # length, we assume the characters are one-to-one aligned.
32 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
33 |
34 | tok_text = " ".join(tokenizer.tokenize(orig_text))
35 |
36 | start_position = tok_text.find(pred_text)
37 | if start_position == -1:
38 | if verbose_logging:
39 | print(
40 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
41 | return orig_text
42 | end_position = start_position + len(pred_text) - 1
43 |
44 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
45 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
46 |
47 | if len(orig_ns_text) != len(tok_ns_text):
48 | if verbose_logging:
49 | print("Length not equal after stripping spaces: '%s' vs '%s'",
50 | orig_ns_text, tok_ns_text)
51 | return orig_text
52 |
53 | # We then project the characters in `pred_text` back to `orig_text` using
54 | # the character-to-character alignment.
55 | tok_s_to_ns_map = {}
56 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
57 | tok_s_to_ns_map[tok_index] = i
58 |
59 | orig_start_position = None
60 | if start_position in tok_s_to_ns_map:
61 | ns_start_position = tok_s_to_ns_map[start_position]
62 | if ns_start_position in orig_ns_to_s_map:
63 | orig_start_position = orig_ns_to_s_map[ns_start_position]
64 |
65 | if orig_start_position is None:
66 | if verbose_logging:
67 | print("Couldn't map start position")
68 | return orig_text
69 |
70 | orig_end_position = None
71 | if end_position in tok_s_to_ns_map:
72 | ns_end_position = tok_s_to_ns_map[end_position]
73 | if ns_end_position in orig_ns_to_s_map:
74 | orig_end_position = orig_ns_to_s_map[ns_end_position]
75 |
76 | if orig_end_position is None:
77 | if verbose_logging:
78 | print("Couldn't map end position")
79 | return orig_text
80 |
81 | output_text = orig_text[orig_start_position:(orig_end_position + 1)]
82 | return output_text
--------------------------------------------------------------------------------
/qa/msmarco_process.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def extract_qa_p(path="../data/msmarco-qa/train_v2.1.json", output="../data/msmarco-qa/train.txt"):
5 | data = json.load(open(path))
6 | data_to_save = []
7 | for id_, answers in data["answers"].items():
8 | if answers[0] != 'No Answer Present.':
9 | passages = data["passages"][id_]
10 | query = data["query"][id_]
11 | relevant_p = []
12 | for p in passages:
13 | if p["is_selected"]:
14 | relevant_p.append(p["passage_text"])
15 | if len(relevant_p) != 0:
16 | data_to_save.append({"q": query, "answer": answers, "para": " ".join(relevant_p)})
17 |
18 | with open(output, "w") as g:
19 | for l in data_to_save:
20 | g.write(json.dumps(l) + "\n")
21 |
22 | from tqdm import tqdm
23 |
24 | if __name__ == "__main__":
25 | # extract_qa_p()
26 |
27 | # data = [json.loads(l)
28 | # for l in open("../data/msmarco-qa/dev.txt").readlines()]
29 |
30 | # source_file = open("../data/msmarco-qa/val.source", "w")
31 | # target_file = open("../data/msmarco-qa/val.target", "w")
32 | # for _ in data:
33 | # source_file.write(_["para"] + "\n")
34 | # target_file.write(_["q"] + "\n")
35 |
36 | all_paras = [json.loads(l) for l in open(
37 | "../data/trec-2019/msmarco_paras.txt").readlines()]
38 | source_file = open("../data/msmarco-qa/test.source", "w")
39 | for _ in tqdm(all_paras):
40 | source_file.write(" ".join(_["text"].split()) + "\n")
41 |
--------------------------------------------------------------------------------
/qa/official_eval.py:
--------------------------------------------------------------------------------
1 | """Official evaluation script for the MRQA Workshop Shared Task.
2 | Adapted fromt the SQuAD v1.1 official evaluation script.
3 | Usage:
4 | python official_eval.py dataset_file.jsonl.gz prediction_file.json
5 | """
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import argparse
11 | import string
12 | import re
13 | import json
14 | import gzip
15 | import sys
16 | from collections import Counter
17 | # from allennlp.common.file_utils import cached_path
18 |
19 |
20 | def normalize_answer(s):
21 | """Lower text and remove punctuation, articles and extra whitespace."""
22 | def remove_articles(text):
23 | return re.sub(r'\b(a|an|the)\b', ' ', text)
24 |
25 | def white_space_fix(text):
26 | return ' '.join(text.split())
27 |
28 | def remove_punc(text):
29 | exclude = set(string.punctuation)
30 | return ''.join(ch for ch in text if ch not in exclude)
31 |
32 | def lower(text):
33 | return text.lower()
34 |
35 | return white_space_fix(remove_articles(remove_punc(lower(s))))
36 |
37 |
38 | def regex_match_score(prediction, pattern):
39 | """Check if the prediction matches the given regular expression."""
40 | try:
41 | compiled = re.compile(
42 | pattern,
43 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE
44 | )
45 | except BaseException:
46 | print('Regular expression failed to compile: %s' % pattern)
47 | return False
48 | return compiled.match(prediction) is not None
49 |
50 | def f1_score(prediction, ground_truth):
51 | prediction_tokens = normalize_answer(prediction).split()
52 | ground_truth_tokens = normalize_answer(ground_truth).split()
53 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
54 | num_same = sum(common.values())
55 | if num_same == 0:
56 | return 0
57 | precision = 1.0 * num_same / len(prediction_tokens)
58 | recall = 1.0 * num_same / len(ground_truth_tokens)
59 | f1 = (2 * precision * recall) / (precision + recall)
60 | return f1
61 |
62 |
63 | def exact_match_score(prediction, ground_truth):
64 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
65 |
66 |
67 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
68 | scores_for_ground_truths = []
69 | for ground_truth in ground_truths:
70 | score = metric_fn(prediction, ground_truth)
71 | scores_for_ground_truths.append(score)
72 | return max(scores_for_ground_truths)
73 |
74 |
75 | def read_predictions(prediction_file):
76 | with open(prediction_file) as f:
77 | predictions = json.load(f)
78 | return predictions
79 |
80 |
81 | def read_answers(gold_file):
82 | answers = {}
83 | with gzip.open(gold_file, 'rb') as f:
84 | for i, line in enumerate(f):
85 | example = json.loads(line)
86 | if i == 0 and 'header' in example:
87 | continue
88 | for qa in example['qas']:
89 | answers[qa['qid']] = qa['answers']
90 | return answers
91 |
92 |
93 | def evaluate(answers, predictions, skip_no_answer=False):
94 | f1 = exact_match = total = 0
95 | for qid, ground_truths in answers.items():
96 | if qid not in predictions:
97 | if not skip_no_answer:
98 | message = 'Unanswered question %s will receive score 0.' % qid
99 | print(message)
100 | total += 1
101 | continue
102 | total += 1
103 | prediction = predictions[qid]
104 | exact_match += metric_max_over_ground_truths(
105 | exact_match_score, prediction, ground_truths)
106 | f1 += metric_max_over_ground_truths(
107 | f1_score, prediction, ground_truths)
108 |
109 | exact_match = 100.0 * exact_match / total
110 | f1 = 100.0 * f1 / total
111 |
112 | return {'exact_match': exact_match, 'f1': f1}
113 |
114 |
115 | if __name__ == '__main__':
116 | parser = argparse.ArgumentParser(
117 | description='Evaluation for MRQA Workshop Shared Task')
118 | parser.add_argument('dataset_file', type=str, help='Dataset File')
119 | parser.add_argument('prediction_file', type=str, help='Prediction File')
120 | parser.add_argument('--skip-no-answer', action='store_true')
121 | args = parser.parse_args()
122 |
123 | # answers = read_answers(cached_path(args.dataset_file))
124 | # predictions = read_predictions(cached_path(args.prediction_file))
125 | # metrics = evaluate(answers, predictions, args.skip_no_answer)
126 |
127 | # print(json.dumps(metrics))
128 |
--------------------------------------------------------------------------------
/qa/online_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import numpy as np
4 | import random
5 | from prepro_utils import hash_question, normalize, find_ans_span_with_char_offsets, prepare
6 | from utils import DocDB
7 | import faiss
8 | from official_eval import normalize_answer
9 | from basic_tokenizer import SimpleTokenizer
10 | from prepro_dense import para_has_answer, match_answer_span
11 | from tqdm import tqdm
12 |
13 | from transformers import BertTokenizer
14 |
15 | """
16 | retrieve paragraphs and find span for top5 on the fly
17 | """
18 |
19 |
20 | def normalize_para(s):
21 |
22 | def white_space_fix(text):
23 | return ' '.join(text.split())
24 |
25 | def lower(text):
26 | return text.lower()
27 |
28 | return white_space_fix(lower(s))
29 |
30 | def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
31 | """Convert a list of 1d tensors into a padded 2d tensor."""
32 | size = max(v.size(0) for v in values)
33 | res = values[0].new(len(values), size).fill_(pad_idx)
34 |
35 | def copy_tensor(src, dst):
36 | assert dst.numel() == src.numel()
37 | if move_eos_to_beginning:
38 | assert src[-1] == eos_idx
39 | dst[0] = eos_idx
40 | dst[1:] = src[:-1]
41 | else:
42 | dst.copy_(src)
43 |
44 | for i, v in enumerate(values):
45 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
46 | return res
47 |
48 |
49 | class OnlineSampler(object):
50 |
51 | def __init__(self,
52 | raw_data,
53 | tokenizer,
54 | max_query_length,
55 | max_length,
56 | db,
57 | para_embed,
58 | index2paraid='retrieval/index_data/idx_id.json',
59 | matched_para_path="",
60 | exact_search=False,
61 | cased=False,
62 | regex=False
63 | ):
64 |
65 | self.max_length = max_length
66 | self.max_query_length = max_query_length
67 | self.para_embed = para_embed
68 | self.cased = cased # spanbert used cased tokenization
69 | self.regex = regex
70 |
71 | if self.cased:
72 | self.cased_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
73 |
74 | # if not exact_search:
75 | quantizer = faiss.IndexFlatIP(128)
76 | self.index = faiss.IndexIVFFlat(quantizer, 128, 100)
77 | self.index.train(self.para_embed)
78 | self.index.add(self.para_embed)
79 | self.index.nprobe = 20
80 | # else:
81 | # self.index = faiss.IndexFlatIP(128)
82 | # self.index.add(self.para_embed)
83 |
84 | self.tokenizer = tokenizer
85 | self.qa_data = [json.loads(l) for l in open(raw_data).readlines()]
86 | self.index2paraid = json.load(open(index2paraid))
87 | self.para_db = db
88 | self.matched_para_path = matched_para_path
89 | if self.matched_para_path != "":
90 | print(f"Load matched gold paras from {self.matched_para_path}")
91 | annotated = [json.loads(l) for l in tqdm(open(
92 | self.matched_para_path).readlines())]
93 | self.qid2goldparas = {hash_question(
94 | item["question"]): item["matched_paras"] for item in annotated}
95 |
96 | self.basic_tokenizer = SimpleTokenizer()
97 |
98 | def shuffle(self):
99 | random.shuffle(self.qa_data)
100 |
101 | def __len__(self):
102 | return len(self.qa_data)
103 |
104 | def load(self, retriever, k=5):
105 | for qa in self.qa_data:
106 | with torch.no_grad():
107 | q_ids = torch.LongTensor(self.tokenizer.encode(
108 | qa["question"], max_length=self.max_query_length)).view(1,-1).cuda()
109 | q_masks = torch.ones(q_ids.shape).bool().view(1,-1).cuda()
110 | q_cls = retriever.bert_q(q_ids, q_masks)[1]
111 | q_embed = retriever.proj_q(q_cls).data.cpu().numpy().astype('float32')
112 |
113 | _, I = self.index.search(q_embed, 5000) # retrieve
114 | para_embed_idx = I.reshape(-1)
115 |
116 | if self.cased:
117 | q_ids_cased = torch.LongTensor(self.cased_tokenizer.encode(
118 | qa["question"], max_length=self.max_query_length)).view(1, -1)
119 |
120 | para_idx = [self.index2paraid[str(_)] for _ in para_embed_idx]
121 | para_embeds = self.para_embed[para_embed_idx]
122 |
123 | qid = hash_question(qa["question"])
124 | gold_paras = self.qid2goldparas[qid]
125 |
126 | # match answer strings
127 | p_labels = []
128 | batched_examples = []
129 | topk5000_labels = [int(_ in gold_paras) for _ in para_idx]
130 |
131 | # match answer spans in top5 paras
132 | for p_idx in para_idx[:k]:
133 | p = normalize(self.para_db.get_doc_text(p_idx))
134 | # p_covered, matched_string = para_has_answer(p, qa["answer"], self.basic_tokenizer)
135 | matched_spans = match_answer_span(
136 | p, qa["answer"], self.basic_tokenizer, match="regex" if self.regex else "string")
137 | p_covered = int(len(matched_spans) > 0)
138 | ans_starts, ans_ends, ans_texts = [], [], []
139 |
140 | if self.cased:
141 | doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare(p, self.cased_tokenizer)
142 | else:
143 | doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare(
144 | p, self.tokenizer)
145 |
146 | if p_covered:
147 | for matched_string in matched_spans:
148 | char_starts = [i for i in range(
149 | len(p)) if p.startswith(matched_string, i)]
150 | if len(char_starts) > 0:
151 | char_ends = [start + len(matched_string) - 1 for start in char_starts]
152 | answer = {"text": matched_string, "char_spans": list(
153 | zip(char_starts, char_ends))}
154 |
155 | if self.cased:
156 | ans_spans = find_ans_span_with_char_offsets(answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, self.cased_tokenizer)
157 | else:
158 | ans_spans = find_ans_span_with_char_offsets(
159 | answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, self.tokenizer)
160 |
161 | for s, e in ans_spans:
162 | ans_starts.append(s)
163 | ans_ends.append(e)
164 | ans_texts.append(matched_string)
165 | batched_examples.append({
166 | "qid": hash_question(qa["question"]),
167 | "q": qa["question"],
168 | "true_answers": qa["answer"],
169 | "doc_subtoks": all_doc_tokens,
170 | "starts": ans_starts,
171 | "ends": ans_ends,
172 | "covered": p_covered
173 | })
174 |
175 | # # look up saved
176 | # if p_idx in gold_paras:
177 | # p_covered = 1
178 | # all_doc_tokens = gold_paras[p_idx]["doc_subtoks"]
179 | # ans_starts = gold_paras[p_idx]["starts"]
180 | # ans_ends = gold_paras[p_idx]["ends"]
181 | # ans_texts = gold_paras[p_idx]["span_texts"]
182 | # else:
183 | # p_covered = 0
184 | # p = normalize(self.para_db.get_doc_text(p_idx))
185 | # _, _, _, _, all_doc_tokens = prepare(p, self.tokenizer)
186 | # ans_starts, ans_ends, ans_texts = [], [], []
187 |
188 | # batched_examples.append({
189 | # "qid": hash_question(qa["question"]),
190 | # "q": qa["question"],
191 | # "true_answers": qa["answer"],
192 | # "doc_subtoks": all_doc_tokens,
193 | # "starts": ans_starts,
194 | # "ends": ans_ends,
195 | # "covered": p_covered
196 | # })
197 | p_labels.append(int(p_covered))
198 |
199 | # calculate loss only when the top5000 covered the answer passage
200 | if np.sum(topk5000_labels) > 0 or np.sum(p_labels) > 0:
201 | # training tensors
202 | for item in batched_examples:
203 | item["input_ids_q"] = q_ids.view(-1).cpu()
204 |
205 | if self.cased:
206 | item["input_ids_q_cased"] = q_ids_cased.view(-1)
207 | para_offset = item["input_ids_q_cased"].size(0)
208 | else:
209 | para_offset = item["input_ids_q"].size(0)
210 |
211 | max_toks_for_doc = self.max_length - para_offset - 1
212 | para_subtoks = item["doc_subtoks"]
213 | if len(para_subtoks) > max_toks_for_doc:
214 | para_subtoks = para_subtoks[:max_toks_for_doc]
215 |
216 | if self.cased:
217 | p_ids = self.cased_tokenizer.convert_tokens_to_ids(para_subtoks)
218 | else:
219 | p_ids = self.tokenizer.convert_tokens_to_ids(
220 | para_subtoks)
221 | item["input_ids_c"] = self._add_special_token(torch.LongTensor(p_ids))
222 | paragraph = item["input_ids_c"][1:-1]
223 | if self.cased:
224 | item["input_ids"], item["segment_ids"] = self._join_sents(
225 | item["input_ids_q_cased"][1:-1], item["input_ids_c"][1:-1])
226 | else:
227 | item["input_ids"], item["segment_ids"] = self._join_sents(item["input_ids_q"][1:-1], item["input_ids_c"][1:-1])
228 | item["para_offset"] = para_offset
229 | item["paragraph_mask"] = torch.zeros(item["input_ids"].shape).bool()
230 | item["paragraph_mask"][para_offset:-1] = 1
231 |
232 | starts, ends, covered = item["starts"], item["ends"], item["covered"]
233 | start_positions, end_positions = [], []
234 |
235 | covered = item["covered"]
236 | if covered:
237 | covered = 0
238 | for s, e in zip(starts, ends):
239 | assert s <= e
240 | if s >= paragraph.size(0):
241 | continue
242 | else:
243 | start_position = min(
244 | s, paragraph.size(0) - 1) + para_offset
245 | end_position = min(e, paragraph.size(0) - 1) + para_offset
246 | covered = 1
247 | start_positions.append(start_position)
248 | end_positions.append(end_position)
249 | if len(start_positions) == 0:
250 | assert not covered
251 | start_positions.append(-1)
252 | end_positions.append(-1)
253 |
254 | start_tensor, end_tensor, covered = torch.LongTensor(
255 | start_positions), torch.LongTensor(end_positions), torch.LongTensor([covered])
256 |
257 | item["start"] = start_tensor
258 | item["end"] = end_tensor
259 | item["covered"] = covered
260 |
261 |
262 | yield self.collate(batched_examples, para_embeds, topk5000_labels)
263 | else:
264 | yield {}
265 |
266 | def eval_load(self, retriever, k=5):
267 | for qa in self.qa_data:
268 | with torch.no_grad():
269 | q_ids = torch.LongTensor(self.tokenizer.encode(qa["question"], max_length=self.max_query_length)).view(1, -1).cuda()
270 | q_masks = torch.ones(q_ids.shape).bool().view(1, -1).cuda()
271 | q_cls = retriever.bert_q(q_ids, q_masks)[1]
272 | q_embed = retriever.proj_q(
273 | q_cls).data.cpu().numpy().astype('float32')
274 | _, I = self.index.search(q_embed, k)
275 | para_embed_idx = I.reshape(-1)
276 | para_idx = [self.index2paraid[str(_)] for _ in para_embed_idx]
277 | paras = [normalize(self.para_db.get_doc_text(idx))
278 | for idx in para_idx]
279 | para_embeds = self.para_embed[para_embed_idx]
280 |
281 | if self.cased:
282 | q_ids_cased = torch.LongTensor(self.cased_tokenizer.encode(
283 | qa["question"], max_length=self.max_query_length)).view(1, -1)
284 |
285 | batched_examples = []
286 | # match answer spans in top5 paras
287 | for p in paras:
288 | p = normalize(p)
289 |
290 | tokenizer = self.cased_tokenizer if self.cased else self.tokenizer
291 | doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare(
292 | p, tokenizer)
293 |
294 | batched_examples.append({
295 | "qid": hash_question(qa["question"]),
296 | "q": qa["question"],
297 | "true_answers": qa["answer"],
298 | "doc_toks": doc_tokens,
299 | "doc_subtoks": all_doc_tokens,
300 | "tok_to_orig_index": tok_to_orig_index,
301 | })
302 |
303 | for item in batched_examples:
304 | item["input_ids_q"] = q_ids.view(-1).cpu()
305 |
306 | if self.cased:
307 | item["input_ids_q_cased"] = q_ids_cased.view(-1)
308 | para_offset = item["input_ids_q_cased"].size(0)
309 | else:
310 | para_offset = item["input_ids_q"].size(0)
311 | max_toks_for_doc = self.max_length - para_offset - 1
312 | para_subtoks = item["doc_subtoks"]
313 | if len(para_subtoks) > max_toks_for_doc:
314 | para_subtoks = para_subtoks[:max_toks_for_doc]
315 | if self.cased:
316 | p_ids = self.cased_tokenizer.convert_tokens_to_ids(
317 | para_subtoks)
318 | else:
319 | p_ids = self.tokenizer.convert_tokens_to_ids(
320 | para_subtoks)
321 | item["input_ids_c"] = self._add_special_token(
322 | torch.LongTensor(p_ids))
323 | paragraph = item["input_ids_c"][1:-1]
324 | if self.cased:
325 | item["input_ids"], item["segment_ids"] = self._join_sents(
326 | item["input_ids_q_cased"][1:-1], item["input_ids_c"][1:-1])
327 | else:
328 | item["input_ids"], item["segment_ids"] = self._join_sents(
329 | item["input_ids_q"][1:-1], item["input_ids_c"][1:-1])
330 | item["para_offset"] = para_offset
331 | item["paragraph_mask"] = torch.zeros(
332 | item["input_ids"].shape).bool()
333 | item["paragraph_mask"][para_offset:-1] = 1
334 |
335 | yield self.collate(batched_examples, para_embeds)
336 |
337 |
338 | def _add_special_token(self, sent):
339 | cls = sent.new_full((1,), self.tokenizer.vocab["[CLS]"])
340 | sep = sent.new_full((1,), self.tokenizer.vocab["[SEP]"])
341 | sent = torch.cat([cls, sent, sep])
342 | return sent
343 |
344 | def _join_sents(self, sent1, sent2):
345 | cls = sent1.new_full((1,), self.tokenizer.vocab["[CLS]"])
346 | sep = sent1.new_full((1,), self.tokenizer.vocab["[SEP]"])
347 | sent1 = torch.cat([cls, sent1, sep])
348 | sent2 = torch.cat([sent2, sep])
349 | text = torch.cat([sent1, sent2])
350 | segment1 = torch.zeros(sent1.size(0)).long()
351 | segment2 = torch.ones(sent2.size(0)).long()
352 | segment = torch.cat([segment1, segment2])
353 | return text, segment
354 |
355 | def collate(self, samples, para_embeds, topk5000_labels=None):
356 | if len(samples) == 0:
357 | return {}
358 |
359 | input_ids = collate_tokens([s['input_ids'] for s in samples], 0)
360 |
361 | if "start" in samples[0]:
362 | assert topk5000_labels is not None
363 | net_input = {
364 | 'input_ids': input_ids,
365 | 'segment_ids': collate_tokens(
366 | [s['segment_ids'] for s in samples], 0),
367 | 'paragraph_mask': collate_tokens(
368 | [s['paragraph_mask'] for s in samples], 0,),
369 | 'start_positions': collate_tokens(
370 | [s['start'] for s in samples], -1),
371 | 'end_positions': collate_tokens(
372 | [s['end'] for s in samples], -1),
373 | 'para_targets': collate_tokens(
374 | [s['covered'] for s in samples], 0),
375 | 'input_mask': collate_tokens([torch.ones_like(s["input_ids"]) for s in samples], 0),
376 | 'input_ids_q': collate_tokens([s['input_ids_q'] for s in samples], 0),
377 | 'input_mask_q': collate_tokens([torch.ones_like(s["input_ids_q"]) for s in samples], 0),
378 | 'para_embed': torch.from_numpy(para_embeds),
379 | "top5000_labels": torch.LongTensor(topk5000_labels)
380 | }
381 | return {
382 | 'id': [s['qid'] for s in samples],
383 | "q": [s['q'] for s in samples],
384 | 'wp_tokens': [s['doc_subtoks'] for s in samples],
385 | 'para_offset': [s['para_offset'] for s in samples],
386 | "true_answers": [s['true_answers'] for s in samples],
387 | 'net_input': net_input,
388 | }
389 |
390 | else:
391 | net_input = {
392 | 'input_ids': input_ids,
393 | 'segment_ids': collate_tokens(
394 | [s['segment_ids'] for s in samples], 0),
395 | 'paragraph_mask': collate_tokens(
396 | [s['paragraph_mask'] for s in samples], 0,),
397 | 'input_mask': collate_tokens([torch.ones_like(s["input_ids"]) for s in samples], 0),
398 | 'input_ids_q': collate_tokens([s['input_ids_q'] for s in samples], 0),
399 | 'input_mask_q': collate_tokens([torch.ones_like(s["input_ids_q"]) for s in samples], 0),
400 | 'para_embed': torch.from_numpy(para_embeds)
401 | }
402 |
403 | return {
404 | 'id': [s['qid'] for s in samples],
405 | "q": [s['q'] for s in samples],
406 | 'doc_tokens': [s['doc_toks'] for s in samples],
407 | 'wp_tokens': [s['doc_subtoks'] for s in samples],
408 | 'tok_to_orig_index': [s['tok_to_orig_index'] for s in samples],
409 | 'para_offset': [s['para_offset'] for s in samples],
410 | "true_answers": [s['true_answers'] for s in samples],
411 | 'net_input': net_input,
412 | }
413 |
414 |
415 |
416 | if __name__ == "__main__":
417 | index_path = "retrieval/index_data/para_embed_3_28_c10000.npy"
418 | raw_data = "../data/nq-train.txt"
419 |
420 |
421 | from transformers import BertConfig, BertTokenizer
422 | from retrieval.retriever import BertForRetriever
423 | from config import get_args
424 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
425 | bert_config = BertConfig.from_pretrained('bert-base-uncased')
426 | args = get_args()
427 | retriever = BertForRetriever(bert_config, args)
428 |
429 | from utils import load_saved
430 | retriever_path = "retrieval/logs/splits_3_28_c10000-seed42-bsz640-fp16True-retrieve-from94_c1000_continue_from_failed-lr1e-05-bert-base-uncased-filterTrue/checkpoint_best.pt"
431 | retriever = load_saved(retriever, retriever_path)
432 | retriever.cuda()
433 |
434 | sampler = OnlineSampler(index_path, raw_data, tokenizer, args.max_query_length, args.max_seq_length)
435 |
436 | sampler.shuffle()
437 | retriever.eval()
438 | for batch in sampler.load(retriever):
439 | if batch is not {}:
440 | print(batch.keys())
441 | print(batch["net_input"]["para_targets"])
442 | import pdb; pdb.set_trace()
443 |
444 |
--------------------------------------------------------------------------------
/qa/prepro_dense.py:
--------------------------------------------------------------------------------
1 | from prepro_utils import hash_question, normalize, find_ans_span_with_char_offsets, prepare
2 | import json
3 | from utils import DocDB
4 | from official_eval import normalize_answer
5 | import numpy as np
6 | from tqdm import tqdm
7 | from basic_tokenizer import RegexpTokenizer, SimpleTokenizer
8 |
9 | from multiprocessing import Pool as ProcessPool
10 | from multiprocessing.util import Finalize
11 | from functools import partial
12 | import re
13 |
14 | import sys
15 | from transformers import BertTokenizer
16 |
17 | PROCESS_TOK = None
18 | PROCESS_DB = None
19 | BERT_TOK = None
20 |
21 | def init():
22 | global PROCESS_TOK, PROCESS_DB, BERT_TOK
23 | PROCESS_TOK = SimpleTokenizer()
24 | BERT_TOK = BertTokenizer.from_pretrained("bert-base-uncased")
25 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
26 | PROCESS_DB = DocDB('../data/nq_paras.db')
27 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
28 |
29 |
30 | def regex_match(text, pattern):
31 | """return all spans that match the pattern"""
32 | try:
33 | pattern = re.compile(
34 | pattern,
35 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
36 | )
37 | except BaseException:
38 | print('Regular expression failed to compile: %s' % pattern)
39 | return []
40 |
41 | matched = [x.group() for x in re.finditer(pattern, text)]
42 | return list(set(matched))
43 |
44 | def para_has_answer(p, answer, tokenizer):
45 | tokens = tokenizer.tokenize(p)
46 | text = tokens.words(uncased=True)
47 | matched = []
48 | for single_answer in answer:
49 | single_answer = normalize(single_answer)
50 | single_answer = tokenizer.tokenize(single_answer)
51 | single_answer = single_answer.words(uncased=True)
52 | for i in range(0, len(text) - len(single_answer) + 1):
53 | if single_answer == text[i: i + len(single_answer)]:
54 | return True, tokens.slice(i, i + len(single_answer)).untokenize()
55 | return False, ""
56 |
57 | def match_answer_span(p, answer, tokenizer, match="string"):
58 | # p has been normalized
59 | if match == 'string':
60 | tokens = tokenizer.tokenize(p)
61 | text = tokens.words(uncased=True)
62 | matched = set()
63 | for single_answer in answer:
64 | single_answer = normalize(single_answer)
65 | single_answer = tokenizer.tokenize(single_answer)
66 | single_answer = single_answer.words(uncased=True)
67 | for i in range(0, len(text) - len(single_answer) + 1):
68 | if single_answer == text[i: i + len(single_answer)]:
69 | matched.add(tokens.slice(i, i + len(single_answer)).untokenize())
70 | return list(matched)
71 | elif match == 'regex':
72 | # Answer is a regex
73 | single_answer = normalize(answer[0])
74 | return regex_match(p, single_answer)
75 |
76 | def process_qa_para(qa_with_result, k=10000, match="string"):
77 | global PROCESS_DB, PROCESS_TOK
78 | qa, result = qa_with_result
79 | matched_paras = {}
80 | for para_id in result["para_id"][:k]:
81 | p = PROCESS_DB.get_doc_text(para_id)
82 | p = normalize(p)
83 | if match == "string":
84 | covered, matched = para_has_answer(p, qa["answer"], PROCESS_TOK)
85 | elif match == "regex":
86 | single_answer = normalize(qa["answer"][0])
87 | matched = regex_match(p, single_answer)
88 | covered = len(matched) > 0
89 | if covered:
90 | matched_paras[para_id] = matched
91 | qa["matched_paras"] = matched_paras
92 | return qa
93 |
94 | def find_span(example):
95 | global PROCESS_DB, BERT_TOK
96 | annotated = {}
97 | for para_id, matched in example["matched_paras"].items():
98 | p = normalize(PROCESS_DB.get_doc_text(para_id))
99 | ans_starts, ans_ends, ans_texts = [], [], []
100 | doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare(
101 | p, BERT_TOK)
102 | char_starts = [i for i in range(
103 | len(p)) if p.startswith(matched, i)]
104 | assert len(char_starts) > 0
105 | char_ends = [start + len(matched) - 1 for start in char_starts]
106 | answer = {"text": matched, "char_spans": list(
107 | zip(char_starts, char_ends))}
108 | ans_spans = find_ans_span_with_char_offsets(
109 | answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, BERT_TOK)
110 | for s, e in ans_spans:
111 | ans_starts.append(s)
112 | ans_ends.append(e)
113 | ans_texts.append(matched)
114 | annotated[para_id] = {
115 | # "doc_toks": doc_tokens,
116 | "doc_subtoks": all_doc_tokens,
117 | "starts": ans_starts,
118 | "ends": ans_ends,
119 | "span_texts": [matched],
120 | # "tok_to_orig_index": tok_to_orig_index
121 | }
122 | example["matched_paras"] = annotated
123 | return example
124 |
125 |
126 | def process_ground_paras(retrieved="../data/wq_finetuneq_train_10000.txt", save_path="../data/wq_ft_train_matched.txt", raw_data="../data/wq-train.txt", num_workers=40, debug=False, k=10000, match="string"):
127 | retrieved = [json.loads(l) for l in open(retrieved).readlines()]
128 | raw_data = [json.loads(l) for l in open(raw_data).readlines()]
129 |
130 | tokenizer = SimpleTokenizer()
131 | recall = []
132 | processes = ProcessPool(
133 | processes=num_workers,
134 | initializer=init,
135 | )
136 | process_qa_para_partial = partial(process_qa_para, k=k, match=match)
137 | num_tasks = len(raw_data)
138 | results = []
139 | for _ in tqdm(processes.imap_unordered(process_qa_para_partial, zip(raw_data, retrieved)), total=len(raw_data)):
140 | results.append(_)
141 |
142 | topk_covered = [len(r["matched_paras"])>0 for r in results]
143 | print(np.mean(topk_covered))
144 |
145 | if debug:
146 | return
147 |
148 | # # annotate those match paras, accelerate training
149 | # processed = []
150 | # for _ in tqdm(processes.imap_unordered(find_span, results), total=len(results)):
151 | # processed.append(_)
152 |
153 | processes.close()
154 | processes.join()
155 |
156 | with open(save_path, "w") as g:
157 | for _ in results:
158 | g.write(json.dumps(_) + "\n")
159 |
160 |
161 | def debug(retrieved="../data/wq_finetuneq_dev_5000.txt", raw_data="../data/wq-dev.txt", precomputed="../data/wq_ft_dev_matched.txt", k=10):
162 | # check wether it reasonable to precompute a paragraph set
163 | retrieved = [json.loads(l) for l in open(retrieved).readlines()]
164 | raw_data = [json.loads(l) for l in open(raw_data).readlines()]
165 |
166 | annotated = [json.loads(l) for l in open(precomputed).readlines()]
167 | qid2goldparas = {hash_question(item["question"]): item["matched_paras"] for item in annotated}
168 |
169 | topk_covered = []
170 | for qa, result in tqdm(zip(raw_data, retrieved), total=len(raw_data)):
171 | qid = hash_question(qa["question"])
172 | covered = 0
173 | for para_id in result["para_id"][:k]:
174 | if para_id in qid2goldparas[qid]:
175 | covered = 1
176 | break
177 | topk_covered.append(covered)
178 | print(np.mean(topk_covered))
179 |
180 |
181 | if __name__ == "__main__":
182 |
183 | # trec
184 | process_ground_paras(retrieved="../data/trec/trec_finetuneq_train-20000.txt", save_path="../data/trec_train_matched_20000.txt", raw_data="../data/trec-train.txt", num_workers=30, k=20000, match="regex")
185 |
186 | # # wq
187 | # process_ground_paras(retrieved="../data/wq_finetuneq_train-combined_15000.txt",
188 | # save_path="../data/wq_ft_train-combined_matched_15000.txt", raw_data="../data/wq-train-combined.txt", num_workers=30, k=15000)
189 |
190 | # nq
191 | #process_ground_paras(retrieved="../data/nq_finetuneq_train_10000.txt",
192 | # save_path="../data/nq_ft_train_matched.txt", raw_data="../data/nq-train.txt", num_workers=40)
193 |
194 |
195 | # # debug
196 | # process_ground_paras(
197 | # retrieved="../data/wq_finetuneq_dev_5000_fi.txt", raw_data="../data/wq-dev.txt", debug=True, k=5)
198 | # process_ground_paras(
199 | # retrieved="../data/wq_finetuneq_dev.txt", raw_data="../data/wq-dev.txt", debug=True, k=5)
200 | # process_ground_paras(
201 | # retrieved="../data/nq_finetuneq_dev_5000_fi.txt", raw_data="../data/nq-dev.txt", debug=True, k=5)
202 | # process_ground_paras(
203 | # retrieved="../data/nq_finetuneq_dev.txt", raw_data="../data/nq-dev.txt", debug=True, k=5)
204 | # #debug(k=30)
205 | # process_ground_paras(retrieved="../data/nq_finetuneq_train_10000.txt",
206 | # save_path="../data/nq_ft_train_matched.txt", raw_data="../data/nq-train.txt", num_workers=40)
207 |
208 |
209 | # # debug
210 | # process_ground_paras(
211 | # retrieved="../data/wq_finetuneq_dev_5000.txt", raw_data="../data/wq-dev.txt", debug=True, k=30)
212 | # debug(k=30)
213 |
214 |
--------------------------------------------------------------------------------
/qa/prepro_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tokenizer import _is_whitespace, _is_punctuation, process, whitespace_tokenize
3 | from transformers import BertTokenizer
4 | from tqdm import tqdm
5 | from multiprocessing import Pool
6 | import hashlib
7 | import unicodedata
8 | import re
9 | import sys
10 | import numpy as np
11 |
12 | def hash_question(q):
13 | hash_object = hashlib.md5(q.encode())
14 | return hash_object.hexdigest()
15 |
16 | def normalize(text):
17 | """Resolve different type of unicode encodings."""
18 | return unicodedata.normalize('NFD', text)
19 |
20 | def load_mrqa_dataset(path):
21 | raw_data = [json.loads(line.strip()) for line in open(path).readlines()[1:]]
22 |
23 | qa_data = []
24 | for item in raw_data:
25 | id_ = item["id"]
26 | context = item["context"]
27 | for qa in item["qas"]:
28 | qid = qa["qid"]
29 | question = qa["question"]
30 | answers = qa.get("answers", [])
31 | matched_answers = qa.get("detected_answers", [])
32 | qa_data.append(
33 | {
34 | "qid": qid,
35 | "question": question,
36 | "context": context,
37 | "matched_answers": matched_answers,
38 | "true_answers": answers
39 | }
40 | )
41 | return qa_data
42 |
43 |
44 | def load_openqa_dataset(path, filter_no_answer=False):
45 |
46 | def _check_no_ans(sample):
47 | no_ans = True
48 | for para in sample["retrieved"]:
49 | if para["matched_answer"] != "":
50 | no_ans = False
51 | return no_ans
52 | return no_ans
53 |
54 | raw_data = [json.loads(line.strip()) for line in open(path).readlines()]
55 |
56 | if filter_no_answer:
57 | raw_data = [item for item in raw_data if not _check_no_ans(item)]
58 |
59 | print(f"Loading {len(raw_data)} QA pairs")
60 | return raw_data
61 |
62 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
63 | orig_answer_text):
64 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
65 |
66 | for new_start in range(input_start, input_end + 1):
67 | for new_end in range(input_end, new_start - 1, -1):
68 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
69 | if text_span == tok_answer_text:
70 | return (new_start, new_end)
71 |
72 | return (input_start, input_end)
73 |
74 | def find_ans_span_with_char_offsets(detected_ans, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, tokenizer):
75 | # could return mutiple spans for an answer string
76 | ans_text = detected_ans["text"]
77 | char_spans = detected_ans["char_spans"]
78 | ans_subtok_spans = []
79 | for char_start, char_end in char_spans:
80 | tok_start = char_to_word_offset[char_start]
81 | tok_end = char_to_word_offset[char_end] # char_end points to the last char of the answer, not one after
82 | sub_tok_start = orig_to_tok_index[tok_start]
83 |
84 | if tok_end < len(doc_tokens) - 1:
85 | sub_tok_end = orig_to_tok_index[tok_end + 1] - 1
86 | else:
87 | sub_tok_end = len(all_doc_tokens) - 1
88 |
89 | actual_text = " ".join(doc_tokens[tok_start:(tok_end + 1)])
90 | cleaned_answer_text = " ".join(whitespace_tokenize(ans_text))
91 | if actual_text.find(cleaned_answer_text) == -1:
92 | print("Could not find answer: '{}' vs. '{}'".format(
93 | actual_text, cleaned_answer_text))
94 |
95 | (sub_tok_start, sub_tok_end) = _improve_answer_span(
96 | all_doc_tokens, sub_tok_start, sub_tok_end, tokenizer, ans_text)
97 | ans_subtok_spans.append((sub_tok_start, sub_tok_end))
98 |
99 | return ans_subtok_spans
100 |
101 | def tokenize_item(sample, tokenizer):
102 | doc_tokens = []
103 | char_to_word_offset = []
104 | prev_is_whitespace = True
105 | for c in sample["context"]:
106 | if _is_whitespace(c):
107 | prev_is_whitespace = True
108 | else:
109 | if prev_is_whitespace:
110 | doc_tokens.append(c)
111 | else:
112 | doc_tokens[-1] += c
113 | prev_is_whitespace = False
114 | char_to_word_offset.append(len(doc_tokens) - 1)
115 |
116 | orig_to_tok_index = []
117 | tok_to_orig_index = []
118 | all_doc_tokens = []
119 | for (i, token) in enumerate(doc_tokens):
120 | orig_to_tok_index.append(len(all_doc_tokens))
121 | sub_tokens = process(token, tokenizer)
122 | for sub_token in sub_tokens:
123 | tok_to_orig_index.append(i)
124 | all_doc_tokens.append(sub_token)
125 | q_sub_toks = process(sample["question"], tokenizer)
126 |
127 | # finding answer spans
128 | ans_starts, ans_ends, ans_texts = [], [], []
129 | for answer in sample["matched_answers"]:
130 | ans_spans = find_ans_span_with_char_offsets(
131 | answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, tokenizer)
132 |
133 | for (s, e) in ans_spans:
134 | ans_starts.append(s)
135 | ans_ends.append(e)
136 | ans_texts.append(answer["text"])
137 |
138 | return {
139 | "q_subtoks": q_sub_toks,
140 | "qid": sample["qid"],
141 | "doc_toks": doc_tokens,
142 | "doc_subtoks": all_doc_tokens,
143 | "tok_to_orig_index": tok_to_orig_index,
144 | "starts": ans_starts,
145 | "ends": ans_ends,
146 | "span_texts": ans_texts,
147 | "true_answers": sample["true_answers"]
148 | }
149 |
150 | def prepare(context, tokenizer):
151 | doc_tokens = []
152 | char_to_word_offset = []
153 | prev_is_whitespace = True
154 |
155 | for c in context:
156 | if _is_whitespace(c):
157 | prev_is_whitespace = True
158 | else:
159 | if prev_is_whitespace:
160 | doc_tokens.append(c)
161 | else:
162 | doc_tokens[-1] += c
163 | prev_is_whitespace = False
164 | char_to_word_offset.append(len(doc_tokens) - 1)
165 |
166 | orig_to_tok_index = []
167 | tok_to_orig_index = []
168 | all_doc_tokens = []
169 | for (i, token) in enumerate(doc_tokens):
170 | orig_to_tok_index.append(len(all_doc_tokens))
171 | sub_tokens = tokenizer.tokenize(token)
172 | for sub_token in sub_tokens:
173 | tok_to_orig_index.append(i)
174 | all_doc_tokens.append(sub_token)
175 | return doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens
176 |
177 | def tokenize_item_openqa(sample, tokenizer):
178 | """
179 | process all the retrieved paragraphs of a QA pair
180 | """
181 | q_sub_toks = process(sample["question"], tokenizer)
182 |
183 | examples = []
184 | for para_idx, para in enumerate(sample["retrieved"]):
185 | doc_tokens = []
186 | char_to_word_offset = []
187 | prev_is_whitespace = True
188 | context = normalize(para["para"])
189 |
190 | for c in context:
191 | if _is_whitespace(c):
192 | prev_is_whitespace = True
193 | else:
194 | if prev_is_whitespace:
195 | doc_tokens.append(c)
196 | else:
197 | doc_tokens[-1] += c
198 | prev_is_whitespace = False
199 | char_to_word_offset.append(len(doc_tokens) - 1)
200 |
201 | orig_to_tok_index = []
202 | tok_to_orig_index = []
203 | all_doc_tokens = []
204 | for (i, token) in enumerate(doc_tokens):
205 | orig_to_tok_index.append(len(all_doc_tokens))
206 | sub_tokens = process(token, tokenizer)
207 | for sub_token in sub_tokens:
208 | tok_to_orig_index.append(i)
209 | all_doc_tokens.append(sub_token)
210 |
211 | # finding answer spans
212 | ans_starts, ans_ends, ans_texts = [], [], []
213 | no_answer = 0
214 | if para["matched_answer"] == "":
215 | ans_starts.append(-1)
216 | ans_ends.append(-1)
217 | ans_texts.append("")
218 | no_answer = 1
219 | else:
220 | ans_texts.append(para["matched_answer"])
221 | char_starts = [i for i in range(
222 | len(context)) if context.startswith(para["matched_answer"], i)]
223 |
224 | if len(char_starts) == 0:
225 | import pdb; pdb.set_trace()
226 | char_ends = [start + len(para["matched_answer"]) - 1 for start in char_starts]
227 | answer = {"text": para["matched_answer"], "char_spans": list(zip(char_starts, char_ends))}
228 | ans_spans = find_ans_span_with_char_offsets(
229 | answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, tokenizer)
230 | for (s, e) in ans_spans:
231 | ans_starts.append(s)
232 | ans_ends.append(e)
233 | ans_texts.append(answer["text"])
234 | qid = hash_question(sample["question"])
235 |
236 | examples.append({
237 | "q": sample["question"],
238 | "q_subtoks": q_sub_toks,
239 | "qid": qid,
240 | "para_id": para_idx,
241 | "doc_toks": doc_tokens,
242 | "doc_subtoks": all_doc_tokens,
243 | "tok_to_orig_index": tok_to_orig_index,
244 | "starts": ans_starts,
245 | "ends": ans_ends,
246 | "span_texts": ans_texts,
247 | "true_answers": sample["gold_answer"],
248 | "no_answer": no_answer,
249 | "bm25": para["bm25"],
250 | })
251 |
252 | return examples
253 |
254 | def tokenize_items(items, tokenizer, verbose=False, openqa=False):
255 | if verbose:
256 | items = tqdm(items)
257 | if openqa:
258 | results = []
259 | for _ in items:
260 | results.extend(tokenize_item_openqa(_, tokenizer))
261 | return results
262 | else:
263 | return [tokenize_item(_, tokenizer) for _ in items]
264 |
265 | def tokenize_data(dataset, bert_model_name="bert-large-cased-whole-word-masking", num_workers=10, save_path=None, openqa=False):
266 |
267 | tokenizer = BertTokenizer.from_pretrained(bert_model_name)
268 |
269 | chunk_size = len(dataset) // num_workers
270 | offsets = [
271 | _ * chunk_size for _ in range(0, num_workers)] + [len(dataset)]
272 | pool = Pool(processes=num_workers)
273 | print(f'Start multi-processing with {num_workers} workers....')
274 | results = [pool.apply_async(tokenize_items, args=(
275 | dataset[offsets[work_id]: offsets[work_id + 1]], tokenizer, True, openqa)) for work_id in range(num_workers)]
276 | outputs = [p.get() for p in results]
277 | samples = []
278 | for o in outputs:
279 | samples.extend(o)
280 |
281 | # check the average number of matched spans
282 | answer_nums = [len(item["starts"])
283 | for item in samples if item["no_answer"] == 0]
284 | print(f"Average number of matched answers: {np.mean(answer_nums)}...")
285 | print(f"Processed {len(samples)} examples...")
286 | if save_path:
287 | with open(save_path, 'w') as f:
288 | for s in samples:
289 | f.write(json.dumps(s) + "\n")
290 | else:
291 | return samples
292 |
293 | if __name__ == "__main__":
294 | import argparse
295 | parser = argparse.ArgumentParser()
296 | parser.add_argument(
297 | "--model-name", default="bert-base-uncased", type=str)
298 | parser.add_argument("--data", default="wq", type=str)
299 | parser.add_argument("--split", default="train", type=str)
300 | parser.add_argument("--topk", default=20, type=int)
301 | parser.add_argument("--filter", action="store_true", help="whether to filter no-answer QA pair")
302 | parser.add_argument("--dense-index", action="store_true")
303 | args = parser.parse_args()
304 |
305 | filter_ = True if "train" in args.split else False
306 |
307 | if args.dense_index:
308 | train_raw = load_openqa_dataset(
309 | f"../data/{args.data}/{args.data}-{args.split}-dense-final.txt", filter_no_answer=filter_)
310 | save_path = f"../data/{args.data}/{args.data}-{args.split}-dense-filtered-tokenized.txt" if filter_ else \
311 | f"../data/{args.data}/{args.data}-{args.split}-dense-tokenized.txt"
312 | else:
313 | train_raw = load_openqa_dataset(
314 | f"../data/{args.data}/{args.data}-{args.split}-openqa-p{args.topk}.txt", filter_no_answer=filter_)
315 | save_path = f"../data/{args.data}/{args.data}-{args.split}-openqa-filtered-tokenized-p{args.topk}-all-matched.txt" if filter_ else \
316 | f"../data/{args.data}/{args.data}-{args.split}-openqa-tokenized-p{args.topk}-all-matched.txt"
317 |
318 | train_tokenized = tokenize_data(train_raw, bert_model_name=args.model_name, save_path=save_path, openqa=True, num_workers=10)
319 |
--------------------------------------------------------------------------------
/qa/tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import six
24 | import tensorflow as tf
25 |
26 |
27 | def convert_tokens_to_ids(vocab, tokens):
28 | """Converts a sequence of tokens into ids using the vocab."""
29 | ids = []
30 | for token in tokens:
31 | ids.append(vocab[token])
32 | return ids
33 |
34 | def whitespace_tokenize(text):
35 | """Runs basic whitespace cleaning and splitting on a peice of text."""
36 | text = text.strip()
37 | if not text:
38 | return []
39 | tokens = text.split()
40 | return tokens
41 |
42 |
43 | def convert_to_unicode(text):
44 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
45 | if six.PY3:
46 | if isinstance(text, str):
47 | return text
48 | elif isinstance(text, bytes):
49 | return text.decode("utf-8", "ignore")
50 | else:
51 | raise ValueError("Unsupported string type: %s" % (type(text)))
52 | elif six.PY2:
53 | if isinstance(text, str):
54 | return text.decode("utf-8", "ignore")
55 | elif isinstance(text, unicode):
56 | return text
57 | else:
58 | raise ValueError("Unsupported string type: %s" % (type(text)))
59 | else:
60 | raise ValueError("Not running on Python2 or Python 3?")
61 |
62 |
63 | def _is_whitespace(char):
64 | """Checks whether `chars` is a whitespace character."""
65 | # \t, \n, and \r are technically contorl characters but we treat them
66 | # as whitespace since they are generally considered as such.
67 | if char == " " or char == "\t" or char == "\n" or char == "\r":
68 | return True
69 | cat = unicodedata.category(char)
70 | if cat == "Zs":
71 | return True
72 | return False
73 |
74 |
75 | def _is_control(char):
76 | """Checks whether `chars` is a control character."""
77 | # These are technically control characters but we count them as whitespace
78 | # characters.
79 | if char == "\t" or char == "\n" or char == "\r":
80 | return False
81 | cat = unicodedata.category(char)
82 | if cat.startswith("C"):
83 | return True
84 | return False
85 |
86 | class BasicTokenizer(object):
87 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
88 |
89 | def __init__(self, do_lower_case=True):
90 | """Constructs a BasicTokenizer.
91 | Args:
92 | do_lower_case: Whether to lower case the input.
93 | """
94 | self.do_lower_case = do_lower_case
95 |
96 | def tokenize(self, text):
97 | """Tokenizes a piece of text."""
98 | text = convert_to_unicode(text)
99 | text = self._clean_text(text)
100 | orig_tokens = whitespace_tokenize(text)
101 | split_tokens = []
102 | for token in orig_tokens:
103 | if self.do_lower_case:
104 | token = token.lower()
105 | token = self._run_strip_accents(token)
106 | split_tokens.extend(self._run_split_on_punc(token))
107 |
108 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
109 | return output_tokens
110 |
111 | def _run_strip_accents(self, text):
112 | """Strips accents from a piece of text."""
113 | text = unicodedata.normalize("NFD", text)
114 | output = []
115 | for char in text:
116 | cat = unicodedata.category(char)
117 | if cat == "Mn":
118 | continue
119 | output.append(char)
120 | return "".join(output)
121 |
122 | def _run_split_on_punc(self, text):
123 | """Splits punctuation on a piece of text."""
124 | chars = list(text)
125 | i = 0
126 | start_new_word = True
127 | output = []
128 | while i < len(chars):
129 | char = chars[i]
130 | if _is_punctuation(char):
131 | output.append([char])
132 | start_new_word = True
133 | else:
134 | if start_new_word:
135 | output.append([])
136 | start_new_word = False
137 | output[-1].append(char)
138 | i += 1
139 |
140 | return ["".join(x) for x in output]
141 |
142 | def _clean_text(self, text):
143 | """Performs invalid character removal and whitespace cleanup on text."""
144 | output = []
145 | for char in text:
146 | cp = ord(char)
147 | if cp == 0 or cp == 0xfffd or _is_control(char):
148 | continue
149 | if _is_whitespace(char):
150 | output.append(" ")
151 | else:
152 | output.append(char)
153 | return "".join(output)
154 |
155 |
156 | def _is_punctuation(char):
157 | """Checks whether `chars` is a punctuation character."""
158 | cp = ord(char)
159 | # We treat all non-letter/number ASCII as punctuation.
160 | # Characters such as "^", "$", and "`" are not in the Unicode
161 | # Punctuation class but we treat them as punctuation anyways, for
162 | # consistency.
163 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
164 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
165 | return True
166 | cat = unicodedata.category(char)
167 | if cat.startswith("P"):
168 | return True
169 | return False
170 |
171 |
172 | def process(s, tokenizer):
173 | try:
174 | return tokenizer.tokenize(s)
175 | except:
176 | print('failed on', s)
177 | raise
178 |
179 | if __name__ == "__main__":
180 | _is_whitespace("a")
181 |
--------------------------------------------------------------------------------
/qa/train.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | import json
4 | import os
5 | import random
6 | from tqdm import tqdm
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import DataLoader
10 | from torch.utils.data.distributed import DistributedSampler
11 | from datasets import QADataset, collate
12 | from bert_qa import BertForQuestionAnswering
13 | from transformers import AdamW, BertConfig, BertTokenizer
14 | from torch.utils.tensorboard import SummaryWriter
15 | from eval_utils import get_final_text
16 | from official_eval import metric_max_over_ground_truths, f1_score, exact_match_score
17 |
18 | from utils import move_to_cuda, convert_to_half, AverageMeter
19 | from config import get_args
20 |
21 | def main():
22 | args = get_args()
23 |
24 | if args.fp16:
25 | try:
26 | import apex
27 | apex.amp.register_half_function(torch, 'einsum')
28 | except ImportError:
29 | raise ImportError(
30 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
31 |
32 | # tb logger
33 | data_name = args.train_file.split("/")[-1].split('-')[0]
34 | model_name = f"{data_name}-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-{args.prefix}-lr{args.learning_rate}-{args.bert_model_name}"
35 | tb_logger = SummaryWriter(os.path.join(args.output_dir, "tflogs", model_name))
36 | args.output_dir = os.path.join(args.output_dir, model_name)
37 |
38 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
39 | print(f"output directory {args.output_dir} already exists and is not empty.")
40 | if not os.path.exists(args.output_dir):
41 | os.makedirs(args.output_dir, exist_ok=True)
42 |
43 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
44 | datefmt='%m/%d/%Y %H:%M:%S',
45 | level=logging.INFO,
46 | handlers=[logging.FileHandler(os.path.join(args.output_dir, "log.txt")),
47 | logging.StreamHandler()])
48 | logger = logging.getLogger(__name__)
49 | logger.info(args)
50 |
51 | if args.local_rank == -1 or args.no_cuda:
52 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
53 | n_gpu = torch.cuda.device_count()
54 | else:
55 | device = torch.device("cuda", args.local_rank)
56 | n_gpu = 1
57 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
58 | torch.distributed.init_process_group(backend='nccl')
59 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
60 |
61 | if args.accumulate_gradients < 1:
62 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
63 | args.accumulate_gradients))
64 |
65 | args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients)
66 | random.seed(args.seed)
67 | np.random.seed(args.seed)
68 | torch.manual_seed(args.seed)
69 | if n_gpu > 0:
70 | torch.cuda.manual_seed_all(args.seed)
71 |
72 | if not args.do_train and not args.do_predict:
73 | raise ValueError("At least one of `do_train` or `do_predict` must be True.")
74 |
75 | if args.do_train:
76 | if not args.train_file:
77 | raise ValueError(
78 | "If `do_train` is True, then `train_file` must be specified.")
79 | if not args.predict_file:
80 | raise ValueError(
81 | "If `do_train` is True, then `predict_file` must be specified.")
82 |
83 | if args.do_predict:
84 | if not args.predict_file:
85 | raise ValueError(
86 | "If `do_predict` is True, then `predict_file` must be specified.")
87 |
88 | bert_config = BertConfig.from_pretrained(args.bert_model_name)
89 | model = BertForQuestionAnswering(bert_config)
90 | tokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
91 |
92 | if args.do_train and args.max_seq_length > bert_config.max_position_embeddings:
93 | raise ValueError(
94 | "Cannot use sequence length %d because the BERT model "
95 | "was only trained up to sequence length %d" %
96 | (args.max_seq_length, bert_config.max_position_embeddings))
97 |
98 | eval_dataset = QADataset(
99 | tokenizer, args.predict_file, args.max_query_length, args.max_seq_length)
100 | eval_dataloader = DataLoader(eval_dataset, batch_size=args.predict_batch_size, collate_fn=collate, pin_memory=True)
101 | logger.info(f"Num of dev batches: {len(eval_dataloader)}")
102 |
103 | if args.init_checkpoint is not None:
104 | logger.info("Loading from {}".format(args.init_checkpoint))
105 | if args.do_train and args.init_checkpoint == "":
106 | model = BertForQuestionAnswering.from_pretrained(
107 | args.bert_model_name)
108 | else:
109 | state_dict = torch.load(args.init_checkpoint)
110 | filter = lambda x: x[7:] if x.startswith('module.') else x
111 | state_dict = {filter(k):v for (k,v) in state_dict.items()}
112 | model.load_state_dict(state_dict)
113 | model.to(device)
114 |
115 | print(f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
116 |
117 | if args.do_train:
118 | no_decay = ['bias', 'LayerNorm.weight']
119 | optimizer_parameters = [
120 | {'params': [p for n, p in model.named_parameters() if not any(
121 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
122 | {'params': [p for n, p in model.named_parameters() if any(
123 | nd in n for nd in no_decay)], 'weight_decay': 0.0}
124 | ]
125 | optimizer = AdamW(optimizer_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
126 |
127 | if args.fp16:
128 | try:
129 | from apex import amp
130 | except ImportError:
131 | raise ImportError(
132 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
133 | model, optimizer = amp.initialize(
134 | model, optimizer, opt_level=args.fp16_opt_level)
135 | else:
136 | if args.fp16:
137 | try:
138 | from apex import amp
139 | except ImportError:
140 | raise ImportError(
141 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
142 | model = amp.initialize(model, opt_level=args.fp16_opt_level)
143 |
144 | if args.local_rank != -1:
145 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
146 | output_device=args.local_rank)
147 | elif n_gpu > 1:
148 | model = torch.nn.DataParallel(model)
149 |
150 | if args.do_train:
151 | global_step = 0
152 | best_f1 = (-1, -1)
153 | wait_step = 0
154 | stop_training = False
155 | train_loss_meter = AverageMeter()
156 | logger.info('Start training....')
157 | model.train()
158 | train_dataset = QADataset(tokenizer, args.train_file, args.max_query_length, args.max_seq_length)
159 | train_dataloader = DataLoader(
160 | train_dataset, batch_size=args.train_batch_size, collate_fn=collate, shuffle=True, pin_memory=True)
161 |
162 | for epoch in range(int(args.num_train_epochs)):
163 |
164 | for step, batch in enumerate(tqdm(train_dataloader)):
165 | batch = move_to_cuda(batch)
166 | outputs = model(batch["net_input"])
167 | loss = outputs["span_loss"]
168 |
169 | if n_gpu > 1:
170 | loss = loss.mean() # mean() to average on multi-gpu.
171 |
172 | if args.gradient_accumulation_steps > 1:
173 | loss = loss / args.gradient_accumulation_steps
174 |
175 | if args.fp16:
176 | with amp.scale_loss(loss, optimizer) as scaled_loss:
177 | scaled_loss.backward()
178 | else:
179 | loss.backward()
180 |
181 | train_loss_meter.update(loss.item())
182 | tb_logger.add_scalar('batch_train_loss', loss.item(), global_step)
183 |
184 | if (step + 1) % args.gradient_accumulation_steps == 0:
185 | if args.fp16:
186 | torch.nn.utils.clip_grad_norm_(
187 | amp.master_params(optimizer), args.max_grad_norm)
188 | else:
189 | torch.nn.utils.clip_grad_norm_(
190 | model.parameters(), args.max_grad_norm)
191 | optimizer.step() # We have accumulated enought gradients
192 | model.zero_grad()
193 | global_step += 1
194 |
195 | if global_step % args.eval_period == 0:
196 | f1 = predict(logger, args, model, eval_dataloader, device, fp16=args.efficient_eval)
197 | logger.info("Step %d Train loss %.2f EM %.2f F1 %.2f on epoch=%d" % (
198 | global_step, train_loss_meter.avg, f1[0]*100, f1[1]*100, epoch))
199 |
200 | tb_logger.add_scalar('dev_f1', f1[0]*100, global_step)
201 | tb_logger.add_scalar('dev_em', f1[1]*100, global_step)
202 |
203 | if best_f1 < f1:
204 | logger.info("Saving model with best EM: %.2f (F1 %.2f) -> %.2f (F1 %.2f) on epoch=%d" % \
205 | (best_f1[1]*100, best_f1[0]*100, f1[1]*100, f1[0]*100, epoch))
206 | model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()}
207 | torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt"))
208 | model = model.to(device)
209 | best_f1 = f1
210 | wait_step = 0
211 | stop_training = False
212 | else:
213 | wait_step += 1
214 | if wait_step == args.wait_step:
215 | stop_training = True
216 |
217 | f1 = predict(logger, args, model, eval_dataloader,
218 | device, fp16=args.efficient_eval)
219 | logger.info("Step %d Train loss %.2f EM %.2f F1 %.2f on epoch=%d" % (
220 | global_step, train_loss_meter.avg, f1[0]*100, f1[1]*100, epoch))
221 | tb_logger.add_scalar('dev_f1', f1[0]*100, global_step)
222 | tb_logger.add_scalar('dev_em', f1[1]*100, global_step)
223 | logger.info(f"average training loss {train_loss_meter.avg}")
224 |
225 |
226 | if stop_training:
227 | break
228 |
229 | logger.info("Training finished!")
230 |
231 | # elif args.do_predict:
232 | # if type(model)==list:
233 | # model = [m.eval() for m in model]
234 | # else:
235 | # model.eval()
236 | # f1 = predict(logger, args, model, eval_dataloader, eval_examples, eval_features,
237 | # device, fp16=args.efficient_eval, write_prediction=False)
238 | # logger.info(f"test performance {f1}")
239 | # print(f1)
240 |
241 |
242 | def predict(logger, args, model, eval_dataloader, device, fp16=False):
243 | model.eval()
244 | all_results = []
245 |
246 | if fp16:
247 | model.half()
248 |
249 | qid2results = {}
250 | for batch in tqdm(eval_dataloader):
251 | batch_to_feed = move_to_cuda(batch["net_input"])
252 | if fp16:
253 | batch_to_feed = convert_to_half(batch_to_feed)
254 | with torch.no_grad():
255 | results = model(batch_to_feed)
256 | batch_start_logits = results["start_logits"]
257 | batch_end_logits = results["end_logits"]
258 | question_mask = batch_to_feed["paragraph_mask"].ne(1)
259 | outs = [o.float().masked_fill(question_mask, -1e10).type_as(o)
260 | for o in [batch_start_logits, batch_end_logits]]
261 |
262 | span_scores = outs[0][:,:,None] + outs[1][:,None]
263 | max_answer_lens = 20
264 | max_seq_len = span_scores.size(1)
265 | span_mask = np.tril(np.triu(np.ones((max_seq_len, max_seq_len)), 0), max_answer_lens)
266 | span_mask = span_scores.data.new(max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask))
267 | span_scores_masked = span_scores.float().masked_fill((1 -
268 | span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores)
269 |
270 | start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1]
271 | end_position = span_scores_masked.max(dim=2)[1].gather(1, start_position.unsqueeze(1)).squeeze(1)
272 |
273 | para_offset = batch['para_offset']
274 | start_position_ = list(np.array(start_position.tolist()) - np.array(para_offset))
275 | end_position_ = list(np.array(end_position.tolist()) - np.array(para_offset))
276 |
277 | for idx, qid in enumerate(batch['id']):
278 | start = start_position_[idx]
279 | end = end_position_[idx]
280 | tok_to_orig_index = batch['tok_to_orig_index'][idx]
281 | doc_tokens = batch['doc_tokens'][idx]
282 | wp_tokens = batch['wp_tokens'][idx]
283 | orig_doc_start = tok_to_orig_index[start]
284 | orig_doc_end = tok_to_orig_index[end]
285 | orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
286 | tok_tokens = wp_tokens[start:end+1]
287 | tok_text = " ".join(tok_tokens)
288 | tok_text = tok_text.replace(" ##", "")
289 | tok_text = tok_text.replace("##", "")
290 | tok_text = tok_text.strip()
291 | tok_text = " ".join(tok_text.split())
292 | orig_text = " ".join(orig_tokens)
293 | final_text = get_final_text(tok_text, orig_text, logger, do_lower_case=args.do_lower_case, verbose_logging=False)
294 | qid2results[qid] = [final_text, batch['true_answers'][idx]]
295 |
296 | f1s = [metric_max_over_ground_truths(f1_score, item[0], item[1]) for item in qid2results.values()]
297 | ems = [metric_max_over_ground_truths(exact_match_score, item[0], item[1]) for item in qid2results.values()]
298 |
299 | print(f"evaluated {len(f1s)} examples...")
300 | if fp16:
301 | model.float()
302 | model.train()
303 |
304 | return (np.mean(f1s), np.mean(ems))
305 |
306 |
307 | if __name__ == "__main__":
308 | main()
309 |
--------------------------------------------------------------------------------
/qa/train_dense_qa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=3 python train_retrieve_qa.py \
5 | --do_train \
6 | --prefix dense-index-trec-nocluser-70k \
7 | --eval_period -1 \
8 | --bert_model_name bert-base-uncased \
9 | --train_batch_size 5 \
10 | --gradient_accumulation_steps 1 \
11 | --accumulate_gradients 1 \
12 | --efficient_eval \
13 | --learning_rate 1e-5 \
14 | --fp16 \
15 | --raw-train-data ../data/trec-train.txt \
16 | --raw-eval-data ../data/trec-dev.txt \
17 | --seed 3 \
18 | --retriever-path ../retrieval/logs/retrieve_train.txt-seed31-bsz640-fp16True-baseline_no_cluster_from_failed_continue-lr1e-05-bert-base-uncased-filterTrue/checkpoint_40000.pt \
19 | --index-path ../retrieval/encodings/para_embed.npy \
20 | --fix-para-encoder \
21 | --num_train_epochs 10 \
22 | --matched-para-path ../data/trec_train_matched_20000.txt \
23 | --regex \
24 | --shared-norm \
25 | # --separate \
26 |
--------------------------------------------------------------------------------
/qa/train_retrieve_qa.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | import json
4 | import os
5 | import random
6 | from tqdm import tqdm
7 | import numpy as np
8 | import torch
9 | from copy import deepcopy
10 |
11 | from torch.utils.data import DataLoader
12 | from torch.utils.data.distributed import DistributedSampler
13 | from bert_retrieve_qa import BertRetrieveQA
14 | from transformers import AdamW, BertConfig, BertTokenizer
15 | from torch.utils.tensorboard import SummaryWriter
16 | from eval_utils import get_final_text
17 | from official_eval import metric_max_over_ground_truths, exact_match_score, regex_match_score
18 | from online_sampler import OnlineSampler
19 |
20 |
21 | from utils import move_to_cuda, convert_to_half, AverageMeter, DocDB
22 | from config import get_args
23 |
24 | from collections import defaultdict, namedtuple
25 | import torch.nn.functional as F
26 |
27 | def load_saved(model, path):
28 | state_dict = torch.load(path)
29 | def filter(x): return x[7:] if x.startswith('module.') else x
30 | state_dict = {filter(k): v for (k, v) in state_dict.items()}
31 | model.load_state_dict(state_dict)
32 | return model
33 |
34 |
35 | def main():
36 | args = get_args()
37 |
38 | if args.fp16:
39 | try:
40 | import apex
41 | apex.amp.register_half_function(torch, 'einsum')
42 | except ImportError:
43 | raise ImportError(
44 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
45 |
46 | # tb logger
47 | data_name = args.train_file.split("/")[-1].split('-')[0]
48 | model_name = f"dense-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-{args.prefix}-lr{args.learning_rate}-{args.bert_model_name}-qdrop{args.qa_drop}-sn{args.shared_norm}-sep{args.separate}-as{args.add_select}-noearly{args.drop_early}"
49 | if args.do_train:
50 | tb_logger = SummaryWriter(os.path.join(
51 | args.output_dir, "tflogs", "dense", model_name))
52 | args.output_dir = os.path.join(args.output_dir, model_name)
53 |
54 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
55 | print(
56 | f"output directory {args.output_dir} already exists and is not empty.")
57 | if not os.path.exists(args.output_dir):
58 | os.makedirs(args.output_dir, exist_ok=True)
59 |
60 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
61 | datefmt='%m/%d/%Y %H:%M:%S',
62 | level=logging.INFO,
63 | handlers=[logging.FileHandler(os.path.join(args.output_dir, "log.txt")),
64 | logging.StreamHandler()])
65 | logger = logging.getLogger(__name__)
66 | logger.info(args)
67 |
68 | if args.local_rank == -1 or args.no_cuda:
69 | device = torch.device(
70 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
71 | n_gpu = torch.cuda.device_count()
72 | else:
73 | device = torch.device("cuda", args.local_rank)
74 | n_gpu = 1
75 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
76 | torch.distributed.init_process_group(backend='nccl')
77 | logger.info("device %s n_gpu %d distributed training %r",
78 | device, n_gpu, bool(args.local_rank != -1))
79 |
80 | if args.accumulate_gradients < 1:
81 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
82 | args.accumulate_gradients))
83 |
84 | args.train_batch_size = int(
85 | args.train_batch_size / args.accumulate_gradients)
86 | random.seed(args.seed)
87 | np.random.seed(args.seed)
88 | torch.manual_seed(args.seed)
89 | if n_gpu > 0:
90 | torch.cuda.manual_seed_all(args.seed)
91 |
92 | if not args.do_train and not args.do_predict:
93 | raise ValueError(
94 | "At least one of `do_train` or `do_predict` must be True.")
95 |
96 | if args.do_train:
97 | if not args.train_file:
98 | raise ValueError(
99 | "If `do_train` is True, then `train_file` must be specified.")
100 | if not args.predict_file:
101 | raise ValueError(
102 | "If `do_train` is True, then `predict_file` must be specified.")
103 |
104 | if args.do_predict:
105 | if not args.predict_file:
106 | raise ValueError(
107 | "If `do_predict` is True, then `predict_file` must be specified.")
108 |
109 | bert_config = BertConfig.from_pretrained(args.bert_model_name)
110 | model = BertRetrieveQA(bert_config, args)
111 | tokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
112 |
113 | logger.info("Loading para db and pretrained index ...")
114 | para_db = DocDB(args.db_path)
115 | para_embed = np.load(args.index_path).astype('float32')
116 |
117 | if args.do_train and args.max_seq_length > bert_config.max_position_embeddings:
118 | raise ValueError(
119 | "Cannot use sequence length %d because the BERT model "
120 | "was only trained up to sequence length %d" %
121 | (args.max_seq_length, bert_config.max_position_embeddings))
122 |
123 | exact_search = True if args.do_predict else False
124 | eval_dataloader = OnlineSampler(args.raw_eval_data, tokenizer, args.max_query_length,
125 | args.max_seq_length, para_db, para_embed, exact_search=exact_search, cased=args.use_spanbert, regex=args.regex)
126 |
127 | if args.init_checkpoint != "":
128 | model = load_saved(model, args.init_checkpoint)
129 |
130 | model.to(device)
131 | logger.info(
132 | f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
133 |
134 | if args.fix_para_encoder:
135 | model.freeze_c_encoder()
136 |
137 | if args.do_train:
138 | no_decay = ['bias', 'LayerNorm.weight']
139 | optimizer_parameters = [
140 | {'params': [p for n, p in model.named_parameters() if not any(
141 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
142 | {'params': [p for n, p in model.named_parameters() if any(
143 | nd in n for nd in no_decay)], 'weight_decay': 0.0}
144 | ]
145 | optimizer = AdamW(optimizer_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
146 |
147 | if args.fp16:
148 | try:
149 | from apex import amp
150 | except ImportError:
151 | raise ImportError(
152 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
153 | model, optimizer = amp.initialize(
154 | model, optimizer, opt_level=args.fp16_opt_level)
155 | else:
156 | if args.fp16:
157 | try:
158 | from apex import amp
159 | except ImportError:
160 | raise ImportError(
161 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
162 | model = amp.initialize(model, opt_level=args.fp16_opt_level)
163 |
164 | if args.local_rank != -1:
165 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
166 | output_device=args.local_rank)
167 | elif n_gpu > 1:
168 | model = torch.nn.DataParallel(model)
169 |
170 | if args.do_train:
171 | global_step = 0 # gradient update step
172 | batch_step = 0 # forward batch count
173 | best_em = 0
174 | wait_step = 0
175 | stop_training = False
176 | train_loss_meter = AverageMeter()
177 | logger.info('Start training....')
178 | model.train()
179 | train_dataloader = OnlineSampler(
180 | args.raw_train_data, tokenizer, args.max_query_length, args.max_seq_length, para_db, para_embed, matched_para_path=args.matched_para_path, cased=args.use_spanbert, regex=args.regex)
181 | for epoch in range(int(args.num_train_epochs)):
182 | train_dataloader.shuffle()
183 | failed_retrival = 0
184 | for batch in tqdm(train_dataloader.load(model.retriever, k=args.train_batch_size), total=len(train_dataloader)):
185 | batch_step += 1
186 | if batch == {}:
187 | failed_retrival += 1
188 | continue
189 | batch = move_to_cuda(batch)
190 | outputs = model(batch["net_input"])
191 | loss = outputs["loss"]
192 | if n_gpu > 1:
193 | loss = loss.mean() # mean() to average on multi-gpu.
194 | if args.gradient_accumulation_steps > 1:
195 | loss = loss / args.gradient_accumulation_steps
196 | if args.fp16:
197 | with amp.scale_loss(loss, optimizer) as scaled_loss:
198 | scaled_loss.backward()
199 | else:
200 | loss.backward()
201 |
202 | train_loss_meter.update(loss.item())
203 | tb_logger.add_scalar('batch_train_loss',
204 | loss.item(), global_step)
205 | tb_logger.add_scalar('smoothed_train_loss',
206 | train_loss_meter.avg, global_step)
207 |
208 | if (batch_step + 1) % args.gradient_accumulation_steps == 0:
209 | if args.fp16:
210 | torch.nn.utils.clip_grad_norm_(
211 | amp.master_params(optimizer), args.max_grad_norm)
212 | else:
213 | torch.nn.utils.clip_grad_norm_(
214 | model.parameters(), args.max_grad_norm)
215 | optimizer.step() # We have accumulated enought gradients
216 | model.zero_grad()
217 | global_step += 1
218 |
219 | if args.eval_period != -1 and global_step % args.eval_period == 0:
220 | em = predict(args, model, eval_dataloader,
221 | device, fp16=args.efficient_eval)
222 | logger.info("Step %d Train loss %.2f EM %.2f on epoch=%d" % (
223 | global_step, train_loss_meter.avg, em*100, epoch))
224 |
225 | tb_logger.add_scalar('dev_em', em*100, global_step)
226 |
227 | if best_em < em:
228 | logger.info("Saving model with best EM: %.2f -> EM %.2f on epoch=%d" %
229 | (best_em*100, em*100, epoch))
230 | model_state_dict = {k: v.cpu() for (
231 | k, v) in model.state_dict().items()}
232 | torch.save(model_state_dict, os.path.join(
233 | args.output_dir, "best-model.pt"))
234 | model = model.to(device)
235 | best_em = em
236 | wait_step = 0
237 | stop_training = False
238 | else:
239 | wait_step += 1
240 | if wait_step == args.wait_step:
241 | stop_training = True
242 |
243 | logger.info(f"Failed retrieval: {failed_retrival}/{len(train_dataloader)} ...")
244 | em = predict(args, model, eval_dataloader,
245 | device, fp16=args.efficient_eval)
246 | tb_logger.add_scalar('dev_em', em*100, global_step)
247 | logger.info(f"average training loss {train_loss_meter.avg}")
248 | if best_em < em:
249 | logger.info("Saving model with best EM: %.2f -> %.2f on epoch=%d" %
250 | (best_em*100, em*100, epoch))
251 | torch.save(model.state_dict(), os.path.join(
252 | args.output_dir, "best-model.pt"))
253 | model = model.to(device)
254 | best_em = em
255 | wait_step = 0
256 |
257 | if epoch > 15:
258 | logger.info(f"Saving model after epoch {epoch + 1}")
259 | torch.save(model.state_dict(), os.path.join(
260 | args.output_dir, f"model-{epoch+1}-{em}.pt"))
261 |
262 | if stop_training:
263 | break
264 |
265 | logger.info("Training finished!")
266 |
267 | elif args.do_predict:
268 | f1 = predict(args, model, eval_dataloader,
269 | device, fp16=args.efficient_eval)
270 | logger.info(f"test performance {f1}")
271 | print(f1)
272 |
273 |
274 | def predict(args, model, eval_dataloader, device, fp16=False):
275 | model.eval()
276 | if fp16:
277 | model.half()
278 |
279 | all_results = []
280 | PredictionMeta = collections.namedtuple(
281 | "Prediction", ["text", "rank_score", "passage", "span_score", "question"])
282 | qid2results = defaultdict(list)
283 | qid2ground = {}
284 |
285 | for batch in tqdm(eval_dataloader.eval_load(model.retriever, args.eval_k), total=len(eval_dataloader)):
286 |
287 | batch_to_feed = move_to_cuda(batch["net_input"])
288 | if fp16:
289 | batch_to_feed = convert_to_half(batch_to_feed)
290 | with torch.no_grad():
291 | results = model(batch_to_feed)
292 | batch_start_logits = results["start_logits"]
293 | batch_end_logits = results["end_logits"]
294 | batch_rank_logits = results["rank_logits"]
295 | if args.add_select:
296 | batch_select_logits = results["select_logits"]
297 |
298 | outs = [batch_start_logits, batch_end_logits]
299 |
300 | span_scores = outs[0][:, :, None] + outs[1][:, None]
301 | max_answer_lens = 10
302 | max_seq_len = span_scores.size(1)
303 | span_mask = np.tril(
304 | np.triu(np.ones((max_seq_len, max_seq_len)), 0), max_answer_lens)
305 | span_mask = span_scores.data.new(
306 | max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask))
307 | span_scores_masked = span_scores.float().masked_fill((1 -
308 | span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores)
309 |
310 | start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1]
311 | end_position = span_scores_masked.max(dim=2)[1].gather(
312 | 1, start_position.unsqueeze(1)).squeeze(1)
313 |
314 | answer_scores = span_scores_masked.max(dim=2)[0].max(dim=1)[0].tolist()
315 |
316 | if args.add_select:
317 | rank_logits = batch_select_logits.view(-1).tolist()
318 | else:
319 | rank_logits = batch_rank_logits.view(-1).tolist()
320 |
321 | para_offset = batch['para_offset']
322 | start_position_ = list(
323 | np.array(start_position.tolist()) - np.array(para_offset))
324 | end_position_ = list(
325 | np.array(end_position.tolist()) - np.array(para_offset))
326 |
327 | for idx, qid in enumerate(batch['id']):
328 | start = start_position_[idx]
329 | end = end_position_[idx]
330 | rank_score = rank_logits[idx]
331 | span_score = answer_scores[idx]
332 | tok_to_orig_index = batch['tok_to_orig_index'][idx]
333 | doc_tokens = batch['doc_tokens'][idx]
334 | wp_tokens = batch['wp_tokens'][idx]
335 | orig_doc_start = tok_to_orig_index[start]
336 | orig_doc_end = tok_to_orig_index[end]
337 | orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
338 | tok_tokens = wp_tokens[start:end+1]
339 | tok_text = " ".join(tok_tokens)
340 | tok_text = tok_text.replace(" ##", "")
341 | tok_text = tok_text.replace("##", "")
342 | tok_text = tok_text.strip()
343 | tok_text = " ".join(tok_text.split())
344 | orig_text = " ".join(orig_tokens)
345 | final_text = get_final_text(
346 | tok_text, orig_text, do_lower_case=args.do_lower_case, verbose_logging=False)
347 | question = batch["q"][idx]
348 | qid2results[qid].append(
349 | PredictionMeta(
350 | text=final_text,
351 | rank_score=rank_score,
352 | span_score=span_score,
353 | passage=" ".join(doc_tokens),
354 | question=question,
355 | )
356 | )
357 | qid2ground[qid] = batch["true_answers"][idx]
358 |
359 | if args.save_all:
360 | print("Saving all prediction results ...")
361 | with open(f"{args.prefix}_all.json", "w") as g:
362 | json.dump(qid2results, g)
363 | with open(f"{args.prefix}_ground.json", "w") as g:
364 | json.dump(qid2ground, g)
365 |
366 | ## linear combination tuning on dev data
367 | best_em = 0
368 | for alpha in [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.5, 0.55, 0.6, 0.7, 0.8, 0.9, 1]:
369 | results_to_save = []
370 | ems = []
371 | for qid in qid2results.keys():
372 | qid2results[qid] = sorted(
373 | qid2results[qid], key=lambda x: alpha*x.span_score + (1 - alpha)*x.rank_score, reverse=True)
374 | match_fn = regex_match_score if args.regex else exact_match_score
375 | ems.append(metric_max_over_ground_truths(
376 | match_fn, qid2results[qid][0].text, qid2ground[qid]))
377 | results_to_save.append({
378 | "question": qid2results[qid][0].question,
379 | "para": qid2results[qid][0].passage,
380 | "answer": qid2results[qid][0].text,
381 | "rank_score": qid2results[qid][0].rank_score,
382 | "gold": qid2ground[qid],
383 | "em": ems[-1]
384 | })
385 | em = np.mean(ems)
386 | if em > best_em:
387 | best_em = em
388 | print(f"evaluated {len(ems)} examples...")
389 | print(f"alpha: {alpha}; avg. EM: {em}")
390 |
391 | if args.save_pred:
392 | with open(f"{args.prefix}_{alpha}.json", "w") as g:
393 | for line in results_to_save:
394 | g.write(json.dumps(line) + "\n")
395 |
396 | if type(model) != list:
397 | if fp16:
398 | model.float()
399 | model.train()
400 |
401 | return best_em
402 |
403 |
404 | if __name__ == "__main__":
405 | main()
406 |
407 |
--------------------------------------------------------------------------------
/qa/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import sqlite3
3 | import torch
4 | import unicodedata
5 |
6 | def move_to_cuda(sample):
7 | if len(sample) == 0:
8 | return {}
9 |
10 | def _move_to_cuda(maybe_tensor):
11 | if torch.is_tensor(maybe_tensor):
12 | return maybe_tensor.cuda()
13 | elif isinstance(maybe_tensor, dict):
14 | return {
15 | key: _move_to_cuda(value)
16 | for key, value in maybe_tensor.items()
17 | }
18 | elif isinstance(maybe_tensor, list):
19 | return [_move_to_cuda(x) for x in maybe_tensor]
20 | else:
21 | return maybe_tensor
22 |
23 | return _move_to_cuda(sample)
24 |
25 | def convert_to_half(sample):
26 | if len(sample) == 0:
27 | return {}
28 |
29 | def _convert_to_half(maybe_floatTensor):
30 | if torch.is_tensor(maybe_floatTensor) and maybe_floatTensor.type() == "torch.FloatTensor":
31 | return maybe_floatTensor.half()
32 | elif isinstance(maybe_floatTensor, dict):
33 | return {
34 | key: _convert_to_half(value)
35 | for key, value in maybe_floatTensor.items()
36 | }
37 | elif isinstance(maybe_floatTensor, list):
38 | return [_convert_to_half(x) for x in maybe_floatTensor]
39 | else:
40 | return maybe_floatTensor
41 |
42 | return _convert_to_half(sample)
43 |
44 |
45 | class AverageMeter(object):
46 | """Computes and stores the average and current value"""
47 |
48 | def __init__(self):
49 | self.reset()
50 |
51 | def reset(self):
52 | self.val = 0
53 | self.avg = 0
54 | self.sum = 0
55 | self.count = 0
56 |
57 | def update(self, val, n=1):
58 | self.val = val
59 | self.sum += val * n
60 | self.count += n
61 | self.avg = self.sum / self.count
62 |
63 |
64 | def normalize(text):
65 | """Resolve different type of unicode encodings."""
66 | return unicodedata.normalize('NFD', text)
67 |
68 |
69 | def load_saved(model, path):
70 | state_dict = torch.load(path)
71 | def filter(x): return x[7:] if x.startswith('module.') else x
72 | state_dict = {filter(k): v for (k, v) in state_dict.items()}
73 | model.load_state_dict(state_dict)
74 | return model
75 |
76 | class DocDB(object):
77 | """Sqlite backed document storage.
78 |
79 | Implements get_doc_text(doc_id).
80 | """
81 |
82 | def __init__(self, db_path=None):
83 | self.path = db_path
84 | self.connection = sqlite3.connect(self.path, check_same_thread=False)
85 |
86 | def __enter__(self):
87 | return self
88 |
89 | def __exit__(self, *args):
90 | self.close()
91 |
92 |
93 | def close(self):
94 | """Close the connection to the database."""
95 | self.connection.close()
96 |
97 | def get_doc_ids(self):
98 | """Fetch all ids of docs stored in the db."""
99 | cursor = self.connection.cursor()
100 | cursor.execute("SELECT id FROM documents")
101 | results = [r[0] for r in cursor.fetchall()]
102 | cursor.close()
103 | return results
104 |
105 | def get_doc_text(self, doc_id):
106 | """Fetch the raw text of the doc for 'doc_id'."""
107 | cursor = self.connection.cursor()
108 | cursor.execute(
109 | "SELECT text FROM documents WHERE id = ?",
110 | (normalize(doc_id),)
111 | )
112 | result = cursor.fetchone()
113 | cursor.close()
114 | return result if result is None else result[0]
115 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fairseq==0.9.0
2 | faiss-cpu==1.6.3
3 | gdown==3.10.3
4 | joblib==0.13.2
5 | tensorboard==2.0.2
6 | tensorboardX==2.0
7 | tensorflow-estimator==2.0.1
8 | tensorflow-gpu==2.0.1
9 | torch==1.4.0
10 | torchvision==0.5.0
11 | tqdm==4.36.1
12 | transformers==2.5.1
13 |
--------------------------------------------------------------------------------
/retrieval/basic_tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Base tokenizer/tokens classes and utilities."""
8 |
9 | import copy
10 |
11 |
12 |
13 | class Tokens(object):
14 | """A class to represent a list of tokenized text."""
15 | TEXT = 0
16 | TEXT_WS = 1
17 | SPAN = 2
18 | POS = 3
19 | LEMMA = 4
20 | NER = 5
21 |
22 | def __init__(self, data, annotators, opts=None):
23 | self.data = data
24 | self.annotators = annotators
25 | self.opts = opts or {}
26 |
27 | def __len__(self):
28 | """The number of tokens."""
29 | return len(self.data)
30 |
31 | def slice(self, i=None, j=None):
32 | """Return a view of the list of tokens from [i, j)."""
33 | new_tokens = copy.copy(self)
34 | new_tokens.data = self.data[i: j]
35 | return new_tokens
36 |
37 | def untokenize(self):
38 | """Returns the original text (with whitespace reinserted)."""
39 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
40 |
41 | def words(self, uncased=False):
42 | """Returns a list of the text of each token
43 |
44 | Args:
45 | uncased: lower cases text
46 | """
47 | if uncased:
48 | return [t[self.TEXT].lower() for t in self.data]
49 | else:
50 | return [t[self.TEXT] for t in self.data]
51 |
52 | def offsets(self):
53 | """Returns a list of [start, end) character offsets of each token."""
54 | return [t[self.SPAN] for t in self.data]
55 |
56 | def pos(self):
57 | """Returns a list of part-of-speech tags of each token.
58 | Returns None if this annotation was not included.
59 | """
60 | if 'pos' not in self.annotators:
61 | return None
62 | return [t[self.POS] for t in self.data]
63 |
64 | def lemmas(self):
65 | """Returns a list of the lemmatized text of each token.
66 | Returns None if this annotation was not included.
67 | """
68 | if 'lemma' not in self.annotators:
69 | return None
70 | return [t[self.LEMMA] for t in self.data]
71 |
72 | def entities(self):
73 | """Returns a list of named-entity-recognition tags of each token.
74 | Returns None if this annotation was not included.
75 | """
76 | if 'ner' not in self.annotators:
77 | return None
78 | return [t[self.NER] for t in self.data]
79 |
80 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
81 | """Returns a list of all ngrams from length 1 to n.
82 |
83 | Args:
84 | n: upper limit of ngram length
85 | uncased: lower cases text
86 | filter_fn: user function that takes in an ngram list and returns
87 | True or False to keep or not keep the ngram
88 | as_string: return the ngram as a string vs list
89 | """
90 | def _skip(gram):
91 | if not filter_fn:
92 | return False
93 | return filter_fn(gram)
94 |
95 | words = self.words(uncased)
96 | ngrams = [(s, e + 1)
97 | for s in range(len(words))
98 | for e in range(s, min(s + n, len(words)))
99 | if not _skip(words[s:e + 1])]
100 |
101 | # Concatenate into strings
102 | if as_strings:
103 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
104 |
105 | return ngrams
106 |
107 | def entity_groups(self):
108 | """Group consecutive entity tokens with the same NER tag."""
109 | entities = self.entities()
110 | if not entities:
111 | return None
112 | non_ent = self.opts.get('non_ent', 'O')
113 | groups = []
114 | idx = 0
115 | while idx < len(entities):
116 | ner_tag = entities[idx]
117 | # Check for entity tag
118 | if ner_tag != non_ent:
119 | # Chomp the sequence
120 | start = idx
121 | while (idx < len(entities) and entities[idx] == ner_tag):
122 | idx += 1
123 | groups.append((self.slice(start, idx).untokenize(), ner_tag))
124 | else:
125 | idx += 1
126 | return groups
127 |
128 |
129 | class Tokenizer(object):
130 | """Base tokenizer class.
131 | Tokenizers implement tokenize, which should return a Tokens class.
132 | """
133 |
134 | def tokenize(self, text):
135 | raise NotImplementedError
136 |
137 | def shutdown(self):
138 | pass
139 |
140 | def __del__(self):
141 | self.shutdown()
142 |
143 |
144 | import regex
145 | import logging
146 |
147 | logger = logging.getLogger(__name__)
148 |
149 |
150 | class RegexpTokenizer(Tokenizer):
151 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*'
152 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)'
153 | r'\.(?=\p{Z})')
154 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)'
155 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++'
156 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM)
157 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't"
158 | CONTRACTION1 = r"can(?=not\b)"
159 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b"
160 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})'
161 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})'
162 | END_DQUOTE = r'(?%s)|(?P%s)|(?P%s)|(?P%s)|(?P%s)|'
177 | '(?P%s)|(?P%s)|(?P%s)|(?P%s)|'
178 | '(?P%s)|(?P%s)|(?P%s)|(?P%s)|'
179 | '(?%s)|(?P%s)|(?P%s)' %
180 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN,
181 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2,
182 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE,
183 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT,
184 | self.NON_WS),
185 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
186 | )
187 | if len(kwargs.get('annotators', {})) > 0:
188 | logger.warning('%s only tokenizes! Skipping annotators: %s' %
189 | (type(self).__name__, kwargs.get('annotators')))
190 | self.annotators = set()
191 | self.substitutions = kwargs.get('substitutions', True)
192 |
193 | def tokenize(self, text):
194 | data = []
195 | matches = [m for m in self._regexp.finditer(text)]
196 | for i in range(len(matches)):
197 | # Get text
198 | token = matches[i].group()
199 |
200 | # Make normalizations for special token types
201 | if self.substitutions:
202 | groups = matches[i].groupdict()
203 | if groups['sdquote']:
204 | token = "``"
205 | elif groups['edquote']:
206 | token = "''"
207 | elif groups['ssquote']:
208 | token = "`"
209 | elif groups['esquote']:
210 | token = "'"
211 | elif groups['dash']:
212 | token = '--'
213 | elif groups['ellipses']:
214 | token = '...'
215 |
216 | # Get whitespace
217 | span = matches[i].span()
218 | start_ws = span[0]
219 | if i + 1 < len(matches):
220 | end_ws = matches[i + 1].span()[0]
221 | else:
222 | end_ws = span[1]
223 |
224 | # Format data
225 | data.append((
226 | token,
227 | text[start_ws: end_ws],
228 | span,
229 | ))
230 | return Tokens(data, self.annotators)
231 |
232 |
233 | class SimpleTokenizer(Tokenizer):
234 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
235 | NON_WS = r'[^\p{Z}\p{C}]'
236 |
237 | def __init__(self, **kwargs):
238 | """
239 | Args:
240 | annotators: None or empty set (only tokenizes).
241 | """
242 | self._regexp = regex.compile(
243 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
244 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
245 | )
246 | if len(kwargs.get('annotators', {})) > 0:
247 | logger.warning('%s only tokenizes! Skipping annotators: %s' %
248 | (type(self).__name__, kwargs.get('annotators')))
249 | self.annotators = set()
250 |
251 | def tokenize(self, text):
252 | data = []
253 | matches = [m for m in self._regexp.finditer(text)]
254 | for i in range(len(matches)):
255 | # Get text
256 | token = matches[i].group()
257 |
258 | # Get whitespace
259 | span = matches[i].span()
260 | start_ws = span[0]
261 | if i + 1 < len(matches):
262 | end_ws = matches[i + 1].span()[0]
263 | else:
264 | end_ws = span[1]
265 |
266 | # Format data
267 | data.append((
268 | token,
269 | text[start_ws: end_ws],
270 | span,
271 | ))
272 | return Tokens(data, self.annotators)
273 |
274 |
275 |
276 |
--------------------------------------------------------------------------------
/retrieval/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser()
6 |
7 | # Required parameters
8 | parser.add_argument("--bert_model_name",
9 | default="bert-large-cased-whole-word-masking", type=str)
10 | parser.add_argument("--output_dir", default="logs", type=str,
11 | help="The output directory where the model checkpoints will be written.")
12 | parser.add_argument("--weight_decay", default=0.0, type=float,
13 | help="Weight decay if we apply some.")
14 |
15 | # Other parameters
16 | parser.add_argument("--load", default=False, action='store_true')
17 | parser.add_argument("--num_workers", default=5, type=int)
18 | parser.add_argument("--train_file", type=str,
19 | default="")
20 | parser.add_argument("--predict_file", type=str,
21 | default="")
22 | parser.add_argument("--init_checkpoint", type=str,
23 | help="Initial checkpoint (usually from a pre-trained BERT model).",
24 | default="")
25 | parser.add_argument("--max_seq_length", default=512, type=int,
26 | help="The maximum total input sequence length after WordPiece tokenization. Sequences "
27 | "longer than this will be truncated, and sequences shorter than this will be padded.")
28 | parser.add_argument("--max_query_length", default=30, type=int,
29 | help="The maximum number of tokens for the question. Questions longer than this will "
30 | "be truncated to this length.")
31 | parser.add_argument("--do_train", default=False,
32 | action='store_true', help="Whether to run training.")
33 | parser.add_argument("--do_predict", default=False,
34 | action='store_true', help="Whether to run eval on the dev set.")
35 | parser.add_argument("--train_batch_size", default=8,
36 | type=int, help="Total batch size for training.")
37 | parser.add_argument("--predict_batch_size", default=100,
38 | type=int, help="Total batch size for predictions.")
39 | parser.add_argument("--learning_rate", default=5e-5,
40 | type=float, help="The initial learning rate for Adam.")
41 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
42 | help="Epsilon for Adam optimizer.")
43 | parser.add_argument("--num_train_epochs", default=5000, type=float,
44 | help="Total number of training epochs to perform.")
45 | parser.add_argument('--wait_step', type=int, default=100)
46 | parser.add_argument("--save_checkpoints_steps", default=20000, type=int,
47 | help="How often to save the model checkpoint.")
48 | parser.add_argument("--iterations_per_loop", default=1000, type=int,
49 | help="How many steps to make in each estimator call.")
50 | parser.add_argument("--no_cuda", default=False, action='store_true',
51 | help="Whether not to use CUDA when available")
52 | parser.add_argument("--local_rank", type=int, default=-1,
53 | help="local_rank for distributed training on gpus")
54 | parser.add_argument("--accumulate_gradients", type=int, default=1,
55 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)")
56 | parser.add_argument('--seed', type=int, default=3,
57 | help="random seed for initialization")
58 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
59 | help="Number of updates steps to accumualte before performing a backward/update pass.")
60 | parser.add_argument('--eval-period', type=int, default=2500)
61 | parser.add_argument('--verbose', action="store_true", default=False)
62 | parser.add_argument('--efficient_eval', action="store_true", help="whether to use fp16 for evaluation")
63 | parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
64 |
65 | parser.add_argument('--fp16', action='store_true')
66 | parser.add_argument('--fp16_opt_level', type=str, default='O1',
67 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
68 | "See details at https://nvidia.github.io/apex/amp.html")
69 |
70 |
71 | parser.add_argument('--filter', action='store_true', help="1. paragraph too short; 2. answer in questions.")
72 |
73 | # For evaluation
74 | parser.add_argument('--prefix', type=str, default="eval")
75 | parser.add_argument('--debug', action="store_true", default=False)
76 | parser.add_argument("--eval-workers", default=32,
77 | help="parallel data loader", type=int)
78 |
79 |
80 | # For encode questions
81 | parser.add_argument("--use-whole-model", action="store_true", help="re encode the questions after QA finetuning")
82 | parser.add_argument("--joint-train", action="store_true")
83 | parser.add_argument("--max-pool", action="store_true", help="CLS or maxpooling")
84 | parser.add_argument("--shared-norm", action="store_true", help="normalize span logits across different paragraphs")
85 | parser.add_argument("--retriever-path", type=str, default="", help="pretrained retriever checkpoint")
86 | parser.add_argument("--qa-drop", default=0, type=float)
87 |
88 | parser.add_argument('--embed_save_path', type=str, default="")
89 | parser.add_argument('--is_query_embed', action="store_true", default=False)
90 |
91 | args = parser.parse_args()
92 |
93 | return args
94 |
--------------------------------------------------------------------------------
/retrieval/datasets.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader, Dataset, Sampler
2 | import torch
3 | import json
4 | import numpy as np
5 | import random
6 | from tqdm import tqdm
7 | import re
8 | import string
9 | import os
10 | from multiprocessing import Pool as ProcessPool
11 |
12 | def normalize_answer(s):
13 | """Lower text and remove punctuation, articles and extra whitespace."""
14 | def remove_articles(text):
15 | return re.sub(r'\b(a|an|the)\b', ' ', text)
16 |
17 | def white_space_fix(text):
18 | return ' '.join(text.split())
19 |
20 | def remove_punc(text):
21 | exclude = set(string.punctuation)
22 | return ''.join(ch for ch in text if ch not in exclude)
23 |
24 | def lower(text):
25 | return text.lower()
26 |
27 | return white_space_fix(remove_articles(remove_punc(lower(s))))
28 |
29 | def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
30 | """Convert a list of 1d tensors into a padded 2d tensor."""
31 | size = max(v.size(0) for v in values)
32 | res = values[0].new(len(values), size).fill_(pad_idx)
33 |
34 | def copy_tensor(src, dst):
35 | assert dst.numel() == src.numel()
36 | if move_eos_to_beginning:
37 | assert src[-1] == eos_idx
38 | dst[0] = eos_idx
39 | dst[1:] = src[:-1]
40 | else:
41 | dst.copy_(src)
42 |
43 | for i, v in enumerate(values):
44 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
45 | return res
46 |
47 |
48 | class ClusterDataset(Dataset):
49 |
50 | def __init__(self,
51 | tokenizer,
52 | data_folder,
53 | max_query_length,
54 | max_length,
55 | filter=False
56 | ):
57 | super().__init__()
58 | self.tokenizer = tokenizer
59 | self.filter = filter
60 | self.max_query_length = max_query_length
61 | self.max_length = max_length
62 |
63 | print(f"Loading data splits from {data_folder}")
64 | file_lists = [os.path.join(data_folder, f) for f in os.listdir(data_folder)]
65 |
66 | self.data, self.index_clusters = [], []
67 | processes = ProcessPool(processes=30)
68 | file_datas = processes.map(self.load_file, file_lists)
69 | processes.close()
70 | processes.join()
71 | for file_data in file_datas:
72 | indice = len(self.data) + np.arange(len(file_data))
73 | self.index_clusters.append(list(indice))
74 | self.data.extend(file_data)
75 |
76 | print(f"Total {len(self.data)} loaded")
77 |
78 | def filter_sample(self, item):
79 | if len(item["Paragraph"].split()) < 20:
80 | return False
81 | if normalize_answer(item["Answer"]) in normalize_answer(item["Question"]):
82 | return False
83 | return True
84 |
85 | def load_file(self, file):
86 | data = [json.loads(line) for line in open(file).readlines()]
87 | if self.filter:
88 | data = [item for item in data if self.filter_sample(item)]
89 | return data
90 |
91 | def __getitem__(self, index):
92 | sample = self.data[index]
93 | question = sample['Question']
94 | paragraph = sample['Paragraph']
95 |
96 | question_ids = torch.LongTensor(self.tokenizer.encode(
97 | question, max_length=self.max_query_length))
98 | question_masks = torch.ones(question_ids.shape).bool()
99 |
100 | paragraph_ids = torch.LongTensor(self.tokenizer.encode(
101 | paragraph, max_length=self.max_length - self.max_query_length))
102 | paragraph_masks = torch.ones(paragraph_ids.shape).bool()
103 |
104 | return {
105 | 'input_ids_q': question_ids,
106 | 'input_mask_q': question_masks,
107 | 'input_ids_c': paragraph_ids,
108 | 'input_mask_c': paragraph_masks,
109 | }
110 |
111 | def __len__(self):
112 | return len(self.data)
113 |
114 |
115 | class ClusterSampler(Sampler):
116 |
117 | def __init__(self, data_source, batch_size):
118 | """
119 | batch size: within batch, all samples come from the same cluster
120 | """
121 | print(f"Sample with batch size {batch_size}")
122 |
123 | index_clusters = data_source.index_clusters
124 | sample_indice = []
125 |
126 | # shuffle inside each cluster
127 | num_group = 3
128 | for cluster in index_clusters:
129 | groups = [] # 3 adjacent examples share the same para
130 | for i in range(num_group):
131 | groups.append(cluster[i::num_group])
132 | random.shuffle(groups)
133 | for g in groups:
134 | random.shuffle(g)
135 | sample_indice += g
136 |
137 | # sample batches, avoid adjacent batches always come from the same cluster
138 | self.sample_indice = []
139 | batch_starts = np.arange(0, len(data_source), batch_size)
140 | np.random.shuffle(batch_starts)
141 | for batch_start in batch_starts:
142 | self.sample_indice += sample_indice[batch_start:batch_start+batch_size]
143 |
144 | assert len(self.sample_indice) == len(data_source)
145 |
146 | def __len__(self):
147 | return len(self.sample_indice)
148 |
149 | def __iter__(self):
150 | return iter(self.sample_indice)
151 |
152 |
153 | class ReDataset(Dataset):
154 |
155 | def __init__(self,
156 | tokenizer,
157 | data_path,
158 | max_query_length,
159 | max_length,
160 | filter=False
161 | ):
162 | super().__init__()
163 | self.tokenizer = tokenizer
164 | self.filter = filter
165 | print(f"Loading data from {data_path}")
166 |
167 | self.data = [json.loads(line) for line in open(data_path).readlines()]
168 |
169 | # filter
170 | original_count = len(self.data)
171 | if self.filter:
172 | self.data = [item for item in self.data if self.filter_sample(item)]
173 | print(f"Using {len(self.data)} out of {original_count}")
174 |
175 | self.max_query_length = max_query_length
176 | self.max_length = max_length
177 | self.group_indexs = []
178 | num_group = 3
179 | indexs = list(range(len(self.data)))
180 | for i in range(num_group):
181 | self.group_indexs.append(indexs[i::num_group])
182 |
183 | def filter_sample(self, item):
184 | if len(item["Paragraph"].split()) < 20:
185 | return False
186 | if normalize_answer(item["Answer"]) in normalize_answer(item["Question"]):
187 | return False
188 | return True
189 |
190 | def __getitem__(self, index):
191 | sample = self.data[index]
192 | question = sample['Question']
193 | paragraph = sample['Paragraph']
194 |
195 | question_ids = torch.LongTensor(self.tokenizer.encode(question, max_length=self.max_query_length))
196 | question_masks = torch.ones(question_ids.shape).bool()
197 |
198 | paragraph_ids = torch.LongTensor(self.tokenizer.encode(paragraph, max_length=self.max_length - self.max_query_length))
199 | paragraph_masks = torch.ones(paragraph_ids.shape).bool()
200 |
201 | return {
202 | 'input_ids_q': question_ids,
203 | 'input_mask_q': question_masks,
204 | 'input_ids_c': paragraph_ids,
205 | 'input_mask_c': paragraph_masks,
206 | }
207 |
208 | def __len__(self):
209 | return len(self.data)
210 |
211 |
212 | class ReSampler(Sampler):
213 | """
214 | Shuffle QA pairs not context, make sure data within the batch are from the same QA pair
215 | """
216 |
217 | def __init__(self, data_source):
218 | # for each QA pair, sample negative paragraphs
219 | sample_indice = []
220 | for _ in data_source.group_indexs:
221 | random.shuffle(_)
222 | sample_indice += _
223 | self.sample_indice = sample_indice
224 |
225 | def __len__(self):
226 | return len(self.sample_indice)
227 |
228 | def __iter__(self):
229 | return iter(self.sample_indice)
230 |
231 | def re_collate(samples):
232 | if len(samples) == 0:
233 | return {}
234 |
235 | return {
236 | 'input_ids_q': collate_tokens([s['input_ids_q'] for s in samples], 0),
237 | 'input_mask_q': collate_tokens([s['input_mask_q'] for s in samples], 0),
238 | 'input_ids_c': collate_tokens([s['input_ids_c'] for s in samples], 0),
239 | 'input_mask_c': collate_tokens([s['input_mask_c'] for s in samples], 0),
240 | }
241 |
242 | class FTDataset(Dataset):
243 | """
244 | finetune the Question encoder with
245 | """
246 |
247 | def __init__(self,
248 | tokenizer,
249 | data_path,
250 | max_query_length,
251 | max_length,
252 | filter=False
253 | ):
254 | super().__init__()
255 |
256 |
257 | class EmDataset(Dataset):
258 |
259 | def __init__(self,
260 | tokenizer,
261 | data_path,
262 | max_query_length,
263 | max_length,
264 | is_query_embed,
265 | ):
266 | super().__init__()
267 | self.is_query_embed = is_query_embed
268 | self.tokenizer = tokenizer
269 |
270 | print(f"Loading data from {data_path}")
271 | self.data = [json.loads(_.strip())
272 | for _ in tqdm(open(data_path).readlines())]
273 |
274 | self.max_length = max_query_length if is_query_embed else max_length
275 | print(f"Max sequence length: {self.max_length}")
276 |
277 |
278 | def __getitem__(self, index):
279 | sample = self.data[index]
280 | if self.is_query_embed:
281 | sent = sample['question']
282 | else:
283 | sent = sample['text']
284 |
285 | sent_ids = torch.LongTensor(
286 | self.tokenizer.encode(sent, max_length=self.max_length))
287 | sent_masks = torch.ones(sent_ids.shape).bool()
288 |
289 | return {
290 | 'input_ids': sent_ids,
291 | 'input_mask': sent_masks,
292 | }
293 |
294 | def __len__(self):
295 | return len(self.data)
296 |
297 |
298 | def em_collate(samples):
299 | if len(samples) == 0:
300 | return {}
301 |
302 | return {
303 | 'input_ids': collate_tokens([s['input_ids'] for s in samples], 0),
304 | 'input_mask': collate_tokens([s['input_mask'] for s in samples], 0),
305 | }
306 |
--------------------------------------------------------------------------------
/retrieval/eval_retrieval.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import json
4 | import faiss
5 | import argparse
6 |
7 | from multiprocessing import Pool as ProcessPool
8 | from multiprocessing.util import Finalize
9 | from functools import partial
10 | from collections import defaultdict
11 |
12 | from basic_tokenizer import SimpleTokenizer
13 | from utils import DocDB, normalize
14 |
15 |
16 | PROCESS_TOK = None
17 | PROCESS_DB = None
18 |
19 | def init(db_path):
20 | global PROCESS_TOK, PROCESS_DB
21 | PROCESS_TOK = SimpleTokenizer()
22 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
23 | PROCESS_DB = DocDB(db_path)
24 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
25 |
26 |
27 | def para_has_answer(answer, para, return_matched=False):
28 | global PROCESS_DB, PROCESS_TOK
29 | text = normalize(para)
30 | tokens = PROCESS_TOK.tokenize(text)
31 | text = tokens.words(uncased=True)
32 | assert len(text) == len(tokens)
33 | for single_answer in answer:
34 | single_answer = normalize(single_answer)
35 | single_answer = PROCESS_TOK.tokenize(single_answer)
36 | single_answer = single_answer.words(uncased=True)
37 | for i in range(0, len(text) - len(single_answer) + 1):
38 | if single_answer == text[i: i + len(single_answer)]:
39 | if return_matched:
40 | return True, tokens.slice(i, i + len(single_answer)).untokenize()
41 | else:
42 | return True
43 | if return_matched:
44 | return False, ""
45 | return False
46 |
47 | def get_score(answer_doc, topk=80):
48 | """Search through all the top docs to see if they have the answer."""
49 | question, answer, doc_ids = answer_doc
50 | top5doc_covered = 0
51 | global PROCESS_DB
52 | all_paras = [PROCESS_DB.get_doc_text(doc_id) for doc_id in doc_ids]
53 |
54 | topk_paras = all_paras[:topk]
55 | topkpara_covered = []
56 | for p in topk_paras:
57 | topkpara_covered.append(int(para_has_answer(answer, p)))
58 |
59 | return {
60 | str(topk): int(np.sum(topkpara_covered) > 0),
61 | "5": int(np.sum(topkpara_covered[:5]) > 0),
62 | "10": int(np.sum(topkpara_covered[:10]) > 0),
63 | "20": int(np.sum(topkpara_covered[:20]) > 0),
64 | "50": int(np.sum(topkpara_covered[:50]) > 0),
65 | }
66 |
67 |
68 | def convert_idx2id(idxs):
69 | idx_id_mapping = json.load(open('../pretrained_models/idx_id.json'))
70 | retrieval_results = []
71 | for cand_idx in idxs:
72 | out_ids = []
73 | for _ in cand_idx:
74 | out_ids.append(idx_id_mapping[str(_)])
75 | retrieval_results.append(out_ids)
76 | return retrieval_results
77 |
78 | if __name__ == '__main__':
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument('raw_data', type=str, default=None)
81 | parser.add_argument('indexpath', type=str, default=None)
82 | parser.add_argument('query_embed', type=str, default=None)
83 | parser.add_argument('db', type=str, default=None)
84 | parser.add_argument('--topk', type=int, default=80)
85 | parser.add_argument('--num-workers', type=int, default=10)
86 | args = parser.parse_args()
87 |
88 | qas = [json.loads(line) for line in open(args.raw_data).readlines()]
89 | questions = [item["question"] for item in qas]
90 | answers = [item["answer"] for item in qas]
91 |
92 | processes = ProcessPool(
93 | processes=args.num_workers,
94 | initializer=init,
95 | initargs=[args.db]
96 | )
97 |
98 | d = 128
99 | xq = np.load(args.query_embed).astype('float32')
100 | xb = np.load(args.indexpath).astype('float32')
101 |
102 | index = faiss.IndexFlatIP(d) # build the index
103 | index.add(xb) # add vectors to the index
104 | D, I = index.search(xq, args.topk) # actual search
105 |
106 | retrieval_results = convert_idx2id(I)
107 |
108 | assert len(retrieval_results) == len(questions) == len(answers)
109 | answers_docs = zip(questions, answers, retrieval_results)
110 |
111 | get_score_partial = partial(
112 | get_score, topk=args.topk)
113 | results = processes.map(get_score_partial, answers_docs)
114 |
115 | aggregate = defaultdict(list)
116 | for r in results:
117 | for k, v in r.items():
118 | aggregate[k].append(v)
119 |
120 | for k in aggregate:
121 | results = aggregate[k]
122 | print('Top {} Recall for {} QA pairs: {} ...'.format(
123 | k, len(results), np.mean(results)))
124 |
--------------------------------------------------------------------------------
/retrieval/gen_index_id_map.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | mapping = {}
4 | with open('../data/para_doc.db') as f_in:
5 | for idx, line in enumerate(f_in):
6 | sample = json.loads(line.strip())
7 | mapping[idx] = sample['id']
8 | with open('index_data/idx_id.json', 'w') as f_out:
9 | json.dump(mapping, f_out)
10 |
11 |
--------------------------------------------------------------------------------
/retrieval/get_embed.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | import json
4 | import os
5 | import random
6 | from tqdm import tqdm
7 | import numpy as np
8 | import torch
9 | from copy import deepcopy
10 |
11 | from torch.utils.data import DataLoader
12 | from datasets import EmDataset, em_collate
13 | from retriever import BertForRetriever
14 | from transformers import AdamW, BertConfig, BertTokenizer
15 | from utils import move_to_cuda, convert_to_half, AverageMeter
16 | from config import get_args
17 |
18 | from collections import defaultdict, namedtuple
19 | import torch.nn.functional as F
20 |
21 |
22 | def load_saved(model, path):
23 | state_dict = torch.load(path)
24 | def filter(x): return x[7:] if x.startswith('module.') else x
25 | state_dict = {filter(k): v for (k, v) in state_dict.items()}
26 | model.load_state_dict(state_dict)
27 | return model
28 |
29 | def main():
30 | args = get_args()
31 |
32 | is_query_embed = args.is_query_embed
33 | embed_save_path = args.embed_save_path
34 |
35 | if args.fp16:
36 | try:
37 | import apex
38 | apex.amp.register_half_function(torch, 'einsum')
39 | except ImportError:
40 | raise ImportError(
41 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
42 |
43 |
44 | if args.local_rank == -1 or args.no_cuda:
45 | device = torch.device(
46 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
47 | n_gpu = torch.cuda.device_count()
48 | else:
49 | device = torch.device("cuda", args.local_rank)
50 | n_gpu = 1
51 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
52 | torch.distributed.init_process_group(backend='nccl')
53 |
54 | if args.accumulate_gradients < 1:
55 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
56 | args.accumulate_gradients))
57 |
58 | args.train_batch_size = int(
59 | args.train_batch_size / args.accumulate_gradients)
60 | random.seed(args.seed)
61 | np.random.seed(args.seed)
62 | torch.manual_seed(args.seed)
63 | if n_gpu > 0:
64 | torch.cuda.manual_seed_all(args.seed)
65 |
66 | if not args.do_train and not args.do_predict:
67 | raise ValueError(
68 | "At least one of `do_train` or `do_predict` must be True.")
69 |
70 | if args.do_train:
71 | if not args.train_file:
72 | raise ValueError(
73 | "If `do_train` is True, then `train_file` must be specified.")
74 | if not args.predict_file:
75 | raise ValueError(
76 | "If `do_train` is True, then `predict_file` must be specified.")
77 |
78 | if args.do_predict:
79 | if not args.predict_file:
80 | raise ValueError(
81 | "If `do_predict` is True, then `predict_file` must be specified.")
82 |
83 | bert_config = BertConfig.from_pretrained(args.bert_model_name)
84 | model = BertForRetriever(bert_config, args)
85 | tokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
86 |
87 | if args.do_train and args.max_seq_length > bert_config.max_position_embeddings:
88 | raise ValueError(
89 | "Cannot use sequence length %d because the BERT model "
90 | "was only trained up to sequence length %d" %
91 | (args.max_seq_length, bert_config.max_position_embeddings))
92 |
93 | eval_dataset = EmDataset(
94 | tokenizer, args.predict_file, args.max_query_length, args.max_seq_length, is_query_embed)
95 | eval_dataloader = DataLoader(
96 | eval_dataset, batch_size=args.predict_batch_size, collate_fn=em_collate, pin_memory=True, num_workers=args.eval_workers)
97 |
98 | assert args.init_checkpoint != ""
99 | model = load_saved(model, args.init_checkpoint)
100 |
101 | model.to(device)
102 |
103 | if args.do_train:
104 | no_decay = ['bias', 'LayerNorm.weight']
105 | optimizer_parameters = [
106 | {'params': [p for n, p in model.named_parameters() if not any(
107 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
108 | {'params': [p for n, p in model.named_parameters() if any(
109 | nd in n for nd in no_decay)], 'weight_decay': 0.0}
110 | ]
111 | optimizer = AdamW(optimizer_parameters,
112 | lr=args.learning_rate, eps=args.adam_epsilon)
113 |
114 | if args.fp16:
115 | try:
116 | from apex import amp
117 | except ImportError:
118 | raise ImportError(
119 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
120 | model, optimizer = amp.initialize(
121 | model, optimizer, opt_level=args.fp16_opt_level)
122 | else:
123 | if args.fp16:
124 | try:
125 | from apex import amp
126 | except ImportError:
127 | raise ImportError(
128 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
129 | model = amp.initialize(model, opt_level=args.fp16_opt_level)
130 |
131 | if args.local_rank != -1:
132 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
133 | output_device=args.local_rank)
134 | elif n_gpu > 1:
135 | model = torch.nn.DataParallel(model)
136 |
137 |
138 | embeds = predict(args, model, eval_dataloader, device, fp16=args.efficient_eval, is_query_embed=is_query_embed)
139 | np.save(embed_save_path, embeds.cpu().numpy())
140 |
141 |
142 | def predict(args, model, eval_dataloader, device, fp16=False, is_query_embed=True):
143 | if type(model) == list:
144 | model = [m.eval() for m in model]
145 | else:
146 | model.eval()
147 | if fp16:
148 | if type(model) == list:
149 | model = [m.half() for m in model]
150 | else:
151 | model.half()
152 |
153 | num_correct = 0.0
154 | num_total = 0.0
155 | embed_array = []
156 | for batch in tqdm(eval_dataloader):
157 | batch_to_feed = move_to_cuda(batch)
158 | with torch.no_grad():
159 | results = model.get_embed(batch_to_feed, is_query_embed)
160 | embed = results['embed']
161 | embed_array.append(embed)
162 | #print(prediction, target, sum(prediction==target), len(prediction))
163 | #print(num_total, num_correct)
164 |
165 | ## linear combination tuning on dev data
166 | embed_array = torch.cat(embed_array)
167 |
168 | if fp16:
169 | model.float()
170 |
171 | model.train()
172 | return embed_array
173 |
174 |
175 | if __name__ == "__main__":
176 | main()
177 |
--------------------------------------------------------------------------------
/retrieval/get_para_embed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=3 python3 get_embed.py \
3 | --do_predict \
4 | --prefix eval-para \
5 | --predict_batch_size 300 \
6 | --bert_model_name bert-base-uncased \
7 | --fp16 \
8 | --predict_file ../data/wiki_splits.txt \
9 | --init_checkpoint logs/retrieve_train.txt-seed87-bsz640-fp16True-retriever_pretraining_single-lr1e-05-bert-base-uncased-filterTrue/checkpoint_best.pt \
10 | --embed_save_path encodings/para_embed.npy \
11 | --eval-workers 32 \
12 |
13 |
--------------------------------------------------------------------------------
/retrieval/group_paras.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import faiss
3 | import os
4 | import argparse
5 |
6 | def write_file(file_name, samples):
7 | with open(file_name, 'w') as f_out:
8 | for _ in samples:
9 | f_out.write(_)
10 |
11 |
12 | def group_paras(I, ncentroids, split_path):
13 | samples = [[] for _ in range(ncentroids)]
14 | with open('../data/retrieve_train.txt') as f_in:
15 | for i, line in enumerate(f_in):
16 | samples[I[i][0]].append(line)
17 | for i, group in enumerate(samples):
18 | write_file(split_path + 'split_'+str(i)+'.txt', group)
19 |
20 | def clusering(data, niter=1000, verbose=True, ncentroids=1024, max_points_per_centroid=10000000, gpu_id=0, spherical=False):
21 | # use one gpu
22 | '''
23 | res = faiss.StandardGpuResources()
24 | cfg = faiss.GpuIndexFlatConfig()
25 | cfg.useFloat16 = False
26 | cfg.device = gpu_id
27 |
28 | d = data.shape[1]
29 | if spherical:
30 | index = faiss.GpuIndexFlatIP(res, d, cfg)
31 | else:
32 | index = faiss.GpuIndexFlatL2(res, d, cfg)
33 | '''
34 | d = data.shape[1]
35 | if spherical:
36 | index = faiss.IndexFlatIP(d)
37 | else:
38 | index = faiss.IndexFlatL2(d)
39 |
40 | clus = faiss.Clustering(d, ncentroids)
41 | clus.verbose = True
42 | clus.niter = niter
43 | clus.max_points_per_centroid = max_points_per_centroid
44 |
45 | clus.train(x, index)
46 | centroids = faiss.vector_float_to_array(clus.centroids)
47 | centroids = centroids.reshape(ncentroids, d)
48 |
49 | index.reset()
50 | index.add(centroids)
51 | D, I = index.search(data, 1)
52 |
53 | return D, I
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--ncentroids', type=int, default=10000)
58 | parser.add_argument('--niter', type=int, default=250)
59 | parser.add_argument('--max_points_per_centroid', type=int, default=1000)
60 | parser.add_argument('--indexpath', type=str, default=None)
61 | parser.add_argument('--spherical', action='store_true')
62 | args = parser.parse_args()
63 |
64 |
65 | train_para_embed_path = "encodings/train_para_embed.npy"
66 | split_save_path = "../data/data_splits/"
67 | if os.path.exists(split_save_path) and os.listdir(split_save_path):
68 | print(f"output directory {split_save_path} already exists and is not empty.")
69 | if not os.path.exists(split_save_path):
70 | os.makedirs(split_save_path, exist_ok=True)
71 |
72 | x = np.load(train_para_embed_path)
73 | x = np.float32(x)
74 |
75 | D, I = clusering(x, niter=args.niter, ncentroids=args.ncentroids, max_points_per_centroid=args.max_points_per_centroid, spherical=args.spherical)
76 |
77 | group_paras(I, args.ncentroids, split_path=split_save_path)
78 |
--------------------------------------------------------------------------------
/retrieval/retriever.py:
--------------------------------------------------------------------------------
1 |
2 | from transformers import BertModel, BertConfig, BertPreTrainedModel
3 | import torch.nn as nn
4 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
5 | import torch
6 |
7 |
8 | class BertForRetriever(nn.Module):
9 |
10 | def __init__(self,
11 | config,
12 | args
13 | ):
14 | super(BertForRetriever, self).__init__()
15 |
16 | self.bert_q = BertModel.from_pretrained(args.bert_model_name)
17 | self.bert_c = BertModel.from_pretrained(args.bert_model_name)
18 |
19 | self.proj_q = nn.Linear(config.hidden_size, 128)
20 | self.proj_c = nn.Linear(config.hidden_size, 128)
21 |
22 | def forward(self, batch):
23 | input_ids_q, attention_mask_q = batch["input_ids_q"], batch["input_mask_q"]
24 | q_cls = self.bert_q(input_ids_q, attention_mask_q)[1]
25 | q = self.proj_q(q_cls)
26 |
27 | input_ids_c, attention_mask_c = batch["input_ids_c"], batch["input_mask_c"]
28 | c_cls = self.bert_c(input_ids_c, attention_mask_c)[1]
29 | c = self.proj_c(c_cls)
30 |
31 | return {"q": q, "c": c}
32 |
33 | def get_embed(self, batch, is_query_embed):
34 |
35 | input_ids, attention_mask = batch["input_ids"], batch["input_mask"]
36 | if is_query_embed:
37 | q_cls = self.bert_q(input_ids, attention_mask)[1]
38 | q = self.proj_q(q_cls)
39 | return {'embed': q}
40 | else:
41 | c_cls = self.bert_c(input_ids, attention_mask)[1]
42 | c = self.proj_c(c_cls)
43 | return {'embed': c}
44 |
--------------------------------------------------------------------------------
/retrieval/tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import six
24 | import tensorflow as tf
25 |
26 |
27 | def convert_tokens_to_ids(vocab, tokens):
28 | """Converts a sequence of tokens into ids using the vocab."""
29 | ids = []
30 | for token in tokens:
31 | ids.append(vocab[token])
32 | return ids
33 |
34 | def whitespace_tokenize(text):
35 | """Runs basic whitespace cleaning and splitting on a peice of text."""
36 | text = text.strip()
37 | if not text:
38 | return []
39 | tokens = text.split()
40 | return tokens
41 |
42 |
43 | def convert_to_unicode(text):
44 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
45 | if six.PY3:
46 | if isinstance(text, str):
47 | return text
48 | elif isinstance(text, bytes):
49 | return text.decode("utf-8", "ignore")
50 | else:
51 | raise ValueError("Unsupported string type: %s" % (type(text)))
52 | elif six.PY2:
53 | if isinstance(text, str):
54 | return text.decode("utf-8", "ignore")
55 | elif isinstance(text, unicode):
56 | return text
57 | else:
58 | raise ValueError("Unsupported string type: %s" % (type(text)))
59 | else:
60 | raise ValueError("Not running on Python2 or Python 3?")
61 |
62 |
63 | def _is_whitespace(char):
64 | """Checks whether `chars` is a whitespace character."""
65 | # \t, \n, and \r are technically contorl characters but we treat them
66 | # as whitespace since they are generally considered as such.
67 | if char == " " or char == "\t" or char == "\n" or char == "\r":
68 | return True
69 | cat = unicodedata.category(char)
70 | if cat == "Zs":
71 | return True
72 | return False
73 |
74 |
75 | def _is_control(char):
76 | """Checks whether `chars` is a control character."""
77 | # These are technically control characters but we count them as whitespace
78 | # characters.
79 | if char == "\t" or char == "\n" or char == "\r":
80 | return False
81 | cat = unicodedata.category(char)
82 | if cat.startswith("C"):
83 | return True
84 | return False
85 |
86 | class BasicTokenizer(object):
87 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
88 |
89 | def __init__(self, do_lower_case=True):
90 | """Constructs a BasicTokenizer.
91 | Args:
92 | do_lower_case: Whether to lower case the input.
93 | """
94 | self.do_lower_case = do_lower_case
95 |
96 | def tokenize(self, text):
97 | """Tokenizes a piece of text."""
98 | text = convert_to_unicode(text)
99 | text = self._clean_text(text)
100 | orig_tokens = whitespace_tokenize(text)
101 | split_tokens = []
102 | for token in orig_tokens:
103 | if self.do_lower_case:
104 | token = token.lower()
105 | token = self._run_strip_accents(token)
106 | split_tokens.extend(self._run_split_on_punc(token))
107 |
108 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
109 | return output_tokens
110 |
111 | def _run_strip_accents(self, text):
112 | """Strips accents from a piece of text."""
113 | text = unicodedata.normalize("NFD", text)
114 | output = []
115 | for char in text:
116 | cat = unicodedata.category(char)
117 | if cat == "Mn":
118 | continue
119 | output.append(char)
120 | return "".join(output)
121 |
122 | def _run_split_on_punc(self, text):
123 | """Splits punctuation on a piece of text."""
124 | chars = list(text)
125 | i = 0
126 | start_new_word = True
127 | output = []
128 | while i < len(chars):
129 | char = chars[i]
130 | if _is_punctuation(char):
131 | output.append([char])
132 | start_new_word = True
133 | else:
134 | if start_new_word:
135 | output.append([])
136 | start_new_word = False
137 | output[-1].append(char)
138 | i += 1
139 |
140 | return ["".join(x) for x in output]
141 |
142 | def _clean_text(self, text):
143 | """Performs invalid character removal and whitespace cleanup on text."""
144 | output = []
145 | for char in text:
146 | cp = ord(char)
147 | if cp == 0 or cp == 0xfffd or _is_control(char):
148 | continue
149 | if _is_whitespace(char):
150 | output.append(" ")
151 | else:
152 | output.append(char)
153 | return "".join(output)
154 |
155 |
156 | def _is_punctuation(char):
157 | """Checks whether `chars` is a punctuation character."""
158 | cp = ord(char)
159 | # We treat all non-letter/number ASCII as punctuation.
160 | # Characters such as "^", "$", and "`" are not in the Unicode
161 | # Punctuation class but we treat them as punctuation anyways, for
162 | # consistency.
163 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
164 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
165 | return True
166 | cat = unicodedata.category(char)
167 | if cat.startswith("P"):
168 | return True
169 | return False
170 |
171 |
172 | def process(s, tokenizer):
173 | try:
174 | return tokenizer.tokenize(s)
175 | except:
176 | print('failed on', s)
177 | raise
178 |
179 | if __name__ == "__main__":
180 | _is_whitespace("a")
181 |
--------------------------------------------------------------------------------
/retrieval/train_retriever.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | import json
4 | import os
5 | import random
6 | from tqdm import tqdm
7 | import numpy as np
8 | import torch
9 | from copy import deepcopy
10 |
11 | from torch.utils.data import DataLoader
12 | from datasets import ReDataset, ReSampler, re_collate, ClusterSampler, ClusterDataset
13 | from retriever import BertForRetriever
14 | from transformers import AdamW, BertConfig, BertTokenizer
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | from utils import move_to_cuda, convert_to_half, AverageMeter
18 | from config import get_args
19 |
20 | from collections import defaultdict, namedtuple
21 | import torch.nn.functional as F
22 | from torch.nn import CrossEntropyLoss
23 |
24 |
25 | def load_saved(model, path):
26 | state_dict = torch.load(path)
27 | def filter(x): return x[7:] if x.startswith('module.') else x
28 | state_dict = {filter(k): v for (k, v) in state_dict.items()}
29 | model.load_state_dict(state_dict)
30 | return model
31 |
32 | def main():
33 | args = get_args()
34 |
35 | if args.fp16:
36 | try:
37 | import apex
38 | apex.amp.register_half_function(torch, 'einsum')
39 | except ImportError:
40 | raise ImportError(
41 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
42 |
43 | # tb logger
44 | data_name = args.train_file.split("/")[-1].split('-')[0]
45 | model_name = f"{data_name}-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-{args.prefix}-lr{args.learning_rate}-{args.bert_model_name}-filter{args.filter}"
46 | tb_logger = SummaryWriter(os.path.join(
47 | args.output_dir, "tflogs", model_name))
48 | args.output_dir = os.path.join(args.output_dir, model_name)
49 |
50 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
51 | print(
52 | f"output directory {args.output_dir} already exists and is not empty.")
53 | if not os.path.exists(args.output_dir):
54 | os.makedirs(args.output_dir, exist_ok=True)
55 |
56 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
57 | datefmt='%m/%d/%Y %H:%M:%S',
58 | level=logging.INFO,
59 | handlers=[logging.FileHandler(os.path.join(args.output_dir, "log.txt")),
60 | logging.StreamHandler()])
61 | logger = logging.getLogger(__name__)
62 | logger.info(args)
63 |
64 | if args.local_rank == -1 or args.no_cuda:
65 | device = torch.device(
66 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
67 | n_gpu = torch.cuda.device_count()
68 | else:
69 | device = torch.device("cuda", args.local_rank)
70 | n_gpu = 1
71 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
72 | torch.distributed.init_process_group(backend='nccl')
73 | logger.info("device %s n_gpu %d distributed training %r",
74 | device, n_gpu, bool(args.local_rank != -1))
75 |
76 | if args.accumulate_gradients < 1:
77 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
78 | args.accumulate_gradients))
79 |
80 | args.train_batch_size = int(
81 | args.train_batch_size / args.accumulate_gradients)
82 | random.seed(args.seed)
83 | np.random.seed(args.seed)
84 | torch.manual_seed(args.seed)
85 | if n_gpu > 0:
86 | torch.cuda.manual_seed_all(args.seed)
87 |
88 | if not args.do_train and not args.do_predict:
89 | raise ValueError(
90 | "At least one of `do_train` or `do_predict` must be True.")
91 |
92 | if args.do_train:
93 | if not args.train_file:
94 | raise ValueError(
95 | "If `do_train` is True, then `train_file` must be specified.")
96 | if not args.predict_file:
97 | raise ValueError(
98 | "If `do_train` is True, then `predict_file` must be specified.")
99 |
100 | if args.do_predict:
101 | if not args.predict_file:
102 | raise ValueError(
103 | "If `do_predict` is True, then `predict_file` must be specified.")
104 |
105 | bert_config = BertConfig.from_pretrained(args.bert_model_name)
106 | model = BertForRetriever(bert_config, args)
107 | tokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
108 |
109 | if args.do_train and args.max_seq_length > bert_config.max_position_embeddings:
110 | raise ValueError(
111 | "Cannot use sequence length %d because the BERT model "
112 | "was only trained up to sequence length %d" %
113 | (args.max_seq_length, bert_config.max_position_embeddings))
114 |
115 | eval_dataset = ReDataset(
116 | tokenizer, args.predict_file, args.max_query_length, args.max_seq_length)
117 | #sampler = ReSampler(eval_dataset)
118 | eval_dataloader = DataLoader(
119 | eval_dataset, batch_size=args.predict_batch_size, collate_fn=re_collate, pin_memory=True, num_workers=args.eval_workers)
120 | logger.info(f"Num of dev batches: {len(eval_dataloader)}")
121 |
122 | if args.init_checkpoint != "":
123 | if ";" in args.init_checkpoint:
124 | models = []
125 | for path in args.init_checkpoint.split(";"):
126 | instance = deepcopy(load_saved(model, path))
127 | models.append(instance)
128 | model = models
129 | else:
130 | model = load_saved(model, args.init_checkpoint)
131 |
132 | if type(model) == list:
133 | model = [m.to(device) for m in model]
134 | else:
135 | model.to(device)
136 | print(
137 | f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
138 |
139 | if args.do_train:
140 | no_decay = ['bias', 'LayerNorm.weight']
141 | optimizer_parameters = [
142 | {'params': [p for n, p in model.named_parameters() if not any(
143 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
144 | {'params': [p for n, p in model.named_parameters() if any(
145 | nd in n for nd in no_decay)], 'weight_decay': 0.0}
146 | ]
147 | optimizer = AdamW(optimizer_parameters,
148 | lr=args.learning_rate, eps=args.adam_epsilon)
149 |
150 | if args.fp16:
151 | try:
152 | from apex import amp
153 | except ImportError:
154 | raise ImportError(
155 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
156 | model, optimizer = amp.initialize(
157 | model, optimizer, opt_level=args.fp16_opt_level)
158 | else:
159 | if args.fp16:
160 | try:
161 | from apex import amp
162 | except ImportError:
163 | raise ImportError(
164 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
165 | model = amp.initialize(model, opt_level=args.fp16_opt_level)
166 |
167 | if args.local_rank != -1:
168 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
169 | output_device=args.local_rank)
170 | elif n_gpu > 1:
171 | model = torch.nn.DataParallel(model)
172 |
173 | if args.do_train:
174 | global_step = 0 # gradient update step
175 | batch_step = 0 # forward batch count
176 | best_acc = 0
177 | wait_step = 0
178 | stop_training = False
179 | train_loss_meter = AverageMeter()
180 | model.train()
181 |
182 | if not os.path.isdir(args.train_file):
183 | train_dataset = ReDataset(
184 | tokenizer, args.train_file, args.max_query_length, args.max_seq_length, args.filter)
185 | sampler = ReSampler(train_dataset)
186 | train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, sampler=sampler, pin_memory=True, collate_fn=re_collate, num_workers=8)
187 | else:
188 | train_dataset = ClusterDataset(
189 | tokenizer, args.train_file, args.max_query_length, args.max_seq_length, args.filter)
190 | sampler = ClusterSampler(
191 | train_dataset, args.train_batch_size)
192 | train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, sampler=sampler, pin_memory=True, collate_fn=re_collate, num_workers=8)
193 |
194 | logger.info('Start training....')
195 | loss_fct = CrossEntropyLoss()
196 | for epoch in range(int(args.num_train_epochs)):
197 |
198 | for batch in tqdm(train_dataloader):
199 | batch_step += 1
200 | batch = move_to_cuda(batch)
201 | outputs = model(batch)
202 |
203 | product = torch.mm(outputs["q"], outputs["c"].t())
204 | target = torch.arange(product.size(0)).to(product.device)
205 | loss = loss_fct(product, target)
206 |
207 | if args.gradient_accumulation_steps > 1:
208 | loss = loss / args.gradient_accumulation_steps
209 |
210 | if args.fp16:
211 | with amp.scale_loss(loss, optimizer) as scaled_loss:
212 | scaled_loss.backward()
213 | else:
214 | loss.backward()
215 |
216 | train_loss_meter.update(loss.item())
217 | tb_logger.add_scalar('batch_train_loss',
218 | loss.item(), global_step)
219 | tb_logger.add_scalar('smoothed_train_loss',
220 | train_loss_meter.avg, global_step)
221 |
222 | if (batch_step + 1) % args.gradient_accumulation_steps == 0:
223 | if args.fp16:
224 | torch.nn.utils.clip_grad_norm_(
225 | amp.master_params(optimizer), args.max_grad_norm)
226 | else:
227 | torch.nn.utils.clip_grad_norm_(
228 | model.parameters(), args.max_grad_norm)
229 | optimizer.step() # We have accumulated enought gradients
230 | model.zero_grad()
231 | global_step += 1
232 |
233 | if global_step % args.save_checkpoints_steps == 0:
234 | torch.save(model.state_dict(), os.path.join(
235 | args.output_dir, f"checkpoint_{global_step}.pt"))
236 |
237 | if global_step % args.eval_period == 0:
238 | acc = predict(args, model, eval_dataloader,
239 | device, fp16=args.efficient_eval)
240 | logger.info("Step %d Train loss %.2f Acc %.2f on epoch=%d" % (
241 | global_step, train_loss_meter.avg, acc*100, epoch))
242 |
243 | tb_logger.add_scalar('dev_acc', acc*100, global_step)
244 |
245 | # save most recent model
246 | torch.save(model.state_dict(), os.path.join(
247 | args.output_dir, f"checkpoint_last.pt"))
248 |
249 | if best_acc < acc:
250 | logger.info("Saving model with best Acc %.2f -> Acc %.2f on epoch=%d" %
251 | (best_acc*100, acc*100, epoch))
252 | # model_state_dict = {k: v.cpu() for (
253 | # k, v) in model.state_dict().items()}
254 | torch.save(model.state_dict(), os.path.join(
255 | args.output_dir, f"checkpoint_best.pt"))
256 | model = model.to(device)
257 | best_acc = acc
258 | wait_step = 0
259 | stop_training = False
260 | else:
261 | wait_step += 1
262 | if wait_step == args.wait_step:
263 | stop_training = True
264 |
265 |
266 |
267 | # acc = predict(args, model, eval_dataloader,
268 | # device, fp16=args.efficient_eval)
269 | # tb_logger.add_scalar('dev_acc', acc*100, global_step)
270 | # logger.info(f"average training loss {train_loss_meter.avg}")
271 | # if best_acc < acc:
272 | # logger.info("Saving model with best Acc %.2f -> Acc %.2f on epoch=%d" %
273 | # (best_acc*100, acc*100, epoch))
274 | # model_state_dict = {k: v.cpu() for (
275 | # k, v) in model.state_dict().items()}
276 | # torch.save(model_state_dict, os.path.join(
277 | # args.output_dir, "best-model.pt"))
278 | # model = model.to(device)
279 | # best_acc = acc
280 | # wait_step = 0
281 |
282 | if stop_training:
283 | break
284 |
285 | logger.info("Training finished!")
286 |
287 | elif args.do_predict:
288 | acc = predict(args, model, eval_dataloader, device, fp16=args.efficient_eval)
289 | logger.info(f"test performance {acc}")
290 | print(acc)
291 |
292 |
293 | def predict(args, model, eval_dataloader, device, fp16=False):
294 | if type(model) == list:
295 | model = [m.eval() for m in model]
296 | else:
297 | model.eval()
298 |
299 | if fp16:
300 | if type(model) == list:
301 | model = [m.half() for m in model]
302 | else:
303 | model.half()
304 |
305 | num_correct = 0.0
306 | num_total = 0.0
307 | for batch in tqdm(eval_dataloader):
308 | batch_to_feed = move_to_cuda(batch)
309 | if fp16:
310 | batch_to_feed = convert_to_half(batch_to_feed)
311 | with torch.no_grad():
312 | results = model(batch_to_feed)
313 | product = torch.mm(results["q"], results["c"].t())
314 | target = torch.arange(product.size(0)).to(product.device)
315 | prediction = product.argmax(-1)
316 | pred_res = prediction == target
317 | num_total += len(pred_res)
318 | num_correct += sum(pred_res)
319 |
320 | ## linear combination tuning on dev data
321 | acc = num_correct/num_total
322 | best_acc = 0
323 | if acc > best_acc:
324 | best_acc = acc
325 | print(f"evaluated {num_total} examples...")
326 | print(f"avg. Acc: {acc}")
327 |
328 |
329 | if fp16:
330 | model.float()
331 | model.train()
332 |
333 | return best_acc
334 |
335 |
336 | if __name__ == "__main__":
337 | main()
338 |
--------------------------------------------------------------------------------
/retrieval/train_retriever_cluster.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_retriever.py \
4 | --do_train \
5 | --prefix retriever_pretraining_cluster \
6 | --predict_batch_size 512 \
7 | --bert_model_name bert-base-uncased \
8 | --train_batch_size 640 \
9 | --gradient_accumulation_steps 8 \
10 | --accumulate_gradients 8 \
11 | --efficient_eval \
12 | --learning_rate 1e-5 \
13 | --train_file ../data/data_splits/\
14 | --predict_file ../data/retrieve_dev_shuffled.txt \
15 | --seed 87 \
16 | --init_checkpoint logs/retrieve_train.txt-seed87-bsz640-fp16True-retriever_pretraining_single-lr1e-05-bert-base-uncased-filterTrue/checkpoint_last.pt \
17 | --eval-period 800 \
18 | --filter
19 |
--------------------------------------------------------------------------------
/retrieval/train_retriever_single.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_retriever.py \
4 | --do_train \
5 | --prefix retriever_pretraining_single \
6 | --predict_batch_size 512 \
7 | --bert_model_name bert-base-uncased \
8 | --train_batch_size 640 \
9 | --gradient_accumulation_steps 8 \
10 | --accumulate_gradients 8 \
11 | --efficient_eval \
12 | --learning_rate 1e-5 \
13 | --fp16 \
14 | --train_file ../data/retrieve_train.txt \
15 | --predict_file ../data/retrieve_dev_shuffled.txt \
16 | --seed 87 \
17 | --eval-period 800 \
18 | --filter
19 |
--------------------------------------------------------------------------------
/retrieval/trec_process.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import faiss
7 |
8 | def prepare_corpus(path="../data/trec-2019/collection.tsv", save_path="../data/trec-2019/msmarco_paras.txt"):
9 | corpus = []
10 | for line in tqdm(open(path).readlines()):
11 | line = line.strip()
12 | pid, text = line.split("\t")
13 | corpus.append({"text": text, "id": int(pid)})
14 | with open(save_path, "w") as g:
15 | for _ in corpus:
16 | g.write(json.dumps(_) + "\n")
17 |
18 | def extract_labels(
19 | input="../data/trec-2019/qrels.train.tsv",
20 | output="../data/trec-2019/msmacro-train.txt",
21 | queries="../data/trec-2019/queries.train.tsv"
22 | ):
23 | # id2queries
24 | qid2query = {}
25 | for line in open(queries).readlines():
26 | line = line.strip()
27 | qid, q = line.split("\t")[0], line.split("\t")[1]
28 | if q.endswith("?"):
29 | q = q[:-1]
30 | qid2query[int(qid)] = q
31 | print(len(qid2query))
32 |
33 | # queries with groundtruths
34 | qid2ground = defaultdict(list)
35 | for line in open(input).readlines():
36 | line = line.strip()
37 | qid, pid = line.split("\t")[0], line.split("\t")[2]
38 | qid2ground[int(qid)].append(int(pid))
39 | print(len(qid2ground))
40 |
41 | # generate data for train/dev
42 | with open(output, "w") as g:
43 | for qid, labels in qid2ground.items():
44 | question = qid2query[qid]
45 | sample = {"question":question, "labels": labels, "qid": qid}
46 | g.write(json.dumps(sample) + "\n")
47 |
48 |
49 | def debug():
50 | top1000_dev = open("../data/trec-2019/top1000.dev").readlines()
51 | qid2top10000 = defaultdict(list)
52 | for l in top1000_dev:
53 | qid2top10000[int(l.split("\t")[0])].append(int(l.split("\t")[1]))
54 | print(len(qid2top10000))
55 |
56 | processed_dev = [json.loads(l) for l in tqdm(open(
57 | "../data/trec-2019/processed/dev.txt").readlines())]
58 | qid2ground = {_["qid"]: _["labels"] for _ in processed_dev}
59 |
60 | covered = []
61 | for qid in qid2top10000.keys():
62 | top1000_labels = [int(_ in qid2ground[qid]) for _ in qid2top10000[qid]]
63 | covered.append(int(np.sum(top1000_labels) > 0))
64 |
65 | print(len(covered))
66 | print(np.mean(covered))
67 |
68 |
69 | def retrieve_topk(index_path="../data/trec-2019/embeds/msmarco_paras_embed.npy", query_embeds="../data/trec-2019/embeds/msmarco-train-query.npy", query_input="../data/trec-2019/msmacro-train.txt", output="../data/trec-2019/processed/train.txt"):
70 | d = 128
71 | xq = np.load(query_embeds).astype('float32')
72 | xb = np.load(index_path).astype('float32')
73 |
74 | index = faiss.IndexFlatIP(d) # build the index
75 | index.add(xb) # add vectors to the index
76 | D, I = index.search(xq, 10000) # actual search
77 |
78 | raw_data = [json.loads(l) for l in open(query_input).readlines()]
79 |
80 | processed = []
81 | covered = []
82 | for idx, para_indice in enumerate(I):
83 | orig_sample = raw_data[idx]
84 | para_embed_idx = [int(_) for _ in para_indice]
85 | para_labels = [int(_ in orig_sample["labels"]) for _ in para_embed_idx]
86 | orig_sample["para_embed_idx"] = para_embed_idx
87 | orig_sample["para_labels"] = para_labels
88 | processed.append(orig_sample)
89 | covered.append(int(np.sum(para_labels) > 0))
90 |
91 | print(f"Avg recall: {np.mean(covered)}")
92 | with open(output, "w") as g:
93 | for _ in processed:
94 | g.write(json.dumps(_) + "\n")
95 |
96 |
97 | if __name__ == "__main__":
98 | # prepare_corpus()
99 | # extract_labels(input="../data/trec-2019/qrels.dev.small.tsv",
100 | # output="../data/trec-2019/msmacro-dev-small.txt",
101 | # queries="../data/trec-2019/queries.dev.tsv")
102 |
103 | # debug()
104 |
105 | retrieve_topk()
106 |
--------------------------------------------------------------------------------
/retrieval/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sqlite3
3 | import unicodedata
4 |
5 | def move_to_cuda(sample):
6 | if len(sample) == 0:
7 | return {}
8 |
9 | def _move_to_cuda(maybe_tensor):
10 | if torch.is_tensor(maybe_tensor):
11 | return maybe_tensor.cuda()
12 | elif isinstance(maybe_tensor, dict):
13 | return {
14 | key: _move_to_cuda(value)
15 | for key, value in maybe_tensor.items()
16 | }
17 | elif isinstance(maybe_tensor, list):
18 | return [_move_to_cuda(x) for x in maybe_tensor]
19 | else:
20 | return maybe_tensor
21 |
22 | return _move_to_cuda(sample)
23 |
24 | def convert_to_half(sample):
25 | if len(sample) == 0:
26 | return {}
27 |
28 | def _convert_to_half(maybe_floatTensor):
29 | if torch.is_tensor(maybe_floatTensor) and maybe_floatTensor.type() == "torch.FloatTensor":
30 | return maybe_floatTensor.half()
31 | elif isinstance(maybe_floatTensor, dict):
32 | return {
33 | key: _convert_to_half(value)
34 | for key, value in maybe_floatTensor.items()
35 | }
36 | elif isinstance(maybe_floatTensor, list):
37 | return [_convert_to_half(x) for x in maybe_floatTensor]
38 | else:
39 | return maybe_floatTensor
40 |
41 | return _convert_to_half(sample)
42 |
43 |
44 | class AverageMeter(object):
45 | """Computes and stores the average and current value"""
46 |
47 | def __init__(self):
48 | self.reset()
49 |
50 | def reset(self):
51 | self.val = 0
52 | self.avg = 0
53 | self.sum = 0
54 | self.count = 0
55 |
56 | def update(self, val, n=1):
57 | self.val = val
58 | self.sum += val * n
59 | self.count += n
60 | self.avg = self.sum / self.count
61 |
62 |
63 | def normalize(text):
64 | """Resolve different type of unicode encodings."""
65 | return unicodedata.normalize('NFD', text)
66 |
67 |
68 | class DocDB(object):
69 | """Sqlite backed document storage.
70 |
71 | Implements get_doc_text(doc_id).
72 | """
73 |
74 | def __init__(self, db_path=None):
75 | self.path = db_path
76 | self.connection = sqlite3.connect(self.path, check_same_thread=False)
77 |
78 | def __enter__(self):
79 | return self
80 |
81 | def __exit__(self, *args):
82 | self.close()
83 |
84 | def close(self):
85 | """Close the connection to the database."""
86 | self.connection.close()
87 |
88 | def get_doc_ids(self):
89 | """Fetch all ids of docs stored in the db."""
90 | cursor = self.connection.cursor()
91 | cursor.execute("SELECT id FROM documents")
92 | results = [r[0] for r in cursor.fetchall()]
93 | cursor.close()
94 | return results
95 |
96 | def get_doc_text(self, doc_id):
97 | """Fetch the raw text of the doc for 'doc_id'."""
98 | cursor = self.connection.cursor()
99 | cursor.execute(
100 | "SELECT text FROM documents WHERE id = ?",
101 | (normalize(doc_id),)
102 | )
103 | result = cursor.fetchone()
104 | cursor.close()
105 | return result if result is None else result[0]
106 |
--------------------------------------------------------------------------------