├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── demo.py ├── eval_main.py ├── eval_odqa.py ├── eval_utils.py ├── graph_retriever ├── README.md ├── __init__.py ├── modeling_graph_retriever.py ├── run_graph_retriever.py └── utils.py ├── img └── odqa_overview-1.png ├── pipeline ├── __init__.py ├── graph_retriever.py ├── reader.py ├── sequential_sentence_selector.py └── tfidf_retriever.py ├── quick_start_hotpot.sh ├── reader ├── README.md ├── __init__.py ├── modeling_reader.py ├── modeling_utils.py ├── rc_utils.py └── run_reader_confidence.py ├── requirements.txt ├── retriever ├── README.md ├── __init__.py ├── build_db.py ├── build_tfidf.py ├── doc_db.py ├── interactive.py ├── tfidf_doc_ranker.py ├── tfidf_vectorizer_article.py ├── tokenizers.py └── utils.py └── sequential_sentence_selector ├── README.md ├── modeling_sequential_sentence_selector.py ├── run_sequential_sentence_selector.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.zip 3 | *.pyc 4 | *.json 5 | *.tsv 6 | *.csv 7 | *.db 8 | *.bin 9 | *.npz 10 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Akari Asai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Retrieve Reasoning Paths over Wikipedia Graph for Question Answering 2 |

