├── .gitignore ├── LICENSE ├── README.md ├── build_tfidf.py ├── environment.yml ├── eval_utils.py ├── evaluate-v1.1.py ├── file_utils.py ├── input_examples.txt ├── local_dump.py ├── mips_phrase.py ├── modeling.py ├── optimization.py ├── post.py ├── pre.py ├── run_index.py ├── run_server.py ├── simple_tokenizer.py ├── static ├── examples.txt ├── files │ ├── all.js │ ├── bootstrap.min.js │ ├── jquery-3.3.1.min.js │ ├── pichu.png │ ├── pika.png │ ├── popper.min.js │ └── style.css ├── index.html └── preview.png ├── tfidf_doc_ranker.py ├── tfidf_util.py ├── tokenization.py ├── tokenizer_util.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | out 3 | test 4 | __pycache__ 5 | logs/ 6 | pred/ 7 | models/ 8 | dumps*/ 9 | data/ 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 University of Washington 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Sparc 3 | 4 | GitHub 5 | 6 |

7 |
8 |

Contextualized Sparse Representations for Real-Time Open-Domain Question Answering 9 |

10 | 11 |
12 | Sparc Demo 13 |

14 | 15 | This repository provides author's implementation of [Contextualized Sparse Representation for Real-Time Open-Domain Question Answering](https://arxiv.org/abs/1911.02896). You can train and evaluate DenSPI+Sparc described in our paper and make your own Sparc vector. 16 | 17 | ## Environment 18 | Please install the Conda environment as follows: 19 | ```bash 20 | $ conda env create -f environment.yml 21 | $ conda activate sparc 22 | ``` 23 | Note that this repository is mostly based on [DenSPI](https://github.com/uwnlp/denspi) and [DrQA](https://github.com/facebookresearch/DrQA). 24 | 25 | ## Resources 26 | We use [SQuAD v1.1](https://github.com/rajpurkar/SQuAD-explorer/tree/master/dataset) for training DenSPI+Sparc. Please download them in `$DATA_DIR`. 27 | ```bash 28 | $ mkdir $DATA_DIR 29 | $ wget https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-v1.1.json -O $DATA_DIR/train-v1.1.json 30 | $ wget https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v1.1.json -O $DATA_DIR/dev-v1.1.json 31 | ``` 32 | 33 | DenSPI is based on BERT. Please download pre-trained weights of BERT under `$BERT_DIR`. 34 | ```bash 35 | $ mkdir $BERT_DIR 36 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin -O $BERT_DIR/pytorch_model_base_uncased.bin 37 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json -O $BERT_DIR/bert_config_base_uncased.json 38 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin -O $BERT_DIR/pytorch_model_large_uncased.bin 39 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json -O $BERT_DIR/bert_config_large_uncased.json 40 | # Vocabulary is the same for BERT-base and BERT-large. 41 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt -O $BERT_DIR/vocab.txt 42 | ``` 43 | 44 | ## Model 45 | To train DenSPI+Sparc on SQuAD, use `train.py`. Trained models will be saved in `$OUT_DIR1`. 46 | ```bash 47 | $ mkdir $OUT_DIR1 48 | # Train with BERT-base 49 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR1 --bert_model_option 'base_uncased' --train_file train-v1.1.json --predict_file dev-v1.1.json --do_train --do_predict --do_eval 50 | # Train with BERT-large (use smaller train_batch_size for 12GB GPUs) 51 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR1 --bert_model_option 'large_uncased' --parallel --train_file train-v1.1.json --predict_file dev-v1.1.json --do_train --do_predict --do_eval --train_batch_size 6 52 | ``` 53 | 54 | The result will look like (in case of BERT-base): 55 | ```bash 56 | 04/28/2020 06:32:59 - INFO - post - num vecs=45059736, num_words=1783576, nvpw=25.2637 57 | 04/28/2020 06:33:01 - INFO - __main__ - [Validation] loss: 8.700, b'{"exact_match": 75.10879848628193, "f1": 83.42143097917004}\n' 58 | ``` 59 | 60 | To use DenSPI+Sparc in an open-domain setting, you have to additionally train it with negative samples. In case of DenSPI+Sparc with BERT-base (same for BERT-large except `--bert_model_option` and `--parallel` arguments), commands for training on negative samples are: 61 | ```bash 62 | $ mkdir $OUT_DIR2 63 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR --bert_model_option 'base_uncased' --train_file train-v1.1.json --predict_file dev-v1.1.json --do_train_neg --do_predict --do_eval --do_load --load_dir $OUT_DIR1 --load_epoch 3 64 | ``` 65 | 66 | Finally, train the phrase classifer as: 67 | ```bash 68 | $ mkdir $OUT_DIR3 69 | # Train only 1 epoch for phrase classifier 70 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR --bert_model_option 'base_uncased' --train_file train-v1.1.json --predict_file dev-v1.1.json --num_train_epochs 1 --do_train_filter --do_predict --do_eval --do_load --load_dir $OUT_DIR2 --load_epoch 3 71 | ``` 72 | 73 | We also provide a pretrained DenSPI+Sparc as follows: 74 | * DenSPI+Sparc pre-trained on SQuAD - [link](https://drive.google.com/file/d/1lObQ2lX8bWwJRzUuEqH6kpPdSTmS_Zxw/view?usp=sharing) 75 | 76 | 77 | ## Sparc Embedding 78 | Given the pre-trained DenSPI+Sparc, you can get Sparc embedding with following commands. Example below assumes using our pre-trained weight ([`denspi_sparc.zip`](https://drive.google.com/file/d/1lObQ2lX8bWwJRzUuEqH6kpPdSTmS_Zxw/view?usp=sharing) unzipped in `denspi_sparc` folder). If you want to use your own model, please modify `MODEL_DIR` accordingly. 79 | 80 | For any type of text you want to embed, put them in each line of `input_examples.txt`. Then run: 81 | ```bash 82 | $ export DATA_DIR=. 83 | $ export MODEL_DIR=denspi_sparc 84 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR --predict_file input_examples.txt --parallel --bert_model_option 'large_uncased' --do_load --load_dir $MODEL_DIR --load_epoch 1 --do_embed --dump_file output.json 85 | ``` 86 | 87 | The result file `$OUT_DIR/output.json` will show Sparc embedding of the input text ([CLS] representation, sorted by scores). For instance: 88 | ```json 89 | { 90 | "out": [ 91 | { 92 | "text": "They defeated the Arizona Cardinals 49-15 in the NFC Championship Game and advanced to their second Super Bowl appearance since the franchise was founded in 1995.", 93 | "sparc": { 94 | "uni": { 95 | "1995": { 96 | "score": 1.6841894388198853, 97 | "vocab": "2786" 98 | }, 99 | "second": { 100 | "score": 1.6321970224380493, 101 | "vocab": "2117" 102 | }, 103 | "49": { 104 | "score": 1.6075607538223267, 105 | "vocab": "4749" 106 | }, 107 | "arizona": { 108 | "score": 1.1734912395477295, 109 | "vocab": "5334" 110 | }, 111 | }, 112 | "bi": { 113 | "arizona cardinals": { 114 | "score": 1.3190401792526245, 115 | "vocab": "5334, 9310" 116 | }, 117 | "nfc championship": { 118 | "score": 1.1005975008010864, 119 | "vocab": "22309, 2528" 120 | }, 121 | "49 -": { 122 | "score": 1.0863999128341675, 123 | "vocab": "4749, 1011" 124 | }, 125 | "the arizona": { 126 | "score": 0.9722453951835632, 127 | "vocab": "1996, 5334" 128 | }, 129 | } 130 | } 131 | } 132 | ] 133 | } 134 | ``` 135 | Note that each text is segmented by the BERT tokenizer (`"vocab"` denotes the BERT vocab index). 136 | 137 | To see how Sparc changes for each phrase, set `start_index` in [here](https://github.com/jhyuklee/sparc/blob/750bf1a2b79f0e074edb77ef535c7e2861ffd8fd/post.py#L371) to the target token position. For instance, setting `start_index = 17` to embed Sparc of `415,000` of the following text gives you (some n-grams are omitted): 138 | 139 | ```json 140 | "text": "Between 1991 and 2000, the total area of forest lost in the Amazon rose from 415,000 to 587,000 square kilometres.", 141 | "sparc": { 142 | "uni": { 143 | "1991": { 144 | "score": 1.182684063911438, 145 | "vocab": "2889" 146 | }, 147 | "2000": { 148 | "score": 0.41507360339164734, 149 | "vocab": "2456" 150 | }, 151 | ``` 152 | whereas setting `start_index = 21` to embed Sparc of `587,000` gives you: 153 | ```json 154 | "text": "Between 1991 and 2000, the total area of forest lost in the Amazon rose from 415,000 to 587,000 square kilometres.", 155 | "sparc": { 156 | "uni": { 157 | "2000": { 158 | "score": 1.1923936605453491, 159 | "vocab": "2456" 160 | }, 161 | "1991": { 162 | "score": 0.7090237140655518, 163 | "vocab": "2889" 164 | }, 165 | ``` 166 | 167 | ## Phrase Index 168 | For now, please see [the original DenSPI repository](https://github.com/uwnlp/denspi) or [the recent application of DenSPI in COVID-19 domain](https://github.com/dmis-lab/covidAsk) for building phrase index using DenSPI+Sparc. 169 | The main changes in phrase indexing are in `post.py` and `mips_phrase.py` where Sparc is used for the open-domain QA inference (See [here](https://github.com/jhyuklee/sparc/blob/885729372706e227fa9c566ca51bd88de984710a/mips_phrase.py#L390-L410)). 170 | 171 | ## Reference 172 | ```bibtex 173 | @inproceedings{lee2020contextualized, 174 | title={Contextualized Sparse Representations for Real-Time Open-Domain Question Answering}, 175 | author={Lee, Jinhyuk and Seo, Minjoon and Hajishirzi, Hannaneh and Kang, Jaewoo}, 176 | booktitle={ACL}, 177 | year={2020} 178 | } 179 | ``` 180 | -------------------------------------------------------------------------------- /build_tfidf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to build the tf-idf document matrices for retrieval.""" 8 | 9 | import numpy as np 10 | import scipy.sparse as sp 11 | import argparse 12 | import os 13 | import math 14 | import logging 15 | import json 16 | import copy 17 | import pandas as pd 18 | 19 | from multiprocessing import Pool as ProcessPool 20 | from multiprocessing.util import Finalize 21 | from functools import partial 22 | from collections import Counter 23 | 24 | import tfidf_util 25 | from simple_tokenizer import SimpleTokenizer 26 | 27 | logger = logging.getLogger() 28 | logger.setLevel(logging.INFO) 29 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 30 | console = logging.StreamHandler() 31 | console.setFormatter(fmt) 32 | logger.addHandler(console) 33 | 34 | 35 | # ------------------------------------------------------------------------------ 36 | # Multiprocessing functions 37 | # ------------------------------------------------------------------------------ 38 | 39 | DOC2IDX = None 40 | PROCESS_TOK = None 41 | PROCESS_DB = None 42 | 43 | 44 | def init(tokenizer_class, db): 45 | global PROCESS_TOK, PROCESS_DB 46 | PROCESS_TOK = tokenizer_class() 47 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 48 | PROCESS_DB = db 49 | 50 | 51 | def fetch_text(doc_id): 52 | global PROCESS_DB 53 | return PROCESS_DB[doc_id] 54 | 55 | 56 | def tokenize(text): 57 | global PROCESS_TOK 58 | return PROCESS_TOK.tokenize(text) 59 | 60 | 61 | # ------------------------------------------------------------------------------ 62 | # Build article --> word count sparse matrix. 63 | # ------------------------------------------------------------------------------ 64 | 65 | 66 | def count(ngram, hash_size, doc_id): 67 | """Fetch the text of a document and compute hashed ngrams counts.""" 68 | global DOC2IDX 69 | row, col, data = [], [], [] 70 | # Tokenize 71 | tokens = tokenize(tfidf_util.normalize(fetch_text(doc_id))) 72 | 73 | # Get ngrams from tokens, with stopword/punctuation filtering. 74 | ngrams = tokens.ngrams( 75 | n=ngram, uncased=True, filter_fn=tfidf_util.filter_ngram 76 | ) 77 | 78 | # Hash ngrams and count occurences 79 | counts = Counter([tfidf_util.hash(gram, hash_size) for gram in ngrams]) 80 | 81 | # Return in sparse matrix data format. 82 | row.extend(counts.keys()) 83 | col.extend([DOC2IDX[doc_id]] * len(counts)) 84 | data.extend(counts.values()) 85 | return row, col, data 86 | 87 | 88 | def get_count_matrix(args, file_path): 89 | """Form a sparse word to document count matrix (inverted index). 90 | 91 | M[i, j] = # times word i appears in document j. 92 | """ 93 | # Map doc_ids to indexes 94 | global DOC2IDX 95 | doc_ids = {} 96 | doc_metas = {} 97 | nan_cnt = 0 98 | for filename in sorted(os.listdir(file_path)): 99 | print(filename) 100 | with open(os.path.join(file_path, filename), 'r') as f: 101 | articles = json.load(f)['data'] 102 | for article in articles: 103 | title = article['title'] 104 | kk = 0 105 | while title in doc_ids: 106 | title += f'_{kk}' 107 | kk += 1 108 | doc_ids[title] = ' '.join([par['context'] for par in article['paragraphs']]) 109 | 110 | # Keep metadata 111 | doc_meta = {} 112 | for key, val in article.items(): 113 | if key != 'paragraphs': 114 | doc_meta[key] = val if val == val else 'NaN' 115 | else: 116 | doc_meta[key] = [] 117 | for para in val: 118 | para_meta = {} 119 | for para_key, para_val in para.items(): 120 | if para_key != 'context': 121 | para_meta[para_key] = para_val if para_val == para_val else 'NaN' 122 | doc_meta[key].append(para_meta) 123 | if not pd.isnull(article.get('pubmed_id', np.nan)): 124 | doc_metas[str(article['pubmed_id'])] = doc_meta # For BEST (might be duplicate) 125 | else: 126 | nan_cnt += 1 127 | doc_metas[article['title']] = doc_meta 128 | 129 | DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)} 130 | print('doc ids:', len(DOC2IDX)) 131 | print('doc metas:', len(doc_metas), 'with nan', str(nan_cnt)) 132 | # assert len(doc_ids)*2 == len(doc_metas) + nan_cnt 133 | 134 | # Setup worker pool 135 | tok_class = SimpleTokenizer 136 | workers = ProcessPool( 137 | args.num_workers, 138 | initializer=init, 139 | initargs=(tok_class, doc_ids) 140 | ) 141 | doc_ids = list(doc_ids.keys()) 142 | 143 | # Compute the count matrix in steps (to keep in memory) 144 | logger.info('Mapping...') 145 | row, col, data = [], [], [] 146 | step = max(int(len(doc_ids) / 10), 1) 147 | batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)] 148 | _count = partial(count, args.ngram, args.hash_size) 149 | for i, batch in enumerate(batches): 150 | logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) + '-' * 25) 151 | for b_row, b_col, b_data in workers.imap_unordered(_count, batch): 152 | row.extend(b_row) 153 | col.extend(b_col) 154 | data.extend(b_data) 155 | workers.close() 156 | workers.join() 157 | 158 | logger.info('Creating sparse matrix...') 159 | count_matrix = sp.csr_matrix( 160 | (data, (row, col)), shape=(args.hash_size, len(doc_ids)) 161 | ) 162 | count_matrix.sum_duplicates() 163 | return count_matrix, (DOC2IDX, doc_ids, doc_metas) 164 | 165 | 166 | # ------------------------------------------------------------------------------ 167 | # Transform count matrix to different forms. 168 | # ------------------------------------------------------------------------------ 169 | 170 | 171 | def get_tfidf_matrix(cnts): 172 | """Convert the word count matrix into tfidf one. 173 | 174 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 175 | * tf = term frequency in document 176 | * N = number of documents 177 | * Nt = number of occurences of term in all documents 178 | """ 179 | Ns = get_doc_freqs(cnts) 180 | idfs = np.log((cnts.shape[1] - Ns + 0.5) / (Ns + 0.5)) 181 | idfs[idfs < 0] = 0 182 | idfs = sp.diags(idfs, 0) 183 | tfs = cnts.log1p() 184 | tfidfs = idfs.dot(tfs) 185 | return tfidfs 186 | 187 | 188 | def get_doc_freqs(cnts): 189 | """Return word --> # of docs it appears in.""" 190 | binary = (cnts > 0).astype(int) 191 | freqs = np.array(binary.sum(1)).squeeze() 192 | return freqs 193 | 194 | 195 | # ------------------------------------------------------------------------------ 196 | # Main. 197 | # ------------------------------------------------------------------------------ 198 | 199 | 200 | if __name__ == '__main__': 201 | parser = argparse.ArgumentParser() 202 | parser.add_argument('file_path', type=str, default=None, 203 | help='Path to document texts') 204 | parser.add_argument('out_dir', type=str, default=None, 205 | help='Directory for saving output files') 206 | parser.add_argument('--ngram', type=int, default=2, 207 | help=('Use up to N-size n-grams ' 208 | '(e.g. 2 = unigrams + bigrams)')) 209 | parser.add_argument('--hash-size', type=int, default=int(math.pow(2, 24)), 210 | help='Number of buckets to use for hashing ngrams') 211 | parser.add_argument('--num-workers', type=int, default=None, 212 | help='Number of CPU processes (for tokenizing, etc)') 213 | args = parser.parse_args() 214 | 215 | logging.info('Counting words...') 216 | count_matrix, doc_dict = get_count_matrix( 217 | args, args.file_path 218 | ) 219 | 220 | logger.info('Making tfidf vectors...') 221 | tfidf = get_tfidf_matrix(count_matrix) 222 | 223 | logger.info('Getting word-doc frequencies...') 224 | freqs = get_doc_freqs(count_matrix) 225 | 226 | basename = os.path.splitext(os.path.basename(args.file_path))[0] 227 | basename += ('-tfidf-ngram=%d-hash=%d-tokenizer=simple' % 228 | (args.ngram, args.hash_size)) 229 | filename = os.path.join(args.out_dir, basename) 230 | 231 | logger.info('Saving to %s.npz' % filename) 232 | metadata = { 233 | 'doc_freqs': freqs, 234 | 'tokenizer': 'simple', 235 | 'hash_size': args.hash_size, 236 | 'ngram': args.ngram, 237 | 'doc_dict': doc_dict 238 | } 239 | tfidf_util.save_sparse_csr(filename, tfidf, metadata) 240 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sparc 2 | channels: 3 | - cyclus 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - blas=1.0=mkl 10 | - ca-certificates=2020.1.1=0 11 | - certifi=2019.11.28=py36_0 12 | - cudatoolkit=10.0.130=0 13 | - curl=7.65.2=hbc83047_0 14 | - faiss-gpu=1.6.1=py36h1a5d453_0 15 | - htop=2.2.0=hf8c457e_1000 16 | - intel-openmp=2019.4=243 17 | - java-jdk=8.45.14=0 18 | - krb5=1.16.1=h173b8e3_7 19 | - libcurl=7.65.2=h20c2e04_0 20 | - libedit=3.1.20181209=hc058e9b_0 21 | - libffi=3.2.1=hd88cf55_4 22 | - libgcc-ng=9.1.0=hdf63c60_0 23 | - libgfortran-ng=7.3.0=hdf63c60_0 24 | - libssh2=1.8.2=h1ba5d50_0 25 | - libstdcxx-ng=9.1.0=hdf63c60_0 26 | - mkl=2020.0=166 27 | - mkl-service=2.3.0=py36he904b0f_0 28 | - mkl_fft=1.0.15=py36ha843d7b_0 29 | - mkl_random=1.1.0=py36hd6b4f25_0 30 | - ncurses=6.1=he6710b0_1 31 | - numpy-base=1.18.1=py36hde5b4d6_1 32 | - openssl=1.1.1e=h7b6447c_0 33 | - python=3.6.9=h265db76_0 34 | - readline=7.0=h7b6447c_5 35 | - setuptools=41.0.1=py36_0 36 | - sqlite=3.29.0=h7b6447c_0 37 | - tk=8.6.8=hbc83047_0 38 | - wheel=0.33.4=py36_0 39 | - xz=5.2.4=h14c3975_4 40 | - zlib=1.2.11=h7b6447c_3 41 | - pip: 42 | - backcall==0.1.0 43 | - chardet==3.0.4 44 | - click==7.0 45 | - decorator==4.4.2 46 | - elasticsearch==7.0.2 47 | - flask==1.0.2 48 | - flask-cors==3.0.7 49 | - h5py==2.9.0 50 | - idna==2.8 51 | - ipython==7.13.0 52 | - ipython-genutils==0.2.0 53 | - itsdangerous==1.1.0 54 | - jedi==0.16.0 55 | - jinja2==2.10.1 56 | - joblib==0.13.2 57 | - markupsafe==1.1.1 58 | - nltk==3.4.4 59 | - numpy==1.17.0 60 | - pandas==0.23.0 61 | - parso==0.6.2 62 | - pexpect==4.2.1 63 | - pickleshare==0.7.5 64 | - pip==20.0.2 65 | - prettytable==0.7.2 66 | - prompt-toolkit==3.0.4 67 | - ptyprocess==0.6.0 68 | - pygments==2.6.1 69 | - python-dateutil==2.8.1 70 | - pytz==2019.3 71 | - regex==2019.6.8 72 | - requests==2.22.0 73 | - requests-futures==0.9.9 74 | - scikit-learn==0.21.3 75 | - scipy==1.1.0 76 | - six==1.12.0 77 | - sklearn==0.0 78 | - termcolor==1.1.0 79 | - torch==1.1.0 80 | - tornado==6.0.1 81 | - tqdm==4.31.1 82 | - traitlets==4.3.3 83 | - ujson==2.0.2 84 | - urllib3==1.25.3 85 | - wcwidth==0.1.8 86 | - werkzeug==0.15.5 87 | prefix: /home/jinhyuk/miniconda3/envs/sparc 88 | 89 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ujson as json 3 | import re 4 | import string 5 | import unicodedata 6 | from collections import Counter 7 | import pickle 8 | from IPython import embed 9 | 10 | def normalize_answer(s): 11 | 12 | def remove_articles(text): 13 | return re.sub(r'\b(a|an|the)\b', ' ', text) 14 | 15 | def white_space_fix(text): 16 | return ' '.join(text.split()) 17 | 18 | def remove_punc(text): 19 | exclude = set(string.punctuation) 20 | return ''.join(ch for ch in text if ch not in exclude) 21 | 22 | def lower(text): 23 | return text.lower() 24 | 25 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 26 | 27 | 28 | def f1_score(prediction, ground_truth): 29 | normalized_prediction = normalize_answer(prediction) 30 | normalized_ground_truth = normalize_answer(ground_truth) 31 | 32 | ZERO_METRIC = (0, 0, 0) 33 | 34 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 35 | return ZERO_METRIC 36 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 37 | return ZERO_METRIC 38 | 39 | prediction_tokens = normalized_prediction.split() 40 | ground_truth_tokens = normalized_ground_truth.split() 41 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 42 | num_same = sum(common.values()) 43 | if num_same == 0: 44 | return ZERO_METRIC 45 | precision = 1.0 * num_same / len(prediction_tokens) 46 | recall = 1.0 * num_same / len(ground_truth_tokens) 47 | f1 = (2 * precision * recall) / (precision + recall) 48 | return f1, precision, recall 49 | 50 | 51 | def exact_match_score(prediction, ground_truth): 52 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 53 | 54 | 55 | def drqa_normalize(text): 56 | """Resolve different type of unicode encodings.""" 57 | return unicodedata.normalize('NFD', text) 58 | 59 | 60 | def drqa_exact_match_score(prediction, ground_truth): 61 | """Check if the prediction is a (soft) exact match with the ground truth.""" 62 | return normalize_answer(prediction) == normalize_answer(ground_truth) 63 | 64 | 65 | def drqa_regex_match_score(prediction, pattern): 66 | """Check if the prediction matches the given regular expression.""" 67 | try: 68 | compiled = re.compile( 69 | pattern, 70 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE 71 | ) 72 | except BaseException as e: 73 | # logger.warn('Regular expression failed to compile: %s' % pattern) 74 | # print('re failed to compile: [%s] due to [%s]' % (pattern, e)) 75 | return False 76 | return compiled.match(prediction) is not None 77 | 78 | 79 | def drqa_metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 80 | """Given a prediction and multiple valid answers, return the score of 81 | the best prediction-answer_n pair given a metric function. 82 | """ 83 | scores_for_ground_truths = [] 84 | for ground_truth in ground_truths: 85 | score = metric_fn(prediction, ground_truth) 86 | scores_for_ground_truths.append(score) 87 | return max(scores_for_ground_truths) 88 | 89 | 90 | def update_answer(metrics, prediction, gold): 91 | em = exact_match_score(prediction, gold) 92 | f1, prec, recall = f1_score(prediction, gold) 93 | metrics['em'] += em 94 | metrics['f1'] += f1 95 | metrics['prec'] += prec 96 | metrics['recall'] += recall 97 | return em, prec, recall 98 | 99 | 100 | def update_sp(metrics, prediction, gold): 101 | cur_sp_pred = set(map(tuple, prediction)) 102 | gold_sp_pred = set(map(tuple, gold)) 103 | tp, fp, fn = 0, 0, 0 104 | for e in cur_sp_pred: 105 | if e in gold_sp_pred: 106 | tp += 1 107 | else: 108 | fp += 1 109 | for e in gold_sp_pred: 110 | if e not in cur_sp_pred: 111 | fn += 1 112 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 113 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 114 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 115 | em = 1.0 if fp + fn == 0 else 0.0 116 | metrics['sp_em'] += em 117 | metrics['sp_f1'] += f1 118 | metrics['sp_prec'] += prec 119 | metrics['sp_recall'] += recall 120 | return em, prec, recall 121 | 122 | 123 | def eval(prediction_file, gold_file): 124 | with open(prediction_file) as f: 125 | prediction = json.load(f) 126 | with open(gold_file) as f: 127 | gold = json.load(f) 128 | 129 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 130 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 131 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} 132 | 133 | for dp in gold: 134 | cur_id = dp['_id'] 135 | em, prec, recall = update_answer( 136 | metrics, prediction['answer'][cur_id], dp['answer']) 137 | 138 | N = len(gold) 139 | for k in metrics.keys(): 140 | metrics[k] /= N 141 | 142 | print(metrics) 143 | 144 | 145 | def analyze(prediction_file, gold_file): 146 | with open(prediction_file) as f: 147 | prediction = json.load(f) 148 | with open(gold_file) as f: 149 | gold = json.load(f) 150 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 151 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 152 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} 153 | 154 | for dp in gold: 155 | cur_id = dp['_id'] 156 | 157 | em, prec, recall = update_answer( 158 | metrics, prediction['answer'][cur_id], dp['answer']) 159 | if (prec + recall == 0): 160 | f1 = 0 161 | else: 162 | f1 = 2 * prec * recall / (prec+recall) 163 | 164 | print (dp['answer'], prediction['answer'][cur_id]) 165 | print (f1, em) 166 | a = input() 167 | 168 | 169 | if __name__ == '__main__': 170 | #eval(sys.argv[1], sys.argv[2]) 171 | analyze(sys.argv[1], sys.argv[2]) 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /evaluate-v1.1.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | def remove_articles(text): 14 | return re.sub(r'\b(a|an|the)\b', ' ', text) 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def f1_score(prediction, ground_truth): 30 | prediction_tokens = normalize_answer(prediction).split() 31 | ground_truth_tokens = normalize_answer(ground_truth).split() 32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 33 | num_same = sum(common.values()) 34 | if num_same == 0: 35 | return 0 36 | precision = 1.0 * num_same / len(prediction_tokens) 37 | recall = 1.0 * num_same / len(ground_truth_tokens) 38 | f1 = (2 * precision * recall) / (precision + recall) 39 | return f1 40 | 41 | 42 | def exact_match_score(prediction, ground_truth): 43 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 44 | 45 | 46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 47 | scores_for_ground_truths = [] 48 | for ground_truth in ground_truths: 49 | score = metric_fn(prediction, ground_truth) 50 | scores_for_ground_truths.append(score) 51 | return max(scores_for_ground_truths) 52 | 53 | 54 | def evaluate(dataset, predictions): 55 | count = 0 56 | f1 = exact_match = total = 0 57 | for article in dataset: 58 | for paragraph in article['paragraphs']: 59 | for qa in paragraph['qas']: 60 | total += 1 61 | if str(qa['id']) not in predictions: 62 | message = 'Unanswered question ' + str(qa['id']) + \ 63 | ' will receive score 0.' 64 | count += 1 65 | # print(message, file=sys.stderr) 66 | continue 67 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 68 | prediction = predictions[str(qa['id'])] 69 | exact_match += metric_max_over_ground_truths( 70 | exact_match_score, prediction, ground_truths) 71 | f1 += metric_max_over_ground_truths( 72 | f1_score, prediction, ground_truths) 73 | 74 | exact_match = 100.0 * exact_match / total 75 | f1 = 100.0 * f1 / total 76 | if count: 77 | print('There are %d unanswered question(s)' % count) 78 | 79 | return {'exact_match': exact_match, 'f1': f1} 80 | 81 | 82 | if __name__ == '__main__': 83 | expected_version = '1.1' 84 | parser = argparse.ArgumentParser( 85 | description='Evaluation for SQuAD ' + expected_version) 86 | parser.add_argument('dataset_file', help='Dataset file') 87 | parser.add_argument('prediction_file', help='Prediction File') 88 | args = parser.parse_args() 89 | with open(args.dataset_file) as dataset_file: 90 | dataset_json = json.load(dataset_file) 91 | if (dataset_json['version'] != expected_version): 92 | print('Evaluation expects v-' + expected_version + 93 | ', but got dataset with v-' + dataset_json['version'], 94 | file=sys.stderr) 95 | dataset = dataset_json['data'] 96 | with open(args.prediction_file) as prediction_file: 97 | predictions = json.load(prediction_file) 98 | print(json.dumps(evaluate(dataset, predictions))) 99 | -------------------------------------------------------------------------------- /file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except (AttributeError, ImportError): 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | response = requests.head(url, allow_redirects=True) 192 | if response.status_code != 200: 193 | raise IOError("HEAD request failed for url {} with status code {}" 194 | .format(url, response.status_code)) 195 | etag = response.headers.get("ETag") 196 | 197 | filename = url_to_filename(url, etag) 198 | 199 | # get cache path to put the file 200 | cache_path = os.path.join(cache_dir, filename) 201 | 202 | if not os.path.exists(cache_path): 203 | # Download to temporary file, then copy to cache dir once finished. 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. 205 | with tempfile.NamedTemporaryFile() as temp_file: 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 207 | 208 | # GET file object 209 | if url.startswith("s3://"): 210 | s3_get(url, temp_file) 211 | else: 212 | http_get(url, temp_file) 213 | 214 | # we are copying the file before closing it, so flush to avoid truncation 215 | temp_file.flush() 216 | # shutil.copyfileobj() starts at the current position, so go to the start 217 | temp_file.seek(0) 218 | 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 220 | with open(cache_path, 'wb') as cache_file: 221 | shutil.copyfileobj(temp_file, cache_file) 222 | 223 | logger.info("creating metadata file for %s", cache_path) 224 | meta = {'url': url, 'etag': etag} 225 | meta_path = cache_path + '.json' 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 227 | json.dump(meta, meta_file) 228 | 229 | logger.info("removing temp file %s", temp_file.name) 230 | 231 | return cache_path 232 | 233 | 234 | def read_set_from_file(filename): 235 | ''' 236 | Extract a de-duped collection (set) of text from a file. 237 | Expected file format is one item per line. 238 | ''' 239 | collection = set() 240 | with open(filename, 'r', encoding='utf-8') as file_: 241 | for line in file_: 242 | collection.add(line.rstrip()) 243 | return collection 244 | 245 | 246 | def get_file_extension(path, dot=True, lower=True): 247 | ext = os.path.splitext(path)[1] 248 | ext = ext if dot else ext[1:] 249 | return ext.lower() if lower else ext 250 | -------------------------------------------------------------------------------- /input_examples.txt: -------------------------------------------------------------------------------- 1 | They defeated the Arizona Cardinals 49-15 in the NFC Championship Game and advanced to their second Super Bowl appearance since the franchise was founded in 1995. 2 | Between 1991 and 2000, the total area of forest lost in the Amazon rose from 415,000 to 587,000 square kilometres. 3 | Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. 4 | The normal force is due to repulsive forces of interaction between atoms at close contact. 5 | The service continued until the closure of BSkyB's analogue service on 27 September 2001, due to the launch and expansion of the Sky Digital platform. 6 | BSkyB launched its HDTV service, Sky+ HD, on 22 May 2006. 7 | Who was the chief executive officer when the service began? 8 | Which NFL team represented the AFC at Super Bowl 50? 9 | When will Ford's manufacturing plants close? 10 | How many dairy cows are there in Australia? 11 | -------------------------------------------------------------------------------- /local_dump.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import subprocess 5 | 6 | 7 | def run_dump_phrase(args): 8 | 9 | parallel = '--parallel' if args.parallel else '' 10 | do_case = '--do_case' if args.do_case else '' 11 | use_biobert = '--use_biobert' if args.use_biobert else '' 12 | append_title = '--append_title' if args.append_title else '' 13 | def get_cmd(start_doc, end_doc): 14 | return ["python", "run_natkb.py", 15 | "--metadata_dir", f"{args.metadata_dir}", 16 | "--data_dir", f"{args.phrase_data_dir}", 17 | "--predict_file", f"{start_doc}:{end_doc}", 18 | "--bert_model_option", f"{args.bert_model_option}", 19 | "--do_dump", 20 | "--use_sparse", 21 | "--filter_threshold", f"{args.filter_threshold:.2f}", 22 | "--dump_dir", f"{args.phrase_dump_dir}", 23 | "--dump_file", f"{start_doc}-{end_doc}.hdf5", 24 | "--max_seq_length", "512", 25 | "--load_dir", f"{args.load_dir}", 26 | "--load_epoch", f"{args.load_epoch}"] + \ 27 | ([f"{parallel}"] if len(parallel) > 0 else []) + \ 28 | ([f"{do_case}"] if len(do_case) > 0 else []) + \ 29 | ([f"{use_biobert}"] if len(use_biobert) > 0 else []) + \ 30 | ([f"{append_title}"] if len(append_title) > 0 else []) 31 | 32 | 33 | num_docs = args.end - args.start 34 | num_gpus = args.num_gpus 35 | num_docs_per_gpu = int(math.ceil(num_docs / num_gpus)) 36 | start_docs = list(range(args.start, args.end, num_docs_per_gpu)) 37 | end_docs = start_docs[1:] + [args.end] 38 | 39 | print(start_docs) 40 | print(end_docs) 41 | 42 | for device_idx, (start_doc, end_doc) in enumerate(zip(start_docs, end_docs)): 43 | print(get_cmd(start_doc, end_doc)) 44 | subprocess.Popen(get_cmd(start_doc, end_doc)) 45 | 46 | 47 | def get_args(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--dump_dir', default=None) 50 | parser.add_argument('--metadata_dir', default='/home/jinhyuk/models/bert') 51 | parser.add_argument('--data_dir', default='/home/jinhyuk/data/covid-19/preprocessed/2020-03-19/dump') 52 | parser.add_argument('--data_name', default='BioASQ') 53 | parser.add_argument('--load_dir', default='/home/jinhyuk/models/KR94373_piqa-nfs_1173') 54 | parser.add_argument('--load_epoch', default='1') 55 | parser.add_argument('--bert_model_option', default='large_uncased') 56 | parser.add_argument('--append_title', default=False, action='store_true') 57 | parser.add_argument('--parallel', default=False, action='store_true') 58 | parser.add_argument('--do_case', default=False, action='store_true') 59 | parser.add_argument('--use_biobert', default=False, action='store_true') 60 | parser.add_argument('--filter_threshold', default=-1e9, type=float) 61 | parser.add_argument('--num_gpus', default=1, type=int) 62 | parser.add_argument('--start', default=0, type=int) 63 | parser.add_argument('--end', default=8, type=int) 64 | args = parser.parse_args() 65 | 66 | if args.dump_dir is None: 67 | args.dump_dir = os.path.join('dump/%s_%s' % (os.path.basename(args.load_dir), 68 | os.path.basename(args.data_name))) 69 | if not os.path.exists(args.dump_dir): 70 | os.makedirs(args.dump_dir) 71 | 72 | args.phrase_data_dir = os.path.join(args.data_dir, args.data_name) 73 | args.phrase_dump_dir = os.path.join(args.dump_dir, 'phrase') 74 | if not os.path.exists(args.phrase_dump_dir): 75 | os.makedirs(args.phrase_dump_dir) 76 | 77 | return args 78 | 79 | 80 | def main(): 81 | args = get_args() 82 | run_dump_phrase(args) 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /mips_phrase.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import logging 6 | from collections import namedtuple, Counter 7 | from time import time 8 | 9 | import h5py 10 | import numpy as np 11 | import faiss 12 | import torch 13 | from tqdm import tqdm 14 | 15 | from scipy.sparse import vstack 16 | 17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 18 | level=logging.INFO) 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class MIPS(object): 23 | def __init__(self, phrase_dump_dir, tfidf_dump_dir, start_index_path, idx2id_path, max_norm_path, 24 | doc_rank_fn, cuda=False, dump_only=False): 25 | 26 | # If dump dir is a file, use it as a dump. 27 | if os.path.isdir(phrase_dump_dir): 28 | self.phrase_dump_paths = sorted( 29 | [os.path.join(phrase_dump_dir, name) for name in os.listdir(phrase_dump_dir) if 'hdf5' in name] 30 | ) 31 | dump_names = [os.path.splitext(os.path.basename(path))[0] for path in self.phrase_dump_paths] 32 | self.dump_ranges = [list(map(int, name.split('-'))) for name in dump_names] 33 | else: 34 | self.phrase_dump_paths = [phrase_dump_dir] 35 | self.phrase_dumps = [h5py.File(path, 'r') for path in self.phrase_dump_paths] 36 | 37 | # Load tfidf dump 38 | assert os.path.isdir(tfidf_dump_dir), tfidf_dump_dir 39 | self.tfidf_dump_paths = sorted( 40 | [os.path.join(tfidf_dump_dir, name) for name in os.listdir(tfidf_dump_dir) if 'hdf5' in name] 41 | ) 42 | tfidf_dump_names = [os.path.splitext(os.path.basename(path))[0] for path in self.tfidf_dump_paths] 43 | if '-' in tfidf_dump_names[0]: # Range check 44 | tfidf_dump_ranges = [list(map(int, name.split('_')[0].split('-'))) for name in tfidf_dump_names] 45 | assert tfidf_dump_ranges == self.dump_ranges 46 | self.tfidf_dumps = [h5py.File(path, 'r') for path in self.tfidf_dump_paths] 47 | logger.info(f'using doc ranker functions: {doc_rank_fn["index"]}') 48 | self.doc_rank_fn = doc_rank_fn 49 | if dump_only: 50 | return 51 | 52 | # Read index 53 | logger.info(f'Reading {start_index_path}') 54 | self.start_index = faiss.read_index(start_index_path, faiss.IO_FLAG_ONDISK_SAME_DIR) 55 | self.idx_f = self.load_idx_f(idx2id_path) 56 | with open(max_norm_path, 'r') as fp: 57 | self.max_norm = json.load(fp) 58 | 59 | # Options 60 | self.num_docs_list = [] 61 | self.cuda = cuda 62 | if self.cuda: 63 | assert torch.cuda.is_available(), f"Cuda availability {torch.cuda.is_available()}" 64 | self.device = torch.device('cuda') 65 | else: 66 | self.device = torch.device("cpu") 67 | 68 | def close(self): 69 | for phrase_dump in self.phrase_dumps: 70 | phrase_dump.close() 71 | for tfidf_dump in self.tfidf_dumps: 72 | tfidf_dump.close() 73 | 74 | def load_idx_f(self, idx2id_path): 75 | idx_f = {} 76 | types = ['doc', 'word'] 77 | with h5py.File(idx2id_path, 'r', driver='core', backing_store=False) as f: 78 | for key in tqdm(f, desc='loading idx2id'): 79 | idx_f_cur = {} 80 | for type_ in types: 81 | idx_f_cur[type_] = f[key][type_][:] 82 | idx_f[key] = idx_f_cur 83 | return idx_f 84 | 85 | def get_idxs(self, I): 86 | offsets = (I / 1e8).astype(np.int64) * int(1e8) 87 | idxs = I % int(1e8) 88 | doc = np.array( 89 | [[self.idx_f[str(offset)]['doc'][idx] for offset, idx in zip(oo, ii)] for oo, ii in zip(offsets, idxs)]) 90 | word = np.array([[self.idx_f[str(offset)]['word'][idx] for offset, idx in zip(oo, ii)] for oo, ii in 91 | zip(offsets, idxs)]) 92 | return doc, word 93 | 94 | def get_doc_group(self, doc_idx): 95 | if len(self.phrase_dumps) == 1: 96 | return self.phrase_dumps[0][str(doc_idx)] 97 | for dump_range, dump in zip(self.dump_ranges, self.phrase_dumps): 98 | if dump_range[0] * 1000 <= int(doc_idx) < dump_range[1] * 1000: 99 | if str(doc_idx) not in dump: 100 | raise ValueError('%d not found in dump list' % int(doc_idx)) 101 | return dump[str(doc_idx)] 102 | 103 | # Just check last 104 | if str(doc_idx) not in self.phrase_dumps[-1]: 105 | raise ValueError('%d not found in dump list' % int(doc_idx)) 106 | else: 107 | return self.phrase_dumps[-1][str(doc_idx)] 108 | 109 | def get_tfidf_group(self, doc_idx): 110 | if len(self.tfidf_dumps) == 1: 111 | return self.tfidf_dumps[0][str(doc_idx)] 112 | for dump_range, dump in zip(self.dump_ranges, self.tfidf_dumps): 113 | if dump_range[0] * 1000 <= int(doc_idx) < dump_range[1] * 1000: 114 | return dump[str(doc_idx)] 115 | 116 | # Just check last 117 | if str(doc_idx) not in self.tfidf_dumps[-1]: 118 | raise ValueError('%d not found in dump list' % int(doc_idx)) 119 | else: 120 | return self.tfidf_dumps[-1][str(doc_idx)] 121 | 122 | def int8_to_float(self, num, offset, factor): 123 | return num.astype(np.float32) / factor + offset 124 | 125 | def adjust(self, each): 126 | last = each['context'].rfind(' [PAR] ', 0, each['start_pos']) 127 | last = 0 if last == -1 else last + len(' [PAR] ') 128 | next = each['context'].find(' [PAR] ', each['end_pos']) 129 | next = len(each['context']) if next == -1 else next 130 | each['context'] = each['context'][last:next] 131 | each['start_pos'] -= last 132 | each['end_pos'] -= last 133 | return each 134 | 135 | def scale_l2_to_ip(self, l2_scores, max_norm=None, query_norm=None): 136 | """ 137 | sqrt(m^2 + q^2 - 2qx) -> m^2 + q^2 - 2qx -> qx - 0.5 (q^2 + m^2) 138 | Note that faiss index returns squared euclidean distance, so no need to square it again. 139 | """ 140 | if max_norm is None: 141 | return -0.5 * l2_scores 142 | assert query_norm is not None 143 | return -0.5 * (l2_scores - query_norm ** 2 - max_norm ** 2) 144 | 145 | def dequant(self, group, input_, attr='dense'): 146 | if 'offset' not in group.attrs: 147 | return input_ 148 | 149 | if attr == 'dense': 150 | return self.int8_to_float(input_, group.attrs['offset'], group.attrs['scale']) 151 | elif attr == 'sparse': 152 | return self.int8_to_float(input_, group.attrs['sparse_offset'], group.attrs['sparse_scale']) 153 | else: 154 | raise NotImplementedError() 155 | 156 | def sparse_bmm(self, q_ids, q_vals, p_ids, p_vals): 157 | """ 158 | Efficient batch inner product after slicing (matrix x matrix) 159 | """ 160 | q_max = max([len(q) for q in q_ids]) 161 | p_max = max([len(p) for p in p_ids]) 162 | factor = len(p_ids)//len(q_ids) 163 | assert q_max == max([len(q) for q in q_vals]) and p_max == max([len(p) for p in p_vals]) 164 | with torch.no_grad(): 165 | q_ids_pad = torch.LongTensor([q_id.tolist() + [0]*(q_max-len(q_id)) for q_id in q_ids]).to(self.device) 166 | q_ids_pad = q_ids_pad.repeat(1, factor).view(len(p_ids), -1) # Repeat for p 167 | q_vals_pad = torch.FloatTensor([q_val.tolist() + [0]*(q_max-len(q_val)) for q_val in q_vals]).to(self.device) 168 | q_vals_pad = q_vals_pad.repeat(1, factor).view(len(p_vals), -1) # Repeat for p 169 | p_ids_pad = torch.LongTensor([p_id.tolist() + [0]*(p_max-len(p_id)) for p_id in p_ids]).to(self.device) 170 | p_vals_pad = torch.FloatTensor([p_val.tolist() + [0]*(p_max-len(p_val)) for p_val in p_vals]).to(self.device) 171 | id_map = q_ids_pad.unsqueeze(1) 172 | id_map_ = p_ids_pad.unsqueeze(2) 173 | match = (id_map == id_map_).to(torch.float32) 174 | val_map = q_vals_pad.unsqueeze(1) 175 | val_map_ = p_vals_pad.unsqueeze(2) 176 | sp_scores = ((val_map * val_map_) * match).sum([1, 2]) 177 | return sp_scores.cpu().numpy() 178 | 179 | def search_dense(self, q_texts, query_start, start_top_k, nprobe, sparse_weight=0.05): 180 | batch_size = query_start.shape[0] 181 | self.start_index.nprobe = nprobe 182 | 183 | # Query concatenation for l2 to ip 184 | query_start = np.concatenate([np.zeros([batch_size, 1]).astype(np.float32), query_start], axis=1) 185 | 186 | # Search with faiss 187 | start_scores, I = self.start_index.search(query_start, start_top_k) 188 | query_norm = np.linalg.norm(query_start, ord=2, axis=1) 189 | start_scores = self.scale_l2_to_ip(start_scores, max_norm=self.max_norm, query_norm=np.expand_dims(query_norm, 1)) 190 | 191 | # Get idxs from resulting I 192 | doc_idxs, start_idxs = self.get_idxs(I) 193 | 194 | # For record 195 | num_docs = sum([len(set(doc_idx.flatten().tolist())) for doc_idx in doc_idxs]) / batch_size 196 | self.num_docs_list.append(num_docs) 197 | 198 | # Doc-level sparse score 199 | b_doc_scores = self.doc_rank_fn['index'](q_texts, doc_idxs.tolist()) # Index 200 | for b_idx in range(batch_size): 201 | start_scores[b_idx] += np.array(b_doc_scores[b_idx]) * sparse_weight 202 | 203 | return (doc_idxs, start_idxs), start_scores 204 | 205 | def search_sparse(self, q_texts, query_start, doc_top_k, start_top_k, sparse_weight=0.05): 206 | batch_size = query_start.shape[0] 207 | 208 | # Reduce search space by doc scores 209 | top_doc_idxs, top_doc_scores = self.doc_rank_fn['top_docs'](q_texts, doc_top_k) # Top docs 210 | 211 | # For each item, add start scores 212 | b_doc_idxs = [] 213 | b_start_idxs = [] 214 | b_scores = [] 215 | max_phrases = 0 216 | for b_idx in range(batch_size): 217 | doc_idxs = [] 218 | start_idxs = [] 219 | scores = [] 220 | for doc_idx, doc_score in zip(top_doc_idxs[b_idx], top_doc_scores[b_idx]): 221 | try: 222 | doc_group = self.get_doc_group(doc_idx) 223 | except ValueError: 224 | continue 225 | start = self.dequant(doc_group, doc_group['start'][:]) 226 | cur_scores = np.sum(query_start[b_idx] * start, 1) 227 | for i, cur_score in enumerate(cur_scores): 228 | doc_idxs.append(doc_idx) 229 | start_idxs.append(i) 230 | scores.append(cur_score + sparse_weight * doc_score) 231 | max_phrases = len(scores) if len(scores) > max_phrases else max_phrases 232 | 233 | b_doc_idxs.append(doc_idxs) 234 | b_start_idxs.append(start_idxs) 235 | b_scores.append(scores) 236 | 237 | # If start_top_k is larger than nonnegative doc_idxs, we need to cut them later 238 | for doc_idxs, start_idxs, scores in zip(b_doc_idxs, b_start_idxs, b_scores): 239 | doc_idxs += [-1] * (max_phrases - len(doc_idxs)) 240 | start_idxs += [-1] * (max_phrases - len(start_idxs)) 241 | scores += [-10**9] * (max_phrases - len(scores)) 242 | 243 | doc_idxs, start_idxs, scores = np.stack(b_doc_idxs), np.stack(b_start_idxs), np.stack(b_scores) 244 | return (doc_idxs, start_idxs), scores 245 | 246 | def batch_par_scores(self, q_texts, q_sparses, doc_idxs, start_idxs, sparse_weight=0.05, mid_top_k=100): 247 | # Reshape for sparse 248 | num_queries = len(q_texts) 249 | doc_idxs = np.reshape(doc_idxs, [-1]) 250 | start_idxs = np.reshape(start_idxs, [-1]) 251 | 252 | default_doc = [doc_idx for doc_idx in doc_idxs if doc_idx >= 0][0] 253 | groups = [self.get_doc_group(doc_idx) if doc_idx >= 0 else self.get_doc_group(default_doc) 254 | for doc_idx in doc_idxs] 255 | 256 | # Calculate paragraph start end location in sparse vector 257 | para_lens = [group['len_per_para'][:] for group in groups] 258 | f2o_start = [group['f2o_start'][:] for group in groups] 259 | para_bounds = [[(sum(para_len[:para_idx]), sum(para_len[:para_idx+1])) for 260 | para_idx in range(len(para_len))] for para_len in para_lens] 261 | para_idxs = [] 262 | for para_bound, start_idx, f2o in zip(para_bounds, start_idxs, f2o_start): 263 | para_bound = np.array(para_bound) 264 | curr_idx = ((f2o[start_idx] >= para_bound[:,0]) & (f2o[start_idx] < para_bound[:,1])).nonzero()[0][0] 265 | para_idxs.append(curr_idx) 266 | para_startend = [para_bound[para_idx] for para_bound, para_idx in zip(para_bounds, para_idxs)] 267 | 268 | # 1) TF-IDF based paragraph score 269 | q_spvecs = self.doc_rank_fn['spvec'](q_texts) # Spvec 270 | qtf_ids = [np.array(q) for q in q_spvecs[1]] 271 | qtf_vals = [np.array(q) for q in q_spvecs[0]] 272 | tfidf_groups = [self.get_tfidf_group(doc_idx) if doc_idx >= 0 else self.get_tfidf_group(default_doc) 273 | for doc_idx in doc_idxs] 274 | tfidf_groups = [group[str(para_idx)] for group, para_idx in zip(tfidf_groups, para_idxs)] 275 | ptf_ids = [data['idxs'][:] for data in tfidf_groups] 276 | ptf_vals = [data['vals'][:] for data in tfidf_groups] 277 | tf_scores = self.sparse_bmm(qtf_ids, qtf_vals, ptf_ids, ptf_vals) * sparse_weight 278 | 279 | # 2) Sparse vectors based paragraph score 280 | q_ids, q_unis, q_bis = q_sparses 281 | q_ids = [np.array(q) for q in q_ids] 282 | q_unis = [np.array(q) for q in q_unis] 283 | q_bis = [np.array(q)[:-1] for q in q_bis] 284 | p_ids_tmp = [group['input_ids'][:] for group in groups] 285 | p_unis_tmp = [group['sparse'][:, :] for group in groups] 286 | p_bis_tmp = [group['sparse_bi'][:, :] for group in groups] 287 | p_ids = [sparse_id[p_se[0]:p_se[1]] 288 | for sparse_id, p_se in zip(p_ids_tmp, para_startend)] 289 | p_unis = [self.dequant(groups[0], sparse_val[start_idx,:p_se[1]-p_se[0]], attr='sparse') 290 | for sparse_val, p_se, start_idx in zip(p_unis_tmp, para_startend, start_idxs)] 291 | p_bis = [self.dequant(groups[0], sparse_bi_val[start_idx,:p_se[1]-p_se[0]-1], attr='sparse') 292 | for sparse_bi_val, p_se, start_idx in zip(p_bis_tmp, para_startend, start_idxs)] 293 | sp_scores = self.sparse_bmm(q_ids, q_unis, p_ids, p_unis) 294 | 295 | # For bigram 296 | MAXV = 30522 297 | q_bids = [np.array([a*MAXV+b for a, b in zip(q_id[:-1], q_id[1:])]) for q_id in q_ids] 298 | p_bids = [np.array([a*MAXV+b for a, b in zip(p_id[:-1], p_id[1:])]) for p_id in p_ids] 299 | sp_scores += self.sparse_bmm(q_bids, q_bis, p_bids, p_bis) 300 | 301 | return np.reshape(tf_scores + sp_scores, [num_queries, -1]) 302 | 303 | def search_start(self, query_start, sparse_query, q_texts=None, 304 | nprobe=16, doc_top_k=5, start_top_k=100, mid_top_k=20, top_k=5, 305 | search_strategy='dense_first', sparse_weight=0.05, no_para=False): 306 | 307 | assert self.start_index is not None 308 | query_start = query_start.astype(np.float32) 309 | batch_size = query_start.shape[0] 310 | # start_time = time() 311 | 312 | # 1) Branch based on the strategy (start_top_k) + doc_score 313 | if search_strategy == 'dense_first': 314 | (doc_idxs, start_idxs), start_scores = self.search_dense( 315 | q_texts, query_start, start_top_k, nprobe, sparse_weight 316 | ) 317 | elif search_strategy == 'sparse_first': 318 | (doc_idxs, start_idxs), start_scores = self.search_sparse( 319 | q_texts, query_start, doc_top_k, start_top_k, sparse_weight 320 | ) 321 | elif search_strategy == 'hybrid': 322 | (doc_idxs, start_idxs), start_scores = self.search_dense( 323 | q_texts, query_start, start_top_k, nprobe, sparse_weight 324 | ) 325 | (doc_idxs_, start_idxs_), start_scores_ = self.search_sparse( 326 | q_texts, query_start, doc_top_k, start_top_k, sparse_weight 327 | ) 328 | 329 | # There could be a duplicate but it's difficult to remove 330 | doc_idxs = np.concatenate([doc_idxs, doc_idxs_], -1) 331 | start_idxs = np.concatenate([start_idxs, start_idxs_], -1) 332 | start_scores = np.concatenate([start_scores, start_scores_], -1) 333 | else: 334 | raise ValueError(search_strategy) 335 | 336 | # 2) Rerank and reduce (mid_top_k) 337 | rerank_idxs = np.argsort(start_scores, axis=1)[:,-mid_top_k:][:,::-1] 338 | doc_idxs = doc_idxs.tolist() 339 | start_idxs = start_idxs.tolist() 340 | start_scores = start_scores.tolist() 341 | for b_idx in range(batch_size): 342 | doc_idxs[b_idx] = np.array(doc_idxs[b_idx])[rerank_idxs[b_idx]] 343 | start_idxs[b_idx] = np.array(start_idxs[b_idx])[rerank_idxs[b_idx]] 344 | start_scores[b_idx] = np.array(start_scores[b_idx])[rerank_idxs[b_idx]] 345 | 346 | # logger.info(f'1st rerank ({start_top_k} => {mid_top_k}), {np.array(start_scores).shape}, {time()-start_time}') 347 | # start_time = time() 348 | 349 | # Para-level sparse score 350 | if not no_para: 351 | par_scores = self.batch_par_scores(q_texts, sparse_query, doc_idxs, start_idxs, sparse_weight, mid_top_k) 352 | start_scores = np.stack(start_scores) + par_scores 353 | start_scores = [s for s in start_scores] 354 | 355 | # 3) Rerank and reduce (top_k) 356 | rerank_idxs = np.argsort(start_scores, axis=1)[:,-top_k:][:,::-1] 357 | for b_idx in range(batch_size): 358 | doc_idxs[b_idx] = doc_idxs[b_idx][rerank_idxs[b_idx]] 359 | start_idxs[b_idx] = start_idxs[b_idx][rerank_idxs[b_idx]] 360 | start_scores[b_idx] = start_scores[b_idx][rerank_idxs[b_idx]] 361 | 362 | doc_idxs = np.stack(doc_idxs) 363 | start_idxs = np.stack(start_idxs) 364 | start_scores = np.stack(start_scores) 365 | 366 | # logger.info(f'2nd rerank ({mid_top_k} => {top_k}), {start_scores.shape}, {time()-start_time}') 367 | return start_scores, doc_idxs, start_idxs 368 | 369 | def search_end(self, query, doc_idxs, start_idxs, start_scores=None, top_k=5, max_answer_length=20): 370 | # Reshape for end 371 | num_queries = query.shape[0] 372 | query = np.reshape(np.tile(np.expand_dims(query, 1), [1, top_k, 1]), [-1, query.shape[1]]) 373 | q_idxs = np.reshape(np.tile(np.expand_dims(np.arange(num_queries), 1), [1, top_k]), [-1]) 374 | doc_idxs = np.reshape(doc_idxs, [-1]) 375 | start_idxs = np.reshape(start_idxs, [-1]) 376 | start_scores = np.reshape(start_scores, [-1]) 377 | 378 | # Get query_end and groups 379 | bs = int((query.shape[1] - 1) / 2) # Boundary of start 380 | query_end, query_span_logit = query[:,bs:2*bs], query[:,-1:] 381 | default_doc = [doc_idx for doc_idx in doc_idxs if doc_idx >= 0][0] 382 | groups = [self.get_doc_group(doc_idx) if doc_idx >= 0 else self.get_doc_group(default_doc) 383 | for doc_idx in doc_idxs] 384 | ends = [group['end'][:] for group in groups] 385 | spans = [group['span_logits'][:] for group in groups] 386 | default_end = np.zeros(bs).astype(np.float32) 387 | 388 | # Calculate end 389 | end_idxs = [group['start2end'][start_idx, :max_answer_length] 390 | for group, start_idx in zip(groups, start_idxs)] # [Q, L] 391 | end_mask = -1e9 * (np.array(end_idxs) < 0) # [Q, L] 392 | 393 | end = np.stack([[each_end[each_end_idx, :] if each_end.size > 0 else default_end 394 | for each_end_idx in each_end_idxs] 395 | for each_end, each_end_idxs in zip(ends, end_idxs)], 0) # [Q, L, d] 396 | end = self.dequant(groups[0], end) 397 | span = np.stack([[each_span[start_idx, i] for i in range(len(each_end_idxs))] 398 | for each_span, start_idx, each_end_idxs in zip(spans, start_idxs, end_idxs)], 0) # [Q, L] 399 | 400 | with torch.no_grad(): 401 | end = torch.FloatTensor(end).to(self.device) 402 | query_end = torch.FloatTensor(query_end).to(self.device) 403 | end_scores = (query_end.unsqueeze(1) * end).sum(2).cpu().numpy() 404 | span_scores = query_span_logit * span # [Q, L] 405 | scores = np.expand_dims(start_scores, 1) + end_scores + span_scores + end_mask # [Q, L] 406 | pred_end_idxs = np.stack([each[idx] for each, idx in zip(end_idxs, np.argmax(scores, 1))], 0) # [Q] 407 | max_scores = np.max(scores, 1) 408 | 409 | # Get answers 410 | out = [{'context': group.attrs['context'], 'title': group.attrs['title'], 'doc_idx': doc_idx, 411 | 'start_pos': group['word2char_start'][group['f2o_start'][start_idx]].item(), 412 | 'end_pos': (group['word2char_end'][group['f2o_end'][end_idx]].item() if len(group['word2char_end']) > 0 413 | else group['word2char_start'][group['f2o_start'][start_idx]].item() + 1), 414 | 'start_idx': start_idx, 'end_idx': end_idx, 'score': score} 415 | for doc_idx, group, start_idx, end_idx, score in zip(doc_idxs.tolist(), groups, start_idxs.tolist(), 416 | pred_end_idxs.tolist(), max_scores.tolist())] 417 | for each in out: 418 | each['answer'] = each['context'][each['start_pos']:each['end_pos']] 419 | out = [self.adjust(each) for each in out] 420 | 421 | # Sort output 422 | new_out = [[] for _ in range(num_queries)] 423 | for idx, each_out in zip(q_idxs, out): 424 | new_out[idx].append(each_out) 425 | for i in range(len(new_out)): 426 | new_out[i] = sorted(new_out[i], key=lambda each_out: -each_out['score']) 427 | new_out[i] = list(filter(lambda x: x['score'] > -1e5, new_out[i])) # In case of no output but masks 428 | return new_out 429 | 430 | def filter_results(self, results): 431 | out = [] 432 | for result in results: 433 | c = Counter(result['context']) 434 | if c['?'] > 3: 435 | continue 436 | if c['!'] > 5: 437 | continue 438 | out.append(result) 439 | return out 440 | 441 | def search(self, query, sparse_query, q_texts=None, 442 | nprobe=256, doc_top_k=5, start_top_k=1000, mid_top_k=100, top_k=10, 443 | search_strategy='dense_first', filter_=False, aggregate=False, return_idxs=False, 444 | max_answer_length=20, sparse_weight=0.05, no_para=False): 445 | 446 | # Search start 447 | start_scores, doc_idxs, start_idxs = self.search_start( 448 | query[:, :int((query.shape[1] -1) / 2)], 449 | sparse_query, 450 | q_texts=q_texts, 451 | nprobe=nprobe, 452 | doc_top_k=doc_top_k, 453 | start_top_k=start_top_k, 454 | mid_top_k=mid_top_k, 455 | top_k=top_k, 456 | search_strategy=search_strategy, 457 | sparse_weight=sparse_weight, 458 | no_para=no_para, 459 | ) 460 | 461 | # start_time = time() 462 | # Search end 463 | outs = self.search_end( 464 | query, doc_idxs, start_idxs, start_scores=start_scores, 465 | top_k=top_k, max_answer_length=max_answer_length 466 | ) 467 | # logger.info(f'last rerank ({top_k}), {len(outs)}, {time()-start_time}') 468 | 469 | if filter_: 470 | outs = [self.filter_results(results) for results in outs] 471 | if return_idxs: 472 | return [[(out_['doc_idx'], out_['start_idx'], out_['end_idx'], out_['answer']) for out_ in out ] for out in outs] 473 | if doc_idxs.shape[1] != top_k: 474 | logger.info(f"Warning.. {doc_idxs.shape[1]} only retrieved") 475 | top_k = doc_idxs.shape[1] 476 | 477 | return outs 478 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | def warmup_cosine(x, warmup=0.002): 23 | if x < warmup: 24 | return x/warmup 25 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 26 | 27 | def warmup_constant(x, warmup=0.002): 28 | if x < warmup: 29 | return x/warmup 30 | return 1.0 31 | 32 | def warmup_linear(x, warmup=0.002): 33 | if x < warmup: 34 | return x/warmup 35 | return 1.0 - x 36 | 37 | SCHEDULES = { 38 | 'warmup_cosine':warmup_cosine, 39 | 'warmup_constant':warmup_constant, 40 | 'warmup_linear':warmup_linear, 41 | } 42 | 43 | 44 | class BERTAdam(Optimizer): 45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 46 | Params: 47 | lr: learning rate 48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 49 | t_total: total number of training steps for the learning 50 | rate schedule, -1 means constant learning rate. Default: -1 51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 52 | b1: Adams b1. Default: 0.9 53 | b2: Adams b2. Default: 0.999 54 | e: Adams epsilon. Default: 1e-6 55 | weight_decay_rate: Weight decay. Default: 0.01 56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 57 | """ 58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 60 | max_grad_norm=1.0): 61 | if not lr >= 0.0: 62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 63 | if schedule not in SCHEDULES: 64 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 65 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 67 | if not 0.0 <= b1 < 1.0: 68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 69 | if not 0.0 <= b2 < 1.0: 70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 71 | if not e >= 0.0: 72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 75 | max_grad_norm=max_grad_norm) 76 | super(BERTAdam, self).__init__(params, defaults) 77 | 78 | def get_lr(self): 79 | lr = [] 80 | for group in self.param_groups: 81 | for p in group['params']: 82 | state = self.state[p] 83 | if len(state) == 0: 84 | return [0] 85 | if group['t_total'] != -1: 86 | schedule_fct = SCHEDULES[group['schedule']] 87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 88 | else: 89 | lr_scheduled = group['lr'] 90 | lr.append(lr_scheduled) 91 | return lr 92 | 93 | def step(self, closure=None): 94 | """Performs a single optimization step. 95 | 96 | Arguments: 97 | closure (callable, optional): A closure that reevaluates the model 98 | and returns the loss. 99 | """ 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | if p.grad is None: 107 | continue 108 | grad = p.grad.data 109 | if grad.is_sparse: 110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 111 | 112 | state = self.state[p] 113 | 114 | # State initialization 115 | if len(state) == 0: 116 | state['step'] = 0 117 | # Exponential moving average of gradient values 118 | state['next_m'] = torch.zeros_like(p.data) 119 | # Exponential moving average of squared gradient values 120 | state['next_v'] = torch.zeros_like(p.data) 121 | 122 | next_m, next_v = state['next_m'], state['next_v'] 123 | beta1, beta2 = group['b1'], group['b2'] 124 | 125 | # Add grad clipping 126 | if group['max_grad_norm'] > 0: 127 | clip_grad_norm_(p, group['max_grad_norm']) 128 | 129 | # Decay the first and second moment running average coefficient 130 | # In-place operations to update the averages at the same time 131 | next_m.mul_(beta1).add_(1 - beta1, grad) 132 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 133 | update = next_m / (next_v.sqrt() + group['e']) 134 | 135 | # Just adding the square of the weights to the loss function is *not* 136 | # the correct way of using L2 regularization/weight decay with Adam, 137 | # since that will interact with the m and v parameters in strange ways. 138 | # 139 | # Instead we want ot decay the weights in a manner that doesn't interact 140 | # with the m/v parameters. This is equivalent to adding the square 141 | # of the weights to the loss with plain (non-momentum) SGD. 142 | if group['weight_decay_rate'] > 0.0: 143 | update += group['weight_decay_rate'] * p.data 144 | 145 | if group['t_total'] != -1: 146 | schedule_fct = SCHEDULES[group['schedule']] 147 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 148 | else: 149 | lr_scheduled = group['lr'] 150 | 151 | update_with_lr = lr_scheduled * update 152 | p.data.add_(-update_with_lr) 153 | 154 | state['step'] += 1 155 | 156 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 157 | # bias_correction1 = 1 - beta1 ** state['step'] 158 | # bias_correction2 = 1 - beta2 ** state['step'] 159 | 160 | return loss 161 | -------------------------------------------------------------------------------- /post.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import logging 4 | import os 5 | import shutil 6 | import torch 7 | import math 8 | import pandas as pd 9 | import numpy as np 10 | import six 11 | import h5py 12 | from multiprocessing import Process 13 | from time import time 14 | from scipy.sparse import csr_matrix, save_npz, hstack, vstack 15 | from termcolor import colored, cprint 16 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 17 | from eval_utils import normalize_answer, f1_score, exact_match_score 18 | 19 | from multiprocessing import Queue 20 | from multiprocessing.pool import ThreadPool 21 | from threading import Thread 22 | 23 | from tqdm import tqdm as tqdm_ 24 | from decimal import * 25 | 26 | import tokenization 27 | 28 | QuestionResult = collections.namedtuple("QuestionResult", 29 | ['qas_id', 'start', 'end', 'sparse', 'input_ids']) 30 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 31 | "NbestPrediction", ["text", "logit", "no_answer_logit"]) 32 | 33 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 34 | datefmt='%m/%d/%Y %H:%M:%S', 35 | level=logging.INFO) 36 | logger = logging.getLogger(__name__) 37 | 38 | # For debugging 39 | quant_stat = {} 40 | b_quant_stat = {} 41 | ranker = None 42 | 43 | 44 | def tqdm(*args, mininterval=5.0, **kwargs): 45 | return tqdm_(*args, mininterval=mininterval, **kwargs) 46 | 47 | 48 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 49 | orig_answer_text): 50 | """Returns tokenized answer spans that better match the annotated answer.""" 51 | 52 | # The SQuAD annotations are character based. We first project them to 53 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 54 | # often find a "better match". For example: 55 | # 56 | # Question: What year was John Smith born? 57 | # Context: The leader was John Smith (1895-1943). 58 | # Answer: 1895 59 | # 60 | # The original whitespace-tokenized answer will be "(1895-1943).". However 61 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 62 | # the exact answer, 1895. 63 | # 64 | # However, this is not always possible. Consider the following: 65 | # 66 | # Question: What country is the top exporter of electornics? 67 | # Context: The Japanese electronics industry is the lagest in the world. 68 | # Answer: Japan 69 | # 70 | # In this case, the annotator chose "Japan" as a character sub-span of 71 | # the word "Japanese". Since our WordPiece tokenizer does not split 72 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 73 | # in SQuAD, but does happen. 74 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 75 | 76 | for new_start in range(input_start, input_end + 1): 77 | for new_end in range(input_end, new_start - 1, -1): 78 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 79 | if text_span == tok_answer_text: 80 | return (new_start, new_end) 81 | 82 | return (input_start, input_end) 83 | 84 | 85 | def _check_is_max_context(doc_spans, cur_span_index, position): 86 | """Check if this is the 'max context' doc span for the token.""" 87 | 88 | # Because of the sliding window approach taken to scoring documents, a single 89 | # token can appear in multiple documents. E.g. 90 | # Doc: the man went to the store and bought a gallon of milk 91 | # Span A: the man went to the 92 | # Span B: to the store and bought 93 | # Span C: and bought a gallon of 94 | # ... 95 | # 96 | # Now the word 'bought' will have two scores from spans B and C. We only 97 | # want to consider the score with "maximum context", which we define as 98 | # the *minimum* of its left and right context (the *sum* of left and 99 | # right context will always be the same, of course). 100 | # 101 | # In the example the maximum context for 'bought' would be span C since 102 | # it has 1 left context and 3 right context, while span B has 4 left context 103 | # and 0 right context. 104 | best_score = None 105 | best_span_index = None 106 | for (span_index, doc_span) in enumerate(doc_spans): 107 | end = doc_span.start + doc_span.length - 1 108 | if position < doc_span.start: 109 | continue 110 | if position > end: 111 | continue 112 | num_left_context = position - doc_span.start 113 | num_right_context = end - position 114 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 115 | if best_score is None or score > best_score: 116 | best_score = score 117 | best_span_index = span_index 118 | 119 | return cur_span_index == best_span_index 120 | 121 | 122 | def get_metadata(id2example, features, results, max_answer_length, do_lower_case, verbose_logging): 123 | start = np.concatenate([result.start[1:len(feature.tokens) - 1] for feature, result in zip(features, results)], 124 | axis=0) 125 | end = np.concatenate([result.end[1:len(feature.tokens) - 1] for feature, result in zip(features, results)], axis=0) 126 | 127 | input_ids = None 128 | sparse_map = None 129 | sparse_bi_map = None 130 | sparse_tri_map = None 131 | len_per_para = [] 132 | if results[0].start_sp is not None: 133 | input_ids = np.concatenate([f.input_ids[1:len(f.tokens) - 1] for f in features], axis=0) 134 | sparse_features = None # uni 135 | sparse_bi_features = None 136 | sparse_tri_features = None 137 | if '1' in results[0].start_sp: 138 | sparse_features = [result.start_sp['1'][1:len(feature.tokens)-1, 1:len(feature.tokens)-1] 139 | for feature, result in zip(features, results)] 140 | if '2' in results[0].start_sp: 141 | sparse_bi_features = [result.start_sp['2'][1:len(feature.tokens)-1, 1:len(feature.tokens)-1] 142 | for feature, result in zip(features, results)] 143 | if '3' in results[0].start_sp: 144 | sparse_tri_features = [result.start_sp['3'][1:len(feature.tokens)-1, 1:len(feature.tokens)-1] 145 | for feature, result in zip(features, results)] 146 | 147 | map_size = max([k.shape[0] for k in sparse_features]) 148 | sparse_map = np.zeros((input_ids.shape[0], map_size), dtype=np.float32) 149 | if sparse_bi_features is not None: 150 | sparse_bi_map = np.zeros((input_ids.shape[0], map_size), dtype=np.float32) 151 | if sparse_tri_features is not None: 152 | sparse_tri_map = np.zeros((input_ids.shape[0], map_size), dtype=np.float32) 153 | 154 | curr_size = 0 155 | for sidx, sparse_feature in enumerate(sparse_features): 156 | sparse_map[curr_size:curr_size + sparse_feature.shape[0],:sparse_feature.shape[1]] += sparse_feature 157 | if sparse_bi_features is not None: 158 | assert sparse_bi_features[sidx].shape == sparse_feature.shape 159 | sparse_bi_map[curr_size:curr_size + sparse_bi_features[sidx].shape[0],:sparse_bi_features[sidx].shape[1]] += \ 160 | sparse_bi_features[sidx] 161 | if sparse_tri_features is not None: 162 | assert sparse_tri_features[sidx].shape == sparse_feature.shape 163 | sparse_tri_map[curr_size:curr_size + sparse_tri_features[sidx].shape[0],:sparse_tri_features[sidx].shape[1]] += \ 164 | sparse_tri_features[sidx] 165 | curr_size += sparse_feature.shape[0] 166 | len_per_para.append(sparse_feature.shape[0]) 167 | 168 | assert input_ids.shape[0] == start.shape[0] and curr_size == sparse_map.shape[0] 169 | 170 | fs = np.concatenate([result.filter_start_logits[1:len(feature.tokens) - 1] 171 | for feature, result in zip(features, results)], 172 | axis=0) 173 | fe = np.concatenate([result.filter_end_logits[1:len(feature.tokens) - 1] 174 | for feature, result in zip(features, results)], 175 | axis=0) 176 | 177 | span_logits = np.zeros([np.shape(start)[0], max_answer_length], dtype=start.dtype) 178 | start2end = -1 * np.ones([np.shape(start)[0], max_answer_length], dtype=np.int32) 179 | idx = 0 180 | for feature, result in zip(features, results): 181 | for i in range(1, len(feature.tokens) - 1): 182 | for j in range(i, min(i + max_answer_length, len(feature.tokens) - 1)): 183 | span_logits[idx, j - i] = result.span_logits[i, j] 184 | start2end[idx, j - i] = idx + j - i 185 | idx += 1 186 | 187 | word2char_start = np.zeros([start.shape[0]], dtype=np.int32) 188 | word2char_end = np.zeros([start.shape[0]], dtype=np.int32) 189 | 190 | sep = ' [PAR] ' 191 | full_text = "" 192 | prev_example = None 193 | 194 | word_pos = 0 195 | for feature in features: 196 | example = id2example[feature.unique_id] 197 | if prev_example is not None and feature.doc_span_index == 0: 198 | full_text = full_text + ' '.join(prev_example.doc_words) + sep 199 | 200 | for i in range(1, len(feature.tokens) - 1): 201 | _, start_pos, _ = get_final_text_(example, feature, i, min(len(feature.tokens) - 2, i + 1), do_lower_case, 202 | verbose_logging) 203 | _, _, end_pos = get_final_text_(example, feature, max(1, i - 1), i, do_lower_case, 204 | verbose_logging) 205 | start_pos += len(full_text) 206 | end_pos += len(full_text) 207 | word2char_start[word_pos] = start_pos 208 | word2char_end[word_pos] = end_pos 209 | word_pos += 1 210 | prev_example = example 211 | full_text = full_text + ' '.join(prev_example.doc_words) 212 | 213 | metadata = {'did': prev_example.doc_idx, 'context': full_text, 'title': prev_example.title, 214 | 'start': start, 'end': end, 'span_logits': span_logits, 215 | 'start2end': start2end, 216 | 'word2char_start': word2char_start, 'word2char_end': word2char_end, 217 | 'filter_start': fs, 'filter_end': fe, 'input_ids': input_ids, 218 | 'sparse': sparse_map, 'sparse_bi': sparse_bi_map, 'sparse_tri': sparse_tri_map, 219 | 'len_per_para': len_per_para} 220 | 221 | return metadata 222 | 223 | 224 | def filter_metadata(metadata, threshold): 225 | start_idxs, = np.where(metadata['filter_start'] > threshold) 226 | end_idxs, = np.where(metadata['filter_end'] > threshold) 227 | end_long2short = {long: short for short, long in enumerate(end_idxs)} 228 | metadata['start'] = metadata['start'][start_idxs] 229 | metadata['end'] = metadata['end'][end_idxs] 230 | metadata['sparse'] = metadata['sparse'][start_idxs] 231 | if metadata['sparse_bi'] is not None: 232 | metadata['sparse_bi'] = metadata['sparse_bi'][start_idxs] 233 | if metadata['sparse_tri'] is not None: 234 | metadata['sparse_tri'] = metadata['sparse_tri'][start_idxs] 235 | metadata['f2o_start'] = start_idxs 236 | metadata['f2o_end'] = end_idxs 237 | metadata['span_logits'] = metadata['span_logits'][start_idxs] 238 | metadata['start2end'] = metadata['start2end'][start_idxs] 239 | for i, each in enumerate(metadata['start2end']): 240 | for j, long in enumerate(each.tolist()): 241 | metadata['start2end'][i, j] = end_long2short[long] if long in end_long2short else -1 242 | 243 | return metadata 244 | 245 | 246 | def compress_metadata(metadata, dense_offset, dense_scale, sparse_offset, sparse_scale): 247 | for key in ['start', 'end']: 248 | if key in metadata: 249 | metadata[key] = float_to_int8(metadata[key], dense_offset, dense_scale) 250 | for key in ['sparse', 'sparse_bi', 'sparse_tri']: 251 | if key in metadata and metadata[key] is not None: 252 | metadata[key] = float_to_int8(metadata[key], sparse_offset, sparse_scale) 253 | return metadata 254 | 255 | 256 | def pool_func(item): 257 | metadata_ = get_metadata(*item[:-1]) 258 | metadata_ = filter_metadata(metadata_, item[-1]) 259 | return metadata_ 260 | 261 | 262 | def write_hdf5(all_examples, all_features, all_results, 263 | max_answer_length, do_lower_case, hdf5_path, filter_threshold, verbose_logging, 264 | dense_offset=None, dense_scale=None, sparse_offset=None, sparse_scale=None, use_sparse=False): 265 | assert len(all_examples) > 0 266 | 267 | id2feature = {feature.unique_id: feature for feature in all_features} 268 | id2example = {id_: all_examples[id2feature[id_].example_index] for id_ in id2feature} 269 | 270 | def add(inqueue_, outqueue_): 271 | for item in iter(inqueue_.get, None): 272 | args = list(item[:3]) + [max_answer_length, do_lower_case, verbose_logging, filter_threshold] 273 | out = pool_func(args) 274 | outqueue_.put(out) 275 | 276 | outqueue_.put(None) 277 | 278 | def write(outqueue_): 279 | with h5py.File(hdf5_path) as f: 280 | while True: 281 | metadata = outqueue_.get() 282 | if metadata: 283 | did = str(metadata['did']) 284 | if did in f: 285 | logger.info('%s exists; replacing' % did) 286 | del f[did] 287 | dg = f.create_group(did) 288 | 289 | dg.attrs['context'] = metadata['context'] 290 | dg.attrs['title'] = metadata['title'] 291 | if dense_offset is not None: 292 | metadata = compress_metadata(metadata, dense_offset, dense_scale, sparse_offset, sparse_scale) 293 | dg.attrs['offset'] = dense_offset 294 | dg.attrs['scale'] = dense_scale 295 | dg.attrs['sparse_offset'] = sparse_offset 296 | dg.attrs['sparse_scale'] = sparse_scale 297 | dg.create_dataset('start', data=metadata['start']) 298 | dg.create_dataset('end', data=metadata['end']) 299 | if metadata['sparse'] is not None: 300 | dg.create_dataset('sparse', data=metadata['sparse']) 301 | if metadata['sparse_bi'] is not None: 302 | dg.create_dataset('sparse_bi', data=metadata['sparse_bi']) 303 | if metadata['sparse_tri'] is not None: 304 | dg.create_dataset('sparse_tri', data=metadata['sparse_tri']) 305 | dg.create_dataset('input_ids', data=metadata['input_ids']) 306 | dg.create_dataset('len_per_para', data=metadata['len_per_para']) 307 | dg.create_dataset('span_logits', data=metadata['span_logits']) 308 | dg.create_dataset('start2end', data=metadata['start2end']) 309 | dg.create_dataset('word2char_start', data=metadata['word2char_start']) 310 | dg.create_dataset('word2char_end', data=metadata['word2char_end']) 311 | dg.create_dataset('f2o_start', data=metadata['f2o_start']) 312 | dg.create_dataset('f2o_end', data=metadata['f2o_end']) 313 | 314 | else: 315 | break 316 | 317 | features = [] 318 | results = [] 319 | inqueue = Queue(maxsize=500) 320 | outqueue = Queue(maxsize=500) 321 | write_p = Thread(target=write, args=(outqueue,)) 322 | p = Thread(target=add, args=(inqueue, outqueue)) 323 | write_p.start() 324 | p.start() 325 | 326 | start_time = time() 327 | for count, result in enumerate(tqdm(all_results, total=len(all_features))): 328 | example = id2example[result.unique_id] 329 | feature = id2feature[result.unique_id] 330 | condition = len(features) > 0 and example.par_idx == 0 and feature.doc_span_index == 0 331 | 332 | if condition: 333 | in_ = (id2example, features, results) 334 | logger.info('inqueue size: %d, outqueue size: %d' % (inqueue.qsize(), outqueue.qsize())) 335 | inqueue.put(in_) 336 | # add(id2example, features, results) 337 | features = [feature] 338 | results = [result] 339 | else: 340 | features.append(feature) 341 | results.append(result) 342 | if count % 500 == 0: 343 | logger.info('%d/%d at %.1f' % (count + 1, len(all_features), time() - start_time)) 344 | in_ = (id2example, features, results) 345 | inqueue.put(in_) 346 | inqueue.put(None) 347 | p.join() 348 | write_p.join() 349 | 350 | import collections 351 | b_stats = collections.OrderedDict(sorted(b_quant_stat.items())) 352 | stats = collections.OrderedDict(sorted(quant_stat.items())) 353 | for k, v in b_stats.items(): 354 | print(k, v) 355 | for k, v in stats.items(): 356 | print(k, v) 357 | 358 | 359 | def write_embed(all_examples, all_features, all_results, max_answer_length, do_lower_case, embed_path): 360 | assert len(all_examples) > 0 361 | 362 | id2feature = {feature.unique_id: feature for feature in all_features} 363 | id2example = {id_: all_examples[id2feature[id_].example_index] for id_ in id2feature} 364 | features = [] 365 | results = [] 366 | outs = [] 367 | 368 | def write(features, results): 369 | out_json = {'uni': {}, 'bi': {}} 370 | for par_index in range(len(features)): 371 | start_index = 0 372 | for vidx, (v1, v2, v3) in enumerate(zip( 373 | features[par_index].tokens[1:-1], 374 | results[par_index].start_sp['1'][start_index][1:len(features[par_index].tokens)-1], 375 | features[par_index].input_ids[1:-1] 376 | )): 377 | if vidx == start_index-1: 378 | cprint('{}({:.3f}, {})'.format(v1, v2, vidx), 'green', end=' ') 379 | continue 380 | if v1 not in out_json['uni']: 381 | out_json['uni'][v1] = {'score': v2.item(), 'vocab': str(v3)} 382 | else: 383 | out_json['uni'][v1]['score'] += v2.item() 384 | if v2 > 1.0: 385 | cprint('{}({:.3f}, {})'.format(v1, v2, vidx), 'red', end=' ') 386 | else: 387 | print('{}({:.3f}, {})'.format(v1, v2, vidx), end=' ') 388 | print() 389 | 390 | for vidx, (v1, v2, v3) in enumerate(zip( 391 | zip(features[par_index].tokens[1:-2], features[par_index].tokens[2:-1]), 392 | results[par_index].start_sp['2'][start_index][1:len(features[par_index].tokens)-2], 393 | zip(features[par_index].input_ids[1:-2], features[par_index].input_ids[2:-1]) 394 | )): 395 | v1 = ' '.join(v1) 396 | if vidx == start_index-1: 397 | cprint('{}({:.3f}, {})'.format(v1, v2, vidx), 'green', end=' ') 398 | continue 399 | if v1 not in out_json['bi']: 400 | out_json['bi'][v1] = {'score': v2.item(), 'vocab': ', '.join([str(k) for k in v3])} 401 | else: 402 | out_json['bi'][v1]['score'] += v2.item() 403 | if v2 > 1.0: 404 | cprint('{}({:.3f}, {})'.format(v1, v2, vidx), 'red', end=' ') 405 | else: 406 | print('{}({:.3f}, {})'.format(v1, v2, vidx), end=' ') 407 | print() 408 | print() 409 | for key, val in out_json.items(): 410 | out_json[key] = dict(sorted(out_json[key].items(), key=lambda x: x[1]['score'], reverse=True)) 411 | out_json[key] = dict(filter(lambda x: x[1]['score'] > 0, out_json[key].items())) 412 | return out_json 413 | 414 | for count, result in enumerate(tqdm(all_results, total=len(all_features))): 415 | example = id2example[result.unique_id] 416 | feature = id2feature[result.unique_id] 417 | condition = len(features) > 0 and example.par_idx == 0 and feature.doc_span_index == 0 418 | 419 | if condition: 420 | outs.append({'text': prev_example.paragraph_text, 'sparc': write(features, results)}) 421 | features = [feature] 422 | results = [result] 423 | else: 424 | features.append(feature) 425 | results.append(result) 426 | prev_example = example 427 | outs.append({'text': prev_example.paragraph_text, 'sparc': write(features, results)}) 428 | 429 | with open(embed_path, 'w') as f: 430 | json.dump({'out': outs}, f, indent=4) 431 | 432 | 433 | def write_predictions(all_examples, all_features, all_results, 434 | max_answer_length, do_lower_case, output_prediction_file, 435 | output_score_file, verbose_logging, threshold): 436 | 437 | id2feature = {feature.unique_id: feature for feature in all_features} 438 | id2example = {id_: all_examples[id2feature[id_].example_index] for id_ in id2feature} 439 | 440 | token_count = 0 441 | vec_count = 0 442 | predictions = {} 443 | scores = {} 444 | loss = [] 445 | 446 | for result in tqdm(all_results, total=len(all_features), desc='[Evaluation]'): 447 | loss += [result.loss] 448 | feature = id2feature[result.unique_id] 449 | example = id2example[result.unique_id] 450 | id_ = example.qas_id 451 | 452 | # Initial setting 453 | token_count += len(feature.tokens) 454 | 455 | for start_index in range(len(feature.tokens)): 456 | for end_index in range(start_index, min(len(feature.tokens), start_index + max_answer_length - 1)): 457 | if start_index not in feature.token_to_word_map: 458 | continue 459 | if end_index not in feature.token_to_word_map: 460 | continue 461 | if not feature.token_is_max_context.get(start_index, False): 462 | continue 463 | filter_start_logit = result.filter_start_logits[start_index] 464 | filter_end_logit = result.filter_end_logits[end_index] 465 | 466 | # Filter based on threshold (default: -2) 467 | if filter_start_logit < threshold or filter_end_logit < threshold: 468 | # orig_text, start_pos, end_pos = get_final_text_(example, feature, start_index, end_index, 469 | # do_lower_case, verbose_logging) 470 | # print('Filter: %s (%.2f, %.2f)'% (orig_text[start_pos:end_pos], filter_start_logit, filter_end_logit)) 471 | continue 472 | else: 473 | # orig_text, start_pos, end_pos = get_final_text_(example, feature, start_index, end_index, 474 | # do_lower_case, verbose_logging) 475 | # print('Saved: %s (%.2f, %.2f)'% (orig_text[start_pos:end_pos], filter_start_logit, filter_end_logit)) 476 | pass 477 | 478 | vec_count += 1 479 | score = result.all_logits[start_index, end_index] 480 | 481 | if id_ not in scores or score > scores[id_]: 482 | orig_text, start_pos, end_pos = get_final_text_(example, feature, start_index, end_index, 483 | do_lower_case, verbose_logging) 484 | # print('Saved: %s (%.2f, %.2f)'% (orig_text[start_pos:end_pos], filter_start_logit, filter_end_logit)) 485 | phrase = orig_text[start_pos:end_pos] 486 | predictions[id_] = phrase 487 | scores[id_] = score.item() 488 | 489 | if id_ not in predictions: 490 | assert id_ not in scores 491 | logger.info('for %s, no answer found'% id_) 492 | 493 | logger.info('num vecs=%d, num_words=%d, nvpw=%.4f' % (vec_count, token_count, vec_count / token_count)) 494 | 495 | with open(output_prediction_file, 'w') as fp: 496 | json.dump(predictions, fp) 497 | 498 | with open(output_score_file, 'w') as fp: 499 | json.dump({k: -v for (k, v) in scores.items()}, fp) 500 | 501 | return sum(loss) / len(loss) 502 | 503 | 504 | def get_question_results(question_examples, query_eval_features, question_dataloader, device, model): 505 | id2feature = {feature.unique_id: feature for feature in query_eval_features} 506 | id2example = {id_: question_examples[id2feature[id_].example_index] for id_ in id2feature} 507 | for (input_ids_, input_mask_, example_indices) in question_dataloader: 508 | input_ids_ = input_ids_.to(device) 509 | input_mask_ = input_mask_.to(device) 510 | with torch.no_grad(): 511 | batch_start, batch_end, batch_sps, batch_eps = model(query_ids=input_ids_, 512 | query_mask=input_mask_) 513 | for i, example_index in enumerate(example_indices): 514 | start = batch_start[i].detach().cpu().numpy().astype(np.float16) 515 | end = batch_end[i].detach().cpu().numpy().astype(np.float16) 516 | sparse = None 517 | if len(batch_sps) > 0: 518 | sparse = {ng: bb_ssp[i].detach().cpu().numpy().astype(np.float16) for ng, bb_ssp in batch_sps.items()} 519 | query_eval_feature = query_eval_features[example_index.item()] 520 | unique_id = int(query_eval_feature.unique_id) 521 | qas_id = id2example[unique_id].qas_id 522 | yield QuestionResult(qas_id=qas_id, 523 | start=start, 524 | end=end, 525 | sparse=sparse, 526 | input_ids=query_eval_feature.input_ids[1:len(query_eval_feature.tokens_)-1]) 527 | 528 | 529 | def write_question_results(question_results, question_features, path): 530 | with h5py.File(path, 'w') as f: 531 | for question_result, question_feature in tqdm(zip(question_results, question_features)): 532 | dummy_ones = np.ones((question_result.start.shape[0], 1)) 533 | data = np.concatenate([question_result.start, question_result.end, dummy_ones], -1) 534 | f.create_dataset(question_result.qas_id, data=data) 535 | 536 | 537 | def convert_question_features_to_dataloader(query_eval_features, fp16, local_rank, predict_batch_size): 538 | all_input_ids_ = torch.tensor([f.input_ids for f in query_eval_features], dtype=torch.long) 539 | all_input_mask_ = torch.tensor([f.input_mask for f in query_eval_features], dtype=torch.long) 540 | all_example_index_ = torch.arange(all_input_ids_.size(0), dtype=torch.long) 541 | if fp16: 542 | all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_)) 543 | 544 | question_data = TensorDataset(all_input_ids_, all_input_mask_, all_example_index_) 545 | 546 | if local_rank == -1: 547 | question_sampler = SequentialSampler(question_data) 548 | else: 549 | question_sampler = DistributedSampler(question_data) 550 | question_dataloader = DataLoader(question_data, sampler=question_sampler, batch_size=predict_batch_size) 551 | return question_dataloader 552 | 553 | 554 | def get_final_text_(example, feature, start_index, end_index, do_lower_case, verbose_logging): 555 | tok_tokens = feature.tokens[start_index:(end_index + 1)] 556 | orig_doc_start = feature.token_to_word_map[start_index] 557 | orig_doc_end = feature.token_to_word_map[end_index] 558 | orig_words = example.doc_words[orig_doc_start:(orig_doc_end + 1)] 559 | tok_text = " ".join(tok_tokens) 560 | 561 | # De-tokenize WordPieces that have been split off. 562 | tok_text = tok_text.replace(" ##", "") 563 | tok_text = tok_text.replace("##", "") 564 | 565 | # Clean whitespace 566 | tok_text = tok_text.strip() 567 | tok_text = " ".join(tok_text.split()) 568 | orig_text = " ".join(orig_words) 569 | full_text = " ".join(example.doc_words) 570 | 571 | start_pos, end_pos = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) # TODO: need to check 572 | offset = sum(len(word) + 1 for word in example.doc_words[:orig_doc_start]) 573 | 574 | return full_text, offset + start_pos, offset + end_pos 575 | 576 | 577 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 578 | """Project the tokenized prediction back to the original text.""" 579 | 580 | # When we created the data, we kept track of the alignment between original 581 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 582 | # now `orig_text` contains the span of our original text corresponding to the 583 | # span that we predicted. 584 | # 585 | # However, `orig_text` may contain extra characters that we don't want in 586 | # our prediction. 587 | # 588 | # For example, let's say: 589 | # pred_text = steve smith 590 | # orig_text = Steve Smith's 591 | # 592 | # We don't want to return `orig_text` because it contains the extra "'s". 593 | # 594 | # We don't want to return `pred_text` because it's already been normalized 595 | # (the SQuAD eval script also does punctuation stripping/lower casing but 596 | # our tokenizer does additional normalization like stripping accent 597 | # characters). 598 | # 599 | # What we really want to return is "Steve Smith". 600 | # 601 | # Therefore, we have to apply a semi-complicated alignment heruistic between 602 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 603 | # can fail in certain cases in which case we just return `orig_text`. 604 | default_out = 0, len(orig_text) 605 | 606 | def _strip_spaces(text): 607 | ns_chars = [] 608 | ns_to_s_map = collections.OrderedDict() 609 | for (i, c) in enumerate(text): 610 | if c == " ": 611 | continue 612 | ns_to_s_map[len(ns_chars)] = i 613 | ns_chars.append(c) 614 | ns_text = "".join(ns_chars) 615 | return (ns_text, ns_to_s_map) 616 | 617 | # We first tokenize `orig_text`, strip whitespace from the result 618 | # and `pred_text`, and check if they are the same length. If they are 619 | # NOT the same length, the heuristic has failed. If they are the same 620 | # length, we assume the characters are one-to-one aligned. 621 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 622 | 623 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 624 | 625 | start_position = tok_text.find(pred_text) 626 | if start_position == -1: 627 | if verbose_logging: 628 | logger.info( 629 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 630 | return default_out 631 | end_position = start_position + len(pred_text) - 1 632 | 633 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 634 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 635 | 636 | if len(orig_ns_text) != len(tok_ns_text): 637 | if verbose_logging: 638 | logger.info("Length not equal after stripping spaces: '%s' vs '%s'", 639 | orig_ns_text, tok_ns_text) 640 | return default_out 641 | 642 | # We then project the characters in `pred_text` back to `orig_text` using 643 | # the character-to-character alignment. 644 | tok_s_to_ns_map = {} 645 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 646 | tok_s_to_ns_map[tok_index] = i 647 | 648 | orig_start_position = None 649 | if start_position in tok_s_to_ns_map: 650 | ns_start_position = tok_s_to_ns_map[start_position] 651 | if ns_start_position in orig_ns_to_s_map: 652 | orig_start_position = orig_ns_to_s_map[ns_start_position] 653 | 654 | if orig_start_position is None: 655 | if verbose_logging: 656 | logger.info("Couldn't map start position") 657 | return default_out 658 | 659 | orig_end_position = None 660 | if end_position in tok_s_to_ns_map: 661 | ns_end_position = tok_s_to_ns_map[end_position] 662 | if ns_end_position in orig_ns_to_s_map: 663 | orig_end_position = orig_ns_to_s_map[ns_end_position] 664 | 665 | if orig_end_position is None: 666 | if verbose_logging: 667 | logger.info("Couldn't map end position") 668 | return default_out 669 | 670 | # output_text = orig_text[orig_start_position:(orig_end_position + 1)] 671 | return orig_start_position, orig_end_position + 1 672 | 673 | 674 | def float_to_int8(num, offset, factor, keep_zeros=False): 675 | out = (num - offset) * factor 676 | out = out.clip(-128, 127) 677 | if keep_zeros: 678 | out = out * (num != 0.0).astype(np.int8) 679 | out = np.round(out).astype(np.int8) 680 | return out 681 | 682 | 683 | def int8_to_float(num, offset, factor, keep_zeros=False): 684 | if not keep_zeros: 685 | return num.astype(np.float32) / factor + offset 686 | else: 687 | return (num.astype(np.float32) / factor + offset) * (num != 0.0).astype(np.float32) 688 | 689 | -------------------------------------------------------------------------------- /run_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | import faiss 7 | import h5py 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from post import int8_to_float 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('dump_dir') 17 | parser.add_argument('stage') 18 | 19 | parser.add_argument('--dump_paths', default=None, 20 | help='Relative to `dump_dir/phrase`. ' 21 | 'If specified, creates subindex dir and save there with same name') 22 | parser.add_argument('--subindex_name', default='index', help='used only if dump_path is specified.') 23 | parser.add_argument('--offset', default=0, type=int) 24 | 25 | # relative paths in dump_dir/index_name 26 | parser.add_argument('--quantizer_path', default='quantizer.faiss') 27 | parser.add_argument('--max_norm_path', default='max_norm.json') 28 | parser.add_argument('--trained_index_path', default='trained.faiss') 29 | parser.add_argument('--index_path', default='index.faiss') 30 | parser.add_argument('--idx2id_path', default='idx2id.hdf5') 31 | parser.add_argument('--inv_path', default='merged.invdata') 32 | 33 | # Adding options 34 | parser.add_argument('--add_all', default=False, action='store_true') 35 | 36 | # coarse, fine, add 37 | parser.add_argument('--num_clusters', type=int, default=16384) 38 | parser.add_argument('--hnsw', default=False, action='store_true') 39 | parser.add_argument('--fine_quant', default='SQ8', 40 | help='SQ8|SQ4|PQ# where # is number of bytes per vector (for SQ it would be 480 Bytes)') 41 | # stable params 42 | parser.add_argument('--max_norm', default=None, type=float) 43 | parser.add_argument('--max_norm_cf', default=1.0, type=float) 44 | parser.add_argument('--norm_th', default=999, type=float) 45 | parser.add_argument('--para', default=False, action='store_true') 46 | parser.add_argument('--doc_sample_ratio', default=0.2, type=float) 47 | parser.add_argument('--vec_sample_ratio', default=0.2, type=float) 48 | 49 | parser.add_argument('--fs', default='local') 50 | parser.add_argument('--cuda', default=False, action='store_true') 51 | parser.add_argument('--num_dummy_zeros', default=0, type=int) 52 | parser.add_argument('--replace', default=False, action='store_true') 53 | parser.add_argument('--num_docs_per_add', default=1000, type=int) 54 | 55 | args = parser.parse_args() 56 | 57 | coarse = 'hnsw' if args.hnsw else 'flat' 58 | args.index_name = '%d_%s_%s' % (args.num_clusters, coarse, args.fine_quant) 59 | 60 | if args.fs == 'nfs': 61 | from nsml import NSML_NFS_OUTPUT 62 | args.dump_dir = os.path.join(NSML_NFS_OUTPUT, args.dump_dir) 63 | elif args.fs == 'nsml': 64 | pass 65 | 66 | args.index_dir = os.path.join(args.dump_dir, args.index_name) 67 | 68 | args.quantizer_path = os.path.join(args.index_dir, args.quantizer_path) 69 | args.max_norm_path = os.path.join(args.index_dir, args.max_norm_path) 70 | args.trained_index_path = os.path.join(args.index_dir, args.trained_index_path) 71 | args.inv_path = os.path.join(args.index_dir, args.inv_path) 72 | 73 | args.subindex_dir = os.path.join(args.index_dir, args.subindex_name) 74 | if args.dump_paths is None: 75 | args.index_path = os.path.join(args.index_dir, args.index_path) 76 | args.idx2id_path = os.path.join(args.index_dir, args.idx2id_path) 77 | else: 78 | args.dump_paths = [os.path.join(args.dump_dir, 'phrase', path) for path in args.dump_paths.split(',')] 79 | args.index_path = os.path.join(args.subindex_dir, '%d.faiss' % args.offset) 80 | args.idx2id_path = os.path.join(args.subindex_dir, '%d.hdf5' % args.offset) 81 | 82 | return args 83 | 84 | 85 | def sample_data(dump_paths, para=False, doc_sample_ratio=0.2, vec_sample_ratio=0.2, seed=29, 86 | max_norm=None, max_norm_cf=1.3, num_dummy_zeros=0, norm_th=999): 87 | vecs = [] 88 | random.seed(seed) 89 | np.random.seed(seed) 90 | print('sampling from:') 91 | for dump_path in dump_paths: 92 | print(dump_path) 93 | dumps = [h5py.File(dump_path, 'r') for dump_path in dump_paths] 94 | for i, f in enumerate(tqdm(dumps)): 95 | doc_ids = list(f.keys()) 96 | sampled_doc_ids = random.sample(doc_ids, int(doc_sample_ratio * len(doc_ids))) 97 | for doc_id in tqdm(sampled_doc_ids, desc='sampling from %d' % i): 98 | doc_group = f[doc_id] 99 | if para: 100 | groups = doc_group.values() 101 | else: 102 | groups = [doc_group] 103 | for group in groups: 104 | num_vecs, d = group['start'].shape 105 | if num_vecs == 0: continue 106 | sampled_vec_idxs = np.random.choice(num_vecs, int(vec_sample_ratio * num_vecs)) 107 | cur_vecs = int8_to_float(group['start'][:], 108 | group.attrs['offset'], group.attrs['scale'])[sampled_vec_idxs] 109 | cur_vecs = cur_vecs[np.linalg.norm(cur_vecs, axis=1) <= norm_th] 110 | vecs.append(cur_vecs) 111 | out = np.concatenate(vecs, 0) 112 | for dump in dumps: 113 | dump.close() 114 | 115 | norms = np.linalg.norm(out, axis=1, keepdims=True) 116 | if max_norm is None: 117 | max_norm = max_norm_cf * np.max(norms) 118 | consts = np.sqrt(np.maximum(0.0, max_norm ** 2 - norms ** 2)) 119 | out = np.concatenate([consts, out], axis=1) 120 | if num_dummy_zeros > 0: 121 | out = np.concatenate([out, np.zeros([out.shape[0], num_dummy_zeros], dtype=out.dtype)], axis=1) 122 | return out, max_norm 123 | 124 | 125 | def train_coarse_quantizer(data, quantizer_path, num_clusters, hnsw=False, niter=10, cuda=False): 126 | d = data.shape[1] 127 | 128 | index_flat = faiss.IndexFlatL2(d) 129 | # make it into a gpu index 130 | if cuda: 131 | res = faiss.StandardGpuResources() 132 | index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat) 133 | clus = faiss.Clustering(d, num_clusters) 134 | clus.verbose = True 135 | clus.niter = niter 136 | clus.train(data, index_flat) 137 | centroids = faiss.vector_float_to_array(clus.centroids) 138 | centroids = centroids.reshape(num_clusters, d) 139 | 140 | if hnsw: 141 | quantizer = faiss.IndexHNSWFlat(d, 32) 142 | quantizer.hnsw.efSearch = 128 143 | quantizer.train(centroids) 144 | quantizer.add(centroids) 145 | else: 146 | quantizer = faiss.IndexFlatL2(d) 147 | quantizer.add(centroids) 148 | 149 | faiss.write_index(quantizer, quantizer_path) 150 | 151 | 152 | def train_index(data, quantizer_path, trained_index_path, fine_quant='SQ8', cuda=False): 153 | quantizer = faiss.read_index(quantizer_path) 154 | if fine_quant == 'SQ8': 155 | trained_index = faiss.IndexIVFScalarQuantizer(quantizer, quantizer.d, quantizer.ntotal, faiss.METRIC_L2) 156 | elif fine_quant.startswith('PQ'): 157 | m = int(fine_quant[2:]) 158 | trained_index = faiss.IndexIVFPQ(quantizer, quantizer.d, quantizer.ntotal, m, 8) 159 | else: 160 | raise ValueError(fine_quant) 161 | 162 | if cuda: 163 | if fine_quant.startswith('PQ'): 164 | print('PQ not supported on GPU; keeping CPU.') 165 | else: 166 | res = faiss.StandardGpuResources() 167 | gpu_index = faiss.index_cpu_to_gpu(res, 0, trained_index) 168 | gpu_index.train(data) 169 | trained_index = faiss.index_gpu_to_cpu(gpu_index) 170 | else: 171 | trained_index.train(data) 172 | faiss.write_index(trained_index, trained_index_path) 173 | 174 | 175 | def add_with_offset(index, data, offset, valids=None): 176 | ids = np.arange(data.shape[0]) + offset + index.ntotal 177 | if valids is not None: 178 | data = data[valids] 179 | ids = ids[valids] 180 | index.add_with_ids(data, ids) 181 | 182 | 183 | def add_to_index(dump_paths, trained_index_path, target_index_path, idx2id_path, max_norm, para=False, 184 | num_docs_per_add=1000, num_dummy_zeros=0, cuda=False, fine_quant='SQ8', offset=0, norm_th=999, 185 | ignore_ids=None): 186 | idx2doc_id = [] 187 | idx2para_id = [] 188 | idx2word_id = [] 189 | dumps = [h5py.File(dump_path, 'r') for dump_path in dump_paths] 190 | print('reading %s' % trained_index_path) 191 | start_index = faiss.read_index(trained_index_path) 192 | 193 | if cuda: 194 | if fine_quant.startswith('PQ'): 195 | print('PQ not supported on GPU; keeping CPU.') 196 | else: 197 | res = faiss.StandardGpuResources() 198 | start_index = faiss.index_cpu_to_gpu(res, 0, start_index) 199 | 200 | print('adding following dumps:') 201 | for dump_path in dump_paths: 202 | print(dump_path) 203 | if para: 204 | for di, phrase_dump in enumerate(tqdm(dumps, desc='dumps')): 205 | starts = [] 206 | for i, (doc_idx, doc_group) in enumerate(tqdm(phrase_dump.items(), desc='faiss indexing')): 207 | for para_idx, group in doc_group.items(): 208 | num_vecs = group['start'].shape[0] 209 | start = int8_to_float(group['start'][:], group.attrs['offset'], group.attrs['scale']) 210 | norms = np.linalg.norm(start, axis=1, keepdims=True) 211 | consts = np.sqrt(np.maximum(0.0, max_norm ** 2 - norms ** 2)) 212 | start = np.concatenate([consts, start], axis=1) 213 | if num_dummy_zeros > 0: 214 | start = np.concatenate( 215 | [start, np.zeros([start.shape[0], num_dummy_zeros], dtype=start.dtype)], axis=1) 216 | starts.append(start) 217 | idx2doc_id.extend([int(doc_idx)] * num_vecs) 218 | idx2para_id.extend([int(para_idx)] * num_vecs) 219 | idx2word_id.extend(list(range(num_vecs))) 220 | if len(starts) > 0 and i % num_docs_per_add == 0: 221 | print('concatenating') 222 | concat = np.concatenate(starts, axis=0) 223 | print('adding') 224 | add_with_offset(start_index, concat, offset) 225 | # start_index.add(concat) 226 | print('done') 227 | starts = [] 228 | if i % 100 == 0: 229 | print('%d/%d' % (i + 1, len(phrase_dump.keys()))) 230 | print('adding leftover') 231 | add_with_offset(start_index, np.concatenate(starts, axis=0), offset) 232 | # start_index.add(np.concatenate(starts, axis=0)) # leftover 233 | print('done') 234 | else: 235 | for di, phrase_dump in enumerate(tqdm(dumps, desc='dumps')): 236 | starts = [] 237 | valids = [] 238 | for i, (doc_idx, doc_group) in enumerate(tqdm(phrase_dump.items(), desc='adding %d' % di)): 239 | if ignore_ids is not None and doc_idx in ignore_ids: 240 | continue 241 | num_vecs = doc_group['start'].shape[0] 242 | start = int8_to_float(doc_group['start'][:], doc_group.attrs['offset'], 243 | doc_group.attrs['scale']) 244 | valid = np.linalg.norm(start, axis=1) <= norm_th 245 | norms = np.linalg.norm(start, axis=1, keepdims=True) 246 | consts = np.sqrt(np.maximum(0.0, max_norm ** 2 - norms ** 2)) 247 | start = np.concatenate([consts, start], axis=1) 248 | if num_dummy_zeros > 0: 249 | start = np.concatenate([start, np.zeros([start.shape[0], num_dummy_zeros], dtype=start.dtype)], 250 | axis=1) 251 | starts.append(start) 252 | valids.append(valid) 253 | idx2doc_id.extend([int(doc_idx)] * num_vecs) 254 | idx2word_id.extend(range(num_vecs)) 255 | if len(starts) > 0 and i % num_docs_per_add == 0: 256 | print('adding at %d' % (i+1)) 257 | add_with_offset(start_index, np.concatenate(starts, axis=0), offset, np.concatenate(valids)) 258 | # start_index.add(np.concatenate(starts, axis=0)) 259 | starts = [] 260 | valids = [] 261 | if i % 100 == 0: 262 | # print('%d/%d' % (i + 1, len(phrase_dump.keys()))) 263 | continue 264 | print('final adding at %d' % (i+1)) 265 | add_with_offset(start_index, np.concatenate(starts, axis=0), offset, np.concatenate(valids)) 266 | # start_index.add(np.concatenate(starts, axis=0)) # leftover 267 | 268 | for dump in dumps: 269 | dump.close() 270 | 271 | if cuda and not fine_quant.startswith('PQ'): 272 | print('moving back to cpu') 273 | start_index = faiss.index_gpu_to_cpu(start_index) 274 | 275 | print('index ntotal: %d' % start_index.ntotal) 276 | idx2doc_id = np.array(idx2doc_id, dtype=np.int32) 277 | idx2para_id = np.array(idx2para_id, dtype=np.int32) 278 | idx2word_id = np.array(idx2word_id, dtype=np.int32) 279 | 280 | print('writing index and metadata') 281 | with h5py.File(idx2id_path, 'w') as f: 282 | g = f.create_group(str(offset)) 283 | g.create_dataset('doc', data=idx2doc_id) 284 | g.create_dataset('para', data=idx2para_id) 285 | g.create_dataset('word', data=idx2word_id) 286 | g.attrs['offset'] = offset 287 | faiss.write_index(start_index, target_index_path) 288 | print('done') 289 | 290 | 291 | def merge_indexes(subindex_dir, trained_index_path, target_index_path, target_idx2id_path, target_inv_path): 292 | # target_inv_path = merged_index.ivfdata 293 | names = os.listdir(subindex_dir) 294 | idx2id_paths = [os.path.join(subindex_dir, name) for name in names if name.endswith('.hdf5')] 295 | index_paths = [os.path.join(subindex_dir, name) for name in names if name.endswith('.faiss')] 296 | print(len(idx2id_paths)) 297 | print(len(index_paths)) 298 | 299 | print('copying idx2id') 300 | with h5py.File(target_idx2id_path, 'w') as out: 301 | for idx2id_path in tqdm(idx2id_paths, desc='copying idx2id'): 302 | with h5py.File(idx2id_path, 'r') as in_: 303 | for key, g in in_.items(): 304 | offset = str(g.attrs['offset']) 305 | assert key == offset 306 | group = out.create_group(offset) 307 | group.create_dataset('doc', data=in_[key]['doc']) 308 | group.create_dataset('para', data=in_[key]['para']) 309 | group.create_dataset('word', data=in_[key]['word']) 310 | 311 | print('loading invlists') 312 | ivfs = [] 313 | for index_path in tqdm(index_paths, desc='loading invlists'): 314 | # the IO_FLAG_MMAP is to avoid actually loading the data thus 315 | # the total size of the inverted lists can exceed the 316 | # available RAM 317 | index = faiss.read_index(index_path, 318 | faiss.IO_FLAG_MMAP) 319 | ivfs.append(index.invlists) 320 | 321 | # avoid that the invlists get deallocated with the index 322 | index.own_invlists = False 323 | 324 | # construct the output index 325 | index = faiss.read_index(trained_index_path) 326 | 327 | # prepare the output inverted lists. They will be written 328 | # to merged_index.ivfdata 329 | invlists = faiss.OnDiskInvertedLists( 330 | index.nlist, index.code_size, 331 | target_inv_path) 332 | 333 | # merge all the inverted lists 334 | print('merging') 335 | ivf_vector = faiss.InvertedListsPtrVector() 336 | for ivf in tqdm(ivfs): 337 | ivf_vector.push_back(ivf) 338 | 339 | print("merge %d inverted lists " % ivf_vector.size()) 340 | ntotal = invlists.merge_from(ivf_vector.data(), ivf_vector.size()) 341 | print(ntotal) 342 | 343 | # now replace the inverted lists in the output index 344 | index.ntotal = ntotal 345 | index.replace_invlists(invlists) 346 | 347 | print('writing index') 348 | faiss.write_index(index, target_index_path) 349 | 350 | 351 | def run_index(args): 352 | phrase_path = os.path.join(args.dump_dir, 'phrase.hdf5') 353 | if os.path.exists(phrase_path): 354 | dump_paths = [phrase_path] 355 | else: 356 | dump_names = os.listdir(os.path.join(args.dump_dir, 'phrase')) 357 | dump_paths = [os.path.join(args.dump_dir, 'phrase', name) for name in dump_names if name.endswith('.hdf5')] 358 | 359 | data = None 360 | if args.stage in ['all', 'coarse']: 361 | if args.replace or not os.path.exists(args.quantizer_path): 362 | if not os.path.exists(args.index_dir): 363 | os.makedirs(args.index_dir) 364 | data, max_norm = sample_data(dump_paths, max_norm=args.max_norm, para=args.para, 365 | doc_sample_ratio=args.doc_sample_ratio, vec_sample_ratio=args.vec_sample_ratio, 366 | max_norm_cf=args.max_norm_cf, num_dummy_zeros=args.num_dummy_zeros, 367 | norm_th=args.norm_th) 368 | with open(args.max_norm_path, 'w') as fp: 369 | json.dump(max_norm, fp) 370 | train_coarse_quantizer(data, args.quantizer_path, args.num_clusters, cuda=args.cuda) 371 | 372 | if args.stage in ['all', 'fine']: 373 | if args.replace or not os.path.exists(args.trained_index_path): 374 | with open(args.max_norm_path, 'r') as fp: 375 | max_norm = json.load(fp) 376 | if data is None: 377 | data, _ = sample_data(dump_paths, max_norm=max_norm, para=args.para, 378 | doc_sample_ratio=args.doc_sample_ratio, vec_sample_ratio=args.vec_sample_ratio, 379 | num_dummy_zeros=args.num_dummy_zeros, norm_th=args.norm_th) 380 | train_index(data, args.quantizer_path, args.trained_index_path, fine_quant=args.fine_quant, cuda=args.cuda) 381 | 382 | if args.stage in ['all', 'add']: 383 | if args.replace or not os.path.exists(args.index_path): 384 | with open(args.max_norm_path, 'r') as fp: 385 | max_norm = json.load(fp) 386 | if args.dump_paths is not None: 387 | dump_paths = args.dump_paths 388 | if not os.path.exists(args.subindex_dir): 389 | os.makedirs(args.subindex_dir) 390 | add_to_index(dump_paths, args.trained_index_path, args.index_path, args.idx2id_path, 391 | max_norm=max_norm, para=args.para, num_dummy_zeros=args.num_dummy_zeros, cuda=args.cuda, 392 | num_docs_per_add=args.num_docs_per_add, offset=args.offset, norm_th=args.norm_th, 393 | fine_quant=args.fine_quant) 394 | 395 | if args.stage == 'merge': 396 | if args.replace or not os.path.exists(args.index_path): 397 | merge_indexes(args.subindex_dir, args.trained_index_path, args.index_path, args.idx2id_path, args.inv_path) 398 | 399 | if args.stage == 'move': 400 | index = faiss.read_index(args.trained_index_path) 401 | invlists = faiss.OnDiskInvertedLists( 402 | index.nlist, index.code_size, 403 | args.inv_path) 404 | index.replace_invlists(invlists) 405 | faiss.write_index(index, args.index_path) 406 | 407 | 408 | def main(): 409 | args = get_args() 410 | run_index(args) 411 | 412 | 413 | if __name__ == '__main__': 414 | main() 415 | -------------------------------------------------------------------------------- /run_server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import torch 4 | import tokenization 5 | import os 6 | import random 7 | import numpy as np 8 | import requests 9 | import logging 10 | import math 11 | import ssl 12 | import copy 13 | from time import time 14 | from flask import Flask, request, jsonify, render_template, redirect 15 | from flask_cors import CORS 16 | from tornado.wsgi import WSGIContainer 17 | from tornado.httpserver import HTTPServer 18 | from tornado.ioloop import IOLoop 19 | from requests_futures.sessions import FuturesSession 20 | from tqdm import tqdm 21 | from collections import namedtuple 22 | 23 | from modeling import BertConfig 24 | from modeling import DenSPI 25 | from tfidf_doc_ranker import TfidfDocRanker 26 | from utils import check_diff 27 | from pre import SquadExample, convert_questions_to_features 28 | from post import convert_question_features_to_dataloader, get_question_results 29 | from mips_phrase import MIPS 30 | 31 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 32 | level=logging.INFO) 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | class DenSPIServer(object): 37 | def __init__(self, args): 38 | self.args = args 39 | # IP and Ports 40 | self.base_ip = args.base_ip 41 | self.query_port = args.query_port 42 | self.doc_port = args.doc_port 43 | self.index_port = args.index_port 44 | 45 | # Saved objects 46 | self.mips = None 47 | 48 | def load_query_encoder(self, device, args): 49 | # Configure paths for query encoder serving 50 | vocab_path = os.path.join(args.metadata_dir, args.vocab_name) 51 | bert_config_path = os.path.join( 52 | args.metadata_dir, args.bert_config_name.replace(".json", "") + "_" + args.bert_model_option + ".json" 53 | ) 54 | 55 | # Load pretrained QueryEncoder 56 | bert_config = BertConfig.from_json_file(bert_config_path) 57 | model = DenSPI(bert_config) 58 | if args.parallel: 59 | model = torch.nn.DataParallel(model) 60 | state = torch.load(args.query_encoder_path, map_location='cpu') 61 | model.load_state_dict(state['model'], strict=False) 62 | check_diff(model.state_dict(), state['model']) 63 | model.to(device) 64 | logger.info('Model loaded from %s' % args.query_encoder_path) 65 | logger.info('Number of model parameters: {:,}'.format(sum(p.numel() for p in model.parameters()))) 66 | 67 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path, do_lower_case=not args.do_case) 68 | return model, tokenizer 69 | 70 | def get_question_dataloader(self, questions, tokenizer, batch_size): 71 | question_examples = [SquadExample(qas_id='qs', question_text=q) for q in questions] 72 | query_features = convert_questions_to_features( 73 | examples=question_examples, 74 | tokenizer=tokenizer, 75 | max_query_length=64 76 | ) 77 | question_dataloader = convert_question_features_to_dataloader( 78 | query_features, 79 | fp16=False, local_rank=-1, 80 | predict_batch_size=batch_size 81 | ) 82 | return question_dataloader, question_examples, query_features 83 | 84 | def serve_query_encoder(self, query_port, args): 85 | device = 'cuda' if args.cuda else 'cpu' 86 | query_encoder, tokenizer = self.load_query_encoder(device, args) 87 | 88 | # Define query to vector function 89 | def query2vec(queries): 90 | question_dataloader, question_examples, query_features = self.get_question_dataloader( 91 | queries, tokenizer, batch_size=24 92 | ) 93 | query_encoder.eval() 94 | question_results = get_question_results( 95 | question_examples, query_features, question_dataloader, device, query_encoder 96 | ) 97 | outs = [] 98 | for qr_idx, question_result in enumerate(question_results): 99 | for ngram in question_result.sparse.keys(): 100 | question_result.sparse[ngram] = question_result.sparse[ngram].tolist() 101 | out = ( 102 | question_result.start.tolist(), question_result.end.tolist(), 103 | question_result.sparse, question_result.input_ids 104 | ) 105 | outs.append(out) 106 | return outs 107 | 108 | # Serve query encoder 109 | app = Flask(__name__) 110 | app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False 111 | CORS(app) 112 | 113 | @app.route('/batch_api', methods=['POST']) 114 | def batch_api(): 115 | batch_query = json.loads(request.form['query']) 116 | outs = query2vec(batch_query) 117 | return jsonify(outs) 118 | 119 | logger.info(f'Starting QueryEncoder server at {self.get_address(query_port)}') 120 | http_server = HTTPServer(WSGIContainer(app)) 121 | http_server.listen(query_port) 122 | IOLoop.instance().start() 123 | 124 | def load_phrase_index(self, args, dump_only=False): 125 | if self.mips is not None: 126 | return self.mips 127 | 128 | # Configure paths for index serving 129 | phrase_dump_dir = os.path.join(args.dump_dir, args.phrase_dir) 130 | tfidf_dump_dir = os.path.join(args.dump_dir, args.tfidf_dir) 131 | index_dir = os.path.join(args.dump_dir, args.index_dir) 132 | index_path = os.path.join(index_dir, args.index_name) 133 | idx2id_path = os.path.join(index_dir, args.idx2id_name) 134 | max_norm_path = os.path.join(index_dir, 'max_norm.json') 135 | 136 | # Load mips 137 | mips_init = MIPS 138 | mips = mips_init( 139 | phrase_dump_dir=phrase_dump_dir, 140 | tfidf_dump_dir=tfidf_dump_dir, 141 | start_index_path=index_path, 142 | idx2id_path=idx2id_path, 143 | max_norm_path=max_norm_path, 144 | doc_rank_fn={ 145 | 'index': self.get_doc_scores, 'top_docs': self.get_top_docs, 'spvec': self.get_q_spvecs 146 | }, 147 | cuda=args.cuda, dump_only=dump_only 148 | ) 149 | return mips 150 | 151 | def serve_phrase_index(self, index_port, args): 152 | args.examples_path = os.path.join('static', args.examples_path) 153 | 154 | # Load mips 155 | self.mips = self.load_phrase_index(args) 156 | app = Flask(__name__, static_url_path='/static') 157 | app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False 158 | CORS(app) 159 | 160 | def batch_search(batch_query, max_answer_length=20, start_top_k=1000, mid_top_k=100, top_k=10, doc_top_k=5, 161 | nprobe=64, sparse_weight=0.05, search_strategy='hybrid'): 162 | t0 = time() 163 | outs, _ = self.embed_query(batch_query)() 164 | start = np.concatenate([out[0] for out in outs], 0) 165 | end = np.concatenate([out[1] for out in outs], 0) 166 | sparse_uni = [out[2]['1'][1:len(out[3])+1] for out in outs] 167 | sparse_bi = [out[2]['2'][1:len(out[3])+1] for out in outs] 168 | input_ids = [out[3] for out in outs] 169 | query_vec = np.concatenate([start, end, [[1]]*len(outs)], 1) 170 | 171 | rets = self.mips.search( 172 | query_vec, (input_ids, sparse_uni, sparse_bi), q_texts=batch_query, nprobe=nprobe, 173 | doc_top_k=doc_top_k, start_top_k=start_top_k, mid_top_k=mid_top_k, top_k=top_k, 174 | search_strategy=search_strategy, filter_=args.filter, max_answer_length=max_answer_length, 175 | sparse_weight=sparse_weight 176 | ) 177 | t1 = time() 178 | out = {'ret': rets, 'time': int(1000 * (t1 - t0))} 179 | return out 180 | 181 | @app.route('/') 182 | def index(): 183 | return app.send_static_file('index.html') 184 | 185 | @app.route('/files/') 186 | def static_files(path): 187 | return app.send_static_file('files/' + path) 188 | 189 | # This one uses a default hyperparameters 190 | @app.route('/api', methods=['GET']) 191 | def api(): 192 | query = request.args['query'] 193 | strat = request.args['strat'] 194 | out = batch_search( 195 | [query], 196 | max_answer_length=args.max_answer_length, 197 | top_k=args.top_k, 198 | nprobe=args.nprobe, 199 | search_strategy=strat, 200 | doc_top_k=args.doc_top_k 201 | ) 202 | out['ret'] = out['ret'][0] 203 | return jsonify(out) 204 | 205 | @app.route('/batch_api', methods=['POST']) 206 | def batch_api(): 207 | batch_query = json.loads(request.form['query']) 208 | max_answer_length = int(request.form['max_answer_length']) 209 | start_top_k = int(request.form['start_top_k']) 210 | mid_top_k = int(request.form['mid_top_k']) 211 | top_k = int(request.form['top_k']) 212 | doc_top_k = int(request.form['doc_top_k']) 213 | nprobe = int(request.form['nprobe']) 214 | sparse_weight = float(request.form['sparse_weight']) 215 | strat = request.form['strat'] 216 | out = batch_search( 217 | batch_query, 218 | max_answer_length=max_answer_length, 219 | start_top_k=start_top_k, 220 | mid_top_k=mid_top_k, 221 | top_k=top_k, 222 | doc_top_k=doc_top_k, 223 | nprobe=nprobe, 224 | sparse_weight=sparse_weight, 225 | search_strategy=strat, 226 | ) 227 | return jsonify(out) 228 | 229 | @app.route('/get_examples', methods=['GET']) 230 | def get_examples(): 231 | with open(args.examples_path, 'r') as fp: 232 | examples = [line.strip() for line in fp.readlines()] 233 | return jsonify(examples) 234 | 235 | if self.query_port is None: 236 | logger.info('You must set self.query_port for querying. You can use self.update_query_port() later on.') 237 | logger.info(f'Starting Index server at {self.get_address(index_port)}') 238 | http_server = HTTPServer(WSGIContainer(app)) 239 | http_server.listen(index_port) 240 | IOLoop.instance().start() 241 | 242 | def serve_doc_ranker(self, doc_port, args): 243 | doc_ranker_path = os.path.join(args.dump_dir, args.doc_ranker_name) 244 | doc_ranker = TfidfDocRanker(doc_ranker_path, strict=False) 245 | app = Flask(__name__) 246 | app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False 247 | CORS(app) 248 | 249 | @app.route('/doc_index', methods=['POST']) 250 | def doc_index(): 251 | batch_query = json.loads(request.form['query']) 252 | doc_idxs = json.loads(request.form['doc_idxs']) 253 | outs = doc_ranker.batch_doc_scores(batch_query, doc_idxs) 254 | logger.info(f'Returning {len(outs)} from batch_doc_scores') 255 | return jsonify(outs) 256 | 257 | @app.route('/top_docs', methods=['POST']) 258 | def top_docs(): 259 | batch_query = json.loads(request.form['query']) 260 | top_k = int(request.form['top_k']) 261 | batch_results = doc_ranker.batch_closest_docs(batch_query, k=top_k) 262 | top_idxs = [b[0] for b in batch_results] 263 | top_scores = [b[1].tolist() for b in batch_results] 264 | logger.info(f'Returning from batch_doc_scores') 265 | return jsonify([top_idxs, top_scores]) 266 | 267 | @app.route('/text2spvec', methods=['POST']) 268 | def text2spvec(): 269 | batch_query = json.loads(request.form['query']) 270 | q_spvecs = [doc_ranker.text2spvec(q, val_idx=True) for q in batch_query] 271 | q_vals = [q_spvec[0].tolist() for q_spvec in q_spvecs] 272 | q_idxs = [q_spvec[1].tolist() for q_spvec in q_spvecs] 273 | logger.info(f'Returning {len(q_vals), len(q_idxs)} q_spvecs') 274 | return jsonify([q_vals, q_idxs]) 275 | 276 | logger.info(f'Starting DocRanker server at {self.get_address(doc_port)}') 277 | http_server = HTTPServer(WSGIContainer(app)) 278 | http_server.listen(doc_port) 279 | IOLoop.instance().start() 280 | 281 | def get_address(self, port): 282 | assert self.base_ip is not None and len(port) > 0 283 | return self.base_ip + ':' + port 284 | 285 | def embed_query(self, batch_query): 286 | emb_session = FuturesSession() 287 | r = emb_session.post(self.get_address(self.query_port) + '/batch_api', data={'query': json.dumps(batch_query)}) 288 | def map_(): 289 | result = r.result() 290 | emb = result.json() 291 | return emb, result.elapsed.total_seconds() * 1000 292 | return map_ 293 | 294 | def query(self, query, search_strategy='hybrid'): 295 | params = {'query': query, 'strat': search_strategy} 296 | res = requests.get(self.get_address(self.index_port) + '/api', params=params) 297 | if res.status_code != 200: 298 | logger.info('Wrong behavior %d' % res.status_code) 299 | try: 300 | outs = json.loads(res.text) 301 | except Exception as e: 302 | logger.info(f'no response or error for q {query}') 303 | logger.info(res.text) 304 | return outs 305 | 306 | def batch_query(self, batch_query, max_answer_length=20, start_top_k=1000, mid_top_k=100, top_k=10, doc_top_k=5, 307 | nprobe=64, sparse_weight=0.05, search_strategy='hybrid'): 308 | post_data = { 309 | 'query': json.dumps(batch_query), 310 | 'max_answer_length': max_answer_length, 311 | 'start_top_k': start_top_k, 312 | 'mid_top_k': mid_top_k, 313 | 'top_k': top_k, 314 | 'doc_top_k': doc_top_k, 315 | 'nprobe': nprobe, 316 | 'sparse_weight': sparse_weight, 317 | 'strat': search_strategy, 318 | } 319 | res = requests.post(self.get_address(self.index_port) + '/batch_api', data=post_data) 320 | if res.status_code != 200: 321 | logger.info('Wrong behavior %d' % res.status_code) 322 | try: 323 | outs = json.loads(res.text) 324 | except Exception as e: 325 | logger.info(f'no response or error for q {batch_query}') 326 | logger.info(res.text) 327 | return outs 328 | 329 | def get_doc_scores(self, batch_query, doc_idxs): 330 | post_data = { 331 | 'query': json.dumps(batch_query), 332 | 'doc_idxs': json.dumps(doc_idxs) 333 | } 334 | res = requests.post(self.get_address(self.doc_port) + '/doc_index', data=post_data) 335 | if res.status_code != 200: 336 | logger.info('Wrong behavior %d' % res.status_code) 337 | try: 338 | result = json.loads(res.text) 339 | except Exception as e: 340 | logger.info(f'no response or error for {doc_idxs}') 341 | logger.info(res.text) 342 | return result 343 | 344 | def get_top_docs(self, batch_query, top_k): 345 | post_data = { 346 | 'query': json.dumps(batch_query), 347 | 'top_k': top_k 348 | } 349 | res = requests.post(self.get_address(self.doc_port) + '/top_docs', data=post_data) 350 | if res.status_code != 200: 351 | logger.info('Wrong behavior %d' % res.status_code) 352 | try: 353 | result = json.loads(res.text) 354 | except Exception as e: 355 | logger.info(f'no response or error for {top_k}') 356 | logger.info(res.text) 357 | return result 358 | 359 | def get_q_spvecs(self, batch_query): 360 | post_data = {'query': json.dumps(batch_query)} 361 | res = requests.post(self.get_address(self.doc_port) + '/text2spvec', data=post_data) 362 | if res.status_code != 200: 363 | logger.info('Wrong behavior %d' % res.status_code) 364 | try: 365 | result = json.loads(res.text) 366 | except Exception as e: 367 | logger.info(f'no response or error for q {batch_query}') 368 | logger.info(res.text) 369 | return result 370 | 371 | 372 | if __name__ == '__main__': 373 | parser = argparse.ArgumentParser() 374 | # QueryEncoder 375 | parser.add_argument('--metadata_dir', default='/nvme/jinhyuk/denspi/bert', type=str) 376 | parser.add_argument("--vocab_name", default='vocab.txt', type=str) 377 | parser.add_argument("--bert_config_name", default='bert_config.json', type=str) 378 | parser.add_argument("--bert_model_option", default='large_uncased', type=str) 379 | parser.add_argument("--parallel", default=False, action='store_true') 380 | parser.add_argument("--do_case", default=False, action='store_true') 381 | parser.add_argument("--query_encoder_path", default='/nvme/jinhyuk/denspi/KR94373_piqa-nfs_1173/1/model.pt', type=str) 382 | parser.add_argument("--query_port", default='-1', type=str) 383 | 384 | # DocRanker 385 | parser.add_argument('--doc_ranker_name', default='docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz') 386 | parser.add_argument('--doc_port', default='-1', type=str) 387 | 388 | # PhraseIndex 389 | parser.add_argument('--dump_dir', default='/nvme/jinhyuk/denspi/1173_wikipedia_filtered') 390 | parser.add_argument('--phrase_dir', default='phrase') 391 | parser.add_argument('--tfidf_dir', default='tfidf') 392 | parser.add_argument('--index_dir', default='1048576_hnsw_SQ8') 393 | parser.add_argument('--index_name', default='index.faiss') 394 | parser.add_argument('--idx2id_name', default='idx2id.hdf5') 395 | parser.add_argument('--index_port', default='-1', type=str) 396 | 397 | # These can be dynamically changed. 398 | parser.add_argument('--max_answer_length', default=20, type=int) 399 | parser.add_argument('--start_top_k', default=1000, type=int) 400 | parser.add_argument('--mid_top_k', default=100, type=int) 401 | parser.add_argument('--top_k', default=10, type=int) 402 | parser.add_argument('--doc_top_k', default=5, type=int) 403 | parser.add_argument('--nprobe', default=256, type=int) 404 | parser.add_argument('--sparse_weight', default=0.05, type=float) 405 | parser.add_argument('--search_strategy', default='hybrid') 406 | parser.add_argument('--filter', default=False, action='store_true') 407 | parser.add_argument('--no_para', default=False, action='store_true') 408 | 409 | # Serving options 410 | parser.add_argument('--examples_path', default='examples.txt') 411 | 412 | # Run mode 413 | parser.add_argument('--base_ip', default='http://163.152.163.248') 414 | parser.add_argument('--run_mode', default='batch_query') 415 | parser.add_argument('--cuda', default=False, action='store_true') 416 | parser.add_argument('--draft', default=False, action='store_true') 417 | parser.add_argument('--seed', default=1992, type=int) 418 | args = parser.parse_args() 419 | 420 | # Seed for reproducibility 421 | random.seed(args.seed) 422 | np.random.seed(args.seed) 423 | torch.manual_seed(args.seed) 424 | if torch.cuda.is_available(): 425 | torch.cuda.manual_seed_all(args.seed) 426 | 427 | server = DenSPIServer(args) 428 | 429 | # Set ports 430 | # server.query_port = '9010' 431 | # server.doc_port = '9020' 432 | # sersver.index_port = '10001' 433 | 434 | if args.run_mode == 'q_serve': 435 | logger.info(f'Query address: {server.get_address(server.query_port)}') 436 | server.serve_query_encoder(args.query_port, args) 437 | 438 | elif args.run_mode == 'd_serve': 439 | logger.info(f'Doc address: {server.get_address(server.doc_port)}') 440 | server.serve_doc_ranker(args.doc_port, args) 441 | 442 | elif args.run_mode == 'p_serve': 443 | logger.info(f'Query address: {server.get_address(server.query_port)}') 444 | logger.info(f'Doc address: {server.get_address(server.doc_port)}') 445 | logger.info(f'Index address: {server.get_address(server.index_port)}') 446 | server.serve_phrase_index(args.index_port, args) 447 | 448 | elif args.run_mode == 'query': 449 | logger.info(f'Index address: {server.get_address(server.index_port)}') 450 | query = 'Name three famous writers' 451 | result = server.query(query) 452 | logger.info(f'Answers to a question: {query}') 453 | logger.info(f'{[r["answer"] for r in result["ret"]]}') 454 | 455 | elif args.run_mode == 'batch_query': 456 | logger.info(f'Index address: {server.get_address(server.index_port)}') 457 | queries= [ 458 | 'Name three famous writers', 459 | 'Who was defeated by computer in chess game?' 460 | ] 461 | result = server.batch_query( 462 | queries, 463 | max_answer_length=args.max_answer_length, 464 | start_top_k=args.start_top_k, 465 | mid_top_k=args.mid_top_k, 466 | top_k=args.top_k, 467 | doc_top_k=args.doc_top_k, 468 | nprobe=args.nprobe, 469 | sparse_weight=args.sparse_weight, 470 | search_strategy=args.search_strategy, 471 | ) 472 | for query, result in zip(queries, result['ret']): 473 | logger.info(f'Answers to a question: {query}') 474 | logger.info(f'{[r["answer"] for r in result]}') 475 | 476 | elif args.run_mode == 'get_doc_scores': 477 | logger.info(f'Doc address: {server.get_address(server.doc_port)}') 478 | queries = [ 479 | 'What was the Yuan\'s paper money called?', 480 | 'What makes a successful startup??', 481 | 'On which date was Genghis Khan\'s palace rediscovered by archeaologists?', 482 | 'To-y is a _ .' 483 | ] 484 | result = server.get_doc_scores(queries, [[36], [2], [31], [22222]]) 485 | logger.info(result) 486 | result = server.get_top_docs(queries, 5) 487 | logger.info(result) 488 | result = server.get_q_spvecs(queries) 489 | logger.info(result) 490 | 491 | else: 492 | raise NotImplementedError 493 | -------------------------------------------------------------------------------- /simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | from tokenizer_util import Tokens, Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SimpleTokenizer(Tokenizer): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | if len(kwargs.get('annotators', {})) > 0: 32 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 33 | (type(self).__name__, kwargs.get('annotators'))) 34 | self.annotators = set() 35 | 36 | def tokenize(self, text): 37 | data = [] 38 | matches = [m for m in self._regexp.finditer(text)] 39 | for i in range(len(matches)): 40 | # Get text 41 | token = matches[i].group() 42 | 43 | # Get whitespace 44 | span = matches[i].span() 45 | start_ws = span[0] 46 | if i + 1 < len(matches): 47 | end_ws = matches[i + 1].span()[0] 48 | else: 49 | end_ws = span[1] 50 | 51 | # Format data 52 | data.append(( 53 | token, 54 | text[start_ws: end_ws], 55 | span, 56 | )) 57 | return Tokens(data, self.annotators) 58 | -------------------------------------------------------------------------------- /static/examples.txt: -------------------------------------------------------------------------------- 1 | What is the GDP of South Korea in 1950? 2 | What is the GDP of South Korea in 2010? 3 | Who is the fourth president of USA? 4 | Who is the seventh president of USA? 5 | ------------------------------------------------- 6 | What is South Korea known for? 7 | What tends to lead to more money? 8 | Who was defeated by computer in chess game? 9 | Name three famous writers 10 | What makes a successful startup? 11 | Why did Oracle sue Google? 12 | Where can you find water in desert? 13 | What does AMI stand for? 14 | How heavy was the apollo 11? 15 | What is water consisted of? 16 | What makes a man great? 17 | Which city is famous for coffee? 18 | On which date was Genghis Khan's palace rediscovered by archeaologists? 19 | What is another term for x-ray imaging? 20 | Who scolded Luther about his rudeness? 21 | What was the Yuan's paper money called? 22 | -------------------------------------------------------------------------------- /static/files/pichu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhyuklee/sparc/bee309bdffd73d162c23a7c3c0b63cebe4aaea97/static/files/pichu.png -------------------------------------------------------------------------------- /static/files/pika.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhyuklee/sparc/bee309bdffd73d162c23a7c3c0b63cebe4aaea97/static/files/pika.png -------------------------------------------------------------------------------- /static/files/popper.min.js: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) Federico Zivolo 2018 3 | Distributed under the MIT License (license terms are at http://opensource.org/licenses/MIT). 4 | */(function(e,t){'object'==typeof exports&&'undefined'!=typeof module?module.exports=t():'function'==typeof define&&define.amd?define(t):e.Popper=t()})(this,function(){'use strict';function e(e){return e&&'[object Function]'==={}.toString.call(e)}function t(e,t){if(1!==e.nodeType)return[];var o=getComputedStyle(e,null);return t?o[t]:o}function o(e){return'HTML'===e.nodeName?e:e.parentNode||e.host}function n(e){if(!e)return document.body;switch(e.nodeName){case'HTML':case'BODY':return e.ownerDocument.body;case'#document':return e.body;}var i=t(e),r=i.overflow,p=i.overflowX,s=i.overflowY;return /(auto|scroll|overlay)/.test(r+s+p)?e:n(o(e))}function r(e){return 11===e?re:10===e?pe:re||pe}function p(e){if(!e)return document.documentElement;for(var o=r(10)?document.body:null,n=e.offsetParent;n===o&&e.nextElementSibling;)n=(e=e.nextElementSibling).offsetParent;var i=n&&n.nodeName;return i&&'BODY'!==i&&'HTML'!==i?-1!==['TD','TABLE'].indexOf(n.nodeName)&&'static'===t(n,'position')?p(n):n:e?e.ownerDocument.documentElement:document.documentElement}function s(e){var t=e.nodeName;return'BODY'!==t&&('HTML'===t||p(e.firstElementChild)===e)}function d(e){return null===e.parentNode?e:d(e.parentNode)}function a(e,t){if(!e||!e.nodeType||!t||!t.nodeType)return document.documentElement;var o=e.compareDocumentPosition(t)&Node.DOCUMENT_POSITION_FOLLOWING,n=o?e:t,i=o?t:e,r=document.createRange();r.setStart(n,0),r.setEnd(i,0);var l=r.commonAncestorContainer;if(e!==l&&t!==l||n.contains(i))return s(l)?l:p(l);var f=d(e);return f.host?a(f.host,t):a(e,d(t).host)}function l(e){var t=1=o.clientWidth&&n>=o.clientHeight}),l=0a[e]&&!t.escapeWithReference&&(n=J(f[o],a[e]-('right'===e?f.width:f.height))),ae({},o,n)}};return l.forEach(function(e){var t=-1===['left','top'].indexOf(e)?'secondary':'primary';f=le({},f,m[t](e))}),e.offsets.popper=f,e},priority:['left','right','top','bottom'],padding:5,boundariesElement:'scrollParent'},keepTogether:{order:400,enabled:!0,fn:function(e){var t=e.offsets,o=t.popper,n=t.reference,i=e.placement.split('-')[0],r=Z,p=-1!==['top','bottom'].indexOf(i),s=p?'right':'bottom',d=p?'left':'top',a=p?'width':'height';return o[s]r(n[s])&&(e.offsets.popper[d]=r(n[s])),e}},arrow:{order:500,enabled:!0,fn:function(e,o){var n;if(!q(e.instance.modifiers,'arrow','keepTogether'))return e;var i=o.element;if('string'==typeof i){if(i=e.instance.popper.querySelector(i),!i)return e;}else if(!e.instance.popper.contains(i))return console.warn('WARNING: `arrow.element` must be child of its popper element!'),e;var r=e.placement.split('-')[0],p=e.offsets,s=p.popper,d=p.reference,a=-1!==['left','right'].indexOf(r),l=a?'height':'width',f=a?'Top':'Left',m=f.toLowerCase(),h=a?'left':'top',c=a?'bottom':'right',u=S(i)[l];d[c]-us[c]&&(e.offsets.popper[m]+=d[m]+u-s[c]),e.offsets.popper=g(e.offsets.popper);var b=d[m]+d[l]/2-u/2,y=t(e.instance.popper),w=parseFloat(y['margin'+f],10),E=parseFloat(y['border'+f+'Width'],10),v=b-e.offsets.popper[m]-w-E;return v=$(J(s[l]-u,v),0),e.arrowElement=i,e.offsets.arrow=(n={},ae(n,m,Q(v)),ae(n,h,''),n),e},element:'[x-arrow]'},flip:{order:600,enabled:!0,fn:function(e,t){if(W(e.instance.modifiers,'inner'))return e;if(e.flipped&&e.placement===e.originalPlacement)return e;var o=v(e.instance.popper,e.instance.reference,t.padding,t.boundariesElement,e.positionFixed),n=e.placement.split('-')[0],i=T(n),r=e.placement.split('-')[1]||'',p=[];switch(t.behavior){case he.FLIP:p=[n,i];break;case he.CLOCKWISE:p=z(n);break;case he.COUNTERCLOCKWISE:p=z(n,!0);break;default:p=t.behavior;}return p.forEach(function(s,d){if(n!==s||p.length===d+1)return e;n=e.placement.split('-')[0],i=T(n);var a=e.offsets.popper,l=e.offsets.reference,f=Z,m='left'===n&&f(a.right)>f(l.left)||'right'===n&&f(a.left)f(l.top)||'bottom'===n&&f(a.top)f(o.right),g=f(a.top)f(o.bottom),b='left'===n&&h||'right'===n&&c||'top'===n&&g||'bottom'===n&&u,y=-1!==['top','bottom'].indexOf(n),w=!!t.flipVariations&&(y&&'start'===r&&h||y&&'end'===r&&c||!y&&'start'===r&&g||!y&&'end'===r&&u);(m||b||w)&&(e.flipped=!0,(m||b)&&(n=p[d+1]),w&&(r=G(r)),e.placement=n+(r?'-'+r:''),e.offsets.popper=le({},e.offsets.popper,C(e.instance.popper,e.offsets.reference,e.placement)),e=P(e.instance.modifiers,e,'flip'))}),e},behavior:'flip',padding:5,boundariesElement:'viewport'},inner:{order:700,enabled:!1,fn:function(e){var t=e.placement,o=t.split('-')[0],n=e.offsets,i=n.popper,r=n.reference,p=-1!==['left','right'].indexOf(o),s=-1===['top','left'].indexOf(o);return i[p?'left':'top']=r[o]-(s?i[p?'width':'height']:0),e.placement=T(t),e.offsets.popper=g(i),e}},hide:{order:800,enabled:!0,fn:function(e){if(!q(e.instance.modifiers,'hide','preventOverflow'))return e;var t=e.offsets.reference,o=D(e.instance.modifiers,function(e){return'preventOverflow'===e.name}).boundaries;if(t.bottomo.right||t.top>o.bottom||t.right 2 | 3 | 4 | DenSPI + Sparc 5 | 6 | 7 | 8 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 37 |
38 | 43 | 44 | 45 |
46 |
47 | 50 | 58 |
59 | 60 | 63 |
64 | 67 |
68 |
69 | 70 |
71 |
Latency:
72 |
73 | 77 | Wikipedia EN (Dec 2016 dump) 78 |
79 |
80 | 81 | 82 |
83 | 84 | 85 |
86 | 87 | 88 |
89 | 90 | 91 |
92 | 93 | 94 |
95 | 96 | 97 |
98 | 99 |
100 | 101 |
102 |
    103 |
  • 104 |
105 |
106 | 107 |
108 | 109 |
110 |
111 | 112 | Contextualized Sparse Representations for Real-Time Open-Domain Question Answering
113 | Jinhyuk Lee, Minjoon Seo, Hannaneh Hajishirzi, Jaewoo Kang
114 |
115 |
116 |
117 | 118 | 119 | 205 | 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /static/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhyuklee/sparc/bee309bdffd73d162c23a7c3c0b63cebe4aaea97/static/preview.png -------------------------------------------------------------------------------- /tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | 16 | import tfidf_util as utils 17 | from simple_tokenizer import SimpleTokenizer 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class TfidfDocRanker(object): 23 | """Loads a pre-weighted inverted index of token/document terms. 24 | Scores new queries by taking sparse dot products. 25 | """ 26 | 27 | def __init__(self, tfidf_path=None, strict=True): 28 | """ 29 | Args: 30 | tfidf_path: path to saved model file 31 | strict: fail on empty queries or continue (and return empty result) 32 | """ 33 | # Load from disk 34 | tfidf_path = tfidf_path 35 | logger.info('Loading %s' % tfidf_path) 36 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 37 | self.doc_mat = matrix 38 | self.ngrams = metadata['ngram'] 39 | self.hash_size = metadata['hash_size'] 40 | self.tokenizer = SimpleTokenizer() 41 | self.doc_freqs = metadata['doc_freqs'].squeeze() 42 | self.doc_dict = metadata['doc_dict'] 43 | self.num_docs = len(self.doc_dict[0]) 44 | self.strict = strict 45 | 46 | def get_doc_index(self, doc_id): 47 | """Convert doc_id --> doc_index""" 48 | return self.doc_dict[0][doc_id] 49 | 50 | def get_doc_id(self, doc_index): 51 | """Convert doc_index --> doc_id""" 52 | return self.doc_dict[1][doc_index] 53 | 54 | def closest_docs(self, query, k=1): 55 | """Closest docs by dot product between query and documents 56 | in tfidf weighted word vector space. 57 | """ 58 | spvec = self.text2spvec(query) 59 | res = spvec * self.doc_mat 60 | 61 | if len(res.data) <= k: 62 | o_sort = np.argsort(-res.data) 63 | else: 64 | o = np.argpartition(-res.data, k)[0:k] 65 | o_sort = o[np.argsort(-res.data[o])] 66 | 67 | doc_scores = res.data[o_sort] 68 | # doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 69 | doc_ids = [int(i) for i in res.indices[o_sort]] # TODO if none is returned? 70 | return doc_ids, doc_scores 71 | 72 | def batch_closest_docs(self, queries, k=1, num_workers=None): 73 | """Process a batch of closest_docs requests multithreaded. 74 | Note: we can use plain threads here as scipy is outside of the GIL. 75 | """ 76 | with ThreadPool(num_workers) as threads: 77 | closest_docs = partial(self.closest_docs, k=k) 78 | results = threads.map(closest_docs, queries) 79 | return results 80 | 81 | def doc_scores(self, query_doc): 82 | """Get doc scores by dot product between query and documents 83 | in tfidf weighted word vector space. 84 | """ 85 | query, doc_idx = query_doc 86 | spvec = self.text2spvec(query) 87 | res = spvec * self.doc_mat 88 | scores = res[0,doc_idx].toarray().tolist()[0] 89 | return scores 90 | 91 | def batch_doc_scores(self, queries, doc_idxs, num_workers=None): 92 | """Process a batch of doc_scores requests multithreaded. 93 | Note: we can use plain threads here as scipy is outside of the GIL. 94 | """ 95 | with ThreadPool(num_workers) as threads: 96 | results = threads.map(self.doc_scores, zip(queries, doc_idxs)) 97 | return results 98 | 99 | def parse(self, query): 100 | """Parse the query into tokens (either ngrams or tokens).""" 101 | tokens = self.tokenizer.tokenize(query) 102 | return tokens.ngrams(n=self.ngrams, uncased=True, 103 | filter_fn=utils.filter_ngram) 104 | 105 | def text2spvec(self, query, val_idx=False): 106 | """Create a sparse tfidf-weighted word vector from query. 107 | 108 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 109 | """ 110 | # Get hashed ngrams 111 | words = self.parse(utils.normalize(query)) 112 | wids = [utils.hash(w, self.hash_size) for w in words] 113 | 114 | if len(wids) == 0: 115 | if self.strict: 116 | raise RuntimeError('No valid word in: %s' % query) 117 | else: 118 | logger.warning('No valid word in: %s' % query) 119 | if val_idx: 120 | return np.array([]), np.array([]) 121 | else: 122 | return sp.csr_matrix((1, self.hash_size)) 123 | 124 | # Count TF 125 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 126 | tfs = np.log1p(wids_counts) 127 | 128 | # Count IDF 129 | Ns = self.doc_freqs[wids_unique] 130 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 131 | idfs[idfs < 0] = 0 132 | 133 | # TF-IDF 134 | data = np.multiply(tfs, idfs) 135 | 136 | if val_idx: 137 | return data, wids_unique 138 | 139 | # One row, sparse csr matrix 140 | indptr = np.array([0, len(wids_unique)]) 141 | spvec = sp.csr_matrix( 142 | (data, wids_unique, indptr), shape=(1, self.hash_size) 143 | ) 144 | 145 | return spvec 146 | -------------------------------------------------------------------------------- /tfidf_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Various retriever utilities.""" 8 | 9 | import regex 10 | import unicodedata 11 | import numpy as np 12 | import scipy.sparse as sp 13 | from sklearn.utils import murmurhash3_32 14 | 15 | 16 | # ------------------------------------------------------------------------------ 17 | # Sparse matrix saving/loading helpers. 18 | # ------------------------------------------------------------------------------ 19 | 20 | 21 | def save_sparse_csr(filename, matrix, metadata=None): 22 | data = { 23 | 'data': matrix.data, 24 | 'indices': matrix.indices, 25 | 'indptr': matrix.indptr, 26 | 'shape': matrix.shape, 27 | 'metadata': metadata, 28 | } 29 | np.savez(filename, **data) 30 | 31 | 32 | def load_sparse_csr(filename): 33 | loader = np.load(filename, allow_pickle=True) 34 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 35 | loader['indptr']), shape=loader['shape']) 36 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 37 | 38 | 39 | # ------------------------------------------------------------------------------ 40 | # Token hashing. 41 | # ------------------------------------------------------------------------------ 42 | 43 | 44 | def hash(token, num_buckets): 45 | """Unsigned 32 bit murmurhash for feature hashing.""" 46 | return murmurhash3_32(token, positive=True) % num_buckets 47 | 48 | 49 | # ------------------------------------------------------------------------------ 50 | # Text cleaning. 51 | # ------------------------------------------------------------------------------ 52 | 53 | 54 | STOPWORDS = { 55 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 56 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 57 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 58 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 59 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 60 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 61 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 62 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 63 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 64 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 65 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 66 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 67 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 68 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 69 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 70 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 71 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 72 | } 73 | 74 | 75 | def normalize(text): 76 | """Resolve different type of unicode encodings.""" 77 | return unicodedata.normalize('NFD', text) 78 | 79 | 80 | def filter_word(text): 81 | """Take out english stopwords, punctuation, and compound endings.""" 82 | text = normalize(text) 83 | if regex.match(r'^\p{P}+$', text): 84 | return True 85 | if text.lower() in STOPWORDS: 86 | return True 87 | return False 88 | 89 | 90 | def filter_ngram(gram, mode='any'): 91 | """Decide whether to keep or discard an n-gram. 92 | 93 | Args: 94 | gram: list of tokens (length N) 95 | mode: Option to throw out ngram if 96 | 'any': any single token passes filter_word 97 | 'all': all tokens pass filter_word 98 | 'ends': book-ended by filterable tokens 99 | """ 100 | filtered = [filter_word(w) for w in gram] 101 | if mode == 'any': 102 | return any(filtered) 103 | elif mode == 'all': 104 | return all(filtered) 105 | elif mode == 'ends': 106 | return filtered[0] or filtered[-1] 107 | else: 108 | raise ValueError('Invalid mode: %s' % mode) 109 | 110 | def get_field(d, field_list): 111 | """get the subfield associated to a list of elastic fields 112 | E.g. ['file', 'filename'] to d['file']['filename'] 113 | """ 114 | if isinstance(field_list, str): 115 | return d[field_list] 116 | else: 117 | idx = d.copy() 118 | for field in field_list: 119 | idx = idx[field] 120 | return idx 121 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | index = 0 73 | with open(vocab_file, "r") as reader: 74 | while True: 75 | token = convert_to_unicode(reader.readline()) 76 | if not token: 77 | break 78 | token = token.strip() 79 | vocab[token] = index 80 | index += 1 81 | return vocab 82 | 83 | 84 | def convert_tokens_to_ids(vocab, tokens): 85 | """Converts a sequence of tokens into ids using the vocab.""" 86 | ids = [] 87 | for token in tokens: 88 | ids.append(vocab[token]) 89 | return ids 90 | 91 | 92 | def whitespace_tokenize(text): 93 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 94 | text = text.strip() 95 | if not text: 96 | return [] 97 | tokens = text.split() 98 | return tokens 99 | 100 | 101 | class FullTokenizer(object): 102 | """Runs end-to-end tokenziation.""" 103 | 104 | def __init__(self, vocab_file, do_lower_case=True): 105 | self.vocab = load_vocab(vocab_file) 106 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 107 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 108 | 109 | def tokenize(self, text, basic_done=False): 110 | split_tokens = [] 111 | if basic_done: 112 | assert type(text)==list 113 | else: 114 | text = self.basic_tokenizer.tokenize(text) 115 | 116 | for token in text: 117 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 118 | split_tokens.append(sub_token) 119 | 120 | return split_tokens 121 | 122 | def convert_tokens_to_ids(self, tokens): 123 | return convert_tokens_to_ids(self.vocab, tokens) 124 | 125 | 126 | class BasicTokenizer(object): 127 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 128 | 129 | def __init__(self, do_lower_case=True): 130 | """Constructs a BasicTokenizer. 131 | 132 | Args: 133 | do_lower_case: Whether to lower case the input. 134 | """ 135 | self.do_lower_case = do_lower_case 136 | 137 | def tokenize(self, text): 138 | """Tokenizes a piece of text.""" 139 | text = convert_to_unicode(text) 140 | text = self._clean_text(text) 141 | # This was added on November 1st, 2018 for the multilingual and Chinese 142 | # models. This is also applied to the English models now, but it doesn't 143 | # matter since the English models were not trained on any Chinese data 144 | # and generally don't have any Chinese data in them (there are Chinese 145 | # characters in the vocabulary because Wikipedia does have some Chinese 146 | # words in the English Wikipedia.). 147 | text = self._tokenize_chinese_chars(text) 148 | orig_tokens = whitespace_tokenize(text) 149 | split_tokens = [] 150 | for token in orig_tokens: 151 | if self.do_lower_case: 152 | token = token.lower() 153 | token = self._run_strip_accents(token) 154 | split_tokens.extend(self._run_split_on_punc(token)) 155 | 156 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 157 | return output_tokens 158 | 159 | def _run_strip_accents(self, text): 160 | """Strips accents from a piece of text.""" 161 | text = unicodedata.normalize("NFD", text) 162 | output = [] 163 | for char in text: 164 | cat = unicodedata.category(char) 165 | if cat == "Mn": 166 | continue 167 | output.append(char) 168 | return "".join(output) 169 | 170 | def _run_split_on_punc(self, text): 171 | """Splits punctuation on a piece of text.""" 172 | chars = list(text) 173 | i = 0 174 | start_new_word = True 175 | output = [] 176 | while i < len(chars): 177 | char = chars[i] 178 | if _is_punctuation(char): 179 | output.append([char]) 180 | start_new_word = True 181 | else: 182 | if start_new_word: 183 | output.append([]) 184 | start_new_word = False 185 | output[-1].append(char) 186 | i += 1 187 | 188 | return ["".join(x) for x in output] 189 | 190 | def _tokenize_chinese_chars(self, text): 191 | """Adds whitespace around any CJK character.""" 192 | output = [] 193 | for char in text: 194 | cp = ord(char) 195 | if self._is_chinese_char(cp): 196 | output.append(" ") 197 | output.append(char) 198 | output.append(" ") 199 | else: 200 | output.append(char) 201 | return "".join(output) 202 | 203 | def _is_chinese_char(self, cp): 204 | """Checks whether CP is the codepoint of a CJK character.""" 205 | # This defines a "chinese character" as anything in the CJK Unicode block: 206 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 207 | # 208 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 209 | # despite its name. The modern Korean Hangul alphabet is a different block, 210 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 211 | # space-separated words, so they are not treated specially and handled 212 | # like the all of the other languages. 213 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 214 | (cp >= 0x3400 and cp <= 0x4DBF) or # 215 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 216 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 217 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 218 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 219 | (cp >= 0xF900 and cp <= 0xFAFF) or # 220 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 221 | return True 222 | 223 | return False 224 | 225 | def _clean_text(self, text): 226 | """Performs invalid character removal and whitespace cleanup on text.""" 227 | output = [] 228 | for char in text: 229 | cp = ord(char) 230 | if cp == 0 or cp == 0xfffd or _is_control(char): 231 | continue 232 | if _is_whitespace(char): 233 | output.append(" ") 234 | else: 235 | output.append(char) 236 | return "".join(output) 237 | 238 | 239 | class WordpieceTokenizer(object): 240 | """Runs WordPiece tokenization.""" 241 | 242 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 243 | self.vocab = vocab 244 | self.unk_token = unk_token 245 | self.max_input_chars_per_word = max_input_chars_per_word 246 | 247 | def tokenize(self, text): 248 | """Tokenizes a piece of text into its word pieces. 249 | 250 | This uses a greedy longest-match-first algorithm to perform tokenization 251 | using the given vocabulary. 252 | 253 | For example: 254 | input = "unaffable" 255 | output = ["un", "##aff", "##able"] 256 | 257 | Args: 258 | text: A single token or whitespace separated tokens. This should have 259 | already been passed through `BasicTokenizer. 260 | 261 | Returns: 262 | A list of wordpiece tokens. 263 | """ 264 | 265 | text = convert_to_unicode(text) 266 | 267 | output_tokens = [] 268 | for token in whitespace_tokenize(text): 269 | chars = list(token) 270 | if len(chars) > self.max_input_chars_per_word: 271 | output_tokens.append(self.unk_token) 272 | continue 273 | 274 | is_bad = False 275 | start = 0 276 | sub_tokens = [] 277 | while start < len(chars): 278 | end = len(chars) 279 | cur_substr = None 280 | while start < end: 281 | substr = "".join(chars[start:end]) 282 | if start > 0: 283 | substr = "##" + substr 284 | if substr in self.vocab: 285 | cur_substr = substr 286 | break 287 | end -= 1 288 | if cur_substr is None: 289 | is_bad = True 290 | break 291 | sub_tokens.append(cur_substr) 292 | start = end 293 | 294 | if is_bad: 295 | output_tokens.append(self.unk_token) 296 | else: 297 | output_tokens.extend(sub_tokens) 298 | return output_tokens 299 | 300 | 301 | def _is_whitespace(char): 302 | """Checks whether `chars` is a whitespace character.""" 303 | # \t, \n, and \r are technically contorl characters but we treat them 304 | # as whitespace since they are generally considered as such. 305 | if char == " " or char == "\t" or char == "\n" or char == "\r": 306 | return True 307 | cat = unicodedata.category(char) 308 | if cat == "Zs": 309 | return True 310 | return False 311 | 312 | 313 | def _is_control(char): 314 | """Checks whether `chars` is a control character.""" 315 | # These are technically control characters but we count them as whitespace 316 | # characters. 317 | if char == "\t" or char == "\n" or char == "\r": 318 | return False 319 | cat = unicodedata.category(char) 320 | if cat.startswith("C"): 321 | return True 322 | return False 323 | 324 | 325 | def _is_punctuation(char): 326 | """Checks whether `chars` is a punctuation character.""" 327 | cp = ord(char) 328 | # We treat all non-letter/number ASCII as punctuation. 329 | # Characters such as "^", "$", and "`" are not in the Unicode 330 | # Punctuation class but we treat them as punctuation anyways, for 331 | # consistency. 332 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 333 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 334 | return True 335 | cat = unicodedata.category(char) 336 | if cat.startswith("P"): 337 | return True 338 | return False 339 | -------------------------------------------------------------------------------- /tokenizer_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | class Tokens(object): 13 | """A class to represent a list of tokenized text.""" 14 | TEXT = 0 15 | TEXT_WS = 1 16 | SPAN = 2 17 | POS = 3 18 | LEMMA = 4 19 | NER = 5 20 | 21 | def __init__(self, data, annotators, opts=None): 22 | self.data = data 23 | self.annotators = annotators 24 | self.opts = opts or {} 25 | 26 | def __len__(self): 27 | """The number of tokens.""" 28 | return len(self.data) 29 | 30 | def slice(self, i=None, j=None): 31 | """Return a view of the list of tokens from [i, j).""" 32 | new_tokens = copy.copy(self) 33 | new_tokens.data = self.data[i: j] 34 | return new_tokens 35 | 36 | def untokenize(self): 37 | """Returns the original text (with whitespace reinserted).""" 38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 39 | 40 | def words(self, uncased=False): 41 | """Returns a list of the text of each token 42 | 43 | Args: 44 | uncased: lower cases text 45 | """ 46 | if uncased: 47 | return [t[self.TEXT].lower() for t in self.data] 48 | else: 49 | return [t[self.TEXT] for t in self.data] 50 | 51 | def offsets(self): 52 | """Returns a list of [start, end) character offsets of each token.""" 53 | return [t[self.SPAN] for t in self.data] 54 | 55 | def pos(self): 56 | """Returns a list of part-of-speech tags of each token. 57 | Returns None if this annotation was not included. 58 | """ 59 | if 'pos' not in self.annotators: 60 | return None 61 | return [t[self.POS] for t in self.data] 62 | 63 | def lemmas(self): 64 | """Returns a list of the lemmatized text of each token. 65 | Returns None if this annotation was not included. 66 | """ 67 | if 'lemma' not in self.annotators: 68 | return None 69 | return [t[self.LEMMA] for t in self.data] 70 | 71 | def entities(self): 72 | """Returns a list of named-entity-recognition tags of each token. 73 | Returns None if this annotation was not included. 74 | """ 75 | if 'ner' not in self.annotators: 76 | return None 77 | return [t[self.NER] for t in self.data] 78 | 79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 80 | """Returns a list of all ngrams from length 1 to n. 81 | 82 | Args: 83 | n: upper limit of ngram length 84 | uncased: lower cases text 85 | filter_fn: user function that takes in an ngram list and returns 86 | True or False to keep or not keep the ngram 87 | as_string: return the ngram as a string vs list 88 | """ 89 | def _skip(gram): 90 | if not filter_fn: 91 | return False 92 | return filter_fn(gram) 93 | 94 | words = self.words(uncased) 95 | ngrams = [(s, e + 1) 96 | for s in range(len(words)) 97 | for e in range(s, min(s + n, len(words))) 98 | if not _skip(words[s:e + 1])] 99 | 100 | # Concatenate into strings 101 | if as_strings: 102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 103 | 104 | return ngrams 105 | 106 | def entity_groups(self): 107 | """Group consecutive entity tokens with the same NER tag.""" 108 | entities = self.entities() 109 | if not entities: 110 | return None 111 | non_ent = self.opts.get('non_ent', 'O') 112 | groups = [] 113 | idx = 0 114 | while idx < len(entities): 115 | ner_tag = entities[idx] 116 | # Check for entity tag 117 | if ner_tag != non_ent: 118 | # Chomp the sequence 119 | start = idx 120 | while (idx < len(entities) and entities[idx] == ner_tag): 121 | idx += 1 122 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 123 | else: 124 | idx += 1 125 | return groups 126 | 127 | 128 | class Tokenizer(object): 129 | """Base tokenizer class. 130 | Tokenizers implement tokenize, which should return a Tokens class. 131 | """ 132 | def tokenize(self, text): 133 | raise NotImplementedError 134 | 135 | def shutdown(self): 136 | pass 137 | 138 | def __del__(self): 139 | self.shutdown() 140 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 3 | level=logging.INFO) 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | def check_diff(model_a, model_b): 8 | a_set = set([a for a in model_a.keys()]) 9 | b_set = set([b for b in model_b.keys()]) 10 | if a_set != b_set: 11 | logger.info('load with different params =>') 12 | if len(a_set - b_set) > 0: 13 | logger.info('Loaded weight does not have ' + str(a_set - b_set)) 14 | if len(b_set - a_set) > 0: 15 | logger.info('Model code does not have: ' + str(b_set - a_set)) 16 | --------------------------------------------------------------------------------