├── LICENSE.txt ├── README.md ├── bert ├── custom_modeling.py ├── modeling.py ├── optimization.py ├── run_squad_document_full_e2e.py ├── run_triviaqa_wiki_full_e2e.py └── tokenization.py ├── data └── squad │ └── dev-v1.1.json ├── image └── framework.PNG ├── squad ├── convert_squad_open.py ├── squad_document_utils.py ├── squad_evaluate.py ├── squad_open_utils.py └── squad_utils.py └── triviaqa ├── ablate_triviaqa_unfiltered.py ├── ablate_triviaqa_wiki.py ├── answer_detection.py ├── build_span_corpus.py ├── configurable.py ├── evidence_corpus.py ├── preprocessed_corpus.py ├── read_data.py ├── triviaqa_document_utils.py ├── triviaqa_eval.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Retrieve, Read, Rerank: Towards End-to-End Multi-Document Reading Comprehension 2 | 3 | This repo contains the code of the following paper: 4 | 5 | [Retrieve, Read, Rerank: Towards End-to-End Multi-Document Reading Comprehension](https://arxiv.org/abs/1906.04618). Minghao Hu, Yuxing Peng, Zhen Huang, Dongsheng Li. ACL 2019. 6 | 7 | In this paper, we propose an end-to-end neural network for the multi-document reading comprehension task, which is shown as below: 8 |

9 | 10 |

11 | 12 | This network consists of three components: 13 | - Early-stopped retriever 14 | - Distantly-supervised reader 15 | - Answer reranker 16 | 17 | Given multiple documents, the network is designed to retrieve relevant document content, propose multiple answer candidates, and finally rerank these candidates. We utilize [BERT](https://github.com/huggingface/pytorch-pretrained-BERT) to initialize our network. The whole network is trained end-to-end with a multi-task objective. 18 | 19 | ## Pre-trained Models 20 | To reproduce our results, we release the following pre-trained models: 21 | - [squad_doc_base](https://drive.google.com/file/d/16lTmN2wu31QdUvExW_fGcDJxKnR7f912/view?usp=sharing) 22 | - [triviaqa_wiki_base](https://drive.google.com/file/d/1Re_2KxBlCQ9_sxTmkZGoahjX72c1eCfk/view?usp=sharing) 23 | - [triviaqa_unfiltered_base](https://drive.google.com/file/d/1kqF40UhJAC6XkAbywI-YMIg_C5t0oS2Q/view?usp=sharing) 24 | 25 | ## Requirements 26 | - Python 3.6 27 | - [Pytorch 1.1](https://pytorch.org/) 28 | 29 | Download the uncased [BERT-Base](https://drive.google.com/file/d/13I0Gj7v8lYhW5Hwmp5kxm3CTlzWZuok2/view?usp=sharing) model and unzip it in the current directory. 30 | 31 | ## SQuAD-document 32 | To run experiments on the SQuAD-document dataset, first set up the environment: 33 | ```bash 34 | export DATA_DIR=data/squad 35 | export BERT_DIR=bert-base-uncased 36 | ``` 37 | 38 | Make sure `train-v1.1.json` and `dev-v1.1.json` are placed in `DATA_DIR`. 39 | 40 | Then run the following command to train the model: 41 | ```shell 42 | python -m bert.run_squad_document_full_e2e \ 43 | --vocab_file $BERT_DIR/vocab.txt \ 44 | --bert_config_file $BERT_DIR/bert_config.json \ 45 | --init_checkpoint $BERT_DIR/pytorch_model.bin \ 46 | --do_train \ 47 | --do_predict \ 48 | --data_dir $DATA_DIR \ 49 | --train_file train-v1.1.json \ 50 | --predict_file dev-v1.1.json \ 51 | --train_batch_size 32 \ 52 | --learning_rate 3e-5 \ 53 | --num_train_epochs 2.0 \ 54 | --output_dir out/squad_doc/01 55 | ``` 56 | All experiments in our paper were conducted on 4 NVIDIA TESLA P40 (22GB memory per card). The training took nearly 22 hours to converge. If you do not have enough GPU capacity, you can change several hyper-parameters such as ( 57 | these changes might cause performance degradation.): 58 | - `--train_batch_size`: total batch size for training. 59 | - `--n_para_train`: the number of paragraph retrieved by TF-IDF during training (denoted as `K` in our paper). 60 | - `--n_best_size_rank`: the number of segments retrieved by early-stopped retriever (denoted as `N` in our paper). 61 | - `--num_hidden_rank`: the number of Transformer blocks used for retrieving (denoted as `J` in our paper). 62 | - `--gradient_accumulation_steps`: number of updates steps to accumulate before performing a backward/update pass. 63 | - `--optimize_on_cpu`: whether to perform optimization and keep the optimizer averages on CPU. 64 | 65 | The base model can be trained on 2 Geforce GTX TITAN (12GB memory per card) with the following command: 66 | ```shell 67 | python -m bert.run_squad_document_full_e2e \ 68 | --vocab_file $BERT_DIR/vocab.txt \ 69 | --bert_config_file $BERT_DIR/bert_config.json \ 70 | --init_checkpoint $BERT_DIR/pytorch_model.bin \ 71 | --do_train \ 72 | --do_predict \ 73 | --data_dir $DATA_DIR \ 74 | --train_file train-v1.1.json \ 75 | --predict_file dev-v1.1.json \ 76 | --train_batch_size 32 \ 77 | --learning_rate 3e-5 \ 78 | --num_train_epochs 2.0 \ 79 | --optimize_on_cpu \ 80 | --gradient_accumulation_steps 4 \ 81 | --output_dir out/squad_doc/01 82 | ``` 83 | 84 | Finally, you can get a dev result from `out/squad_doc/01/performance.txt` like this: 85 | ```bash 86 | Ranker, type: test, step: 19332, map: 0.891, mrr: 0.916, top_1: 0.880, top_3: 0.945, top_5: 0.969, top_7: 0.977, retrieval_rate: 0.558 87 | Reader, type: test, step: 19332, test_em: 77.909, test_f1: 84.817 88 | ``` 89 | 90 | ## SQuAD-open 91 | Once you have trained a model on document-level SQuAD, you can evaluate it on the open-domain version of SQuAD dataset. 92 | 93 | First, download the pre-processed [SQuAD-open dev set](https://drive.google.com/file/d/1oBqoNNGVV2yCKvEWv5k91PBUHNDl5q8J/view?usp=sharing) and place it in `data/squad/` 94 | 95 | Then run the following command to evaluate the model: 96 | ```shell 97 | python -m bert.run_squad_document_full_e2e \ 98 | --vocab_file $BERT_DIR/vocab.txt \ 99 | --bert_config_file $BERT_DIR/bert_config.json \ 100 | --do_predict_open \ 101 | --data_dir $DATA_DIR \ 102 | --output_dir out/squad_doc/01 103 | ``` 104 | 105 | You can get a dev result from `out/squad_doc/01/performance.txt` like this: 106 | ```bash 107 | Ranker, type: test_open, step: 19332, map: 0.000, mrr: 0.000, top_1: 0.000, top_3: 0.000, top_5: 0.000, top_7: 0.000, retrieval_rate: 0.190 108 | Reader, type: test_open, step: 19332, em: 40.123, f1: 48.358 109 | ``` 110 | 111 | ## TriviaQA 112 | ### Data Preprocessing 113 | The raw TriviaQA data is expected to be unzipped in `~/data/triviaqa`. Training 114 | or testing in the unfiltered setting requires the unfiltered data to be 115 | download to `~/data/triviaqa-unfiltered`. 116 | ```bash 117 | mkdir -p ~/data/triviaqa 118 | cd ~/data/triviaqa 119 | wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz 120 | tar xf triviaqa-rc.tar.gz 121 | rm triviaqa-rc.tar.gz 122 | 123 | cd ~/data 124 | wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz 125 | tar xf triviaqa-unfiltered.tar.gz 126 | rm triviaqa-unfiltered.tar.gz 127 | ``` 128 | 129 | First tokenize evidence documents by 130 | ```shell 131 | python -m triviaqa.evidence_corpus --n_processes 8 --max_tokens 200 132 | ``` 133 | where paragraphs that are less than 200 words are merged. 134 | 135 | Then tokenize questions and locate relevant answers spans in each document. Run 136 | ```shell 137 | python -m triviaqa.build_span_corpus {wiki|unfiltered} --n_processes 8 138 | ``` 139 | to build the desired set. This builds pkl files in "data/triviaqa/{wiki|unfiltered}" 140 | 141 | Next, retrieve top-n paragraphs based on TF-IDF to construct the train and dev sets by 142 | ```shell 143 | python -m triviaqa.ablate_triviaqa_wiki --n_processes 8 --n_para_train 12 --n_para_dev 14 --n_para_test 14 --do_train --do_dev --do_test 144 | python -m triviaqa.ablate_triviaqa_unfiltered --n_processes 8 --n_para_train 12 --n_para_dev 14 --n_para_test 14 --do_train --do_dev --do_test 145 | ``` 146 | 147 | ### Wikipedia Domain 148 | To run experiments on the TriviaQA-wiki dataset, first set up the environment: 149 | ```bash 150 | export DATA_DIR=data/triviaqa/wiki 151 | export BERT_DIR=bert-base-uncased 152 | ``` 153 | 154 | Then run the the following command to train the model: 155 | ```shell 156 | python -m bert.run_triviaqa_wiki_full_e2e \ 157 | --vocab_file $BERT_DIR/vocab.txt \ 158 | --bert_config_file $BERT_DIR/bert_config.json \ 159 | --init_checkpoint $BERT_DIR/pytorch_model.bin \ 160 | --do_train \ 161 | --do_dev \ 162 | --data_dir $DATA_DIR \ 163 | --train_batch_size 32 \ 164 | --learning_rate 3e-5 \ 165 | --num_train_epochs 2.0 \ 166 | --output_dir out/triviaqa_wiki/01 167 | ``` 168 | 169 | Once the training is finished, a dev result can be obtained from `out/triviaqa_wiki/01/performance.txt` as: 170 | ```bash 171 | Ranker, type: dev, step: 20088, map: 0.778, mrr: 0.849, top_1: 0.797, top_3: 0.888, top_5: 0.918, top_7: 0.932, retrieval_rate: 0.460 172 | Reader, type: dev, step: 20088, em: 68.510, f1: 72.680 173 | ``` 174 | 175 | ### Unfiltered Domain 176 | To run experiments on the TriviaQA-unfiltered dataset, first set up the environment: 177 | ```bash 178 | export DATA_DIR=data/triviaqa/unfiltered 179 | export BERT_DIR=bert-base-uncased 180 | ``` 181 | 182 | Then run the the following command to train the model: 183 | ```shell 184 | python -m bert.run_triviaqa_wiki_full_e2e \ 185 | --vocab_file $BERT_DIR/vocab.txt \ 186 | --bert_config_file $BERT_DIR/bert_config.json \ 187 | --init_checkpoint $BERT_DIR/pytorch_model.bin \ 188 | --do_train \ 189 | --do_dev \ 190 | --data_dir $DATA_DIR \ 191 | --train_batch_size 32 \ 192 | --learning_rate 3e-5 \ 193 | --num_train_epochs 2.0 \ 194 | --output_dir out/triviaqa_unfiltered/01 195 | ``` 196 | 197 | Once the training is finished, a dev result can be obtained from `out/triviaqa_unfiltered/01/performance.txt` as: 198 | ```bash 199 | Ranker, type: dev, step: 26726, map: 0.737, mrr: 0.781, top_1: 0.749, top_3: 0.806, top_5: 0.824, top_7: 0.831, retrieval_rate: 0.392 200 | Reader, type: dev, step: 26726, em: 63.953, f1: 69.506 201 | ``` 202 | 203 | ## Acknowledgements 204 | Some preprocessing codes were modified from the [document-qa](https://github.com/allenai/document-qa) implementation. 205 | 206 | The BERT implementation is based on [pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT). 207 | 208 | If you find the paper or this repository helpful in your work, please use the following citation: 209 | ``` 210 | @inproceedings{hu2019retrieve, 211 | title={Retrieve, Read, Rerank: Towards End-to-End Multi-Document Reading Comprehension}, 212 | author={Hu, Minghao and Peng, Yuxing and Huang, Zhen and Li, Dongsheng}, 213 | booktitle={Proceedings of ACL}, 214 | year={2019} 215 | } 216 | ``` 217 | -------------------------------------------------------------------------------- /bert/custom_modeling.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import copy 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import CrossEntropyLoss, MSELoss 9 | 10 | from bert.modeling import BertConfig, BERTLayerNorm, BERTLayer, BERTEmbeddings, BERTPooler 11 | 12 | def flatten(x): 13 | if len(x.size()) == 2: 14 | batch_size = x.size()[0] 15 | seq_length = x.size()[1] 16 | return x.view([batch_size * seq_length]) 17 | elif len(x.size()) == 3: 18 | batch_size = x.size()[0] 19 | seq_length = x.size()[1] 20 | hidden_size = x.size()[2] 21 | return x.view([batch_size * seq_length, hidden_size]) 22 | else: 23 | raise Exception() 24 | 25 | def reconstruct(x, ref): 26 | if len(x.size()) == 1: 27 | batch_size = ref.size()[0] 28 | turn_num = ref.size()[1] 29 | return x.view([batch_size, turn_num]) 30 | elif len(x.size()) == 2: 31 | batch_size = ref.size()[0] 32 | turn_num = ref.size()[1] 33 | sequence_length = x.size()[1] 34 | return x.view([batch_size, turn_num, sequence_length]) 35 | else: 36 | raise Exception() 37 | 38 | def flatten_emb_by_sentence(emb, emb_mask): 39 | batch_size = emb.size()[0] 40 | seq_length = emb.size()[1] 41 | flat_emb = flatten(emb) 42 | flat_emb_mask = emb_mask.view([batch_size * seq_length]) 43 | return flat_emb[flat_emb_mask.nonzero().squeeze(), :] 44 | 45 | def get_span_representation(span_starts, span_ends, input, input_mask): 46 | ''' 47 | :param span_starts: [N, M] 48 | :param span_ends: [N, M] 49 | :param input: [N, L, D] 50 | :param input_mask: [N, L] 51 | :return: [N*M, JR, D], [N*M, JR] 52 | ''' 53 | input_mask = input_mask.to(dtype=span_starts.dtype) # fp16 compatibility 54 | input_len = torch.sum(input_mask, dim=-1) # [N] 55 | word_offset = torch.cumsum(input_len, dim=0) # [N] 56 | word_offset -= input_len 57 | 58 | span_starts_offset = span_starts + word_offset.unsqueeze(1) 59 | span_ends_offset = span_ends + word_offset.unsqueeze(1) 60 | 61 | span_starts_offset = span_starts_offset.view([-1]) # [N*M] 62 | span_ends_offset = span_ends_offset.view([-1]) 63 | 64 | span_width = span_ends_offset - span_starts_offset + 1 65 | JR = torch.max(span_width) 66 | 67 | context_outputs = flatten_emb_by_sentence(input, input_mask) # [ 1: 262 | start_positions = start_positions.squeeze(-1) 263 | if len(end_positions.size()) > 1: 264 | end_positions = end_positions.squeeze(-1) 265 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 266 | ignored_index = start_logits.size(1) 267 | start_positions.clamp_(0, ignored_index) 268 | end_positions.clamp_(0, ignored_index) 269 | 270 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 271 | start_loss = loss_fct(start_logits, start_positions) 272 | end_loss = loss_fct(end_logits, end_positions) 273 | read_loss = (start_loss + end_loss) / 2 274 | 275 | assert span_starts is not None and span_ends is not None and hard_labels is not None and soft_labels is not None 276 | span_output, span_mask = get_span_representation(span_starts, span_ends, sequence_output, 277 | attention_mask) # [N*M, JR, D], [N*M, JR] 278 | span_score = self.rerank_affine(span_output) 279 | span_score = span_score.squeeze(-1) # [N*M, JR] 280 | span_pooled_output = get_self_att_representation(span_output, span_score, span_mask) # [N*M, D] 281 | 282 | span_pooled_output = self.rerank_dense(span_pooled_output) 283 | span_pooled_output = self.activation(span_pooled_output) 284 | span_pooled_output = self.dropout(span_pooled_output) 285 | rerank_logits = self.rerank_classifier(span_pooled_output).squeeze(-1) 286 | rerank_logits = reconstruct(rerank_logits, span_starts) 287 | 288 | hard_loss = distant_cross_entropy(rerank_logits, hard_labels) 289 | soft_loss_fct = MSELoss() 290 | soft_loss = soft_loss_fct(rerank_logits, soft_labels.to(dtype=rerank_logits.dtype)) 291 | rerank_loss = hard_loss + soft_loss 292 | return read_loss + rerank_loss 293 | 294 | else: 295 | raise Exception 296 | 297 | class BertForRankingAndDistantReadingAndReranking(nn.Module): 298 | def __init__(self, config, num_hidden_rank): 299 | super(BertForRankingAndDistantReadingAndReranking, self).__init__() 300 | super(BertForRankingAndDistantReadingAndReranking, self).__init__() 301 | self.num_hidden_rank = num_hidden_rank 302 | self.num_hidden_read = config.num_hidden_layers 303 | self.bert = EarlyStopBertModel(config) 304 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 305 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 306 | self.activation = nn.Tanh() 307 | self.rank_affine = nn.Linear(config.hidden_size, 1) 308 | self.rank_dense = nn.Linear(config.hidden_size, config.hidden_size) 309 | self.rank_classifier = nn.Linear(config.hidden_size, 2) 310 | self.read_affine = nn.Linear(config.hidden_size, 2) 311 | self.rerank_affine = nn.Linear(config.hidden_size, 1) 312 | self.rerank_dense = nn.Linear(config.hidden_size, config.hidden_size) 313 | self.rerank_classifier = nn.Linear(config.hidden_size, 1) 314 | 315 | def init_weights(module): 316 | if isinstance(module, (nn.Linear, nn.Embedding)): 317 | # Slightly different from the TF version which uses truncated_normal for initialization 318 | # cf https://github.com/pytorch/pytorch/pull/5617 319 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 320 | elif isinstance(module, BERTLayerNorm): 321 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 322 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 323 | if isinstance(module, nn.Linear): 324 | module.bias.data.zero_() 325 | self.apply(init_weights) 326 | 327 | def forward(self, mode, attention_mask, input_ids=None, token_type_ids=None, rank_labels=None, start_positions=None, 328 | end_positions=None, span_starts=None, span_ends=None, hard_labels=None, soft_labels=None, 329 | sequence_input=None): 330 | if mode == 'rank': 331 | assert input_ids is not None and token_type_ids is not None 332 | all_encoder_layers, _ = self.bert(self.num_hidden_rank, input_ids, token_type_ids, attention_mask) 333 | sequence_output = all_encoder_layers[-1] 334 | 335 | sequence_weights = self.rank_affine(sequence_output).squeeze(-1) 336 | pooled_output = get_self_att_representation(sequence_output, sequence_weights, attention_mask) 337 | 338 | pooled_output = self.rank_dense(pooled_output) 339 | pooled_output = self.activation(pooled_output) 340 | pooled_output = self.dropout(pooled_output) 341 | rank_logits = self.rank_classifier(pooled_output) 342 | 343 | if rank_labels is not None: 344 | rank_loss_fct = CrossEntropyLoss() 345 | rank_loss = rank_loss_fct(rank_logits, rank_labels) 346 | return rank_loss 347 | else: 348 | return rank_logits 349 | 350 | elif mode == 'read_inference': 351 | assert input_ids is not None and token_type_ids is not None 352 | all_encoder_layers, _ = self.bert(self.num_hidden_read, input_ids, token_type_ids, attention_mask) 353 | sequence_output = all_encoder_layers[-1] 354 | 355 | logits = self.read_affine(sequence_output) 356 | start_logits, end_logits = logits.split(1, dim=-1) 357 | start_logits = start_logits.squeeze(-1) 358 | end_logits = end_logits.squeeze(-1) 359 | return start_logits, end_logits, sequence_output 360 | 361 | elif mode == 'rerank_inference': 362 | assert span_starts is not None and span_ends is not None and sequence_input is not None 363 | span_output, span_mask = get_span_representation(span_starts, span_ends, sequence_input, 364 | attention_mask) # [N*M, JR, D], [N*M, JR] 365 | 366 | span_weights = self.rerank_affine(span_output).squeeze(-1) 367 | span_pooled_output = get_self_att_representation(span_output, span_weights, span_mask) # [N*M, D] 368 | 369 | span_pooled_output = self.rerank_dense(span_pooled_output) 370 | span_pooled_output = self.activation(span_pooled_output) 371 | span_pooled_output = self.dropout(span_pooled_output) 372 | rerank_logits = self.rerank_classifier(span_pooled_output).squeeze(-1) 373 | rerank_logits = reconstruct(rerank_logits, span_starts) 374 | return rerank_logits 375 | 376 | elif mode == 'read_rerank_train': 377 | assert input_ids is not None and token_type_ids is not None 378 | assert start_positions is not None and end_positions is not None 379 | all_encoder_layers, _ = self.bert(self.num_hidden_read, input_ids, token_type_ids, attention_mask) 380 | sequence_output = all_encoder_layers[-1] 381 | 382 | logits = self.read_affine(sequence_output) 383 | start_logits, end_logits = logits.split(1, dim=-1) 384 | start_logits = start_logits.squeeze(-1) 385 | end_logits = end_logits.squeeze(-1) 386 | 387 | start_loss = distant_cross_entropy(start_logits, start_positions) 388 | end_loss = distant_cross_entropy(end_logits, end_positions) 389 | read_loss = (start_loss + end_loss) / 2 390 | 391 | assert span_starts is not None and span_ends is not None and hard_labels is not None and soft_labels is not None 392 | span_output, span_mask = get_span_representation(span_starts, span_ends, sequence_output, 393 | attention_mask) # [N*M, JR, D], [N*M, JR] 394 | span_score = self.rerank_affine(span_output) 395 | span_score = span_score.squeeze(-1) # [N*M, JR] 396 | span_pooled_output = get_self_att_representation(span_output, span_score, span_mask) # [N*M, D] 397 | 398 | span_pooled_output = self.rerank_dense(span_pooled_output) 399 | span_pooled_output = self.activation(span_pooled_output) 400 | span_pooled_output = self.dropout(span_pooled_output) 401 | rerank_logits = self.rerank_classifier(span_pooled_output).squeeze(-1) 402 | rerank_logits = reconstruct(rerank_logits, span_starts) 403 | 404 | hard_loss = distant_cross_entropy(rerank_logits, hard_labels) 405 | soft_loss_fct = MSELoss() 406 | soft_loss = soft_loss_fct(rerank_logits, soft_labels.to(dtype=rerank_logits.dtype)) 407 | rerank_loss = hard_loss + soft_loss 408 | return read_loss + rerank_loss 409 | 410 | else: 411 | raise Exception 412 | -------------------------------------------------------------------------------- /bert/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | 30 | def gelu(x): 31 | """Implementation of the gelu activation function. 32 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 33 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 34 | """ 35 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 36 | 37 | 38 | class BertConfig(object): 39 | """Configuration class to store the configuration of a `BertModel`. 40 | """ 41 | def __init__(self, 42 | vocab_size, 43 | hidden_size=768, 44 | num_hidden_layers=12, 45 | num_attention_heads=12, 46 | intermediate_size=3072, 47 | hidden_act="gelu", 48 | hidden_dropout_prob=0.1, 49 | attention_probs_dropout_prob=0.1, 50 | max_position_embeddings=512, 51 | type_vocab_size=16, 52 | initializer_range=0.02): 53 | """Constructs BertConfig. 54 | 55 | Args: 56 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 57 | hidden_size: Size of the encoder layers and the pooler layer. 58 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 59 | num_attention_heads: Number of attention heads for each attention layer in 60 | the Transformer encoder. 61 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 62 | layer in the Transformer encoder. 63 | hidden_act: The non-linear activation function (function or string) in the 64 | encoder and pooler. 65 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 66 | layers in the embeddings, encoder, and pooler. 67 | attention_probs_dropout_prob: The dropout ratio for the attention 68 | probabilities. 69 | max_position_embeddings: The maximum sequence length that this model might 70 | ever be used with. Typically set this to something large just in case 71 | (e.g., 512 or 1024 or 2048). 72 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 73 | `BertModel`. 74 | initializer_range: The sttdev of the truncated_normal_initializer for 75 | initializing all weight matrices. 76 | """ 77 | self.vocab_size = vocab_size 78 | self.hidden_size = hidden_size 79 | self.num_hidden_layers = num_hidden_layers 80 | self.num_attention_heads = num_attention_heads 81 | self.hidden_act = hidden_act 82 | self.intermediate_size = intermediate_size 83 | self.hidden_dropout_prob = hidden_dropout_prob 84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 85 | self.max_position_embeddings = max_position_embeddings 86 | self.type_vocab_size = type_vocab_size 87 | self.initializer_range = initializer_range 88 | 89 | @classmethod 90 | def from_dict(cls, json_object): 91 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 92 | config = BertConfig(vocab_size=None) 93 | for (key, value) in six.iteritems(json_object): 94 | config.__dict__[key] = value 95 | return config 96 | 97 | @classmethod 98 | def from_json_file(cls, json_file): 99 | """Constructs a `BertConfig` from a json file of parameters.""" 100 | with open(json_file, "r") as reader: 101 | text = reader.read() 102 | return cls.from_dict(json.loads(text)) 103 | 104 | def to_dict(self): 105 | """Serializes this instance to a Python dictionary.""" 106 | output = copy.deepcopy(self.__dict__) 107 | return output 108 | 109 | def to_json_string(self): 110 | """Serializes this instance to a JSON string.""" 111 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 112 | 113 | 114 | class BERTLayerNorm(nn.Module): 115 | def __init__(self, config, variance_epsilon=1e-12): 116 | """Construct a layernorm module in the TF style (epsilon inside the square root). 117 | """ 118 | super(BERTLayerNorm, self).__init__() 119 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 120 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 121 | self.variance_epsilon = variance_epsilon 122 | 123 | def forward(self, x): 124 | u = x.mean(-1, keepdim=True) 125 | s = (x - u).pow(2).mean(-1, keepdim=True) 126 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 127 | return self.gamma * x + self.beta 128 | 129 | 130 | class BERTEmbeddings(nn.Module): 131 | def __init__(self, config): 132 | super(BERTEmbeddings, self).__init__() 133 | """Construct the embedding module from word, position and token_type embeddings. 134 | """ 135 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 136 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 137 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 138 | 139 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 140 | # any TensorFlow checkpoint file 141 | self.LayerNorm = BERTLayerNorm(config) 142 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 143 | 144 | def forward(self, input_ids, token_type_ids=None): 145 | seq_length = input_ids.size(1) 146 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 147 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 148 | if token_type_ids is None: 149 | token_type_ids = torch.zeros_like(input_ids) 150 | 151 | words_embeddings = self.word_embeddings(input_ids) 152 | position_embeddings = self.position_embeddings(position_ids) 153 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 154 | 155 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 156 | embeddings = self.LayerNorm(embeddings) 157 | embeddings = self.dropout(embeddings) 158 | return embeddings 159 | 160 | 161 | class BERTSelfAttention(nn.Module): 162 | def __init__(self, config): 163 | super(BERTSelfAttention, self).__init__() 164 | if config.hidden_size % config.num_attention_heads != 0: 165 | raise ValueError( 166 | "The hidden size (%d) is not a multiple of the number of attention " 167 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 168 | self.num_attention_heads = config.num_attention_heads 169 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 170 | self.all_head_size = self.num_attention_heads * self.attention_head_size 171 | 172 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 173 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 174 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 175 | 176 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 177 | 178 | def transpose_for_scores(self, x): 179 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 180 | x = x.view(*new_x_shape) 181 | return x.permute(0, 2, 1, 3) 182 | 183 | def forward(self, hidden_states, attention_mask): 184 | mixed_query_layer = self.query(hidden_states) # [N, L, H] 185 | mixed_key_layer = self.key(hidden_states) 186 | mixed_value_layer = self.value(hidden_states) 187 | 188 | query_layer = self.transpose_for_scores(mixed_query_layer) # [N, K, L, H//K] 189 | key_layer = self.transpose_for_scores(mixed_key_layer) 190 | value_layer = self.transpose_for_scores(mixed_value_layer) 191 | 192 | # Take the dot product between "query" and "key" to get the raw attention scores. 193 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [N, K, L, L] 194 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 195 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 196 | attention_scores = attention_scores + attention_mask 197 | 198 | # Normalize the attention scores to probabilities. 199 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 200 | 201 | # This is actually dropping out entire tokens to attend to, which might 202 | # seem a bit unusual, but is taken from the original Transformer paper. 203 | attention_probs = self.dropout(attention_probs) 204 | 205 | context_layer = torch.matmul(attention_probs, value_layer) # [N, K, L, H//K] 206 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [N, L, K, H//K] 207 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 208 | context_layer = context_layer.view(*new_context_layer_shape) # [N, L, H] 209 | return context_layer 210 | 211 | 212 | class BERTSelfOutput(nn.Module): 213 | def __init__(self, config): 214 | super(BERTSelfOutput, self).__init__() 215 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 216 | self.LayerNorm = BERTLayerNorm(config) 217 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 218 | 219 | def forward(self, hidden_states, input_tensor): 220 | hidden_states = self.dense(hidden_states) 221 | hidden_states = self.dropout(hidden_states) 222 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 223 | return hidden_states 224 | 225 | 226 | class BERTAttention(nn.Module): 227 | def __init__(self, config): 228 | super(BERTAttention, self).__init__() 229 | self.self = BERTSelfAttention(config) 230 | self.output = BERTSelfOutput(config) 231 | 232 | def forward(self, input_tensor, attention_mask): 233 | self_output = self.self(input_tensor, attention_mask) 234 | attention_output = self.output(self_output, input_tensor) 235 | return attention_output 236 | 237 | 238 | class BERTIntermediate(nn.Module): 239 | def __init__(self, config): 240 | super(BERTIntermediate, self).__init__() 241 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 242 | self.intermediate_act_fn = gelu 243 | 244 | def forward(self, hidden_states): 245 | hidden_states = self.dense(hidden_states) 246 | hidden_states = self.intermediate_act_fn(hidden_states) 247 | return hidden_states 248 | 249 | 250 | class BERTOutput(nn.Module): 251 | def __init__(self, config): 252 | super(BERTOutput, self).__init__() 253 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 254 | self.LayerNorm = BERTLayerNorm(config) 255 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 256 | 257 | def forward(self, hidden_states, input_tensor): 258 | hidden_states = self.dense(hidden_states) 259 | hidden_states = self.dropout(hidden_states) 260 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 261 | return hidden_states 262 | 263 | 264 | class BERTLayer(nn.Module): 265 | def __init__(self, config): 266 | super(BERTLayer, self).__init__() 267 | self.attention = BERTAttention(config) 268 | self.intermediate = BERTIntermediate(config) 269 | self.output = BERTOutput(config) 270 | 271 | def forward(self, hidden_states, attention_mask): 272 | attention_output = self.attention(hidden_states, attention_mask) 273 | intermediate_output = self.intermediate(attention_output) 274 | layer_output = self.output(intermediate_output, attention_output) 275 | return layer_output 276 | 277 | 278 | class BERTEncoder(nn.Module): 279 | def __init__(self, config): 280 | super(BERTEncoder, self).__init__() 281 | layer = BERTLayer(config) 282 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 283 | 284 | def forward(self, hidden_states, attention_mask): 285 | all_encoder_layers = [] 286 | for layer_module in self.layer: 287 | hidden_states = layer_module(hidden_states, attention_mask) 288 | all_encoder_layers.append(hidden_states) 289 | return all_encoder_layers 290 | 291 | 292 | class BERTPooler(nn.Module): 293 | def __init__(self, config): 294 | super(BERTPooler, self).__init__() 295 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 296 | self.activation = nn.Tanh() 297 | 298 | def forward(self, hidden_states): 299 | # We "pool" the model by simply taking the hidden state corresponding 300 | # to the first token. 301 | first_token_tensor = hidden_states[:, 0] 302 | pooled_output = self.dense(first_token_tensor) 303 | pooled_output = self.activation(pooled_output) 304 | return pooled_output 305 | 306 | 307 | class BertModel(nn.Module): 308 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 309 | 310 | Example usage: 311 | ```python 312 | # Already been converted into WordPiece token ids 313 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 314 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 315 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 316 | 317 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 318 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 319 | 320 | model = modeling.BertModel(config=config) 321 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 322 | ``` 323 | """ 324 | def __init__(self, config: BertConfig): 325 | """Constructor for BertModel. 326 | 327 | Args: 328 | config: `BertConfig` instance. 329 | """ 330 | super(BertModel, self).__init__() 331 | self.embeddings = BERTEmbeddings(config) 332 | self.encoder = BERTEncoder(config) 333 | self.pooler = BERTPooler(config) 334 | 335 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 336 | if attention_mask is None: 337 | attention_mask = torch.ones_like(input_ids) 338 | if token_type_ids is None: 339 | token_type_ids = torch.zeros_like(input_ids) 340 | 341 | # We create a 3D attention mask from a 2D tensor mask. 342 | # Sizes are [batch_size, 1, 1, to_seq_length] 343 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 344 | # this attention mask is more simple than the triangular masking of causal attention 345 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 346 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 347 | 348 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 349 | # masked positions, this operation will create a tensor which is 0.0 for 350 | # positions we want to attend and -10000.0 for masked positions. 351 | # Since we are adding it to the raw scores before the softmax, this is 352 | # effectively the same as removing these entirely. 353 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 354 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 355 | 356 | embedding_output = self.embeddings(input_ids, token_type_ids) 357 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 358 | sequence_output = all_encoder_layers[-1] 359 | pooled_output = self.pooler(sequence_output) 360 | return all_encoder_layers, pooled_output 361 | 362 | 363 | class BertForSequenceClassification(nn.Module): 364 | """BERT model for classification. 365 | This module is composed of the BERT model with a linear layer on top of 366 | the pooled output. 367 | 368 | Example usage: 369 | ```python 370 | # Already been converted into WordPiece token ids 371 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 372 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 373 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 374 | 375 | config = BertConfig(vocab_size=32000, hidden_size=512, 376 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 377 | 378 | num_labels = 2 379 | 380 | model = BertForSequenceClassification(config, num_labels) 381 | logits = model(input_ids, token_type_ids, input_mask) 382 | ``` 383 | """ 384 | def __init__(self, config, num_labels): 385 | super(BertForSequenceClassification, self).__init__() 386 | self.bert = BertModel(config) 387 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 388 | self.classifier = nn.Linear(config.hidden_size, num_labels) 389 | 390 | def init_weights(module): 391 | if isinstance(module, (nn.Linear, nn.Embedding)): 392 | # Slightly different from the TF version which uses truncated_normal for initialization 393 | # cf https://github.com/pytorch/pytorch/pull/5617 394 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 395 | elif isinstance(module, BERTLayerNorm): 396 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 397 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 398 | if isinstance(module, nn.Linear): 399 | module.bias.data.zero_() 400 | self.apply(init_weights) 401 | 402 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 403 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 404 | pooled_output = self.dropout(pooled_output) 405 | logits = self.classifier(pooled_output) 406 | 407 | if labels is not None: 408 | loss_fct = CrossEntropyLoss() 409 | loss = loss_fct(logits, labels) 410 | return loss, logits 411 | else: 412 | return logits 413 | 414 | 415 | class BertForQuestionAnswering(nn.Module): 416 | """BERT model for Question Answering (span extraction). 417 | This module is composed of the BERT model with a linear layer on top of 418 | the sequence output that computes start_logits and end_logits 419 | 420 | Example usage: 421 | ```python 422 | # Already been converted into WordPiece token ids 423 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 424 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 425 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 426 | 427 | config = BertConfig(vocab_size=32000, hidden_size=512, 428 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 429 | 430 | model = BertForQuestionAnswering(config) 431 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 432 | ``` 433 | """ 434 | def __init__(self, config): 435 | super(BertForQuestionAnswering, self).__init__() 436 | self.bert = BertModel(config) 437 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 438 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 439 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 440 | 441 | def init_weights(module): 442 | if isinstance(module, (nn.Linear, nn.Embedding)): 443 | # Slightly different from the TF version which uses truncated_normal for initialization 444 | # cf https://github.com/pytorch/pytorch/pull/5617 445 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 446 | elif isinstance(module, BERTLayerNorm): 447 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 448 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 449 | if isinstance(module, nn.Linear): 450 | module.bias.data.zero_() 451 | self.apply(init_weights) 452 | 453 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 454 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 455 | sequence_output = all_encoder_layers[-1] 456 | logits = self.qa_outputs(sequence_output) 457 | start_logits, end_logits = logits.split(1, dim=-1) 458 | start_logits = start_logits.squeeze(-1) 459 | end_logits = end_logits.squeeze(-1) 460 | 461 | if start_positions is not None and end_positions is not None: 462 | # If we are on multi-GPU, split add a dimension 463 | if len(start_positions.size()) > 1: 464 | start_positions = start_positions.squeeze(-1) 465 | if len(end_positions.size()) > 1: 466 | end_positions = end_positions.squeeze(-1) 467 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 468 | ignored_index = start_logits.size(1) 469 | start_positions.clamp_(0, ignored_index) 470 | end_positions.clamp_(0, ignored_index) 471 | 472 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 473 | start_loss = loss_fct(start_logits, start_positions) 474 | end_loss = loss_fct(end_logits, end_positions) 475 | total_loss = (start_loss + end_loss) / 2 476 | return total_loss 477 | else: 478 | return start_logits, end_logits -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | def warmup_cosine(x, warmup=0.002): 23 | if x < warmup: 24 | return x/warmup 25 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 26 | 27 | def warmup_constant(x, warmup=0.002): 28 | if x < warmup: 29 | return x/warmup 30 | return 1.0 31 | 32 | def warmup_linear(x, warmup=0.002): 33 | if x < warmup: 34 | return x/warmup 35 | return 1.0 - x 36 | 37 | SCHEDULES = { 38 | 'warmup_cosine':warmup_cosine, 39 | 'warmup_constant':warmup_constant, 40 | 'warmup_linear':warmup_linear, 41 | } 42 | 43 | 44 | class BERTAdam(Optimizer): 45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 46 | Params: 47 | lr: learning rate 48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 49 | t_total: total number of training steps for the learning 50 | rate schedule, -1 means constant learning rate. Default: -1 51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 52 | b1: Adams b1. Default: 0.9 53 | b2: Adams b2. Default: 0.999 54 | e: Adams epsilon. Default: 1e-6 55 | weight_decay_rate: Weight decay. Default: 0.01 56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 57 | """ 58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 60 | max_grad_norm=1.0): 61 | if not lr >= 0.0: 62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 63 | if schedule not in SCHEDULES: 64 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 65 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 67 | if not 0.0 <= b1 < 1.0: 68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 69 | if not 0.0 <= b2 < 1.0: 70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 71 | if not e >= 0.0: 72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 75 | max_grad_norm=max_grad_norm) 76 | super(BERTAdam, self).__init__(params, defaults) 77 | 78 | def get_lr(self): 79 | lr = [] 80 | for group in self.param_groups: 81 | for p in group['params']: 82 | state = self.state[p] 83 | if len(state) == 0: 84 | return [0] 85 | if group['t_total'] != -1: 86 | schedule_fct = SCHEDULES[group['schedule']] 87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 88 | else: 89 | lr_scheduled = group['lr'] 90 | lr.append(lr_scheduled) 91 | return lr 92 | 93 | def step(self, closure=None): 94 | """Performs a single optimization step. 95 | 96 | Arguments: 97 | closure (callable, optional): A closure that reevaluates the model 98 | and returns the loss. 99 | """ 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | if p.grad is None: 107 | continue 108 | grad = p.grad.data 109 | if grad.is_sparse: 110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 111 | 112 | state = self.state[p] 113 | 114 | # State initialization 115 | if len(state) == 0: 116 | state['step'] = 0 117 | # Exponential moving average of gradient values 118 | state['next_m'] = torch.zeros_like(p.data) 119 | # Exponential moving average of squared gradient values 120 | state['next_v'] = torch.zeros_like(p.data) 121 | 122 | next_m, next_v = state['next_m'], state['next_v'] 123 | beta1, beta2 = group['b1'], group['b2'] 124 | 125 | # Add grad clipping 126 | if group['max_grad_norm'] > 0: 127 | clip_grad_norm_(p, group['max_grad_norm']) 128 | 129 | # Decay the first and second moment running average coefficient 130 | # In-place operations to update the averages at the same time 131 | next_m.mul_(beta1).add_(1 - beta1, grad) 132 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 133 | update = next_m / (next_v.sqrt() + group['e']) 134 | 135 | # Just adding the square of the weights to the loss function is *not* 136 | # the correct way of using L2 regularization/weight decay with Adam, 137 | # since that will interact with the m and v parameters in strange ways. 138 | # 139 | # Instead we want ot decay the weights in a manner that doesn't interact 140 | # with the m/v parameters. This is equivalent to adding the square 141 | # of the weights to the loss with plain (non-momentum) SGD. 142 | if group['weight_decay_rate'] > 0.0: 143 | update += group['weight_decay_rate'] * p.data 144 | 145 | if group['t_total'] != -1: 146 | schedule_fct = SCHEDULES[group['schedule']] 147 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 148 | else: 149 | lr_scheduled = group['lr'] 150 | 151 | update_with_lr = lr_scheduled * update 152 | p.data.add_(-update_with_lr) 153 | 154 | state['step'] += 1 155 | 156 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 157 | # bias_correction1 = 1 - beta1 ** state['step'] 158 | # bias_correction2 = 1 - beta2 ** state['step'] 159 | 160 | return loss 161 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | index = 0 73 | with open(vocab_file, "r", encoding="utf-8") as reader: 74 | while True: 75 | token = convert_to_unicode(reader.readline()) 76 | if not token: 77 | break 78 | token = token.strip() 79 | vocab[token] = index 80 | index += 1 81 | return vocab 82 | 83 | 84 | def convert_tokens_to_ids(vocab, tokens): 85 | """Converts a sequence of tokens into ids using the vocab.""" 86 | ids = [] 87 | for token in tokens: 88 | ids.append(vocab[token]) 89 | return ids 90 | 91 | 92 | def whitespace_tokenize(text): 93 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 94 | text = text.strip() 95 | if not text: 96 | return [] 97 | tokens = text.split() 98 | return tokens 99 | 100 | 101 | class FullTokenizer(object): 102 | """Runs end-to-end tokenziation.""" 103 | 104 | def __init__(self, vocab_file, do_lower_case=True): 105 | self.vocab = load_vocab(vocab_file) 106 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 107 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 108 | 109 | def tokenize(self, text): 110 | split_tokens = [] 111 | for token in self.basic_tokenizer.tokenize(text): 112 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 113 | split_tokens.append(sub_token) 114 | 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | return convert_tokens_to_ids(self.vocab, tokens) 119 | 120 | 121 | class BasicTokenizer(object): 122 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 123 | 124 | def __init__(self, do_lower_case=True): 125 | """Constructs a BasicTokenizer. 126 | 127 | Args: 128 | do_lower_case: Whether to lower case the input. 129 | """ 130 | self.do_lower_case = do_lower_case 131 | 132 | def tokenize(self, text): 133 | """Tokenizes a piece of text.""" 134 | text = convert_to_unicode(text) 135 | text = self._clean_text(text) 136 | # This was added on November 1st, 2018 for the multilingual and Chinese 137 | # models. This is also applied to the English models now, but it doesn't 138 | # matter since the English models were not trained on any Chinese data 139 | # and generally don't have any Chinese data in them (there are Chinese 140 | # characters in the vocabulary because Wikipedia does have some Chinese 141 | # words in the English Wikipedia.). 142 | text = self._tokenize_chinese_chars(text) 143 | orig_tokens = whitespace_tokenize(text) 144 | split_tokens = [] 145 | for token in orig_tokens: 146 | if self.do_lower_case: 147 | token = token.lower() 148 | token = self._run_strip_accents(token) 149 | split_tokens.extend(self._run_split_on_punc(token)) 150 | 151 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 152 | return output_tokens 153 | 154 | def _run_strip_accents(self, text): 155 | """Strips accents from a piece of text.""" 156 | text = unicodedata.normalize("NFD", text) 157 | output = [] 158 | for char in text: 159 | cat = unicodedata.category(char) 160 | if cat == "Mn": 161 | continue 162 | output.append(char) 163 | return "".join(output) 164 | 165 | def _run_split_on_punc(self, text): 166 | """Splits punctuation on a piece of text.""" 167 | chars = list(text) 168 | i = 0 169 | start_new_word = True 170 | output = [] 171 | while i < len(chars): 172 | char = chars[i] 173 | if _is_punctuation(char): 174 | output.append([char]) 175 | start_new_word = True 176 | else: 177 | if start_new_word: 178 | output.append([]) 179 | start_new_word = False 180 | output[-1].append(char) 181 | i += 1 182 | 183 | return ["".join(x) for x in output] 184 | 185 | def _tokenize_chinese_chars(self, text): 186 | """Adds whitespace around any CJK character.""" 187 | output = [] 188 | for char in text: 189 | cp = ord(char) 190 | if self._is_chinese_char(cp): 191 | output.append(" ") 192 | output.append(char) 193 | output.append(" ") 194 | else: 195 | output.append(char) 196 | return "".join(output) 197 | 198 | def _is_chinese_char(self, cp): 199 | """Checks whether CP is the codepoint of a CJK character.""" 200 | # This defines a "chinese character" as anything in the CJK Unicode block: 201 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 202 | # 203 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 204 | # despite its name. The modern Korean Hangul alphabet is a different block, 205 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 206 | # space-separated words, so they are not treated specially and handled 207 | # like the all of the other languages. 208 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 209 | (cp >= 0x3400 and cp <= 0x4DBF) or # 210 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 211 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 212 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 213 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 214 | (cp >= 0xF900 and cp <= 0xFAFF) or # 215 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 216 | return True 217 | 218 | return False 219 | 220 | def _clean_text(self, text): 221 | """Performs invalid character removal and whitespace cleanup on text.""" 222 | output = [] 223 | for char in text: 224 | cp = ord(char) 225 | if cp == 0 or cp == 0xfffd or _is_control(char): 226 | continue 227 | if _is_whitespace(char): 228 | output.append(" ") 229 | else: 230 | output.append(char) 231 | return "".join(output) 232 | 233 | 234 | class WordpieceTokenizer(object): 235 | """Runs WordPiece tokenization.""" 236 | 237 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 238 | self.vocab = vocab 239 | self.unk_token = unk_token 240 | self.max_input_chars_per_word = max_input_chars_per_word 241 | 242 | def tokenize(self, text): 243 | """Tokenizes a piece of text into its word pieces. 244 | 245 | This uses a greedy longest-match-first algorithm to perform tokenization 246 | using the given vocabulary. 247 | 248 | For example: 249 | input = "unaffable" 250 | output = ["un", "##aff", "##able"] 251 | 252 | Args: 253 | text: A single token or whitespace separated tokens. This should have 254 | already been passed through `BasicTokenizer. 255 | 256 | Returns: 257 | A list of wordpiece tokens. 258 | """ 259 | 260 | text = convert_to_unicode(text) 261 | 262 | output_tokens = [] 263 | for token in whitespace_tokenize(text): 264 | chars = list(token) 265 | if len(chars) > self.max_input_chars_per_word: 266 | output_tokens.append(self.unk_token) 267 | continue 268 | 269 | is_bad = False 270 | start = 0 271 | sub_tokens = [] 272 | while start < len(chars): 273 | end = len(chars) 274 | cur_substr = None 275 | while start < end: 276 | substr = "".join(chars[start:end]) 277 | if start > 0: 278 | substr = "##" + substr 279 | if substr in self.vocab: 280 | cur_substr = substr 281 | break 282 | end -= 1 283 | if cur_substr is None: 284 | is_bad = True 285 | break 286 | sub_tokens.append(cur_substr) 287 | start = end 288 | 289 | if is_bad: 290 | output_tokens.append(self.unk_token) 291 | else: 292 | output_tokens.extend(sub_tokens) 293 | return output_tokens 294 | 295 | 296 | def _is_whitespace(char): 297 | """Checks whether `chars` is a whitespace character.""" 298 | # \t, \n, and \r are technically contorl characters but we treat them 299 | # as whitespace since they are generally considered as such. 300 | if char == " " or char == "\t" or char == "\n" or char == "\r": 301 | return True 302 | cat = unicodedata.category(char) 303 | if cat == "Zs": 304 | return True 305 | return False 306 | 307 | 308 | def _is_control(char): 309 | """Checks whether `chars` is a control character.""" 310 | # These are technically control characters but we count them as whitespace 311 | # characters. 312 | if char == "\t" or char == "\n" or char == "\r": 313 | return False 314 | cat = unicodedata.category(char) 315 | if cat.startswith("C"): 316 | return True 317 | return False 318 | 319 | 320 | def _is_punctuation(char): 321 | """Checks whether `chars` is a punctuation character.""" 322 | cp = ord(char) 323 | # We treat all non-letter/number ASCII as punctuation. 324 | # Characters such as "^", "$", and "`" are not in the Unicode 325 | # Punctuation class but we treat them as punctuation anyways, for 326 | # consistency. 327 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 328 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 329 | return True 330 | cat = unicodedata.category(char) 331 | if cat.startswith("P"): 332 | return True 333 | return False 334 | -------------------------------------------------------------------------------- /image/framework.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huminghao16/RE3QA/14faa386b519bed7c94ddff399afdb2c9967de44/image/framework.PNG -------------------------------------------------------------------------------- /squad/convert_squad_open.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import numpy as np 4 | from os.path import relpath, join, exists, expanduser 5 | from sklearn.feature_extraction.text import TfidfVectorizer 6 | from sklearn.metrics import pairwise_distances 7 | from typing import List, TypeVar, Iterable 8 | from tqdm import tqdm 9 | 10 | import bert.tokenization as tokenization 11 | from triviaqa.evidence_corpus import MergeParagraphs 12 | from triviaqa.build_span_corpus import FastNormalizedAnswerDetector 13 | from squad.squad_document_utils import DocumentAndQuestion 14 | 15 | stop_words = {'t', '–', 'there', 'but', 'needn', 'themselves', '’', '~', '$', 'few', '^', '₹', ']', 'we', 're', 16 | 'again', '?', 'they', 'ain', 'o', 'you', '+', 'has', 'by', 'than', 'whom', 'same', 'don', 'her', 17 | 'are', '(', 'an', 'so', 'the', 'been', 'wouldn', 'a', 'many', 'she', 'how', 'your', '°', 'do', 18 | 'shan', 'himself', 'between', 'ours', 'at', 'should', 'doesn', 'hasn', 'he', 'have', 'over', 19 | 'hadn', 'was', 'weren', 'down', 'above', '_', 'those', 'not', 'having', 'its', 'ourselves', 20 | 'for', 'when', 'if', ',', ';', 'about', 'theirs', 'him', '}', 'here', 'any', 'own', 'itself', 21 | 'very', 'on', 'myself', 'mustn', ')', 'because', 'now', '/', 'isn', 'to', 'just', 'these', 22 | 'i', 'further', 'mightn', 'll', '@', 'am', '”', 'below', 'shouldn', 'my', 'who', 'yours', 'why', 23 | 'such', '"', 'does', 'did', 'before', 'being', 'and', 'had', 'aren', '£', 'with', 'more', 'into', 24 | '<', 'herself', 'which', '[', "'", 'of', 'haven', 'that', 'will', 'yourself', 'in', 'doing', '−', 25 | 'them', '‘', 'some', '`', 'while', 'each', 'it', 'through', 'all', 'their', ':', '\\', 'where', 26 | 'both', 'hers', '¢', '—', 'm', '.', 'from', 'or', 'other', 'too', 'couldn', 'as', 'our', 'off', 27 | '%', '&', '-', '{', '=', 'didn', 'yourselves', 'under', 'y', 'ma', 'won', '!', '|', 'against', 28 | '#', '¥', 'is', 'nor', 'up', 'most', 's', 'no', 'can', '>', '*', 'during', 'once', 'what', 'me', 29 | 'then', 'd', 'only', 'de', 've', 'were', '€', 'until', 'his', 'out', 'wasn', 'this', 'after', 30 | 'be'} 31 | 32 | class SquadOpenExample(object): 33 | def __init__(self, qas_id, question_text, answer_texts, doc_text): 34 | self.qas_id = qas_id 35 | self.question_text = question_text 36 | self.answer_texts = answer_texts 37 | self.doc_text = doc_text 38 | 39 | def __str__(self): 40 | return self.__repr__() 41 | 42 | def __repr__(self): 43 | s = "" 44 | s += "qas_id: %s" % self.qas_id 45 | s += ", question_text: %s" % self.question_text 46 | s += ", answer_texts: {}".format(self.answer_texts) 47 | s += ", doc_text: %s" % self.doc_text[:1000] 48 | return s 49 | 50 | def rank(tfidf, questions: List[str], paragraphs: List[str]): 51 | para_features = tfidf.fit_transform(paragraphs) 52 | q_features = tfidf.transform(questions) 53 | scores = pairwise_distances(q_features, para_features, "cosine") 54 | return scores 55 | 56 | def main(): 57 | parse = argparse.ArgumentParser("Pre-tokenize the SQuAD open dev file") 58 | parse.add_argument("--input_file", type=str, default=join("data", "squad", "squad_dev_open.pkl")) 59 | # This is slow, using more processes is recommended 60 | parse.add_argument("--max_tokens", type=int, default=200, help="Number of maximal tokens in each merged paragraph") 61 | parse.add_argument("--n_to_select", type=int, default=30, help="Number of paragraphs to retrieve") 62 | parse.add_argument("--sort_passage", type=bool, default=True, help="Sort passage according to order") 63 | parse.add_argument("--debug", type=bool, default=False, help="Whether to run in debug mode") 64 | args = parse.parse_args() 65 | 66 | dev_examples = pickle.load(open(args.input_file, 'rb')) 67 | 68 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 69 | splitter = MergeParagraphs(args.max_tokens) 70 | tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=stop_words) 71 | detector = FastNormalizedAnswerDetector() 72 | 73 | ir_count, total_doc_length, pruned_doc_length = 0, 0, 0 74 | out = [] 75 | for example_ix, example in tqdm(enumerate(dev_examples), total=len(dev_examples)): 76 | paras = [x for x in example.doc_text.split("\n") if len(x) > 0] 77 | paragraphs = [tokenizer.tokenize(x) for x in paras] 78 | merged_paragraphs = splitter.merge(paragraphs) 79 | 80 | scores = rank(tfidf, [example.question_text], [" ".join(x) for x in merged_paragraphs]) 81 | para_scores = scores[0] 82 | para_ranks = np.argsort(para_scores) 83 | selection = [i for i in para_ranks[:args.n_to_select]] 84 | 85 | if args.sort_passage: 86 | selection = np.sort(selection) 87 | 88 | doc_tokens = [] 89 | for idx in selection: 90 | current_para = merged_paragraphs[idx] 91 | doc_tokens += current_para 92 | 93 | tokenized_answers = [tokenizer.tokenize(x) for x in example.answer_texts] 94 | detector.set_question(tokenized_answers) 95 | if len(detector.any_found(doc_tokens)) > 0: 96 | ir_count += 1 97 | 98 | total_doc_length += sum(len(para) for para in merged_paragraphs) 99 | pruned_doc_length += len(doc_tokens) 100 | 101 | out.append(DocumentAndQuestion(example_ix, example.qas_id, example.question_text, doc_tokens, 102 | '', 0, 0, True)) 103 | if args.debug and example_ix > 5: 104 | break 105 | print("Recall of answer existence in documents: {:.3f}".format(ir_count / len(out))) 106 | print("Average length of documents: {:.3f}".format(total_doc_length / len(out))) 107 | print("Average pruned length of documents: {:.3f}".format(pruned_doc_length / len(out))) 108 | output_file = join("data", "squad", "eval_open_{}paras_examples.pkl".format(args.n_to_select)) 109 | pickle.dump(out, open(output_file, 'wb')) 110 | 111 | if __name__ == "__main__": 112 | main() -------------------------------------------------------------------------------- /squad/squad_evaluate.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. [Changed name for external importing]""" 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def span_len(span): 12 | return span[1] - span[0] 13 | 14 | def span_overlap(s1, s2): 15 | start = max(s1[0], s2[0]) 16 | stop = min(s1[1], s2[1]) 17 | if stop > start: 18 | return start, stop 19 | return None 20 | 21 | def span_prec(true_span, pred_span): 22 | overlap = span_overlap(true_span, pred_span) 23 | if overlap is None: 24 | return 0. 25 | return span_len(overlap) / span_len(pred_span) 26 | 27 | def span_recall(true_span, pred_span): 28 | overlap = span_overlap(true_span, pred_span) 29 | if overlap is None: 30 | return 0. 31 | return span_len(overlap) / span_len(true_span) 32 | 33 | def span_f1(true_span, pred_span): 34 | p = span_prec(true_span, pred_span) 35 | r = span_recall(true_span, pred_span) 36 | if p == 0 or r == 0: 37 | return 0.0 38 | return 2. * p * r / (p + r) 39 | 40 | 41 | def normalize_answer(s): 42 | """Lower text and remove punctuation, articles and extra whitespace.""" 43 | def remove_articles(text): 44 | return re.sub(r'\b(a|an|the)\b', ' ', text) 45 | 46 | def white_space_fix(text): 47 | return ' '.join(text.split()) 48 | 49 | def remove_punc(text): 50 | exclude = set(string.punctuation) 51 | return ''.join(ch for ch in text if ch not in exclude) 52 | 53 | def lower(text): 54 | return text.lower() 55 | 56 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 57 | 58 | 59 | def f1_score(prediction, ground_truth): 60 | prediction_tokens = normalize_answer(prediction).split() 61 | ground_truth_tokens = normalize_answer(ground_truth).split() 62 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 63 | num_same = sum(common.values()) 64 | if num_same == 0: 65 | return 0 66 | precision = 1.0 * num_same / len(prediction_tokens) 67 | recall = 1.0 * num_same / len(ground_truth_tokens) 68 | f1 = (2 * precision * recall) / (precision + recall) 69 | return f1 70 | 71 | 72 | def exact_match_score(prediction, ground_truth): 73 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 74 | 75 | 76 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 77 | scores_for_ground_truths = [] 78 | for ground_truth in ground_truths: 79 | score = metric_fn(prediction, ground_truth) 80 | scores_for_ground_truths.append(score) 81 | return max(scores_for_ground_truths) 82 | 83 | 84 | def evaluate(dataset, predictions): 85 | f1 = exact_match = total = 0 86 | missing_count = 0 87 | for article in dataset: 88 | for paragraph in article['paragraphs']: 89 | for qa in paragraph['qas']: 90 | total += 1 91 | if qa['id'] not in predictions: 92 | missing_count += 1 93 | # message = 'Unanswered question ' + qa['id'] + \ 94 | # ' will receive score 0.' 95 | # print(message, file=sys.stderr) 96 | continue 97 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 98 | prediction = predictions[qa['id']] 99 | exact_match += metric_max_over_ground_truths( 100 | exact_match_score, prediction, ground_truths) 101 | f1 += metric_max_over_ground_truths( 102 | f1_score, prediction, ground_truths) 103 | 104 | exact_match = 100.0 * exact_match / (total-missing_count) 105 | f1 = 100.0 * f1 / (total-missing_count) 106 | print("missing prediction on %d examples" % (missing_count)) 107 | return {'exact_match': exact_match, 'f1': f1} 108 | 109 | 110 | def merge_eval(main_eval, new_eval): 111 | for k in new_eval: 112 | main_eval['%s' % (k)] = new_eval[k] 113 | 114 | 115 | if __name__ == '__main__': 116 | expected_version = '1.1' 117 | parser = argparse.ArgumentParser( 118 | description='Evaluation for SQuAD ' + expected_version) 119 | parser.add_argument('dataset_file', help='Dataset file') 120 | parser.add_argument('prediction_file', help='Prediction File') 121 | args = parser.parse_args() 122 | with open(args.dataset_file) as dataset_file: 123 | dataset_json = json.load(dataset_file) 124 | # if (dataset_json['version'] != expected_version): 125 | # print('Evaluation expects v-' + expected_version + 126 | # ', but got dataset with v-' + dataset_json['version'], 127 | # file=sys.stderr) 128 | dataset = dataset_json['data'] 129 | with open(args.prediction_file) as prediction_file: 130 | predictions = json.load(prediction_file) 131 | print(json.dumps(evaluate(dataset, predictions))) 132 | 133 | # prediction = '1854–1855' 134 | # ground_truths = ['1854'] 135 | # print(metric_max_over_ground_truths( 136 | # f1_score, prediction, ground_truths)) 137 | -------------------------------------------------------------------------------- /squad/squad_open_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import math 4 | import collections 5 | import numpy as np 6 | from sklearn.feature_extraction.text import TfidfVectorizer 7 | from sklearn.metrics import pairwise_distances 8 | from typing import List, TypeVar, Iterable 9 | 10 | import bert.tokenization as tokenization 11 | 12 | T = TypeVar('T') 13 | 14 | stop_words = {'t', '–', 'there', 'but', 'needn', 'themselves', '’', '~', '$', 'few', '^', '₹', ']', 'we', 're', 15 | 'again', '?', 'they', 'ain', 'o', 'you', '+', 'has', 'by', 'than', 'whom', 'same', 'don', 'her', 16 | 'are', '(', 'an', 'so', 'the', 'been', 'wouldn', 'a', 'many', 'she', 'how', 'your', '°', 'do', 17 | 'shan', 'himself', 'between', 'ours', 'at', 'should', 'doesn', 'hasn', 'he', 'have', 'over', 18 | 'hadn', 'was', 'weren', 'down', 'above', '_', 'those', 'not', 'having', 'its', 'ourselves', 19 | 'for', 'when', 'if', ',', ';', 'about', 'theirs', 'him', '}', 'here', 'any', 'own', 'itself', 20 | 'very', 'on', 'myself', 'mustn', ')', 'because', 'now', '/', 'isn', 'to', 'just', 'these', 21 | 'i', 'further', 'mightn', 'll', '@', 'am', '”', 'below', 'shouldn', 'my', 'who', 'yours', 'why', 22 | 'such', '"', 'does', 'did', 'before', 'being', 'and', 'had', 'aren', '£', 'with', 'more', 'into', 23 | '<', 'herself', 'which', '[', "'", 'of', 'haven', 'that', 'will', 'yourself', 'in', 'doing', '−', 24 | 'them', '‘', 'some', '`', 'while', 'each', 'it', 'through', 'all', 'their', ':', '\\', 'where', 25 | 'both', 'hers', '¢', '—', 'm', '.', 'from', 'or', 'other', 'too', 'couldn', 'as', 'our', 'off', 26 | '%', '&', '-', '{', '=', 'didn', 'yourselves', 'under', 'y', 'ma', 'won', '!', '|', 'against', 27 | '#', '¥', 'is', 'nor', 'up', 'most', 's', 'no', 'can', '>', '*', 'during', 'once', 'what', 'me', 28 | 'then', 'd', 'only', 'de', 've', 'were', '€', 'until', 'his', 'out', 'wasn', 'this', 'after', 29 | 'be'} 30 | 31 | 32 | def flatten_iterable(listoflists: Iterable[Iterable[T]]) -> List[T]: 33 | return [item for sublist in listoflists for item in sublist] 34 | 35 | 36 | class Question(object): 37 | def __init__(self, 38 | qas_id, 39 | doc_index, 40 | para_index, 41 | question_text, 42 | answer_texts=None): 43 | self.qas_id = qas_id 44 | self.doc_index = doc_index 45 | self.para_index = para_index 46 | self.question_text = question_text 47 | self.answer_texts = answer_texts 48 | 49 | def __str__(self): 50 | return self.__repr__() 51 | 52 | def __repr__(self): 53 | s = "" 54 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 55 | s += "doc_index: %d" % (self.doc_index) 56 | s += "para_index: %d" % (self.para_index) 57 | s += ", question_text: %s" % ( 58 | tokenization.printable_text(self.question_text)) 59 | if self.answer_texts is not None: 60 | s += ", answer_texts: ".format(self.answer_texts) 61 | return s 62 | 63 | 64 | class Paragraph(object): 65 | def __init__(self, 66 | paragraph_id, 67 | paragraph_text): 68 | self.paragraph_id = paragraph_id 69 | self.paragraph_text = paragraph_text 70 | 71 | def __str__(self): 72 | return self.__repr__() 73 | 74 | def __repr__(self): 75 | s = "" 76 | s += "paragraph_id: %s" % (self.paragraph_id) 77 | return s 78 | 79 | 80 | class Document(object): 81 | def __init__(self, document_id: str, paragraphs: List[Paragraph]): 82 | self.document_id = document_id 83 | self.paragraphs = paragraphs 84 | 85 | def __str__(self): 86 | return self.__repr__() 87 | 88 | def __repr__(self): 89 | s = "" 90 | s += "document_id: %s" % (self.document_id) 91 | s += ", paragraph_num: %s" % (len(self.paragraphs)) 92 | return s 93 | 94 | def get_doc_text(self): 95 | all_doc_text = '' 96 | for idx, para in enumerate(self.paragraphs): 97 | if idx == 0: 98 | all_doc_text += para.paragraph_text 99 | else: 100 | all_doc_text += ' ' 101 | all_doc_text += para.paragraph_text 102 | return all_doc_text 103 | 104 | def tfidf_rank(questions: List[str], documents: List[str]): 105 | tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=stop_words) 106 | doc_features = tfidf.fit_transform(documents) 107 | q_features = tfidf.transform(questions) 108 | scores = pairwise_distances(q_features, doc_features, "cosine") 109 | return scores 110 | 111 | def read_squad_open_examples(input_file, n_to_select, is_training, debug=False): 112 | """Read a SQuAD json file into a list of SquadExample.""" 113 | with open(input_file, "r") as reader: 114 | input_data = json.load(reader)["data"] 115 | 116 | documents = [] 117 | questions = [] 118 | for article_ix, article in enumerate(input_data): 119 | document_id = "%s-%d" % (article['title'], article_ix) 120 | paragraphs = [] 121 | for paragraph_ix, paragraph in enumerate(article["paragraphs"]): 122 | paragraph_text = paragraph["context"] 123 | paragraphs.append(Paragraph(paragraph_ix, paragraph_text)) 124 | 125 | for qa in paragraph["qas"]: 126 | qas_id = qa["id"] 127 | question_text = qa["question"] 128 | answer_texts = [] 129 | for answer in qa["answers"]: 130 | answer_texts.append(answer["text"]) 131 | questions.append(Question(qas_id, article_ix, paragraph_ix, question_text, answer_texts)) 132 | 133 | documents.append(Document(document_id, paragraphs)) 134 | if (article_ix+1) == 10 and debug: 135 | break 136 | 137 | scores = tfidf_rank([x.question_text for x in questions], [x.get_doc_text() for x in documents]) # [1177, 3] 138 | 139 | ir_count = 0 140 | for que_ix, question in enumerate(questions): 141 | doc_scores = scores[que_ix] 142 | doc_ranks = np.argsort(doc_scores) 143 | selection = [i for i in doc_ranks[:n_to_select]] 144 | rank = [i for i in np.arange(n_to_select)] 145 | 146 | if question.doc_index in selection: 147 | ir_count += 1 148 | 149 | if is_training and question.doc_index not in selection: 150 | selection[-1] = question.doc_index 151 | 152 | print("Retrieve {} questions from {} documents".format(len(questions), len(documents))) 153 | print("Recall of answer existence in documents: {:.3f}".format(ir_count / len(questions))) 154 | 155 | read_squad_open_examples("../data/squad/dev-v1.1.json", 5, False, False) -------------------------------------------------------------------------------- /triviaqa/ablate_triviaqa_unfiltered.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | from triviaqa.build_span_corpus import TriviaQaUnfilteredDataset, TriviaQaSampleUnfilteredDataset 6 | from triviaqa.preprocessed_corpus import preprocess_par, ExtractMultiParagraphsPerQuestion, TopTfIdf 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description='Train a model on TriviaQA open') 11 | parser.add_argument("--debug", default=False, action='store_true', help="Whether to run in debug mode.") 12 | parser.add_argument("--data_dir", default="data/triviaqa/unfiltered", type=str, help="Triviaqa wiki data dir") 13 | parser.add_argument('--n_processes', type=int, default=1, 14 | help="Number of processes (i.e., select which paragraphs to train on) " 15 | "the data with") 16 | parser.add_argument('--chunk_size', type=int, default=1000, 17 | help="Size of one chunk") 18 | parser.add_argument('--n_para_train', type=int, default=2, 19 | help="Num of selected paragraphs during training") 20 | parser.add_argument('--n_para_dev', type=int, default=4, 21 | help="Num of selected paragraphs during evaluation") 22 | parser.add_argument('--n_para_test', type=int, default=4, 23 | help="Num of selected paragraphs during testing") 24 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to process train set.") 25 | parser.add_argument("--do_dev", default=False, action='store_true', help="Whether to process dev set.") 26 | parser.add_argument("--do_test", default=False, action='store_true', help="Whether to process test set.") 27 | args = parser.parse_args() 28 | 29 | if args.debug: 30 | corpus = TriviaQaSampleUnfilteredDataset() 31 | else: 32 | corpus = TriviaQaUnfilteredDataset() 33 | 34 | if args.do_train: 35 | train_questions = corpus.get_train() # List[TriviaQaQuestion] 36 | train_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_train, is_training=True), 37 | intern=True, is_training=True) 38 | _train = preprocess_par(train_questions, corpus.evidence, train_preprocesser, args.n_processes, args.chunk_size, 39 | "train") 40 | print("Recall of answer existence in {} set: {:.3f}".format("train", _train.ir_count / len(_train.data))) 41 | print("Average number of documents in {} set: {:.3f}".format("train", _train.total_doc_num / len(_train.data))) 42 | print("Average length of documents in {} set: {:.3f}".format("train", _train.total_doc_length / len(_train.data))) 43 | print("Average pruned length of documents in {} set: {:.3f}".format("train", _train.pruned_doc_length / len(_train.data))) 44 | print("Number of examples: {}".format(len(_train.data))) 45 | 46 | train_examples_path = os.path.join(args.data_dir, "train_{}paras_examples.pkl".format(args.n_para_train)) 47 | pickle.dump(_train.data, open(train_examples_path, 'wb')) 48 | 49 | if args.do_dev: 50 | dev_questions = corpus.get_dev() 51 | dev_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_dev, is_training=False), 52 | intern=True, is_training=False) 53 | _dev = preprocess_par(dev_questions, corpus.evidence, dev_preprocesser, args.n_processes, args.chunk_size, "dev") 54 | print("Recall of answer existence in {} set: {:.3f}".format("dev", _dev.ir_count / len(_dev.data))) 55 | print("Average number of documents in {} set: {:.3f}".format("dev", _dev.total_doc_num / len(_dev.data))) 56 | print("Average length of documents in {} set: {:.3f}".format("dev", _dev.total_doc_length / len(_dev.data))) 57 | print("Average pruned length of documents in {} set: {:.3f}".format("dev", _dev.pruned_doc_length / len(_dev.data))) 58 | print("Number of examples: {}".format(len(_dev.data))) 59 | 60 | dev_examples_path = os.path.join(args.data_dir, "dev_{}paras_examples.pkl".format(args.n_para_dev)) 61 | pickle.dump(_dev.data, open(dev_examples_path, 'wb')) 62 | 63 | if args.do_test: 64 | test_questions = corpus.get_test() 65 | test_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_test, is_training=False), 66 | intern=True, is_training=False) 67 | _test = preprocess_par(test_questions, corpus.evidence, test_preprocesser, args.n_processes, 68 | args.chunk_size, "test") 69 | print("Recall of answer existence in {} set: {:.3f}".format("test", _test.ir_count / len(_test.data))) 70 | print("Average number of documents in {} set: {:.3f}".format("test", _test.total_doc_num / len(_test.data))) 71 | print("Average length of documents in {} set: {:.3f}".format("test", _test.total_doc_length / len(_test.data))) 72 | print("Average pruned length of documents in {} set: {:.3f}".format("test", _test.pruned_doc_length / len(_test.data))) 73 | print("Number of examples: {}".format(len(_test.data))) 74 | 75 | test_examples_path = os.path.join(args.data_dir, "test_{}paras_examples.pkl".format(args.n_para_test)) 76 | pickle.dump(_test.data, open(test_examples_path, 'wb')) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() -------------------------------------------------------------------------------- /triviaqa/ablate_triviaqa_wiki.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | from triviaqa.build_span_corpus import TriviaQaWikiDataset, TriviaQaSampleWikiDataset 6 | from triviaqa.preprocessed_corpus import preprocess_par, ExtractMultiParagraphsPerQuestion, TopTfIdf 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description='Train a model on TriviaQA web') 11 | parser.add_argument("--debug", default=False, action='store_true', help="Whether to run in debug mode.") 12 | parser.add_argument("--data_dir", default="data/triviaqa/wiki", type=str, help="Triviaqa wiki data dir") 13 | parser.add_argument('--n_processes', type=int, default=1, 14 | help="Number of processes (i.e., select which paragraphs to train on) " 15 | "the data with") 16 | parser.add_argument('--chunk_size', type=int, default=1000, 17 | help="Size of one chunk") 18 | parser.add_argument('--n_para_train', type=int, default=2, 19 | help="Num of selected paragraphs during training") 20 | parser.add_argument('--n_para_dev', type=int, default=4, 21 | help="Num of selected paragraphs during evaluation") 22 | parser.add_argument('--n_para_verified', type=int, default=4, 23 | help="Num of selected paragraphs during evaluation") 24 | parser.add_argument('--n_para_test', type=int, default=4, 25 | help="Num of selected paragraphs during testing") 26 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to process train set.") 27 | parser.add_argument("--do_dev", default=False, action='store_true', help="Whether to process dev set.") 28 | parser.add_argument("--do_verified", default=False, action='store_true', help="Whether to process verified set.") 29 | parser.add_argument("--do_test", default=False, action='store_true', help="Whether to process test set.") 30 | args = parser.parse_args() 31 | 32 | if args.debug: 33 | corpus = TriviaQaSampleWikiDataset() 34 | else: 35 | corpus = TriviaQaWikiDataset() 36 | 37 | if args.do_train: 38 | train_questions = corpus.get_train() # List[TriviaQaQuestion] 39 | train_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_train, is_training=True), 40 | intern=True, is_training=True) 41 | _train = preprocess_par(train_questions, corpus.evidence, train_preprocesser, args.n_processes, args.chunk_size, 42 | "train") 43 | print("Recall of answer existence in {} set: {:.3f}".format("train", _train.ir_count / len(_train.data))) 44 | print("Average number of documents in {} set: {:.3f}".format("train", _train.total_doc_num / len(_train.data))) 45 | print("Average length of documents in {} set: {:.3f}".format("train", _train.total_doc_length / len(_train.data))) 46 | print("Average pruned length of documents in {} set: {:.3f}".format("train", _train.pruned_doc_length / len(_train.data))) 47 | print("Number of examples: {}".format(len(_train.data))) 48 | 49 | train_examples_path = os.path.join(args.data_dir, "train_{}paras_examples.pkl".format(args.n_para_train)) 50 | pickle.dump(_train.data, open(train_examples_path, 'wb')) 51 | 52 | if args.do_dev: 53 | dev_questions = corpus.get_dev() 54 | dev_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_dev, is_training=False), 55 | intern=True, is_training=False) 56 | _dev = preprocess_par(dev_questions, corpus.evidence, dev_preprocesser, args.n_processes, args.chunk_size, "dev") 57 | print("Recall of answer existence in {} set: {:.3f}".format("dev", _dev.ir_count / len(_dev.data))) 58 | print("Average number of documents in {} set: {:.3f}".format("dev", _dev.total_doc_num / len(_dev.data))) 59 | print("Average length of documents in {} set: {:.3f}".format("dev", _dev.total_doc_length / len(_dev.data))) 60 | print("Average pruned length of documents in {} set: {:.3f}".format("dev", _dev.pruned_doc_length / len(_dev.data))) 61 | print("Number of examples: {}".format(len(_dev.data))) 62 | 63 | dev_examples_path = os.path.join(args.data_dir, "dev_{}paras_examples.pkl".format(args.n_para_dev)) 64 | pickle.dump(_dev.data, open(dev_examples_path, 'wb')) 65 | 66 | if args.do_verified: 67 | verified_questions = corpus.get_verified() 68 | verified_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_verified, is_training=False), 69 | intern=True, is_training=False) 70 | _verified = preprocess_par(verified_questions, corpus.evidence, verified_preprocesser, args.n_processes, 71 | args.chunk_size, "verified") 72 | print("Recall of answer existence in {} set: {:.3f}".format("verified", _verified.ir_count / len(_verified.data))) 73 | print("Average number of documents in {} set: {:.3f}".format("verified", _verified.total_doc_num / len(_verified.data))) 74 | print("Average length of documents in {} set: {:.3f}".format("verified", _verified.total_doc_length / len(_verified.data))) 75 | print("Average pruned length of documents in {} set: {:.3f}".format("verified", _verified.pruned_doc_length / len(_verified.data))) 76 | print("Number of examples: {}".format(len(_verified.data))) 77 | 78 | verified_examples_path = os.path.join(args.data_dir, "verified_{}paras_examples.pkl".format(args.n_para_verified)) 79 | pickle.dump(_verified.data, open(verified_examples_path, 'wb')) 80 | 81 | if args.do_test: 82 | test_questions = corpus.get_test() 83 | test_preprocesser = ExtractMultiParagraphsPerQuestion(TopTfIdf(n_to_select=args.n_para_test, is_training=False), 84 | intern=True, is_training=False) 85 | _test = preprocess_par(test_questions, corpus.evidence, test_preprocesser, args.n_processes, 86 | args.chunk_size, "test") 87 | print("Recall of answer existence in {} set: {:.3f}".format("test", _test.ir_count / len(_test.data))) 88 | print("Average number of documents in {} set: {:.3f}".format("test", _test.total_doc_num / len(_test.data))) 89 | print("Average length of documents in {} set: {:.3f}".format("test", _test.total_doc_length / len(_test.data))) 90 | print("Average pruned length of documents in {} set: {:.3f}".format("test", _test.pruned_doc_length / len(_test.data))) 91 | print("Number of examples: {}".format(len(_test.data))) 92 | 93 | test_examples_path = os.path.join(args.data_dir, "test_{}paras_examples.pkl".format(args.n_para_test)) 94 | pickle.dump(_test.data, open(test_examples_path, 'wb')) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() -------------------------------------------------------------------------------- /triviaqa/answer_detection.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | from typing import List 7 | 8 | from triviaqa.read_data import TriviaQaQuestion 9 | from triviaqa.triviaqa_eval import normalize_answer, f1_score 10 | from triviaqa.utils import flatten_iterable, split 11 | 12 | 13 | class FastNormalizedAnswerDetector(object): 14 | """ almost twice as fast and very,very close to NormalizedAnswerDetector's output """ 15 | 16 | def __init__(self): 17 | # These come from the TrivaQA official evaluation script 18 | self.skip = {"a", "an", "the", ""} 19 | self.strip = string.punctuation + "".join([u"‘", u"’", u"´", u"`", "_"]) 20 | 21 | self.answer_tokens = None 22 | 23 | def set_question(self, normalized_aliases): 24 | self.answer_tokens = normalized_aliases 25 | 26 | def any_found(self, para): # List[str] 27 | # Normalize the paragraph 28 | words = [w.lower().strip(self.strip) for w in para] 29 | occurances = [] 30 | for answer_ix, answer in enumerate(self.answer_tokens): 31 | # Locations where the first word occurs 32 | if len(answer) == 0: 33 | continue 34 | word_starts = [i for i, w in enumerate(words) if answer[0] == w] # [12, 50, 63 ...] 35 | n_tokens = len(answer) # 2 36 | 37 | # Advance forward until we find all the words, skipping over articles 38 | for start in word_starts: 39 | end = start + 1 40 | ans_token = 1 41 | while ans_token < n_tokens and end < len(words): 42 | next = words[end] 43 | if answer[ans_token] == next: 44 | ans_token += 1 45 | end += 1 46 | elif next in self.skip: 47 | end += 1 48 | else: 49 | break 50 | if n_tokens == ans_token: 51 | occurances.append((start, end)) 52 | return list(set(occurances)) 53 | 54 | 55 | def compute_answer_spans(questions: List[TriviaQaQuestion], corpus, tokenizer, 56 | detector): 57 | 58 | for i, q in enumerate(questions): 59 | if i % 500 == 0: 60 | print("Completed question %d of %d (%.3f)" % (i, len(questions), i/len(questions))) 61 | q.question = tokenizer.tokenize(q.question) 62 | if q.answer is None: 63 | continue 64 | tokenized_aliases = [tokenizer.tokenize(x) for x in q.answer.all_answers] 65 | if len(tokenized_aliases) == 0: 66 | raise ValueError() 67 | detector.set_question(tokenized_aliases) 68 | for doc in q.all_docs: 69 | text = corpus.get_document(doc.doc_id) # List[List[str]] 70 | if text is None: 71 | raise ValueError() 72 | spans = [] 73 | offset = 0 74 | for para_ix, para in enumerate(text): 75 | for s, e in detector.any_found(para): 76 | spans.append((s+offset, e+offset-1)) # turn into inclusive span 77 | offset += len(para) 78 | if len(spans) == 0: 79 | spans = np.zeros((0, 2), dtype=np.int32) 80 | else: 81 | spans = np.array(spans, dtype=np.int32) 82 | doc.answer_spans = spans 83 | 84 | 85 | def _compute_answer_spans_chunk(questions, corpus, tokenizer, detector): 86 | compute_answer_spans(questions, corpus, tokenizer, detector) 87 | return questions 88 | 89 | 90 | def compute_answer_spans_par(questions: List[TriviaQaQuestion], corpus, 91 | tokenizer, detector, n_processes: int): 92 | if n_processes == 1: 93 | compute_answer_spans(questions, corpus, tokenizer, detector) 94 | return questions 95 | from multiprocessing import Pool 96 | with Pool(n_processes) as p: 97 | chunks = split(questions, n_processes) 98 | questions = flatten_iterable(p.starmap(_compute_answer_spans_chunk, 99 | [[c, corpus, tokenizer, detector] for c in chunks])) 100 | return questions -------------------------------------------------------------------------------- /triviaqa/build_span_corpus.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pickle 4 | import unicodedata 5 | from itertools import islice 6 | from typing import List, Optional, Dict 7 | from os import mkdir 8 | from os.path import join, exists, expanduser 9 | import bert.tokenization as tokenization 10 | from triviaqa.configurable import Configurable 11 | from triviaqa.read_data import iter_trivia_question, TriviaQaQuestion 12 | from triviaqa.evidence_corpus import TriviaQaEvidenceCorpusTxt 13 | from triviaqa.answer_detection import compute_answer_spans_par, FastNormalizedAnswerDetector 14 | 15 | TRIVIA_QA = join(expanduser("~"), "data", "triviaqa") 16 | TRIVIA_QA_UNFILTERED = join(expanduser("~"), "data", "triviaqa-unfiltered") 17 | 18 | 19 | def build_dataset(name: str, tokenizer, train_files: Dict[str, str], 20 | answer_detector, n_process: int, prune_unmapped_docs=True, 21 | sample=None): 22 | out_dir = join("data", "triviaqa", name) 23 | if not exists(out_dir): 24 | mkdir(out_dir) 25 | 26 | file_map = {} # maps document_id -> filename 27 | 28 | for name, filename in train_files.items(): 29 | print("Loading %s questions" % name) 30 | if sample is None: 31 | questions = list(iter_trivia_question(filename, file_map, False)) 32 | else: 33 | if isinstance(sample, int): 34 | questions = list(islice(iter_trivia_question(filename, file_map, False), sample)) 35 | elif isinstance(sample, dict): 36 | questions = list(islice(iter_trivia_question(filename, file_map, False), sample[name])) 37 | else: 38 | raise ValueError() 39 | 40 | if prune_unmapped_docs: 41 | for q in questions: 42 | if q.web_docs is not None: 43 | q.web_docs = [x for x in q.web_docs if x.doc_id in file_map] 44 | q.entity_docs = [x for x in q.entity_docs if x.doc_id in file_map] 45 | 46 | print("Adding answers for %s question" % name) 47 | corpus = TriviaQaEvidenceCorpusTxt(file_map) 48 | questions = compute_answer_spans_par(questions, corpus, tokenizer, answer_detector, n_process) 49 | for q in questions: # Sanity check, we should have answers for everything (even if of size 0) 50 | if q.answer is None: 51 | continue 52 | for doc in q.all_docs: 53 | if doc.doc_id in file_map: 54 | if doc.answer_spans is None: 55 | raise RuntimeError() 56 | 57 | print("Saving %s question" % name) 58 | with open(join(out_dir, name + ".pkl"), "wb") as f: 59 | pickle.dump(questions, f) 60 | 61 | print("Dumping file mapping") 62 | with open(join(out_dir, "file_map.json"), "w") as f: 63 | json.dump(file_map, f) 64 | 65 | print("Complete") 66 | 67 | class TriviaQaSpanCorpus(Configurable): 68 | def __init__(self, corpus_name): 69 | self.corpus_name = corpus_name # web-sample 70 | self.dir = join("data", "triviaqa", corpus_name) 71 | with open(join(self.dir, "file_map.json"), "r") as f: 72 | file_map = json.load(f) 73 | for k, v in file_map.items(): 74 | file_map[k] = unicodedata.normalize("NFD", v) 75 | self.evidence = TriviaQaEvidenceCorpusTxt(file_map) # evidence_corpus.py 76 | 77 | def get_train(self) -> List[TriviaQaQuestion]: 78 | with open(join(self.dir, "train.pkl"), "rb") as f: 79 | return pickle.load(f) 80 | 81 | def get_dev(self) -> List[TriviaQaQuestion]: 82 | with open(join(self.dir, "dev.pkl"), "rb") as f: 83 | return pickle.load(f) 84 | 85 | def get_test(self) -> List[TriviaQaQuestion]: 86 | with open(join(self.dir, "test.pkl"), "rb") as f: 87 | return pickle.load(f) 88 | 89 | def get_verified(self) -> Optional[List[TriviaQaQuestion]]: 90 | verified_dir = join(self.dir, "verified.pkl") 91 | if not exists(verified_dir): 92 | return None 93 | with open(verified_dir, "rb") as f: 94 | return pickle.load(f) 95 | 96 | @property 97 | def name(self): 98 | return self.corpus_name 99 | 100 | class TriviaQaWebDataset(TriviaQaSpanCorpus): 101 | def __init__(self): 102 | super().__init__("web") 103 | 104 | class TriviaQaWikiDataset(TriviaQaSpanCorpus): 105 | def __init__(self): 106 | super().__init__("wiki") 107 | 108 | class TriviaQaUnfilteredDataset(TriviaQaSpanCorpus): 109 | def __init__(self): 110 | super().__init__("unfiltered") 111 | 112 | class TriviaQaSampleWebDataset(TriviaQaSpanCorpus): 113 | def __init__(self): 114 | super().__init__("web-sample") 115 | 116 | class TriviaQaSampleWikiDataset(TriviaQaSpanCorpus): 117 | def __init__(self): 118 | super().__init__("wiki-sample") 119 | 120 | class TriviaQaSampleUnfilteredDataset(TriviaQaSpanCorpus): 121 | def __init__(self): 122 | super().__init__("unfiltered-sample") 123 | 124 | def build_wiki_corpus(n_processes): 125 | build_dataset("wiki", tokenization.BasicTokenizer(do_lower_case=True), 126 | dict( 127 | verified=join(TRIVIA_QA, "qa", "verified-wikipedia-dev.json"), 128 | dev=join(TRIVIA_QA, "qa", "wikipedia-dev.json"), 129 | train=join(TRIVIA_QA, "qa", "wikipedia-train.json"), 130 | test=join(TRIVIA_QA, "qa", "wikipedia-test-without-answers.json") 131 | ), 132 | FastNormalizedAnswerDetector(), n_processes) 133 | 134 | def build_web_corpus(n_processes): 135 | build_dataset("web", tokenization.BasicTokenizer(do_lower_case=True), 136 | dict( 137 | verified=join(TRIVIA_QA, "qa", "verified-web-dev.json"), 138 | dev=join(TRIVIA_QA, "qa", "web-dev.json"), 139 | train=join(TRIVIA_QA, "qa", "web-train.json"), 140 | test=join(TRIVIA_QA, "qa", "web-test-without-answers.json") 141 | ), 142 | FastNormalizedAnswerDetector(), n_processes) 143 | 144 | def build_unfiltered_corpus(n_processes): 145 | build_dataset("unfiltered", tokenization.BasicTokenizer(do_lower_case=True), 146 | dict( 147 | dev=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-dev.json"), 148 | train=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-train.json"), 149 | test=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-test-without-answers.json") 150 | ), 151 | FastNormalizedAnswerDetector(), n_processes) 152 | 153 | def build_wiki_sample_corpus(n_processes): 154 | build_dataset("wiki-sample", tokenization.BasicTokenizer(do_lower_case=True), 155 | dict( 156 | verified=join(TRIVIA_QA, "qa", "verified-wikipedia-dev.json"), 157 | dev=join(TRIVIA_QA, "qa", "wikipedia-dev.json"), 158 | train=join(TRIVIA_QA, "qa", "wikipedia-train.json"), 159 | test=join(TRIVIA_QA, "qa", "wikipedia-test-without-answers.json") 160 | ), 161 | FastNormalizedAnswerDetector(), n_processes, sample=20) 162 | 163 | def build_web_sample_corpus(n_processes): 164 | build_dataset("web-sample", tokenization.BasicTokenizer(do_lower_case=True), 165 | dict( 166 | verified=join(TRIVIA_QA, "qa", "verified-web-dev.json"), 167 | dev=join(TRIVIA_QA, "qa", "web-dev.json"), 168 | train=join(TRIVIA_QA, "qa", "web-train.json"), 169 | test=join(TRIVIA_QA, "qa", "web-test-without-answers.json") 170 | ), 171 | FastNormalizedAnswerDetector(), n_processes, sample=20) 172 | 173 | def build_unfiltered_sample_corpus(n_processes): 174 | build_dataset("unfiltered-sample", tokenization.BasicTokenizer(do_lower_case=True), 175 | dict( 176 | dev=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-dev.json"), 177 | train=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-train.json"), 178 | test=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-test-without-answers.json") 179 | ), 180 | FastNormalizedAnswerDetector(), n_processes, sample=20) 181 | 182 | def main(): 183 | parser = argparse.ArgumentParser("Pre-procsess TriviaQA data") 184 | parser.add_argument("corpus", choices=["web", "wiki", "unfiltered", "web-sample", "wiki-sample", "unfiltered-sample"]) 185 | parser.add_argument("-n", "--n_processes", type=int, default=1, help="Number of processes to use") 186 | args = parser.parse_args() 187 | if args.corpus == "web": 188 | build_web_corpus(args.n_processes) 189 | elif args.corpus == "wiki": 190 | build_wiki_corpus(args.n_processes) 191 | elif args.corpus == "unfiltered": 192 | build_unfiltered_corpus(args.n_processes) 193 | elif args.corpus == "web-sample": 194 | build_web_sample_corpus(args.n_processes) 195 | elif args.corpus == "wiki-sample": 196 | build_wiki_sample_corpus(args.n_processes) 197 | elif args.corpus == "unfiltered-sample": 198 | build_unfiltered_sample_corpus(args.n_processes) 199 | else: 200 | raise RuntimeError() 201 | 202 | 203 | if __name__ == "__main__": 204 | main() -------------------------------------------------------------------------------- /triviaqa/configurable.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | from inspect import signature 4 | from warnings import warn 5 | 6 | import numpy as np 7 | from sklearn.base import BaseEstimator 8 | 9 | 10 | class Configuration(object): 11 | def __init__(self, name, version, params): 12 | if not isinstance(name, str): 13 | raise ValueError() 14 | if not isinstance(params, dict): 15 | raise ValueError() 16 | self.name = name 17 | self.version = version 18 | self.params = params 19 | 20 | def __str__(self): 21 | if len(self.params) == 0: 22 | return "%s-v%s" % (self.name, self.version) 23 | json_params = config_to_json(self.params) 24 | if len(json_params) < 200: 25 | return "%s-v%s: %s" % (self.name, self.version, json_params) 26 | else: 27 | return "%s-v%s {...}" % (self.name, self.version) 28 | 29 | def __eq__(self, other): 30 | return isinstance(other, Configuration) and \ 31 | self.name == other.name and \ 32 | self.version == other.version and \ 33 | self.params == other.params 34 | 35 | 36 | class Configurable(object): 37 | """ 38 | Configurable classes have names, versions, and a set of parameters that are either "simple" aka JSON serializable 39 | types or other Configurable objects. Configurable objects should also be serializable via pickle. 40 | Configurable classes are defined mainly to give us a human-readable way of reading of the `parameters` 41 | set for different objects and to attach version numbers to them. 42 | 43 | By default we follow the format sklearn uses for its `BaseEstimator` class, where parameters are automatically 44 | derived based on the constructor parameters. 45 | """ 46 | 47 | @classmethod 48 | def _get_param_names(cls): 49 | # fetch the constructor or the original constructor before 50 | init = cls.__init__ 51 | if init is object.__init__: 52 | # No explicit constructor to introspect 53 | return [] 54 | 55 | init_signature = signature(init) 56 | parameters = [p for p in init_signature.parameters.values() 57 | if p.name != 'self'] 58 | if any(p.kind == p.VAR_POSITIONAL for p in parameters): 59 | raise RuntimeError() 60 | return sorted([p.name for p in parameters]) 61 | 62 | @property 63 | def name(self): 64 | return self.__class__.__name__ 65 | 66 | @property 67 | def version(self): 68 | return 0 69 | 70 | def get_params(self): 71 | out = {} 72 | for key in self._get_param_names(): 73 | v = getattr(self, key, None) 74 | if isinstance(v, Configurable): 75 | out[key] = v.get_config() 76 | elif hasattr(v, "get_config"): # for keras objects 77 | out[key] = {"name": v.__class__.__name__, "config": v.get_config()} 78 | else: 79 | out[key] = v 80 | return out 81 | 82 | def get_config(self) -> Configuration: 83 | params = {k: describe(v) for k,v in self.get_params().items()} 84 | return Configuration(self.name, self.version, params) 85 | 86 | def __getstate__(self): 87 | state = dict(self.__dict__) 88 | if "version" in state: 89 | if state["version"] != self.version: 90 | raise RuntimeError() 91 | else: 92 | state["version"] = self.version 93 | return state 94 | 95 | def __setstate__(self, state): 96 | if "version" not in state: 97 | raise RuntimeError("Version should be in state (%s)" % self.__class__.__name__) 98 | if state["version"] != self.version: 99 | warn(("%s loaded with version %s, but class " + 100 | "version is %s") % (self.__class__.__name__, state["version"], self.version)) 101 | 102 | if "state" in state: 103 | self.__dict__ = state["state"] 104 | else: 105 | del state["version"] 106 | self.__dict__ = state 107 | 108 | 109 | def describe(obj): 110 | if isinstance(obj, Configurable): 111 | return obj.get_config() 112 | else: 113 | obj_type = type(obj) 114 | 115 | if obj_type in (list, set, frozenset, tuple): 116 | return obj_type([describe(e) for e in obj]) 117 | elif isinstance(obj, tuple): 118 | # Name tuple, convert to tuple 119 | return tuple(describe(e) for e in obj) 120 | elif obj_type in (dict, OrderedDict): 121 | output = OrderedDict() 122 | for k, v in obj.items(): 123 | if isinstance(k, Configurable): 124 | raise ValueError() 125 | output[k] = describe(v) 126 | return output 127 | else: 128 | return obj 129 | 130 | 131 | class EncodeDescription(json.JSONEncoder): 132 | """ Json encoder that encodes 'Configurable' objects as dictionaries and handles 133 | some numpy types. Note decoding this output will not reproduce the original input, 134 | for these types, this is only intended to be used to produce human readable output. 135 | '""" 136 | def default(self, obj): 137 | if isinstance(obj, np.integer): 138 | return int(obj) 139 | elif isinstance(obj, np.dtype): 140 | return str(obj) 141 | elif isinstance(obj, np.floating): 142 | return float(obj) 143 | elif isinstance(obj, np.bool_): 144 | return bool(obj) 145 | elif isinstance(obj, np.ndarray): 146 | return obj.tolist() 147 | elif isinstance(obj, BaseEstimator): # handle sklearn estimators 148 | return Configuration(obj.__class__.__name__, 0, obj.get_params()) 149 | elif isinstance(obj, Configuration): 150 | if "version" in obj.params or "name" in obj.params: 151 | raise ValueError() 152 | out = OrderedDict() 153 | out["name"] = obj.name 154 | if obj.version != 0: 155 | out["version"] = obj.version 156 | out.update(obj.params) 157 | return out 158 | elif isinstance(obj, Configurable): 159 | return obj.get_config() 160 | elif isinstance(obj, set): 161 | return sorted(obj) # Ensure deterministic order 162 | else: 163 | try: 164 | return super().default(obj) 165 | except TypeError: 166 | return str(obj) 167 | 168 | 169 | def config_to_json(data, indent=None): 170 | return json.dumps(data, sort_keys=False, cls=EncodeDescription, indent=indent) 171 | -------------------------------------------------------------------------------- /triviaqa/evidence_corpus.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | from os import walk, mkdir, makedirs 4 | from os.path import relpath, join, exists, expanduser 5 | from typing import Set 6 | from tqdm import tqdm 7 | from typing import List 8 | 9 | import bert.tokenization as tokenization 10 | from triviaqa.utils import split, flatten_iterable, group 11 | from triviaqa.read_data import normalize_wiki_filename 12 | 13 | TRIVIA_QA = join(expanduser("~"), "data", "triviaqa") 14 | 15 | class MergeParagraphs(object): 16 | def __init__(self, max_tokens: int): 17 | self.max_tokens = max_tokens 18 | 19 | def merge(self, paragraphs: List[List[str]]): 20 | all_paragraphs = [] 21 | 22 | on_paragraph = [] # text we have collect for the current paragraph 23 | cur_tokens = 0 # number of tokens in the current paragraph 24 | 25 | word_ix = 0 26 | for para in paragraphs: 27 | n_words = len(para) 28 | start_token = word_ix 29 | end_token = start_token + n_words 30 | word_ix = end_token 31 | 32 | if cur_tokens + n_words > self.max_tokens: 33 | if cur_tokens != 0: # end the current paragraph 34 | all_paragraphs.append(on_paragraph) 35 | on_paragraph = [] 36 | cur_tokens = 0 37 | 38 | if n_words >= self.max_tokens: # either add current paragraph, or begin a new paragraph 39 | all_paragraphs.append(para) 40 | else: 41 | on_paragraph += para 42 | cur_tokens = n_words 43 | else: 44 | on_paragraph += para 45 | cur_tokens += n_words 46 | 47 | if on_paragraph != []: 48 | all_paragraphs.append(on_paragraph) 49 | return all_paragraphs 50 | 51 | def _gather_files(input_root, output_dir, skip_dirs, wiki_only): 52 | if not exists(output_dir): 53 | mkdir(output_dir) 54 | 55 | all_files = [] 56 | for root, dirs, filenames in walk(input_root): 57 | if skip_dirs: # False 58 | output = join(output_dir, relpath(root, input_root)) 59 | if exists(output): 60 | continue 61 | path = relpath(root, input_root) 62 | normalized_path = normalize_wiki_filename(path) 63 | if not exists(join(output_dir, normalized_path)): 64 | mkdir(join(output_dir, normalized_path)) 65 | all_files += [join(path, x) for x in filenames] 66 | if wiki_only: 67 | all_files = [x for x in all_files if "wikipedia/" in x] 68 | return all_files 69 | 70 | def build_tokenized_files(filenames, input_root, output_root, tokenizer, splitter, override=True) -> Set[str]: 71 | """ 72 | For each file in `filenames` loads the text, tokenizes it with `tokenizer, and 73 | saves the output to the same relative location in `output_root`. 74 | @:return a set of all the individual words seen 75 | """ 76 | voc = set() 77 | for filename in filenames: 78 | out_file = normalize_wiki_filename(filename[:filename.rfind(".")]) + ".txt" 79 | out_file = join(output_root, out_file) 80 | if not override and exists(out_file): 81 | continue 82 | with open(join(input_root, filename), "r") as in_file: 83 | text = in_file.read().strip() 84 | paras = [x for x in text.split("\n") if len(x) > 0] 85 | paragraphs = [tokenizer.tokenize(x) for x in paras] 86 | merged_paragraphs = splitter.merge(paragraphs) 87 | 88 | for para in merged_paragraphs: 89 | for i, word in enumerate(para): 90 | voc.update(word) 91 | 92 | with open(out_file, "w") as in_file: 93 | in_file.write("\n\n".join(" ".join(para) for para in merged_paragraphs)) 94 | return voc 95 | 96 | def build_tokenized_corpus(input_root, tokenizer, splitter, output_dir, skip_dirs=False, 97 | n_processes=1, wiki_only=False): 98 | if not exists(output_dir): 99 | makedirs(output_dir) 100 | 101 | all_files = _gather_files(input_root, output_dir, skip_dirs, wiki_only) 102 | 103 | if n_processes == 1: 104 | voc = build_tokenized_files(tqdm(all_files, ncols=80), input_root, output_dir, tokenizer, splitter) 105 | else: 106 | voc = set() 107 | from multiprocessing import Pool 108 | with Pool(n_processes) as pool: 109 | chunks = split(all_files, n_processes) 110 | chunks = flatten_iterable(group(c, 500) for c in chunks) 111 | pbar = tqdm(total=len(chunks), ncols=80) 112 | for v in pool.imap_unordered(_build_tokenized_files_t, 113 | [[c, input_root, output_dir, tokenizer, splitter] for c in chunks]): 114 | voc.update(v) 115 | pbar.update(1) 116 | pbar.close() 117 | 118 | def _build_tokenized_files_t(arg): 119 | return build_tokenized_files(*arg) 120 | 121 | class TriviaQaEvidenceCorpusTxt(object): 122 | """ 123 | Corpus of the tokenized text from the given TriviaQa evidence documents. 124 | Allows the text to be retrieved by document id 125 | """ 126 | 127 | _split_para = re.compile("\n\n+") 128 | 129 | def __init__(self, file_id_map=None): 130 | self.directory = join("data", "triviaqa/evidence") 131 | self.file_id_map = file_id_map 132 | 133 | def get_document(self, doc_id): 134 | if self.file_id_map is None: 135 | file_id = doc_id 136 | else: 137 | file_id = self.file_id_map.get(doc_id) 138 | 139 | if file_id is None: 140 | return None 141 | 142 | file_id = join(self.directory, file_id + ".txt") 143 | if not exists(file_id): 144 | return None 145 | 146 | with open(file_id, "r") as f: 147 | text = f.read() 148 | paragraphs = [] 149 | for para in self._split_para.split(text): 150 | paragraphs.append(para.split(" ")) 151 | return paragraphs # List[List[str]] 152 | 153 | def main(): 154 | parse = argparse.ArgumentParser("Pre-tokenize the TriviaQA evidence corpus") 155 | parse.add_argument("-o", "--output_dir", type=str, default=join("data", "triviaqa", "evidence")) 156 | parse.add_argument("-s", "--source", type=str, default=join(TRIVIA_QA, "evidence")) 157 | # This is slow, using more processes is recommended 158 | parse.add_argument("-n", "--n_processes", type=int, default=1, help="Number of processes to use") 159 | parse.add_argument("--max_tokens", type=int, default=200, help="Number of maximal tokens in each merged paragraph") 160 | parse.add_argument("--wiki_only", action="store_true") 161 | args = parse.parse_args() 162 | 163 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 164 | splitter = MergeParagraphs(args.max_tokens) 165 | build_tokenized_corpus(args.source, tokenizer, splitter, args.output_dir, 166 | n_processes=args.n_processes, wiki_only=args.wiki_only) 167 | 168 | if __name__ == "__main__": 169 | main() -------------------------------------------------------------------------------- /triviaqa/preprocessed_corpus.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gzip 3 | import random 4 | import pickle 5 | from collections import Counter 6 | from threading import Lock 7 | from typing import List, Iterable, Optional 8 | 9 | import math 10 | import numpy as np 11 | from sklearn.feature_extraction.text import TfidfVectorizer 12 | from sklearn.metrics import pairwise_distances 13 | from tqdm import tqdm 14 | from triviaqa.utils import split, flatten_iterable, group 15 | from triviaqa.configurable import Configurable 16 | from triviaqa.read_data import TriviaQaQuestion 17 | from triviaqa.triviaqa_document_utils import ExtractedParagraphWithAnswers, DocParagraphWithAnswers, DocumentAndQuestion 18 | 19 | stop_words = {'t', '–', 'there', 'but', 'needn', 'themselves', '’', '~', '$', 'few', '^', '₹', ']', 'we', 're', 20 | 'again', '?', 'they', 'ain', 'o', 'you', '+', 'has', 'by', 'than', 'whom', 'same', 'don', 'her', 21 | 'are', '(', 'an', 'so', 'the', 'been', 'wouldn', 'a', 'many', 'she', 'how', 'your', '°', 'do', 22 | 'shan', 'himself', 'between', 'ours', 'at', 'should', 'doesn', 'hasn', 'he', 'have', 'over', 23 | 'hadn', 'was', 'weren', 'down', 'above', '_', 'those', 'not', 'having', 'its', 'ourselves', 24 | 'for', 'when', 'if', ',', ';', 'about', 'theirs', 'him', '}', 'here', 'any', 'own', 'itself', 25 | 'very', 'on', 'myself', 'mustn', ')', 'because', 'now', '/', 'isn', 'to', 'just', 'these', 26 | 'i', 'further', 'mightn', 'll', '@', 'am', '”', 'below', 'shouldn', 'my', 'who', 'yours', 'why', 27 | 'such', '"', 'does', 'did', 'before', 'being', 'and', 'had', 'aren', '£', 'with', 'more', 'into', 28 | '<', 'herself', 'which', '[', "'", 'of', 'haven', 'that', 'will', 'yourself', 'in', 'doing', '−', 29 | 'them', '‘', 'some', '`', 'while', 'each', 'it', 'through', 'all', 'their', ':', '\\', 'where', 30 | 'both', 'hers', '¢', '—', 'm', '.', 'from', 'or', 'other', 'too', 'couldn', 'as', 'our', 'off', 31 | '%', '&', '-', '{', '=', 'didn', 'yourselves', 'under', 'y', 'ma', 'won', '!', '|', 'against', 32 | '#', '¥', 'is', 'nor', 'up', 'most', 's', 'no', 'can', '>', '*', 'during', 'once', 'what', 'me', 33 | 'then', 'd', 'only', 'de', 've', 'were', '€', 'until', 'his', 'out', 'wasn', 'this', 'after', 34 | 'be'} 35 | 36 | class ParagraphsSet(object): 37 | def __init__(self, paragraphs: List[ExtractedParagraphWithAnswers], ir_hit: bool): 38 | self.paragraphs = paragraphs 39 | self.ir_hit = ir_hit 40 | 41 | class ParagraphFilter(Configurable): 42 | """ Selects and ranks paragraphs """ 43 | 44 | def prune(self, question, paragraphs: List[ExtractedParagraphWithAnswers]): 45 | raise NotImplementedError() 46 | 47 | class TopTfIdf(ParagraphFilter): 48 | def __init__(self, n_to_select: int, is_training: bool=False, sort_passage: bool=True): 49 | self.n_to_select = n_to_select 50 | self.is_training = is_training 51 | self.sort_passage = sort_passage 52 | 53 | def prune(self, question: List[str], paragraphs: List[ExtractedParagraphWithAnswers]): 54 | tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=stop_words) 55 | text = [] 56 | for para in paragraphs: 57 | text.append(" ".join(para.text)) 58 | try: 59 | para_features = tfidf.fit_transform(text) 60 | q_features = tfidf.transform([" ".join(question)]) 61 | except ValueError: 62 | return [] 63 | 64 | dists = pairwise_distances(q_features, para_features, "cosine").ravel() # [N] 65 | sorted_ix = np.lexsort(([x.start for x in paragraphs], dists)) # in case of ties, use the earlier paragraph [N] 66 | 67 | selection = [i for i in sorted_ix[:self.n_to_select]] 68 | selected_paras = [paragraphs[i] for i in selection] 69 | ir_hit = 0. if all(len(x.answer_spans) == 0 for x in selected_paras) else 1. 70 | 71 | if self.is_training and not ir_hit: 72 | gold_indexes = [i for i, x in enumerate(paragraphs) if len(x.answer_spans) != 0] 73 | gold_index = random.choice(gold_indexes) 74 | selection[-1] = gold_index 75 | 76 | if self.sort_passage: 77 | selection = np.sort(selection) 78 | 79 | return [paragraphs[i] for i in selection], ir_hit 80 | 81 | class ShallowOpenWebRanker(ParagraphFilter): 82 | # Hard coded weight learned from a logistic regression classifier 83 | TFIDF_W = 5.13365065 84 | LOG_WORD_START_W = 0.46022765 85 | FIRST_W = -0.08611607 86 | LOWER_WORD_W = 0.0499123 87 | WORD_W = -0.15537181 88 | 89 | def __init__(self, n_to_select: int, is_training: bool=False, sort_passage: bool=True): 90 | self.n_to_select = n_to_select 91 | self.is_training = is_training 92 | self.sort_passage = sort_passage 93 | self._tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=stop_words) 94 | 95 | def score_paragraphs(self, question, paragraphs: List[ExtractedParagraphWithAnswers]): 96 | tfidf = self._tfidf 97 | text = [] 98 | for para in paragraphs: 99 | text.append(" ".join(para.text)) 100 | try: 101 | para_features = tfidf.fit_transform(text) 102 | q_features = tfidf.transform([" ".join(question)]) 103 | except ValueError: 104 | return [] 105 | 106 | q_words = {x for x in question if x.lower() not in stop_words} 107 | q_words_lower = {x.lower() for x in q_words} 108 | word_matches_features = np.zeros((len(paragraphs), 2)) 109 | for para_ix, para in enumerate(paragraphs): 110 | found = set() 111 | found_lower = set() 112 | for word in para.text: 113 | if word in q_words: 114 | found.add(word) 115 | elif word.lower() in q_words_lower: 116 | found_lower.add(word.lower()) 117 | word_matches_features[para_ix, 0] = len(found) 118 | word_matches_features[para_ix, 1] = len(found_lower) 119 | 120 | tfidf = pairwise_distances(q_features, para_features, "cosine").ravel() 121 | starts = np.array([p.start for p in paragraphs]) 122 | log_word_start = np.log(starts / 200.0 + 1) 123 | first = starts == 0 124 | scores = tfidf * self.TFIDF_W + self.LOG_WORD_START_W * log_word_start + self.FIRST_W * first + \ 125 | self.LOWER_WORD_W * word_matches_features[:, 1] + self.WORD_W * word_matches_features[:, 0] 126 | return scores 127 | 128 | def prune(self, question: List[str], paragraphs: List[ExtractedParagraphWithAnswers]): 129 | scores = self.score_paragraphs(question, paragraphs) 130 | sorted_ix = np.argsort(scores) 131 | 132 | selection = [i for i in sorted_ix[:self.n_to_select]] 133 | selected_paras = [paragraphs[i] for i in selection] 134 | ir_hit = 0. if all(len(x.answer_spans) == 0 for x in selected_paras) else 1. 135 | 136 | if self.is_training and not ir_hit: 137 | gold_indexes = [i for i, x in enumerate(paragraphs) if len(x.answer_spans) != 0] 138 | gold_index = random.choice(gold_indexes) 139 | selection[-1] = gold_index 140 | 141 | if self.sort_passage: 142 | selection = np.sort(selection) 143 | 144 | return [paragraphs[i] for i in selection], ir_hit 145 | 146 | 147 | class Preprocessor(Configurable): 148 | 149 | def preprocess(self, question: Iterable, evidence) -> object: 150 | """ Map elements to an unspecified intermediate format """ 151 | raise NotImplementedError() 152 | 153 | def finalize_chunk(self, x): 154 | """ Finalize the output from `preprocess`, in multi-processing senarios this will still be run on 155 | the main thread so it can be used for things like interning """ 156 | pass 157 | 158 | def _preprocess_and_count(questions: List, evidence, preprocessor: Preprocessor): 159 | count = len(questions) 160 | output = preprocessor.preprocess(questions, evidence) 161 | return output, count 162 | 163 | def preprocess_par(questions: List, evidence, preprocessor, 164 | n_processes=2, chunk_size=200, name=None): 165 | if chunk_size <= 0: 166 | raise ValueError("Chunk size must be >= 0, but got %s" % chunk_size) 167 | if n_processes is not None and n_processes <= 0: 168 | raise ValueError("n_processes must be >= 1 or None, but got %s" % n_processes) 169 | n_processes = min(len(questions), n_processes) 170 | 171 | if n_processes == 1: 172 | out = preprocessor.preprocess(tqdm(questions, desc=name, ncols=80), evidence) 173 | preprocessor.finalize_chunk(out) 174 | return out 175 | else: 176 | from multiprocessing import Pool 177 | chunks = split(questions, n_processes) 178 | chunks = flatten_iterable([group(c, chunk_size) for c in chunks]) 179 | print("Processing %d chunks with %d processes" % (len(chunks), n_processes)) 180 | pbar = tqdm(total=len(questions), desc=name, ncols=80) 181 | lock = Lock() 182 | 183 | def call_back(results): 184 | preprocessor.finalize_chunk(results[0]) 185 | with lock: # FIXME Even with the lock, the progress bar still is jumping around 186 | pbar.update(results[1]) 187 | 188 | with Pool(n_processes) as pool: 189 | results = [pool.apply_async(_preprocess_and_count, [c, evidence, preprocessor], callback=call_back) 190 | for c in chunks] 191 | results = [r.get()[0] for r in results] 192 | 193 | pbar.close() 194 | output = results[0] 195 | for r in results[1:]: 196 | output += r 197 | return output 198 | 199 | class FilteredData(object): 200 | def __init__(self, data: List, true_len: int, ir_count: int, 201 | total_doc_num: int, total_doc_length: int, pruned_doc_length: int): 202 | self.data = data 203 | self.true_len = true_len 204 | self.ir_count = ir_count 205 | self.total_doc_num = total_doc_num 206 | self.total_doc_length = total_doc_length 207 | self.pruned_doc_length = pruned_doc_length 208 | 209 | def __add__(self, other): 210 | return FilteredData(self.data + other.data, self.true_len + other.true_len, self.ir_count + other.ir_count, 211 | self.total_doc_num + other.total_doc_num, self.total_doc_length + other.total_doc_length, 212 | self.pruned_doc_length + other.pruned_doc_length) 213 | 214 | def split_annotated(doc: List[List[str]], spans: np.ndarray): 215 | out = [] 216 | offset = 0 217 | for para in doc: 218 | para_start = offset 219 | para_end = para_start + len(para) 220 | para_spans = spans[np.logical_and(spans[:, 0] >= para_start, spans[:, 1] < para_end)] - para_start 221 | out.append(ExtractedParagraphWithAnswers(para, para_start, para_end, para_spans)) 222 | offset += len(para) 223 | return out 224 | 225 | class ExtractMultiParagraphsPerQuestion(Preprocessor): 226 | def __init__(self, ranker: ParagraphFilter, intern: bool=False, is_training=False): 227 | self.ranker = ranker 228 | self.intern = intern 229 | self.is_training = is_training 230 | 231 | def preprocess(self, questions: List[TriviaQaQuestion], evidence): # TriviaQaEvidenceCorpusTxt evidence_corpus.py 232 | ir_count, total_doc_num, total_doc_length, pruned_doc_length = 0, 0, 0, 0 233 | 234 | instances = [] 235 | for q in questions: 236 | doc_paras = [] 237 | doc_count, doc_length = 0, 0 238 | for doc in q.all_docs: 239 | if self.is_training and len(doc.answer_spans) == 0: 240 | continue 241 | text = evidence.get_document(doc.doc_id) # List[List[str]] 242 | if text is None: 243 | raise ValueError("No evidence text found document: " + doc.doc_id) 244 | if doc.answer_spans is not None: 245 | paras = split_annotated(text, doc.answer_spans) 246 | else: 247 | # this is kind of a hack to make the rest of the pipeline work, only 248 | # needed for test cases 249 | paras = split_annotated(text, np.zeros((0, 2), dtype=np.int32)) 250 | doc_paras.extend([DocParagraphWithAnswers(x.text, x.start, x.end, x.answer_spans, doc.doc_id) 251 | for x in paras]) # List[DocParagraphWithAnswers] 252 | doc_length += sum(len(para) for para in text) 253 | doc_count += 1 254 | 255 | if len(doc_paras) == 0: 256 | continue 257 | 258 | doc_paras, ir_hit = self.ranker.prune(q.question, doc_paras) # List[ExtractedParagraphWithAnswers] len=4 259 | total_doc_num += doc_count 260 | total_doc_length += doc_length 261 | ir_count += ir_hit 262 | 263 | # merge into documentandquestion 264 | doc_tokens, start_positions, end_positions = [], [], [] 265 | for x in doc_paras: 266 | offset_doc = len(doc_tokens) 267 | doc_tokens += x.text 268 | if len(x.answer_spans) != 0: 269 | start_position = x.answer_spans[:, 0] + offset_doc 270 | end_position = x.answer_spans[:, 1] + offset_doc 271 | start_positions.extend(start_position) 272 | end_positions.extend(end_position) 273 | instance = DocumentAndQuestion(q.all_docs[0].doc_id, q.question_id, " ".join(q.question), doc_tokens, 274 | None if q.answer is None else q.answer.all_answers, start_positions, 275 | end_positions) 276 | pruned_doc_length += len(doc_tokens) 277 | 278 | instances.append(instance) 279 | return FilteredData(instances, len(questions), ir_count, total_doc_num, total_doc_length, pruned_doc_length) 280 | 281 | def finalize_chunk(self, f: FilteredData): 282 | if self.intern: 283 | for ins in f.data: 284 | ins.document_id = sys.intern(ins.document_id) 285 | ins.qas_id = sys.intern(ins.qas_id) 286 | ins.question_text = sys.intern(ins.question_text) 287 | 288 | 289 | # class ExtractMultiParagraphs(Preprocessor): 290 | # def __init__(self, ranker: ParagraphFilter, intern: bool=False, is_training=False): 291 | # self.ranker = ranker 292 | # self.intern = intern 293 | # self.is_training = is_training 294 | # 295 | # def preprocess(self, questions: List[TriviaQaQuestion], evidence): # TriviaQaEvidenceCorpusTxt evidence_corpus.py 296 | # true_len = 0 297 | # ir_count, ir_total, pruned_doc_length = 0, 0, 0 298 | # 299 | # instances = [] 300 | # for q in questions: 301 | # true_len += len(q.all_docs) 302 | # for doc in q.all_docs: 303 | # if self.is_training and len(doc.answer_spans) == 0: 304 | # continue 305 | # text = evidence.get_document(doc.doc_id) # List[List[str]] 306 | # if text is None: 307 | # raise ValueError("No evidence text found document: " + doc.doc_id) 308 | # if doc.answer_spans is not None: 309 | # paras = split_annotated(text, doc.answer_spans) 310 | # else: 311 | # # this is kind of a hack to make the rest of the pipeline work, only 312 | # # needed for test cases 313 | # paras = split_annotated(text, np.zeros((0, 2), dtype=np.int32)) 314 | # 315 | # if len(paras) == 0: 316 | # continue 317 | # 318 | # paras, ir_hit = self.ranker.prune(q.question, paras) # List[ExtractedParagraphWithAnswers] len=4 319 | # ir_count += ir_hit 320 | # ir_total += 1 321 | # 322 | # # merge into documentandquestion 323 | # doc_tokens, start_positions, end_positions = [], [], [] 324 | # for x in paras: 325 | # offset_doc = len(doc_tokens) 326 | # doc_tokens += x.text 327 | # if len(x.answer_spans) != 0: 328 | # start_position = x.answer_spans[:, 0] + offset_doc 329 | # end_position = x.answer_spans[:, 1] + offset_doc 330 | # start_positions.extend(start_position) 331 | # end_positions.extend(end_position) 332 | # instance = DocumentAndQuestion(doc.doc_id, q.question_id, " ".join(q.question), doc_tokens, 333 | # q.answer.all_answers, start_positions, end_positions) 334 | # pruned_doc_length += len(doc_tokens) 335 | # 336 | # instances.append(instance) 337 | # return FilteredData(instances, true_len, ir_count, ir_total, pruned_doc_length) 338 | # 339 | # def finalize_chunk(self, f: FilteredData): 340 | # if self.intern: 341 | # for ins in f.data: 342 | # ins.document_id = sys.intern(ins.document_id) 343 | # ins.qas_id = sys.intern(ins.qas_id) 344 | # ins.question_text = sys.intern(ins.question_text) -------------------------------------------------------------------------------- /triviaqa/read_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unicodedata 3 | from os.path import join 4 | from typing import List 5 | 6 | from triviaqa.triviaqa_eval import normalize_answer as triviaqa_normalize_answer 7 | 8 | """ 9 | Read and represent trivia-qa data 10 | """ 11 | 12 | 13 | def normalize_wiki_filename(filename): 14 | """ 15 | Wiki filenames have been an pain, since the data seems to have filenames encoded in 16 | the incorrect case sometimes, and we have to be careful to keep a consistent unicode format. 17 | Our current solution is require all filenames to be normalized like this 18 | """ 19 | return unicodedata.normalize("NFD", filename).lower() 20 | 21 | 22 | class WikipediaEntity(object): 23 | __slots__ = ["value", "normalized_value", "aliases", "normalized_aliases", 24 | "wiki_entity_name", "normalized_wiki_entity_name", "human_answers"] 25 | 26 | def __init__(self, value: str, normalized_value: str, aliases, normalized_aliases: List[str], 27 | wiki_entity_name: str, normalized_wiki_entity_name: str, human_answers): 28 | self.aliases = aliases 29 | self.value = value 30 | self.normalized_value = normalized_value 31 | self.normalized_aliases = normalized_aliases 32 | self.wiki_entity_name = wiki_entity_name 33 | self.normalized_wiki_entity_name = normalized_wiki_entity_name 34 | self.human_answers = human_answers 35 | 36 | @property 37 | def all_answers(self): 38 | if self.human_answers is None: 39 | return self.normalized_aliases 40 | else: 41 | # normalize to be consistent with the other normallized aliases 42 | human_answers = [triviaqa_normalize_answer(x) for x in self.human_answers] 43 | return self.normalized_aliases + [x for x in human_answers if len(x) > 0] 44 | 45 | def __repr__(self) -> str: 46 | return self.value 47 | 48 | 49 | class Numerical(object): 50 | __slots__ = ["number", "aliases", "normalized_aliases", "value", "unit", 51 | "normalized_value", "multiplier", "human_answers"] 52 | 53 | def __init__(self, number: float, aliases, normalized_aliases, value, unit, 54 | normalized_value, multiplier, human_answers): 55 | self.number = number 56 | self.aliases = aliases 57 | self.normalized_aliases = normalized_aliases 58 | self.value = value 59 | self.unit = unit 60 | self.normalized_value = normalized_value 61 | self.multiplier = multiplier 62 | self.human_answers = human_answers 63 | 64 | @property 65 | def all_answers(self): 66 | if self.human_answers is None: 67 | return self.normalized_aliases 68 | else: 69 | human_answers = [triviaqa_normalize_answer(x) for x in self.human_answers] 70 | return self.normalized_aliases + [x for x in human_answers if len(x) > 0] 71 | 72 | def __repr__(self) -> str: 73 | return self.value 74 | 75 | 76 | class FreeForm(object): 77 | __slots__ = ["value", "normalized_value", "aliases", "normalized_aliases", "human_answers"] 78 | 79 | def __init__(self, value, normalized_value, aliases, normalized_aliases, human_answers): 80 | self.value = value 81 | self.aliases = aliases 82 | self.normalized_value = normalized_value 83 | self.normalized_aliases = normalized_aliases 84 | self.human_answers = human_answers 85 | 86 | @property 87 | def all_answers(self): 88 | if self.human_answers is None: 89 | return self.normalized_aliases 90 | else: 91 | human_answers = [triviaqa_normalize_answer(x) for x in self.human_answers] 92 | return self.normalized_aliases + [x for x in human_answers if len(x) > 0] 93 | 94 | def __repr__(self) -> str: 95 | return self.value 96 | 97 | 98 | class Range(object): 99 | __slots__ = ["value", "normalized_value", "aliases", "normalized_aliases", 100 | "start", "end", "unit", "multiplier", "human_answers"] 101 | 102 | def __init__(self, value, normalized_value, aliases, normalized_aliases, 103 | start, end, unit, multiplier, human_answers): 104 | self.value = value 105 | self.normalized_value = normalized_value 106 | self.aliases = aliases 107 | self.normalized_aliases = normalized_aliases 108 | self.start = start 109 | self.end = end 110 | self.unit = unit 111 | self.multiplier = multiplier 112 | self.human_answers = human_answers 113 | 114 | @property 115 | def all_answers(self): 116 | if self.human_answers is None: 117 | return self.normalized_aliases 118 | else: 119 | human_answers = [triviaqa_normalize_answer(x) for x in self.human_answers] 120 | return self.normalized_aliases + [x for x in human_answers if len(x) > 0] 121 | 122 | def __repr__(self) -> str: 123 | return self.value 124 | 125 | 126 | class TagMeEntityDoc(object): 127 | __slots__ = ["rho", "link_probability", "title", "trivia_qa_selected", "answer_spans"] 128 | 129 | def __init__(self, rho, link_probability, title): 130 | self.rho = rho 131 | self.link_probability = link_probability 132 | self.title = title 133 | self.trivia_qa_selected = False 134 | self.answer_spans = None 135 | 136 | @property 137 | def doc_id(self): 138 | return self.title 139 | 140 | def __repr__(self) -> str: 141 | return "TagMeEntityDoc(%s)" % self.title 142 | 143 | 144 | class SearchEntityDoc(object): 145 | __slots__ = ["title", "trivia_qa_selected", "answer_spans"] 146 | 147 | def __init__(self, title): 148 | self.title = title 149 | self.answer_spans = None 150 | self.trivia_qa_selected = False 151 | 152 | @property 153 | def doc_id(self): 154 | return self.title 155 | 156 | def __repr__(self) -> str: 157 | return "SearchEntityDoc(%s)" % self.title 158 | 159 | 160 | class SearchDoc(object): 161 | __slots__ = ["title", "description", "rank", "url", "trivia_qa_selected", "answer_spans"] 162 | 163 | def __init__(self, title, description, rank, url): 164 | self.title = title 165 | self.description = description 166 | self.rank = rank 167 | self.url = url 168 | self.answer_spans = None 169 | self.trivia_qa_selected = False 170 | 171 | @property 172 | def doc_id(self): 173 | return self.url 174 | 175 | def __repr__(self) -> str: 176 | return "SearchDoc(%s)" % self.title 177 | 178 | 179 | class TriviaQaQuestion(object): 180 | __slots__ = ["question", "question_id", "answer", "entity_docs", "web_docs"] 181 | 182 | def __init__(self, question, question_id, answer, entity_docs, web_docs): 183 | self.question = question 184 | self.question_id = question_id 185 | self.answer = answer 186 | self.entity_docs = entity_docs 187 | self.web_docs = web_docs 188 | 189 | @property 190 | def all_docs(self): 191 | if self.web_docs is not None: 192 | return self.web_docs + self.entity_docs 193 | else: 194 | return self.entity_docs 195 | 196 | def to_compressed_json(self): 197 | return [ 198 | self.question, 199 | self.question_id, 200 | [self.answer.__class__.__name__] + [getattr(self.answer, x) for x in self.answer.__slots__], 201 | [[doc.__class__.__name__] + [getattr(doc, x) for x in doc.__slots__] for doc in self.entity_docs], 202 | [[getattr(doc, x) for x in doc.__slots__] for doc in self.web_docs], 203 | ] 204 | 205 | @staticmethod 206 | def from_compressed_json(text): 207 | question, quid, answer, entity_docs, web_docs = json.loads(text) 208 | if answer[0] == "WikipediaEntity": 209 | answer = WikipediaEntity(*answer[1:]) 210 | elif answer[0] == "Numerical": 211 | answer = Numerical(*answer[1:]) 212 | elif answer[0] == "FreeForm": 213 | answer = FreeForm(*answer[1:]) 214 | elif answer[0] == "Range": 215 | answer = Range(*answer[1:]) 216 | else: 217 | raise ValueError() 218 | for i, doc in enumerate(entity_docs): 219 | if doc[0] == "TagMeEntityDoc": 220 | entity_docs[i] = TagMeEntityDoc(*doc[1:]) 221 | elif doc[0] == "SearchEntityDoc": 222 | entity_docs[i] = SearchEntityDoc(*doc[1:]) 223 | web_docs = [SearchDoc(*x) for x in web_docs] 224 | return TriviaQaQuestion(question, quid, answer, entity_docs, web_docs) 225 | 226 | 227 | def iter_question_json(filename): 228 | """ Iterates over trivia-qa questions in a JSON file, useful if the file is too large to be 229 | parse all at once """ 230 | with open(filename, "r") as f: 231 | if f.readline().strip() != "{": 232 | raise ValueError() 233 | if "Data\": [" not in f.readline(): 234 | raise ValueError() 235 | line = f.readline() 236 | while line.strip() == "{": 237 | obj = [] 238 | line = f.readline() 239 | while not line.startswith(" }"): 240 | obj.append(line) 241 | line = f.readline() 242 | yield "{" + "".join(obj) + "}" 243 | if not line.startswith(" },"): 244 | # no comma means this was the last element of the data list 245 | return 246 | else: 247 | line = f.readline() 248 | else: 249 | raise ValueError() 250 | 251 | 252 | def build_questions(json_questions, title_to_file, require_filename): 253 | for q in json_questions: 254 | q = json.loads(q) 255 | ans = q.get("Answer") 256 | valid_attempt = q.get("QuestionVerifiedEvalAttempt", False) 257 | if valid_attempt and not q["QuestionPartOfVerifiedEval"]: 258 | continue # don't both with questions in the verified set that were rejected 259 | if ans is not None: 260 | answer_type = ans["Type"] 261 | if answer_type == "WikipediaEntity": 262 | answer = WikipediaEntity(ans["NormalizedValue"], ans["Value"], ans["Aliases"], ans["NormalizedAliases"], 263 | ans["MatchedWikiEntityName"], ans["NormalizedMatchedWikiEntityName"], 264 | ans.get("HumanAnswers")) 265 | if not (len(ans) == 7 or (len(ans) == 8 and "HumanAnswers" in ans)): 266 | raise ValueError() 267 | elif answer_type == "Numerical": 268 | answer = Numerical(float(ans["Number"]), ans["Aliases"], ans["NormalizedAliases"], 269 | ans["Value"], ans["Unit"], ans["NormalizedValue"], 270 | ans["Multiplier"], ans.get("HumanAnswers")) 271 | if not (len(ans) == 8 or (len(ans) == 9 and "HumanAnswers" in ans)): 272 | raise ValueError() 273 | elif answer_type == "FreeForm": 274 | answer = FreeForm(ans["Value"], ans["NormalizedValue"], ans["Aliases"], 275 | ans["NormalizedAliases"], ans.get("HumanAnswers")) 276 | if not (len(ans) == 5 or (len(ans) == 6 and "HumanAnswers" in ans)): 277 | raise ValueError() 278 | elif answer_type == "Range": 279 | answer = Range(ans["Value"], ans["NormalizedValue"], ans["Aliases"], ans["NormalizedAliases"], 280 | float(ans["To"]), float(ans["From"]), ans["Unit"], 281 | ans["Multiplier"], ans.get("HumanAnswers")) 282 | if not (len(ans) == 9 or (len(ans) == 10 and "HumanAnswers" in ans)): 283 | if "Number" in ans: 284 | # This appears to be a bug, the number fields in this 285 | # cases seem to be meaningless (and VERY rare) 286 | pass 287 | else: 288 | raise ValueError() 289 | else: 290 | raise ValueError() 291 | else: 292 | answer = None 293 | 294 | entity_pages = [] 295 | for page in q["EntityPages"]: 296 | verified_attempt = page.get("DocVerifiedEvalAttempt", False) 297 | if verified_attempt and not page["DocPartOfVerifiedEval"]: 298 | continue 299 | title = page["Title"] 300 | if page["DocSource"] == "Search": 301 | entity_pages.append(SearchEntityDoc(title)) 302 | elif page["DocSource"] == "TagMe": 303 | entity_pages.append(TagMeEntityDoc(page.get("Rho"), page.get("LinkProbability"), title)) 304 | else: 305 | raise ValueError() 306 | filename = page.get("Filename") 307 | if filename is not None: 308 | filename = join("wikipedia", filename[:filename.rfind(".")]) 309 | filename = normalize_wiki_filename(filename) 310 | cur = title_to_file.get(title) 311 | if cur is None: 312 | title_to_file[title] = filename 313 | elif cur != filename: 314 | raise ValueError() 315 | elif require_filename: 316 | raise ValueError() 317 | 318 | if "SearchResults" in q: 319 | web_pages = [] 320 | for page in q["SearchResults"]: 321 | verified_attempt = page.get("DocVerifiedEvalAttempt", False) 322 | if verified_attempt and not page["DocPartOfVerifiedEval"]: 323 | continue 324 | url = page["Url"] 325 | web_pages.append(SearchDoc(page["Title"], page["Description"], page["Rank"], url)) 326 | filename = page.get("Filename") 327 | if filename is not None: 328 | filename = join("web", filename[:filename.rfind(".")]) 329 | cur = title_to_file.get(url) 330 | if cur is None: 331 | title_to_file[url] = filename 332 | elif cur != filename: 333 | raise ValueError() 334 | elif require_filename: 335 | raise ValueError() 336 | else: 337 | web_pages = None 338 | 339 | yield TriviaQaQuestion(q["Question"], q["QuestionId"], answer, entity_pages, web_pages) 340 | 341 | 342 | def iter_trivia_question(filename, file_map, require_filename): 343 | return build_questions(iter_question_json(filename), file_map, require_filename) 344 | 345 | 346 | -------------------------------------------------------------------------------- /triviaqa/triviaqa_document_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import math 4 | import six 5 | import collections 6 | import numpy as np 7 | from typing import List 8 | import bert.tokenization as tokenization 9 | from squad.squad_utils import _improve_answer_span, _get_best_indexes, _compute_softmax, get_final_text 10 | from squad.squad_evaluate import exact_match_score, f1_score, span_f1 11 | 12 | 13 | class ExtractedParagraphWithAnswers(object): 14 | __slots__ = ["text", "start", "end", "answer_spans"] 15 | 16 | def __init__(self, text: List[str], start: int, end: int, answer_spans: np.ndarray): 17 | """ 18 | :param text: List of source paragraphs that have been merged to form `self` 19 | :param start: start token of this text in the source document 20 | :param end: end token of this text in the source document 21 | """ 22 | self.text = text 23 | self.start = start 24 | self.end = end 25 | self.answer_spans = answer_spans 26 | 27 | @property 28 | def n_context_words(self): 29 | return len(self.text) 30 | 31 | def __repr__(self): 32 | s = "" 33 | s += "text: %s ..." % (" ".join(self.text[:10])) 34 | s += ", start: %d" % (self.start) 35 | s += ", end: %d" % (self.end) 36 | s += ", answer_spans: {}".format(self.answer_spans) 37 | return s 38 | 39 | 40 | class DocParagraphWithAnswers(ExtractedParagraphWithAnswers): 41 | __slots__ = ["doc_id"] 42 | 43 | def __init__(self, text: List[str], start: int, end: int, answer_spans: np.ndarray, 44 | doc_id): 45 | super().__init__(text, start, end, answer_spans) 46 | self.doc_id = doc_id 47 | 48 | 49 | class DocumentAndQuestion(object): 50 | def __init__(self, 51 | document_id, 52 | qas_id, 53 | question_text, # str 54 | doc_tokens, 55 | orig_answer_texts=None, 56 | start_positions=None, 57 | end_positions=None): 58 | self.document_id = document_id 59 | self.qas_id = qas_id 60 | self.question_text = question_text 61 | self.doc_tokens = doc_tokens 62 | self.orig_answer_texts = orig_answer_texts 63 | self.start_positions = start_positions 64 | self.end_positions = end_positions 65 | 66 | def __str__(self): 67 | return self.__repr__() 68 | 69 | def __repr__(self): 70 | s = "" 71 | s += "document_id: %s" % (self.document_id) 72 | s += ", qas_id: %s" % (tokenization.printable_text(self.qas_id)) 73 | s += ", question_text: %s" % ( 74 | tokenization.printable_text(self.question_text)) 75 | s += ", doc_tokens: %s ..." % (" ".join(self.doc_tokens[:20])) 76 | s += ", length of doc_tokens: %d" % (len(self.doc_tokens)) 77 | if self.orig_answer_texts: 78 | s += ", orig_answer_texts: {}".format(self.orig_answer_texts) 79 | if self.start_positions and self.end_positions: 80 | s += ", start_positions: {}".format(self.start_positions) 81 | s += ", end_positions: {}".format(self.end_positions) 82 | s += ", token_answer: " 83 | for start, end in zip(self.start_positions, self.end_positions): 84 | s += "{}, ".format(" ".join(self.doc_tokens[start:(end+1)])) 85 | return s 86 | 87 | 88 | class InputFeatures(object): 89 | """A single set of features of data.""" 90 | 91 | def __init__(self, 92 | unique_id, 93 | example_index, 94 | doc_span_index, 95 | tokens, 96 | token_to_orig_map, 97 | input_ids, 98 | input_mask, 99 | segment_ids, 100 | start_positions=None, 101 | end_positions=None, 102 | start_indexes=None, 103 | end_indexes=None, 104 | is_impossible=None): 105 | self.unique_id = unique_id 106 | self.example_index = example_index 107 | self.doc_span_index = doc_span_index 108 | self.tokens = tokens 109 | self.token_to_orig_map = token_to_orig_map 110 | self.input_ids = input_ids 111 | self.input_mask = input_mask 112 | self.segment_ids = segment_ids 113 | self.start_positions = start_positions 114 | self.end_positions = end_positions 115 | self.start_indexes = start_indexes 116 | self.end_indexes = end_indexes 117 | self.is_impossible = is_impossible 118 | 119 | 120 | def convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride, 121 | max_query_length, verbose_logging=False, logger=None): 122 | """Loads a data file into a list of `InputBatch`s.""" 123 | 124 | unique_id = 1000000000 125 | 126 | features = [] 127 | for (example_index, example) in enumerate(examples): 128 | query_tokens = tokenizer.tokenize(example.question_text) 129 | 130 | if len(query_tokens) > max_query_length: 131 | query_tokens = query_tokens[0:max_query_length] 132 | 133 | tok_to_orig_index = [] 134 | orig_to_tok_index = [] 135 | all_doc_tokens = [] 136 | for (i, token) in enumerate(example.doc_tokens): 137 | orig_to_tok_index.append(len(all_doc_tokens)) 138 | sub_tokens = tokenizer.tokenize(token) 139 | for sub_token in sub_tokens: 140 | tok_to_orig_index.append(i) 141 | all_doc_tokens.append(sub_token) 142 | 143 | tok_start_positions = [] 144 | tok_end_positions = [] 145 | for start_position, end_position in \ 146 | zip(example.start_positions, example.end_positions): 147 | tok_start_position = orig_to_tok_index[start_position] 148 | if end_position < len(example.doc_tokens) - 1: 149 | tok_end_position = orig_to_tok_index[end_position + 1] - 1 150 | else: 151 | tok_end_position = len(all_doc_tokens) - 1 152 | tok_start_positions.append(tok_start_position) 153 | tok_end_positions.append(tok_end_position) 154 | 155 | # The -3 accounts for [CLS], [SEP] and [SEP] 156 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 157 | 158 | # We can have documents that are longer than the maximum sequence length. 159 | # To deal with this we do a sliding window approach, where we take chunks 160 | # of the up to our max length with a stride of `doc_stride`. 161 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 162 | "DocSpan", ["start", "length"]) 163 | doc_spans = [] 164 | start_offset = 0 165 | while start_offset < len(all_doc_tokens): 166 | length = len(all_doc_tokens) - start_offset 167 | if length > max_tokens_for_doc: 168 | length = max_tokens_for_doc 169 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 170 | if start_offset + length == len(all_doc_tokens): 171 | break 172 | start_offset += min(length, doc_stride) 173 | 174 | for (doc_span_index, doc_span) in enumerate(doc_spans): 175 | tokens = [] 176 | token_to_orig_map = {} 177 | segment_ids = [] 178 | tokens.append("[CLS]") 179 | segment_ids.append(0) 180 | for token in query_tokens: 181 | tokens.append(token) 182 | segment_ids.append(0) 183 | tokens.append("[SEP]") 184 | segment_ids.append(0) 185 | 186 | for i in range(doc_span.length): 187 | split_token_index = doc_span.start + i 188 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 189 | tokens.append(all_doc_tokens[split_token_index]) 190 | segment_ids.append(1) 191 | tokens.append("[SEP]") 192 | segment_ids.append(1) 193 | 194 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 195 | 196 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 197 | # tokens are attended to. 198 | input_mask = [1] * len(input_ids) 199 | 200 | # Zero-pad up to the sequence length. 201 | while len(input_ids) < max_seq_length: 202 | input_ids.append(0) 203 | input_mask.append(0) 204 | segment_ids.append(0) 205 | 206 | assert len(input_ids) == max_seq_length 207 | assert len(input_mask) == max_seq_length 208 | assert len(segment_ids) == max_seq_length 209 | 210 | # For distant supervision, we annotate the positions of all answer spans 211 | start_positions = [0] * len(input_ids) 212 | end_positions = [0] * len(input_ids) 213 | start_indexes, end_indexes = [], [] 214 | doc_start = doc_span.start 215 | doc_end = doc_span.start + doc_span.length - 1 216 | is_impossible = True 217 | for tok_start_position, tok_end_position in zip(tok_start_positions, tok_end_positions): 218 | if (tok_start_position >= doc_start and tok_end_position <= doc_end): 219 | doc_offset = len(query_tokens) + 2 220 | start_position = tok_start_position - doc_start + doc_offset 221 | end_position = tok_end_position - doc_start + doc_offset 222 | start_positions[start_position] = 1 223 | end_positions[end_position] = 1 224 | start_indexes.append(start_position) 225 | end_indexes.append(end_position) 226 | is_impossible = False 227 | 228 | if is_impossible: 229 | start_positions[0] = 1 230 | end_positions[0] = 1 231 | start_indexes.append(0) 232 | end_indexes.append(0) 233 | 234 | if example_index < 2 and verbose_logging: 235 | logger.info("*** Example ***") 236 | logger.info("unique_id: %s" % (unique_id)) 237 | logger.info("example_index: %s" % (example_index)) 238 | logger.info("doc_span_index: %s" % (doc_span_index)) 239 | logger.info("doc_span_start: %s" % (doc_span.start)) 240 | if is_impossible: 241 | logger.info("impossible example") 242 | else: 243 | logger.info("start_indexes: {}".format(start_indexes)) 244 | logger.info("end_indexes: {}".format(end_indexes)) 245 | 246 | features.append( 247 | InputFeatures( 248 | unique_id=unique_id, 249 | example_index=example_index, 250 | doc_span_index=doc_span_index, 251 | tokens=tokens, 252 | token_to_orig_map=token_to_orig_map, 253 | input_ids=input_ids, 254 | input_mask=input_mask, 255 | segment_ids=segment_ids, 256 | start_positions=start_positions, 257 | end_positions=end_positions, 258 | start_indexes=start_indexes, 259 | end_indexes=end_indexes, 260 | is_impossible=is_impossible)) 261 | unique_id += 1 262 | 263 | if len(features) % 5000 == 0: 264 | logger.info("Processing features: %d" % (len(features))) 265 | 266 | return features 267 | 268 | def annotate_candidates(all_examples, batch_features, batch_results, filter_type, is_training, n_best_size, 269 | max_answer_length, do_lower_case, verbose_logging, logger): 270 | """Annotate top-k candidate answers into features.""" 271 | unique_id_to_result = {} 272 | for result in batch_results: 273 | unique_id_to_result[result.unique_id] = result 274 | 275 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 276 | "PrelimPrediction", 277 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit", "rank_logit"]) 278 | 279 | batch_span_starts, batch_span_ends, batch_hard_labels, batch_soft_labels = [], [], [], [] 280 | for (feature_index, feature) in enumerate(batch_features): 281 | example = all_examples[feature.example_index] 282 | result = unique_id_to_result[feature.unique_id] 283 | 284 | prelim_predictions_per_feature = [] 285 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 286 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 287 | for start_index in start_indexes: 288 | for end_index in end_indexes: 289 | # We could hypothetically create invalid predictions, e.g., predict 290 | # that the start of the span is in the question. We throw out all 291 | # invalid predictions. 292 | if start_index >= len(feature.tokens): 293 | continue 294 | if end_index >= len(feature.tokens): 295 | continue 296 | if start_index not in feature.token_to_orig_map: 297 | continue 298 | if end_index not in feature.token_to_orig_map: 299 | continue 300 | if end_index < start_index: 301 | continue 302 | length = end_index - start_index + 1 303 | if length > max_answer_length: 304 | continue 305 | 306 | prelim_predictions_per_feature.append( 307 | _PrelimPrediction( 308 | feature_index=feature_index, 309 | start_index=start_index, 310 | end_index=end_index, 311 | start_logit=result.start_logits[start_index], 312 | end_logit=result.end_logits[end_index], 313 | rank_logit=result.rank_logit)) 314 | 315 | prelim_predictions_per_feature = sorted( 316 | prelim_predictions_per_feature, 317 | key=lambda x: (x.start_logit + x.end_logit + x.rank_logit), 318 | reverse=True) 319 | 320 | seen_predictions = {} 321 | span_starts, span_ends, hard_labels, soft_labels = [], [], [], [] 322 | 323 | if is_training: 324 | # add no-answer option into candidate answers 325 | span_starts.append(0) 326 | span_ends.append(0) 327 | if feature.is_impossible: 328 | hard_labels.append(1) 329 | soft_labels.append(1.) 330 | else: 331 | hard_labels.append(0) 332 | soft_labels.append(0.) 333 | 334 | for i, pred_i in enumerate(prelim_predictions_per_feature): 335 | if len(span_starts) >= int(n_best_size/4): 336 | break 337 | tok_tokens = feature.tokens[pred_i.start_index:(pred_i.end_index + 1)] 338 | orig_doc_start = feature.token_to_orig_map[pred_i.start_index] 339 | orig_doc_end = feature.token_to_orig_map[pred_i.end_index] 340 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 341 | tok_text = " ".join(tok_tokens) 342 | 343 | # De-tokenize WordPieces that have been split off. 344 | tok_text = tok_text.replace(" ##", "") 345 | tok_text = tok_text.replace("##", "") 346 | 347 | # Clean whitespace 348 | tok_text = tok_text.strip() 349 | tok_text = " ".join(tok_text.split()) 350 | orig_text = " ".join(orig_tokens) 351 | 352 | final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging, logger) 353 | if final_text in seen_predictions: 354 | continue 355 | seen_predictions[final_text] = True 356 | 357 | if is_training: 358 | if pred_i.start_index != 0 and pred_i.end_index != 0: 359 | span_starts.append(pred_i.start_index) 360 | span_ends.append(pred_i.end_index) 361 | if feature.is_impossible: 362 | hard_labels.append(0) 363 | soft_labels.append(0.) 364 | else: 365 | max_em, max_f1 = 0, 0 366 | for orig_answer_text in example.orig_answer_texts: 367 | em = int(exact_match_score(final_text, orig_answer_text)) 368 | f1 = float(f1_score(final_text, orig_answer_text)) 369 | if em > max_em: 370 | max_em = em 371 | if f1 > max_f1: 372 | max_f1 = f1 373 | hard_labels.append(max_em) 374 | soft_labels.append(max_f1) 375 | else: 376 | span_starts.append(pred_i.start_index) 377 | span_ends.append(pred_i.end_index) 378 | 379 | # filter out redundant candidates 380 | if (i+1) < len(prelim_predictions_per_feature): 381 | indexes = [] 382 | for j, pred_j in enumerate(prelim_predictions_per_feature[(i+1):]): 383 | if filter_type == 'em': 384 | if pred_i.start_index == pred_j.start_index or pred_i.end_index == pred_j.end_index: 385 | indexes.append(i + j + 1) 386 | elif filter_type == 'f1': 387 | if span_f1([pred_i.start_index, pred_i.end_index], [pred_j.start_index, pred_j.end_index]) > 0: 388 | indexes.append(i + j + 1) 389 | elif filter_type == 'none': 390 | indexes = [] 391 | else: 392 | raise Exception 393 | [prelim_predictions_per_feature.pop(index - k) for k, index in enumerate(indexes)] 394 | 395 | # Pad to fixed length 396 | while len(span_starts) < int(n_best_size/4): 397 | span_starts.append(0) 398 | span_ends.append(0) 399 | if is_training: 400 | if feature.is_impossible: 401 | hard_labels.append(1) 402 | soft_labels.append(1.) 403 | else: 404 | hard_labels.append(0) 405 | soft_labels.append(0.) 406 | assert len(span_starts) == int(n_best_size/4) 407 | if is_training: 408 | assert len(hard_labels) == int(n_best_size/4) 409 | 410 | # Add ground truth answer spans if there is no positive label 411 | if is_training: 412 | if max(hard_labels) == 0: 413 | sample_start = random.sample(feature.start_indexes, 1) 414 | sample_end = random.sample(feature.end_indexes, 1) 415 | span_starts[-1] = sample_start[0] 416 | span_ends[-1] = sample_end[0] 417 | hard_labels[-1] = 1 418 | soft_labels[-1] = 1. 419 | 420 | batch_span_starts.append(span_starts) 421 | batch_span_ends.append(span_ends) 422 | batch_hard_labels.append(hard_labels) 423 | batch_soft_labels.append(soft_labels) 424 | return batch_span_starts, batch_span_ends, batch_hard_labels, batch_soft_labels 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | -------------------------------------------------------------------------------- /triviaqa/triviaqa_eval.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.0 of the TriviaQA dataset. 2 | Extended from the evaluation script for v1.1 of the SQuAD dataset. 3 | 4 | (Additionally condensed into a single file) 5 | """ 6 | from __future__ import print_function 7 | 8 | import json 9 | from collections import Counter 10 | import string 11 | import re 12 | import sys 13 | import argparse 14 | 15 | import unicodedata 16 | 17 | 18 | def normalize_answer(s): 19 | """Lower text and remove punctuation, articles and extra whitespace.""" 20 | 21 | def remove_articles(text): 22 | return re.sub(r'\b(a|an|the)\b', ' ', text) 23 | 24 | def white_space_fix(text): 25 | return ' '.join(text.split()) 26 | 27 | def handle_punc(text): 28 | exclude = set(string.punctuation + "".join([u"‘", u"’", u"´", u"`"])) 29 | return ''.join(ch if ch not in exclude else ' ' for ch in text) 30 | 31 | def lower(text): 32 | return text.lower() 33 | 34 | def replace_underscore(text): 35 | return text.replace('_', ' ') 36 | 37 | return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(s))))).strip() 38 | 39 | 40 | def f1_score(prediction, ground_truth): 41 | prediction_tokens = normalize_answer(prediction).split() 42 | ground_truth_tokens = normalize_answer(ground_truth).split() 43 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 44 | num_same = sum(common.values()) 45 | if num_same == 0: 46 | return 0 47 | precision = 1.0 * num_same / len(prediction_tokens) 48 | recall = 1.0 * num_same / len(ground_truth_tokens) 49 | f1 = (2 * precision * recall) / (precision + recall) 50 | return f1 51 | 52 | 53 | def exact_match_score(prediction, ground_truth): 54 | return normalize_answer(prediction) == normalize_answer(ground_truth) 55 | 56 | 57 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 58 | scores_for_ground_truths = [] 59 | for ground_truth in ground_truths: 60 | score = metric_fn(prediction, ground_truth) 61 | scores_for_ground_truths.append(score) 62 | return max(scores_for_ground_truths) 63 | 64 | 65 | def get_ground_truths(answer): 66 | return answer['NormalizedAliases'] + [normalize_answer(ans) for ans in answer.get('HumanAnswers', [])] 67 | 68 | 69 | def get_file_contents(filename, encoding='utf-8'): 70 | with open(filename, encoding=encoding) as f: 71 | content = f.read() 72 | return content 73 | 74 | 75 | def read_json(filename, encoding='utf-8'): 76 | contents = get_file_contents(filename, encoding=encoding) 77 | return json.loads(contents) 78 | 79 | 80 | def is_exact_match(answer_object, prediction): 81 | ground_truths = get_ground_truths(answer_object) 82 | for ground_truth in ground_truths: 83 | if exact_match_score(prediction, ground_truth): 84 | return True 85 | return False 86 | 87 | 88 | def has_exact_match(ground_truths, candidates): 89 | for ground_truth in ground_truths: 90 | if ground_truth in candidates: 91 | return True 92 | return False 93 | 94 | 95 | def get_key_to_ground_truth(data): 96 | if data['Domain'] == 'Wikipedia': 97 | return {datum['QuestionId']: datum['Answer'] for datum in data['Data']} 98 | else: 99 | return get_qd_to_answer(data) 100 | 101 | 102 | def get_key_to_ground_truth_per_question(data): 103 | return {datum['QuestionId']: datum['Answer'] for datum in data['Data']} 104 | 105 | 106 | def get_question_doc_string(qid, doc_name): 107 | return '{}--{}'.format(qid, unicodedata.normalize("NFD", doc_name).lower()) 108 | 109 | 110 | def get_qd_to_answer(data): 111 | key_to_answer = {} 112 | for datum in data['Data']: 113 | for page in datum.get('EntityPages', []) + datum.get('SearchResults', []): 114 | qd_tuple = get_question_doc_string(datum['QuestionId'], page['Filename']) 115 | key_to_answer[qd_tuple] = datum['Answer'] 116 | return key_to_answer 117 | 118 | 119 | def evaluate_triviaqa(ground_truth, predicted_answers, qid_list=None): 120 | f1 = exact_match = common = 0 121 | missing_count = 0 122 | if qid_list is None: 123 | qid_list = ground_truth.keys() 124 | for qid in qid_list: 125 | if qid not in predicted_answers: 126 | missing_count += 1 127 | # message = 'Missed question {} will receive score 0.'.format(qid) 128 | # print(message, file=sys.stderr) 129 | continue 130 | if qid not in ground_truth: 131 | missing_count += 1 132 | continue 133 | common += 1 134 | prediction = predicted_answers[qid] 135 | ground_truths = get_ground_truths(ground_truth[qid]) 136 | em_for_this_question = metric_max_over_ground_truths( 137 | exact_match_score, prediction, ground_truths) 138 | exact_match += em_for_this_question 139 | f1_for_this_question = metric_max_over_ground_truths( 140 | f1_score, prediction, ground_truths) 141 | f1 += f1_for_this_question 142 | 143 | exact_match = 100.0 * exact_match / len(qid_list) 144 | f1 = 100.0 * f1 / len(qid_list) 145 | 146 | print("missing prediction on %d examples" % (missing_count)) 147 | return {'exact_match': exact_match, 'f1': f1, 'common': common, 'denominator': len(qid_list), 148 | 'pred_len': len(predicted_answers), 'gold_len': len(ground_truth)} 149 | 150 | 151 | def read_clean_part(datum): 152 | for key in ['EntityPages', 'SearchResults']: 153 | new_page_list = [] 154 | for page in datum.get(key, []): 155 | if page['DocPartOfVerifiedEval']: 156 | new_page_list.append(page) 157 | datum[key] = new_page_list 158 | assert len(datum['EntityPages']) + len(datum['SearchResults']) > 0 159 | return datum 160 | 161 | 162 | def read_triviaqa_data(qajson): 163 | data = read_json(qajson) 164 | # read only documents and questions that are a part of clean data set 165 | if data['VerifiedEval']: 166 | clean_data = [] 167 | for datum in data['Data']: 168 | if datum['QuestionPartOfVerifiedEval']: 169 | if data['Domain'] == 'Web': 170 | datum = read_clean_part(datum) 171 | clean_data.append(datum) 172 | data['Data'] = clean_data 173 | return data 174 | 175 | 176 | def get_args(): 177 | parser = argparse.ArgumentParser(description='Evaluation for TriviaQA') 178 | parser.add_argument('--dataset_file', help='Dataset file') 179 | parser.add_argument('--prediction_file', help='Prediction File') 180 | args = parser.parse_args() 181 | return args 182 | 183 | 184 | if __name__ == '__main__': 185 | expected_version = 1.0 186 | args = get_args() 187 | 188 | dataset_json = read_triviaqa_data(args.dataset_file) 189 | if dataset_json['Version'] != expected_version: 190 | print('Evaluation expects v-{} , but got dataset with v-{}'.format(expected_version,dataset_json['Version']), 191 | file=sys.stderr) 192 | key_to_ground_truth = get_key_to_ground_truth(dataset_json) 193 | predictions = read_json(args.prediction_file) 194 | eval_dict = evaluate_triviaqa(key_to_ground_truth, predictions) 195 | print(eval_dict) -------------------------------------------------------------------------------- /triviaqa/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, TypeVar, Iterable 2 | import collections 3 | import re 4 | import string 5 | T = TypeVar('T') 6 | 7 | stop_words = {'t', '–', 'there', 'but', 'needn', 'themselves', '’', '~', '$', 'few', '^', '₹', ']', 'we', 're', 8 | 'again', '?', 'they', 'ain', 'o', 'you', '+', 'has', 'by', 'than', 'whom', 'same', 'don', 'her', 9 | 'are', '(', 'an', 'so', 'the', 'been', 'wouldn', 'a', 'many', 'she', 'how', 'your', '°', 'do', 10 | 'shan', 'himself', 'between', 'ours', 'at', 'should', 'doesn', 'hasn', 'he', 'have', 'over', 11 | 'hadn', 'was', 'weren', 'down', 'above', '_', 'those', 'not', 'having', 'its', 'ourselves', 12 | 'for', 'when', 'if', ',', ';', 'about', 'theirs', 'him', '}', 'here', 'any', 'own', 'itself', 13 | 'very', 'on', 'myself', 'mustn', ')', 'because', 'now', '/', 'isn', 'to', 'just', 'these', 14 | 'i', 'further', 'mightn', 'll', '@', 'am', '”', 'below', 'shouldn', 'my', 'who', 'yours', 'why', 15 | 'such', '"', 'does', 'did', 'before', 'being', 'and', 'had', 'aren', '£', 'with', 'more', 'into', 16 | '<', 'herself', 'which', '[', "'", 'of', 'haven', 'that', 'will', 'yourself', 'in', 'doing', '−', 17 | 'them', '‘', 'some', '`', 'while', 'each', 'it', 'through', 'all', 'their', ':', '\\', 'where', 18 | 'both', 'hers', '¢', '—', 'm', '.', 'from', 'or', 'other', 'too', 'couldn', 'as', 'our', 'off', 19 | '%', '&', '-', '{', '=', 'didn', 'yourselves', 'under', 'y', 'ma', 'won', '!', '|', 'against', 20 | '#', '¥', 'is', 'nor', 'up', 'most', 's', 'no', 'can', '>', '*', 'during', 'once', 'what', 'me', 21 | 'then', 'd', 'only', 'de', 've', 'were', '€', 'until', 'his', 'out', 'wasn', 'this', 'after', 22 | 'be', "' s", "' t"} 23 | 24 | def flatten_iterable(listoflists: Iterable[Iterable[T]]) -> List[T]: 25 | return [item for sublist in listoflists for item in sublist] 26 | 27 | 28 | def split(lst: List[T], n_groups) -> List[List[T]]: 29 | """ partition `lst` into `n_groups` that are as evenly sized as possible """ 30 | per_group = len(lst) // n_groups 31 | remainder = len(lst) % n_groups 32 | groups = [] 33 | ix = 0 34 | for _ in range(n_groups): 35 | group_size = per_group 36 | if remainder > 0: 37 | remainder -= 1 38 | group_size += 1 39 | groups.append(lst[ix:ix + group_size]) 40 | ix += group_size 41 | return groups 42 | 43 | def group(lst: List[T], max_group_size) -> List[List[T]]: 44 | """ partition `lst` into that the mininal number of groups that as evenly sized 45 | as possible and are at most `max_group_size` in size """ 46 | if max_group_size is None: 47 | return [lst] 48 | n_groups = (len(lst)+max_group_size-1) // max_group_size 49 | per_group = len(lst) // n_groups 50 | remainder = len(lst) % n_groups 51 | groups = [] 52 | ix = 0 53 | for _ in range(n_groups): 54 | group_size = per_group 55 | if remainder > 0: 56 | remainder -= 1 57 | group_size += 1 58 | groups.append(lst[ix:ix + group_size]) 59 | ix += group_size 60 | return groups 61 | 62 | 63 | def simple_normalize_answer(s): 64 | """Lower text and remove punctuation, articles and extra whitespace.""" 65 | def white_space_fix(text): 66 | return ' '.join(text.split()) 67 | 68 | def lower(text): 69 | return text.lower() 70 | 71 | return white_space_fix(lower(s)) 72 | 73 | 74 | def normalize_answer(s): 75 | """Lower text and remove punctuation, articles and extra whitespace.""" 76 | def remove_articles(text): 77 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 78 | return re.sub(regex, ' ', text) 79 | def white_space_fix(text): 80 | return ' '.join(text.split()) 81 | def remove_punc(text): 82 | exclude = set(string.punctuation) 83 | return ''.join(ch for ch in text if ch not in exclude) 84 | def lower(text): 85 | return text.lower() 86 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 87 | 88 | 89 | def get_tokens(s): 90 | if not s: return [] 91 | return normalize_answer(s).split() 92 | 93 | 94 | def compute_f1(a_gold, a_pred): 95 | gold_toks = get_tokens(a_gold) 96 | pred_toks = get_tokens(a_pred) 97 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 98 | num_same = sum(common.values()) 99 | if len(gold_toks) == 0 or len(pred_toks) == 0: 100 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 101 | return int(gold_toks == pred_toks) 102 | if num_same == 0: 103 | return 0 104 | precision = 1.0 * num_same / len(pred_toks) 105 | recall = 1.0 * num_same / len(gold_toks) 106 | f1 = (2 * precision * recall) / (precision + recall) 107 | return f1 108 | 109 | 110 | def get_max_f1_span(words, answer, window_size): 111 | max_f1 = 0 112 | max_span = (0, 0) 113 | 114 | for idx1, word1 in enumerate(words): 115 | for idx2, word2 in enumerate(words[idx1: idx1 + window_size + 1]): 116 | candidate_answer = words[idx1: idx1 + idx2 + 1] 117 | f1 = compute_f1(' '.join(answer), ' '.join(candidate_answer)) 118 | if f1 > max_f1: 119 | max_f1 = f1 120 | max_span = (idx1, idx1 + idx2) 121 | return max_span, max_f1 --------------------------------------------------------------------------------