├── .gitignore ├── README.md ├── agent ├── __init__.py ├── client.py └── core.py ├── analysis.ipynb ├── basic_tokenizer.py ├── cache_bm25.ipynb ├── config.py ├── demo.ipynb ├── dense_encoder.py ├── dense_indexer.py ├── dpr ├── __init__.py ├── data │ ├── __init__.py │ ├── qa_validation.py │ └── reader_data.py ├── indexer │ └── faiss_indexers.py ├── models │ ├── __init__.py │ ├── biencoder.py │ ├── fairseq_models.py │ ├── hf_models.py │ ├── pytext_models.py │ └── reader.py ├── options.py └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── dist_utils.py │ ├── model_utils.py │ └── tokenizers.py ├── drqa ├── __init__.py ├── pipeline │ ├── __init__.py │ └── drqa.py ├── reader │ ├── __init__.py │ ├── config.py │ ├── data.py │ ├── layers.py │ ├── model.py │ ├── predictor.py │ ├── rnn_reader.py │ ├── utils.py │ └── vector.py ├── retriever │ ├── __init__.py │ ├── doc_db.py │ ├── elastic_doc_ranker.py │ ├── tfidf_doc_ranker.py │ └── utils.py └── tokenizers │ ├── __init__.py │ ├── corenlp_tokenizer.py │ ├── regexp_tokenizer.py │ ├── simple_tokenizer.py │ ├── spacy_tokenizer.py │ └── tokenizer.py ├── encode_corpus.py ├── env ├── __init__.py ├── client.py ├── core.py └── server.py ├── evaluation ├── __init__.py └── ir.py ├── figs ├── demo.gif └── simple_demo.png ├── game.ipynb ├── gen_cmd.ipynb ├── gen_step_data.ipynb ├── generate_dense.py ├── hotpot_evaluate_plus.py ├── index_sparse.5.ipynb ├── index_sparse.ipynb ├── install_corenlp.sh ├── mdr ├── __init__.py ├── qa │ ├── __init__.py │ ├── basic_tokenizer.py │ ├── config.py │ ├── data_utils.py │ ├── hotpot_evaluate_v1.py │ ├── qa_dataset.py │ ├── qa_model.py │ ├── qa_trainer.py │ ├── train.md │ ├── train_ranker.py │ └── utils.py └── retrieval │ ├── __init__.py │ ├── config.py │ ├── criterions.py │ ├── data │ ├── __init__.py │ ├── data_utils.py │ ├── encode_datasets.py │ ├── fever_dataset.py │ ├── mhop_dataset.py │ ├── sp_datasets.py │ └── unified_dataset.py │ ├── decomposed_analysis.py │ ├── fever.ipynb │ ├── hotpot.ipynb │ ├── interactive_retrieval.py │ ├── mhop_trainer.py │ ├── models │ ├── hop1_retriever.py │ ├── mhop_retriever.py │ ├── retriever.py │ └── unified_retriever.py │ ├── single_trainer.py │ ├── train_single.py │ └── utils │ ├── basic_tokenizer.py │ ├── gen_index_id_map.py │ ├── mhop_utils.py │ ├── tokenizer.py │ └── utils.py ├── mdr_encode_corpus.py ├── mdr_eval.py ├── models ├── __init__.py ├── reranker.py └── union_model.py ├── refine_hard_neg.ipynb ├── requirements.txt ├── reranking_data.py ├── retriever.py ├── text_clean.py ├── train_union.py ├── transition_data.py ├── utils ├── __init__.py ├── data_utils.py ├── model_utils.py ├── rank_losses.py ├── tensor_utils.py ├── text_utils.py └── utils.py └── wiki_world.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | .idea/ 4 | .vscode/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # Flask stuff: 12 | instance/ 13 | .webassets-cache 14 | 15 | # Scrapy stuff: 16 | .scrapy 17 | 18 | # Jupyter Notebook 19 | .ipynb_checkpoints 20 | 21 | # IPython 22 | profile_default/ 23 | ipython_config.py 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AISO 2 | 3 | This repository hosts the authors' implementation of the paper [Adaptive Information Seeking for Open-Domain Question Answering](https://arxiv.org/pdf/2109.06747.pdf), published in EMNLP 2021. 4 | 5 | ![demo](figs/demo.gif) 6 | 7 | ## Usage 8 | 9 | ### Set up environment 10 | 11 | Our experiments are conducted on Python 3.6 and PyTorch 1.4. 12 | 13 | We employ [GoldEn retriever]() as our query reformulator for the sparse retriever, so you need also install elasticsearch, java >= 8 and corenlp (run `install_corenlp.sh`). 14 | 15 | ### Prepare index and training data 16 | 17 | See [index_sparse.ipynb](index_sparse.ipynb), [gen_step_data.ipynb](gen_step_data.ipynb). 18 | 19 | ### Training 20 | 21 | See [train_union.py](train_union.py). 22 | 23 | ### Inference 24 | 25 | See [game.ipynb](game.ipynb). 26 | 27 | ## Demo 28 | 29 | The front-end of the demo in the first GIF is not open source, but we provide a simple visual interface based on jupyter widgets in the [notebook](demo.ipynb). 30 | 31 | ![simple demo](figs/simple_demo.png) 32 | 33 | ## TODO 34 | 35 | - [ ] Convert jupyer notebooks to scripts 36 | - [ ] More dependencies detail about environment setup 37 | - [ ] Upload processed training data and model checkpoints 38 | 39 | ## Citation 40 | 41 | ``` 42 | @inproceedings{zhu-etal-2021-adaptive, 43 | title = "Adaptive Information Seeking for Open-Domain Question Answering", 44 | author = "Zhu, Yunchang and 45 | Pang, Liang and 46 | Lan, Yanyan and 47 | Shen, Huawei and 48 | Cheng, Xueqi", 49 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 50 | month = nov, 51 | year = "2021", 52 | address = "Online and Punta Cana, Dominican Republic", 53 | publisher = "Association for Computational Linguistics", 54 | url = "https://aclanthology.org/2021.emnlp-main.293", 55 | pages = "3615--3626", 56 | } 57 | ``` 58 | 59 | -------------------------------------------------------------------------------- /agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/agent/__init__.py -------------------------------------------------------------------------------- /agent/client.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from requests import delete, get, post 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class Agent(object): 8 | 9 | def __init__(self, server='10.60.1.79:17101'): 10 | self.server = server 11 | 12 | def reset(self): 13 | delete(f'http://{self.server}/games') 14 | 15 | def memory(self, game_id, dtype='dict'): 16 | return get(f'http://{self.server}/memory/{game_id}', params={"dtype": dtype}).json() 17 | 18 | def add_evidence(self, game_id, p_id): 19 | post(f'http://{self.server}/memory/{game_id}/{p_id}') 20 | 21 | def del_evidence(self, game_id, p_id): 22 | delete(f'http://{self.server}/memory/{game_id}/{p_id}') 23 | 24 | def act(self, game_ids, questions, observations=None, review=False): 25 | args = {"game_ids": game_ids, "questions": questions, "observations": observations, "review": review} 26 | return post(f'http://{self.server}/actions', json=args).json() 27 | 28 | def proposals(self, game_id, step): 29 | return get(f'http://{self.server}/proposals/{game_id}/{step}').json() 30 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | ADDITIONAL_SPECIAL_TOKENS = { 4 | "YES": "[unused0]", 5 | "NO": "[unused1]", 6 | "SOP": "[unused2]", 7 | "NONE": "[unused3]" 8 | } 9 | FUNCTIONS = ("ANSWER", "BM25", "MDR", "LINK") 10 | FUNC2ID = {func: idx for idx, func in enumerate(FUNCTIONS)} 11 | NA_POS = 3 12 | 13 | 14 | def common_args(): 15 | parser = argparse.ArgumentParser() 16 | 17 | # input and output 18 | parser.add_argument("--corpus_path", type=str, default="data/corpus/hotpot-paragraph-5.tsv") 19 | parser.add_argument("--train_file", type=str, default="data/hotpot-step-train.jsonl") 20 | parser.add_argument("--predict_file", type=str, default="data/hotpot-step-dev.jsonl") 21 | parser.add_argument("--output_dir", default="ckpts", type=str, 22 | help="The output directory where the model checkpoints and predictions will be written.") 23 | 24 | parser.add_argument("--do_train", default=False, action='store_true', 25 | help="Whether to run training.") 26 | parser.add_argument("--do_predict", default=False, action='store_true', 27 | help="Whether to run eval on the dev set.") 28 | parser.add_argument("--do_test", default=False, action="store_true", 29 | help="for final test submission") 30 | 31 | # model 32 | parser.add_argument("--encoder_name", default="google/electra-base-discriminator", type=str) 33 | parser.add_argument("--init_checkpoint", default=None, type=str, 34 | help="Initial checkpoint (usually from a pre-trained BERT model).") 35 | 36 | # data 37 | parser.add_argument("--max_seq_len", default=512, type=int, 38 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 39 | "longer than this will be truncated, and sequences shorter than this will be padded.") 40 | parser.add_argument("--max_q_len", default=96, type=int) 41 | parser.add_argument("--max_obs_len", default=256, type=int) 42 | parser.add_argument("--max_ans_len", default=64, type=int) 43 | parser.add_argument("--hard_negs_per_state", type=int, default=2, 44 | help="how many hard negative observations per state") 45 | parser.add_argument("--memory_size", type=int, default=3, 46 | help="max num of passages stored in memory") 47 | parser.add_argument("--max_distractors", type=int, default=2, 48 | help="max num of distractor passages in context") 49 | parser.add_argument("--num_workers", default=0, type=int) 50 | parser.add_argument("--strict", action="store_true", help="whether to strictly use original data of dataset") 51 | 52 | parser.add_argument("--no_cuda", default=False, action='store_true', 53 | help="Whether not to use CUDA when available") 54 | parser.add_argument('--fp16', action='store_true') 55 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 56 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 57 | "See details at https://nvidia.github.io/apex/amp.html") 58 | parser.add_argument("--local_rank", type=int, default=-1, 59 | help="local_rank for distributed training on gpus") 60 | 61 | parser.add_argument("--per_gpu_infer_batch_size", default=16, type=int, 62 | help="Batch size per GPU for inference.") 63 | parser.add_argument("--save-prediction", default="", type=str) 64 | 65 | parser.add_argument("--sp_pred", action="store_true", help="whether to predict sentence sp") 66 | 67 | parser.add_argument('--debug', action='store_true') 68 | 69 | parser.add_argument('--seed', type=int, default=42, 70 | help="random seed for initialization") 71 | 72 | return parser 73 | 74 | 75 | def train_args(): 76 | parser = common_args() 77 | 78 | parser.add_argument('--tag', default=None, type=str, 79 | help='The comment to the experiment') 80 | parser.add_argument('--comment', default=None, type=str, 81 | help='The comment to the experiment') 82 | 83 | # model 84 | parser.add_argument('--cmd_dropout_prob', type=float, default=0.1) 85 | 86 | # optimization 87 | parser.add_argument("--sp_weight", default=0.0, type=float, help="weight of the sp loss") 88 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 89 | help="Batch size per GPU for training.") 90 | parser.add_argument("--learning_rate", default=1e-5, type=float, 91 | help="The initial learning rate for Adam.") 92 | parser.add_argument("--warmup_ratio", default=0.0, type=float, 93 | help="Linear warmup over warmup_steps.") 94 | parser.add_argument("--weight_decay", default=0.0, type=float, 95 | help="Weight decay if we apply some.") 96 | parser.add_argument("--use_adam", action="store_true", 97 | help="use adam or adamW") 98 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 99 | help="Epsilon for Adam optimizer.") 100 | parser.add_argument("--num_train_epochs", default=5, type=int, 101 | help="Total number of training epochs to perform.") 102 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 103 | help="Number of updates steps to accumulate before performing a backward/update pass.") 104 | parser.add_argument("--max_grad_norm", default=2.0, type=float, 105 | help="Max gradient norm.") 106 | 107 | parser.add_argument('--log_period', type=int, default=100, 108 | help="Log every X updates steps.") 109 | parser.add_argument('--eval_period', type=int, default=1000, 110 | help="Evaluate every X updates steps.") 111 | parser.add_argument("--criterion_metric", default="joint_f1") 112 | parser.add_argument('--save_period', type=int, default=-1, 113 | help="Save checkpoint every X updates steps.") 114 | 115 | return parser.parse_args() 116 | -------------------------------------------------------------------------------- /dense_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import logging 3 | from typing import Dict, List, Union, Tuple 4 | # from tqdm import trange 5 | from tqdm.auto import trange 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset 11 | 12 | from transformers import RobertaTokenizer, DistilBertTokenizer, DistilBertModel 13 | from mdr.retrieval.models.retriever import RobertaCtxEncoder 14 | 15 | from utils.model_utils import get_device 16 | from utils.tensor_utils import pad_tensors, to_device 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | # Define type aliases 21 | TextPair = Tuple[str, str] 22 | 23 | 24 | class DenseEncoder(metaclass=ABCMeta): 25 | @abstractmethod 26 | def encode_queries(self, queries: List, batch_size: int = None, **kwargs) -> np.ndarray: 27 | pass 28 | 29 | @abstractmethod 30 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = None, **kwargs) -> np.ndarray: 31 | pass 32 | 33 | 34 | class MDREncoder(DenseEncoder): 35 | def __init__(self, model: Union[RobertaCtxEncoder, nn.DataParallel], tokenizer: RobertaTokenizer, 36 | max_q_len: int = None, max_p_len: int = None): 37 | self.model = model 38 | self.tokenizer = tokenizer 39 | self.max_q_len = max_q_len 40 | self.max_p_len = max_p_len 41 | 42 | def encode(self, texts_or_text_pairs: Union[List[str], List[TextPair]], 43 | max_seq_len: int, batch_size: int = None) -> torch.Tensor: 44 | total = len(texts_or_text_pairs) 45 | if batch_size is None or batch_size <= 0: 46 | batch_size = max(total, 256) 47 | device = get_device(self.model) 48 | vectors = [] 49 | self.model.eval() 50 | with torch.no_grad(): 51 | for batch_start in trange(0, total, batch_size, disable=(total / batch_size <= 10.)): 52 | inputs = self.tokenizer.batch_encode_plus( 53 | batch_text_or_text_pairs=texts_or_text_pairs[batch_start:batch_start + batch_size], 54 | padding=True, truncation=True, max_length=max_seq_len, return_tensors="pt" 55 | ) 56 | inputs = to_device({"input_ids": inputs["input_ids"], "input_mask": inputs["attention_mask"]}, device) 57 | embeddings = self.model(inputs)['embed'] 58 | vectors.append(embeddings.cpu()) 59 | vectors = torch.cat(vectors, dim=0).contiguous() 60 | assert vectors.shape[0] == total 61 | return vectors 62 | 63 | def encode_queries(self, queries: List[TextPair], batch_size: int = None, **kwargs) -> np.ndarray: 64 | try: 65 | max_q_len = int(kwargs['max_q_len']) 66 | assert max_q_len > 0 67 | except: 68 | max_q_len = self.max_q_len 69 | return self.encode(queries, max_q_len, batch_size).numpy() 70 | 71 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = None, **kwargs) -> np.ndarray: 72 | text_pairs = [(para['title'], para['text'] if para['text'] else para['title']) for para in corpus] 73 | return self.encode(text_pairs, self.max_p_len, batch_size).numpy() 74 | 75 | 76 | class TASEncoder(DenseEncoder): 77 | def __init__(self, model: Union[DistilBertModel, nn.DataParallel], tokenizer: DistilBertTokenizer, 78 | max_q_len: int = None, max_p_len: int = None): 79 | self.model = model 80 | self.tokenizer = tokenizer 81 | self.max_q_len = max_q_len 82 | self.max_p_len = max_p_len 83 | 84 | def encode(self, texts: List[str], max_seq_len: int, batch_size: int = None) -> torch.Tensor: 85 | total = len(texts) 86 | if batch_size is None or batch_size <= 0: 87 | batch_size = max(total, 256) 88 | device = get_device(self.model) 89 | vectors = [] 90 | self.model.eval() 91 | with torch.no_grad(): 92 | for batch_start in trange(0, total, batch_size, disable=(total / batch_size <= 10.)): 93 | inputs = self.tokenizer.batch_encode_plus( 94 | batch_text_or_text_pairs=texts[batch_start:batch_start + batch_size], 95 | padding=True, truncation=True, max_length=max_seq_len, return_tensors="pt" 96 | ) 97 | inputs = to_device(inputs, device) 98 | embeddings = self.model(**inputs)[0][:, 0, :] 99 | vectors.append(embeddings.cpu()) 100 | vectors = torch.cat(vectors, dim=0).contiguous() 101 | assert vectors.shape[0] == total 102 | return vectors 103 | 104 | def encode_queries(self, queries: List[str], batch_size: int = None, **kwargs) -> np.ndarray: 105 | return self.encode(queries, self.max_q_len, batch_size).numpy() 106 | 107 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = None, **kwargs) -> np.ndarray: 108 | texts = [para['text'] for para in corpus] 109 | return self.encode(texts, self.max_p_len, batch_size).numpy() 110 | 111 | 112 | class PassageDataset(Dataset): 113 | 114 | def __init__(self, corpus, tokenizer, max_seq_len): 115 | super().__init__() 116 | self.corpus = [(p_id, corpus[p_id]) for p_id in sorted(corpus.keys())] 117 | self.tokenizer = tokenizer 118 | self.max_seq_len = max_seq_len 119 | 120 | def __getitem__(self, index): 121 | p_id, para = self.corpus[index] 122 | para_codes = self.tokenizer.encode_plus(para["text"], truncation=True, max_length=self.max_seq_len, 123 | return_tensors="pt") 124 | for k in para_codes.keys(): 125 | para_codes[k].squeeze_(0) 126 | para_codes['p_id'] = p_id 127 | 128 | return para_codes 129 | 130 | def __len__(self): 131 | return len(self.corpus) 132 | 133 | 134 | def collate_passages(samples, pad_id=0): 135 | if len(samples) == 0: 136 | return {} 137 | 138 | nn_input = { 139 | "input_ids": pad_tensors([sample['input_ids'] for sample in samples], pad_id), # (B, T) 140 | "attention_mask": pad_tensors([sample['attention_mask'] for sample in samples], 0), # (B, T) 141 | } 142 | if 'token_type_ids' in samples[0]: 143 | nn_input['token_type_ids'] = pad_tensors([sample['token_type_ids'] for sample in samples], 0), # (B, T) 144 | 145 | batch = {key: [] for key in samples[0] if key not in nn_input} 146 | for sample in samples: 147 | for k in batch: 148 | batch[k].append(sample[k]) 149 | batch['nn_input'] = nn_input 150 | 151 | return batch 152 | -------------------------------------------------------------------------------- /dpr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/dpr/__init__.py -------------------------------------------------------------------------------- /dpr/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/dpr/data/__init__.py -------------------------------------------------------------------------------- /dpr/data/qa_validation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Set of utilities for Q&A results validation tasks - Retriver passage validation and Reader predicted answer validation 10 | """ 11 | 12 | import collections 13 | import logging 14 | import string 15 | import unicodedata 16 | from functools import partial 17 | from multiprocessing import Pool as ProcessPool 18 | from typing import Tuple, List, Dict 19 | 20 | import regex as re 21 | 22 | from dpr.utils.tokenizers import SimpleTokenizer 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 27 | 28 | 29 | def calculate_matches(all_docs: Dict[object, Tuple[str, str]], answers: List[List[str]], 30 | closest_docs: List[Tuple[List[object], List[float]]], workers_num: int, 31 | match_type: str) -> QAMatchStats: 32 | """ 33 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 34 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 35 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 36 | :param answers: list of answers's list. One list per question 37 | :param closest_docs: document ids of the top results along with their scores 38 | :param workers_num: amount of parallel threads to process data 39 | :param match_type: type of answer matching. Refer to has_answer code for available options 40 | :return: matching information tuple. 41 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 42 | valid matches across an entire dataset. 43 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 44 | """ 45 | global dpr_all_documents 46 | dpr_all_documents = all_docs 47 | 48 | tok_opts = {} 49 | tokenizer = SimpleTokenizer(**tok_opts) 50 | 51 | processes = ProcessPool( 52 | processes=workers_num, 53 | ) 54 | 55 | logger.info('Matching answers in top docs...') 56 | 57 | get_score_partial = partial(check_answer, match_type=match_type, tokenizer=tokenizer) 58 | 59 | questions_answers_docs = zip(answers, closest_docs) 60 | 61 | scores = processes.map(get_score_partial, questions_answers_docs) 62 | 63 | logger.info('Per question validation results len=%d', len(scores)) 64 | 65 | n_docs = len(closest_docs[0][0]) 66 | top_k_hits = [0] * n_docs 67 | for question_hits in scores: 68 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 69 | if best_hit is not None: 70 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 71 | 72 | return QAMatchStats(top_k_hits, scores) 73 | 74 | 75 | def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: 76 | """Search through all the top docs to see if they have any of the answers.""" 77 | answers, (doc_ids, doc_scores) = questions_answers_docs 78 | 79 | global dpr_all_documents 80 | hits = [] 81 | 82 | for i, doc_id in enumerate(doc_ids): 83 | doc = dpr_all_documents[doc_id] 84 | text = doc[0] 85 | 86 | answer_found = False 87 | if text is None: # cannot find the document for some reason 88 | logger.warning("no doc in db") 89 | hits.append(False) 90 | continue 91 | 92 | if has_answer(answers, text, tokenizer, match_type): 93 | answer_found = True 94 | hits.append(answer_found) 95 | return hits 96 | 97 | 98 | def has_answer(answers, text, tokenizer, match_type) -> bool: 99 | """Check if a document contains an answer string. 100 | If `match_type` is string, token matching is done between the text and answer. 101 | If `match_type` is regex, we search the whole text with the regex. 102 | """ 103 | text = _normalize(text) 104 | 105 | if match_type == 'string': 106 | # Answer is a list of possible strings 107 | text = tokenizer.tokenize(text).words(uncased=True) 108 | 109 | for single_answer in answers: 110 | single_answer = _normalize(single_answer) 111 | single_answer = tokenizer.tokenize(single_answer) 112 | single_answer = single_answer.words(uncased=True) 113 | 114 | for i in range(0, len(text) - len(single_answer) + 1): 115 | if single_answer == text[i: i + len(single_answer)]: 116 | return True 117 | 118 | elif match_type == 'regex': 119 | # Answer is a regex 120 | for single_answer in answers: 121 | single_answer = _normalize(single_answer) 122 | if regex_match(text, single_answer): 123 | return True 124 | return False 125 | 126 | 127 | def regex_match(text, pattern): 128 | """Test if a regex pattern is contained within a text.""" 129 | try: 130 | pattern = re.compile( 131 | pattern, 132 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, 133 | ) 134 | except BaseException: 135 | return False 136 | return pattern.search(text) is not None 137 | 138 | 139 | # function for the reader model answer validation 140 | def exact_match_score(prediction, ground_truth): 141 | return _normalize_answer(prediction) == _normalize_answer(ground_truth) 142 | 143 | 144 | def _normalize_answer(s): 145 | def remove_articles(text): 146 | return re.sub(r'\b(a|an|the)\b', ' ', text) 147 | 148 | def white_space_fix(text): 149 | return ' '.join(text.split()) 150 | 151 | def remove_punc(text): 152 | exclude = set(string.punctuation) 153 | return ''.join(ch for ch in text if ch not in exclude) 154 | 155 | def lower(text): 156 | return text.lower() 157 | 158 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 159 | 160 | 161 | def _normalize(text): 162 | return unicodedata.normalize('NFD', text) 163 | -------------------------------------------------------------------------------- /dpr/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import importlib 9 | 10 | """ 11 | 'Router'-like set of methods for component initialization with lazy imports 12 | """ 13 | 14 | 15 | def init_hf_bert_biencoder(args, **kwargs): 16 | if importlib.util.find_spec("transformers") is None: 17 | raise RuntimeError('Please install transformers lib') 18 | from .hf_models import get_bert_biencoder_components 19 | return get_bert_biencoder_components(args, **kwargs) 20 | 21 | 22 | def init_hf_bert_reader(args, **kwargs): 23 | if importlib.util.find_spec("transformers") is None: 24 | raise RuntimeError('Please install transformers lib') 25 | from .hf_models import get_bert_reader_components 26 | return get_bert_reader_components(args, **kwargs) 27 | 28 | 29 | def init_pytext_bert_biencoder(args, **kwargs): 30 | if importlib.util.find_spec("pytext") is None: 31 | raise RuntimeError('Please install pytext lib') 32 | from .pytext_models import get_bert_biencoder_components 33 | return get_bert_biencoder_components(args, **kwargs) 34 | 35 | 36 | def init_fairseq_roberta_biencoder(args, **kwargs): 37 | if importlib.util.find_spec("fairseq") is None: 38 | raise RuntimeError('Please install fairseq lib') 39 | from .fairseq_models import get_roberta_biencoder_components 40 | return get_roberta_biencoder_components(args, **kwargs) 41 | 42 | 43 | def init_hf_bert_tenzorizer(args, **kwargs): 44 | if importlib.util.find_spec("transformers") is None: 45 | raise RuntimeError('Please install transformers lib') 46 | from .hf_models import get_bert_tensorizer 47 | return get_bert_tensorizer(args) 48 | 49 | 50 | def init_hf_roberta_tenzorizer(args, **kwargs): 51 | if importlib.util.find_spec("transformers") is None: 52 | raise RuntimeError('Please install transformers lib') 53 | from .hf_models import get_roberta_tensorizer 54 | return get_roberta_tensorizer(args) 55 | 56 | 57 | BIENCODER_INITIALIZERS = { 58 | 'hf_bert': init_hf_bert_biencoder, 59 | 'pytext_bert': init_pytext_bert_biencoder, 60 | 'fairseq_roberta': init_fairseq_roberta_biencoder, 61 | } 62 | 63 | READER_INITIALIZERS = { 64 | 'hf_bert': init_hf_bert_reader, 65 | } 66 | 67 | TENSORIZER_INITIALIZERS = { 68 | 'hf_bert': init_hf_bert_tenzorizer, 69 | 'hf_roberta': init_hf_roberta_tenzorizer, 70 | 'pytext_bert': init_hf_bert_tenzorizer, # using HF's code as of now 71 | 'fairseq_roberta': init_hf_roberta_tenzorizer, # using HF's code as of now 72 | } 73 | 74 | 75 | def init_comp(initializers_dict, type, args, **kwargs): 76 | if type in initializers_dict: 77 | return initializers_dict[type](args, **kwargs) 78 | else: 79 | raise RuntimeError('unsupported model type: {}'.format(type)) 80 | 81 | 82 | def init_biencoder_components(encoder_type: str, args, **kwargs): 83 | return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs) 84 | 85 | 86 | def init_reader_components(encoder_type: str, args, **kwargs): 87 | return init_comp(READER_INITIALIZERS, encoder_type, args, **kwargs) 88 | 89 | 90 | def init_tenzorizer(encoder_type: str, args, **kwargs): 91 | return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs) 92 | -------------------------------------------------------------------------------- /dpr/models/fairseq_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Encoder model wrappers based on Fairseq code 10 | """ 11 | 12 | import logging 13 | from typing import Tuple 14 | 15 | from fairseq.models.roberta.hub_interface import RobertaHubInterface 16 | from fairseq.models.roberta.model import RobertaModel as FaiseqRobertaModel 17 | from fairseq.optim.adam import FairseqAdam 18 | from torch import Tensor as T 19 | from torch import nn 20 | 21 | from dpr.models.hf_models import get_roberta_tensorizer 22 | from .biencoder import BiEncoder 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def get_roberta_biencoder_components(args, inference_only: bool = False, **kwargs): 28 | question_encoder = RobertaEncoder.from_pretrained(args.pretrained_file) 29 | ctx_encoder = RobertaEncoder.from_pretrained(args.pretrained_file) 30 | biencoder = BiEncoder(question_encoder, ctx_encoder) 31 | optimizer = get_fairseq_adamw_optimizer(biencoder, args) if not inference_only else None 32 | 33 | tensorizer = get_roberta_tensorizer(args) 34 | 35 | return tensorizer, biencoder, optimizer 36 | 37 | 38 | def get_fairseq_adamw_optimizer(model: nn.Module, args): 39 | setattr(args, 'lr', [args.learning_rate]) 40 | return FairseqAdam(args, model.parameters()).optimizer 41 | 42 | 43 | class RobertaEncoder(nn.Module): 44 | 45 | def __init__(self, fairseq_roberta_hub: RobertaHubInterface): 46 | super(RobertaEncoder, self).__init__() 47 | self.fairseq_roberta = fairseq_roberta_hub 48 | 49 | @classmethod 50 | def from_pretrained(cls, pretrained_dir_path: str): 51 | model = FaiseqRobertaModel.from_pretrained(pretrained_dir_path) 52 | return cls(model) 53 | 54 | def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: 55 | roberta_out = self.fairseq_roberta.extract_features(input_ids) 56 | cls_out = roberta_out[:, 0, :] 57 | return roberta_out, cls_out, None 58 | 59 | def get_out_size(self): 60 | raise NotImplementedError 61 | -------------------------------------------------------------------------------- /dpr/models/pytext_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Encoder model wrappers based on HuggingFace code 10 | """ 11 | 12 | import logging 13 | from typing import Tuple 14 | 15 | import torch 16 | from pytext.models.representations.transformer_sentence_encoder import TransformerSentenceEncoder 17 | from pytext.optimizer.optimizers import AdamW 18 | from torch import Tensor as T 19 | from torch import nn 20 | 21 | from .biencoder import BiEncoder 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def get_bert_biencoder_components(args, inference_only: bool = False): 27 | # since bert tokenizer is the same in HF and pytext/fairseq, just use HF's implementation here for now 28 | from .hf_models import get_tokenizer, BertTensorizer 29 | 30 | tokenizer = get_tokenizer(args.pretrained_model_cfg, do_lower_case=args.do_lower_case) 31 | 32 | question_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, 33 | projection_dim=args.projection_dim, dropout=args.dropout, 34 | vocab_size=tokenizer.vocab_size, 35 | padding_idx=tokenizer.pad_token_type_id 36 | ) 37 | 38 | ctx_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, 39 | projection_dim=args.projection_dim, dropout=args.dropout, 40 | vocab_size=tokenizer.vocab_size, 41 | padding_idx=tokenizer.pad_token_type_id 42 | ) 43 | 44 | biencoder = BiEncoder(question_encoder, ctx_encoder) 45 | 46 | optimizer = get_optimizer(biencoder, 47 | learning_rate=args.learning_rate, 48 | adam_eps=args.adam_eps, weight_decay=args.weight_decay, 49 | ) if not inference_only else None 50 | 51 | tensorizer = BertTensorizer(tokenizer, args.sequence_length) 52 | return tensorizer, biencoder, optimizer 53 | 54 | 55 | def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, 56 | weight_decay: float = 0.0) -> torch.optim.Optimizer: 57 | cfg = AdamW.Config() 58 | cfg.lr = learning_rate 59 | cfg.weight_decay = weight_decay 60 | cfg.eps = adam_eps 61 | optimizer = AdamW.from_config(cfg, model) 62 | return optimizer 63 | 64 | 65 | def get_pytext_bert_base_cfg(): 66 | cfg = TransformerSentenceEncoder.Config() 67 | cfg.embedding_dim = 768 68 | cfg.ffn_embedding_dim = 3072 69 | cfg.num_encoder_layers = 12 70 | cfg.num_attention_heads = 12 71 | cfg.num_segments = 2 72 | cfg.use_position_embeddings = True 73 | cfg.offset_positions_by_padding = True 74 | cfg.apply_bert_init = True 75 | cfg.encoder_normalize_before = True 76 | cfg.activation_fn = "gelu" 77 | cfg.projection_dim = 0 78 | cfg.max_seq_len = 512 79 | cfg.multilingual = False 80 | cfg.freeze_embeddings = False 81 | cfg.n_trans_layers_to_freeze = 0 82 | cfg.use_torchscript = False 83 | return cfg 84 | 85 | 86 | class PytextBertEncoder(TransformerSentenceEncoder): 87 | 88 | def __init__(self, config: TransformerSentenceEncoder.Config, 89 | padding_idx: int, 90 | vocab_size: int, 91 | projection_dim: int = 0, 92 | *args, 93 | **kwarg 94 | ): 95 | 96 | TransformerSentenceEncoder.__init__(self, config, False, padding_idx, vocab_size, *args, **kwarg) 97 | 98 | assert config.embedding_dim > 0, 'Encoder hidden_size can\'t be zero' 99 | self.encode_proj = nn.Linear(config.embedding_dim, projection_dim) if projection_dim != 0 else None 100 | 101 | @classmethod 102 | def init_encoder(cls, pretrained_file: str = None, projection_dim: int = 0, dropout: float = 0.1, 103 | vocab_size: int = 0, 104 | padding_idx: int = 0, **kwargs): 105 | cfg = get_pytext_bert_base_cfg() 106 | 107 | if dropout != 0: 108 | cfg.dropout = dropout 109 | cfg.attention_dropout = dropout 110 | cfg.activation_dropout = dropout 111 | 112 | encoder = cls(cfg, padding_idx, vocab_size, projection_dim, **kwargs) 113 | 114 | if pretrained_file: 115 | logger.info('Loading pre-trained pytext encoder state from %s', pretrained_file) 116 | state = torch.load(pretrained_file) 117 | encoder.load_state_dict(state) 118 | return encoder 119 | 120 | def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: 121 | pooled_output = super().forward((input_ids, attention_mask, token_type_ids, None))[0] 122 | if self.encode_proj: 123 | pooled_output = self.encode_proj(pooled_output) 124 | 125 | return None, pooled_output, None 126 | 127 | def get_out_size(self): 128 | if self.encode_proj: 129 | return self.encode_proj.out_features 130 | return self.representation_dim 131 | -------------------------------------------------------------------------------- /dpr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/dpr/utils/__init__.py -------------------------------------------------------------------------------- /dpr/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for general purpose data processing 10 | """ 11 | 12 | import json 13 | import logging 14 | import math 15 | import pickle 16 | import random 17 | from typing import List, Iterator, Callable 18 | 19 | from torch import Tensor as T 20 | 21 | logger = logging.getLogger() 22 | 23 | 24 | def read_serialized_data_from_files(paths: List[str]) -> List: 25 | results = [] 26 | for i, path in enumerate(paths): 27 | with open(path, "rb") as reader: 28 | logger.info('Reading file %s', path) 29 | data = pickle.load(reader) 30 | results.extend(data) 31 | logger.info('Aggregated data size: {}'.format(len(results))) 32 | logger.info('Total data size: {}'.format(len(results))) 33 | return results 34 | 35 | 36 | def read_data_from_json_files(paths: List[str], upsample_rates: List = None) -> List: 37 | results = [] 38 | if upsample_rates is None: 39 | upsample_rates = [1] * len(paths) 40 | 41 | assert len(upsample_rates) == len(paths), 'up-sample rates parameter doesn\'t match input files amount' 42 | 43 | for i, path in enumerate(paths): 44 | with open(path, 'r', encoding="utf-8") as f: 45 | logger.info('Reading file %s' % path) 46 | data = json.load(f) 47 | upsample_factor = int(upsample_rates[i]) 48 | data = data * upsample_factor 49 | results.extend(data) 50 | logger.info('Aggregated data size: {}'.format(len(results))) 51 | return results 52 | 53 | 54 | class ShardedDataIterator(object): 55 | """ 56 | General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of 57 | the data. 58 | Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. 59 | It fills the extra sample by just taking first samples in a shard. 60 | It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). 61 | """ 62 | def __init__(self, data: list, shard_id: int = 0, num_shards: int = 1, batch_size: int = 1, shuffle=True, 63 | shuffle_seed: int = 0, offset: int = 0, 64 | strict_batch_size: bool = False 65 | ): 66 | 67 | self.data = data 68 | total_size = len(data) 69 | 70 | self.shards_num = max(num_shards, 1) 71 | self.shard_id = max(shard_id, 0) 72 | 73 | samples_per_shard = math.ceil(total_size / self.shards_num) 74 | 75 | self.shard_start_idx = self.shard_id * samples_per_shard 76 | 77 | self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size) 78 | 79 | if strict_batch_size: 80 | self.max_iterations = math.ceil(samples_per_shard / batch_size) 81 | else: 82 | self.max_iterations = int(samples_per_shard / batch_size) 83 | 84 | logger.debug( 85 | 'samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d', samples_per_shard, 86 | self.shard_start_idx, 87 | self.shard_end_idx, 88 | self.max_iterations) 89 | 90 | self.iteration = offset # to track in-shard iteration status 91 | self.shuffle = shuffle 92 | self.batch_size = batch_size 93 | self.shuffle_seed = shuffle_seed 94 | self.strict_batch_size = strict_batch_size 95 | 96 | def total_data_len(self) -> int: 97 | return len(self.data) 98 | 99 | def iterate_data(self, epoch: int = 0) -> Iterator[List]: 100 | if self.shuffle: 101 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 102 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 103 | epoch_rnd.shuffle(self.data) 104 | 105 | # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations 106 | 107 | max_iterations = self.max_iterations - self.iteration 108 | 109 | shard_samples = self.data[self.shard_start_idx:self.shard_end_idx] 110 | for i in range(self.iteration * self.batch_size, len(shard_samples), self.batch_size): 111 | items = shard_samples[i:i + self.batch_size] 112 | if self.strict_batch_size and len(items) < self.batch_size: 113 | logger.debug('Extending batch to max size') 114 | items.extend(shard_samples[0:self.batch_size - len(items)]) 115 | self.iteration += 1 116 | yield items 117 | 118 | # some shards may done iterating while the others are at the last batch. Just return the first batch 119 | while self.iteration < max_iterations: 120 | logger.debug('Fulfilling non complete shard='.format(self.shard_id)) 121 | self.iteration += 1 122 | batch = shard_samples[0:self.batch_size] 123 | yield batch 124 | 125 | logger.debug('Finished iterating, iteration={}, shard={}'.format(self.iteration, self.shard_id)) 126 | # reset the iteration status 127 | self.iteration = 0 128 | 129 | def get_iteration(self) -> int: 130 | return self.iteration 131 | 132 | def apply(self, visitor_func: Callable): 133 | for sample in self.data: 134 | visitor_func(sample) 135 | 136 | 137 | def normalize_question(question: str) -> str: 138 | if question[-1] == '?': 139 | question = question[:-1] 140 | return question 141 | 142 | 143 | class Tensorizer(object): 144 | """ 145 | Component for all text to model input data conversions and related utility methods 146 | """ 147 | 148 | # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) 149 | def text_to_tensor(self, text: str, title: str = None, add_special_tokens: bool = True): 150 | raise NotImplementedError 151 | 152 | def get_pair_separator_ids(self) -> T: 153 | raise NotImplementedError 154 | 155 | def get_pad_id(self) -> int: 156 | raise NotImplementedError 157 | 158 | def get_attn_mask(self, tokens_tensor: T): 159 | raise NotImplementedError 160 | 161 | def is_sub_word_id(self, token_id: int): 162 | raise NotImplementedError 163 | 164 | def to_string(self, token_ids, skip_special_tokens=True): 165 | raise NotImplementedError 166 | 167 | def set_pad_to_max(self, pad: bool): 168 | raise NotImplementedError 169 | -------------------------------------------------------------------------------- /dpr/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for distributed model training 10 | """ 11 | 12 | import pickle 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | def get_rank(): 19 | return dist.get_rank() 20 | 21 | 22 | def get_world_size(): 23 | return dist.get_world_size() 24 | 25 | 26 | def get_default_group(): 27 | return dist.group.WORLD 28 | 29 | 30 | def all_reduce(tensor, group=None): 31 | if group is None: 32 | group = get_default_group() 33 | return dist.all_reduce(tensor, group=group) 34 | 35 | 36 | def all_gather_list(data, group=None, max_size=16384): 37 | """Gathers arbitrary data from all nodes into a list. 38 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 39 | data. Note that *data* must be picklable. 40 | Args: 41 | data (Any): data from the local worker to be gathered on other workers 42 | group (optional): group of the collective 43 | """ 44 | SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size 45 | 46 | enc = pickle.dumps(data) 47 | enc_size = len(enc) 48 | 49 | if enc_size + SIZE_STORAGE_BYTES > max_size: 50 | raise ValueError( 51 | 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) 52 | 53 | rank = get_rank() 54 | world_size = get_world_size() 55 | buffer_size = max_size * world_size 56 | 57 | if not hasattr(all_gather_list, '_buffer') or \ 58 | all_gather_list._buffer.numel() < buffer_size: 59 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 60 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 61 | 62 | buffer = all_gather_list._buffer 63 | buffer.zero_() 64 | cpu_buffer = all_gather_list._cpu_buffer 65 | 66 | assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( 67 | 256 ** SIZE_STORAGE_BYTES) 68 | 69 | size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') 70 | 71 | cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) 72 | cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) 73 | 74 | start = rank * max_size 75 | size = enc_size + SIZE_STORAGE_BYTES 76 | buffer[start: start + size].copy_(cpu_buffer[:size]) 77 | 78 | all_reduce(buffer, group=group) 79 | 80 | try: 81 | result = [] 82 | for i in range(world_size): 83 | out_buffer = buffer[i * max_size: (i + 1) * max_size] 84 | size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') 85 | if size > 0: 86 | result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) 87 | return result 88 | except pickle.UnpicklingError: 89 | raise Exception( 90 | 'Unable to unpickle data from other workers. all_gather_list requires all ' 91 | 'workers to enter the function together, so this error usually indicates ' 92 | 'that the workers have fallen out of sync somehow. Workers can fall out of ' 93 | 'sync if one of them runs out of memory, or if there are other conditions ' 94 | 'in your training script that can cause one worker to finish an epoch ' 95 | 'while other workers are still iterating over their portions of the data.' 96 | ) 97 | -------------------------------------------------------------------------------- /dpr/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import glob 10 | import logging 11 | import os 12 | from typing import List 13 | 14 | import torch 15 | from torch import nn 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.serialization import default_restore_location 18 | 19 | logger = logging.getLogger() 20 | 21 | CheckpointState = collections.namedtuple("CheckpointState", 22 | ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', 23 | 'encoder_params']) 24 | 25 | 26 | def setup_for_distributed_mode(model: nn.Module, optimizer: torch.optim.Optimizer, device: object, n_gpu: int = 1, 27 | local_rank: int = -1, 28 | fp16: bool = False, 29 | fp16_opt_level: str = "O1") -> (nn.Module, torch.optim.Optimizer): 30 | model.to(device) 31 | if fp16: 32 | try: 33 | import apex 34 | from apex import amp 35 | apex.amp.register_half_function(torch, "einsum") 36 | except ImportError: 37 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 38 | 39 | model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) 40 | 41 | if n_gpu > 1: 42 | model = torch.nn.DataParallel(model) 43 | 44 | if local_rank != -1: 45 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], 46 | output_device=local_rank, 47 | find_unused_parameters=True) 48 | return model, optimizer 49 | 50 | 51 | def move_to_cuda(sample): 52 | if len(sample) == 0: 53 | return {} 54 | 55 | def _move_to_cuda(maybe_tensor): 56 | if torch.is_tensor(maybe_tensor): 57 | return maybe_tensor.cuda() 58 | elif isinstance(maybe_tensor, dict): 59 | return { 60 | key: _move_to_cuda(value) 61 | for key, value in maybe_tensor.items() 62 | } 63 | elif isinstance(maybe_tensor, list): 64 | return [_move_to_cuda(x) for x in maybe_tensor] 65 | elif isinstance(maybe_tensor, tuple): 66 | return [_move_to_cuda(x) for x in maybe_tensor] 67 | else: 68 | return maybe_tensor 69 | 70 | return _move_to_cuda(sample) 71 | 72 | 73 | def move_to_device(sample, device): 74 | if len(sample) == 0: 75 | return {} 76 | 77 | def _move_to_device(maybe_tensor, device): 78 | if torch.is_tensor(maybe_tensor): 79 | return maybe_tensor.to(device) 80 | elif isinstance(maybe_tensor, dict): 81 | return { 82 | key: _move_to_device(value, device) 83 | for key, value in maybe_tensor.items() 84 | } 85 | elif isinstance(maybe_tensor, list): 86 | return [_move_to_device(x, device) for x in maybe_tensor] 87 | elif isinstance(maybe_tensor, tuple): 88 | return [_move_to_device(x, device) for x in maybe_tensor] 89 | else: 90 | return maybe_tensor 91 | 92 | return _move_to_device(sample, device) 93 | 94 | 95 | def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1): 96 | """ Create a schedule with a learning rate that decreases linearly after 97 | linearly increasing during a warmup period. 98 | """ 99 | 100 | def lr_lambda(current_step): 101 | if current_step < warmup_steps: 102 | return float(current_step) / float(max(1, warmup_steps)) 103 | return max( 104 | 0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps)) 105 | ) 106 | 107 | return LambdaLR(optimizer, lr_lambda, last_epoch) 108 | 109 | 110 | def init_weights(modules: List): 111 | for module in modules: 112 | if isinstance(module, (nn.Linear, nn.Embedding)): 113 | module.weight.data.normal_(mean=0.0, std=0.02) 114 | elif isinstance(module, nn.LayerNorm): 115 | module.bias.data.zero_() 116 | module.weight.data.fill_(1.0) 117 | if isinstance(module, nn.Linear) and module.bias is not None: 118 | module.bias.data.zero_() 119 | 120 | 121 | def get_model_obj(model: nn.Module): 122 | return model.module if hasattr(model, 'module') else model 123 | 124 | 125 | def get_model_file(args, file_prefix) -> str: 126 | if args.model_file and os.path.exists(args.model_file): 127 | return args.model_file 128 | 129 | out_cp_files = glob.glob(os.path.join(args.output_dir, file_prefix + '*')) if args.output_dir else [] 130 | logger.info('Checkpoint files %s', out_cp_files) 131 | model_file = None 132 | 133 | if len(out_cp_files) > 0: 134 | model_file = max(out_cp_files, key=os.path.getctime) 135 | return model_file 136 | 137 | 138 | def load_states_from_checkpoint(model_file: str) -> CheckpointState: 139 | logger.info('Reading saved model from %s', model_file) 140 | state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu')) 141 | logger.info('model_state_dict keys %s', state_dict.keys()) 142 | return CheckpointState(**state_dict) 143 | -------------------------------------------------------------------------------- /drqa/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import sys 10 | from pathlib import PosixPath 11 | 12 | if sys.version_info < (3, 5): 13 | raise RuntimeError('DrQA supports Python 3.5 or higher.') 14 | 15 | DATA_DIR = ( 16 | os.getenv('DRQA_DATA') or 17 | os.path.join(PosixPath(__file__).absolute().parents[1].as_posix(), 'data') 18 | ) 19 | 20 | from . import tokenizers 21 | from . import reader 22 | from . import retriever 23 | from . import pipeline 24 | -------------------------------------------------------------------------------- /drqa/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from ..tokenizers import CoreNLPTokenizer 10 | from ..retriever import TfidfDocRanker 11 | from ..retriever import DocDB 12 | from .. import DATA_DIR 13 | 14 | DEFAULTS = { 15 | 'tokenizer': CoreNLPTokenizer, 16 | 'ranker': TfidfDocRanker, 17 | 'db': DocDB, 18 | 'reader_model': os.path.join(DATA_DIR, 'reader/multitask.mdl'), 19 | } 20 | 21 | 22 | def set_default(key, value): 23 | global DEFAULTS 24 | DEFAULTS[key] = value 25 | 26 | 27 | from .drqa import DrQA 28 | -------------------------------------------------------------------------------- /drqa/reader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from ..tokenizers import CoreNLPTokenizer 10 | from .. import DATA_DIR 11 | 12 | 13 | DEFAULTS = { 14 | 'tokenizer': CoreNLPTokenizer, 15 | 'model': os.path.join(DATA_DIR, 'reader/single.mdl'), 16 | } 17 | 18 | 19 | def set_default(key, value): 20 | global DEFAULTS 21 | DEFAULTS[key] = value 22 | 23 | from .model import DocReader 24 | from .predictor import Predictor 25 | from . import config 26 | from . import vector 27 | from . import data 28 | from . import utils 29 | -------------------------------------------------------------------------------- /drqa/reader/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Model architecture/optimization options for DrQA document reader.""" 8 | 9 | import argparse 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | # Index of arguments concerning the core model architecture 15 | MODEL_ARCHITECTURE = { 16 | 'model_type', 'embedding_dim', 'hidden_size', 'doc_layers', 17 | 'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge', 18 | 'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf' 19 | } 20 | 21 | # Index of arguments concerning the model optimizer/training 22 | MODEL_OPTIMIZER = { 23 | 'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay', 24 | 'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb', 25 | 'max_len', 'grad_clipping', 'tune_partial' 26 | } 27 | 28 | 29 | def str2bool(v): 30 | return v.lower() in ('yes', 'true', 't', '1', 'y') 31 | 32 | 33 | def add_model_args(parser): 34 | parser.register('type', 'bool', str2bool) 35 | 36 | # Model architecture 37 | model = parser.add_argument_group('DrQA Reader Model Architecture') 38 | model.add_argument('--model-type', type=str, default='rnn', 39 | help='Model architecture type') 40 | model.add_argument('--embedding-dim', type=int, default=300, 41 | help='Embedding size if embedding_file is not given') 42 | model.add_argument('--hidden-size', type=int, default=128, 43 | help='Hidden size of RNN units') 44 | model.add_argument('--doc-layers', type=int, default=3, 45 | help='Number of encoding layers for document') 46 | model.add_argument('--question-layers', type=int, default=3, 47 | help='Number of encoding layers for question') 48 | model.add_argument('--rnn-type', type=str, default='lstm', 49 | help='RNN type: LSTM, GRU, or RNN') 50 | 51 | # Model specific details 52 | detail = parser.add_argument_group('DrQA Reader Model Details') 53 | detail.add_argument('--concat-rnn-layers', type='bool', default=True, 54 | help='Combine hidden states from each encoding layer') 55 | detail.add_argument('--question-merge', type=str, default='self_attn', 56 | help='The way of computing the question representation') 57 | detail.add_argument('--use-qemb', type='bool', default=True, 58 | help='Whether to use weighted question embeddings') 59 | detail.add_argument('--use-in-question', type='bool', default=True, 60 | help='Whether to use in_question_* features') 61 | detail.add_argument('--use-pos', type='bool', default=True, 62 | help='Whether to use pos features') 63 | detail.add_argument('--use-ner', type='bool', default=True, 64 | help='Whether to use ner features') 65 | detail.add_argument('--use-lemma', type='bool', default=True, 66 | help='Whether to use lemma features') 67 | detail.add_argument('--use-tf', type='bool', default=True, 68 | help='Whether to use term frequency features') 69 | 70 | # Optimization details 71 | optim = parser.add_argument_group('DrQA Reader Optimization') 72 | optim.add_argument('--dropout-emb', type=float, default=0.4, 73 | help='Dropout rate for word embeddings') 74 | optim.add_argument('--dropout-rnn', type=float, default=0.4, 75 | help='Dropout rate for RNN states') 76 | optim.add_argument('--dropout-rnn-output', type='bool', default=True, 77 | help='Whether to dropout the RNN output') 78 | optim.add_argument('--optimizer', type=str, default='adamax', 79 | help='Optimizer: sgd or adamax') 80 | optim.add_argument('--learning-rate', type=float, default=0.1, 81 | help='Learning rate for SGD only') 82 | optim.add_argument('--grad-clipping', type=float, default=10, 83 | help='Gradient clipping') 84 | optim.add_argument('--weight-decay', type=float, default=0, 85 | help='Weight decay factor') 86 | optim.add_argument('--momentum', type=float, default=0, 87 | help='Momentum factor') 88 | optim.add_argument('--fix-embeddings', type='bool', default=True, 89 | help='Keep word embeddings fixed (use pretrained)') 90 | optim.add_argument('--tune-partial', type=int, default=0, 91 | help='Backprop through only the top N question words') 92 | optim.add_argument('--rnn-padding', type='bool', default=False, 93 | help='Explicitly account for padding in RNN encoding') 94 | optim.add_argument('--max-len', type=int, default=15, 95 | help='The max span allowed during decoding') 96 | 97 | 98 | def get_model_args(args): 99 | """Filter args for model ones. 100 | 101 | From a args Namespace, return a new Namespace with *only* the args specific 102 | to the model architecture or optimization. (i.e. the ones defined here.) 103 | """ 104 | global MODEL_ARCHITECTURE, MODEL_OPTIMIZER 105 | required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER 106 | arg_values = {k: v for k, v in vars(args).items() if k in required_args} 107 | return argparse.Namespace(**arg_values) 108 | 109 | 110 | def override_model_args(old_args, new_args): 111 | """Set args to new parameters. 112 | 113 | Decide which model args to keep and which to override when resolving a set 114 | of saved args and new args. 115 | 116 | We keep the new optimation, but leave the model architecture alone. 117 | """ 118 | global MODEL_OPTIMIZER 119 | old_args, new_args = vars(old_args), vars(new_args) 120 | for k in old_args.keys(): 121 | if k in new_args and old_args[k] != new_args[k]: 122 | if k in MODEL_OPTIMIZER: 123 | logger.info('Overriding saved %s: %s --> %s' % 124 | (k, old_args[k], new_args[k])) 125 | old_args[k] = new_args[k] 126 | else: 127 | logger.info('Keeping saved %s: %s' % (k, old_args[k])) 128 | return argparse.Namespace(**old_args) 129 | -------------------------------------------------------------------------------- /drqa/reader/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Data processing/loading helpers.""" 8 | 9 | import numpy as np 10 | import logging 11 | import unicodedata 12 | 13 | from torch.utils.data import Dataset 14 | from torch.utils.data.sampler import Sampler 15 | from .vector import vectorize 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | # ------------------------------------------------------------------------------ 21 | # Dictionary class for tokens. 22 | # ------------------------------------------------------------------------------ 23 | 24 | 25 | class Dictionary(object): 26 | NULL = '' 27 | UNK = '' 28 | START = 2 29 | 30 | @staticmethod 31 | def normalize(token): 32 | return unicodedata.normalize('NFD', token) 33 | 34 | def __init__(self): 35 | self.tok2ind = {self.NULL: 0, self.UNK: 1} 36 | self.ind2tok = {0: self.NULL, 1: self.UNK} 37 | 38 | def __len__(self): 39 | return len(self.tok2ind) 40 | 41 | def __iter__(self): 42 | return iter(self.tok2ind) 43 | 44 | def __contains__(self, key): 45 | if type(key) == int: 46 | return key in self.ind2tok 47 | elif type(key) == str: 48 | return self.normalize(key) in self.tok2ind 49 | 50 | def __getitem__(self, key): 51 | if type(key) == int: 52 | return self.ind2tok.get(key, self.UNK) 53 | if type(key) == str: 54 | return self.tok2ind.get(self.normalize(key), 55 | self.tok2ind.get(self.UNK)) 56 | 57 | def __setitem__(self, key, item): 58 | if type(key) == int and type(item) == str: 59 | self.ind2tok[key] = item 60 | elif type(key) == str and type(item) == int: 61 | self.tok2ind[key] = item 62 | else: 63 | raise RuntimeError('Invalid (key, item) types.') 64 | 65 | def add(self, token): 66 | token = self.normalize(token) 67 | if token not in self.tok2ind: 68 | index = len(self.tok2ind) 69 | self.tok2ind[token] = index 70 | self.ind2tok[index] = token 71 | 72 | def tokens(self): 73 | """Get dictionary tokens. 74 | 75 | Return all the words indexed by this dictionary, except for special 76 | tokens. 77 | """ 78 | tokens = [k for k in self.tok2ind.keys() 79 | if k not in {'', ''}] 80 | return tokens 81 | 82 | 83 | # ------------------------------------------------------------------------------ 84 | # PyTorch dataset class for SQuAD (and SQuAD-like) data. 85 | # ------------------------------------------------------------------------------ 86 | 87 | 88 | class ReaderDataset(Dataset): 89 | 90 | def __init__(self, examples, model, single_answer=False): 91 | self.model = model 92 | self.examples = examples 93 | self.single_answer = single_answer 94 | 95 | def __len__(self): 96 | return len(self.examples) 97 | 98 | def __getitem__(self, index): 99 | return vectorize(self.examples[index], self.model, self.single_answer) 100 | 101 | def lengths(self): 102 | return [(len(ex['document']), len(ex['question'])) 103 | for ex in self.examples] 104 | 105 | 106 | # ------------------------------------------------------------------------------ 107 | # PyTorch sampler returning batched of sorted lengths (by doc and question). 108 | # ------------------------------------------------------------------------------ 109 | 110 | 111 | class SortedBatchSampler(Sampler): 112 | 113 | def __init__(self, lengths, batch_size, shuffle=True): 114 | self.lengths = lengths 115 | self.batch_size = batch_size 116 | self.shuffle = shuffle 117 | 118 | def __iter__(self): 119 | lengths = np.array( 120 | [(-l[0], -l[1], np.random.random()) for l in self.lengths], 121 | dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] 122 | ) 123 | indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) 124 | batches = [indices[i:i + self.batch_size] 125 | for i in range(0, len(indices), self.batch_size)] 126 | if self.shuffle: 127 | np.random.shuffle(batches) 128 | return iter([i for batch in batches for i in batch]) 129 | 130 | def __len__(self): 131 | return len(self.lengths) 132 | -------------------------------------------------------------------------------- /drqa/reader/predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """DrQA Document Reader predictor""" 8 | 9 | import logging 10 | 11 | from multiprocessing import Pool as ProcessPool 12 | from multiprocessing.util import Finalize 13 | 14 | from .vector import vectorize, batchify 15 | from .model import DocReader 16 | from . import DEFAULTS, utils 17 | from .. import tokenizers 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | # ------------------------------------------------------------------------------ 23 | # Tokenize + annotate 24 | # ------------------------------------------------------------------------------ 25 | 26 | PROCESS_TOK = None 27 | 28 | 29 | def init(tokenizer_class, annotators): 30 | global PROCESS_TOK 31 | PROCESS_TOK = tokenizer_class(annotators=annotators) 32 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 33 | 34 | 35 | def tokenize(text): 36 | global PROCESS_TOK 37 | return PROCESS_TOK.tokenize(text) 38 | 39 | 40 | # ------------------------------------------------------------------------------ 41 | # Predictor class. 42 | # ------------------------------------------------------------------------------ 43 | 44 | 45 | class Predictor(object): 46 | """Load a pretrained DocReader model and predict inputs on the fly.""" 47 | 48 | def __init__(self, model=None, tokenizer=None, normalize=True, 49 | embedding_file=None, num_workers=None): 50 | """ 51 | Args: 52 | model: path to saved model file. 53 | tokenizer: option string to select tokenizer class. 54 | normalize: squash output score to 0-1 probabilities with a softmax. 55 | embedding_file: if provided, will expand dictionary to use all 56 | available pretrained vectors in this file. 57 | num_workers: number of CPU processes to use to preprocess batches. 58 | """ 59 | logger.info('Initializing model...') 60 | self.model = DocReader.load(model or DEFAULTS['model'], 61 | normalize=normalize) 62 | 63 | if embedding_file: 64 | logger.info('Expanding dictionary...') 65 | words = utils.index_embedding_words(embedding_file) 66 | added = self.model.expand_dictionary(words) 67 | self.model.load_embeddings(added, embedding_file) 68 | 69 | logger.info('Initializing tokenizer...') 70 | annotators = tokenizers.get_annotators_for_model(self.model) 71 | if not tokenizer: 72 | tokenizer_class = DEFAULTS['tokenizer'] 73 | else: 74 | tokenizer_class = tokenizers.get_class(tokenizer) 75 | 76 | if num_workers is None or num_workers > 0: 77 | self.workers = ProcessPool( 78 | num_workers, 79 | initializer=init, 80 | initargs=(tokenizer_class, annotators), 81 | ) 82 | else: 83 | self.workers = None 84 | self.tokenizer = tokenizer_class(annotators=annotators) 85 | 86 | def predict(self, document, question, candidates=None, top_n=1): 87 | """Predict a single document - question pair.""" 88 | results = self.predict_batch([(document, question, candidates,)], top_n) 89 | return results[0] 90 | 91 | def predict_batch(self, batch, top_n=1): 92 | """Predict a batch of document - question pairs.""" 93 | documents, questions, candidates = [], [], [] 94 | for b in batch: 95 | documents.append(b[0]) 96 | questions.append(b[1]) 97 | candidates.append(b[2] if len(b) == 3 else None) 98 | candidates = candidates if any(candidates) else None 99 | 100 | # Tokenize the inputs, perhaps multi-processed. 101 | if self.workers: 102 | q_tokens = self.workers.map_async(tokenize, questions) 103 | d_tokens = self.workers.map_async(tokenize, documents) 104 | q_tokens = list(q_tokens.get()) 105 | d_tokens = list(d_tokens.get()) 106 | else: 107 | q_tokens = list(map(self.tokenizer.tokenize, questions)) 108 | d_tokens = list(map(self.tokenizer.tokenize, documents)) 109 | 110 | examples = [] 111 | for i in range(len(questions)): 112 | examples.append({ 113 | 'id': i, 114 | 'question': q_tokens[i].words(), 115 | 'qlemma': q_tokens[i].lemmas(), 116 | 'document': d_tokens[i].words(), 117 | 'lemma': d_tokens[i].lemmas(), 118 | 'pos': d_tokens[i].pos(), 119 | 'ner': d_tokens[i].entities(), 120 | }) 121 | 122 | # Stick document tokens in candidates for decoding 123 | if candidates: 124 | candidates = [{'input': d_tokens[i], 'cands': candidates[i]} 125 | for i in range(len(candidates))] 126 | 127 | # Build the batch and run it through the model 128 | batch_exs = batchify([vectorize(e, self.model) for e in examples]) 129 | s, e, score = self.model.predict(batch_exs, candidates, top_n) 130 | 131 | # Retrieve the predicted spans 132 | results = [] 133 | for i in range(len(s)): 134 | predictions = [] 135 | for j in range(len(s[i])): 136 | span = d_tokens[i].slice(s[i][j], e[i][j] + 1).untokenize() 137 | predictions.append((span, score[i][j].item())) 138 | results.append(predictions) 139 | return results 140 | 141 | def cuda(self): 142 | self.model.cuda() 143 | 144 | def cpu(self): 145 | self.model.cpu() 146 | -------------------------------------------------------------------------------- /drqa/reader/rnn_reader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Implementation of the RNN based DrQA reader.""" 8 | 9 | import torch 10 | import torch.nn as nn 11 | from . import layers 12 | 13 | 14 | # ------------------------------------------------------------------------------ 15 | # Network 16 | # ------------------------------------------------------------------------------ 17 | 18 | 19 | class RnnDocReader(nn.Module): 20 | RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} 21 | 22 | def __init__(self, args, normalize=True): 23 | super(RnnDocReader, self).__init__() 24 | # Store config 25 | self.args = args 26 | 27 | # Word embeddings (+1 for padding) 28 | self.embedding = nn.Embedding(args.vocab_size, 29 | args.embedding_dim, 30 | padding_idx=0) 31 | 32 | # Projection for attention weighted question 33 | if args.use_qemb: 34 | self.qemb_match = layers.SeqAttnMatch(args.embedding_dim) 35 | 36 | # Input size to RNN: word emb + question emb + manual features 37 | doc_input_size = args.embedding_dim + args.num_features 38 | if args.use_qemb: 39 | doc_input_size += args.embedding_dim 40 | 41 | # RNN document encoder 42 | self.doc_rnn = layers.StackedBRNN( 43 | input_size=doc_input_size, 44 | hidden_size=args.hidden_size, 45 | num_layers=args.doc_layers, 46 | dropout_rate=args.dropout_rnn, 47 | dropout_output=args.dropout_rnn_output, 48 | concat_layers=args.concat_rnn_layers, 49 | rnn_type=self.RNN_TYPES[args.rnn_type], 50 | padding=args.rnn_padding, 51 | ) 52 | 53 | # RNN question encoder 54 | self.question_rnn = layers.StackedBRNN( 55 | input_size=args.embedding_dim, 56 | hidden_size=args.hidden_size, 57 | num_layers=args.question_layers, 58 | dropout_rate=args.dropout_rnn, 59 | dropout_output=args.dropout_rnn_output, 60 | concat_layers=args.concat_rnn_layers, 61 | rnn_type=self.RNN_TYPES[args.rnn_type], 62 | padding=args.rnn_padding, 63 | ) 64 | 65 | # Output sizes of rnn encoders 66 | doc_hidden_size = 2 * args.hidden_size 67 | question_hidden_size = 2 * args.hidden_size 68 | if args.concat_rnn_layers: 69 | doc_hidden_size *= args.doc_layers 70 | question_hidden_size *= args.question_layers 71 | 72 | # Question merging 73 | if args.question_merge not in ['avg', 'self_attn']: 74 | raise NotImplementedError('merge_mode = %s' % args.merge_mode) 75 | if args.question_merge == 'self_attn': 76 | self.self_attn = layers.LinearSeqAttn(question_hidden_size) 77 | 78 | # Bilinear attention for span start/end 79 | self.start_attn = layers.BilinearSeqAttn( 80 | doc_hidden_size, 81 | question_hidden_size, 82 | normalize=normalize, 83 | ) 84 | self.end_attn = layers.BilinearSeqAttn( 85 | doc_hidden_size, 86 | question_hidden_size, 87 | normalize=normalize, 88 | ) 89 | 90 | def forward(self, x1, x1_f, x1_mask, x2, x2_mask): 91 | """Inputs: 92 | x1 = document word indices [batch * len_d] 93 | x1_f = document word features indices [batch * len_d * nfeat] 94 | x1_mask = document padding mask [batch * len_d] 95 | x2 = question word indices [batch * len_q] 96 | x2_mask = question padding mask [batch * len_q] 97 | """ 98 | # Embed both document and question 99 | x1_emb = self.embedding(x1) 100 | x2_emb = self.embedding(x2) 101 | 102 | # Dropout on embeddings 103 | if self.args.dropout_emb > 0: 104 | x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb, 105 | training=self.training) 106 | x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb, 107 | training=self.training) 108 | 109 | # Form document encoding inputs 110 | drnn_input = [x1_emb] 111 | 112 | # Add attention-weighted question representation 113 | if self.args.use_qemb: 114 | x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask) 115 | drnn_input.append(x2_weighted_emb) 116 | 117 | # Add manual features 118 | if self.args.num_features > 0: 119 | drnn_input.append(x1_f) 120 | 121 | # Encode document with RNN 122 | doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask) 123 | 124 | # Encode question with RNN + merge hiddens 125 | question_hiddens = self.question_rnn(x2_emb, x2_mask) 126 | if self.args.question_merge == 'avg': 127 | q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask) 128 | elif self.args.question_merge == 'self_attn': 129 | q_merge_weights = self.self_attn(question_hiddens, x2_mask) 130 | question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights) 131 | 132 | # Predict start and end positions 133 | start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask) 134 | end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask) 135 | return start_scores, end_scores 136 | -------------------------------------------------------------------------------- /drqa/reader/vector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Functions for putting examples into torch format.""" 8 | 9 | from collections import Counter 10 | import torch 11 | 12 | 13 | def vectorize(ex, model, single_answer=False): 14 | """Torchify a single example.""" 15 | args = model.args 16 | word_dict = model.word_dict 17 | feature_dict = model.feature_dict 18 | 19 | # Index words 20 | document = torch.LongTensor([word_dict[w] for w in ex['document']]) 21 | question = torch.LongTensor([word_dict[w] for w in ex['question']]) 22 | 23 | # Create extra features vector 24 | if len(feature_dict) > 0: 25 | features = torch.zeros(len(ex['document']), len(feature_dict)) 26 | else: 27 | features = None 28 | 29 | # f_{exact_match} 30 | if args.use_in_question: 31 | q_words_cased = {w for w in ex['question']} 32 | q_words_uncased = {w.lower() for w in ex['question']} 33 | q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None 34 | for i in range(len(ex['document'])): 35 | if ex['document'][i] in q_words_cased: 36 | features[i][feature_dict['in_question']] = 1.0 37 | if ex['document'][i].lower() in q_words_uncased: 38 | features[i][feature_dict['in_question_uncased']] = 1.0 39 | if q_lemma and ex['lemma'][i] in q_lemma: 40 | features[i][feature_dict['in_question_lemma']] = 1.0 41 | 42 | # f_{token} (POS) 43 | if args.use_pos: 44 | for i, w in enumerate(ex['pos']): 45 | f = 'pos=%s' % w 46 | if f in feature_dict: 47 | features[i][feature_dict[f]] = 1.0 48 | 49 | # f_{token} (NER) 50 | if args.use_ner: 51 | for i, w in enumerate(ex['ner']): 52 | f = 'ner=%s' % w 53 | if f in feature_dict: 54 | features[i][feature_dict[f]] = 1.0 55 | 56 | # f_{token} (TF) 57 | if args.use_tf: 58 | counter = Counter([w.lower() for w in ex['document']]) 59 | l = len(ex['document']) 60 | for i, w in enumerate(ex['document']): 61 | features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l 62 | 63 | # Maybe return without target 64 | if 'answers' not in ex: 65 | return document, features, question, ex['id'] 66 | 67 | # ...or with target(s) (might still be empty if answers is empty) 68 | if single_answer: 69 | assert(len(ex['answers']) > 0) 70 | start = torch.LongTensor(1).fill_(ex['answers'][0][0]) 71 | end = torch.LongTensor(1).fill_(ex['answers'][0][1]) 72 | else: 73 | start = [a[0] for a in ex['answers']] 74 | end = [a[1] for a in ex['answers']] 75 | 76 | return document, features, question, start, end, ex['id'] 77 | 78 | 79 | def batchify(batch): 80 | """Gather a batch of individual examples into one batch.""" 81 | NUM_INPUTS = 3 82 | NUM_TARGETS = 2 83 | NUM_EXTRA = 1 84 | 85 | ids = [ex[-1] for ex in batch] 86 | docs = [ex[0] for ex in batch] 87 | features = [ex[1] for ex in batch] 88 | questions = [ex[2] for ex in batch] 89 | 90 | # Batch documents and features 91 | max_length = max([d.size(0) for d in docs]) 92 | x1 = torch.LongTensor(len(docs), max_length).zero_() 93 | x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1) 94 | if features[0] is None: 95 | x1_f = None 96 | else: 97 | x1_f = torch.zeros(len(docs), max_length, features[0].size(1)) 98 | for i, d in enumerate(docs): 99 | x1[i, :d.size(0)].copy_(d) 100 | x1_mask[i, :d.size(0)].fill_(0) 101 | if x1_f is not None: 102 | x1_f[i, :d.size(0)].copy_(features[i]) 103 | 104 | # Batch questions 105 | max_length = max([q.size(0) for q in questions]) 106 | x2 = torch.LongTensor(len(questions), max_length).zero_() 107 | x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1) 108 | for i, q in enumerate(questions): 109 | x2[i, :q.size(0)].copy_(q) 110 | x2_mask[i, :q.size(0)].fill_(0) 111 | 112 | # Maybe return without targets 113 | if len(batch[0]) == NUM_INPUTS + NUM_EXTRA: 114 | return x1, x1_f, x1_mask, x2, x2_mask, ids 115 | 116 | elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS: 117 | # ...Otherwise add targets 118 | if torch.is_tensor(batch[0][3]): 119 | y_s = torch.cat([ex[3] for ex in batch]) 120 | y_e = torch.cat([ex[4] for ex in batch]) 121 | else: 122 | y_s = [ex[3] for ex in batch] 123 | y_e = [ex[4] for ex in batch] 124 | else: 125 | raise RuntimeError('Incorrect number of inputs per example.') 126 | 127 | return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids 128 | -------------------------------------------------------------------------------- /drqa/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from .. import DATA_DIR 10 | 11 | DEFAULTS = { 12 | 'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'), 13 | 'tfidf_path': os.path.join( 14 | DATA_DIR, 15 | 'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz' 16 | ), 17 | 'elastic_url': 'localhost:9200' 18 | } 19 | 20 | 21 | def set_default(key, value): 22 | global DEFAULTS 23 | DEFAULTS[key] = value 24 | 25 | 26 | def get_class(name): 27 | if name == 'tfidf': 28 | return TfidfDocRanker 29 | if name == 'sqlite': 30 | return DocDB 31 | if name == 'elasticsearch': 32 | return ElasticDocRanker 33 | raise RuntimeError('Invalid retriever class: %s' % name) 34 | 35 | 36 | from .doc_db import DocDB 37 | from .tfidf_doc_ranker import TfidfDocRanker 38 | from .elastic_doc_ranker import ElasticDocRanker 39 | -------------------------------------------------------------------------------- /drqa/retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | from . import utils 11 | from . import DEFAULTS 12 | 13 | 14 | class DocDB(object): 15 | """Sqlite backed document storage. 16 | 17 | Implements get_doc_text(doc_id). 18 | """ 19 | 20 | def __init__(self, db_path=None): 21 | self.path = db_path or DEFAULTS['db_path'] 22 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 23 | 24 | def __enter__(self): 25 | return self 26 | 27 | def __exit__(self, *args): 28 | self.close() 29 | 30 | def path(self): 31 | """Return the path to the file that backs this database.""" 32 | return self.path 33 | 34 | def close(self): 35 | """Close the connection to the database.""" 36 | self.connection.close() 37 | 38 | def get_doc_ids(self): 39 | """Fetch all ids of docs stored in the db.""" 40 | cursor = self.connection.cursor() 41 | cursor.execute("SELECT id FROM documents") 42 | results = [r[0] for r in cursor.fetchall()] 43 | cursor.close() 44 | return results 45 | 46 | def get_doc_text(self, doc_id): 47 | """Fetch the raw text of the doc for 'doc_id'.""" 48 | cursor = self.connection.cursor() 49 | cursor.execute( 50 | "SELECT text FROM documents WHERE id = ?", 51 | (utils.normalize(doc_id),) 52 | ) 53 | result = cursor.fetchone() 54 | cursor.close() 55 | return result if result is None else result[0] 56 | -------------------------------------------------------------------------------- /drqa/retriever/elastic_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with an ElasticSearch index""" 8 | 9 | import logging 10 | import scipy.sparse as sp 11 | 12 | from multiprocessing.pool import ThreadPool 13 | from functools import partial 14 | from elasticsearch import Elasticsearch 15 | 16 | from . import utils 17 | from . import DEFAULTS 18 | from .. import tokenizers 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class ElasticDocRanker(object): 24 | """ Connect to an ElasticSearch index. 25 | Score pairs based on Elasticsearch 26 | """ 27 | 28 | def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None): 29 | """ 30 | Args: 31 | elastic_url: URL of the ElasticSearch server containing port 32 | elastic_index: Index name of ElasticSearch 33 | elastic_fields: Fields of the Elasticsearch index to search in 34 | elastic_field_doc_name: Field containing the name of the document (index) 35 | strict: fail on empty queries or continue (and return empty result) 36 | elastic_field_content: Field containing the content of document in plaint text 37 | """ 38 | # Load from disk 39 | elastic_url = elastic_url or DEFAULTS['elastic_url'] 40 | logger.info('Connecting to %s' % elastic_url) 41 | self.es = Elasticsearch(hosts=elastic_url) 42 | self.elastic_index = elastic_index 43 | self.elastic_fields = elastic_fields 44 | self.elastic_field_doc_name = elastic_field_doc_name 45 | self.elastic_field_content = elastic_field_content 46 | self.strict = strict 47 | 48 | # Elastic Ranker 49 | 50 | def get_doc_index(self, doc_id): 51 | """Convert doc_id --> doc_index""" 52 | field_index = self.elastic_field_doc_name 53 | if isinstance(field_index, list): 54 | field_index = '.'.join(field_index) 55 | result = self.es.search(index=self.elastic_index, body={'query':{'match': 56 | {field_index: doc_id}}}) 57 | return result['hits']['hits'][0]['_id'] 58 | 59 | 60 | def get_doc_id(self, doc_index): 61 | """Convert doc_index --> doc_id""" 62 | result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}}) 63 | source = result['hits']['hits'][0]['_source'] 64 | return utils.get_field(source, self.elastic_field_doc_name) 65 | 66 | def closest_docs(self, query, k=1): 67 | """Closest docs by using ElasticSearch 68 | """ 69 | results = self.es.search(index=self.elastic_index, body={'size':k ,'query': 70 | {'multi_match': { 71 | 'query': query, 72 | 'type': 'most_fields', 73 | 'fields': self.elastic_fields}}}) 74 | hits = results['hits']['hits'] 75 | doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits] 76 | doc_scores = [row['_score'] for row in hits] 77 | return doc_ids, doc_scores 78 | 79 | def batch_closest_docs(self, queries, k=1, num_workers=None): 80 | """Process a batch of closest_docs requests multithreaded. 81 | Note: we can use plain threads here as scipy is outside of the GIL. 82 | """ 83 | with ThreadPool(num_workers) as threads: 84 | closest_docs = partial(self.closest_docs, k=k) 85 | results = threads.map(closest_docs, queries) 86 | return results 87 | 88 | # Elastic DB 89 | 90 | def __enter__(self): 91 | return self 92 | 93 | def close(self): 94 | """Close the connection to the database.""" 95 | self.es = None 96 | 97 | def get_doc_ids(self): 98 | """Fetch all ids of docs stored in the db.""" 99 | results = self.es.search(index= self.elastic_index, body={ 100 | "query": {"match_all": {}}}) 101 | doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']] 102 | return doc_ids 103 | 104 | def get_doc_text(self, doc_id): 105 | """Fetch the raw text of the doc for 'doc_id'.""" 106 | idx = self.get_doc_index(doc_id) 107 | result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx) 108 | return result if result is None else result['_source'][self.elastic_field_content] 109 | 110 | -------------------------------------------------------------------------------- /drqa/retriever/tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | 16 | from . import utils 17 | from . import DEFAULTS 18 | from .. import tokenizers 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TfidfDocRanker(object): 24 | """Loads a pre-weighted inverted index of token/document terms. 25 | Scores new queries by taking sparse dot products. 26 | """ 27 | 28 | def __init__(self, tfidf_path=None, strict=True): 29 | """ 30 | Args: 31 | tfidf_path: path to saved model file 32 | strict: fail on empty queries or continue (and return empty result) 33 | """ 34 | # Load from disk 35 | tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] 36 | logger.info('Loading %s' % tfidf_path) 37 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 38 | self.doc_mat = matrix 39 | self.ngrams = metadata['ngram'] 40 | self.hash_size = metadata['hash_size'] 41 | self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() 42 | self.doc_freqs = metadata['doc_freqs'].squeeze() 43 | self.doc_dict = metadata['doc_dict'] 44 | self.num_docs = len(self.doc_dict[0]) 45 | self.strict = strict 46 | 47 | def get_doc_index(self, doc_id): 48 | """Convert doc_id --> doc_index""" 49 | return self.doc_dict[0][doc_id] 50 | 51 | def get_doc_id(self, doc_index): 52 | """Convert doc_index --> doc_id""" 53 | return self.doc_dict[1][doc_index] 54 | 55 | def closest_docs(self, query, k=1): 56 | """Closest docs by dot product between query and documents 57 | in tfidf weighted word vector space. 58 | """ 59 | spvec = self.text2spvec(query) 60 | res = spvec * self.doc_mat 61 | 62 | if len(res.data) <= k: 63 | o_sort = np.argsort(-res.data) 64 | else: 65 | o = np.argpartition(-res.data, k)[0:k] 66 | o_sort = o[np.argsort(-res.data[o])] 67 | 68 | doc_scores = res.data[o_sort] 69 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 70 | return doc_ids, doc_scores 71 | 72 | def batch_closest_docs(self, queries, k=1, num_workers=None): 73 | """Process a batch of closest_docs requests multithreaded. 74 | Note: we can use plain threads here as scipy is outside of the GIL. 75 | """ 76 | with ThreadPool(num_workers) as threads: 77 | closest_docs = partial(self.closest_docs, k=k) 78 | results = threads.map(closest_docs, queries) 79 | return results 80 | 81 | def parse(self, query): 82 | """Parse the query into tokens (either ngrams or tokens).""" 83 | tokens = self.tokenizer.tokenize(query) 84 | return tokens.ngrams(n=self.ngrams, uncased=True, 85 | filter_fn=utils.filter_ngram) 86 | 87 | def text2spvec(self, query): 88 | """Create a sparse tfidf-weighted word vector from query. 89 | 90 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 91 | """ 92 | # Get hashed ngrams 93 | words = self.parse(utils.normalize(query)) 94 | wids = [utils.hash(w, self.hash_size) for w in words] 95 | 96 | if len(wids) == 0: 97 | if self.strict: 98 | raise RuntimeError('No valid word in: %s' % query) 99 | else: 100 | logger.warning('No valid word in: %s' % query) 101 | return sp.csr_matrix((1, self.hash_size)) 102 | 103 | # Count TF 104 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 105 | tfs = np.log1p(wids_counts) 106 | 107 | # Count IDF 108 | Ns = self.doc_freqs[wids_unique] 109 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 110 | idfs[idfs < 0] = 0 111 | 112 | # TF-IDF 113 | data = np.multiply(tfs, idfs) 114 | 115 | # One row, sparse csr matrix 116 | indptr = np.array([0, len(wids_unique)]) 117 | spvec = sp.csr_matrix( 118 | (data, wids_unique, indptr), shape=(1, self.hash_size) 119 | ) 120 | 121 | return spvec 122 | -------------------------------------------------------------------------------- /drqa/retriever/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Various retriever utilities.""" 8 | 9 | import regex 10 | import unicodedata 11 | import numpy as np 12 | import scipy.sparse as sp 13 | from sklearn.utils import murmurhash3_32 14 | 15 | 16 | # ------------------------------------------------------------------------------ 17 | # Sparse matrix saving/loading helpers. 18 | # ------------------------------------------------------------------------------ 19 | 20 | 21 | def save_sparse_csr(filename, matrix, metadata=None): 22 | data = { 23 | 'data': matrix.data, 24 | 'indices': matrix.indices, 25 | 'indptr': matrix.indptr, 26 | 'shape': matrix.shape, 27 | 'metadata': metadata, 28 | } 29 | np.savez(filename, **data) 30 | 31 | 32 | def load_sparse_csr(filename): 33 | loader = np.load(filename) 34 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 35 | loader['indptr']), shape=loader['shape']) 36 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 37 | 38 | 39 | # ------------------------------------------------------------------------------ 40 | # Token hashing. 41 | # ------------------------------------------------------------------------------ 42 | 43 | 44 | def hash(token, num_buckets): 45 | """Unsigned 32 bit murmurhash for feature hashing.""" 46 | return murmurhash3_32(token, positive=True) % num_buckets 47 | 48 | 49 | # ------------------------------------------------------------------------------ 50 | # Text cleaning. 51 | # ------------------------------------------------------------------------------ 52 | 53 | 54 | STOPWORDS = { 55 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 56 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 57 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 58 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 59 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 60 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 61 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 62 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 63 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 64 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 65 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 66 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 67 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 68 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 69 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 70 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 71 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 72 | } 73 | 74 | 75 | def normalize(text): 76 | """Resolve different type of unicode encodings.""" 77 | return unicodedata.normalize('NFD', text) 78 | 79 | 80 | def filter_word(text): 81 | """Take out english stopwords, punctuation, and compound endings.""" 82 | text = normalize(text) 83 | if regex.match(r'^\p{P}+$', text): 84 | return True 85 | if text.lower() in STOPWORDS: 86 | return True 87 | return False 88 | 89 | 90 | def filter_ngram(gram, mode='any'): 91 | """Decide whether to keep or discard an n-gram. 92 | 93 | Args: 94 | gram: list of tokens (length N) 95 | mode: Option to throw out ngram if 96 | 'any': any single token passes filter_word 97 | 'all': all tokens pass filter_word 98 | 'ends': book-ended by filterable tokens 99 | """ 100 | filtered = [filter_word(w) for w in gram] 101 | if mode == 'any': 102 | return any(filtered) 103 | elif mode == 'all': 104 | return all(filtered) 105 | elif mode == 'ends': 106 | return filtered[0] or filtered[-1] 107 | else: 108 | raise ValueError('Invalid mode: %s' % mode) 109 | 110 | def get_field(d, field_list): 111 | """get the subfield associated to a list of elastic fields 112 | E.g. ['file', 'filename'] to d['file']['filename'] 113 | """ 114 | if isinstance(field_list, str): 115 | return d[field_list] 116 | else: 117 | idx = d.copy() 118 | for field in field_list: 119 | idx = idx[field] 120 | return idx 121 | -------------------------------------------------------------------------------- /drqa/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DEFAULTS = { 11 | 'corenlp_classpath': os.getenv('CLASSPATH') 12 | } 13 | 14 | 15 | def set_default(key, value): 16 | global DEFAULTS 17 | DEFAULTS[key] = value 18 | 19 | 20 | from .corenlp_tokenizer import CoreNLPTokenizer 21 | from .regexp_tokenizer import RegexpTokenizer 22 | from .simple_tokenizer import SimpleTokenizer 23 | 24 | # Spacy is optional 25 | try: 26 | from .spacy_tokenizer import SpacyTokenizer 27 | except ImportError: 28 | pass 29 | 30 | 31 | def get_class(name): 32 | if name == 'spacy': 33 | return SpacyTokenizer 34 | if name == 'corenlp': 35 | return CoreNLPTokenizer 36 | if name == 'regexp': 37 | return RegexpTokenizer 38 | if name == 'simple': 39 | return SimpleTokenizer 40 | 41 | raise RuntimeError('Invalid tokenizer: %s' % name) 42 | 43 | 44 | def get_annotators_for_args(args): 45 | annotators = set() 46 | if args.use_pos: 47 | annotators.add('pos') 48 | if args.use_lemma: 49 | annotators.add('lemma') 50 | if args.use_ner: 51 | annotators.add('ner') 52 | return annotators 53 | 54 | 55 | def get_annotators_for_model(model): 56 | return get_annotators_for_args(model.args) 57 | -------------------------------------------------------------------------------- /drqa/tokenizers/corenlp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Simple wrapper around the Stanford CoreNLP pipeline. 8 | 9 | Serves commands to a java subprocess running the jar. Requires java 8. 10 | """ 11 | 12 | import copy 13 | import json 14 | import pexpect 15 | 16 | from .tokenizer import Tokens, Tokenizer 17 | from . import DEFAULTS 18 | 19 | 20 | class CoreNLPTokenizer(Tokenizer): 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: set that can include pos, lemma, and ner. 26 | classpath: Path to the corenlp directory of jars 27 | mem: Java heap memory 28 | """ 29 | self.classpath = (kwargs.get('classpath') or 30 | DEFAULTS['corenlp_classpath']) 31 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 32 | self.mem = kwargs.get('mem', '2g') 33 | self._launch() 34 | 35 | def _launch(self): 36 | """Start the CoreNLP jar with pexpect.""" 37 | annotators = ['tokenize', 'ssplit'] 38 | if 'ner' in self.annotators: 39 | annotators.extend(['pos', 'lemma', 'ner']) 40 | elif 'lemma' in self.annotators: 41 | annotators.extend(['pos', 'lemma']) 42 | elif 'pos' in self.annotators: 43 | annotators.extend(['pos']) 44 | annotators = ','.join(annotators) 45 | options = ','.join(['untokenizable=noneDelete', 46 | 'invertible=true']) 47 | cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath, 48 | 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 49 | annotators, '-tokenize.options', options, 50 | '-outputFormat', 'json', '-prettyPrint', 'false'] 51 | 52 | # We use pexpect to keep the subprocess alive and feed it commands. 53 | # Because we don't want to get hit by the max terminal buffer size, 54 | # we turn off canonical input processing to have unlimited bytes. 55 | self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60) 56 | self.corenlp.setecho(False) 57 | self.corenlp.sendline('stty -icanon') 58 | self.corenlp.sendline(' '.join(cmd)) 59 | self.corenlp.delaybeforesend = 0 60 | self.corenlp.delayafterread = 0 61 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 62 | 63 | @staticmethod 64 | def _convert(token): 65 | if token == '-LRB-': 66 | return '(' 67 | if token == '-RRB-': 68 | return ')' 69 | if token == '-LSB-': 70 | return '[' 71 | if token == '-RSB-': 72 | return ']' 73 | if token == '-LCB-': 74 | return '{' 75 | if token == '-RCB-': 76 | return '}' 77 | return token 78 | 79 | def tokenize(self, text): 80 | # Since we're feeding text to the commandline, we're waiting on seeing 81 | # the NLP> prompt. Hacky! 82 | if 'NLP>' in text: 83 | raise RuntimeError('Bad token (NLP>) in text!') 84 | 85 | # Sending q will cause the process to quit -- manually override 86 | if text.lower().strip() == 'q': 87 | token = text.strip() 88 | index = text.index(token) 89 | data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')] 90 | return Tokens(data, self.annotators) 91 | 92 | # Minor cleanup before tokenizing. 93 | clean_text = text.replace('\n', ' ') 94 | 95 | self.corenlp.sendline(clean_text.encode('utf-8')) 96 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 97 | 98 | # Skip to start of output (may have been stderr logging messages) 99 | output = self.corenlp.before 100 | start = output.find(b'{"sentences":') 101 | output = json.loads(output[start:].decode('utf-8')) 102 | 103 | data = [] 104 | tokens = [t for s in output['sentences'] for t in s['tokens']] 105 | for i in range(len(tokens)): 106 | # Get whitespace 107 | start_ws = tokens[i]['characterOffsetBegin'] 108 | if i + 1 < len(tokens): 109 | end_ws = tokens[i + 1]['characterOffsetBegin'] 110 | else: 111 | end_ws = tokens[i]['characterOffsetEnd'] 112 | 113 | data.append(( 114 | self._convert(tokens[i]['word']), 115 | text[start_ws: end_ws], 116 | (tokens[i]['characterOffsetBegin'], 117 | tokens[i]['characterOffsetEnd']), 118 | tokens[i].get('pos', None), 119 | tokens[i].get('lemma', None), 120 | tokens[i].get('ner', None) 121 | )) 122 | return Tokens(data, self.annotators) 123 | -------------------------------------------------------------------------------- /drqa/tokenizers/regexp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers. 8 | 9 | However it is purely in Python, supports robust untokenization, unicode, 10 | and requires minimal dependencies. 11 | """ 12 | 13 | import regex 14 | import logging 15 | from .tokenizer import Tokens, Tokenizer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class RegexpTokenizer(Tokenizer): 21 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 22 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 23 | r'\.(?=\p{Z})') 24 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 25 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 26 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 27 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 28 | CONTRACTION1 = r"can(?=not\b)" 29 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 30 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 31 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 32 | END_DQUOTE = r'(?%s)|(?P%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' 47 | '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' 48 | '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' 49 | '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % 50 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, 51 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, 52 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, 53 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, 54 | self.NON_WS), 55 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 56 | ) 57 | if len(kwargs.get('annotators', {})) > 0: 58 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 59 | (type(self).__name__, kwargs.get('annotators'))) 60 | self.annotators = set() 61 | self.substitutions = kwargs.get('substitutions', True) 62 | 63 | def tokenize(self, text): 64 | data = [] 65 | matches = [m for m in self._regexp.finditer(text)] 66 | for i in range(len(matches)): 67 | # Get text 68 | token = matches[i].group() 69 | 70 | # Make normalizations for special token types 71 | if self.substitutions: 72 | groups = matches[i].groupdict() 73 | if groups['sdquote']: 74 | token = "``" 75 | elif groups['edquote']: 76 | token = "''" 77 | elif groups['ssquote']: 78 | token = "`" 79 | elif groups['esquote']: 80 | token = "'" 81 | elif groups['dash']: 82 | token = '--' 83 | elif groups['ellipses']: 84 | token = '...' 85 | 86 | # Get whitespace 87 | span = matches[i].span() 88 | start_ws = span[0] 89 | if i + 1 < len(matches): 90 | end_ws = matches[i + 1].span()[0] 91 | else: 92 | end_ws = span[1] 93 | 94 | # Format data 95 | data.append(( 96 | token, 97 | text[start_ws: end_ws], 98 | span, 99 | )) 100 | return Tokens(data, self.annotators) 101 | -------------------------------------------------------------------------------- /drqa/tokenizers/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | from .tokenizer import Tokens, Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SimpleTokenizer(Tokenizer): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | if len(kwargs.get('annotators', {})) > 0: 32 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 33 | (type(self).__name__, kwargs.get('annotators'))) 34 | self.annotators = set() 35 | 36 | def tokenize(self, text): 37 | data = [] 38 | matches = [m for m in self._regexp.finditer(text)] 39 | for i in range(len(matches)): 40 | # Get text 41 | token = matches[i].group() 42 | 43 | # Get whitespace 44 | span = matches[i].span() 45 | start_ws = span[0] 46 | if i + 1 < len(matches): 47 | end_ws = matches[i + 1].span()[0] 48 | else: 49 | end_ws = span[1] 50 | 51 | # Format data 52 | data.append(( 53 | token, 54 | text[start_ws: end_ws], 55 | span, 56 | )) 57 | return Tokens(data, self.annotators) 58 | -------------------------------------------------------------------------------- /drqa/tokenizers/spacy_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Tokenizer that is backed by spaCy (spacy.io). 8 | 9 | Requires spaCy package and the spaCy english model. 10 | """ 11 | 12 | import spacy 13 | import copy 14 | from .tokenizer import Tokens, Tokenizer 15 | 16 | 17 | class SpacyTokenizer(Tokenizer): 18 | 19 | def __init__(self, **kwargs): 20 | """ 21 | Args: 22 | annotators: set that can include pos, lemma, and ner. 23 | model: spaCy model to use (either path, or keyword like 'en'). 24 | """ 25 | model = kwargs.get('model', 'en') 26 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 27 | nlp_kwargs = {'parser': False} 28 | if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 29 | nlp_kwargs['tagger'] = False 30 | if 'ner' not in self.annotators: 31 | nlp_kwargs['entity'] = False 32 | self.nlp = spacy.load(model, **nlp_kwargs) 33 | 34 | def tokenize(self, text): 35 | # We don't treat new lines as tokens. 36 | clean_text = text.replace('\n', ' ') 37 | tokens = self.nlp.tokenizer(clean_text) 38 | if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 39 | self.nlp.tagger(tokens) 40 | if 'ner' in self.annotators: 41 | self.nlp.entity(tokens) 42 | 43 | data = [] 44 | for i in range(len(tokens)): 45 | # Get whitespace 46 | start_ws = tokens[i].idx 47 | if i + 1 < len(tokens): 48 | end_ws = tokens[i + 1].idx 49 | else: 50 | end_ws = tokens[i].idx + len(tokens[i].text) 51 | 52 | data.append(( 53 | tokens[i].text, 54 | text[start_ws: end_ws], 55 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 56 | tokens[i].tag_, 57 | tokens[i].lemma_, 58 | tokens[i].ent_type_, 59 | )) 60 | 61 | # Set special option for non-entity tag: '' vs 'O' in spaCy 62 | return Tokens(data, self.annotators, opts={'non_ent': ''}) 63 | -------------------------------------------------------------------------------- /drqa/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | class Tokens(object): 13 | """A class to represent a list of tokenized text.""" 14 | TEXT = 0 15 | TEXT_WS = 1 16 | SPAN = 2 17 | POS = 3 18 | LEMMA = 4 19 | NER = 5 20 | 21 | def __init__(self, data, annotators, opts=None): 22 | self.data = data 23 | self.annotators = annotators 24 | self.opts = opts or {} 25 | 26 | def __len__(self): 27 | """The number of tokens.""" 28 | return len(self.data) 29 | 30 | def slice(self, i=None, j=None): 31 | """Return a view of the list of tokens from [i, j).""" 32 | new_tokens = copy.copy(self) 33 | new_tokens.data = self.data[i: j] 34 | return new_tokens 35 | 36 | def untokenize(self): 37 | """Returns the original text (with whitespace reinserted).""" 38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 39 | 40 | def words(self, uncased=False): 41 | """Returns a list of the text of each token 42 | 43 | Args: 44 | uncased: lower cases text 45 | """ 46 | if uncased: 47 | return [t[self.TEXT].lower() for t in self.data] 48 | else: 49 | return [t[self.TEXT] for t in self.data] 50 | 51 | def offsets(self): 52 | """Returns a list of [start, end) character offsets of each token.""" 53 | return [t[self.SPAN] for t in self.data] 54 | 55 | def pos(self): 56 | """Returns a list of part-of-speech tags of each token. 57 | Returns None if this annotation was not included. 58 | """ 59 | if 'pos' not in self.annotators: 60 | return None 61 | return [t[self.POS] for t in self.data] 62 | 63 | def lemmas(self): 64 | """Returns a list of the lemmatized text of each token. 65 | Returns None if this annotation was not included. 66 | """ 67 | if 'lemma' not in self.annotators: 68 | return None 69 | return [t[self.LEMMA] for t in self.data] 70 | 71 | def entities(self): 72 | """Returns a list of named-entity-recognition tags of each token. 73 | Returns None if this annotation was not included. 74 | """ 75 | if 'ner' not in self.annotators: 76 | return None 77 | return [t[self.NER] for t in self.data] 78 | 79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 80 | """Returns a list of all ngrams from length 1 to n. 81 | 82 | Args: 83 | n: upper limit of ngram length 84 | uncased: lower cases text 85 | filter_fn: user function that takes in an ngram list and returns 86 | True or False to keep or not keep the ngram 87 | as_string: return the ngram as a string vs list 88 | """ 89 | def _skip(gram): 90 | if not filter_fn: 91 | return False 92 | return filter_fn(gram) 93 | 94 | words = self.words(uncased) 95 | ngrams = [(s, e + 1) 96 | for s in range(len(words)) 97 | for e in range(s, min(s + n, len(words))) 98 | if not _skip(words[s:e + 1])] 99 | 100 | # Concatenate into strings 101 | if as_strings: 102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 103 | 104 | return ngrams 105 | 106 | def entity_groups(self): 107 | """Group consecutive entity tokens with the same NER tag.""" 108 | entities = self.entities() 109 | if not entities: 110 | return None 111 | non_ent = self.opts.get('non_ent', 'O') 112 | groups = [] 113 | idx = 0 114 | while idx < len(entities): 115 | ner_tag = entities[idx] 116 | # Check for entity tag 117 | if ner_tag != non_ent: 118 | # Chomp the sequence 119 | start = idx 120 | while (idx < len(entities) and entities[idx] == ner_tag): 121 | idx += 1 122 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 123 | else: 124 | idx += 1 125 | return groups 126 | 127 | 128 | class Tokenizer(object): 129 | """Base tokenizer class. 130 | Tokenizers implement tokenize, which should return a Tokens class. 131 | """ 132 | def tokenize(self, text): 133 | raise NotImplementedError 134 | 135 | def shutdown(self): 136 | pass 137 | 138 | def __del__(self): 139 | self.shutdown() 140 | -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/env/__init__.py -------------------------------------------------------------------------------- /env/client.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from requests import delete, get, post 3 | 4 | from env.core import BaseEnv 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class Env(BaseEnv): 10 | 11 | def __init__(self, server='10.60.1.79:17101'): 12 | self.server = server 13 | self._corpus = dict() 14 | self._title2id = dict() 15 | 16 | def reset(self): 17 | delete(f'http://{self.server}/states') 18 | 19 | def get(self, p_id): 20 | if p_id not in self._corpus: 21 | self._corpus[p_id] = get(f'http://{self.server}/passages/{p_id}').json() 22 | return self._corpus[p_id] 23 | 24 | def title2id(self, norm_title): 25 | if norm_title not in self._title2id: 26 | self._title2id[norm_title] = post(f'http://{self.server}/title2id', json={"norm_title": norm_title}).json() 27 | return self._title2id[norm_title] 28 | 29 | def step(self, command, session_id=None, exclusion=None): 30 | args = {"command": command, "session_id": session_id, "exclusion": exclusion} 31 | return post(f'http://{self.server}/executions', json=args).json() 32 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/ir.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict, Tuple 3 | 4 | import pytrec_eval 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def evaluate_trec(qrels: Dict[str, Dict[str, int]], results: Dict[str, Dict[str, float]], 10 | k_values: List[int]) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]: 11 | ndcg, _map, recall, precision = dict(), dict(), dict(), dict() 12 | for k in k_values: 13 | ndcg[f"NDCG@{k}"] = 0.0 14 | _map[f"MAP@{k}"] = 0.0 15 | recall[f"Recall@{k}"] = 0.0 16 | precision[f"P@{k}"] = 0.0 17 | 18 | map_string = "map_cut." + ",".join([str(k) for k in k_values]) 19 | ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) 20 | recall_string = "recall." + ",".join([str(k) for k in k_values]) 21 | precision_string = "P." + ",".join([str(k) for k in k_values]) 22 | evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string}) 23 | scores = evaluator.evaluate(results) 24 | 25 | for q_id in scores.keys(): 26 | for k in k_values: 27 | ndcg[f"NDCG@{k}"] += scores[q_id][f"ndcg_cut_{k}"] 28 | _map[f"MAP@{k}"] += scores[q_id][f"map_cut_{k}"] 29 | recall[f"Recall@{k}"] += scores[q_id][f"recall_{k}"] 30 | precision[f"P@{k}"] += scores[q_id][f"P_{k}"] 31 | 32 | for k in k_values: 33 | ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(qrels) * 100., 3) 34 | _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(qrels) * 100., 3) 35 | recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(qrels) * 100., 3) 36 | precision[f"P@{k}"] = round(precision[f"P@{k}"] / len(qrels) * 100., 3) 37 | 38 | return ndcg, _map, recall, precision 39 | 40 | 41 | def evaluate_custom(qrels: Dict[str, Dict[str, int]], results: Dict[str, Dict[str, float]], 42 | k_values: List[int], metric: str) -> Tuple[Dict[str, float]]: 43 | if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]: 44 | return mrr(qrels, results, k_values) 45 | 46 | elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]: 47 | return recall_cap(qrels, results, k_values) 48 | 49 | elif metric.lower() in ["hole", "hole@k"]: 50 | return hole(qrels, results, k_values) 51 | 52 | 53 | def mrr(qrels: Dict[str, Dict[str, int]], results: Dict[str, Dict[str, float]], 54 | k_values: List[int]) -> Tuple[Dict[str, float]]: 55 | measures = dict() 56 | for k in k_values: 57 | measures[f"MRR@{k}"] = 0.0 58 | 59 | k_max, top_hits = max(k_values), dict() 60 | for q_id, hits in results.items(): 61 | top_hits[q_id] = sorted(hits.items(), key=lambda item: item[1], reverse=True)[0:k_max] 62 | 63 | for q_id in set(qrels) & set(top_hits): 64 | q_relevant_paras = set([p_id for p_id in qrels[q_id] if qrels[q_id][p_id] > 0]) 65 | for k in k_values: 66 | for rank, hit in enumerate(top_hits[q_id][0:k]): 67 | if hit[0] in q_relevant_paras: 68 | measures[f"MRR@{k}"] += 1.0 / (rank + 1) 69 | break 70 | 71 | for k in k_values: 72 | measures[f"MRR@{k}"] = round(measures[f"MRR@{k}"] / len(qrels) * 100., 3) 73 | 74 | return measures 75 | 76 | 77 | def recall_cap(qrels: Dict[str, Dict[str, int]], results: Dict[str, Dict[str, float]], 78 | k_values: List[int]) -> Tuple[Dict[str, float]]: 79 | measures = dict() 80 | for k in k_values: 81 | measures[f"R_cap@{k}"] = 0.0 82 | 83 | k_max = max(k_values) 84 | for query_id, doc_scores in results.items(): 85 | top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] 86 | query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0] 87 | for k in k_values: 88 | retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0] 89 | denominator = min(len(query_relevant_docs), k) 90 | measures[f"R_cap@{k}"] += (len(retrieved_docs) / denominator) 91 | 92 | for k in k_values: 93 | measures[f"R_cap@{k}"] = round(measures[f"R_cap@{k}"] / len(results) * 100., 3) 94 | 95 | return measures 96 | 97 | 98 | def hole(qrels: Dict[str, Dict[str, int]], results: Dict[str, Dict[str, float]], 99 | k_values: List[int]) -> Tuple[Dict[str, float]]: 100 | measures = {} 101 | for k in k_values: 102 | measures[f"Hole@{k}"] = 0.0 103 | 104 | annotated_corpus = set() 105 | for _, docs in qrels.items(): 106 | for doc_id, score in docs.items(): 107 | annotated_corpus.add(doc_id) 108 | 109 | k_max = max(k_values) 110 | for _, scores in results.items(): 111 | top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] 112 | for k in k_values: 113 | hole_docs = [row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus] 114 | measures[f"Hole@{k}"] += len(hole_docs) / k 115 | 116 | for k in k_values: 117 | measures[f"Hole@{k}"] = round(measures[f"Hole@{k}"] / len(results) * 100., 3) 118 | 119 | return measures 120 | -------------------------------------------------------------------------------- /figs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/figs/demo.gif -------------------------------------------------------------------------------- /figs/simple_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/figs/simple_demo.png -------------------------------------------------------------------------------- /generate_dense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line tool that produces embeddings for a large documents base based on the pretrained ctx & question encoders 10 | Supposed to be used in a 'sharded' way to speed up the process. 11 | 12 | export CUDA_VISIBLE_DEVICES=0,1,2,3 13 | 14 | python generate_dense.py \ 15 | --model_file ckpts/dpr/retriever/multiset/bert-base-encoder.cp \ 16 | --ctx_file data/corpus/enwiki-20171001-paragraph-5.tsv \ 17 | --out_file data/vector/dpr/enwiki-20171001-paragraph-5 \ 18 | --batch_size 256 \ 19 | --num_shards 8 \ 20 | --shard_id 0 21 | 22 | """ 23 | import os 24 | import pathlib 25 | 26 | import argparse 27 | import logging 28 | import pickle 29 | from tqdm import trange 30 | from typing import List, Tuple 31 | 32 | import numpy as np 33 | import torch 34 | from torch import nn 35 | 36 | from dpr.models import init_biencoder_components 37 | from dpr.options import (add_encoder_params, setup_args_gpu, print_args, set_encoder_params_from_state, 38 | add_tokenizer_params, add_cuda_params) 39 | from dpr.utils.data_utils import Tensorizer 40 | from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint, move_to_device 41 | 42 | logger = logging.getLogger() 43 | logger.setLevel(logging.INFO) 44 | if logger.hasHandlers(): 45 | logger.handlers.clear() 46 | console = logging.StreamHandler() 47 | logger.addHandler(console) 48 | 49 | 50 | def gen_ctx_vectors(ctx_rows: List[Tuple[object, str, str]], 51 | model: nn.Module, 52 | tensorizer: Tensorizer, 53 | insert_title: bool = True) -> List[Tuple[object, np.array]]: 54 | n = len(ctx_rows) 55 | bsz = args.batch_size 56 | results = [] 57 | for batch_start in trange(0, n, bsz): 58 | ctx_ids = [] 59 | batch_token_tensors = [] 60 | for r in ctx_rows[batch_start:batch_start + bsz]: 61 | ctx_ids.append(r[0]) 62 | batch_token_tensors.append(tensorizer.text_to_tensor(r[1], title=r[2] if insert_title else None)) 63 | 64 | ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0), args.device) 65 | ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), args.device) 66 | ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), args.device) 67 | with torch.no_grad(): 68 | _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) 69 | out = out.cpu() 70 | 71 | assert len(ctx_ids) == out.size(0) 72 | 73 | results.extend([(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))]) 74 | 75 | return results 76 | 77 | 78 | def main(): 79 | saved_state = load_states_from_checkpoint(args.model_file) 80 | set_encoder_params_from_state(saved_state.encoder_params, args) 81 | print_args(args) 82 | 83 | tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) 84 | 85 | encoder = encoder.ctx_model 86 | 87 | encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, 88 | args.local_rank, args.fp16, args.fp16_opt_level) 89 | encoder.eval() 90 | 91 | # load weights from the model file 92 | model_to_load = get_model_obj(encoder) 93 | logger.info('Loading saved model state ...') 94 | logger.debug('saved model keys =%s', saved_state.model_dict.keys()) 95 | 96 | prefix_len = len('ctx_model.') 97 | ctx_state = {key[prefix_len:]: value 98 | for (key, value) in saved_state.model_dict.items() if key.startswith('ctx_model.')} 99 | model_to_load.load_state_dict(ctx_state) 100 | 101 | logger.info(f'reading data from {args.ctx_file} ...') 102 | rows = [] 103 | with open(args.ctx_file) as tsv_file: 104 | # file format: doc_id, doc_text, title(, xx)* 105 | num_field = None 106 | for line in tsv_file: 107 | segs = line.strip().split('\t') 108 | pid, text, title = segs[:3] 109 | if pid != 'id': 110 | rows.append((pid, text, title)) 111 | else: 112 | num_field = len(segs) 113 | if len(segs) != num_field: 114 | logger.warning(f'Wrong line format: {pid}') 115 | 116 | shard_size = len(rows) // args.num_shards 117 | start_idx = args.shard_id * shard_size 118 | end_idx = start_idx + shard_size if args.shard_id != args.num_shards - 1 else len(rows) 119 | logger.info(f'Producing encodings for passages [{start_idx:,d}, {end_idx:,d}) ' 120 | f'({args.shard_id}/{args.num_shards} of {len(rows):,d})') 121 | rows = rows[start_idx:end_idx] 122 | 123 | data = gen_ctx_vectors(rows, encoder, tensorizer, True) 124 | 125 | file = args.out_file + '_' + str(args.shard_id) + '.pkl' 126 | pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) 127 | logger.info(f'{len(data):,d} passages processed. Writing results to {file}') 128 | with open(file, mode='wb') as f: 129 | pickle.dump(data, f) 130 | 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser() 134 | 135 | add_encoder_params(parser) 136 | add_tokenizer_params(parser) 137 | add_cuda_params(parser) 138 | 139 | parser.add_argument('--ctx_file', type=str, default=None, help='Path to passages set .tsv file') 140 | parser.add_argument('--out_file', required=True, type=str, default=None, 141 | help='Output file path (prefix) to write results to') 142 | parser.add_argument('--shard_id', type=int, default=0, help="Number(0-based) of data shard to process") 143 | parser.add_argument('--num_shards', type=int, default=1, help="Total amount of data shards") 144 | parser.add_argument('--batch_size', type=int, default=32, help="Batch size for the passage encoder forward pass") 145 | args = parser.parse_args() 146 | 147 | assert args.model_file, 'Please specify --model_file checkpoint to init model weights' 148 | 149 | setup_args_gpu(args) 150 | 151 | main() 152 | -------------------------------------------------------------------------------- /install_corenlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -e 9 | 10 | # By default download to the data directory I guess 11 | read -p "Specify download path or enter to use default (corenlp): " path 12 | DOWNLOAD_PATH="${path:-corenlp}" 13 | echo "Will download to: $DOWNLOAD_PATH" 14 | 15 | # Download zip, unzip 16 | pushd "/tmp" 17 | wget -O "stanford-corenlp-full-2017-06-09.zip" "http://nlp.stanford.edu/software/stanford-corenlp-full-2017-06-09.zip" 18 | unzip "stanford-corenlp-full-2017-06-09.zip" 19 | rm "stanford-corenlp-full-2017-06-09.zip" 20 | popd 21 | 22 | # Put jars in DOWNLOAD_PATH 23 | mkdir -p "$DOWNLOAD_PATH" 24 | mv "/tmp/stanford-corenlp-full-2017-06-09/"*".jar" "$DOWNLOAD_PATH/" 25 | 26 | # Append to bashrc, instructions 27 | while read -p "Add to ~/.bashrc CLASSPATH (recommended)? [yes/no]: " choice; do 28 | case "$choice" in 29 | yes ) 30 | echo "export CLASSPATH=\$CLASSPATH:$DOWNLOAD_PATH/*" >> ~/.bashrc; 31 | break ;; 32 | no ) 33 | break ;; 34 | * ) echo "Please answer yes or no." ;; 35 | esac 36 | done 37 | 38 | printf "\n*** NOW RUN: ***\n\nexport CLASSPATH=\$CLASSPATH:$DOWNLOAD_PATH/*\n\n****************\n" 39 | -------------------------------------------------------------------------------- /mdr/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from . import qa 4 | from . import retrieval -------------------------------------------------------------------------------- /mdr/qa/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /mdr/qa/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ast import parse 3 | 4 | from typing import NamedTuple 5 | 6 | from torch.nn import parallel 7 | class ClusterConfig(NamedTuple): 8 | dist_backend: str 9 | dist_url: str 10 | 11 | def common_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | # task 15 | parser.add_argument("--train_file", type=str, 16 | default="../data/nq-with-neg-train.txt") 17 | parser.add_argument("--predict_file", type=str, 18 | default="../data/nq-with-neg-dev.txt") 19 | parser.add_argument("--num_workers", default=10, type=int) 20 | parser.add_argument("--do_train", default=False, 21 | action='store_true', help="Whether to run training.") 22 | parser.add_argument("--do_predict", default=False, 23 | action='store_true', help="Whether to run eval on the dev set.") 24 | parser.add_argument("--do_test", default=False, action="store_true", help="for final test submission") 25 | 26 | # model 27 | parser.add_argument("--model_name", 28 | default="bert-base-uncased", type=str) 29 | parser.add_argument("--init_checkpoint", type=str, 30 | help="Initial checkpoint (usually from a pre-trained BERT model).", 31 | default="") 32 | parser.add_argument("--max_seq_len", default=512, type=int, 33 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 34 | "longer than this will be truncated, and sequences shorter than this will be padded.") 35 | parser.add_argument("--max_q_len", default=64, type=int) 36 | parser.add_argument("--max_ans_len", default=35, type=int) 37 | parser.add_argument('--fp16', action='store_true') 38 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 39 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 40 | "See details at https://nvidia.github.io/apex/amp.html") 41 | parser.add_argument("--no_cuda", default=False, action='store_true', 42 | help="Whether not to use CUDA when available") 43 | parser.add_argument("--local_rank", type=int, default=-1, 44 | help="local_rank for distributed training on gpus") 45 | parser.add_argument("--predict_batch_size", default=256, 46 | type=int, help="Total batch size for predictions.") 47 | parser.add_argument("--save-prediction", default="", type=str) 48 | 49 | parser.add_argument("--sp-pred", action="store_true", help="whether to predict sentence sp") 50 | return parser 51 | 52 | def train_args(): 53 | parser = common_args() 54 | # optimization 55 | parser.add_argument('--prefix', type=str, default="eval") 56 | parser.add_argument("--weight_decay", default=0.0, type=float, 57 | help="Weight decay if we apply some.") 58 | parser.add_argument("--output_dir", default="./logs", type=str, 59 | help="The output directory where the model checkpoints will be written.") 60 | parser.add_argument("--train_batch_size", default=128, 61 | type=int, help="Total batch size for training.") 62 | parser.add_argument("--num_q_per_gpu", default=1) 63 | parser.add_argument("--learning_rate", default=1e-5, 64 | type=float, help="The initial learning rate for Adam.") 65 | parser.add_argument("--num_train_epochs", default=5, type=float, 66 | help="Total number of training epochs to perform.") 67 | parser.add_argument('--seed', type=int, default=3, 68 | help="random seed for initialization") 69 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 70 | help="Number of updates steps to accumualte before performing a backward/update pass.") 71 | parser.add_argument('--eval-period', type=int, default=2500) 72 | parser.add_argument("--max_grad_norm", default=2.0, type=float, help="Max gradient norm.") 73 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 74 | parser.add_argument("--neg-num", type=int, default=9, help="how many neg/distant passage chains to use") 75 | parser.add_argument("--shared-norm", action="store_true") 76 | parser.add_argument("--qa-drop", default=0, type=float) 77 | parser.add_argument("--rank-drop", default=0, type=float) 78 | parser.add_argument("--sp-drop", default=0, type=float) 79 | parser.add_argument("--final-metric", default="joint_f1") 80 | parser.add_argument("--use-adam", action="store_true", help="use adam or adamW") 81 | parser.add_argument("--warmup-ratio", default=0, type=float, help="Linear warmup over warmup_steps.") 82 | parser.add_argument("--sp-weight", default=0, type=float, help="weight of the sp loss") 83 | return parser.parse_args() 84 | -------------------------------------------------------------------------------- /mdr/qa/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | def explore(path): 7 | train = json.load(open(path)) 8 | 9 | neg_counts = [] 10 | for item in train: 11 | tfidf_neg = item["tfidf_neg"] 12 | linked_neg = item["linked_neg"] 13 | neg_counts.append(len(tfidf_neg + linked_neg)) 14 | 15 | import pdb; pdb.set_trace() 16 | return 17 | 18 | def load_corpus(corpus_path="/private/home/xwhan/data/hotpot/tfidf/abstracts.txt"): 19 | content = [json.loads(l) for l in open(corpus_path).readlines()] 20 | title2doc = {item["title"]:item["text"] for item in content} 21 | 22 | if __name__ == "__main__": 23 | explore("/private/home/xwhan/data/hotpot/hotpot_rerank_train_2_neg_types.json") -------------------------------------------------------------------------------- /mdr/qa/hotpot_evaluate_v1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ujson as json 3 | import re 4 | import string 5 | from collections import Counter 6 | import pickle 7 | 8 | def normalize_answer(s): 9 | 10 | def remove_articles(text): 11 | return re.sub(r'\b(a|an|the)\b', ' ', text) 12 | 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | 26 | def f1_score(prediction, ground_truth): 27 | normalized_prediction = normalize_answer(prediction) 28 | normalized_ground_truth = normalize_answer(ground_truth) 29 | 30 | ZERO_METRIC = (0, 0, 0) 31 | 32 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 33 | return ZERO_METRIC 34 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 35 | return ZERO_METRIC 36 | 37 | prediction_tokens = normalized_prediction.split() 38 | ground_truth_tokens = normalized_ground_truth.split() 39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 40 | num_same = sum(common.values()) 41 | if num_same == 0: 42 | return ZERO_METRIC 43 | precision = 1.0 * num_same / len(prediction_tokens) 44 | recall = 1.0 * num_same / len(ground_truth_tokens) 45 | f1 = (2 * precision * recall) / (precision + recall) 46 | return f1, precision, recall 47 | 48 | 49 | def exact_match_score(prediction, ground_truth): 50 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 51 | 52 | def update_answer(metrics, prediction, gold): 53 | em = exact_match_score(prediction, gold) 54 | f1, prec, recall = f1_score(prediction, gold) 55 | metrics['em'] += float(em) 56 | metrics['f1'] += f1 57 | metrics['prec'] += prec 58 | metrics['recall'] += recall 59 | return em, prec, recall 60 | 61 | def update_sp(metrics, prediction, gold): 62 | cur_sp_pred = set(map(tuple, prediction)) 63 | gold_sp_pred = set(map(tuple, gold)) 64 | tp, fp, fn = 0, 0, 0 65 | for e in cur_sp_pred: 66 | if e in gold_sp_pred: 67 | tp += 1 68 | else: 69 | fp += 1 70 | for e in gold_sp_pred: 71 | if e not in cur_sp_pred: 72 | fn += 1 73 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 74 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 75 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 76 | em = 1.0 if fp + fn == 0 else 0.0 77 | metrics['sp_em'] += em 78 | metrics['sp_f1'] += f1 79 | metrics['sp_prec'] += prec 80 | metrics['sp_recall'] += recall 81 | return em, prec, recall 82 | 83 | def eval(prediction_file, gold_file): 84 | with open(prediction_file) as f: 85 | prediction = json.load(f) 86 | with open(gold_file) as f: 87 | gold = json.load(f) 88 | 89 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 90 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 91 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} 92 | for dp in gold: 93 | cur_id = dp['_id'] 94 | can_eval_joint = True 95 | if cur_id not in prediction['answer']: 96 | print('missing answer {}'.format(cur_id)) 97 | can_eval_joint = False 98 | else: 99 | em, prec, recall = update_answer( 100 | metrics, prediction['answer'][cur_id], dp['answer']) 101 | if cur_id not in prediction['sp']: 102 | print('missing sp fact {}'.format(cur_id)) 103 | can_eval_joint = False 104 | else: 105 | sp_em, sp_prec, sp_recall = update_sp( 106 | metrics, prediction['sp'][cur_id], dp['supporting_facts']) 107 | 108 | if can_eval_joint: 109 | joint_prec = prec * sp_prec 110 | joint_recall = recall * sp_recall 111 | if joint_prec + joint_recall > 0: 112 | joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) 113 | else: 114 | joint_f1 = 0. 115 | joint_em = em * sp_em 116 | 117 | metrics['joint_em'] += joint_em 118 | metrics['joint_f1'] += joint_f1 119 | metrics['joint_prec'] += joint_prec 120 | metrics['joint_recall'] += joint_recall 121 | 122 | N = len(gold) 123 | for k in metrics.keys(): 124 | metrics[k] /= N 125 | 126 | print(metrics) 127 | 128 | if __name__ == '__main__': 129 | eval(sys.argv[1], sys.argv[2]) 130 | 131 | -------------------------------------------------------------------------------- /mdr/qa/qa_model.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers import AutoModel, BertModel 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | class BertPooler(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 12 | self.activation = nn.Tanh() 13 | 14 | def forward(self, hidden_states): 15 | # We "pool" the model by simply taking the hidden state corresponding 16 | # to the first token. 17 | first_token_tensor = hidden_states[:, 0] 18 | pooled_output = self.dense(first_token_tensor) 19 | pooled_output = self.activation(pooled_output) 20 | return pooled_output 21 | 22 | class QAModel(nn.Module): 23 | 24 | def __init__(self, 25 | config, 26 | args 27 | ): 28 | super().__init__() 29 | self.model_name = args.model_name 30 | self.sp_weight = args.sp_weight 31 | self.sp_pred = args.sp_pred 32 | self.encoder = AutoModel.from_pretrained(args.model_name) 33 | 34 | if "electra" in args.model_name: 35 | self.pooler = BertPooler(config) 36 | 37 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 38 | self.rank = nn.Linear(config.hidden_size, 1) # noan 39 | 40 | if self.sp_pred: 41 | self.sp = nn.Linear(config.hidden_size, 1) 42 | self.loss_fct = CrossEntropyLoss(ignore_index=-1, reduction="none") 43 | 44 | def forward(self, batch): 45 | 46 | outputs = self.encoder(batch['input_ids'], batch['attention_mask'], batch.get('token_type_ids', None)) 47 | 48 | if "electra" in self.model_name: 49 | sequence_output = outputs[0] 50 | pooled_output = self.pooler(sequence_output) 51 | else: 52 | sequence_output, pooled_output = outputs[0], outputs[1] 53 | 54 | logits = self.qa_outputs(sequence_output) 55 | outs = [o.squeeze(-1) for o in logits.split(1, dim=-1)] 56 | outs = [o.float().masked_fill(batch["paragraph_mask"].ne(1), float("-inf")).type_as(o) for o in outs] 57 | 58 | start_logits, end_logits = outs[0], outs[1] 59 | rank_score = self.rank(pooled_output) 60 | 61 | if self.sp_pred: 62 | gather_index = batch["sent_offsets"].unsqueeze(2).expand(-1, -1, sequence_output.size()[-1]) 63 | sent_marker_rep = torch.gather(sequence_output, 1, gather_index) 64 | sp_score = self.sp(sent_marker_rep).squeeze(2) 65 | else: 66 | sp_score = None 67 | 68 | if self.training: 69 | 70 | rank_target = batch["label"] 71 | if self.sp_pred: 72 | sp_loss = F.binary_cross_entropy_with_logits(sp_score, batch["sent_labels"].float(), reduction="none") 73 | sp_loss = (sp_loss * batch["sent_offsets"]) * batch["label"] 74 | sp_loss = sp_loss.sum() 75 | 76 | start_positions, end_positions = batch["starts"], batch["ends"] 77 | 78 | rank_loss = F.binary_cross_entropy_with_logits(rank_score, rank_target.float(), reduction="sum") 79 | 80 | start_losses = [self.loss_fct(start_logits, starts) for starts in torch.unbind(start_positions, dim=1)] 81 | end_losses = [self.loss_fct(end_logits, ends) for ends in torch.unbind(end_positions, dim=1)] 82 | loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat([t.unsqueeze(1) for t in end_losses], dim=1) 83 | log_prob = - loss_tensor 84 | log_prob = log_prob.float().masked_fill(log_prob == 0, float('-inf')).type_as(log_prob) 85 | marginal_probs = torch.sum(torch.exp(log_prob), dim=1) 86 | m_prob = [marginal_probs[idx] for idx in marginal_probs.nonzero()] 87 | if len(m_prob) == 0: 88 | span_loss = self.loss_fct(start_logits, start_logits.new_zeros( 89 | start_logits.size(0)).long()-1).sum() 90 | else: 91 | span_loss = - torch.log(torch.cat(m_prob)).sum() 92 | 93 | if self.sp_pred: 94 | loss = rank_loss + span_loss + sp_loss * self.sp_weight 95 | else: 96 | loss = rank_loss + span_loss 97 | return loss.unsqueeze(0) 98 | 99 | return { 100 | 'start_logits': start_logits, 101 | 'end_logits': end_logits, 102 | 'rank_score': rank_score, 103 | "sp_score": sp_score 104 | } 105 | -------------------------------------------------------------------------------- /mdr/qa/train.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \ 4 | --do_train \ 5 | --prefix qa_wwm_bert_title_mark_eval_debug \ 6 | --predict_batch_size 512 \ 7 | --model_name bert-large-uncased-whole-word-masking \ 8 | --train_batch_size 80 \ 9 | --learning_rate 3e-5 \ 10 | --fp16 \ 11 | --train_file /private/home/xwhan/data/hotpot/dense_train_b10_top20_outputs.json \ 12 | --predict_file /private/home/xwhan/data/hotpot/dense_val_outputs.json \ 13 | --seed 3 \ 14 | --eval-period 10 \ 15 | --max_seq_len 512 \ 16 | --max_q_len 100 \ 17 | --gradient_accumulation_steps 8 \ 18 | --neg-num 4 19 | 20 | 21 | # spanbert debug, fp16 does not work 22 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \ 23 | --do_train \ 24 | --prefix ranked_spanbert_debug \ 25 | --predict_batch_size 1024 \ 26 | --model_name spanbert \ 27 | --train_batch_size 48 \ 28 | --learning_rate 3e-5 \ 29 | --train_file /private/home/xwhan/data/hotpot/dense_train_b10_top20_outputs_sents.json \ 30 | --predict_file /private/home/xwhan/data/hotpot/dense_val_outputs_sents.json \ 31 | --seed 3 \ 32 | --eval-period 500 \ 33 | --max_seq_len 512 \ 34 | --max_q_len 64 \ 35 | --gradient_accumulation_steps 8 \ 36 | --neg-num 5 \ 37 | --use-adam 38 | 39 | # test electra 40 | CUDA_VISIBLE_DEVICES=0 python train_qa.py \ 41 | --do_train \ 42 | --prefix electra_large_debug_sn \ 43 | --predict_batch_size 1024 \ 44 | --model_name google/electra-large-discriminator \ 45 | --train_batch_size 12 \ 46 | --learning_rate 5e-5 \ 47 | --train_file /private/home/xwhan/data/hotpot/dense_train_b100_k100_sents.json \ 48 | --predict_file /private/home/xwhan/data/hotpot/dense_val_b30_k30_roberta_sents.json \ 49 | --seed 42 \ 50 | --eval-period 250 \ 51 | --max_seq_len 512 \ 52 | --max_q_len 64 \ 53 | --gradient_accumulation_steps 8 \ 54 | --neg-num 11 \ 55 | --fp16 \ 56 | --use-adam \ 57 | --warmup-ratio 0.1 \ 58 | --sp-weight 0.05 \ 59 | --sp-pred \ 60 | --shared-norm 61 | 62 | 63 | # QA evaluation 64 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \ 65 | --do_predict \ 66 | --predict_batch_size 2000 \ 67 | --model_name google/electra-large-discriminator \ 68 | --fp16 \ 69 | --predict_file /private/home/xwhan/data/hotpot/dense_val_b100_k100_roberta_best_sents.json \ 70 | --max_seq_len 512 \ 71 | --max_q_len 64 \ 72 | --init_checkpoint qa/logs/08-10-2020/electra_val_top30-epoch7-lr5e-05-seed42-rdrop0-qadrop0-decay0-qpergpu2-aggstep8-clip2-evalper250-evalbsize1024-negnum5-warmup0.1-adamTrue-spweight0.025/checkpoint_best.pt \ 73 | --sp-pred \ 74 | --max_ans_len 30 \ 75 | --save-prediction hotpot_val_top100.json 76 | 77 | # QA evaluation with wwm 78 | 79 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \ 80 | --do_predict \ 81 | --predict_batch_size 1024 \ 82 | --model_name bert-large-uncased-whole-word-masking \ 83 | --fp16 \ 84 | --predict_file /private/home/xwhan/data/hotpot/dense_hotpot_val_b250_k250_roberta_best_sents.json \ 85 | --max_seq_len 512 \ 86 | --max_q_len 64 \ 87 | --init_checkpoint qa/logs/08-17-2020/wwm_val_top50-epoch7-lr5e-05-seed42-rdrop0-qadrop0-decay0-qpergpu2-aggstep8-clip2-evalper250-evalbsize1024-negnum5-warmup0.2-adamTrue-spweight0.025-snFalse/checkpoint_best.pt \ 88 | --sp-pred \ 89 | --max_ans_len 30 \ 90 | --save-prediction hotpot_val_wwm_top250.json 91 | 92 | 93 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \ 94 | --do_predict \ 95 | --predict_batch_size 1024 \ 96 | --model_name google/electra-large-discriminator \ 97 | --fp16 \ 98 | --predict_file /private/home/xwhan/data/hotpot/dense_val_b50_k50_roberta_best_sents.json \ 99 | --max_seq_len 512 \ 100 | --max_q_len 64 \ 101 | --init_checkpoint qa/logs/08-10-2020/electra_val_top30-epoch7-lr5e-05-seed42-rdrop0-qadrop0-decay0-qpergpu2-aggstep8-clip2-evalper250-evalbsize1024-negnum5-warmup0.1-adamTrue-spweight0.025/checkpoint_best.pt \ 102 | --sp-pred \ 103 | --max_ans_len 30 \ 104 | --save-prediction hotpot_val_b5_k5.json \ 105 | 106 | srun --gres=gpu:8 --partition learnfair --time=48:00:00 --mem 500G --constraint volta32gb --cpus-per-task 80 --pty /bin/bash -l 107 | -------------------------------------------------------------------------------- /mdr/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from . import data 9 | from . import models 10 | from . import utils 11 | -------------------------------------------------------------------------------- /mdr/retrieval/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ast import parse 3 | from typing import NamedTuple 4 | 5 | 6 | class ClusterConfig(NamedTuple): 7 | dist_backend: str 8 | dist_url: str 9 | 10 | 11 | def common_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | # task 15 | parser.add_argument("--train_file", type=str, default="../data/nq-with-neg-train.txt") 16 | parser.add_argument("--predict_file", type=str, default="../data/nq-with-neg-dev.txt") 17 | parser.add_argument("--num_workers", default=20, type=int) 18 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training") 19 | parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set") 20 | 21 | # model 22 | parser.add_argument("--model_name", default="bert-base-uncased", type=str) 23 | parser.add_argument("--init_checkpoint", default="", type=str, 24 | help="Initial checkpoint (usually from a pre-trained BERT model).") 25 | parser.add_argument("--max_c_len", default=512, type=int, 26 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 27 | "longer than this will be truncated, and sequences shorter than this will be padded.") 28 | parser.add_argument("--max_q_len", default=50, type=int, 29 | help="The maximum number of tokens for the question. Questions longer than this will " 30 | "be truncated to this length.") 31 | parser.add_argument('--fp16', action='store_true') 32 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 33 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 34 | "See details at https://nvidia.github.io/apex/amp.html") 35 | parser.add_argument("--no_cuda", default=False, action='store_true', 36 | help="Whether not to use CUDA when available") 37 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") 38 | parser.add_argument("--max_q_sp_len", default=50, type=int) 39 | parser.add_argument("--sent_level", action="store_true") 40 | parser.add_argument("--rnn_retriever", action="store_true") 41 | parser.add_argument("--predict_batch_size", default=512, type=int, help="Batch size for prediction") 42 | parser.add_argument("--shared_encoder", action="store_true") 43 | 44 | # multi vector scheme 45 | parser.add_argument("--multi_vector", type=int, default=1) 46 | parser.add_argument("--scheme", type=str, default="none", help="how to get the multivector, layerwise or tokenwise") 47 | 48 | # momentum 49 | parser.add_argument("--momentum", action="store_true") 50 | parser.add_argument("--init_retriever", type=str, default="") 51 | parser.add_argument("--k", type=int, default=38400, help="memory bank size") 52 | parser.add_argument("--m", type=float, default=0.999, help="momentum") 53 | 54 | # NQ multi-hop trial 55 | parser.add_argument("--nq-multi", action="store_true", 56 | help="train the NQ retrieval model to recover from error cases") 57 | 58 | return parser 59 | 60 | 61 | def train_args(): 62 | parser = common_args() 63 | # optimization 64 | parser.add_argument('--prefix', type=str, default="eval") 65 | parser.add_argument("--weight_decay", default=0.0, type=float, 66 | help="Weight decay if we apply some.") 67 | parser.add_argument("--temperature", default=1, type=float) 68 | parser.add_argument("--output_dir", default="./logs", type=str, 69 | help="The output directory where the model checkpoints will be written.") 70 | parser.add_argument("--train_batch_size", default=128, 71 | type=int, help="Total batch size for training.") 72 | parser.add_argument("--learning_rate", default=1e-5, 73 | type=float, help="The initial learning rate for Adam.") 74 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 75 | help="Epsilon for Adam optimizer.") 76 | parser.add_argument("--num_train_epochs", default=50, type=float, 77 | help="Total number of training epochs to perform.") 78 | parser.add_argument("--save_checkpoints_steps", default=20000, type=int, 79 | help="How often to save the model checkpoint.") 80 | parser.add_argument("--iterations_per_loop", default=1000, type=int, 81 | help="How many steps to make in each estimator call.") 82 | parser.add_argument("--accumulate_gradients", type=int, default=1, 83 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 84 | parser.add_argument('--seed', type=int, default=3, 85 | help="random seed for initialization") 86 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 87 | help="Number of updates steps to accumulate before performing a backward/update pass.") 88 | parser.add_argument('--eval_period', type=int, default=2500) 89 | parser.add_argument("--max_grad_norm", default=2.0, type=float, help="Max gradient norm.") 90 | parser.add_argument("--stop_drop", default=0, type=float) 91 | parser.add_argument("--use_adam", action="store_true") 92 | parser.add_argument("--warmup_ratio", default=0, type=float, help="Linear warmup over warmup_steps.") 93 | 94 | return parser.parse_args() 95 | 96 | 97 | def encode_args(): 98 | parser = common_args() 99 | parser.add_argument('--corpus_file', required=True, type=str, default=None, help='Path to passages .tsv file') 100 | parser.add_argument("--strict", action="store_true", help="whether to strictly use original data of dataset") 101 | parser.add_argument('--embedding_prefix', required=True, type=str, default=None, 102 | help='Output path(prefix) to write embeddings to') 103 | parser.add_argument('--num_shards', type=int, default=1, help="Total amount of data shards") 104 | parser.add_argument('--shard_id', type=int, default=0, help="Number(0-based) of data shard to process") 105 | args = parser.parse_args() 106 | assert args.init_checkpoint, 'Please specify --init_checkpoint checkpoint to init model weights' 107 | 108 | return args 109 | -------------------------------------------------------------------------------- /mdr/retrieval/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mdr/retrieval/data/encode_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | from .data_utils import collate_tokens 5 | import unicodedata 6 | import re 7 | import os 8 | 9 | 10 | def normalize(text): 11 | """Resolve different type of unicode encodings.""" 12 | return unicodedata.normalize('NFD', text) 13 | 14 | 15 | def convert_brc(string): 16 | string = re.sub('-LRB-', '(', string) 17 | string = re.sub('-RRB-', ')', string) 18 | string = re.sub('-LSB-', '[', string) 19 | string = re.sub('-RSB-', ']', string) 20 | string = re.sub('-LCB-', '{', string) 21 | string = re.sub('-RCB-', '}', string) 22 | string = re.sub('-COLON-', ':', string) 23 | return string 24 | 25 | 26 | class EmDataset(Dataset): 27 | 28 | def __init__(self, tokenizer, data_path, max_c_len, strict=False): 29 | super().__init__() 30 | self.tokenizer = tokenizer 31 | self.max_c_len = max_c_len 32 | self.strict = strict 33 | print(f"Max sequence length: {self.max_c_len}") 34 | 35 | print(f"Loading data from {data_path} ...") 36 | if data_path.endswith("tsv"): 37 | self.data = [] 38 | with open(data_path) as tsv_file: 39 | num_field = None 40 | for line in tsv_file: 41 | segs = line.strip().split('\t') 42 | p_id, text, title = segs[:3] 43 | if p_id != 'id': 44 | p_id, text, title = p_id.strip(), text.strip(), title.strip() 45 | if self.strict: 46 | sentence_spans = [tuple(span) for span in eval(segs[-1])] 47 | text = text[sentence_spans[0][0]:sentence_spans[-1][1]] 48 | self.data.append({"p_id": p_id, "text": text, "title": title}) 49 | else: 50 | num_field = len(segs) 51 | if len(segs) != num_field: 52 | print(f'Wrong line format: {p_id}') 53 | elif "fever" in data_path: 54 | raw_data = [json.loads(line) for line in tqdm(open(data_path).readlines())] 55 | self.data = [] 56 | for obj in raw_data: 57 | self.data.append(obj) 58 | else: 59 | self.data = [json.loads(line) for line in open(data_path).readlines()] 60 | print(f"loaded {len(self.data)} passages") 61 | 62 | def __getitem__(self, index): 63 | sample = self.data[index] 64 | 65 | if "Roberta" in self.tokenizer.__class__.__name__ and sample["text"].strip() == "": 66 | print(f"empty passage: {sample['title']}") 67 | sample["text"] = sample["title"] 68 | # if sample["text"].endswith("."): 69 | # sample["text"] = sample["text"][:-1] 70 | 71 | para_codes = self.tokenizer.encode_plus(sample["title"].strip(), text_pair=sample['text'].strip(), 72 | truncation=True, max_length=self.max_c_len, return_tensors="pt") 73 | para_codes['p_id'] = sample['p_id'] 74 | 75 | return para_codes 76 | 77 | def __len__(self): 78 | return len(self.data) 79 | 80 | 81 | def em_collate(samples): 82 | if len(samples) == 0: 83 | return {} 84 | 85 | batch = { 86 | "p_id": [s['p_id'] for s in samples], 87 | 'input_ids': collate_tokens([s['input_ids'].view(-1) for s in samples], 0), 88 | 'input_mask': collate_tokens([s['attention_mask'].view(-1) for s in samples], 0), 89 | } 90 | 91 | if "token_type_ids" in samples[0]: 92 | batch["input_type_ids"] = collate_tokens([s['token_type_ids'].view(-1) for s in samples], 0) 93 | 94 | return batch 95 | -------------------------------------------------------------------------------- /mdr/retrieval/data/fever_dataset.py: -------------------------------------------------------------------------------- 1 | from torch import normal 2 | from torch.utils.data import Dataset 3 | import torch 4 | import json 5 | import random 6 | import unicodedata 7 | import re 8 | 9 | def normalize(text): 10 | """Resolve different type of unicode encodings.""" 11 | return unicodedata.normalize('NFD', text) 12 | 13 | def convert_brc(string): 14 | string = re.sub('-LRB-', '(', string) 15 | string = re.sub('-RRB-', ')', string) 16 | string = re.sub('-LSB-', '[', string) 17 | string = re.sub('-RSB-', ']', string) 18 | string = re.sub('-LCB-', '{', string) 19 | string = re.sub('-RCB-', '}', string) 20 | string = re.sub('-COLON-', ':', string) 21 | return string 22 | 23 | class FeverDataset(Dataset): 24 | 25 | def __init__(self, 26 | tokenizer, 27 | data_path, 28 | max_q_len, 29 | max_q_sp_len, 30 | max_c_len, 31 | train=False, 32 | ): 33 | super().__init__() 34 | self.tokenizer = tokenizer 35 | self.max_q_len = max_q_len 36 | self.max_c_len = max_c_len 37 | self.max_q_sp_len = max_q_sp_len 38 | self.train = train 39 | print(f"Loading data from {data_path}") 40 | self.data = [json.loads(line) for line in open(data_path).readlines()] 41 | print(f"Total sample count {len(self.data)}") 42 | 43 | def encode_para(self, para, max_len): 44 | para["title"] = normalize(para["title"]) 45 | # para["text"] = convert_brc(para["text"]) 46 | 47 | return self.tokenizer.encode_plus(para["title"].strip(), text_pair=para["text"].strip(), max_length=max_len, return_tensors="pt") 48 | 49 | def __getitem__(self, index): 50 | sample = self.data[index] 51 | question = sample["claim"] 52 | 53 | evidence_multi = [e for e in sample["evidence"] if len(set([p["title"] for p in e])) > 1] 54 | neg_paras = sample["tfidf_neg"] + sample["linked_neg"] 55 | 56 | if self.train: 57 | random.shuffle(evidence_multi) 58 | random.shuffle(neg_paras) 59 | start_para, bridge_para = evidence_multi[0][0], evidence_multi[0][1] 60 | 61 | start_para_codes = self.encode_para(start_para, self.max_c_len) 62 | bridge_para_codes = self.encode_para(bridge_para, self.max_c_len) 63 | neg_codes_1 = self.encode_para(neg_paras[0], self.max_c_len) 64 | neg_codes_2 = self.encode_para(neg_paras[1], self.max_c_len) 65 | 66 | q_sp_codes = self.tokenizer.encode_plus(question, text_pair=start_para["text"].strip(), max_length=self.max_q_sp_len, return_tensors="pt") 67 | q_codes = self.tokenizer.encode_plus(question, max_length=self.max_q_len, return_tensors="pt") 68 | 69 | return { 70 | "q_codes": q_codes, 71 | "q_sp_codes": q_sp_codes, 72 | "start_para_codes": start_para_codes, 73 | "bridge_para_codes": bridge_para_codes, 74 | "neg_codes_1": neg_codes_1, 75 | "neg_codes_2": neg_codes_2, 76 | } 77 | 78 | def __len__(self): 79 | return len(self.data) 80 | 81 | -------------------------------------------------------------------------------- /mdr/retrieval/data/mhop_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import random 4 | 5 | from .data_utils import collate_tokens 6 | 7 | 8 | class MhopDataset(Dataset): 9 | 10 | def __init__(self, tokenizer, data_path, max_q_len, max_q_sp_len, max_c_len, train=False): 11 | super().__init__() 12 | self.tokenizer = tokenizer 13 | self.max_q_len = max_q_len 14 | self.max_c_len = max_c_len 15 | self.max_q_sp_len = max_q_sp_len 16 | self.train = train 17 | print(f"Loading data from {data_path}") 18 | self.data = [json.loads(line) for line in open(data_path).readlines()] 19 | if train: 20 | 21 | # import pdb 22 | # pdb.set_trace() 23 | 24 | # debug TODO: remove for final release 25 | for idx in range(len(self.data)): 26 | self.data[idx]["neg_paras"] = self.data[idx]["tfidf_neg"] 27 | self.data = [_ for _ in self.data if len(_["neg_paras"]) >= 2] 28 | else: 29 | for idx in range(len(self.data)): 30 | self.data[idx]["neg_paras"] = self.data[idx]["tfidf_neg"] 31 | self.data = [_ for _ in self.data if len(_["neg_paras"]) >= 2] 32 | 33 | print(f"Total sample count {len(self.data)}") 34 | 35 | def encode_para(self, para, max_len): 36 | return self.tokenizer.encode_plus(para["title"].strip(), text_pair=para["text"].strip(), 37 | truncation=True, max_length=max_len, return_tensors="pt") 38 | 39 | def __getitem__(self, index): 40 | sample = self.data[index] 41 | question = sample['question'] 42 | if question.endswith("?"): 43 | question = question[:-1] 44 | if sample["type"] == "comparison": 45 | random.shuffle(sample["pos_paras"]) 46 | start_para, bridge_para = sample["pos_paras"] 47 | else: 48 | for para in sample["pos_paras"]: 49 | if para["title"] != sample["bridge"]: 50 | start_para = para 51 | else: 52 | bridge_para = para 53 | if self.train: 54 | random.shuffle(sample["neg_paras"]) 55 | 56 | start_para_codes = self.encode_para(start_para, self.max_c_len) 57 | bridge_para_codes = self.encode_para(bridge_para, self.max_c_len) 58 | neg_codes_1 = self.encode_para(sample["neg_paras"][0], self.max_c_len) 59 | neg_codes_2 = self.encode_para(sample["neg_paras"][1], self.max_c_len) 60 | 61 | q_sp_codes = self.tokenizer.encode_plus(question, text_pair=start_para["text"].strip(), 62 | truncation=True, max_length=self.max_q_sp_len, return_tensors="pt") 63 | q_codes = self.tokenizer.encode_plus(question, truncation=True, max_length=self.max_q_len, return_tensors="pt") 64 | 65 | return { 66 | "q_codes": q_codes, 67 | "q_sp_codes": q_sp_codes, 68 | "start_para_codes": start_para_codes, 69 | "bridge_para_codes": bridge_para_codes, 70 | "neg_codes_1": neg_codes_1, 71 | "neg_codes_2": neg_codes_2, 72 | } 73 | 74 | def __len__(self): 75 | return len(self.data) 76 | 77 | 78 | def mhop_collate(samples, pad_id=0): 79 | if len(samples) == 0: 80 | return {} 81 | 82 | batch = { 83 | 'q_input_ids': collate_tokens([s["q_codes"]["input_ids"].view(-1) for s in samples], 0), 84 | 'q_mask': collate_tokens([s["q_codes"]["attention_mask"].view(-1) for s in samples], 0), 85 | 86 | 'q_sp_input_ids': collate_tokens([s["q_sp_codes"]["input_ids"].view(-1) for s in samples], 0), 87 | 'q_sp_mask': collate_tokens([s["q_sp_codes"]["attention_mask"].view(-1) for s in samples], 0), 88 | 89 | 'c1_input_ids': collate_tokens([s["start_para_codes"]["input_ids"] for s in samples], 0), 90 | 'c1_mask': collate_tokens([s["start_para_codes"]["attention_mask"] for s in samples], 0), 91 | 92 | 'c2_input_ids': collate_tokens([s["bridge_para_codes"]["input_ids"] for s in samples], 0), 93 | 'c2_mask': collate_tokens([s["bridge_para_codes"]["attention_mask"] for s in samples], 0), 94 | 95 | 'neg1_input_ids': collate_tokens([s["neg_codes_1"]["input_ids"] for s in samples], 0), 96 | 'neg1_mask': collate_tokens([s["neg_codes_1"]["attention_mask"] for s in samples], 0), 97 | 98 | 'neg2_input_ids': collate_tokens([s["neg_codes_2"]["input_ids"] for s in samples], 0), 99 | 'neg2_mask': collate_tokens([s["neg_codes_2"]["attention_mask"] for s in samples], 0), 100 | 101 | } 102 | 103 | if "token_type_ids" in samples[0]["q_codes"]: 104 | batch.update({ 105 | 'q_type_ids': collate_tokens([s["q_codes"]["token_type_ids"].view(-1) for s in samples], 0), 106 | 'c1_type_ids': collate_tokens([s["start_para_codes"]["token_type_ids"] for s in samples], 0), 107 | 'c2_type_ids': collate_tokens([s["bridge_para_codes"]["token_type_ids"] for s in samples], 0), 108 | "q_sp_type_ids": collate_tokens([s["q_sp_codes"]["token_type_ids"].view(-1) for s in samples], 0), 109 | 'neg1_type_ids': collate_tokens([s["neg_codes_1"]["token_type_ids"] for s in samples], 0), 110 | 'neg2_type_ids': collate_tokens([s["neg_codes_2"]["token_type_ids"] for s in samples], 0), 111 | }) 112 | 113 | if "sent_ids" in samples[0]["start_para_codes"]: 114 | batch["c1_sent_target"] = collate_tokens([s["start_para_codes"]["sent_ids"] for s in samples], -1) 115 | batch["c1_sent_offsets"] = collate_tokens([s["start_para_codes"]["sent_offsets"] for s in samples], 0), 116 | 117 | return batch 118 | -------------------------------------------------------------------------------- /mdr/retrieval/decomposed_analysis.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def decomposed_errors(): 5 | top1_pred = [json.loads(l) for l in open("/private/home/xwhan/data/hotpot/dense_val_b1_top1.json").readlines()] 6 | analysis_folder = "/private/home/xwhan/data/hotpot/analysis" 7 | 8 | start_errors, bridge_errors, failed = [], [], [] 9 | correct = [] 10 | for item in top1_pred: 11 | pred_titles = [_[0] for _ in item["candidate_chains"][0]] 12 | gold_titles = [_[0] for _ in item["sp"]] 13 | if set(pred_titles) == set(gold_titles): 14 | if item["type"] == "bridge": 15 | correct.append(item) 16 | continue 17 | if item["type"] == "bridge": 18 | start_title = None 19 | for t in gold_titles: 20 | if t != item["bridge"]: 21 | start_title = t 22 | assert start_title is not None 23 | if item["bridge"] in pred_titles and start_title not in pred_titles: 24 | start_errors.append(item) 25 | elif item["bridge"] not in pred_titles and start_title in pred_titles: 26 | bridge_errors.append(item) 27 | else: 28 | failed.append(item) 29 | 30 | with open(analysis_folder + "/correct.json", "w") as g: 31 | for _ in correct: 32 | _["predicted"] = _.pop("candidate_chains")[0] 33 | g.write(json.dumps(_) + "\n") 34 | 35 | with open(analysis_folder + "/start_errors.json", "w") as g: 36 | for _ in start_errors: 37 | _["predicted"] = _.pop("candidate_chains")[0] 38 | g.write(json.dumps(_) + "\n") 39 | 40 | with open(analysis_folder + "/bridge_errors.json", "w") as g: 41 | for _ in bridge_errors: 42 | _["predicted"] = _.pop("candidate_chains")[0] 43 | g.write(json.dumps(_) + "\n") 44 | 45 | with open(analysis_folder + "/total_errors.json", "w") as g: 46 | for _ in failed: 47 | _["predicted"] = _.pop("candidate_chains")[0] 48 | g.write(json.dumps(_) + "\n") 49 | 50 | 51 | print(len(correct)) 52 | print(len(start_errors)) 53 | print(len(bridge_errors)) 54 | print(len(failed)) 55 | 56 | import random 57 | def collect_gold_decomposition(): 58 | """ 59 | interactively collect 60 | """ 61 | dev_qdmr = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/dev.json").readlines()] 62 | bridge_dev = [_ for _ in dev_qdmr if _["type"] == "bridge"] 63 | 64 | random.shuffle(bridge_dev) 65 | idx = 0 66 | samples_to_inspect = [] 67 | while True: 68 | print(f"\n-----{len(samples_to_inspect)} samples collected so far-----") 69 | sample = bridge_dev[idx] 70 | idx += 1 71 | print(f"Original Q: {sample['q']}") 72 | print(f"Decomposed Q: {sample['q_decom']}") 73 | print(f"Supporting Passages: {sample['sp']}") 74 | subq1 = input("Type SUB Q1:") 75 | if subq1 == "bad": 76 | continue 77 | elif subq1 == "stop": 78 | break 79 | subq2 = input("Type SUB Q2:") 80 | samples_to_inspect.append({ 81 | "id": sample["id"], 82 | "sp": sample["sp"], 83 | "orig_q": sample['q'], 84 | "subQ_1": subq1, 85 | "subQ_2": subq2 86 | }) 87 | 88 | print(f"{len(samples_to_inspect)} samples collected in total..") 89 | 90 | with open("/private/home/xwhan/data/QDMR/inspect.json", "w") as g: 91 | for _ in samples_to_inspect: 92 | g.write(json.dumps(_) + "\n") 93 | 94 | def qdmr_utils(): 95 | """ 96 | change file format for decomposed and end-to-end retrieval 97 | """ 98 | qdmr_data = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/inspect.json").readlines()] 99 | 100 | mhop_data, decomposed_data = [], [] 101 | for idx, item in enumerate(qdmr_data): 102 | if idx in [65,66,67]: 103 | continue 104 | sp = [_["title"] for _ in item["sp"]] 105 | question = item["orig_q"] 106 | mhop_data.append({ 107 | "question": question, 108 | "sp": sp, 109 | "type": "bridge", 110 | "_id": item["id"] 111 | }) 112 | decomposed_data.append(item) 113 | 114 | # with open("/private/home/xwhan/data/QDMR/qdmr_decomposed.json", "w") as g: 115 | # for item in decomposed_data: 116 | # g.write(json.dumps(item) + "\n") 117 | 118 | with open("/private/home/xwhan/data/QDMR/qdmr_e2e.json", "w") as g: 119 | for item in mhop_data: 120 | g.write(json.dumps(item) + "\n") 121 | 122 | 123 | def analyze_results(): 124 | decomposed_results = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/qdmr_decomposed_results.json")] 125 | e2e_results = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/qdmr_e2e_results.json")] 126 | better = 0 127 | worse = 0 128 | both = 0 129 | for res1, res2 in zip(decomposed_results, e2e_results): 130 | sp_titles = set([_[0] for _ in res1["sp"]]) 131 | 132 | res1_top1 = [_[0] for _ in res1["candidate_chains"][0]] 133 | res2_top1 = [_[0] for _ in res2["candidate_chains"][0]] 134 | 135 | assert res1["_id"] == res2["_id"] 136 | 137 | question = res1["question"] 138 | q_pairs = res1["q_pairs"] 139 | 140 | if set(res2_top1) == sp_titles and set(res1_top1) != sp_titles: 141 | # print(sp_titles) 142 | # import pdb; pdb.set_trace() 143 | better += 1 144 | elif set(res2_top1) != sp_titles and set(res1_top1) == sp_titles: 145 | worse += 1 146 | elif set(res2_top1) == sp_titles and set(res1_top1) == sp_titles: 147 | both += 1 148 | 149 | print(both) 150 | print(better) 151 | print(worse) 152 | print(len(decomposed_results)) 153 | 154 | if __name__ == "__main__": 155 | # collect_gold_decomposition() 156 | # qdmr_utils() 157 | 158 | analyze_results() 159 | -------------------------------------------------------------------------------- /mdr/retrieval/interactive_retrieval.py: -------------------------------------------------------------------------------- 1 | from models.mhop_retriever import MhopRetriever 2 | 3 | import faiss 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | from transformers import AutoConfig, AutoTokenizer 8 | import json 9 | import logging 10 | import argparse 11 | from .utils.utils import (load_saved, move_to_cuda) 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--topk', type=int, default=2, help="topk paths") 15 | parser.add_argument('--num-workers', type=int, default=10) 16 | parser.add_argument('--max-q-len', type=int, default=70) 17 | parser.add_argument('--max-c-len', type=int, default=300) 18 | parser.add_argument('--max-q-sp-len', type=int, default=350) 19 | parser.add_argument('--model-name', type=str, default='bert-base-uncased') 20 | parser.add_argument('--gpu', action="store_true") 21 | parser.add_argument('--shared-encoder', action="store_true") 22 | parser.add_argument("--stop-drop", default=0, type=float) 23 | args = parser.parse_args() 24 | 25 | index_path = "index/abstracts_v0_fixed.npy" 26 | corpus_path = "index/abstracts_id2doc.json" 27 | model_path = "logs/08-05-2020/baseline_v0_fixed-seed16-bsz150-fp16True-lr2e-05-decay0.0-warm0.1-valbsz3000-sharedTrue-multi1-schemenone/checkpoint_best.pt" 28 | 29 | print(f"Loading corpus and index...") 30 | id2doc = json.load(open(corpus_path)) 31 | index_vectors = np.load(index_path).astype('float32') 32 | 33 | index = faiss.IndexFlatIP(768) 34 | index.add(index_vectors) 35 | res = faiss.StandardGpuResources() 36 | index = faiss.index_cpu_to_gpu(res, 1, index) 37 | 38 | print(f"Loading retrieval model...") 39 | bert_config = AutoConfig.from_pretrained("bert-base-uncased") 40 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 41 | model = MhopRetriever(bert_config, args) 42 | model = load_saved(model, args.model_path, exact=False) 43 | 44 | cuda = torch.device('cuda') 45 | model.to(cuda) 46 | from apex import amp 47 | 48 | model = amp.initialize(model, opt_level='O1') 49 | model.eval() 50 | 51 | while True: 52 | question = input("Type Question:") 53 | question = "the Danish musicians who died in 1931" 54 | batch_q_encodes = tokenizer.batch_encode_plus(["question"], max_length=args.max_q_len, pad_to_max_length=True, 55 | return_tensors="pt") 56 | batch_q_encodes = move_to_cuda(dict(batch_q_encodes)) 57 | q_embeds = model.encode_q(batch_q_encodes["input_ids"], batch_q_encodes["attention_mask"], 58 | batch_q_encodes.get("token_type_ids", None)) 59 | q_embeds_numpy = q_embeds.cpu().contiguous().numpy() 60 | D, I = index.search(q_embeds_numpy, 1) 61 | 62 | print(I) 63 | -------------------------------------------------------------------------------- /mdr/retrieval/models/hop1_retriever.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, BertConfig, BertPreTrainedModel 2 | import torch.nn as nn 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | from torch.nn import CrossEntropyLoss 6 | 7 | 8 | class Retriever1hop(nn.Module): 9 | 10 | def __init__(self, 11 | config, 12 | args 13 | ): 14 | super().__init__() 15 | 16 | self.bert_q = BertModel.from_pretrained(args.bert_model_name) 17 | self.bert_c = BertModel.from_pretrained(args.bert_model_name) 18 | self.hidden_size = config.hidden_size 19 | 20 | def forward(self, batch): 21 | # representations 22 | q_hidden_states = self.bert_q(batch['q_input_ids'], batch['q_mask'], batch['q_type_ids'])[0] 23 | q_cls = q_hidden_states[:,0,:] 24 | c_hidden_states = self.bert_c(batch['c_input_ids'], batch['c_mask'], batch['c_type_ids'])[0] 25 | c_cls = c_hidden_states[:, 0, :] 26 | neg_c_cls = self.bert_c(batch['neg_input_ids'], batch['neg_mask'], batch['neg_type_ids'])[0][:, 0, :] 27 | 28 | # sentence-level representations 29 | gather_index = batch["c_sent_offsets"].unsqueeze(2).expand(-1,-1,self.hidden_size) # B x |S| x h 30 | c_sent_rep = torch.gather(c_hidden_states, 1, gather_index) 31 | 32 | outputs = {'q': q_cls, 'c':c_cls, "neg_c": neg_c_cls, "c_sent_rep": c_sent_rep} 33 | 34 | return outputs 35 | 36 | -------------------------------------------------------------------------------- /mdr/retrieval/models/mhop_retriever.py: -------------------------------------------------------------------------------- 1 | from torch import embedding 2 | from transformers import AutoModel 3 | import torch.nn as nn 4 | import torch 5 | 6 | 7 | class RobertaRetriever(nn.Module): 8 | 9 | def __init__(self, config, args): 10 | super().__init__() 11 | 12 | self.encoder = AutoModel.from_pretrained(args.model_name) 13 | self.project = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size), 14 | nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)) 15 | 16 | def encode_seq(self, input_ids, mask, token_type_ids=None): 17 | cls_rep = self.encoder(input_ids, mask, token_type_ids)[0][:, 0, :] 18 | vector = self.project(cls_rep) 19 | return vector 20 | 21 | def forward(self, batch): 22 | c1 = self.encode_seq(batch['c1_input_ids'], batch['c1_mask']) 23 | c2 = self.encode_seq(batch['c2_input_ids'], batch['c2_mask']) 24 | 25 | neg_1 = self.encode_seq(batch['neg1_input_ids'], batch['neg1_mask']) 26 | neg_2 = self.encode_seq(batch['neg2_input_ids'], batch['neg2_mask']) 27 | 28 | q = self.encode_seq(batch['q_input_ids'], batch['q_mask']) 29 | q_sp1 = self.encode_seq(batch['q_sp_input_ids'], batch['q_sp_mask']) 30 | vectors = {'q': q, 'c1': c1, "c2": c2, "neg_1": neg_1, "neg_2": neg_2, "q_sp1": q_sp1} 31 | return vectors 32 | 33 | def encode_q(self, input_ids, mask, token_type_ids=None): 34 | return self.encode_seq(input_ids, mask, token_type_ids) 35 | 36 | 37 | class RobertaMomentumRetriever(nn.Module): 38 | 39 | def __init__(self, config, args): 40 | super().__init__() 41 | 42 | self.encoder_q = RobertaRetriever(config, args) 43 | self.encoder_k = RobertaRetriever(config, args) 44 | 45 | if args.init_retriever != "": 46 | print(f"Load pretrained retriever from {args.init_retriever}") 47 | self.load_retriever(args.init_retriever) 48 | 49 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 50 | param_k.data.copy_(param_q.data) # initialize 51 | param_k.requires_grad = False # not update by gradient 52 | 53 | self.k = args.k 54 | self.m = args.m 55 | self.register_buffer("queue", torch.randn(self.k, config.hidden_size)) 56 | # add layernorm? 57 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 58 | 59 | def load_retriever(self, path): 60 | state_dict = torch.load(path) 61 | 62 | def filter_name(x): return x[7:] if x.startswith('module.') else x 63 | 64 | state_dict = {filter_name(k): v for k, v in state_dict.items() if filter_name(k) in self.encoder_q.state_dict()} 65 | self.encoder_q.load_state_dict(state_dict) 66 | return 67 | 68 | @torch.no_grad() 69 | def momentum_update_key_encoder(self): 70 | """ 71 | Momentum update of the key encoder 72 | """ 73 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 74 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 75 | 76 | @torch.no_grad() 77 | def dequeue_and_enqueue(self, embeddings): 78 | """ 79 | memory bank of previous context embeddings, c1 and c2 80 | """ 81 | # gather keys before updating queue 82 | batch_size = embeddings.shape[0] 83 | ptr = int(self.queue_ptr) 84 | if ptr + batch_size > self.k: 85 | batch_size = self.k - ptr 86 | embeddings = embeddings[:batch_size] 87 | 88 | # if self.k % batch_size != 0: 89 | # return 90 | # assert self.k % batch_size == 0 # for simplicity 91 | 92 | # replace the keys at ptr (dequeue and enqueue) 93 | self.queue[ptr:ptr + batch_size, :] = embeddings 94 | 95 | ptr = (ptr + batch_size) % self.k # move pointer 96 | self.queue_ptr[0] = ptr 97 | return 98 | 99 | def forward(self, batch): 100 | q = self.encoder_q.encode_seq(batch['q_input_ids'], batch['q_mask']) 101 | q_sp1 = self.encoder_q.encode_seq(batch['q_sp_input_ids'], batch['q_sp_mask']) 102 | 103 | if self.training: 104 | with torch.no_grad(): 105 | c1 = self.encoder_k.encode_seq(batch['c1_input_ids'], batch['c1_mask']) 106 | c2 = self.encoder_k.encode_seq(batch['c2_input_ids'], batch['c2_mask']) 107 | 108 | neg_1 = self.encoder_k.encode_seq(batch['neg1_input_ids'], batch['neg1_mask']) 109 | neg_2 = self.encoder_k.encode_seq(batch['neg2_input_ids'], batch['neg2_mask']) 110 | else: 111 | # whether to use the momentum encoder for inference 112 | c1 = self.encoder_k.encode_seq(batch['c1_input_ids'], batch['c1_mask']) 113 | c2 = self.encoder_k.encode_seq(batch['c2_input_ids'], batch['c2_mask']) 114 | 115 | neg_1 = self.encoder_k.encode_seq(batch['neg1_input_ids'], batch['neg1_mask']) 116 | neg_2 = self.encoder_k.encode_seq(batch['neg2_input_ids'], batch['neg2_mask']) 117 | 118 | vectors = {'q': q, 'c1': c1, "c2": c2, "neg_1": neg_1, "neg_2": neg_2, "q_sp1": q_sp1} 119 | return vectors 120 | -------------------------------------------------------------------------------- /mdr/retrieval/utils/gen_index_id_map.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | mapping = {} 4 | with open('../data/para_doc.db') as f_in: 5 | for idx, line in enumerate(f_in): 6 | sample = json.loads(line.strip()) 7 | mapping[idx] = sample['id'] 8 | with open('index_data/idx_id.json', 'w') as f_out: 9 | json.dump(mapping, f_out) 10 | 11 | -------------------------------------------------------------------------------- /mdr/retrieval/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_tokens_to_ids(vocab, tokens): 28 | """Converts a sequence of tokens into ids using the vocab.""" 29 | ids = [] 30 | for token in tokens: 31 | ids.append(vocab[token]) 32 | return ids 33 | 34 | def whitespace_tokenize(text): 35 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 36 | text = text.strip() 37 | if not text: 38 | return [] 39 | tokens = text.split() 40 | return tokens 41 | 42 | 43 | def convert_to_unicode(text): 44 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 45 | if six.PY3: 46 | if isinstance(text, str): 47 | return text 48 | elif isinstance(text, bytes): 49 | return text.decode("utf-8", "ignore") 50 | else: 51 | raise ValueError("Unsupported string type: %s" % (type(text))) 52 | elif six.PY2: 53 | if isinstance(text, str): 54 | return text.decode("utf-8", "ignore") 55 | elif isinstance(text, unicode): 56 | return text 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | else: 60 | raise ValueError("Not running on Python2 or Python 3?") 61 | 62 | 63 | def _is_whitespace(char): 64 | """Checks whether `chars` is a whitespace character.""" 65 | # \t, \n, and \r are technically contorl characters but we treat them 66 | # as whitespace since they are generally considered as such. 67 | if char == " " or char == "\t" or char == "\n" or char == "\r": 68 | return True 69 | cat = unicodedata.category(char) 70 | if cat == "Zs": 71 | return True 72 | return False 73 | 74 | 75 | def _is_control(char): 76 | """Checks whether `chars` is a control character.""" 77 | # These are technically control characters but we count them as whitespace 78 | # characters. 79 | if char == "\t" or char == "\n" or char == "\r": 80 | return False 81 | cat = unicodedata.category(char) 82 | if cat.startswith("C"): 83 | return True 84 | return False 85 | 86 | class BasicTokenizer(object): 87 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 88 | 89 | def __init__(self, do_lower_case=True): 90 | """Constructs a BasicTokenizer. 91 | Args: 92 | do_lower_case: Whether to lower case the input. 93 | """ 94 | self.do_lower_case = do_lower_case 95 | 96 | def tokenize(self, text): 97 | """Tokenizes a piece of text.""" 98 | text = convert_to_unicode(text) 99 | text = self._clean_text(text) 100 | orig_tokens = whitespace_tokenize(text) 101 | split_tokens = [] 102 | for token in orig_tokens: 103 | if self.do_lower_case: 104 | token = token.lower() 105 | token = self._run_strip_accents(token) 106 | split_tokens.extend(self._run_split_on_punc(token)) 107 | 108 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 109 | return output_tokens 110 | 111 | def _run_strip_accents(self, text): 112 | """Strips accents from a piece of text.""" 113 | text = unicodedata.normalize("NFD", text) 114 | output = [] 115 | for char in text: 116 | cat = unicodedata.category(char) 117 | if cat == "Mn": 118 | continue 119 | output.append(char) 120 | return "".join(output) 121 | 122 | def _run_split_on_punc(self, text): 123 | """Splits punctuation on a piece of text.""" 124 | chars = list(text) 125 | i = 0 126 | start_new_word = True 127 | output = [] 128 | while i < len(chars): 129 | char = chars[i] 130 | if _is_punctuation(char): 131 | output.append([char]) 132 | start_new_word = True 133 | else: 134 | if start_new_word: 135 | output.append([]) 136 | start_new_word = False 137 | output[-1].append(char) 138 | i += 1 139 | 140 | return ["".join(x) for x in output] 141 | 142 | def _clean_text(self, text): 143 | """Performs invalid character removal and whitespace cleanup on text.""" 144 | output = [] 145 | for char in text: 146 | cp = ord(char) 147 | if cp == 0 or cp == 0xfffd or _is_control(char): 148 | continue 149 | if _is_whitespace(char): 150 | output.append(" ") 151 | else: 152 | output.append(char) 153 | return "".join(output) 154 | 155 | 156 | def _is_punctuation(char): 157 | """Checks whether `chars` is a punctuation character.""" 158 | cp = ord(char) 159 | # We treat all non-letter/number ASCII as punctuation. 160 | # Characters such as "^", "$", and "`" are not in the Unicode 161 | # Punctuation class but we treat them as punctuation anyways, for 162 | # consistency. 163 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 164 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 165 | return True 166 | cat = unicodedata.category(char) 167 | if cat.startswith("P"): 168 | return True 169 | return False 170 | 171 | 172 | def process(s, tokenizer): 173 | try: 174 | return tokenizer.tokenize(s) 175 | except: 176 | print('failed on', s) 177 | raise 178 | 179 | if __name__ == "__main__": 180 | _is_whitespace("a") 181 | -------------------------------------------------------------------------------- /mdr/retrieval/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sqlite3 3 | import unicodedata 4 | 5 | 6 | def load_saved(model, path, exact=True): 7 | try: 8 | state_dict = torch.load(path) 9 | except: 10 | state_dict = torch.load(path, map_location=torch.device('cpu')) 11 | 12 | def filter_name(x): 13 | return x[7:] if x.startswith('module.') else x 14 | 15 | if exact: 16 | state_dict = {filter_name(k): v for (k, v) in state_dict.items()} 17 | else: 18 | state_dict = {filter_name(k): v for (k, v) in state_dict.items() if filter_name(k) in model.state_dict()} 19 | model.load_state_dict(state_dict) 20 | return model 21 | 22 | 23 | def move_to_cuda(sample): 24 | if len(sample) == 0: 25 | return {} 26 | 27 | def _move_to_cuda(maybe_tensor): 28 | if torch.is_tensor(maybe_tensor): 29 | return maybe_tensor.cuda() 30 | elif isinstance(maybe_tensor, dict): 31 | return { 32 | key: _move_to_cuda(value) 33 | for key, value in maybe_tensor.items() 34 | } 35 | elif isinstance(maybe_tensor, list): 36 | return [_move_to_cuda(x) for x in maybe_tensor] 37 | else: 38 | return maybe_tensor 39 | 40 | return _move_to_cuda(sample) 41 | 42 | 43 | def convert_to_half(sample): 44 | if len(sample) == 0: 45 | return {} 46 | 47 | def _convert_to_half(maybe_floatTensor): 48 | if torch.is_tensor(maybe_floatTensor) and maybe_floatTensor.type() == "torch.FloatTensor": 49 | return maybe_floatTensor.half() 50 | elif isinstance(maybe_floatTensor, dict): 51 | return { 52 | key: _convert_to_half(value) 53 | for key, value in maybe_floatTensor.items() 54 | } 55 | elif isinstance(maybe_floatTensor, list): 56 | return [_convert_to_half(x) for x in maybe_floatTensor] 57 | else: 58 | return maybe_floatTensor 59 | 60 | return _convert_to_half(sample) 61 | 62 | 63 | class AverageMeter(object): 64 | """Computes and stores the average and current value""" 65 | 66 | def __init__(self): 67 | self.reset() 68 | 69 | def reset(self): 70 | self.val = 0 71 | self.avg = 0 72 | self.sum = 0 73 | self.count = 0 74 | 75 | def update(self, val, n=1): 76 | self.val = val 77 | self.sum += val * n 78 | self.count += n 79 | self.avg = self.sum / self.count 80 | 81 | 82 | def normalize(text): 83 | """Resolve different type of unicode encodings.""" 84 | return unicodedata.normalize('NFD', text) 85 | 86 | 87 | class DocDB(object): 88 | """Sqlite backed document storage. 89 | 90 | Implements get_doc_text(doc_id). 91 | """ 92 | 93 | def __init__(self, db_path=None): 94 | self.path = db_path 95 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 96 | 97 | def __enter__(self): 98 | return self 99 | 100 | def __exit__(self, *args): 101 | self.close() 102 | 103 | def close(self): 104 | """Close the connection to the database.""" 105 | self.connection.close() 106 | 107 | def get_doc_ids(self): 108 | """Fetch all ids of docs stored in the db.""" 109 | cursor = self.connection.cursor() 110 | cursor.execute("SELECT id FROM documents") 111 | results = [r[0] for r in cursor.fetchall()] 112 | cursor.close() 113 | return results 114 | 115 | def get_doc_text(self, doc_id): 116 | """Fetch the raw text of the doc for 'doc_id'.""" 117 | cursor = self.connection.cursor() 118 | cursor.execute( 119 | "SELECT text FROM documents WHERE id = ?", 120 | (normalize(doc_id),) 121 | ) 122 | result = cursor.fetchone() 123 | cursor.close() 124 | return result if result is None else result[0] 125 | 126 | 127 | def para_has_answer(answer, para, tokenizer): 128 | assert isinstance(answer, list) 129 | text = normalize(para) 130 | tokens = tokenizer.tokenize(text) 131 | text = tokens.words(uncased=True) 132 | assert len(text) == len(tokens) 133 | for single_answer in answer: 134 | single_answer = normalize(single_answer) 135 | single_answer = tokenizer.tokenize(single_answer) 136 | single_answer = single_answer.words(uncased=True) 137 | for i in range(0, len(text) - len(single_answer) + 1): 138 | if single_answer == text[i: i + len(single_answer)]: 139 | return True 140 | return False 141 | 142 | 143 | def complex_ans_recall(): 144 | """ 145 | calculate retrieval metrics for complexwebQ 146 | """ 147 | import json 148 | import numpy as np 149 | from basic_tokenizer import SimpleTokenizer 150 | tok = SimpleTokenizer() 151 | 152 | predictions = json.load( 153 | open("/private/home/xwhan/code/learning_to_retrieve_reasoning_paths/results/complexwebq_retrieval_res.json")) 154 | raw_dev = [json.loads(l) for l in open("/private/home/xwhan/data/ComplexWebQ/complexwebq_dev_qas.txt").readlines()] 155 | id2qas = {_["id"]: _ for _ in raw_dev} 156 | 157 | assert len(predictions) == len(raw_dev) 158 | answer_recalls = [] 159 | for item in predictions: 160 | qid = item["q_id"] 161 | title2passage = item["context"] 162 | gold_answers = id2qas[qid]["answer"] 163 | 164 | chain_coverage = [] 165 | for chain in item["topk_titles"]: 166 | chain_text = " ".join([title2passage[_] for _ in chain]) 167 | chain_coverage.append(para_has_answer(gold_answers, chain_text, tok)) 168 | answer_recalls.append(np.sum(chain_coverage) > 0) 169 | print(len(answer_recalls)) 170 | print(np.mean(answer_recalls)) 171 | 172 | 173 | if __name__ == "__main__": 174 | complex_ans_recall() 175 | -------------------------------------------------------------------------------- /mdr_encode_corpus.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: encode text corpus into a store of dense vectors. 3 | 4 | Usage (adjust the batch size according to your GPU memory): 5 | 6 | export MODEL_CHECKPOINT=ckpts/mdr/doc_encoder.pt 7 | export CUDA_VISIBLE_DEVICES=0,1,2,3 8 | python mdr_encode_corpus.py \ 9 | --predict_batch_size 512 \ 10 | --model_name roberta-base \ 11 | --init_checkpoint ${MODEL_CHECKPOINT} \ 12 | --corpus_file data/corpus/hotpot-paragraph-5.tsv \ 13 | --embedding_prefix data/vector/mdr/hotpot-paragraph-5 \ 14 | --max_c_len 300 \ 15 | --num_workers 4 \ 16 | --num_shards 1 \ 17 | --shard_id 0 \ 18 | --strict 19 | 20 | """ 21 | import os 22 | import pathlib 23 | import pickle 24 | from tqdm import tqdm 25 | # from tqdm.auto import tqdm 26 | 27 | import torch 28 | from torch.utils.data import DataLoader, Subset 29 | 30 | from transformers import AutoConfig, AutoTokenizer 31 | from mdr.retrieval.config import encode_args 32 | from mdr.retrieval.data.encode_datasets import EmDataset, em_collate 33 | from mdr.retrieval.models.retriever import CtxEncoder, RobertaCtxEncoder 34 | from mdr.retrieval.utils.utils import move_to_cuda, load_saved 35 | 36 | 37 | def main(): 38 | args = encode_args() 39 | if args.fp16: 40 | import apex 41 | apex.amp.register_half_function(torch, 'einsum') 42 | 43 | if args.local_rank == -1 or args.no_cuda: 44 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 45 | n_gpu = torch.cuda.device_count() 46 | else: 47 | device = torch.device("cuda", args.local_rank) 48 | n_gpu = 1 49 | torch.distributed.init_process_group(backend='nccl') 50 | 51 | bert_config = AutoConfig.from_pretrained(args.model_name) 52 | if "roberta" in args.model_name: 53 | model = RobertaCtxEncoder(bert_config, args) 54 | else: 55 | model = CtxEncoder(bert_config, args) 56 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 57 | 58 | eval_dataset = EmDataset(tokenizer, args.corpus_file, args.max_c_len, args.strict) 59 | shard_size = len(eval_dataset) // args.num_shards 60 | start_idx = args.shard_id * shard_size 61 | end_idx = start_idx + shard_size if args.shard_id != args.num_shards - 1 else len(eval_dataset) 62 | sub_eval_dataset = Subset(eval_dataset, list(range(start_idx, end_idx))) 63 | print(f'Producing encodings for passages [{start_idx:,d}, {end_idx:,d}) ' 64 | f'({args.shard_id}/{args.num_shards} of {len(eval_dataset):,d})') 65 | eval_dataloader = DataLoader(sub_eval_dataset, batch_size=args.predict_batch_size, collate_fn=em_collate, 66 | pin_memory=True, num_workers=args.num_workers) 67 | 68 | assert args.init_checkpoint != "" 69 | print(f'Loading the model checkpoint from {args.init_checkpoint} ...') 70 | model = load_saved(model, args.init_checkpoint, exact=False) 71 | model.to(device) 72 | model.eval() 73 | 74 | if args.fp16: 75 | try: 76 | from apex import amp 77 | except ImportError: 78 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 79 | model = amp.initialize(model, opt_level=args.fp16_opt_level) 80 | 81 | if args.local_rank != -1: 82 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 83 | output_device=args.local_rank) 84 | elif n_gpu > 1: 85 | model = torch.nn.DataParallel(model) 86 | 87 | embeddings = predict(model, eval_dataloader) 88 | assert len(embeddings) == end_idx - start_idx 89 | assert embeddings[0][0] == eval_dataset[start_idx]['p_id'] 90 | assert embeddings[-1][0] == eval_dataset[end_idx - 1]['p_id'] 91 | 92 | if args.strict: 93 | if 'strict' not in args.embedding_prefix: 94 | args.embedding_prefix = f"{args.embedding_prefix}-strict" 95 | out_file = f"{args.embedding_prefix}_{args.shard_id}.pkl" 96 | pathlib.Path(os.path.dirname(out_file)).mkdir(parents=True, exist_ok=True) 97 | print(f'Encoded passages into {len(embeddings)} x {embeddings[0][1].shape} embeddings, writing to {out_file} ...') 98 | with open(out_file, mode='wb') as f: 99 | pickle.dump(embeddings, f) 100 | 101 | 102 | def predict(model, eval_dataloader): 103 | model.eval() 104 | 105 | embeddings = [] 106 | for batch in tqdm(eval_dataloader): 107 | batch_to_feed = move_to_cuda(batch) 108 | with torch.no_grad(): 109 | out = model(batch_to_feed) 110 | batch_embedding = out['embed'].cpu().numpy() 111 | 112 | assert len(batch['p_id']) == batch_embedding.shape[0] 113 | embeddings.extend([(batch['p_id'][i], batch_embedding[i]) for i in range(batch_embedding.shape[0])]) 114 | 115 | model.train() 116 | return embeddings 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/models/__init__.py -------------------------------------------------------------------------------- /models/reranker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import BCEWithLogitsLoss 4 | 5 | from transformers import AutoModel # , AutoConfig 6 | 7 | from utils.rank_losses import list_mle # , list_net 8 | from .union_model import get_paras_weight 9 | 10 | 11 | class Reranker(nn.Module): 12 | 13 | def __init__(self, encoder_name): 14 | super(Reranker, self).__init__() 15 | self.encoder = AutoModel.from_pretrained(encoder_name) 16 | self.hidden_size = self.encoder.config.hidden_size 17 | 18 | self.reranker = nn.Linear(self.hidden_size, 1) 19 | 20 | self.bce_loss = BCEWithLogitsLoss(reduction='none') 21 | 22 | def forward(self, batch): 23 | # (B, T, H) 24 | seq_hiddens = self.encoder(batch['input_ids'], batch['attention_mask'], batch.get('token_type_ids', None))[0] 25 | 26 | para_num = [len(paras_mark) for paras_mark in batch['paras_mark']] # (B,) 27 | para_hiddens = [seq_hiddens[i, paras_mark] for i, paras_mark in enumerate(batch['paras_mark'])] # (B, _P, H) 28 | # (B, _P) 29 | para_logits = self.reranker(torch.cat(para_hiddens, dim=0)).squeeze(-1).split(para_num, dim=0) 30 | 31 | if self.training: 32 | # (B, _P) 33 | paras_loss = self.bce_loss(torch.cat(para_logits), torch.cat(batch['paras_label'])).split(para_num, dim=0) 34 | para_loss = torch.stack([(_paras_loss * get_paras_weight(_paras_loss, obs_weight=-1)).sum() 35 | for _paras_loss in paras_loss], dim=0) # (B,) 36 | 37 | memory_loss = torch.zeros_like(para_loss) # (B,) 38 | for i, (_para_logits, _paras_label) in enumerate(zip(para_logits, batch['paras_label'])): 39 | if _paras_label.size(0) > 1 and _paras_label.max() != _paras_label.min(): 40 | memory_loss[i] = list_mle(_para_logits.unsqueeze(0), _paras_label.unsqueeze(0)) 41 | # memory_loss[i] = list_net(_para_logits.unsqueeze(0), _paras_label.unsqueeze(0), irrelevant_val=0.) 42 | 43 | loss = (para_loss + memory_loss).mean() 44 | 45 | return loss 46 | 47 | return para_logits 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==3.0.2 2 | scipy 3 | scikit-learn 4 | fuzzysearch 5 | python-Levenshtein 6 | elasticsearch 7 | redis 8 | tqdm 9 | regex 10 | ujson 11 | pexpect 12 | prettytable>=2.0.0 13 | flask-restful 14 | requests -------------------------------------------------------------------------------- /text_clean.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | import re 4 | import unicodedata 5 | from html import unescape 6 | from urllib.parse import unquote 7 | 8 | 9 | def is_whitespace(char): 10 | """Checks whether `chars` is a whitespace character.""" 11 | # \t, \n, and \r are technically control characters but we treat them 12 | # as whitespace since they are generally considered as such. 13 | if char in [" ", "\t", "\n", "\r"]: 14 | return True 15 | cat = unicodedata.category(char) 16 | if cat == "Zs": 17 | return True 18 | return False 19 | 20 | 21 | def is_control(char): 22 | """Checks whether `chars` is a control character.""" 23 | # These are technically control characters but we count them as whitespace characters. 24 | if char in ["\t", "\n", "\r"]: 25 | return False 26 | cat = unicodedata.category(char) 27 | if cat.startswith("C"): 28 | return True 29 | return False 30 | 31 | 32 | def is_punctuation(char): 33 | """Checks whether `chars` is a punctuation character.""" 34 | if char in ["~", "¥", "×"]: 35 | return True 36 | cp = ord(char) 37 | # We treat all non-letter/number ASCII as punctuation. 38 | # Characters such as "^", "$", and "`" are not in the Unicode 39 | # Punctuation class but we treat them as punctuation anyways, for 40 | # consistency. 41 | if 33 <= cp <= 47 or 58 <= cp <= 64 or 91 <= cp <= 96 or 123 <= cp <= 126: 42 | return True 43 | cat = unicodedata.category(char) 44 | if cat.startswith("P"): 45 | return True 46 | return False 47 | 48 | 49 | def is_chinese_char(char): 50 | """Checks whether CP is the codepoint of a CJK character.""" 51 | # This defines a "chinese character" as anything in the CJK Unicode block: 52 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 53 | # 54 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 55 | # despite its name. The modern Korean Hangul alphabet is a different block, 56 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 57 | # space-separated words, so they are not treated specially and handled 58 | # like the all of the other languages. 59 | cp = ord(char) 60 | if (0x4E00 <= cp <= 0x9FFF or 61 | 0x3400 <= cp <= 0x4DBF or 62 | 0x20000 <= cp <= 0x2A6DF or 63 | 0x2A700 <= cp <= 0x2B73F or 64 | 0x2B740 <= cp <= 0x2B81F or 65 | 0x2B820 <= cp <= 0x2CEAF or 66 | 0xF900 <= cp <= 0xFAFF or 67 | 0x2F800 <= cp <= 0x2FA1F): 68 | return True 69 | 70 | return False 71 | 72 | 73 | def clean_text(text): 74 | # unescaped_text = unescape(text) 75 | # unquoted_text = unquote(unescaped_text, 'utf-8') 76 | output = [] 77 | for char in text: 78 | cp = ord(char) 79 | if cp == 0 or cp == 0xfffd or is_control(char): 80 | continue 81 | if is_whitespace(char): 82 | output.append(" ") 83 | # elif char in ["–"]: 84 | # output.append("-") 85 | else: 86 | output.append(char) 87 | output_text = ''.join(output) 88 | # output_text = re.sub(r' {2,}', ' ', output_text).strip() 89 | return output_text 90 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zycdev/AISO/e7fd24ef009f9467997d7c14056d9afd13d7031f/utils/__init__.py -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def get_device(model): 8 | return next(model.parameters()).device 9 | 10 | 11 | def init_weights(modules: List): 12 | for module in modules: 13 | if isinstance(module, (nn.Linear, nn.Embedding)): 14 | module.weight.data.normal_(mean=0.0, std=0.02) 15 | elif isinstance(module, nn.LayerNorm): 16 | module.bias.data.zero_() 17 | module.weight.data.fill_(1.0) 18 | if isinstance(module, nn.Linear) and module.bias is not None: 19 | module.bias.data.zero_() 20 | 21 | 22 | def load_state(model, path, exact=True, strict=True): 23 | state_dict = torch.load(path, map_location=torch.device('cpu')) 24 | 25 | def filter_name(x): 26 | return x[7:] if x.startswith('module.') else x 27 | 28 | if exact: 29 | state_dict = {filter_name(k): v for (k, v) in state_dict.items()} 30 | else: 31 | state_dict = {filter_name(k): v for (k, v) in state_dict.items() if filter_name(k) in model.state_dict()} 32 | model.load_state_dict(state_dict, strict=strict) 33 | return model 34 | 35 | 36 | def save_model(model, path): 37 | model_to_save = model.module if hasattr(model, 'module') else model 38 | if torch.__version__ >= '1.4': 39 | torch.save(model_to_save.state_dict(), path, _use_new_zipfile_serialization=False) 40 | else: 41 | torch.save(model_to_save.state_dict(), path) 42 | -------------------------------------------------------------------------------- /utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # noinspection PyTypeChecker 5 | def mask_pos0(tensor, pos_mask): 6 | """ 7 | 8 | Args: 9 | tensor (torch.Tensor): (*) 10 | pos_mask (torch.Tensor): same size as logits 11 | 1 for positions that are NOT MASKED, 0 for MASKED positions. 12 | 13 | Returns: 14 | torch.Tensor: same size as logits 15 | """ 16 | if tensor.dtype == torch.float16: 17 | return tensor * pos_mask - 65500 * (1 - pos_mask) 18 | else: 19 | return tensor * pos_mask - 1e30 * (1 - pos_mask) 20 | 21 | 22 | # noinspection PyTypeChecker 23 | def mask_pos(x, mask): 24 | """ 25 | 26 | Args: 27 | x (torch.Tensor): 28 | mask (torch.Tensor): same shape as x, 1 for available position, 0 for masked position 29 | 30 | Returns: 31 | torch.Tensor: 32 | """ 33 | return x - 10000.0 * (1.0 - mask) 34 | 35 | 36 | def pad_tensors(tensors, pad_val, left_pad=False, move_eos_to_beginning=False, eos_val=None): 37 | """Convert a list of 1d tensors into a padded 2d tensor.""" 38 | 39 | def copy_tensor(src, dst): 40 | assert dst.numel() == src.numel() 41 | if move_eos_to_beginning: 42 | assert src[-1] == eos_val 43 | dst[0] = eos_val 44 | dst[1:] = src[:-1] 45 | else: 46 | dst.copy_(src) 47 | 48 | if len(tensors[0].size()) > 1: 49 | tensors = [x.view(-1) for x in tensors] 50 | batch_size = len(tensors) 51 | max_len = max(x.size(0) for x in tensors) 52 | padded_tensor = tensors[0].new_full((batch_size, max_len), pad_val, requires_grad=tensors[0].requires_grad) 53 | for i, x in enumerate(tensors): 54 | copy_tensor(x, padded_tensor[i, max_len - len(x):] if left_pad else padded_tensor[i, :len(x)]) 55 | return padded_tensor 56 | 57 | 58 | def to_cuda(obj): 59 | if torch.is_tensor(obj): 60 | return obj.cuda() 61 | elif isinstance(obj, dict): 62 | return {k: to_cuda(v) for k, v in obj.items()} 63 | elif isinstance(obj, list): 64 | return [to_cuda(x) for x in obj] 65 | else: 66 | return obj 67 | 68 | 69 | def to_device(obj, device): 70 | if torch.is_tensor(obj): 71 | return obj.to(device) 72 | elif isinstance(obj, dict): 73 | return {k: to_device(v, device) for k, v in obj.items()} 74 | elif isinstance(obj, list): 75 | return [to_device(x, device) for x in obj] 76 | elif isinstance(obj, tuple): 77 | return tuple(to_device(x, device) for x in obj) 78 | else: 79 | return obj 80 | -------------------------------------------------------------------------------- /wiki_world.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from html import unescape 3 | import logging 4 | 5 | import faiss 6 | 7 | from retriever import SparseRetriever, DenseRetriever 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class WikiWorld(object): 13 | def __init__(self, corpus, title2id, sparse_retriever, dense_retriever, bm25_redis, mdr_redis, 14 | for_hotpot=True, strict=False, max_ret_size=500): 15 | """ 16 | 17 | Args: 18 | corpus (dict): 19 | title2id (dict): 20 | sparse_retriever (SparseRetriever): 21 | dense_retriever (DenseRetriever): 22 | bm25_redis (redis.Redis): 23 | mdr_redis (redis.Redis): 24 | for_hotpot (bool): 25 | strict (bool): 26 | """ 27 | self._corpus = corpus 28 | self.title2id = title2id 29 | 30 | self.query_filter = {"term": {"for_hotpot": True}} if for_hotpot else None 31 | self.sparse_retriever = sparse_retriever 32 | # redis.Redis(host='10.60.1.79', port=6379, db=0, password='redis4zyc', decode_responses=True) 33 | self.bm25_redis = bm25_redis 34 | self.bm25_offset = defaultdict(int) 35 | 36 | self.max_q_len = 70 37 | self.max_q_sp_len = 350 38 | self.dense_retriever = dense_retriever 39 | # redis.Redis(host='10.60.1.79', port=6379, db=1, password='redis4zyc', decode_responses=True) 40 | self.mdr_redis = mdr_redis 41 | self.mdr_offset = defaultdict(int) 42 | 43 | self.strict = strict 44 | self.max_ret_size = max_ret_size 45 | 46 | def reset(self): 47 | self.bm25_offset.clear() 48 | self.mdr_offset.clear() 49 | 50 | def get(self, p_id): 51 | if p_id not in self._corpus: 52 | return None 53 | para = {"para_id": p_id} 54 | para.update(self._corpus[p_id]) 55 | para['refs'] = {para['text'][span[0]:span[1]]: (tgt_title, span) 56 | for tgt_title, anchors in para['hyperlinks'].items() for span in anchors} 57 | return para 58 | 59 | def link(self, tgt_title, q_id=None, excluded=None): 60 | if excluded is not None: 61 | excluded = set(excluded) 62 | 63 | if tgt_title not in self.title2id: 64 | logger.warning(f"{q_id}: invalid link 『{tgt_title}』") 65 | return None 66 | 67 | tgt_id = self.title2id[tgt_title] 68 | if excluded is not None and tgt_id in excluded: 69 | logger.warning(f"{q_id}: link target 『{tgt_title}』 should be excluded") 70 | 71 | return tgt_id 72 | 73 | def bm25(self, query, q_id=None, excluded=None): 74 | if excluded is not None: 75 | excluded = set(excluded) 76 | session_id = (q_id, query) 77 | 78 | if self.bm25_redis.exists(query) and (self.bm25_redis.llen(query) >= self.max_ret_size or 79 | self.bm25_redis.lindex(query, -1) == 'EOL'): 80 | hits = self.bm25_redis.lrange(query, 0, -1) 81 | else: 82 | hits = [hit['_id'] for hit in self.sparse_retriever.search(query, self.max_ret_size, 83 | filter_dic=self.query_filter, 84 | n_retrieval=self.max_ret_size * 2)] 85 | if len(hits) < self.max_ret_size: 86 | hits.append('EOL') 87 | self.bm25_redis.delete(query) 88 | self.bm25_redis.rpush(query, *hits) 89 | 90 | if self.bm25_offset[session_id] >= len(hits): 91 | return None 92 | for hit_id in hits[self.bm25_offset[session_id]:]: 93 | if hit_id == 'EOL': # don't increase offset if reach the end of retrieval list 94 | return None 95 | self.bm25_offset[session_id] += 1 96 | if excluded is None or hit_id not in excluded: 97 | return hit_id 98 | return None 99 | 100 | def mdr(self, question, expansion=None, q_id=None, excluded=None): 101 | if question.endswith('?'): 102 | question = question[:-1] 103 | if excluded is not None: 104 | excluded = set(excluded) 105 | if expansion is None: 106 | key = question 107 | query = question 108 | else: 109 | sp = self.get(expansion) 110 | key = f"{question}\t+++\t{unescape(sp['title'])}" 111 | expansion_text = sp['text'] 112 | if self.strict: 113 | expansion_text = expansion_text[sp['sentence_spans'][0][0]:sp['sentence_spans'][-1][1]] 114 | query = (question, expansion_text if expansion_text else sp['title']) 115 | session_id = (q_id, key) 116 | 117 | if self.mdr_redis.exists(key) and (self.mdr_redis.llen(key) >= self.max_ret_size or 118 | self.mdr_redis.lindex(key, -1) == 'EOL'): 119 | hits = self.mdr_redis.lrange(key, 0, -1) 120 | else: 121 | faiss.omp_set_num_threads(1) 122 | hits = self.dense_retriever.search(query, max(self.max_ret_size, 1000), 123 | self.max_q_len if expansion is None else self.max_q_sp_len)[0] 124 | if len(hits) < max(self.max_ret_size, 1000): 125 | hits.append('EOL') 126 | self.mdr_redis.delete(key) 127 | self.mdr_redis.rpush(key, *hits) 128 | 129 | if self.mdr_offset[session_id] >= len(hits): 130 | return None 131 | for hit_id in hits[self.mdr_offset[session_id]:]: 132 | if hit_id == 'EOL': # don't increase offset if reach the end of retrieval list 133 | return None 134 | self.mdr_offset[session_id] += 1 135 | if excluded is None or hit_id not in excluded: 136 | return hit_id 137 | return None 138 | 139 | def execute(self, command, q_id=None, excluded=None): 140 | func_name = command[0] 141 | if func_name == 'BM25': 142 | query = command[1] 143 | return self.bm25(query, q_id=q_id, excluded=excluded) 144 | elif func_name == 'MDR': 145 | question, expansion = command[1] 146 | return self.mdr(question, expansion, q_id=q_id, excluded=excluded) 147 | elif func_name == 'LINK': 148 | tgt_title = command[1] 149 | return self.link(tgt_title, q_id=q_id, excluded=excluded) 150 | else: 151 | logger.warning(f'unresolved func: {func_name} in WikiWorld') 152 | return None 153 | --------------------------------------------------------------------------------