├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── demo.py
├── eval_main.py
├── eval_odqa.py
├── eval_utils.py
├── graph_retriever
├── README.md
├── __init__.py
├── modeling_graph_retriever.py
├── run_graph_retriever.py
└── utils.py
├── img
└── odqa_overview-1.png
├── pipeline
├── __init__.py
├── graph_retriever.py
├── reader.py
├── sequential_sentence_selector.py
└── tfidf_retriever.py
├── quick_start_hotpot.sh
├── reader
├── README.md
├── __init__.py
├── modeling_reader.py
├── modeling_utils.py
├── rc_utils.py
└── run_reader_confidence.py
├── requirements.txt
├── retriever
├── README.md
├── __init__.py
├── build_db.py
├── build_tfidf.py
├── doc_db.py
├── interactive.py
├── tfidf_doc_ranker.py
├── tfidf_vectorizer_article.py
├── tokenizers.py
└── utils.py
└── sequential_sentence_selector
├── README.md
├── modeling_sequential_sentence_selector.py
├── run_sequential_sentence_selector.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.DS_Store
2 | *.zip
3 | *.pyc
4 | *.json
5 | *.tsv
6 | *.csv
7 | *.db
8 | *.bin
9 | *.npz
10 | __pycache__/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2020 Akari Asai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning to Retrieve Reasoning Paths over Wikipedia Graph for Question Answering
2 |

