├── LICENSE.txt
├── README.md
├── bert
├── custom_modeling.py
├── modeling.py
├── optimization.py
├── run_squad_document_full_e2e.py
├── run_triviaqa_wiki_full_e2e.py
└── tokenization.py
├── data
└── squad
│ └── dev-v1.1.json
├── image
└── framework.PNG
├── squad
├── convert_squad_open.py
├── squad_document_utils.py
├── squad_evaluate.py
├── squad_open_utils.py
└── squad_utils.py
└── triviaqa
├── ablate_triviaqa_unfiltered.py
├── ablate_triviaqa_wiki.py
├── answer_detection.py
├── build_span_corpus.py
├── configurable.py
├── evidence_corpus.py
├── preprocessed_corpus.py
├── read_data.py
├── triviaqa_document_utils.py
├── triviaqa_eval.py
└── utils.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Retrieve, Read, Rerank: Towards End-to-End Multi-Document Reading Comprehension
2 |
3 | This repo contains the code of the following paper:
4 |
5 | [Retrieve, Read, Rerank: Towards End-to-End Multi-Document Reading Comprehension](https://arxiv.org/abs/1906.04618). Minghao Hu, Yuxing Peng, Zhen Huang, Dongsheng Li. ACL 2019.
6 |
7 | In this paper, we propose an end-to-end neural network for the multi-document reading comprehension task, which is shown as below:
8 |
9 |
10 |
11 |
12 | This network consists of three components:
13 | - Early-stopped retriever
14 | - Distantly-supervised reader
15 | - Answer reranker
16 |
17 | Given multiple documents, the network is designed to retrieve relevant document content, propose multiple answer candidates, and finally rerank these candidates. We utilize [BERT](https://github.com/huggingface/pytorch-pretrained-BERT) to initialize our network. The whole network is trained end-to-end with a multi-task objective.
18 |
19 | ## Pre-trained Models
20 | To reproduce our results, we release the following pre-trained models:
21 | - [squad_doc_base](https://drive.google.com/file/d/16lTmN2wu31QdUvExW_fGcDJxKnR7f912/view?usp=sharing)
22 | - [triviaqa_wiki_base](https://drive.google.com/file/d/1Re_2KxBlCQ9_sxTmkZGoahjX72c1eCfk/view?usp=sharing)
23 | - [triviaqa_unfiltered_base](https://drive.google.com/file/d/1kqF40UhJAC6XkAbywI-YMIg_C5t0oS2Q/view?usp=sharing)
24 |
25 | ## Requirements
26 | - Python 3.6
27 | - [Pytorch 1.1](https://pytorch.org/)
28 |
29 | Download the uncased [BERT-Base](https://drive.google.com/file/d/13I0Gj7v8lYhW5Hwmp5kxm3CTlzWZuok2/view?usp=sharing) model and unzip it in the current directory.
30 |
31 | ## SQuAD-document
32 | To run experiments on the SQuAD-document dataset, first set up the environment:
33 | ```bash
34 | export DATA_DIR=data/squad
35 | export BERT_DIR=bert-base-uncased
36 | ```
37 |
38 | Make sure `train-v1.1.json` and `dev-v1.1.json` are placed in `DATA_DIR`.
39 |
40 | Then run the following command to train the model:
41 | ```shell
42 | python -m bert.run_squad_document_full_e2e \
43 | --vocab_file $BERT_DIR/vocab.txt \
44 | --bert_config_file $BERT_DIR/bert_config.json \
45 | --init_checkpoint $BERT_DIR/pytorch_model.bin \
46 | --do_train \
47 | --do_predict \
48 | --data_dir $DATA_DIR \
49 | --train_file train-v1.1.json \
50 | --predict_file dev-v1.1.json \
51 | --train_batch_size 32 \
52 | --learning_rate 3e-5 \
53 | --num_train_epochs 2.0 \
54 | --output_dir out/squad_doc/01
55 | ```
56 | All experiments in our paper were conducted on 4 NVIDIA TESLA P40 (22GB memory per card). The training took nearly 22 hours to converge. If you do not have enough GPU capacity, you can change several hyper-parameters such as (
57 | these changes might cause performance degradation.):
58 | - `--train_batch_size`: total batch size for training.
59 | - `--n_para_train`: the number of paragraph retrieved by TF-IDF during training (denoted as `K` in our paper).
60 | - `--n_best_size_rank`: the number of segments retrieved by early-stopped retriever (denoted as `N` in our paper).
61 | - `--num_hidden_rank`: the number of Transformer blocks used for retrieving (denoted as `J` in our paper).
62 | - `--gradient_accumulation_steps`: number of updates steps to accumulate before performing a backward/update pass.
63 | - `--optimize_on_cpu`: whether to perform optimization and keep the optimizer averages on CPU.
64 |
65 | The base model can be trained on 2 Geforce GTX TITAN (12GB memory per card) with the following command:
66 | ```shell
67 | python -m bert.run_squad_document_full_e2e \
68 | --vocab_file $BERT_DIR/vocab.txt \
69 | --bert_config_file $BERT_DIR/bert_config.json \
70 | --init_checkpoint $BERT_DIR/pytorch_model.bin \
71 | --do_train \
72 | --do_predict \
73 | --data_dir $DATA_DIR \
74 | --train_file train-v1.1.json \
75 | --predict_file dev-v1.1.json \
76 | --train_batch_size 32 \
77 | --learning_rate 3e-5 \
78 | --num_train_epochs 2.0 \
79 | --optimize_on_cpu \
80 | --gradient_accumulation_steps 4 \
81 | --output_dir out/squad_doc/01
82 | ```
83 |
84 | Finally, you can get a dev result from `out/squad_doc/01/performance.txt` like this:
85 | ```bash
86 | Ranker, type: test, step: 19332, map: 0.891, mrr: 0.916, top_1: 0.880, top_3: 0.945, top_5: 0.969, top_7: 0.977, retrieval_rate: 0.558
87 | Reader, type: test, step: 19332, test_em: 77.909, test_f1: 84.817
88 | ```
89 |
90 | ## SQuAD-open
91 | Once you have trained a model on document-level SQuAD, you can evaluate it on the open-domain version of SQuAD dataset.
92 |
93 | First, download the pre-processed [SQuAD-open dev set](https://drive.google.com/file/d/1oBqoNNGVV2yCKvEWv5k91PBUHNDl5q8J/view?usp=sharing) and place it in `data/squad/`
94 |
95 | Then run the following command to evaluate the model:
96 | ```shell
97 | python -m bert.run_squad_document_full_e2e \
98 | --vocab_file $BERT_DIR/vocab.txt \
99 | --bert_config_file $BERT_DIR/bert_config.json \
100 | --do_predict_open \
101 | --data_dir $DATA_DIR \
102 | --output_dir out/squad_doc/01
103 | ```
104 |
105 | You can get a dev result from `out/squad_doc/01/performance.txt` like this:
106 | ```bash
107 | Ranker, type: test_open, step: 19332, map: 0.000, mrr: 0.000, top_1: 0.000, top_3: 0.000, top_5: 0.000, top_7: 0.000, retrieval_rate: 0.190
108 | Reader, type: test_open, step: 19332, em: 40.123, f1: 48.358
109 | ```
110 |
111 | ## TriviaQA
112 | ### Data Preprocessing
113 | The raw TriviaQA data is expected to be unzipped in `~/data/triviaqa`. Training
114 | or testing in the unfiltered setting requires the unfiltered data to be
115 | download to `~/data/triviaqa-unfiltered`.
116 | ```bash
117 | mkdir -p ~/data/triviaqa
118 | cd ~/data/triviaqa
119 | wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz
120 | tar xf triviaqa-rc.tar.gz
121 | rm triviaqa-rc.tar.gz
122 |
123 | cd ~/data
124 | wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz
125 | tar xf triviaqa-unfiltered.tar.gz
126 | rm triviaqa-unfiltered.tar.gz
127 | ```
128 |
129 | First tokenize evidence documents by
130 | ```shell
131 | python -m triviaqa.evidence_corpus --n_processes 8 --max_tokens 200
132 | ```
133 | where paragraphs that are less than 200 words are merged.
134 |
135 | Then tokenize questions and locate relevant answers spans in each document. Run
136 | ```shell
137 | python -m triviaqa.build_span_corpus {wiki|unfiltered} --n_processes 8
138 | ```
139 | to build the desired set. This builds pkl files in "data/triviaqa/{wiki|unfiltered}"
140 |
141 | Next, retrieve top-n paragraphs based on TF-IDF to construct the train and dev sets by
142 | ```shell
143 | python -m triviaqa.ablate_triviaqa_wiki --n_processes 8 --n_para_train 12 --n_para_dev 14 --n_para_test 14 --do_train --do_dev --do_test
144 | python -m triviaqa.ablate_triviaqa_unfiltered --n_processes 8 --n_para_train 12 --n_para_dev 14 --n_para_test 14 --do_train --do_dev --do_test
145 | ```
146 |
147 | ### Wikipedia Domain
148 | To run experiments on the TriviaQA-wiki dataset, first set up the environment:
149 | ```bash
150 | export DATA_DIR=data/triviaqa/wiki
151 | export BERT_DIR=bert-base-uncased
152 | ```
153 |
154 | Then run the the following command to train the model:
155 | ```shell
156 | python -m bert.run_triviaqa_wiki_full_e2e \
157 | --vocab_file $BERT_DIR/vocab.txt \
158 | --bert_config_file $BERT_DIR/bert_config.json \
159 | --init_checkpoint $BERT_DIR/pytorch_model.bin \
160 | --do_train \
161 | --do_dev \
162 | --data_dir $DATA_DIR \
163 | --train_batch_size 32 \
164 | --learning_rate 3e-5 \
165 | --num_train_epochs 2.0 \
166 | --output_dir out/triviaqa_wiki/01
167 | ```
168 |
169 | Once the training is finished, a dev result can be obtained from `out/triviaqa_wiki/01/performance.txt` as:
170 | ```bash
171 | Ranker, type: dev, step: 20088, map: 0.778, mrr: 0.849, top_1: 0.797, top_3: 0.888, top_5: 0.918, top_7: 0.932, retrieval_rate: 0.460
172 | Reader, type: dev, step: 20088, em: 68.510, f1: 72.680
173 | ```
174 |
175 | ### Unfiltered Domain
176 | To run experiments on the TriviaQA-unfiltered dataset, first set up the environment:
177 | ```bash
178 | export DATA_DIR=data/triviaqa/unfiltered
179 | export BERT_DIR=bert-base-uncased
180 | ```
181 |
182 | Then run the the following command to train the model:
183 | ```shell
184 | python -m bert.run_triviaqa_wiki_full_e2e \
185 | --vocab_file $BERT_DIR/vocab.txt \
186 | --bert_config_file $BERT_DIR/bert_config.json \
187 | --init_checkpoint $BERT_DIR/pytorch_model.bin \
188 | --do_train \
189 | --do_dev \
190 | --data_dir $DATA_DIR \
191 | --train_batch_size 32 \
192 | --learning_rate 3e-5 \
193 | --num_train_epochs 2.0 \
194 | --output_dir out/triviaqa_unfiltered/01
195 | ```
196 |
197 | Once the training is finished, a dev result can be obtained from `out/triviaqa_unfiltered/01/performance.txt` as:
198 | ```bash
199 | Ranker, type: dev, step: 26726, map: 0.737, mrr: 0.781, top_1: 0.749, top_3: 0.806, top_5: 0.824, top_7: 0.831, retrieval_rate: 0.392
200 | Reader, type: dev, step: 26726, em: 63.953, f1: 69.506
201 | ```
202 |
203 | ## Acknowledgements
204 | Some preprocessing codes were modified from the [document-qa](https://github.com/allenai/document-qa) implementation.
205 |
206 | The BERT implementation is based on [pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT).
207 |
208 | If you find the paper or this repository helpful in your work, please use the following citation:
209 | ```
210 | @inproceedings{hu2019retrieve,
211 | title={Retrieve, Read, Rerank: Towards End-to-End Multi-Document Reading Comprehension},
212 | author={Hu, Minghao and Peng, Yuxing and Huang, Zhen and Li, Dongsheng},
213 | booktitle={Proceedings of ACL},
214 | year={2019}
215 | }
216 | ```
217 |
--------------------------------------------------------------------------------
/bert/custom_modeling.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import copy
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import CrossEntropyLoss, MSELoss
9 |
10 | from bert.modeling import BertConfig, BERTLayerNorm, BERTLayer, BERTEmbeddings, BERTPooler
11 |
12 | def flatten(x):
13 | if len(x.size()) == 2:
14 | batch_size = x.size()[0]
15 | seq_length = x.size()[1]
16 | return x.view([batch_size * seq_length])
17 | elif len(x.size()) == 3:
18 | batch_size = x.size()[0]
19 | seq_length = x.size()[1]
20 | hidden_size = x.size()[2]
21 | return x.view([batch_size * seq_length, hidden_size])
22 | else:
23 | raise Exception()
24 |
25 | def reconstruct(x, ref):
26 | if len(x.size()) == 1:
27 | batch_size = ref.size()[0]
28 | turn_num = ref.size()[1]
29 | return x.view([batch_size, turn_num])
30 | elif len(x.size()) == 2:
31 | batch_size = ref.size()[0]
32 | turn_num = ref.size()[1]
33 | sequence_length = x.size()[1]
34 | return x.view([batch_size, turn_num, sequence_length])
35 | else:
36 | raise Exception()
37 |
38 | def flatten_emb_by_sentence(emb, emb_mask):
39 | batch_size = emb.size()[0]
40 | seq_length = emb.size()[1]
41 | flat_emb = flatten(emb)
42 | flat_emb_mask = emb_mask.view([batch_size * seq_length])
43 | return flat_emb[flat_emb_mask.nonzero().squeeze(), :]
44 |
45 | def get_span_representation(span_starts, span_ends, input, input_mask):
46 | '''
47 | :param span_starts: [N, M]
48 | :param span_ends: [N, M]
49 | :param input: [N, L, D]
50 | :param input_mask: [N, L]
51 | :return: [N*M, JR, D], [N*M, JR]
52 | '''
53 | input_mask = input_mask.to(dtype=span_starts.dtype) # fp16 compatibility
54 | input_len = torch.sum(input_mask, dim=-1) # [N]
55 | word_offset = torch.cumsum(input_len, dim=0) # [N]
56 | word_offset -= input_len
57 |
58 | span_starts_offset = span_starts + word_offset.unsqueeze(1)
59 | span_ends_offset = span_ends + word_offset.unsqueeze(1)
60 |
61 | span_starts_offset = span_starts_offset.view([-1]) # [N*M]
62 | span_ends_offset = span_ends_offset.view([-1])
63 |
64 | span_width = span_ends_offset - span_starts_offset + 1
65 | JR = torch.max(span_width)
66 |
67 | context_outputs = flatten_emb_by_sentence(input, input_mask) # [ 1:
262 | start_positions = start_positions.squeeze(-1)
263 | if len(end_positions.size()) > 1:
264 | end_positions = end_positions.squeeze(-1)
265 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
266 | ignored_index = start_logits.size(1)
267 | start_positions.clamp_(0, ignored_index)
268 | end_positions.clamp_(0, ignored_index)
269 |
270 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
271 | start_loss = loss_fct(start_logits, start_positions)
272 | end_loss = loss_fct(end_logits, end_positions)
273 | read_loss = (start_loss + end_loss) / 2
274 |
275 | assert span_starts is not None and span_ends is not None and hard_labels is not None and soft_labels is not None
276 | span_output, span_mask = get_span_representation(span_starts, span_ends, sequence_output,
277 | attention_mask) # [N*M, JR, D], [N*M, JR]
278 | span_score = self.rerank_affine(span_output)
279 | span_score = span_score.squeeze(-1) # [N*M, JR]
280 | span_pooled_output = get_self_att_representation(span_output, span_score, span_mask) # [N*M, D]
281 |
282 | span_pooled_output = self.rerank_dense(span_pooled_output)
283 | span_pooled_output = self.activation(span_pooled_output)
284 | span_pooled_output = self.dropout(span_pooled_output)
285 | rerank_logits = self.rerank_classifier(span_pooled_output).squeeze(-1)
286 | rerank_logits = reconstruct(rerank_logits, span_starts)
287 |
288 | hard_loss = distant_cross_entropy(rerank_logits, hard_labels)
289 | soft_loss_fct = MSELoss()
290 | soft_loss = soft_loss_fct(rerank_logits, soft_labels.to(dtype=rerank_logits.dtype))
291 | rerank_loss = hard_loss + soft_loss
292 | return read_loss + rerank_loss
293 |
294 | else:
295 | raise Exception
296 |
297 | class BertForRankingAndDistantReadingAndReranking(nn.Module):
298 | def __init__(self, config, num_hidden_rank):
299 | super(BertForRankingAndDistantReadingAndReranking, self).__init__()
300 | super(BertForRankingAndDistantReadingAndReranking, self).__init__()
301 | self.num_hidden_rank = num_hidden_rank
302 | self.num_hidden_read = config.num_hidden_layers
303 | self.bert = EarlyStopBertModel(config)
304 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
305 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
306 | self.activation = nn.Tanh()
307 | self.rank_affine = nn.Linear(config.hidden_size, 1)
308 | self.rank_dense = nn.Linear(config.hidden_size, config.hidden_size)
309 | self.rank_classifier = nn.Linear(config.hidden_size, 2)
310 | self.read_affine = nn.Linear(config.hidden_size, 2)
311 | self.rerank_affine = nn.Linear(config.hidden_size, 1)
312 | self.rerank_dense = nn.Linear(config.hidden_size, config.hidden_size)
313 | self.rerank_classifier = nn.Linear(config.hidden_size, 1)
314 |
315 | def init_weights(module):
316 | if isinstance(module, (nn.Linear, nn.Embedding)):
317 | # Slightly different from the TF version which uses truncated_normal for initialization
318 | # cf https://github.com/pytorch/pytorch/pull/5617
319 | module.weight.data.normal_(mean=0.0, std=config.initializer_range)
320 | elif isinstance(module, BERTLayerNorm):
321 | module.beta.data.normal_(mean=0.0, std=config.initializer_range)
322 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
323 | if isinstance(module, nn.Linear):
324 | module.bias.data.zero_()
325 | self.apply(init_weights)
326 |
327 | def forward(self, mode, attention_mask, input_ids=None, token_type_ids=None, rank_labels=None, start_positions=None,
328 | end_positions=None, span_starts=None, span_ends=None, hard_labels=None, soft_labels=None,
329 | sequence_input=None):
330 | if mode == 'rank':
331 | assert input_ids is not None and token_type_ids is not None
332 | all_encoder_layers, _ = self.bert(self.num_hidden_rank, input_ids, token_type_ids, attention_mask)
333 | sequence_output = all_encoder_layers[-1]
334 |
335 | sequence_weights = self.rank_affine(sequence_output).squeeze(-1)
336 | pooled_output = get_self_att_representation(sequence_output, sequence_weights, attention_mask)
337 |
338 | pooled_output = self.rank_dense(pooled_output)
339 | pooled_output = self.activation(pooled_output)
340 | pooled_output = self.dropout(pooled_output)
341 | rank_logits = self.rank_classifier(pooled_output)
342 |
343 | if rank_labels is not None:
344 | rank_loss_fct = CrossEntropyLoss()
345 | rank_loss = rank_loss_fct(rank_logits, rank_labels)
346 | return rank_loss
347 | else:
348 | return rank_logits
349 |
350 | elif mode == 'read_inference':
351 | assert input_ids is not None and token_type_ids is not None
352 | all_encoder_layers, _ = self.bert(self.num_hidden_read, input_ids, token_type_ids, attention_mask)
353 | sequence_output = all_encoder_layers[-1]
354 |
355 | logits = self.read_affine(sequence_output)
356 | start_logits, end_logits = logits.split(1, dim=-1)
357 | start_logits = start_logits.squeeze(-1)
358 | end_logits = end_logits.squeeze(-1)
359 | return start_logits, end_logits, sequence_output
360 |
361 | elif mode == 'rerank_inference':
362 | assert span_starts is not None and span_ends is not None and sequence_input is not None
363 | span_output, span_mask = get_span_representation(span_starts, span_ends, sequence_input,
364 | attention_mask) # [N*M, JR, D], [N*M, JR]
365 |
366 | span_weights = self.rerank_affine(span_output).squeeze(-1)
367 | span_pooled_output = get_self_att_representation(span_output, span_weights, span_mask) # [N*M, D]
368 |
369 | span_pooled_output = self.rerank_dense(span_pooled_output)
370 | span_pooled_output = self.activation(span_pooled_output)
371 | span_pooled_output = self.dropout(span_pooled_output)
372 | rerank_logits = self.rerank_classifier(span_pooled_output).squeeze(-1)
373 | rerank_logits = reconstruct(rerank_logits, span_starts)
374 | return rerank_logits
375 |
376 | elif mode == 'read_rerank_train':
377 | assert input_ids is not None and token_type_ids is not None
378 | assert start_positions is not None and end_positions is not None
379 | all_encoder_layers, _ = self.bert(self.num_hidden_read, input_ids, token_type_ids, attention_mask)
380 | sequence_output = all_encoder_layers[-1]
381 |
382 | logits = self.read_affine(sequence_output)
383 | start_logits, end_logits = logits.split(1, dim=-1)
384 | start_logits = start_logits.squeeze(-1)
385 | end_logits = end_logits.squeeze(-1)
386 |
387 | start_loss = distant_cross_entropy(start_logits, start_positions)
388 | end_loss = distant_cross_entropy(end_logits, end_positions)
389 | read_loss = (start_loss + end_loss) / 2
390 |
391 | assert span_starts is not None and span_ends is not None and hard_labels is not None and soft_labels is not None
392 | span_output, span_mask = get_span_representation(span_starts, span_ends, sequence_output,
393 | attention_mask) # [N*M, JR, D], [N*M, JR]
394 | span_score = self.rerank_affine(span_output)
395 | span_score = span_score.squeeze(-1) # [N*M, JR]
396 | span_pooled_output = get_self_att_representation(span_output, span_score, span_mask) # [N*M, D]
397 |
398 | span_pooled_output = self.rerank_dense(span_pooled_output)
399 | span_pooled_output = self.activation(span_pooled_output)
400 | span_pooled_output = self.dropout(span_pooled_output)
401 | rerank_logits = self.rerank_classifier(span_pooled_output).squeeze(-1)
402 | rerank_logits = reconstruct(rerank_logits, span_starts)
403 |
404 | hard_loss = distant_cross_entropy(rerank_logits, hard_labels)
405 | soft_loss_fct = MSELoss()
406 | soft_loss = soft_loss_fct(rerank_logits, soft_labels.to(dtype=rerank_logits.dtype))
407 | rerank_loss = hard_loss + soft_loss
408 | return read_loss + rerank_loss
409 |
410 | else:
411 | raise Exception
412 |
--------------------------------------------------------------------------------
/bert/modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch BERT model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import copy
22 | import json
23 | import math
24 | import six
25 | import torch
26 | import torch.nn as nn
27 | from torch.nn import CrossEntropyLoss
28 |
29 |
30 | def gelu(x):
31 | """Implementation of the gelu activation function.
32 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
33 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
34 | """
35 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
36 |
37 |
38 | class BertConfig(object):
39 | """Configuration class to store the configuration of a `BertModel`.
40 | """
41 | def __init__(self,
42 | vocab_size,
43 | hidden_size=768,
44 | num_hidden_layers=12,
45 | num_attention_heads=12,
46 | intermediate_size=3072,
47 | hidden_act="gelu",
48 | hidden_dropout_prob=0.1,
49 | attention_probs_dropout_prob=0.1,
50 | max_position_embeddings=512,
51 | type_vocab_size=16,
52 | initializer_range=0.02):
53 | """Constructs BertConfig.
54 |
55 | Args:
56 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
57 | hidden_size: Size of the encoder layers and the pooler layer.
58 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
59 | num_attention_heads: Number of attention heads for each attention layer in
60 | the Transformer encoder.
61 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
62 | layer in the Transformer encoder.
63 | hidden_act: The non-linear activation function (function or string) in the
64 | encoder and pooler.
65 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
66 | layers in the embeddings, encoder, and pooler.
67 | attention_probs_dropout_prob: The dropout ratio for the attention
68 | probabilities.
69 | max_position_embeddings: The maximum sequence length that this model might
70 | ever be used with. Typically set this to something large just in case
71 | (e.g., 512 or 1024 or 2048).
72 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
73 | `BertModel`.
74 | initializer_range: The sttdev of the truncated_normal_initializer for
75 | initializing all weight matrices.
76 | """
77 | self.vocab_size = vocab_size
78 | self.hidden_size = hidden_size
79 | self.num_hidden_layers = num_hidden_layers
80 | self.num_attention_heads = num_attention_heads
81 | self.hidden_act = hidden_act
82 | self.intermediate_size = intermediate_size
83 | self.hidden_dropout_prob = hidden_dropout_prob
84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
85 | self.max_position_embeddings = max_position_embeddings
86 | self.type_vocab_size = type_vocab_size
87 | self.initializer_range = initializer_range
88 |
89 | @classmethod
90 | def from_dict(cls, json_object):
91 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
92 | config = BertConfig(vocab_size=None)
93 | for (key, value) in six.iteritems(json_object):
94 | config.__dict__[key] = value
95 | return config
96 |
97 | @classmethod
98 | def from_json_file(cls, json_file):
99 | """Constructs a `BertConfig` from a json file of parameters."""
100 | with open(json_file, "r") as reader:
101 | text = reader.read()
102 | return cls.from_dict(json.loads(text))
103 |
104 | def to_dict(self):
105 | """Serializes this instance to a Python dictionary."""
106 | output = copy.deepcopy(self.__dict__)
107 | return output
108 |
109 | def to_json_string(self):
110 | """Serializes this instance to a JSON string."""
111 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
112 |
113 |
114 | class BERTLayerNorm(nn.Module):
115 | def __init__(self, config, variance_epsilon=1e-12):
116 | """Construct a layernorm module in the TF style (epsilon inside the square root).
117 | """
118 | super(BERTLayerNorm, self).__init__()
119 | self.gamma = nn.Parameter(torch.ones(config.hidden_size))
120 | self.beta = nn.Parameter(torch.zeros(config.hidden_size))
121 | self.variance_epsilon = variance_epsilon
122 |
123 | def forward(self, x):
124 | u = x.mean(-1, keepdim=True)
125 | s = (x - u).pow(2).mean(-1, keepdim=True)
126 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
127 | return self.gamma * x + self.beta
128 |
129 |
130 | class BERTEmbeddings(nn.Module):
131 | def __init__(self, config):
132 | super(BERTEmbeddings, self).__init__()
133 | """Construct the embedding module from word, position and token_type embeddings.
134 | """
135 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
136 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
137 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
138 |
139 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
140 | # any TensorFlow checkpoint file
141 | self.LayerNorm = BERTLayerNorm(config)
142 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
143 |
144 | def forward(self, input_ids, token_type_ids=None):
145 | seq_length = input_ids.size(1)
146 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
147 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
148 | if token_type_ids is None:
149 | token_type_ids = torch.zeros_like(input_ids)
150 |
151 | words_embeddings = self.word_embeddings(input_ids)
152 | position_embeddings = self.position_embeddings(position_ids)
153 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
154 |
155 | embeddings = words_embeddings + position_embeddings + token_type_embeddings
156 | embeddings = self.LayerNorm(embeddings)
157 | embeddings = self.dropout(embeddings)
158 | return embeddings
159 |
160 |
161 | class BERTSelfAttention(nn.Module):
162 | def __init__(self, config):
163 | super(BERTSelfAttention, self).__init__()
164 | if config.hidden_size % config.num_attention_heads != 0:
165 | raise ValueError(
166 | "The hidden size (%d) is not a multiple of the number of attention "
167 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
168 | self.num_attention_heads = config.num_attention_heads
169 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
170 | self.all_head_size = self.num_attention_heads * self.attention_head_size
171 |
172 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
173 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
174 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
175 |
176 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
177 |
178 | def transpose_for_scores(self, x):
179 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
180 | x = x.view(*new_x_shape)
181 | return x.permute(0, 2, 1, 3)
182 |
183 | def forward(self, hidden_states, attention_mask):
184 | mixed_query_layer = self.query(hidden_states) # [N, L, H]
185 | mixed_key_layer = self.key(hidden_states)
186 | mixed_value_layer = self.value(hidden_states)
187 |
188 | query_layer = self.transpose_for_scores(mixed_query_layer) # [N, K, L, H//K]
189 | key_layer = self.transpose_for_scores(mixed_key_layer)
190 | value_layer = self.transpose_for_scores(mixed_value_layer)
191 |
192 | # Take the dot product between "query" and "key" to get the raw attention scores.
193 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [N, K, L, L]
194 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
195 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
196 | attention_scores = attention_scores + attention_mask
197 |
198 | # Normalize the attention scores to probabilities.
199 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
200 |
201 | # This is actually dropping out entire tokens to attend to, which might
202 | # seem a bit unusual, but is taken from the original Transformer paper.
203 | attention_probs = self.dropout(attention_probs)
204 |
205 | context_layer = torch.matmul(attention_probs, value_layer) # [N, K, L, H//K]
206 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [N, L, K, H//K]
207 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
208 | context_layer = context_layer.view(*new_context_layer_shape) # [N, L, H]
209 | return context_layer
210 |
211 |
212 | class BERTSelfOutput(nn.Module):
213 | def __init__(self, config):
214 | super(BERTSelfOutput, self).__init__()
215 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
216 | self.LayerNorm = BERTLayerNorm(config)
217 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
218 |
219 | def forward(self, hidden_states, input_tensor):
220 | hidden_states = self.dense(hidden_states)
221 | hidden_states = self.dropout(hidden_states)
222 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
223 | return hidden_states
224 |
225 |
226 | class BERTAttention(nn.Module):
227 | def __init__(self, config):
228 | super(BERTAttention, self).__init__()
229 | self.self = BERTSelfAttention(config)
230 | self.output = BERTSelfOutput(config)
231 |
232 | def forward(self, input_tensor, attention_mask):
233 | self_output = self.self(input_tensor, attention_mask)
234 | attention_output = self.output(self_output, input_tensor)
235 | return attention_output
236 |
237 |
238 | class BERTIntermediate(nn.Module):
239 | def __init__(self, config):
240 | super(BERTIntermediate, self).__init__()
241 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
242 | self.intermediate_act_fn = gelu
243 |
244 | def forward(self, hidden_states):
245 | hidden_states = self.dense(hidden_states)
246 | hidden_states = self.intermediate_act_fn(hidden_states)
247 | return hidden_states
248 |
249 |
250 | class BERTOutput(nn.Module):
251 | def __init__(self, config):
252 | super(BERTOutput, self).__init__()
253 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
254 | self.LayerNorm = BERTLayerNorm(config)
255 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
256 |
257 | def forward(self, hidden_states, input_tensor):
258 | hidden_states = self.dense(hidden_states)
259 | hidden_states = self.dropout(hidden_states)
260 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
261 | return hidden_states
262 |
263 |
264 | class BERTLayer(nn.Module):
265 | def __init__(self, config):
266 | super(BERTLayer, self).__init__()
267 | self.attention = BERTAttention(config)
268 | self.intermediate = BERTIntermediate(config)
269 | self.output = BERTOutput(config)
270 |
271 | def forward(self, hidden_states, attention_mask):
272 | attention_output = self.attention(hidden_states, attention_mask)
273 | intermediate_output = self.intermediate(attention_output)
274 | layer_output = self.output(intermediate_output, attention_output)
275 | return layer_output
276 |
277 |
278 | class BERTEncoder(nn.Module):
279 | def __init__(self, config):
280 | super(BERTEncoder, self).__init__()
281 | layer = BERTLayer(config)
282 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
283 |
284 | def forward(self, hidden_states, attention_mask):
285 | all_encoder_layers = []
286 | for layer_module in self.layer:
287 | hidden_states = layer_module(hidden_states, attention_mask)
288 | all_encoder_layers.append(hidden_states)
289 | return all_encoder_layers
290 |
291 |
292 | class BERTPooler(nn.Module):
293 | def __init__(self, config):
294 | super(BERTPooler, self).__init__()
295 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
296 | self.activation = nn.Tanh()
297 |
298 | def forward(self, hidden_states):
299 | # We "pool" the model by simply taking the hidden state corresponding
300 | # to the first token.
301 | first_token_tensor = hidden_states[:, 0]
302 | pooled_output = self.dense(first_token_tensor)
303 | pooled_output = self.activation(pooled_output)
304 | return pooled_output
305 |
306 |
307 | class BertModel(nn.Module):
308 | """BERT model ("Bidirectional Embedding Representations from a Transformer").
309 |
310 | Example usage:
311 | ```python
312 | # Already been converted into WordPiece token ids
313 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
314 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
315 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
316 |
317 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
318 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
319 |
320 | model = modeling.BertModel(config=config)
321 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
322 | ```
323 | """
324 | def __init__(self, config: BertConfig):
325 | """Constructor for BertModel.
326 |
327 | Args:
328 | config: `BertConfig` instance.
329 | """
330 | super(BertModel, self).__init__()
331 | self.embeddings = BERTEmbeddings(config)
332 | self.encoder = BERTEncoder(config)
333 | self.pooler = BERTPooler(config)
334 |
335 | def forward(self, input_ids, token_type_ids=None, attention_mask=None):
336 | if attention_mask is None:
337 | attention_mask = torch.ones_like(input_ids)
338 | if token_type_ids is None:
339 | token_type_ids = torch.zeros_like(input_ids)
340 |
341 | # We create a 3D attention mask from a 2D tensor mask.
342 | # Sizes are [batch_size, 1, 1, to_seq_length]
343 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
344 | # this attention mask is more simple than the triangular masking of causal attention
345 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
346 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
347 |
348 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
349 | # masked positions, this operation will create a tensor which is 0.0 for
350 | # positions we want to attend and -10000.0 for masked positions.
351 | # Since we are adding it to the raw scores before the softmax, this is
352 | # effectively the same as removing these entirely.
353 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
354 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
355 |
356 | embedding_output = self.embeddings(input_ids, token_type_ids)
357 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
358 | sequence_output = all_encoder_layers[-1]
359 | pooled_output = self.pooler(sequence_output)
360 | return all_encoder_layers, pooled_output
361 |
362 |
363 | class BertForSequenceClassification(nn.Module):
364 | """BERT model for classification.
365 | This module is composed of the BERT model with a linear layer on top of
366 | the pooled output.
367 |
368 | Example usage:
369 | ```python
370 | # Already been converted into WordPiece token ids
371 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
372 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
373 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
374 |
375 | config = BertConfig(vocab_size=32000, hidden_size=512,
376 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
377 |
378 | num_labels = 2
379 |
380 | model = BertForSequenceClassification(config, num_labels)
381 | logits = model(input_ids, token_type_ids, input_mask)
382 | ```
383 | """
384 | def __init__(self, config, num_labels):
385 | super(BertForSequenceClassification, self).__init__()
386 | self.bert = BertModel(config)
387 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
388 | self.classifier = nn.Linear(config.hidden_size, num_labels)
389 |
390 | def init_weights(module):
391 | if isinstance(module, (nn.Linear, nn.Embedding)):
392 | # Slightly different from the TF version which uses truncated_normal for initialization
393 | # cf https://github.com/pytorch/pytorch/pull/5617
394 | module.weight.data.normal_(mean=0.0, std=config.initializer_range)
395 | elif isinstance(module, BERTLayerNorm):
396 | module.beta.data.normal_(mean=0.0, std=config.initializer_range)
397 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
398 | if isinstance(module, nn.Linear):
399 | module.bias.data.zero_()
400 | self.apply(init_weights)
401 |
402 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
403 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
404 | pooled_output = self.dropout(pooled_output)
405 | logits = self.classifier(pooled_output)
406 |
407 | if labels is not None:
408 | loss_fct = CrossEntropyLoss()
409 | loss = loss_fct(logits, labels)
410 | return loss, logits
411 | else:
412 | return logits
413 |
414 |
415 | class BertForQuestionAnswering(nn.Module):
416 | """BERT model for Question Answering (span extraction).
417 | This module is composed of the BERT model with a linear layer on top of
418 | the sequence output that computes start_logits and end_logits
419 |
420 | Example usage:
421 | ```python
422 | # Already been converted into WordPiece token ids
423 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
424 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
425 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
426 |
427 | config = BertConfig(vocab_size=32000, hidden_size=512,
428 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
429 |
430 | model = BertForQuestionAnswering(config)
431 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
432 | ```
433 | """
434 | def __init__(self, config):
435 | super(BertForQuestionAnswering, self).__init__()
436 | self.bert = BertModel(config)
437 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
438 | # self.dropout = nn.Dropout(config.hidden_dropout_prob)
439 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
440 |
441 | def init_weights(module):
442 | if isinstance(module, (nn.Linear, nn.Embedding)):
443 | # Slightly different from the TF version which uses truncated_normal for initialization
444 | # cf https://github.com/pytorch/pytorch/pull/5617
445 | module.weight.data.normal_(mean=0.0, std=config.initializer_range)
446 | elif isinstance(module, BERTLayerNorm):
447 | module.beta.data.normal_(mean=0.0, std=config.initializer_range)
448 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
449 | if isinstance(module, nn.Linear):
450 | module.bias.data.zero_()
451 | self.apply(init_weights)
452 |
453 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
454 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
455 | sequence_output = all_encoder_layers[-1]
456 | logits = self.qa_outputs(sequence_output)
457 | start_logits, end_logits = logits.split(1, dim=-1)
458 | start_logits = start_logits.squeeze(-1)
459 | end_logits = end_logits.squeeze(-1)
460 |
461 | if start_positions is not None and end_positions is not None:
462 | # If we are on multi-GPU, split add a dimension
463 | if len(start_positions.size()) > 1:
464 | start_positions = start_positions.squeeze(-1)
465 | if len(end_positions.size()) > 1:
466 | end_positions = end_positions.squeeze(-1)
467 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
468 | ignored_index = start_logits.size(1)
469 | start_positions.clamp_(0, ignored_index)
470 | end_positions.clamp_(0, ignored_index)
471 |
472 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
473 | start_loss = loss_fct(start_logits, start_positions)
474 | end_loss = loss_fct(end_logits, end_positions)
475 | total_loss = (start_loss + end_loss) / 2
476 | return total_loss
477 | else:
478 | return start_logits, end_logits
--------------------------------------------------------------------------------
/bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.nn.utils import clip_grad_norm_
21 |
22 | def warmup_cosine(x, warmup=0.002):
23 | if x < warmup:
24 | return x/warmup
25 | return 0.5 * (1.0 + torch.cos(math.pi * x))
26 |
27 | def warmup_constant(x, warmup=0.002):
28 | if x < warmup:
29 | return x/warmup
30 | return 1.0
31 |
32 | def warmup_linear(x, warmup=0.002):
33 | if x < warmup:
34 | return x/warmup
35 | return 1.0 - x
36 |
37 | SCHEDULES = {
38 | 'warmup_cosine':warmup_cosine,
39 | 'warmup_constant':warmup_constant,
40 | 'warmup_linear':warmup_linear,
41 | }
42 |
43 |
44 | class BERTAdam(Optimizer):
45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ).
46 | Params:
47 | lr: learning rate
48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
49 | t_total: total number of training steps for the learning
50 | rate schedule, -1 means constant learning rate. Default: -1
51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
52 | b1: Adams b1. Default: 0.9
53 | b2: Adams b2. Default: 0.999
54 | e: Adams epsilon. Default: 1e-6
55 | weight_decay_rate: Weight decay. Default: 0.01
56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
57 | """
58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
60 | max_grad_norm=1.0):
61 | if not lr >= 0.0:
62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
63 | if schedule not in SCHEDULES:
64 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
65 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
67 | if not 0.0 <= b1 < 1.0:
68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
69 | if not 0.0 <= b2 < 1.0:
70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
71 | if not e >= 0.0:
72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
75 | max_grad_norm=max_grad_norm)
76 | super(BERTAdam, self).__init__(params, defaults)
77 |
78 | def get_lr(self):
79 | lr = []
80 | for group in self.param_groups:
81 | for p in group['params']:
82 | state = self.state[p]
83 | if len(state) == 0:
84 | return [0]
85 | if group['t_total'] != -1:
86 | schedule_fct = SCHEDULES[group['schedule']]
87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
88 | else:
89 | lr_scheduled = group['lr']
90 | lr.append(lr_scheduled)
91 | return lr
92 |
93 | def step(self, closure=None):
94 | """Performs a single optimization step.
95 |
96 | Arguments:
97 | closure (callable, optional): A closure that reevaluates the model
98 | and returns the loss.
99 | """
100 | loss = None
101 | if closure is not None:
102 | loss = closure()
103 |
104 | for group in self.param_groups:
105 | for p in group['params']:
106 | if p.grad is None:
107 | continue
108 | grad = p.grad.data
109 | if grad.is_sparse:
110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
111 |
112 | state = self.state[p]
113 |
114 | # State initialization
115 | if len(state) == 0:
116 | state['step'] = 0
117 | # Exponential moving average of gradient values
118 | state['next_m'] = torch.zeros_like(p.data)
119 | # Exponential moving average of squared gradient values
120 | state['next_v'] = torch.zeros_like(p.data)
121 |
122 | next_m, next_v = state['next_m'], state['next_v']
123 | beta1, beta2 = group['b1'], group['b2']
124 |
125 | # Add grad clipping
126 | if group['max_grad_norm'] > 0:
127 | clip_grad_norm_(p, group['max_grad_norm'])
128 |
129 | # Decay the first and second moment running average coefficient
130 | # In-place operations to update the averages at the same time
131 | next_m.mul_(beta1).add_(1 - beta1, grad)
132 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
133 | update = next_m / (next_v.sqrt() + group['e'])
134 |
135 | # Just adding the square of the weights to the loss function is *not*
136 | # the correct way of using L2 regularization/weight decay with Adam,
137 | # since that will interact with the m and v parameters in strange ways.
138 | #
139 | # Instead we want ot decay the weights in a manner that doesn't interact
140 | # with the m/v parameters. This is equivalent to adding the square
141 | # of the weights to the loss with plain (non-momentum) SGD.
142 | if group['weight_decay_rate'] > 0.0:
143 | update += group['weight_decay_rate'] * p.data
144 |
145 | if group['t_total'] != -1:
146 | schedule_fct = SCHEDULES[group['schedule']]
147 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
148 | else:
149 | lr_scheduled = group['lr']
150 |
151 | update_with_lr = lr_scheduled * update
152 | p.data.add_(-update_with_lr)
153 |
154 | state['step'] += 1
155 |
156 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
157 | # bias_correction1 = 1 - beta1 ** state['step']
158 | # bias_correction2 = 1 - beta2 ** state['step']
159 |
160 | return loss
161 |
--------------------------------------------------------------------------------
/bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import six
24 |
25 |
26 | def convert_to_unicode(text):
27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
28 | if six.PY3:
29 | if isinstance(text, str):
30 | return text
31 | elif isinstance(text, bytes):
32 | return text.decode("utf-8", "ignore")
33 | else:
34 | raise ValueError("Unsupported string type: %s" % (type(text)))
35 | elif six.PY2:
36 | if isinstance(text, str):
37 | return text.decode("utf-8", "ignore")
38 | elif isinstance(text, unicode):
39 | return text
40 | else:
41 | raise ValueError("Unsupported string type: %s" % (type(text)))
42 | else:
43 | raise ValueError("Not running on Python2 or Python 3?")
44 |
45 |
46 | def printable_text(text):
47 | """Returns text encoded in a way suitable for print or `tf.logging`."""
48 |
49 | # These functions want `str` for both Python2 and Python3, but in one case
50 | # it's a Unicode string and in the other it's a byte string.
51 | if six.PY3:
52 | if isinstance(text, str):
53 | return text
54 | elif isinstance(text, bytes):
55 | return text.decode("utf-8", "ignore")
56 | else:
57 | raise ValueError("Unsupported string type: %s" % (type(text)))
58 | elif six.PY2:
59 | if isinstance(text, str):
60 | return text
61 | elif isinstance(text, unicode):
62 | return text.encode("utf-8")
63 | else:
64 | raise ValueError("Unsupported string type: %s" % (type(text)))
65 | else:
66 | raise ValueError("Not running on Python2 or Python 3?")
67 |
68 |
69 | def load_vocab(vocab_file):
70 | """Loads a vocabulary file into a dictionary."""
71 | vocab = collections.OrderedDict()
72 | index = 0
73 | with open(vocab_file, "r", encoding="utf-8") as reader:
74 | while True:
75 | token = convert_to_unicode(reader.readline())
76 | if not token:
77 | break
78 | token = token.strip()
79 | vocab[token] = index
80 | index += 1
81 | return vocab
82 |
83 |
84 | def convert_tokens_to_ids(vocab, tokens):
85 | """Converts a sequence of tokens into ids using the vocab."""
86 | ids = []
87 | for token in tokens:
88 | ids.append(vocab[token])
89 | return ids
90 |
91 |
92 | def whitespace_tokenize(text):
93 | """Runs basic whitespace cleaning and splitting on a peice of text."""
94 | text = text.strip()
95 | if not text:
96 | return []
97 | tokens = text.split()
98 | return tokens
99 |
100 |
101 | class FullTokenizer(object):
102 | """Runs end-to-end tokenziation."""
103 |
104 | def __init__(self, vocab_file, do_lower_case=True):
105 | self.vocab = load_vocab(vocab_file)
106 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
107 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
108 |
109 | def tokenize(self, text):
110 | split_tokens = []
111 | for token in self.basic_tokenizer.tokenize(text):
112 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
113 | split_tokens.append(sub_token)
114 |
115 | return split_tokens
116 |
117 | def convert_tokens_to_ids(self, tokens):
118 | return convert_tokens_to_ids(self.vocab, tokens)
119 |
120 |
121 | class BasicTokenizer(object):
122 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
123 |
124 | def __init__(self, do_lower_case=True):
125 | """Constructs a BasicTokenizer.
126 |
127 | Args:
128 | do_lower_case: Whether to lower case the input.
129 | """
130 | self.do_lower_case = do_lower_case
131 |
132 | def tokenize(self, text):
133 | """Tokenizes a piece of text."""
134 | text = convert_to_unicode(text)
135 | text = self._clean_text(text)
136 | # This was added on November 1st, 2018 for the multilingual and Chinese
137 | # models. This is also applied to the English models now, but it doesn't
138 | # matter since the English models were not trained on any Chinese data
139 | # and generally don't have any Chinese data in them (there are Chinese
140 | # characters in the vocabulary because Wikipedia does have some Chinese
141 | # words in the English Wikipedia.).
142 | text = self._tokenize_chinese_chars(text)
143 | orig_tokens = whitespace_tokenize(text)
144 | split_tokens = []
145 | for token in orig_tokens:
146 | if self.do_lower_case:
147 | token = token.lower()
148 | token = self._run_strip_accents(token)
149 | split_tokens.extend(self._run_split_on_punc(token))
150 |
151 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
152 | return output_tokens
153 |
154 | def _run_strip_accents(self, text):
155 | """Strips accents from a piece of text."""
156 | text = unicodedata.normalize("NFD", text)
157 | output = []
158 | for char in text:
159 | cat = unicodedata.category(char)
160 | if cat == "Mn":
161 | continue
162 | output.append(char)
163 | return "".join(output)
164 |
165 | def _run_split_on_punc(self, text):
166 | """Splits punctuation on a piece of text."""
167 | chars = list(text)
168 | i = 0
169 | start_new_word = True
170 | output = []
171 | while i < len(chars):
172 | char = chars[i]
173 | if _is_punctuation(char):
174 | output.append([char])
175 | start_new_word = True
176 | else:
177 | if start_new_word:
178 | output.append([])
179 | start_new_word = False
180 | output[-1].append(char)
181 | i += 1
182 |
183 | return ["".join(x) for x in output]
184 |
185 | def _tokenize_chinese_chars(self, text):
186 | """Adds whitespace around any CJK character."""
187 | output = []
188 | for char in text:
189 | cp = ord(char)
190 | if self._is_chinese_char(cp):
191 | output.append(" ")
192 | output.append(char)
193 | output.append(" ")
194 | else:
195 | output.append(char)
196 | return "".join(output)
197 |
198 | def _is_chinese_char(self, cp):
199 | """Checks whether CP is the codepoint of a CJK character."""
200 | # This defines a "chinese character" as anything in the CJK Unicode block:
201 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
202 | #
203 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
204 | # despite its name. The modern Korean Hangul alphabet is a different block,
205 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
206 | # space-separated words, so they are not treated specially and handled
207 | # like the all of the other languages.
208 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
209 | (cp >= 0x3400 and cp <= 0x4DBF) or #
210 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
211 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
212 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
213 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
214 | (cp >= 0xF900 and cp <= 0xFAFF) or #
215 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
216 | return True
217 |
218 | return False
219 |
220 | def _clean_text(self, text):
221 | """Performs invalid character removal and whitespace cleanup on text."""
222 | output = []
223 | for char in text:
224 | cp = ord(char)
225 | if cp == 0 or cp == 0xfffd or _is_control(char):
226 | continue
227 | if _is_whitespace(char):
228 | output.append(" ")
229 | else:
230 | output.append(char)
231 | return "".join(output)
232 |
233 |
234 | class WordpieceTokenizer(object):
235 | """Runs WordPiece tokenization."""
236 |
237 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
238 | self.vocab = vocab
239 | self.unk_token = unk_token
240 | self.max_input_chars_per_word = max_input_chars_per_word
241 |
242 | def tokenize(self, text):
243 | """Tokenizes a piece of text into its word pieces.
244 |
245 | This uses a greedy longest-match-first algorithm to perform tokenization
246 | using the given vocabulary.
247 |
248 | For example:
249 | input = "unaffable"
250 | output = ["un", "##aff", "##able"]
251 |
252 | Args:
253 | text: A single token or whitespace separated tokens. This should have
254 | already been passed through `BasicTokenizer.
255 |
256 | Returns:
257 | A list of wordpiece tokens.
258 | """
259 |
260 | text = convert_to_unicode(text)
261 |
262 | output_tokens = []
263 | for token in whitespace_tokenize(text):
264 | chars = list(token)
265 | if len(chars) > self.max_input_chars_per_word:
266 | output_tokens.append(self.unk_token)
267 | continue
268 |
269 | is_bad = False
270 | start = 0
271 | sub_tokens = []
272 | while start < len(chars):
273 | end = len(chars)
274 | cur_substr = None
275 | while start < end:
276 | substr = "".join(chars[start:end])
277 | if start > 0:
278 | substr = "##" + substr
279 | if substr in self.vocab:
280 | cur_substr = substr
281 | break
282 | end -= 1
283 | if cur_substr is None:
284 | is_bad = True
285 | break
286 | sub_tokens.append(cur_substr)
287 | start = end
288 |
289 | if is_bad:
290 | output_tokens.append(self.unk_token)
291 | else:
292 | output_tokens.extend(sub_tokens)
293 | return output_tokens
294 |
295 |
296 | def _is_whitespace(char):
297 | """Checks whether `chars` is a whitespace character."""
298 | # \t, \n, and \r are technically contorl characters but we treat them
299 | # as whitespace since they are generally considered as such.
300 | if char == " " or char == "\t" or char == "\n" or char == "\r":
301 | return True
302 | cat = unicodedata.category(char)
303 | if cat == "Zs":
304 | return True
305 | return False
306 |
307 |
308 | def _is_control(char):
309 | """Checks whether `chars` is a control character."""
310 | # These are technically control characters but we count them as whitespace
311 | # characters.
312 | if char == "\t" or char == "\n" or char == "\r":
313 | return False
314 | cat = unicodedata.category(char)
315 | if cat.startswith("C"):
316 | return True
317 | return False
318 |
319 |
320 | def _is_punctuation(char):
321 | """Checks whether `chars` is a punctuation character."""
322 | cp = ord(char)
323 | # We treat all non-letter/number ASCII as punctuation.
324 | # Characters such as "^", "$", and "`" are not in the Unicode
325 | # Punctuation class but we treat them as punctuation anyways, for
326 | # consistency.
327 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
328 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
329 | return True
330 | cat = unicodedata.category(char)
331 | if cat.startswith("P"):
332 | return True
333 | return False
334 |
--------------------------------------------------------------------------------
/image/framework.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huminghao16/RE3QA/14faa386b519bed7c94ddff399afdb2c9967de44/image/framework.PNG
--------------------------------------------------------------------------------
/squad/convert_squad_open.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 | import numpy as np
4 | from os.path import relpath, join, exists, expanduser
5 | from sklearn.feature_extraction.text import TfidfVectorizer
6 | from sklearn.metrics import pairwise_distances
7 | from typing import List, TypeVar, Iterable
8 | from tqdm import tqdm
9 |
10 | import bert.tokenization as tokenization
11 | from triviaqa.evidence_corpus import MergeParagraphs
12 | from triviaqa.build_span_corpus import FastNormalizedAnswerDetector
13 | from squad.squad_document_utils import DocumentAndQuestion
14 |
15 | stop_words = {'t', '–', 'there', 'but', 'needn', 'themselves', '’', '~', '$', 'few', '^', '₹', ']', 'we', 're',
16 | 'again', '?', 'they', 'ain', 'o', 'you', '+', 'has', 'by', 'than', 'whom', 'same', 'don', 'her',
17 | 'are', '(', 'an', 'so', 'the', 'been', 'wouldn', 'a', 'many', 'she', 'how', 'your', '°', 'do',
18 | 'shan', 'himself', 'between', 'ours', 'at', 'should', 'doesn', 'hasn', 'he', 'have', 'over',
19 | 'hadn', 'was', 'weren', 'down', 'above', '_', 'those', 'not', 'having', 'its', 'ourselves',
20 | 'for', 'when', 'if', ',', ';', 'about', 'theirs', 'him', '}', 'here', 'any', 'own', 'itself',
21 | 'very', 'on', 'myself', 'mustn', ')', 'because', 'now', '/', 'isn', 'to', 'just', 'these',
22 | 'i', 'further', 'mightn', 'll', '@', 'am', '”', 'below', 'shouldn', 'my', 'who', 'yours', 'why',
23 | 'such', '"', 'does', 'did', 'before', 'being', 'and', 'had', 'aren', '£', 'with', 'more', 'into',
24 | '<', 'herself', 'which', '[', "'", 'of', 'haven', 'that', 'will', 'yourself', 'in', 'doing', '−',
25 | 'them', '‘', 'some', '`', 'while', 'each', 'it', 'through', 'all', 'their', ':', '\\', 'where',
26 | 'both', 'hers', '¢', '—', 'm', '.', 'from', 'or', 'other', 'too', 'couldn', 'as', 'our', 'off',
27 | '%', '&', '-', '{', '=', 'didn', 'yourselves', 'under', 'y', 'ma', 'won', '!', '|', 'against',
28 | '#', '¥', 'is', 'nor', 'up', 'most', 's', 'no', 'can', '>', '*', 'during', 'once', 'what', 'me',
29 | 'then', 'd', 'only', 'de', 've', 'were', '€', 'until', 'his', 'out', 'wasn', 'this', 'after',
30 | 'be'}
31 |
32 | class SquadOpenExample(object):
33 | def __init__(self, qas_id, question_text, answer_texts, doc_text):
34 | self.qas_id = qas_id
35 | self.question_text = question_text
36 | self.answer_texts = answer_texts
37 | self.doc_text = doc_text
38 |
39 | def __str__(self):
40 | return self.__repr__()
41 |
42 | def __repr__(self):
43 | s = ""
44 | s += "qas_id: %s" % self.qas_id
45 | s += ", question_text: %s" % self.question_text
46 | s += ", answer_texts: {}".format(self.answer_texts)
47 | s += ", doc_text: %s" % self.doc_text[:1000]
48 | return s
49 |
50 | def rank(tfidf, questions: List[str], paragraphs: List[str]):
51 | para_features = tfidf.fit_transform(paragraphs)
52 | q_features = tfidf.transform(questions)
53 | scores = pairwise_distances(q_features, para_features, "cosine")
54 | return scores
55 |
56 | def main():
57 | parse = argparse.ArgumentParser("Pre-tokenize the SQuAD open dev file")
58 | parse.add_argument("--input_file", type=str, default=join("data", "squad", "squad_dev_open.pkl"))
59 | # This is slow, using more processes is recommended
60 | parse.add_argument("--max_tokens", type=int, default=200, help="Number of maximal tokens in each merged paragraph")
61 | parse.add_argument("--n_to_select", type=int, default=30, help="Number of paragraphs to retrieve")
62 | parse.add_argument("--sort_passage", type=bool, default=True, help="Sort passage according to order")
63 | parse.add_argument("--debug", type=bool, default=False, help="Whether to run in debug mode")
64 | args = parse.parse_args()
65 |
66 | dev_examples = pickle.load(open(args.input_file, 'rb'))
67 |
68 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
69 | splitter = MergeParagraphs(args.max_tokens)
70 | tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=stop_words)
71 | detector = FastNormalizedAnswerDetector()
72 |
73 | ir_count, total_doc_length, pruned_doc_length = 0, 0, 0
74 | out = []
75 | for example_ix, example in tqdm(enumerate(dev_examples), total=len(dev_examples)):
76 | paras = [x for x in example.doc_text.split("\n") if len(x) > 0]
77 | paragraphs = [tokenizer.tokenize(x) for x in paras]
78 | merged_paragraphs = splitter.merge(paragraphs)
79 |
80 | scores = rank(tfidf, [example.question_text], [" ".join(x) for x in merged_paragraphs])
81 | para_scores = scores[0]
82 | para_ranks = np.argsort(para_scores)
83 | selection = [i for i in para_ranks[:args.n_to_select]]
84 |
85 | if args.sort_passage:
86 | selection = np.sort(selection)
87 |
88 | doc_tokens = []
89 | for idx in selection:
90 | current_para = merged_paragraphs[idx]
91 | doc_tokens += current_para
92 |
93 | tokenized_answers = [tokenizer.tokenize(x) for x in example.answer_texts]
94 | detector.set_question(tokenized_answers)
95 | if len(detector.any_found(doc_tokens)) > 0:
96 | ir_count += 1
97 |
98 | total_doc_length += sum(len(para) for para in merged_paragraphs)
99 | pruned_doc_length += len(doc_tokens)
100 |
101 | out.append(DocumentAndQuestion(example_ix, example.qas_id, example.question_text, doc_tokens,
102 | '', 0, 0, True))
103 | if args.debug and example_ix > 5:
104 | break
105 | print("Recall of answer existence in documents: {:.3f}".format(ir_count / len(out)))
106 | print("Average length of documents: {:.3f}".format(total_doc_length / len(out)))
107 | print("Average pruned length of documents: {:.3f}".format(pruned_doc_length / len(out)))
108 | output_file = join("data", "squad", "eval_open_{}paras_examples.pkl".format(args.n_to_select))
109 | pickle.dump(out, open(output_file, 'wb'))
110 |
111 | if __name__ == "__main__":
112 | main()
--------------------------------------------------------------------------------
/squad/squad_evaluate.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for v1.1 of the SQuAD dataset. [Changed name for external importing]"""
2 | from __future__ import print_function
3 | from collections import Counter
4 | import string
5 | import re
6 | import argparse
7 | import json
8 | import sys
9 |
10 |
11 | def span_len(span):
12 | return span[1] - span[0]
13 |
14 | def span_overlap(s1, s2):
15 | start = max(s1[0], s2[0])
16 | stop = min(s1[1], s2[1])
17 | if stop > start:
18 | return start, stop
19 | return None
20 |
21 | def span_prec(true_span, pred_span):
22 | overlap = span_overlap(true_span, pred_span)
23 | if overlap is None:
24 | return 0.
25 | return span_len(overlap) / span_len(pred_span)
26 |
27 | def span_recall(true_span, pred_span):
28 | overlap = span_overlap(true_span, pred_span)
29 | if overlap is None:
30 | return 0.
31 | return span_len(overlap) / span_len(true_span)
32 |
33 | def span_f1(true_span, pred_span):
34 | p = span_prec(true_span, pred_span)
35 | r = span_recall(true_span, pred_span)
36 | if p == 0 or r == 0:
37 | return 0.0
38 | return 2. * p * r / (p + r)
39 |
40 |
41 | def normalize_answer(s):
42 | """Lower text and remove punctuation, articles and extra whitespace."""
43 | def remove_articles(text):
44 | return re.sub(r'\b(a|an|the)\b', ' ', text)
45 |
46 | def white_space_fix(text):
47 | return ' '.join(text.split())
48 |
49 | def remove_punc(text):
50 | exclude = set(string.punctuation)
51 | return ''.join(ch for ch in text if ch not in exclude)
52 |
53 | def lower(text):
54 | return text.lower()
55 |
56 | return white_space_fix(remove_articles(remove_punc(lower(s))))
57 |
58 |
59 | def f1_score(prediction, ground_truth):
60 | prediction_tokens = normalize_answer(prediction).split()
61 | ground_truth_tokens = normalize_answer(ground_truth).split()
62 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
63 | num_same = sum(common.values())
64 | if num_same == 0:
65 | return 0
66 | precision = 1.0 * num_same / len(prediction_tokens)
67 | recall = 1.0 * num_same / len(ground_truth_tokens)
68 | f1 = (2 * precision * recall) / (precision + recall)
69 | return f1
70 |
71 |
72 | def exact_match_score(prediction, ground_truth):
73 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
74 |
75 |
76 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
77 | scores_for_ground_truths = []
78 | for ground_truth in ground_truths:
79 | score = metric_fn(prediction, ground_truth)
80 | scores_for_ground_truths.append(score)
81 | return max(scores_for_ground_truths)
82 |
83 |
84 | def evaluate(dataset, predictions):
85 | f1 = exact_match = total = 0
86 | missing_count = 0
87 | for article in dataset:
88 | for paragraph in article['paragraphs']:
89 | for qa in paragraph['qas']:
90 | total += 1
91 | if qa['id'] not in predictions:
92 | missing_count += 1
93 | # message = 'Unanswered question ' + qa['id'] + \
94 | # ' will receive score 0.'
95 | # print(message, file=sys.stderr)
96 | continue
97 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
98 | prediction = predictions[qa['id']]
99 | exact_match += metric_max_over_ground_truths(
100 | exact_match_score, prediction, ground_truths)
101 | f1 += metric_max_over_ground_truths(
102 | f1_score, prediction, ground_truths)
103 |
104 | exact_match = 100.0 * exact_match / (total-missing_count)
105 | f1 = 100.0 * f1 / (total-missing_count)
106 | print("missing prediction on %d examples" % (missing_count))
107 | return {'exact_match': exact_match, 'f1': f1}
108 |
109 |
110 | def merge_eval(main_eval, new_eval):
111 | for k in new_eval:
112 | main_eval['%s' % (k)] = new_eval[k]
113 |
114 |
115 | if __name__ == '__main__':
116 | expected_version = '1.1'
117 | parser = argparse.ArgumentParser(
118 | description='Evaluation for SQuAD ' + expected_version)
119 | parser.add_argument('dataset_file', help='Dataset file')
120 | parser.add_argument('prediction_file', help='Prediction File')
121 | args = parser.parse_args()
122 | with open(args.dataset_file) as dataset_file:
123 | dataset_json = json.load(dataset_file)
124 | # if (dataset_json['version'] != expected_version):
125 | # print('Evaluation expects v-' + expected_version +
126 | # ', but got dataset with v-' + dataset_json['version'],
127 | # file=sys.stderr)
128 | dataset = dataset_json['data']
129 | with open(args.prediction_file) as prediction_file:
130 | predictions = json.load(prediction_file)
131 | print(json.dumps(evaluate(dataset, predictions)))
132 |
133 | # prediction = '1854–1855'
134 | # ground_truths = ['1854']
135 | # print(metric_max_over_ground_truths(
136 | # f1_score, prediction, ground_truths))
137 |
--------------------------------------------------------------------------------
/squad/squad_open_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import math
4 | import collections
5 | import numpy as np
6 | from sklearn.feature_extraction.text import TfidfVectorizer
7 | from sklearn.metrics import pairwise_distances
8 | from typing import List, TypeVar, Iterable
9 |
10 | import bert.tokenization as tokenization
11 |
12 | T = TypeVar('T')
13 |
14 | stop_words = {'t', '–', 'there', 'but', 'needn', 'themselves', '’', '~', '$', 'few', '^', '₹', ']', 'we', 're',
15 | 'again', '?', 'they', 'ain', 'o', 'you', '+', 'has', 'by', 'than', 'whom', 'same', 'don', 'her',
16 | 'are', '(', 'an', 'so', 'the', 'been', 'wouldn', 'a', 'many', 'she', 'how', 'your', '°', 'do',
17 | 'shan', 'himself', 'between', 'ours', 'at', 'should', 'doesn', 'hasn', 'he', 'have', 'over',
18 | 'hadn', 'was', 'weren', 'down', 'above', '_', 'those', 'not', 'having', 'its', 'ourselves',
19 | 'for', 'when', 'if', ',', ';', 'about', 'theirs', 'him', '}', 'here', 'any', 'own', 'itself',
20 | 'very', 'on', 'myself', 'mustn', ')', 'because', 'now', '/', 'isn', 'to', 'just', 'these',
21 | 'i', 'further', 'mightn', 'll', '@', 'am', '”', 'below', 'shouldn', 'my', 'who', 'yours', 'why',
22 | 'such', '"', 'does', 'did', 'before', 'being', 'and', 'had', 'aren', '£', 'with', 'more', 'into',
23 | '<', 'herself', 'which', '[', "'", 'of', 'haven', 'that', 'will', 'yourself', 'in', 'doing', '−',
24 | 'them', '‘', 'some', '`', 'while', 'each', 'it', 'through', 'all', 'their', ':', '\\', 'where',
25 | 'both', 'hers', '¢', '—', 'm', '.', 'from', 'or', 'other', 'too', 'couldn', 'as', 'our', 'off',
26 | '%', '&', '-', '{', '=', 'didn', 'yourselves', 'under', 'y', 'ma', 'won', '!', '|', 'against',
27 | '#', '¥', 'is', 'nor', 'up', 'most', 's', 'no', 'can', '>', '*', 'during', 'once', 'what', 'me',
28 | 'then', 'd', 'only', 'de', 've', 'were', '€', 'until', 'his', 'out', 'wasn', 'this', 'after',
29 | 'be'}
30 |
31 |
32 | def flatten_iterable(listoflists: Iterable[Iterable[T]]) -> List[T]:
33 | return [item for sublist in listoflists for item in sublist]
34 |
35 |
36 | class Question(object):
37 | def __init__(self,
38 | qas_id,
39 | doc_index,
40 | para_index,
41 | question_text,
42 | answer_texts=None):
43 | self.qas_id = qas_id
44 | self.doc_index = doc_index
45 | self.para_index = para_index
46 | self.question_text = question_text
47 | self.answer_texts = answer_texts
48 |
49 | def __str__(self):
50 | return self.__repr__()
51 |
52 | def __repr__(self):
53 | s = ""
54 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
55 | s += "doc_index: %d" % (self.doc_index)
56 | s += "para_index: %d" % (self.para_index)
57 | s += ", question_text: %s" % (
58 | tokenization.printable_text(self.question_text))
59 | if self.answer_texts is not None:
60 | s += ", answer_texts: ".format(self.answer_texts)
61 | return s
62 |
63 |
64 | class Paragraph(object):
65 | def __init__(self,
66 | paragraph_id,
67 | paragraph_text):
68 | self.paragraph_id = paragraph_id
69 | self.paragraph_text = paragraph_text
70 |
71 | def __str__(self):
72 | return self.__repr__()
73 |
74 | def __repr__(self):
75 | s = ""
76 | s += "paragraph_id: %s" % (self.paragraph_id)
77 | return s
78 |
79 |
80 | class Document(object):
81 | def __init__(self, document_id: str, paragraphs: List[Paragraph]):
82 | self.document_id = document_id
83 | self.paragraphs = paragraphs
84 |
85 | def __str__(self):
86 | return self.__repr__()
87 |
88 | def __repr__(self):
89 | s = ""
90 | s += "document_id: %s" % (self.document_id)
91 | s += ", paragraph_num: %s" % (len(self.paragraphs))
92 | return s
93 |
94 | def get_doc_text(self):
95 | all_doc_text = ''
96 | for idx, para in enumerate(self.paragraphs):
97 | if idx == 0:
98 | all_doc_text += para.paragraph_text
99 | else:
100 | all_doc_text += ' '
101 | all_doc_text += para.paragraph_text
102 | return all_doc_text
103 |
104 | def tfidf_rank(questions: List[str], documents: List[str]):
105 | tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=stop_words)
106 | doc_features = tfidf.fit_transform(documents)
107 | q_features = tfidf.transform(questions)
108 | scores = pairwise_distances(q_features, doc_features, "cosine")
109 | return scores
110 |
111 | def read_squad_open_examples(input_file, n_to_select, is_training, debug=False):
112 | """Read a SQuAD json file into a list of SquadExample."""
113 | with open(input_file, "r") as reader:
114 | input_data = json.load(reader)["data"]
115 |
116 | documents = []
117 | questions = []
118 | for article_ix, article in enumerate(input_data):
119 | document_id = "%s-%d" % (article['title'], article_ix)
120 | paragraphs = []
121 | for paragraph_ix, paragraph in enumerate(article["paragraphs"]):
122 | paragraph_text = paragraph["context"]
123 | paragraphs.append(Paragraph(paragraph_ix, paragraph_text))
124 |
125 | for qa in paragraph["qas"]:
126 | qas_id = qa["id"]
127 | question_text = qa["question"]
128 | answer_texts = []
129 | for answer in qa["answers"]:
130 | answer_texts.append(answer["text"])
131 | questions.append(Question(qas_id, article_ix, paragraph_ix, question_text, answer_texts))
132 |
133 | documents.append(Document(document_id, paragraphs))
134 | if (article_ix+1) == 10 and debug:
135 | break
136 |
137 | scores = tfidf_rank([x.question_text for x in questions], [x.get_doc_text() for x in documents]) # [1177, 3]
138 |
139 | ir_count = 0
140 | for que_ix, question in enumerate(questions):
141 | doc_scores = scores[que_ix]
142 | doc_ranks = np.argsort(doc_scores)
143 | selection = [i for i in doc_ranks[:n_to_select]]
144 | rank = [i for i in np.arange(n_to_select)]
145 |
146 | if question.doc_index in selection:
147 | ir_count += 1
148 |
149 | if is_training and question.doc_index not in selection:
150 | selection[-1] = question.doc_index
151 |
152 | print("Retrieve {} questions from {} documents".format(len(questions), len(documents)))
153 | print("Recall of answer existence in documents: {:.3f}".format(ir_count / len(questions)))
154 |
155 | read_squad_open_examples("../data/squad/dev-v1.1.json", 5, False, False)
--------------------------------------------------------------------------------
/triviaqa/ablate_triviaqa_unfiltered.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 |
5 | from triviaqa.build_span_corpus import TriviaQaUnfilteredDataset, TriviaQaSampleUnfilteredDataset
6 | from triviaqa.preprocessed_corpus import preprocess_par, ExtractMultiParagraphsPerQuestion, TopTfIdf
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser(description='Train a model on TriviaQA open')
11 | parser.add_argument("--debug", default=False, action='store_true', help="Whether to run in debug mode.")
12 | parser.add_argument("--data_dir", default="data/triviaqa/unfiltered", type=str, help="Triviaqa wiki data dir")
13 | parser.add_argument('--n_processes', type=int, default=1,
14 | help="Number of processes (i.e., select which paragraphs to train on) "
15 | "the data with")
16 | parser.add_argument('--chunk_size', type=int, default=1000,
17 | help="Size of one chunk")
18 | parser.add_argument('--n_para_train', type=int, default=2,
19 | help="Num of selected paragraphs during training")
20 | parser.add_argument('--n_para_dev', type=int, default=4,
21 | help="Num of selected paragraphs during evaluation")
22 | parser.add_argument('--n_para_test', type=int, default=4,
23 | help="Num of selected paragraphs during testing")
24 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to process train set.")
25 | parser.add_argument("--do_dev", default=False, action='store_true', help="Whether to process dev set.")
26 | parser.add_argument("--do_test", default=False, action='store_true', help="Whether to process test set.")
27 | args = parser.parse_args()
28 |
29 | if args.debug:
30 | corpus = TriviaQaSampleUnfilteredDataset()
31 | else:
32 | corpus = TriviaQaUnfilteredDataset()
33 |
34 | if args.do_train:
35 | train_questions = corpus.get_train() # List[TriviaQaQuestion]
36 | train_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_train, is_training=True),
37 | intern=True, is_training=True)
38 | _train = preprocess_par(train_questions, corpus.evidence, train_preprocesser, args.n_processes, args.chunk_size,
39 | "train")
40 | print("Recall of answer existence in {} set: {:.3f}".format("train", _train.ir_count / len(_train.data)))
41 | print("Average number of documents in {} set: {:.3f}".format("train", _train.total_doc_num / len(_train.data)))
42 | print("Average length of documents in {} set: {:.3f}".format("train", _train.total_doc_length / len(_train.data)))
43 | print("Average pruned length of documents in {} set: {:.3f}".format("train", _train.pruned_doc_length / len(_train.data)))
44 | print("Number of examples: {}".format(len(_train.data)))
45 |
46 | train_examples_path = os.path.join(args.data_dir, "train_{}paras_examples.pkl".format(args.n_para_train))
47 | pickle.dump(_train.data, open(train_examples_path, 'wb'))
48 |
49 | if args.do_dev:
50 | dev_questions = corpus.get_dev()
51 | dev_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_dev, is_training=False),
52 | intern=True, is_training=False)
53 | _dev = preprocess_par(dev_questions, corpus.evidence, dev_preprocesser, args.n_processes, args.chunk_size, "dev")
54 | print("Recall of answer existence in {} set: {:.3f}".format("dev", _dev.ir_count / len(_dev.data)))
55 | print("Average number of documents in {} set: {:.3f}".format("dev", _dev.total_doc_num / len(_dev.data)))
56 | print("Average length of documents in {} set: {:.3f}".format("dev", _dev.total_doc_length / len(_dev.data)))
57 | print("Average pruned length of documents in {} set: {:.3f}".format("dev", _dev.pruned_doc_length / len(_dev.data)))
58 | print("Number of examples: {}".format(len(_dev.data)))
59 |
60 | dev_examples_path = os.path.join(args.data_dir, "dev_{}paras_examples.pkl".format(args.n_para_dev))
61 | pickle.dump(_dev.data, open(dev_examples_path, 'wb'))
62 |
63 | if args.do_test:
64 | test_questions = corpus.get_test()
65 | test_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_test, is_training=False),
66 | intern=True, is_training=False)
67 | _test = preprocess_par(test_questions, corpus.evidence, test_preprocesser, args.n_processes,
68 | args.chunk_size, "test")
69 | print("Recall of answer existence in {} set: {:.3f}".format("test", _test.ir_count / len(_test.data)))
70 | print("Average number of documents in {} set: {:.3f}".format("test", _test.total_doc_num / len(_test.data)))
71 | print("Average length of documents in {} set: {:.3f}".format("test", _test.total_doc_length / len(_test.data)))
72 | print("Average pruned length of documents in {} set: {:.3f}".format("test", _test.pruned_doc_length / len(_test.data)))
73 | print("Number of examples: {}".format(len(_test.data)))
74 |
75 | test_examples_path = os.path.join(args.data_dir, "test_{}paras_examples.pkl".format(args.n_para_test))
76 | pickle.dump(_test.data, open(test_examples_path, 'wb'))
77 |
78 |
79 | if __name__ == "__main__":
80 | main()
--------------------------------------------------------------------------------
/triviaqa/ablate_triviaqa_wiki.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 |
5 | from triviaqa.build_span_corpus import TriviaQaWikiDataset, TriviaQaSampleWikiDataset
6 | from triviaqa.preprocessed_corpus import preprocess_par, ExtractMultiParagraphsPerQuestion, TopTfIdf
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser(description='Train a model on TriviaQA web')
11 | parser.add_argument("--debug", default=False, action='store_true', help="Whether to run in debug mode.")
12 | parser.add_argument("--data_dir", default="data/triviaqa/wiki", type=str, help="Triviaqa wiki data dir")
13 | parser.add_argument('--n_processes', type=int, default=1,
14 | help="Number of processes (i.e., select which paragraphs to train on) "
15 | "the data with")
16 | parser.add_argument('--chunk_size', type=int, default=1000,
17 | help="Size of one chunk")
18 | parser.add_argument('--n_para_train', type=int, default=2,
19 | help="Num of selected paragraphs during training")
20 | parser.add_argument('--n_para_dev', type=int, default=4,
21 | help="Num of selected paragraphs during evaluation")
22 | parser.add_argument('--n_para_verified', type=int, default=4,
23 | help="Num of selected paragraphs during evaluation")
24 | parser.add_argument('--n_para_test', type=int, default=4,
25 | help="Num of selected paragraphs during testing")
26 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to process train set.")
27 | parser.add_argument("--do_dev", default=False, action='store_true', help="Whether to process dev set.")
28 | parser.add_argument("--do_verified", default=False, action='store_true', help="Whether to process verified set.")
29 | parser.add_argument("--do_test", default=False, action='store_true', help="Whether to process test set.")
30 | args = parser.parse_args()
31 |
32 | if args.debug:
33 | corpus = TriviaQaSampleWikiDataset()
34 | else:
35 | corpus = TriviaQaWikiDataset()
36 |
37 | if args.do_train:
38 | train_questions = corpus.get_train() # List[TriviaQaQuestion]
39 | train_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_train, is_training=True),
40 | intern=True, is_training=True)
41 | _train = preprocess_par(train_questions, corpus.evidence, train_preprocesser, args.n_processes, args.chunk_size,
42 | "train")
43 | print("Recall of answer existence in {} set: {:.3f}".format("train", _train.ir_count / len(_train.data)))
44 | print("Average number of documents in {} set: {:.3f}".format("train", _train.total_doc_num / len(_train.data)))
45 | print("Average length of documents in {} set: {:.3f}".format("train", _train.total_doc_length / len(_train.data)))
46 | print("Average pruned length of documents in {} set: {:.3f}".format("train", _train.pruned_doc_length / len(_train.data)))
47 | print("Number of examples: {}".format(len(_train.data)))
48 |
49 | train_examples_path = os.path.join(args.data_dir, "train_{}paras_examples.pkl".format(args.n_para_train))
50 | pickle.dump(_train.data, open(train_examples_path, 'wb'))
51 |
52 | if args.do_dev:
53 | dev_questions = corpus.get_dev()
54 | dev_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_dev, is_training=False),
55 | intern=True, is_training=False)
56 | _dev = preprocess_par(dev_questions, corpus.evidence, dev_preprocesser, args.n_processes, args.chunk_size, "dev")
57 | print("Recall of answer existence in {} set: {:.3f}".format("dev", _dev.ir_count / len(_dev.data)))
58 | print("Average number of documents in {} set: {:.3f}".format("dev", _dev.total_doc_num / len(_dev.data)))
59 | print("Average length of documents in {} set: {:.3f}".format("dev", _dev.total_doc_length / len(_dev.data)))
60 | print("Average pruned length of documents in {} set: {:.3f}".format("dev", _dev.pruned_doc_length / len(_dev.data)))
61 | print("Number of examples: {}".format(len(_dev.data)))
62 |
63 | dev_examples_path = os.path.join(args.data_dir, "dev_{}paras_examples.pkl".format(args.n_para_dev))
64 | pickle.dump(_dev.data, open(dev_examples_path, 'wb'))
65 |
66 | if args.do_verified:
67 | verified_questions = corpus.get_verified()
68 | verified_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_verified, is_training=False),
69 | intern=True, is_training=False)
70 | _verified = preprocess_par(verified_questions, corpus.evidence, verified_preprocesser, args.n_processes,
71 | args.chunk_size, "verified")
72 | print("Recall of answer existence in {} set: {:.3f}".format("verified", _verified.ir_count / len(_verified.data)))
73 | print("Average number of documents in {} set: {:.3f}".format("verified", _verified.total_doc_num / len(_verified.data)))
74 | print("Average length of documents in {} set: {:.3f}".format("verified", _verified.total_doc_length / len(_verified.data)))
75 | print("Average pruned length of documents in {} set: {:.3f}".format("verified", _verified.pruned_doc_length / len(_verified.data)))
76 | print("Number of examples: {}".format(len(_verified.data)))
77 |
78 | verified_examples_path = os.path.join(args.data_dir, "verified_{}paras_examples.pkl".format(args.n_para_verified))
79 | pickle.dump(_verified.data, open(verified_examples_path, 'wb'))
80 |
81 | if args.do_test:
82 | test_questions = corpus.get_test()
83 | test_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_test, is_training=False),
84 | intern=True, is_training=False)
85 | _test = preprocess_par(test_questions, corpus.evidence, test_preprocesser, args.n_processes,
86 | args.chunk_size, "test")
87 | print("Recall of answer existence in {} set: {:.3f}".format("test", _test.ir_count / len(_test.data)))
88 | print("Average number of documents in {} set: {:.3f}".format("test", _test.total_doc_num / len(_test.data)))
89 | print("Average length of documents in {} set: {:.3f}".format("test", _test.total_doc_length / len(_test.data)))
90 | print("Average pruned length of documents in {} set: {:.3f}".format("test", _test.pruned_doc_length / len(_test.data)))
91 | print("Number of examples: {}".format(len(_test.data)))
92 |
93 | test_examples_path = os.path.join(args.data_dir, "test_{}paras_examples.pkl".format(args.n_para_test))
94 | pickle.dump(_test.data, open(test_examples_path, 'wb'))
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
--------------------------------------------------------------------------------
/triviaqa/answer_detection.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 | from typing import List
7 |
8 | from triviaqa.read_data import TriviaQaQuestion
9 | from triviaqa.triviaqa_eval import normalize_answer, f1_score
10 | from triviaqa.utils import flatten_iterable, split
11 |
12 |
13 | class FastNormalizedAnswerDetector(object):
14 | """ almost twice as fast and very,very close to NormalizedAnswerDetector's output """
15 |
16 | def __init__(self):
17 | # These come from the TrivaQA official evaluation script
18 | self.skip = {"a", "an", "the", ""}
19 | self.strip = string.punctuation + "".join([u"‘", u"’", u"´", u"`", "_"])
20 |
21 | self.answer_tokens = None
22 |
23 | def set_question(self, normalized_aliases):
24 | self.answer_tokens = normalized_aliases
25 |
26 | def any_found(self, para): # List[str]
27 | # Normalize the paragraph
28 | words = [w.lower().strip(self.strip) for w in para]
29 | occurances = []
30 | for answer_ix, answer in enumerate(self.answer_tokens):
31 | # Locations where the first word occurs
32 | if len(answer) == 0:
33 | continue
34 | word_starts = [i for i, w in enumerate(words) if answer[0] == w] # [12, 50, 63 ...]
35 | n_tokens = len(answer) # 2
36 |
37 | # Advance forward until we find all the words, skipping over articles
38 | for start in word_starts:
39 | end = start + 1
40 | ans_token = 1
41 | while ans_token < n_tokens and end < len(words):
42 | next = words[end]
43 | if answer[ans_token] == next:
44 | ans_token += 1
45 | end += 1
46 | elif next in self.skip:
47 | end += 1
48 | else:
49 | break
50 | if n_tokens == ans_token:
51 | occurances.append((start, end))
52 | return list(set(occurances))
53 |
54 |
55 | def compute_answer_spans(questions: List[TriviaQaQuestion], corpus, tokenizer,
56 | detector):
57 |
58 | for i, q in enumerate(questions):
59 | if i % 500 == 0:
60 | print("Completed question %d of %d (%.3f)" % (i, len(questions), i/len(questions)))
61 | q.question = tokenizer.tokenize(q.question)
62 | if q.answer is None:
63 | continue
64 | tokenized_aliases = [tokenizer.tokenize(x) for x in q.answer.all_answers]
65 | if len(tokenized_aliases) == 0:
66 | raise ValueError()
67 | detector.set_question(tokenized_aliases)
68 | for doc in q.all_docs:
69 | text = corpus.get_document(doc.doc_id) # List[List[str]]
70 | if text is None:
71 | raise ValueError()
72 | spans = []
73 | offset = 0
74 | for para_ix, para in enumerate(text):
75 | for s, e in detector.any_found(para):
76 | spans.append((s+offset, e+offset-1)) # turn into inclusive span
77 | offset += len(para)
78 | if len(spans) == 0:
79 | spans = np.zeros((0, 2), dtype=np.int32)
80 | else:
81 | spans = np.array(spans, dtype=np.int32)
82 | doc.answer_spans = spans
83 |
84 |
85 | def _compute_answer_spans_chunk(questions, corpus, tokenizer, detector):
86 | compute_answer_spans(questions, corpus, tokenizer, detector)
87 | return questions
88 |
89 |
90 | def compute_answer_spans_par(questions: List[TriviaQaQuestion], corpus,
91 | tokenizer, detector, n_processes: int):
92 | if n_processes == 1:
93 | compute_answer_spans(questions, corpus, tokenizer, detector)
94 | return questions
95 | from multiprocessing import Pool
96 | with Pool(n_processes) as p:
97 | chunks = split(questions, n_processes)
98 | questions = flatten_iterable(p.starmap(_compute_answer_spans_chunk,
99 | [[c, corpus, tokenizer, detector] for c in chunks]))
100 | return questions
--------------------------------------------------------------------------------
/triviaqa/build_span_corpus.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import pickle
4 | import unicodedata
5 | from itertools import islice
6 | from typing import List, Optional, Dict
7 | from os import mkdir
8 | from os.path import join, exists, expanduser
9 | import bert.tokenization as tokenization
10 | from triviaqa.configurable import Configurable
11 | from triviaqa.read_data import iter_trivia_question, TriviaQaQuestion
12 | from triviaqa.evidence_corpus import TriviaQaEvidenceCorpusTxt
13 | from triviaqa.answer_detection import compute_answer_spans_par, FastNormalizedAnswerDetector
14 |
15 | TRIVIA_QA = join(expanduser("~"), "data", "triviaqa")
16 | TRIVIA_QA_UNFILTERED = join(expanduser("~"), "data", "triviaqa-unfiltered")
17 |
18 |
19 | def build_dataset(name: str, tokenizer, train_files: Dict[str, str],
20 | answer_detector, n_process: int, prune_unmapped_docs=True,
21 | sample=None):
22 | out_dir = join("data", "triviaqa", name)
23 | if not exists(out_dir):
24 | mkdir(out_dir)
25 |
26 | file_map = {} # maps document_id -> filename
27 |
28 | for name, filename in train_files.items():
29 | print("Loading %s questions" % name)
30 | if sample is None:
31 | questions = list(iter_trivia_question(filename, file_map, False))
32 | else:
33 | if isinstance(sample, int):
34 | questions = list(islice(iter_trivia_question(filename, file_map, False), sample))
35 | elif isinstance(sample, dict):
36 | questions = list(islice(iter_trivia_question(filename, file_map, False), sample[name]))
37 | else:
38 | raise ValueError()
39 |
40 | if prune_unmapped_docs:
41 | for q in questions:
42 | if q.web_docs is not None:
43 | q.web_docs = [x for x in q.web_docs if x.doc_id in file_map]
44 | q.entity_docs = [x for x in q.entity_docs if x.doc_id in file_map]
45 |
46 | print("Adding answers for %s question" % name)
47 | corpus = TriviaQaEvidenceCorpusTxt(file_map)
48 | questions = compute_answer_spans_par(questions, corpus, tokenizer, answer_detector, n_process)
49 | for q in questions: # Sanity check, we should have answers for everything (even if of size 0)
50 | if q.answer is None:
51 | continue
52 | for doc in q.all_docs:
53 | if doc.doc_id in file_map:
54 | if doc.answer_spans is None:
55 | raise RuntimeError()
56 |
57 | print("Saving %s question" % name)
58 | with open(join(out_dir, name + ".pkl"), "wb") as f:
59 | pickle.dump(questions, f)
60 |
61 | print("Dumping file mapping")
62 | with open(join(out_dir, "file_map.json"), "w") as f:
63 | json.dump(file_map, f)
64 |
65 | print("Complete")
66 |
67 | class TriviaQaSpanCorpus(Configurable):
68 | def __init__(self, corpus_name):
69 | self.corpus_name = corpus_name # web-sample
70 | self.dir = join("data", "triviaqa", corpus_name)
71 | with open(join(self.dir, "file_map.json"), "r") as f:
72 | file_map = json.load(f)
73 | for k, v in file_map.items():
74 | file_map[k] = unicodedata.normalize("NFD", v)
75 | self.evidence = TriviaQaEvidenceCorpusTxt(file_map) # evidence_corpus.py
76 |
77 | def get_train(self) -> List[TriviaQaQuestion]:
78 | with open(join(self.dir, "train.pkl"), "rb") as f:
79 | return pickle.load(f)
80 |
81 | def get_dev(self) -> List[TriviaQaQuestion]:
82 | with open(join(self.dir, "dev.pkl"), "rb") as f:
83 | return pickle.load(f)
84 |
85 | def get_test(self) -> List[TriviaQaQuestion]:
86 | with open(join(self.dir, "test.pkl"), "rb") as f:
87 | return pickle.load(f)
88 |
89 | def get_verified(self) -> Optional[List[TriviaQaQuestion]]:
90 | verified_dir = join(self.dir, "verified.pkl")
91 | if not exists(verified_dir):
92 | return None
93 | with open(verified_dir, "rb") as f:
94 | return pickle.load(f)
95 |
96 | @property
97 | def name(self):
98 | return self.corpus_name
99 |
100 | class TriviaQaWebDataset(TriviaQaSpanCorpus):
101 | def __init__(self):
102 | super().__init__("web")
103 |
104 | class TriviaQaWikiDataset(TriviaQaSpanCorpus):
105 | def __init__(self):
106 | super().__init__("wiki")
107 |
108 | class TriviaQaUnfilteredDataset(TriviaQaSpanCorpus):
109 | def __init__(self):
110 | super().__init__("unfiltered")
111 |
112 | class TriviaQaSampleWebDataset(TriviaQaSpanCorpus):
113 | def __init__(self):
114 | super().__init__("web-sample")
115 |
116 | class TriviaQaSampleWikiDataset(TriviaQaSpanCorpus):
117 | def __init__(self):
118 | super().__init__("wiki-sample")
119 |
120 | class TriviaQaSampleUnfilteredDataset(TriviaQaSpanCorpus):
121 | def __init__(self):
122 | super().__init__("unfiltered-sample")
123 |
124 | def build_wiki_corpus(n_processes):
125 | build_dataset("wiki", tokenization.BasicTokenizer(do_lower_case=True),
126 | dict(
127 | verified=join(TRIVIA_QA, "qa", "verified-wikipedia-dev.json"),
128 | dev=join(TRIVIA_QA, "qa", "wikipedia-dev.json"),
129 | train=join(TRIVIA_QA, "qa", "wikipedia-train.json"),
130 | test=join(TRIVIA_QA, "qa", "wikipedia-test-without-answers.json")
131 | ),
132 | FastNormalizedAnswerDetector(), n_processes)
133 |
134 | def build_web_corpus(n_processes):
135 | build_dataset("web", tokenization.BasicTokenizer(do_lower_case=True),
136 | dict(
137 | verified=join(TRIVIA_QA, "qa", "verified-web-dev.json"),
138 | dev=join(TRIVIA_QA, "qa", "web-dev.json"),
139 | train=join(TRIVIA_QA, "qa", "web-train.json"),
140 | test=join(TRIVIA_QA, "qa", "web-test-without-answers.json")
141 | ),
142 | FastNormalizedAnswerDetector(), n_processes)
143 |
144 | def build_unfiltered_corpus(n_processes):
145 | build_dataset("unfiltered", tokenization.BasicTokenizer(do_lower_case=True),
146 | dict(
147 | dev=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-dev.json"),
148 | train=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-train.json"),
149 | test=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-test-without-answers.json")
150 | ),
151 | FastNormalizedAnswerDetector(), n_processes)
152 |
153 | def build_wiki_sample_corpus(n_processes):
154 | build_dataset("wiki-sample", tokenization.BasicTokenizer(do_lower_case=True),
155 | dict(
156 | verified=join(TRIVIA_QA, "qa", "verified-wikipedia-dev.json"),
157 | dev=join(TRIVIA_QA, "qa", "wikipedia-dev.json"),
158 | train=join(TRIVIA_QA, "qa", "wikipedia-train.json"),
159 | test=join(TRIVIA_QA, "qa", "wikipedia-test-without-answers.json")
160 | ),
161 | FastNormalizedAnswerDetector(), n_processes, sample=20)
162 |
163 | def build_web_sample_corpus(n_processes):
164 | build_dataset("web-sample", tokenization.BasicTokenizer(do_lower_case=True),
165 | dict(
166 | verified=join(TRIVIA_QA, "qa", "verified-web-dev.json"),
167 | dev=join(TRIVIA_QA, "qa", "web-dev.json"),
168 | train=join(TRIVIA_QA, "qa", "web-train.json"),
169 | test=join(TRIVIA_QA, "qa", "web-test-without-answers.json")
170 | ),
171 | FastNormalizedAnswerDetector(), n_processes, sample=20)
172 |
173 | def build_unfiltered_sample_corpus(n_processes):
174 | build_dataset("unfiltered-sample", tokenization.BasicTokenizer(do_lower_case=True),
175 | dict(
176 | dev=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-dev.json"),
177 | train=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-train.json"),
178 | test=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-test-without-answers.json")
179 | ),
180 | FastNormalizedAnswerDetector(), n_processes, sample=20)
181 |
182 | def main():
183 | parser = argparse.ArgumentParser("Pre-procsess TriviaQA data")
184 | parser.add_argument("corpus", choices=["web", "wiki", "unfiltered", "web-sample", "wiki-sample", "unfiltered-sample"])
185 | parser.add_argument("-n", "--n_processes", type=int, default=1, help="Number of processes to use")
186 | args = parser.parse_args()
187 | if args.corpus == "web":
188 | build_web_corpus(args.n_processes)
189 | elif args.corpus == "wiki":
190 | build_wiki_corpus(args.n_processes)
191 | elif args.corpus == "unfiltered":
192 | build_unfiltered_corpus(args.n_processes)
193 | elif args.corpus == "web-sample":
194 | build_web_sample_corpus(args.n_processes)
195 | elif args.corpus == "wiki-sample":
196 | build_wiki_sample_corpus(args.n_processes)
197 | elif args.corpus == "unfiltered-sample":
198 | build_unfiltered_sample_corpus(args.n_processes)
199 | else:
200 | raise RuntimeError()
201 |
202 |
203 | if __name__ == "__main__":
204 | main()
--------------------------------------------------------------------------------
/triviaqa/configurable.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import OrderedDict
3 | from inspect import signature
4 | from warnings import warn
5 |
6 | import numpy as np
7 | from sklearn.base import BaseEstimator
8 |
9 |
10 | class Configuration(object):
11 | def __init__(self, name, version, params):
12 | if not isinstance(name, str):
13 | raise ValueError()
14 | if not isinstance(params, dict):
15 | raise ValueError()
16 | self.name = name
17 | self.version = version
18 | self.params = params
19 |
20 | def __str__(self):
21 | if len(self.params) == 0:
22 | return "%s-v%s" % (self.name, self.version)
23 | json_params = config_to_json(self.params)
24 | if len(json_params) < 200:
25 | return "%s-v%s: %s" % (self.name, self.version, json_params)
26 | else:
27 | return "%s-v%s {...}" % (self.name, self.version)
28 |
29 | def __eq__(self, other):
30 | return isinstance(other, Configuration) and \
31 | self.name == other.name and \
32 | self.version == other.version and \
33 | self.params == other.params
34 |
35 |
36 | class Configurable(object):
37 | """
38 | Configurable classes have names, versions, and a set of parameters that are either "simple" aka JSON serializable
39 | types or other Configurable objects. Configurable objects should also be serializable via pickle.
40 | Configurable classes are defined mainly to give us a human-readable way of reading of the `parameters`
41 | set for different objects and to attach version numbers to them.
42 |
43 | By default we follow the format sklearn uses for its `BaseEstimator` class, where parameters are automatically
44 | derived based on the constructor parameters.
45 | """
46 |
47 | @classmethod
48 | def _get_param_names(cls):
49 | # fetch the constructor or the original constructor before
50 | init = cls.__init__
51 | if init is object.__init__:
52 | # No explicit constructor to introspect
53 | return []
54 |
55 | init_signature = signature(init)
56 | parameters = [p for p in init_signature.parameters.values()
57 | if p.name != 'self']
58 | if any(p.kind == p.VAR_POSITIONAL for p in parameters):
59 | raise RuntimeError()
60 | return sorted([p.name for p in parameters])
61 |
62 | @property
63 | def name(self):
64 | return self.__class__.__name__
65 |
66 | @property
67 | def version(self):
68 | return 0
69 |
70 | def get_params(self):
71 | out = {}
72 | for key in self._get_param_names():
73 | v = getattr(self, key, None)
74 | if isinstance(v, Configurable):
75 | out[key] = v.get_config()
76 | elif hasattr(v, "get_config"): # for keras objects
77 | out[key] = {"name": v.__class__.__name__, "config": v.get_config()}
78 | else:
79 | out[key] = v
80 | return out
81 |
82 | def get_config(self) -> Configuration:
83 | params = {k: describe(v) for k,v in self.get_params().items()}
84 | return Configuration(self.name, self.version, params)
85 |
86 | def __getstate__(self):
87 | state = dict(self.__dict__)
88 | if "version" in state:
89 | if state["version"] != self.version:
90 | raise RuntimeError()
91 | else:
92 | state["version"] = self.version
93 | return state
94 |
95 | def __setstate__(self, state):
96 | if "version" not in state:
97 | raise RuntimeError("Version should be in state (%s)" % self.__class__.__name__)
98 | if state["version"] != self.version:
99 | warn(("%s loaded with version %s, but class " +
100 | "version is %s") % (self.__class__.__name__, state["version"], self.version))
101 |
102 | if "state" in state:
103 | self.__dict__ = state["state"]
104 | else:
105 | del state["version"]
106 | self.__dict__ = state
107 |
108 |
109 | def describe(obj):
110 | if isinstance(obj, Configurable):
111 | return obj.get_config()
112 | else:
113 | obj_type = type(obj)
114 |
115 | if obj_type in (list, set, frozenset, tuple):
116 | return obj_type([describe(e) for e in obj])
117 | elif isinstance(obj, tuple):
118 | # Name tuple, convert to tuple
119 | return tuple(describe(e) for e in obj)
120 | elif obj_type in (dict, OrderedDict):
121 | output = OrderedDict()
122 | for k, v in obj.items():
123 | if isinstance(k, Configurable):
124 | raise ValueError()
125 | output[k] = describe(v)
126 | return output
127 | else:
128 | return obj
129 |
130 |
131 | class EncodeDescription(json.JSONEncoder):
132 | """ Json encoder that encodes 'Configurable' objects as dictionaries and handles
133 | some numpy types. Note decoding this output will not reproduce the original input,
134 | for these types, this is only intended to be used to produce human readable output.
135 | '"""
136 | def default(self, obj):
137 | if isinstance(obj, np.integer):
138 | return int(obj)
139 | elif isinstance(obj, np.dtype):
140 | return str(obj)
141 | elif isinstance(obj, np.floating):
142 | return float(obj)
143 | elif isinstance(obj, np.bool_):
144 | return bool(obj)
145 | elif isinstance(obj, np.ndarray):
146 | return obj.tolist()
147 | elif isinstance(obj, BaseEstimator): # handle sklearn estimators
148 | return Configuration(obj.__class__.__name__, 0, obj.get_params())
149 | elif isinstance(obj, Configuration):
150 | if "version" in obj.params or "name" in obj.params:
151 | raise ValueError()
152 | out = OrderedDict()
153 | out["name"] = obj.name
154 | if obj.version != 0:
155 | out["version"] = obj.version
156 | out.update(obj.params)
157 | return out
158 | elif isinstance(obj, Configurable):
159 | return obj.get_config()
160 | elif isinstance(obj, set):
161 | return sorted(obj) # Ensure deterministic order
162 | else:
163 | try:
164 | return super().default(obj)
165 | except TypeError:
166 | return str(obj)
167 |
168 |
169 | def config_to_json(data, indent=None):
170 | return json.dumps(data, sort_keys=False, cls=EncodeDescription, indent=indent)
171 |
--------------------------------------------------------------------------------
/triviaqa/evidence_corpus.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import re
3 | from os import walk, mkdir, makedirs
4 | from os.path import relpath, join, exists, expanduser
5 | from typing import Set
6 | from tqdm import tqdm
7 | from typing import List
8 |
9 | import bert.tokenization as tokenization
10 | from triviaqa.utils import split, flatten_iterable, group
11 | from triviaqa.read_data import normalize_wiki_filename
12 |
13 | TRIVIA_QA = join(expanduser("~"), "data", "triviaqa")
14 |
15 | class MergeParagraphs(object):
16 | def __init__(self, max_tokens: int):
17 | self.max_tokens = max_tokens
18 |
19 | def merge(self, paragraphs: List[List[str]]):
20 | all_paragraphs = []
21 |
22 | on_paragraph = [] # text we have collect for the current paragraph
23 | cur_tokens = 0 # number of tokens in the current paragraph
24 |
25 | word_ix = 0
26 | for para in paragraphs:
27 | n_words = len(para)
28 | start_token = word_ix
29 | end_token = start_token + n_words
30 | word_ix = end_token
31 |
32 | if cur_tokens + n_words > self.max_tokens:
33 | if cur_tokens != 0: # end the current paragraph
34 | all_paragraphs.append(on_paragraph)
35 | on_paragraph = []
36 | cur_tokens = 0
37 |
38 | if n_words >= self.max_tokens: # either add current paragraph, or begin a new paragraph
39 | all_paragraphs.append(para)
40 | else:
41 | on_paragraph += para
42 | cur_tokens = n_words
43 | else:
44 | on_paragraph += para
45 | cur_tokens += n_words
46 |
47 | if on_paragraph != []:
48 | all_paragraphs.append(on_paragraph)
49 | return all_paragraphs
50 |
51 | def _gather_files(input_root, output_dir, skip_dirs, wiki_only):
52 | if not exists(output_dir):
53 | mkdir(output_dir)
54 |
55 | all_files = []
56 | for root, dirs, filenames in walk(input_root):
57 | if skip_dirs: # False
58 | output = join(output_dir, relpath(root, input_root))
59 | if exists(output):
60 | continue
61 | path = relpath(root, input_root)
62 | normalized_path = normalize_wiki_filename(path)
63 | if not exists(join(output_dir, normalized_path)):
64 | mkdir(join(output_dir, normalized_path))
65 | all_files += [join(path, x) for x in filenames]
66 | if wiki_only:
67 | all_files = [x for x in all_files if "wikipedia/" in x]
68 | return all_files
69 |
70 | def build_tokenized_files(filenames, input_root, output_root, tokenizer, splitter, override=True) -> Set[str]:
71 | """
72 | For each file in `filenames` loads the text, tokenizes it with `tokenizer, and
73 | saves the output to the same relative location in `output_root`.
74 | @:return a set of all the individual words seen
75 | """
76 | voc = set()
77 | for filename in filenames:
78 | out_file = normalize_wiki_filename(filename[:filename.rfind(".")]) + ".txt"
79 | out_file = join(output_root, out_file)
80 | if not override and exists(out_file):
81 | continue
82 | with open(join(input_root, filename), "r") as in_file:
83 | text = in_file.read().strip()
84 | paras = [x for x in text.split("\n") if len(x) > 0]
85 | paragraphs = [tokenizer.tokenize(x) for x in paras]
86 | merged_paragraphs = splitter.merge(paragraphs)
87 |
88 | for para in merged_paragraphs:
89 | for i, word in enumerate(para):
90 | voc.update(word)
91 |
92 | with open(out_file, "w") as in_file:
93 | in_file.write("\n\n".join(" ".join(para) for para in merged_paragraphs))
94 | return voc
95 |
96 | def build_tokenized_corpus(input_root, tokenizer, splitter, output_dir, skip_dirs=False,
97 | n_processes=1, wiki_only=False):
98 | if not exists(output_dir):
99 | makedirs(output_dir)
100 |
101 | all_files = _gather_files(input_root, output_dir, skip_dirs, wiki_only)
102 |
103 | if n_processes == 1:
104 | voc = build_tokenized_files(tqdm(all_files, ncols=80), input_root, output_dir, tokenizer, splitter)
105 | else:
106 | voc = set()
107 | from multiprocessing import Pool
108 | with Pool(n_processes) as pool:
109 | chunks = split(all_files, n_processes)
110 | chunks = flatten_iterable(group(c, 500) for c in chunks)
111 | pbar = tqdm(total=len(chunks), ncols=80)
112 | for v in pool.imap_unordered(_build_tokenized_files_t,
113 | [[c, input_root, output_dir, tokenizer, splitter] for c in chunks]):
114 | voc.update(v)
115 | pbar.update(1)
116 | pbar.close()
117 |
118 | def _build_tokenized_files_t(arg):
119 | return build_tokenized_files(*arg)
120 |
121 | class TriviaQaEvidenceCorpusTxt(object):
122 | """
123 | Corpus of the tokenized text from the given TriviaQa evidence documents.
124 | Allows the text to be retrieved by document id
125 | """
126 |
127 | _split_para = re.compile("\n\n+")
128 |
129 | def __init__(self, file_id_map=None):
130 | self.directory = join("data", "triviaqa/evidence")
131 | self.file_id_map = file_id_map
132 |
133 | def get_document(self, doc_id):
134 | if self.file_id_map is None:
135 | file_id = doc_id
136 | else:
137 | file_id = self.file_id_map.get(doc_id)
138 |
139 | if file_id is None:
140 | return None
141 |
142 | file_id = join(self.directory, file_id + ".txt")
143 | if not exists(file_id):
144 | return None
145 |
146 | with open(file_id, "r") as f:
147 | text = f.read()
148 | paragraphs = []
149 | for para in self._split_para.split(text):
150 | paragraphs.append(para.split(" "))
151 | return paragraphs # List[List[str]]
152 |
153 | def main():
154 | parse = argparse.ArgumentParser("Pre-tokenize the TriviaQA evidence corpus")
155 | parse.add_argument("-o", "--output_dir", type=str, default=join("data", "triviaqa", "evidence"))
156 | parse.add_argument("-s", "--source", type=str, default=join(TRIVIA_QA, "evidence"))
157 | # This is slow, using more processes is recommended
158 | parse.add_argument("-n", "--n_processes", type=int, default=1, help="Number of processes to use")
159 | parse.add_argument("--max_tokens", type=int, default=200, help="Number of maximal tokens in each merged paragraph")
160 | parse.add_argument("--wiki_only", action="store_true")
161 | args = parse.parse_args()
162 |
163 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
164 | splitter = MergeParagraphs(args.max_tokens)
165 | build_tokenized_corpus(args.source, tokenizer, splitter, args.output_dir,
166 | n_processes=args.n_processes, wiki_only=args.wiki_only)
167 |
168 | if __name__ == "__main__":
169 | main()
--------------------------------------------------------------------------------
/triviaqa/preprocessed_corpus.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import gzip
3 | import random
4 | import pickle
5 | from collections import Counter
6 | from threading import Lock
7 | from typing import List, Iterable, Optional
8 |
9 | import math
10 | import numpy as np
11 | from sklearn.feature_extraction.text import TfidfVectorizer
12 | from sklearn.metrics import pairwise_distances
13 | from tqdm import tqdm
14 | from triviaqa.utils import split, flatten_iterable, group
15 | from triviaqa.configurable import Configurable
16 | from triviaqa.read_data import TriviaQaQuestion
17 | from triviaqa.triviaqa_document_utils import ExtractedParagraphWithAnswers, DocParagraphWithAnswers, DocumentAndQuestion
18 |
19 | stop_words = {'t', '–', 'there', 'but', 'needn', 'themselves', '’', '~', '$', 'few', '^', '₹', ']', 'we', 're',
20 | 'again', '?', 'they', 'ain', 'o', 'you', '+', 'has', 'by', 'than', 'whom', 'same', 'don', 'her',
21 | 'are', '(', 'an', 'so', 'the', 'been', 'wouldn', 'a', 'many', 'she', 'how', 'your', '°', 'do',
22 | 'shan', 'himself', 'between', 'ours', 'at', 'should', 'doesn', 'hasn', 'he', 'have', 'over',
23 | 'hadn', 'was', 'weren', 'down', 'above', '_', 'those', 'not', 'having', 'its', 'ourselves',
24 | 'for', 'when', 'if', ',', ';', 'about', 'theirs', 'him', '}', 'here', 'any', 'own', 'itself',
25 | 'very', 'on', 'myself', 'mustn', ')', 'because', 'now', '/', 'isn', 'to', 'just', 'these',
26 | 'i', 'further', 'mightn', 'll', '@', 'am', '”', 'below', 'shouldn', 'my', 'who', 'yours', 'why',
27 | 'such', '"', 'does', 'did', 'before', 'being', 'and', 'had', 'aren', '£', 'with', 'more', 'into',
28 | '<', 'herself', 'which', '[', "'", 'of', 'haven', 'that', 'will', 'yourself', 'in', 'doing', '−',
29 | 'them', '‘', 'some', '`', 'while', 'each', 'it', 'through', 'all', 'their', ':', '\\', 'where',
30 | 'both', 'hers', '¢', '—', 'm', '.', 'from', 'or', 'other', 'too', 'couldn', 'as', 'our', 'off',
31 | '%', '&', '-', '{', '=', 'didn', 'yourselves', 'under', 'y', 'ma', 'won', '!', '|', 'against',
32 | '#', '¥', 'is', 'nor', 'up', 'most', 's', 'no', 'can', '>', '*', 'during', 'once', 'what', 'me',
33 | 'then', 'd', 'only', 'de', 've', 'were', '€', 'until', 'his', 'out', 'wasn', 'this', 'after',
34 | 'be'}
35 |
36 | class ParagraphsSet(object):
37 | def __init__(self, paragraphs: List[ExtractedParagraphWithAnswers], ir_hit: bool):
38 | self.paragraphs = paragraphs
39 | self.ir_hit = ir_hit
40 |
41 | class ParagraphFilter(Configurable):
42 | """ Selects and ranks paragraphs """
43 |
44 | def prune(self, question, paragraphs: List[ExtractedParagraphWithAnswers]):
45 | raise NotImplementedError()
46 |
47 | class TopTfIdf(ParagraphFilter):
48 | def __init__(self, n_to_select: int, is_training: bool=False, sort_passage: bool=True):
49 | self.n_to_select = n_to_select
50 | self.is_training = is_training
51 | self.sort_passage = sort_passage
52 |
53 | def prune(self, question: List[str], paragraphs: List[ExtractedParagraphWithAnswers]):
54 | tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=stop_words)
55 | text = []
56 | for para in paragraphs:
57 | text.append(" ".join(para.text))
58 | try:
59 | para_features = tfidf.fit_transform(text)
60 | q_features = tfidf.transform([" ".join(question)])
61 | except ValueError:
62 | return []
63 |
64 | dists = pairwise_distances(q_features, para_features, "cosine").ravel() # [N]
65 | sorted_ix = np.lexsort(([x.start for x in paragraphs], dists)) # in case of ties, use the earlier paragraph [N]
66 |
67 | selection = [i for i in sorted_ix[:self.n_to_select]]
68 | selected_paras = [paragraphs[i] for i in selection]
69 | ir_hit = 0. if all(len(x.answer_spans) == 0 for x in selected_paras) else 1.
70 |
71 | if self.is_training and not ir_hit:
72 | gold_indexes = [i for i, x in enumerate(paragraphs) if len(x.answer_spans) != 0]
73 | gold_index = random.choice(gold_indexes)
74 | selection[-1] = gold_index
75 |
76 | if self.sort_passage:
77 | selection = np.sort(selection)
78 |
79 | return [paragraphs[i] for i in selection], ir_hit
80 |
81 | class ShallowOpenWebRanker(ParagraphFilter):
82 | # Hard coded weight learned from a logistic regression classifier
83 | TFIDF_W = 5.13365065
84 | LOG_WORD_START_W = 0.46022765
85 | FIRST_W = -0.08611607
86 | LOWER_WORD_W = 0.0499123
87 | WORD_W = -0.15537181
88 |
89 | def __init__(self, n_to_select: int, is_training: bool=False, sort_passage: bool=True):
90 | self.n_to_select = n_to_select
91 | self.is_training = is_training
92 | self.sort_passage = sort_passage
93 | self._tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=stop_words)
94 |
95 | def score_paragraphs(self, question, paragraphs: List[ExtractedParagraphWithAnswers]):
96 | tfidf = self._tfidf
97 | text = []
98 | for para in paragraphs:
99 | text.append(" ".join(para.text))
100 | try:
101 | para_features = tfidf.fit_transform(text)
102 | q_features = tfidf.transform([" ".join(question)])
103 | except ValueError:
104 | return []
105 |
106 | q_words = {x for x in question if x.lower() not in stop_words}
107 | q_words_lower = {x.lower() for x in q_words}
108 | word_matches_features = np.zeros((len(paragraphs), 2))
109 | for para_ix, para in enumerate(paragraphs):
110 | found = set()
111 | found_lower = set()
112 | for word in para.text:
113 | if word in q_words:
114 | found.add(word)
115 | elif word.lower() in q_words_lower:
116 | found_lower.add(word.lower())
117 | word_matches_features[para_ix, 0] = len(found)
118 | word_matches_features[para_ix, 1] = len(found_lower)
119 |
120 | tfidf = pairwise_distances(q_features, para_features, "cosine").ravel()
121 | starts = np.array([p.start for p in paragraphs])
122 | log_word_start = np.log(starts / 200.0 + 1)
123 | first = starts == 0
124 | scores = tfidf * self.TFIDF_W + self.LOG_WORD_START_W * log_word_start + self.FIRST_W * first + \
125 | self.LOWER_WORD_W * word_matches_features[:, 1] + self.WORD_W * word_matches_features[:, 0]
126 | return scores
127 |
128 | def prune(self, question: List[str], paragraphs: List[ExtractedParagraphWithAnswers]):
129 | scores = self.score_paragraphs(question, paragraphs)
130 | sorted_ix = np.argsort(scores)
131 |
132 | selection = [i for i in sorted_ix[:self.n_to_select]]
133 | selected_paras = [paragraphs[i] for i in selection]
134 | ir_hit = 0. if all(len(x.answer_spans) == 0 for x in selected_paras) else 1.
135 |
136 | if self.is_training and not ir_hit:
137 | gold_indexes = [i for i, x in enumerate(paragraphs) if len(x.answer_spans) != 0]
138 | gold_index = random.choice(gold_indexes)
139 | selection[-1] = gold_index
140 |
141 | if self.sort_passage:
142 | selection = np.sort(selection)
143 |
144 | return [paragraphs[i] for i in selection], ir_hit
145 |
146 |
147 | class Preprocessor(Configurable):
148 |
149 | def preprocess(self, question: Iterable, evidence) -> object:
150 | """ Map elements to an unspecified intermediate format """
151 | raise NotImplementedError()
152 |
153 | def finalize_chunk(self, x):
154 | """ Finalize the output from `preprocess`, in multi-processing senarios this will still be run on
155 | the main thread so it can be used for things like interning """
156 | pass
157 |
158 | def _preprocess_and_count(questions: List, evidence, preprocessor: Preprocessor):
159 | count = len(questions)
160 | output = preprocessor.preprocess(questions, evidence)
161 | return output, count
162 |
163 | def preprocess_par(questions: List, evidence, preprocessor,
164 | n_processes=2, chunk_size=200, name=None):
165 | if chunk_size <= 0:
166 | raise ValueError("Chunk size must be >= 0, but got %s" % chunk_size)
167 | if n_processes is not None and n_processes <= 0:
168 | raise ValueError("n_processes must be >= 1 or None, but got %s" % n_processes)
169 | n_processes = min(len(questions), n_processes)
170 |
171 | if n_processes == 1:
172 | out = preprocessor.preprocess(tqdm(questions, desc=name, ncols=80), evidence)
173 | preprocessor.finalize_chunk(out)
174 | return out
175 | else:
176 | from multiprocessing import Pool
177 | chunks = split(questions, n_processes)
178 | chunks = flatten_iterable([group(c, chunk_size) for c in chunks])
179 | print("Processing %d chunks with %d processes" % (len(chunks), n_processes))
180 | pbar = tqdm(total=len(questions), desc=name, ncols=80)
181 | lock = Lock()
182 |
183 | def call_back(results):
184 | preprocessor.finalize_chunk(results[0])
185 | with lock: # FIXME Even with the lock, the progress bar still is jumping around
186 | pbar.update(results[1])
187 |
188 | with Pool(n_processes) as pool:
189 | results = [pool.apply_async(_preprocess_and_count, [c, evidence, preprocessor], callback=call_back)
190 | for c in chunks]
191 | results = [r.get()[0] for r in results]
192 |
193 | pbar.close()
194 | output = results[0]
195 | for r in results[1:]:
196 | output += r
197 | return output
198 |
199 | class FilteredData(object):
200 | def __init__(self, data: List, true_len: int, ir_count: int,
201 | total_doc_num: int, total_doc_length: int, pruned_doc_length: int):
202 | self.data = data
203 | self.true_len = true_len
204 | self.ir_count = ir_count
205 | self.total_doc_num = total_doc_num
206 | self.total_doc_length = total_doc_length
207 | self.pruned_doc_length = pruned_doc_length
208 |
209 | def __add__(self, other):
210 | return FilteredData(self.data + other.data, self.true_len + other.true_len, self.ir_count + other.ir_count,
211 | self.total_doc_num + other.total_doc_num, self.total_doc_length + other.total_doc_length,
212 | self.pruned_doc_length + other.pruned_doc_length)
213 |
214 | def split_annotated(doc: List[List[str]], spans: np.ndarray):
215 | out = []
216 | offset = 0
217 | for para in doc:
218 | para_start = offset
219 | para_end = para_start + len(para)
220 | para_spans = spans[np.logical_and(spans[:, 0] >= para_start, spans[:, 1] < para_end)] - para_start
221 | out.append(ExtractedParagraphWithAnswers(para, para_start, para_end, para_spans))
222 | offset += len(para)
223 | return out
224 |
225 | class ExtractMultiParagraphsPerQuestion(Preprocessor):
226 | def __init__(self, ranker: ParagraphFilter, intern: bool=False, is_training=False):
227 | self.ranker = ranker
228 | self.intern = intern
229 | self.is_training = is_training
230 |
231 | def preprocess(self, questions: List[TriviaQaQuestion], evidence): # TriviaQaEvidenceCorpusTxt evidence_corpus.py
232 | ir_count, total_doc_num, total_doc_length, pruned_doc_length = 0, 0, 0, 0
233 |
234 | instances = []
235 | for q in questions:
236 | doc_paras = []
237 | doc_count, doc_length = 0, 0
238 | for doc in q.all_docs:
239 | if self.is_training and len(doc.answer_spans) == 0:
240 | continue
241 | text = evidence.get_document(doc.doc_id) # List[List[str]]
242 | if text is None:
243 | raise ValueError("No evidence text found document: " + doc.doc_id)
244 | if doc.answer_spans is not None:
245 | paras = split_annotated(text, doc.answer_spans)
246 | else:
247 | # this is kind of a hack to make the rest of the pipeline work, only
248 | # needed for test cases
249 | paras = split_annotated(text, np.zeros((0, 2), dtype=np.int32))
250 | doc_paras.extend([DocParagraphWithAnswers(x.text, x.start, x.end, x.answer_spans, doc.doc_id)
251 | for x in paras]) # List[DocParagraphWithAnswers]
252 | doc_length += sum(len(para) for para in text)
253 | doc_count += 1
254 |
255 | if len(doc_paras) == 0:
256 | continue
257 |
258 | doc_paras, ir_hit = self.ranker.prune(q.question, doc_paras) # List[ExtractedParagraphWithAnswers] len=4
259 | total_doc_num += doc_count
260 | total_doc_length += doc_length
261 | ir_count += ir_hit
262 |
263 | # merge into documentandquestion
264 | doc_tokens, start_positions, end_positions = [], [], []
265 | for x in doc_paras:
266 | offset_doc = len(doc_tokens)
267 | doc_tokens += x.text
268 | if len(x.answer_spans) != 0:
269 | start_position = x.answer_spans[:, 0] + offset_doc
270 | end_position = x.answer_spans[:, 1] + offset_doc
271 | start_positions.extend(start_position)
272 | end_positions.extend(end_position)
273 | instance = DocumentAndQuestion(q.all_docs[0].doc_id, q.question_id, " ".join(q.question), doc_tokens,
274 | None if q.answer is None else q.answer.all_answers, start_positions,
275 | end_positions)
276 | pruned_doc_length += len(doc_tokens)
277 |
278 | instances.append(instance)
279 | return FilteredData(instances, len(questions), ir_count, total_doc_num, total_doc_length, pruned_doc_length)
280 |
281 | def finalize_chunk(self, f: FilteredData):
282 | if self.intern:
283 | for ins in f.data:
284 | ins.document_id = sys.intern(ins.document_id)
285 | ins.qas_id = sys.intern(ins.qas_id)
286 | ins.question_text = sys.intern(ins.question_text)
287 |
288 |
289 | # class ExtractMultiParagraphs(Preprocessor):
290 | # def __init__(self, ranker: ParagraphFilter, intern: bool=False, is_training=False):
291 | # self.ranker = ranker
292 | # self.intern = intern
293 | # self.is_training = is_training
294 | #
295 | # def preprocess(self, questions: List[TriviaQaQuestion], evidence): # TriviaQaEvidenceCorpusTxt evidence_corpus.py
296 | # true_len = 0
297 | # ir_count, ir_total, pruned_doc_length = 0, 0, 0
298 | #
299 | # instances = []
300 | # for q in questions:
301 | # true_len += len(q.all_docs)
302 | # for doc in q.all_docs:
303 | # if self.is_training and len(doc.answer_spans) == 0:
304 | # continue
305 | # text = evidence.get_document(doc.doc_id) # List[List[str]]
306 | # if text is None:
307 | # raise ValueError("No evidence text found document: " + doc.doc_id)
308 | # if doc.answer_spans is not None:
309 | # paras = split_annotated(text, doc.answer_spans)
310 | # else:
311 | # # this is kind of a hack to make the rest of the pipeline work, only
312 | # # needed for test cases
313 | # paras = split_annotated(text, np.zeros((0, 2), dtype=np.int32))
314 | #
315 | # if len(paras) == 0:
316 | # continue
317 | #
318 | # paras, ir_hit = self.ranker.prune(q.question, paras) # List[ExtractedParagraphWithAnswers] len=4
319 | # ir_count += ir_hit
320 | # ir_total += 1
321 | #
322 | # # merge into documentandquestion
323 | # doc_tokens, start_positions, end_positions = [], [], []
324 | # for x in paras:
325 | # offset_doc = len(doc_tokens)
326 | # doc_tokens += x.text
327 | # if len(x.answer_spans) != 0:
328 | # start_position = x.answer_spans[:, 0] + offset_doc
329 | # end_position = x.answer_spans[:, 1] + offset_doc
330 | # start_positions.extend(start_position)
331 | # end_positions.extend(end_position)
332 | # instance = DocumentAndQuestion(doc.doc_id, q.question_id, " ".join(q.question), doc_tokens,
333 | # q.answer.all_answers, start_positions, end_positions)
334 | # pruned_doc_length += len(doc_tokens)
335 | #
336 | # instances.append(instance)
337 | # return FilteredData(instances, true_len, ir_count, ir_total, pruned_doc_length)
338 | #
339 | # def finalize_chunk(self, f: FilteredData):
340 | # if self.intern:
341 | # for ins in f.data:
342 | # ins.document_id = sys.intern(ins.document_id)
343 | # ins.qas_id = sys.intern(ins.qas_id)
344 | # ins.question_text = sys.intern(ins.question_text)
--------------------------------------------------------------------------------
/triviaqa/read_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import unicodedata
3 | from os.path import join
4 | from typing import List
5 |
6 | from triviaqa.triviaqa_eval import normalize_answer as triviaqa_normalize_answer
7 |
8 | """
9 | Read and represent trivia-qa data
10 | """
11 |
12 |
13 | def normalize_wiki_filename(filename):
14 | """
15 | Wiki filenames have been an pain, since the data seems to have filenames encoded in
16 | the incorrect case sometimes, and we have to be careful to keep a consistent unicode format.
17 | Our current solution is require all filenames to be normalized like this
18 | """
19 | return unicodedata.normalize("NFD", filename).lower()
20 |
21 |
22 | class WikipediaEntity(object):
23 | __slots__ = ["value", "normalized_value", "aliases", "normalized_aliases",
24 | "wiki_entity_name", "normalized_wiki_entity_name", "human_answers"]
25 |
26 | def __init__(self, value: str, normalized_value: str, aliases, normalized_aliases: List[str],
27 | wiki_entity_name: str, normalized_wiki_entity_name: str, human_answers):
28 | self.aliases = aliases
29 | self.value = value
30 | self.normalized_value = normalized_value
31 | self.normalized_aliases = normalized_aliases
32 | self.wiki_entity_name = wiki_entity_name
33 | self.normalized_wiki_entity_name = normalized_wiki_entity_name
34 | self.human_answers = human_answers
35 |
36 | @property
37 | def all_answers(self):
38 | if self.human_answers is None:
39 | return self.normalized_aliases
40 | else:
41 | # normalize to be consistent with the other normallized aliases
42 | human_answers = [triviaqa_normalize_answer(x) for x in self.human_answers]
43 | return self.normalized_aliases + [x for x in human_answers if len(x) > 0]
44 |
45 | def __repr__(self) -> str:
46 | return self.value
47 |
48 |
49 | class Numerical(object):
50 | __slots__ = ["number", "aliases", "normalized_aliases", "value", "unit",
51 | "normalized_value", "multiplier", "human_answers"]
52 |
53 | def __init__(self, number: float, aliases, normalized_aliases, value, unit,
54 | normalized_value, multiplier, human_answers):
55 | self.number = number
56 | self.aliases = aliases
57 | self.normalized_aliases = normalized_aliases
58 | self.value = value
59 | self.unit = unit
60 | self.normalized_value = normalized_value
61 | self.multiplier = multiplier
62 | self.human_answers = human_answers
63 |
64 | @property
65 | def all_answers(self):
66 | if self.human_answers is None:
67 | return self.normalized_aliases
68 | else:
69 | human_answers = [triviaqa_normalize_answer(x) for x in self.human_answers]
70 | return self.normalized_aliases + [x for x in human_answers if len(x) > 0]
71 |
72 | def __repr__(self) -> str:
73 | return self.value
74 |
75 |
76 | class FreeForm(object):
77 | __slots__ = ["value", "normalized_value", "aliases", "normalized_aliases", "human_answers"]
78 |
79 | def __init__(self, value, normalized_value, aliases, normalized_aliases, human_answers):
80 | self.value = value
81 | self.aliases = aliases
82 | self.normalized_value = normalized_value
83 | self.normalized_aliases = normalized_aliases
84 | self.human_answers = human_answers
85 |
86 | @property
87 | def all_answers(self):
88 | if self.human_answers is None:
89 | return self.normalized_aliases
90 | else:
91 | human_answers = [triviaqa_normalize_answer(x) for x in self.human_answers]
92 | return self.normalized_aliases + [x for x in human_answers if len(x) > 0]
93 |
94 | def __repr__(self) -> str:
95 | return self.value
96 |
97 |
98 | class Range(object):
99 | __slots__ = ["value", "normalized_value", "aliases", "normalized_aliases",
100 | "start", "end", "unit", "multiplier", "human_answers"]
101 |
102 | def __init__(self, value, normalized_value, aliases, normalized_aliases,
103 | start, end, unit, multiplier, human_answers):
104 | self.value = value
105 | self.normalized_value = normalized_value
106 | self.aliases = aliases
107 | self.normalized_aliases = normalized_aliases
108 | self.start = start
109 | self.end = end
110 | self.unit = unit
111 | self.multiplier = multiplier
112 | self.human_answers = human_answers
113 |
114 | @property
115 | def all_answers(self):
116 | if self.human_answers is None:
117 | return self.normalized_aliases
118 | else:
119 | human_answers = [triviaqa_normalize_answer(x) for x in self.human_answers]
120 | return self.normalized_aliases + [x for x in human_answers if len(x) > 0]
121 |
122 | def __repr__(self) -> str:
123 | return self.value
124 |
125 |
126 | class TagMeEntityDoc(object):
127 | __slots__ = ["rho", "link_probability", "title", "trivia_qa_selected", "answer_spans"]
128 |
129 | def __init__(self, rho, link_probability, title):
130 | self.rho = rho
131 | self.link_probability = link_probability
132 | self.title = title
133 | self.trivia_qa_selected = False
134 | self.answer_spans = None
135 |
136 | @property
137 | def doc_id(self):
138 | return self.title
139 |
140 | def __repr__(self) -> str:
141 | return "TagMeEntityDoc(%s)" % self.title
142 |
143 |
144 | class SearchEntityDoc(object):
145 | __slots__ = ["title", "trivia_qa_selected", "answer_spans"]
146 |
147 | def __init__(self, title):
148 | self.title = title
149 | self.answer_spans = None
150 | self.trivia_qa_selected = False
151 |
152 | @property
153 | def doc_id(self):
154 | return self.title
155 |
156 | def __repr__(self) -> str:
157 | return "SearchEntityDoc(%s)" % self.title
158 |
159 |
160 | class SearchDoc(object):
161 | __slots__ = ["title", "description", "rank", "url", "trivia_qa_selected", "answer_spans"]
162 |
163 | def __init__(self, title, description, rank, url):
164 | self.title = title
165 | self.description = description
166 | self.rank = rank
167 | self.url = url
168 | self.answer_spans = None
169 | self.trivia_qa_selected = False
170 |
171 | @property
172 | def doc_id(self):
173 | return self.url
174 |
175 | def __repr__(self) -> str:
176 | return "SearchDoc(%s)" % self.title
177 |
178 |
179 | class TriviaQaQuestion(object):
180 | __slots__ = ["question", "question_id", "answer", "entity_docs", "web_docs"]
181 |
182 | def __init__(self, question, question_id, answer, entity_docs, web_docs):
183 | self.question = question
184 | self.question_id = question_id
185 | self.answer = answer
186 | self.entity_docs = entity_docs
187 | self.web_docs = web_docs
188 |
189 | @property
190 | def all_docs(self):
191 | if self.web_docs is not None:
192 | return self.web_docs + self.entity_docs
193 | else:
194 | return self.entity_docs
195 |
196 | def to_compressed_json(self):
197 | return [
198 | self.question,
199 | self.question_id,
200 | [self.answer.__class__.__name__] + [getattr(self.answer, x) for x in self.answer.__slots__],
201 | [[doc.__class__.__name__] + [getattr(doc, x) for x in doc.__slots__] for doc in self.entity_docs],
202 | [[getattr(doc, x) for x in doc.__slots__] for doc in self.web_docs],
203 | ]
204 |
205 | @staticmethod
206 | def from_compressed_json(text):
207 | question, quid, answer, entity_docs, web_docs = json.loads(text)
208 | if answer[0] == "WikipediaEntity":
209 | answer = WikipediaEntity(*answer[1:])
210 | elif answer[0] == "Numerical":
211 | answer = Numerical(*answer[1:])
212 | elif answer[0] == "FreeForm":
213 | answer = FreeForm(*answer[1:])
214 | elif answer[0] == "Range":
215 | answer = Range(*answer[1:])
216 | else:
217 | raise ValueError()
218 | for i, doc in enumerate(entity_docs):
219 | if doc[0] == "TagMeEntityDoc":
220 | entity_docs[i] = TagMeEntityDoc(*doc[1:])
221 | elif doc[0] == "SearchEntityDoc":
222 | entity_docs[i] = SearchEntityDoc(*doc[1:])
223 | web_docs = [SearchDoc(*x) for x in web_docs]
224 | return TriviaQaQuestion(question, quid, answer, entity_docs, web_docs)
225 |
226 |
227 | def iter_question_json(filename):
228 | """ Iterates over trivia-qa questions in a JSON file, useful if the file is too large to be
229 | parse all at once """
230 | with open(filename, "r") as f:
231 | if f.readline().strip() != "{":
232 | raise ValueError()
233 | if "Data\": [" not in f.readline():
234 | raise ValueError()
235 | line = f.readline()
236 | while line.strip() == "{":
237 | obj = []
238 | line = f.readline()
239 | while not line.startswith(" }"):
240 | obj.append(line)
241 | line = f.readline()
242 | yield "{" + "".join(obj) + "}"
243 | if not line.startswith(" },"):
244 | # no comma means this was the last element of the data list
245 | return
246 | else:
247 | line = f.readline()
248 | else:
249 | raise ValueError()
250 |
251 |
252 | def build_questions(json_questions, title_to_file, require_filename):
253 | for q in json_questions:
254 | q = json.loads(q)
255 | ans = q.get("Answer")
256 | valid_attempt = q.get("QuestionVerifiedEvalAttempt", False)
257 | if valid_attempt and not q["QuestionPartOfVerifiedEval"]:
258 | continue # don't both with questions in the verified set that were rejected
259 | if ans is not None:
260 | answer_type = ans["Type"]
261 | if answer_type == "WikipediaEntity":
262 | answer = WikipediaEntity(ans["NormalizedValue"], ans["Value"], ans["Aliases"], ans["NormalizedAliases"],
263 | ans["MatchedWikiEntityName"], ans["NormalizedMatchedWikiEntityName"],
264 | ans.get("HumanAnswers"))
265 | if not (len(ans) == 7 or (len(ans) == 8 and "HumanAnswers" in ans)):
266 | raise ValueError()
267 | elif answer_type == "Numerical":
268 | answer = Numerical(float(ans["Number"]), ans["Aliases"], ans["NormalizedAliases"],
269 | ans["Value"], ans["Unit"], ans["NormalizedValue"],
270 | ans["Multiplier"], ans.get("HumanAnswers"))
271 | if not (len(ans) == 8 or (len(ans) == 9 and "HumanAnswers" in ans)):
272 | raise ValueError()
273 | elif answer_type == "FreeForm":
274 | answer = FreeForm(ans["Value"], ans["NormalizedValue"], ans["Aliases"],
275 | ans["NormalizedAliases"], ans.get("HumanAnswers"))
276 | if not (len(ans) == 5 or (len(ans) == 6 and "HumanAnswers" in ans)):
277 | raise ValueError()
278 | elif answer_type == "Range":
279 | answer = Range(ans["Value"], ans["NormalizedValue"], ans["Aliases"], ans["NormalizedAliases"],
280 | float(ans["To"]), float(ans["From"]), ans["Unit"],
281 | ans["Multiplier"], ans.get("HumanAnswers"))
282 | if not (len(ans) == 9 or (len(ans) == 10 and "HumanAnswers" in ans)):
283 | if "Number" in ans:
284 | # This appears to be a bug, the number fields in this
285 | # cases seem to be meaningless (and VERY rare)
286 | pass
287 | else:
288 | raise ValueError()
289 | else:
290 | raise ValueError()
291 | else:
292 | answer = None
293 |
294 | entity_pages = []
295 | for page in q["EntityPages"]:
296 | verified_attempt = page.get("DocVerifiedEvalAttempt", False)
297 | if verified_attempt and not page["DocPartOfVerifiedEval"]:
298 | continue
299 | title = page["Title"]
300 | if page["DocSource"] == "Search":
301 | entity_pages.append(SearchEntityDoc(title))
302 | elif page["DocSource"] == "TagMe":
303 | entity_pages.append(TagMeEntityDoc(page.get("Rho"), page.get("LinkProbability"), title))
304 | else:
305 | raise ValueError()
306 | filename = page.get("Filename")
307 | if filename is not None:
308 | filename = join("wikipedia", filename[:filename.rfind(".")])
309 | filename = normalize_wiki_filename(filename)
310 | cur = title_to_file.get(title)
311 | if cur is None:
312 | title_to_file[title] = filename
313 | elif cur != filename:
314 | raise ValueError()
315 | elif require_filename:
316 | raise ValueError()
317 |
318 | if "SearchResults" in q:
319 | web_pages = []
320 | for page in q["SearchResults"]:
321 | verified_attempt = page.get("DocVerifiedEvalAttempt", False)
322 | if verified_attempt and not page["DocPartOfVerifiedEval"]:
323 | continue
324 | url = page["Url"]
325 | web_pages.append(SearchDoc(page["Title"], page["Description"], page["Rank"], url))
326 | filename = page.get("Filename")
327 | if filename is not None:
328 | filename = join("web", filename[:filename.rfind(".")])
329 | cur = title_to_file.get(url)
330 | if cur is None:
331 | title_to_file[url] = filename
332 | elif cur != filename:
333 | raise ValueError()
334 | elif require_filename:
335 | raise ValueError()
336 | else:
337 | web_pages = None
338 |
339 | yield TriviaQaQuestion(q["Question"], q["QuestionId"], answer, entity_pages, web_pages)
340 |
341 |
342 | def iter_trivia_question(filename, file_map, require_filename):
343 | return build_questions(iter_question_json(filename), file_map, require_filename)
344 |
345 |
346 |
--------------------------------------------------------------------------------
/triviaqa/triviaqa_document_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import math
4 | import six
5 | import collections
6 | import numpy as np
7 | from typing import List
8 | import bert.tokenization as tokenization
9 | from squad.squad_utils import _improve_answer_span, _get_best_indexes, _compute_softmax, get_final_text
10 | from squad.squad_evaluate import exact_match_score, f1_score, span_f1
11 |
12 |
13 | class ExtractedParagraphWithAnswers(object):
14 | __slots__ = ["text", "start", "end", "answer_spans"]
15 |
16 | def __init__(self, text: List[str], start: int, end: int, answer_spans: np.ndarray):
17 | """
18 | :param text: List of source paragraphs that have been merged to form `self`
19 | :param start: start token of this text in the source document
20 | :param end: end token of this text in the source document
21 | """
22 | self.text = text
23 | self.start = start
24 | self.end = end
25 | self.answer_spans = answer_spans
26 |
27 | @property
28 | def n_context_words(self):
29 | return len(self.text)
30 |
31 | def __repr__(self):
32 | s = ""
33 | s += "text: %s ..." % (" ".join(self.text[:10]))
34 | s += ", start: %d" % (self.start)
35 | s += ", end: %d" % (self.end)
36 | s += ", answer_spans: {}".format(self.answer_spans)
37 | return s
38 |
39 |
40 | class DocParagraphWithAnswers(ExtractedParagraphWithAnswers):
41 | __slots__ = ["doc_id"]
42 |
43 | def __init__(self, text: List[str], start: int, end: int, answer_spans: np.ndarray,
44 | doc_id):
45 | super().__init__(text, start, end, answer_spans)
46 | self.doc_id = doc_id
47 |
48 |
49 | class DocumentAndQuestion(object):
50 | def __init__(self,
51 | document_id,
52 | qas_id,
53 | question_text, # str
54 | doc_tokens,
55 | orig_answer_texts=None,
56 | start_positions=None,
57 | end_positions=None):
58 | self.document_id = document_id
59 | self.qas_id = qas_id
60 | self.question_text = question_text
61 | self.doc_tokens = doc_tokens
62 | self.orig_answer_texts = orig_answer_texts
63 | self.start_positions = start_positions
64 | self.end_positions = end_positions
65 |
66 | def __str__(self):
67 | return self.__repr__()
68 |
69 | def __repr__(self):
70 | s = ""
71 | s += "document_id: %s" % (self.document_id)
72 | s += ", qas_id: %s" % (tokenization.printable_text(self.qas_id))
73 | s += ", question_text: %s" % (
74 | tokenization.printable_text(self.question_text))
75 | s += ", doc_tokens: %s ..." % (" ".join(self.doc_tokens[:20]))
76 | s += ", length of doc_tokens: %d" % (len(self.doc_tokens))
77 | if self.orig_answer_texts:
78 | s += ", orig_answer_texts: {}".format(self.orig_answer_texts)
79 | if self.start_positions and self.end_positions:
80 | s += ", start_positions: {}".format(self.start_positions)
81 | s += ", end_positions: {}".format(self.end_positions)
82 | s += ", token_answer: "
83 | for start, end in zip(self.start_positions, self.end_positions):
84 | s += "{}, ".format(" ".join(self.doc_tokens[start:(end+1)]))
85 | return s
86 |
87 |
88 | class InputFeatures(object):
89 | """A single set of features of data."""
90 |
91 | def __init__(self,
92 | unique_id,
93 | example_index,
94 | doc_span_index,
95 | tokens,
96 | token_to_orig_map,
97 | input_ids,
98 | input_mask,
99 | segment_ids,
100 | start_positions=None,
101 | end_positions=None,
102 | start_indexes=None,
103 | end_indexes=None,
104 | is_impossible=None):
105 | self.unique_id = unique_id
106 | self.example_index = example_index
107 | self.doc_span_index = doc_span_index
108 | self.tokens = tokens
109 | self.token_to_orig_map = token_to_orig_map
110 | self.input_ids = input_ids
111 | self.input_mask = input_mask
112 | self.segment_ids = segment_ids
113 | self.start_positions = start_positions
114 | self.end_positions = end_positions
115 | self.start_indexes = start_indexes
116 | self.end_indexes = end_indexes
117 | self.is_impossible = is_impossible
118 |
119 |
120 | def convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride,
121 | max_query_length, verbose_logging=False, logger=None):
122 | """Loads a data file into a list of `InputBatch`s."""
123 |
124 | unique_id = 1000000000
125 |
126 | features = []
127 | for (example_index, example) in enumerate(examples):
128 | query_tokens = tokenizer.tokenize(example.question_text)
129 |
130 | if len(query_tokens) > max_query_length:
131 | query_tokens = query_tokens[0:max_query_length]
132 |
133 | tok_to_orig_index = []
134 | orig_to_tok_index = []
135 | all_doc_tokens = []
136 | for (i, token) in enumerate(example.doc_tokens):
137 | orig_to_tok_index.append(len(all_doc_tokens))
138 | sub_tokens = tokenizer.tokenize(token)
139 | for sub_token in sub_tokens:
140 | tok_to_orig_index.append(i)
141 | all_doc_tokens.append(sub_token)
142 |
143 | tok_start_positions = []
144 | tok_end_positions = []
145 | for start_position, end_position in \
146 | zip(example.start_positions, example.end_positions):
147 | tok_start_position = orig_to_tok_index[start_position]
148 | if end_position < len(example.doc_tokens) - 1:
149 | tok_end_position = orig_to_tok_index[end_position + 1] - 1
150 | else:
151 | tok_end_position = len(all_doc_tokens) - 1
152 | tok_start_positions.append(tok_start_position)
153 | tok_end_positions.append(tok_end_position)
154 |
155 | # The -3 accounts for [CLS], [SEP] and [SEP]
156 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
157 |
158 | # We can have documents that are longer than the maximum sequence length.
159 | # To deal with this we do a sliding window approach, where we take chunks
160 | # of the up to our max length with a stride of `doc_stride`.
161 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
162 | "DocSpan", ["start", "length"])
163 | doc_spans = []
164 | start_offset = 0
165 | while start_offset < len(all_doc_tokens):
166 | length = len(all_doc_tokens) - start_offset
167 | if length > max_tokens_for_doc:
168 | length = max_tokens_for_doc
169 | doc_spans.append(_DocSpan(start=start_offset, length=length))
170 | if start_offset + length == len(all_doc_tokens):
171 | break
172 | start_offset += min(length, doc_stride)
173 |
174 | for (doc_span_index, doc_span) in enumerate(doc_spans):
175 | tokens = []
176 | token_to_orig_map = {}
177 | segment_ids = []
178 | tokens.append("[CLS]")
179 | segment_ids.append(0)
180 | for token in query_tokens:
181 | tokens.append(token)
182 | segment_ids.append(0)
183 | tokens.append("[SEP]")
184 | segment_ids.append(0)
185 |
186 | for i in range(doc_span.length):
187 | split_token_index = doc_span.start + i
188 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
189 | tokens.append(all_doc_tokens[split_token_index])
190 | segment_ids.append(1)
191 | tokens.append("[SEP]")
192 | segment_ids.append(1)
193 |
194 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
195 |
196 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
197 | # tokens are attended to.
198 | input_mask = [1] * len(input_ids)
199 |
200 | # Zero-pad up to the sequence length.
201 | while len(input_ids) < max_seq_length:
202 | input_ids.append(0)
203 | input_mask.append(0)
204 | segment_ids.append(0)
205 |
206 | assert len(input_ids) == max_seq_length
207 | assert len(input_mask) == max_seq_length
208 | assert len(segment_ids) == max_seq_length
209 |
210 | # For distant supervision, we annotate the positions of all answer spans
211 | start_positions = [0] * len(input_ids)
212 | end_positions = [0] * len(input_ids)
213 | start_indexes, end_indexes = [], []
214 | doc_start = doc_span.start
215 | doc_end = doc_span.start + doc_span.length - 1
216 | is_impossible = True
217 | for tok_start_position, tok_end_position in zip(tok_start_positions, tok_end_positions):
218 | if (tok_start_position >= doc_start and tok_end_position <= doc_end):
219 | doc_offset = len(query_tokens) + 2
220 | start_position = tok_start_position - doc_start + doc_offset
221 | end_position = tok_end_position - doc_start + doc_offset
222 | start_positions[start_position] = 1
223 | end_positions[end_position] = 1
224 | start_indexes.append(start_position)
225 | end_indexes.append(end_position)
226 | is_impossible = False
227 |
228 | if is_impossible:
229 | start_positions[0] = 1
230 | end_positions[0] = 1
231 | start_indexes.append(0)
232 | end_indexes.append(0)
233 |
234 | if example_index < 2 and verbose_logging:
235 | logger.info("*** Example ***")
236 | logger.info("unique_id: %s" % (unique_id))
237 | logger.info("example_index: %s" % (example_index))
238 | logger.info("doc_span_index: %s" % (doc_span_index))
239 | logger.info("doc_span_start: %s" % (doc_span.start))
240 | if is_impossible:
241 | logger.info("impossible example")
242 | else:
243 | logger.info("start_indexes: {}".format(start_indexes))
244 | logger.info("end_indexes: {}".format(end_indexes))
245 |
246 | features.append(
247 | InputFeatures(
248 | unique_id=unique_id,
249 | example_index=example_index,
250 | doc_span_index=doc_span_index,
251 | tokens=tokens,
252 | token_to_orig_map=token_to_orig_map,
253 | input_ids=input_ids,
254 | input_mask=input_mask,
255 | segment_ids=segment_ids,
256 | start_positions=start_positions,
257 | end_positions=end_positions,
258 | start_indexes=start_indexes,
259 | end_indexes=end_indexes,
260 | is_impossible=is_impossible))
261 | unique_id += 1
262 |
263 | if len(features) % 5000 == 0:
264 | logger.info("Processing features: %d" % (len(features)))
265 |
266 | return features
267 |
268 | def annotate_candidates(all_examples, batch_features, batch_results, filter_type, is_training, n_best_size,
269 | max_answer_length, do_lower_case, verbose_logging, logger):
270 | """Annotate top-k candidate answers into features."""
271 | unique_id_to_result = {}
272 | for result in batch_results:
273 | unique_id_to_result[result.unique_id] = result
274 |
275 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
276 | "PrelimPrediction",
277 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit", "rank_logit"])
278 |
279 | batch_span_starts, batch_span_ends, batch_hard_labels, batch_soft_labels = [], [], [], []
280 | for (feature_index, feature) in enumerate(batch_features):
281 | example = all_examples[feature.example_index]
282 | result = unique_id_to_result[feature.unique_id]
283 |
284 | prelim_predictions_per_feature = []
285 | start_indexes = _get_best_indexes(result.start_logits, n_best_size)
286 | end_indexes = _get_best_indexes(result.end_logits, n_best_size)
287 | for start_index in start_indexes:
288 | for end_index in end_indexes:
289 | # We could hypothetically create invalid predictions, e.g., predict
290 | # that the start of the span is in the question. We throw out all
291 | # invalid predictions.
292 | if start_index >= len(feature.tokens):
293 | continue
294 | if end_index >= len(feature.tokens):
295 | continue
296 | if start_index not in feature.token_to_orig_map:
297 | continue
298 | if end_index not in feature.token_to_orig_map:
299 | continue
300 | if end_index < start_index:
301 | continue
302 | length = end_index - start_index + 1
303 | if length > max_answer_length:
304 | continue
305 |
306 | prelim_predictions_per_feature.append(
307 | _PrelimPrediction(
308 | feature_index=feature_index,
309 | start_index=start_index,
310 | end_index=end_index,
311 | start_logit=result.start_logits[start_index],
312 | end_logit=result.end_logits[end_index],
313 | rank_logit=result.rank_logit))
314 |
315 | prelim_predictions_per_feature = sorted(
316 | prelim_predictions_per_feature,
317 | key=lambda x: (x.start_logit + x.end_logit + x.rank_logit),
318 | reverse=True)
319 |
320 | seen_predictions = {}
321 | span_starts, span_ends, hard_labels, soft_labels = [], [], [], []
322 |
323 | if is_training:
324 | # add no-answer option into candidate answers
325 | span_starts.append(0)
326 | span_ends.append(0)
327 | if feature.is_impossible:
328 | hard_labels.append(1)
329 | soft_labels.append(1.)
330 | else:
331 | hard_labels.append(0)
332 | soft_labels.append(0.)
333 |
334 | for i, pred_i in enumerate(prelim_predictions_per_feature):
335 | if len(span_starts) >= int(n_best_size/4):
336 | break
337 | tok_tokens = feature.tokens[pred_i.start_index:(pred_i.end_index + 1)]
338 | orig_doc_start = feature.token_to_orig_map[pred_i.start_index]
339 | orig_doc_end = feature.token_to_orig_map[pred_i.end_index]
340 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
341 | tok_text = " ".join(tok_tokens)
342 |
343 | # De-tokenize WordPieces that have been split off.
344 | tok_text = tok_text.replace(" ##", "")
345 | tok_text = tok_text.replace("##", "")
346 |
347 | # Clean whitespace
348 | tok_text = tok_text.strip()
349 | tok_text = " ".join(tok_text.split())
350 | orig_text = " ".join(orig_tokens)
351 |
352 | final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging, logger)
353 | if final_text in seen_predictions:
354 | continue
355 | seen_predictions[final_text] = True
356 |
357 | if is_training:
358 | if pred_i.start_index != 0 and pred_i.end_index != 0:
359 | span_starts.append(pred_i.start_index)
360 | span_ends.append(pred_i.end_index)
361 | if feature.is_impossible:
362 | hard_labels.append(0)
363 | soft_labels.append(0.)
364 | else:
365 | max_em, max_f1 = 0, 0
366 | for orig_answer_text in example.orig_answer_texts:
367 | em = int(exact_match_score(final_text, orig_answer_text))
368 | f1 = float(f1_score(final_text, orig_answer_text))
369 | if em > max_em:
370 | max_em = em
371 | if f1 > max_f1:
372 | max_f1 = f1
373 | hard_labels.append(max_em)
374 | soft_labels.append(max_f1)
375 | else:
376 | span_starts.append(pred_i.start_index)
377 | span_ends.append(pred_i.end_index)
378 |
379 | # filter out redundant candidates
380 | if (i+1) < len(prelim_predictions_per_feature):
381 | indexes = []
382 | for j, pred_j in enumerate(prelim_predictions_per_feature[(i+1):]):
383 | if filter_type == 'em':
384 | if pred_i.start_index == pred_j.start_index or pred_i.end_index == pred_j.end_index:
385 | indexes.append(i + j + 1)
386 | elif filter_type == 'f1':
387 | if span_f1([pred_i.start_index, pred_i.end_index], [pred_j.start_index, pred_j.end_index]) > 0:
388 | indexes.append(i + j + 1)
389 | elif filter_type == 'none':
390 | indexes = []
391 | else:
392 | raise Exception
393 | [prelim_predictions_per_feature.pop(index - k) for k, index in enumerate(indexes)]
394 |
395 | # Pad to fixed length
396 | while len(span_starts) < int(n_best_size/4):
397 | span_starts.append(0)
398 | span_ends.append(0)
399 | if is_training:
400 | if feature.is_impossible:
401 | hard_labels.append(1)
402 | soft_labels.append(1.)
403 | else:
404 | hard_labels.append(0)
405 | soft_labels.append(0.)
406 | assert len(span_starts) == int(n_best_size/4)
407 | if is_training:
408 | assert len(hard_labels) == int(n_best_size/4)
409 |
410 | # Add ground truth answer spans if there is no positive label
411 | if is_training:
412 | if max(hard_labels) == 0:
413 | sample_start = random.sample(feature.start_indexes, 1)
414 | sample_end = random.sample(feature.end_indexes, 1)
415 | span_starts[-1] = sample_start[0]
416 | span_ends[-1] = sample_end[0]
417 | hard_labels[-1] = 1
418 | soft_labels[-1] = 1.
419 |
420 | batch_span_starts.append(span_starts)
421 | batch_span_ends.append(span_ends)
422 | batch_hard_labels.append(hard_labels)
423 | batch_soft_labels.append(soft_labels)
424 | return batch_span_starts, batch_span_ends, batch_hard_labels, batch_soft_labels
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
--------------------------------------------------------------------------------
/triviaqa/triviaqa_eval.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for v1.0 of the TriviaQA dataset.
2 | Extended from the evaluation script for v1.1 of the SQuAD dataset.
3 |
4 | (Additionally condensed into a single file)
5 | """
6 | from __future__ import print_function
7 |
8 | import json
9 | from collections import Counter
10 | import string
11 | import re
12 | import sys
13 | import argparse
14 |
15 | import unicodedata
16 |
17 |
18 | def normalize_answer(s):
19 | """Lower text and remove punctuation, articles and extra whitespace."""
20 |
21 | def remove_articles(text):
22 | return re.sub(r'\b(a|an|the)\b', ' ', text)
23 |
24 | def white_space_fix(text):
25 | return ' '.join(text.split())
26 |
27 | def handle_punc(text):
28 | exclude = set(string.punctuation + "".join([u"‘", u"’", u"´", u"`"]))
29 | return ''.join(ch if ch not in exclude else ' ' for ch in text)
30 |
31 | def lower(text):
32 | return text.lower()
33 |
34 | def replace_underscore(text):
35 | return text.replace('_', ' ')
36 |
37 | return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(s))))).strip()
38 |
39 |
40 | def f1_score(prediction, ground_truth):
41 | prediction_tokens = normalize_answer(prediction).split()
42 | ground_truth_tokens = normalize_answer(ground_truth).split()
43 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
44 | num_same = sum(common.values())
45 | if num_same == 0:
46 | return 0
47 | precision = 1.0 * num_same / len(prediction_tokens)
48 | recall = 1.0 * num_same / len(ground_truth_tokens)
49 | f1 = (2 * precision * recall) / (precision + recall)
50 | return f1
51 |
52 |
53 | def exact_match_score(prediction, ground_truth):
54 | return normalize_answer(prediction) == normalize_answer(ground_truth)
55 |
56 |
57 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
58 | scores_for_ground_truths = []
59 | for ground_truth in ground_truths:
60 | score = metric_fn(prediction, ground_truth)
61 | scores_for_ground_truths.append(score)
62 | return max(scores_for_ground_truths)
63 |
64 |
65 | def get_ground_truths(answer):
66 | return answer['NormalizedAliases'] + [normalize_answer(ans) for ans in answer.get('HumanAnswers', [])]
67 |
68 |
69 | def get_file_contents(filename, encoding='utf-8'):
70 | with open(filename, encoding=encoding) as f:
71 | content = f.read()
72 | return content
73 |
74 |
75 | def read_json(filename, encoding='utf-8'):
76 | contents = get_file_contents(filename, encoding=encoding)
77 | return json.loads(contents)
78 |
79 |
80 | def is_exact_match(answer_object, prediction):
81 | ground_truths = get_ground_truths(answer_object)
82 | for ground_truth in ground_truths:
83 | if exact_match_score(prediction, ground_truth):
84 | return True
85 | return False
86 |
87 |
88 | def has_exact_match(ground_truths, candidates):
89 | for ground_truth in ground_truths:
90 | if ground_truth in candidates:
91 | return True
92 | return False
93 |
94 |
95 | def get_key_to_ground_truth(data):
96 | if data['Domain'] == 'Wikipedia':
97 | return {datum['QuestionId']: datum['Answer'] for datum in data['Data']}
98 | else:
99 | return get_qd_to_answer(data)
100 |
101 |
102 | def get_key_to_ground_truth_per_question(data):
103 | return {datum['QuestionId']: datum['Answer'] for datum in data['Data']}
104 |
105 |
106 | def get_question_doc_string(qid, doc_name):
107 | return '{}--{}'.format(qid, unicodedata.normalize("NFD", doc_name).lower())
108 |
109 |
110 | def get_qd_to_answer(data):
111 | key_to_answer = {}
112 | for datum in data['Data']:
113 | for page in datum.get('EntityPages', []) + datum.get('SearchResults', []):
114 | qd_tuple = get_question_doc_string(datum['QuestionId'], page['Filename'])
115 | key_to_answer[qd_tuple] = datum['Answer']
116 | return key_to_answer
117 |
118 |
119 | def evaluate_triviaqa(ground_truth, predicted_answers, qid_list=None):
120 | f1 = exact_match = common = 0
121 | missing_count = 0
122 | if qid_list is None:
123 | qid_list = ground_truth.keys()
124 | for qid in qid_list:
125 | if qid not in predicted_answers:
126 | missing_count += 1
127 | # message = 'Missed question {} will receive score 0.'.format(qid)
128 | # print(message, file=sys.stderr)
129 | continue
130 | if qid not in ground_truth:
131 | missing_count += 1
132 | continue
133 | common += 1
134 | prediction = predicted_answers[qid]
135 | ground_truths = get_ground_truths(ground_truth[qid])
136 | em_for_this_question = metric_max_over_ground_truths(
137 | exact_match_score, prediction, ground_truths)
138 | exact_match += em_for_this_question
139 | f1_for_this_question = metric_max_over_ground_truths(
140 | f1_score, prediction, ground_truths)
141 | f1 += f1_for_this_question
142 |
143 | exact_match = 100.0 * exact_match / len(qid_list)
144 | f1 = 100.0 * f1 / len(qid_list)
145 |
146 | print("missing prediction on %d examples" % (missing_count))
147 | return {'exact_match': exact_match, 'f1': f1, 'common': common, 'denominator': len(qid_list),
148 | 'pred_len': len(predicted_answers), 'gold_len': len(ground_truth)}
149 |
150 |
151 | def read_clean_part(datum):
152 | for key in ['EntityPages', 'SearchResults']:
153 | new_page_list = []
154 | for page in datum.get(key, []):
155 | if page['DocPartOfVerifiedEval']:
156 | new_page_list.append(page)
157 | datum[key] = new_page_list
158 | assert len(datum['EntityPages']) + len(datum['SearchResults']) > 0
159 | return datum
160 |
161 |
162 | def read_triviaqa_data(qajson):
163 | data = read_json(qajson)
164 | # read only documents and questions that are a part of clean data set
165 | if data['VerifiedEval']:
166 | clean_data = []
167 | for datum in data['Data']:
168 | if datum['QuestionPartOfVerifiedEval']:
169 | if data['Domain'] == 'Web':
170 | datum = read_clean_part(datum)
171 | clean_data.append(datum)
172 | data['Data'] = clean_data
173 | return data
174 |
175 |
176 | def get_args():
177 | parser = argparse.ArgumentParser(description='Evaluation for TriviaQA')
178 | parser.add_argument('--dataset_file', help='Dataset file')
179 | parser.add_argument('--prediction_file', help='Prediction File')
180 | args = parser.parse_args()
181 | return args
182 |
183 |
184 | if __name__ == '__main__':
185 | expected_version = 1.0
186 | args = get_args()
187 |
188 | dataset_json = read_triviaqa_data(args.dataset_file)
189 | if dataset_json['Version'] != expected_version:
190 | print('Evaluation expects v-{} , but got dataset with v-{}'.format(expected_version,dataset_json['Version']),
191 | file=sys.stderr)
192 | key_to_ground_truth = get_key_to_ground_truth(dataset_json)
193 | predictions = read_json(args.prediction_file)
194 | eval_dict = evaluate_triviaqa(key_to_ground_truth, predictions)
195 | print(eval_dict)
--------------------------------------------------------------------------------
/triviaqa/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, TypeVar, Iterable
2 | import collections
3 | import re
4 | import string
5 | T = TypeVar('T')
6 |
7 | stop_words = {'t', '–', 'there', 'but', 'needn', 'themselves', '’', '~', '$', 'few', '^', '₹', ']', 'we', 're',
8 | 'again', '?', 'they', 'ain', 'o', 'you', '+', 'has', 'by', 'than', 'whom', 'same', 'don', 'her',
9 | 'are', '(', 'an', 'so', 'the', 'been', 'wouldn', 'a', 'many', 'she', 'how', 'your', '°', 'do',
10 | 'shan', 'himself', 'between', 'ours', 'at', 'should', 'doesn', 'hasn', 'he', 'have', 'over',
11 | 'hadn', 'was', 'weren', 'down', 'above', '_', 'those', 'not', 'having', 'its', 'ourselves',
12 | 'for', 'when', 'if', ',', ';', 'about', 'theirs', 'him', '}', 'here', 'any', 'own', 'itself',
13 | 'very', 'on', 'myself', 'mustn', ')', 'because', 'now', '/', 'isn', 'to', 'just', 'these',
14 | 'i', 'further', 'mightn', 'll', '@', 'am', '”', 'below', 'shouldn', 'my', 'who', 'yours', 'why',
15 | 'such', '"', 'does', 'did', 'before', 'being', 'and', 'had', 'aren', '£', 'with', 'more', 'into',
16 | '<', 'herself', 'which', '[', "'", 'of', 'haven', 'that', 'will', 'yourself', 'in', 'doing', '−',
17 | 'them', '‘', 'some', '`', 'while', 'each', 'it', 'through', 'all', 'their', ':', '\\', 'where',
18 | 'both', 'hers', '¢', '—', 'm', '.', 'from', 'or', 'other', 'too', 'couldn', 'as', 'our', 'off',
19 | '%', '&', '-', '{', '=', 'didn', 'yourselves', 'under', 'y', 'ma', 'won', '!', '|', 'against',
20 | '#', '¥', 'is', 'nor', 'up', 'most', 's', 'no', 'can', '>', '*', 'during', 'once', 'what', 'me',
21 | 'then', 'd', 'only', 'de', 've', 'were', '€', 'until', 'his', 'out', 'wasn', 'this', 'after',
22 | 'be', "' s", "' t"}
23 |
24 | def flatten_iterable(listoflists: Iterable[Iterable[T]]) -> List[T]:
25 | return [item for sublist in listoflists for item in sublist]
26 |
27 |
28 | def split(lst: List[T], n_groups) -> List[List[T]]:
29 | """ partition `lst` into `n_groups` that are as evenly sized as possible """
30 | per_group = len(lst) // n_groups
31 | remainder = len(lst) % n_groups
32 | groups = []
33 | ix = 0
34 | for _ in range(n_groups):
35 | group_size = per_group
36 | if remainder > 0:
37 | remainder -= 1
38 | group_size += 1
39 | groups.append(lst[ix:ix + group_size])
40 | ix += group_size
41 | return groups
42 |
43 | def group(lst: List[T], max_group_size) -> List[List[T]]:
44 | """ partition `lst` into that the mininal number of groups that as evenly sized
45 | as possible and are at most `max_group_size` in size """
46 | if max_group_size is None:
47 | return [lst]
48 | n_groups = (len(lst)+max_group_size-1) // max_group_size
49 | per_group = len(lst) // n_groups
50 | remainder = len(lst) % n_groups
51 | groups = []
52 | ix = 0
53 | for _ in range(n_groups):
54 | group_size = per_group
55 | if remainder > 0:
56 | remainder -= 1
57 | group_size += 1
58 | groups.append(lst[ix:ix + group_size])
59 | ix += group_size
60 | return groups
61 |
62 |
63 | def simple_normalize_answer(s):
64 | """Lower text and remove punctuation, articles and extra whitespace."""
65 | def white_space_fix(text):
66 | return ' '.join(text.split())
67 |
68 | def lower(text):
69 | return text.lower()
70 |
71 | return white_space_fix(lower(s))
72 |
73 |
74 | def normalize_answer(s):
75 | """Lower text and remove punctuation, articles and extra whitespace."""
76 | def remove_articles(text):
77 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
78 | return re.sub(regex, ' ', text)
79 | def white_space_fix(text):
80 | return ' '.join(text.split())
81 | def remove_punc(text):
82 | exclude = set(string.punctuation)
83 | return ''.join(ch for ch in text if ch not in exclude)
84 | def lower(text):
85 | return text.lower()
86 | return white_space_fix(remove_articles(remove_punc(lower(s))))
87 |
88 |
89 | def get_tokens(s):
90 | if not s: return []
91 | return normalize_answer(s).split()
92 |
93 |
94 | def compute_f1(a_gold, a_pred):
95 | gold_toks = get_tokens(a_gold)
96 | pred_toks = get_tokens(a_pred)
97 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
98 | num_same = sum(common.values())
99 | if len(gold_toks) == 0 or len(pred_toks) == 0:
100 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
101 | return int(gold_toks == pred_toks)
102 | if num_same == 0:
103 | return 0
104 | precision = 1.0 * num_same / len(pred_toks)
105 | recall = 1.0 * num_same / len(gold_toks)
106 | f1 = (2 * precision * recall) / (precision + recall)
107 | return f1
108 |
109 |
110 | def get_max_f1_span(words, answer, window_size):
111 | max_f1 = 0
112 | max_span = (0, 0)
113 |
114 | for idx1, word1 in enumerate(words):
115 | for idx2, word2 in enumerate(words[idx1: idx1 + window_size + 1]):
116 | candidate_answer = words[idx1: idx1 + idx2 + 1]
117 | f1 = compute_f1(' '.join(answer), ' '.join(candidate_answer))
118 | if f1 > max_f1:
119 | max_f1 = f1
120 | max_span = (idx1, idx1 + idx2)
121 | return max_span, max_f1
--------------------------------------------------------------------------------