├── .gitignore
├── LICENSE
├── README.md
├── build_tfidf.py
├── environment.yml
├── eval_utils.py
├── evaluate-v1.1.py
├── file_utils.py
├── input_examples.txt
├── local_dump.py
├── mips_phrase.py
├── modeling.py
├── optimization.py
├── post.py
├── pre.py
├── run_index.py
├── run_server.py
├── simple_tokenizer.py
├── static
├── examples.txt
├── files
│ ├── all.js
│ ├── bootstrap.min.js
│ ├── jquery-3.3.1.min.js
│ ├── pichu.png
│ ├── pika.png
│ ├── popper.min.js
│ └── style.css
├── index.html
└── preview.png
├── tfidf_doc_ranker.py
├── tfidf_util.py
├── tokenization.py
├── tokenizer_util.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | out
3 | test
4 | __pycache__
5 | logs/
6 | pred/
7 | models/
8 | dumps*/
9 | data/
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019 University of Washington
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Sparc
3 |
4 |
5 |
6 |
7 |
8 |
C ontextualized Spar se Representations for Real-Time Open-Domain Question Answering
9 |
10 |
11 |
12 |
13 |
14 |
15 | This repository provides author's implementation of [Contextualized Sparse Representation for Real-Time Open-Domain Question Answering](https://arxiv.org/abs/1911.02896). You can train and evaluate DenSPI+Sparc described in our paper and make your own Sparc vector.
16 |
17 | ## Environment
18 | Please install the Conda environment as follows:
19 | ```bash
20 | $ conda env create -f environment.yml
21 | $ conda activate sparc
22 | ```
23 | Note that this repository is mostly based on [DenSPI](https://github.com/uwnlp/denspi) and [DrQA](https://github.com/facebookresearch/DrQA).
24 |
25 | ## Resources
26 | We use [SQuAD v1.1](https://github.com/rajpurkar/SQuAD-explorer/tree/master/dataset) for training DenSPI+Sparc. Please download them in `$DATA_DIR`.
27 | ```bash
28 | $ mkdir $DATA_DIR
29 | $ wget https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-v1.1.json -O $DATA_DIR/train-v1.1.json
30 | $ wget https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v1.1.json -O $DATA_DIR/dev-v1.1.json
31 | ```
32 |
33 | DenSPI is based on BERT. Please download pre-trained weights of BERT under `$BERT_DIR`.
34 | ```bash
35 | $ mkdir $BERT_DIR
36 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin -O $BERT_DIR/pytorch_model_base_uncased.bin
37 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json -O $BERT_DIR/bert_config_base_uncased.json
38 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin -O $BERT_DIR/pytorch_model_large_uncased.bin
39 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json -O $BERT_DIR/bert_config_large_uncased.json
40 | # Vocabulary is the same for BERT-base and BERT-large.
41 | $ wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt -O $BERT_DIR/vocab.txt
42 | ```
43 |
44 | ## Model
45 | To train DenSPI+Sparc on SQuAD, use `train.py`. Trained models will be saved in `$OUT_DIR1`.
46 | ```bash
47 | $ mkdir $OUT_DIR1
48 | # Train with BERT-base
49 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR1 --bert_model_option 'base_uncased' --train_file train-v1.1.json --predict_file dev-v1.1.json --do_train --do_predict --do_eval
50 | # Train with BERT-large (use smaller train_batch_size for 12GB GPUs)
51 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR1 --bert_model_option 'large_uncased' --parallel --train_file train-v1.1.json --predict_file dev-v1.1.json --do_train --do_predict --do_eval --train_batch_size 6
52 | ```
53 |
54 | The result will look like (in case of BERT-base):
55 | ```bash
56 | 04/28/2020 06:32:59 - INFO - post - num vecs=45059736, num_words=1783576, nvpw=25.2637
57 | 04/28/2020 06:33:01 - INFO - __main__ - [Validation] loss: 8.700, b'{"exact_match": 75.10879848628193, "f1": 83.42143097917004}\n'
58 | ```
59 |
60 | To use DenSPI+Sparc in an open-domain setting, you have to additionally train it with negative samples. In case of DenSPI+Sparc with BERT-base (same for BERT-large except `--bert_model_option` and `--parallel` arguments), commands for training on negative samples are:
61 | ```bash
62 | $ mkdir $OUT_DIR2
63 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR --bert_model_option 'base_uncased' --train_file train-v1.1.json --predict_file dev-v1.1.json --do_train_neg --do_predict --do_eval --do_load --load_dir $OUT_DIR1 --load_epoch 3
64 | ```
65 |
66 | Finally, train the phrase classifer as:
67 | ```bash
68 | $ mkdir $OUT_DIR3
69 | # Train only 1 epoch for phrase classifier
70 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR --bert_model_option 'base_uncased' --train_file train-v1.1.json --predict_file dev-v1.1.json --num_train_epochs 1 --do_train_filter --do_predict --do_eval --do_load --load_dir $OUT_DIR2 --load_epoch 3
71 | ```
72 |
73 | We also provide a pretrained DenSPI+Sparc as follows:
74 | * DenSPI+Sparc pre-trained on SQuAD - [link](https://drive.google.com/file/d/1lObQ2lX8bWwJRzUuEqH6kpPdSTmS_Zxw/view?usp=sharing)
75 |
76 |
77 | ## Sparc Embedding
78 | Given the pre-trained DenSPI+Sparc, you can get Sparc embedding with following commands. Example below assumes using our pre-trained weight ([`denspi_sparc.zip`](https://drive.google.com/file/d/1lObQ2lX8bWwJRzUuEqH6kpPdSTmS_Zxw/view?usp=sharing) unzipped in `denspi_sparc` folder). If you want to use your own model, please modify `MODEL_DIR` accordingly.
79 |
80 | For any type of text you want to embed, put them in each line of `input_examples.txt`. Then run:
81 | ```bash
82 | $ export DATA_DIR=.
83 | $ export MODEL_DIR=denspi_sparc
84 | $ python train.py --data_dir $DATA_DIR --metadata_dir $BERT_DIR --output_dir $OUT_DIR --predict_file input_examples.txt --parallel --bert_model_option 'large_uncased' --do_load --load_dir $MODEL_DIR --load_epoch 1 --do_embed --dump_file output.json
85 | ```
86 |
87 | The result file `$OUT_DIR/output.json` will show Sparc embedding of the input text ([CLS] representation, sorted by scores). For instance:
88 | ```json
89 | {
90 | "out": [
91 | {
92 | "text": "They defeated the Arizona Cardinals 49-15 in the NFC Championship Game and advanced to their second Super Bowl appearance since the franchise was founded in 1995.",
93 | "sparc": {
94 | "uni": {
95 | "1995": {
96 | "score": 1.6841894388198853,
97 | "vocab": "2786"
98 | },
99 | "second": {
100 | "score": 1.6321970224380493,
101 | "vocab": "2117"
102 | },
103 | "49": {
104 | "score": 1.6075607538223267,
105 | "vocab": "4749"
106 | },
107 | "arizona": {
108 | "score": 1.1734912395477295,
109 | "vocab": "5334"
110 | },
111 | },
112 | "bi": {
113 | "arizona cardinals": {
114 | "score": 1.3190401792526245,
115 | "vocab": "5334, 9310"
116 | },
117 | "nfc championship": {
118 | "score": 1.1005975008010864,
119 | "vocab": "22309, 2528"
120 | },
121 | "49 -": {
122 | "score": 1.0863999128341675,
123 | "vocab": "4749, 1011"
124 | },
125 | "the arizona": {
126 | "score": 0.9722453951835632,
127 | "vocab": "1996, 5334"
128 | },
129 | }
130 | }
131 | }
132 | ]
133 | }
134 | ```
135 | Note that each text is segmented by the BERT tokenizer (`"vocab"` denotes the BERT vocab index).
136 |
137 | To see how Sparc changes for each phrase, set `start_index` in [here](https://github.com/jhyuklee/sparc/blob/750bf1a2b79f0e074edb77ef535c7e2861ffd8fd/post.py#L371) to the target token position. For instance, setting `start_index = 17` to embed Sparc of `415,000` of the following text gives you (some n-grams are omitted):
138 |
139 | ```json
140 | "text": "Between 1991 and 2000, the total area of forest lost in the Amazon rose from 415,000 to 587,000 square kilometres.",
141 | "sparc": {
142 | "uni": {
143 | "1991": {
144 | "score": 1.182684063911438,
145 | "vocab": "2889"
146 | },
147 | "2000": {
148 | "score": 0.41507360339164734,
149 | "vocab": "2456"
150 | },
151 | ```
152 | whereas setting `start_index = 21` to embed Sparc of `587,000` gives you:
153 | ```json
154 | "text": "Between 1991 and 2000, the total area of forest lost in the Amazon rose from 415,000 to 587,000 square kilometres.",
155 | "sparc": {
156 | "uni": {
157 | "2000": {
158 | "score": 1.1923936605453491,
159 | "vocab": "2456"
160 | },
161 | "1991": {
162 | "score": 0.7090237140655518,
163 | "vocab": "2889"
164 | },
165 | ```
166 |
167 | ## Phrase Index
168 | For now, please see [the original DenSPI repository](https://github.com/uwnlp/denspi) or [the recent application of DenSPI in COVID-19 domain](https://github.com/dmis-lab/covidAsk) for building phrase index using DenSPI+Sparc.
169 | The main changes in phrase indexing are in `post.py` and `mips_phrase.py` where Sparc is used for the open-domain QA inference (See [here](https://github.com/jhyuklee/sparc/blob/885729372706e227fa9c566ca51bd88de984710a/mips_phrase.py#L390-L410)).
170 |
171 | ## Reference
172 | ```bibtex
173 | @inproceedings{lee2020contextualized,
174 | title={Contextualized Sparse Representations for Real-Time Open-Domain Question Answering},
175 | author={Lee, Jinhyuk and Seo, Minjoon and Hajishirzi, Hannaneh and Kang, Jaewoo},
176 | booktitle={ACL},
177 | year={2020}
178 | }
179 | ```
180 |
--------------------------------------------------------------------------------
/build_tfidf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """A script to build the tf-idf document matrices for retrieval."""
8 |
9 | import numpy as np
10 | import scipy.sparse as sp
11 | import argparse
12 | import os
13 | import math
14 | import logging
15 | import json
16 | import copy
17 | import pandas as pd
18 |
19 | from multiprocessing import Pool as ProcessPool
20 | from multiprocessing.util import Finalize
21 | from functools import partial
22 | from collections import Counter
23 |
24 | import tfidf_util
25 | from simple_tokenizer import SimpleTokenizer
26 |
27 | logger = logging.getLogger()
28 | logger.setLevel(logging.INFO)
29 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
30 | console = logging.StreamHandler()
31 | console.setFormatter(fmt)
32 | logger.addHandler(console)
33 |
34 |
35 | # ------------------------------------------------------------------------------
36 | # Multiprocessing functions
37 | # ------------------------------------------------------------------------------
38 |
39 | DOC2IDX = None
40 | PROCESS_TOK = None
41 | PROCESS_DB = None
42 |
43 |
44 | def init(tokenizer_class, db):
45 | global PROCESS_TOK, PROCESS_DB
46 | PROCESS_TOK = tokenizer_class()
47 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
48 | PROCESS_DB = db
49 |
50 |
51 | def fetch_text(doc_id):
52 | global PROCESS_DB
53 | return PROCESS_DB[doc_id]
54 |
55 |
56 | def tokenize(text):
57 | global PROCESS_TOK
58 | return PROCESS_TOK.tokenize(text)
59 |
60 |
61 | # ------------------------------------------------------------------------------
62 | # Build article --> word count sparse matrix.
63 | # ------------------------------------------------------------------------------
64 |
65 |
66 | def count(ngram, hash_size, doc_id):
67 | """Fetch the text of a document and compute hashed ngrams counts."""
68 | global DOC2IDX
69 | row, col, data = [], [], []
70 | # Tokenize
71 | tokens = tokenize(tfidf_util.normalize(fetch_text(doc_id)))
72 |
73 | # Get ngrams from tokens, with stopword/punctuation filtering.
74 | ngrams = tokens.ngrams(
75 | n=ngram, uncased=True, filter_fn=tfidf_util.filter_ngram
76 | )
77 |
78 | # Hash ngrams and count occurences
79 | counts = Counter([tfidf_util.hash(gram, hash_size) for gram in ngrams])
80 |
81 | # Return in sparse matrix data format.
82 | row.extend(counts.keys())
83 | col.extend([DOC2IDX[doc_id]] * len(counts))
84 | data.extend(counts.values())
85 | return row, col, data
86 |
87 |
88 | def get_count_matrix(args, file_path):
89 | """Form a sparse word to document count matrix (inverted index).
90 |
91 | M[i, j] = # times word i appears in document j.
92 | """
93 | # Map doc_ids to indexes
94 | global DOC2IDX
95 | doc_ids = {}
96 | doc_metas = {}
97 | nan_cnt = 0
98 | for filename in sorted(os.listdir(file_path)):
99 | print(filename)
100 | with open(os.path.join(file_path, filename), 'r') as f:
101 | articles = json.load(f)['data']
102 | for article in articles:
103 | title = article['title']
104 | kk = 0
105 | while title in doc_ids:
106 | title += f'_{kk}'
107 | kk += 1
108 | doc_ids[title] = ' '.join([par['context'] for par in article['paragraphs']])
109 |
110 | # Keep metadata
111 | doc_meta = {}
112 | for key, val in article.items():
113 | if key != 'paragraphs':
114 | doc_meta[key] = val if val == val else 'NaN'
115 | else:
116 | doc_meta[key] = []
117 | for para in val:
118 | para_meta = {}
119 | for para_key, para_val in para.items():
120 | if para_key != 'context':
121 | para_meta[para_key] = para_val if para_val == para_val else 'NaN'
122 | doc_meta[key].append(para_meta)
123 | if not pd.isnull(article.get('pubmed_id', np.nan)):
124 | doc_metas[str(article['pubmed_id'])] = doc_meta # For BEST (might be duplicate)
125 | else:
126 | nan_cnt += 1
127 | doc_metas[article['title']] = doc_meta
128 |
129 | DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)}
130 | print('doc ids:', len(DOC2IDX))
131 | print('doc metas:', len(doc_metas), 'with nan', str(nan_cnt))
132 | # assert len(doc_ids)*2 == len(doc_metas) + nan_cnt
133 |
134 | # Setup worker pool
135 | tok_class = SimpleTokenizer
136 | workers = ProcessPool(
137 | args.num_workers,
138 | initializer=init,
139 | initargs=(tok_class, doc_ids)
140 | )
141 | doc_ids = list(doc_ids.keys())
142 |
143 | # Compute the count matrix in steps (to keep in memory)
144 | logger.info('Mapping...')
145 | row, col, data = [], [], []
146 | step = max(int(len(doc_ids) / 10), 1)
147 | batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)]
148 | _count = partial(count, args.ngram, args.hash_size)
149 | for i, batch in enumerate(batches):
150 | logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) + '-' * 25)
151 | for b_row, b_col, b_data in workers.imap_unordered(_count, batch):
152 | row.extend(b_row)
153 | col.extend(b_col)
154 | data.extend(b_data)
155 | workers.close()
156 | workers.join()
157 |
158 | logger.info('Creating sparse matrix...')
159 | count_matrix = sp.csr_matrix(
160 | (data, (row, col)), shape=(args.hash_size, len(doc_ids))
161 | )
162 | count_matrix.sum_duplicates()
163 | return count_matrix, (DOC2IDX, doc_ids, doc_metas)
164 |
165 |
166 | # ------------------------------------------------------------------------------
167 | # Transform count matrix to different forms.
168 | # ------------------------------------------------------------------------------
169 |
170 |
171 | def get_tfidf_matrix(cnts):
172 | """Convert the word count matrix into tfidf one.
173 |
174 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
175 | * tf = term frequency in document
176 | * N = number of documents
177 | * Nt = number of occurences of term in all documents
178 | """
179 | Ns = get_doc_freqs(cnts)
180 | idfs = np.log((cnts.shape[1] - Ns + 0.5) / (Ns + 0.5))
181 | idfs[idfs < 0] = 0
182 | idfs = sp.diags(idfs, 0)
183 | tfs = cnts.log1p()
184 | tfidfs = idfs.dot(tfs)
185 | return tfidfs
186 |
187 |
188 | def get_doc_freqs(cnts):
189 | """Return word --> # of docs it appears in."""
190 | binary = (cnts > 0).astype(int)
191 | freqs = np.array(binary.sum(1)).squeeze()
192 | return freqs
193 |
194 |
195 | # ------------------------------------------------------------------------------
196 | # Main.
197 | # ------------------------------------------------------------------------------
198 |
199 |
200 | if __name__ == '__main__':
201 | parser = argparse.ArgumentParser()
202 | parser.add_argument('file_path', type=str, default=None,
203 | help='Path to document texts')
204 | parser.add_argument('out_dir', type=str, default=None,
205 | help='Directory for saving output files')
206 | parser.add_argument('--ngram', type=int, default=2,
207 | help=('Use up to N-size n-grams '
208 | '(e.g. 2 = unigrams + bigrams)'))
209 | parser.add_argument('--hash-size', type=int, default=int(math.pow(2, 24)),
210 | help='Number of buckets to use for hashing ngrams')
211 | parser.add_argument('--num-workers', type=int, default=None,
212 | help='Number of CPU processes (for tokenizing, etc)')
213 | args = parser.parse_args()
214 |
215 | logging.info('Counting words...')
216 | count_matrix, doc_dict = get_count_matrix(
217 | args, args.file_path
218 | )
219 |
220 | logger.info('Making tfidf vectors...')
221 | tfidf = get_tfidf_matrix(count_matrix)
222 |
223 | logger.info('Getting word-doc frequencies...')
224 | freqs = get_doc_freqs(count_matrix)
225 |
226 | basename = os.path.splitext(os.path.basename(args.file_path))[0]
227 | basename += ('-tfidf-ngram=%d-hash=%d-tokenizer=simple' %
228 | (args.ngram, args.hash_size))
229 | filename = os.path.join(args.out_dir, basename)
230 |
231 | logger.info('Saving to %s.npz' % filename)
232 | metadata = {
233 | 'doc_freqs': freqs,
234 | 'tokenizer': 'simple',
235 | 'hash_size': args.hash_size,
236 | 'ngram': args.ngram,
237 | 'doc_dict': doc_dict
238 | }
239 | tfidf_util.save_sparse_csr(filename, tfidf, metadata)
240 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: sparc
2 | channels:
3 | - cyclus
4 | - pytorch
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - blas=1.0=mkl
10 | - ca-certificates=2020.1.1=0
11 | - certifi=2019.11.28=py36_0
12 | - cudatoolkit=10.0.130=0
13 | - curl=7.65.2=hbc83047_0
14 | - faiss-gpu=1.6.1=py36h1a5d453_0
15 | - htop=2.2.0=hf8c457e_1000
16 | - intel-openmp=2019.4=243
17 | - java-jdk=8.45.14=0
18 | - krb5=1.16.1=h173b8e3_7
19 | - libcurl=7.65.2=h20c2e04_0
20 | - libedit=3.1.20181209=hc058e9b_0
21 | - libffi=3.2.1=hd88cf55_4
22 | - libgcc-ng=9.1.0=hdf63c60_0
23 | - libgfortran-ng=7.3.0=hdf63c60_0
24 | - libssh2=1.8.2=h1ba5d50_0
25 | - libstdcxx-ng=9.1.0=hdf63c60_0
26 | - mkl=2020.0=166
27 | - mkl-service=2.3.0=py36he904b0f_0
28 | - mkl_fft=1.0.15=py36ha843d7b_0
29 | - mkl_random=1.1.0=py36hd6b4f25_0
30 | - ncurses=6.1=he6710b0_1
31 | - numpy-base=1.18.1=py36hde5b4d6_1
32 | - openssl=1.1.1e=h7b6447c_0
33 | - python=3.6.9=h265db76_0
34 | - readline=7.0=h7b6447c_5
35 | - setuptools=41.0.1=py36_0
36 | - sqlite=3.29.0=h7b6447c_0
37 | - tk=8.6.8=hbc83047_0
38 | - wheel=0.33.4=py36_0
39 | - xz=5.2.4=h14c3975_4
40 | - zlib=1.2.11=h7b6447c_3
41 | - pip:
42 | - backcall==0.1.0
43 | - chardet==3.0.4
44 | - click==7.0
45 | - decorator==4.4.2
46 | - elasticsearch==7.0.2
47 | - flask==1.0.2
48 | - flask-cors==3.0.7
49 | - h5py==2.9.0
50 | - idna==2.8
51 | - ipython==7.13.0
52 | - ipython-genutils==0.2.0
53 | - itsdangerous==1.1.0
54 | - jedi==0.16.0
55 | - jinja2==2.10.1
56 | - joblib==0.13.2
57 | - markupsafe==1.1.1
58 | - nltk==3.4.4
59 | - numpy==1.17.0
60 | - pandas==0.23.0
61 | - parso==0.6.2
62 | - pexpect==4.2.1
63 | - pickleshare==0.7.5
64 | - pip==20.0.2
65 | - prettytable==0.7.2
66 | - prompt-toolkit==3.0.4
67 | - ptyprocess==0.6.0
68 | - pygments==2.6.1
69 | - python-dateutil==2.8.1
70 | - pytz==2019.3
71 | - regex==2019.6.8
72 | - requests==2.22.0
73 | - requests-futures==0.9.9
74 | - scikit-learn==0.21.3
75 | - scipy==1.1.0
76 | - six==1.12.0
77 | - sklearn==0.0
78 | - termcolor==1.1.0
79 | - torch==1.1.0
80 | - tornado==6.0.1
81 | - tqdm==4.31.1
82 | - traitlets==4.3.3
83 | - ujson==2.0.2
84 | - urllib3==1.25.3
85 | - wcwidth==0.1.8
86 | - werkzeug==0.15.5
87 | prefix: /home/jinhyuk/miniconda3/envs/sparc
88 |
89 |
--------------------------------------------------------------------------------
/eval_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import ujson as json
3 | import re
4 | import string
5 | import unicodedata
6 | from collections import Counter
7 | import pickle
8 | from IPython import embed
9 |
10 | def normalize_answer(s):
11 |
12 | def remove_articles(text):
13 | return re.sub(r'\b(a|an|the)\b', ' ', text)
14 |
15 | def white_space_fix(text):
16 | return ' '.join(text.split())
17 |
18 | def remove_punc(text):
19 | exclude = set(string.punctuation)
20 | return ''.join(ch for ch in text if ch not in exclude)
21 |
22 | def lower(text):
23 | return text.lower()
24 |
25 | return white_space_fix(remove_articles(remove_punc(lower(s))))
26 |
27 |
28 | def f1_score(prediction, ground_truth):
29 | normalized_prediction = normalize_answer(prediction)
30 | normalized_ground_truth = normalize_answer(ground_truth)
31 |
32 | ZERO_METRIC = (0, 0, 0)
33 |
34 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
35 | return ZERO_METRIC
36 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
37 | return ZERO_METRIC
38 |
39 | prediction_tokens = normalized_prediction.split()
40 | ground_truth_tokens = normalized_ground_truth.split()
41 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
42 | num_same = sum(common.values())
43 | if num_same == 0:
44 | return ZERO_METRIC
45 | precision = 1.0 * num_same / len(prediction_tokens)
46 | recall = 1.0 * num_same / len(ground_truth_tokens)
47 | f1 = (2 * precision * recall) / (precision + recall)
48 | return f1, precision, recall
49 |
50 |
51 | def exact_match_score(prediction, ground_truth):
52 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
53 |
54 |
55 | def drqa_normalize(text):
56 | """Resolve different type of unicode encodings."""
57 | return unicodedata.normalize('NFD', text)
58 |
59 |
60 | def drqa_exact_match_score(prediction, ground_truth):
61 | """Check if the prediction is a (soft) exact match with the ground truth."""
62 | return normalize_answer(prediction) == normalize_answer(ground_truth)
63 |
64 |
65 | def drqa_regex_match_score(prediction, pattern):
66 | """Check if the prediction matches the given regular expression."""
67 | try:
68 | compiled = re.compile(
69 | pattern,
70 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE
71 | )
72 | except BaseException as e:
73 | # logger.warn('Regular expression failed to compile: %s' % pattern)
74 | # print('re failed to compile: [%s] due to [%s]' % (pattern, e))
75 | return False
76 | return compiled.match(prediction) is not None
77 |
78 |
79 | def drqa_metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
80 | """Given a prediction and multiple valid answers, return the score of
81 | the best prediction-answer_n pair given a metric function.
82 | """
83 | scores_for_ground_truths = []
84 | for ground_truth in ground_truths:
85 | score = metric_fn(prediction, ground_truth)
86 | scores_for_ground_truths.append(score)
87 | return max(scores_for_ground_truths)
88 |
89 |
90 | def update_answer(metrics, prediction, gold):
91 | em = exact_match_score(prediction, gold)
92 | f1, prec, recall = f1_score(prediction, gold)
93 | metrics['em'] += em
94 | metrics['f1'] += f1
95 | metrics['prec'] += prec
96 | metrics['recall'] += recall
97 | return em, prec, recall
98 |
99 |
100 | def update_sp(metrics, prediction, gold):
101 | cur_sp_pred = set(map(tuple, prediction))
102 | gold_sp_pred = set(map(tuple, gold))
103 | tp, fp, fn = 0, 0, 0
104 | for e in cur_sp_pred:
105 | if e in gold_sp_pred:
106 | tp += 1
107 | else:
108 | fp += 1
109 | for e in gold_sp_pred:
110 | if e not in cur_sp_pred:
111 | fn += 1
112 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
113 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
114 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
115 | em = 1.0 if fp + fn == 0 else 0.0
116 | metrics['sp_em'] += em
117 | metrics['sp_f1'] += f1
118 | metrics['sp_prec'] += prec
119 | metrics['sp_recall'] += recall
120 | return em, prec, recall
121 |
122 |
123 | def eval(prediction_file, gold_file):
124 | with open(prediction_file) as f:
125 | prediction = json.load(f)
126 | with open(gold_file) as f:
127 | gold = json.load(f)
128 |
129 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
130 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
131 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
132 |
133 | for dp in gold:
134 | cur_id = dp['_id']
135 | em, prec, recall = update_answer(
136 | metrics, prediction['answer'][cur_id], dp['answer'])
137 |
138 | N = len(gold)
139 | for k in metrics.keys():
140 | metrics[k] /= N
141 |
142 | print(metrics)
143 |
144 |
145 | def analyze(prediction_file, gold_file):
146 | with open(prediction_file) as f:
147 | prediction = json.load(f)
148 | with open(gold_file) as f:
149 | gold = json.load(f)
150 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
151 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
152 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
153 |
154 | for dp in gold:
155 | cur_id = dp['_id']
156 |
157 | em, prec, recall = update_answer(
158 | metrics, prediction['answer'][cur_id], dp['answer'])
159 | if (prec + recall == 0):
160 | f1 = 0
161 | else:
162 | f1 = 2 * prec * recall / (prec+recall)
163 |
164 | print (dp['answer'], prediction['answer'][cur_id])
165 | print (f1, em)
166 | a = input()
167 |
168 |
169 | if __name__ == '__main__':
170 | #eval(sys.argv[1], sys.argv[2])
171 | analyze(sys.argv[1], sys.argv[2])
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/evaluate-v1.1.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """
2 | from __future__ import print_function
3 | from collections import Counter
4 | import string
5 | import re
6 | import argparse
7 | import json
8 | import sys
9 |
10 |
11 | def normalize_answer(s):
12 | """Lower text and remove punctuation, articles and extra whitespace."""
13 | def remove_articles(text):
14 | return re.sub(r'\b(a|an|the)\b', ' ', text)
15 |
16 | def white_space_fix(text):
17 | return ' '.join(text.split())
18 |
19 | def remove_punc(text):
20 | exclude = set(string.punctuation)
21 | return ''.join(ch for ch in text if ch not in exclude)
22 |
23 | def lower(text):
24 | return text.lower()
25 |
26 | return white_space_fix(remove_articles(remove_punc(lower(s))))
27 |
28 |
29 | def f1_score(prediction, ground_truth):
30 | prediction_tokens = normalize_answer(prediction).split()
31 | ground_truth_tokens = normalize_answer(ground_truth).split()
32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
33 | num_same = sum(common.values())
34 | if num_same == 0:
35 | return 0
36 | precision = 1.0 * num_same / len(prediction_tokens)
37 | recall = 1.0 * num_same / len(ground_truth_tokens)
38 | f1 = (2 * precision * recall) / (precision + recall)
39 | return f1
40 |
41 |
42 | def exact_match_score(prediction, ground_truth):
43 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
44 |
45 |
46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
47 | scores_for_ground_truths = []
48 | for ground_truth in ground_truths:
49 | score = metric_fn(prediction, ground_truth)
50 | scores_for_ground_truths.append(score)
51 | return max(scores_for_ground_truths)
52 |
53 |
54 | def evaluate(dataset, predictions):
55 | count = 0
56 | f1 = exact_match = total = 0
57 | for article in dataset:
58 | for paragraph in article['paragraphs']:
59 | for qa in paragraph['qas']:
60 | total += 1
61 | if str(qa['id']) not in predictions:
62 | message = 'Unanswered question ' + str(qa['id']) + \
63 | ' will receive score 0.'
64 | count += 1
65 | # print(message, file=sys.stderr)
66 | continue
67 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
68 | prediction = predictions[str(qa['id'])]
69 | exact_match += metric_max_over_ground_truths(
70 | exact_match_score, prediction, ground_truths)
71 | f1 += metric_max_over_ground_truths(
72 | f1_score, prediction, ground_truths)
73 |
74 | exact_match = 100.0 * exact_match / total
75 | f1 = 100.0 * f1 / total
76 | if count:
77 | print('There are %d unanswered question(s)' % count)
78 |
79 | return {'exact_match': exact_match, 'f1': f1}
80 |
81 |
82 | if __name__ == '__main__':
83 | expected_version = '1.1'
84 | parser = argparse.ArgumentParser(
85 | description='Evaluation for SQuAD ' + expected_version)
86 | parser.add_argument('dataset_file', help='Dataset file')
87 | parser.add_argument('prediction_file', help='Prediction File')
88 | args = parser.parse_args()
89 | with open(args.dataset_file) as dataset_file:
90 | dataset_json = json.load(dataset_file)
91 | if (dataset_json['version'] != expected_version):
92 | print('Evaluation expects v-' + expected_version +
93 | ', but got dataset with v-' + dataset_json['version'],
94 | file=sys.stderr)
95 | dataset = dataset_json['data']
96 | with open(args.prediction_file) as prediction_file:
97 | predictions = json.load(prediction_file)
98 | print(json.dumps(evaluate(dataset, predictions)))
99 |
--------------------------------------------------------------------------------
/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import json
9 | import logging
10 | import os
11 | import shutil
12 | import tempfile
13 | from functools import wraps
14 | from hashlib import sha256
15 | import sys
16 | from io import open
17 |
18 | import boto3
19 | import requests
20 | from botocore.exceptions import ClientError
21 | from tqdm import tqdm
22 |
23 | try:
24 | from urllib.parse import urlparse
25 | except ImportError:
26 | from urlparse import urlparse
27 |
28 | try:
29 | from pathlib import Path
30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
31 | Path.home() / '.pytorch_pretrained_bert'))
32 | except (AttributeError, ImportError):
33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
35 |
36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37 |
38 |
39 | def url_to_filename(url, etag=None):
40 | """
41 | Convert `url` into a hashed filename in a repeatable way.
42 | If `etag` is specified, append its hash to the url's, delimited
43 | by a period.
44 | """
45 | url_bytes = url.encode('utf-8')
46 | url_hash = sha256(url_bytes)
47 | filename = url_hash.hexdigest()
48 |
49 | if etag:
50 | etag_bytes = etag.encode('utf-8')
51 | etag_hash = sha256(etag_bytes)
52 | filename += '.' + etag_hash.hexdigest()
53 |
54 | return filename
55 |
56 |
57 | def filename_to_url(filename, cache_dir=None):
58 | """
59 | Return the url and etag (which may be ``None``) stored for `filename`.
60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61 | """
62 | if cache_dir is None:
63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65 | cache_dir = str(cache_dir)
66 |
67 | cache_path = os.path.join(cache_dir, filename)
68 | if not os.path.exists(cache_path):
69 | raise EnvironmentError("file {} not found".format(cache_path))
70 |
71 | meta_path = cache_path + '.json'
72 | if not os.path.exists(meta_path):
73 | raise EnvironmentError("file {} not found".format(meta_path))
74 |
75 | with open(meta_path, encoding="utf-8") as meta_file:
76 | metadata = json.load(meta_file)
77 | url = metadata['url']
78 | etag = metadata['etag']
79 |
80 | return url, etag
81 |
82 |
83 | def cached_path(url_or_filename, cache_dir=None):
84 | """
85 | Given something that might be a URL (or might be a local path),
86 | determine which. If it's a URL, download the file and cache it, and
87 | return the path to the cached file. If it's already a local path,
88 | make sure the file exists and then return the path.
89 | """
90 | if cache_dir is None:
91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93 | url_or_filename = str(url_or_filename)
94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95 | cache_dir = str(cache_dir)
96 |
97 | parsed = urlparse(url_or_filename)
98 |
99 | if parsed.scheme in ('http', 'https', 's3'):
100 | # URL, so get it from the cache (downloading if necessary)
101 | return get_from_cache(url_or_filename, cache_dir)
102 | elif os.path.exists(url_or_filename):
103 | # File, and it exists.
104 | return url_or_filename
105 | elif parsed.scheme == '':
106 | # File, but it doesn't exist.
107 | raise EnvironmentError("file {} not found".format(url_or_filename))
108 | else:
109 | # Something unknown
110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111 |
112 |
113 | def split_s3_path(url):
114 | """Split a full s3 path into the bucket name and path."""
115 | parsed = urlparse(url)
116 | if not parsed.netloc or not parsed.path:
117 | raise ValueError("bad s3 path {}".format(url))
118 | bucket_name = parsed.netloc
119 | s3_path = parsed.path
120 | # Remove '/' at beginning of path.
121 | if s3_path.startswith("/"):
122 | s3_path = s3_path[1:]
123 | return bucket_name, s3_path
124 |
125 |
126 | def s3_request(func):
127 | """
128 | Wrapper function for s3 requests in order to create more helpful error
129 | messages.
130 | """
131 |
132 | @wraps(func)
133 | def wrapper(url, *args, **kwargs):
134 | try:
135 | return func(url, *args, **kwargs)
136 | except ClientError as exc:
137 | if int(exc.response["Error"]["Code"]) == 404:
138 | raise EnvironmentError("file {} not found".format(url))
139 | else:
140 | raise
141 |
142 | return wrapper
143 |
144 |
145 | @s3_request
146 | def s3_etag(url):
147 | """Check ETag on S3 object."""
148 | s3_resource = boto3.resource("s3")
149 | bucket_name, s3_path = split_s3_path(url)
150 | s3_object = s3_resource.Object(bucket_name, s3_path)
151 | return s3_object.e_tag
152 |
153 |
154 | @s3_request
155 | def s3_get(url, temp_file):
156 | """Pull a file directly from S3."""
157 | s3_resource = boto3.resource("s3")
158 | bucket_name, s3_path = split_s3_path(url)
159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160 |
161 |
162 | def http_get(url, temp_file):
163 | req = requests.get(url, stream=True)
164 | content_length = req.headers.get('Content-Length')
165 | total = int(content_length) if content_length is not None else None
166 | progress = tqdm(unit="B", total=total)
167 | for chunk in req.iter_content(chunk_size=1024):
168 | if chunk: # filter out keep-alive new chunks
169 | progress.update(len(chunk))
170 | temp_file.write(chunk)
171 | progress.close()
172 |
173 |
174 | def get_from_cache(url, cache_dir=None):
175 | """
176 | Given a URL, look for the corresponding dataset in the local cache.
177 | If it's not there, download it. Then return the path to the cached file.
178 | """
179 | if cache_dir is None:
180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182 | cache_dir = str(cache_dir)
183 |
184 | if not os.path.exists(cache_dir):
185 | os.makedirs(cache_dir)
186 |
187 | # Get eTag to add to filename, if it exists.
188 | if url.startswith("s3://"):
189 | etag = s3_etag(url)
190 | else:
191 | response = requests.head(url, allow_redirects=True)
192 | if response.status_code != 200:
193 | raise IOError("HEAD request failed for url {} with status code {}"
194 | .format(url, response.status_code))
195 | etag = response.headers.get("ETag")
196 |
197 | filename = url_to_filename(url, etag)
198 |
199 | # get cache path to put the file
200 | cache_path = os.path.join(cache_dir, filename)
201 |
202 | if not os.path.exists(cache_path):
203 | # Download to temporary file, then copy to cache dir once finished.
204 | # Otherwise you get corrupt cache entries if the download gets interrupted.
205 | with tempfile.NamedTemporaryFile() as temp_file:
206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
207 |
208 | # GET file object
209 | if url.startswith("s3://"):
210 | s3_get(url, temp_file)
211 | else:
212 | http_get(url, temp_file)
213 |
214 | # we are copying the file before closing it, so flush to avoid truncation
215 | temp_file.flush()
216 | # shutil.copyfileobj() starts at the current position, so go to the start
217 | temp_file.seek(0)
218 |
219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
220 | with open(cache_path, 'wb') as cache_file:
221 | shutil.copyfileobj(temp_file, cache_file)
222 |
223 | logger.info("creating metadata file for %s", cache_path)
224 | meta = {'url': url, 'etag': etag}
225 | meta_path = cache_path + '.json'
226 | with open(meta_path, 'w', encoding="utf-8") as meta_file:
227 | json.dump(meta, meta_file)
228 |
229 | logger.info("removing temp file %s", temp_file.name)
230 |
231 | return cache_path
232 |
233 |
234 | def read_set_from_file(filename):
235 | '''
236 | Extract a de-duped collection (set) of text from a file.
237 | Expected file format is one item per line.
238 | '''
239 | collection = set()
240 | with open(filename, 'r', encoding='utf-8') as file_:
241 | for line in file_:
242 | collection.add(line.rstrip())
243 | return collection
244 |
245 |
246 | def get_file_extension(path, dot=True, lower=True):
247 | ext = os.path.splitext(path)[1]
248 | ext = ext if dot else ext[1:]
249 | return ext.lower() if lower else ext
250 |
--------------------------------------------------------------------------------
/input_examples.txt:
--------------------------------------------------------------------------------
1 | They defeated the Arizona Cardinals 49-15 in the NFC Championship Game and advanced to their second Super Bowl appearance since the franchise was founded in 1995.
2 | Between 1991 and 2000, the total area of forest lost in the Amazon rose from 415,000 to 587,000 square kilometres.
3 | Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season.
4 | The normal force is due to repulsive forces of interaction between atoms at close contact.
5 | The service continued until the closure of BSkyB's analogue service on 27 September 2001, due to the launch and expansion of the Sky Digital platform.
6 | BSkyB launched its HDTV service, Sky+ HD, on 22 May 2006.
7 | Who was the chief executive officer when the service began?
8 | Which NFL team represented the AFC at Super Bowl 50?
9 | When will Ford's manufacturing plants close?
10 | How many dairy cows are there in Australia?
11 |
--------------------------------------------------------------------------------
/local_dump.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 | import subprocess
5 |
6 |
7 | def run_dump_phrase(args):
8 |
9 | parallel = '--parallel' if args.parallel else ''
10 | do_case = '--do_case' if args.do_case else ''
11 | use_biobert = '--use_biobert' if args.use_biobert else ''
12 | append_title = '--append_title' if args.append_title else ''
13 | def get_cmd(start_doc, end_doc):
14 | return ["python", "run_natkb.py",
15 | "--metadata_dir", f"{args.metadata_dir}",
16 | "--data_dir", f"{args.phrase_data_dir}",
17 | "--predict_file", f"{start_doc}:{end_doc}",
18 | "--bert_model_option", f"{args.bert_model_option}",
19 | "--do_dump",
20 | "--use_sparse",
21 | "--filter_threshold", f"{args.filter_threshold:.2f}",
22 | "--dump_dir", f"{args.phrase_dump_dir}",
23 | "--dump_file", f"{start_doc}-{end_doc}.hdf5",
24 | "--max_seq_length", "512",
25 | "--load_dir", f"{args.load_dir}",
26 | "--load_epoch", f"{args.load_epoch}"] + \
27 | ([f"{parallel}"] if len(parallel) > 0 else []) + \
28 | ([f"{do_case}"] if len(do_case) > 0 else []) + \
29 | ([f"{use_biobert}"] if len(use_biobert) > 0 else []) + \
30 | ([f"{append_title}"] if len(append_title) > 0 else [])
31 |
32 |
33 | num_docs = args.end - args.start
34 | num_gpus = args.num_gpus
35 | num_docs_per_gpu = int(math.ceil(num_docs / num_gpus))
36 | start_docs = list(range(args.start, args.end, num_docs_per_gpu))
37 | end_docs = start_docs[1:] + [args.end]
38 |
39 | print(start_docs)
40 | print(end_docs)
41 |
42 | for device_idx, (start_doc, end_doc) in enumerate(zip(start_docs, end_docs)):
43 | print(get_cmd(start_doc, end_doc))
44 | subprocess.Popen(get_cmd(start_doc, end_doc))
45 |
46 |
47 | def get_args():
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument('--dump_dir', default=None)
50 | parser.add_argument('--metadata_dir', default='/home/jinhyuk/models/bert')
51 | parser.add_argument('--data_dir', default='/home/jinhyuk/data/covid-19/preprocessed/2020-03-19/dump')
52 | parser.add_argument('--data_name', default='BioASQ')
53 | parser.add_argument('--load_dir', default='/home/jinhyuk/models/KR94373_piqa-nfs_1173')
54 | parser.add_argument('--load_epoch', default='1')
55 | parser.add_argument('--bert_model_option', default='large_uncased')
56 | parser.add_argument('--append_title', default=False, action='store_true')
57 | parser.add_argument('--parallel', default=False, action='store_true')
58 | parser.add_argument('--do_case', default=False, action='store_true')
59 | parser.add_argument('--use_biobert', default=False, action='store_true')
60 | parser.add_argument('--filter_threshold', default=-1e9, type=float)
61 | parser.add_argument('--num_gpus', default=1, type=int)
62 | parser.add_argument('--start', default=0, type=int)
63 | parser.add_argument('--end', default=8, type=int)
64 | args = parser.parse_args()
65 |
66 | if args.dump_dir is None:
67 | args.dump_dir = os.path.join('dump/%s_%s' % (os.path.basename(args.load_dir),
68 | os.path.basename(args.data_name)))
69 | if not os.path.exists(args.dump_dir):
70 | os.makedirs(args.dump_dir)
71 |
72 | args.phrase_data_dir = os.path.join(args.data_dir, args.data_name)
73 | args.phrase_dump_dir = os.path.join(args.dump_dir, 'phrase')
74 | if not os.path.exists(args.phrase_dump_dir):
75 | os.makedirs(args.phrase_dump_dir)
76 |
77 | return args
78 |
79 |
80 | def main():
81 | args = get_args()
82 | run_dump_phrase(args)
83 |
84 |
85 | if __name__ == '__main__':
86 | main()
87 |
--------------------------------------------------------------------------------
/mips_phrase.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 | import logging
6 | from collections import namedtuple, Counter
7 | from time import time
8 |
9 | import h5py
10 | import numpy as np
11 | import faiss
12 | import torch
13 | from tqdm import tqdm
14 |
15 | from scipy.sparse import vstack
16 |
17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
18 | level=logging.INFO)
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class MIPS(object):
23 | def __init__(self, phrase_dump_dir, tfidf_dump_dir, start_index_path, idx2id_path, max_norm_path,
24 | doc_rank_fn, cuda=False, dump_only=False):
25 |
26 | # If dump dir is a file, use it as a dump.
27 | if os.path.isdir(phrase_dump_dir):
28 | self.phrase_dump_paths = sorted(
29 | [os.path.join(phrase_dump_dir, name) for name in os.listdir(phrase_dump_dir) if 'hdf5' in name]
30 | )
31 | dump_names = [os.path.splitext(os.path.basename(path))[0] for path in self.phrase_dump_paths]
32 | self.dump_ranges = [list(map(int, name.split('-'))) for name in dump_names]
33 | else:
34 | self.phrase_dump_paths = [phrase_dump_dir]
35 | self.phrase_dumps = [h5py.File(path, 'r') for path in self.phrase_dump_paths]
36 |
37 | # Load tfidf dump
38 | assert os.path.isdir(tfidf_dump_dir), tfidf_dump_dir
39 | self.tfidf_dump_paths = sorted(
40 | [os.path.join(tfidf_dump_dir, name) for name in os.listdir(tfidf_dump_dir) if 'hdf5' in name]
41 | )
42 | tfidf_dump_names = [os.path.splitext(os.path.basename(path))[0] for path in self.tfidf_dump_paths]
43 | if '-' in tfidf_dump_names[0]: # Range check
44 | tfidf_dump_ranges = [list(map(int, name.split('_')[0].split('-'))) for name in tfidf_dump_names]
45 | assert tfidf_dump_ranges == self.dump_ranges
46 | self.tfidf_dumps = [h5py.File(path, 'r') for path in self.tfidf_dump_paths]
47 | logger.info(f'using doc ranker functions: {doc_rank_fn["index"]}')
48 | self.doc_rank_fn = doc_rank_fn
49 | if dump_only:
50 | return
51 |
52 | # Read index
53 | logger.info(f'Reading {start_index_path}')
54 | self.start_index = faiss.read_index(start_index_path, faiss.IO_FLAG_ONDISK_SAME_DIR)
55 | self.idx_f = self.load_idx_f(idx2id_path)
56 | with open(max_norm_path, 'r') as fp:
57 | self.max_norm = json.load(fp)
58 |
59 | # Options
60 | self.num_docs_list = []
61 | self.cuda = cuda
62 | if self.cuda:
63 | assert torch.cuda.is_available(), f"Cuda availability {torch.cuda.is_available()}"
64 | self.device = torch.device('cuda')
65 | else:
66 | self.device = torch.device("cpu")
67 |
68 | def close(self):
69 | for phrase_dump in self.phrase_dumps:
70 | phrase_dump.close()
71 | for tfidf_dump in self.tfidf_dumps:
72 | tfidf_dump.close()
73 |
74 | def load_idx_f(self, idx2id_path):
75 | idx_f = {}
76 | types = ['doc', 'word']
77 | with h5py.File(idx2id_path, 'r', driver='core', backing_store=False) as f:
78 | for key in tqdm(f, desc='loading idx2id'):
79 | idx_f_cur = {}
80 | for type_ in types:
81 | idx_f_cur[type_] = f[key][type_][:]
82 | idx_f[key] = idx_f_cur
83 | return idx_f
84 |
85 | def get_idxs(self, I):
86 | offsets = (I / 1e8).astype(np.int64) * int(1e8)
87 | idxs = I % int(1e8)
88 | doc = np.array(
89 | [[self.idx_f[str(offset)]['doc'][idx] for offset, idx in zip(oo, ii)] for oo, ii in zip(offsets, idxs)])
90 | word = np.array([[self.idx_f[str(offset)]['word'][idx] for offset, idx in zip(oo, ii)] for oo, ii in
91 | zip(offsets, idxs)])
92 | return doc, word
93 |
94 | def get_doc_group(self, doc_idx):
95 | if len(self.phrase_dumps) == 1:
96 | return self.phrase_dumps[0][str(doc_idx)]
97 | for dump_range, dump in zip(self.dump_ranges, self.phrase_dumps):
98 | if dump_range[0] * 1000 <= int(doc_idx) < dump_range[1] * 1000:
99 | if str(doc_idx) not in dump:
100 | raise ValueError('%d not found in dump list' % int(doc_idx))
101 | return dump[str(doc_idx)]
102 |
103 | # Just check last
104 | if str(doc_idx) not in self.phrase_dumps[-1]:
105 | raise ValueError('%d not found in dump list' % int(doc_idx))
106 | else:
107 | return self.phrase_dumps[-1][str(doc_idx)]
108 |
109 | def get_tfidf_group(self, doc_idx):
110 | if len(self.tfidf_dumps) == 1:
111 | return self.tfidf_dumps[0][str(doc_idx)]
112 | for dump_range, dump in zip(self.dump_ranges, self.tfidf_dumps):
113 | if dump_range[0] * 1000 <= int(doc_idx) < dump_range[1] * 1000:
114 | return dump[str(doc_idx)]
115 |
116 | # Just check last
117 | if str(doc_idx) not in self.tfidf_dumps[-1]:
118 | raise ValueError('%d not found in dump list' % int(doc_idx))
119 | else:
120 | return self.tfidf_dumps[-1][str(doc_idx)]
121 |
122 | def int8_to_float(self, num, offset, factor):
123 | return num.astype(np.float32) / factor + offset
124 |
125 | def adjust(self, each):
126 | last = each['context'].rfind(' [PAR] ', 0, each['start_pos'])
127 | last = 0 if last == -1 else last + len(' [PAR] ')
128 | next = each['context'].find(' [PAR] ', each['end_pos'])
129 | next = len(each['context']) if next == -1 else next
130 | each['context'] = each['context'][last:next]
131 | each['start_pos'] -= last
132 | each['end_pos'] -= last
133 | return each
134 |
135 | def scale_l2_to_ip(self, l2_scores, max_norm=None, query_norm=None):
136 | """
137 | sqrt(m^2 + q^2 - 2qx) -> m^2 + q^2 - 2qx -> qx - 0.5 (q^2 + m^2)
138 | Note that faiss index returns squared euclidean distance, so no need to square it again.
139 | """
140 | if max_norm is None:
141 | return -0.5 * l2_scores
142 | assert query_norm is not None
143 | return -0.5 * (l2_scores - query_norm ** 2 - max_norm ** 2)
144 |
145 | def dequant(self, group, input_, attr='dense'):
146 | if 'offset' not in group.attrs:
147 | return input_
148 |
149 | if attr == 'dense':
150 | return self.int8_to_float(input_, group.attrs['offset'], group.attrs['scale'])
151 | elif attr == 'sparse':
152 | return self.int8_to_float(input_, group.attrs['sparse_offset'], group.attrs['sparse_scale'])
153 | else:
154 | raise NotImplementedError()
155 |
156 | def sparse_bmm(self, q_ids, q_vals, p_ids, p_vals):
157 | """
158 | Efficient batch inner product after slicing (matrix x matrix)
159 | """
160 | q_max = max([len(q) for q in q_ids])
161 | p_max = max([len(p) for p in p_ids])
162 | factor = len(p_ids)//len(q_ids)
163 | assert q_max == max([len(q) for q in q_vals]) and p_max == max([len(p) for p in p_vals])
164 | with torch.no_grad():
165 | q_ids_pad = torch.LongTensor([q_id.tolist() + [0]*(q_max-len(q_id)) for q_id in q_ids]).to(self.device)
166 | q_ids_pad = q_ids_pad.repeat(1, factor).view(len(p_ids), -1) # Repeat for p
167 | q_vals_pad = torch.FloatTensor([q_val.tolist() + [0]*(q_max-len(q_val)) for q_val in q_vals]).to(self.device)
168 | q_vals_pad = q_vals_pad.repeat(1, factor).view(len(p_vals), -1) # Repeat for p
169 | p_ids_pad = torch.LongTensor([p_id.tolist() + [0]*(p_max-len(p_id)) for p_id in p_ids]).to(self.device)
170 | p_vals_pad = torch.FloatTensor([p_val.tolist() + [0]*(p_max-len(p_val)) for p_val in p_vals]).to(self.device)
171 | id_map = q_ids_pad.unsqueeze(1)
172 | id_map_ = p_ids_pad.unsqueeze(2)
173 | match = (id_map == id_map_).to(torch.float32)
174 | val_map = q_vals_pad.unsqueeze(1)
175 | val_map_ = p_vals_pad.unsqueeze(2)
176 | sp_scores = ((val_map * val_map_) * match).sum([1, 2])
177 | return sp_scores.cpu().numpy()
178 |
179 | def search_dense(self, q_texts, query_start, start_top_k, nprobe, sparse_weight=0.05):
180 | batch_size = query_start.shape[0]
181 | self.start_index.nprobe = nprobe
182 |
183 | # Query concatenation for l2 to ip
184 | query_start = np.concatenate([np.zeros([batch_size, 1]).astype(np.float32), query_start], axis=1)
185 |
186 | # Search with faiss
187 | start_scores, I = self.start_index.search(query_start, start_top_k)
188 | query_norm = np.linalg.norm(query_start, ord=2, axis=1)
189 | start_scores = self.scale_l2_to_ip(start_scores, max_norm=self.max_norm, query_norm=np.expand_dims(query_norm, 1))
190 |
191 | # Get idxs from resulting I
192 | doc_idxs, start_idxs = self.get_idxs(I)
193 |
194 | # For record
195 | num_docs = sum([len(set(doc_idx.flatten().tolist())) for doc_idx in doc_idxs]) / batch_size
196 | self.num_docs_list.append(num_docs)
197 |
198 | # Doc-level sparse score
199 | b_doc_scores = self.doc_rank_fn['index'](q_texts, doc_idxs.tolist()) # Index
200 | for b_idx in range(batch_size):
201 | start_scores[b_idx] += np.array(b_doc_scores[b_idx]) * sparse_weight
202 |
203 | return (doc_idxs, start_idxs), start_scores
204 |
205 | def search_sparse(self, q_texts, query_start, doc_top_k, start_top_k, sparse_weight=0.05):
206 | batch_size = query_start.shape[0]
207 |
208 | # Reduce search space by doc scores
209 | top_doc_idxs, top_doc_scores = self.doc_rank_fn['top_docs'](q_texts, doc_top_k) # Top docs
210 |
211 | # For each item, add start scores
212 | b_doc_idxs = []
213 | b_start_idxs = []
214 | b_scores = []
215 | max_phrases = 0
216 | for b_idx in range(batch_size):
217 | doc_idxs = []
218 | start_idxs = []
219 | scores = []
220 | for doc_idx, doc_score in zip(top_doc_idxs[b_idx], top_doc_scores[b_idx]):
221 | try:
222 | doc_group = self.get_doc_group(doc_idx)
223 | except ValueError:
224 | continue
225 | start = self.dequant(doc_group, doc_group['start'][:])
226 | cur_scores = np.sum(query_start[b_idx] * start, 1)
227 | for i, cur_score in enumerate(cur_scores):
228 | doc_idxs.append(doc_idx)
229 | start_idxs.append(i)
230 | scores.append(cur_score + sparse_weight * doc_score)
231 | max_phrases = len(scores) if len(scores) > max_phrases else max_phrases
232 |
233 | b_doc_idxs.append(doc_idxs)
234 | b_start_idxs.append(start_idxs)
235 | b_scores.append(scores)
236 |
237 | # If start_top_k is larger than nonnegative doc_idxs, we need to cut them later
238 | for doc_idxs, start_idxs, scores in zip(b_doc_idxs, b_start_idxs, b_scores):
239 | doc_idxs += [-1] * (max_phrases - len(doc_idxs))
240 | start_idxs += [-1] * (max_phrases - len(start_idxs))
241 | scores += [-10**9] * (max_phrases - len(scores))
242 |
243 | doc_idxs, start_idxs, scores = np.stack(b_doc_idxs), np.stack(b_start_idxs), np.stack(b_scores)
244 | return (doc_idxs, start_idxs), scores
245 |
246 | def batch_par_scores(self, q_texts, q_sparses, doc_idxs, start_idxs, sparse_weight=0.05, mid_top_k=100):
247 | # Reshape for sparse
248 | num_queries = len(q_texts)
249 | doc_idxs = np.reshape(doc_idxs, [-1])
250 | start_idxs = np.reshape(start_idxs, [-1])
251 |
252 | default_doc = [doc_idx for doc_idx in doc_idxs if doc_idx >= 0][0]
253 | groups = [self.get_doc_group(doc_idx) if doc_idx >= 0 else self.get_doc_group(default_doc)
254 | for doc_idx in doc_idxs]
255 |
256 | # Calculate paragraph start end location in sparse vector
257 | para_lens = [group['len_per_para'][:] for group in groups]
258 | f2o_start = [group['f2o_start'][:] for group in groups]
259 | para_bounds = [[(sum(para_len[:para_idx]), sum(para_len[:para_idx+1])) for
260 | para_idx in range(len(para_len))] for para_len in para_lens]
261 | para_idxs = []
262 | for para_bound, start_idx, f2o in zip(para_bounds, start_idxs, f2o_start):
263 | para_bound = np.array(para_bound)
264 | curr_idx = ((f2o[start_idx] >= para_bound[:,0]) & (f2o[start_idx] < para_bound[:,1])).nonzero()[0][0]
265 | para_idxs.append(curr_idx)
266 | para_startend = [para_bound[para_idx] for para_bound, para_idx in zip(para_bounds, para_idxs)]
267 |
268 | # 1) TF-IDF based paragraph score
269 | q_spvecs = self.doc_rank_fn['spvec'](q_texts) # Spvec
270 | qtf_ids = [np.array(q) for q in q_spvecs[1]]
271 | qtf_vals = [np.array(q) for q in q_spvecs[0]]
272 | tfidf_groups = [self.get_tfidf_group(doc_idx) if doc_idx >= 0 else self.get_tfidf_group(default_doc)
273 | for doc_idx in doc_idxs]
274 | tfidf_groups = [group[str(para_idx)] for group, para_idx in zip(tfidf_groups, para_idxs)]
275 | ptf_ids = [data['idxs'][:] for data in tfidf_groups]
276 | ptf_vals = [data['vals'][:] for data in tfidf_groups]
277 | tf_scores = self.sparse_bmm(qtf_ids, qtf_vals, ptf_ids, ptf_vals) * sparse_weight
278 |
279 | # 2) Sparse vectors based paragraph score
280 | q_ids, q_unis, q_bis = q_sparses
281 | q_ids = [np.array(q) for q in q_ids]
282 | q_unis = [np.array(q) for q in q_unis]
283 | q_bis = [np.array(q)[:-1] for q in q_bis]
284 | p_ids_tmp = [group['input_ids'][:] for group in groups]
285 | p_unis_tmp = [group['sparse'][:, :] for group in groups]
286 | p_bis_tmp = [group['sparse_bi'][:, :] for group in groups]
287 | p_ids = [sparse_id[p_se[0]:p_se[1]]
288 | for sparse_id, p_se in zip(p_ids_tmp, para_startend)]
289 | p_unis = [self.dequant(groups[0], sparse_val[start_idx,:p_se[1]-p_se[0]], attr='sparse')
290 | for sparse_val, p_se, start_idx in zip(p_unis_tmp, para_startend, start_idxs)]
291 | p_bis = [self.dequant(groups[0], sparse_bi_val[start_idx,:p_se[1]-p_se[0]-1], attr='sparse')
292 | for sparse_bi_val, p_se, start_idx in zip(p_bis_tmp, para_startend, start_idxs)]
293 | sp_scores = self.sparse_bmm(q_ids, q_unis, p_ids, p_unis)
294 |
295 | # For bigram
296 | MAXV = 30522
297 | q_bids = [np.array([a*MAXV+b for a, b in zip(q_id[:-1], q_id[1:])]) for q_id in q_ids]
298 | p_bids = [np.array([a*MAXV+b for a, b in zip(p_id[:-1], p_id[1:])]) for p_id in p_ids]
299 | sp_scores += self.sparse_bmm(q_bids, q_bis, p_bids, p_bis)
300 |
301 | return np.reshape(tf_scores + sp_scores, [num_queries, -1])
302 |
303 | def search_start(self, query_start, sparse_query, q_texts=None,
304 | nprobe=16, doc_top_k=5, start_top_k=100, mid_top_k=20, top_k=5,
305 | search_strategy='dense_first', sparse_weight=0.05, no_para=False):
306 |
307 | assert self.start_index is not None
308 | query_start = query_start.astype(np.float32)
309 | batch_size = query_start.shape[0]
310 | # start_time = time()
311 |
312 | # 1) Branch based on the strategy (start_top_k) + doc_score
313 | if search_strategy == 'dense_first':
314 | (doc_idxs, start_idxs), start_scores = self.search_dense(
315 | q_texts, query_start, start_top_k, nprobe, sparse_weight
316 | )
317 | elif search_strategy == 'sparse_first':
318 | (doc_idxs, start_idxs), start_scores = self.search_sparse(
319 | q_texts, query_start, doc_top_k, start_top_k, sparse_weight
320 | )
321 | elif search_strategy == 'hybrid':
322 | (doc_idxs, start_idxs), start_scores = self.search_dense(
323 | q_texts, query_start, start_top_k, nprobe, sparse_weight
324 | )
325 | (doc_idxs_, start_idxs_), start_scores_ = self.search_sparse(
326 | q_texts, query_start, doc_top_k, start_top_k, sparse_weight
327 | )
328 |
329 | # There could be a duplicate but it's difficult to remove
330 | doc_idxs = np.concatenate([doc_idxs, doc_idxs_], -1)
331 | start_idxs = np.concatenate([start_idxs, start_idxs_], -1)
332 | start_scores = np.concatenate([start_scores, start_scores_], -1)
333 | else:
334 | raise ValueError(search_strategy)
335 |
336 | # 2) Rerank and reduce (mid_top_k)
337 | rerank_idxs = np.argsort(start_scores, axis=1)[:,-mid_top_k:][:,::-1]
338 | doc_idxs = doc_idxs.tolist()
339 | start_idxs = start_idxs.tolist()
340 | start_scores = start_scores.tolist()
341 | for b_idx in range(batch_size):
342 | doc_idxs[b_idx] = np.array(doc_idxs[b_idx])[rerank_idxs[b_idx]]
343 | start_idxs[b_idx] = np.array(start_idxs[b_idx])[rerank_idxs[b_idx]]
344 | start_scores[b_idx] = np.array(start_scores[b_idx])[rerank_idxs[b_idx]]
345 |
346 | # logger.info(f'1st rerank ({start_top_k} => {mid_top_k}), {np.array(start_scores).shape}, {time()-start_time}')
347 | # start_time = time()
348 |
349 | # Para-level sparse score
350 | if not no_para:
351 | par_scores = self.batch_par_scores(q_texts, sparse_query, doc_idxs, start_idxs, sparse_weight, mid_top_k)
352 | start_scores = np.stack(start_scores) + par_scores
353 | start_scores = [s for s in start_scores]
354 |
355 | # 3) Rerank and reduce (top_k)
356 | rerank_idxs = np.argsort(start_scores, axis=1)[:,-top_k:][:,::-1]
357 | for b_idx in range(batch_size):
358 | doc_idxs[b_idx] = doc_idxs[b_idx][rerank_idxs[b_idx]]
359 | start_idxs[b_idx] = start_idxs[b_idx][rerank_idxs[b_idx]]
360 | start_scores[b_idx] = start_scores[b_idx][rerank_idxs[b_idx]]
361 |
362 | doc_idxs = np.stack(doc_idxs)
363 | start_idxs = np.stack(start_idxs)
364 | start_scores = np.stack(start_scores)
365 |
366 | # logger.info(f'2nd rerank ({mid_top_k} => {top_k}), {start_scores.shape}, {time()-start_time}')
367 | return start_scores, doc_idxs, start_idxs
368 |
369 | def search_end(self, query, doc_idxs, start_idxs, start_scores=None, top_k=5, max_answer_length=20):
370 | # Reshape for end
371 | num_queries = query.shape[0]
372 | query = np.reshape(np.tile(np.expand_dims(query, 1), [1, top_k, 1]), [-1, query.shape[1]])
373 | q_idxs = np.reshape(np.tile(np.expand_dims(np.arange(num_queries), 1), [1, top_k]), [-1])
374 | doc_idxs = np.reshape(doc_idxs, [-1])
375 | start_idxs = np.reshape(start_idxs, [-1])
376 | start_scores = np.reshape(start_scores, [-1])
377 |
378 | # Get query_end and groups
379 | bs = int((query.shape[1] - 1) / 2) # Boundary of start
380 | query_end, query_span_logit = query[:,bs:2*bs], query[:,-1:]
381 | default_doc = [doc_idx for doc_idx in doc_idxs if doc_idx >= 0][0]
382 | groups = [self.get_doc_group(doc_idx) if doc_idx >= 0 else self.get_doc_group(default_doc)
383 | for doc_idx in doc_idxs]
384 | ends = [group['end'][:] for group in groups]
385 | spans = [group['span_logits'][:] for group in groups]
386 | default_end = np.zeros(bs).astype(np.float32)
387 |
388 | # Calculate end
389 | end_idxs = [group['start2end'][start_idx, :max_answer_length]
390 | for group, start_idx in zip(groups, start_idxs)] # [Q, L]
391 | end_mask = -1e9 * (np.array(end_idxs) < 0) # [Q, L]
392 |
393 | end = np.stack([[each_end[each_end_idx, :] if each_end.size > 0 else default_end
394 | for each_end_idx in each_end_idxs]
395 | for each_end, each_end_idxs in zip(ends, end_idxs)], 0) # [Q, L, d]
396 | end = self.dequant(groups[0], end)
397 | span = np.stack([[each_span[start_idx, i] for i in range(len(each_end_idxs))]
398 | for each_span, start_idx, each_end_idxs in zip(spans, start_idxs, end_idxs)], 0) # [Q, L]
399 |
400 | with torch.no_grad():
401 | end = torch.FloatTensor(end).to(self.device)
402 | query_end = torch.FloatTensor(query_end).to(self.device)
403 | end_scores = (query_end.unsqueeze(1) * end).sum(2).cpu().numpy()
404 | span_scores = query_span_logit * span # [Q, L]
405 | scores = np.expand_dims(start_scores, 1) + end_scores + span_scores + end_mask # [Q, L]
406 | pred_end_idxs = np.stack([each[idx] for each, idx in zip(end_idxs, np.argmax(scores, 1))], 0) # [Q]
407 | max_scores = np.max(scores, 1)
408 |
409 | # Get answers
410 | out = [{'context': group.attrs['context'], 'title': group.attrs['title'], 'doc_idx': doc_idx,
411 | 'start_pos': group['word2char_start'][group['f2o_start'][start_idx]].item(),
412 | 'end_pos': (group['word2char_end'][group['f2o_end'][end_idx]].item() if len(group['word2char_end']) > 0
413 | else group['word2char_start'][group['f2o_start'][start_idx]].item() + 1),
414 | 'start_idx': start_idx, 'end_idx': end_idx, 'score': score}
415 | for doc_idx, group, start_idx, end_idx, score in zip(doc_idxs.tolist(), groups, start_idxs.tolist(),
416 | pred_end_idxs.tolist(), max_scores.tolist())]
417 | for each in out:
418 | each['answer'] = each['context'][each['start_pos']:each['end_pos']]
419 | out = [self.adjust(each) for each in out]
420 |
421 | # Sort output
422 | new_out = [[] for _ in range(num_queries)]
423 | for idx, each_out in zip(q_idxs, out):
424 | new_out[idx].append(each_out)
425 | for i in range(len(new_out)):
426 | new_out[i] = sorted(new_out[i], key=lambda each_out: -each_out['score'])
427 | new_out[i] = list(filter(lambda x: x['score'] > -1e5, new_out[i])) # In case of no output but masks
428 | return new_out
429 |
430 | def filter_results(self, results):
431 | out = []
432 | for result in results:
433 | c = Counter(result['context'])
434 | if c['?'] > 3:
435 | continue
436 | if c['!'] > 5:
437 | continue
438 | out.append(result)
439 | return out
440 |
441 | def search(self, query, sparse_query, q_texts=None,
442 | nprobe=256, doc_top_k=5, start_top_k=1000, mid_top_k=100, top_k=10,
443 | search_strategy='dense_first', filter_=False, aggregate=False, return_idxs=False,
444 | max_answer_length=20, sparse_weight=0.05, no_para=False):
445 |
446 | # Search start
447 | start_scores, doc_idxs, start_idxs = self.search_start(
448 | query[:, :int((query.shape[1] -1) / 2)],
449 | sparse_query,
450 | q_texts=q_texts,
451 | nprobe=nprobe,
452 | doc_top_k=doc_top_k,
453 | start_top_k=start_top_k,
454 | mid_top_k=mid_top_k,
455 | top_k=top_k,
456 | search_strategy=search_strategy,
457 | sparse_weight=sparse_weight,
458 | no_para=no_para,
459 | )
460 |
461 | # start_time = time()
462 | # Search end
463 | outs = self.search_end(
464 | query, doc_idxs, start_idxs, start_scores=start_scores,
465 | top_k=top_k, max_answer_length=max_answer_length
466 | )
467 | # logger.info(f'last rerank ({top_k}), {len(outs)}, {time()-start_time}')
468 |
469 | if filter_:
470 | outs = [self.filter_results(results) for results in outs]
471 | if return_idxs:
472 | return [[(out_['doc_idx'], out_['start_idx'], out_['end_idx'], out_['answer']) for out_ in out ] for out in outs]
473 | if doc_idxs.shape[1] != top_k:
474 | logger.info(f"Warning.. {doc_idxs.shape[1]} only retrieved")
475 | top_k = doc_idxs.shape[1]
476 |
477 | return outs
478 |
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.nn.utils import clip_grad_norm_
21 |
22 | def warmup_cosine(x, warmup=0.002):
23 | if x < warmup:
24 | return x/warmup
25 | return 0.5 * (1.0 + torch.cos(math.pi * x))
26 |
27 | def warmup_constant(x, warmup=0.002):
28 | if x < warmup:
29 | return x/warmup
30 | return 1.0
31 |
32 | def warmup_linear(x, warmup=0.002):
33 | if x < warmup:
34 | return x/warmup
35 | return 1.0 - x
36 |
37 | SCHEDULES = {
38 | 'warmup_cosine':warmup_cosine,
39 | 'warmup_constant':warmup_constant,
40 | 'warmup_linear':warmup_linear,
41 | }
42 |
43 |
44 | class BERTAdam(Optimizer):
45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ).
46 | Params:
47 | lr: learning rate
48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
49 | t_total: total number of training steps for the learning
50 | rate schedule, -1 means constant learning rate. Default: -1
51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
52 | b1: Adams b1. Default: 0.9
53 | b2: Adams b2. Default: 0.999
54 | e: Adams epsilon. Default: 1e-6
55 | weight_decay_rate: Weight decay. Default: 0.01
56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
57 | """
58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
60 | max_grad_norm=1.0):
61 | if not lr >= 0.0:
62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
63 | if schedule not in SCHEDULES:
64 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
65 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
67 | if not 0.0 <= b1 < 1.0:
68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
69 | if not 0.0 <= b2 < 1.0:
70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
71 | if not e >= 0.0:
72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
75 | max_grad_norm=max_grad_norm)
76 | super(BERTAdam, self).__init__(params, defaults)
77 |
78 | def get_lr(self):
79 | lr = []
80 | for group in self.param_groups:
81 | for p in group['params']:
82 | state = self.state[p]
83 | if len(state) == 0:
84 | return [0]
85 | if group['t_total'] != -1:
86 | schedule_fct = SCHEDULES[group['schedule']]
87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
88 | else:
89 | lr_scheduled = group['lr']
90 | lr.append(lr_scheduled)
91 | return lr
92 |
93 | def step(self, closure=None):
94 | """Performs a single optimization step.
95 |
96 | Arguments:
97 | closure (callable, optional): A closure that reevaluates the model
98 | and returns the loss.
99 | """
100 | loss = None
101 | if closure is not None:
102 | loss = closure()
103 |
104 | for group in self.param_groups:
105 | for p in group['params']:
106 | if p.grad is None:
107 | continue
108 | grad = p.grad.data
109 | if grad.is_sparse:
110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
111 |
112 | state = self.state[p]
113 |
114 | # State initialization
115 | if len(state) == 0:
116 | state['step'] = 0
117 | # Exponential moving average of gradient values
118 | state['next_m'] = torch.zeros_like(p.data)
119 | # Exponential moving average of squared gradient values
120 | state['next_v'] = torch.zeros_like(p.data)
121 |
122 | next_m, next_v = state['next_m'], state['next_v']
123 | beta1, beta2 = group['b1'], group['b2']
124 |
125 | # Add grad clipping
126 | if group['max_grad_norm'] > 0:
127 | clip_grad_norm_(p, group['max_grad_norm'])
128 |
129 | # Decay the first and second moment running average coefficient
130 | # In-place operations to update the averages at the same time
131 | next_m.mul_(beta1).add_(1 - beta1, grad)
132 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
133 | update = next_m / (next_v.sqrt() + group['e'])
134 |
135 | # Just adding the square of the weights to the loss function is *not*
136 | # the correct way of using L2 regularization/weight decay with Adam,
137 | # since that will interact with the m and v parameters in strange ways.
138 | #
139 | # Instead we want ot decay the weights in a manner that doesn't interact
140 | # with the m/v parameters. This is equivalent to adding the square
141 | # of the weights to the loss with plain (non-momentum) SGD.
142 | if group['weight_decay_rate'] > 0.0:
143 | update += group['weight_decay_rate'] * p.data
144 |
145 | if group['t_total'] != -1:
146 | schedule_fct = SCHEDULES[group['schedule']]
147 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
148 | else:
149 | lr_scheduled = group['lr']
150 |
151 | update_with_lr = lr_scheduled * update
152 | p.data.add_(-update_with_lr)
153 |
154 | state['step'] += 1
155 |
156 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
157 | # bias_correction1 = 1 - beta1 ** state['step']
158 | # bias_correction2 = 1 - beta2 ** state['step']
159 |
160 | return loss
161 |
--------------------------------------------------------------------------------
/post.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import json
3 | import logging
4 | import os
5 | import shutil
6 | import torch
7 | import math
8 | import pandas as pd
9 | import numpy as np
10 | import six
11 | import h5py
12 | from multiprocessing import Process
13 | from time import time
14 | from scipy.sparse import csr_matrix, save_npz, hstack, vstack
15 | from termcolor import colored, cprint
16 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
17 | from eval_utils import normalize_answer, f1_score, exact_match_score
18 |
19 | from multiprocessing import Queue
20 | from multiprocessing.pool import ThreadPool
21 | from threading import Thread
22 |
23 | from tqdm import tqdm as tqdm_
24 | from decimal import *
25 |
26 | import tokenization
27 |
28 | QuestionResult = collections.namedtuple("QuestionResult",
29 | ['qas_id', 'start', 'end', 'sparse', 'input_ids'])
30 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
31 | "NbestPrediction", ["text", "logit", "no_answer_logit"])
32 |
33 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
34 | datefmt='%m/%d/%Y %H:%M:%S',
35 | level=logging.INFO)
36 | logger = logging.getLogger(__name__)
37 |
38 | # For debugging
39 | quant_stat = {}
40 | b_quant_stat = {}
41 | ranker = None
42 |
43 |
44 | def tqdm(*args, mininterval=5.0, **kwargs):
45 | return tqdm_(*args, mininterval=mininterval, **kwargs)
46 |
47 |
48 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
49 | orig_answer_text):
50 | """Returns tokenized answer spans that better match the annotated answer."""
51 |
52 | # The SQuAD annotations are character based. We first project them to
53 | # whitespace-tokenized words. But then after WordPiece tokenization, we can
54 | # often find a "better match". For example:
55 | #
56 | # Question: What year was John Smith born?
57 | # Context: The leader was John Smith (1895-1943).
58 | # Answer: 1895
59 | #
60 | # The original whitespace-tokenized answer will be "(1895-1943).". However
61 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
62 | # the exact answer, 1895.
63 | #
64 | # However, this is not always possible. Consider the following:
65 | #
66 | # Question: What country is the top exporter of electornics?
67 | # Context: The Japanese electronics industry is the lagest in the world.
68 | # Answer: Japan
69 | #
70 | # In this case, the annotator chose "Japan" as a character sub-span of
71 | # the word "Japanese". Since our WordPiece tokenizer does not split
72 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
73 | # in SQuAD, but does happen.
74 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
75 |
76 | for new_start in range(input_start, input_end + 1):
77 | for new_end in range(input_end, new_start - 1, -1):
78 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
79 | if text_span == tok_answer_text:
80 | return (new_start, new_end)
81 |
82 | return (input_start, input_end)
83 |
84 |
85 | def _check_is_max_context(doc_spans, cur_span_index, position):
86 | """Check if this is the 'max context' doc span for the token."""
87 |
88 | # Because of the sliding window approach taken to scoring documents, a single
89 | # token can appear in multiple documents. E.g.
90 | # Doc: the man went to the store and bought a gallon of milk
91 | # Span A: the man went to the
92 | # Span B: to the store and bought
93 | # Span C: and bought a gallon of
94 | # ...
95 | #
96 | # Now the word 'bought' will have two scores from spans B and C. We only
97 | # want to consider the score with "maximum context", which we define as
98 | # the *minimum* of its left and right context (the *sum* of left and
99 | # right context will always be the same, of course).
100 | #
101 | # In the example the maximum context for 'bought' would be span C since
102 | # it has 1 left context and 3 right context, while span B has 4 left context
103 | # and 0 right context.
104 | best_score = None
105 | best_span_index = None
106 | for (span_index, doc_span) in enumerate(doc_spans):
107 | end = doc_span.start + doc_span.length - 1
108 | if position < doc_span.start:
109 | continue
110 | if position > end:
111 | continue
112 | num_left_context = position - doc_span.start
113 | num_right_context = end - position
114 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
115 | if best_score is None or score > best_score:
116 | best_score = score
117 | best_span_index = span_index
118 |
119 | return cur_span_index == best_span_index
120 |
121 |
122 | def get_metadata(id2example, features, results, max_answer_length, do_lower_case, verbose_logging):
123 | start = np.concatenate([result.start[1:len(feature.tokens) - 1] for feature, result in zip(features, results)],
124 | axis=0)
125 | end = np.concatenate([result.end[1:len(feature.tokens) - 1] for feature, result in zip(features, results)], axis=0)
126 |
127 | input_ids = None
128 | sparse_map = None
129 | sparse_bi_map = None
130 | sparse_tri_map = None
131 | len_per_para = []
132 | if results[0].start_sp is not None:
133 | input_ids = np.concatenate([f.input_ids[1:len(f.tokens) - 1] for f in features], axis=0)
134 | sparse_features = None # uni
135 | sparse_bi_features = None
136 | sparse_tri_features = None
137 | if '1' in results[0].start_sp:
138 | sparse_features = [result.start_sp['1'][1:len(feature.tokens)-1, 1:len(feature.tokens)-1]
139 | for feature, result in zip(features, results)]
140 | if '2' in results[0].start_sp:
141 | sparse_bi_features = [result.start_sp['2'][1:len(feature.tokens)-1, 1:len(feature.tokens)-1]
142 | for feature, result in zip(features, results)]
143 | if '3' in results[0].start_sp:
144 | sparse_tri_features = [result.start_sp['3'][1:len(feature.tokens)-1, 1:len(feature.tokens)-1]
145 | for feature, result in zip(features, results)]
146 |
147 | map_size = max([k.shape[0] for k in sparse_features])
148 | sparse_map = np.zeros((input_ids.shape[0], map_size), dtype=np.float32)
149 | if sparse_bi_features is not None:
150 | sparse_bi_map = np.zeros((input_ids.shape[0], map_size), dtype=np.float32)
151 | if sparse_tri_features is not None:
152 | sparse_tri_map = np.zeros((input_ids.shape[0], map_size), dtype=np.float32)
153 |
154 | curr_size = 0
155 | for sidx, sparse_feature in enumerate(sparse_features):
156 | sparse_map[curr_size:curr_size + sparse_feature.shape[0],:sparse_feature.shape[1]] += sparse_feature
157 | if sparse_bi_features is not None:
158 | assert sparse_bi_features[sidx].shape == sparse_feature.shape
159 | sparse_bi_map[curr_size:curr_size + sparse_bi_features[sidx].shape[0],:sparse_bi_features[sidx].shape[1]] += \
160 | sparse_bi_features[sidx]
161 | if sparse_tri_features is not None:
162 | assert sparse_tri_features[sidx].shape == sparse_feature.shape
163 | sparse_tri_map[curr_size:curr_size + sparse_tri_features[sidx].shape[0],:sparse_tri_features[sidx].shape[1]] += \
164 | sparse_tri_features[sidx]
165 | curr_size += sparse_feature.shape[0]
166 | len_per_para.append(sparse_feature.shape[0])
167 |
168 | assert input_ids.shape[0] == start.shape[0] and curr_size == sparse_map.shape[0]
169 |
170 | fs = np.concatenate([result.filter_start_logits[1:len(feature.tokens) - 1]
171 | for feature, result in zip(features, results)],
172 | axis=0)
173 | fe = np.concatenate([result.filter_end_logits[1:len(feature.tokens) - 1]
174 | for feature, result in zip(features, results)],
175 | axis=0)
176 |
177 | span_logits = np.zeros([np.shape(start)[0], max_answer_length], dtype=start.dtype)
178 | start2end = -1 * np.ones([np.shape(start)[0], max_answer_length], dtype=np.int32)
179 | idx = 0
180 | for feature, result in zip(features, results):
181 | for i in range(1, len(feature.tokens) - 1):
182 | for j in range(i, min(i + max_answer_length, len(feature.tokens) - 1)):
183 | span_logits[idx, j - i] = result.span_logits[i, j]
184 | start2end[idx, j - i] = idx + j - i
185 | idx += 1
186 |
187 | word2char_start = np.zeros([start.shape[0]], dtype=np.int32)
188 | word2char_end = np.zeros([start.shape[0]], dtype=np.int32)
189 |
190 | sep = ' [PAR] '
191 | full_text = ""
192 | prev_example = None
193 |
194 | word_pos = 0
195 | for feature in features:
196 | example = id2example[feature.unique_id]
197 | if prev_example is not None and feature.doc_span_index == 0:
198 | full_text = full_text + ' '.join(prev_example.doc_words) + sep
199 |
200 | for i in range(1, len(feature.tokens) - 1):
201 | _, start_pos, _ = get_final_text_(example, feature, i, min(len(feature.tokens) - 2, i + 1), do_lower_case,
202 | verbose_logging)
203 | _, _, end_pos = get_final_text_(example, feature, max(1, i - 1), i, do_lower_case,
204 | verbose_logging)
205 | start_pos += len(full_text)
206 | end_pos += len(full_text)
207 | word2char_start[word_pos] = start_pos
208 | word2char_end[word_pos] = end_pos
209 | word_pos += 1
210 | prev_example = example
211 | full_text = full_text + ' '.join(prev_example.doc_words)
212 |
213 | metadata = {'did': prev_example.doc_idx, 'context': full_text, 'title': prev_example.title,
214 | 'start': start, 'end': end, 'span_logits': span_logits,
215 | 'start2end': start2end,
216 | 'word2char_start': word2char_start, 'word2char_end': word2char_end,
217 | 'filter_start': fs, 'filter_end': fe, 'input_ids': input_ids,
218 | 'sparse': sparse_map, 'sparse_bi': sparse_bi_map, 'sparse_tri': sparse_tri_map,
219 | 'len_per_para': len_per_para}
220 |
221 | return metadata
222 |
223 |
224 | def filter_metadata(metadata, threshold):
225 | start_idxs, = np.where(metadata['filter_start'] > threshold)
226 | end_idxs, = np.where(metadata['filter_end'] > threshold)
227 | end_long2short = {long: short for short, long in enumerate(end_idxs)}
228 | metadata['start'] = metadata['start'][start_idxs]
229 | metadata['end'] = metadata['end'][end_idxs]
230 | metadata['sparse'] = metadata['sparse'][start_idxs]
231 | if metadata['sparse_bi'] is not None:
232 | metadata['sparse_bi'] = metadata['sparse_bi'][start_idxs]
233 | if metadata['sparse_tri'] is not None:
234 | metadata['sparse_tri'] = metadata['sparse_tri'][start_idxs]
235 | metadata['f2o_start'] = start_idxs
236 | metadata['f2o_end'] = end_idxs
237 | metadata['span_logits'] = metadata['span_logits'][start_idxs]
238 | metadata['start2end'] = metadata['start2end'][start_idxs]
239 | for i, each in enumerate(metadata['start2end']):
240 | for j, long in enumerate(each.tolist()):
241 | metadata['start2end'][i, j] = end_long2short[long] if long in end_long2short else -1
242 |
243 | return metadata
244 |
245 |
246 | def compress_metadata(metadata, dense_offset, dense_scale, sparse_offset, sparse_scale):
247 | for key in ['start', 'end']:
248 | if key in metadata:
249 | metadata[key] = float_to_int8(metadata[key], dense_offset, dense_scale)
250 | for key in ['sparse', 'sparse_bi', 'sparse_tri']:
251 | if key in metadata and metadata[key] is not None:
252 | metadata[key] = float_to_int8(metadata[key], sparse_offset, sparse_scale)
253 | return metadata
254 |
255 |
256 | def pool_func(item):
257 | metadata_ = get_metadata(*item[:-1])
258 | metadata_ = filter_metadata(metadata_, item[-1])
259 | return metadata_
260 |
261 |
262 | def write_hdf5(all_examples, all_features, all_results,
263 | max_answer_length, do_lower_case, hdf5_path, filter_threshold, verbose_logging,
264 | dense_offset=None, dense_scale=None, sparse_offset=None, sparse_scale=None, use_sparse=False):
265 | assert len(all_examples) > 0
266 |
267 | id2feature = {feature.unique_id: feature for feature in all_features}
268 | id2example = {id_: all_examples[id2feature[id_].example_index] for id_ in id2feature}
269 |
270 | def add(inqueue_, outqueue_):
271 | for item in iter(inqueue_.get, None):
272 | args = list(item[:3]) + [max_answer_length, do_lower_case, verbose_logging, filter_threshold]
273 | out = pool_func(args)
274 | outqueue_.put(out)
275 |
276 | outqueue_.put(None)
277 |
278 | def write(outqueue_):
279 | with h5py.File(hdf5_path) as f:
280 | while True:
281 | metadata = outqueue_.get()
282 | if metadata:
283 | did = str(metadata['did'])
284 | if did in f:
285 | logger.info('%s exists; replacing' % did)
286 | del f[did]
287 | dg = f.create_group(did)
288 |
289 | dg.attrs['context'] = metadata['context']
290 | dg.attrs['title'] = metadata['title']
291 | if dense_offset is not None:
292 | metadata = compress_metadata(metadata, dense_offset, dense_scale, sparse_offset, sparse_scale)
293 | dg.attrs['offset'] = dense_offset
294 | dg.attrs['scale'] = dense_scale
295 | dg.attrs['sparse_offset'] = sparse_offset
296 | dg.attrs['sparse_scale'] = sparse_scale
297 | dg.create_dataset('start', data=metadata['start'])
298 | dg.create_dataset('end', data=metadata['end'])
299 | if metadata['sparse'] is not None:
300 | dg.create_dataset('sparse', data=metadata['sparse'])
301 | if metadata['sparse_bi'] is not None:
302 | dg.create_dataset('sparse_bi', data=metadata['sparse_bi'])
303 | if metadata['sparse_tri'] is not None:
304 | dg.create_dataset('sparse_tri', data=metadata['sparse_tri'])
305 | dg.create_dataset('input_ids', data=metadata['input_ids'])
306 | dg.create_dataset('len_per_para', data=metadata['len_per_para'])
307 | dg.create_dataset('span_logits', data=metadata['span_logits'])
308 | dg.create_dataset('start2end', data=metadata['start2end'])
309 | dg.create_dataset('word2char_start', data=metadata['word2char_start'])
310 | dg.create_dataset('word2char_end', data=metadata['word2char_end'])
311 | dg.create_dataset('f2o_start', data=metadata['f2o_start'])
312 | dg.create_dataset('f2o_end', data=metadata['f2o_end'])
313 |
314 | else:
315 | break
316 |
317 | features = []
318 | results = []
319 | inqueue = Queue(maxsize=500)
320 | outqueue = Queue(maxsize=500)
321 | write_p = Thread(target=write, args=(outqueue,))
322 | p = Thread(target=add, args=(inqueue, outqueue))
323 | write_p.start()
324 | p.start()
325 |
326 | start_time = time()
327 | for count, result in enumerate(tqdm(all_results, total=len(all_features))):
328 | example = id2example[result.unique_id]
329 | feature = id2feature[result.unique_id]
330 | condition = len(features) > 0 and example.par_idx == 0 and feature.doc_span_index == 0
331 |
332 | if condition:
333 | in_ = (id2example, features, results)
334 | logger.info('inqueue size: %d, outqueue size: %d' % (inqueue.qsize(), outqueue.qsize()))
335 | inqueue.put(in_)
336 | # add(id2example, features, results)
337 | features = [feature]
338 | results = [result]
339 | else:
340 | features.append(feature)
341 | results.append(result)
342 | if count % 500 == 0:
343 | logger.info('%d/%d at %.1f' % (count + 1, len(all_features), time() - start_time))
344 | in_ = (id2example, features, results)
345 | inqueue.put(in_)
346 | inqueue.put(None)
347 | p.join()
348 | write_p.join()
349 |
350 | import collections
351 | b_stats = collections.OrderedDict(sorted(b_quant_stat.items()))
352 | stats = collections.OrderedDict(sorted(quant_stat.items()))
353 | for k, v in b_stats.items():
354 | print(k, v)
355 | for k, v in stats.items():
356 | print(k, v)
357 |
358 |
359 | def write_embed(all_examples, all_features, all_results, max_answer_length, do_lower_case, embed_path):
360 | assert len(all_examples) > 0
361 |
362 | id2feature = {feature.unique_id: feature for feature in all_features}
363 | id2example = {id_: all_examples[id2feature[id_].example_index] for id_ in id2feature}
364 | features = []
365 | results = []
366 | outs = []
367 |
368 | def write(features, results):
369 | out_json = {'uni': {}, 'bi': {}}
370 | for par_index in range(len(features)):
371 | start_index = 0
372 | for vidx, (v1, v2, v3) in enumerate(zip(
373 | features[par_index].tokens[1:-1],
374 | results[par_index].start_sp['1'][start_index][1:len(features[par_index].tokens)-1],
375 | features[par_index].input_ids[1:-1]
376 | )):
377 | if vidx == start_index-1:
378 | cprint('{}({:.3f}, {})'.format(v1, v2, vidx), 'green', end=' ')
379 | continue
380 | if v1 not in out_json['uni']:
381 | out_json['uni'][v1] = {'score': v2.item(), 'vocab': str(v3)}
382 | else:
383 | out_json['uni'][v1]['score'] += v2.item()
384 | if v2 > 1.0:
385 | cprint('{}({:.3f}, {})'.format(v1, v2, vidx), 'red', end=' ')
386 | else:
387 | print('{}({:.3f}, {})'.format(v1, v2, vidx), end=' ')
388 | print()
389 |
390 | for vidx, (v1, v2, v3) in enumerate(zip(
391 | zip(features[par_index].tokens[1:-2], features[par_index].tokens[2:-1]),
392 | results[par_index].start_sp['2'][start_index][1:len(features[par_index].tokens)-2],
393 | zip(features[par_index].input_ids[1:-2], features[par_index].input_ids[2:-1])
394 | )):
395 | v1 = ' '.join(v1)
396 | if vidx == start_index-1:
397 | cprint('{}({:.3f}, {})'.format(v1, v2, vidx), 'green', end=' ')
398 | continue
399 | if v1 not in out_json['bi']:
400 | out_json['bi'][v1] = {'score': v2.item(), 'vocab': ', '.join([str(k) for k in v3])}
401 | else:
402 | out_json['bi'][v1]['score'] += v2.item()
403 | if v2 > 1.0:
404 | cprint('{}({:.3f}, {})'.format(v1, v2, vidx), 'red', end=' ')
405 | else:
406 | print('{}({:.3f}, {})'.format(v1, v2, vidx), end=' ')
407 | print()
408 | print()
409 | for key, val in out_json.items():
410 | out_json[key] = dict(sorted(out_json[key].items(), key=lambda x: x[1]['score'], reverse=True))
411 | out_json[key] = dict(filter(lambda x: x[1]['score'] > 0, out_json[key].items()))
412 | return out_json
413 |
414 | for count, result in enumerate(tqdm(all_results, total=len(all_features))):
415 | example = id2example[result.unique_id]
416 | feature = id2feature[result.unique_id]
417 | condition = len(features) > 0 and example.par_idx == 0 and feature.doc_span_index == 0
418 |
419 | if condition:
420 | outs.append({'text': prev_example.paragraph_text, 'sparc': write(features, results)})
421 | features = [feature]
422 | results = [result]
423 | else:
424 | features.append(feature)
425 | results.append(result)
426 | prev_example = example
427 | outs.append({'text': prev_example.paragraph_text, 'sparc': write(features, results)})
428 |
429 | with open(embed_path, 'w') as f:
430 | json.dump({'out': outs}, f, indent=4)
431 |
432 |
433 | def write_predictions(all_examples, all_features, all_results,
434 | max_answer_length, do_lower_case, output_prediction_file,
435 | output_score_file, verbose_logging, threshold):
436 |
437 | id2feature = {feature.unique_id: feature for feature in all_features}
438 | id2example = {id_: all_examples[id2feature[id_].example_index] for id_ in id2feature}
439 |
440 | token_count = 0
441 | vec_count = 0
442 | predictions = {}
443 | scores = {}
444 | loss = []
445 |
446 | for result in tqdm(all_results, total=len(all_features), desc='[Evaluation]'):
447 | loss += [result.loss]
448 | feature = id2feature[result.unique_id]
449 | example = id2example[result.unique_id]
450 | id_ = example.qas_id
451 |
452 | # Initial setting
453 | token_count += len(feature.tokens)
454 |
455 | for start_index in range(len(feature.tokens)):
456 | for end_index in range(start_index, min(len(feature.tokens), start_index + max_answer_length - 1)):
457 | if start_index not in feature.token_to_word_map:
458 | continue
459 | if end_index not in feature.token_to_word_map:
460 | continue
461 | if not feature.token_is_max_context.get(start_index, False):
462 | continue
463 | filter_start_logit = result.filter_start_logits[start_index]
464 | filter_end_logit = result.filter_end_logits[end_index]
465 |
466 | # Filter based on threshold (default: -2)
467 | if filter_start_logit < threshold or filter_end_logit < threshold:
468 | # orig_text, start_pos, end_pos = get_final_text_(example, feature, start_index, end_index,
469 | # do_lower_case, verbose_logging)
470 | # print('Filter: %s (%.2f, %.2f)'% (orig_text[start_pos:end_pos], filter_start_logit, filter_end_logit))
471 | continue
472 | else:
473 | # orig_text, start_pos, end_pos = get_final_text_(example, feature, start_index, end_index,
474 | # do_lower_case, verbose_logging)
475 | # print('Saved: %s (%.2f, %.2f)'% (orig_text[start_pos:end_pos], filter_start_logit, filter_end_logit))
476 | pass
477 |
478 | vec_count += 1
479 | score = result.all_logits[start_index, end_index]
480 |
481 | if id_ not in scores or score > scores[id_]:
482 | orig_text, start_pos, end_pos = get_final_text_(example, feature, start_index, end_index,
483 | do_lower_case, verbose_logging)
484 | # print('Saved: %s (%.2f, %.2f)'% (orig_text[start_pos:end_pos], filter_start_logit, filter_end_logit))
485 | phrase = orig_text[start_pos:end_pos]
486 | predictions[id_] = phrase
487 | scores[id_] = score.item()
488 |
489 | if id_ not in predictions:
490 | assert id_ not in scores
491 | logger.info('for %s, no answer found'% id_)
492 |
493 | logger.info('num vecs=%d, num_words=%d, nvpw=%.4f' % (vec_count, token_count, vec_count / token_count))
494 |
495 | with open(output_prediction_file, 'w') as fp:
496 | json.dump(predictions, fp)
497 |
498 | with open(output_score_file, 'w') as fp:
499 | json.dump({k: -v for (k, v) in scores.items()}, fp)
500 |
501 | return sum(loss) / len(loss)
502 |
503 |
504 | def get_question_results(question_examples, query_eval_features, question_dataloader, device, model):
505 | id2feature = {feature.unique_id: feature for feature in query_eval_features}
506 | id2example = {id_: question_examples[id2feature[id_].example_index] for id_ in id2feature}
507 | for (input_ids_, input_mask_, example_indices) in question_dataloader:
508 | input_ids_ = input_ids_.to(device)
509 | input_mask_ = input_mask_.to(device)
510 | with torch.no_grad():
511 | batch_start, batch_end, batch_sps, batch_eps = model(query_ids=input_ids_,
512 | query_mask=input_mask_)
513 | for i, example_index in enumerate(example_indices):
514 | start = batch_start[i].detach().cpu().numpy().astype(np.float16)
515 | end = batch_end[i].detach().cpu().numpy().astype(np.float16)
516 | sparse = None
517 | if len(batch_sps) > 0:
518 | sparse = {ng: bb_ssp[i].detach().cpu().numpy().astype(np.float16) for ng, bb_ssp in batch_sps.items()}
519 | query_eval_feature = query_eval_features[example_index.item()]
520 | unique_id = int(query_eval_feature.unique_id)
521 | qas_id = id2example[unique_id].qas_id
522 | yield QuestionResult(qas_id=qas_id,
523 | start=start,
524 | end=end,
525 | sparse=sparse,
526 | input_ids=query_eval_feature.input_ids[1:len(query_eval_feature.tokens_)-1])
527 |
528 |
529 | def write_question_results(question_results, question_features, path):
530 | with h5py.File(path, 'w') as f:
531 | for question_result, question_feature in tqdm(zip(question_results, question_features)):
532 | dummy_ones = np.ones((question_result.start.shape[0], 1))
533 | data = np.concatenate([question_result.start, question_result.end, dummy_ones], -1)
534 | f.create_dataset(question_result.qas_id, data=data)
535 |
536 |
537 | def convert_question_features_to_dataloader(query_eval_features, fp16, local_rank, predict_batch_size):
538 | all_input_ids_ = torch.tensor([f.input_ids for f in query_eval_features], dtype=torch.long)
539 | all_input_mask_ = torch.tensor([f.input_mask for f in query_eval_features], dtype=torch.long)
540 | all_example_index_ = torch.arange(all_input_ids_.size(0), dtype=torch.long)
541 | if fp16:
542 | all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_))
543 |
544 | question_data = TensorDataset(all_input_ids_, all_input_mask_, all_example_index_)
545 |
546 | if local_rank == -1:
547 | question_sampler = SequentialSampler(question_data)
548 | else:
549 | question_sampler = DistributedSampler(question_data)
550 | question_dataloader = DataLoader(question_data, sampler=question_sampler, batch_size=predict_batch_size)
551 | return question_dataloader
552 |
553 |
554 | def get_final_text_(example, feature, start_index, end_index, do_lower_case, verbose_logging):
555 | tok_tokens = feature.tokens[start_index:(end_index + 1)]
556 | orig_doc_start = feature.token_to_word_map[start_index]
557 | orig_doc_end = feature.token_to_word_map[end_index]
558 | orig_words = example.doc_words[orig_doc_start:(orig_doc_end + 1)]
559 | tok_text = " ".join(tok_tokens)
560 |
561 | # De-tokenize WordPieces that have been split off.
562 | tok_text = tok_text.replace(" ##", "")
563 | tok_text = tok_text.replace("##", "")
564 |
565 | # Clean whitespace
566 | tok_text = tok_text.strip()
567 | tok_text = " ".join(tok_text.split())
568 | orig_text = " ".join(orig_words)
569 | full_text = " ".join(example.doc_words)
570 |
571 | start_pos, end_pos = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) # TODO: need to check
572 | offset = sum(len(word) + 1 for word in example.doc_words[:orig_doc_start])
573 |
574 | return full_text, offset + start_pos, offset + end_pos
575 |
576 |
577 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
578 | """Project the tokenized prediction back to the original text."""
579 |
580 | # When we created the data, we kept track of the alignment between original
581 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
582 | # now `orig_text` contains the span of our original text corresponding to the
583 | # span that we predicted.
584 | #
585 | # However, `orig_text` may contain extra characters that we don't want in
586 | # our prediction.
587 | #
588 | # For example, let's say:
589 | # pred_text = steve smith
590 | # orig_text = Steve Smith's
591 | #
592 | # We don't want to return `orig_text` because it contains the extra "'s".
593 | #
594 | # We don't want to return `pred_text` because it's already been normalized
595 | # (the SQuAD eval script also does punctuation stripping/lower casing but
596 | # our tokenizer does additional normalization like stripping accent
597 | # characters).
598 | #
599 | # What we really want to return is "Steve Smith".
600 | #
601 | # Therefore, we have to apply a semi-complicated alignment heruistic between
602 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
603 | # can fail in certain cases in which case we just return `orig_text`.
604 | default_out = 0, len(orig_text)
605 |
606 | def _strip_spaces(text):
607 | ns_chars = []
608 | ns_to_s_map = collections.OrderedDict()
609 | for (i, c) in enumerate(text):
610 | if c == " ":
611 | continue
612 | ns_to_s_map[len(ns_chars)] = i
613 | ns_chars.append(c)
614 | ns_text = "".join(ns_chars)
615 | return (ns_text, ns_to_s_map)
616 |
617 | # We first tokenize `orig_text`, strip whitespace from the result
618 | # and `pred_text`, and check if they are the same length. If they are
619 | # NOT the same length, the heuristic has failed. If they are the same
620 | # length, we assume the characters are one-to-one aligned.
621 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
622 |
623 | tok_text = " ".join(tokenizer.tokenize(orig_text))
624 |
625 | start_position = tok_text.find(pred_text)
626 | if start_position == -1:
627 | if verbose_logging:
628 | logger.info(
629 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
630 | return default_out
631 | end_position = start_position + len(pred_text) - 1
632 |
633 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
634 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
635 |
636 | if len(orig_ns_text) != len(tok_ns_text):
637 | if verbose_logging:
638 | logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
639 | orig_ns_text, tok_ns_text)
640 | return default_out
641 |
642 | # We then project the characters in `pred_text` back to `orig_text` using
643 | # the character-to-character alignment.
644 | tok_s_to_ns_map = {}
645 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
646 | tok_s_to_ns_map[tok_index] = i
647 |
648 | orig_start_position = None
649 | if start_position in tok_s_to_ns_map:
650 | ns_start_position = tok_s_to_ns_map[start_position]
651 | if ns_start_position in orig_ns_to_s_map:
652 | orig_start_position = orig_ns_to_s_map[ns_start_position]
653 |
654 | if orig_start_position is None:
655 | if verbose_logging:
656 | logger.info("Couldn't map start position")
657 | return default_out
658 |
659 | orig_end_position = None
660 | if end_position in tok_s_to_ns_map:
661 | ns_end_position = tok_s_to_ns_map[end_position]
662 | if ns_end_position in orig_ns_to_s_map:
663 | orig_end_position = orig_ns_to_s_map[ns_end_position]
664 |
665 | if orig_end_position is None:
666 | if verbose_logging:
667 | logger.info("Couldn't map end position")
668 | return default_out
669 |
670 | # output_text = orig_text[orig_start_position:(orig_end_position + 1)]
671 | return orig_start_position, orig_end_position + 1
672 |
673 |
674 | def float_to_int8(num, offset, factor, keep_zeros=False):
675 | out = (num - offset) * factor
676 | out = out.clip(-128, 127)
677 | if keep_zeros:
678 | out = out * (num != 0.0).astype(np.int8)
679 | out = np.round(out).astype(np.int8)
680 | return out
681 |
682 |
683 | def int8_to_float(num, offset, factor, keep_zeros=False):
684 | if not keep_zeros:
685 | return num.astype(np.float32) / factor + offset
686 | else:
687 | return (num.astype(np.float32) / factor + offset) * (num != 0.0).astype(np.float32)
688 |
689 |
--------------------------------------------------------------------------------
/run_index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 | import faiss
7 | import h5py
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from post import int8_to_float
12 |
13 |
14 | def get_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('dump_dir')
17 | parser.add_argument('stage')
18 |
19 | parser.add_argument('--dump_paths', default=None,
20 | help='Relative to `dump_dir/phrase`. '
21 | 'If specified, creates subindex dir and save there with same name')
22 | parser.add_argument('--subindex_name', default='index', help='used only if dump_path is specified.')
23 | parser.add_argument('--offset', default=0, type=int)
24 |
25 | # relative paths in dump_dir/index_name
26 | parser.add_argument('--quantizer_path', default='quantizer.faiss')
27 | parser.add_argument('--max_norm_path', default='max_norm.json')
28 | parser.add_argument('--trained_index_path', default='trained.faiss')
29 | parser.add_argument('--index_path', default='index.faiss')
30 | parser.add_argument('--idx2id_path', default='idx2id.hdf5')
31 | parser.add_argument('--inv_path', default='merged.invdata')
32 |
33 | # Adding options
34 | parser.add_argument('--add_all', default=False, action='store_true')
35 |
36 | # coarse, fine, add
37 | parser.add_argument('--num_clusters', type=int, default=16384)
38 | parser.add_argument('--hnsw', default=False, action='store_true')
39 | parser.add_argument('--fine_quant', default='SQ8',
40 | help='SQ8|SQ4|PQ# where # is number of bytes per vector (for SQ it would be 480 Bytes)')
41 | # stable params
42 | parser.add_argument('--max_norm', default=None, type=float)
43 | parser.add_argument('--max_norm_cf', default=1.0, type=float)
44 | parser.add_argument('--norm_th', default=999, type=float)
45 | parser.add_argument('--para', default=False, action='store_true')
46 | parser.add_argument('--doc_sample_ratio', default=0.2, type=float)
47 | parser.add_argument('--vec_sample_ratio', default=0.2, type=float)
48 |
49 | parser.add_argument('--fs', default='local')
50 | parser.add_argument('--cuda', default=False, action='store_true')
51 | parser.add_argument('--num_dummy_zeros', default=0, type=int)
52 | parser.add_argument('--replace', default=False, action='store_true')
53 | parser.add_argument('--num_docs_per_add', default=1000, type=int)
54 |
55 | args = parser.parse_args()
56 |
57 | coarse = 'hnsw' if args.hnsw else 'flat'
58 | args.index_name = '%d_%s_%s' % (args.num_clusters, coarse, args.fine_quant)
59 |
60 | if args.fs == 'nfs':
61 | from nsml import NSML_NFS_OUTPUT
62 | args.dump_dir = os.path.join(NSML_NFS_OUTPUT, args.dump_dir)
63 | elif args.fs == 'nsml':
64 | pass
65 |
66 | args.index_dir = os.path.join(args.dump_dir, args.index_name)
67 |
68 | args.quantizer_path = os.path.join(args.index_dir, args.quantizer_path)
69 | args.max_norm_path = os.path.join(args.index_dir, args.max_norm_path)
70 | args.trained_index_path = os.path.join(args.index_dir, args.trained_index_path)
71 | args.inv_path = os.path.join(args.index_dir, args.inv_path)
72 |
73 | args.subindex_dir = os.path.join(args.index_dir, args.subindex_name)
74 | if args.dump_paths is None:
75 | args.index_path = os.path.join(args.index_dir, args.index_path)
76 | args.idx2id_path = os.path.join(args.index_dir, args.idx2id_path)
77 | else:
78 | args.dump_paths = [os.path.join(args.dump_dir, 'phrase', path) for path in args.dump_paths.split(',')]
79 | args.index_path = os.path.join(args.subindex_dir, '%d.faiss' % args.offset)
80 | args.idx2id_path = os.path.join(args.subindex_dir, '%d.hdf5' % args.offset)
81 |
82 | return args
83 |
84 |
85 | def sample_data(dump_paths, para=False, doc_sample_ratio=0.2, vec_sample_ratio=0.2, seed=29,
86 | max_norm=None, max_norm_cf=1.3, num_dummy_zeros=0, norm_th=999):
87 | vecs = []
88 | random.seed(seed)
89 | np.random.seed(seed)
90 | print('sampling from:')
91 | for dump_path in dump_paths:
92 | print(dump_path)
93 | dumps = [h5py.File(dump_path, 'r') for dump_path in dump_paths]
94 | for i, f in enumerate(tqdm(dumps)):
95 | doc_ids = list(f.keys())
96 | sampled_doc_ids = random.sample(doc_ids, int(doc_sample_ratio * len(doc_ids)))
97 | for doc_id in tqdm(sampled_doc_ids, desc='sampling from %d' % i):
98 | doc_group = f[doc_id]
99 | if para:
100 | groups = doc_group.values()
101 | else:
102 | groups = [doc_group]
103 | for group in groups:
104 | num_vecs, d = group['start'].shape
105 | if num_vecs == 0: continue
106 | sampled_vec_idxs = np.random.choice(num_vecs, int(vec_sample_ratio * num_vecs))
107 | cur_vecs = int8_to_float(group['start'][:],
108 | group.attrs['offset'], group.attrs['scale'])[sampled_vec_idxs]
109 | cur_vecs = cur_vecs[np.linalg.norm(cur_vecs, axis=1) <= norm_th]
110 | vecs.append(cur_vecs)
111 | out = np.concatenate(vecs, 0)
112 | for dump in dumps:
113 | dump.close()
114 |
115 | norms = np.linalg.norm(out, axis=1, keepdims=True)
116 | if max_norm is None:
117 | max_norm = max_norm_cf * np.max(norms)
118 | consts = np.sqrt(np.maximum(0.0, max_norm ** 2 - norms ** 2))
119 | out = np.concatenate([consts, out], axis=1)
120 | if num_dummy_zeros > 0:
121 | out = np.concatenate([out, np.zeros([out.shape[0], num_dummy_zeros], dtype=out.dtype)], axis=1)
122 | return out, max_norm
123 |
124 |
125 | def train_coarse_quantizer(data, quantizer_path, num_clusters, hnsw=False, niter=10, cuda=False):
126 | d = data.shape[1]
127 |
128 | index_flat = faiss.IndexFlatL2(d)
129 | # make it into a gpu index
130 | if cuda:
131 | res = faiss.StandardGpuResources()
132 | index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)
133 | clus = faiss.Clustering(d, num_clusters)
134 | clus.verbose = True
135 | clus.niter = niter
136 | clus.train(data, index_flat)
137 | centroids = faiss.vector_float_to_array(clus.centroids)
138 | centroids = centroids.reshape(num_clusters, d)
139 |
140 | if hnsw:
141 | quantizer = faiss.IndexHNSWFlat(d, 32)
142 | quantizer.hnsw.efSearch = 128
143 | quantizer.train(centroids)
144 | quantizer.add(centroids)
145 | else:
146 | quantizer = faiss.IndexFlatL2(d)
147 | quantizer.add(centroids)
148 |
149 | faiss.write_index(quantizer, quantizer_path)
150 |
151 |
152 | def train_index(data, quantizer_path, trained_index_path, fine_quant='SQ8', cuda=False):
153 | quantizer = faiss.read_index(quantizer_path)
154 | if fine_quant == 'SQ8':
155 | trained_index = faiss.IndexIVFScalarQuantizer(quantizer, quantizer.d, quantizer.ntotal, faiss.METRIC_L2)
156 | elif fine_quant.startswith('PQ'):
157 | m = int(fine_quant[2:])
158 | trained_index = faiss.IndexIVFPQ(quantizer, quantizer.d, quantizer.ntotal, m, 8)
159 | else:
160 | raise ValueError(fine_quant)
161 |
162 | if cuda:
163 | if fine_quant.startswith('PQ'):
164 | print('PQ not supported on GPU; keeping CPU.')
165 | else:
166 | res = faiss.StandardGpuResources()
167 | gpu_index = faiss.index_cpu_to_gpu(res, 0, trained_index)
168 | gpu_index.train(data)
169 | trained_index = faiss.index_gpu_to_cpu(gpu_index)
170 | else:
171 | trained_index.train(data)
172 | faiss.write_index(trained_index, trained_index_path)
173 |
174 |
175 | def add_with_offset(index, data, offset, valids=None):
176 | ids = np.arange(data.shape[0]) + offset + index.ntotal
177 | if valids is not None:
178 | data = data[valids]
179 | ids = ids[valids]
180 | index.add_with_ids(data, ids)
181 |
182 |
183 | def add_to_index(dump_paths, trained_index_path, target_index_path, idx2id_path, max_norm, para=False,
184 | num_docs_per_add=1000, num_dummy_zeros=0, cuda=False, fine_quant='SQ8', offset=0, norm_th=999,
185 | ignore_ids=None):
186 | idx2doc_id = []
187 | idx2para_id = []
188 | idx2word_id = []
189 | dumps = [h5py.File(dump_path, 'r') for dump_path in dump_paths]
190 | print('reading %s' % trained_index_path)
191 | start_index = faiss.read_index(trained_index_path)
192 |
193 | if cuda:
194 | if fine_quant.startswith('PQ'):
195 | print('PQ not supported on GPU; keeping CPU.')
196 | else:
197 | res = faiss.StandardGpuResources()
198 | start_index = faiss.index_cpu_to_gpu(res, 0, start_index)
199 |
200 | print('adding following dumps:')
201 | for dump_path in dump_paths:
202 | print(dump_path)
203 | if para:
204 | for di, phrase_dump in enumerate(tqdm(dumps, desc='dumps')):
205 | starts = []
206 | for i, (doc_idx, doc_group) in enumerate(tqdm(phrase_dump.items(), desc='faiss indexing')):
207 | for para_idx, group in doc_group.items():
208 | num_vecs = group['start'].shape[0]
209 | start = int8_to_float(group['start'][:], group.attrs['offset'], group.attrs['scale'])
210 | norms = np.linalg.norm(start, axis=1, keepdims=True)
211 | consts = np.sqrt(np.maximum(0.0, max_norm ** 2 - norms ** 2))
212 | start = np.concatenate([consts, start], axis=1)
213 | if num_dummy_zeros > 0:
214 | start = np.concatenate(
215 | [start, np.zeros([start.shape[0], num_dummy_zeros], dtype=start.dtype)], axis=1)
216 | starts.append(start)
217 | idx2doc_id.extend([int(doc_idx)] * num_vecs)
218 | idx2para_id.extend([int(para_idx)] * num_vecs)
219 | idx2word_id.extend(list(range(num_vecs)))
220 | if len(starts) > 0 and i % num_docs_per_add == 0:
221 | print('concatenating')
222 | concat = np.concatenate(starts, axis=0)
223 | print('adding')
224 | add_with_offset(start_index, concat, offset)
225 | # start_index.add(concat)
226 | print('done')
227 | starts = []
228 | if i % 100 == 0:
229 | print('%d/%d' % (i + 1, len(phrase_dump.keys())))
230 | print('adding leftover')
231 | add_with_offset(start_index, np.concatenate(starts, axis=0), offset)
232 | # start_index.add(np.concatenate(starts, axis=0)) # leftover
233 | print('done')
234 | else:
235 | for di, phrase_dump in enumerate(tqdm(dumps, desc='dumps')):
236 | starts = []
237 | valids = []
238 | for i, (doc_idx, doc_group) in enumerate(tqdm(phrase_dump.items(), desc='adding %d' % di)):
239 | if ignore_ids is not None and doc_idx in ignore_ids:
240 | continue
241 | num_vecs = doc_group['start'].shape[0]
242 | start = int8_to_float(doc_group['start'][:], doc_group.attrs['offset'],
243 | doc_group.attrs['scale'])
244 | valid = np.linalg.norm(start, axis=1) <= norm_th
245 | norms = np.linalg.norm(start, axis=1, keepdims=True)
246 | consts = np.sqrt(np.maximum(0.0, max_norm ** 2 - norms ** 2))
247 | start = np.concatenate([consts, start], axis=1)
248 | if num_dummy_zeros > 0:
249 | start = np.concatenate([start, np.zeros([start.shape[0], num_dummy_zeros], dtype=start.dtype)],
250 | axis=1)
251 | starts.append(start)
252 | valids.append(valid)
253 | idx2doc_id.extend([int(doc_idx)] * num_vecs)
254 | idx2word_id.extend(range(num_vecs))
255 | if len(starts) > 0 and i % num_docs_per_add == 0:
256 | print('adding at %d' % (i+1))
257 | add_with_offset(start_index, np.concatenate(starts, axis=0), offset, np.concatenate(valids))
258 | # start_index.add(np.concatenate(starts, axis=0))
259 | starts = []
260 | valids = []
261 | if i % 100 == 0:
262 | # print('%d/%d' % (i + 1, len(phrase_dump.keys())))
263 | continue
264 | print('final adding at %d' % (i+1))
265 | add_with_offset(start_index, np.concatenate(starts, axis=0), offset, np.concatenate(valids))
266 | # start_index.add(np.concatenate(starts, axis=0)) # leftover
267 |
268 | for dump in dumps:
269 | dump.close()
270 |
271 | if cuda and not fine_quant.startswith('PQ'):
272 | print('moving back to cpu')
273 | start_index = faiss.index_gpu_to_cpu(start_index)
274 |
275 | print('index ntotal: %d' % start_index.ntotal)
276 | idx2doc_id = np.array(idx2doc_id, dtype=np.int32)
277 | idx2para_id = np.array(idx2para_id, dtype=np.int32)
278 | idx2word_id = np.array(idx2word_id, dtype=np.int32)
279 |
280 | print('writing index and metadata')
281 | with h5py.File(idx2id_path, 'w') as f:
282 | g = f.create_group(str(offset))
283 | g.create_dataset('doc', data=idx2doc_id)
284 | g.create_dataset('para', data=idx2para_id)
285 | g.create_dataset('word', data=idx2word_id)
286 | g.attrs['offset'] = offset
287 | faiss.write_index(start_index, target_index_path)
288 | print('done')
289 |
290 |
291 | def merge_indexes(subindex_dir, trained_index_path, target_index_path, target_idx2id_path, target_inv_path):
292 | # target_inv_path = merged_index.ivfdata
293 | names = os.listdir(subindex_dir)
294 | idx2id_paths = [os.path.join(subindex_dir, name) for name in names if name.endswith('.hdf5')]
295 | index_paths = [os.path.join(subindex_dir, name) for name in names if name.endswith('.faiss')]
296 | print(len(idx2id_paths))
297 | print(len(index_paths))
298 |
299 | print('copying idx2id')
300 | with h5py.File(target_idx2id_path, 'w') as out:
301 | for idx2id_path in tqdm(idx2id_paths, desc='copying idx2id'):
302 | with h5py.File(idx2id_path, 'r') as in_:
303 | for key, g in in_.items():
304 | offset = str(g.attrs['offset'])
305 | assert key == offset
306 | group = out.create_group(offset)
307 | group.create_dataset('doc', data=in_[key]['doc'])
308 | group.create_dataset('para', data=in_[key]['para'])
309 | group.create_dataset('word', data=in_[key]['word'])
310 |
311 | print('loading invlists')
312 | ivfs = []
313 | for index_path in tqdm(index_paths, desc='loading invlists'):
314 | # the IO_FLAG_MMAP is to avoid actually loading the data thus
315 | # the total size of the inverted lists can exceed the
316 | # available RAM
317 | index = faiss.read_index(index_path,
318 | faiss.IO_FLAG_MMAP)
319 | ivfs.append(index.invlists)
320 |
321 | # avoid that the invlists get deallocated with the index
322 | index.own_invlists = False
323 |
324 | # construct the output index
325 | index = faiss.read_index(trained_index_path)
326 |
327 | # prepare the output inverted lists. They will be written
328 | # to merged_index.ivfdata
329 | invlists = faiss.OnDiskInvertedLists(
330 | index.nlist, index.code_size,
331 | target_inv_path)
332 |
333 | # merge all the inverted lists
334 | print('merging')
335 | ivf_vector = faiss.InvertedListsPtrVector()
336 | for ivf in tqdm(ivfs):
337 | ivf_vector.push_back(ivf)
338 |
339 | print("merge %d inverted lists " % ivf_vector.size())
340 | ntotal = invlists.merge_from(ivf_vector.data(), ivf_vector.size())
341 | print(ntotal)
342 |
343 | # now replace the inverted lists in the output index
344 | index.ntotal = ntotal
345 | index.replace_invlists(invlists)
346 |
347 | print('writing index')
348 | faiss.write_index(index, target_index_path)
349 |
350 |
351 | def run_index(args):
352 | phrase_path = os.path.join(args.dump_dir, 'phrase.hdf5')
353 | if os.path.exists(phrase_path):
354 | dump_paths = [phrase_path]
355 | else:
356 | dump_names = os.listdir(os.path.join(args.dump_dir, 'phrase'))
357 | dump_paths = [os.path.join(args.dump_dir, 'phrase', name) for name in dump_names if name.endswith('.hdf5')]
358 |
359 | data = None
360 | if args.stage in ['all', 'coarse']:
361 | if args.replace or not os.path.exists(args.quantizer_path):
362 | if not os.path.exists(args.index_dir):
363 | os.makedirs(args.index_dir)
364 | data, max_norm = sample_data(dump_paths, max_norm=args.max_norm, para=args.para,
365 | doc_sample_ratio=args.doc_sample_ratio, vec_sample_ratio=args.vec_sample_ratio,
366 | max_norm_cf=args.max_norm_cf, num_dummy_zeros=args.num_dummy_zeros,
367 | norm_th=args.norm_th)
368 | with open(args.max_norm_path, 'w') as fp:
369 | json.dump(max_norm, fp)
370 | train_coarse_quantizer(data, args.quantizer_path, args.num_clusters, cuda=args.cuda)
371 |
372 | if args.stage in ['all', 'fine']:
373 | if args.replace or not os.path.exists(args.trained_index_path):
374 | with open(args.max_norm_path, 'r') as fp:
375 | max_norm = json.load(fp)
376 | if data is None:
377 | data, _ = sample_data(dump_paths, max_norm=max_norm, para=args.para,
378 | doc_sample_ratio=args.doc_sample_ratio, vec_sample_ratio=args.vec_sample_ratio,
379 | num_dummy_zeros=args.num_dummy_zeros, norm_th=args.norm_th)
380 | train_index(data, args.quantizer_path, args.trained_index_path, fine_quant=args.fine_quant, cuda=args.cuda)
381 |
382 | if args.stage in ['all', 'add']:
383 | if args.replace or not os.path.exists(args.index_path):
384 | with open(args.max_norm_path, 'r') as fp:
385 | max_norm = json.load(fp)
386 | if args.dump_paths is not None:
387 | dump_paths = args.dump_paths
388 | if not os.path.exists(args.subindex_dir):
389 | os.makedirs(args.subindex_dir)
390 | add_to_index(dump_paths, args.trained_index_path, args.index_path, args.idx2id_path,
391 | max_norm=max_norm, para=args.para, num_dummy_zeros=args.num_dummy_zeros, cuda=args.cuda,
392 | num_docs_per_add=args.num_docs_per_add, offset=args.offset, norm_th=args.norm_th,
393 | fine_quant=args.fine_quant)
394 |
395 | if args.stage == 'merge':
396 | if args.replace or not os.path.exists(args.index_path):
397 | merge_indexes(args.subindex_dir, args.trained_index_path, args.index_path, args.idx2id_path, args.inv_path)
398 |
399 | if args.stage == 'move':
400 | index = faiss.read_index(args.trained_index_path)
401 | invlists = faiss.OnDiskInvertedLists(
402 | index.nlist, index.code_size,
403 | args.inv_path)
404 | index.replace_invlists(invlists)
405 | faiss.write_index(index, args.index_path)
406 |
407 |
408 | def main():
409 | args = get_args()
410 | run_index(args)
411 |
412 |
413 | if __name__ == '__main__':
414 | main()
415 |
--------------------------------------------------------------------------------
/run_server.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import torch
4 | import tokenization
5 | import os
6 | import random
7 | import numpy as np
8 | import requests
9 | import logging
10 | import math
11 | import ssl
12 | import copy
13 | from time import time
14 | from flask import Flask, request, jsonify, render_template, redirect
15 | from flask_cors import CORS
16 | from tornado.wsgi import WSGIContainer
17 | from tornado.httpserver import HTTPServer
18 | from tornado.ioloop import IOLoop
19 | from requests_futures.sessions import FuturesSession
20 | from tqdm import tqdm
21 | from collections import namedtuple
22 |
23 | from modeling import BertConfig
24 | from modeling import DenSPI
25 | from tfidf_doc_ranker import TfidfDocRanker
26 | from utils import check_diff
27 | from pre import SquadExample, convert_questions_to_features
28 | from post import convert_question_features_to_dataloader, get_question_results
29 | from mips_phrase import MIPS
30 |
31 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
32 | level=logging.INFO)
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | class DenSPIServer(object):
37 | def __init__(self, args):
38 | self.args = args
39 | # IP and Ports
40 | self.base_ip = args.base_ip
41 | self.query_port = args.query_port
42 | self.doc_port = args.doc_port
43 | self.index_port = args.index_port
44 |
45 | # Saved objects
46 | self.mips = None
47 |
48 | def load_query_encoder(self, device, args):
49 | # Configure paths for query encoder serving
50 | vocab_path = os.path.join(args.metadata_dir, args.vocab_name)
51 | bert_config_path = os.path.join(
52 | args.metadata_dir, args.bert_config_name.replace(".json", "") + "_" + args.bert_model_option + ".json"
53 | )
54 |
55 | # Load pretrained QueryEncoder
56 | bert_config = BertConfig.from_json_file(bert_config_path)
57 | model = DenSPI(bert_config)
58 | if args.parallel:
59 | model = torch.nn.DataParallel(model)
60 | state = torch.load(args.query_encoder_path, map_location='cpu')
61 | model.load_state_dict(state['model'], strict=False)
62 | check_diff(model.state_dict(), state['model'])
63 | model.to(device)
64 | logger.info('Model loaded from %s' % args.query_encoder_path)
65 | logger.info('Number of model parameters: {:,}'.format(sum(p.numel() for p in model.parameters())))
66 |
67 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path, do_lower_case=not args.do_case)
68 | return model, tokenizer
69 |
70 | def get_question_dataloader(self, questions, tokenizer, batch_size):
71 | question_examples = [SquadExample(qas_id='qs', question_text=q) for q in questions]
72 | query_features = convert_questions_to_features(
73 | examples=question_examples,
74 | tokenizer=tokenizer,
75 | max_query_length=64
76 | )
77 | question_dataloader = convert_question_features_to_dataloader(
78 | query_features,
79 | fp16=False, local_rank=-1,
80 | predict_batch_size=batch_size
81 | )
82 | return question_dataloader, question_examples, query_features
83 |
84 | def serve_query_encoder(self, query_port, args):
85 | device = 'cuda' if args.cuda else 'cpu'
86 | query_encoder, tokenizer = self.load_query_encoder(device, args)
87 |
88 | # Define query to vector function
89 | def query2vec(queries):
90 | question_dataloader, question_examples, query_features = self.get_question_dataloader(
91 | queries, tokenizer, batch_size=24
92 | )
93 | query_encoder.eval()
94 | question_results = get_question_results(
95 | question_examples, query_features, question_dataloader, device, query_encoder
96 | )
97 | outs = []
98 | for qr_idx, question_result in enumerate(question_results):
99 | for ngram in question_result.sparse.keys():
100 | question_result.sparse[ngram] = question_result.sparse[ngram].tolist()
101 | out = (
102 | question_result.start.tolist(), question_result.end.tolist(),
103 | question_result.sparse, question_result.input_ids
104 | )
105 | outs.append(out)
106 | return outs
107 |
108 | # Serve query encoder
109 | app = Flask(__name__)
110 | app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
111 | CORS(app)
112 |
113 | @app.route('/batch_api', methods=['POST'])
114 | def batch_api():
115 | batch_query = json.loads(request.form['query'])
116 | outs = query2vec(batch_query)
117 | return jsonify(outs)
118 |
119 | logger.info(f'Starting QueryEncoder server at {self.get_address(query_port)}')
120 | http_server = HTTPServer(WSGIContainer(app))
121 | http_server.listen(query_port)
122 | IOLoop.instance().start()
123 |
124 | def load_phrase_index(self, args, dump_only=False):
125 | if self.mips is not None:
126 | return self.mips
127 |
128 | # Configure paths for index serving
129 | phrase_dump_dir = os.path.join(args.dump_dir, args.phrase_dir)
130 | tfidf_dump_dir = os.path.join(args.dump_dir, args.tfidf_dir)
131 | index_dir = os.path.join(args.dump_dir, args.index_dir)
132 | index_path = os.path.join(index_dir, args.index_name)
133 | idx2id_path = os.path.join(index_dir, args.idx2id_name)
134 | max_norm_path = os.path.join(index_dir, 'max_norm.json')
135 |
136 | # Load mips
137 | mips_init = MIPS
138 | mips = mips_init(
139 | phrase_dump_dir=phrase_dump_dir,
140 | tfidf_dump_dir=tfidf_dump_dir,
141 | start_index_path=index_path,
142 | idx2id_path=idx2id_path,
143 | max_norm_path=max_norm_path,
144 | doc_rank_fn={
145 | 'index': self.get_doc_scores, 'top_docs': self.get_top_docs, 'spvec': self.get_q_spvecs
146 | },
147 | cuda=args.cuda, dump_only=dump_only
148 | )
149 | return mips
150 |
151 | def serve_phrase_index(self, index_port, args):
152 | args.examples_path = os.path.join('static', args.examples_path)
153 |
154 | # Load mips
155 | self.mips = self.load_phrase_index(args)
156 | app = Flask(__name__, static_url_path='/static')
157 | app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
158 | CORS(app)
159 |
160 | def batch_search(batch_query, max_answer_length=20, start_top_k=1000, mid_top_k=100, top_k=10, doc_top_k=5,
161 | nprobe=64, sparse_weight=0.05, search_strategy='hybrid'):
162 | t0 = time()
163 | outs, _ = self.embed_query(batch_query)()
164 | start = np.concatenate([out[0] for out in outs], 0)
165 | end = np.concatenate([out[1] for out in outs], 0)
166 | sparse_uni = [out[2]['1'][1:len(out[3])+1] for out in outs]
167 | sparse_bi = [out[2]['2'][1:len(out[3])+1] for out in outs]
168 | input_ids = [out[3] for out in outs]
169 | query_vec = np.concatenate([start, end, [[1]]*len(outs)], 1)
170 |
171 | rets = self.mips.search(
172 | query_vec, (input_ids, sparse_uni, sparse_bi), q_texts=batch_query, nprobe=nprobe,
173 | doc_top_k=doc_top_k, start_top_k=start_top_k, mid_top_k=mid_top_k, top_k=top_k,
174 | search_strategy=search_strategy, filter_=args.filter, max_answer_length=max_answer_length,
175 | sparse_weight=sparse_weight
176 | )
177 | t1 = time()
178 | out = {'ret': rets, 'time': int(1000 * (t1 - t0))}
179 | return out
180 |
181 | @app.route('/')
182 | def index():
183 | return app.send_static_file('index.html')
184 |
185 | @app.route('/files/')
186 | def static_files(path):
187 | return app.send_static_file('files/' + path)
188 |
189 | # This one uses a default hyperparameters
190 | @app.route('/api', methods=['GET'])
191 | def api():
192 | query = request.args['query']
193 | strat = request.args['strat']
194 | out = batch_search(
195 | [query],
196 | max_answer_length=args.max_answer_length,
197 | top_k=args.top_k,
198 | nprobe=args.nprobe,
199 | search_strategy=strat,
200 | doc_top_k=args.doc_top_k
201 | )
202 | out['ret'] = out['ret'][0]
203 | return jsonify(out)
204 |
205 | @app.route('/batch_api', methods=['POST'])
206 | def batch_api():
207 | batch_query = json.loads(request.form['query'])
208 | max_answer_length = int(request.form['max_answer_length'])
209 | start_top_k = int(request.form['start_top_k'])
210 | mid_top_k = int(request.form['mid_top_k'])
211 | top_k = int(request.form['top_k'])
212 | doc_top_k = int(request.form['doc_top_k'])
213 | nprobe = int(request.form['nprobe'])
214 | sparse_weight = float(request.form['sparse_weight'])
215 | strat = request.form['strat']
216 | out = batch_search(
217 | batch_query,
218 | max_answer_length=max_answer_length,
219 | start_top_k=start_top_k,
220 | mid_top_k=mid_top_k,
221 | top_k=top_k,
222 | doc_top_k=doc_top_k,
223 | nprobe=nprobe,
224 | sparse_weight=sparse_weight,
225 | search_strategy=strat,
226 | )
227 | return jsonify(out)
228 |
229 | @app.route('/get_examples', methods=['GET'])
230 | def get_examples():
231 | with open(args.examples_path, 'r') as fp:
232 | examples = [line.strip() for line in fp.readlines()]
233 | return jsonify(examples)
234 |
235 | if self.query_port is None:
236 | logger.info('You must set self.query_port for querying. You can use self.update_query_port() later on.')
237 | logger.info(f'Starting Index server at {self.get_address(index_port)}')
238 | http_server = HTTPServer(WSGIContainer(app))
239 | http_server.listen(index_port)
240 | IOLoop.instance().start()
241 |
242 | def serve_doc_ranker(self, doc_port, args):
243 | doc_ranker_path = os.path.join(args.dump_dir, args.doc_ranker_name)
244 | doc_ranker = TfidfDocRanker(doc_ranker_path, strict=False)
245 | app = Flask(__name__)
246 | app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
247 | CORS(app)
248 |
249 | @app.route('/doc_index', methods=['POST'])
250 | def doc_index():
251 | batch_query = json.loads(request.form['query'])
252 | doc_idxs = json.loads(request.form['doc_idxs'])
253 | outs = doc_ranker.batch_doc_scores(batch_query, doc_idxs)
254 | logger.info(f'Returning {len(outs)} from batch_doc_scores')
255 | return jsonify(outs)
256 |
257 | @app.route('/top_docs', methods=['POST'])
258 | def top_docs():
259 | batch_query = json.loads(request.form['query'])
260 | top_k = int(request.form['top_k'])
261 | batch_results = doc_ranker.batch_closest_docs(batch_query, k=top_k)
262 | top_idxs = [b[0] for b in batch_results]
263 | top_scores = [b[1].tolist() for b in batch_results]
264 | logger.info(f'Returning from batch_doc_scores')
265 | return jsonify([top_idxs, top_scores])
266 |
267 | @app.route('/text2spvec', methods=['POST'])
268 | def text2spvec():
269 | batch_query = json.loads(request.form['query'])
270 | q_spvecs = [doc_ranker.text2spvec(q, val_idx=True) for q in batch_query]
271 | q_vals = [q_spvec[0].tolist() for q_spvec in q_spvecs]
272 | q_idxs = [q_spvec[1].tolist() for q_spvec in q_spvecs]
273 | logger.info(f'Returning {len(q_vals), len(q_idxs)} q_spvecs')
274 | return jsonify([q_vals, q_idxs])
275 |
276 | logger.info(f'Starting DocRanker server at {self.get_address(doc_port)}')
277 | http_server = HTTPServer(WSGIContainer(app))
278 | http_server.listen(doc_port)
279 | IOLoop.instance().start()
280 |
281 | def get_address(self, port):
282 | assert self.base_ip is not None and len(port) > 0
283 | return self.base_ip + ':' + port
284 |
285 | def embed_query(self, batch_query):
286 | emb_session = FuturesSession()
287 | r = emb_session.post(self.get_address(self.query_port) + '/batch_api', data={'query': json.dumps(batch_query)})
288 | def map_():
289 | result = r.result()
290 | emb = result.json()
291 | return emb, result.elapsed.total_seconds() * 1000
292 | return map_
293 |
294 | def query(self, query, search_strategy='hybrid'):
295 | params = {'query': query, 'strat': search_strategy}
296 | res = requests.get(self.get_address(self.index_port) + '/api', params=params)
297 | if res.status_code != 200:
298 | logger.info('Wrong behavior %d' % res.status_code)
299 | try:
300 | outs = json.loads(res.text)
301 | except Exception as e:
302 | logger.info(f'no response or error for q {query}')
303 | logger.info(res.text)
304 | return outs
305 |
306 | def batch_query(self, batch_query, max_answer_length=20, start_top_k=1000, mid_top_k=100, top_k=10, doc_top_k=5,
307 | nprobe=64, sparse_weight=0.05, search_strategy='hybrid'):
308 | post_data = {
309 | 'query': json.dumps(batch_query),
310 | 'max_answer_length': max_answer_length,
311 | 'start_top_k': start_top_k,
312 | 'mid_top_k': mid_top_k,
313 | 'top_k': top_k,
314 | 'doc_top_k': doc_top_k,
315 | 'nprobe': nprobe,
316 | 'sparse_weight': sparse_weight,
317 | 'strat': search_strategy,
318 | }
319 | res = requests.post(self.get_address(self.index_port) + '/batch_api', data=post_data)
320 | if res.status_code != 200:
321 | logger.info('Wrong behavior %d' % res.status_code)
322 | try:
323 | outs = json.loads(res.text)
324 | except Exception as e:
325 | logger.info(f'no response or error for q {batch_query}')
326 | logger.info(res.text)
327 | return outs
328 |
329 | def get_doc_scores(self, batch_query, doc_idxs):
330 | post_data = {
331 | 'query': json.dumps(batch_query),
332 | 'doc_idxs': json.dumps(doc_idxs)
333 | }
334 | res = requests.post(self.get_address(self.doc_port) + '/doc_index', data=post_data)
335 | if res.status_code != 200:
336 | logger.info('Wrong behavior %d' % res.status_code)
337 | try:
338 | result = json.loads(res.text)
339 | except Exception as e:
340 | logger.info(f'no response or error for {doc_idxs}')
341 | logger.info(res.text)
342 | return result
343 |
344 | def get_top_docs(self, batch_query, top_k):
345 | post_data = {
346 | 'query': json.dumps(batch_query),
347 | 'top_k': top_k
348 | }
349 | res = requests.post(self.get_address(self.doc_port) + '/top_docs', data=post_data)
350 | if res.status_code != 200:
351 | logger.info('Wrong behavior %d' % res.status_code)
352 | try:
353 | result = json.loads(res.text)
354 | except Exception as e:
355 | logger.info(f'no response or error for {top_k}')
356 | logger.info(res.text)
357 | return result
358 |
359 | def get_q_spvecs(self, batch_query):
360 | post_data = {'query': json.dumps(batch_query)}
361 | res = requests.post(self.get_address(self.doc_port) + '/text2spvec', data=post_data)
362 | if res.status_code != 200:
363 | logger.info('Wrong behavior %d' % res.status_code)
364 | try:
365 | result = json.loads(res.text)
366 | except Exception as e:
367 | logger.info(f'no response or error for q {batch_query}')
368 | logger.info(res.text)
369 | return result
370 |
371 |
372 | if __name__ == '__main__':
373 | parser = argparse.ArgumentParser()
374 | # QueryEncoder
375 | parser.add_argument('--metadata_dir', default='/nvme/jinhyuk/denspi/bert', type=str)
376 | parser.add_argument("--vocab_name", default='vocab.txt', type=str)
377 | parser.add_argument("--bert_config_name", default='bert_config.json', type=str)
378 | parser.add_argument("--bert_model_option", default='large_uncased', type=str)
379 | parser.add_argument("--parallel", default=False, action='store_true')
380 | parser.add_argument("--do_case", default=False, action='store_true')
381 | parser.add_argument("--query_encoder_path", default='/nvme/jinhyuk/denspi/KR94373_piqa-nfs_1173/1/model.pt', type=str)
382 | parser.add_argument("--query_port", default='-1', type=str)
383 |
384 | # DocRanker
385 | parser.add_argument('--doc_ranker_name', default='docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz')
386 | parser.add_argument('--doc_port', default='-1', type=str)
387 |
388 | # PhraseIndex
389 | parser.add_argument('--dump_dir', default='/nvme/jinhyuk/denspi/1173_wikipedia_filtered')
390 | parser.add_argument('--phrase_dir', default='phrase')
391 | parser.add_argument('--tfidf_dir', default='tfidf')
392 | parser.add_argument('--index_dir', default='1048576_hnsw_SQ8')
393 | parser.add_argument('--index_name', default='index.faiss')
394 | parser.add_argument('--idx2id_name', default='idx2id.hdf5')
395 | parser.add_argument('--index_port', default='-1', type=str)
396 |
397 | # These can be dynamically changed.
398 | parser.add_argument('--max_answer_length', default=20, type=int)
399 | parser.add_argument('--start_top_k', default=1000, type=int)
400 | parser.add_argument('--mid_top_k', default=100, type=int)
401 | parser.add_argument('--top_k', default=10, type=int)
402 | parser.add_argument('--doc_top_k', default=5, type=int)
403 | parser.add_argument('--nprobe', default=256, type=int)
404 | parser.add_argument('--sparse_weight', default=0.05, type=float)
405 | parser.add_argument('--search_strategy', default='hybrid')
406 | parser.add_argument('--filter', default=False, action='store_true')
407 | parser.add_argument('--no_para', default=False, action='store_true')
408 |
409 | # Serving options
410 | parser.add_argument('--examples_path', default='examples.txt')
411 |
412 | # Run mode
413 | parser.add_argument('--base_ip', default='http://163.152.163.248')
414 | parser.add_argument('--run_mode', default='batch_query')
415 | parser.add_argument('--cuda', default=False, action='store_true')
416 | parser.add_argument('--draft', default=False, action='store_true')
417 | parser.add_argument('--seed', default=1992, type=int)
418 | args = parser.parse_args()
419 |
420 | # Seed for reproducibility
421 | random.seed(args.seed)
422 | np.random.seed(args.seed)
423 | torch.manual_seed(args.seed)
424 | if torch.cuda.is_available():
425 | torch.cuda.manual_seed_all(args.seed)
426 |
427 | server = DenSPIServer(args)
428 |
429 | # Set ports
430 | # server.query_port = '9010'
431 | # server.doc_port = '9020'
432 | # sersver.index_port = '10001'
433 |
434 | if args.run_mode == 'q_serve':
435 | logger.info(f'Query address: {server.get_address(server.query_port)}')
436 | server.serve_query_encoder(args.query_port, args)
437 |
438 | elif args.run_mode == 'd_serve':
439 | logger.info(f'Doc address: {server.get_address(server.doc_port)}')
440 | server.serve_doc_ranker(args.doc_port, args)
441 |
442 | elif args.run_mode == 'p_serve':
443 | logger.info(f'Query address: {server.get_address(server.query_port)}')
444 | logger.info(f'Doc address: {server.get_address(server.doc_port)}')
445 | logger.info(f'Index address: {server.get_address(server.index_port)}')
446 | server.serve_phrase_index(args.index_port, args)
447 |
448 | elif args.run_mode == 'query':
449 | logger.info(f'Index address: {server.get_address(server.index_port)}')
450 | query = 'Name three famous writers'
451 | result = server.query(query)
452 | logger.info(f'Answers to a question: {query}')
453 | logger.info(f'{[r["answer"] for r in result["ret"]]}')
454 |
455 | elif args.run_mode == 'batch_query':
456 | logger.info(f'Index address: {server.get_address(server.index_port)}')
457 | queries= [
458 | 'Name three famous writers',
459 | 'Who was defeated by computer in chess game?'
460 | ]
461 | result = server.batch_query(
462 | queries,
463 | max_answer_length=args.max_answer_length,
464 | start_top_k=args.start_top_k,
465 | mid_top_k=args.mid_top_k,
466 | top_k=args.top_k,
467 | doc_top_k=args.doc_top_k,
468 | nprobe=args.nprobe,
469 | sparse_weight=args.sparse_weight,
470 | search_strategy=args.search_strategy,
471 | )
472 | for query, result in zip(queries, result['ret']):
473 | logger.info(f'Answers to a question: {query}')
474 | logger.info(f'{[r["answer"] for r in result]}')
475 |
476 | elif args.run_mode == 'get_doc_scores':
477 | logger.info(f'Doc address: {server.get_address(server.doc_port)}')
478 | queries = [
479 | 'What was the Yuan\'s paper money called?',
480 | 'What makes a successful startup??',
481 | 'On which date was Genghis Khan\'s palace rediscovered by archeaologists?',
482 | 'To-y is a _ .'
483 | ]
484 | result = server.get_doc_scores(queries, [[36], [2], [31], [22222]])
485 | logger.info(result)
486 | result = server.get_top_docs(queries, 5)
487 | logger.info(result)
488 | result = server.get_q_spvecs(queries)
489 | logger.info(result)
490 |
491 | else:
492 | raise NotImplementedError
493 |
--------------------------------------------------------------------------------
/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Basic tokenizer that splits text into alpha-numeric tokens and
8 | non-whitespace tokens.
9 | """
10 |
11 | import regex
12 | import logging
13 | from tokenizer_util import Tokens, Tokenizer
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class SimpleTokenizer(Tokenizer):
19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
20 | NON_WS = r'[^\p{Z}\p{C}]'
21 |
22 | def __init__(self, **kwargs):
23 | """
24 | Args:
25 | annotators: None or empty set (only tokenizes).
26 | """
27 | self._regexp = regex.compile(
28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
30 | )
31 | if len(kwargs.get('annotators', {})) > 0:
32 | logger.warning('%s only tokenizes! Skipping annotators: %s' %
33 | (type(self).__name__, kwargs.get('annotators')))
34 | self.annotators = set()
35 |
36 | def tokenize(self, text):
37 | data = []
38 | matches = [m for m in self._regexp.finditer(text)]
39 | for i in range(len(matches)):
40 | # Get text
41 | token = matches[i].group()
42 |
43 | # Get whitespace
44 | span = matches[i].span()
45 | start_ws = span[0]
46 | if i + 1 < len(matches):
47 | end_ws = matches[i + 1].span()[0]
48 | else:
49 | end_ws = span[1]
50 |
51 | # Format data
52 | data.append((
53 | token,
54 | text[start_ws: end_ws],
55 | span,
56 | ))
57 | return Tokens(data, self.annotators)
58 |
--------------------------------------------------------------------------------
/static/examples.txt:
--------------------------------------------------------------------------------
1 | What is the GDP of South Korea in 1950?
2 | What is the GDP of South Korea in 2010?
3 | Who is the fourth president of USA?
4 | Who is the seventh president of USA?
5 | -------------------------------------------------
6 | What is South Korea known for?
7 | What tends to lead to more money?
8 | Who was defeated by computer in chess game?
9 | Name three famous writers
10 | What makes a successful startup?
11 | Why did Oracle sue Google?
12 | Where can you find water in desert?
13 | What does AMI stand for?
14 | How heavy was the apollo 11?
15 | What is water consisted of?
16 | What makes a man great?
17 | Which city is famous for coffee?
18 | On which date was Genghis Khan's palace rediscovered by archeaologists?
19 | What is another term for x-ray imaging?
20 | Who scolded Luther about his rudeness?
21 | What was the Yuan's paper money called?
22 |
--------------------------------------------------------------------------------
/static/files/pichu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhyuklee/sparc/bee309bdffd73d162c23a7c3c0b63cebe4aaea97/static/files/pichu.png
--------------------------------------------------------------------------------
/static/files/pika.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhyuklee/sparc/bee309bdffd73d162c23a7c3c0b63cebe4aaea97/static/files/pika.png
--------------------------------------------------------------------------------
/static/files/popper.min.js:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (C) Federico Zivolo 2018
3 | Distributed under the MIT License (license terms are at http://opensource.org/licenses/MIT).
4 | */(function(e,t){'object'==typeof exports&&'undefined'!=typeof module?module.exports=t():'function'==typeof define&&define.amd?define(t):e.Popper=t()})(this,function(){'use strict';function e(e){return e&&'[object Function]'==={}.toString.call(e)}function t(e,t){if(1!==e.nodeType)return[];var o=getComputedStyle(e,null);return t?o[t]:o}function o(e){return'HTML'===e.nodeName?e:e.parentNode||e.host}function n(e){if(!e)return document.body;switch(e.nodeName){case'HTML':case'BODY':return e.ownerDocument.body;case'#document':return e.body;}var i=t(e),r=i.overflow,p=i.overflowX,s=i.overflowY;return /(auto|scroll|overlay)/.test(r+s+p)?e:n(o(e))}function r(e){return 11===e?re:10===e?pe:re||pe}function p(e){if(!e)return document.documentElement;for(var o=r(10)?document.body:null,n=e.offsetParent;n===o&&e.nextElementSibling;)n=(e=e.nextElementSibling).offsetParent;var i=n&&n.nodeName;return i&&'BODY'!==i&&'HTML'!==i?-1!==['TD','TABLE'].indexOf(n.nodeName)&&'static'===t(n,'position')?p(n):n:e?e.ownerDocument.documentElement:document.documentElement}function s(e){var t=e.nodeName;return'BODY'!==t&&('HTML'===t||p(e.firstElementChild)===e)}function d(e){return null===e.parentNode?e:d(e.parentNode)}function a(e,t){if(!e||!e.nodeType||!t||!t.nodeType)return document.documentElement;var o=e.compareDocumentPosition(t)&Node.DOCUMENT_POSITION_FOLLOWING,n=o?e:t,i=o?t:e,r=document.createRange();r.setStart(n,0),r.setEnd(i,0);var l=r.commonAncestorContainer;if(e!==l&&t!==l||n.contains(i))return s(l)?l:p(l);var f=d(e);return f.host?a(f.host,t):a(e,d(t).host)}function l(e){var t=1=o.clientWidth&&n>=o.clientHeight}),l=0a[e]&&!t.escapeWithReference&&(n=J(f[o],a[e]-('right'===e?f.width:f.height))),ae({},o,n)}};return l.forEach(function(e){var t=-1===['left','top'].indexOf(e)?'secondary':'primary';f=le({},f,m[t](e))}),e.offsets.popper=f,e},priority:['left','right','top','bottom'],padding:5,boundariesElement:'scrollParent'},keepTogether:{order:400,enabled:!0,fn:function(e){var t=e.offsets,o=t.popper,n=t.reference,i=e.placement.split('-')[0],r=Z,p=-1!==['top','bottom'].indexOf(i),s=p?'right':'bottom',d=p?'left':'top',a=p?'width':'height';return o[s]r(n[s])&&(e.offsets.popper[d]=r(n[s])),e}},arrow:{order:500,enabled:!0,fn:function(e,o){var n;if(!q(e.instance.modifiers,'arrow','keepTogether'))return e;var i=o.element;if('string'==typeof i){if(i=e.instance.popper.querySelector(i),!i)return e;}else if(!e.instance.popper.contains(i))return console.warn('WARNING: `arrow.element` must be child of its popper element!'),e;var r=e.placement.split('-')[0],p=e.offsets,s=p.popper,d=p.reference,a=-1!==['left','right'].indexOf(r),l=a?'height':'width',f=a?'Top':'Left',m=f.toLowerCase(),h=a?'left':'top',c=a?'bottom':'right',u=S(i)[l];d[c]-us[c]&&(e.offsets.popper[m]+=d[m]+u-s[c]),e.offsets.popper=g(e.offsets.popper);var b=d[m]+d[l]/2-u/2,y=t(e.instance.popper),w=parseFloat(y['margin'+f],10),E=parseFloat(y['border'+f+'Width'],10),v=b-e.offsets.popper[m]-w-E;return v=$(J(s[l]-u,v),0),e.arrowElement=i,e.offsets.arrow=(n={},ae(n,m,Q(v)),ae(n,h,''),n),e},element:'[x-arrow]'},flip:{order:600,enabled:!0,fn:function(e,t){if(W(e.instance.modifiers,'inner'))return e;if(e.flipped&&e.placement===e.originalPlacement)return e;var o=v(e.instance.popper,e.instance.reference,t.padding,t.boundariesElement,e.positionFixed),n=e.placement.split('-')[0],i=T(n),r=e.placement.split('-')[1]||'',p=[];switch(t.behavior){case he.FLIP:p=[n,i];break;case he.CLOCKWISE:p=z(n);break;case he.COUNTERCLOCKWISE:p=z(n,!0);break;default:p=t.behavior;}return p.forEach(function(s,d){if(n!==s||p.length===d+1)return e;n=e.placement.split('-')[0],i=T(n);var a=e.offsets.popper,l=e.offsets.reference,f=Z,m='left'===n&&f(a.right)>f(l.left)||'right'===n&&f(a.left)f(l.top)||'bottom'===n&&f(a.top)f(o.right),g=f(a.top)f(o.bottom),b='left'===n&&h||'right'===n&&c||'top'===n&&g||'bottom'===n&&u,y=-1!==['top','bottom'].indexOf(n),w=!!t.flipVariations&&(y&&'start'===r&&h||y&&'end'===r&&c||!y&&'start'===r&&g||!y&&'end'===r&&u);(m||b||w)&&(e.flipped=!0,(m||b)&&(n=p[d+1]),w&&(r=G(r)),e.placement=n+(r?'-'+r:''),e.offsets.popper=le({},e.offsets.popper,C(e.instance.popper,e.offsets.reference,e.placement)),e=P(e.instance.modifiers,e,'flip'))}),e},behavior:'flip',padding:5,boundariesElement:'viewport'},inner:{order:700,enabled:!1,fn:function(e){var t=e.placement,o=t.split('-')[0],n=e.offsets,i=n.popper,r=n.reference,p=-1!==['left','right'].indexOf(o),s=-1===['top','left'].indexOf(o);return i[p?'left':'top']=r[o]-(s?i[p?'width':'height']:0),e.placement=T(t),e.offsets.popper=g(i),e}},hide:{order:800,enabled:!0,fn:function(e){if(!q(e.instance.modifiers,'hide','preventOverflow'))return e;var t=e.offsets.reference,o=D(e.instance.modifiers,function(e){return'preventOverflow'===e.name}).boundaries;if(t.bottomo.right||t.top>o.bottom||t.right
2 |
3 |
4 | DenSPI + Sparc
5 |
6 |
7 |
8 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | DenSPI + Sparc
28 |
36 |
37 |
38 |
43 |
44 |
45 |
69 |
70 |
71 |
Latency:
72 |
73 |
77 | Wikipedia EN (Dec 2016 dump)
78 |
79 |
80 |
81 |
82 |
83 |
84 | Dense-First Search
85 |
86 |
87 |
88 |
89 |
90 | Sparse-First Search
91 |
92 |
93 |
94 |
95 |
96 | Hybrid
97 |
98 |
99 |
100 |
101 |
106 |
107 |
108 |
109 |
117 |
118 |
119 |
205 |
206 |
207 |
208 |
209 |
--------------------------------------------------------------------------------
/static/preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhyuklee/sparc/bee309bdffd73d162c23a7c3c0b63cebe4aaea97/static/preview.png
--------------------------------------------------------------------------------
/tfidf_doc_ranker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Rank documents with TF-IDF scores"""
8 |
9 | import logging
10 | import numpy as np
11 | import scipy.sparse as sp
12 |
13 | from multiprocessing.pool import ThreadPool
14 | from functools import partial
15 |
16 | import tfidf_util as utils
17 | from simple_tokenizer import SimpleTokenizer
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class TfidfDocRanker(object):
23 | """Loads a pre-weighted inverted index of token/document terms.
24 | Scores new queries by taking sparse dot products.
25 | """
26 |
27 | def __init__(self, tfidf_path=None, strict=True):
28 | """
29 | Args:
30 | tfidf_path: path to saved model file
31 | strict: fail on empty queries or continue (and return empty result)
32 | """
33 | # Load from disk
34 | tfidf_path = tfidf_path
35 | logger.info('Loading %s' % tfidf_path)
36 | matrix, metadata = utils.load_sparse_csr(tfidf_path)
37 | self.doc_mat = matrix
38 | self.ngrams = metadata['ngram']
39 | self.hash_size = metadata['hash_size']
40 | self.tokenizer = SimpleTokenizer()
41 | self.doc_freqs = metadata['doc_freqs'].squeeze()
42 | self.doc_dict = metadata['doc_dict']
43 | self.num_docs = len(self.doc_dict[0])
44 | self.strict = strict
45 |
46 | def get_doc_index(self, doc_id):
47 | """Convert doc_id --> doc_index"""
48 | return self.doc_dict[0][doc_id]
49 |
50 | def get_doc_id(self, doc_index):
51 | """Convert doc_index --> doc_id"""
52 | return self.doc_dict[1][doc_index]
53 |
54 | def closest_docs(self, query, k=1):
55 | """Closest docs by dot product between query and documents
56 | in tfidf weighted word vector space.
57 | """
58 | spvec = self.text2spvec(query)
59 | res = spvec * self.doc_mat
60 |
61 | if len(res.data) <= k:
62 | o_sort = np.argsort(-res.data)
63 | else:
64 | o = np.argpartition(-res.data, k)[0:k]
65 | o_sort = o[np.argsort(-res.data[o])]
66 |
67 | doc_scores = res.data[o_sort]
68 | # doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]]
69 | doc_ids = [int(i) for i in res.indices[o_sort]] # TODO if none is returned?
70 | return doc_ids, doc_scores
71 |
72 | def batch_closest_docs(self, queries, k=1, num_workers=None):
73 | """Process a batch of closest_docs requests multithreaded.
74 | Note: we can use plain threads here as scipy is outside of the GIL.
75 | """
76 | with ThreadPool(num_workers) as threads:
77 | closest_docs = partial(self.closest_docs, k=k)
78 | results = threads.map(closest_docs, queries)
79 | return results
80 |
81 | def doc_scores(self, query_doc):
82 | """Get doc scores by dot product between query and documents
83 | in tfidf weighted word vector space.
84 | """
85 | query, doc_idx = query_doc
86 | spvec = self.text2spvec(query)
87 | res = spvec * self.doc_mat
88 | scores = res[0,doc_idx].toarray().tolist()[0]
89 | return scores
90 |
91 | def batch_doc_scores(self, queries, doc_idxs, num_workers=None):
92 | """Process a batch of doc_scores requests multithreaded.
93 | Note: we can use plain threads here as scipy is outside of the GIL.
94 | """
95 | with ThreadPool(num_workers) as threads:
96 | results = threads.map(self.doc_scores, zip(queries, doc_idxs))
97 | return results
98 |
99 | def parse(self, query):
100 | """Parse the query into tokens (either ngrams or tokens)."""
101 | tokens = self.tokenizer.tokenize(query)
102 | return tokens.ngrams(n=self.ngrams, uncased=True,
103 | filter_fn=utils.filter_ngram)
104 |
105 | def text2spvec(self, query, val_idx=False):
106 | """Create a sparse tfidf-weighted word vector from query.
107 |
108 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
109 | """
110 | # Get hashed ngrams
111 | words = self.parse(utils.normalize(query))
112 | wids = [utils.hash(w, self.hash_size) for w in words]
113 |
114 | if len(wids) == 0:
115 | if self.strict:
116 | raise RuntimeError('No valid word in: %s' % query)
117 | else:
118 | logger.warning('No valid word in: %s' % query)
119 | if val_idx:
120 | return np.array([]), np.array([])
121 | else:
122 | return sp.csr_matrix((1, self.hash_size))
123 |
124 | # Count TF
125 | wids_unique, wids_counts = np.unique(wids, return_counts=True)
126 | tfs = np.log1p(wids_counts)
127 |
128 | # Count IDF
129 | Ns = self.doc_freqs[wids_unique]
130 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5))
131 | idfs[idfs < 0] = 0
132 |
133 | # TF-IDF
134 | data = np.multiply(tfs, idfs)
135 |
136 | if val_idx:
137 | return data, wids_unique
138 |
139 | # One row, sparse csr matrix
140 | indptr = np.array([0, len(wids_unique)])
141 | spvec = sp.csr_matrix(
142 | (data, wids_unique, indptr), shape=(1, self.hash_size)
143 | )
144 |
145 | return spvec
146 |
--------------------------------------------------------------------------------
/tfidf_util.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Various retriever utilities."""
8 |
9 | import regex
10 | import unicodedata
11 | import numpy as np
12 | import scipy.sparse as sp
13 | from sklearn.utils import murmurhash3_32
14 |
15 |
16 | # ------------------------------------------------------------------------------
17 | # Sparse matrix saving/loading helpers.
18 | # ------------------------------------------------------------------------------
19 |
20 |
21 | def save_sparse_csr(filename, matrix, metadata=None):
22 | data = {
23 | 'data': matrix.data,
24 | 'indices': matrix.indices,
25 | 'indptr': matrix.indptr,
26 | 'shape': matrix.shape,
27 | 'metadata': metadata,
28 | }
29 | np.savez(filename, **data)
30 |
31 |
32 | def load_sparse_csr(filename):
33 | loader = np.load(filename, allow_pickle=True)
34 | matrix = sp.csr_matrix((loader['data'], loader['indices'],
35 | loader['indptr']), shape=loader['shape'])
36 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None
37 |
38 |
39 | # ------------------------------------------------------------------------------
40 | # Token hashing.
41 | # ------------------------------------------------------------------------------
42 |
43 |
44 | def hash(token, num_buckets):
45 | """Unsigned 32 bit murmurhash for feature hashing."""
46 | return murmurhash3_32(token, positive=True) % num_buckets
47 |
48 |
49 | # ------------------------------------------------------------------------------
50 | # Text cleaning.
51 | # ------------------------------------------------------------------------------
52 |
53 |
54 | STOPWORDS = {
55 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your',
56 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she',
57 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their',
58 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that',
59 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
60 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an',
61 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of',
62 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through',
63 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down',
64 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then',
65 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any',
66 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor',
67 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can',
68 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've',
69 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven',
70 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren',
71 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``"
72 | }
73 |
74 |
75 | def normalize(text):
76 | """Resolve different type of unicode encodings."""
77 | return unicodedata.normalize('NFD', text)
78 |
79 |
80 | def filter_word(text):
81 | """Take out english stopwords, punctuation, and compound endings."""
82 | text = normalize(text)
83 | if regex.match(r'^\p{P}+$', text):
84 | return True
85 | if text.lower() in STOPWORDS:
86 | return True
87 | return False
88 |
89 |
90 | def filter_ngram(gram, mode='any'):
91 | """Decide whether to keep or discard an n-gram.
92 |
93 | Args:
94 | gram: list of tokens (length N)
95 | mode: Option to throw out ngram if
96 | 'any': any single token passes filter_word
97 | 'all': all tokens pass filter_word
98 | 'ends': book-ended by filterable tokens
99 | """
100 | filtered = [filter_word(w) for w in gram]
101 | if mode == 'any':
102 | return any(filtered)
103 | elif mode == 'all':
104 | return all(filtered)
105 | elif mode == 'ends':
106 | return filtered[0] or filtered[-1]
107 | else:
108 | raise ValueError('Invalid mode: %s' % mode)
109 |
110 | def get_field(d, field_list):
111 | """get the subfield associated to a list of elastic fields
112 | E.g. ['file', 'filename'] to d['file']['filename']
113 | """
114 | if isinstance(field_list, str):
115 | return d[field_list]
116 | else:
117 | idx = d.copy()
118 | for field in field_list:
119 | idx = idx[field]
120 | return idx
121 |
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import six
24 |
25 |
26 | def convert_to_unicode(text):
27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
28 | if six.PY3:
29 | if isinstance(text, str):
30 | return text
31 | elif isinstance(text, bytes):
32 | return text.decode("utf-8", "ignore")
33 | else:
34 | raise ValueError("Unsupported string type: %s" % (type(text)))
35 | elif six.PY2:
36 | if isinstance(text, str):
37 | return text.decode("utf-8", "ignore")
38 | elif isinstance(text, unicode):
39 | return text
40 | else:
41 | raise ValueError("Unsupported string type: %s" % (type(text)))
42 | else:
43 | raise ValueError("Not running on Python2 or Python 3?")
44 |
45 |
46 | def printable_text(text):
47 | """Returns text encoded in a way suitable for print or `tf.logging`."""
48 |
49 | # These functions want `str` for both Python2 and Python3, but in one case
50 | # it's a Unicode string and in the other it's a byte string.
51 | if six.PY3:
52 | if isinstance(text, str):
53 | return text
54 | elif isinstance(text, bytes):
55 | return text.decode("utf-8", "ignore")
56 | else:
57 | raise ValueError("Unsupported string type: %s" % (type(text)))
58 | elif six.PY2:
59 | if isinstance(text, str):
60 | return text
61 | elif isinstance(text, unicode):
62 | return text.encode("utf-8")
63 | else:
64 | raise ValueError("Unsupported string type: %s" % (type(text)))
65 | else:
66 | raise ValueError("Not running on Python2 or Python 3?")
67 |
68 |
69 | def load_vocab(vocab_file):
70 | """Loads a vocabulary file into a dictionary."""
71 | vocab = collections.OrderedDict()
72 | index = 0
73 | with open(vocab_file, "r") as reader:
74 | while True:
75 | token = convert_to_unicode(reader.readline())
76 | if not token:
77 | break
78 | token = token.strip()
79 | vocab[token] = index
80 | index += 1
81 | return vocab
82 |
83 |
84 | def convert_tokens_to_ids(vocab, tokens):
85 | """Converts a sequence of tokens into ids using the vocab."""
86 | ids = []
87 | for token in tokens:
88 | ids.append(vocab[token])
89 | return ids
90 |
91 |
92 | def whitespace_tokenize(text):
93 | """Runs basic whitespace cleaning and splitting on a peice of text."""
94 | text = text.strip()
95 | if not text:
96 | return []
97 | tokens = text.split()
98 | return tokens
99 |
100 |
101 | class FullTokenizer(object):
102 | """Runs end-to-end tokenziation."""
103 |
104 | def __init__(self, vocab_file, do_lower_case=True):
105 | self.vocab = load_vocab(vocab_file)
106 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
107 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
108 |
109 | def tokenize(self, text, basic_done=False):
110 | split_tokens = []
111 | if basic_done:
112 | assert type(text)==list
113 | else:
114 | text = self.basic_tokenizer.tokenize(text)
115 |
116 | for token in text:
117 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
118 | split_tokens.append(sub_token)
119 |
120 | return split_tokens
121 |
122 | def convert_tokens_to_ids(self, tokens):
123 | return convert_tokens_to_ids(self.vocab, tokens)
124 |
125 |
126 | class BasicTokenizer(object):
127 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
128 |
129 | def __init__(self, do_lower_case=True):
130 | """Constructs a BasicTokenizer.
131 |
132 | Args:
133 | do_lower_case: Whether to lower case the input.
134 | """
135 | self.do_lower_case = do_lower_case
136 |
137 | def tokenize(self, text):
138 | """Tokenizes a piece of text."""
139 | text = convert_to_unicode(text)
140 | text = self._clean_text(text)
141 | # This was added on November 1st, 2018 for the multilingual and Chinese
142 | # models. This is also applied to the English models now, but it doesn't
143 | # matter since the English models were not trained on any Chinese data
144 | # and generally don't have any Chinese data in them (there are Chinese
145 | # characters in the vocabulary because Wikipedia does have some Chinese
146 | # words in the English Wikipedia.).
147 | text = self._tokenize_chinese_chars(text)
148 | orig_tokens = whitespace_tokenize(text)
149 | split_tokens = []
150 | for token in orig_tokens:
151 | if self.do_lower_case:
152 | token = token.lower()
153 | token = self._run_strip_accents(token)
154 | split_tokens.extend(self._run_split_on_punc(token))
155 |
156 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
157 | return output_tokens
158 |
159 | def _run_strip_accents(self, text):
160 | """Strips accents from a piece of text."""
161 | text = unicodedata.normalize("NFD", text)
162 | output = []
163 | for char in text:
164 | cat = unicodedata.category(char)
165 | if cat == "Mn":
166 | continue
167 | output.append(char)
168 | return "".join(output)
169 |
170 | def _run_split_on_punc(self, text):
171 | """Splits punctuation on a piece of text."""
172 | chars = list(text)
173 | i = 0
174 | start_new_word = True
175 | output = []
176 | while i < len(chars):
177 | char = chars[i]
178 | if _is_punctuation(char):
179 | output.append([char])
180 | start_new_word = True
181 | else:
182 | if start_new_word:
183 | output.append([])
184 | start_new_word = False
185 | output[-1].append(char)
186 | i += 1
187 |
188 | return ["".join(x) for x in output]
189 |
190 | def _tokenize_chinese_chars(self, text):
191 | """Adds whitespace around any CJK character."""
192 | output = []
193 | for char in text:
194 | cp = ord(char)
195 | if self._is_chinese_char(cp):
196 | output.append(" ")
197 | output.append(char)
198 | output.append(" ")
199 | else:
200 | output.append(char)
201 | return "".join(output)
202 |
203 | def _is_chinese_char(self, cp):
204 | """Checks whether CP is the codepoint of a CJK character."""
205 | # This defines a "chinese character" as anything in the CJK Unicode block:
206 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
207 | #
208 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
209 | # despite its name. The modern Korean Hangul alphabet is a different block,
210 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
211 | # space-separated words, so they are not treated specially and handled
212 | # like the all of the other languages.
213 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
214 | (cp >= 0x3400 and cp <= 0x4DBF) or #
215 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
216 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
217 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
218 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
219 | (cp >= 0xF900 and cp <= 0xFAFF) or #
220 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
221 | return True
222 |
223 | return False
224 |
225 | def _clean_text(self, text):
226 | """Performs invalid character removal and whitespace cleanup on text."""
227 | output = []
228 | for char in text:
229 | cp = ord(char)
230 | if cp == 0 or cp == 0xfffd or _is_control(char):
231 | continue
232 | if _is_whitespace(char):
233 | output.append(" ")
234 | else:
235 | output.append(char)
236 | return "".join(output)
237 |
238 |
239 | class WordpieceTokenizer(object):
240 | """Runs WordPiece tokenization."""
241 |
242 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
243 | self.vocab = vocab
244 | self.unk_token = unk_token
245 | self.max_input_chars_per_word = max_input_chars_per_word
246 |
247 | def tokenize(self, text):
248 | """Tokenizes a piece of text into its word pieces.
249 |
250 | This uses a greedy longest-match-first algorithm to perform tokenization
251 | using the given vocabulary.
252 |
253 | For example:
254 | input = "unaffable"
255 | output = ["un", "##aff", "##able"]
256 |
257 | Args:
258 | text: A single token or whitespace separated tokens. This should have
259 | already been passed through `BasicTokenizer.
260 |
261 | Returns:
262 | A list of wordpiece tokens.
263 | """
264 |
265 | text = convert_to_unicode(text)
266 |
267 | output_tokens = []
268 | for token in whitespace_tokenize(text):
269 | chars = list(token)
270 | if len(chars) > self.max_input_chars_per_word:
271 | output_tokens.append(self.unk_token)
272 | continue
273 |
274 | is_bad = False
275 | start = 0
276 | sub_tokens = []
277 | while start < len(chars):
278 | end = len(chars)
279 | cur_substr = None
280 | while start < end:
281 | substr = "".join(chars[start:end])
282 | if start > 0:
283 | substr = "##" + substr
284 | if substr in self.vocab:
285 | cur_substr = substr
286 | break
287 | end -= 1
288 | if cur_substr is None:
289 | is_bad = True
290 | break
291 | sub_tokens.append(cur_substr)
292 | start = end
293 |
294 | if is_bad:
295 | output_tokens.append(self.unk_token)
296 | else:
297 | output_tokens.extend(sub_tokens)
298 | return output_tokens
299 |
300 |
301 | def _is_whitespace(char):
302 | """Checks whether `chars` is a whitespace character."""
303 | # \t, \n, and \r are technically contorl characters but we treat them
304 | # as whitespace since they are generally considered as such.
305 | if char == " " or char == "\t" or char == "\n" or char == "\r":
306 | return True
307 | cat = unicodedata.category(char)
308 | if cat == "Zs":
309 | return True
310 | return False
311 |
312 |
313 | def _is_control(char):
314 | """Checks whether `chars` is a control character."""
315 | # These are technically control characters but we count them as whitespace
316 | # characters.
317 | if char == "\t" or char == "\n" or char == "\r":
318 | return False
319 | cat = unicodedata.category(char)
320 | if cat.startswith("C"):
321 | return True
322 | return False
323 |
324 |
325 | def _is_punctuation(char):
326 | """Checks whether `chars` is a punctuation character."""
327 | cp = ord(char)
328 | # We treat all non-letter/number ASCII as punctuation.
329 | # Characters such as "^", "$", and "`" are not in the Unicode
330 | # Punctuation class but we treat them as punctuation anyways, for
331 | # consistency.
332 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
333 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
334 | return True
335 | cat = unicodedata.category(char)
336 | if cat.startswith("P"):
337 | return True
338 | return False
339 |
--------------------------------------------------------------------------------
/tokenizer_util.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Base tokenizer/tokens classes and utilities."""
8 |
9 | import copy
10 |
11 |
12 | class Tokens(object):
13 | """A class to represent a list of tokenized text."""
14 | TEXT = 0
15 | TEXT_WS = 1
16 | SPAN = 2
17 | POS = 3
18 | LEMMA = 4
19 | NER = 5
20 |
21 | def __init__(self, data, annotators, opts=None):
22 | self.data = data
23 | self.annotators = annotators
24 | self.opts = opts or {}
25 |
26 | def __len__(self):
27 | """The number of tokens."""
28 | return len(self.data)
29 |
30 | def slice(self, i=None, j=None):
31 | """Return a view of the list of tokens from [i, j)."""
32 | new_tokens = copy.copy(self)
33 | new_tokens.data = self.data[i: j]
34 | return new_tokens
35 |
36 | def untokenize(self):
37 | """Returns the original text (with whitespace reinserted)."""
38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
39 |
40 | def words(self, uncased=False):
41 | """Returns a list of the text of each token
42 |
43 | Args:
44 | uncased: lower cases text
45 | """
46 | if uncased:
47 | return [t[self.TEXT].lower() for t in self.data]
48 | else:
49 | return [t[self.TEXT] for t in self.data]
50 |
51 | def offsets(self):
52 | """Returns a list of [start, end) character offsets of each token."""
53 | return [t[self.SPAN] for t in self.data]
54 |
55 | def pos(self):
56 | """Returns a list of part-of-speech tags of each token.
57 | Returns None if this annotation was not included.
58 | """
59 | if 'pos' not in self.annotators:
60 | return None
61 | return [t[self.POS] for t in self.data]
62 |
63 | def lemmas(self):
64 | """Returns a list of the lemmatized text of each token.
65 | Returns None if this annotation was not included.
66 | """
67 | if 'lemma' not in self.annotators:
68 | return None
69 | return [t[self.LEMMA] for t in self.data]
70 |
71 | def entities(self):
72 | """Returns a list of named-entity-recognition tags of each token.
73 | Returns None if this annotation was not included.
74 | """
75 | if 'ner' not in self.annotators:
76 | return None
77 | return [t[self.NER] for t in self.data]
78 |
79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
80 | """Returns a list of all ngrams from length 1 to n.
81 |
82 | Args:
83 | n: upper limit of ngram length
84 | uncased: lower cases text
85 | filter_fn: user function that takes in an ngram list and returns
86 | True or False to keep or not keep the ngram
87 | as_string: return the ngram as a string vs list
88 | """
89 | def _skip(gram):
90 | if not filter_fn:
91 | return False
92 | return filter_fn(gram)
93 |
94 | words = self.words(uncased)
95 | ngrams = [(s, e + 1)
96 | for s in range(len(words))
97 | for e in range(s, min(s + n, len(words)))
98 | if not _skip(words[s:e + 1])]
99 |
100 | # Concatenate into strings
101 | if as_strings:
102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
103 |
104 | return ngrams
105 |
106 | def entity_groups(self):
107 | """Group consecutive entity tokens with the same NER tag."""
108 | entities = self.entities()
109 | if not entities:
110 | return None
111 | non_ent = self.opts.get('non_ent', 'O')
112 | groups = []
113 | idx = 0
114 | while idx < len(entities):
115 | ner_tag = entities[idx]
116 | # Check for entity tag
117 | if ner_tag != non_ent:
118 | # Chomp the sequence
119 | start = idx
120 | while (idx < len(entities) and entities[idx] == ner_tag):
121 | idx += 1
122 | groups.append((self.slice(start, idx).untokenize(), ner_tag))
123 | else:
124 | idx += 1
125 | return groups
126 |
127 |
128 | class Tokenizer(object):
129 | """Base tokenizer class.
130 | Tokenizers implement tokenize, which should return a Tokens class.
131 | """
132 | def tokenize(self, text):
133 | raise NotImplementedError
134 |
135 | def shutdown(self):
136 | pass
137 |
138 | def __del__(self):
139 | self.shutdown()
140 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
3 | level=logging.INFO)
4 | logger = logging.getLogger(__name__)
5 |
6 |
7 | def check_diff(model_a, model_b):
8 | a_set = set([a for a in model_a.keys()])
9 | b_set = set([b for b in model_b.keys()])
10 | if a_set != b_set:
11 | logger.info('load with different params =>')
12 | if len(a_set - b_set) > 0:
13 | logger.info('Loaded weight does not have ' + str(a_set - b_set))
14 | if len(b_set - a_set) > 0:
15 | logger.info('Model code does not have: ' + str(b_set - a_set))
16 |
--------------------------------------------------------------------------------