":
125 | self._onRowEnd()
126 | elif "| ", ""]:
129 | self._on_cell_end()
130 |
131 | return self.all_tables
132 |
133 | def _on_table_start(self):
134 | caption = self.title
135 | parent_table = self.current_table
136 | if parent_table:
137 | self.tables_stack.append(parent_table)
138 |
139 | caption = parent_table.caption
140 | if parent_table.body and parent_table.body[-1].cells:
141 | current_cell = self.current_table.body[-1].cells[-1]
142 | caption += " | " + " ".join(current_cell.value_tokens)
143 |
144 | t = Table()
145 | t.caption = caption
146 | self.current_table = t
147 | self.all_tables.append(t)
148 |
149 | def _on_table_end(self):
150 | t = self.current_table
151 | if t:
152 | if self.tables_stack: # t is a nested table
153 | self.current_table = self.tables_stack.pop()
154 | if self.current_table.body:
155 | current_cell = self.current_table.body[-1].cells[-1]
156 | current_cell.nested_tables.append(t)
157 | else:
158 | logger.error("table end without table object")
159 |
160 | def _onRowStart(self):
161 | self.current_table.body.append(Row())
162 |
163 | def _onRowEnd(self):
164 | pass
165 |
166 | def _onCellStart(self):
167 | current_row = self.current_table.body[-1]
168 | current_row.cells.append(Cell())
169 |
170 | def _on_cell_end(self):
171 | pass
172 |
173 | def _on_content(self, token):
174 | if self.current_table.body:
175 | current_row = self.current_table.body[-1]
176 | current_cell = current_row.cells[-1]
177 | current_cell.value_tokens.append(token)
178 | else: # tokens outside of row/cells. Just append to the table caption.
179 | self.current_table.caption += " " + token
180 |
181 |
182 | def read_nq_tables_jsonl(path: str, out_file: str = None) -> Dict[str, Table]:
183 | tables_with_issues = 0
184 | single_row_tables = 0
185 | nested_tables = 0
186 | regular_tables = 0
187 | total_tables = 0
188 | total_rows = 0
189 | tables_dict = {}
190 |
191 | with jsonlines.open(path, mode="r") as jsonl_reader:
192 | for jline in jsonl_reader:
193 | tokens = jline["tokens"]
194 |
195 | if "( hide ) This section has multiple issues" in " ".join(tokens):
196 | tables_with_issues += 1
197 | continue
198 | mask = jline["html_mask"]
199 | # _page_url = jline["doc_url"]
200 | title = jline["title"]
201 | p = NQTableParser(tokens, mask, title)
202 | tables = p.parse()
203 |
204 | nested_tables += len(tables[1:])
205 |
206 | for t in tables:
207 | total_tables += 1
208 |
209 | # calc amount of non empty rows
210 | non_empty_rows = sum([1 for r in t.body if r.cells and any([True for c in r.cells if c.value_tokens])])
211 |
212 | if non_empty_rows <= 1:
213 | single_row_tables += 1
214 | else:
215 | regular_tables += 1
216 | total_rows += len(t.body)
217 |
218 | if t.get_key() not in tables_dict:
219 | tables_dict[t.get_key()] = t
220 |
221 | if len(tables_dict) % 1000 == 0:
222 | logger.info("tables_dict %d", len(tables_dict))
223 |
224 | logger.info("regular tables %d", regular_tables)
225 | logger.info("tables_with_issues %d", tables_with_issues)
226 | logger.info("single_row_tables %d", single_row_tables)
227 | logger.info("nested_tables %d", nested_tables)
228 |
229 | if out_file:
230 | convert_to_csv_for_lucene(tables_dict, out_file)
231 | return tables_dict
232 |
233 |
234 | def get_table_string_for_answer_check(table: Table): # this doesn't use caption
235 | table_text = ""
236 | for r in table.body:
237 | table_text += " . ".join([" ".join(c.value_tokens) for c in r.cells])
238 | table_text += " . "
239 | return table_text
240 |
241 |
242 | def convert_to_csv_for_lucene(tables_dict, out_file: str):
243 | id = 0
244 | with open(out_file, "w", newline="") as csvfile:
245 | writer = csv.writer(csvfile, delimiter="\t")
246 | for _, v in tables_dict.items():
247 | id += 1
248 | # strip all
249 | table_text = get_table_string_for_answer_check(v)
250 | writer.writerow([id, table_text, v.caption])
251 | logger.info("Saved to %s", out_file)
252 |
253 |
254 | def convert_jsonl_to_qas_tsv(path, out):
255 | results = []
256 | with jsonlines.open(path, mode="r") as jsonl_reader:
257 | for jline in jsonl_reader:
258 | q = jline["question"]
259 | answers = []
260 | if "short_answers" in jline:
261 | answers = jline["short_answers"]
262 |
263 | results.append((q, answers))
264 |
265 | with open(out, "w", newline="") as csvfile:
266 | writer = csv.writer(csvfile, delimiter="\t")
267 | for r in results:
268 | writer.writerow([r[0], r[1]])
269 | logger.info("Saved to %s", out)
270 |
271 |
272 | def tokenize(text):
273 | doc = nlp(text)
274 | return [token.text.lower() for token in doc]
275 |
276 |
277 | def normalize(text):
278 | """Resolve different type of unicode encodings."""
279 | return unicodedata.normalize("NFD", text)
280 |
281 |
282 | def prepare_answers(answers) -> List[List[str]]:
283 | r = []
284 | for single_answer in answers:
285 | single_answer = normalize(single_answer)
286 | single_answer = single_answer.lower().split(" ") # tokenize(single_answer)
287 | r.append(single_answer)
288 | return r
289 |
290 |
291 | def has_prepared_answer(prep_answers: List[List[str]], text: List[str]):
292 | """Check if a document contains an answer string."""
293 | text = [normalize(token).lower() for token in text]
294 |
295 | for single_answer in prep_answers:
296 | for i in range(0, len(text) - len(single_answer) + 1):
297 | if single_answer == text[i : i + len(single_answer)]:
298 | return True
299 | return False
300 |
301 |
302 | def has_answer(answers, text, regMatxh=False):
303 | """Check if a document contains an answer string."""
304 |
305 | text = normalize(text)
306 |
307 | if regMatxh:
308 | single_answer = normalize(answers[0])
309 | if regex_match(text, single_answer):
310 | return True
311 | else:
312 | # Answer is a list of possible strings
313 | text = tokenize(text)
314 |
315 | for single_answer in answers:
316 | single_answer = normalize(single_answer)
317 | single_answer = tokenize(single_answer)
318 |
319 | for i in range(0, len(text) - len(single_answer) + 1):
320 | if single_answer == text[i : i + len(single_answer)]:
321 | return True
322 | return False
323 |
324 |
325 | def convert_search_res_to_dpr_and_eval(
326 | res_file, all_tables_file_jsonl, nq_table_file, out_file, gold_res_file: str = None
327 | ):
328 | db = {}
329 | id = 0
330 | tables_dict = read_nq_tables_jsonl(all_tables_file_jsonl)
331 | for _, v in tables_dict.items():
332 | id += 1
333 | db[id] = v
334 |
335 | logger.info("db size %s", len(db))
336 | total = 0
337 | dpr_results = {}
338 | import torch
339 |
340 | bm25_per_topk_hits = torch.tensor([0] * 100)
341 | qas = []
342 | with open(res_file) as tsvfile:
343 | reader = csv.reader(tsvfile, delimiter="\t")
344 | # file format: id, text
345 | for row in reader:
346 | total += 1
347 | q = row[0]
348 | answers = eval(row[1])
349 | prep_answers = prepare_answers(answers)
350 | qas.append((q, prep_answers))
351 | question_hns = []
352 | question_positives = []
353 | answers_table_links = []
354 |
355 | for k, bm25result in enumerate(row[2:]):
356 | score, id = bm25result.split(",")
357 | table = db[int(id)]
358 | answer_locations = []
359 |
360 | def check_answer(tokens, row_idx: int, cell_idx: int):
361 | if has_prepared_answer(prep_answers, tokens):
362 | answer_locations.append((row_idx, cell_idx))
363 |
364 | # get string representation to find answer
365 | if (len(question_positives) >= 10 and len(question_hns) >= 10) or (len(question_hns) >= 30):
366 | break
367 |
368 | # table_str = get_table_string_for_answer_check(table)
369 | table.visit(check_answer)
370 | has_answer = len(answer_locations) > 0
371 |
372 | if has_answer:
373 | question_positives.append(table)
374 | answers_table_links.append(answer_locations)
375 | else:
376 | question_hns.append(table)
377 |
378 | dpr_results[q] = (question_positives, question_hns, answers_table_links)
379 | if len(dpr_results) % 100 == 0:
380 | logger.info("dpr_results %s", len(dpr_results))
381 |
382 | logger.info("dpr_results size %s", len(dpr_results))
383 | logger.info("total %s", total)
384 | logger.info("bm25_per_topk_hits %s", bm25_per_topk_hits)
385 |
386 | if gold_res_file:
387 | logger.info("Processing gold_res_file")
388 | with open(gold_res_file) as cFile:
389 | csvReader = csv.reader(cFile, delimiter=",")
390 | for row in csvReader:
391 | q_id = int(row[0])
392 | qas_tuple = qas[q_id]
393 | prep_answers = qas_tuple[1]
394 | question_gold_positive_match = None
395 | q = qas_tuple[0]
396 | answers_links = None
397 | for field in row[1:]:
398 | psg_id = int(field.split()[0])
399 | table = db[psg_id]
400 | answer_locations = []
401 |
402 | def check_answer(tokens, row_idx: int, cell_idx: int):
403 | if has_prepared_answer(prep_answers, tokens):
404 | answer_locations.append((row_idx, cell_idx))
405 |
406 | table.visit(check_answer)
407 | has_answer = len(answer_locations) > 0
408 | if has_answer and question_gold_positive_match is None:
409 | question_gold_positive_match = table
410 | question_gold_positive_match.gold_match = True
411 | answers_links = answer_locations
412 |
413 | if question_gold_positive_match is None:
414 | logger.info("No gold match for q=%s, q_id=%s", q, q_id)
415 | else: # inject into ctx+ at the first position
416 | question_positives, hns, ans_links = dpr_results[q]
417 | question_positives.insert(0, question_gold_positive_match)
418 | ans_links.insert(0, answers_links)
419 |
420 | out_results = []
421 | with jsonlines.open(nq_table_file, mode="r") as jsonl_reader:
422 | for jline in jsonl_reader:
423 | q = jline["question"]
424 | gold_positive_table = jline["contexts"][0]
425 | mask = gold_positive_table["html_mask"]
426 | # page_url = jline['doc_url']
427 | title = jline["title"]
428 | p = NQTableParser(gold_positive_table["tokens"], mask, title)
429 | tables = p.parse()
430 | # select the one with the answer(s)
431 | prep_answers = prepare_answers(jline["short_answers"])
432 |
433 | tables_with_answers = []
434 | tables_answer_locations = []
435 |
436 | for t in tables:
437 | answer_locations = []
438 |
439 | def check_answer(tokens, row_idx: int, cell_idx: int):
440 | if has_prepared_answer(prep_answers, tokens):
441 | answer_locations.append((row_idx, cell_idx))
442 |
443 | t.visit(check_answer)
444 | has_answer = len(answer_locations) > 0
445 | if has_answer:
446 | tables_with_answers.append(t)
447 | tables_answer_locations.append(answer_locations)
448 |
449 | if not tables_with_answers:
450 | logger.info("No answer in gold table(s) for q=%s", q)
451 |
452 | positive_ctxs, hard_neg_ctxs, answers_table_links = dpr_results[q]
453 | positive_ctxs = positive_ctxs + tables_with_answers
454 | tables_answer_locations = answers_table_links + tables_answer_locations
455 | assert len(positive_ctxs) == len(tables_answer_locations)
456 | positive_ctxs = [t.to_dpr_json() for t in positive_ctxs]
457 |
458 | # set has_answer attributes
459 | for i, ctx_json in enumerate(positive_ctxs):
460 | answer_links = tables_answer_locations[i]
461 | ctx_json["answer_pos"] = answer_links
462 | hard_neg_ctxs = [t.to_dpr_json() for t in hard_neg_ctxs]
463 | out_results.append(
464 | {
465 | "question": q,
466 | "id": jline["example_id"],
467 | "answers": jline["short_answers"],
468 | "positive_ctxs": positive_ctxs,
469 | "hard_negative_ctxs": hard_neg_ctxs,
470 | }
471 | )
472 |
473 | logger.info("out_results size %s", len(out_results))
474 |
475 | with jsonlines.open(out_file, mode="w") as writer: # encoding="utf-8", .encode('utf-8')
476 | for r in out_results:
477 | writer.write(r)
478 |
479 | logger.info("Saved to %s", out_file)
480 |
481 |
482 | def convert_long_ans_to_dpr(nq_table_file, out_file):
483 | out_results = []
484 | with jsonlines.open(nq_table_file, mode="r") as jsonl_reader:
485 | for jline in jsonl_reader:
486 | q = jline["question"]
487 |
488 | gold_positive_table = jline["contexts"]
489 |
490 | mask = gold_positive_table["la_ans_tokens_html_mask"]
491 | # page_url = jline['doc_url']
492 | title = jline["title"]
493 |
494 | p = NQTableParser(gold_positive_table["la_ans_tokens"], mask, title)
495 | tables = p.parse()
496 | # select the one with the answer(s)
497 |
498 | positive_ctxs = [tables[0].to_dpr_json()]
499 |
500 | out_results.append(
501 | {
502 | "question": q,
503 | "id": jline["example_id"],
504 | "answers": [],
505 | "positive_ctxs": positive_ctxs,
506 | "hard_negative_ctxs": [],
507 | }
508 | )
509 |
510 | logger.info("out_results size %s", len(out_results))
511 |
512 | with jsonlines.open(out_file, mode="w") as writer: # encoding="utf-8", .encode('utf-8')
513 | for r in out_results:
514 | writer.write(r)
515 |
516 | logger.info("Saved to %s", out_file)
517 |
518 |
519 | def parse_qa_csv_file(location):
520 | res = []
521 | with open(location) as ifile:
522 | reader = csv.reader(ifile, delimiter="\t")
523 | for row in reader:
524 | question = row[0]
525 | answers = eval(row[1])
526 | res.append((question, answers))
527 | return res
528 |
529 |
530 | def calc_questions_overlap(tables_file, regular_file, dev_file):
531 | tab_questions = set()
532 |
533 | with jsonlines.open(tables_file, mode="r") as jsonl_reader:
534 | logger.info("Reading file %s" % tables_file)
535 | for jline in jsonl_reader:
536 | q = jline["question"]
537 | tab_questions.add(q)
538 |
539 | reg_questions = set()
540 |
541 | if regular_file[-4:] == ".csv":
542 | qas = parse_qa_csv_file(regular_file)
543 | for qa in qas:
544 | reg_questions.add(qa[0])
545 | else:
546 | with open(regular_file, "r", encoding="utf-8") as f:
547 | logger.info("Reading file %s" % regular_file)
548 | data = json.load(f)
549 | for item in data:
550 | q = item["question"]
551 | reg_questions.add(q)
552 | if dev_file:
553 | if dev_file[-4:] == ".csv":
554 | qas = parse_qa_csv_file(dev_file)
555 | for qa in qas:
556 | reg_questions.add(qa[0])
557 | else:
558 | with open(dev_file, "r", encoding="utf-8") as f:
559 | logger.info("Reading file %s" % dev_file)
560 | data = json.load(f)
561 | for item in data:
562 | q = item["question"]
563 | reg_questions.add(q)
564 |
565 | logger.info("tab_questions %d", len(tab_questions))
566 | logger.info("reg_questions %d", len(reg_questions))
567 | logger.info("overlap %d", len(tab_questions.intersection(reg_questions)))
568 |
569 |
570 | def convert_train_jsonl_to_ctxmatch(path: str, out_file: str):
571 | def get_table_string_for_ctx_match(table: dict): # this doesn't use caption
572 | table_text = table["caption"] + " . "
573 | for r in table["rows"]:
574 | table_text += " . ".join([c["value"] for c in r["columns"]])
575 | table_text += " . "
576 | return table_text
577 |
578 | results = []
579 | with jsonlines.open(path, mode="r") as jsonl_reader:
580 | for jline in jsonl_reader:
581 | if len(jline["positive_ctxs"]) == 0:
582 | continue
583 | ctx_pos = jline["positive_ctxs"][0]
584 | table_str = get_table_string_for_ctx_match(ctx_pos)
585 | q = jline["question"]
586 | results.append((q, table_str))
587 |
588 | if len(results) % 1000 == 0:
589 | logger.info("results %d", len(results))
590 |
591 | shards_sz = 3000
592 | shard = 0
593 |
594 | for s in range(0, len(results), shards_sz):
595 | chunk = results[s : s + shards_sz]
596 | shard_file = out_file + ".shard_{}".format(shard)
597 | with jsonlines.open(shard_file, mode="w") as writer:
598 | logger.info("Saving to %s", shard_file)
599 | for i, item in enumerate(chunk):
600 | writer.write({"id": s + i, "question": item[0], "context": item[1]})
601 | shard += 1
602 |
603 |
604 | # TODO: tmp copy-paste fix to avoid circular dependency
605 | def regex_match(text, pattern):
606 | """Test if a regex pattern is contained within a text."""
607 | try:
608 | pattern = re.compile(pattern, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE)
609 | except BaseException:
610 | return False
611 | return pattern.search(text) is not None
612 |
--------------------------------------------------------------------------------
/dpr/data/reader_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """
9 | Set of utilities for the Reader model related data processing tasks
10 | """
11 |
12 | import collections
13 | import glob
14 | import json
15 | import logging
16 | import math
17 | import multiprocessing
18 | import os
19 | import pickle
20 | from functools import partial
21 | from typing import Tuple, List, Dict, Iterable, Optional
22 |
23 | import torch
24 | from torch import Tensor as T
25 | from tqdm import tqdm
26 |
27 | from dpr.utils.data_utils import (
28 | Tensorizer,
29 | read_serialized_data_from_files,
30 | read_data_from_json_files,
31 | Dataset as DprDataset,
32 | )
33 |
34 | logger = logging.getLogger()
35 |
36 |
37 | class ReaderPassage(object):
38 | """
39 | Container to collect and cache all Q&A passages related attributes before generating the reader input
40 | """
41 |
42 | def __init__(
43 | self,
44 | id=None,
45 | text: str = None,
46 | title: str = None,
47 | score=None,
48 | has_answer: bool = None,
49 | ):
50 | self.id = id
51 | # string passage representations
52 | self.passage_text = text
53 | self.title = title
54 | self.score = score
55 | self.has_answer = has_answer
56 | self.passage_token_ids = None
57 | # offset of the actual passage (i.e. not a question or may be title) in the sequence_ids
58 | self.passage_offset = None
59 | self.answers_spans = None
60 | # passage token ids
61 | self.sequence_ids = None
62 |
63 | def on_serialize(self):
64 | # store only final sequence_ids and the ctx offset
65 | self.sequence_ids = self.sequence_ids.numpy()
66 | self.passage_text = None
67 | self.title = None
68 | self.passage_token_ids = None
69 |
70 | def on_deserialize(self):
71 | self.sequence_ids = torch.tensor(self.sequence_ids)
72 |
73 |
74 | class ReaderSample(object):
75 | """
76 | Container to collect all Q&A passages data per singe question
77 | """
78 |
79 | def __init__(
80 | self,
81 | question: str,
82 | answers: List,
83 | positive_passages: List[ReaderPassage] = [],
84 | negative_passages: List[ReaderPassage] = [],
85 | passages: List[ReaderPassage] = [],
86 | ):
87 | self.question = question
88 | self.answers = answers
89 | self.positive_passages = positive_passages
90 | self.negative_passages = negative_passages
91 | self.passages = passages
92 |
93 | def on_serialize(self):
94 | for passage in self.passages + self.positive_passages + self.negative_passages:
95 | passage.on_serialize()
96 |
97 | def on_deserialize(self):
98 | for passage in self.passages + self.positive_passages + self.negative_passages:
99 | passage.on_deserialize()
100 |
101 |
102 | class ExtractiveReaderDataset(torch.utils.data.Dataset):
103 | def __init__(
104 | self,
105 | files: str,
106 | is_train: bool,
107 | gold_passages_src: str,
108 | tensorizer: Tensorizer,
109 | run_preprocessing: bool,
110 | num_workers: int,
111 | ):
112 | self.files = files
113 | self.data = []
114 | self.is_train = is_train
115 | self.gold_passages_src = gold_passages_src
116 | self.tensorizer = tensorizer
117 | self.run_preprocessing = run_preprocessing
118 | self.num_workers = num_workers
119 |
120 | def __getitem__(self, index):
121 | return self.data[index]
122 |
123 | def __len__(self):
124 | return len(self.data)
125 |
126 | def calc_total_data_len(self):
127 | if not self.data:
128 | self.load_data()
129 | return len(self.data)
130 |
131 | def load_data(
132 | self,
133 | ):
134 | if self.data:
135 | return
136 |
137 | data_files = glob.glob(self.files)
138 | logger.info("Data files: %s", data_files)
139 | if not data_files:
140 | raise RuntimeError("No Data files found")
141 | preprocessed_data_files = self._get_preprocessed_files(data_files)
142 | self.data = read_serialized_data_from_files(preprocessed_data_files)
143 |
144 | def _get_preprocessed_files(
145 | self,
146 | data_files: List,
147 | ):
148 |
149 | serialized_files = [file for file in data_files if file.endswith(".pkl")]
150 | if serialized_files:
151 | return serialized_files
152 | assert len(data_files) == 1, "Only 1 source file pre-processing is supported."
153 |
154 | # data may have been serialized and cached before, try to find ones from same dir
155 | def _find_cached_files(path: str):
156 | dir_path, base_name = os.path.split(path)
157 | base_name = base_name.replace(".json", "")
158 | out_file_prefix = os.path.join(dir_path, base_name)
159 | out_file_pattern = out_file_prefix + "*.pkl"
160 | return glob.glob(out_file_pattern), out_file_prefix
161 |
162 | serialized_files, out_file_prefix = _find_cached_files(data_files[0])
163 | if serialized_files:
164 | logger.info("Found preprocessed files. %s", serialized_files)
165 | return serialized_files
166 |
167 | logger.info("Data are not preprocessed for reader training. Start pre-processing ...")
168 |
169 | # start pre-processing and save results
170 | def _run_preprocessing(tensorizer: Tensorizer):
171 | # temporarily disable auto-padding to save disk space usage of serialized files
172 | tensorizer.set_pad_to_max(False)
173 | serialized_files = convert_retriever_results(
174 | self.is_train,
175 | data_files[0],
176 | out_file_prefix,
177 | self.gold_passages_src,
178 | self.tensorizer,
179 | num_workers=self.num_workers,
180 | )
181 | tensorizer.set_pad_to_max(True)
182 | return serialized_files
183 |
184 | if self.run_preprocessing:
185 | serialized_files = _run_preprocessing(self.tensorizer)
186 | # TODO: check if pytorch process group is initialized
187 | # torch.distributed.barrier()
188 | else:
189 | # torch.distributed.barrier()
190 | serialized_files = _find_cached_files(data_files[0])
191 | return serialized_files
192 |
193 |
194 | SpanPrediction = collections.namedtuple(
195 | "SpanPrediction",
196 | [
197 | "prediction_text",
198 | "span_score",
199 | "relevance_score",
200 | "passage_index",
201 | "passage_token_ids",
202 | ],
203 | )
204 |
205 | # configuration for reader model passage selection
206 | ReaderPreprocessingCfg = collections.namedtuple(
207 | "ReaderPreprocessingCfg",
208 | [
209 | "use_tailing_sep",
210 | "skip_no_positves",
211 | "include_gold_passage",
212 | "gold_page_only_positives",
213 | "max_positives",
214 | "max_negatives",
215 | "min_negatives",
216 | "max_retriever_passages",
217 | ],
218 | )
219 |
220 | DEFAULT_PREPROCESSING_CFG_TRAIN = ReaderPreprocessingCfg(
221 | use_tailing_sep=False,
222 | skip_no_positves=True,
223 | include_gold_passage=False, # True - for speech Q&A
224 | gold_page_only_positives=True,
225 | max_positives=20,
226 | max_negatives=50,
227 | min_negatives=150,
228 | max_retriever_passages=200,
229 | )
230 |
231 | DEFAULT_EVAL_PASSAGES = 100
232 |
233 |
234 | def preprocess_retriever_data(
235 | samples: List[Dict],
236 | gold_info_file: Optional[str],
237 | tensorizer: Tensorizer,
238 | cfg: ReaderPreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN,
239 | is_train_set: bool = True,
240 | ) -> Iterable[ReaderSample]:
241 | """
242 | Converts retriever results into reader training data.
243 | :param samples: samples from the retriever's json file results
244 | :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ
245 | :param tensorizer: Tensorizer object for text to model input tensors conversions
246 | :param cfg: ReaderPreprocessingCfg object with positive and negative passage selection parameters
247 | :param is_train_set: if the data should be processed as a train set
248 | :return: iterable of ReaderSample objects which can be consumed by the reader model
249 | """
250 | sep_tensor = tensorizer.get_pair_separator_ids() # separator can be a multi token
251 | gold_passage_map, canonical_questions = _get_gold_ctx_dict(gold_info_file) if gold_info_file else ({}, {})
252 |
253 | no_positive_passages = 0
254 | positives_from_gold = 0
255 |
256 | def create_reader_sample_ids(sample: ReaderPassage, question: str):
257 | question_and_title = tensorizer.text_to_tensor(sample.title, title=question, add_special_tokens=True)
258 | if sample.passage_token_ids is None:
259 | sample.passage_token_ids = tensorizer.text_to_tensor(sample.passage_text, add_special_tokens=False)
260 |
261 | all_concatenated, shift = _concat_pair(
262 | question_and_title,
263 | sample.passage_token_ids,
264 | tailing_sep=sep_tensor if cfg.use_tailing_sep else None,
265 | )
266 |
267 | sample.sequence_ids = all_concatenated
268 | sample.passage_offset = shift
269 | assert shift > 1
270 | if sample.has_answer and is_train_set:
271 | sample.answers_spans = [(span[0] + shift, span[1] + shift) for span in sample.answers_spans]
272 | return sample
273 |
274 | for sample in samples:
275 | question = sample["question"]
276 | question_txt = sample["query_text"] if "query_text" in sample else question
277 |
278 | if canonical_questions and question_txt in canonical_questions:
279 | question_txt = canonical_questions[question_txt]
280 |
281 | positive_passages, negative_passages = _select_reader_passages(
282 | sample,
283 | question_txt,
284 | tensorizer,
285 | gold_passage_map,
286 | cfg.gold_page_only_positives,
287 | cfg.max_positives,
288 | cfg.max_negatives,
289 | cfg.min_negatives,
290 | cfg.max_retriever_passages,
291 | cfg.include_gold_passage,
292 | is_train_set,
293 | )
294 | # create concatenated sequence ids for each passage and adjust answer spans
295 | positive_passages = [create_reader_sample_ids(s, question) for s in positive_passages]
296 | negative_passages = [create_reader_sample_ids(s, question) for s in negative_passages]
297 |
298 | if is_train_set and len(positive_passages) == 0:
299 | no_positive_passages += 1
300 | if cfg.skip_no_positves:
301 | continue
302 |
303 | if next(iter(ctx for ctx in positive_passages if ctx.score == -1), None):
304 | positives_from_gold += 1
305 |
306 | if is_train_set:
307 | yield ReaderSample(
308 | question,
309 | sample["answers"],
310 | positive_passages=positive_passages,
311 | negative_passages=negative_passages,
312 | )
313 | else:
314 | yield ReaderSample(question, sample["answers"], passages=negative_passages)
315 |
316 | logger.info("no positive passages samples: %d", no_positive_passages)
317 | logger.info("positive passages from gold samples: %d", positives_from_gold)
318 |
319 |
320 | def convert_retriever_results(
321 | is_train_set: bool,
322 | input_file: str,
323 | out_file_prefix: str,
324 | gold_passages_file: str,
325 | tensorizer: Tensorizer,
326 | num_workers: int = 8,
327 | ) -> List[str]:
328 | """
329 | Converts the file with dense retriever(or any compatible file format) results into the reader input data and
330 | serializes them into a set of files.
331 | Conversion splits the input data into multiple chunks and processes them in parallel. Each chunk results are stored
332 | in a separate file with name out_file_prefix.{number}.pkl
333 | :param is_train_set: if the data should be processed for a train set (i.e. with answer span detection)
334 | :param input_file: path to a json file with data to convert
335 | :param out_file_prefix: output path prefix.
336 | :param gold_passages_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ
337 | :param tensorizer: Tensorizer object for text to model input tensors conversions
338 | :param num_workers: the number of parallel processes for conversion
339 | :return: names of files with serialized results
340 | """
341 | with open(input_file, "r", encoding="utf-8") as f:
342 | samples = json.loads("".join(f.readlines()))
343 | logger.info("Loaded %d questions + retrieval results from %s", len(samples), input_file)
344 | workers = multiprocessing.Pool(num_workers)
345 | ds_size = len(samples)
346 | step = max(math.ceil(ds_size / num_workers), 1)
347 | chunks = [samples[i : i + step] for i in range(0, ds_size, step)]
348 | chunks = [(i, chunks[i]) for i in range(len(chunks))]
349 |
350 | logger.info("Split data into %d chunks", len(chunks))
351 |
352 | processed = 0
353 | _parse_batch = partial(
354 | _preprocess_reader_samples_chunk,
355 | out_file_prefix=out_file_prefix,
356 | gold_passages_file=gold_passages_file,
357 | tensorizer=tensorizer,
358 | is_train_set=is_train_set,
359 | )
360 | serialized_files = []
361 | for file_name in workers.map(_parse_batch, chunks):
362 | processed += 1
363 | serialized_files.append(file_name)
364 | logger.info("Chunks processed %d", processed)
365 | logger.info("Data saved to %s", file_name)
366 | logger.info("Preprocessed data stored in %s", serialized_files)
367 | return serialized_files
368 |
369 |
370 | def get_best_spans(
371 | tensorizer: Tensorizer,
372 | start_logits: List,
373 | end_logits: List,
374 | ctx_ids: List,
375 | max_answer_length: int,
376 | passage_idx: int,
377 | relevance_score: float,
378 | top_spans: int = 1,
379 | ) -> List[SpanPrediction]:
380 | """
381 | Finds the best answer span for the extractive Q&A model
382 | """
383 | scores = []
384 | for (i, s) in enumerate(start_logits):
385 | for (j, e) in enumerate(end_logits[i : i + max_answer_length]):
386 | scores.append(((i, i + j), s + e))
387 |
388 | scores = sorted(scores, key=lambda x: x[1], reverse=True)
389 |
390 | chosen_span_intervals = []
391 | best_spans = []
392 |
393 | for (start_index, end_index), score in scores:
394 | assert start_index <= end_index
395 | length = end_index - start_index + 1
396 | assert length <= max_answer_length
397 |
398 | if any(
399 | [
400 | start_index <= prev_start_index <= prev_end_index <= end_index
401 | or prev_start_index <= start_index <= end_index <= prev_end_index
402 | for (prev_start_index, prev_end_index) in chosen_span_intervals
403 | ]
404 | ):
405 | continue
406 |
407 | # extend bpe subtokens to full tokens
408 | start_index, end_index = _extend_span_to_full_words(tensorizer, ctx_ids, (start_index, end_index))
409 |
410 | predicted_answer = tensorizer.to_string(ctx_ids[start_index : end_index + 1])
411 | best_spans.append(SpanPrediction(predicted_answer, score, relevance_score, passage_idx, ctx_ids))
412 | chosen_span_intervals.append((start_index, end_index))
413 |
414 | if len(chosen_span_intervals) == top_spans:
415 | break
416 | return best_spans
417 |
418 |
419 | def _select_reader_passages(
420 | sample: Dict,
421 | question: str,
422 | tensorizer: Tensorizer,
423 | gold_passage_map: Optional[Dict[str, ReaderPassage]],
424 | gold_page_only_positives: bool,
425 | max_positives: int,
426 | max1_negatives: int,
427 | max2_negatives: int,
428 | max_retriever_passages: int,
429 | include_gold_passage: bool,
430 | is_train_set: bool,
431 | ) -> Tuple[List[ReaderPassage], List[ReaderPassage]]:
432 | answers = sample["answers"]
433 |
434 | ctxs = [ReaderPassage(**ctx) for ctx in sample["ctxs"]][0:max_retriever_passages]
435 | answers_token_ids = [tensorizer.text_to_tensor(a, add_special_tokens=False) for a in answers]
436 |
437 | if is_train_set:
438 | positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs))
439 | negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs))
440 | else:
441 | positive_samples = []
442 | negative_samples = ctxs
443 |
444 | positive_ctxs_from_gold_page = (
445 | list(
446 | filter(
447 | lambda ctx: _is_from_gold_wiki_page(gold_passage_map, ctx.title, question),
448 | positive_samples,
449 | )
450 | )
451 | if gold_page_only_positives and gold_passage_map
452 | else []
453 | )
454 |
455 | def find_answer_spans(ctx: ReaderPassage):
456 | if ctx.has_answer:
457 | if ctx.passage_token_ids is None:
458 | ctx.passage_token_ids = tensorizer.text_to_tensor(ctx.passage_text, add_special_tokens=False)
459 |
460 | answer_spans = [
461 | _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in range(len(answers))
462 | ]
463 |
464 | # flatten spans list
465 | answer_spans = [item for sublist in answer_spans for item in sublist]
466 | answers_spans = list(filter(None, answer_spans))
467 | ctx.answers_spans = answers_spans
468 |
469 | if not answers_spans:
470 | logger.warning(
471 | "No answer found in passage id=%s text=%s, answers=%s, question=%s",
472 | ctx.id,
473 | "", # ctx.passage_text
474 | answers,
475 | question,
476 | )
477 | ctx.has_answer = bool(answers_spans)
478 | return ctx
479 |
480 | # check if any of the selected ctx+ has answer spans
481 | selected_positive_ctxs = list(
482 | filter(
483 | lambda ctx: ctx.has_answer,
484 | [find_answer_spans(ctx) for ctx in positive_ctxs_from_gold_page],
485 | )
486 | )
487 |
488 | if not selected_positive_ctxs: # fallback to positive ctx not from gold pages
489 | selected_positive_ctxs = list(
490 | filter(
491 | lambda ctx: ctx.has_answer,
492 | [find_answer_spans(ctx) for ctx in positive_samples],
493 | )
494 | )[0:max_positives]
495 |
496 | # optionally include gold passage itself if it is still not in the positives list
497 | if include_gold_passage and question in gold_passage_map:
498 | gold_passage = gold_passage_map[question]
499 | included_gold_passage = next(
500 | iter(ctx for ctx in selected_positive_ctxs if ctx.passage_text == gold_passage.passage_text),
501 | None,
502 | )
503 | if not included_gold_passage:
504 | gold_passage.has_answer = True
505 | gold_passage = find_answer_spans(gold_passage)
506 | if not gold_passage.has_answer:
507 | logger.warning("No answer found in gold passage: %s", gold_passage)
508 | else:
509 | selected_positive_ctxs.append(gold_passage)
510 |
511 | max_negatives = (
512 | min(max(10 * len(selected_positive_ctxs), max1_negatives), max2_negatives)
513 | if is_train_set
514 | else DEFAULT_EVAL_PASSAGES
515 | )
516 | negative_samples = negative_samples[0:max_negatives]
517 | return selected_positive_ctxs, negative_samples
518 |
519 |
520 | def _find_answer_positions(ctx_ids: T, answer: T) -> List[Tuple[int, int]]:
521 | c_len = ctx_ids.size(0)
522 | a_len = answer.size(0)
523 | answer_occurences = []
524 | for i in range(0, c_len - a_len + 1):
525 | if (answer == ctx_ids[i : i + a_len]).all():
526 | answer_occurences.append((i, i + a_len - 1))
527 | return answer_occurences
528 |
529 |
530 | def _concat_pair(t1: T, t2: T, middle_sep: T = None, tailing_sep: T = None):
531 | middle = [middle_sep] if middle_sep else []
532 | r = [t1] + middle + [t2] + ([tailing_sep] if tailing_sep else [])
533 | return torch.cat(r, dim=0), t1.size(0) + len(middle)
534 |
535 |
536 | def _get_gold_ctx_dict(file: str) -> Tuple[Dict[str, ReaderPassage], Dict[str, str]]:
537 | gold_passage_infos = {} # question|question_tokens -> ReaderPassage (with title and gold ctx)
538 |
539 | # original NQ dataset has 2 forms of same question - original, and tokenized.
540 | # Tokenized form is not fully consisted with the original question if tokenized by some encoder tokenizers
541 | # Specifically, this is the case for the BERT tokenizer.
542 | # Depending of which form was used for retriever training and results generation, it may be useful to convert
543 | # all questions to the canonical original representation.
544 | original_questions = {} # question from tokens -> original question (NQ only)
545 |
546 | with open(file, "r", encoding="utf-8") as f:
547 | logger.info("Reading file %s" % file)
548 | data = json.load(f)["data"]
549 |
550 | for sample in data:
551 | question = sample["question"]
552 | question_from_tokens = sample["question_tokens"] if "question_tokens" in sample else question
553 | original_questions[question_from_tokens] = question
554 | title = sample["title"].lower()
555 | context = sample["context"] # Note: This one is cased
556 | rp = ReaderPassage(sample["example_id"], text=context, title=title)
557 | if question in gold_passage_infos:
558 | logger.info("Duplicate question %s", question)
559 | rp_exist = gold_passage_infos[question]
560 | logger.info(
561 | "Duplicate question gold info: title new =%s | old title=%s",
562 | title,
563 | rp_exist.title,
564 | )
565 | logger.info("Duplicate question gold info: new ctx =%s ", context)
566 | logger.info("Duplicate question gold info: old ctx =%s ", rp_exist.passage_text)
567 | gold_passage_infos[question] = rp
568 | gold_passage_infos[question_from_tokens] = rp
569 | return gold_passage_infos, original_questions
570 |
571 |
572 | def _is_from_gold_wiki_page(gold_passage_map: Dict[str, ReaderPassage], passage_title: str, question: str):
573 | gold_info = gold_passage_map.get(question, None)
574 | if gold_info:
575 | return passage_title.lower() == gold_info.title.lower()
576 | return False
577 |
578 |
579 | def _extend_span_to_full_words(tensorizer: Tensorizer, tokens: List[int], span: Tuple[int, int]) -> Tuple[int, int]:
580 | start_index, end_index = span
581 | max_len = len(tokens)
582 | while start_index > 0 and tensorizer.is_sub_word_id(tokens[start_index]):
583 | start_index -= 1
584 |
585 | while end_index < max_len - 1 and tensorizer.is_sub_word_id(tokens[end_index + 1]):
586 | end_index += 1
587 |
588 | return start_index, end_index
589 |
590 |
591 | def _preprocess_reader_samples_chunk(
592 | samples: List,
593 | out_file_prefix: str,
594 | gold_passages_file: str,
595 | tensorizer: Tensorizer,
596 | is_train_set: bool,
597 | ) -> str:
598 | chunk_id, samples = samples
599 | logger.info("Start batch %d", len(samples))
600 | iterator = preprocess_retriever_data(
601 | samples,
602 | gold_passages_file,
603 | tensorizer,
604 | is_train_set=is_train_set,
605 | )
606 |
607 | results = []
608 |
609 | iterator = tqdm(iterator)
610 | for i, r in enumerate(iterator):
611 | r.on_serialize()
612 | results.append(r)
613 |
614 | out_file = out_file_prefix + "." + str(chunk_id) + ".pkl"
615 | with open(out_file, mode="wb") as f:
616 | logger.info("Serialize %d results to %s", len(results), out_file)
617 | pickle.dump(results, f)
618 | return out_file
619 |
--------------------------------------------------------------------------------
|