3 | 4 | This is the official implementation of the following paper: 5 | Akari Asai, Kazuma Hashimoto, Hannaneh Hajishirzi, Richard Socher, Caiming Xiong. [Learning to Retrieve Reasoning Paths over Wikipedia Graph for Question Answering](https://arxiv.org/abs/1911.10470). In: Proceedings of ICLR. 2020 6 | 7 | In the paper, we introduce a graph-based retriever-reader framework that learns to retrieve reasoning paths (a reasoning path = a chain of multiple paragraphs to answer multi-hop questions) from English Wikipedia using its graphical structure, and further verify and extract answers from the selected reasoning paths. Our experimental results show state-of-the-art results across three diverse open-domain QA datasets: [HotpotQA (full wiki)](https://hotpotqa.github.io/), [Natural Questions](https://ai.google.com/research/NaturalQuestions/) Open, [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) Open. 8 | 9 | *Acknowledgements*: To implement our BERT-based modules, we used the [huggingface's transformers](https://huggingface.co/transformers/) library. The implementation of TF-IDF based document ranker and splitter started from the [DrQA](https://github.com/facebookresearch/DrQA) and [document-qa](https://github.com/allenai/document-qa) repositories. Huge thanks to the contributors of those amazing repositories! 10 | 11 | 12 | ## Quick Links 13 | 0. [Quick Run on HotpotQA](#0-quick-run-on-hotpotqa) 14 | 1. [Installation](#1-installation) 15 | 2. [Train](#2-train) 16 | 3. [Evaluation](#3-evaluation) 17 | 4. [Interactive Demo](#4-interactive-demo) 18 | 5. [Others](#5-others) 19 | 6. [Citation and Contact](#citation-and-contact) 20 | 21 | ## 0. Quick Run on HotpotQA 22 | We provide [quick_start_hotpot.sh](quick_start_hotpot.sh), with which you can easily set up and run evaluation on HotpotQA full wiki (on the first 100 questions). 23 | 24 | The script will 25 | 1. download our trained models and evaluation data (See [Installation](#1-installation) for the details), 26 | 2. run the whole pipeline on the evaluation data (See [Evaluation](#3-evaluation)), and 27 | 3. calculate the QA scores and supporting facts scores. 28 | 29 | The evaluation will give us the following results: 30 | ``` 31 | {'em': 0.6, 'f1': 0.7468968253968253, 'prec': 0.754030303030303, 'recall': 0.7651666666666667, 'sp_em': 0.49, 'sp_f1': 0.7769365079365077, 'sp_prec': 0.8275, 'sp_recall': 0.7488333333333332, 'joint_em': 0.33, 'joint_f1': 0.6249458756180065, 'joint_prec': 0.6706212121212122, 'joint_recall': 0.6154999999999999} 32 | ``` 33 | 34 | Wanna try your own open-domain question? See [Interactive Demo](#4-interactive-demo)! Once you run the [quick_start_hotpot.sh](quick_start_hotpot.sh), you can easily switch to the demo mode by changing some options in the command. 35 | 36 | ## 1. Installation 37 | ### Requirements 38 | 39 | Our framework requires Python 3.5 or higher. We do not support Python 2.X. 40 | 41 | It also requires installing [pytorch-pretrained-bert version (version 0.6.2)](https://github.com/huggingface/transformers/tree/v0.6.2) and [PyTorch](http://pytorch.org/) version 1.0 or higher. The other dependencies are listed in [requirements.txt](requirements.txt). 42 | We are planning to [migrate from pytorch-pretrained-bert](https://huggingface.co/transformers/migration.html) to transformers soon. 43 | 44 | ### Set up 45 | Run the following commands to clone the repository and install our framework: 46 | 47 | ```bash 48 | git clone https://github.com/AkariAsai/learning_to_retrieve_reasoning_paths.git 49 | cd learning_to_retrieve_reasoning_paths 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | ### Downloading trained models 54 | All the trained models used in our paper for the three datasets are available in google drive: 55 | - HotpotQA full wiki: [hotpot_models.zip](https://drive.google.com/open?id=1ra37xtEXSROG_f90XxR4kgElGJWUHQyM) 56 | - Natural Questions Open: [nq_models.zip](https://drive.google.com/open?id=120JNI49nK-W014cjneuXJuUC09To3KeQ) 57 | - SQuAD Open: [squad_models.zip](https://drive.google.com/open?id=1_z54KceuYXnA0fJvdASDubti9Zb8aAR4): for SQuAD Open, please download `.db` and `.npz` files following [DrQA](https://github.com/facebookresearch/DrQA/blob/master/download.sh) repository. 58 | 59 | Alternatively, you can download a zip file containing all models by using [gdown](https://pypi.org/project/gdown/). 60 | 61 | e.g., download HotpotQA models 62 | ```bash 63 | mkdir models 64 | cd models 65 | gdown https://drive.google.com/uc?id=1ra37xtEXSROG_f90XxR4kgElGJWUHQyM 66 | unzip hotpot_models.zip 67 | rm hotpot_models.zip 68 | cd .. 69 | ``` 70 | **Note: the size of the zip file is about 4GB for HotpotQA models, and once it is extracted, the total size of the models is more than 8GB (including the introductory paragraph only Wikipedia database). The `nq_models.zip` include full Wikipedia database, which is around 30GB once extracted.** 71 | 72 | 73 | ### Downloading data 74 | #### for training 75 | - You can download all of the training datasets from [here (google drive)](https://drive.google.com/drive/folders/1nYQOtoxJiiL5XK6PHOeluTWgh3PTOCVD?usp=sharing). 76 | - We create (1) data to train graph-based retriever, and (2) data to train reader by augmenting the publicly available machine reading comprehension datasets (HotpotQA, SQuAD and Natural Questions). 77 | See the details of the process in Section 3.1.2 and Section 3.2 in [our paper](https://arxiv.org/pdf/1911.10470.pdf). 78 | 79 | #### for evaluation 80 | - Following previous work such as [DrQA](https://github.com/shmsw25/qa-hard-em) or [qa-hard-em](https://github.com/shmsw25/qa-hard-em), we convert the original machine reading comprehension datasets to sets of question and answer pairs. You can download our preprocessed data from [here](https://drive.google.com/open?id=1na7vxYWadK2kS2aqg88RDMFir8PP-2lS). 81 | 82 | - For HotpotQA, we only use question-answer pairs as input, but we need to use the original HotpotQA development set (either fullwiki or distractor) to evaluate supporting fact evaluations from [HotpotQA's website](https://hotpotqa.github.io/). 83 | 84 | ```bash 85 | mkdir data 86 | cd data 87 | mkdir hotpot 88 | cd hotpot 89 | gdown https://drive.google.com/uc?id=1m_7ZJtWQsZ8qDqtItDTWYlsEHDeVHbPt # download preprocessed full wiki data 90 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json # download the original full wiki data for sp evaluation. 91 | cd ../.. 92 | ``` 93 | 94 | ## 2. Train 95 | In this work, we use a two-stage training approach, which lets you train the reader and retriever independently and easily switch to new reader models. 96 | The details of the training process can be seen in the README files in [graph_retriever](graph_retriever), [reader](reader) and [sequence_sentence_selector](sequential_sentence_selector). 97 | 98 | You can download our pre-trained models from the link mentioned above. 99 | 100 | ## 3. Evaluation 101 | After downloading a TF-IDF retriever, training a graph-retriever and reader models, you can test the performance of our entire system. 102 | 103 | 104 | #### HotpotQA 105 | If you set up using `quick_start_hotpot.sh`, you can run full evaluation by setting the `--eval_file_path` option to `data/hotpot/hotpot_fullwiki_first_100.jsonl` . 106 | 107 | ```bash 108 | python eval_main.py \ 109 | --eval_file_path data/hotpot/hotpot_fullwiki_data.jsonl \ 110 | --eval_file_path_sp data/hotpot/hotpot_dev_distractor_v1.json \ 111 | --graph_retriever_path models/hotpot_models/graph_retriever_path/pytorch_model.bin \ 112 | --reader_path models/hotpot_models/reader \ 113 | --sequential_sentence_selector_path models/hotpot_models/sequential_sentence_selector/pytorch_model.bin \ 114 | --tfidf_path models/hotpot_models/tfidf_retriever/wiki_open_full_new_db_intro_only-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \ 115 | --db_path models/hotpot_models/wiki_db/wiki_abst_only_hotpotqa_w_original_title.db \ 116 | --bert_model_sequential_sentence_selector bert-large-uncased --do_lower_case \ 117 | --tfidf_limit 500 --eval_batch_size 4 --pruning_by_links --beam_graph_retriever 8 \ 118 | --beam_sequential_sentence_selector 8 --max_para_num 2000 --sp_eval 119 | ``` 120 | 121 | The evaluation will give us the following results (equivalent to our reported results): 122 | ``` 123 | {'em': 0.6049966239027684, 'f1': 0.7330873757783022, 'prec': 0.7613180885780131, 'recall': 0.7421444532461545, 'sp_em': 0.49169480081026334, 'sp_f1': 0.7605390258327606, 'sp_prec': 0.8103758721584524, 'sp_recall': 0.7325846435805953, 'joint_em': 0.35827143821742063, 'joint_f1': 0.6143774960171196, 'joint_prec': 0.679462464277477, 'joint_recall': 0.5987834193329556} 124 | ``` 125 | 126 | #### SQuAD Open 127 | 128 | ```bash 129 | python eval_main.py \ 130 | --eval_file_path data/squad/squad_open_domain_data.jsonl \ 131 | --graph_retriever_path models/squad_models/selector/pytorch_model.bin \ 132 | --reader_path models/squad_models/reader \ 133 | --tfidf_path DrQA/data/wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \ 134 | --db_path DrQA/data/wikipedia/docs.db \ 135 | --bert_model bert-base-uncased --do_lower_case \ 136 | --tfidf_limit 50 --eval_batch_size 4 \ 137 | --beam_graph_retriever 8 --max_para_num 2000 --use_full_article 138 | ``` 139 | 140 | #### Natural Questions 141 | 142 | ``` 143 | python eval_main.py \ 144 | --eval_file_path data/nq_open_domain_data.jsonl \ 145 | --graph_retriever_path models/nq/selector/pytorch_model.bin --reader_path models/nq/reader/ \ 146 | --tfidf_path models/nq_models/tfidf_retriever/wiki_20181220_nq_hyper_linked-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \ 147 | --db_path models/nq_models/wiki_db/wiki_20181220_nq_hyper_linked.db \ 148 | --bert_model bert-base-uncased --do_lower_case --tfidf_limit 20 --eval_batch_size 4 --pruning_by_links \ 149 | --beam_graph_retriever 8 --max_para_num 2000 --use_full_article 150 | ``` 151 | 152 | #### (optional) Using TagMe for initial retrieval 153 | As mentioned in Appendix B.7 in our paper, you can optionally use an entity linking system ([TagMe](https://sobigdata.d4science.org/web/tagme/tagme-help)) for the initial retrieval. 154 | 155 | To uee TagMe, 156 | 1. [register](https://services.d4science.org/group/d4science-services-gateway/explore?siteId=22994) to get API key, and 157 | 3. set the API key via `--tagme_api_key` option, and set `--tagme` option true. 158 | 159 | ``` 160 | python eval_main.py \ 161 | --eval_file_path data/nq_open_domain_data.jsonl \ 162 | --graph_retriever_path models/nq/selector/pytorch_model.bin --reader_path models/nq/reader/ \ 163 | --tfidf_path models/nq_models/tfidf_retriever/wiki_20181220_nq_hyper_linked-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \ 164 | --db_path models/nq_models/wiki_db/wiki_20181220_nq_hyper_linked.db \ 165 | --bert_model bert-base-uncased --do_lower_case --tfidf_limit 20 --eval_batch_size 4 --pruning_by_links --beam 8 --max_para_num 2000 --use_full_article --tagme --tagme_api_key YOUR_API_KEY 166 | ``` 167 | 168 | *The implementation of the two-step TF-IDF retrieval module (article retrieval --> paragraph-level re-ranking) for Natural Questions is currently in progress, which might give slightly lower scores than the reported results in our paper. We'll fix the issue soon.* 169 | 170 | ## 4. Interactive demo 171 | You could run interactive demo and ask open-domain questions. Our model answers the question with supporting facts. 172 | 173 | If you set up using `quick_start.sh` script, you can run full evaluation by changing the script name to from `eval_main.py` to `demo.py`, and removing `--eval_file_path` and `--eval_file_path_sp` options. 174 | 175 | e.g., 176 | ```bash 177 | python demo.py \ 178 | --graph_retriever_path models/hotpot_models/graph_retriever_path/pytorch_model.bin \ 179 | --reader_path models/hotpot_models/reader \ 180 | --sequential_sentence_selector_path models/hotpot_models/sequential_sentence_selector/pytorch_model.bin \ 181 | --tfidf_path models/hotpot_models/tfidf_retriever/wiki_open_full_new_db_intro_only-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \ 182 | --db_path models/hotpot_models/wiki_db/wiki_abst_only_hotpotqa_w_original_title.db \ 183 | --do_lower_case --beam 4 --quiet --max_para_num 200 \ 184 | --tfidf_limit 20 --pruning_by_links \ 185 | ``` 186 | 187 | An output example is as follows: 188 | ``` 189 | #### Reader results #### 190 | [ { 191 | "q_id": "DEMO_0", 192 | "question": "Bordan Tkachuk was the CEO of a company that provides what sort of products?", 193 | "answer": "IT products and services", 194 | "context": [ 195 | "Cintas_0", 196 | "Bordan Tkachuk_0", 197 | "Viglen_0" 198 | ] 199 | } 200 | ] 201 | 202 | #### Supporting facts #### 203 | [ 204 | { 205 | "q_id": "DEMO_0", 206 | "supporting facts": { 207 | "Viglen_0": [ 208 | [0, "Viglen Ltd provides IT products and services, including storage systems, servers, workstations and data/voice communications equipment and services." 209 | ] 210 | ], 211 | "Bordan Tkachuk_0": [ 212 | [0, "Bordan Tkachuk ( ) is a British business executive, the former CEO of Viglen, also known from his appearances on the BBC-produced British version of \"The Apprentice,\" interviewing for his boss Lord Sugar." 213 | ] 214 | ] 215 | } 216 | } 217 | ] 218 | 219 | ``` 220 | 221 | 222 | ## 5. Others 223 | ### Distant supervision & negative examples data generation 224 | In this work, we augment the original MRC data with negative and distant supervision examples to make our retriever and reader robust to inference time noise. Our experimental results show these training strategy gives significant performance improvements. 225 | 226 | All of the training data is available [here (google drive)](https://drive.google.com/drive/u/1/folders/1nYQOtoxJiiL5XK6PHOeluTWgh3PTOCVD). 227 | 228 | *We are planning to release our codes to augment training data with negative examples and distant examples to guide future research in open-domain QA fields. Please stay tuned!* 229 | 230 | ### Dataset format 231 | For quick experiments and detailed human analysis, we save intermediate results for each step: original Q-A pair (format A), TF-IDF retrieval (format B), our graph-based (format C) retriever. 232 | 233 | #### Format A (eval data, the input of TF-IDF retriever) 234 | For the evaluation pipeline, our initial input is a simple `jsonlines` format where each line contains one example with `id = [str]`, `question = [str]` and `answer = List[str]` (or `answers = List[str]` for datasets where multiple answers are annotated for each question) information. 235 | 236 | For SQuAD Open and HotpotQA fullwiki, you can download the preprocessed format A files from [here](https://drive.google.com/file/d/1f3YtvImDxB9h6GuVGFelzxvgwEcqxuHn/view?usp=sharing). 237 | 238 | e.g., HotpotQA fullwiki dev 239 | ``` 240 | { 241 | "id": "5ab3b0bf5542992ade7c6e39", 242 | "question": "What year did Guns N Roses perform a promo for a movie starring Arnold Schwarzenegger 243 | as a former New York Police detective?", 244 | "answer": ["1999"] 245 | } 246 | ``` 247 | 248 | e.g., SQuAD Open dev 249 | ```py 250 | { 251 | "id": "56beace93aeaaa14008c91e0", 252 | "question": "What venue did Super Bowl 50 take place in?", 253 | "answers": ["Levi's Stadium", "Levi's Stadium", 254 | "Levi's Stadium in the San Francisco Bay Area at Santa Clara"] 255 | } 256 | ``` 257 | 258 | #### Format B (TF-IDF retriever output) 259 | For TF-IDF results, we store the data as a list of `JSON`, and each data point contains several information. 260 | 261 | - `q_id = [str]` 262 | - `question = [str]` 263 | - `answer = List[str]` 264 | - `context = Dict[str, str]`: Top $N$ paragraphs which are ranked high by our TF-IDF retriever. 265 | - `all_linked_para_title_dic = Dict[str, List[str]]`: Hyper-linked paragraphs' titles from paragraphs in `context`. 266 | - `all_linked_paras_dic = Dict[str, str]`: the paragraphs of the hyper-linked paragraphs. 267 | 268 | For training data, we have additional items that are used as ground-truth reasoning paths. 269 | - `short_gold = List[str]` 270 | - `redundant_gold = List[str]` 271 | - `all_redundant_gold = List[List[str]]` 272 | 273 | e.g., HotpotQA fullwiki dev 274 | 275 | ```py 276 | { 277 | "question": 'Were Scott Derrickson and Ed Wood of the same nationality?'. 278 | "q_id": "5ab3b0bf5542992ade7c6e39", 279 | "context": 280 | {"Scott Derrickson_0": "Scott Derrickson (born July 16, 1966) is an American director,....", 281 | "Ed Wood'_0": "...", ....}, 282 | 'all_linked_para_title_dic': 283 | {"Scott Derrickson_0": ['Los Angeles_0', 'California_0', 'Horror film_0', ...]}, 284 | 'all_linked_paras_dic': 285 | {"Los Angeles_0": "Los Angeles, officially the City of Los Angeles and often known by its initials L.A., is ...", ...}, 286 | 'short_gold':[], 287 | 'redundant_gold': [], 288 | 'all_redundant_gold': [] 289 | } 290 | ``` 291 | 292 | #### Format C (Graph-based retriever output) 293 | The graph-based retriever's output is a list of `JSON` objects as follows: 294 | 295 | - `q_id = [str]` 296 | - `titles = [str]`: a sequence of titles (the top one reasoning path) 297 | - `topk_titles = List[List[str]]`: k sequences of titles (the top k reasoning paths). 298 | - `context = Dict[str, str]`: the paragraphs which are included in top reasoning paths. 299 | 300 | ```py 301 | { 302 | "q_id": "5a713ea95542994082a3e6e4", 303 | "titles": ["Alvaro Mexia_0", "Boruca_0"], 304 | "topk_titles": [ 305 | ["Alvaro Mexia_0", "Boruca_0"], 306 | ["Alvaro Mexia_0", "Indigenous peoples of Florida_0"], 307 | ["Alvaro Mexia_0"], 308 | ["List of Ambassadors of Spain to the United States_0", "Boruca_0"], 309 | ["Alvaro Mexia_0", "St. Augustine, Florida_0"], 310 | ["Alvaro Mexia_0", "Cape Canaveral, Florida_0"], 311 | ["Alvaro Mexia_0", "Florida_0"], 312 | ["Parque de la Bombilla (Mexico City)_0", "Alvaro Mexia_0", "Boruca_0"]], 313 | "context": { 314 | "Alvaro Mexia_0": "Alvaro Mexia was a 17th-century Spanish explorer and cartographer of the east coast of Florida....", "Boruca_0": "The Boruca (also known as the Brunca or the Brunka) are an indigenous people living in Costa Rica"} 315 | } 316 | ``` 317 | 318 | 319 | ## Citation and Contact 320 | If you find this codebase is useful or use in your work, please cite our paper. 321 | ``` 322 | @inproceedings{ 323 | asai2020learning, 324 | title={Learning to Retrieve Reasoning Paths over Wikipedia Graph for Question Answering}, 325 | author={Akari Asai and Kazuma Hashimoto and Hannaneh Hajishirzi and Richard Socher and Caiming Xiong}, 326 | booktitle={International Conference on Learning Representations}, 327 | year={2020} 328 | } 329 | ``` 330 | 331 | Please contact Akari Asai ([@AkariAsai](https://twitter.com/AkariAsai?s=20), akari[at]cs.washington.edu) for questions and suggestions. 332 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/__init__.py -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import torch 5 | 6 | from pipeline.tfidf_retriever import TfidfRetriever 7 | from pipeline.graph_retriever import GraphRetriever 8 | from pipeline.reader import Reader 9 | from pipeline.sequential_sentence_selector import SequentialSentenceSelector 10 | 11 | import logging 12 | class DisableLogger(): 13 | def __enter__(self): 14 | logging.disable(logging.CRITICAL) 15 | def __exit__(self, a, b, c): 16 | logging.disable(logging.NOTSET) 17 | 18 | class ODQA: 19 | def __init__(self, args): 20 | 21 | self.args = args 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() and not self.args.no_cuda else "cpu") 24 | 25 | # TF-IDF Retriever 26 | self.tfidf_retriever = TfidfRetriever(self.args.db_path, self.args.tfidf_path) 27 | 28 | # Graph Retriever 29 | self.graph_retriever = GraphRetriever(self.args, device) 30 | 31 | # Reader 32 | self.reader = Reader(self.args, device) 33 | 34 | # Supporting facts selector 35 | self.sequential_sentence_selector = SequentialSentenceSelector(self.args, device) 36 | 37 | def predict(self, 38 | questions: list): 39 | 40 | print('-- Retrieving paragraphs by TF-IDF...', flush=True) 41 | tfidf_retrieval_output = [] 42 | for i in range(len(questions)): 43 | question = questions[i] 44 | tfidf_retrieval_output += self.tfidf_retriever.get_abstract_tfidf('DEMO_{}'.format(i), question, self.args) 45 | 46 | print('-- Running the graph-based recurrent retriever model...', flush=True) 47 | graph_retrieval_output = self.graph_retriever.predict(tfidf_retrieval_output, self.tfidf_retriever, self.args) 48 | 49 | print('-- Running the reader model...', flush=True) 50 | answer, title = self.reader.predict(graph_retrieval_output, self.args) 51 | 52 | reader_output = [{'q_id': s['q_id'], 53 | 'question': s['question'], 54 | 'answer': answer[s['q_id']], 55 | 'context': title[s['q_id']]} for s in graph_retrieval_output] 56 | 57 | if self.args.sequential_sentence_selector_path is not None: 58 | print('-- Running the supporting facts retriever...', flush=True) 59 | supporting_facts = self.sequential_sentence_selector.predict(reader_output, self.tfidf_retriever, self.args) 60 | else: 61 | supporting_facts = [] 62 | 63 | return tfidf_retrieval_output, graph_retrieval_output, reader_output, supporting_facts 64 | 65 | 66 | def main(): 67 | 68 | parser = argparse.ArgumentParser() 69 | 70 | ## Required parameters 71 | parser.add_argument("--graph_retriever_path", 72 | default=None, 73 | type=str, 74 | required=True, 75 | help="Graph retriever model path.") 76 | parser.add_argument("--reader_path", 77 | default=None, 78 | type=str, 79 | required=True, 80 | help="Reader model path.") 81 | parser.add_argument("--tfidf_path", 82 | default=None, 83 | type=str, 84 | required=True, 85 | help="TF-IDF path.") 86 | parser.add_argument("--db_path", 87 | default=None, 88 | type=str, 89 | required=True, 90 | help="DB path.") 91 | 92 | ## Other parameters 93 | parser.add_argument("--sequential_sentence_selector_path", 94 | default=None, 95 | type=str, 96 | help="Supporting facts model path.") 97 | parser.add_argument("--max_sent_num", 98 | default=30, 99 | type=int) 100 | parser.add_argument("--max_sf_num", 101 | default=15, 102 | type=int) 103 | 104 | 105 | parser.add_argument("--bert_model_graph_retriever", default='bert-base-uncased', type=str, 106 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 107 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 108 | "bert-base-multilingual-cased, bert-base-chinese.") 109 | 110 | parser.add_argument("--bert_model_sequential_sentence_selector", default='bert-large-uncased', type=str, 111 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 112 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 113 | "bert-base-multilingual-cased, bert-base-chinese.") 114 | 115 | parser.add_argument("--max_seq_length", 116 | default=378, 117 | type=int, 118 | help="The maximum total input sequence length after WordPiece tokenization. \n" 119 | "Sequences longer than this will be truncated, and sequences shorter \n" 120 | "than this will be padded.") 121 | 122 | parser.add_argument("--max_seq_length_sequential_sentence_selector", 123 | default=256, 124 | type=int, 125 | help="The maximum total input sequence length after WordPiece tokenization. \n" 126 | "Sequences longer than this will be truncated, and sequences shorter \n" 127 | "than this will be padded.") 128 | 129 | parser.add_argument("--do_lower_case", 130 | action='store_true', 131 | help="Set this flag if you are using an uncased model.") 132 | parser.add_argument("--no_cuda", 133 | action='store_true', 134 | help="Whether not to use CUDA when available") 135 | 136 | # RNN graph retriever-specific parameters 137 | parser.add_argument("--max_para_num", 138 | default=10, 139 | type=int) 140 | 141 | parser.add_argument('--eval_batch_size', 142 | type=int, 143 | default=5, 144 | help="Eval batch size") 145 | 146 | parser.add_argument('--beam_graph_retriever', 147 | type=int, 148 | default=1, 149 | help="Beam size for Graph Retriever") 150 | parser.add_argument('--beam_sequential_sentence_selector', 151 | type=int, 152 | default=1, 153 | help="Beam size for Sequential Sentence Selector") 154 | 155 | parser.add_argument('--min_select_num', 156 | type=int, 157 | default=1, 158 | help="Minimum number of selected paragraphs") 159 | parser.add_argument('--max_select_num', 160 | type=int, 161 | default=3, 162 | help="Maximum number of selected paragraphs") 163 | parser.add_argument("--no_links", 164 | action='store_true', 165 | help="Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)") 166 | parser.add_argument("--pruning_by_links", 167 | action='store_true', 168 | help="Whether to do pruning by links (and top 1)") 169 | parser.add_argument("--expand_links", 170 | action='store_true', 171 | help="Whether to expand links with paragraphs in the same article (for NQ)") 172 | parser.add_argument('--tfidf_limit', 173 | type=int, 174 | default=None, 175 | help="Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)") 176 | 177 | parser.add_argument("--split_chunk", default=100, type=int, 178 | help="Chunk size for BERT encoding at inference time") 179 | parser.add_argument("--eval_chunk", default=500, type=int, 180 | help="Chunk size for inference of graph_retriever") 181 | 182 | parser.add_argument("--tagme", 183 | action='store_true', 184 | help="Whether to use tagme at inference") 185 | parser.add_argument('--topk', 186 | type=int, 187 | default=2, 188 | help="Whether to use how many paragraphs from the previous steps") 189 | 190 | parser.add_argument("--n_best_size", default=5, type=int, 191 | help="The total number of n-best predictions to generate in the nbest_predictions.json " 192 | "output file.") 193 | parser.add_argument("--max_answer_length", default=30, type=int, 194 | help="The maximum length of an answer that can be generated. This is needed because the start " 195 | "and end predictions are not conditioned on one another.") 196 | parser.add_argument("--max_query_length", default=64, type=int, 197 | help="The maximum number of tokens for the question. Questions longer than this will " 198 | "be truncated to this length.") 199 | parser.add_argument("--doc_stride", default=128, type=int, 200 | help="When splitting up a long document into chunks, how much stride to take between chunks.") 201 | 202 | 203 | odqa = ODQA(parser.parse_args()) 204 | 205 | print() 206 | while True: 207 | questions = input('Questions: ') 208 | questions = questions.strip() 209 | if questions == 'q': 210 | break 211 | elif questions == '': 212 | continue 213 | 214 | questions = questions.strip().split('|||') 215 | tfidf_retrieval_output, graph_retriever_output, reader_output, supporting_facts = odqa.predict(questions) 216 | 217 | if graph_retriever_output is None: 218 | print() 219 | print('Invalid question! "{}"'.format(question)) 220 | print() 221 | continue 222 | 223 | print() 224 | print('#### Retrieval results ####') 225 | print(json.dumps(graph_retriever_output, indent=4)) 226 | print() 227 | 228 | print('#### Reader results ####') 229 | print(json.dumps(reader_output, indent=4)) 230 | print() 231 | 232 | if len(supporting_facts) > 0: 233 | print('#### Supporting facts ####') 234 | print(json.dumps(supporting_facts, indent=4)) 235 | print() 236 | 237 | 238 | if __name__ == "__main__": 239 | with DisableLogger(): 240 | main() 241 | -------------------------------------------------------------------------------- /eval_main.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from eval_odqa import ODQAEval 4 | from eval_utils import evaluate, evaluate_w_sp_facts, convert_qa_sp_results_into_hp_eval_format 5 | 6 | 7 | def main(): 8 | 9 | odqa = ODQAEval() 10 | 11 | if odqa.args.sequential_sentence_selector_path is not None: 12 | tfidf_retrieval_output, selector_output, reader_output, sp_selector_output = odqa.eval() 13 | if odqa.args.sp_eval is True: 14 | # eval the performance; based on F1 & EM. 15 | predictions = convert_qa_sp_results_into_hp_eval_format( 16 | reader_output, sp_selector_output, odqa.args.db_path) 17 | results = evaluate_w_sp_facts( 18 | odqa.args.eval_file_path_sp, predictions, odqa.args.sampled) 19 | else: 20 | results = evaluate(odqa.args.eval_file_path, reader_output) 21 | print(results) 22 | 23 | else: 24 | tfidf_retrieval_output, selector_output, reader_output = odqa.eval() 25 | # eval the performance; based on F1 & EM. 26 | results = evaluate(odqa.args.eval_file_path, reader_output) 27 | 28 | print("EM :{0}, F1: {1}".format(results['exact_match'], results['f1'])) 29 | 30 | # Save the intermediate results. 31 | if odqa.args.tfidf_results_save_path is not None: 32 | print('#### save TFIDF Retrieval results to {}####'.format( 33 | odqa.args.tfidf_results_save_path)) 34 | with open(odqa.args.tfidf_results_save_path, "w") as writer: 35 | writer.write(json.dumps(tfidf_retrieval_output, indent=4) + "\n") 36 | 37 | if odqa.args.selector_results_save_path is not None: 38 | print('#### save graph-based Retrieval results to {} ####'.format( 39 | odqa.args.selector_results_save_path)) 40 | with open(odqa.args.selector_results_save_path, "w") as writer: 41 | writer.write(json.dumps(selector_output, indent=4) + "\n") 42 | 43 | if odqa.args.reader_results_save_path is not None: 44 | print('#### save reader results to {} ####'.format( 45 | odqa.args.reader_results_save_path)) 46 | with open(odqa.args.reader_results_save_path, "w") as writer: 47 | writer.write(json.dumps(reader_output, indent=4) + "\n") 48 | 49 | if odqa.args.sequence_sentence_selector_save_path is not None: 50 | print("#### save sentence selector results to {} ####".format( 51 | odqa.args.sequence_sentence_selector_save_path)) 52 | with open(odqa.args.sequence_sentence_selector_save_path, "w") as writer: 53 | writer.write(json.dumps(sp_selector_output, indent=4) + "\n") 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /eval_odqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | from tqdm import tqdm 5 | 6 | from pipeline.tfidf_retriever import TfidfRetriever 7 | from pipeline.graph_retriever import GraphRetriever 8 | from pipeline.reader import Reader 9 | from pipeline.sequential_sentence_selector import SequentialSentenceSelector 10 | 11 | from eval_utils import read_jsonlines 12 | 13 | # ODQA components for evaluation 14 | class ODQAEval: 15 | def __init__(self): 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | ## Required parameters 20 | parser.add_argument("--eval_file_path", 21 | default=None, 22 | type=str, 23 | required=True, 24 | help="Eval data file path") 25 | parser.add_argument("--eval_file_path_sp", 26 | default=None, 27 | type=str, 28 | required=False, 29 | help="Eval data file path for supporting fact evaluation (only for HotpotQA)") 30 | parser.add_argument("--graph_retriever_path", 31 | default=None, 32 | type=str, 33 | required=True, 34 | help="Selector model path.") 35 | parser.add_argument("--reader_path", 36 | default=None, 37 | type=str, 38 | required=True, 39 | help="Reader model path.") 40 | parser.add_argument("--sequential_sentence_selector_path", 41 | default=None, 42 | type=str, 43 | required=False, 44 | help="supporting fact selector model path.") 45 | parser.add_argument("--tfidf_path", 46 | default=None, 47 | type=str, 48 | required=True, 49 | help="TF-IDF path.") 50 | parser.add_argument("--db_path", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="DB path.") 55 | 56 | parser.add_argument("--bert_model_graph_retriever", default='bert-base-uncased', type=str, 57 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 58 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 59 | "bert-base-multilingual-cased, bert-base-chinese.") 60 | 61 | parser.add_argument("--bert_model_sequential_sentence_selector", default='bert-base-uncased', type=str, 62 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 63 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 64 | "bert-base-multilingual-cased, bert-base-chinese.") 65 | 66 | ## Other parameters 67 | parser.add_argument('--eval_batch_size', 68 | type=int, 69 | default=5, 70 | help="Eval batch size") 71 | 72 | parser.add_argument("--max_seq_length", 73 | default=378, 74 | type=int, 75 | help="The maximum total input sequence length after WordPiece tokenization. \n" 76 | "Sequences longer than this will be truncated, and sequences shorter \n" 77 | "than this will be padded.") 78 | 79 | parser.add_argument("--max_seq_length_sequential_sentence_selector", 80 | default=256, 81 | type=int, 82 | help="The maximum total input sequence length after WordPiece tokenization. \n" 83 | "Sequences longer than this will be truncated, and sequences shorter \n" 84 | "than this will be padded.") 85 | 86 | parser.add_argument("--do_lower_case", 87 | action='store_true', 88 | help="Set this flag if you are using an uncased model.") 89 | parser.add_argument("--no_cuda", 90 | action='store_true', 91 | help="Whether not to use CUDA when available") 92 | 93 | # RNN selector-specific parameters 94 | parser.add_argument("--max_para_num", 95 | default=10, 96 | type=int) 97 | 98 | parser.add_argument('--beam_graph_retriever', 99 | type=int, 100 | default=1, 101 | help="Beam size for Graph Retriever") 102 | parser.add_argument('--beam_sequential_sentence_selector', 103 | type=int, 104 | default=1, 105 | help="Beam size for Sequential Sentence Selector") 106 | 107 | parser.add_argument('--min_select_num', 108 | type=int, 109 | default=1, 110 | help="Minimum number of selected paragraphs") 111 | parser.add_argument('--max_select_num', 112 | type=int, 113 | default=3, 114 | help="Maximum number of selected paragraphs") 115 | parser.add_argument("--no_links", 116 | action='store_true', 117 | help="Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)") 118 | parser.add_argument("--pruning_by_links", 119 | action='store_true', 120 | help="Whether to do pruning by links (and top 1)") 121 | parser.add_argument("--expand_links", 122 | action='store_true', 123 | help="Whether to expand links with paragraphs in the same article (for NQ)") 124 | parser.add_argument('--tfidf_limit', 125 | type=int, 126 | default=100, 127 | help="Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)") 128 | parser.add_argument('--pruning_l', 129 | type=int, 130 | default=10, 131 | help="Set the maximum number of paragraphs retrieved from the same article.") 132 | 133 | parser.add_argument("--split_chunk", 134 | default=100, 135 | type=int, 136 | help="Chunk size for BERT encoding at inference time") 137 | 138 | parser.add_argument("--eval_chunk", default=500, type=int, 139 | help="Chunk size for inference of graph_retriever") 140 | 141 | # To use TagMe, you need to register first here https://sobigdata.d4science.org/web/tagme/tagme-help . 142 | parser.add_argument("--tagme", 143 | action='store_true', 144 | help="Whether to use tagme at inference") 145 | parser.add_argument("--tagme_api_key", 146 | type=str, 147 | default=None, 148 | help="Set the TagMe private API key if you use TagMe.") 149 | 150 | parser.add_argument('--topk', 151 | type=int, 152 | default=2, 153 | help="Whether to use how many paragraphs from the previous steps") 154 | 155 | parser.add_argument("--n_best_size", default=5, type=int, 156 | help="The total number of n-best predictions to generate in the nbest_predictions.json " 157 | "output file.") 158 | parser.add_argument("--max_answer_length", default=30, type=int, 159 | help="The maximum length of an answer that can be generated. This is needed because the start " 160 | "and end predictions are not conditioned on one another.") 161 | parser.add_argument("--max_query_length", default=64, type=int, 162 | help="The maximum number of tokens for the question. Questions longer than this will " 163 | "be truncated to this length.") 164 | parser.add_argument("--doc_stride", default=128, type=int, 165 | help="When splitting up a long document into chunks, how much stride to take between chunks.") 166 | parser.add_argument("--max_sent_num", 167 | default=30, 168 | type=int) 169 | parser.add_argument("--max_sf_num", 170 | default=15, 171 | type=int) 172 | # save intermediate results 173 | parser.add_argument('--tfidf_results_save_path', 174 | type=str, 175 | default=None, 176 | help="If specified, the TF-IDF results will be saved in the file path") 177 | 178 | parser.add_argument('--selector_results_save_path', 179 | type=str, 180 | default=None, 181 | help="If specified, the selector results will be saved in the file path") 182 | 183 | parser.add_argument('--reader_results_save_path', 184 | type=str, 185 | default=None, 186 | help="If specified, the reader results will be saved in the file path") 187 | 188 | parser.add_argument('--sequence_sentence_selector_save_path', 189 | type=str, 190 | default=None, 191 | help="If specified, the reader results will be saved in the file path") 192 | 193 | parser.add_argument('--saved_tfidf_retrieval_outputs_path', 194 | type=str, 195 | default=None, 196 | help="If specified, load the saved TF-IDF retrieval results from the path.") 197 | 198 | parser.add_argument('--saved_selector_outputs_path', 199 | type=str, 200 | default=None, 201 | help="If specified, load the saved reasoning path retrieval results from the path.") 202 | 203 | parser.add_argument("--sp_eval", 204 | action='store_true', 205 | help="set true if you evaluate supporting fact evaluations while running QA evaluation (HotpotQA only).") 206 | 207 | parser.add_argument("--sampled", 208 | action='store_true', 209 | help="evaluate on sampled examples; only for debugging and quick demo.") 210 | 211 | parser.add_argument("--use_full_article", 212 | action='store_true', 213 | help="Set true if you use all of the wikipedia paragraphs, not limiting to intro paragraphs.") 214 | 215 | parser.add_argument("--prune_after_agg", 216 | action='store_true', 217 | help="Pruning after aggregating all paragraphs from top k TFIDF paragraphs.") 218 | 219 | 220 | self.args = parser.parse_args() 221 | 222 | self.device = torch.device("cuda" if torch.cuda.is_available() 223 | and not self.args.no_cuda else "cpu") 224 | 225 | # Retriever 226 | self.retriever = TfidfRetriever( 227 | self.args.db_path, self.args.tfidf_path, self.args.use_full_article, self.args.pruning_l) 228 | 229 | def retrieve(self, eval_questions): 230 | tfidf_retrieval_output = [] 231 | for _, eval_q in enumerate(tqdm(eval_questions, desc="Question")): 232 | if self.args.use_full_article is True: 233 | tfidf_retrieval_output += self.retriever.get_article_tfidf_with_hyperlinked_titles( 234 | eval_q["id"], eval_q["question"], self.args) 235 | else: 236 | tfidf_retrieval_output += self.retriever.get_abstract_tfidf( 237 | eval_q["id"], eval_q["question"], self.args) 238 | # create examples with retrieval results. 239 | print("retriever") 240 | print(len(tfidf_retrieval_output)) 241 | 242 | # with `use_full_article` setting, we store the title to hyperlinked map and store 243 | # it as retriever's property, 244 | if self.args.use_full_article is True: 245 | title2hyperlink_dic = {} 246 | for example in tfidf_retrieval_output: 247 | title2hyperlink_dic.update( 248 | example["all_linked_para_title_dic"]) 249 | self.retriever.store_title2hyperlink_dic(title2hyperlink_dic) 250 | 251 | return tfidf_retrieval_output 252 | 253 | def select(self, tfidf_retrieval_output): 254 | # Selector 255 | selector = GraphRetriever(self.args, self.device) 256 | 257 | selector_output = selector.predict( 258 | tfidf_retrieval_output, self.retriever, self.args) 259 | print("selector") 260 | print(len(selector_output)) 261 | 262 | return selector_output 263 | 264 | def read(self, selector_output): 265 | # Reader 266 | reader = Reader(self.args, self.device) 267 | 268 | answers, titles = reader.predict(selector_output, self.args) 269 | reader_output = {} 270 | print("reader") 271 | print(len(answers)) 272 | print(answers) 273 | for s in selector_output: 274 | reader_output[s["q_id"]] = answers[s["q_id"]] 275 | 276 | return reader_output, titles 277 | 278 | def explain(self, reader_output_sp): 279 | sequential_sentence_selector = SequentialSentenceSelector(self.args, self.device) 280 | supporting_facts = sequential_sentence_selector.predict( 281 | reader_output_sp, self.retriever, self.args) 282 | 283 | return supporting_facts 284 | 285 | 286 | def eval(self): 287 | # load eval data 288 | # the eval data is described in '{"id": "q_id", "question": "q1", answer": ["a11", ..., "a1i"]}' format (jsonlines) as in DrQA repository. 289 | # TODO: create eval data for HotpotQA, SQuAD and Natural Questions Open. 290 | eval_questions = read_jsonlines(self.args.eval_file_path) 291 | 292 | # Run (or load) graph retriever 293 | # FIXME: do not override saved results. 294 | tfidf_retrieval_output = None 295 | if self.args.saved_selector_outputs_path: 296 | selector_output = json.load( 297 | open(self.args.saved_selector_outputs_path)) 298 | else: 299 | if self.args.saved_tfidf_retrieval_outputs_path: 300 | tfidf_retrieval_output = json.load( 301 | open(self.args.saved_tfidf_retrieval_outputs_path)) 302 | else: 303 | tfidf_retrieval_output = self.retrieve(eval_questions) 304 | 305 | selector_output = self.select(tfidf_retrieval_output) 306 | 307 | # read and extract answers from reasoning paths 308 | reader_output, titles = self.read(selector_output) 309 | 310 | if self.args.sequential_sentence_selector_path is None: 311 | return tfidf_retrieval_output, selector_output, reader_output 312 | else: 313 | reader_output_sp = [{'q_id': s['q_id'], 314 | 'question': s['question'], 315 | 'answer': reader_output[s['q_id']], 316 | 'context': titles[s['q_id']]} for s in selector_output] 317 | sp_selector_output = self.explain(reader_output_sp) 318 | print(sp_selector_output) 319 | return tfidf_retrieval_output, selector_output, reader_output, sp_selector_output 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | # FIXME: temporary using SQuAD's eval scripts. HotpotQA using different official scripts. 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import random 9 | import jsonlines 10 | from retriever.doc_db import DocDB 11 | 12 | 13 | def normalize_answer(s): 14 | """Lower text and remove punctuation, articles and extra whitespace.""" 15 | def remove_articles(text): 16 | return re.sub(r'\b(a|an|the)\b', ' ', text) 17 | 18 | def white_space_fix(text): 19 | return ' '.join(text.split()) 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return ''.join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 29 | 30 | 31 | def f1_score(prediction, ground_truth): 32 | prediction_tokens = normalize_answer(prediction).split() 33 | ground_truth_tokens = normalize_answer(ground_truth).split() 34 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 35 | num_same = sum(common.values()) 36 | if num_same == 0: 37 | return 0 38 | precision = 1.0 * num_same / len(prediction_tokens) 39 | recall = 1.0 * num_same / len(ground_truth_tokens) 40 | f1 = (2 * precision * recall) / (precision + recall) 41 | return f1 42 | 43 | 44 | def exact_match_score(prediction, ground_truth): 45 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 46 | 47 | 48 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 49 | scores_for_ground_truths = [] 50 | for ground_truth in ground_truths: 51 | score = metric_fn(prediction, ground_truth) 52 | scores_for_ground_truths.append(score) 53 | return max(scores_for_ground_truths) 54 | 55 | def evaluate(eval_file_path, predictions, quiet=False, multiple_gts=True): 56 | eval_data = read_jsonlines(eval_file_path) 57 | f1 = exact_match = total = 0 58 | 59 | for qa in eval_data: 60 | q_id = qa['id'] 61 | if str(q_id) not in predictions: 62 | print("q_id: {0} is missing.".format(q_id)) 63 | continue 64 | if multiple_gts is True: 65 | ground_truths = qa['answers'] 66 | else: 67 | ground_truths = qa['answer'] 68 | prediction = predictions[q_id] 69 | exact_match += metric_max_over_ground_truths( 70 | exact_match_score, prediction, ground_truths) 71 | f1 += metric_max_over_ground_truths( 72 | f1_score, prediction, ground_truths) 73 | total += 1 74 | 75 | exact_match = 100.0 * exact_match / total 76 | f1 = 100.0 * f1 / total 77 | 78 | return {'exact_match': exact_match, 'f1': f1} 79 | 80 | 81 | def f1_score_normalized(prediction, ground_truth): 82 | normalized_prediction = normalize_answer(prediction) 83 | normalized_ground_truth = normalize_answer(ground_truth) 84 | 85 | ZERO_METRIC = (0, 0, 0) 86 | 87 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 88 | return ZERO_METRIC 89 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 90 | return ZERO_METRIC 91 | 92 | prediction_tokens = normalized_prediction.split() 93 | ground_truth_tokens = normalized_ground_truth.split() 94 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 95 | num_same = sum(common.values()) 96 | if num_same == 0: 97 | return ZERO_METRIC 98 | precision = 1.0 * num_same / len(prediction_tokens) 99 | recall = 1.0 * num_same / len(ground_truth_tokens) 100 | f1 = (2 * precision * recall) / (precision + recall) 101 | return f1, precision, recall 102 | 103 | 104 | def update_answer(metrics, prediction, gold): 105 | em = exact_match_score(prediction, gold) 106 | f1, prec, recall = f1_score_normalized(prediction, gold) 107 | metrics['em'] += float(em) 108 | metrics['f1'] += f1 109 | metrics['prec'] += prec 110 | metrics['recall'] += recall 111 | return em, prec, recall 112 | 113 | def update_sp(metrics, prediction, gold): 114 | print(prediction) 115 | cur_sp_pred = set(map(tuple, prediction)) 116 | gold_sp_pred = set(map(tuple, gold)) 117 | tp, fp, fn = 0, 0, 0 118 | for e in cur_sp_pred: 119 | if e in gold_sp_pred: 120 | tp += 1 121 | else: 122 | fp += 1 123 | for e in gold_sp_pred: 124 | if e not in cur_sp_pred: 125 | fn += 1 126 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 127 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 128 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 129 | em = 1.0 if fp + fn == 0 else 0.0 130 | metrics['sp_em'] += em 131 | metrics['sp_f1'] += f1 132 | metrics['sp_prec'] += prec 133 | metrics['sp_recall'] += recall 134 | return em, prec, recall 135 | 136 | 137 | def convert_qa_sp_results_into_hp_eval_format(reader_output, sp_selector_output, db_path): 138 | db = DocDB(db_path) 139 | sp_dict = {} 140 | 141 | for sp_pred in sp_selector_output: 142 | q_id = sp_pred["q_id"] 143 | sp_dict[q_id] = [] 144 | sp_fact_pred = sp_pred["supporting facts"] 145 | 146 | for title in sp_fact_pred: 147 | orig_title = db.get_original_title(title) 148 | for sent_pred in sp_fact_pred[title]: 149 | sp_dict[q_id].append([orig_title, sent_pred[0]]) 150 | 151 | return {"answer": reader_output, "sp": sp_dict} 152 | 153 | def evaluate_w_sp_facts(eval_file_path, prediction, sampled=False): 154 | with open(eval_file_path) as f: 155 | gold = json.load(f) 156 | 157 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 158 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 159 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} 160 | for dp in gold: 161 | cur_id = dp['_id'] 162 | can_eval_joint = True 163 | if cur_id not in prediction['answer']: 164 | can_eval_joint = False 165 | if sampled is False: 166 | print('missing answer {}'.format(cur_id)) 167 | else: 168 | em, prec, recall = update_answer( 169 | metrics, prediction['answer'][cur_id], dp['answer']) 170 | if cur_id not in prediction['sp']: 171 | can_eval_joint = False 172 | if sampled is False: 173 | print('missing answer {}'.format(cur_id)) 174 | else: 175 | sp_em, sp_prec, sp_recall = update_sp( 176 | metrics, prediction['sp'][cur_id], dp['supporting_facts']) 177 | 178 | if can_eval_joint: 179 | joint_prec = prec * sp_prec 180 | joint_recall = recall * sp_recall 181 | if joint_prec + joint_recall > 0: 182 | joint_f1 = 2 * joint_prec * joint_recall / \ 183 | (joint_prec + joint_recall) 184 | else: 185 | joint_f1 = 0. 186 | joint_em = em * sp_em 187 | 188 | metrics['joint_em'] += joint_em 189 | metrics['joint_f1'] += joint_f1 190 | metrics['joint_prec'] += joint_prec 191 | metrics['joint_recall'] += joint_recall 192 | 193 | if sampled is True: 194 | N = len(prediction["answer"]) 195 | else: 196 | N = len(gold) 197 | for k in metrics.keys(): 198 | metrics[k] /= N 199 | 200 | return metrics 201 | 202 | def read_jsonlines(eval_file_name): 203 | lines = [] 204 | print("loading examples from {0}".format(eval_file_name)) 205 | with jsonlines.open(eval_file_name) as reader: 206 | for obj in reader: 207 | lines.append(obj) 208 | return lines 209 | 210 | 211 | -------------------------------------------------------------------------------- /graph_retriever/README.md: -------------------------------------------------------------------------------- 1 | # Graph-based Recurrent Retriever 2 | 3 | This directory includes codes for our graph-based recurrent retriever model described in Section 3.1.1 of our paper. 4 | 5 | Table of contents: 6 | - 1. Training 7 | - 2. Inference (optional) 8 | 9 | ## 1. Training 10 | We use `run_graph_retriever.py` to train the model. 11 | We first need to prepare training data for each task; in our paper, we ued HotpotQA, HotpotQA distractor, Natural Questions Open, and SQuAD Open. 12 | To train the model with the same settings used in our paper, we explain some of the most important arguments below. 13 | See each example for more details about the values of the arguments. 14 | 15 | Note: it is not possible to perfectly reproduce the experimental results even with exactly the same settings, due to device or environmental differences. 16 | When we trained the model for the SQuAD Open dataset five times with five different random seeds, the standard deviation of the final QA EM score was around `0.5`. 17 | Therefore, you would occasionally face a different score by ~1% when changing random seeds or something. 18 | In our paper, we reported all the results based on the default seed (42) in the BERT code base. 19 | 20 | - `--task`
21 | This is a required argument to specify which dataset we use. 22 | Currently, the valid values are `hotpot_open`, `hotpot_distractor`, `squad`, and `nq`. 23 | 24 | - `--train_file_path`
25 | This has to be specified for training. 26 | This can be either a singe file or a set of split files; for the latter case, we simply use the `glob` package to read all the files associated with a partial file path. 27 | For example, if we have three files named `./data_1.json`, `./data_2.json`, and `./data_3.json`, then `--train_file_path ./data_` allows you to load all the three files by `glob.glob(./data_*)`. 28 | This is useful when the single data file is too big and we want to split it into smaller files. 29 | 30 | - `--output_dir`
31 | This is a directory path to save model checkpoints; a checkpoint is saved every half epoch during training. 32 | The checkpoint files are `pytorch_model_0.5.bin`, `pytorch_model_1.bin`, `pytorch_model_1.5.bin`, etc. 33 | 34 | - `--max_para_num`
35 | This is the number of paragraphs associated with a question. 36 | If `--max_para_num` is `N` and the number of the ground-truth paragraphs is `2` for the question, then there are `N-2` paragraphs as negative examples for training. 37 | We expect higher accuracy with larger values of `N`, but there is a trade-off with the training time. 38 | 39 | - `--tfidf_limit`
40 | This is specifically used for HotpotQA, where we use negative examples from both TF-IDF-based and hyperlink-based paragraphs. 41 | If `--max_para_num` is `N` and `--tfidf_limit` is `M` (`N` >= `M`), then there are `M` TF-IDF-based negative examples and `N-M` hyperlink-based negative examples. 42 | 43 | - `--neg_chunk`
44 | This is used to control GPU memory consumption. 45 | Our model training needs to handle many paragraphs for a question with BERT, so it is not feasible to run forward/backward functions all together. 46 | To resolve this issue, we have this argument to split the negative examples into small chunks, where the chunk size can be specified with this argument. 47 | We used NVIDIA V100 GPUs (with 16GB memory) for our experiments, and `--neg_chunk 8` works. 48 | For other GPUs with less memory, please consider using smaller number for this argument. 49 | It should be noted that, changing this value does not affect the results; our model does not use the softmax normalization, and thus we can run the forward/backward functions separately for each chunk. 50 | 51 | - `--train_batch_size` & `--gradient_accumulation_steps`
52 | These control the mini-batch size and how often we update the model parameters with the optimizer. 53 | More importantly, these depend on how many GPUs we can use. 54 | Due to the model size of BERT and the number of paragraphs to be processed, one GPU can handle one example. 55 | That means, `--train_batch_size 4` and `--gradient_accumulation_steps 4` work on a single GPU, but `--train_batch_size 4` and `--gradient_accumulation_steps 1` do not work due to OOM. 56 | However, if we have four GPUs, for example, the latter setting works because the four examples can be handled by the four GPUs. 57 | 58 | - `--use_redundant` and `--use_multiple_redundant`
59 | These are used to use the data augmentation technique for the sake of robustness. 60 | `--use_redundant` allows you to use one additional training example for a question, and `--use_multiple_redundant` allows you to use multiple examepls. 61 | To further specify how many examples can be used for the training, you can specify `--max_redundant_num`. 62 | 63 | - `max_select_num`
64 | This is set to specify the maximum number of reasoning steps in our model. 65 | This value should be `K+1`, where `K` is the number of ground-truth paragraphs, and `1` is for the EOE symbol. 66 | You further need to add `1` when using the `--use_redundant` option. 67 | For example, for HotpotQA, `K` is 2 and we used the `--use_redundant` option, and then the total value is `4`. 68 | 69 | - `--example_limit`
70 | This allows you to sanity-check your running the code, by limiting the number of examples (i.e., questions) to load from each file. 71 | This is useful in checking if everything goes well in your environments. 72 | 73 | ### HotpotQA 74 | For HotpotQA, we used up tp 50 paragraphs for each question, and among them up to 40 paragraphs are from the TF-IDF retriever. 75 | The other 10 paragraphs are from hyperlinks. 76 | The following command assumes the use of four V100 GPUs, but with other devices, you might modify `--neg_chunk` and `--gradient_accumulation_steps`. 77 | 78 | ```bash 79 | python run_graph_retriever.py \ 80 | --task hotpot_open \ 81 | --bert_model bert-base-uncased --do_lower_case \ 82 | --train_file_path \ 83 | --output_dir \ 84 | --max_para_num 50 \ 85 | --tfidf_limit 40 \ 86 | --neg_chunk 8 --train_batch_size 4 --gradient_accumulation_steps 1 \ 87 | --learning_rate 3e-5 --num_train_epochs 3 \ 88 | --use_redundant \ 89 | --max_select_num 4 \ 90 | ``` 91 | 92 | You can use the files in `hotpotqa_new_selector_train_data_db_2017_10_12_fix.zip` for `--train_file_path`. 93 | 94 | ### HotpotQA distractor 95 | For HotpotQA distractor, `--max_para_num` is always 10, due to the task setting. 96 | The following command assumes the use of one V100 GPU, but with other devices, you might modify `--neg_chunk`. 97 | 98 | ```bash 99 | python run_graph_retriever.py \ 100 | --task hotpot_distractor \ 101 | --bert_model bert-base-uncased --do_lower_case \ 102 | --train_file_path \ 103 | --output_dir \ 104 | --max_para_num 10 \ 105 | --neg_chunk 8 --train_batch_size 4 --gradient_accumulation_steps 4 \ 106 | --learning_rate 3e-5 --num_train_epochs 3 \ 107 | --max_select_num 3 108 | ``` 109 | 110 | You can use `hotpot_train_order_sensitive.json` for `--train_file_path`. 111 | 112 | ### Natural Questions Open 113 | For Natural Questions, we used up to 80 paragraphs for each question; changing this to 50 or so does not make significant difference, but in general, the more negative examples, the better (at least, not worse). 114 | To encourage the multi-step nature, we used the `--use_multiple_redundant` option with a larger mini-batch size, because the number of training examples is significantly increased. 115 | The following command assumes the use of four V100 GPUs, but with other devices, you might modify `--neg_chunk` and `--gradient_accumulation_steps`. 116 | 117 | ```bash 118 | python run_graph_retriever.py \ 119 | --task nq \ 120 | --bert_model bert-base-uncased --do_lower_case \ 121 | --train_file_path \ 122 | --output_dir \ 123 | --max_para_num 80 \ 124 | --neg_chunk 8 --train_batch_size 8 --gradient_accumulation_steps 2 \ 125 | --learning_rate 2e-5 --num_train_epochs 3 \ 126 | --use_multiple_redundant \ 127 | --max_select_num 3 \ 128 | ``` 129 | 130 | You can use the files in `nq_selector_train.tar.gz` for `--train_file_path`. 131 | 132 | ### SQuAD Open 133 | For SQuAD, we used up to 50 paragraphs for each question. 134 | The following command assumes the use of four V100 GPUs, but with other devices, you might modify `--neg_chunk` and `--gradient_accumulation_steps`. 135 | 136 | ```bash 137 | python run_graph_retriever.py \ 138 | --task squad \ 139 | --bert_model bert-base-uncased --do_lower_case \ 140 | --train_file_path \ 141 | --output_dir \ 142 | --max_para_num 50 \ 143 | --neg_chunk 8 --train_batch_size 4 --gradient_accumulation_steps 1 \ 144 | --learning_rate 2e-5 --num_train_epochs 3 \ 145 | --max_select_num 2 146 | ``` 147 | 148 | You can use `squad_tfidf_rgs_train_tfidf_top_negative_example.json` for `--train_file_path`. 149 | 150 | ## 2. Inference 151 | The trained models can be evaluated in our pipelined evaluation script. 152 | However, we can also use `run_graph_retriever.py` to run trained models for inference. 153 | This is in particular useful in sanity checking retrieval accuracy for HotpotQA distractor. 154 | To run the models with the same settings used in our paper, we explain some of the most important arguments below. 155 | Some of th arguments are also used in the pipelined evaluation. 156 | 157 | - `--dev_file_path`
158 | This has to be specified for evaluation or inference. 159 | The semantics is the same as `--train_file_path` in that you can use either a single file or a set of multiple files. 160 | 161 | - `--pred_file`
162 | This has to be specified for evaluation or inference. 163 | This is a file path to save the model's prediction results to be used for the next reading step. 164 | Note that, if you used split files for `--dev_file_path`, you may need to merge the output json files later. 165 | 166 | - `--output_dir` and `--model_suffix`
167 | This is based on `--output_dir` for training. 168 | To load a trained model from `pytorch_model_1.5.bin`, you need to set `--model_suffix 1.5`. 169 | 170 | - `--max_para_num`
171 | This is used to specify the maximum number of paragraphs, including hyper-linked ones, for each question. 172 | This is typically set to a large number to cover all the possible paragraphs. 173 | 174 | - `--beam`
175 | This is used to specify a beam size for our beam search algorithm to retrieve reasoning paths. 176 | 177 | - `--pruning_by_links`
178 | This option is used to do pruning during the beam search, based on hyper-links. 179 | 180 | - `--exapnd_links`
181 | This options is used to add within-document links, along with the default hyper-links on Wikipedia. 182 | 183 | - `--no_links`
184 | This option is used to avoid using the link information, to see how effective the use of the links is. 185 | 186 | - `--tagme`
187 | This is used to add TagMe-based paragraphs for each question, for better initial retrieval. 188 | 189 | - `--eval_chunk`
190 | This option's purpose is similar to that of `--neg_chunk`. 191 | If an evaluation file is too big, it would not fit in CPU RAM, and by this option we can specify a chunk size to run evaluation by avoiding processing all the evaluation examples together. 192 | 193 | - `--split_chunk`
194 | This is useful to control the use of GPU RAM for the BERT encoding. 195 | This is the number of paragraphs to be encoded by BERT together. 196 | The smaller the value is, the less GPU memory is consumed. 197 | 198 | ### HotpotQA 199 | 200 | ```bash 201 | python run_graph_retriever.py \ 202 | --task hotpot_open \ 203 | --bert_model bert-base-uncased --do_lower_case \ 204 | --dev_file_path \ 205 | --pred_file \ 206 | --output_dir \ 207 | --model_suffix 2 \ 208 | --max_para_num 2500 \ 209 | --beam 8 \ 210 | --pruning_by_links \ 211 | --eval_chunk 500 \ 212 | --split_chunk 300 213 | ``` 214 | 215 | ### HotpotQA distractor 216 | 217 | ```bash 218 | python run_graph_retriever.py \ 219 | --task hotpot_distractor \ 220 | --bert_model bert-base-uncased --do_lower_case \ 221 | --dev_file_path \ 222 | --pred_file \ 223 | --output_dir \ 224 | --model_suffix 2 \ 225 | --max_para_num 10 \ 226 | --beam 8 \ 227 | --pruning_by_links \ 228 | --eval_chunk 500 \ 229 | --split_chunk 300 230 | ``` 231 | 232 | You can use `hotpot_fake_sq_dev_new.json` for `--train_file_path`. 233 | 234 | ### Natural Questions Open 235 | 236 | ```bash 237 | python run_graph_retriever.py \ 238 | --task nq \ 239 | --bert_model bert-base-uncased --do_lower_case \ 240 | --dev_file_path \ 241 | --pred_file \ 242 | --output_dir \ 243 | --model_suffix 2 \ 244 | --max_para_num 2000 \ 245 | --beam 8 \ 246 | --pruning_by_links \ 247 | --expand_links \ 248 | --tagme \ 249 | --eval_chunk 500 \ 250 | --split_chunk 300 251 | ``` 252 | 253 | ### SQuAD Open 254 | 255 | ```bash 256 | python run_graph_retriever.py \ 257 | --task squad \ 258 | --bert_model bert-base-uncased --do_lower_case \ 259 | --dev_file_path \ 260 | --pred_file \ 261 | --output_dir \ 262 | --model_suffix 2 \ 263 | --max_para_num 500 \ 264 | --beam 8 \ 265 | --no_links \ 266 | --eval_chunk 500 \ 267 | --split_chunk 300 268 | ``` 269 | -------------------------------------------------------------------------------- /graph_retriever/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/graph_retriever/__init__.py -------------------------------------------------------------------------------- /graph_retriever/modeling_graph_retriever.py: -------------------------------------------------------------------------------- 1 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | 8 | try: 9 | from graph_retriever.utils import tokenize_question 10 | from graph_retriever.utils import tokenize_paragraph 11 | from graph_retriever.utils import expand_links 12 | except: 13 | from utils import tokenize_question 14 | from utils import tokenize_paragraph 15 | from utils import expand_links 16 | 17 | class BertForGraphRetriever(BertPreTrainedModel): 18 | 19 | def __init__(self, config, graph_retriever_config): 20 | super(BertForGraphRetriever, self).__init__(config) 21 | 22 | self.graph_retriever_config = graph_retriever_config 23 | 24 | self.bert = BertModel(config) 25 | 26 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 27 | 28 | # Initial state 29 | self.s = Parameter(torch.FloatTensor(config.hidden_size).uniform_(-0.1, 0.1)) 30 | 31 | # Scaling factor for weight norm 32 | self.g = Parameter(torch.FloatTensor(1).fill_(1.0)) 33 | 34 | # RNN weight 35 | self.rw = nn.Linear(2*config.hidden_size, config.hidden_size) 36 | 37 | # EOE and output bias 38 | self.eos = Parameter(torch.FloatTensor(config.hidden_size).uniform_(-0.1, 0.1)) 39 | self.bias = Parameter(torch.FloatTensor(1).zero_()) 40 | 41 | self.apply(self.init_bert_weights) 42 | self.cpu = torch.device('cpu') 43 | 44 | ''' 45 | state: (B, 1, D) 46 | ''' 47 | def weight_norm(self, state): 48 | state = state / state.norm(dim = 2).unsqueeze(2) 49 | state = self.g * state 50 | return state 51 | 52 | ''' 53 | input_ids, token_type_ids, attention_mask: (B, N, L) 54 | B: batch size 55 | N: maximum number of Q-P pairs 56 | L: maximum number of input tokens 57 | ''' 58 | def encode(self, input_ids, token_type_ids, attention_mask, split_chunk = None): 59 | B = input_ids.size(0) 60 | N = input_ids.size(1) 61 | L = input_ids.size(2) 62 | input_ids = input_ids.contiguous().view(B*N, L) 63 | token_type_ids = token_type_ids.contiguous().view(B*N, L) 64 | attention_mask = attention_mask.contiguous().view(B*N, L) 65 | 66 | # [CLS] vectors for Q-P pairs 67 | if split_chunk is None: 68 | encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 69 | pooled_output = encoded_layers[:, 0] 70 | 71 | # an option to reduce GPU memory consumption at eval time, by splitting all the Q-P pairs into smaller chunks 72 | else: 73 | assert type(split_chunk) == int 74 | 75 | TOTAL = input_ids.size(0) 76 | start = 0 77 | 78 | while start < TOTAL: 79 | end = min(start+split_chunk-1, TOTAL-1) 80 | chunk_len = end-start+1 81 | 82 | input_ids_ = input_ids[start:start+chunk_len, :] 83 | token_type_ids_ = token_type_ids[start:start+chunk_len, :] 84 | attention_mask_ = attention_mask[start:start+chunk_len, :] 85 | 86 | encoded_layers, pooled_output_ = self.bert(input_ids_, token_type_ids_, attention_mask_, output_all_encoded_layers=False) 87 | encoded_layers = encoded_layers[:, 0] 88 | 89 | if start == 0: 90 | pooled_output = encoded_layers 91 | else: 92 | pooled_output = torch.cat((pooled_output, encoded_layers), dim = 0) 93 | 94 | start = end+1 95 | 96 | pooled_output = pooled_output.contiguous() 97 | 98 | paragraphs = pooled_output.view(pooled_output.size(0)//N, N, pooled_output.size(1)) # (B, N, D), D: BERT dim 99 | EOE = self.eos.unsqueeze(0).unsqueeze(0) # (1, 1, D) 100 | EOE = EOE.expand(paragraphs.size(0), EOE.size(1), EOE.size(2)) # (B, 1, D) 101 | EOE = self.bert.encoder.layer[-1].output.LayerNorm(EOE) 102 | paragraphs = torch.cat((paragraphs, EOE), dim = 1) # (B, N+1, D) 103 | 104 | # Initial state 105 | state = self.s.expand(paragraphs.size(0), 1, self.s.size(0)) 106 | state = self.weight_norm(state) 107 | 108 | return paragraphs, state 109 | 110 | ''' 111 | input_ids, token_type_ids, attention_mask: (B, N, L) 112 | - B: batch size 113 | - N: maximum number of Q-P pairs 114 | - L: maximum number of input tokens 115 | 116 | output_mask, target: (B, max_num_steps, N+1) 117 | ''' 118 | def forward(self, input_ids, token_type_ids, attention_mask, output_mask, target, max_num_steps): 119 | 120 | paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask) 121 | 122 | for i in range(max_num_steps): 123 | if i == 0: 124 | h = state 125 | else: 126 | input = paragraphs[:, i-1:i, :] # (B, 1, D) 127 | state = torch.cat((state, input), dim = 2) # (B, 1, 2*D) 128 | state = self.rw(state) # (B, 1, D) 129 | state = self.weight_norm(state) 130 | h = torch.cat((h, state), dim = 1) # ...--> (B, max_num_steps, D) 131 | 132 | h = self.dropout(h) 133 | output = torch.bmm(h, paragraphs.transpose(1, 2)) # (B, max_num_steps, N+1) 134 | output = output + self.bias 135 | 136 | loss = F.binary_cross_entropy_with_logits(output, target, weight = output_mask, reduction = 'mean') 137 | return loss 138 | 139 | def beam_search(self, input_ids, token_type_ids, attention_mask, examples, tokenizer, retriever, split_chunk): 140 | beam = self.graph_retriever_config.beam 141 | B = input_ids.size(0) 142 | N = self.graph_retriever_config.max_para_num 143 | 144 | pred = [] 145 | prob = [] 146 | 147 | topk_pred = [] 148 | topk_prob = [] 149 | 150 | eos_index = N 151 | 152 | init_paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask, split_chunk = split_chunk) 153 | 154 | # Output matrix to be populated 155 | ps = torch.FloatTensor(N+1, self.s.size(0)).zero_().to(self.s.device) # (N+1, D) 156 | 157 | for i in range(B): 158 | init_context_len = len(examples[i].context) 159 | 160 | # Populating the output matrix by the initial encoding 161 | ps[:init_context_len, :].copy_(init_paragraphs[i, :init_context_len, :]) 162 | ps[-1, :].copy_(init_paragraphs[i, -1, :]) 163 | encoded_titles = set(examples[i].title_order) 164 | 165 | pred_ = [[[], [], 1.0] for _ in range(beam)] # [hist_1, topk_1, score_1], [hist_2, topk_2, score_2], ... 166 | prob_ = [[] for _ in range(beam)] 167 | 168 | state_ = state[i:i+1] # (1, 1, D) 169 | state_ = state_.expand(beam, 1, state_.size(2)) # -> (beam, 1, D) 170 | state_tmp = torch.FloatTensor(state_.size()).zero_().to(state_.device) 171 | 172 | for j in range(self.graph_retriever_config.max_select_num): 173 | if j > 0: 174 | input = [p[0][-1] for p in pred_] 175 | input = torch.LongTensor(input).to(ps.device) 176 | input = ps[input].unsqueeze(1) # (beam, 1, D) 177 | state_ = torch.cat((state_, input), dim = 2) # (beam, 1, 2*D) 178 | state_ = self.rw(state_) # (beam, 1, D) 179 | state_ = self.weight_norm(state_) 180 | 181 | # Opening new links from the previous predictions (pupulating the output matrix dynamically) 182 | if j > 0: 183 | prev_title_size = len(examples[i].title_order) 184 | new_titles = [] 185 | for b in range(beam): 186 | prev_pred = pred_[b][0][-1] 187 | 188 | if prev_pred == eos_index: 189 | continue 190 | 191 | prev_title = examples[i].title_order[prev_pred] 192 | 193 | if prev_title not in examples[i].all_linked_paras_dic: 194 | 195 | if retriever is None: 196 | continue 197 | else: 198 | linked_paras_dic = retriever.get_hyperlinked_abstract_paragraphs( 199 | prev_title, examples[i].question) 200 | examples[i].all_linked_paras_dic[prev_title] = {} 201 | examples[i].all_linked_paras_dic[prev_title].update(linked_paras_dic) 202 | examples[i].all_paras.update(linked_paras_dic) 203 | 204 | for linked_title in examples[i].all_linked_paras_dic[prev_title]: 205 | if linked_title in encoded_titles or len(examples[i].title_order) == N: 206 | continue 207 | 208 | encoded_titles.add(linked_title) 209 | new_titles.append(linked_title) 210 | examples[i].title_order.append(linked_title) 211 | 212 | if len(new_titles) > 0: 213 | 214 | tokens_q = tokenize_question(examples[i].question, tokenizer) 215 | input_ids = [] 216 | input_masks = [] 217 | segment_ids = [] 218 | for linked_title in new_titles: 219 | linked_para = examples[i].all_paras[linked_title] 220 | 221 | input_ids_, input_masks_, segment_ids_ = tokenize_paragraph(linked_para, tokens_q, self.graph_retriever_config.max_seq_length, tokenizer) 222 | input_ids.append(input_ids_) 223 | input_masks.append(input_masks_) 224 | segment_ids.append(segment_ids_) 225 | 226 | input_ids = torch.LongTensor([input_ids]).to(ps.device) 227 | token_type_ids = torch.LongTensor([segment_ids]).to(ps.device) 228 | attention_mask = torch.LongTensor([input_masks]).to(ps.device) 229 | 230 | paragraphs, _ = self.encode(input_ids, token_type_ids, attention_mask, split_chunk = split_chunk) 231 | paragraphs = paragraphs.squeeze(0) 232 | ps[prev_title_size:prev_title_size+len(new_titles)].copy_(paragraphs[:len(new_titles), :]) 233 | 234 | if retriever is not None and self.graph_retriever_config.expand_links: 235 | expand_links(examples[i].all_paras, examples[i].all_linked_paras_dic, examples[i].all_paras) 236 | 237 | output = torch.bmm(state_, ps.unsqueeze(0).expand(beam, ps.size(0), ps.size(1)).transpose(1, 2)) # (beam, 1, N+1) 238 | output = output + self.bias 239 | output = torch.sigmoid(output) 240 | 241 | output = output.to(self.cpu) 242 | 243 | if j == 0: 244 | output[:, :, len(examples[i].context):] = 0.0 245 | else: 246 | if len(examples[i].title_order) < N: 247 | output[:, :, len(examples[i].title_order):N] = 0.0 248 | for b in range(beam): 249 | 250 | # Omitting previous predictions 251 | for k in range(len(pred_[b][0])): 252 | output[b, :, pred_[b][0][k]] = 0.0 253 | 254 | # Links & topK-based pruning 255 | if self.graph_retriever_config.pruning_by_links: 256 | if pred_[b][0][-1] == eos_index: 257 | output[b, :, :eos_index] = 0.0 258 | output[b, :, eos_index] = 1.0 259 | 260 | elif examples[i].title_order[pred_[b][0][-1]] not in examples[i].all_linked_paras_dic: 261 | for k in range(len(examples[i].title_order)): 262 | if k not in pred_[b][1]: 263 | output[b, :, k] = 0.0 264 | 265 | else: 266 | for k in range(len(examples[i].title_order)): 267 | if k not in pred_[b][1] and examples[i].title_order[k] not in examples[i].all_linked_paras_dic[examples[i].title_order[pred_[b][0][-1]]]: 268 | output[b, :, k] = 0.0 269 | 270 | # always >= M before EOS 271 | if j <= self.graph_retriever_config.min_select_num-1: 272 | output[:, :, -1] = 0.0 273 | 274 | 275 | score = [p[2] for p in pred_] 276 | score = torch.FloatTensor(score) 277 | score = score.unsqueeze(1).unsqueeze(2) # (beam, 1, 1) 278 | score = output * score 279 | 280 | output = output.squeeze(1) # (beam, N+1) 281 | score = score.squeeze(1) # (beam, N+1) 282 | new_pred_ = [] 283 | new_prob_ = [] 284 | 285 | b = 0 286 | while b < beam: 287 | s, p = torch.max(score.view(score.size(0)*score.size(1)), dim = 0) 288 | s = s.item() 289 | p = p.item() 290 | row = p // score.size(1) 291 | col = p % score.size(1) 292 | 293 | if j == 0: 294 | score[:, col] = 0.0 295 | else: 296 | score[row, col] = 0.0 297 | 298 | p = [[index for index in pred_[row][0]] + [col], 299 | output[row].topk(k = 2, dim = 0)[1].tolist(), 300 | s] 301 | new_pred_.append(p) 302 | 303 | p = [[p_ for p_ in prb] for prb in prob_[row]] + [output[row].tolist()] 304 | new_prob_.append(p) 305 | 306 | state_tmp[b].copy_(state_[row]) 307 | b += 1 308 | 309 | pred_ = new_pred_ 310 | prob_ = new_prob_ 311 | state_ = state_.clone() 312 | state_.copy_(state_tmp) 313 | 314 | if pred_[0][0][-1] == eos_index: 315 | break 316 | 317 | topk_pred.append([]) 318 | topk_prob.append([]) 319 | for index__ in range(beam): 320 | 321 | pred_tmp = [] 322 | for index in pred_[index__][0]: 323 | if index == eos_index: 324 | break 325 | pred_tmp.append(index) 326 | 327 | if index__ == 0: 328 | pred.append(pred_tmp) 329 | prob.append(prob_[0]) 330 | 331 | topk_pred[-1].append(pred_tmp) 332 | topk_prob[-1].append(prob_[index__]) 333 | 334 | return pred, prob, topk_pred, topk_prob 335 | -------------------------------------------------------------------------------- /img/odqa_overview-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/img/odqa_overview-1.png -------------------------------------------------------------------------------- /pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/pipeline/__init__.py -------------------------------------------------------------------------------- /pipeline/graph_retriever.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | 4 | import torch 5 | 6 | from graph_retriever.utils import InputExample 7 | from graph_retriever.utils import InputFeatures 8 | from graph_retriever.utils import tokenize_question 9 | from graph_retriever.utils import tokenize_paragraph 10 | from graph_retriever.utils import GraphRetrieverConfig 11 | from graph_retriever.utils import expand_links 12 | from graph_retriever.modeling_graph_retriever import BertForGraphRetriever 13 | 14 | from pytorch_pretrained_bert.tokenization import BertTokenizer 15 | 16 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 17 | 18 | import logging 19 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 20 | datefmt = '%m/%d/%Y %H:%M:%S', 21 | level = logging.INFO) 22 | logger = logging.getLogger(__name__) 23 | 24 | def create_examples(jsn, graph_retriever_config): 25 | 26 | task = graph_retriever_config.task 27 | 28 | examples = [] 29 | 30 | ''' 31 | Find the mximum size of the initial context (links are not included) 32 | ''' 33 | graph_retriever_config.max_context_size = 0 34 | 35 | for data in jsn: 36 | 37 | guid = data['q_id'] 38 | question = data['question'] 39 | context = data['context'] # {context title: paragraph} 40 | all_linked_paras_dic = {} # {context title: {linked title: paragraph}} 41 | 42 | ''' 43 | Use TagMe-based context at test time. 44 | ''' 45 | if graph_retriever_config.tagme: 46 | assert 'tagged_context' in data 47 | 48 | ''' 49 | Reformat "tagged_context" if needed (c.f. the "context" case above) 50 | ''' 51 | if type(data['tagged_context']) == list: 52 | tagged_context = {c[0]: c[1] for c in data['tagged_context']} 53 | data['tagged_context'] = tagged_context 54 | 55 | ''' 56 | Append valid paragraphs from "tagged_context" to "context" 57 | ''' 58 | for tagged_title in data['tagged_context']: 59 | tagged_text = data['tagged_context'][tagged_title] 60 | if tagged_title not in context and tagged_title is not None and tagged_title.strip() != '' and tagged_text is not None and tagged_text.strip() != '': 61 | context[tagged_title] = tagged_text 62 | 63 | ''' 64 | Clean "context" by removing invalid paragraphs 65 | ''' 66 | removed_keys = [] 67 | for title in context: 68 | if title is None or title.strip() == '' or context[title] is None or context[title].strip() == '': 69 | removed_keys.append(title) 70 | for key in removed_keys: 71 | context.pop(key) 72 | 73 | all_paras = {} 74 | for title in context: 75 | all_paras[title] = context[title] 76 | 77 | if graph_retriever_config.expand_links: 78 | expand_links(context, all_linked_paras_dic, all_paras) 79 | 80 | graph_retriever_config.max_context_size = max(graph_retriever_config.max_context_size, len(context)) 81 | 82 | examples.append(InputExample(guid = guid, 83 | q = question, 84 | c = context, 85 | para_dic = all_linked_paras_dic, 86 | s_g = None, r_g = None, all_r_g = None, 87 | all_paras = all_paras)) 88 | 89 | return examples 90 | 91 | def convert_examples_to_features(examples, max_seq_length, max_para_num, graph_retriever_config, tokenizer): 92 | """Loads a data file into a list of `InputBatch`s.""" 93 | 94 | max_para_num = graph_retriever_config.max_context_size 95 | graph_retriever_config.max_para_num = max(graph_retriever_config.max_para_num, max_para_num) 96 | 97 | max_steps = graph_retriever_config.max_select_num 98 | 99 | DUMMY = [0] * max_seq_length 100 | features = [] 101 | 102 | for (ex_index, example) in enumerate(examples): 103 | tokens_q = tokenize_question(example.question, tokenizer) 104 | 105 | title2index = {} 106 | input_ids = [] 107 | input_masks = [] 108 | segment_ids = [] 109 | 110 | titles_list = list(example.context.keys()) 111 | for p in titles_list: 112 | 113 | if len(input_ids) == max_para_num: 114 | break 115 | 116 | if p in title2index: 117 | continue 118 | 119 | title2index[p] = len(title2index) 120 | example.title_order.append(p) 121 | p = example.context[p] 122 | 123 | input_ids_, input_masks_, segment_ids_ = tokenize_paragraph(p, tokens_q, max_seq_length, tokenizer) 124 | input_ids.append(input_ids_) 125 | input_masks.append(input_masks_) 126 | segment_ids.append(segment_ids_) 127 | 128 | num_paragraphs_no_links = len(input_ids) 129 | 130 | assert len(input_ids) <= max_para_num 131 | 132 | num_paragraphs = len(input_ids) 133 | 134 | output_masks = [([1.0] * len(input_ids) + [0.0] * (max_para_num - len(input_ids) + 1)) for _ in range(max_para_num + 2)] 135 | 136 | assert len(example.context) == num_paragraphs_no_links 137 | for i in range(len(output_masks[0])): 138 | if i >= num_paragraphs_no_links: 139 | output_masks[0][i] = 0.0 140 | 141 | for i in range(len(input_ids)): 142 | output_masks[i+1][i] = 0.0 143 | 144 | padding = [DUMMY] * (max_para_num - len(input_ids)) 145 | input_ids += padding 146 | input_masks += padding 147 | segment_ids += padding 148 | 149 | features.append( 150 | InputFeatures(input_ids=input_ids, 151 | input_masks=input_masks, 152 | segment_ids=segment_ids, 153 | output_masks = output_masks, 154 | num_paragraphs = num_paragraphs, 155 | num_steps = -1, 156 | ex_index = ex_index)) 157 | 158 | return features 159 | 160 | class GraphRetriever: 161 | def __init__(self, 162 | args, 163 | device): 164 | 165 | self.graph_retriever_config = GraphRetrieverConfig(example_limit = None, 166 | task = None, 167 | max_seq_length = args.max_seq_length, 168 | max_select_num = args.max_select_num, 169 | max_para_num = args.max_para_num, 170 | tfidf_limit = None, 171 | 172 | train_file_path = None, 173 | use_redundant = None, 174 | use_multiple_redundant = None, 175 | max_redundant_num = None, 176 | 177 | dev_file_path = None, 178 | beam = args.beam_graph_retriever, 179 | min_select_num = args.min_select_num, 180 | no_links = args.no_links, 181 | pruning_by_links = args.pruning_by_links, 182 | expand_links = args.expand_links, 183 | eval_chunk = args.eval_chunk, 184 | tagme = args.tagme, 185 | topk = args.topk, 186 | db_save_path = None) 187 | 188 | print('initializing GraphRetriever...', flush=True) 189 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_graph_retriever, do_lower_case=args.do_lower_case) 190 | model_state_dict = torch.load(args.graph_retriever_path) 191 | self.model = BertForGraphRetriever.from_pretrained(args.bert_model_graph_retriever, state_dict=model_state_dict, graph_retriever_config = self.graph_retriever_config) 192 | self.device = device 193 | self.model.to(self.device) 194 | self.model.eval() 195 | print('Done!', flush=True) 196 | 197 | def predict(self, 198 | tfidf_retrieval_output, 199 | retriever, 200 | args): 201 | 202 | pred_output = [] 203 | 204 | eval_examples = create_examples(tfidf_retrieval_output, self.graph_retriever_config) 205 | 206 | TOTAL_NUM = len(eval_examples) 207 | eval_start_index = 0 208 | 209 | while eval_start_index < TOTAL_NUM: 210 | eval_end_index = min(eval_start_index+self.graph_retriever_config.eval_chunk-1, TOTAL_NUM-1) 211 | chunk_len = eval_end_index - eval_start_index + 1 212 | 213 | features = convert_examples_to_features(eval_examples[eval_start_index:eval_start_index+chunk_len], args.max_seq_length, args.max_para_num, self.graph_retriever_config, self.tokenizer) 214 | 215 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 216 | all_input_masks = torch.tensor([f.input_masks for f in features], dtype=torch.long) 217 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 218 | all_output_masks = torch.tensor([f.output_masks for f in features], dtype=torch.float) 219 | all_num_paragraphs = torch.tensor([f.num_paragraphs for f in features], dtype=torch.long) 220 | all_num_steps = torch.tensor([f.num_steps for f in features], dtype=torch.long) 221 | all_ex_indices = torch.tensor([f.ex_index for f in features], dtype=torch.long) 222 | eval_data = TensorDataset(all_input_ids, all_input_masks, all_segment_ids, all_output_masks, all_num_paragraphs, all_num_steps, all_ex_indices) 223 | 224 | eval_sampler = SequentialSampler(eval_data) 225 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 226 | logger.info('Examples from '+str(eval_start_index)+' to '+str(eval_end_index)) 227 | for input_ids, input_masks, segment_ids, output_masks, num_paragraphs, num_steps, ex_indices in tqdm(eval_dataloader, desc="Evaluating"): 228 | 229 | batch_max_len = input_masks.sum(dim = 2).max().item() 230 | batch_max_para_num = num_paragraphs.max().item() 231 | batch_max_steps = num_steps.max().item() 232 | 233 | input_ids = input_ids[:, :batch_max_para_num, :batch_max_len] 234 | input_masks = input_masks[:, :batch_max_para_num, :batch_max_len] 235 | segment_ids = segment_ids[:, :batch_max_para_num, :batch_max_len] 236 | output_masks = output_masks[:, :batch_max_para_num+2, :batch_max_para_num+1] 237 | output_masks[:, 1:, -1] = 1.0 # Ignore EOS in the first step 238 | 239 | input_ids = input_ids.to(self.device) 240 | input_masks = input_masks.to(self.device) 241 | segment_ids = segment_ids.to(self.device) 242 | output_masks = output_masks.to(self.device) 243 | 244 | examples = [eval_examples[eval_start_index+ex_indices[i].item()] for i in range(input_ids.size(0))] 245 | 246 | with torch.no_grad(): 247 | pred, prob, topk_pred, topk_prob = self.model.beam_search(input_ids, segment_ids, input_masks, examples = examples, tokenizer = self.tokenizer, retriever = retriever, split_chunk = args.split_chunk) 248 | 249 | for i in range(len(pred)): 250 | e = examples[i] 251 | 252 | titles = [e.title_order[p] for p in pred[i]] 253 | question = e.question 254 | 255 | pred_output.append({}) 256 | pred_output[-1]['q_id'] = e.guid 257 | 258 | pred_output[-1]['question'] = question 259 | 260 | topk_titles = [[e.title_order[p] for p in topk_pred[i][j]] for j in range(len(topk_pred[i]))] 261 | pred_output[-1]['topk_titles'] = topk_titles 262 | 263 | topk_probs = [] 264 | pred_output[-1]['topk_probs'] = topk_probs 265 | 266 | context = {} 267 | context_from_tfidf = set() 268 | context_from_hyperlink = set() 269 | for ts in topk_titles: 270 | for t in ts: 271 | context[t] = e.all_paras[t] 272 | 273 | if t in e.context: 274 | context_from_tfidf.add(t) 275 | else: 276 | context_from_hyperlink.add(t) 277 | 278 | pred_output[-1]['context'] = context 279 | pred_output[-1]['context_from_tfidf'] = list(context_from_tfidf) 280 | pred_output[-1]['context_from_hyperlink'] = list(context_from_hyperlink) 281 | 282 | eval_start_index = eval_end_index + 1 283 | del features 284 | del all_input_ids 285 | del all_input_masks 286 | del all_segment_ids 287 | del all_output_masks 288 | del all_num_paragraphs 289 | del all_num_steps 290 | del all_ex_indices 291 | del eval_data 292 | 293 | return pred_output 294 | -------------------------------------------------------------------------------- /pipeline/reader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_pretrained_bert.tokenization import BertTokenizer 4 | 5 | from reader.modeling_reader import BertForQuestionAnsweringConfidence 6 | from reader.rc_utils import read_squad_style_hotpot_examples, \ 7 | convert_examples_to_features, write_predictions_yes_no_beam 8 | 9 | import collections 10 | 11 | from tqdm import tqdm 12 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 13 | 14 | RawResult = collections.namedtuple("RawResult", 15 | ["unique_id", "start_logits", "end_logits", "switch_logits"]) 16 | 17 | class Reader: 18 | def __init__(self, 19 | args, 20 | device): 21 | 22 | print('initializing Reader...', flush=True) 23 | self.model = BertForQuestionAnsweringConfidence.from_pretrained(args.reader_path, num_labels=4, no_masking=True) 24 | self.tokenizer = BertTokenizer.from_pretrained(args.reader_path, args.do_lower_case) 25 | self.device = device 26 | 27 | self.model.to(device) 28 | self.model.eval() 29 | print('Done!', flush=True) 30 | 31 | def convert_retriever_output(self, 32 | retriever_output): 33 | 34 | selected_paras_top_n = {str(item["q_id"]): item["topk_titles"] 35 | for item in retriever_output} 36 | 37 | context_dic = {str(item["q_id"]): item["context"] 38 | for item in retriever_output} 39 | 40 | squad_style_data = {'data': [], 'version': '1.1'} 41 | 42 | retrieved_para_dict = {} 43 | 44 | for data in retriever_output: 45 | example_id = data['q_id'] 46 | question_text = data['question'] 47 | pred_para_titles = selected_paras_top_n[example_id] 48 | 49 | for selected_paras in pred_para_titles: 50 | title, context = "", "" 51 | 52 | for para_title in selected_paras: 53 | paragraphs = context_dic[example_id][para_title] 54 | context += paragraphs 55 | 56 | title = para_title 57 | context += " " 58 | # post process to remove unnecessary spaces. 59 | if context[0] == " ": 60 | context = context[1:] 61 | if context[-1] == " ": 62 | context = context[: -1] 63 | 64 | context = context.replace(" ", " ") 65 | 66 | squad_example = {'context': context, 'para_titles': selected_paras, 67 | 'qas': [{'question': question_text, 'id': example_id}]} 68 | squad_style_data["data"].append( 69 | {'title': title, 'paragraphs': [squad_example]}) 70 | 71 | return squad_style_data 72 | 73 | def predict(self, 74 | retriever_output, 75 | args): 76 | 77 | squad_style_data = self.convert_retriever_output(retriever_output) 78 | 79 | e = read_squad_style_hotpot_examples(squad_style_hotpot_dev=squad_style_data, 80 | is_training=False, 81 | version_2_with_negative=False, 82 | store_path_prob=False) 83 | 84 | features = convert_examples_to_features( 85 | examples=e, 86 | tokenizer=self.tokenizer, 87 | max_seq_length=args.max_seq_length, 88 | doc_stride=args.doc_stride, 89 | max_query_length=args.max_query_length, 90 | is_training=False, 91 | quiet = True) 92 | 93 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 94 | all_input_masks = torch.tensor([f.input_mask for f in features], dtype=torch.long) 95 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 96 | eval_data = TensorDataset(all_input_ids, all_input_masks, all_segment_ids) 97 | 98 | eval_sampler = SequentialSampler(eval_data) 99 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 100 | 101 | all_results = [] 102 | 103 | f_offset = 0 104 | for input_ids, input_masks, segment_ids in tqdm(eval_dataloader, desc="Evaluating"): 105 | input_ids = input_ids.to(self.device) 106 | input_masks = input_masks.to(self.device) 107 | segment_ids = segment_ids.to(self.device) 108 | with torch.no_grad(): 109 | batch_start_logits, batch_end_logits, batch_switch_logits = self.model(input_ids, segment_ids, input_masks) 110 | 111 | for i in range(input_ids.size(0)): 112 | start_logits = batch_start_logits[i].detach().cpu().tolist() 113 | end_logits = batch_end_logits[i].detach().cpu().tolist() 114 | switch_logits = batch_switch_logits[i].detach().cpu().tolist() 115 | eval_feature = features[f_offset+i] 116 | unique_id = int(features[f_offset+i].unique_id) 117 | all_results.append(RawResult(unique_id=unique_id, 118 | start_logits=start_logits, 119 | end_logits=end_logits, 120 | switch_logits=switch_logits)) 121 | f_offset += input_ids.size(0) 122 | 123 | return write_predictions_yes_no_beam(e, features, all_results, 124 | args.n_best_size, args.max_answer_length, 125 | args.do_lower_case, None, 126 | None, None, False, 127 | False, None, 128 | output_selected_paras=True, 129 | quiet = True) 130 | -------------------------------------------------------------------------------- /pipeline/sequential_sentence_selector.py: -------------------------------------------------------------------------------- 1 | from sequential_sentence_selector.modeling_sequential_sentence_selector import BertForSequentialSentenceSelector 2 | from sequential_sentence_selector.run_sequential_sentence_selector import InputExample 3 | from sequential_sentence_selector.run_sequential_sentence_selector import InputFeatures 4 | from sequential_sentence_selector.run_sequential_sentence_selector import DataProcessor 5 | from sequential_sentence_selector.run_sequential_sentence_selector import convert_examples_to_features 6 | 7 | from pytorch_pretrained_bert.tokenization import BertTokenizer 8 | 9 | import torch 10 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 11 | 12 | from tqdm import tqdm 13 | 14 | class SequentialSentenceSelector: 15 | def __init__(self, 16 | args, 17 | device): 18 | 19 | if args.sequential_sentence_selector_path is None: 20 | return None 21 | 22 | print('initializing SequentialSentenceSelector...', flush=True) 23 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_sequential_sentence_selector, do_lower_case=args.do_lower_case) 24 | model_state_dict = torch.load(args.sequential_sentence_selector_path) 25 | self.model = BertForSequentialSentenceSelector.from_pretrained(args.bert_model_sequential_sentence_selector, state_dict=model_state_dict) 26 | self.device = device 27 | self.model.to(self.device) 28 | self.model.eval() 29 | 30 | self.processor = DataProcessor() 31 | print('Done!', flush=True) 32 | 33 | def convert_reader_output(self, 34 | reader_output, 35 | tfidf_retriever): 36 | new_output = [] 37 | 38 | for data in reader_output: 39 | entry = {} 40 | entry['q_id'] = data['q_id'] 41 | entry['question'] = data['question'] 42 | entry['answer'] = data['answer'] 43 | entry['titles'] = data['context'] 44 | entry['context'] = tfidf_retriever.load_abstract_para_text(entry['titles'], keep_sentence_split = True) 45 | entry['supporting_facts'] = {t: [] for t in entry['titles']} 46 | new_output.append(entry) 47 | 48 | return new_output 49 | 50 | def predict(self, 51 | reader_output, 52 | tfidf_retriever, 53 | args): 54 | 55 | reader_output = self.convert_reader_output(reader_output, tfidf_retriever) 56 | eval_examples = self.processor.create_examples(reader_output) 57 | eval_features = convert_examples_to_features( 58 | eval_examples, args.max_seq_length_sequential_sentence_selector, args.max_sent_num, args.max_sf_num, self.tokenizer) 59 | 60 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 61 | all_input_masks = torch.tensor([f.input_masks for f in eval_features], dtype=torch.long) 62 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 63 | all_output_masks = torch.tensor([f.output_masks for f in eval_features], dtype=torch.float) 64 | all_num_sents = torch.tensor([f.num_sents for f in eval_features], dtype=torch.long) 65 | all_num_sfs = torch.tensor([f.num_sfs for f in eval_features], dtype=torch.long) 66 | all_ex_indices = torch.tensor([f.ex_index for f in eval_features], dtype=torch.long) 67 | eval_data = TensorDataset(all_input_ids, 68 | all_input_masks, 69 | all_segment_ids, 70 | all_output_masks, 71 | all_num_sents, 72 | all_num_sfs, 73 | all_ex_indices) 74 | # Run prediction for full data 75 | eval_sampler = SequentialSampler(eval_data) 76 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 77 | 78 | pred_output = [] 79 | 80 | for input_ids, input_masks, segment_ids, output_masks, num_sents, num_sfs, ex_indices in tqdm(eval_dataloader, desc="Evaluating"): 81 | batch_max_len = input_masks.sum(dim = 2).max().item() 82 | batch_max_sent_num = num_sents.max().item() 83 | batch_max_sf_num = num_sfs.max().item() 84 | 85 | input_ids = input_ids[:, :batch_max_sent_num, :batch_max_len] 86 | input_masks = input_masks[:, :batch_max_sent_num, :batch_max_len] 87 | segment_ids = segment_ids[:, :batch_max_sent_num, :batch_max_len] 88 | output_masks = output_masks[:, :batch_max_sent_num+2, :batch_max_sent_num+1] 89 | 90 | output_masks[:, 1:, -1] = 1.0 # Ignore EOE in the first step 91 | 92 | input_ids = input_ids.to(self.device) 93 | input_masks = input_masks.to(self.device) 94 | segment_ids = segment_ids.to(self.device) 95 | output_masks = output_masks.to(self.device) 96 | 97 | examples = [eval_examples[ex_indices[i].item()] for i in range(input_ids.size(0))] 98 | 99 | with torch.no_grad(): 100 | pred, prob, topk_pred, topk_prob = self.model.beam_search(input_ids, segment_ids, input_masks, output_masks, max_num_steps = args.max_sf_num+1, examples = examples, beam = args.beam_sequential_sentence_selector) 101 | 102 | for i in range(len(pred)): 103 | e = examples[i] 104 | 105 | sfs = {} 106 | for p in pred[i]: 107 | offset = 0 108 | for j in range(len(e.titles)): 109 | if p >= offset and p < offset+len(e.context[e.titles[j]]): 110 | if e.titles[j] not in sfs: 111 | sfs[e.titles[j]] = [[p-offset, e.context[e.titles[j]][p-offset]]] 112 | else: 113 | sfs[e.titles[j]].append([p-offset, e.context[e.titles[j]][p-offset]]) 114 | break 115 | offset += len(e.context[e.titles[j]]) 116 | 117 | # Hack 118 | for title in e.titles: 119 | if title not in sfs and len(sfs) < 2: 120 | sfs[title] = [0] 121 | 122 | output = {} 123 | output['q_id'] = e.guid 124 | output['supporting facts'] = sfs 125 | pred_output.append(output) 126 | 127 | return pred_output 128 | 129 | -------------------------------------------------------------------------------- /pipeline/tfidf_retriever.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from retriever.doc_db import DocDB 4 | from retriever.tfidf_doc_ranker import TfidfDocRanker 5 | from retriever.tfidf_vectorizer_article import TopTfIdf 6 | 7 | from retriever.utils import load_para_collections_from_tfidf_id_intro_only, \ 8 | load_para_and_linked_titles_dict_from_tfidf_id, prune_top_k_paragraphs, \ 9 | normalize 10 | 11 | from retriever.tfidf_vectorizer_article import TopTfIdf 12 | 13 | import logging 14 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 15 | datefmt = '%m/%d/%Y %H:%M:%S', 16 | level = logging.INFO) 17 | logger = logging.getLogger(__name__) 18 | 19 | class TfidfRetriever: 20 | def __init__(self, 21 | db_save_path: str, 22 | tfidf_model_path: str, 23 | use_full_article: bool = False, 24 | pruning_l: int = 10): 25 | 26 | print('initializing TfidfRetriever...', flush=True) 27 | self.db = DocDB(db_save_path) 28 | if tfidf_model_path is not None: 29 | self.ranker = TfidfDocRanker(tfidf_path=tfidf_model_path, strict=False) 30 | self.use_full_article = use_full_article 31 | # FIXME: make this to dynamically set the pruning_l 32 | self.pruning_l = pruning_l 33 | print('Done!', flush=True) 34 | 35 | def store_title2hyperlink_dic(self, title2hyperlink_dic): 36 | self.title2hyperlink_dic = title2hyperlink_dic 37 | 38 | def load_abstract_para_text(self, 39 | doc_names, 40 | keep_sentence_split = False): 41 | 42 | context = {} 43 | 44 | for doc_name in doc_names: 45 | para_title_text_pairs = load_para_collections_from_tfidf_id_intro_only(doc_name, self.db) 46 | if len(para_title_text_pairs) == 0: 47 | logger.warning("{} is missing".format(doc_name)) 48 | continue 49 | else: 50 | para_title_text_pairs = para_title_text_pairs[0] 51 | if keep_sentence_split: 52 | para_title_text_pairs = {para_title_text_pairs[0]: para_title_text_pairs[1]} 53 | else: 54 | para_title_text_pairs = {para_title_text_pairs[0]: "".join(para_title_text_pairs[1])} 55 | context.update(para_title_text_pairs) 56 | 57 | return context 58 | 59 | # load sampled text included in the target articles, with two stage tfidf retrieval. 60 | def load_sampled_para_text_and_linked_titles(self, 61 | doc_names, 62 | question, 63 | pruning_l, 64 | prune_after_agg=True): 65 | 66 | context = {} 67 | linked_titles = {} 68 | tfidf_vectorizer = TopTfIdf(n_to_select=pruning_l, 69 | filter_dist_one=True, rank=True) 70 | para_dict_all = {} 71 | linked_titles_dict_all = {} 72 | for doc_name in doc_names: 73 | paras_dict, linked_titles_dict = load_para_and_linked_titles_dict_from_tfidf_id( 74 | doc_name, self.db) 75 | if len(paras_dict) == 0: 76 | continue 77 | if prune_after_agg is True: 78 | para_dict_all.update(paras_dict) 79 | linked_titles_dict_all.update(linked_titles_dict) 80 | else: 81 | pruned_para_dict = prune_top_k_paragraphs( 82 | question, paras_dict, tfidf_vectorizer, pruning_l) 83 | 84 | # add top pruning_l paragraphs from the target article. 85 | context.update(pruned_para_dict) 86 | # add hyperlinked paragraphs of the top pruning_l paragraphs from the target article. 87 | pruned_linked_titles = {k: v for k, v in linked_titles_dict.items() if k in pruned_para_dict} 88 | assert len(pruned_para_dict) == len(pruned_linked_titles) 89 | linked_titles.update(pruned_linked_titles) 90 | 91 | if prune_after_agg is True: 92 | pruned_para_dict = prune_top_k_paragraphs(question, para_dict_all, tfidf_vectorizer, pruning_l) 93 | context.update(pruned_para_dict) 94 | pruned_linked_titles = { 95 | k: v for k, v in linked_titles_dict_all.items() if k in pruned_para_dict} 96 | assert len(pruned_para_dict) == len(pruned_linked_titles) 97 | linked_titles.update(pruned_linked_titles) 98 | 99 | return context, linked_titles 100 | 101 | def retrieve_titles_w_tag_me(self, question, tagme_api_key): 102 | import tagme 103 | tagme.GCUBE_TOKEN = tagme_api_key 104 | q_annotations = tagme.annotate(question) 105 | tagged_titles = [] 106 | for ann in q_annotations.get_annotations(0.1): 107 | tagged_titles.append(ann.entity_title) 108 | return tagged_titles 109 | 110 | def load_sampled_tagged_para_text(self, question, pruning_l, tagme_api_key): 111 | tagged_titles = self.retrieve_titles_w_tag_me(question, tagme_api_key) 112 | tagged_doc_names = [normalize(title) for title in tagged_titles] 113 | 114 | context, _ = self.load_sampled_para_text_and_linked_titles( 115 | tagged_doc_names, question, pruning_l) 116 | 117 | return context 118 | 119 | def get_abstract_tfidf(self, 120 | q_id, 121 | question, 122 | args): 123 | 124 | doc_names, _ = self.ranker.closest_docs(question, k=args.tfidf_limit) 125 | # Add TFIDF close documents 126 | context = self.load_abstract_para_text(doc_names) 127 | 128 | return [{"question": question, 129 | "context": context, 130 | "q_id": q_id}] 131 | 132 | def get_article_tfidf_with_hyperlinked_titles(self, q_id,question, args): 133 | """ 134 | Retrieve articles with their corresponding hyperlinked titles. 135 | Due to efficiency, we sample top k articles, and then sample top l paragraphs from each article. 136 | (so, eventually we get k*l paragraphs with tfidf-based pruning.) 137 | We also store the hyperlinked titles for each paragraph. 138 | """ 139 | 140 | tfidf_limit, pruning_l, prune_after_agg = args.tfidf_limit, args.pruning_l, args.prune_after_agg 141 | doc_names, _ = self.ranker.closest_docs(question, k=tfidf_limit) 142 | context, hyper_linked_titles = self.load_sampled_para_text_and_linked_titles( 143 | doc_names, question, pruning_l, prune_after_agg) 144 | 145 | if args.tagme is True and args.tagme_api_key is not None: 146 | # if add TagMe 147 | tagged_context = self.load_sampled_tagged_para_text( 148 | question, pruning_l, args.tagme_api_key) 149 | 150 | return [{"question": question, 151 | "context": context, 152 | "tagged_context": tagged_context, 153 | "all_linked_para_title_dic": hyper_linked_titles, 154 | "q_id": q_id}] 155 | else: 156 | return [{"question": question, 157 | "context": context, 158 | "all_linked_para_title_dic": hyper_linked_titles, 159 | "q_id": q_id}] 160 | 161 | 162 | def get_hyperlinked_abstract_paragraphs(self, 163 | title: str, 164 | question: str = None): 165 | 166 | if self.use_full_article is True and self.title2hyperlink_dic is not None: 167 | if title not in self.title2hyperlink_dic: 168 | return {} 169 | hyper_linked_titles = self.title2hyperlink_dic[title] 170 | elif self.use_full_article is True and self.title2hyperlink_dic is None: 171 | # for full article version, we need to store title2hyperlink_dic beforehand. 172 | raise NotImplementedError() 173 | else: 174 | hyper_linked_titles = self.db.get_hyper_linked(normalize(title)) 175 | 176 | if hyper_linked_titles is None: 177 | return {} 178 | # if there are any hyperlinked titles, add the information to all_linked_paragraph 179 | all_linked_paras_dic = {} 180 | 181 | if self.use_full_article is True and self.title2hyperlink_dic is not None: 182 | for hyper_linked_para_title in hyper_linked_titles: 183 | paras_dict, _ = load_para_and_linked_titles_dict_from_tfidf_id( 184 | hyper_linked_para_title, self.db) 185 | # Sometimes article titles are updated over times but the hyperlinked titles are not. e.g., Winds <--> Wind 186 | # in our current database, we do not handle these "redirect" cases and thus we cannot recover. 187 | # If we cannot retrieve the hyperlinked articles, we just discard these articles. 188 | if len(paras_dict) == 0: 189 | continue 190 | tfidf_vectorizer = TopTfIdf(n_to_select=self.pruning_l, 191 | filter_dist_one=True, rank=True) 192 | pruned_para_dict = prune_top_k_paragraphs( 193 | question, paras_dict, tfidf_vectorizer, self.pruning_l) 194 | 195 | all_linked_paras_dic.update(pruned_para_dict) 196 | 197 | else: 198 | for hyper_linked_para_title in hyper_linked_titles: 199 | para_title_text_pairs = load_para_collections_from_tfidf_id_intro_only( 200 | hyper_linked_para_title, self.db) 201 | # Sometimes article titles are updated over times but the hyperlinked titles are not. e.g., Winds <--> Wind 202 | # in our current database, we do not handle these "redirect" cases and thus we cannot recover. 203 | # If we cannot retrieve the hyperlinked articles, we just discard these articles. 204 | if len(para_title_text_pairs) == 0: 205 | continue 206 | 207 | para_title_text_pairs = {para[0]: "".join(para[1]) 208 | for para in para_title_text_pairs} 209 | 210 | all_linked_paras_dic.update(para_title_text_pairs) 211 | 212 | return all_linked_paras_dic 213 | -------------------------------------------------------------------------------- /quick_start_hotpot.sh: -------------------------------------------------------------------------------- 1 | # download trained models 2 | mkdir models 3 | cd models 4 | gdown https://drive.google.com/uc?id=1ra37xtEXSROG_f90XxR4kgElGJWUHQyM 5 | unzip hotpot_models.zip 6 | rm hotpot_models.zip 7 | cd .. 8 | 9 | # download eval data 10 | mkdir data 11 | cd data 12 | mkdir hotpot 13 | cd hotpot 14 | gdown https://drive.google.com/uc?id=1m_7ZJtWQsZ8qDqtItDTWYlsEHDeVHbPt 15 | gdown https://drive.google.com/uc?id=1D-Uj4DPMZWkSouzw5Gg5YhkGiBHSqCJp 16 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json 17 | cd ../.. 18 | 19 | # run evaluation scripts 20 | python eval_main.py \ 21 | --eval_file_path data/hotpot/hotpot_fullwiki_first_100.jsonl \ 22 | --eval_file_path_sp data/hotpot/hotpot_dev_distractor_v1.json \ 23 | --graph_retriever_path models/hotpot_models/graph_retriever_path/pytorch_model.bin \ 24 | --reader_path models/hotpot_models/reader \ 25 | --sequential_sentence_selector_path models/hotpot_models/sequential_sentence_selector/pytorch_model.bin \ 26 | --tfidf_path models/hotpot_models/tfidf_retriever/wiki_open_full_new_db_intro_only-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \ 27 | --db_path models/hotpot_models/wiki_db/wiki_abst_only_hotpotqa_w_original_title.db \ 28 | --bert_model_sequential_sentence_selector bert-large-uncased --do_lower_case \ 29 | --tfidf_limit 500 --eval_batch_size 4 --pruning_by_links --beam_graph_retriever 8 \ 30 | --beam_sequential_sentence_selector 8 --max_para_num 2000 --sp_eval --sampled 31 | -------------------------------------------------------------------------------- /reader/README.md: -------------------------------------------------------------------------------- 1 | ## Reasoning Path Reader 2 | This directory includes codes for our reasoning path reader model described in Section 3.2 of our paper. 3 | Our reader model is based on BERT QA model ([Devlin et al. 2019](https://arxiv.org/abs/1810.04805)), and we extend it to jointly predict answer spans and plausibility of reasoning paths selected by our retriever components. 4 | 5 | Table of contents: 6 | - 1. Training 7 | - 2. Evaluation 8 | 9 | ## 1. Training 10 | ### Training data 11 | We use [rc_utils.py](rc_utils.py) to train our reasoning path reader models. 12 | To train our reader, we first convert the original MRC datasets into SQuAD (v.2) data format, adding distant examples and negative examples. 13 | 14 | We provide the pre-processed train and dev data files for all three datasets here (google drive): 15 | 16 | - [HotpotQA reader train data](https://drive.google.com/file/d/1BZXSZXN99Mb7--4u0x58cixBTon1PX8N/view?usp=sharing) 17 | - [SQuAD reader train data](https://drive.google.com/file/d/1aMTXIxYZCAC6sX5mZt6nytYxeKvjuigq/view?usp=sharing) 18 | - [Natural Questions train data](https://drive.google.com/file/d/1wUlRkC3_yJnEzdxduFE__yQSfWa_3l0j/view?usp=sharing) 19 | 20 | 21 | We explain some of the some required arguments below. 22 | 23 | - `--bert_model`
24 | This is a bert model type (e.g., `bert-base-uncased`). In our paper, we experiment both with `bert-base-uncased` and `bert-large-uncased-whole-word-masking`. 25 | 26 | - `--output_dir`
27 | This is a directory path to save model checkpoints; a checkpoint is saved every half epoch during training. 28 | 29 | - `--train_file`
30 | This is a file path to train data you can download from the link mentioned above. 31 | 32 | - `--version_2_with_negative`
33 | Please add this option to train our reader model with negative examples. 34 | 35 | - `--do_lower_case`
36 | We use lower-cased version of BERT following previous papers in machine reading comprehension. To reproduce the results, please add this option. 37 | 38 | There are some optional arguments; please see the full list from our [rc_utils.py](rc_utils.py). 39 | 40 | - `--train_batch_size`
41 | This is to specify the number of the batch size during training (default=`32`). 42 | *To train BERT large QA models, you are likely to reduce the number of train batch size (currently set to 32) to make it fit to your GPU memory.* 43 | 44 | - `--max_seq_length`
45 | This is to set the maximum length of input sequence and when the input exceeds the limits, we split the data into several windows. 46 | 47 | - `--predict_file`
48 | This is a file path to your inference data if you would like to evaluate the reader performance (See the details below). Your `predict_file` must be in SQuAD v.2 format like `train_file`. 49 | 50 | You can run training the command below. 51 | 52 | ```bash 53 | python run_reader_confidence.py \ 54 | --bert_model bert-base-uncased \ 55 | --output_dir /path/to/your/output/dir \ 56 | --train_file /path/to/your/train/file \ 57 | --predict_file /path/to/your/eval/file \ 58 | --max_seq_length 384 \ 59 | --do_train \ 60 | --do_predict \ 61 | --do_lower_case \ 62 | --version_2_with_negative 63 | ``` 64 | 65 | e.g., HotpotQA 66 | 67 | ```bash 68 | python run_reader_confidence.py \ 69 | --bert_model bert-base-uncased \ 70 | --output_dir output_hotpot_bert_base \ 71 | --train_file data/hotpot/hotpot_reader_train_data.json \ 72 | --predict_file data/hotpot/hotpot_dev_squad_v2.0_format.json \ 73 | --max_seq_length 384 \ 74 | --do_train \ 75 | --do_predict \ 76 | --do_lower_case \ 77 | --version_2_with_negative 78 | ``` 79 | 80 | ## 2. Evaluation 81 | As the main goal of this work is on improving open-domain QA performance, we recommend you running the pipeline to evaluate your reader performance. 82 | Alternatively, you can run sanity check on HotpotQA gold paragraph only settings. 83 | 84 | #### Sanity check on HotpotQA gold only setting 85 | For the sanity check, you can run the evaluation of the reader model performance on preprocessed dev file. 86 | 87 | The original HotpotQA questions contain 10 paragraphs, we discard the 8 distractor paragraphs and keep only gold paragraphs. The preprocessed data is also available [here](https://drive.google.com/open?id=1MysthH2TRYoJcK_eLOueoLeYR42T-JhB). 88 | 89 | You can download [the SQuAD 2.0 evaluation script](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/). 90 | 91 | **note: the F1 calculation of is slightly different from the original HotpotQA eval script. We use the SQuAD 2.0 evaluation script for quick sanity check. Please do not use the number to report the performance on HotpotQA.** 92 | 93 | You can run evaluation with the command below: 94 | 95 | ```bash 96 | python evaluate-v2.0.py \ 97 | /path/to/eval/file/hotpot_dev_squad_v2.0_format.json \ 98 | /path/to/your/output/dir/predictions.json 99 | ``` 100 | The F1/EM scores of the bert-base-uncased model on the gold-paragraph only HotpotQA distractor dev data is as follows: 101 | 102 | ```py 103 | { 104 | "exact": 60.60769750168805, 105 | "f1": 74.45707974099558, 106 | "total": 7405, 107 | "HasAns_exact": 60.60769750168805, 108 | "HasAns_f1": 74.45707974099558, 109 | "HasAns_total": 7405 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /reader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/reader/__init__.py -------------------------------------------------------------------------------- /reader/modeling_reader.py: -------------------------------------------------------------------------------- 1 | from pytorch_pretrained_bert.modeling import BertModel, BertPreTrainedModel 2 | from torch import nn 3 | from torch.nn import CrossEntropyLoss 4 | import torch 5 | 6 | 7 | class BERTLayerNorm(nn.Module): 8 | def __init__(self, config, variance_epsilon=1e-12): 9 | """Construct a layernorm module in the TF style (epsilon inside the square root). 10 | """ 11 | super(BERTLayerNorm, self).__init__() 12 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 13 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 14 | self.variance_epsilon = variance_epsilon 15 | 16 | def forward(self, x): 17 | u = x.mean(-1, keepdim=True) 18 | s = (x - u).pow(2).mean(-1, keepdim=True) 19 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 20 | return self.gamma * x + self.beta 21 | 22 | class BertForQuestionAnsweringConfidence(BertPreTrainedModel): 23 | 24 | def __init__(self, config, num_labels, no_masking, lambda_scale=1.0): 25 | super(BertForQuestionAnsweringConfidence, self).__init__(config) 26 | self.bert = BertModel(config) 27 | self.num_labels = num_labels 28 | self.no_masking = no_masking 29 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 30 | self.qa_outputs = nn.Linear( 31 | config.hidden_size, 2) # [N, L, H] => [N, L, 2] 32 | self.qa_classifier = nn.Linear( 33 | config.hidden_size, self.num_labels) # [N, H] => [N, n_class] 34 | self.lambda_scale = lambda_scale 35 | 36 | def init_weights(module): 37 | if isinstance(module, (nn.Linear, nn.Embedding)): 38 | module.weight.data.normal_( 39 | mean=0.0, std=config.initializer_range) 40 | elif isinstance(module, BERTLayerNorm): 41 | module.beta.data.normal_( 42 | mean=0.0, std=config.initializer_range) 43 | module.gamma.data.normal_( 44 | mean=0.0, std=config.initializer_range) 45 | if isinstance(module, nn.Linear): 46 | module.bias.data.zero_() 47 | 48 | self.apply(init_weights) 49 | 50 | def forward(self, input_ids, token_type_ids, attention_mask, 51 | start_positions=None, end_positions=None, switch_list=None): 52 | sequence_output, pooled_output = self.bert( 53 | input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 54 | 55 | # Calculate the sequence logits. 56 | logits = self.qa_outputs(sequence_output) 57 | start_logits, end_logits = logits.split(1, dim=-1) 58 | start_logits = start_logits.squeeze(-1) 59 | end_logits = end_logits.squeeze(-1) 60 | # Calculate the class logits 61 | pooled_output = self.dropout(pooled_output) 62 | switch_logits = self.qa_classifier(pooled_output) 63 | 64 | if start_positions is not None and end_positions is not None and switch_list is not None: 65 | # If we are on multi-GPU, split add a dimension 66 | if len(start_positions.size()) > 1: 67 | start_positions = start_positions.squeeze(-1) 68 | if len(end_positions.size()) > 1: 69 | end_positions = end_positions.squeeze(-1) 70 | 71 | ignored_index = start_logits.size(1) 72 | start_positions.clamp_(0, ignored_index) 73 | end_positions.clamp_(0, ignored_index) 74 | loss_fct = CrossEntropyLoss( 75 | ignore_index=ignored_index, reduce=False) 76 | 77 | # if no_masking is True, we do not mask the no-answer examples' 78 | # span losses. 79 | if self.no_masking is True: 80 | start_losses = loss_fct(start_logits, start_positions) 81 | end_losses = loss_fct(end_logits, end_positions) 82 | 83 | else: 84 | # You care about the span only when switch is 0 85 | span_mask = (switch_list == 0).type(torch.FloatTensor).cuda() 86 | start_losses = loss_fct( 87 | start_logits, start_positions) * span_mask 88 | end_losses = loss_fct(end_logits, end_positions) * span_mask 89 | 90 | switch_losses = loss_fct(switch_logits, switch_list) 91 | assert len(start_losses) == len( 92 | end_losses) == len(switch_losses) 93 | return self.lambda_scale * (start_losses + end_losses) + switch_losses 94 | 95 | elif start_positions is None or end_positions is None or switch_list is None: 96 | return start_logits, end_logits, switch_logits 97 | 98 | else: 99 | raise NotImplementedError() 100 | -------------------------------------------------------------------------------- /reader/modeling_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | 5 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, 6 | BertModel, BertTokenizer) 7 | 8 | MODEL_CLASSES = { 9 | 'bert': (BertConfig, BertModel, BertTokenizer), 10 | } 11 | 12 | def get_bert_model_from_pytorch_transformers(model_name): 13 | config_class, model_class, tokenizer_class = MODEL_CLASSES['bert'] 14 | config = config_class.from_pretrained(model_name) 15 | model = model_class.from_pretrained(model_name, from_tf=bool('.ckpt' in model_name), config=config) 16 | 17 | tokenizer = tokenizer_class.from_pretrained(model_name) 18 | 19 | vocab_file_name = './vocabulary_'+model_name+'.txt' 20 | 21 | if not os.path.exists(vocab_file_name): 22 | index = 0 23 | with open(vocab_file_name, "w", encoding="utf-8") as writer: 24 | for token, token_index in sorted(tokenizer.vocab.items(), key=lambda kv: kv[1]): 25 | if index != token_index: 26 | assert False 27 | index = token_index 28 | writer.write(token + u'\n') 29 | index += 1 30 | 31 | return model.state_dict(), vocab_file_name 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gdown 2 | glob2 3 | Jinja2 4 | joblib 5 | jsonlines 6 | nltk 7 | numpy 8 | pathlib2 9 | prettytable 10 | pytorch-pretrained-bert==0.6.2 11 | pytorch-transformers==1.2.0 12 | regex 13 | scipy 14 | scikit-learn 15 | spacy==2.2.3 16 | tagme==0.1.3 17 | torch==1.3.0 18 | tqdm 19 | ujson 20 | unicodecsv 21 | urllib3 -------------------------------------------------------------------------------- /retriever/README.md: -------------------------------------------------------------------------------- 1 | # TF-IDF based retriever 2 | For efficiency and scalability, we bootstrap retrieval process with a TF-IDF based retrieval. In particular, we first sample seed paragraphs based on TF-IDF scores and sequentially retrieve sub-graphs connected from the seed paragraphs. 3 | 4 | Table of contents: 5 | - 1. Preprocessing Wikipedia dump 6 | - 2. Building database 7 | - 3. Building the TF-IDF N-grams 8 | - 4. Interactive mode 9 | 10 | *Acknowledgement: The code bases are started from the amazing [DrQA](https://github.com/facebookresearch/DrQA)'s document retriever code. We really appreciate the efforts by the DrQA authors.* 11 | 12 | ## 1. Preprocessing Wikipedia dump 13 | 14 | #### download Wikipedia dumps 15 | First, you need to install the Wikipedia dump. 16 | 17 | - **HotpotQA**: You do not need to download Wikipedia dump by yourself, as the authors' provide the preprocessed dump in [HotpotQA official website](https://hotpotqa.github.io/wiki-readme.html). If you consider using our model for HotpotQA only, we recommend you downloading the [intro-paragraph only version](https://nlp.stanford.edu/projects/hotpotqa/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2). You can extract files as the instruction in the website. 18 | 19 | - **SQuAD**: Although you can use the same dump as in HotpotQA, we recommend you using the older dump and the DB distributed by the [DrQA repository](https://github.com/facebookresearch/DrQA/blob/master/download.sh). Wikipedia is frequently edited by users, and thus if you use the newer version, some answers (originally included in the context) are lost. Please refer the details of our finding in **Appendix B.5** in our paper. Although the DB does not preserve hyperlink information, we do not observe large performance information on SQuAD without link-based hop, as the questions are mostly single-hop or inner-article multi-hop. 20 | 21 | - **Natural Questions**: For Natural Questions, due to the similar reasons mentioned above, we recommend you using the [dump](https://archive.org/download/enwiki-20181220/enwiki-20181220-pages-meta-current.xml.bz2) from December, 2018, which also used in previous work on NQ Open ([Lee et al., 2019](https://arxiv.org/abs/1906.00300); [Min et al., 2019](https://arxiv.org/abs/1909.04849)). 22 | 23 | 24 | #### Extract articles 25 | As mentioned, you do not need to preprocess the dump by yourself for HotpotQA or SQuAD. If you attempt to experiment on Natural Questions or other Wikipedia dump by yourself, you need to extract the articles. 26 | 27 | 1. Install [wikiextractor](https://github.com/attardi/wikiextractor) 28 | 2. Run `Wikiextractor.py` with `--json` and `-l`. The first option make the output easy-to-read-and-process `jsonlines` format and the second option preserve the hyperlinks, which is crutial for our framework. 29 | 30 | 31 | ## 2. Building database 32 | 33 | To efficiently store and access our documents, we store them in a sqlite database. 34 | To create a sqlite db from a corpus of documents, run: 35 | 36 | ```bash 37 | python build_db.py /path/to/data /path/to/saved/db.db --hotpoqa_format 38 | ``` 39 | 40 | **Note** 41 | Do not forget `--hotpoqa_format` option when you process Wikipedia data for HotpotQA experiments. The HotpotQA authors kindly provide the preprocessed dump, and the titles and sentence separations should be consistent for supporting fact evaluations. 42 | 43 | For introductory paragraph only Wikipedia, the total number of the paragraphs stored into DB should be 5,233,329. 44 | 45 | ``` 46 | $ python build_db.py $PATH_TO_WIKI_DIR/enwiki-20171001-pages-meta-current-withlinks-abstracts enwiki_intro.db --intro_only 47 | 07/01/2019 11:31:51 PM: [ Reading into database... ] 48 | 100%|███████████████████████████████████████| 15517/15517 [01:47<00:00, 143.97it/s] 49 | 07/01/2019 11:33:39 PM: [ Read 5233329 docs. ] 50 | 07/01/2019 11:33:39 PM: [ Committing... ] 51 | ``` 52 | 53 | If you create the DB from the 2018/12/20 dump, the total number of articles will be 5,771,730. 54 | ``` 55 | $ python retriever/build_db.py $PATH_WIKI_DIR enwiki_20181220_all.db 56 | 02/01/2020 11:52:28 PM: [ Reading into database... ] 57 | 100%|██████████████████████████████████████| 16399/16399 [04:03<00:00, 67.28it/s] 58 | 02/01/2020 11:56:32 PM: [ Read 5771730 docs. ] 59 | 02/01/2020 11:56:32 PM: [ Committing... ] 60 | ``` 61 | 62 | **Note: the total number of paragraphs would be 30M and te DB size would be 27 GB.)** 63 | 64 | Optional arguments: 65 | ``` 66 | --preprocess File path to a python module that defines a `preprocess` function. 67 | --num-workers Number of CPU processes (for tokenizing, etc). 68 | ``` 69 | 70 | #### Keeping hyperlinks in `doc_text` 71 | Due to the nature of Wikipedia hyperlinks, a hyperlink connection is from an paragraph (source paragraph) to an article (target articles), although if we only consider introductory paragraphs, the relations are always paragraph-paragraph. 72 | 73 | To efficiently store the relationship, for the multiple paragraph settings (e.g., Natural Questions Open), we keep the hyperlink information in the `doc_text`. 74 | 75 | e.g., Seattle 76 | ``` 77 | Seattle is a seaport city on the West Coast of the United States. 78 | ``` 79 | 80 | ## 3. Building the TF-IDF N-grams 81 | 82 | To build a TF-IDF weighted word-doc sparse matrix from the documents stored in the sqlite db, run: 83 | 84 | ```bash 85 | python build_tfidf.py /path/to/doc/db.db /path/to/output/dir 86 | ``` 87 | 88 | e.g., 89 | ```bash 90 | python build_tfidf.py enwiki_intro.db tfidf_results_from_enwiki_intro_only/ 91 | ``` 92 | 93 | The sparse matrix and its associated metadata will be saved to the output directory under (i.e., `tfidf_results_from_enwiki_intro_only`) `-tfidf-ngram=-hash=-tokenizer=.npz`. 94 | 95 | 96 | Optional arguments: 97 | ``` 98 | --ngram Use up to N-size n-grams (e.g. 2 = unigrams + bigrams). By default only ngrams without stopwords or punctuation are kept. 99 | --hash-size Number of buckets to use for hashing ngrams. 100 | --tokenizer String option specifying tokenizer type to use (e.g. 'corenlp'). 101 | --num-workers Number of CPU processes (for tokenizing, etc). 102 | ``` 103 | 104 | **Note: If you build TFIDF matrix from the full Wikipedia paragraphs, it ends up consuming more than a lot of CPU memories, which your local machine might not accommodate. In that case, please use the DB & .npz files we distribute, our consider using a amchine with more memory.** 105 | 106 | 107 | ## 4. Interactive mode 108 | You can play with the TFIDF retriever with interactive mode :) 109 | If you set the `with_content=True` in the process function, you can see the paragraph as well as title. 110 | 111 | ```bash 112 | python scripts/retriever/interactive.py --model /path/to/model \ 113 | --db_save_path /path/to/db file 114 | ``` 115 | e.g., 116 | 117 | ```bash 118 | python interactive.py --model tfidf_wiki_abst/wiki_open_full_new_db_intro_only-tfidf-ngram\=2-hash\=16777216-tokenizer\=simple.npz \ 119 | --db_save_path wiki_open_full_new_db_intro_only.db 120 | ``` 121 | ``` 122 | >>> process('At what university can the building that served as the fictional household that includes Gomez and Morticia be found?', k=1, with_content=True) 123 | +------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 124 | | Rank | Doc Id | Doc Text | 125 | +------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 126 | | 1 | The Addams Family_0 | The Addams Family is a fictional household created by American cartoonist Charles Addams. The Addams Family characters have traditionally included Gomez and Morticia Addams, their children Wednesday and Pugsley, close family members Uncle Fester and Grandmama, their butler Lurch, the disembodied hand Thing, and Gomez's Cousin Itt. | 127 | +------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 128 | ``` 129 | -------------------------------------------------------------------------------- /retriever/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/retriever/__init__.py -------------------------------------------------------------------------------- /retriever/build_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # The codes are started from DrQA (https://github.com/facebookresearch/DrQA) library. 3 | """A script to read in and store documents in a sqlite database.""" 4 | 5 | import argparse 6 | import sqlite3 7 | import json 8 | import os 9 | import logging 10 | import importlib.util 11 | import glob 12 | 13 | from multiprocessing import Pool as ProcessPool 14 | from tqdm import tqdm 15 | 16 | try: 17 | from retriever.utils import normalize, process_jsonlines_hotpotqa, process_jsonlines 18 | except: 19 | from utils import normalize, process_jsonlines_hotpotqa, process_jsonlines 20 | 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.INFO) 23 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 24 | console = logging.StreamHandler() 25 | console.setFormatter(fmt) 26 | logger.addHandler(console) 27 | 28 | 29 | # ------------------------------------------------------------------------------ 30 | # Import helper 31 | # ------------------------------------------------------------------------------ 32 | 33 | 34 | PREPROCESS_FN = None 35 | 36 | 37 | def init(filename): 38 | global PREPROCESS_FN 39 | if filename: 40 | PREPROCESS_FN = import_module(filename).preprocess 41 | 42 | 43 | def import_module(filename): 44 | """Import a module given a full path to the file.""" 45 | spec = importlib.util.spec_from_file_location('doc_filter', filename) 46 | module = importlib.util.module_from_spec(spec) 47 | spec.loader.exec_module(module) 48 | return module 49 | 50 | 51 | # ------------------------------------------------------------------------------ 52 | # Store corpus. 53 | # ------------------------------------------------------------------------------ 54 | 55 | 56 | def iter_files(path): 57 | """Walk through all files located under a root path.""" 58 | if os.path.isfile(path): 59 | yield path 60 | elif os.path.isdir(path): 61 | for dirpath, _, filenames in os.walk(path): 62 | for f in filenames: 63 | yield os.path.join(dirpath, f) 64 | else: 65 | raise RuntimeError('Path %s is invalid' % path) 66 | 67 | 68 | def get_contents_hotpotqa(filename): 69 | """Parse the contents of a file. Each line is a JSON encoded document.""" 70 | global PREPROCESS_FN 71 | documents = [] 72 | extracted_items = process_jsonlines_hotpotqa(filename) 73 | for extracted_item in extracted_items: 74 | title = extracted_item["title"] 75 | text = extracted_item["plain_text"] 76 | original_title = extracted_item["original_title"] 77 | hyper_linked_titles = extracted_item["hyper_linked_titles"] 78 | 79 | documents.append((title, text, 80 | hyper_linked_titles, original_title)) 81 | return documents 82 | 83 | def get_contents(filename): 84 | """Parse the contents of a file. Each line is a JSON encoded document.""" 85 | global PREPROCESS_FN 86 | documents = [] 87 | extracted_items = process_jsonlines(filename) 88 | for extracted_item in extracted_items: 89 | title = extracted_item["title"] 90 | text = extracted_item["plain_text"] 91 | original_title = extracted_item["original_title"] 92 | hyper_linked_titles = extracted_item["hyper_linked_titles"] 93 | 94 | documents.append((title, text, 95 | hyper_linked_titles, original_title)) 96 | return documents 97 | 98 | def store_contents(wiki_dir, save_path, preprocess, num_workers=None, hotpotqa_format=False): 99 | """Preprocess and store a corpus of documents in sqlite. 100 | 101 | Args: 102 | data_path: Root path to directory (or directory of directories) of files 103 | containing json encoded documents (must have `id` and `text` fields). 104 | save_path: Path to output sqlite db. 105 | preprocess: Path to file defining a custom `preprocess` function. Takes 106 | in and outputs a structured doc. 107 | num_workers: Number of parallel processes to use when reading docs. 108 | """ 109 | filenames = [f for f in glob.glob( 110 | wiki_dir + "/*/wiki_*", recursive=True) if ".bz2" not in f] 111 | if os.path.isfile(save_path): 112 | raise RuntimeError('%s already exists! Not overwriting.' % save_path) 113 | 114 | logger.info('Reading into database...') 115 | conn = sqlite3.connect(save_path) 116 | c = conn.cursor() 117 | c.execute( 118 | "CREATE TABLE documents (id PRIMARY KEY, text, linked_title, original_title);") 119 | 120 | workers = ProcessPool(num_workers, initializer=init, 121 | initargs=(preprocess,)) 122 | count = 0 123 | # Due to the slight difference of input format between preprocessed HotpotQA wikipedia data and 124 | # the ones by Wikiextractor, we call different functions for data process. 125 | if hotpotqa_format is True: 126 | content_processing_method = get_contents_hotpotqa 127 | else: 128 | content_processing_method = get_contents 129 | 130 | with tqdm(total=len(filenames)) as pbar: 131 | for pairs in tqdm(workers.imap_unordered(content_processing_method, filenames)): 132 | count += len(pairs) 133 | c.executemany( 134 | "INSERT OR REPLACE INTO documents VALUES (?,?,?,?)", pairs) 135 | pbar.update() 136 | logger.info('Read %d docs.' % count) 137 | logger.info('Committing...') 138 | conn.commit() 139 | conn.close() 140 | 141 | # ------------------------------------------------------------------------------ 142 | # Main. 143 | # ------------------------------------------------------------------------------ 144 | 145 | 146 | if __name__ == '__main__': 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument('wiki_dir', type=str, help='/path/to/data') 149 | parser.add_argument('save_path', type=str, help='/path/to/saved/db.db') 150 | parser.add_argument('--preprocess', type=str, default=None, 151 | help=('File path to a python module that defines ' 152 | 'a `preprocess` function')) 153 | parser.add_argument('--num-workers', type=int, default=None, 154 | help='Number of CPU processes (for tokenizing, etc)') 155 | parser.add_argument('--hotpoqa_format', action='store_true', 156 | help='the input files are hotpotqa format.') 157 | args = parser.parse_args() 158 | 159 | store_contents( 160 | args.wiki_dir, args.save_path, args.preprocess, args.num_workers, 161 | args.hotpoqa_format 162 | ) 163 | -------------------------------------------------------------------------------- /retriever/build_tfidf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # The codes are based on DrQA (https://github.com/facebookresearch/DrQA) library. 3 | """A script to build the tf-idf document matrices for retrieval.""" 4 | 5 | import numpy as np 6 | import scipy.sparse as sp 7 | import argparse 8 | import os 9 | import math 10 | import logging 11 | 12 | from multiprocessing import Pool as ProcessPool 13 | from multiprocessing.util import Finalize 14 | from functools import partial 15 | from collections import Counter 16 | 17 | try: 18 | from retriever.tfidf_doc_ranker import TfidfDocRanker 19 | from retriever.doc_db import DocDB 20 | from retriever.tokenizers import SimpleTokenizer 21 | from retriever.utils import filter_ngram, hash, save_sparse_csr 22 | except: 23 | from tfidf_doc_ranker import TfidfDocRanker 24 | from doc_db import DocDB 25 | from tokenizers import SimpleTokenizer 26 | from utils import filter_ngram, hash, save_sparse_csr 27 | 28 | logger = logging.getLogger() 29 | logger.setLevel(logging.INFO) 30 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 31 | console = logging.StreamHandler() 32 | console.setFormatter(fmt) 33 | logger.addHandler(console) 34 | 35 | 36 | # ------------------------------------------------------------------------------ 37 | # Set retriever class 38 | # ------------------------------------------------------------------------------ 39 | 40 | 41 | def get_class(name): 42 | if name == 'tfidf': 43 | return TfidfDocRanker 44 | if name == 'sqlite': 45 | return DocDB 46 | raise RuntimeError('Invalid retriever class: %s' % name) 47 | 48 | # ------------------------------------------------------------------------------ 49 | # Multiprocessing functions 50 | # ------------------------------------------------------------------------------ 51 | 52 | 53 | DOC2IDX = None 54 | PROCESS_TOK = None 55 | PROCESS_DB = None 56 | 57 | 58 | def init(tokenizer_class, db_class, db_opts): 59 | global PROCESS_TOK, PROCESS_DB 60 | PROCESS_TOK = tokenizer_class() 61 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 62 | PROCESS_DB = db_class(**db_opts) 63 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 64 | 65 | 66 | def fetch_text(doc_id): 67 | global PROCESS_DB 68 | return PROCESS_DB.get_doc_text(doc_id) 69 | 70 | 71 | def fetch_text_multi_para(doc_id): 72 | global PROCESS_DB 73 | paras, _ = PROCESS_DB.get_doc_text_hyper_linked_titles_for_articles(doc_id) 74 | if len(paras) > 0: 75 | paras = "\n\n".join(paras) 76 | return paras 77 | 78 | 79 | def tokenize(text): 80 | global PROCESS_TOK 81 | return PROCESS_TOK.tokenize(text) 82 | 83 | 84 | # ------------------------------------------------------------------------------ 85 | # Build article --> word count sparse matrix. 86 | # ------------------------------------------------------------------------------ 87 | 88 | def count(ngram, hash_size, multi_para, doc_id): 89 | """Fetch the text of a document and compute hashed ngrams counts.""" 90 | global DOC2IDX 91 | # FIXME: remove hard coding. 92 | row, col, data = [], [], [] 93 | # Tokenize 94 | 95 | if multi_para is True: 96 | # 1. if multi_para is true, the doc contains multiple paragraphs separated by \n\n and with links. 97 | tokens = tokenize(fetch_text_multi_para(doc_id)) 98 | else: 99 | # 2. if not, only intro docs are retrieved and the sentences are separated by \t. 100 | # remove sentence separations ("\t") (only for HotpotQA). 101 | tokens = tokenize(fetch_text(doc_id).replace("\t", "")) 102 | 103 | # Get ngrams from tokens, with stopword/punctuation filtering. 104 | ngrams = tokens.ngrams( 105 | n=ngram, uncased=True, filter_fn=filter_ngram 106 | ) 107 | 108 | # Hash ngrams and count occurences 109 | counts = Counter([hash(gram, hash_size) 110 | for gram in ngrams]) 111 | 112 | # Return in sparse matrix data format. 113 | row.extend(counts.keys()) 114 | col.extend([DOC2IDX[doc_id]] * len(counts)) 115 | data.extend(counts.values()) 116 | return row, col, data 117 | 118 | 119 | def get_count_matrix(args, db, db_opts): 120 | """Form a sparse word to document count matrix (inverted index). 121 | 122 | M[i, j] = # times word i appears in document j. 123 | """ 124 | # Map doc_ids to indexes 125 | global DOC2IDX 126 | db_class = get_class(db) 127 | with db_class(**db_opts) as doc_db: 128 | doc_ids = doc_db.get_doc_ids() 129 | DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)} 130 | 131 | # Setup worker pool 132 | # TODO: Add tokenizer's choice. 133 | tok_class = SimpleTokenizer 134 | workers = ProcessPool( 135 | args.num_workers, 136 | initializer=init, 137 | initargs=(tok_class, db_class, db_opts) 138 | ) 139 | 140 | # Compute the count matrix in steps (to keep in memory) 141 | logger.info('Mapping...') 142 | row, col, data = [], [], [] 143 | step = max(int(len(doc_ids) / 10), 1) 144 | batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)] 145 | _count = partial(count, args.ngram, args.hash_size, args.multi_para) 146 | for i, batch in enumerate(batches): 147 | logger.info('-' * 25 + 'Batch %d/%d' % 148 | (i + 1, len(batches)) + '-' * 25) 149 | for b_row, b_col, b_data in workers.imap_unordered(_count, batch): 150 | row.extend(b_row) 151 | col.extend(b_col) 152 | data.extend(b_data) 153 | workers.close() 154 | workers.join() 155 | 156 | logger.info('Creating sparse matrix...') 157 | count_matrix = sp.csr_matrix( 158 | (data, (row, col)), shape=(args.hash_size, len(doc_ids)) 159 | ) 160 | count_matrix.sum_duplicates() 161 | return count_matrix, (DOC2IDX, doc_ids) 162 | 163 | 164 | # ------------------------------------------------------------------------------ 165 | # Transform count matrix to different forms. 166 | # ------------------------------------------------------------------------------ 167 | 168 | 169 | def get_tfidf_matrix(cnts): 170 | """Convert the word count matrix into tfidf one. 171 | 172 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 173 | * tf = term frequency in document 174 | * N = number of documents 175 | * Nt = number of occurences of term in all documents 176 | """ 177 | Ns = get_doc_freqs(cnts) 178 | idfs = np.log((cnts.shape[1] - Ns + 0.5) / (Ns + 0.5)) 179 | idfs[idfs < 0] = 0 180 | idfs = sp.diags(idfs, 0) 181 | tfs = cnts.log1p() 182 | tfidfs = idfs.dot(tfs) 183 | return tfidfs 184 | 185 | 186 | def get_doc_freqs(cnts): 187 | """Return word --> # of docs it appears in.""" 188 | binary = (cnts > 0).astype(int) 189 | freqs = np.array(binary.sum(1)).squeeze() 190 | return freqs 191 | 192 | 193 | # ------------------------------------------------------------------------------ 194 | # Main. 195 | # ------------------------------------------------------------------------------ 196 | 197 | 198 | if __name__ == '__main__': 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument('db_path', type=str, default=None, 201 | help='Path to sqlite db holding document texts') 202 | parser.add_argument('out_dir', type=str, default=None, 203 | help='Directory for saving output files') 204 | parser.add_argument('--ngram', type=int, default=2, 205 | help=('Use up to N-size n-grams ' 206 | '(e.g. 2 = unigrams + bigrams)')) 207 | parser.add_argument('--hash-size', type=int, default=int(math.pow(2, 24)), 208 | help='Number of buckets to use for hashing ngrams') 209 | parser.add_argument('--tokenizer', type=str, default='simple', 210 | help=("String option specifying tokenizer type to use " 211 | "(e.g. 'corenlp')")) 212 | parser.add_argument('--num-workers', type=int, default=None, 213 | help='Number of CPU processes (for tokenizing, etc)') 214 | parser.add_argument('--multi_para', action='store_true', 215 | help='set true if the db contains multiple paragraphs, not intro-paragraph only.') 216 | args = parser.parse_args() 217 | 218 | logging.info('Counting words...') 219 | count_matrix, doc_dict = get_count_matrix( 220 | args, 'sqlite', {'db_path': args.db_path} 221 | ) 222 | 223 | logger.info('Making tfidf vectors...') 224 | tfidf = get_tfidf_matrix(count_matrix) 225 | 226 | logger.info('Getting word-doc frequencies...') 227 | freqs = get_doc_freqs(count_matrix) 228 | 229 | basename = os.path.splitext(os.path.basename(args.db_path))[0] 230 | basename += ('-tfidf-ngram=%d-hash=%d-tokenizer=%s' % 231 | (args.ngram, args.hash_size, args.tokenizer)) 232 | 233 | # check if output_dir exists; if not, create the output_dir. 234 | if not os.path.exists(args.out_dir): 235 | os.makedirs(args.out_dir) 236 | filename = os.path.join(args.out_dir, basename) 237 | 238 | logger.info('Saving to %s.npz' % filename) 239 | metadata = { 240 | 'doc_freqs': freqs, 241 | 'tokenizer': args.tokenizer, 242 | 'hash_size': args.hash_size, 243 | 'ngram': args.ngram, 244 | 'doc_dict': doc_dict 245 | } 246 | save_sparse_csr(filename, tfidf, metadata) 247 | -------------------------------------------------------------------------------- /retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import argparse 3 | import time 4 | try: 5 | from retriever.utils import find_hyper_linked_titles, remove_tags, normalize 6 | except: 7 | from utils import find_hyper_linked_titles, remove_tags, normalize 8 | 9 | class DocDB(object): 10 | """Sqlite backed document storage. 11 | 12 | Implements get_doc_text(doc_id). 13 | """ 14 | 15 | def __init__(self, db_path=None): 16 | self.path = db_path 17 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 18 | 19 | def __enter__(self): 20 | return self 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def close(self): 26 | """Close the connection to the database.""" 27 | self.connection.close() 28 | 29 | def get_doc_ids(self): 30 | """Fetch all ids of docs stored in the db.""" 31 | cursor = self.connection.cursor() 32 | cursor.execute("SELECT id FROM documents") 33 | results = [r[0] for r in cursor.fetchall()] 34 | cursor.close() 35 | return results 36 | 37 | def get_doc_text(self, doc_id): 38 | """Fetch the raw text of the doc for 'doc_id'.""" 39 | cursor = self.connection.cursor() 40 | cursor.execute( 41 | "SELECT text FROM documents WHERE id = ?", 42 | (doc_id,) 43 | ) 44 | result = cursor.fetchone() 45 | cursor.close() 46 | return result if result is None else result[0] 47 | 48 | def get_hyper_linked(self, doc_id): 49 | """Fetch the hyper-linked titles of the doc for 'doc_id'.""" 50 | cursor = self.connection.cursor() 51 | cursor.execute( 52 | "SELECT linked_title FROM documents WHERE id = ?", 53 | (doc_id,) 54 | ) 55 | result = cursor.fetchone() 56 | cursor.close() 57 | return result if (result is None or len(result[0]) == 0) else [normalize(title) for title in result[0].split("\t")] 58 | 59 | def get_original_title(self, doc_id): 60 | """Fetch the original title name of the doc.""" 61 | cursor = self.connection.cursor() 62 | cursor.execute( 63 | "SELECT original_title FROM documents WHERE id = ?", 64 | (doc_id,) 65 | ) 66 | result = cursor.fetchone() 67 | cursor.close() 68 | return result if result is None else result[0] 69 | 70 | def get_doc_text_hyper_linked_titles_for_articles(self, doc_id): 71 | """ 72 | fetch all of the paragraphs with their corresponding hyperlink titles. 73 | e.g., 74 | >>> paras, links = db.get_doc_text_hyper_linked_titles_for_articles("Tokyo Imperial Palace_0") 75 | >>> paras[2] 76 | 'It is built on the site of the old Edo Castle. The total area including the gardens is . During the height of the 1980s Japanese property bubble, the palace grounds were valued by some to be more than the value of all of the real estate in the state of California.' 77 | >>> links[2] 78 | ['Edo Castle', 'Japanese asset price bubble', 'Real estate', 'California'] 79 | """ 80 | cursor = self.connection.cursor() 81 | cursor.execute( 82 | "SELECT text FROM documents WHERE id = ?", 83 | (doc_id,) 84 | ) 85 | result = cursor.fetchone() 86 | cursor.close() 87 | if result is None: 88 | return [], [] 89 | else: 90 | hyper_linked_paragraphs = result[0].split("\n\n") 91 | paragraphs, hyper_linked_titles = [], [] 92 | 93 | for hyper_linked_paragraph in hyper_linked_paragraphs: 94 | paragraphs.append(remove_tags(hyper_linked_paragraph)) 95 | hyper_linked_titles.append([normalize(title) for title in find_hyper_linked_titles( 96 | hyper_linked_paragraph)]) 97 | 98 | return paragraphs, hyper_linked_titles 99 | -------------------------------------------------------------------------------- /retriever/interactive.py: -------------------------------------------------------------------------------- 1 | """Interactive mode for the tfidf DrQA retriever module.""" 2 | 3 | import argparse 4 | import code 5 | import prettytable 6 | import logging 7 | try: 8 | from retriever.doc_db import DocDB 9 | from retriever.tfidf_doc_ranker import TfidfDocRanker 10 | except: 11 | from doc_db import DocDB 12 | from tfidf_doc_ranker import TfidfDocRanker 13 | 14 | logger = logging.getLogger() 15 | logger.setLevel(logging.INFO) 16 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 17 | console = logging.StreamHandler() 18 | console.setFormatter(fmt) 19 | logger.addHandler(console) 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--model', type=str, default=None) 23 | parser.add_argument('--db_save_path', type=str, default=None) 24 | args = parser.parse_args() 25 | db = DocDB(args.db_save_path) 26 | 27 | logger.info('Initializing ranker...') 28 | ranker = TfidfDocRanker(tfidf_path=args.model) 29 | 30 | 31 | # ------------------------------------------------------------------------------ 32 | # Drop in to interactive 33 | # ------------------------------------------------------------------------------ 34 | 35 | 36 | def process(query, k=1, with_content=False): 37 | doc_names, doc_scores = ranker.closest_docs(query, k) 38 | if with_content is True: 39 | doc_text = [] 40 | for doc_name in doc_names: 41 | doc_text.append(db.get_doc_text(doc_name)) 42 | table = prettytable.PrettyTable( 43 | ['Rank', 'Doc Id', 'Doc Text'] 44 | ) 45 | 46 | for i in range(len(doc_names)): 47 | table.add_row([i + 1, doc_names[i], doc_text[i]]) 48 | print(table) 49 | 50 | else: 51 | table = prettytable.PrettyTable( 52 | ['Rank', 'Doc Id', 'Doc Score'] 53 | ) 54 | for i in range(len(doc_names)): 55 | table.add_row([i + 1, doc_names[i], '%.5g' % doc_scores[i]]) 56 | print(table) 57 | 58 | 59 | banner = """ 60 | Interactive TF-IDF DrQA Retriever 61 | >> process(question, k=1) 62 | >> usage() 63 | """ 64 | 65 | 66 | def usage(): 67 | print(banner) 68 | 69 | 70 | code.interact(banner=banner, local=locals()) 71 | -------------------------------------------------------------------------------- /retriever/tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | 16 | try: 17 | from retriever.tokenizers import SimpleTokenizer 18 | from retriever.utils import load_sparse_csr, filter_ngram, hash, normalize 19 | except: 20 | from tokenizers import SimpleTokenizer 21 | from utils import load_sparse_csr, filter_ngram, hash, normalize 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class TfidfDocRanker(object): 27 | """Loads a pre-weighted inverted index of token/document terms. 28 | Scores new queries by taking sparse dot products. 29 | """ 30 | 31 | def __init__(self, tfidf_path=None, strict=True): 32 | """ 33 | Args: 34 | tfidf_path: path to saved model file 35 | strict: fail on empty queries or continue (and return empty result) 36 | """ 37 | # Load from disk 38 | tfidf_path = tfidf_path 39 | logger.info('Loading %s' % tfidf_path) 40 | matrix, metadata = load_sparse_csr(tfidf_path) 41 | self.doc_mat = matrix 42 | self.ngrams = metadata['ngram'] 43 | self.hash_size = metadata['hash_size'] 44 | self.tokenizer = SimpleTokenizer() 45 | self.doc_freqs = metadata['doc_freqs'].squeeze() 46 | self.doc_dict = metadata['doc_dict'] 47 | self.num_docs = len(self.doc_dict[0]) 48 | self.strict = strict 49 | 50 | def get_doc_index(self, doc_id): 51 | """Convert doc_id --> doc_index""" 52 | return self.doc_dict[0][doc_id] 53 | 54 | def get_doc_id(self, doc_index): 55 | """Convert doc_index --> doc_id""" 56 | return self.doc_dict[1][doc_index] 57 | 58 | def closest_docs(self, query, k=1): 59 | """Closest docs by dot product between query and documents 60 | in tfidf weighted word vector space. 61 | """ 62 | spvec = self.text2spvec(query) 63 | res = spvec * self.doc_mat 64 | 65 | if len(res.data) <= k: 66 | o_sort = np.argsort(-res.data) 67 | else: 68 | o = np.argpartition(-res.data, k)[0:k] 69 | o_sort = o[np.argsort(-res.data[o])] 70 | 71 | doc_scores = res.data[o_sort] 72 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 73 | return doc_ids, doc_scores 74 | 75 | def batch_closest_docs(self, queries, k=1, num_workers=None): 76 | """Process a batch of closest_docs requests multithreaded. 77 | Note: we can use plain threads here as scipy is outside of the GIL. 78 | """ 79 | with ThreadPool(num_workers) as threads: 80 | closest_docs = partial(self.closest_docs, k=k) 81 | results = threads.map(closest_docs, queries) 82 | return results 83 | 84 | def parse(self, query): 85 | """Parse the query into tokens (either ngrams or tokens).""" 86 | tokens = self.tokenizer.tokenize(query) 87 | return tokens.ngrams(n=self.ngrams, uncased=True, 88 | filter_fn=filter_ngram) 89 | 90 | def text2spvec(self, query): 91 | """Create a sparse tfidf-weighted word vector from query. 92 | 93 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 94 | """ 95 | # Get hashed ngrams 96 | # TODO: do we need to have normalize? 97 | words = self.parse(normalize(query)) 98 | wids = [hash(w, self.hash_size) for w in words] 99 | 100 | if len(wids) == 0: 101 | if self.strict: 102 | raise RuntimeError('No valid word in: %s' % query) 103 | else: 104 | logger.warning('No valid word in: %s' % query) 105 | return sp.csr_matrix((1, self.hash_size)) 106 | 107 | # Count TF 108 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 109 | tfs = np.log1p(wids_counts) 110 | 111 | # Count IDF 112 | Ns = self.doc_freqs[wids_unique] 113 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 114 | idfs[idfs < 0] = 0 115 | 116 | # TF-IDF 117 | data = np.multiply(tfs, idfs) 118 | 119 | # One row, sparse csr matrix 120 | indptr = np.array([0, len(wids_unique)]) 121 | spvec = sp.csr_matrix( 122 | (data, wids_unique, indptr), shape=(1, self.hash_size) 123 | ) 124 | 125 | return spvec 126 | -------------------------------------------------------------------------------- /retriever/tfidf_vectorizer_article.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.feature_extraction.text import TfidfVectorizer 3 | from sklearn.metrics import pairwise_distances 4 | from sklearn.metrics.pairwise import cosine_similarity 5 | 6 | 7 | class TopTfIdf(): 8 | def __init__(self, n_to_select: int, filter_dist_one: bool = False, rank=True): 9 | self.rank = rank 10 | self.n_to_select = n_to_select 11 | self.filter_dist_one = filter_dist_one 12 | 13 | def prune(self, question, paragraphs, return_scores=False): 14 | if not self.filter_dist_one and len(paragraphs) == 1: 15 | return paragraphs 16 | 17 | tfidf = TfidfVectorizer(strip_accents="unicode", 18 | stop_words="english") 19 | text = [] 20 | for para in paragraphs: 21 | text.append(para) 22 | try: 23 | para_features = tfidf.fit_transform(text) 24 | except ValueError: 25 | return [] 26 | # question should be tokenized beforehand 27 | q_features = tfidf.transform([question]) 28 | dists = cosine_similarity(q_features, para_features, "cosine").ravel() 29 | # in case of ties, use the earlier paragraph 30 | sorted_ix = np.argsort(dists)[::-1] 31 | if return_scores is True: 32 | return sorted_ix, dists 33 | else: 34 | return sorted_ix 35 | 36 | def dists(self, question, paragraphs): 37 | tfidf = TfidfVectorizer(strip_accents="unicode", 38 | stop_words=self.stop.words) 39 | text = [] 40 | for para in paragraphs: 41 | text.append(" ".join(" ".join(s) for s in para.text)) 42 | try: 43 | para_features = tfidf.fit_transform(text) 44 | q_features = tfidf.transform([" ".join(question)]) 45 | except ValueError: 46 | return [] 47 | 48 | dists = pairwise_distances(q_features, para_features, "cosine").ravel() 49 | # in case of ties, use the earlier paragraph 50 | sorted_ix = np.lexsort(([x for x in range(len(paragraphs))], dists)) 51 | 52 | if self.filter_dist_one: 53 | return [(paragraphs[i], dists[i]) for i in sorted_ix[:self.n_to_select] if dists[i] < 1.0] 54 | else: 55 | return [(paragraphs[i], dists[i]) for i in sorted_ix[:self.n_to_select]] 56 | -------------------------------------------------------------------------------- /retriever/tokenizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | import copy 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class Tokens(object): 19 | """A class to represent a list of tokenized text.""" 20 | TEXT = 0 21 | TEXT_WS = 1 22 | SPAN = 2 23 | POS = 3 24 | LEMMA = 4 25 | NER = 5 26 | 27 | def __init__(self, data, annotators, opts=None): 28 | self.data = data 29 | self.annotators = annotators 30 | self.opts = opts or {} 31 | 32 | def __len__(self): 33 | """The number of tokens.""" 34 | return len(self.data) 35 | 36 | def slice(self, i=None, j=None): 37 | """Return a view of the list of tokens from [i, j).""" 38 | new_tokens = copy.copy(self) 39 | new_tokens.data = self.data[i: j] 40 | return new_tokens 41 | 42 | def untokenize(self): 43 | """Returns the original text (with whitespace reinserted).""" 44 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 45 | 46 | def words(self, uncased=False): 47 | """Returns a list of the text of each token 48 | Args: 49 | uncased: lower cases text 50 | """ 51 | if uncased: 52 | return [t[self.TEXT].lower() for t in self.data] 53 | else: 54 | return [t[self.TEXT] for t in self.data] 55 | 56 | def offsets(self): 57 | """Returns a list of [start, end) character offsets of each token.""" 58 | return [t[self.SPAN] for t in self.data] 59 | 60 | def pos(self): 61 | """Returns a list of part-of-speech tags of each token. 62 | Returns None if this annotation was not included. 63 | """ 64 | if 'pos' not in self.annotators: 65 | return None 66 | return [t[self.POS] for t in self.data] 67 | 68 | def lemmas(self): 69 | """Returns a list of the lemmatized text of each token. 70 | Returns None if this annotation was not included. 71 | """ 72 | if 'lemma' not in self.annotators: 73 | return None 74 | return [t[self.LEMMA] for t in self.data] 75 | 76 | def entities(self): 77 | """Returns a list of named-entity-recognition tags of each token. 78 | Returns None if this annotation was not included. 79 | """ 80 | if 'ner' not in self.annotators: 81 | return None 82 | return [t[self.NER] for t in self.data] 83 | 84 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 85 | """Returns a list of all ngrams from length 1 to n. 86 | Args: 87 | n: upper limit of ngram length 88 | uncased: lower cases text 89 | filter_fn: user function that takes in an ngram list and returns 90 | True or False to keep or not keep the ngram 91 | as_string: return the ngram as a string vs list 92 | """ 93 | def _skip(gram): 94 | if not filter_fn: 95 | return False 96 | return filter_fn(gram) 97 | 98 | words = self.words(uncased) 99 | ngrams = [(s, e + 1) 100 | for s in range(len(words)) 101 | for e in range(s, min(s + n, len(words))) 102 | if not _skip(words[s:e + 1])] 103 | 104 | # Concatenate into strings 105 | if as_strings: 106 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 107 | 108 | return ngrams 109 | 110 | def entity_groups(self): 111 | """Group consecutive entity tokens with the same NER tag.""" 112 | entities = self.entities() 113 | if not entities: 114 | return None 115 | non_ent = self.opts.get('non_ent', 'O') 116 | groups = [] 117 | idx = 0 118 | while idx < len(entities): 119 | ner_tag = entities[idx] 120 | # Check for entity tag 121 | if ner_tag != non_ent: 122 | # Chomp the sequence 123 | start = idx 124 | while (idx < len(entities) and entities[idx] == ner_tag): 125 | idx += 1 126 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 127 | else: 128 | idx += 1 129 | return groups 130 | 131 | 132 | class Tokenizer(object): 133 | """Base tokenizer class. 134 | Tokenizers implement tokenize, which should return a Tokens class. 135 | """ 136 | 137 | def tokenize(self, text): 138 | raise NotImplementedError 139 | 140 | def shutdown(self): 141 | pass 142 | 143 | def __del__(self): 144 | self.shutdown() 145 | 146 | 147 | class SimpleTokenizer(Tokenizer): 148 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 149 | NON_WS = r'[^\p{Z}\p{C}]' 150 | 151 | def __init__(self, **kwargs): 152 | """ 153 | Args: 154 | annotators: None or empty set (only tokenizes). 155 | """ 156 | self._regexp = regex.compile( 157 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 158 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 159 | ) 160 | if len(kwargs.get('annotators', {})) > 0: 161 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 162 | (type(self).__name__, kwargs.get('annotators'))) 163 | self.annotators = set() 164 | 165 | def tokenize(self, text): 166 | data = [] 167 | matches = [m for m in self._regexp.finditer(text)] 168 | for i in range(len(matches)): 169 | # Get text 170 | token = matches[i].group() 171 | 172 | # Get whitespace 173 | span = matches[i].span() 174 | start_ws = span[0] 175 | if i + 1 < len(matches): 176 | end_ws = matches[i + 1].span()[0] 177 | else: 178 | end_ws = span[1] 179 | 180 | # Format data 181 | data.append(( 182 | token, 183 | text[start_ws: end_ws], 184 | span, 185 | )) 186 | return Tokens(data, self.annotators) 187 | -------------------------------------------------------------------------------- /retriever/utils.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | import jsonlines 3 | import re 4 | from urllib.parse import unquote 5 | import regex 6 | import numpy as np 7 | import scipy.sparse as sp 8 | from sklearn.utils import murmurhash3_32 9 | 10 | import logging 11 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 12 | datefmt='%m/%d/%Y %H:%M:%S', 13 | level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def normalize(text): 18 | """Resolve different type of unicode encodings / capitarization in HotpotQA data.""" 19 | text = unicodedata.normalize('NFD', text) 20 | return text[0].capitalize() + text[1:] 21 | 22 | 23 | def make_wiki_id(title, para_index): 24 | title_id = "{0}_{1}".format(normalize(title), para_index) 25 | return title_id 26 | 27 | 28 | def find_hyper_linked_titles(text_w_links): 29 | titles = re.findall(r'href=[\'"]?([^\'" >]+)', text_w_links) 30 | titles = [unquote(title) for title in titles] 31 | titles = [title[0].capitalize() + title[1:] for title in titles] 32 | return titles 33 | 34 | 35 | TAG_RE = re.compile(r'<[^>]+>') 36 | 37 | 38 | def remove_tags(text): 39 | return TAG_RE.sub('', text) 40 | 41 | 42 | def process_jsonlines(filename): 43 | """ 44 | This is process_jsonlines method for extracted Wikipedia file. 45 | After extracting items by using Wikiextractor (with `--json` and `--links` options), 46 | you will get the files named with wiki_xx, where each line contains the information of each article. 47 | e.g., 48 | {"id": "316", "url": "https://en.wikipedia.org/wiki?curid=316", "title": "Academy Award for Best Production Design", 49 | "text": "Academy Award for Best Production Design\n\nThe Academy Award for 50 | Best Production Design recognizes achievement for art direction \n\n"} 51 | This function takes these input and extract items. 52 | Each article contains one or more than one paragraphs, and each paragraphs are separeated by \n\n. 53 | """ 54 | # item should be nested list 55 | extracted_items = [] 56 | with jsonlines.open(filename) as reader: 57 | for obj in reader: 58 | wiki_id = obj["id"] 59 | title = obj["title"] 60 | title_id = make_wiki_id(title, 0) 61 | text_with_links = obj["text"] 62 | 63 | hyper_linked_titles_text = "" 64 | # When we consider the whole article as a document unit (e.g., SQuAD Open, Natural Questions Open) 65 | # we'll keep the links with the original articles, and dynamically process and extract the links 66 | # when we process with our selector. 67 | extracted_items.append({"wiki_id": wiki_id, "title": title_id, 68 | "plain_text": text_with_links, 69 | "hyper_linked_titles": hyper_linked_titles_text, 70 | "original_title": title}) 71 | 72 | return extracted_items 73 | 74 | def process_jsonlines_hotpotqa(filename): 75 | """ 76 | This is process_jsonlines method for intro-only processed_wikipedia file. 77 | The item example: 78 | {"id": "45668011", "url": "https://en.wikipedia.org/wiki?curid=45668011", "title": "Flouch Roundabout", 79 | "text": ["Flouch Roundabout is a roundabout near Penistone, South Yorkshire, England, where the A628 meets the A616."], 80 | "charoffset": [[[0, 6],...]] 81 | "text_with_links" : ["Flouch Roundabout is a roundabout near Penistone, 82 | South Yorkshire, England, where the A628 83 | meets the A616."], 84 | "charoffset_with_links": [[[0, 6], ... [213, 214]]]} 85 | """ 86 | # item should be nested list 87 | extracted_items = [] 88 | with jsonlines.open(filename) as reader: 89 | for obj in reader: 90 | wiki_id = obj["id"] 91 | title = obj["title"] 92 | title_id = make_wiki_id(title, 0) 93 | plain_text = "\t".join(obj["text"]) 94 | text_with_links = "\t".join(obj["text_with_links"]) 95 | 96 | hyper_linked_titles = [] 97 | hyper_linked_titles = find_hyper_linked_titles(text_with_links) 98 | if len(hyper_linked_titles) > 0: 99 | hyper_linked_titles_text = "\t".join(hyper_linked_titles) 100 | else: 101 | hyper_linked_titles_text = "" 102 | extracted_items.append({"wiki_id": wiki_id, "title": title_id, 103 | "plain_text": plain_text, 104 | "hyper_linked_titles": hyper_linked_titles_text, 105 | "original_title": title}) 106 | 107 | return extracted_items 108 | 109 | 110 | # ------------------------------------------------------------------------------ 111 | # Sparse matrix saving/loading helpers. 112 | # ------------------------------------------------------------------------------ 113 | 114 | 115 | def save_sparse_csr(filename, matrix, metadata=None): 116 | data = { 117 | 'data': matrix.data, 118 | 'indices': matrix.indices, 119 | 'indptr': matrix.indptr, 120 | 'shape': matrix.shape, 121 | 'metadata': metadata, 122 | } 123 | np.savez(filename, **data) 124 | 125 | 126 | def load_sparse_csr(filename): 127 | loader = np.load(filename, allow_pickle=True) 128 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 129 | loader['indptr']), shape=loader['shape']) 130 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 131 | 132 | # ------------------------------------------------------------------------------ 133 | # Token hashing. 134 | # ------------------------------------------------------------------------------ 135 | 136 | 137 | def hash(token, num_buckets): 138 | """Unsigned 32 bit murmurhash for feature hashing.""" 139 | return murmurhash3_32(token, positive=True) % num_buckets 140 | 141 | # ------------------------------------------------------------------------------ 142 | # Text cleaning. 143 | # ------------------------------------------------------------------------------ 144 | 145 | 146 | STOPWORDS = { 147 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 148 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 149 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 150 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 151 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 152 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 153 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 154 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 155 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 156 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 157 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 158 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 159 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 160 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 161 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 162 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 163 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 164 | } 165 | 166 | 167 | def filter_word(text): 168 | """Take out english stopwords, punctuation, and compound endings.""" 169 | text = normalize(text) 170 | if regex.match(r'^\p{P}+$', text): 171 | return True 172 | if text.lower() in STOPWORDS: 173 | return True 174 | return False 175 | 176 | 177 | def filter_ngram(gram, mode='any'): 178 | """Decide whether to keep or discard an n-gram. 179 | Args: 180 | gram: list of tokens (length N) 181 | mode: Option to throw out ngram if 182 | 'any': any single token passes filter_word 183 | 'all': all tokens pass filter_word 184 | 'ends': book-ended by filterable tokens 185 | """ 186 | filtered = [filter_word(w) for w in gram] 187 | if mode == 'any': 188 | return any(filtered) 189 | elif mode == 'all': 190 | return all(filtered) 191 | elif mode == 'ends': 192 | return filtered[0] or filtered[-1] 193 | else: 194 | raise ValueError('Invalid mode: %s' % mode) 195 | 196 | 197 | def get_field(d, field_list): 198 | """get the subfield associated to a list of elastic fields 199 | E.g. ['file', 'filename'] to d['file']['filename'] 200 | """ 201 | if isinstance(field_list, str): 202 | return d[field_list] 203 | else: 204 | idx = d.copy() 205 | for field in field_list: 206 | idx = idx[field] 207 | return idx 208 | 209 | 210 | def load_para_collections_from_tfidf_id_intro_only(tfidf_id, db): 211 | if "_0" not in tfidf_id: 212 | tfidf_id = "{0}_0".format(tfidf_id) 213 | if db.get_doc_text(tfidf_id) is None: 214 | logger.warning("{0} is missing".format(tfidf_id)) 215 | return [] 216 | return [[tfidf_id, db.get_doc_text(tfidf_id).split("\t")]] 217 | 218 | def load_linked_titles_from_tfidf_id(tfidf_id, db): 219 | para_titles = db.get_paras_with_article(tfidf_id) 220 | linked_titles_all = [] 221 | for para_title in para_titles: 222 | linked_title_per_para = db.get_hyper_linked(para_title) 223 | if len(linked_title_per_para) > 0: 224 | linked_titles_all += linked_title_per_para.split("\t") 225 | return linked_titles_all 226 | 227 | def load_para_and_linked_titles_dict_from_tfidf_id(tfidf_id, db): 228 | """ 229 | load paragraphs and hyperlinked titles from DB. 230 | This method is mainly for Natural Questions Open benchmark. 231 | """ 232 | # will be fixed in the later version; current tfidf weights use indexed titles as keys. 233 | if "_0" not in tfidf_id: 234 | tfidf_id = "{0}_0".format(tfidf_id) 235 | paras, linked_titles = db.get_doc_text_hyper_linked_titles_for_articles( 236 | tfidf_id) 237 | if len(paras) == 0: 238 | logger.warning("{0} is missing".format(tfidf_id)) 239 | return [], [] 240 | 241 | paras_dict = {} 242 | linked_titles_dict = {} 243 | article_name = tfidf_id.split("_0")[0] 244 | # store the para_dict and linked_titles_dict; skip the first para (title) 245 | for para_idx, (para, linked_title_list) in enumerate(zip(paras[1:], linked_titles[1:])): 246 | paras_dict["{0}_{1}".format(article_name, para_idx)] = para 247 | linked_titles_dict["{0}_{1}".format( 248 | article_name, para_idx)] = linked_title_list 249 | 250 | return paras_dict, linked_titles_dict 251 | 252 | def prune_top_k_paragraphs(question_text, paragraphs, tfidf_vectorizer, pruning_l=10): 253 | para_titles, para_text = list(paragraphs.keys()), list(paragraphs.values()) 254 | # prune top l paragraphs using the question as query to reduce the search space. 255 | top_tfidf_para_indices = tfidf_vectorizer.prune( 256 | question_text, para_text)[:pruning_l] 257 | para_title_text_pairs_pruned = {} 258 | 259 | # store the selected paras into dictionary. 260 | for idx in top_tfidf_para_indices: 261 | para_title_text_pairs_pruned[para_titles[idx]] = para_text[idx] 262 | 263 | return para_title_text_pairs_pruned 264 | -------------------------------------------------------------------------------- /sequential_sentence_selector/README.md: -------------------------------------------------------------------------------- 1 | # Sequential Sentence Selector 2 | 3 | This directory includes codes for our sequential sentence selector model described in Appendix A.4 of our paper. 4 | The model is the same as our proposed graph retriever model to retrieve reasoning paths, and it is adapted to the supporting fact prediction task in HotpotQA. 5 | 6 | ## Training 7 | We use `run_sequential_sentence_selector.py` to train the model. 8 | Basically, the overall code is based on that of our graph retriever. 9 | Here is an example command used for our paper. 10 | Since we used `pytorch-pretrained-bert` to develop our models, we cannot directly use the BERT-whole-word-masking configurations. 11 | However, we have a way to use, for example, `bert-large-uncased-whole-word-masking` when we have `pytorch-transformers` in our env. 12 | Refer to `./utils.py` for this trick. 13 | 14 | ```bash 15 | python run_sequential_sentence_selector.py \ 16 | --bert_model bert-large-uncased-whole-word-masking \ 17 | --train_file_path \ 18 | --output_dir \ 19 | --do_lower_case \ 20 | --train_batch_size 12 \ 21 | --gradient_accumulation_steps 1 \ 22 | --num_train_epochs 3 \ 23 | --learning_rate 3e-5 24 | ``` 25 | 26 | Once you train your model, you can use it in our evaluation pipeline script for HotpotQA. 27 | You can use `hotpot_sf_selector_order_train.json` for `--train_file_path`. 28 | -------------------------------------------------------------------------------- /sequential_sentence_selector/modeling_sequential_sentence_selector.py: -------------------------------------------------------------------------------- 1 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | from torch.nn.parameter import Parameter 8 | 9 | class BertForSequentialSentenceSelector(BertPreTrainedModel): 10 | 11 | def __init__(self, config): 12 | super(BertForSequentialSentenceSelector, self).__init__(config) 13 | 14 | self.bert = BertModel(config) 15 | 16 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 17 | 18 | # Initial state 19 | self.s = Parameter(torch.FloatTensor(config.hidden_size).uniform_(-0.1, 0.1)) 20 | 21 | # Scaling factor for weight norm 22 | self.g = Parameter(torch.FloatTensor(1).fill_(1.0)) 23 | 24 | # RNN weight 25 | self.rw = nn.Linear(2*config.hidden_size, config.hidden_size) 26 | 27 | # EOE and output bias 28 | self.eos = Parameter(torch.FloatTensor(config.hidden_size).uniform_(-0.1, 0.1)) 29 | self.bias = Parameter(torch.FloatTensor(1).zero_()) 30 | 31 | self.apply(self.init_bert_weights) 32 | self.cpu = torch.device('cpu') 33 | 34 | ''' 35 | state: (B, 1, D) 36 | ''' 37 | def weight_norm(self, state): 38 | state = state / state.norm(dim = 2).unsqueeze(2) 39 | state = self.g * state 40 | return state 41 | 42 | def encode(self, input_ids, token_type_ids, attention_mask, split_chunk = None): 43 | B = input_ids.size(0) 44 | N = input_ids.size(1) 45 | input_ids = input_ids.contiguous().view(input_ids.size(0)*input_ids.size(1), input_ids.size(2)) 46 | token_type_ids = token_type_ids.contiguous().view(token_type_ids.size(0)*token_type_ids.size(1), token_type_ids.size(2)) 47 | attention_mask = attention_mask.contiguous().view(attention_mask.size(0)*attention_mask.size(1), attention_mask.size(2)) 48 | 49 | # [CLS] vectors for Q-P pairs 50 | if split_chunk is None: 51 | encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 52 | pooled_output = encoded_layers[:, 0] 53 | else: 54 | TOTAL = input_ids.size(0) 55 | start = 0 56 | 57 | while start < TOTAL: 58 | end = min(start+split_chunk-1, TOTAL-1) 59 | chunk_len = end-start+1 60 | 61 | input_ids_ = input_ids[start:start+chunk_len, :] 62 | token_type_ids_ = token_type_ids[start:start+chunk_len, :] 63 | attention_mask_ = attention_mask[start:start+chunk_len, :] 64 | 65 | encoded_layers, pooled_output_ = self.bert(input_ids_, token_type_ids_, attention_mask_, output_all_encoded_layers=False) 66 | encoded_layers = encoded_layers[:, 0] 67 | 68 | if start == 0: 69 | pooled_output = encoded_layers 70 | else: 71 | pooled_output = torch.cat((pooled_output, encoded_layers), dim = 0) 72 | 73 | start = end+1 74 | 75 | pooled_output = pooled_output.contiguous() 76 | 77 | paragraphs = pooled_output.view(pooled_output.size(0)//N, N, pooled_output.size(1)) # (B, N, D) 78 | EOE = self.eos.unsqueeze(0).unsqueeze(0) # (1, 1, D) 79 | EOE = EOE.expand(paragraphs.size(0), EOE.size(1), EOE.size(2)) # (B, 1, D) 80 | EOE = self.bert.encoder.layer[-1].output.LayerNorm(EOE) 81 | paragraphs = torch.cat((paragraphs, EOE), dim = 1) # (B, N+1, D) 82 | 83 | # Initial state 84 | state = self.s.expand(paragraphs.size(0), 1, self.s.size(0)) 85 | state = self.weight_norm(state) 86 | 87 | return paragraphs, state 88 | 89 | def forward(self, input_ids, token_type_ids, attention_mask, output_mask, target, target_ids, max_sf_num): 90 | 91 | paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask) 92 | 93 | for i in range(max_sf_num+1): 94 | if i == 0: 95 | h = state 96 | else: 97 | for j in range(target_ids.size(0)): 98 | index = target_ids[j, i-1] 99 | input_ = paragraphs[j:j+1, index:index+1, :] # (B, 1, D) 100 | 101 | if j == 0: 102 | input = input_ 103 | else: 104 | input = torch.cat((input, input_), dim = 0) 105 | 106 | state = torch.cat((state, input), dim = 2) # (B, 1, 2*D) 107 | state = self.rw(state) # (B, 1, D) 108 | state = self.weight_norm(state) 109 | h = torch.cat((h, state), dim = 1) # ...--> (B, max_num_steps, D) 110 | 111 | h = self.dropout(h) 112 | output = torch.bmm(h, paragraphs.transpose(1, 2)) # (B, max_num_steps, N+1) 113 | output = output + self.bias 114 | loss = F.binary_cross_entropy_with_logits(output, target, weight = output_mask, reduction = 'mean') 115 | return loss 116 | 117 | def beam_search(self, input_ids, token_type_ids, attention_mask, output_mask, max_num_steps, examples, beam = 2): 118 | 119 | B = input_ids.size(0) 120 | paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask, split_chunk = 300) 121 | 122 | pred = [] 123 | prob = [] 124 | 125 | topk_pred = [] 126 | topk_prob = [] 127 | 128 | eoe_index = paragraphs.size(1)-1 129 | 130 | output_mask = output_mask.to(self.cpu) 131 | 132 | for i in range(B): 133 | pred_ = [[[], 1.0, 0] for _ in range(beam)] # [hist_1, score_1, len_1], [hist_2, score_2, len_2], ... 134 | prob_ = [[] for _ in range(beam)] 135 | 136 | state_ = state[i:i+1] # (1, 1, D) 137 | state_ = state_.expand(beam, 1, state_.size(2)) # -> (beam, 1, D) 138 | state_tmp = torch.FloatTensor(state_.size()).zero_().to(state_.device) 139 | ps = paragraphs[i:i+1] # (1, N+1, D) 140 | ps = ps.expand(beam, ps.size(1), ps.size(2)) # -> (beam, N+1, D) 141 | 142 | for j in range(max_num_steps): 143 | if j > 0: 144 | input = [p[0][-1] for p in pred_] 145 | input = torch.LongTensor(input).to(paragraphs.device) 146 | input = ps[0][input].unsqueeze(1) # (beam, 1, D) 147 | state_ = torch.cat((state_, input), dim = 2) # (beam, 1, 2*D) 148 | state_ = self.rw(state_) # (beam, 1, D) 149 | state_ = self.weight_norm(state_) 150 | 151 | output = torch.bmm(state_, ps.transpose(1, 2)) # (beam, 1, N+1) 152 | output = output + self.bias 153 | output = torch.sigmoid(output) 154 | 155 | output = output.to(self.cpu) 156 | 157 | if j == 0: 158 | output = output * output_mask[i:i+1, 0:1, :] 159 | else: 160 | for b in range(beam): 161 | for k in range(len(pred_[b][0])): 162 | output[b:b+1] *= output_mask[i:i+1, pred_[b][0][k]+1:pred_[b][0][k]+2, :] 163 | 164 | e = examples[i] 165 | # Predict at least 1 sentence anyway 166 | if j <= 0: 167 | output[:, :, -1] = 0.0 168 | # No further constraints for single title 169 | elif len(e.titles) == 1: 170 | pass 171 | else: 172 | for b in range(beam): 173 | sfs = set() 174 | for p in pred_[b][0]: 175 | if p == eoe_index: 176 | break 177 | offset = 0 178 | for k in range(len(e.titles)): 179 | if p >= offset and p < offset+len(e.context[e.titles[k]]): 180 | sfs.add(e.titles[k]) 181 | break 182 | offset += len(e.context[e.titles[k]]) 183 | if len(sfs) == 1: 184 | output[b, :, -1] = 0.0 185 | 186 | score = [p[1] for p in pred_] 187 | score = torch.FloatTensor(score) 188 | score = score.unsqueeze(1).unsqueeze(2) # (beam, 1, 1) 189 | score = output * score 190 | 191 | output = output.squeeze(1) # (beam, N+1) 192 | score = score.squeeze(1) # (beam, N+1) 193 | new_pred_ = [] 194 | new_prob_ = [] 195 | 196 | for b in range(beam): 197 | s, p = torch.max(score.view(score.size(0)*score.size(1)), dim = 0) 198 | s = s.item() 199 | p = p.item() 200 | row = p // score.size(1) 201 | col = p % score.size(1) 202 | 203 | p = [[index for index in pred_[row][0]] + [col], 204 | score[row, col].item(), 205 | pred_[row][2] + (1 if col != eoe_index else 0)] 206 | new_pred_.append(p) 207 | 208 | p = [[p_ for p_ in prb] for prb in prob_[row]] + [output[row].tolist()] 209 | new_prob_.append(p) 210 | 211 | state_tmp[b].copy_(state_[row]) 212 | 213 | if j == 0: 214 | score[:, col] = 0.0 215 | else: 216 | score[row, col] = 0.0 217 | 218 | pred_ = new_pred_ 219 | prob_ = new_prob_ 220 | state_ = state_.clone() 221 | state_.copy_(state_tmp) 222 | 223 | if pred_[0][0][-1] == eoe_index: 224 | break 225 | 226 | topk_pred.append([]) 227 | topk_prob.append([]) 228 | for index__ in range(beam): 229 | 230 | pred_tmp = [] 231 | for index in pred_[index__][0]: 232 | if index == eoe_index: 233 | break 234 | pred_tmp.append(index) 235 | 236 | if index__ == 0: 237 | pred.append(pred_tmp) 238 | prob.append(prob_[0]) 239 | 240 | topk_pred[-1].append(pred_tmp) 241 | topk_prob[-1].append(prob_[index__]) 242 | 243 | return pred, prob, topk_pred, topk_prob 244 | -------------------------------------------------------------------------------- /sequential_sentence_selector/run_sequential_sentence_selector.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import csv 6 | import os 7 | import logging 8 | import argparse 9 | import random 10 | from tqdm import tqdm, trange 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | from pytorch_pretrained_bert.tokenization import BertTokenizer 18 | from pytorch_pretrained_bert.optimization import BertAdam 19 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 20 | 21 | try: 22 | from sequential_sentence_selector.modeling_sequential_sentence_selector import BertForSequentialSentenceSelector 23 | except: 24 | from modeling_sequential_sentence_selector import BertForSequentialSentenceSelector 25 | 26 | import json 27 | 28 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 29 | datefmt='%m/%d/%Y %H:%M:%S', 30 | level=logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class InputExample(object): 35 | """A single training/test example for simple sequence classification.""" 36 | 37 | def __init__(self, guid, q, a, t, c, sf): 38 | 39 | self.guid = guid 40 | self.question = q 41 | self.answer = a 42 | self.titles = t 43 | self.context = c 44 | self.supporting_facts = sf 45 | 46 | 47 | class InputFeatures(object): 48 | """A single set of features of data.""" 49 | 50 | def __init__(self, input_ids, input_masks, segment_ids, target_ids, output_masks, num_sents, num_sfs, ex_index): 51 | self.input_ids = input_ids 52 | self.input_masks = input_masks 53 | self.segment_ids = segment_ids 54 | self.target_ids = target_ids 55 | self.output_masks = output_masks 56 | self.num_sents = num_sents 57 | self.num_sfs = num_sfs 58 | 59 | self.ex_index = ex_index 60 | 61 | 62 | class DataProcessor: 63 | 64 | def get_train_examples(self, file_name): 65 | return self.create_examples(json.load(open(file_name, 'r'))) 66 | 67 | def create_examples(self, jsn): 68 | examples = [] 69 | max_sent_num = 0 70 | for data in jsn: 71 | guid = data['q_id'] 72 | question = data['question'] 73 | titles = data['titles'] 74 | context = data['context'] # {title: [s1, s2, ...]} 75 | # {title: [index1, index2, ...]} 76 | supporting_facts = data['supporting_facts'] 77 | 78 | max_sent_num = max(max_sent_num, sum( 79 | [len(context[title]) for title in context])) 80 | 81 | examples.append(InputExample( 82 | guid, question, data['answer'], titles, context, supporting_facts)) 83 | 84 | return examples 85 | 86 | 87 | def convert_examples_to_features(examples, max_seq_length, max_sent_num, max_sf_num, tokenizer, train=False): 88 | """Loads a data file into a list of `InputBatch`s.""" 89 | 90 | DUMMY = [0] * max_seq_length 91 | DUMMY_ = [0.0] * max_sent_num 92 | features = [] 93 | logger.info('#### Constructing features... ####') 94 | for (ex_index, example) in enumerate(tqdm(examples, desc='Example')): 95 | 96 | tokens_q = tokenizer.tokenize( 97 | 'Q: {} A: {}'.format(example.question, example.answer)) 98 | tokens_q = ['[CLS]'] + tokens_q + ['[SEP]'] 99 | 100 | input_ids = [] 101 | input_masks = [] 102 | segment_ids = [] 103 | 104 | for title in example.titles: 105 | sents = example.context[title] 106 | for (i, s) in enumerate(sents): 107 | 108 | if len(input_ids) == max_sent_num: 109 | break 110 | 111 | tokens_s = tokenizer.tokenize( 112 | s)[:max_seq_length-len(tokens_q)-1] 113 | tokens_s = tokens_s + ['[SEP]'] 114 | 115 | padding = [0] * (max_seq_length - 116 | len(tokens_s) - len(tokens_q)) 117 | 118 | input_ids_ = tokenizer.convert_tokens_to_ids( 119 | tokens_q + tokens_s) 120 | input_masks_ = [1] * len(input_ids_) 121 | segment_ids_ = [0] * len(tokens_q) + [1] * len(tokens_s) 122 | 123 | input_ids_ += padding 124 | input_ids.append(input_ids_) 125 | 126 | input_masks_ += padding 127 | input_masks.append(input_masks_) 128 | 129 | segment_ids_ += padding 130 | segment_ids.append(segment_ids_) 131 | 132 | assert len(input_ids_) == max_seq_length 133 | assert len(input_masks_) == max_seq_length 134 | assert len(segment_ids_) == max_seq_length 135 | 136 | target_ids = [] 137 | target_offset = 0 138 | 139 | for title in example.titles: 140 | sfs = example.supporting_facts[title] 141 | for i in sfs: 142 | if i < len(example.context[title]) and i+target_offset < len(input_ids): 143 | target_ids.append(i+target_offset) 144 | else: 145 | logger.warning('') 146 | logger.warning('Invalid annotation: {}'.format(sfs)) 147 | logger.warning('Invalid annotation: {}'.format( 148 | example.context[title])) 149 | 150 | target_offset += len(example.context[title]) 151 | 152 | assert len(input_ids) <= max_sent_num 153 | assert len(target_ids) <= max_sf_num 154 | 155 | num_sents = len(input_ids) 156 | num_sfs = len(target_ids) 157 | 158 | output_masks = [([1.0] * len(input_ids) + [0.0] * (max_sent_num - 159 | len(input_ids) + 1)) for _ in range(max_sent_num + 2)] 160 | 161 | if train: 162 | 163 | for i in range(len(target_ids)): 164 | for j in range(len(target_ids)): 165 | if i == j: 166 | continue 167 | 168 | output_masks[i][target_ids[j]] = 0.0 169 | 170 | for i in range(len(output_masks)): 171 | if i >= num_sfs+1: 172 | for j in range(len(output_masks[i])): 173 | output_masks[i][j] = 0.0 174 | 175 | else: 176 | for i in range(len(input_ids)): 177 | output_masks[i+1][i] = 0.0 178 | 179 | target_ids += [0] * (max_sf_num - len(target_ids)) 180 | 181 | padding = [DUMMY] * (max_sent_num - len(input_ids)) 182 | input_ids += padding 183 | input_masks += padding 184 | segment_ids += padding 185 | 186 | features.append( 187 | InputFeatures(input_ids=input_ids, 188 | input_masks=input_masks, 189 | segment_ids=segment_ids, 190 | target_ids=target_ids, 191 | output_masks=output_masks, 192 | num_sents=num_sents, 193 | num_sfs=num_sfs, 194 | ex_index=ex_index)) 195 | 196 | logger.info('Done!') 197 | 198 | return features 199 | 200 | 201 | def warmup_linear(x, warmup=0.002): 202 | if x < warmup: 203 | return x/warmup 204 | return 1.0 - x 205 | 206 | 207 | def main(): 208 | parser = argparse.ArgumentParser() 209 | 210 | ## Required parameters 211 | parser.add_argument("--bert_model", default=None, type=str, required=True, 212 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 213 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 214 | "bert-base-multilingual-cased, bert-base-chinese.") 215 | parser.add_argument("--output_dir", 216 | default=None, 217 | type=str, 218 | required=True, 219 | help="The output directory where the model predictions and checkpoints will be written.") 220 | 221 | parser.add_argument("--train_file_path", 222 | type=str, 223 | default=None, 224 | required=True, 225 | help="File path to training data") 226 | 227 | ## Other parameters 228 | parser.add_argument("--max_seq_length", 229 | default=256, 230 | type=int, 231 | help="The maximum total input sequence length after WordPiece tokenization. \n" 232 | "Sequences longer than this will be truncated, and sequences shorter \n" 233 | "than this will be padded.") 234 | parser.add_argument("--max_sent_num", 235 | default=30, 236 | type=int) 237 | parser.add_argument("--max_sf_num", 238 | default=15, 239 | type=int) 240 | parser.add_argument("--do_lower_case", 241 | action='store_true', 242 | help="Set this flag if you are using an uncased model.") 243 | parser.add_argument("--train_batch_size", 244 | default=1, 245 | type=int, 246 | help="Total batch size for training.") 247 | parser.add_argument("--eval_batch_size", 248 | default=5, 249 | type=int, 250 | help="Total batch size for eval.") 251 | parser.add_argument("--learning_rate", 252 | default=5e-5, 253 | type=float, 254 | help="The initial learning rate for Adam. (def: 5e-5)") 255 | parser.add_argument("--num_train_epochs", 256 | default=5.0, 257 | type=float, 258 | help="Total number of training epochs to perform.") 259 | parser.add_argument("--warmup_proportion", 260 | default=0.1, 261 | type=float, 262 | help="Proportion of training to perform linear learning rate warmup for. " 263 | "E.g., 0.1 = 10%% of training.") 264 | parser.add_argument("--no_cuda", 265 | action='store_true', 266 | help="Whether not to use CUDA when available") 267 | parser.add_argument('--seed', 268 | type=int, 269 | default=42, 270 | help="random seed for initialization") 271 | parser.add_argument('--gradient_accumulation_steps', 272 | type=int, 273 | default=1, 274 | help="Number of updates steps to accumulate before performing a backward/update pass.") 275 | 276 | args = parser.parse_args() 277 | 278 | cpu = torch.device('cpu') 279 | 280 | device = torch.device("cuda" if torch.cuda.is_available() 281 | and not args.no_cuda else "cpu") 282 | n_gpu = torch.cuda.device_count() 283 | 284 | args.train_batch_size = int( 285 | args.train_batch_size / args.gradient_accumulation_steps) 286 | 287 | random.seed(args.seed) 288 | np.random.seed(args.seed) 289 | torch.manual_seed(args.seed) 290 | if n_gpu > 0: 291 | torch.cuda.manual_seed_all(args.seed) 292 | 293 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 294 | raise ValueError( 295 | "Output directory ({}) already exists and is not empty.".format(args.output_dir)) 296 | os.makedirs(args.output_dir, exist_ok=True) 297 | 298 | processor = DataProcessor() 299 | 300 | # Prepare model 301 | if args.bert_model != 'bert-large-uncased-whole-word-masking': 302 | tokenizer = BertTokenizer.from_pretrained( 303 | args.bert_model, do_lower_case=args.do_lower_case) 304 | 305 | model = BertForSequentialSentenceSelector.from_pretrained(args.bert_model, 306 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1)) 307 | else: 308 | model = BertForSequentialSentenceSelector.from_pretrained('bert-large-uncased', 309 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1)) 310 | from utils import get_bert_model_from_pytorch_transformers 311 | 312 | state_dict, vocab_file = get_bert_model_from_pytorch_transformers( 313 | args.bert_model) 314 | model.bert.load_state_dict(state_dict) 315 | tokenizer = BertTokenizer.from_pretrained( 316 | vocab_file, do_lower_case=args.do_lower_case) 317 | 318 | logger.info( 319 | 'The {} model is successfully loaded!'.format(args.bert_model)) 320 | 321 | model.to(device) 322 | if n_gpu > 1: 323 | model = torch.nn.DataParallel(model) 324 | 325 | global_step = 0 326 | nb_tr_steps = 0 327 | tr_loss = 0 328 | 329 | POSITIVE = 1.0 330 | NEGATIVE = 0.0 331 | 332 | # Load training examples 333 | train_examples = None 334 | num_train_steps = None 335 | train_examples = processor.get_train_examples(args.train_file_path) 336 | train_features = convert_examples_to_features( 337 | train_examples, args.max_seq_length, args.max_sent_num, args.max_sf_num, tokenizer, train=True) 338 | 339 | num_train_steps = int( 340 | len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 341 | 342 | # Prepare optimizer 343 | param_optimizer = list(model.named_parameters()) 344 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 345 | optimizer_grouped_parameters = [ 346 | {'params': [p for n, p in param_optimizer if not any( 347 | nd in n for nd in no_decay)], 'weight_decay': 0.01}, 348 | {'params': [p for n, p in param_optimizer if any( 349 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 350 | ] 351 | t_total = num_train_steps 352 | 353 | optimizer = BertAdam(optimizer_grouped_parameters, 354 | lr=args.learning_rate, 355 | warmup=args.warmup_proportion, 356 | t_total=t_total, 357 | max_grad_norm=1.0) 358 | 359 | logger.info("***** Running training *****") 360 | logger.info(" Num examples = %d", len(train_features)) 361 | logger.info(" Batch size = %d", args.train_batch_size) 362 | logger.info(" Num steps = %d", num_train_steps) 363 | 364 | all_input_ids = torch.tensor( 365 | [f.input_ids for f in train_features], dtype=torch.long) 366 | all_input_masks = torch.tensor( 367 | [f.input_masks for f in train_features], dtype=torch.long) 368 | all_segment_ids = torch.tensor( 369 | [f.segment_ids for f in train_features], dtype=torch.long) 370 | all_target_ids = torch.tensor( 371 | [f.target_ids for f in train_features], dtype=torch.long) 372 | all_output_masks = torch.tensor( 373 | [f.output_masks for f in train_features], dtype=torch.float) 374 | all_num_sents = torch.tensor( 375 | [f.num_sents for f in train_features], dtype=torch.long) 376 | all_num_sfs = torch.tensor( 377 | [f.num_sfs for f in train_features], dtype=torch.long) 378 | train_data = TensorDataset(all_input_ids, 379 | all_input_masks, 380 | all_segment_ids, 381 | all_target_ids, 382 | all_output_masks, 383 | all_num_sents, 384 | all_num_sfs) 385 | train_sampler = RandomSampler(train_data) 386 | train_dataloader = DataLoader( 387 | train_data, sampler=train_sampler, batch_size=args.train_batch_size) 388 | 389 | model.train() 390 | epc = 0 391 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 392 | tr_loss = 0 393 | nb_tr_examples, nb_tr_steps = 0, 0 394 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 395 | input_masks = batch[1] 396 | batch_max_len = input_masks.sum(dim=2).max().item() 397 | 398 | target_ids = batch[3] 399 | 400 | num_sents = batch[5] 401 | batch_max_sent_num = num_sents.max().item() 402 | 403 | num_sfs = batch[6] 404 | batch_max_sf_num = num_sfs.max().item() 405 | 406 | output_masks_cpu = (batch[4])[ 407 | :, :batch_max_sf_num+1, :batch_max_sent_num+1] 408 | 409 | batch = tuple(t.to(device) for t in batch) 410 | input_ids, input_masks, segment_ids, _, output_masks, __, ___ = batch 411 | B = input_ids.size(0) 412 | 413 | input_ids = input_ids[:, :batch_max_sent_num, :batch_max_len] 414 | input_masks = input_masks[:, :batch_max_sent_num, :batch_max_len] 415 | segment_ids = segment_ids[:, :batch_max_sent_num, :batch_max_len] 416 | target_ids = target_ids[:, :batch_max_sf_num] 417 | # 1 for EOE 418 | output_masks = output_masks[:, 419 | :batch_max_sf_num+1, :batch_max_sent_num+1] 420 | 421 | target = torch.FloatTensor(output_masks.size()).fill_( 422 | NEGATIVE) # (B, NUM_STEPS, |S|+1) <- 1 for EOE 423 | for i in range(B): 424 | output_masks[i, :num_sfs[i]+1, -1] = 1.0 # for EOE 425 | target[i, num_sfs[i], -1].fill_(POSITIVE) 426 | 427 | for j in range(num_sfs[i].item()): 428 | target[i, j, target_ids[i, j]].fill_(POSITIVE) 429 | target = target.to(device) 430 | 431 | loss = model(input_ids, segment_ids, input_masks, 432 | output_masks, target, target_ids, batch_max_sf_num) 433 | 434 | if n_gpu > 1: 435 | loss = loss.mean() # mean() to average on multi-gpu. 436 | if args.gradient_accumulation_steps > 1: 437 | loss = loss / args.gradient_accumulation_steps 438 | 439 | loss.backward() 440 | 441 | tr_loss += loss.item() 442 | 443 | nb_tr_examples += B 444 | nb_tr_steps += 1 445 | if (step + 1) % args.gradient_accumulation_steps == 0: 446 | # modify learning rate with special warm up BERT uses 447 | lr_this_step = args.learning_rate * \ 448 | warmup_linear(global_step/t_total, args.warmup_proportion) 449 | for param_group in optimizer.param_groups: 450 | param_group['lr'] = lr_this_step 451 | optimizer.step() 452 | optimizer.zero_grad() 453 | global_step += 1 454 | 455 | model_to_save = model.module if hasattr( 456 | model, 'module') else model # Only save the model it-self 457 | output_model_file = os.path.join( 458 | args.output_dir, "pytorch_model_"+str(epc+1)+".bin") 459 | torch.save(model_to_save.state_dict(), output_model_file) 460 | epc += 1 461 | 462 | 463 | if __name__ == "__main__": 464 | main() 465 | -------------------------------------------------------------------------------- /sequential_sentence_selector/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | 5 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, 6 | BertModel, BertTokenizer) 7 | 8 | MODEL_CLASSES = { 9 | 'bert': (BertConfig, BertModel, BertTokenizer), 10 | } 11 | 12 | def get_bert_model_from_pytorch_transformers(model_name): 13 | config_class, model_class, tokenizer_class = MODEL_CLASSES['bert'] 14 | config = config_class.from_pretrained(model_name) 15 | model = model_class.from_pretrained(model_name, from_tf=bool('.ckpt' in model_name), config=config) 16 | 17 | tokenizer = tokenizer_class.from_pretrained(model_name) 18 | 19 | vocab_file_name = './vocabulary_'+model_name+'.txt' 20 | 21 | if not os.path.exists(vocab_file_name): 22 | index = 0 23 | with open(vocab_file_name, "w", encoding="utf-8") as writer: 24 | for token, token_index in sorted(tokenizer.vocab.items(), key=lambda kv: kv[1]): 25 | if index != token_index: 26 | assert False 27 | index = token_index 28 | writer.write(token + u'\n') 29 | index += 1 30 | 31 | return model.state_dict(), vocab_file_name 32 | --------------------------------------------------------------------------------