3 |
4 | This is the official implementation of the following paper:
5 | Akari Asai, Kazuma Hashimoto, Hannaneh Hajishirzi, Richard Socher, Caiming Xiong. [Learning to Retrieve Reasoning Paths over Wikipedia Graph for Question Answering](https://arxiv.org/abs/1911.10470). In: Proceedings of ICLR. 2020
6 |
7 | In the paper, we introduce a graph-based retriever-reader framework that learns to retrieve reasoning paths (a reasoning path = a chain of multiple paragraphs to answer multi-hop questions) from English Wikipedia using its graphical structure, and further verify and extract answers from the selected reasoning paths. Our experimental results show state-of-the-art results across three diverse open-domain QA datasets: [HotpotQA (full wiki)](https://hotpotqa.github.io/), [Natural Questions](https://ai.google.com/research/NaturalQuestions/) Open, [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) Open.
8 |
9 | *Acknowledgements*: To implement our BERT-based modules, we used the [huggingface's transformers](https://huggingface.co/transformers/) library. The implementation of TF-IDF based document ranker and splitter started from the [DrQA](https://github.com/facebookresearch/DrQA) and [document-qa](https://github.com/allenai/document-qa) repositories. Huge thanks to the contributors of those amazing repositories!
10 |
11 |
12 | ## Quick Links
13 | 0. [Quick Run on HotpotQA](#0-quick-run-on-hotpotqa)
14 | 1. [Installation](#1-installation)
15 | 2. [Train](#2-train)
16 | 3. [Evaluation](#3-evaluation)
17 | 4. [Interactive Demo](#4-interactive-demo)
18 | 5. [Others](#5-others)
19 | 6. [Citation and Contact](#citation-and-contact)
20 |
21 | ## 0. Quick Run on HotpotQA
22 | We provide [quick_start_hotpot.sh](quick_start_hotpot.sh), with which you can easily set up and run evaluation on HotpotQA full wiki (on the first 100 questions).
23 |
24 | The script will
25 | 1. download our trained models and evaluation data (See [Installation](#1-installation) for the details),
26 | 2. run the whole pipeline on the evaluation data (See [Evaluation](#3-evaluation)), and
27 | 3. calculate the QA scores and supporting facts scores.
28 |
29 | The evaluation will give us the following results:
30 | ```
31 | {'em': 0.6, 'f1': 0.7468968253968253, 'prec': 0.754030303030303, 'recall': 0.7651666666666667, 'sp_em': 0.49, 'sp_f1': 0.7769365079365077, 'sp_prec': 0.8275, 'sp_recall': 0.7488333333333332, 'joint_em': 0.33, 'joint_f1': 0.6249458756180065, 'joint_prec': 0.6706212121212122, 'joint_recall': 0.6154999999999999}
32 | ```
33 |
34 | Wanna try your own open-domain question? See [Interactive Demo](#4-interactive-demo)! Once you run the [quick_start_hotpot.sh](quick_start_hotpot.sh), you can easily switch to the demo mode by changing some options in the command.
35 |
36 | ## 1. Installation
37 | ### Requirements
38 |
39 | Our framework requires Python 3.5 or higher. We do not support Python 2.X.
40 |
41 | It also requires installing [pytorch-pretrained-bert version (version 0.6.2)](https://github.com/huggingface/transformers/tree/v0.6.2) and [PyTorch](http://pytorch.org/) version 1.0 or higher. The other dependencies are listed in [requirements.txt](requirements.txt).
42 | We are planning to [migrate from pytorch-pretrained-bert](https://huggingface.co/transformers/migration.html) to transformers soon.
43 |
44 | ### Set up
45 | Run the following commands to clone the repository and install our framework:
46 |
47 | ```bash
48 | git clone https://github.com/AkariAsai/learning_to_retrieve_reasoning_paths.git
49 | cd learning_to_retrieve_reasoning_paths
50 | pip install -r requirements.txt
51 | ```
52 |
53 | ### Downloading trained models
54 | All the trained models used in our paper for the three datasets are available in google drive:
55 | - HotpotQA full wiki: [hotpot_models.zip](https://drive.google.com/open?id=1ra37xtEXSROG_f90XxR4kgElGJWUHQyM)
56 | - Natural Questions Open: [nq_models.zip](https://drive.google.com/open?id=120JNI49nK-W014cjneuXJuUC09To3KeQ)
57 | - SQuAD Open: [squad_models.zip](https://drive.google.com/open?id=1_z54KceuYXnA0fJvdASDubti9Zb8aAR4): for SQuAD Open, please download `.db` and `.npz` files following [DrQA](https://github.com/facebookresearch/DrQA/blob/master/download.sh) repository.
58 |
59 | Alternatively, you can download a zip file containing all models by using [gdown](https://pypi.org/project/gdown/).
60 |
61 | e.g., download HotpotQA models
62 | ```bash
63 | mkdir models
64 | cd models
65 | gdown https://drive.google.com/uc?id=1ra37xtEXSROG_f90XxR4kgElGJWUHQyM
66 | unzip hotpot_models.zip
67 | rm hotpot_models.zip
68 | cd ..
69 | ```
70 | **Note: the size of the zip file is about 4GB for HotpotQA models, and once it is extracted, the total size of the models is more than 8GB (including the introductory paragraph only Wikipedia database). The `nq_models.zip` include full Wikipedia database, which is around 30GB once extracted.**
71 |
72 |
73 | ### Downloading data
74 | #### for training
75 | - You can download all of the training datasets from [here (google drive)](https://drive.google.com/drive/folders/1nYQOtoxJiiL5XK6PHOeluTWgh3PTOCVD?usp=sharing).
76 | - We create (1) data to train graph-based retriever, and (2) data to train reader by augmenting the publicly available machine reading comprehension datasets (HotpotQA, SQuAD and Natural Questions).
77 | See the details of the process in Section 3.1.2 and Section 3.2 in [our paper](https://arxiv.org/pdf/1911.10470.pdf).
78 |
79 | #### for evaluation
80 | - Following previous work such as [DrQA](https://github.com/shmsw25/qa-hard-em) or [qa-hard-em](https://github.com/shmsw25/qa-hard-em), we convert the original machine reading comprehension datasets to sets of question and answer pairs. You can download our preprocessed data from [here](https://drive.google.com/open?id=1na7vxYWadK2kS2aqg88RDMFir8PP-2lS).
81 |
82 | - For HotpotQA, we only use question-answer pairs as input, but we need to use the original HotpotQA development set (either fullwiki or distractor) to evaluate supporting fact evaluations from [HotpotQA's website](https://hotpotqa.github.io/).
83 |
84 | ```bash
85 | mkdir data
86 | cd data
87 | mkdir hotpot
88 | cd hotpot
89 | gdown https://drive.google.com/uc?id=1m_7ZJtWQsZ8qDqtItDTWYlsEHDeVHbPt # download preprocessed full wiki data
90 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json # download the original full wiki data for sp evaluation.
91 | cd ../..
92 | ```
93 |
94 | ## 2. Train
95 | In this work, we use a two-stage training approach, which lets you train the reader and retriever independently and easily switch to new reader models.
96 | The details of the training process can be seen in the README files in [graph_retriever](graph_retriever), [reader](reader) and [sequence_sentence_selector](sequential_sentence_selector).
97 |
98 | You can download our pre-trained models from the link mentioned above.
99 |
100 | ## 3. Evaluation
101 | After downloading a TF-IDF retriever, training a graph-retriever and reader models, you can test the performance of our entire system.
102 |
103 |
104 | #### HotpotQA
105 | If you set up using `quick_start_hotpot.sh`, you can run full evaluation by setting the `--eval_file_path` option to `data/hotpot/hotpot_fullwiki_first_100.jsonl` .
106 |
107 | ```bash
108 | python eval_main.py \
109 | --eval_file_path data/hotpot/hotpot_fullwiki_data.jsonl \
110 | --eval_file_path_sp data/hotpot/hotpot_dev_distractor_v1.json \
111 | --graph_retriever_path models/hotpot_models/graph_retriever_path/pytorch_model.bin \
112 | --reader_path models/hotpot_models/reader \
113 | --sequential_sentence_selector_path models/hotpot_models/sequential_sentence_selector/pytorch_model.bin \
114 | --tfidf_path models/hotpot_models/tfidf_retriever/wiki_open_full_new_db_intro_only-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \
115 | --db_path models/hotpot_models/wiki_db/wiki_abst_only_hotpotqa_w_original_title.db \
116 | --bert_model_sequential_sentence_selector bert-large-uncased --do_lower_case \
117 | --tfidf_limit 500 --eval_batch_size 4 --pruning_by_links --beam_graph_retriever 8 \
118 | --beam_sequential_sentence_selector 8 --max_para_num 2000 --sp_eval
119 | ```
120 |
121 | The evaluation will give us the following results (equivalent to our reported results):
122 | ```
123 | {'em': 0.6049966239027684, 'f1': 0.7330873757783022, 'prec': 0.7613180885780131, 'recall': 0.7421444532461545, 'sp_em': 0.49169480081026334, 'sp_f1': 0.7605390258327606, 'sp_prec': 0.8103758721584524, 'sp_recall': 0.7325846435805953, 'joint_em': 0.35827143821742063, 'joint_f1': 0.6143774960171196, 'joint_prec': 0.679462464277477, 'joint_recall': 0.5987834193329556}
124 | ```
125 |
126 | #### SQuAD Open
127 |
128 | ```bash
129 | python eval_main.py \
130 | --eval_file_path data/squad/squad_open_domain_data.jsonl \
131 | --graph_retriever_path models/squad_models/selector/pytorch_model.bin \
132 | --reader_path models/squad_models/reader \
133 | --tfidf_path DrQA/data/wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \
134 | --db_path DrQA/data/wikipedia/docs.db \
135 | --bert_model bert-base-uncased --do_lower_case \
136 | --tfidf_limit 50 --eval_batch_size 4 \
137 | --beam_graph_retriever 8 --max_para_num 2000 --use_full_article
138 | ```
139 |
140 | #### Natural Questions
141 |
142 | ```
143 | python eval_main.py \
144 | --eval_file_path data/nq_open_domain_data.jsonl \
145 | --graph_retriever_path models/nq/selector/pytorch_model.bin --reader_path models/nq/reader/ \
146 | --tfidf_path models/nq_models/tfidf_retriever/wiki_20181220_nq_hyper_linked-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \
147 | --db_path models/nq_models/wiki_db/wiki_20181220_nq_hyper_linked.db \
148 | --bert_model bert-base-uncased --do_lower_case --tfidf_limit 20 --eval_batch_size 4 --pruning_by_links \
149 | --beam_graph_retriever 8 --max_para_num 2000 --use_full_article
150 | ```
151 |
152 | #### (optional) Using TagMe for initial retrieval
153 | As mentioned in Appendix B.7 in our paper, you can optionally use an entity linking system ([TagMe](https://sobigdata.d4science.org/web/tagme/tagme-help)) for the initial retrieval.
154 |
155 | To uee TagMe,
156 | 1. [register](https://services.d4science.org/group/d4science-services-gateway/explore?siteId=22994) to get API key, and
157 | 3. set the API key via `--tagme_api_key` option, and set `--tagme` option true.
158 |
159 | ```
160 | python eval_main.py \
161 | --eval_file_path data/nq_open_domain_data.jsonl \
162 | --graph_retriever_path models/nq/selector/pytorch_model.bin --reader_path models/nq/reader/ \
163 | --tfidf_path models/nq_models/tfidf_retriever/wiki_20181220_nq_hyper_linked-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \
164 | --db_path models/nq_models/wiki_db/wiki_20181220_nq_hyper_linked.db \
165 | --bert_model bert-base-uncased --do_lower_case --tfidf_limit 20 --eval_batch_size 4 --pruning_by_links --beam 8 --max_para_num 2000 --use_full_article --tagme --tagme_api_key YOUR_API_KEY
166 | ```
167 |
168 | *The implementation of the two-step TF-IDF retrieval module (article retrieval --> paragraph-level re-ranking) for Natural Questions is currently in progress, which might give slightly lower scores than the reported results in our paper. We'll fix the issue soon.*
169 |
170 | ## 4. Interactive demo
171 | You could run interactive demo and ask open-domain questions. Our model answers the question with supporting facts.
172 |
173 | If you set up using `quick_start.sh` script, you can run full evaluation by changing the script name to from `eval_main.py` to `demo.py`, and removing `--eval_file_path` and `--eval_file_path_sp` options.
174 |
175 | e.g.,
176 | ```bash
177 | python demo.py \
178 | --graph_retriever_path models/hotpot_models/graph_retriever_path/pytorch_model.bin \
179 | --reader_path models/hotpot_models/reader \
180 | --sequential_sentence_selector_path models/hotpot_models/sequential_sentence_selector/pytorch_model.bin \
181 | --tfidf_path models/hotpot_models/tfidf_retriever/wiki_open_full_new_db_intro_only-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \
182 | --db_path models/hotpot_models/wiki_db/wiki_abst_only_hotpotqa_w_original_title.db \
183 | --do_lower_case --beam 4 --quiet --max_para_num 200 \
184 | --tfidf_limit 20 --pruning_by_links \
185 | ```
186 |
187 | An output example is as follows:
188 | ```
189 | #### Reader results ####
190 | [ {
191 | "q_id": "DEMO_0",
192 | "question": "Bordan Tkachuk was the CEO of a company that provides what sort of products?",
193 | "answer": "IT products and services",
194 | "context": [
195 | "Cintas_0",
196 | "Bordan Tkachuk_0",
197 | "Viglen_0"
198 | ]
199 | }
200 | ]
201 |
202 | #### Supporting facts ####
203 | [
204 | {
205 | "q_id": "DEMO_0",
206 | "supporting facts": {
207 | "Viglen_0": [
208 | [0, "Viglen Ltd provides IT products and services, including storage systems, servers, workstations and data/voice communications equipment and services."
209 | ]
210 | ],
211 | "Bordan Tkachuk_0": [
212 | [0, "Bordan Tkachuk ( ) is a British business executive, the former CEO of Viglen, also known from his appearances on the BBC-produced British version of \"The Apprentice,\" interviewing for his boss Lord Sugar."
213 | ]
214 | ]
215 | }
216 | }
217 | ]
218 |
219 | ```
220 |
221 |
222 | ## 5. Others
223 | ### Distant supervision & negative examples data generation
224 | In this work, we augment the original MRC data with negative and distant supervision examples to make our retriever and reader robust to inference time noise. Our experimental results show these training strategy gives significant performance improvements.
225 |
226 | All of the training data is available [here (google drive)](https://drive.google.com/drive/u/1/folders/1nYQOtoxJiiL5XK6PHOeluTWgh3PTOCVD).
227 |
228 | *We are planning to release our codes to augment training data with negative examples and distant examples to guide future research in open-domain QA fields. Please stay tuned!*
229 |
230 | ### Dataset format
231 | For quick experiments and detailed human analysis, we save intermediate results for each step: original Q-A pair (format A), TF-IDF retrieval (format B), our graph-based (format C) retriever.
232 |
233 | #### Format A (eval data, the input of TF-IDF retriever)
234 | For the evaluation pipeline, our initial input is a simple `jsonlines` format where each line contains one example with `id = [str]`, `question = [str]` and `answer = List[str]` (or `answers = List[str]` for datasets where multiple answers are annotated for each question) information.
235 |
236 | For SQuAD Open and HotpotQA fullwiki, you can download the preprocessed format A files from [here](https://drive.google.com/file/d/1f3YtvImDxB9h6GuVGFelzxvgwEcqxuHn/view?usp=sharing).
237 |
238 | e.g., HotpotQA fullwiki dev
239 | ```
240 | {
241 | "id": "5ab3b0bf5542992ade7c6e39",
242 | "question": "What year did Guns N Roses perform a promo for a movie starring Arnold Schwarzenegger
243 | as a former New York Police detective?",
244 | "answer": ["1999"]
245 | }
246 | ```
247 |
248 | e.g., SQuAD Open dev
249 | ```py
250 | {
251 | "id": "56beace93aeaaa14008c91e0",
252 | "question": "What venue did Super Bowl 50 take place in?",
253 | "answers": ["Levi's Stadium", "Levi's Stadium",
254 | "Levi's Stadium in the San Francisco Bay Area at Santa Clara"]
255 | }
256 | ```
257 |
258 | #### Format B (TF-IDF retriever output)
259 | For TF-IDF results, we store the data as a list of `JSON`, and each data point contains several information.
260 |
261 | - `q_id = [str]`
262 | - `question = [str]`
263 | - `answer = List[str]`
264 | - `context = Dict[str, str]`: Top $N$ paragraphs which are ranked high by our TF-IDF retriever.
265 | - `all_linked_para_title_dic = Dict[str, List[str]]`: Hyper-linked paragraphs' titles from paragraphs in `context`.
266 | - `all_linked_paras_dic = Dict[str, str]`: the paragraphs of the hyper-linked paragraphs.
267 |
268 | For training data, we have additional items that are used as ground-truth reasoning paths.
269 | - `short_gold = List[str]`
270 | - `redundant_gold = List[str]`
271 | - `all_redundant_gold = List[List[str]]`
272 |
273 | e.g., HotpotQA fullwiki dev
274 |
275 | ```py
276 | {
277 | "question": 'Were Scott Derrickson and Ed Wood of the same nationality?'.
278 | "q_id": "5ab3b0bf5542992ade7c6e39",
279 | "context":
280 | {"Scott Derrickson_0": "Scott Derrickson (born July 16, 1966) is an American director,....",
281 | "Ed Wood'_0": "...", ....},
282 | 'all_linked_para_title_dic':
283 | {"Scott Derrickson_0": ['Los Angeles_0', 'California_0', 'Horror film_0', ...]},
284 | 'all_linked_paras_dic':
285 | {"Los Angeles_0": "Los Angeles, officially the City of Los Angeles and often known by its initials L.A., is ...", ...},
286 | 'short_gold':[],
287 | 'redundant_gold': [],
288 | 'all_redundant_gold': []
289 | }
290 | ```
291 |
292 | #### Format C (Graph-based retriever output)
293 | The graph-based retriever's output is a list of `JSON` objects as follows:
294 |
295 | - `q_id = [str]`
296 | - `titles = [str]`: a sequence of titles (the top one reasoning path)
297 | - `topk_titles = List[List[str]]`: k sequences of titles (the top k reasoning paths).
298 | - `context = Dict[str, str]`: the paragraphs which are included in top reasoning paths.
299 |
300 | ```py
301 | {
302 | "q_id": "5a713ea95542994082a3e6e4",
303 | "titles": ["Alvaro Mexia_0", "Boruca_0"],
304 | "topk_titles": [
305 | ["Alvaro Mexia_0", "Boruca_0"],
306 | ["Alvaro Mexia_0", "Indigenous peoples of Florida_0"],
307 | ["Alvaro Mexia_0"],
308 | ["List of Ambassadors of Spain to the United States_0", "Boruca_0"],
309 | ["Alvaro Mexia_0", "St. Augustine, Florida_0"],
310 | ["Alvaro Mexia_0", "Cape Canaveral, Florida_0"],
311 | ["Alvaro Mexia_0", "Florida_0"],
312 | ["Parque de la Bombilla (Mexico City)_0", "Alvaro Mexia_0", "Boruca_0"]],
313 | "context": {
314 | "Alvaro Mexia_0": "Alvaro Mexia was a 17th-century Spanish explorer and cartographer of the east coast of Florida....", "Boruca_0": "The Boruca (also known as the Brunca or the Brunka) are an indigenous people living in Costa Rica"}
315 | }
316 | ```
317 |
318 |
319 | ## Citation and Contact
320 | If you find this codebase is useful or use in your work, please cite our paper.
321 | ```
322 | @inproceedings{
323 | asai2020learning,
324 | title={Learning to Retrieve Reasoning Paths over Wikipedia Graph for Question Answering},
325 | author={Akari Asai and Kazuma Hashimoto and Hannaneh Hajishirzi and Richard Socher and Caiming Xiong},
326 | booktitle={International Conference on Learning Representations},
327 | year={2020}
328 | }
329 | ```
330 |
331 | Please contact Akari Asai ([@AkariAsai](https://twitter.com/AkariAsai?s=20), akari[at]cs.washington.edu) for questions and suggestions.
332 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/__init__.py
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import torch
5 |
6 | from pipeline.tfidf_retriever import TfidfRetriever
7 | from pipeline.graph_retriever import GraphRetriever
8 | from pipeline.reader import Reader
9 | from pipeline.sequential_sentence_selector import SequentialSentenceSelector
10 |
11 | import logging
12 | class DisableLogger():
13 | def __enter__(self):
14 | logging.disable(logging.CRITICAL)
15 | def __exit__(self, a, b, c):
16 | logging.disable(logging.NOTSET)
17 |
18 | class ODQA:
19 | def __init__(self, args):
20 |
21 | self.args = args
22 |
23 | device = torch.device("cuda" if torch.cuda.is_available() and not self.args.no_cuda else "cpu")
24 |
25 | # TF-IDF Retriever
26 | self.tfidf_retriever = TfidfRetriever(self.args.db_path, self.args.tfidf_path)
27 |
28 | # Graph Retriever
29 | self.graph_retriever = GraphRetriever(self.args, device)
30 |
31 | # Reader
32 | self.reader = Reader(self.args, device)
33 |
34 | # Supporting facts selector
35 | self.sequential_sentence_selector = SequentialSentenceSelector(self.args, device)
36 |
37 | def predict(self,
38 | questions: list):
39 |
40 | print('-- Retrieving paragraphs by TF-IDF...', flush=True)
41 | tfidf_retrieval_output = []
42 | for i in range(len(questions)):
43 | question = questions[i]
44 | tfidf_retrieval_output += self.tfidf_retriever.get_abstract_tfidf('DEMO_{}'.format(i), question, self.args)
45 |
46 | print('-- Running the graph-based recurrent retriever model...', flush=True)
47 | graph_retrieval_output = self.graph_retriever.predict(tfidf_retrieval_output, self.tfidf_retriever, self.args)
48 |
49 | print('-- Running the reader model...', flush=True)
50 | answer, title = self.reader.predict(graph_retrieval_output, self.args)
51 |
52 | reader_output = [{'q_id': s['q_id'],
53 | 'question': s['question'],
54 | 'answer': answer[s['q_id']],
55 | 'context': title[s['q_id']]} for s in graph_retrieval_output]
56 |
57 | if self.args.sequential_sentence_selector_path is not None:
58 | print('-- Running the supporting facts retriever...', flush=True)
59 | supporting_facts = self.sequential_sentence_selector.predict(reader_output, self.tfidf_retriever, self.args)
60 | else:
61 | supporting_facts = []
62 |
63 | return tfidf_retrieval_output, graph_retrieval_output, reader_output, supporting_facts
64 |
65 |
66 | def main():
67 |
68 | parser = argparse.ArgumentParser()
69 |
70 | ## Required parameters
71 | parser.add_argument("--graph_retriever_path",
72 | default=None,
73 | type=str,
74 | required=True,
75 | help="Graph retriever model path.")
76 | parser.add_argument("--reader_path",
77 | default=None,
78 | type=str,
79 | required=True,
80 | help="Reader model path.")
81 | parser.add_argument("--tfidf_path",
82 | default=None,
83 | type=str,
84 | required=True,
85 | help="TF-IDF path.")
86 | parser.add_argument("--db_path",
87 | default=None,
88 | type=str,
89 | required=True,
90 | help="DB path.")
91 |
92 | ## Other parameters
93 | parser.add_argument("--sequential_sentence_selector_path",
94 | default=None,
95 | type=str,
96 | help="Supporting facts model path.")
97 | parser.add_argument("--max_sent_num",
98 | default=30,
99 | type=int)
100 | parser.add_argument("--max_sf_num",
101 | default=15,
102 | type=int)
103 |
104 |
105 | parser.add_argument("--bert_model_graph_retriever", default='bert-base-uncased', type=str,
106 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
107 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
108 | "bert-base-multilingual-cased, bert-base-chinese.")
109 |
110 | parser.add_argument("--bert_model_sequential_sentence_selector", default='bert-large-uncased', type=str,
111 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
112 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
113 | "bert-base-multilingual-cased, bert-base-chinese.")
114 |
115 | parser.add_argument("--max_seq_length",
116 | default=378,
117 | type=int,
118 | help="The maximum total input sequence length after WordPiece tokenization. \n"
119 | "Sequences longer than this will be truncated, and sequences shorter \n"
120 | "than this will be padded.")
121 |
122 | parser.add_argument("--max_seq_length_sequential_sentence_selector",
123 | default=256,
124 | type=int,
125 | help="The maximum total input sequence length after WordPiece tokenization. \n"
126 | "Sequences longer than this will be truncated, and sequences shorter \n"
127 | "than this will be padded.")
128 |
129 | parser.add_argument("--do_lower_case",
130 | action='store_true',
131 | help="Set this flag if you are using an uncased model.")
132 | parser.add_argument("--no_cuda",
133 | action='store_true',
134 | help="Whether not to use CUDA when available")
135 |
136 | # RNN graph retriever-specific parameters
137 | parser.add_argument("--max_para_num",
138 | default=10,
139 | type=int)
140 |
141 | parser.add_argument('--eval_batch_size',
142 | type=int,
143 | default=5,
144 | help="Eval batch size")
145 |
146 | parser.add_argument('--beam_graph_retriever',
147 | type=int,
148 | default=1,
149 | help="Beam size for Graph Retriever")
150 | parser.add_argument('--beam_sequential_sentence_selector',
151 | type=int,
152 | default=1,
153 | help="Beam size for Sequential Sentence Selector")
154 |
155 | parser.add_argument('--min_select_num',
156 | type=int,
157 | default=1,
158 | help="Minimum number of selected paragraphs")
159 | parser.add_argument('--max_select_num',
160 | type=int,
161 | default=3,
162 | help="Maximum number of selected paragraphs")
163 | parser.add_argument("--no_links",
164 | action='store_true',
165 | help="Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)")
166 | parser.add_argument("--pruning_by_links",
167 | action='store_true',
168 | help="Whether to do pruning by links (and top 1)")
169 | parser.add_argument("--expand_links",
170 | action='store_true',
171 | help="Whether to expand links with paragraphs in the same article (for NQ)")
172 | parser.add_argument('--tfidf_limit',
173 | type=int,
174 | default=None,
175 | help="Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)")
176 |
177 | parser.add_argument("--split_chunk", default=100, type=int,
178 | help="Chunk size for BERT encoding at inference time")
179 | parser.add_argument("--eval_chunk", default=500, type=int,
180 | help="Chunk size for inference of graph_retriever")
181 |
182 | parser.add_argument("--tagme",
183 | action='store_true',
184 | help="Whether to use tagme at inference")
185 | parser.add_argument('--topk',
186 | type=int,
187 | default=2,
188 | help="Whether to use how many paragraphs from the previous steps")
189 |
190 | parser.add_argument("--n_best_size", default=5, type=int,
191 | help="The total number of n-best predictions to generate in the nbest_predictions.json "
192 | "output file.")
193 | parser.add_argument("--max_answer_length", default=30, type=int,
194 | help="The maximum length of an answer that can be generated. This is needed because the start "
195 | "and end predictions are not conditioned on one another.")
196 | parser.add_argument("--max_query_length", default=64, type=int,
197 | help="The maximum number of tokens for the question. Questions longer than this will "
198 | "be truncated to this length.")
199 | parser.add_argument("--doc_stride", default=128, type=int,
200 | help="When splitting up a long document into chunks, how much stride to take between chunks.")
201 |
202 |
203 | odqa = ODQA(parser.parse_args())
204 |
205 | print()
206 | while True:
207 | questions = input('Questions: ')
208 | questions = questions.strip()
209 | if questions == 'q':
210 | break
211 | elif questions == '':
212 | continue
213 |
214 | questions = questions.strip().split('|||')
215 | tfidf_retrieval_output, graph_retriever_output, reader_output, supporting_facts = odqa.predict(questions)
216 |
217 | if graph_retriever_output is None:
218 | print()
219 | print('Invalid question! "{}"'.format(question))
220 | print()
221 | continue
222 |
223 | print()
224 | print('#### Retrieval results ####')
225 | print(json.dumps(graph_retriever_output, indent=4))
226 | print()
227 |
228 | print('#### Reader results ####')
229 | print(json.dumps(reader_output, indent=4))
230 | print()
231 |
232 | if len(supporting_facts) > 0:
233 | print('#### Supporting facts ####')
234 | print(json.dumps(supporting_facts, indent=4))
235 | print()
236 |
237 |
238 | if __name__ == "__main__":
239 | with DisableLogger():
240 | main()
241 |
--------------------------------------------------------------------------------
/eval_main.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from eval_odqa import ODQAEval
4 | from eval_utils import evaluate, evaluate_w_sp_facts, convert_qa_sp_results_into_hp_eval_format
5 |
6 |
7 | def main():
8 |
9 | odqa = ODQAEval()
10 |
11 | if odqa.args.sequential_sentence_selector_path is not None:
12 | tfidf_retrieval_output, selector_output, reader_output, sp_selector_output = odqa.eval()
13 | if odqa.args.sp_eval is True:
14 | # eval the performance; based on F1 & EM.
15 | predictions = convert_qa_sp_results_into_hp_eval_format(
16 | reader_output, sp_selector_output, odqa.args.db_path)
17 | results = evaluate_w_sp_facts(
18 | odqa.args.eval_file_path_sp, predictions, odqa.args.sampled)
19 | else:
20 | results = evaluate(odqa.args.eval_file_path, reader_output)
21 | print(results)
22 |
23 | else:
24 | tfidf_retrieval_output, selector_output, reader_output = odqa.eval()
25 | # eval the performance; based on F1 & EM.
26 | results = evaluate(odqa.args.eval_file_path, reader_output)
27 |
28 | print("EM :{0}, F1: {1}".format(results['exact_match'], results['f1']))
29 |
30 | # Save the intermediate results.
31 | if odqa.args.tfidf_results_save_path is not None:
32 | print('#### save TFIDF Retrieval results to {}####'.format(
33 | odqa.args.tfidf_results_save_path))
34 | with open(odqa.args.tfidf_results_save_path, "w") as writer:
35 | writer.write(json.dumps(tfidf_retrieval_output, indent=4) + "\n")
36 |
37 | if odqa.args.selector_results_save_path is not None:
38 | print('#### save graph-based Retrieval results to {} ####'.format(
39 | odqa.args.selector_results_save_path))
40 | with open(odqa.args.selector_results_save_path, "w") as writer:
41 | writer.write(json.dumps(selector_output, indent=4) + "\n")
42 |
43 | if odqa.args.reader_results_save_path is not None:
44 | print('#### save reader results to {} ####'.format(
45 | odqa.args.reader_results_save_path))
46 | with open(odqa.args.reader_results_save_path, "w") as writer:
47 | writer.write(json.dumps(reader_output, indent=4) + "\n")
48 |
49 | if odqa.args.sequence_sentence_selector_save_path is not None:
50 | print("#### save sentence selector results to {} ####".format(
51 | odqa.args.sequence_sentence_selector_save_path))
52 | with open(odqa.args.sequence_sentence_selector_save_path, "w") as writer:
53 | writer.write(json.dumps(sp_selector_output, indent=4) + "\n")
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/eval_odqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import json
4 | from tqdm import tqdm
5 |
6 | from pipeline.tfidf_retriever import TfidfRetriever
7 | from pipeline.graph_retriever import GraphRetriever
8 | from pipeline.reader import Reader
9 | from pipeline.sequential_sentence_selector import SequentialSentenceSelector
10 |
11 | from eval_utils import read_jsonlines
12 |
13 | # ODQA components for evaluation
14 | class ODQAEval:
15 | def __init__(self):
16 |
17 | parser = argparse.ArgumentParser()
18 |
19 | ## Required parameters
20 | parser.add_argument("--eval_file_path",
21 | default=None,
22 | type=str,
23 | required=True,
24 | help="Eval data file path")
25 | parser.add_argument("--eval_file_path_sp",
26 | default=None,
27 | type=str,
28 | required=False,
29 | help="Eval data file path for supporting fact evaluation (only for HotpotQA)")
30 | parser.add_argument("--graph_retriever_path",
31 | default=None,
32 | type=str,
33 | required=True,
34 | help="Selector model path.")
35 | parser.add_argument("--reader_path",
36 | default=None,
37 | type=str,
38 | required=True,
39 | help="Reader model path.")
40 | parser.add_argument("--sequential_sentence_selector_path",
41 | default=None,
42 | type=str,
43 | required=False,
44 | help="supporting fact selector model path.")
45 | parser.add_argument("--tfidf_path",
46 | default=None,
47 | type=str,
48 | required=True,
49 | help="TF-IDF path.")
50 | parser.add_argument("--db_path",
51 | default=None,
52 | type=str,
53 | required=True,
54 | help="DB path.")
55 |
56 | parser.add_argument("--bert_model_graph_retriever", default='bert-base-uncased', type=str,
57 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
58 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
59 | "bert-base-multilingual-cased, bert-base-chinese.")
60 |
61 | parser.add_argument("--bert_model_sequential_sentence_selector", default='bert-base-uncased', type=str,
62 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
63 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
64 | "bert-base-multilingual-cased, bert-base-chinese.")
65 |
66 | ## Other parameters
67 | parser.add_argument('--eval_batch_size',
68 | type=int,
69 | default=5,
70 | help="Eval batch size")
71 |
72 | parser.add_argument("--max_seq_length",
73 | default=378,
74 | type=int,
75 | help="The maximum total input sequence length after WordPiece tokenization. \n"
76 | "Sequences longer than this will be truncated, and sequences shorter \n"
77 | "than this will be padded.")
78 |
79 | parser.add_argument("--max_seq_length_sequential_sentence_selector",
80 | default=256,
81 | type=int,
82 | help="The maximum total input sequence length after WordPiece tokenization. \n"
83 | "Sequences longer than this will be truncated, and sequences shorter \n"
84 | "than this will be padded.")
85 |
86 | parser.add_argument("--do_lower_case",
87 | action='store_true',
88 | help="Set this flag if you are using an uncased model.")
89 | parser.add_argument("--no_cuda",
90 | action='store_true',
91 | help="Whether not to use CUDA when available")
92 |
93 | # RNN selector-specific parameters
94 | parser.add_argument("--max_para_num",
95 | default=10,
96 | type=int)
97 |
98 | parser.add_argument('--beam_graph_retriever',
99 | type=int,
100 | default=1,
101 | help="Beam size for Graph Retriever")
102 | parser.add_argument('--beam_sequential_sentence_selector',
103 | type=int,
104 | default=1,
105 | help="Beam size for Sequential Sentence Selector")
106 |
107 | parser.add_argument('--min_select_num',
108 | type=int,
109 | default=1,
110 | help="Minimum number of selected paragraphs")
111 | parser.add_argument('--max_select_num',
112 | type=int,
113 | default=3,
114 | help="Maximum number of selected paragraphs")
115 | parser.add_argument("--no_links",
116 | action='store_true',
117 | help="Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)")
118 | parser.add_argument("--pruning_by_links",
119 | action='store_true',
120 | help="Whether to do pruning by links (and top 1)")
121 | parser.add_argument("--expand_links",
122 | action='store_true',
123 | help="Whether to expand links with paragraphs in the same article (for NQ)")
124 | parser.add_argument('--tfidf_limit',
125 | type=int,
126 | default=100,
127 | help="Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)")
128 | parser.add_argument('--pruning_l',
129 | type=int,
130 | default=10,
131 | help="Set the maximum number of paragraphs retrieved from the same article.")
132 |
133 | parser.add_argument("--split_chunk",
134 | default=100,
135 | type=int,
136 | help="Chunk size for BERT encoding at inference time")
137 |
138 | parser.add_argument("--eval_chunk", default=500, type=int,
139 | help="Chunk size for inference of graph_retriever")
140 |
141 | # To use TagMe, you need to register first here https://sobigdata.d4science.org/web/tagme/tagme-help .
142 | parser.add_argument("--tagme",
143 | action='store_true',
144 | help="Whether to use tagme at inference")
145 | parser.add_argument("--tagme_api_key",
146 | type=str,
147 | default=None,
148 | help="Set the TagMe private API key if you use TagMe.")
149 |
150 | parser.add_argument('--topk',
151 | type=int,
152 | default=2,
153 | help="Whether to use how many paragraphs from the previous steps")
154 |
155 | parser.add_argument("--n_best_size", default=5, type=int,
156 | help="The total number of n-best predictions to generate in the nbest_predictions.json "
157 | "output file.")
158 | parser.add_argument("--max_answer_length", default=30, type=int,
159 | help="The maximum length of an answer that can be generated. This is needed because the start "
160 | "and end predictions are not conditioned on one another.")
161 | parser.add_argument("--max_query_length", default=64, type=int,
162 | help="The maximum number of tokens for the question. Questions longer than this will "
163 | "be truncated to this length.")
164 | parser.add_argument("--doc_stride", default=128, type=int,
165 | help="When splitting up a long document into chunks, how much stride to take between chunks.")
166 | parser.add_argument("--max_sent_num",
167 | default=30,
168 | type=int)
169 | parser.add_argument("--max_sf_num",
170 | default=15,
171 | type=int)
172 | # save intermediate results
173 | parser.add_argument('--tfidf_results_save_path',
174 | type=str,
175 | default=None,
176 | help="If specified, the TF-IDF results will be saved in the file path")
177 |
178 | parser.add_argument('--selector_results_save_path',
179 | type=str,
180 | default=None,
181 | help="If specified, the selector results will be saved in the file path")
182 |
183 | parser.add_argument('--reader_results_save_path',
184 | type=str,
185 | default=None,
186 | help="If specified, the reader results will be saved in the file path")
187 |
188 | parser.add_argument('--sequence_sentence_selector_save_path',
189 | type=str,
190 | default=None,
191 | help="If specified, the reader results will be saved in the file path")
192 |
193 | parser.add_argument('--saved_tfidf_retrieval_outputs_path',
194 | type=str,
195 | default=None,
196 | help="If specified, load the saved TF-IDF retrieval results from the path.")
197 |
198 | parser.add_argument('--saved_selector_outputs_path',
199 | type=str,
200 | default=None,
201 | help="If specified, load the saved reasoning path retrieval results from the path.")
202 |
203 | parser.add_argument("--sp_eval",
204 | action='store_true',
205 | help="set true if you evaluate supporting fact evaluations while running QA evaluation (HotpotQA only).")
206 |
207 | parser.add_argument("--sampled",
208 | action='store_true',
209 | help="evaluate on sampled examples; only for debugging and quick demo.")
210 |
211 | parser.add_argument("--use_full_article",
212 | action='store_true',
213 | help="Set true if you use all of the wikipedia paragraphs, not limiting to intro paragraphs.")
214 |
215 | parser.add_argument("--prune_after_agg",
216 | action='store_true',
217 | help="Pruning after aggregating all paragraphs from top k TFIDF paragraphs.")
218 |
219 |
220 | self.args = parser.parse_args()
221 |
222 | self.device = torch.device("cuda" if torch.cuda.is_available()
223 | and not self.args.no_cuda else "cpu")
224 |
225 | # Retriever
226 | self.retriever = TfidfRetriever(
227 | self.args.db_path, self.args.tfidf_path, self.args.use_full_article, self.args.pruning_l)
228 |
229 | def retrieve(self, eval_questions):
230 | tfidf_retrieval_output = []
231 | for _, eval_q in enumerate(tqdm(eval_questions, desc="Question")):
232 | if self.args.use_full_article is True:
233 | tfidf_retrieval_output += self.retriever.get_article_tfidf_with_hyperlinked_titles(
234 | eval_q["id"], eval_q["question"], self.args)
235 | else:
236 | tfidf_retrieval_output += self.retriever.get_abstract_tfidf(
237 | eval_q["id"], eval_q["question"], self.args)
238 | # create examples with retrieval results.
239 | print("retriever")
240 | print(len(tfidf_retrieval_output))
241 |
242 | # with `use_full_article` setting, we store the title to hyperlinked map and store
243 | # it as retriever's property,
244 | if self.args.use_full_article is True:
245 | title2hyperlink_dic = {}
246 | for example in tfidf_retrieval_output:
247 | title2hyperlink_dic.update(
248 | example["all_linked_para_title_dic"])
249 | self.retriever.store_title2hyperlink_dic(title2hyperlink_dic)
250 |
251 | return tfidf_retrieval_output
252 |
253 | def select(self, tfidf_retrieval_output):
254 | # Selector
255 | selector = GraphRetriever(self.args, self.device)
256 |
257 | selector_output = selector.predict(
258 | tfidf_retrieval_output, self.retriever, self.args)
259 | print("selector")
260 | print(len(selector_output))
261 |
262 | return selector_output
263 |
264 | def read(self, selector_output):
265 | # Reader
266 | reader = Reader(self.args, self.device)
267 |
268 | answers, titles = reader.predict(selector_output, self.args)
269 | reader_output = {}
270 | print("reader")
271 | print(len(answers))
272 | print(answers)
273 | for s in selector_output:
274 | reader_output[s["q_id"]] = answers[s["q_id"]]
275 |
276 | return reader_output, titles
277 |
278 | def explain(self, reader_output_sp):
279 | sequential_sentence_selector = SequentialSentenceSelector(self.args, self.device)
280 | supporting_facts = sequential_sentence_selector.predict(
281 | reader_output_sp, self.retriever, self.args)
282 |
283 | return supporting_facts
284 |
285 |
286 | def eval(self):
287 | # load eval data
288 | # the eval data is described in '{"id": "q_id", "question": "q1", answer": ["a11", ..., "a1i"]}' format (jsonlines) as in DrQA repository.
289 | # TODO: create eval data for HotpotQA, SQuAD and Natural Questions Open.
290 | eval_questions = read_jsonlines(self.args.eval_file_path)
291 |
292 | # Run (or load) graph retriever
293 | # FIXME: do not override saved results.
294 | tfidf_retrieval_output = None
295 | if self.args.saved_selector_outputs_path:
296 | selector_output = json.load(
297 | open(self.args.saved_selector_outputs_path))
298 | else:
299 | if self.args.saved_tfidf_retrieval_outputs_path:
300 | tfidf_retrieval_output = json.load(
301 | open(self.args.saved_tfidf_retrieval_outputs_path))
302 | else:
303 | tfidf_retrieval_output = self.retrieve(eval_questions)
304 |
305 | selector_output = self.select(tfidf_retrieval_output)
306 |
307 | # read and extract answers from reasoning paths
308 | reader_output, titles = self.read(selector_output)
309 |
310 | if self.args.sequential_sentence_selector_path is None:
311 | return tfidf_retrieval_output, selector_output, reader_output
312 | else:
313 | reader_output_sp = [{'q_id': s['q_id'],
314 | 'question': s['question'],
315 | 'answer': reader_output[s['q_id']],
316 | 'context': titles[s['q_id']]} for s in selector_output]
317 | sp_selector_output = self.explain(reader_output_sp)
318 | print(sp_selector_output)
319 | return tfidf_retrieval_output, selector_output, reader_output, sp_selector_output
320 |
321 |
322 |
323 |
--------------------------------------------------------------------------------
/eval_utils.py:
--------------------------------------------------------------------------------
1 | # FIXME: temporary using SQuAD's eval scripts. HotpotQA using different official scripts.
2 | from __future__ import print_function
3 | from collections import Counter
4 | import string
5 | import re
6 | import argparse
7 | import json
8 | import random
9 | import jsonlines
10 | from retriever.doc_db import DocDB
11 |
12 |
13 | def normalize_answer(s):
14 | """Lower text and remove punctuation, articles and extra whitespace."""
15 | def remove_articles(text):
16 | return re.sub(r'\b(a|an|the)\b', ' ', text)
17 |
18 | def white_space_fix(text):
19 | return ' '.join(text.split())
20 |
21 | def remove_punc(text):
22 | exclude = set(string.punctuation)
23 | return ''.join(ch for ch in text if ch not in exclude)
24 |
25 | def lower(text):
26 | return text.lower()
27 |
28 | return white_space_fix(remove_articles(remove_punc(lower(s))))
29 |
30 |
31 | def f1_score(prediction, ground_truth):
32 | prediction_tokens = normalize_answer(prediction).split()
33 | ground_truth_tokens = normalize_answer(ground_truth).split()
34 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
35 | num_same = sum(common.values())
36 | if num_same == 0:
37 | return 0
38 | precision = 1.0 * num_same / len(prediction_tokens)
39 | recall = 1.0 * num_same / len(ground_truth_tokens)
40 | f1 = (2 * precision * recall) / (precision + recall)
41 | return f1
42 |
43 |
44 | def exact_match_score(prediction, ground_truth):
45 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
46 |
47 |
48 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
49 | scores_for_ground_truths = []
50 | for ground_truth in ground_truths:
51 | score = metric_fn(prediction, ground_truth)
52 | scores_for_ground_truths.append(score)
53 | return max(scores_for_ground_truths)
54 |
55 | def evaluate(eval_file_path, predictions, quiet=False, multiple_gts=True):
56 | eval_data = read_jsonlines(eval_file_path)
57 | f1 = exact_match = total = 0
58 |
59 | for qa in eval_data:
60 | q_id = qa['id']
61 | if str(q_id) not in predictions:
62 | print("q_id: {0} is missing.".format(q_id))
63 | continue
64 | if multiple_gts is True:
65 | ground_truths = qa['answers']
66 | else:
67 | ground_truths = qa['answer']
68 | prediction = predictions[q_id]
69 | exact_match += metric_max_over_ground_truths(
70 | exact_match_score, prediction, ground_truths)
71 | f1 += metric_max_over_ground_truths(
72 | f1_score, prediction, ground_truths)
73 | total += 1
74 |
75 | exact_match = 100.0 * exact_match / total
76 | f1 = 100.0 * f1 / total
77 |
78 | return {'exact_match': exact_match, 'f1': f1}
79 |
80 |
81 | def f1_score_normalized(prediction, ground_truth):
82 | normalized_prediction = normalize_answer(prediction)
83 | normalized_ground_truth = normalize_answer(ground_truth)
84 |
85 | ZERO_METRIC = (0, 0, 0)
86 |
87 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
88 | return ZERO_METRIC
89 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
90 | return ZERO_METRIC
91 |
92 | prediction_tokens = normalized_prediction.split()
93 | ground_truth_tokens = normalized_ground_truth.split()
94 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
95 | num_same = sum(common.values())
96 | if num_same == 0:
97 | return ZERO_METRIC
98 | precision = 1.0 * num_same / len(prediction_tokens)
99 | recall = 1.0 * num_same / len(ground_truth_tokens)
100 | f1 = (2 * precision * recall) / (precision + recall)
101 | return f1, precision, recall
102 |
103 |
104 | def update_answer(metrics, prediction, gold):
105 | em = exact_match_score(prediction, gold)
106 | f1, prec, recall = f1_score_normalized(prediction, gold)
107 | metrics['em'] += float(em)
108 | metrics['f1'] += f1
109 | metrics['prec'] += prec
110 | metrics['recall'] += recall
111 | return em, prec, recall
112 |
113 | def update_sp(metrics, prediction, gold):
114 | print(prediction)
115 | cur_sp_pred = set(map(tuple, prediction))
116 | gold_sp_pred = set(map(tuple, gold))
117 | tp, fp, fn = 0, 0, 0
118 | for e in cur_sp_pred:
119 | if e in gold_sp_pred:
120 | tp += 1
121 | else:
122 | fp += 1
123 | for e in gold_sp_pred:
124 | if e not in cur_sp_pred:
125 | fn += 1
126 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
127 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
128 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
129 | em = 1.0 if fp + fn == 0 else 0.0
130 | metrics['sp_em'] += em
131 | metrics['sp_f1'] += f1
132 | metrics['sp_prec'] += prec
133 | metrics['sp_recall'] += recall
134 | return em, prec, recall
135 |
136 |
137 | def convert_qa_sp_results_into_hp_eval_format(reader_output, sp_selector_output, db_path):
138 | db = DocDB(db_path)
139 | sp_dict = {}
140 |
141 | for sp_pred in sp_selector_output:
142 | q_id = sp_pred["q_id"]
143 | sp_dict[q_id] = []
144 | sp_fact_pred = sp_pred["supporting facts"]
145 |
146 | for title in sp_fact_pred:
147 | orig_title = db.get_original_title(title)
148 | for sent_pred in sp_fact_pred[title]:
149 | sp_dict[q_id].append([orig_title, sent_pred[0]])
150 |
151 | return {"answer": reader_output, "sp": sp_dict}
152 |
153 | def evaluate_w_sp_facts(eval_file_path, prediction, sampled=False):
154 | with open(eval_file_path) as f:
155 | gold = json.load(f)
156 |
157 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
158 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
159 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
160 | for dp in gold:
161 | cur_id = dp['_id']
162 | can_eval_joint = True
163 | if cur_id not in prediction['answer']:
164 | can_eval_joint = False
165 | if sampled is False:
166 | print('missing answer {}'.format(cur_id))
167 | else:
168 | em, prec, recall = update_answer(
169 | metrics, prediction['answer'][cur_id], dp['answer'])
170 | if cur_id not in prediction['sp']:
171 | can_eval_joint = False
172 | if sampled is False:
173 | print('missing answer {}'.format(cur_id))
174 | else:
175 | sp_em, sp_prec, sp_recall = update_sp(
176 | metrics, prediction['sp'][cur_id], dp['supporting_facts'])
177 |
178 | if can_eval_joint:
179 | joint_prec = prec * sp_prec
180 | joint_recall = recall * sp_recall
181 | if joint_prec + joint_recall > 0:
182 | joint_f1 = 2 * joint_prec * joint_recall / \
183 | (joint_prec + joint_recall)
184 | else:
185 | joint_f1 = 0.
186 | joint_em = em * sp_em
187 |
188 | metrics['joint_em'] += joint_em
189 | metrics['joint_f1'] += joint_f1
190 | metrics['joint_prec'] += joint_prec
191 | metrics['joint_recall'] += joint_recall
192 |
193 | if sampled is True:
194 | N = len(prediction["answer"])
195 | else:
196 | N = len(gold)
197 | for k in metrics.keys():
198 | metrics[k] /= N
199 |
200 | return metrics
201 |
202 | def read_jsonlines(eval_file_name):
203 | lines = []
204 | print("loading examples from {0}".format(eval_file_name))
205 | with jsonlines.open(eval_file_name) as reader:
206 | for obj in reader:
207 | lines.append(obj)
208 | return lines
209 |
210 |
211 |
--------------------------------------------------------------------------------
/graph_retriever/README.md:
--------------------------------------------------------------------------------
1 | # Graph-based Recurrent Retriever
2 |
3 | This directory includes codes for our graph-based recurrent retriever model described in Section 3.1.1 of our paper.
4 |
5 | Table of contents:
6 | - 1. Training
7 | - 2. Inference (optional)
8 |
9 | ## 1. Training
10 | We use `run_graph_retriever.py` to train the model.
11 | We first need to prepare training data for each task; in our paper, we ued HotpotQA, HotpotQA distractor, Natural Questions Open, and SQuAD Open.
12 | To train the model with the same settings used in our paper, we explain some of the most important arguments below.
13 | See each example for more details about the values of the arguments.
14 |
15 | Note: it is not possible to perfectly reproduce the experimental results even with exactly the same settings, due to device or environmental differences.
16 | When we trained the model for the SQuAD Open dataset five times with five different random seeds, the standard deviation of the final QA EM score was around `0.5`.
17 | Therefore, you would occasionally face a different score by ~1% when changing random seeds or something.
18 | In our paper, we reported all the results based on the default seed (42) in the BERT code base.
19 |
20 | - `--task`
21 | This is a required argument to specify which dataset we use.
22 | Currently, the valid values are `hotpot_open`, `hotpot_distractor`, `squad`, and `nq`.
23 |
24 | - `--train_file_path`
25 | This has to be specified for training.
26 | This can be either a singe file or a set of split files; for the latter case, we simply use the `glob` package to read all the files associated with a partial file path.
27 | For example, if we have three files named `./data_1.json`, `./data_2.json`, and `./data_3.json`, then `--train_file_path ./data_` allows you to load all the three files by `glob.glob(./data_*)`.
28 | This is useful when the single data file is too big and we want to split it into smaller files.
29 |
30 | - `--output_dir`
31 | This is a directory path to save model checkpoints; a checkpoint is saved every half epoch during training.
32 | The checkpoint files are `pytorch_model_0.5.bin`, `pytorch_model_1.bin`, `pytorch_model_1.5.bin`, etc.
33 |
34 | - `--max_para_num`
35 | This is the number of paragraphs associated with a question.
36 | If `--max_para_num` is `N` and the number of the ground-truth paragraphs is `2` for the question, then there are `N-2` paragraphs as negative examples for training.
37 | We expect higher accuracy with larger values of `N`, but there is a trade-off with the training time.
38 |
39 | - `--tfidf_limit`
40 | This is specifically used for HotpotQA, where we use negative examples from both TF-IDF-based and hyperlink-based paragraphs.
41 | If `--max_para_num` is `N` and `--tfidf_limit` is `M` (`N` >= `M`), then there are `M` TF-IDF-based negative examples and `N-M` hyperlink-based negative examples.
42 |
43 | - `--neg_chunk`
44 | This is used to control GPU memory consumption.
45 | Our model training needs to handle many paragraphs for a question with BERT, so it is not feasible to run forward/backward functions all together.
46 | To resolve this issue, we have this argument to split the negative examples into small chunks, where the chunk size can be specified with this argument.
47 | We used NVIDIA V100 GPUs (with 16GB memory) for our experiments, and `--neg_chunk 8` works.
48 | For other GPUs with less memory, please consider using smaller number for this argument.
49 | It should be noted that, changing this value does not affect the results; our model does not use the softmax normalization, and thus we can run the forward/backward functions separately for each chunk.
50 |
51 | - `--train_batch_size` & `--gradient_accumulation_steps`
52 | These control the mini-batch size and how often we update the model parameters with the optimizer.
53 | More importantly, these depend on how many GPUs we can use.
54 | Due to the model size of BERT and the number of paragraphs to be processed, one GPU can handle one example.
55 | That means, `--train_batch_size 4` and `--gradient_accumulation_steps 4` work on a single GPU, but `--train_batch_size 4` and `--gradient_accumulation_steps 1` do not work due to OOM.
56 | However, if we have four GPUs, for example, the latter setting works because the four examples can be handled by the four GPUs.
57 |
58 | - `--use_redundant` and `--use_multiple_redundant`
59 | These are used to use the data augmentation technique for the sake of robustness.
60 | `--use_redundant` allows you to use one additional training example for a question, and `--use_multiple_redundant` allows you to use multiple examepls.
61 | To further specify how many examples can be used for the training, you can specify `--max_redundant_num`.
62 |
63 | - `max_select_num`
64 | This is set to specify the maximum number of reasoning steps in our model.
65 | This value should be `K+1`, where `K` is the number of ground-truth paragraphs, and `1` is for the EOE symbol.
66 | You further need to add `1` when using the `--use_redundant` option.
67 | For example, for HotpotQA, `K` is 2 and we used the `--use_redundant` option, and then the total value is `4`.
68 |
69 | - `--example_limit`
70 | This allows you to sanity-check your running the code, by limiting the number of examples (i.e., questions) to load from each file.
71 | This is useful in checking if everything goes well in your environments.
72 |
73 | ### HotpotQA
74 | For HotpotQA, we used up tp 50 paragraphs for each question, and among them up to 40 paragraphs are from the TF-IDF retriever.
75 | The other 10 paragraphs are from hyperlinks.
76 | The following command assumes the use of four V100 GPUs, but with other devices, you might modify `--neg_chunk` and `--gradient_accumulation_steps`.
77 |
78 | ```bash
79 | python run_graph_retriever.py \
80 | --task hotpot_open \
81 | --bert_model bert-base-uncased --do_lower_case \
82 | --train_file_path \
83 | --output_dir \
84 | --max_para_num 50 \
85 | --tfidf_limit 40 \
86 | --neg_chunk 8 --train_batch_size 4 --gradient_accumulation_steps 1 \
87 | --learning_rate 3e-5 --num_train_epochs 3 \
88 | --use_redundant \
89 | --max_select_num 4 \
90 | ```
91 |
92 | You can use the files in `hotpotqa_new_selector_train_data_db_2017_10_12_fix.zip` for `--train_file_path`.
93 |
94 | ### HotpotQA distractor
95 | For HotpotQA distractor, `--max_para_num` is always 10, due to the task setting.
96 | The following command assumes the use of one V100 GPU, but with other devices, you might modify `--neg_chunk`.
97 |
98 | ```bash
99 | python run_graph_retriever.py \
100 | --task hotpot_distractor \
101 | --bert_model bert-base-uncased --do_lower_case \
102 | --train_file_path \
103 | --output_dir \
104 | --max_para_num 10 \
105 | --neg_chunk 8 --train_batch_size 4 --gradient_accumulation_steps 4 \
106 | --learning_rate 3e-5 --num_train_epochs 3 \
107 | --max_select_num 3
108 | ```
109 |
110 | You can use `hotpot_train_order_sensitive.json` for `--train_file_path`.
111 |
112 | ### Natural Questions Open
113 | For Natural Questions, we used up to 80 paragraphs for each question; changing this to 50 or so does not make significant difference, but in general, the more negative examples, the better (at least, not worse).
114 | To encourage the multi-step nature, we used the `--use_multiple_redundant` option with a larger mini-batch size, because the number of training examples is significantly increased.
115 | The following command assumes the use of four V100 GPUs, but with other devices, you might modify `--neg_chunk` and `--gradient_accumulation_steps`.
116 |
117 | ```bash
118 | python run_graph_retriever.py \
119 | --task nq \
120 | --bert_model bert-base-uncased --do_lower_case \
121 | --train_file_path \
122 | --output_dir \
123 | --max_para_num 80 \
124 | --neg_chunk 8 --train_batch_size 8 --gradient_accumulation_steps 2 \
125 | --learning_rate 2e-5 --num_train_epochs 3 \
126 | --use_multiple_redundant \
127 | --max_select_num 3 \
128 | ```
129 |
130 | You can use the files in `nq_selector_train.tar.gz` for `--train_file_path`.
131 |
132 | ### SQuAD Open
133 | For SQuAD, we used up to 50 paragraphs for each question.
134 | The following command assumes the use of four V100 GPUs, but with other devices, you might modify `--neg_chunk` and `--gradient_accumulation_steps`.
135 |
136 | ```bash
137 | python run_graph_retriever.py \
138 | --task squad \
139 | --bert_model bert-base-uncased --do_lower_case \
140 | --train_file_path \
141 | --output_dir \
142 | --max_para_num 50 \
143 | --neg_chunk 8 --train_batch_size 4 --gradient_accumulation_steps 1 \
144 | --learning_rate 2e-5 --num_train_epochs 3 \
145 | --max_select_num 2
146 | ```
147 |
148 | You can use `squad_tfidf_rgs_train_tfidf_top_negative_example.json` for `--train_file_path`.
149 |
150 | ## 2. Inference
151 | The trained models can be evaluated in our pipelined evaluation script.
152 | However, we can also use `run_graph_retriever.py` to run trained models for inference.
153 | This is in particular useful in sanity checking retrieval accuracy for HotpotQA distractor.
154 | To run the models with the same settings used in our paper, we explain some of the most important arguments below.
155 | Some of th arguments are also used in the pipelined evaluation.
156 |
157 | - `--dev_file_path`
158 | This has to be specified for evaluation or inference.
159 | The semantics is the same as `--train_file_path` in that you can use either a single file or a set of multiple files.
160 |
161 | - `--pred_file`
162 | This has to be specified for evaluation or inference.
163 | This is a file path to save the model's prediction results to be used for the next reading step.
164 | Note that, if you used split files for `--dev_file_path`, you may need to merge the output json files later.
165 |
166 | - `--output_dir` and `--model_suffix`
167 | This is based on `--output_dir` for training.
168 | To load a trained model from `pytorch_model_1.5.bin`, you need to set `--model_suffix 1.5`.
169 |
170 | - `--max_para_num`
171 | This is used to specify the maximum number of paragraphs, including hyper-linked ones, for each question.
172 | This is typically set to a large number to cover all the possible paragraphs.
173 |
174 | - `--beam`
175 | This is used to specify a beam size for our beam search algorithm to retrieve reasoning paths.
176 |
177 | - `--pruning_by_links`
178 | This option is used to do pruning during the beam search, based on hyper-links.
179 |
180 | - `--exapnd_links`
181 | This options is used to add within-document links, along with the default hyper-links on Wikipedia.
182 |
183 | - `--no_links`
184 | This option is used to avoid using the link information, to see how effective the use of the links is.
185 |
186 | - `--tagme`
187 | This is used to add TagMe-based paragraphs for each question, for better initial retrieval.
188 |
189 | - `--eval_chunk`
190 | This option's purpose is similar to that of `--neg_chunk`.
191 | If an evaluation file is too big, it would not fit in CPU RAM, and by this option we can specify a chunk size to run evaluation by avoiding processing all the evaluation examples together.
192 |
193 | - `--split_chunk`
194 | This is useful to control the use of GPU RAM for the BERT encoding.
195 | This is the number of paragraphs to be encoded by BERT together.
196 | The smaller the value is, the less GPU memory is consumed.
197 |
198 | ### HotpotQA
199 |
200 | ```bash
201 | python run_graph_retriever.py \
202 | --task hotpot_open \
203 | --bert_model bert-base-uncased --do_lower_case \
204 | --dev_file_path \
205 | --pred_file \
206 | --output_dir \
207 | --model_suffix 2 \
208 | --max_para_num 2500 \
209 | --beam 8 \
210 | --pruning_by_links \
211 | --eval_chunk 500 \
212 | --split_chunk 300
213 | ```
214 |
215 | ### HotpotQA distractor
216 |
217 | ```bash
218 | python run_graph_retriever.py \
219 | --task hotpot_distractor \
220 | --bert_model bert-base-uncased --do_lower_case \
221 | --dev_file_path \
222 | --pred_file \
223 | --output_dir \
224 | --model_suffix 2 \
225 | --max_para_num 10 \
226 | --beam 8 \
227 | --pruning_by_links \
228 | --eval_chunk 500 \
229 | --split_chunk 300
230 | ```
231 |
232 | You can use `hotpot_fake_sq_dev_new.json` for `--train_file_path`.
233 |
234 | ### Natural Questions Open
235 |
236 | ```bash
237 | python run_graph_retriever.py \
238 | --task nq \
239 | --bert_model bert-base-uncased --do_lower_case \
240 | --dev_file_path \
241 | --pred_file \
242 | --output_dir \
243 | --model_suffix 2 \
244 | --max_para_num 2000 \
245 | --beam 8 \
246 | --pruning_by_links \
247 | --expand_links \
248 | --tagme \
249 | --eval_chunk 500 \
250 | --split_chunk 300
251 | ```
252 |
253 | ### SQuAD Open
254 |
255 | ```bash
256 | python run_graph_retriever.py \
257 | --task squad \
258 | --bert_model bert-base-uncased --do_lower_case \
259 | --dev_file_path \
260 | --pred_file \
261 | --output_dir \
262 | --model_suffix 2 \
263 | --max_para_num 500 \
264 | --beam 8 \
265 | --no_links \
266 | --eval_chunk 500 \
267 | --split_chunk 300
268 | ```
269 |
--------------------------------------------------------------------------------
/graph_retriever/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/graph_retriever/__init__.py
--------------------------------------------------------------------------------
/graph_retriever/modeling_graph_retriever.py:
--------------------------------------------------------------------------------
1 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.parameter import Parameter
7 |
8 | try:
9 | from graph_retriever.utils import tokenize_question
10 | from graph_retriever.utils import tokenize_paragraph
11 | from graph_retriever.utils import expand_links
12 | except:
13 | from utils import tokenize_question
14 | from utils import tokenize_paragraph
15 | from utils import expand_links
16 |
17 | class BertForGraphRetriever(BertPreTrainedModel):
18 |
19 | def __init__(self, config, graph_retriever_config):
20 | super(BertForGraphRetriever, self).__init__(config)
21 |
22 | self.graph_retriever_config = graph_retriever_config
23 |
24 | self.bert = BertModel(config)
25 |
26 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
27 |
28 | # Initial state
29 | self.s = Parameter(torch.FloatTensor(config.hidden_size).uniform_(-0.1, 0.1))
30 |
31 | # Scaling factor for weight norm
32 | self.g = Parameter(torch.FloatTensor(1).fill_(1.0))
33 |
34 | # RNN weight
35 | self.rw = nn.Linear(2*config.hidden_size, config.hidden_size)
36 |
37 | # EOE and output bias
38 | self.eos = Parameter(torch.FloatTensor(config.hidden_size).uniform_(-0.1, 0.1))
39 | self.bias = Parameter(torch.FloatTensor(1).zero_())
40 |
41 | self.apply(self.init_bert_weights)
42 | self.cpu = torch.device('cpu')
43 |
44 | '''
45 | state: (B, 1, D)
46 | '''
47 | def weight_norm(self, state):
48 | state = state / state.norm(dim = 2).unsqueeze(2)
49 | state = self.g * state
50 | return state
51 |
52 | '''
53 | input_ids, token_type_ids, attention_mask: (B, N, L)
54 | B: batch size
55 | N: maximum number of Q-P pairs
56 | L: maximum number of input tokens
57 | '''
58 | def encode(self, input_ids, token_type_ids, attention_mask, split_chunk = None):
59 | B = input_ids.size(0)
60 | N = input_ids.size(1)
61 | L = input_ids.size(2)
62 | input_ids = input_ids.contiguous().view(B*N, L)
63 | token_type_ids = token_type_ids.contiguous().view(B*N, L)
64 | attention_mask = attention_mask.contiguous().view(B*N, L)
65 |
66 | # [CLS] vectors for Q-P pairs
67 | if split_chunk is None:
68 | encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
69 | pooled_output = encoded_layers[:, 0]
70 |
71 | # an option to reduce GPU memory consumption at eval time, by splitting all the Q-P pairs into smaller chunks
72 | else:
73 | assert type(split_chunk) == int
74 |
75 | TOTAL = input_ids.size(0)
76 | start = 0
77 |
78 | while start < TOTAL:
79 | end = min(start+split_chunk-1, TOTAL-1)
80 | chunk_len = end-start+1
81 |
82 | input_ids_ = input_ids[start:start+chunk_len, :]
83 | token_type_ids_ = token_type_ids[start:start+chunk_len, :]
84 | attention_mask_ = attention_mask[start:start+chunk_len, :]
85 |
86 | encoded_layers, pooled_output_ = self.bert(input_ids_, token_type_ids_, attention_mask_, output_all_encoded_layers=False)
87 | encoded_layers = encoded_layers[:, 0]
88 |
89 | if start == 0:
90 | pooled_output = encoded_layers
91 | else:
92 | pooled_output = torch.cat((pooled_output, encoded_layers), dim = 0)
93 |
94 | start = end+1
95 |
96 | pooled_output = pooled_output.contiguous()
97 |
98 | paragraphs = pooled_output.view(pooled_output.size(0)//N, N, pooled_output.size(1)) # (B, N, D), D: BERT dim
99 | EOE = self.eos.unsqueeze(0).unsqueeze(0) # (1, 1, D)
100 | EOE = EOE.expand(paragraphs.size(0), EOE.size(1), EOE.size(2)) # (B, 1, D)
101 | EOE = self.bert.encoder.layer[-1].output.LayerNorm(EOE)
102 | paragraphs = torch.cat((paragraphs, EOE), dim = 1) # (B, N+1, D)
103 |
104 | # Initial state
105 | state = self.s.expand(paragraphs.size(0), 1, self.s.size(0))
106 | state = self.weight_norm(state)
107 |
108 | return paragraphs, state
109 |
110 | '''
111 | input_ids, token_type_ids, attention_mask: (B, N, L)
112 | - B: batch size
113 | - N: maximum number of Q-P pairs
114 | - L: maximum number of input tokens
115 |
116 | output_mask, target: (B, max_num_steps, N+1)
117 | '''
118 | def forward(self, input_ids, token_type_ids, attention_mask, output_mask, target, max_num_steps):
119 |
120 | paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask)
121 |
122 | for i in range(max_num_steps):
123 | if i == 0:
124 | h = state
125 | else:
126 | input = paragraphs[:, i-1:i, :] # (B, 1, D)
127 | state = torch.cat((state, input), dim = 2) # (B, 1, 2*D)
128 | state = self.rw(state) # (B, 1, D)
129 | state = self.weight_norm(state)
130 | h = torch.cat((h, state), dim = 1) # ...--> (B, max_num_steps, D)
131 |
132 | h = self.dropout(h)
133 | output = torch.bmm(h, paragraphs.transpose(1, 2)) # (B, max_num_steps, N+1)
134 | output = output + self.bias
135 |
136 | loss = F.binary_cross_entropy_with_logits(output, target, weight = output_mask, reduction = 'mean')
137 | return loss
138 |
139 | def beam_search(self, input_ids, token_type_ids, attention_mask, examples, tokenizer, retriever, split_chunk):
140 | beam = self.graph_retriever_config.beam
141 | B = input_ids.size(0)
142 | N = self.graph_retriever_config.max_para_num
143 |
144 | pred = []
145 | prob = []
146 |
147 | topk_pred = []
148 | topk_prob = []
149 |
150 | eos_index = N
151 |
152 | init_paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask, split_chunk = split_chunk)
153 |
154 | # Output matrix to be populated
155 | ps = torch.FloatTensor(N+1, self.s.size(0)).zero_().to(self.s.device) # (N+1, D)
156 |
157 | for i in range(B):
158 | init_context_len = len(examples[i].context)
159 |
160 | # Populating the output matrix by the initial encoding
161 | ps[:init_context_len, :].copy_(init_paragraphs[i, :init_context_len, :])
162 | ps[-1, :].copy_(init_paragraphs[i, -1, :])
163 | encoded_titles = set(examples[i].title_order)
164 |
165 | pred_ = [[[], [], 1.0] for _ in range(beam)] # [hist_1, topk_1, score_1], [hist_2, topk_2, score_2], ...
166 | prob_ = [[] for _ in range(beam)]
167 |
168 | state_ = state[i:i+1] # (1, 1, D)
169 | state_ = state_.expand(beam, 1, state_.size(2)) # -> (beam, 1, D)
170 | state_tmp = torch.FloatTensor(state_.size()).zero_().to(state_.device)
171 |
172 | for j in range(self.graph_retriever_config.max_select_num):
173 | if j > 0:
174 | input = [p[0][-1] for p in pred_]
175 | input = torch.LongTensor(input).to(ps.device)
176 | input = ps[input].unsqueeze(1) # (beam, 1, D)
177 | state_ = torch.cat((state_, input), dim = 2) # (beam, 1, 2*D)
178 | state_ = self.rw(state_) # (beam, 1, D)
179 | state_ = self.weight_norm(state_)
180 |
181 | # Opening new links from the previous predictions (pupulating the output matrix dynamically)
182 | if j > 0:
183 | prev_title_size = len(examples[i].title_order)
184 | new_titles = []
185 | for b in range(beam):
186 | prev_pred = pred_[b][0][-1]
187 |
188 | if prev_pred == eos_index:
189 | continue
190 |
191 | prev_title = examples[i].title_order[prev_pred]
192 |
193 | if prev_title not in examples[i].all_linked_paras_dic:
194 |
195 | if retriever is None:
196 | continue
197 | else:
198 | linked_paras_dic = retriever.get_hyperlinked_abstract_paragraphs(
199 | prev_title, examples[i].question)
200 | examples[i].all_linked_paras_dic[prev_title] = {}
201 | examples[i].all_linked_paras_dic[prev_title].update(linked_paras_dic)
202 | examples[i].all_paras.update(linked_paras_dic)
203 |
204 | for linked_title in examples[i].all_linked_paras_dic[prev_title]:
205 | if linked_title in encoded_titles or len(examples[i].title_order) == N:
206 | continue
207 |
208 | encoded_titles.add(linked_title)
209 | new_titles.append(linked_title)
210 | examples[i].title_order.append(linked_title)
211 |
212 | if len(new_titles) > 0:
213 |
214 | tokens_q = tokenize_question(examples[i].question, tokenizer)
215 | input_ids = []
216 | input_masks = []
217 | segment_ids = []
218 | for linked_title in new_titles:
219 | linked_para = examples[i].all_paras[linked_title]
220 |
221 | input_ids_, input_masks_, segment_ids_ = tokenize_paragraph(linked_para, tokens_q, self.graph_retriever_config.max_seq_length, tokenizer)
222 | input_ids.append(input_ids_)
223 | input_masks.append(input_masks_)
224 | segment_ids.append(segment_ids_)
225 |
226 | input_ids = torch.LongTensor([input_ids]).to(ps.device)
227 | token_type_ids = torch.LongTensor([segment_ids]).to(ps.device)
228 | attention_mask = torch.LongTensor([input_masks]).to(ps.device)
229 |
230 | paragraphs, _ = self.encode(input_ids, token_type_ids, attention_mask, split_chunk = split_chunk)
231 | paragraphs = paragraphs.squeeze(0)
232 | ps[prev_title_size:prev_title_size+len(new_titles)].copy_(paragraphs[:len(new_titles), :])
233 |
234 | if retriever is not None and self.graph_retriever_config.expand_links:
235 | expand_links(examples[i].all_paras, examples[i].all_linked_paras_dic, examples[i].all_paras)
236 |
237 | output = torch.bmm(state_, ps.unsqueeze(0).expand(beam, ps.size(0), ps.size(1)).transpose(1, 2)) # (beam, 1, N+1)
238 | output = output + self.bias
239 | output = torch.sigmoid(output)
240 |
241 | output = output.to(self.cpu)
242 |
243 | if j == 0:
244 | output[:, :, len(examples[i].context):] = 0.0
245 | else:
246 | if len(examples[i].title_order) < N:
247 | output[:, :, len(examples[i].title_order):N] = 0.0
248 | for b in range(beam):
249 |
250 | # Omitting previous predictions
251 | for k in range(len(pred_[b][0])):
252 | output[b, :, pred_[b][0][k]] = 0.0
253 |
254 | # Links & topK-based pruning
255 | if self.graph_retriever_config.pruning_by_links:
256 | if pred_[b][0][-1] == eos_index:
257 | output[b, :, :eos_index] = 0.0
258 | output[b, :, eos_index] = 1.0
259 |
260 | elif examples[i].title_order[pred_[b][0][-1]] not in examples[i].all_linked_paras_dic:
261 | for k in range(len(examples[i].title_order)):
262 | if k not in pred_[b][1]:
263 | output[b, :, k] = 0.0
264 |
265 | else:
266 | for k in range(len(examples[i].title_order)):
267 | if k not in pred_[b][1] and examples[i].title_order[k] not in examples[i].all_linked_paras_dic[examples[i].title_order[pred_[b][0][-1]]]:
268 | output[b, :, k] = 0.0
269 |
270 | # always >= M before EOS
271 | if j <= self.graph_retriever_config.min_select_num-1:
272 | output[:, :, -1] = 0.0
273 |
274 |
275 | score = [p[2] for p in pred_]
276 | score = torch.FloatTensor(score)
277 | score = score.unsqueeze(1).unsqueeze(2) # (beam, 1, 1)
278 | score = output * score
279 |
280 | output = output.squeeze(1) # (beam, N+1)
281 | score = score.squeeze(1) # (beam, N+1)
282 | new_pred_ = []
283 | new_prob_ = []
284 |
285 | b = 0
286 | while b < beam:
287 | s, p = torch.max(score.view(score.size(0)*score.size(1)), dim = 0)
288 | s = s.item()
289 | p = p.item()
290 | row = p // score.size(1)
291 | col = p % score.size(1)
292 |
293 | if j == 0:
294 | score[:, col] = 0.0
295 | else:
296 | score[row, col] = 0.0
297 |
298 | p = [[index for index in pred_[row][0]] + [col],
299 | output[row].topk(k = 2, dim = 0)[1].tolist(),
300 | s]
301 | new_pred_.append(p)
302 |
303 | p = [[p_ for p_ in prb] for prb in prob_[row]] + [output[row].tolist()]
304 | new_prob_.append(p)
305 |
306 | state_tmp[b].copy_(state_[row])
307 | b += 1
308 |
309 | pred_ = new_pred_
310 | prob_ = new_prob_
311 | state_ = state_.clone()
312 | state_.copy_(state_tmp)
313 |
314 | if pred_[0][0][-1] == eos_index:
315 | break
316 |
317 | topk_pred.append([])
318 | topk_prob.append([])
319 | for index__ in range(beam):
320 |
321 | pred_tmp = []
322 | for index in pred_[index__][0]:
323 | if index == eos_index:
324 | break
325 | pred_tmp.append(index)
326 |
327 | if index__ == 0:
328 | pred.append(pred_tmp)
329 | prob.append(prob_[0])
330 |
331 | topk_pred[-1].append(pred_tmp)
332 | topk_prob[-1].append(prob_[index__])
333 |
334 | return pred, prob, topk_pred, topk_prob
335 |
--------------------------------------------------------------------------------
/img/odqa_overview-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/img/odqa_overview-1.png
--------------------------------------------------------------------------------
/pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/pipeline/__init__.py
--------------------------------------------------------------------------------
/pipeline/graph_retriever.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 |
4 | import torch
5 |
6 | from graph_retriever.utils import InputExample
7 | from graph_retriever.utils import InputFeatures
8 | from graph_retriever.utils import tokenize_question
9 | from graph_retriever.utils import tokenize_paragraph
10 | from graph_retriever.utils import GraphRetrieverConfig
11 | from graph_retriever.utils import expand_links
12 | from graph_retriever.modeling_graph_retriever import BertForGraphRetriever
13 |
14 | from pytorch_pretrained_bert.tokenization import BertTokenizer
15 |
16 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
17 |
18 | import logging
19 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
20 | datefmt = '%m/%d/%Y %H:%M:%S',
21 | level = logging.INFO)
22 | logger = logging.getLogger(__name__)
23 |
24 | def create_examples(jsn, graph_retriever_config):
25 |
26 | task = graph_retriever_config.task
27 |
28 | examples = []
29 |
30 | '''
31 | Find the mximum size of the initial context (links are not included)
32 | '''
33 | graph_retriever_config.max_context_size = 0
34 |
35 | for data in jsn:
36 |
37 | guid = data['q_id']
38 | question = data['question']
39 | context = data['context'] # {context title: paragraph}
40 | all_linked_paras_dic = {} # {context title: {linked title: paragraph}}
41 |
42 | '''
43 | Use TagMe-based context at test time.
44 | '''
45 | if graph_retriever_config.tagme:
46 | assert 'tagged_context' in data
47 |
48 | '''
49 | Reformat "tagged_context" if needed (c.f. the "context" case above)
50 | '''
51 | if type(data['tagged_context']) == list:
52 | tagged_context = {c[0]: c[1] for c in data['tagged_context']}
53 | data['tagged_context'] = tagged_context
54 |
55 | '''
56 | Append valid paragraphs from "tagged_context" to "context"
57 | '''
58 | for tagged_title in data['tagged_context']:
59 | tagged_text = data['tagged_context'][tagged_title]
60 | if tagged_title not in context and tagged_title is not None and tagged_title.strip() != '' and tagged_text is not None and tagged_text.strip() != '':
61 | context[tagged_title] = tagged_text
62 |
63 | '''
64 | Clean "context" by removing invalid paragraphs
65 | '''
66 | removed_keys = []
67 | for title in context:
68 | if title is None or title.strip() == '' or context[title] is None or context[title].strip() == '':
69 | removed_keys.append(title)
70 | for key in removed_keys:
71 | context.pop(key)
72 |
73 | all_paras = {}
74 | for title in context:
75 | all_paras[title] = context[title]
76 |
77 | if graph_retriever_config.expand_links:
78 | expand_links(context, all_linked_paras_dic, all_paras)
79 |
80 | graph_retriever_config.max_context_size = max(graph_retriever_config.max_context_size, len(context))
81 |
82 | examples.append(InputExample(guid = guid,
83 | q = question,
84 | c = context,
85 | para_dic = all_linked_paras_dic,
86 | s_g = None, r_g = None, all_r_g = None,
87 | all_paras = all_paras))
88 |
89 | return examples
90 |
91 | def convert_examples_to_features(examples, max_seq_length, max_para_num, graph_retriever_config, tokenizer):
92 | """Loads a data file into a list of `InputBatch`s."""
93 |
94 | max_para_num = graph_retriever_config.max_context_size
95 | graph_retriever_config.max_para_num = max(graph_retriever_config.max_para_num, max_para_num)
96 |
97 | max_steps = graph_retriever_config.max_select_num
98 |
99 | DUMMY = [0] * max_seq_length
100 | features = []
101 |
102 | for (ex_index, example) in enumerate(examples):
103 | tokens_q = tokenize_question(example.question, tokenizer)
104 |
105 | title2index = {}
106 | input_ids = []
107 | input_masks = []
108 | segment_ids = []
109 |
110 | titles_list = list(example.context.keys())
111 | for p in titles_list:
112 |
113 | if len(input_ids) == max_para_num:
114 | break
115 |
116 | if p in title2index:
117 | continue
118 |
119 | title2index[p] = len(title2index)
120 | example.title_order.append(p)
121 | p = example.context[p]
122 |
123 | input_ids_, input_masks_, segment_ids_ = tokenize_paragraph(p, tokens_q, max_seq_length, tokenizer)
124 | input_ids.append(input_ids_)
125 | input_masks.append(input_masks_)
126 | segment_ids.append(segment_ids_)
127 |
128 | num_paragraphs_no_links = len(input_ids)
129 |
130 | assert len(input_ids) <= max_para_num
131 |
132 | num_paragraphs = len(input_ids)
133 |
134 | output_masks = [([1.0] * len(input_ids) + [0.0] * (max_para_num - len(input_ids) + 1)) for _ in range(max_para_num + 2)]
135 |
136 | assert len(example.context) == num_paragraphs_no_links
137 | for i in range(len(output_masks[0])):
138 | if i >= num_paragraphs_no_links:
139 | output_masks[0][i] = 0.0
140 |
141 | for i in range(len(input_ids)):
142 | output_masks[i+1][i] = 0.0
143 |
144 | padding = [DUMMY] * (max_para_num - len(input_ids))
145 | input_ids += padding
146 | input_masks += padding
147 | segment_ids += padding
148 |
149 | features.append(
150 | InputFeatures(input_ids=input_ids,
151 | input_masks=input_masks,
152 | segment_ids=segment_ids,
153 | output_masks = output_masks,
154 | num_paragraphs = num_paragraphs,
155 | num_steps = -1,
156 | ex_index = ex_index))
157 |
158 | return features
159 |
160 | class GraphRetriever:
161 | def __init__(self,
162 | args,
163 | device):
164 |
165 | self.graph_retriever_config = GraphRetrieverConfig(example_limit = None,
166 | task = None,
167 | max_seq_length = args.max_seq_length,
168 | max_select_num = args.max_select_num,
169 | max_para_num = args.max_para_num,
170 | tfidf_limit = None,
171 |
172 | train_file_path = None,
173 | use_redundant = None,
174 | use_multiple_redundant = None,
175 | max_redundant_num = None,
176 |
177 | dev_file_path = None,
178 | beam = args.beam_graph_retriever,
179 | min_select_num = args.min_select_num,
180 | no_links = args.no_links,
181 | pruning_by_links = args.pruning_by_links,
182 | expand_links = args.expand_links,
183 | eval_chunk = args.eval_chunk,
184 | tagme = args.tagme,
185 | topk = args.topk,
186 | db_save_path = None)
187 |
188 | print('initializing GraphRetriever...', flush=True)
189 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_graph_retriever, do_lower_case=args.do_lower_case)
190 | model_state_dict = torch.load(args.graph_retriever_path)
191 | self.model = BertForGraphRetriever.from_pretrained(args.bert_model_graph_retriever, state_dict=model_state_dict, graph_retriever_config = self.graph_retriever_config)
192 | self.device = device
193 | self.model.to(self.device)
194 | self.model.eval()
195 | print('Done!', flush=True)
196 |
197 | def predict(self,
198 | tfidf_retrieval_output,
199 | retriever,
200 | args):
201 |
202 | pred_output = []
203 |
204 | eval_examples = create_examples(tfidf_retrieval_output, self.graph_retriever_config)
205 |
206 | TOTAL_NUM = len(eval_examples)
207 | eval_start_index = 0
208 |
209 | while eval_start_index < TOTAL_NUM:
210 | eval_end_index = min(eval_start_index+self.graph_retriever_config.eval_chunk-1, TOTAL_NUM-1)
211 | chunk_len = eval_end_index - eval_start_index + 1
212 |
213 | features = convert_examples_to_features(eval_examples[eval_start_index:eval_start_index+chunk_len], args.max_seq_length, args.max_para_num, self.graph_retriever_config, self.tokenizer)
214 |
215 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
216 | all_input_masks = torch.tensor([f.input_masks for f in features], dtype=torch.long)
217 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
218 | all_output_masks = torch.tensor([f.output_masks for f in features], dtype=torch.float)
219 | all_num_paragraphs = torch.tensor([f.num_paragraphs for f in features], dtype=torch.long)
220 | all_num_steps = torch.tensor([f.num_steps for f in features], dtype=torch.long)
221 | all_ex_indices = torch.tensor([f.ex_index for f in features], dtype=torch.long)
222 | eval_data = TensorDataset(all_input_ids, all_input_masks, all_segment_ids, all_output_masks, all_num_paragraphs, all_num_steps, all_ex_indices)
223 |
224 | eval_sampler = SequentialSampler(eval_data)
225 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
226 | logger.info('Examples from '+str(eval_start_index)+' to '+str(eval_end_index))
227 | for input_ids, input_masks, segment_ids, output_masks, num_paragraphs, num_steps, ex_indices in tqdm(eval_dataloader, desc="Evaluating"):
228 |
229 | batch_max_len = input_masks.sum(dim = 2).max().item()
230 | batch_max_para_num = num_paragraphs.max().item()
231 | batch_max_steps = num_steps.max().item()
232 |
233 | input_ids = input_ids[:, :batch_max_para_num, :batch_max_len]
234 | input_masks = input_masks[:, :batch_max_para_num, :batch_max_len]
235 | segment_ids = segment_ids[:, :batch_max_para_num, :batch_max_len]
236 | output_masks = output_masks[:, :batch_max_para_num+2, :batch_max_para_num+1]
237 | output_masks[:, 1:, -1] = 1.0 # Ignore EOS in the first step
238 |
239 | input_ids = input_ids.to(self.device)
240 | input_masks = input_masks.to(self.device)
241 | segment_ids = segment_ids.to(self.device)
242 | output_masks = output_masks.to(self.device)
243 |
244 | examples = [eval_examples[eval_start_index+ex_indices[i].item()] for i in range(input_ids.size(0))]
245 |
246 | with torch.no_grad():
247 | pred, prob, topk_pred, topk_prob = self.model.beam_search(input_ids, segment_ids, input_masks, examples = examples, tokenizer = self.tokenizer, retriever = retriever, split_chunk = args.split_chunk)
248 |
249 | for i in range(len(pred)):
250 | e = examples[i]
251 |
252 | titles = [e.title_order[p] for p in pred[i]]
253 | question = e.question
254 |
255 | pred_output.append({})
256 | pred_output[-1]['q_id'] = e.guid
257 |
258 | pred_output[-1]['question'] = question
259 |
260 | topk_titles = [[e.title_order[p] for p in topk_pred[i][j]] for j in range(len(topk_pred[i]))]
261 | pred_output[-1]['topk_titles'] = topk_titles
262 |
263 | topk_probs = []
264 | pred_output[-1]['topk_probs'] = topk_probs
265 |
266 | context = {}
267 | context_from_tfidf = set()
268 | context_from_hyperlink = set()
269 | for ts in topk_titles:
270 | for t in ts:
271 | context[t] = e.all_paras[t]
272 |
273 | if t in e.context:
274 | context_from_tfidf.add(t)
275 | else:
276 | context_from_hyperlink.add(t)
277 |
278 | pred_output[-1]['context'] = context
279 | pred_output[-1]['context_from_tfidf'] = list(context_from_tfidf)
280 | pred_output[-1]['context_from_hyperlink'] = list(context_from_hyperlink)
281 |
282 | eval_start_index = eval_end_index + 1
283 | del features
284 | del all_input_ids
285 | del all_input_masks
286 | del all_segment_ids
287 | del all_output_masks
288 | del all_num_paragraphs
289 | del all_num_steps
290 | del all_ex_indices
291 | del eval_data
292 |
293 | return pred_output
294 |
--------------------------------------------------------------------------------
/pipeline/reader.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from pytorch_pretrained_bert.tokenization import BertTokenizer
4 |
5 | from reader.modeling_reader import BertForQuestionAnsweringConfidence
6 | from reader.rc_utils import read_squad_style_hotpot_examples, \
7 | convert_examples_to_features, write_predictions_yes_no_beam
8 |
9 | import collections
10 |
11 | from tqdm import tqdm
12 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
13 |
14 | RawResult = collections.namedtuple("RawResult",
15 | ["unique_id", "start_logits", "end_logits", "switch_logits"])
16 |
17 | class Reader:
18 | def __init__(self,
19 | args,
20 | device):
21 |
22 | print('initializing Reader...', flush=True)
23 | self.model = BertForQuestionAnsweringConfidence.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
24 | self.tokenizer = BertTokenizer.from_pretrained(args.reader_path, args.do_lower_case)
25 | self.device = device
26 |
27 | self.model.to(device)
28 | self.model.eval()
29 | print('Done!', flush=True)
30 |
31 | def convert_retriever_output(self,
32 | retriever_output):
33 |
34 | selected_paras_top_n = {str(item["q_id"]): item["topk_titles"]
35 | for item in retriever_output}
36 |
37 | context_dic = {str(item["q_id"]): item["context"]
38 | for item in retriever_output}
39 |
40 | squad_style_data = {'data': [], 'version': '1.1'}
41 |
42 | retrieved_para_dict = {}
43 |
44 | for data in retriever_output:
45 | example_id = data['q_id']
46 | question_text = data['question']
47 | pred_para_titles = selected_paras_top_n[example_id]
48 |
49 | for selected_paras in pred_para_titles:
50 | title, context = "", ""
51 |
52 | for para_title in selected_paras:
53 | paragraphs = context_dic[example_id][para_title]
54 | context += paragraphs
55 |
56 | title = para_title
57 | context += " "
58 | # post process to remove unnecessary spaces.
59 | if context[0] == " ":
60 | context = context[1:]
61 | if context[-1] == " ":
62 | context = context[: -1]
63 |
64 | context = context.replace(" ", " ")
65 |
66 | squad_example = {'context': context, 'para_titles': selected_paras,
67 | 'qas': [{'question': question_text, 'id': example_id}]}
68 | squad_style_data["data"].append(
69 | {'title': title, 'paragraphs': [squad_example]})
70 |
71 | return squad_style_data
72 |
73 | def predict(self,
74 | retriever_output,
75 | args):
76 |
77 | squad_style_data = self.convert_retriever_output(retriever_output)
78 |
79 | e = read_squad_style_hotpot_examples(squad_style_hotpot_dev=squad_style_data,
80 | is_training=False,
81 | version_2_with_negative=False,
82 | store_path_prob=False)
83 |
84 | features = convert_examples_to_features(
85 | examples=e,
86 | tokenizer=self.tokenizer,
87 | max_seq_length=args.max_seq_length,
88 | doc_stride=args.doc_stride,
89 | max_query_length=args.max_query_length,
90 | is_training=False,
91 | quiet = True)
92 |
93 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
94 | all_input_masks = torch.tensor([f.input_mask for f in features], dtype=torch.long)
95 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
96 | eval_data = TensorDataset(all_input_ids, all_input_masks, all_segment_ids)
97 |
98 | eval_sampler = SequentialSampler(eval_data)
99 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
100 |
101 | all_results = []
102 |
103 | f_offset = 0
104 | for input_ids, input_masks, segment_ids in tqdm(eval_dataloader, desc="Evaluating"):
105 | input_ids = input_ids.to(self.device)
106 | input_masks = input_masks.to(self.device)
107 | segment_ids = segment_ids.to(self.device)
108 | with torch.no_grad():
109 | batch_start_logits, batch_end_logits, batch_switch_logits = self.model(input_ids, segment_ids, input_masks)
110 |
111 | for i in range(input_ids.size(0)):
112 | start_logits = batch_start_logits[i].detach().cpu().tolist()
113 | end_logits = batch_end_logits[i].detach().cpu().tolist()
114 | switch_logits = batch_switch_logits[i].detach().cpu().tolist()
115 | eval_feature = features[f_offset+i]
116 | unique_id = int(features[f_offset+i].unique_id)
117 | all_results.append(RawResult(unique_id=unique_id,
118 | start_logits=start_logits,
119 | end_logits=end_logits,
120 | switch_logits=switch_logits))
121 | f_offset += input_ids.size(0)
122 |
123 | return write_predictions_yes_no_beam(e, features, all_results,
124 | args.n_best_size, args.max_answer_length,
125 | args.do_lower_case, None,
126 | None, None, False,
127 | False, None,
128 | output_selected_paras=True,
129 | quiet = True)
130 |
--------------------------------------------------------------------------------
/pipeline/sequential_sentence_selector.py:
--------------------------------------------------------------------------------
1 | from sequential_sentence_selector.modeling_sequential_sentence_selector import BertForSequentialSentenceSelector
2 | from sequential_sentence_selector.run_sequential_sentence_selector import InputExample
3 | from sequential_sentence_selector.run_sequential_sentence_selector import InputFeatures
4 | from sequential_sentence_selector.run_sequential_sentence_selector import DataProcessor
5 | from sequential_sentence_selector.run_sequential_sentence_selector import convert_examples_to_features
6 |
7 | from pytorch_pretrained_bert.tokenization import BertTokenizer
8 |
9 | import torch
10 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
11 |
12 | from tqdm import tqdm
13 |
14 | class SequentialSentenceSelector:
15 | def __init__(self,
16 | args,
17 | device):
18 |
19 | if args.sequential_sentence_selector_path is None:
20 | return None
21 |
22 | print('initializing SequentialSentenceSelector...', flush=True)
23 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_sequential_sentence_selector, do_lower_case=args.do_lower_case)
24 | model_state_dict = torch.load(args.sequential_sentence_selector_path)
25 | self.model = BertForSequentialSentenceSelector.from_pretrained(args.bert_model_sequential_sentence_selector, state_dict=model_state_dict)
26 | self.device = device
27 | self.model.to(self.device)
28 | self.model.eval()
29 |
30 | self.processor = DataProcessor()
31 | print('Done!', flush=True)
32 |
33 | def convert_reader_output(self,
34 | reader_output,
35 | tfidf_retriever):
36 | new_output = []
37 |
38 | for data in reader_output:
39 | entry = {}
40 | entry['q_id'] = data['q_id']
41 | entry['question'] = data['question']
42 | entry['answer'] = data['answer']
43 | entry['titles'] = data['context']
44 | entry['context'] = tfidf_retriever.load_abstract_para_text(entry['titles'], keep_sentence_split = True)
45 | entry['supporting_facts'] = {t: [] for t in entry['titles']}
46 | new_output.append(entry)
47 |
48 | return new_output
49 |
50 | def predict(self,
51 | reader_output,
52 | tfidf_retriever,
53 | args):
54 |
55 | reader_output = self.convert_reader_output(reader_output, tfidf_retriever)
56 | eval_examples = self.processor.create_examples(reader_output)
57 | eval_features = convert_examples_to_features(
58 | eval_examples, args.max_seq_length_sequential_sentence_selector, args.max_sent_num, args.max_sf_num, self.tokenizer)
59 |
60 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
61 | all_input_masks = torch.tensor([f.input_masks for f in eval_features], dtype=torch.long)
62 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
63 | all_output_masks = torch.tensor([f.output_masks for f in eval_features], dtype=torch.float)
64 | all_num_sents = torch.tensor([f.num_sents for f in eval_features], dtype=torch.long)
65 | all_num_sfs = torch.tensor([f.num_sfs for f in eval_features], dtype=torch.long)
66 | all_ex_indices = torch.tensor([f.ex_index for f in eval_features], dtype=torch.long)
67 | eval_data = TensorDataset(all_input_ids,
68 | all_input_masks,
69 | all_segment_ids,
70 | all_output_masks,
71 | all_num_sents,
72 | all_num_sfs,
73 | all_ex_indices)
74 | # Run prediction for full data
75 | eval_sampler = SequentialSampler(eval_data)
76 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
77 |
78 | pred_output = []
79 |
80 | for input_ids, input_masks, segment_ids, output_masks, num_sents, num_sfs, ex_indices in tqdm(eval_dataloader, desc="Evaluating"):
81 | batch_max_len = input_masks.sum(dim = 2).max().item()
82 | batch_max_sent_num = num_sents.max().item()
83 | batch_max_sf_num = num_sfs.max().item()
84 |
85 | input_ids = input_ids[:, :batch_max_sent_num, :batch_max_len]
86 | input_masks = input_masks[:, :batch_max_sent_num, :batch_max_len]
87 | segment_ids = segment_ids[:, :batch_max_sent_num, :batch_max_len]
88 | output_masks = output_masks[:, :batch_max_sent_num+2, :batch_max_sent_num+1]
89 |
90 | output_masks[:, 1:, -1] = 1.0 # Ignore EOE in the first step
91 |
92 | input_ids = input_ids.to(self.device)
93 | input_masks = input_masks.to(self.device)
94 | segment_ids = segment_ids.to(self.device)
95 | output_masks = output_masks.to(self.device)
96 |
97 | examples = [eval_examples[ex_indices[i].item()] for i in range(input_ids.size(0))]
98 |
99 | with torch.no_grad():
100 | pred, prob, topk_pred, topk_prob = self.model.beam_search(input_ids, segment_ids, input_masks, output_masks, max_num_steps = args.max_sf_num+1, examples = examples, beam = args.beam_sequential_sentence_selector)
101 |
102 | for i in range(len(pred)):
103 | e = examples[i]
104 |
105 | sfs = {}
106 | for p in pred[i]:
107 | offset = 0
108 | for j in range(len(e.titles)):
109 | if p >= offset and p < offset+len(e.context[e.titles[j]]):
110 | if e.titles[j] not in sfs:
111 | sfs[e.titles[j]] = [[p-offset, e.context[e.titles[j]][p-offset]]]
112 | else:
113 | sfs[e.titles[j]].append([p-offset, e.context[e.titles[j]][p-offset]])
114 | break
115 | offset += len(e.context[e.titles[j]])
116 |
117 | # Hack
118 | for title in e.titles:
119 | if title not in sfs and len(sfs) < 2:
120 | sfs[title] = [0]
121 |
122 | output = {}
123 | output['q_id'] = e.guid
124 | output['supporting facts'] = sfs
125 | pred_output.append(output)
126 |
127 | return pred_output
128 |
129 |
--------------------------------------------------------------------------------
/pipeline/tfidf_retriever.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from retriever.doc_db import DocDB
4 | from retriever.tfidf_doc_ranker import TfidfDocRanker
5 | from retriever.tfidf_vectorizer_article import TopTfIdf
6 |
7 | from retriever.utils import load_para_collections_from_tfidf_id_intro_only, \
8 | load_para_and_linked_titles_dict_from_tfidf_id, prune_top_k_paragraphs, \
9 | normalize
10 |
11 | from retriever.tfidf_vectorizer_article import TopTfIdf
12 |
13 | import logging
14 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
15 | datefmt = '%m/%d/%Y %H:%M:%S',
16 | level = logging.INFO)
17 | logger = logging.getLogger(__name__)
18 |
19 | class TfidfRetriever:
20 | def __init__(self,
21 | db_save_path: str,
22 | tfidf_model_path: str,
23 | use_full_article: bool = False,
24 | pruning_l: int = 10):
25 |
26 | print('initializing TfidfRetriever...', flush=True)
27 | self.db = DocDB(db_save_path)
28 | if tfidf_model_path is not None:
29 | self.ranker = TfidfDocRanker(tfidf_path=tfidf_model_path, strict=False)
30 | self.use_full_article = use_full_article
31 | # FIXME: make this to dynamically set the pruning_l
32 | self.pruning_l = pruning_l
33 | print('Done!', flush=True)
34 |
35 | def store_title2hyperlink_dic(self, title2hyperlink_dic):
36 | self.title2hyperlink_dic = title2hyperlink_dic
37 |
38 | def load_abstract_para_text(self,
39 | doc_names,
40 | keep_sentence_split = False):
41 |
42 | context = {}
43 |
44 | for doc_name in doc_names:
45 | para_title_text_pairs = load_para_collections_from_tfidf_id_intro_only(doc_name, self.db)
46 | if len(para_title_text_pairs) == 0:
47 | logger.warning("{} is missing".format(doc_name))
48 | continue
49 | else:
50 | para_title_text_pairs = para_title_text_pairs[0]
51 | if keep_sentence_split:
52 | para_title_text_pairs = {para_title_text_pairs[0]: para_title_text_pairs[1]}
53 | else:
54 | para_title_text_pairs = {para_title_text_pairs[0]: "".join(para_title_text_pairs[1])}
55 | context.update(para_title_text_pairs)
56 |
57 | return context
58 |
59 | # load sampled text included in the target articles, with two stage tfidf retrieval.
60 | def load_sampled_para_text_and_linked_titles(self,
61 | doc_names,
62 | question,
63 | pruning_l,
64 | prune_after_agg=True):
65 |
66 | context = {}
67 | linked_titles = {}
68 | tfidf_vectorizer = TopTfIdf(n_to_select=pruning_l,
69 | filter_dist_one=True, rank=True)
70 | para_dict_all = {}
71 | linked_titles_dict_all = {}
72 | for doc_name in doc_names:
73 | paras_dict, linked_titles_dict = load_para_and_linked_titles_dict_from_tfidf_id(
74 | doc_name, self.db)
75 | if len(paras_dict) == 0:
76 | continue
77 | if prune_after_agg is True:
78 | para_dict_all.update(paras_dict)
79 | linked_titles_dict_all.update(linked_titles_dict)
80 | else:
81 | pruned_para_dict = prune_top_k_paragraphs(
82 | question, paras_dict, tfidf_vectorizer, pruning_l)
83 |
84 | # add top pruning_l paragraphs from the target article.
85 | context.update(pruned_para_dict)
86 | # add hyperlinked paragraphs of the top pruning_l paragraphs from the target article.
87 | pruned_linked_titles = {k: v for k, v in linked_titles_dict.items() if k in pruned_para_dict}
88 | assert len(pruned_para_dict) == len(pruned_linked_titles)
89 | linked_titles.update(pruned_linked_titles)
90 |
91 | if prune_after_agg is True:
92 | pruned_para_dict = prune_top_k_paragraphs(question, para_dict_all, tfidf_vectorizer, pruning_l)
93 | context.update(pruned_para_dict)
94 | pruned_linked_titles = {
95 | k: v for k, v in linked_titles_dict_all.items() if k in pruned_para_dict}
96 | assert len(pruned_para_dict) == len(pruned_linked_titles)
97 | linked_titles.update(pruned_linked_titles)
98 |
99 | return context, linked_titles
100 |
101 | def retrieve_titles_w_tag_me(self, question, tagme_api_key):
102 | import tagme
103 | tagme.GCUBE_TOKEN = tagme_api_key
104 | q_annotations = tagme.annotate(question)
105 | tagged_titles = []
106 | for ann in q_annotations.get_annotations(0.1):
107 | tagged_titles.append(ann.entity_title)
108 | return tagged_titles
109 |
110 | def load_sampled_tagged_para_text(self, question, pruning_l, tagme_api_key):
111 | tagged_titles = self.retrieve_titles_w_tag_me(question, tagme_api_key)
112 | tagged_doc_names = [normalize(title) for title in tagged_titles]
113 |
114 | context, _ = self.load_sampled_para_text_and_linked_titles(
115 | tagged_doc_names, question, pruning_l)
116 |
117 | return context
118 |
119 | def get_abstract_tfidf(self,
120 | q_id,
121 | question,
122 | args):
123 |
124 | doc_names, _ = self.ranker.closest_docs(question, k=args.tfidf_limit)
125 | # Add TFIDF close documents
126 | context = self.load_abstract_para_text(doc_names)
127 |
128 | return [{"question": question,
129 | "context": context,
130 | "q_id": q_id}]
131 |
132 | def get_article_tfidf_with_hyperlinked_titles(self, q_id,question, args):
133 | """
134 | Retrieve articles with their corresponding hyperlinked titles.
135 | Due to efficiency, we sample top k articles, and then sample top l paragraphs from each article.
136 | (so, eventually we get k*l paragraphs with tfidf-based pruning.)
137 | We also store the hyperlinked titles for each paragraph.
138 | """
139 |
140 | tfidf_limit, pruning_l, prune_after_agg = args.tfidf_limit, args.pruning_l, args.prune_after_agg
141 | doc_names, _ = self.ranker.closest_docs(question, k=tfidf_limit)
142 | context, hyper_linked_titles = self.load_sampled_para_text_and_linked_titles(
143 | doc_names, question, pruning_l, prune_after_agg)
144 |
145 | if args.tagme is True and args.tagme_api_key is not None:
146 | # if add TagMe
147 | tagged_context = self.load_sampled_tagged_para_text(
148 | question, pruning_l, args.tagme_api_key)
149 |
150 | return [{"question": question,
151 | "context": context,
152 | "tagged_context": tagged_context,
153 | "all_linked_para_title_dic": hyper_linked_titles,
154 | "q_id": q_id}]
155 | else:
156 | return [{"question": question,
157 | "context": context,
158 | "all_linked_para_title_dic": hyper_linked_titles,
159 | "q_id": q_id}]
160 |
161 |
162 | def get_hyperlinked_abstract_paragraphs(self,
163 | title: str,
164 | question: str = None):
165 |
166 | if self.use_full_article is True and self.title2hyperlink_dic is not None:
167 | if title not in self.title2hyperlink_dic:
168 | return {}
169 | hyper_linked_titles = self.title2hyperlink_dic[title]
170 | elif self.use_full_article is True and self.title2hyperlink_dic is None:
171 | # for full article version, we need to store title2hyperlink_dic beforehand.
172 | raise NotImplementedError()
173 | else:
174 | hyper_linked_titles = self.db.get_hyper_linked(normalize(title))
175 |
176 | if hyper_linked_titles is None:
177 | return {}
178 | # if there are any hyperlinked titles, add the information to all_linked_paragraph
179 | all_linked_paras_dic = {}
180 |
181 | if self.use_full_article is True and self.title2hyperlink_dic is not None:
182 | for hyper_linked_para_title in hyper_linked_titles:
183 | paras_dict, _ = load_para_and_linked_titles_dict_from_tfidf_id(
184 | hyper_linked_para_title, self.db)
185 | # Sometimes article titles are updated over times but the hyperlinked titles are not. e.g., Winds <--> Wind
186 | # in our current database, we do not handle these "redirect" cases and thus we cannot recover.
187 | # If we cannot retrieve the hyperlinked articles, we just discard these articles.
188 | if len(paras_dict) == 0:
189 | continue
190 | tfidf_vectorizer = TopTfIdf(n_to_select=self.pruning_l,
191 | filter_dist_one=True, rank=True)
192 | pruned_para_dict = prune_top_k_paragraphs(
193 | question, paras_dict, tfidf_vectorizer, self.pruning_l)
194 |
195 | all_linked_paras_dic.update(pruned_para_dict)
196 |
197 | else:
198 | for hyper_linked_para_title in hyper_linked_titles:
199 | para_title_text_pairs = load_para_collections_from_tfidf_id_intro_only(
200 | hyper_linked_para_title, self.db)
201 | # Sometimes article titles are updated over times but the hyperlinked titles are not. e.g., Winds <--> Wind
202 | # in our current database, we do not handle these "redirect" cases and thus we cannot recover.
203 | # If we cannot retrieve the hyperlinked articles, we just discard these articles.
204 | if len(para_title_text_pairs) == 0:
205 | continue
206 |
207 | para_title_text_pairs = {para[0]: "".join(para[1])
208 | for para in para_title_text_pairs}
209 |
210 | all_linked_paras_dic.update(para_title_text_pairs)
211 |
212 | return all_linked_paras_dic
213 |
--------------------------------------------------------------------------------
/quick_start_hotpot.sh:
--------------------------------------------------------------------------------
1 | # download trained models
2 | mkdir models
3 | cd models
4 | gdown https://drive.google.com/uc?id=1ra37xtEXSROG_f90XxR4kgElGJWUHQyM
5 | unzip hotpot_models.zip
6 | rm hotpot_models.zip
7 | cd ..
8 |
9 | # download eval data
10 | mkdir data
11 | cd data
12 | mkdir hotpot
13 | cd hotpot
14 | gdown https://drive.google.com/uc?id=1m_7ZJtWQsZ8qDqtItDTWYlsEHDeVHbPt
15 | gdown https://drive.google.com/uc?id=1D-Uj4DPMZWkSouzw5Gg5YhkGiBHSqCJp
16 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json
17 | cd ../..
18 |
19 | # run evaluation scripts
20 | python eval_main.py \
21 | --eval_file_path data/hotpot/hotpot_fullwiki_first_100.jsonl \
22 | --eval_file_path_sp data/hotpot/hotpot_dev_distractor_v1.json \
23 | --graph_retriever_path models/hotpot_models/graph_retriever_path/pytorch_model.bin \
24 | --reader_path models/hotpot_models/reader \
25 | --sequential_sentence_selector_path models/hotpot_models/sequential_sentence_selector/pytorch_model.bin \
26 | --tfidf_path models/hotpot_models/tfidf_retriever/wiki_open_full_new_db_intro_only-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz \
27 | --db_path models/hotpot_models/wiki_db/wiki_abst_only_hotpotqa_w_original_title.db \
28 | --bert_model_sequential_sentence_selector bert-large-uncased --do_lower_case \
29 | --tfidf_limit 500 --eval_batch_size 4 --pruning_by_links --beam_graph_retriever 8 \
30 | --beam_sequential_sentence_selector 8 --max_para_num 2000 --sp_eval --sampled
31 |
--------------------------------------------------------------------------------
/reader/README.md:
--------------------------------------------------------------------------------
1 | ## Reasoning Path Reader
2 | This directory includes codes for our reasoning path reader model described in Section 3.2 of our paper.
3 | Our reader model is based on BERT QA model ([Devlin et al. 2019](https://arxiv.org/abs/1810.04805)), and we extend it to jointly predict answer spans and plausibility of reasoning paths selected by our retriever components.
4 |
5 | Table of contents:
6 | - 1. Training
7 | - 2. Evaluation
8 |
9 | ## 1. Training
10 | ### Training data
11 | We use [rc_utils.py](rc_utils.py) to train our reasoning path reader models.
12 | To train our reader, we first convert the original MRC datasets into SQuAD (v.2) data format, adding distant examples and negative examples.
13 |
14 | We provide the pre-processed train and dev data files for all three datasets here (google drive):
15 |
16 | - [HotpotQA reader train data](https://drive.google.com/file/d/1BZXSZXN99Mb7--4u0x58cixBTon1PX8N/view?usp=sharing)
17 | - [SQuAD reader train data](https://drive.google.com/file/d/1aMTXIxYZCAC6sX5mZt6nytYxeKvjuigq/view?usp=sharing)
18 | - [Natural Questions train data](https://drive.google.com/file/d/1wUlRkC3_yJnEzdxduFE__yQSfWa_3l0j/view?usp=sharing)
19 |
20 |
21 | We explain some of the some required arguments below.
22 |
23 | - `--bert_model`
24 | This is a bert model type (e.g., `bert-base-uncased`). In our paper, we experiment both with `bert-base-uncased` and `bert-large-uncased-whole-word-masking`.
25 |
26 | - `--output_dir`
27 | This is a directory path to save model checkpoints; a checkpoint is saved every half epoch during training.
28 |
29 | - `--train_file`
30 | This is a file path to train data you can download from the link mentioned above.
31 |
32 | - `--version_2_with_negative`
33 | Please add this option to train our reader model with negative examples.
34 |
35 | - `--do_lower_case`
36 | We use lower-cased version of BERT following previous papers in machine reading comprehension. To reproduce the results, please add this option.
37 |
38 | There are some optional arguments; please see the full list from our [rc_utils.py](rc_utils.py).
39 |
40 | - `--train_batch_size`
41 | This is to specify the number of the batch size during training (default=`32`).
42 | *To train BERT large QA models, you are likely to reduce the number of train batch size (currently set to 32) to make it fit to your GPU memory.*
43 |
44 | - `--max_seq_length`
45 | This is to set the maximum length of input sequence and when the input exceeds the limits, we split the data into several windows.
46 |
47 | - `--predict_file`
48 | This is a file path to your inference data if you would like to evaluate the reader performance (See the details below). Your `predict_file` must be in SQuAD v.2 format like `train_file`.
49 |
50 | You can run training the command below.
51 |
52 | ```bash
53 | python run_reader_confidence.py \
54 | --bert_model bert-base-uncased \
55 | --output_dir /path/to/your/output/dir \
56 | --train_file /path/to/your/train/file \
57 | --predict_file /path/to/your/eval/file \
58 | --max_seq_length 384 \
59 | --do_train \
60 | --do_predict \
61 | --do_lower_case \
62 | --version_2_with_negative
63 | ```
64 |
65 | e.g., HotpotQA
66 |
67 | ```bash
68 | python run_reader_confidence.py \
69 | --bert_model bert-base-uncased \
70 | --output_dir output_hotpot_bert_base \
71 | --train_file data/hotpot/hotpot_reader_train_data.json \
72 | --predict_file data/hotpot/hotpot_dev_squad_v2.0_format.json \
73 | --max_seq_length 384 \
74 | --do_train \
75 | --do_predict \
76 | --do_lower_case \
77 | --version_2_with_negative
78 | ```
79 |
80 | ## 2. Evaluation
81 | As the main goal of this work is on improving open-domain QA performance, we recommend you running the pipeline to evaluate your reader performance.
82 | Alternatively, you can run sanity check on HotpotQA gold paragraph only settings.
83 |
84 | #### Sanity check on HotpotQA gold only setting
85 | For the sanity check, you can run the evaluation of the reader model performance on preprocessed dev file.
86 |
87 | The original HotpotQA questions contain 10 paragraphs, we discard the 8 distractor paragraphs and keep only gold paragraphs. The preprocessed data is also available [here](https://drive.google.com/open?id=1MysthH2TRYoJcK_eLOueoLeYR42T-JhB).
88 |
89 | You can download [the SQuAD 2.0 evaluation script](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/).
90 |
91 | **note: the F1 calculation of is slightly different from the original HotpotQA eval script. We use the SQuAD 2.0 evaluation script for quick sanity check. Please do not use the number to report the performance on HotpotQA.**
92 |
93 | You can run evaluation with the command below:
94 |
95 | ```bash
96 | python evaluate-v2.0.py \
97 | /path/to/eval/file/hotpot_dev_squad_v2.0_format.json \
98 | /path/to/your/output/dir/predictions.json
99 | ```
100 | The F1/EM scores of the bert-base-uncased model on the gold-paragraph only HotpotQA distractor dev data is as follows:
101 |
102 | ```py
103 | {
104 | "exact": 60.60769750168805,
105 | "f1": 74.45707974099558,
106 | "total": 7405,
107 | "HasAns_exact": 60.60769750168805,
108 | "HasAns_f1": 74.45707974099558,
109 | "HasAns_total": 7405
110 | }
111 | ```
112 |
--------------------------------------------------------------------------------
/reader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/reader/__init__.py
--------------------------------------------------------------------------------
/reader/modeling_reader.py:
--------------------------------------------------------------------------------
1 | from pytorch_pretrained_bert.modeling import BertModel, BertPreTrainedModel
2 | from torch import nn
3 | from torch.nn import CrossEntropyLoss
4 | import torch
5 |
6 |
7 | class BERTLayerNorm(nn.Module):
8 | def __init__(self, config, variance_epsilon=1e-12):
9 | """Construct a layernorm module in the TF style (epsilon inside the square root).
10 | """
11 | super(BERTLayerNorm, self).__init__()
12 | self.gamma = nn.Parameter(torch.ones(config.hidden_size))
13 | self.beta = nn.Parameter(torch.zeros(config.hidden_size))
14 | self.variance_epsilon = variance_epsilon
15 |
16 | def forward(self, x):
17 | u = x.mean(-1, keepdim=True)
18 | s = (x - u).pow(2).mean(-1, keepdim=True)
19 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
20 | return self.gamma * x + self.beta
21 |
22 | class BertForQuestionAnsweringConfidence(BertPreTrainedModel):
23 |
24 | def __init__(self, config, num_labels, no_masking, lambda_scale=1.0):
25 | super(BertForQuestionAnsweringConfidence, self).__init__(config)
26 | self.bert = BertModel(config)
27 | self.num_labels = num_labels
28 | self.no_masking = no_masking
29 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
30 | self.qa_outputs = nn.Linear(
31 | config.hidden_size, 2) # [N, L, H] => [N, L, 2]
32 | self.qa_classifier = nn.Linear(
33 | config.hidden_size, self.num_labels) # [N, H] => [N, n_class]
34 | self.lambda_scale = lambda_scale
35 |
36 | def init_weights(module):
37 | if isinstance(module, (nn.Linear, nn.Embedding)):
38 | module.weight.data.normal_(
39 | mean=0.0, std=config.initializer_range)
40 | elif isinstance(module, BERTLayerNorm):
41 | module.beta.data.normal_(
42 | mean=0.0, std=config.initializer_range)
43 | module.gamma.data.normal_(
44 | mean=0.0, std=config.initializer_range)
45 | if isinstance(module, nn.Linear):
46 | module.bias.data.zero_()
47 |
48 | self.apply(init_weights)
49 |
50 | def forward(self, input_ids, token_type_ids, attention_mask,
51 | start_positions=None, end_positions=None, switch_list=None):
52 | sequence_output, pooled_output = self.bert(
53 | input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
54 |
55 | # Calculate the sequence logits.
56 | logits = self.qa_outputs(sequence_output)
57 | start_logits, end_logits = logits.split(1, dim=-1)
58 | start_logits = start_logits.squeeze(-1)
59 | end_logits = end_logits.squeeze(-1)
60 | # Calculate the class logits
61 | pooled_output = self.dropout(pooled_output)
62 | switch_logits = self.qa_classifier(pooled_output)
63 |
64 | if start_positions is not None and end_positions is not None and switch_list is not None:
65 | # If we are on multi-GPU, split add a dimension
66 | if len(start_positions.size()) > 1:
67 | start_positions = start_positions.squeeze(-1)
68 | if len(end_positions.size()) > 1:
69 | end_positions = end_positions.squeeze(-1)
70 |
71 | ignored_index = start_logits.size(1)
72 | start_positions.clamp_(0, ignored_index)
73 | end_positions.clamp_(0, ignored_index)
74 | loss_fct = CrossEntropyLoss(
75 | ignore_index=ignored_index, reduce=False)
76 |
77 | # if no_masking is True, we do not mask the no-answer examples'
78 | # span losses.
79 | if self.no_masking is True:
80 | start_losses = loss_fct(start_logits, start_positions)
81 | end_losses = loss_fct(end_logits, end_positions)
82 |
83 | else:
84 | # You care about the span only when switch is 0
85 | span_mask = (switch_list == 0).type(torch.FloatTensor).cuda()
86 | start_losses = loss_fct(
87 | start_logits, start_positions) * span_mask
88 | end_losses = loss_fct(end_logits, end_positions) * span_mask
89 |
90 | switch_losses = loss_fct(switch_logits, switch_list)
91 | assert len(start_losses) == len(
92 | end_losses) == len(switch_losses)
93 | return self.lambda_scale * (start_losses + end_losses) + switch_losses
94 |
95 | elif start_positions is None or end_positions is None or switch_list is None:
96 | return start_logits, end_logits, switch_logits
97 |
98 | else:
99 | raise NotImplementedError()
100 |
--------------------------------------------------------------------------------
/reader/modeling_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 |
5 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
6 | BertModel, BertTokenizer)
7 |
8 | MODEL_CLASSES = {
9 | 'bert': (BertConfig, BertModel, BertTokenizer),
10 | }
11 |
12 | def get_bert_model_from_pytorch_transformers(model_name):
13 | config_class, model_class, tokenizer_class = MODEL_CLASSES['bert']
14 | config = config_class.from_pretrained(model_name)
15 | model = model_class.from_pretrained(model_name, from_tf=bool('.ckpt' in model_name), config=config)
16 |
17 | tokenizer = tokenizer_class.from_pretrained(model_name)
18 |
19 | vocab_file_name = './vocabulary_'+model_name+'.txt'
20 |
21 | if not os.path.exists(vocab_file_name):
22 | index = 0
23 | with open(vocab_file_name, "w", encoding="utf-8") as writer:
24 | for token, token_index in sorted(tokenizer.vocab.items(), key=lambda kv: kv[1]):
25 | if index != token_index:
26 | assert False
27 | index = token_index
28 | writer.write(token + u'\n')
29 | index += 1
30 |
31 | return model.state_dict(), vocab_file_name
32 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gdown
2 | glob2
3 | Jinja2
4 | joblib
5 | jsonlines
6 | nltk
7 | numpy
8 | pathlib2
9 | prettytable
10 | pytorch-pretrained-bert==0.6.2
11 | pytorch-transformers==1.2.0
12 | regex
13 | scipy
14 | scikit-learn
15 | spacy==2.2.3
16 | tagme==0.1.3
17 | torch==1.3.0
18 | tqdm
19 | ujson
20 | unicodecsv
21 | urllib3
--------------------------------------------------------------------------------
/retriever/README.md:
--------------------------------------------------------------------------------
1 | # TF-IDF based retriever
2 | For efficiency and scalability, we bootstrap retrieval process with a TF-IDF based retrieval. In particular, we first sample seed paragraphs based on TF-IDF scores and sequentially retrieve sub-graphs connected from the seed paragraphs.
3 |
4 | Table of contents:
5 | - 1. Preprocessing Wikipedia dump
6 | - 2. Building database
7 | - 3. Building the TF-IDF N-grams
8 | - 4. Interactive mode
9 |
10 | *Acknowledgement: The code bases are started from the amazing [DrQA](https://github.com/facebookresearch/DrQA)'s document retriever code. We really appreciate the efforts by the DrQA authors.*
11 |
12 | ## 1. Preprocessing Wikipedia dump
13 |
14 | #### download Wikipedia dumps
15 | First, you need to install the Wikipedia dump.
16 |
17 | - **HotpotQA**: You do not need to download Wikipedia dump by yourself, as the authors' provide the preprocessed dump in [HotpotQA official website](https://hotpotqa.github.io/wiki-readme.html). If you consider using our model for HotpotQA only, we recommend you downloading the [intro-paragraph only version](https://nlp.stanford.edu/projects/hotpotqa/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2). You can extract files as the instruction in the website.
18 |
19 | - **SQuAD**: Although you can use the same dump as in HotpotQA, we recommend you using the older dump and the DB distributed by the [DrQA repository](https://github.com/facebookresearch/DrQA/blob/master/download.sh). Wikipedia is frequently edited by users, and thus if you use the newer version, some answers (originally included in the context) are lost. Please refer the details of our finding in **Appendix B.5** in our paper. Although the DB does not preserve hyperlink information, we do not observe large performance information on SQuAD without link-based hop, as the questions are mostly single-hop or inner-article multi-hop.
20 |
21 | - **Natural Questions**: For Natural Questions, due to the similar reasons mentioned above, we recommend you using the [dump](https://archive.org/download/enwiki-20181220/enwiki-20181220-pages-meta-current.xml.bz2) from December, 2018, which also used in previous work on NQ Open ([Lee et al., 2019](https://arxiv.org/abs/1906.00300); [Min et al., 2019](https://arxiv.org/abs/1909.04849)).
22 |
23 |
24 | #### Extract articles
25 | As mentioned, you do not need to preprocess the dump by yourself for HotpotQA or SQuAD. If you attempt to experiment on Natural Questions or other Wikipedia dump by yourself, you need to extract the articles.
26 |
27 | 1. Install [wikiextractor](https://github.com/attardi/wikiextractor)
28 | 2. Run `Wikiextractor.py` with `--json` and `-l`. The first option make the output easy-to-read-and-process `jsonlines` format and the second option preserve the hyperlinks, which is crutial for our framework.
29 |
30 |
31 | ## 2. Building database
32 |
33 | To efficiently store and access our documents, we store them in a sqlite database.
34 | To create a sqlite db from a corpus of documents, run:
35 |
36 | ```bash
37 | python build_db.py /path/to/data /path/to/saved/db.db --hotpoqa_format
38 | ```
39 |
40 | **Note**
41 | Do not forget `--hotpoqa_format` option when you process Wikipedia data for HotpotQA experiments. The HotpotQA authors kindly provide the preprocessed dump, and the titles and sentence separations should be consistent for supporting fact evaluations.
42 |
43 | For introductory paragraph only Wikipedia, the total number of the paragraphs stored into DB should be 5,233,329.
44 |
45 | ```
46 | $ python build_db.py $PATH_TO_WIKI_DIR/enwiki-20171001-pages-meta-current-withlinks-abstracts enwiki_intro.db --intro_only
47 | 07/01/2019 11:31:51 PM: [ Reading into database... ]
48 | 100%|███████████████████████████████████████| 15517/15517 [01:47<00:00, 143.97it/s]
49 | 07/01/2019 11:33:39 PM: [ Read 5233329 docs. ]
50 | 07/01/2019 11:33:39 PM: [ Committing... ]
51 | ```
52 |
53 | If you create the DB from the 2018/12/20 dump, the total number of articles will be 5,771,730.
54 | ```
55 | $ python retriever/build_db.py $PATH_WIKI_DIR enwiki_20181220_all.db
56 | 02/01/2020 11:52:28 PM: [ Reading into database... ]
57 | 100%|██████████████████████████████████████| 16399/16399 [04:03<00:00, 67.28it/s]
58 | 02/01/2020 11:56:32 PM: [ Read 5771730 docs. ]
59 | 02/01/2020 11:56:32 PM: [ Committing... ]
60 | ```
61 |
62 | **Note: the total number of paragraphs would be 30M and te DB size would be 27 GB.)**
63 |
64 | Optional arguments:
65 | ```
66 | --preprocess File path to a python module that defines a `preprocess` function.
67 | --num-workers Number of CPU processes (for tokenizing, etc).
68 | ```
69 |
70 | #### Keeping hyperlinks in `doc_text`
71 | Due to the nature of Wikipedia hyperlinks, a hyperlink connection is from an paragraph (source paragraph) to an article (target articles), although if we only consider introductory paragraphs, the relations are always paragraph-paragraph.
72 |
73 | To efficiently store the relationship, for the multiple paragraph settings (e.g., Natural Questions Open), we keep the hyperlink information in the `doc_text`.
74 |
75 | e.g., Seattle
76 | ```
77 | Seattle is a seaport city on the West Coast of the United States.
78 | ```
79 |
80 | ## 3. Building the TF-IDF N-grams
81 |
82 | To build a TF-IDF weighted word-doc sparse matrix from the documents stored in the sqlite db, run:
83 |
84 | ```bash
85 | python build_tfidf.py /path/to/doc/db.db /path/to/output/dir
86 | ```
87 |
88 | e.g.,
89 | ```bash
90 | python build_tfidf.py enwiki_intro.db tfidf_results_from_enwiki_intro_only/
91 | ```
92 |
93 | The sparse matrix and its associated metadata will be saved to the output directory under (i.e., `tfidf_results_from_enwiki_intro_only`) `-tfidf-ngram=-hash=-tokenizer=.npz`.
94 |
95 |
96 | Optional arguments:
97 | ```
98 | --ngram Use up to N-size n-grams (e.g. 2 = unigrams + bigrams). By default only ngrams without stopwords or punctuation are kept.
99 | --hash-size Number of buckets to use for hashing ngrams.
100 | --tokenizer String option specifying tokenizer type to use (e.g. 'corenlp').
101 | --num-workers Number of CPU processes (for tokenizing, etc).
102 | ```
103 |
104 | **Note: If you build TFIDF matrix from the full Wikipedia paragraphs, it ends up consuming more than a lot of CPU memories, which your local machine might not accommodate. In that case, please use the DB & .npz files we distribute, our consider using a amchine with more memory.**
105 |
106 |
107 | ## 4. Interactive mode
108 | You can play with the TFIDF retriever with interactive mode :)
109 | If you set the `with_content=True` in the process function, you can see the paragraph as well as title.
110 |
111 | ```bash
112 | python scripts/retriever/interactive.py --model /path/to/model \
113 | --db_save_path /path/to/db file
114 | ```
115 | e.g.,
116 |
117 | ```bash
118 | python interactive.py --model tfidf_wiki_abst/wiki_open_full_new_db_intro_only-tfidf-ngram\=2-hash\=16777216-tokenizer\=simple.npz \
119 | --db_save_path wiki_open_full_new_db_intro_only.db
120 | ```
121 | ```
122 | >>> process('At what university can the building that served as the fictional household that includes Gomez and Morticia be found?', k=1, with_content=True)
123 | +------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
124 | | Rank | Doc Id | Doc Text |
125 | +------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
126 | | 1 | The Addams Family_0 | The Addams Family is a fictional household created by American cartoonist Charles Addams. The Addams Family characters have traditionally included Gomez and Morticia Addams, their children Wednesday and Pugsley, close family members Uncle Fester and Grandmama, their butler Lurch, the disembodied hand Thing, and Gomez's Cousin Itt. |
127 | +------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
128 | ```
129 |
--------------------------------------------------------------------------------
/retriever/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AkariAsai/learning_to_retrieve_reasoning_paths/a020d52cfbbb7d7fca9fa25361e549c85e81875c/retriever/__init__.py
--------------------------------------------------------------------------------
/retriever/build_db.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # The codes are started from DrQA (https://github.com/facebookresearch/DrQA) library.
3 | """A script to read in and store documents in a sqlite database."""
4 |
5 | import argparse
6 | import sqlite3
7 | import json
8 | import os
9 | import logging
10 | import importlib.util
11 | import glob
12 |
13 | from multiprocessing import Pool as ProcessPool
14 | from tqdm import tqdm
15 |
16 | try:
17 | from retriever.utils import normalize, process_jsonlines_hotpotqa, process_jsonlines
18 | except:
19 | from utils import normalize, process_jsonlines_hotpotqa, process_jsonlines
20 |
21 | logger = logging.getLogger()
22 | logger.setLevel(logging.INFO)
23 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
24 | console = logging.StreamHandler()
25 | console.setFormatter(fmt)
26 | logger.addHandler(console)
27 |
28 |
29 | # ------------------------------------------------------------------------------
30 | # Import helper
31 | # ------------------------------------------------------------------------------
32 |
33 |
34 | PREPROCESS_FN = None
35 |
36 |
37 | def init(filename):
38 | global PREPROCESS_FN
39 | if filename:
40 | PREPROCESS_FN = import_module(filename).preprocess
41 |
42 |
43 | def import_module(filename):
44 | """Import a module given a full path to the file."""
45 | spec = importlib.util.spec_from_file_location('doc_filter', filename)
46 | module = importlib.util.module_from_spec(spec)
47 | spec.loader.exec_module(module)
48 | return module
49 |
50 |
51 | # ------------------------------------------------------------------------------
52 | # Store corpus.
53 | # ------------------------------------------------------------------------------
54 |
55 |
56 | def iter_files(path):
57 | """Walk through all files located under a root path."""
58 | if os.path.isfile(path):
59 | yield path
60 | elif os.path.isdir(path):
61 | for dirpath, _, filenames in os.walk(path):
62 | for f in filenames:
63 | yield os.path.join(dirpath, f)
64 | else:
65 | raise RuntimeError('Path %s is invalid' % path)
66 |
67 |
68 | def get_contents_hotpotqa(filename):
69 | """Parse the contents of a file. Each line is a JSON encoded document."""
70 | global PREPROCESS_FN
71 | documents = []
72 | extracted_items = process_jsonlines_hotpotqa(filename)
73 | for extracted_item in extracted_items:
74 | title = extracted_item["title"]
75 | text = extracted_item["plain_text"]
76 | original_title = extracted_item["original_title"]
77 | hyper_linked_titles = extracted_item["hyper_linked_titles"]
78 |
79 | documents.append((title, text,
80 | hyper_linked_titles, original_title))
81 | return documents
82 |
83 | def get_contents(filename):
84 | """Parse the contents of a file. Each line is a JSON encoded document."""
85 | global PREPROCESS_FN
86 | documents = []
87 | extracted_items = process_jsonlines(filename)
88 | for extracted_item in extracted_items:
89 | title = extracted_item["title"]
90 | text = extracted_item["plain_text"]
91 | original_title = extracted_item["original_title"]
92 | hyper_linked_titles = extracted_item["hyper_linked_titles"]
93 |
94 | documents.append((title, text,
95 | hyper_linked_titles, original_title))
96 | return documents
97 |
98 | def store_contents(wiki_dir, save_path, preprocess, num_workers=None, hotpotqa_format=False):
99 | """Preprocess and store a corpus of documents in sqlite.
100 |
101 | Args:
102 | data_path: Root path to directory (or directory of directories) of files
103 | containing json encoded documents (must have `id` and `text` fields).
104 | save_path: Path to output sqlite db.
105 | preprocess: Path to file defining a custom `preprocess` function. Takes
106 | in and outputs a structured doc.
107 | num_workers: Number of parallel processes to use when reading docs.
108 | """
109 | filenames = [f for f in glob.glob(
110 | wiki_dir + "/*/wiki_*", recursive=True) if ".bz2" not in f]
111 | if os.path.isfile(save_path):
112 | raise RuntimeError('%s already exists! Not overwriting.' % save_path)
113 |
114 | logger.info('Reading into database...')
115 | conn = sqlite3.connect(save_path)
116 | c = conn.cursor()
117 | c.execute(
118 | "CREATE TABLE documents (id PRIMARY KEY, text, linked_title, original_title);")
119 |
120 | workers = ProcessPool(num_workers, initializer=init,
121 | initargs=(preprocess,))
122 | count = 0
123 | # Due to the slight difference of input format between preprocessed HotpotQA wikipedia data and
124 | # the ones by Wikiextractor, we call different functions for data process.
125 | if hotpotqa_format is True:
126 | content_processing_method = get_contents_hotpotqa
127 | else:
128 | content_processing_method = get_contents
129 |
130 | with tqdm(total=len(filenames)) as pbar:
131 | for pairs in tqdm(workers.imap_unordered(content_processing_method, filenames)):
132 | count += len(pairs)
133 | c.executemany(
134 | "INSERT OR REPLACE INTO documents VALUES (?,?,?,?)", pairs)
135 | pbar.update()
136 | logger.info('Read %d docs.' % count)
137 | logger.info('Committing...')
138 | conn.commit()
139 | conn.close()
140 |
141 | # ------------------------------------------------------------------------------
142 | # Main.
143 | # ------------------------------------------------------------------------------
144 |
145 |
146 | if __name__ == '__main__':
147 | parser = argparse.ArgumentParser()
148 | parser.add_argument('wiki_dir', type=str, help='/path/to/data')
149 | parser.add_argument('save_path', type=str, help='/path/to/saved/db.db')
150 | parser.add_argument('--preprocess', type=str, default=None,
151 | help=('File path to a python module that defines '
152 | 'a `preprocess` function'))
153 | parser.add_argument('--num-workers', type=int, default=None,
154 | help='Number of CPU processes (for tokenizing, etc)')
155 | parser.add_argument('--hotpoqa_format', action='store_true',
156 | help='the input files are hotpotqa format.')
157 | args = parser.parse_args()
158 |
159 | store_contents(
160 | args.wiki_dir, args.save_path, args.preprocess, args.num_workers,
161 | args.hotpoqa_format
162 | )
163 |
--------------------------------------------------------------------------------
/retriever/build_tfidf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # The codes are based on DrQA (https://github.com/facebookresearch/DrQA) library.
3 | """A script to build the tf-idf document matrices for retrieval."""
4 |
5 | import numpy as np
6 | import scipy.sparse as sp
7 | import argparse
8 | import os
9 | import math
10 | import logging
11 |
12 | from multiprocessing import Pool as ProcessPool
13 | from multiprocessing.util import Finalize
14 | from functools import partial
15 | from collections import Counter
16 |
17 | try:
18 | from retriever.tfidf_doc_ranker import TfidfDocRanker
19 | from retriever.doc_db import DocDB
20 | from retriever.tokenizers import SimpleTokenizer
21 | from retriever.utils import filter_ngram, hash, save_sparse_csr
22 | except:
23 | from tfidf_doc_ranker import TfidfDocRanker
24 | from doc_db import DocDB
25 | from tokenizers import SimpleTokenizer
26 | from utils import filter_ngram, hash, save_sparse_csr
27 |
28 | logger = logging.getLogger()
29 | logger.setLevel(logging.INFO)
30 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
31 | console = logging.StreamHandler()
32 | console.setFormatter(fmt)
33 | logger.addHandler(console)
34 |
35 |
36 | # ------------------------------------------------------------------------------
37 | # Set retriever class
38 | # ------------------------------------------------------------------------------
39 |
40 |
41 | def get_class(name):
42 | if name == 'tfidf':
43 | return TfidfDocRanker
44 | if name == 'sqlite':
45 | return DocDB
46 | raise RuntimeError('Invalid retriever class: %s' % name)
47 |
48 | # ------------------------------------------------------------------------------
49 | # Multiprocessing functions
50 | # ------------------------------------------------------------------------------
51 |
52 |
53 | DOC2IDX = None
54 | PROCESS_TOK = None
55 | PROCESS_DB = None
56 |
57 |
58 | def init(tokenizer_class, db_class, db_opts):
59 | global PROCESS_TOK, PROCESS_DB
60 | PROCESS_TOK = tokenizer_class()
61 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
62 | PROCESS_DB = db_class(**db_opts)
63 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
64 |
65 |
66 | def fetch_text(doc_id):
67 | global PROCESS_DB
68 | return PROCESS_DB.get_doc_text(doc_id)
69 |
70 |
71 | def fetch_text_multi_para(doc_id):
72 | global PROCESS_DB
73 | paras, _ = PROCESS_DB.get_doc_text_hyper_linked_titles_for_articles(doc_id)
74 | if len(paras) > 0:
75 | paras = "\n\n".join(paras)
76 | return paras
77 |
78 |
79 | def tokenize(text):
80 | global PROCESS_TOK
81 | return PROCESS_TOK.tokenize(text)
82 |
83 |
84 | # ------------------------------------------------------------------------------
85 | # Build article --> word count sparse matrix.
86 | # ------------------------------------------------------------------------------
87 |
88 | def count(ngram, hash_size, multi_para, doc_id):
89 | """Fetch the text of a document and compute hashed ngrams counts."""
90 | global DOC2IDX
91 | # FIXME: remove hard coding.
92 | row, col, data = [], [], []
93 | # Tokenize
94 |
95 | if multi_para is True:
96 | # 1. if multi_para is true, the doc contains multiple paragraphs separated by \n\n and with links.
97 | tokens = tokenize(fetch_text_multi_para(doc_id))
98 | else:
99 | # 2. if not, only intro docs are retrieved and the sentences are separated by \t.
100 | # remove sentence separations ("\t") (only for HotpotQA).
101 | tokens = tokenize(fetch_text(doc_id).replace("\t", ""))
102 |
103 | # Get ngrams from tokens, with stopword/punctuation filtering.
104 | ngrams = tokens.ngrams(
105 | n=ngram, uncased=True, filter_fn=filter_ngram
106 | )
107 |
108 | # Hash ngrams and count occurences
109 | counts = Counter([hash(gram, hash_size)
110 | for gram in ngrams])
111 |
112 | # Return in sparse matrix data format.
113 | row.extend(counts.keys())
114 | col.extend([DOC2IDX[doc_id]] * len(counts))
115 | data.extend(counts.values())
116 | return row, col, data
117 |
118 |
119 | def get_count_matrix(args, db, db_opts):
120 | """Form a sparse word to document count matrix (inverted index).
121 |
122 | M[i, j] = # times word i appears in document j.
123 | """
124 | # Map doc_ids to indexes
125 | global DOC2IDX
126 | db_class = get_class(db)
127 | with db_class(**db_opts) as doc_db:
128 | doc_ids = doc_db.get_doc_ids()
129 | DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)}
130 |
131 | # Setup worker pool
132 | # TODO: Add tokenizer's choice.
133 | tok_class = SimpleTokenizer
134 | workers = ProcessPool(
135 | args.num_workers,
136 | initializer=init,
137 | initargs=(tok_class, db_class, db_opts)
138 | )
139 |
140 | # Compute the count matrix in steps (to keep in memory)
141 | logger.info('Mapping...')
142 | row, col, data = [], [], []
143 | step = max(int(len(doc_ids) / 10), 1)
144 | batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)]
145 | _count = partial(count, args.ngram, args.hash_size, args.multi_para)
146 | for i, batch in enumerate(batches):
147 | logger.info('-' * 25 + 'Batch %d/%d' %
148 | (i + 1, len(batches)) + '-' * 25)
149 | for b_row, b_col, b_data in workers.imap_unordered(_count, batch):
150 | row.extend(b_row)
151 | col.extend(b_col)
152 | data.extend(b_data)
153 | workers.close()
154 | workers.join()
155 |
156 | logger.info('Creating sparse matrix...')
157 | count_matrix = sp.csr_matrix(
158 | (data, (row, col)), shape=(args.hash_size, len(doc_ids))
159 | )
160 | count_matrix.sum_duplicates()
161 | return count_matrix, (DOC2IDX, doc_ids)
162 |
163 |
164 | # ------------------------------------------------------------------------------
165 | # Transform count matrix to different forms.
166 | # ------------------------------------------------------------------------------
167 |
168 |
169 | def get_tfidf_matrix(cnts):
170 | """Convert the word count matrix into tfidf one.
171 |
172 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
173 | * tf = term frequency in document
174 | * N = number of documents
175 | * Nt = number of occurences of term in all documents
176 | """
177 | Ns = get_doc_freqs(cnts)
178 | idfs = np.log((cnts.shape[1] - Ns + 0.5) / (Ns + 0.5))
179 | idfs[idfs < 0] = 0
180 | idfs = sp.diags(idfs, 0)
181 | tfs = cnts.log1p()
182 | tfidfs = idfs.dot(tfs)
183 | return tfidfs
184 |
185 |
186 | def get_doc_freqs(cnts):
187 | """Return word --> # of docs it appears in."""
188 | binary = (cnts > 0).astype(int)
189 | freqs = np.array(binary.sum(1)).squeeze()
190 | return freqs
191 |
192 |
193 | # ------------------------------------------------------------------------------
194 | # Main.
195 | # ------------------------------------------------------------------------------
196 |
197 |
198 | if __name__ == '__main__':
199 | parser = argparse.ArgumentParser()
200 | parser.add_argument('db_path', type=str, default=None,
201 | help='Path to sqlite db holding document texts')
202 | parser.add_argument('out_dir', type=str, default=None,
203 | help='Directory for saving output files')
204 | parser.add_argument('--ngram', type=int, default=2,
205 | help=('Use up to N-size n-grams '
206 | '(e.g. 2 = unigrams + bigrams)'))
207 | parser.add_argument('--hash-size', type=int, default=int(math.pow(2, 24)),
208 | help='Number of buckets to use for hashing ngrams')
209 | parser.add_argument('--tokenizer', type=str, default='simple',
210 | help=("String option specifying tokenizer type to use "
211 | "(e.g. 'corenlp')"))
212 | parser.add_argument('--num-workers', type=int, default=None,
213 | help='Number of CPU processes (for tokenizing, etc)')
214 | parser.add_argument('--multi_para', action='store_true',
215 | help='set true if the db contains multiple paragraphs, not intro-paragraph only.')
216 | args = parser.parse_args()
217 |
218 | logging.info('Counting words...')
219 | count_matrix, doc_dict = get_count_matrix(
220 | args, 'sqlite', {'db_path': args.db_path}
221 | )
222 |
223 | logger.info('Making tfidf vectors...')
224 | tfidf = get_tfidf_matrix(count_matrix)
225 |
226 | logger.info('Getting word-doc frequencies...')
227 | freqs = get_doc_freqs(count_matrix)
228 |
229 | basename = os.path.splitext(os.path.basename(args.db_path))[0]
230 | basename += ('-tfidf-ngram=%d-hash=%d-tokenizer=%s' %
231 | (args.ngram, args.hash_size, args.tokenizer))
232 |
233 | # check if output_dir exists; if not, create the output_dir.
234 | if not os.path.exists(args.out_dir):
235 | os.makedirs(args.out_dir)
236 | filename = os.path.join(args.out_dir, basename)
237 |
238 | logger.info('Saving to %s.npz' % filename)
239 | metadata = {
240 | 'doc_freqs': freqs,
241 | 'tokenizer': args.tokenizer,
242 | 'hash_size': args.hash_size,
243 | 'ngram': args.ngram,
244 | 'doc_dict': doc_dict
245 | }
246 | save_sparse_csr(filename, tfidf, metadata)
247 |
--------------------------------------------------------------------------------
/retriever/doc_db.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 | import argparse
3 | import time
4 | try:
5 | from retriever.utils import find_hyper_linked_titles, remove_tags, normalize
6 | except:
7 | from utils import find_hyper_linked_titles, remove_tags, normalize
8 |
9 | class DocDB(object):
10 | """Sqlite backed document storage.
11 |
12 | Implements get_doc_text(doc_id).
13 | """
14 |
15 | def __init__(self, db_path=None):
16 | self.path = db_path
17 | self.connection = sqlite3.connect(self.path, check_same_thread=False)
18 |
19 | def __enter__(self):
20 | return self
21 |
22 | def __exit__(self, *args):
23 | self.close()
24 |
25 | def close(self):
26 | """Close the connection to the database."""
27 | self.connection.close()
28 |
29 | def get_doc_ids(self):
30 | """Fetch all ids of docs stored in the db."""
31 | cursor = self.connection.cursor()
32 | cursor.execute("SELECT id FROM documents")
33 | results = [r[0] for r in cursor.fetchall()]
34 | cursor.close()
35 | return results
36 |
37 | def get_doc_text(self, doc_id):
38 | """Fetch the raw text of the doc for 'doc_id'."""
39 | cursor = self.connection.cursor()
40 | cursor.execute(
41 | "SELECT text FROM documents WHERE id = ?",
42 | (doc_id,)
43 | )
44 | result = cursor.fetchone()
45 | cursor.close()
46 | return result if result is None else result[0]
47 |
48 | def get_hyper_linked(self, doc_id):
49 | """Fetch the hyper-linked titles of the doc for 'doc_id'."""
50 | cursor = self.connection.cursor()
51 | cursor.execute(
52 | "SELECT linked_title FROM documents WHERE id = ?",
53 | (doc_id,)
54 | )
55 | result = cursor.fetchone()
56 | cursor.close()
57 | return result if (result is None or len(result[0]) == 0) else [normalize(title) for title in result[0].split("\t")]
58 |
59 | def get_original_title(self, doc_id):
60 | """Fetch the original title name of the doc."""
61 | cursor = self.connection.cursor()
62 | cursor.execute(
63 | "SELECT original_title FROM documents WHERE id = ?",
64 | (doc_id,)
65 | )
66 | result = cursor.fetchone()
67 | cursor.close()
68 | return result if result is None else result[0]
69 |
70 | def get_doc_text_hyper_linked_titles_for_articles(self, doc_id):
71 | """
72 | fetch all of the paragraphs with their corresponding hyperlink titles.
73 | e.g.,
74 | >>> paras, links = db.get_doc_text_hyper_linked_titles_for_articles("Tokyo Imperial Palace_0")
75 | >>> paras[2]
76 | 'It is built on the site of the old Edo Castle. The total area including the gardens is . During the height of the 1980s Japanese property bubble, the palace grounds were valued by some to be more than the value of all of the real estate in the state of California.'
77 | >>> links[2]
78 | ['Edo Castle', 'Japanese asset price bubble', 'Real estate', 'California']
79 | """
80 | cursor = self.connection.cursor()
81 | cursor.execute(
82 | "SELECT text FROM documents WHERE id = ?",
83 | (doc_id,)
84 | )
85 | result = cursor.fetchone()
86 | cursor.close()
87 | if result is None:
88 | return [], []
89 | else:
90 | hyper_linked_paragraphs = result[0].split("\n\n")
91 | paragraphs, hyper_linked_titles = [], []
92 |
93 | for hyper_linked_paragraph in hyper_linked_paragraphs:
94 | paragraphs.append(remove_tags(hyper_linked_paragraph))
95 | hyper_linked_titles.append([normalize(title) for title in find_hyper_linked_titles(
96 | hyper_linked_paragraph)])
97 |
98 | return paragraphs, hyper_linked_titles
99 |
--------------------------------------------------------------------------------
/retriever/interactive.py:
--------------------------------------------------------------------------------
1 | """Interactive mode for the tfidf DrQA retriever module."""
2 |
3 | import argparse
4 | import code
5 | import prettytable
6 | import logging
7 | try:
8 | from retriever.doc_db import DocDB
9 | from retriever.tfidf_doc_ranker import TfidfDocRanker
10 | except:
11 | from doc_db import DocDB
12 | from tfidf_doc_ranker import TfidfDocRanker
13 |
14 | logger = logging.getLogger()
15 | logger.setLevel(logging.INFO)
16 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
17 | console = logging.StreamHandler()
18 | console.setFormatter(fmt)
19 | logger.addHandler(console)
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--model', type=str, default=None)
23 | parser.add_argument('--db_save_path', type=str, default=None)
24 | args = parser.parse_args()
25 | db = DocDB(args.db_save_path)
26 |
27 | logger.info('Initializing ranker...')
28 | ranker = TfidfDocRanker(tfidf_path=args.model)
29 |
30 |
31 | # ------------------------------------------------------------------------------
32 | # Drop in to interactive
33 | # ------------------------------------------------------------------------------
34 |
35 |
36 | def process(query, k=1, with_content=False):
37 | doc_names, doc_scores = ranker.closest_docs(query, k)
38 | if with_content is True:
39 | doc_text = []
40 | for doc_name in doc_names:
41 | doc_text.append(db.get_doc_text(doc_name))
42 | table = prettytable.PrettyTable(
43 | ['Rank', 'Doc Id', 'Doc Text']
44 | )
45 |
46 | for i in range(len(doc_names)):
47 | table.add_row([i + 1, doc_names[i], doc_text[i]])
48 | print(table)
49 |
50 | else:
51 | table = prettytable.PrettyTable(
52 | ['Rank', 'Doc Id', 'Doc Score']
53 | )
54 | for i in range(len(doc_names)):
55 | table.add_row([i + 1, doc_names[i], '%.5g' % doc_scores[i]])
56 | print(table)
57 |
58 |
59 | banner = """
60 | Interactive TF-IDF DrQA Retriever
61 | >> process(question, k=1)
62 | >> usage()
63 | """
64 |
65 |
66 | def usage():
67 | print(banner)
68 |
69 |
70 | code.interact(banner=banner, local=locals())
71 |
--------------------------------------------------------------------------------
/retriever/tfidf_doc_ranker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Rank documents with TF-IDF scores"""
8 |
9 | import logging
10 | import numpy as np
11 | import scipy.sparse as sp
12 |
13 | from multiprocessing.pool import ThreadPool
14 | from functools import partial
15 |
16 | try:
17 | from retriever.tokenizers import SimpleTokenizer
18 | from retriever.utils import load_sparse_csr, filter_ngram, hash, normalize
19 | except:
20 | from tokenizers import SimpleTokenizer
21 | from utils import load_sparse_csr, filter_ngram, hash, normalize
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class TfidfDocRanker(object):
27 | """Loads a pre-weighted inverted index of token/document terms.
28 | Scores new queries by taking sparse dot products.
29 | """
30 |
31 | def __init__(self, tfidf_path=None, strict=True):
32 | """
33 | Args:
34 | tfidf_path: path to saved model file
35 | strict: fail on empty queries or continue (and return empty result)
36 | """
37 | # Load from disk
38 | tfidf_path = tfidf_path
39 | logger.info('Loading %s' % tfidf_path)
40 | matrix, metadata = load_sparse_csr(tfidf_path)
41 | self.doc_mat = matrix
42 | self.ngrams = metadata['ngram']
43 | self.hash_size = metadata['hash_size']
44 | self.tokenizer = SimpleTokenizer()
45 | self.doc_freqs = metadata['doc_freqs'].squeeze()
46 | self.doc_dict = metadata['doc_dict']
47 | self.num_docs = len(self.doc_dict[0])
48 | self.strict = strict
49 |
50 | def get_doc_index(self, doc_id):
51 | """Convert doc_id --> doc_index"""
52 | return self.doc_dict[0][doc_id]
53 |
54 | def get_doc_id(self, doc_index):
55 | """Convert doc_index --> doc_id"""
56 | return self.doc_dict[1][doc_index]
57 |
58 | def closest_docs(self, query, k=1):
59 | """Closest docs by dot product between query and documents
60 | in tfidf weighted word vector space.
61 | """
62 | spvec = self.text2spvec(query)
63 | res = spvec * self.doc_mat
64 |
65 | if len(res.data) <= k:
66 | o_sort = np.argsort(-res.data)
67 | else:
68 | o = np.argpartition(-res.data, k)[0:k]
69 | o_sort = o[np.argsort(-res.data[o])]
70 |
71 | doc_scores = res.data[o_sort]
72 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]]
73 | return doc_ids, doc_scores
74 |
75 | def batch_closest_docs(self, queries, k=1, num_workers=None):
76 | """Process a batch of closest_docs requests multithreaded.
77 | Note: we can use plain threads here as scipy is outside of the GIL.
78 | """
79 | with ThreadPool(num_workers) as threads:
80 | closest_docs = partial(self.closest_docs, k=k)
81 | results = threads.map(closest_docs, queries)
82 | return results
83 |
84 | def parse(self, query):
85 | """Parse the query into tokens (either ngrams or tokens)."""
86 | tokens = self.tokenizer.tokenize(query)
87 | return tokens.ngrams(n=self.ngrams, uncased=True,
88 | filter_fn=filter_ngram)
89 |
90 | def text2spvec(self, query):
91 | """Create a sparse tfidf-weighted word vector from query.
92 |
93 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
94 | """
95 | # Get hashed ngrams
96 | # TODO: do we need to have normalize?
97 | words = self.parse(normalize(query))
98 | wids = [hash(w, self.hash_size) for w in words]
99 |
100 | if len(wids) == 0:
101 | if self.strict:
102 | raise RuntimeError('No valid word in: %s' % query)
103 | else:
104 | logger.warning('No valid word in: %s' % query)
105 | return sp.csr_matrix((1, self.hash_size))
106 |
107 | # Count TF
108 | wids_unique, wids_counts = np.unique(wids, return_counts=True)
109 | tfs = np.log1p(wids_counts)
110 |
111 | # Count IDF
112 | Ns = self.doc_freqs[wids_unique]
113 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5))
114 | idfs[idfs < 0] = 0
115 |
116 | # TF-IDF
117 | data = np.multiply(tfs, idfs)
118 |
119 | # One row, sparse csr matrix
120 | indptr = np.array([0, len(wids_unique)])
121 | spvec = sp.csr_matrix(
122 | (data, wids_unique, indptr), shape=(1, self.hash_size)
123 | )
124 |
125 | return spvec
126 |
--------------------------------------------------------------------------------
/retriever/tfidf_vectorizer_article.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.feature_extraction.text import TfidfVectorizer
3 | from sklearn.metrics import pairwise_distances
4 | from sklearn.metrics.pairwise import cosine_similarity
5 |
6 |
7 | class TopTfIdf():
8 | def __init__(self, n_to_select: int, filter_dist_one: bool = False, rank=True):
9 | self.rank = rank
10 | self.n_to_select = n_to_select
11 | self.filter_dist_one = filter_dist_one
12 |
13 | def prune(self, question, paragraphs, return_scores=False):
14 | if not self.filter_dist_one and len(paragraphs) == 1:
15 | return paragraphs
16 |
17 | tfidf = TfidfVectorizer(strip_accents="unicode",
18 | stop_words="english")
19 | text = []
20 | for para in paragraphs:
21 | text.append(para)
22 | try:
23 | para_features = tfidf.fit_transform(text)
24 | except ValueError:
25 | return []
26 | # question should be tokenized beforehand
27 | q_features = tfidf.transform([question])
28 | dists = cosine_similarity(q_features, para_features, "cosine").ravel()
29 | # in case of ties, use the earlier paragraph
30 | sorted_ix = np.argsort(dists)[::-1]
31 | if return_scores is True:
32 | return sorted_ix, dists
33 | else:
34 | return sorted_ix
35 |
36 | def dists(self, question, paragraphs):
37 | tfidf = TfidfVectorizer(strip_accents="unicode",
38 | stop_words=self.stop.words)
39 | text = []
40 | for para in paragraphs:
41 | text.append(" ".join(" ".join(s) for s in para.text))
42 | try:
43 | para_features = tfidf.fit_transform(text)
44 | q_features = tfidf.transform([" ".join(question)])
45 | except ValueError:
46 | return []
47 |
48 | dists = pairwise_distances(q_features, para_features, "cosine").ravel()
49 | # in case of ties, use the earlier paragraph
50 | sorted_ix = np.lexsort(([x for x in range(len(paragraphs))], dists))
51 |
52 | if self.filter_dist_one:
53 | return [(paragraphs[i], dists[i]) for i in sorted_ix[:self.n_to_select] if dists[i] < 1.0]
54 | else:
55 | return [(paragraphs[i], dists[i]) for i in sorted_ix[:self.n_to_select]]
56 |
--------------------------------------------------------------------------------
/retriever/tokenizers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Basic tokenizer that splits text into alpha-numeric tokens and
8 | non-whitespace tokens.
9 | """
10 |
11 | import regex
12 | import logging
13 | import copy
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class Tokens(object):
19 | """A class to represent a list of tokenized text."""
20 | TEXT = 0
21 | TEXT_WS = 1
22 | SPAN = 2
23 | POS = 3
24 | LEMMA = 4
25 | NER = 5
26 |
27 | def __init__(self, data, annotators, opts=None):
28 | self.data = data
29 | self.annotators = annotators
30 | self.opts = opts or {}
31 |
32 | def __len__(self):
33 | """The number of tokens."""
34 | return len(self.data)
35 |
36 | def slice(self, i=None, j=None):
37 | """Return a view of the list of tokens from [i, j)."""
38 | new_tokens = copy.copy(self)
39 | new_tokens.data = self.data[i: j]
40 | return new_tokens
41 |
42 | def untokenize(self):
43 | """Returns the original text (with whitespace reinserted)."""
44 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
45 |
46 | def words(self, uncased=False):
47 | """Returns a list of the text of each token
48 | Args:
49 | uncased: lower cases text
50 | """
51 | if uncased:
52 | return [t[self.TEXT].lower() for t in self.data]
53 | else:
54 | return [t[self.TEXT] for t in self.data]
55 |
56 | def offsets(self):
57 | """Returns a list of [start, end) character offsets of each token."""
58 | return [t[self.SPAN] for t in self.data]
59 |
60 | def pos(self):
61 | """Returns a list of part-of-speech tags of each token.
62 | Returns None if this annotation was not included.
63 | """
64 | if 'pos' not in self.annotators:
65 | return None
66 | return [t[self.POS] for t in self.data]
67 |
68 | def lemmas(self):
69 | """Returns a list of the lemmatized text of each token.
70 | Returns None if this annotation was not included.
71 | """
72 | if 'lemma' not in self.annotators:
73 | return None
74 | return [t[self.LEMMA] for t in self.data]
75 |
76 | def entities(self):
77 | """Returns a list of named-entity-recognition tags of each token.
78 | Returns None if this annotation was not included.
79 | """
80 | if 'ner' not in self.annotators:
81 | return None
82 | return [t[self.NER] for t in self.data]
83 |
84 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
85 | """Returns a list of all ngrams from length 1 to n.
86 | Args:
87 | n: upper limit of ngram length
88 | uncased: lower cases text
89 | filter_fn: user function that takes in an ngram list and returns
90 | True or False to keep or not keep the ngram
91 | as_string: return the ngram as a string vs list
92 | """
93 | def _skip(gram):
94 | if not filter_fn:
95 | return False
96 | return filter_fn(gram)
97 |
98 | words = self.words(uncased)
99 | ngrams = [(s, e + 1)
100 | for s in range(len(words))
101 | for e in range(s, min(s + n, len(words)))
102 | if not _skip(words[s:e + 1])]
103 |
104 | # Concatenate into strings
105 | if as_strings:
106 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
107 |
108 | return ngrams
109 |
110 | def entity_groups(self):
111 | """Group consecutive entity tokens with the same NER tag."""
112 | entities = self.entities()
113 | if not entities:
114 | return None
115 | non_ent = self.opts.get('non_ent', 'O')
116 | groups = []
117 | idx = 0
118 | while idx < len(entities):
119 | ner_tag = entities[idx]
120 | # Check for entity tag
121 | if ner_tag != non_ent:
122 | # Chomp the sequence
123 | start = idx
124 | while (idx < len(entities) and entities[idx] == ner_tag):
125 | idx += 1
126 | groups.append((self.slice(start, idx).untokenize(), ner_tag))
127 | else:
128 | idx += 1
129 | return groups
130 |
131 |
132 | class Tokenizer(object):
133 | """Base tokenizer class.
134 | Tokenizers implement tokenize, which should return a Tokens class.
135 | """
136 |
137 | def tokenize(self, text):
138 | raise NotImplementedError
139 |
140 | def shutdown(self):
141 | pass
142 |
143 | def __del__(self):
144 | self.shutdown()
145 |
146 |
147 | class SimpleTokenizer(Tokenizer):
148 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
149 | NON_WS = r'[^\p{Z}\p{C}]'
150 |
151 | def __init__(self, **kwargs):
152 | """
153 | Args:
154 | annotators: None or empty set (only tokenizes).
155 | """
156 | self._regexp = regex.compile(
157 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
158 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
159 | )
160 | if len(kwargs.get('annotators', {})) > 0:
161 | logger.warning('%s only tokenizes! Skipping annotators: %s' %
162 | (type(self).__name__, kwargs.get('annotators')))
163 | self.annotators = set()
164 |
165 | def tokenize(self, text):
166 | data = []
167 | matches = [m for m in self._regexp.finditer(text)]
168 | for i in range(len(matches)):
169 | # Get text
170 | token = matches[i].group()
171 |
172 | # Get whitespace
173 | span = matches[i].span()
174 | start_ws = span[0]
175 | if i + 1 < len(matches):
176 | end_ws = matches[i + 1].span()[0]
177 | else:
178 | end_ws = span[1]
179 |
180 | # Format data
181 | data.append((
182 | token,
183 | text[start_ws: end_ws],
184 | span,
185 | ))
186 | return Tokens(data, self.annotators)
187 |
--------------------------------------------------------------------------------
/retriever/utils.py:
--------------------------------------------------------------------------------
1 | import unicodedata
2 | import jsonlines
3 | import re
4 | from urllib.parse import unquote
5 | import regex
6 | import numpy as np
7 | import scipy.sparse as sp
8 | from sklearn.utils import murmurhash3_32
9 |
10 | import logging
11 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
12 | datefmt='%m/%d/%Y %H:%M:%S',
13 | level=logging.INFO)
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def normalize(text):
18 | """Resolve different type of unicode encodings / capitarization in HotpotQA data."""
19 | text = unicodedata.normalize('NFD', text)
20 | return text[0].capitalize() + text[1:]
21 |
22 |
23 | def make_wiki_id(title, para_index):
24 | title_id = "{0}_{1}".format(normalize(title), para_index)
25 | return title_id
26 |
27 |
28 | def find_hyper_linked_titles(text_w_links):
29 | titles = re.findall(r'href=[\'"]?([^\'" >]+)', text_w_links)
30 | titles = [unquote(title) for title in titles]
31 | titles = [title[0].capitalize() + title[1:] for title in titles]
32 | return titles
33 |
34 |
35 | TAG_RE = re.compile(r'<[^>]+>')
36 |
37 |
38 | def remove_tags(text):
39 | return TAG_RE.sub('', text)
40 |
41 |
42 | def process_jsonlines(filename):
43 | """
44 | This is process_jsonlines method for extracted Wikipedia file.
45 | After extracting items by using Wikiextractor (with `--json` and `--links` options),
46 | you will get the files named with wiki_xx, where each line contains the information of each article.
47 | e.g.,
48 | {"id": "316", "url": "https://en.wikipedia.org/wiki?curid=316", "title": "Academy Award for Best Production Design",
49 | "text": "Academy Award for Best Production Design\n\nThe Academy Award for
50 | Best Production Design recognizes achievement for art direction \n\n"}
51 | This function takes these input and extract items.
52 | Each article contains one or more than one paragraphs, and each paragraphs are separeated by \n\n.
53 | """
54 | # item should be nested list
55 | extracted_items = []
56 | with jsonlines.open(filename) as reader:
57 | for obj in reader:
58 | wiki_id = obj["id"]
59 | title = obj["title"]
60 | title_id = make_wiki_id(title, 0)
61 | text_with_links = obj["text"]
62 |
63 | hyper_linked_titles_text = ""
64 | # When we consider the whole article as a document unit (e.g., SQuAD Open, Natural Questions Open)
65 | # we'll keep the links with the original articles, and dynamically process and extract the links
66 | # when we process with our selector.
67 | extracted_items.append({"wiki_id": wiki_id, "title": title_id,
68 | "plain_text": text_with_links,
69 | "hyper_linked_titles": hyper_linked_titles_text,
70 | "original_title": title})
71 |
72 | return extracted_items
73 |
74 | def process_jsonlines_hotpotqa(filename):
75 | """
76 | This is process_jsonlines method for intro-only processed_wikipedia file.
77 | The item example:
78 | {"id": "45668011", "url": "https://en.wikipedia.org/wiki?curid=45668011", "title": "Flouch Roundabout",
79 | "text": ["Flouch Roundabout is a roundabout near Penistone, South Yorkshire, England, where the A628 meets the A616."],
80 | "charoffset": [[[0, 6],...]]
81 | "text_with_links" : ["Flouch Roundabout is a roundabout near Penistone,
82 | South Yorkshire, England, where the A628
83 | meets the A616."],
84 | "charoffset_with_links": [[[0, 6], ... [213, 214]]]}
85 | """
86 | # item should be nested list
87 | extracted_items = []
88 | with jsonlines.open(filename) as reader:
89 | for obj in reader:
90 | wiki_id = obj["id"]
91 | title = obj["title"]
92 | title_id = make_wiki_id(title, 0)
93 | plain_text = "\t".join(obj["text"])
94 | text_with_links = "\t".join(obj["text_with_links"])
95 |
96 | hyper_linked_titles = []
97 | hyper_linked_titles = find_hyper_linked_titles(text_with_links)
98 | if len(hyper_linked_titles) > 0:
99 | hyper_linked_titles_text = "\t".join(hyper_linked_titles)
100 | else:
101 | hyper_linked_titles_text = ""
102 | extracted_items.append({"wiki_id": wiki_id, "title": title_id,
103 | "plain_text": plain_text,
104 | "hyper_linked_titles": hyper_linked_titles_text,
105 | "original_title": title})
106 |
107 | return extracted_items
108 |
109 |
110 | # ------------------------------------------------------------------------------
111 | # Sparse matrix saving/loading helpers.
112 | # ------------------------------------------------------------------------------
113 |
114 |
115 | def save_sparse_csr(filename, matrix, metadata=None):
116 | data = {
117 | 'data': matrix.data,
118 | 'indices': matrix.indices,
119 | 'indptr': matrix.indptr,
120 | 'shape': matrix.shape,
121 | 'metadata': metadata,
122 | }
123 | np.savez(filename, **data)
124 |
125 |
126 | def load_sparse_csr(filename):
127 | loader = np.load(filename, allow_pickle=True)
128 | matrix = sp.csr_matrix((loader['data'], loader['indices'],
129 | loader['indptr']), shape=loader['shape'])
130 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None
131 |
132 | # ------------------------------------------------------------------------------
133 | # Token hashing.
134 | # ------------------------------------------------------------------------------
135 |
136 |
137 | def hash(token, num_buckets):
138 | """Unsigned 32 bit murmurhash for feature hashing."""
139 | return murmurhash3_32(token, positive=True) % num_buckets
140 |
141 | # ------------------------------------------------------------------------------
142 | # Text cleaning.
143 | # ------------------------------------------------------------------------------
144 |
145 |
146 | STOPWORDS = {
147 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your',
148 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she',
149 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their',
150 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that',
151 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
152 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an',
153 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of',
154 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through',
155 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down',
156 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then',
157 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any',
158 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor',
159 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can',
160 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've',
161 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven',
162 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren',
163 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``"
164 | }
165 |
166 |
167 | def filter_word(text):
168 | """Take out english stopwords, punctuation, and compound endings."""
169 | text = normalize(text)
170 | if regex.match(r'^\p{P}+$', text):
171 | return True
172 | if text.lower() in STOPWORDS:
173 | return True
174 | return False
175 |
176 |
177 | def filter_ngram(gram, mode='any'):
178 | """Decide whether to keep or discard an n-gram.
179 | Args:
180 | gram: list of tokens (length N)
181 | mode: Option to throw out ngram if
182 | 'any': any single token passes filter_word
183 | 'all': all tokens pass filter_word
184 | 'ends': book-ended by filterable tokens
185 | """
186 | filtered = [filter_word(w) for w in gram]
187 | if mode == 'any':
188 | return any(filtered)
189 | elif mode == 'all':
190 | return all(filtered)
191 | elif mode == 'ends':
192 | return filtered[0] or filtered[-1]
193 | else:
194 | raise ValueError('Invalid mode: %s' % mode)
195 |
196 |
197 | def get_field(d, field_list):
198 | """get the subfield associated to a list of elastic fields
199 | E.g. ['file', 'filename'] to d['file']['filename']
200 | """
201 | if isinstance(field_list, str):
202 | return d[field_list]
203 | else:
204 | idx = d.copy()
205 | for field in field_list:
206 | idx = idx[field]
207 | return idx
208 |
209 |
210 | def load_para_collections_from_tfidf_id_intro_only(tfidf_id, db):
211 | if "_0" not in tfidf_id:
212 | tfidf_id = "{0}_0".format(tfidf_id)
213 | if db.get_doc_text(tfidf_id) is None:
214 | logger.warning("{0} is missing".format(tfidf_id))
215 | return []
216 | return [[tfidf_id, db.get_doc_text(tfidf_id).split("\t")]]
217 |
218 | def load_linked_titles_from_tfidf_id(tfidf_id, db):
219 | para_titles = db.get_paras_with_article(tfidf_id)
220 | linked_titles_all = []
221 | for para_title in para_titles:
222 | linked_title_per_para = db.get_hyper_linked(para_title)
223 | if len(linked_title_per_para) > 0:
224 | linked_titles_all += linked_title_per_para.split("\t")
225 | return linked_titles_all
226 |
227 | def load_para_and_linked_titles_dict_from_tfidf_id(tfidf_id, db):
228 | """
229 | load paragraphs and hyperlinked titles from DB.
230 | This method is mainly for Natural Questions Open benchmark.
231 | """
232 | # will be fixed in the later version; current tfidf weights use indexed titles as keys.
233 | if "_0" not in tfidf_id:
234 | tfidf_id = "{0}_0".format(tfidf_id)
235 | paras, linked_titles = db.get_doc_text_hyper_linked_titles_for_articles(
236 | tfidf_id)
237 | if len(paras) == 0:
238 | logger.warning("{0} is missing".format(tfidf_id))
239 | return [], []
240 |
241 | paras_dict = {}
242 | linked_titles_dict = {}
243 | article_name = tfidf_id.split("_0")[0]
244 | # store the para_dict and linked_titles_dict; skip the first para (title)
245 | for para_idx, (para, linked_title_list) in enumerate(zip(paras[1:], linked_titles[1:])):
246 | paras_dict["{0}_{1}".format(article_name, para_idx)] = para
247 | linked_titles_dict["{0}_{1}".format(
248 | article_name, para_idx)] = linked_title_list
249 |
250 | return paras_dict, linked_titles_dict
251 |
252 | def prune_top_k_paragraphs(question_text, paragraphs, tfidf_vectorizer, pruning_l=10):
253 | para_titles, para_text = list(paragraphs.keys()), list(paragraphs.values())
254 | # prune top l paragraphs using the question as query to reduce the search space.
255 | top_tfidf_para_indices = tfidf_vectorizer.prune(
256 | question_text, para_text)[:pruning_l]
257 | para_title_text_pairs_pruned = {}
258 |
259 | # store the selected paras into dictionary.
260 | for idx in top_tfidf_para_indices:
261 | para_title_text_pairs_pruned[para_titles[idx]] = para_text[idx]
262 |
263 | return para_title_text_pairs_pruned
264 |
--------------------------------------------------------------------------------
/sequential_sentence_selector/README.md:
--------------------------------------------------------------------------------
1 | # Sequential Sentence Selector
2 |
3 | This directory includes codes for our sequential sentence selector model described in Appendix A.4 of our paper.
4 | The model is the same as our proposed graph retriever model to retrieve reasoning paths, and it is adapted to the supporting fact prediction task in HotpotQA.
5 |
6 | ## Training
7 | We use `run_sequential_sentence_selector.py` to train the model.
8 | Basically, the overall code is based on that of our graph retriever.
9 | Here is an example command used for our paper.
10 | Since we used `pytorch-pretrained-bert` to develop our models, we cannot directly use the BERT-whole-word-masking configurations.
11 | However, we have a way to use, for example, `bert-large-uncased-whole-word-masking` when we have `pytorch-transformers` in our env.
12 | Refer to `./utils.py` for this trick.
13 |
14 | ```bash
15 | python run_sequential_sentence_selector.py \
16 | --bert_model bert-large-uncased-whole-word-masking \
17 | --train_file_path \
18 | --output_dir \
19 | --do_lower_case \
20 | --train_batch_size 12 \
21 | --gradient_accumulation_steps 1 \
22 | --num_train_epochs 3 \
23 | --learning_rate 3e-5
24 | ```
25 |
26 | Once you train your model, you can use it in our evaluation pipeline script for HotpotQA.
27 | You can use `hotpot_sf_selector_order_train.json` for `--train_file_path`.
28 |
--------------------------------------------------------------------------------
/sequential_sentence_selector/modeling_sequential_sentence_selector.py:
--------------------------------------------------------------------------------
1 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import CrossEntropyLoss
7 | from torch.nn.parameter import Parameter
8 |
9 | class BertForSequentialSentenceSelector(BertPreTrainedModel):
10 |
11 | def __init__(self, config):
12 | super(BertForSequentialSentenceSelector, self).__init__(config)
13 |
14 | self.bert = BertModel(config)
15 |
16 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
17 |
18 | # Initial state
19 | self.s = Parameter(torch.FloatTensor(config.hidden_size).uniform_(-0.1, 0.1))
20 |
21 | # Scaling factor for weight norm
22 | self.g = Parameter(torch.FloatTensor(1).fill_(1.0))
23 |
24 | # RNN weight
25 | self.rw = nn.Linear(2*config.hidden_size, config.hidden_size)
26 |
27 | # EOE and output bias
28 | self.eos = Parameter(torch.FloatTensor(config.hidden_size).uniform_(-0.1, 0.1))
29 | self.bias = Parameter(torch.FloatTensor(1).zero_())
30 |
31 | self.apply(self.init_bert_weights)
32 | self.cpu = torch.device('cpu')
33 |
34 | '''
35 | state: (B, 1, D)
36 | '''
37 | def weight_norm(self, state):
38 | state = state / state.norm(dim = 2).unsqueeze(2)
39 | state = self.g * state
40 | return state
41 |
42 | def encode(self, input_ids, token_type_ids, attention_mask, split_chunk = None):
43 | B = input_ids.size(0)
44 | N = input_ids.size(1)
45 | input_ids = input_ids.contiguous().view(input_ids.size(0)*input_ids.size(1), input_ids.size(2))
46 | token_type_ids = token_type_ids.contiguous().view(token_type_ids.size(0)*token_type_ids.size(1), token_type_ids.size(2))
47 | attention_mask = attention_mask.contiguous().view(attention_mask.size(0)*attention_mask.size(1), attention_mask.size(2))
48 |
49 | # [CLS] vectors for Q-P pairs
50 | if split_chunk is None:
51 | encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
52 | pooled_output = encoded_layers[:, 0]
53 | else:
54 | TOTAL = input_ids.size(0)
55 | start = 0
56 |
57 | while start < TOTAL:
58 | end = min(start+split_chunk-1, TOTAL-1)
59 | chunk_len = end-start+1
60 |
61 | input_ids_ = input_ids[start:start+chunk_len, :]
62 | token_type_ids_ = token_type_ids[start:start+chunk_len, :]
63 | attention_mask_ = attention_mask[start:start+chunk_len, :]
64 |
65 | encoded_layers, pooled_output_ = self.bert(input_ids_, token_type_ids_, attention_mask_, output_all_encoded_layers=False)
66 | encoded_layers = encoded_layers[:, 0]
67 |
68 | if start == 0:
69 | pooled_output = encoded_layers
70 | else:
71 | pooled_output = torch.cat((pooled_output, encoded_layers), dim = 0)
72 |
73 | start = end+1
74 |
75 | pooled_output = pooled_output.contiguous()
76 |
77 | paragraphs = pooled_output.view(pooled_output.size(0)//N, N, pooled_output.size(1)) # (B, N, D)
78 | EOE = self.eos.unsqueeze(0).unsqueeze(0) # (1, 1, D)
79 | EOE = EOE.expand(paragraphs.size(0), EOE.size(1), EOE.size(2)) # (B, 1, D)
80 | EOE = self.bert.encoder.layer[-1].output.LayerNorm(EOE)
81 | paragraphs = torch.cat((paragraphs, EOE), dim = 1) # (B, N+1, D)
82 |
83 | # Initial state
84 | state = self.s.expand(paragraphs.size(0), 1, self.s.size(0))
85 | state = self.weight_norm(state)
86 |
87 | return paragraphs, state
88 |
89 | def forward(self, input_ids, token_type_ids, attention_mask, output_mask, target, target_ids, max_sf_num):
90 |
91 | paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask)
92 |
93 | for i in range(max_sf_num+1):
94 | if i == 0:
95 | h = state
96 | else:
97 | for j in range(target_ids.size(0)):
98 | index = target_ids[j, i-1]
99 | input_ = paragraphs[j:j+1, index:index+1, :] # (B, 1, D)
100 |
101 | if j == 0:
102 | input = input_
103 | else:
104 | input = torch.cat((input, input_), dim = 0)
105 |
106 | state = torch.cat((state, input), dim = 2) # (B, 1, 2*D)
107 | state = self.rw(state) # (B, 1, D)
108 | state = self.weight_norm(state)
109 | h = torch.cat((h, state), dim = 1) # ...--> (B, max_num_steps, D)
110 |
111 | h = self.dropout(h)
112 | output = torch.bmm(h, paragraphs.transpose(1, 2)) # (B, max_num_steps, N+1)
113 | output = output + self.bias
114 | loss = F.binary_cross_entropy_with_logits(output, target, weight = output_mask, reduction = 'mean')
115 | return loss
116 |
117 | def beam_search(self, input_ids, token_type_ids, attention_mask, output_mask, max_num_steps, examples, beam = 2):
118 |
119 | B = input_ids.size(0)
120 | paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask, split_chunk = 300)
121 |
122 | pred = []
123 | prob = []
124 |
125 | topk_pred = []
126 | topk_prob = []
127 |
128 | eoe_index = paragraphs.size(1)-1
129 |
130 | output_mask = output_mask.to(self.cpu)
131 |
132 | for i in range(B):
133 | pred_ = [[[], 1.0, 0] for _ in range(beam)] # [hist_1, score_1, len_1], [hist_2, score_2, len_2], ...
134 | prob_ = [[] for _ in range(beam)]
135 |
136 | state_ = state[i:i+1] # (1, 1, D)
137 | state_ = state_.expand(beam, 1, state_.size(2)) # -> (beam, 1, D)
138 | state_tmp = torch.FloatTensor(state_.size()).zero_().to(state_.device)
139 | ps = paragraphs[i:i+1] # (1, N+1, D)
140 | ps = ps.expand(beam, ps.size(1), ps.size(2)) # -> (beam, N+1, D)
141 |
142 | for j in range(max_num_steps):
143 | if j > 0:
144 | input = [p[0][-1] for p in pred_]
145 | input = torch.LongTensor(input).to(paragraphs.device)
146 | input = ps[0][input].unsqueeze(1) # (beam, 1, D)
147 | state_ = torch.cat((state_, input), dim = 2) # (beam, 1, 2*D)
148 | state_ = self.rw(state_) # (beam, 1, D)
149 | state_ = self.weight_norm(state_)
150 |
151 | output = torch.bmm(state_, ps.transpose(1, 2)) # (beam, 1, N+1)
152 | output = output + self.bias
153 | output = torch.sigmoid(output)
154 |
155 | output = output.to(self.cpu)
156 |
157 | if j == 0:
158 | output = output * output_mask[i:i+1, 0:1, :]
159 | else:
160 | for b in range(beam):
161 | for k in range(len(pred_[b][0])):
162 | output[b:b+1] *= output_mask[i:i+1, pred_[b][0][k]+1:pred_[b][0][k]+2, :]
163 |
164 | e = examples[i]
165 | # Predict at least 1 sentence anyway
166 | if j <= 0:
167 | output[:, :, -1] = 0.0
168 | # No further constraints for single title
169 | elif len(e.titles) == 1:
170 | pass
171 | else:
172 | for b in range(beam):
173 | sfs = set()
174 | for p in pred_[b][0]:
175 | if p == eoe_index:
176 | break
177 | offset = 0
178 | for k in range(len(e.titles)):
179 | if p >= offset and p < offset+len(e.context[e.titles[k]]):
180 | sfs.add(e.titles[k])
181 | break
182 | offset += len(e.context[e.titles[k]])
183 | if len(sfs) == 1:
184 | output[b, :, -1] = 0.0
185 |
186 | score = [p[1] for p in pred_]
187 | score = torch.FloatTensor(score)
188 | score = score.unsqueeze(1).unsqueeze(2) # (beam, 1, 1)
189 | score = output * score
190 |
191 | output = output.squeeze(1) # (beam, N+1)
192 | score = score.squeeze(1) # (beam, N+1)
193 | new_pred_ = []
194 | new_prob_ = []
195 |
196 | for b in range(beam):
197 | s, p = torch.max(score.view(score.size(0)*score.size(1)), dim = 0)
198 | s = s.item()
199 | p = p.item()
200 | row = p // score.size(1)
201 | col = p % score.size(1)
202 |
203 | p = [[index for index in pred_[row][0]] + [col],
204 | score[row, col].item(),
205 | pred_[row][2] + (1 if col != eoe_index else 0)]
206 | new_pred_.append(p)
207 |
208 | p = [[p_ for p_ in prb] for prb in prob_[row]] + [output[row].tolist()]
209 | new_prob_.append(p)
210 |
211 | state_tmp[b].copy_(state_[row])
212 |
213 | if j == 0:
214 | score[:, col] = 0.0
215 | else:
216 | score[row, col] = 0.0
217 |
218 | pred_ = new_pred_
219 | prob_ = new_prob_
220 | state_ = state_.clone()
221 | state_.copy_(state_tmp)
222 |
223 | if pred_[0][0][-1] == eoe_index:
224 | break
225 |
226 | topk_pred.append([])
227 | topk_prob.append([])
228 | for index__ in range(beam):
229 |
230 | pred_tmp = []
231 | for index in pred_[index__][0]:
232 | if index == eoe_index:
233 | break
234 | pred_tmp.append(index)
235 |
236 | if index__ == 0:
237 | pred.append(pred_tmp)
238 | prob.append(prob_[0])
239 |
240 | topk_pred[-1].append(pred_tmp)
241 | topk_prob[-1].append(prob_[index__])
242 |
243 | return pred, prob, topk_pred, topk_prob
244 |
--------------------------------------------------------------------------------
/sequential_sentence_selector/run_sequential_sentence_selector.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import csv
6 | import os
7 | import logging
8 | import argparse
9 | import random
10 | from tqdm import tqdm, trange
11 |
12 | import numpy as np
13 | import torch
14 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler
15 | from torch.utils.data.distributed import DistributedSampler
16 |
17 | from pytorch_pretrained_bert.tokenization import BertTokenizer
18 | from pytorch_pretrained_bert.optimization import BertAdam
19 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
20 |
21 | try:
22 | from sequential_sentence_selector.modeling_sequential_sentence_selector import BertForSequentialSentenceSelector
23 | except:
24 | from modeling_sequential_sentence_selector import BertForSequentialSentenceSelector
25 |
26 | import json
27 |
28 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
29 | datefmt='%m/%d/%Y %H:%M:%S',
30 | level=logging.INFO)
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | class InputExample(object):
35 | """A single training/test example for simple sequence classification."""
36 |
37 | def __init__(self, guid, q, a, t, c, sf):
38 |
39 | self.guid = guid
40 | self.question = q
41 | self.answer = a
42 | self.titles = t
43 | self.context = c
44 | self.supporting_facts = sf
45 |
46 |
47 | class InputFeatures(object):
48 | """A single set of features of data."""
49 |
50 | def __init__(self, input_ids, input_masks, segment_ids, target_ids, output_masks, num_sents, num_sfs, ex_index):
51 | self.input_ids = input_ids
52 | self.input_masks = input_masks
53 | self.segment_ids = segment_ids
54 | self.target_ids = target_ids
55 | self.output_masks = output_masks
56 | self.num_sents = num_sents
57 | self.num_sfs = num_sfs
58 |
59 | self.ex_index = ex_index
60 |
61 |
62 | class DataProcessor:
63 |
64 | def get_train_examples(self, file_name):
65 | return self.create_examples(json.load(open(file_name, 'r')))
66 |
67 | def create_examples(self, jsn):
68 | examples = []
69 | max_sent_num = 0
70 | for data in jsn:
71 | guid = data['q_id']
72 | question = data['question']
73 | titles = data['titles']
74 | context = data['context'] # {title: [s1, s2, ...]}
75 | # {title: [index1, index2, ...]}
76 | supporting_facts = data['supporting_facts']
77 |
78 | max_sent_num = max(max_sent_num, sum(
79 | [len(context[title]) for title in context]))
80 |
81 | examples.append(InputExample(
82 | guid, question, data['answer'], titles, context, supporting_facts))
83 |
84 | return examples
85 |
86 |
87 | def convert_examples_to_features(examples, max_seq_length, max_sent_num, max_sf_num, tokenizer, train=False):
88 | """Loads a data file into a list of `InputBatch`s."""
89 |
90 | DUMMY = [0] * max_seq_length
91 | DUMMY_ = [0.0] * max_sent_num
92 | features = []
93 | logger.info('#### Constructing features... ####')
94 | for (ex_index, example) in enumerate(tqdm(examples, desc='Example')):
95 |
96 | tokens_q = tokenizer.tokenize(
97 | 'Q: {} A: {}'.format(example.question, example.answer))
98 | tokens_q = ['[CLS]'] + tokens_q + ['[SEP]']
99 |
100 | input_ids = []
101 | input_masks = []
102 | segment_ids = []
103 |
104 | for title in example.titles:
105 | sents = example.context[title]
106 | for (i, s) in enumerate(sents):
107 |
108 | if len(input_ids) == max_sent_num:
109 | break
110 |
111 | tokens_s = tokenizer.tokenize(
112 | s)[:max_seq_length-len(tokens_q)-1]
113 | tokens_s = tokens_s + ['[SEP]']
114 |
115 | padding = [0] * (max_seq_length -
116 | len(tokens_s) - len(tokens_q))
117 |
118 | input_ids_ = tokenizer.convert_tokens_to_ids(
119 | tokens_q + tokens_s)
120 | input_masks_ = [1] * len(input_ids_)
121 | segment_ids_ = [0] * len(tokens_q) + [1] * len(tokens_s)
122 |
123 | input_ids_ += padding
124 | input_ids.append(input_ids_)
125 |
126 | input_masks_ += padding
127 | input_masks.append(input_masks_)
128 |
129 | segment_ids_ += padding
130 | segment_ids.append(segment_ids_)
131 |
132 | assert len(input_ids_) == max_seq_length
133 | assert len(input_masks_) == max_seq_length
134 | assert len(segment_ids_) == max_seq_length
135 |
136 | target_ids = []
137 | target_offset = 0
138 |
139 | for title in example.titles:
140 | sfs = example.supporting_facts[title]
141 | for i in sfs:
142 | if i < len(example.context[title]) and i+target_offset < len(input_ids):
143 | target_ids.append(i+target_offset)
144 | else:
145 | logger.warning('')
146 | logger.warning('Invalid annotation: {}'.format(sfs))
147 | logger.warning('Invalid annotation: {}'.format(
148 | example.context[title]))
149 |
150 | target_offset += len(example.context[title])
151 |
152 | assert len(input_ids) <= max_sent_num
153 | assert len(target_ids) <= max_sf_num
154 |
155 | num_sents = len(input_ids)
156 | num_sfs = len(target_ids)
157 |
158 | output_masks = [([1.0] * len(input_ids) + [0.0] * (max_sent_num -
159 | len(input_ids) + 1)) for _ in range(max_sent_num + 2)]
160 |
161 | if train:
162 |
163 | for i in range(len(target_ids)):
164 | for j in range(len(target_ids)):
165 | if i == j:
166 | continue
167 |
168 | output_masks[i][target_ids[j]] = 0.0
169 |
170 | for i in range(len(output_masks)):
171 | if i >= num_sfs+1:
172 | for j in range(len(output_masks[i])):
173 | output_masks[i][j] = 0.0
174 |
175 | else:
176 | for i in range(len(input_ids)):
177 | output_masks[i+1][i] = 0.0
178 |
179 | target_ids += [0] * (max_sf_num - len(target_ids))
180 |
181 | padding = [DUMMY] * (max_sent_num - len(input_ids))
182 | input_ids += padding
183 | input_masks += padding
184 | segment_ids += padding
185 |
186 | features.append(
187 | InputFeatures(input_ids=input_ids,
188 | input_masks=input_masks,
189 | segment_ids=segment_ids,
190 | target_ids=target_ids,
191 | output_masks=output_masks,
192 | num_sents=num_sents,
193 | num_sfs=num_sfs,
194 | ex_index=ex_index))
195 |
196 | logger.info('Done!')
197 |
198 | return features
199 |
200 |
201 | def warmup_linear(x, warmup=0.002):
202 | if x < warmup:
203 | return x/warmup
204 | return 1.0 - x
205 |
206 |
207 | def main():
208 | parser = argparse.ArgumentParser()
209 |
210 | ## Required parameters
211 | parser.add_argument("--bert_model", default=None, type=str, required=True,
212 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
213 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
214 | "bert-base-multilingual-cased, bert-base-chinese.")
215 | parser.add_argument("--output_dir",
216 | default=None,
217 | type=str,
218 | required=True,
219 | help="The output directory where the model predictions and checkpoints will be written.")
220 |
221 | parser.add_argument("--train_file_path",
222 | type=str,
223 | default=None,
224 | required=True,
225 | help="File path to training data")
226 |
227 | ## Other parameters
228 | parser.add_argument("--max_seq_length",
229 | default=256,
230 | type=int,
231 | help="The maximum total input sequence length after WordPiece tokenization. \n"
232 | "Sequences longer than this will be truncated, and sequences shorter \n"
233 | "than this will be padded.")
234 | parser.add_argument("--max_sent_num",
235 | default=30,
236 | type=int)
237 | parser.add_argument("--max_sf_num",
238 | default=15,
239 | type=int)
240 | parser.add_argument("--do_lower_case",
241 | action='store_true',
242 | help="Set this flag if you are using an uncased model.")
243 | parser.add_argument("--train_batch_size",
244 | default=1,
245 | type=int,
246 | help="Total batch size for training.")
247 | parser.add_argument("--eval_batch_size",
248 | default=5,
249 | type=int,
250 | help="Total batch size for eval.")
251 | parser.add_argument("--learning_rate",
252 | default=5e-5,
253 | type=float,
254 | help="The initial learning rate for Adam. (def: 5e-5)")
255 | parser.add_argument("--num_train_epochs",
256 | default=5.0,
257 | type=float,
258 | help="Total number of training epochs to perform.")
259 | parser.add_argument("--warmup_proportion",
260 | default=0.1,
261 | type=float,
262 | help="Proportion of training to perform linear learning rate warmup for. "
263 | "E.g., 0.1 = 10%% of training.")
264 | parser.add_argument("--no_cuda",
265 | action='store_true',
266 | help="Whether not to use CUDA when available")
267 | parser.add_argument('--seed',
268 | type=int,
269 | default=42,
270 | help="random seed for initialization")
271 | parser.add_argument('--gradient_accumulation_steps',
272 | type=int,
273 | default=1,
274 | help="Number of updates steps to accumulate before performing a backward/update pass.")
275 |
276 | args = parser.parse_args()
277 |
278 | cpu = torch.device('cpu')
279 |
280 | device = torch.device("cuda" if torch.cuda.is_available()
281 | and not args.no_cuda else "cpu")
282 | n_gpu = torch.cuda.device_count()
283 |
284 | args.train_batch_size = int(
285 | args.train_batch_size / args.gradient_accumulation_steps)
286 |
287 | random.seed(args.seed)
288 | np.random.seed(args.seed)
289 | torch.manual_seed(args.seed)
290 | if n_gpu > 0:
291 | torch.cuda.manual_seed_all(args.seed)
292 |
293 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
294 | raise ValueError(
295 | "Output directory ({}) already exists and is not empty.".format(args.output_dir))
296 | os.makedirs(args.output_dir, exist_ok=True)
297 |
298 | processor = DataProcessor()
299 |
300 | # Prepare model
301 | if args.bert_model != 'bert-large-uncased-whole-word-masking':
302 | tokenizer = BertTokenizer.from_pretrained(
303 | args.bert_model, do_lower_case=args.do_lower_case)
304 |
305 | model = BertForSequentialSentenceSelector.from_pretrained(args.bert_model,
306 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1))
307 | else:
308 | model = BertForSequentialSentenceSelector.from_pretrained('bert-large-uncased',
309 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1))
310 | from utils import get_bert_model_from_pytorch_transformers
311 |
312 | state_dict, vocab_file = get_bert_model_from_pytorch_transformers(
313 | args.bert_model)
314 | model.bert.load_state_dict(state_dict)
315 | tokenizer = BertTokenizer.from_pretrained(
316 | vocab_file, do_lower_case=args.do_lower_case)
317 |
318 | logger.info(
319 | 'The {} model is successfully loaded!'.format(args.bert_model))
320 |
321 | model.to(device)
322 | if n_gpu > 1:
323 | model = torch.nn.DataParallel(model)
324 |
325 | global_step = 0
326 | nb_tr_steps = 0
327 | tr_loss = 0
328 |
329 | POSITIVE = 1.0
330 | NEGATIVE = 0.0
331 |
332 | # Load training examples
333 | train_examples = None
334 | num_train_steps = None
335 | train_examples = processor.get_train_examples(args.train_file_path)
336 | train_features = convert_examples_to_features(
337 | train_examples, args.max_seq_length, args.max_sent_num, args.max_sf_num, tokenizer, train=True)
338 |
339 | num_train_steps = int(
340 | len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
341 |
342 | # Prepare optimizer
343 | param_optimizer = list(model.named_parameters())
344 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
345 | optimizer_grouped_parameters = [
346 | {'params': [p for n, p in param_optimizer if not any(
347 | nd in n for nd in no_decay)], 'weight_decay': 0.01},
348 | {'params': [p for n, p in param_optimizer if any(
349 | nd in n for nd in no_decay)], 'weight_decay': 0.0}
350 | ]
351 | t_total = num_train_steps
352 |
353 | optimizer = BertAdam(optimizer_grouped_parameters,
354 | lr=args.learning_rate,
355 | warmup=args.warmup_proportion,
356 | t_total=t_total,
357 | max_grad_norm=1.0)
358 |
359 | logger.info("***** Running training *****")
360 | logger.info(" Num examples = %d", len(train_features))
361 | logger.info(" Batch size = %d", args.train_batch_size)
362 | logger.info(" Num steps = %d", num_train_steps)
363 |
364 | all_input_ids = torch.tensor(
365 | [f.input_ids for f in train_features], dtype=torch.long)
366 | all_input_masks = torch.tensor(
367 | [f.input_masks for f in train_features], dtype=torch.long)
368 | all_segment_ids = torch.tensor(
369 | [f.segment_ids for f in train_features], dtype=torch.long)
370 | all_target_ids = torch.tensor(
371 | [f.target_ids for f in train_features], dtype=torch.long)
372 | all_output_masks = torch.tensor(
373 | [f.output_masks for f in train_features], dtype=torch.float)
374 | all_num_sents = torch.tensor(
375 | [f.num_sents for f in train_features], dtype=torch.long)
376 | all_num_sfs = torch.tensor(
377 | [f.num_sfs for f in train_features], dtype=torch.long)
378 | train_data = TensorDataset(all_input_ids,
379 | all_input_masks,
380 | all_segment_ids,
381 | all_target_ids,
382 | all_output_masks,
383 | all_num_sents,
384 | all_num_sfs)
385 | train_sampler = RandomSampler(train_data)
386 | train_dataloader = DataLoader(
387 | train_data, sampler=train_sampler, batch_size=args.train_batch_size)
388 |
389 | model.train()
390 | epc = 0
391 | for _ in trange(int(args.num_train_epochs), desc="Epoch"):
392 | tr_loss = 0
393 | nb_tr_examples, nb_tr_steps = 0, 0
394 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
395 | input_masks = batch[1]
396 | batch_max_len = input_masks.sum(dim=2).max().item()
397 |
398 | target_ids = batch[3]
399 |
400 | num_sents = batch[5]
401 | batch_max_sent_num = num_sents.max().item()
402 |
403 | num_sfs = batch[6]
404 | batch_max_sf_num = num_sfs.max().item()
405 |
406 | output_masks_cpu = (batch[4])[
407 | :, :batch_max_sf_num+1, :batch_max_sent_num+1]
408 |
409 | batch = tuple(t.to(device) for t in batch)
410 | input_ids, input_masks, segment_ids, _, output_masks, __, ___ = batch
411 | B = input_ids.size(0)
412 |
413 | input_ids = input_ids[:, :batch_max_sent_num, :batch_max_len]
414 | input_masks = input_masks[:, :batch_max_sent_num, :batch_max_len]
415 | segment_ids = segment_ids[:, :batch_max_sent_num, :batch_max_len]
416 | target_ids = target_ids[:, :batch_max_sf_num]
417 | # 1 for EOE
418 | output_masks = output_masks[:,
419 | :batch_max_sf_num+1, :batch_max_sent_num+1]
420 |
421 | target = torch.FloatTensor(output_masks.size()).fill_(
422 | NEGATIVE) # (B, NUM_STEPS, |S|+1) <- 1 for EOE
423 | for i in range(B):
424 | output_masks[i, :num_sfs[i]+1, -1] = 1.0 # for EOE
425 | target[i, num_sfs[i], -1].fill_(POSITIVE)
426 |
427 | for j in range(num_sfs[i].item()):
428 | target[i, j, target_ids[i, j]].fill_(POSITIVE)
429 | target = target.to(device)
430 |
431 | loss = model(input_ids, segment_ids, input_masks,
432 | output_masks, target, target_ids, batch_max_sf_num)
433 |
434 | if n_gpu > 1:
435 | loss = loss.mean() # mean() to average on multi-gpu.
436 | if args.gradient_accumulation_steps > 1:
437 | loss = loss / args.gradient_accumulation_steps
438 |
439 | loss.backward()
440 |
441 | tr_loss += loss.item()
442 |
443 | nb_tr_examples += B
444 | nb_tr_steps += 1
445 | if (step + 1) % args.gradient_accumulation_steps == 0:
446 | # modify learning rate with special warm up BERT uses
447 | lr_this_step = args.learning_rate * \
448 | warmup_linear(global_step/t_total, args.warmup_proportion)
449 | for param_group in optimizer.param_groups:
450 | param_group['lr'] = lr_this_step
451 | optimizer.step()
452 | optimizer.zero_grad()
453 | global_step += 1
454 |
455 | model_to_save = model.module if hasattr(
456 | model, 'module') else model # Only save the model it-self
457 | output_model_file = os.path.join(
458 | args.output_dir, "pytorch_model_"+str(epc+1)+".bin")
459 | torch.save(model_to_save.state_dict(), output_model_file)
460 | epc += 1
461 |
462 |
463 | if __name__ == "__main__":
464 | main()
465 |
--------------------------------------------------------------------------------
/sequential_sentence_selector/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 |
5 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
6 | BertModel, BertTokenizer)
7 |
8 | MODEL_CLASSES = {
9 | 'bert': (BertConfig, BertModel, BertTokenizer),
10 | }
11 |
12 | def get_bert_model_from_pytorch_transformers(model_name):
13 | config_class, model_class, tokenizer_class = MODEL_CLASSES['bert']
14 | config = config_class.from_pretrained(model_name)
15 | model = model_class.from_pretrained(model_name, from_tf=bool('.ckpt' in model_name), config=config)
16 |
17 | tokenizer = tokenizer_class.from_pretrained(model_name)
18 |
19 | vocab_file_name = './vocabulary_'+model_name+'.txt'
20 |
21 | if not os.path.exists(vocab_file_name):
22 | index = 0
23 | with open(vocab_file_name, "w", encoding="utf-8") as writer:
24 | for token, token_index in sorted(tokenizer.vocab.items(), key=lambda kv: kv[1]):
25 | if index != token_index:
26 | assert False
27 | index = token_index
28 | writer.write(token + u'\n')
29 | index += 1
30 |
31 | return model.state_dict(), vocab_file_name
32 |
--------------------------------------------------------------------------------