├── LICENSE
├── NOTICE
├── README.md
├── competition
├── Dockerfile
├── build.sh
├── submission.sh
└── workspace
│ ├── __init__.py
│ ├── bert_tokenizer
│ ├── __init__.py
│ ├── file_utils.py
│ ├── mobilebert-uncased
│ │ ├── tokenizer.json
│ │ └── vocab.txt
│ ├── tokenization_bert.py
│ ├── tokenization_utils.py
│ └── tokenization_utils_base.py
│ └── infer.py
└── playground
├── build.sh
├── configs
├── Dockerfile-pytorch
├── Dockerfile-tfserving
├── requirements-pytorch.txt
├── requirements-tfserving.txt
└── requirements-tfserving_faiss.txt
├── entrypoint.sh
└── workspace
├── minimal_rnr
├── __init__.py
├── pytorch
│ ├── __init__.py
│ ├── inferencer.py
│ └── model.py
├── tfserving
│ ├── __init__.py
│ ├── bert_tokenizer
│ │ ├── __init__.py
│ │ ├── file_utils.py
│ │ ├── mobilebert-uncased
│ │ │ ├── tokenizer.json
│ │ │ └── vocab.txt
│ │ ├── tokenization_bert.py
│ │ ├── tokenization_utils.py
│ │ └── tokenization_utils_base.py
│ └── inferencer.py
└── utils
│ ├── __init__.py
│ ├── demo.py
│ ├── evaluation.py
│ ├── inference.py
│ ├── inferencer.py
│ ├── logger.py
│ └── static
│ ├── examples.txt
│ ├── files
│ ├── all.js
│ ├── bootstrap.min.js
│ ├── icon.png
│ ├── jquery-3.3.1.min.js
│ ├── popper.min.js
│ └── style.css
│ └── index.html
├── run_pt_demo.py
├── run_pt_inference.py
├── run_tf_demo.py
└── run_tf_inference.py
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "{}"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright 2021-present NAVER Corp.
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Designing a Minimal Retrieve-and-Read System for Open-Domain Question Answering
2 |
3 | ## Abstract
4 | In open-domain question answering (QA), retrieve-and-read mechanism has the inherent benefit of interpretability and the easiness of adding, removing, or editing knowledge compared to the parametric approaches of closed-book QA models. However, it is also known to suffer from its large storage footprint due to its document corpus and index. Here, we discuss several orthogonal strategies to drastically reduce the footprint of a retrieve-and-read open-domain QA system by up to 160x. Our results indicate that retrieve-and-read can be a viable option even in a highly constrained serving environment such as edge devices, as we show that it can achieve better accuracy than a purely parametric model with comparable docker-level system size.
5 |
6 | - [Paper](https://arxiv.org/abs/2104.07242) (To appear in [NAACL 2021](https://2021.naacl.org/))
7 | - Authors: [Sohee Yang](https://soheeyang.github.io/) and [Minjoon Seo](http://seominjoon.github.io/)
8 | - [Live Web Demo](http://52.156.155.214:8890)
9 | - The demo normally takes **1-3 seconds** for inference in the default setting, but ocassionally becomes slow according to the server condition.
10 | - BibTeX:
11 |
12 | ```
13 | @inproceedings{mininalrnr,
14 | title={Designing a Minimal Retrieve-and-Read System for Open-Domain Question Answering},
15 | author={Yang, Sohee and Seo, Minjoon},
16 | booktitle={NAACL},
17 | year={2021}
18 | }
19 | ```
20 |
21 | This repository contains the code of the **Minimal Retrieve & Read QA System** that ranked the first place in the human (manual) evaluation and the second place in the automatic evaluation on "Systems Under 500Mb Track" of the [NeurIPS 2020 EfficientQA competition](https://efficientqa.github.io/).
22 |
23 | ## Web Demo
24 | 
25 |
26 | You can play with the QA system in the [Live Web Demo](http://52.156.155.214:8890). You can also dynamically change the inference setting by controlling the values of `top_k` and `passage_score_weight`. (The demo normally takes **1-3 seconds** for inference in the default setting, but ocassionally becomes slow according to the server condition.)
27 | - `top_k`: sets the number of passages to retrieve and pass to the reader. The value must be a positive integer. Default value is set to 50. For the live web demo, the values are limited within the range [1, 100] to prevent freezing from reading too many passages.
28 | - `passage_score_weight`:
29 | - When the default value `null` is used, only the passage with the highest ranking score is used to extract the answers from. This is the setting used in DPR.
30 | - If a float value (λ ∈ [0, 1] highly recommended) is given, multiple passages are considered to select the answer. Specifically, answer spans from multiple passages are scored using the weighted sum of passage ranking scores and answer spans scores. The weighted sum is calculated as (1 - λ) (log Pstart + log Pend) + 2λ log Prank. Please refer to the paper for more details.
31 |
32 | If you have Docker installed, you can also run the web demo on your local machine in five minutes using [this command](#b-local-web-demo-quickstart).
33 |
34 | ## A. Introduction
35 |
36 | This repository contains the code for an interactive web demo, code for inference on the questions in a file (and evaluation on the answers), links to the model graphs/checkpoints (models), links to the index and preprocessed corpus files (resources), and links to built docker images.
37 |
38 | - `competition` directory contains the code to build and run the minimal-sized docker container used for the EfficientQA competition. Typing `du -h /` in the launched container reports 484.68MB as its size. Please see [**E. Competition Setting: Build & Run**](#e-competition-setting-build--run) for detail.
39 | - `playground` directory contains more practical, refactored code to play with that one can either run a web demo or run inference on a file using models built in different settings. Please see [**D. Playground: Build & Run**](#d-playground-build--run) for detail.
40 |
41 | ## B. Local Web Demo Quickstart
42 |
43 | To run the web demo on your local machine, run the following using docker:
44 |
45 | ```bash
46 | docker run \
47 | -v /etc/localtime:/etc/localtime:ro \
48 | --oom-kill-disable=true \
49 | --env MODE=demo \
50 | -p 10001:10001 \
51 | soheeyang/minimal-rnr-qa:effqa-tfserving \
52 | /workspace/entrypoint.sh
53 | ```
54 |
55 | Then, access [http://localhost:10001](http://localhost:10001) to play with the demo!
56 |
57 | ## C. Pre-Built Docker Images
58 |
59 | Available in [https://hub.docker.com/r/soheeyang/minimal-rnr-qa](https://hub.docker.com/r/soheeyang/minimal-rnr-qa)
60 | - soheeyang/minimal-rnr-qa:$DATASET-$MODEL_TYPE
61 | - $DATASET: [ effqa | nq | trivia ]
62 | - $MODEL_TYPE: [ tfserving | tfserving_faiss | pytorch ]
63 | - soheeyang/minimal-rnr-qa:$DATASET-competition
64 | - $DATASET: [ effqa | nq | trivia ]
65 | - `soheeyang/minimal-rnr-qa:effqa-competition` is the docker container used for the EfficientQA challenge
66 |
67 | The follwoing are descriptions for each of the options for DATASET and MODEL_TYPE.
68 | - `$DATASET` is used to select the **model**; The model trained on this dataset is selected. The value must be one of the followings.
69 | - `effqa` trained on Natural Questions (NQ) train set, validation done on EfficientQA dev set
70 | - `nq` trained on NQ train set, validation done on NQ dev set
71 | - `trivia` trained on TriviaQA (Trivia) train set, validation done on Trivia dev set
72 | - `$MODEL_TYPE` is used to select the type of the chosen model. The value must be one of the followings.
73 | - `tfserving` TensorFlow (TF) graph for TF Serving. Index is fused into the graph to perform efficient passage retrieval without additional library dependency. This is the setting used in the EfficientQA competition. CPU serving. Smallest system footprint.
74 | - `tfserving_faiss` TF graph for TF Serving, but without index. It installs and makes use of FAISS to perform passage retrieval. CPU serving.
75 | - `pytorch` PyTorch checkpoint. It installs and makes use of FAISS to perform passage retrieval. The model code can be found at `playground/workspace/minimal_rnr/pytorch/model.py`. Supports serving on both CPU & GPU. Largest system footprint.
76 |
77 | ## D. Playground: Build & Run
78 |
79 | You can skip steps 1 and 2 if you use the [pre-built docker images](#c-pre-built-docker-images).
80 |
81 | ### 1. Download the code and necessary resources
82 |
83 | ```bash
84 | git clone https://github.com/clovaai/minimal-rnr-qa.git
85 | cd minimal-rnr-qa/playground
86 |
87 | wget https://dl.dropboxusercontent.com/s/l7034dttyp4bbf2/minrnr_playground_models.tar.gz
88 | wget https://dl.dropboxusercontent.com/s/51g36ytprbcl3mv/minrnr_playground_resources.tar.gz
89 |
90 | tar xvf minrnr_playground_models.tar.gz
91 | tar xvf minrnr_playground_resources.tar.gz
92 | ```
93 |
94 | ### 2. Build docker image
95 |
96 | ```bash
97 | # inside minimal-rnr-qa/playground
98 |
99 | DATASET=effqa
100 | MODEL_TYPE=tfserving
101 |
102 | chmod a+x ./build.sh
103 | ./build.sh $DATASET $MODEL_TYPE
104 | ```
105 | - This command builds a docker image tagged as `minimal-rnr-qa:$DATASET-$MODEL_TYPE`.
106 |
107 | ### 3-1. Run web demo
108 |
109 | ```jsx
110 | docker run \
111 | -v /etc/localtime:/etc/localtime:ro \
112 | --oom-kill-disable=true \
113 | --env MODE=demo \
114 | -p 10001:10001 \
115 | minimal-rnr-qa:$DATASET-$MODEL_TYPE \
116 | /workspace/entrypoint.sh
117 | ```
118 |
119 | - `-v /etc/localtime:/etc/localtime:ro` sets the timezone of the container to be same with the host's.
120 | - `--oom-kill-disable=true` prevents kill by OOM.
121 | - `--env MODE=demo` [REQUIRED] runs a **web demo**.
122 | - `-p $HOST_PORT:10001` [REQUIRED] sets the port of the web page. connects the port 10001 of the container to a port of the host.
123 | - `minimal-rnr-qa:$DATASET-$MODEL_TYPE` [REQUIRED] Tag of the built image.
124 | - `/workspace/entrypoint.sh` [REQUIRED] Entrypoint of the container.
125 |
126 | ### 3-2. Run inference on a file
127 |
128 | #### Download input data
129 | The input files for EfficientQA dev set, NQ dev & test set, and Trivia dev & test set can be downloaded at once.
130 | ```bash
131 | INPUT_DIR=/tmp/minimal-rnr-qa
132 | OUTPUT_DIR=/tmp/minimal-rnr-qa
133 |
134 | mkdir -p $INPUT_DIR
135 | mkdir -p $OUTPUT_DIR
136 |
137 | wget -P $INPUT_DIR https://dl.dropboxusercontent.com/s/juh12j1z0ct3zeu/minrnr_datasets.tar.gz
138 | tar xvf $INPUT_DIR/minrnr_datasets.tar.gz -C $INPUT_DIR --strip-components=1
139 | ```
140 |
141 | #### Run inference
142 | ```bash
143 | INPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1.jsonl
144 | OUTPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1-predictions.jsonl
145 | TOP_K=80
146 | PASSAGE_W=0.8
147 |
148 | docker run \
149 | -v /etc/localtime:/etc/localtime:ro \
150 | --oom-kill-disable=true \
151 | -v $INPUT_DIR:/input \
152 | -v $OUTPUT_DIR:/output \
153 | --env MODE=file \
154 | --env TOP_K=$TOP_K \
155 | --env PASSAGE_W=$PASSAGE_W \
156 | minimal-rnr-qa:$DATASET-$MODEL_TYPE \
157 | /workspace/entrypoint.sh \
158 | /input/$INPUT_FILE_NAME \
159 | /output/$OUTPUT_FILE_NAME
160 | ```
161 |
162 | - `-v /etc/localtime:/etc/localtime:ro` sets the timezone of the container to be same with the host's.
163 | - `--oom-kill-disable=true` prevents kill by OOM.
164 | - `-v $INPUT_DIR:/input` [REQUIRED] maps `$INPUT_DIR` of the host to `/input` in the container where the data is read from. This directory must have the file to run inference on.
165 | - `-v $OUTPUT_DIR:/output` [OPTIONAL] maps `$OUTPUT_DIR` of the host to `/output` in the container where the prediction result file is written. If not specified, the output prediction file is written only in the container.
166 | - `--env MODE=demo` [REQUIRED] runs **inference on the given input file** and outputs the predictions
167 | - `--env TOP_K=$INT_VALUE` [OPTIONAL] sets the number of passages to retrieve and pass to the reader. It must be an integer value. Default value is set to 50.
168 | - `--env PASSAGE_W=$FLOAT_VALUE` [OPTIONAL]
169 | - If the option is not used (as default) or `null` is given as the value, only the passage with the highest ranking score is used to extract the answers from. This is the setting used in DPR.
170 | - If the value is given, multiple passages are considered to select the answer. Specifically, answer spans from multiple passages are scored using the weighted sum of passage ranking scores and answer spans scores. The given value for this option must be λ ∈ [0, 1], and the weighted sum is calculated as (1 - λ) (log Pstart + log Pend) + 2λ log Prank. This value may be tuned on the validation set to slightly raise the end-to-end question answering accuracy.
171 | - `minimal-rnr-qa:$DATASET_$MODEL_TYPE` [REQUIRED] Tag of the built image.
172 | - `/workspace/entrypoint.sh` [REQUIRED] Entrypoint of the container.
173 | - `/input/$INPUT_FILE_NAME` [REQUIRED] Name of the file to run inference on. CSV or JSON Lines files are supported.
174 | - CSV files must consist of row of question strings or `question\t["answer_1", ..., "answer_n"]`.
175 | - JSON Lines files must consist of rows of `{"question": ...}`, `{"question": ..., "answers": ...}`, or `{"question": ..., "answer": ...}`.
176 | - If answers exist, Exact Match (EM) score is calculated and reported at the end of the inference.
177 | - `/output/$OUTPUT_FILE_NAME` [REQUIRED] Name of the output prediction result file. The file takes JSON Lines format. Please note that even if "answer" is given as the key for answers in the input file, it changes to "answers" in the prediction file for consistency and easier evaluation.
178 |
179 | ## E. Competition Setting: Build & Run
180 |
181 | You can skip steps 1 and 2 if you use the [pre-built docker images](#c-pre-built-docker-images).
182 |
183 | ### 1. Download the code and necessary resources
184 |
185 | ```bash
186 | git clone https://github.com/clovaai/minimal-rnr-qa.git
187 | cd minimal-rnr-qa/competition
188 |
189 | wget https://dl.dropboxusercontent.com/s/s5fa4rgf48bhhkb/minrnr_competition_resources.tar.gz
190 | wget https://dl.dropboxusercontent.com/s/utwzozvuret1sdo/minrnr_competition_models.tar.gz
191 |
192 | tar xvf minrnr_competition_models.tar.gz
193 | tar xvf minrnr_competition_resources.tar.gz
194 | ```
195 |
196 | ### 2. Build docker image
197 | ```bash
198 | # inside minimal-rnr-qa/competition
199 |
200 | DATASET=effqa
201 |
202 | chmod a+x ./build.sh
203 | ./build.sh $DATASET
204 | ```
205 | - Values for `$DATASET`
206 | - `effqa`: the model used in the challenge (Section 3 in the paper)
207 | - `nq`: trained on Natural Questions (Appendix A.5 in the paper)
208 | - `trivia`: trained on TriviaQA (Appendix A.5 in the paper)
209 | - This command builds a docker image tagged as `minimal-rnr-qa:$DATASET-competition`.
210 |
211 | ### 3. Prepare data (same as the above)
212 |
213 | The input files for EfficientQA dev set, NQ dev & test set, and Trivia dev & test set can be downloaded at once.
214 | ```bash
215 | INPUT_DIR=/tmp/minimal-rnr-qa
216 | OUTPUT_DIR=/tmp/minimal-rnr-qa
217 |
218 | mkdir -p $INPUT_DIR
219 | mkdir -p $OUTPUT_DIR
220 |
221 | wget -P $INPUT_DIR https://dl.dropboxusercontent.com/s/juh12j1z0ct3zeu/minrnr_datasets.tar.gz
222 | tar xvf $INPUT_DIR/minrnr_datasets.tar.gz -C $INPUT_DIR --strip-components=1
223 | ```
224 |
225 | ### 4. Run
226 |
227 | ```bash
228 | # The setting used for EfficientQA submission
229 |
230 | INPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1.jsonl
231 | OUTPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1-predictions.jsonl
232 | TOP_K=80
233 | PASSAGE_W=0.8
234 |
235 | docker run \
236 | -v ${INPUT_DIR}:/input \
237 | -v ${OUTPUT_DIR}:/output \
238 | --env TOP_K=$TOP_K \
239 | --env PASSAGE_W=$PASSAGE_W \
240 | --network="none" \
241 | --oom-kill-disable=true \
242 | minimal-rnr-qa:$DATASET-competition \
243 | /submission.sh \
244 | /input/$INPUT_FILE_NAME \
245 | /output/$OUTPUT_FILE_NAME
246 | ```
247 |
248 | - Below are the parameters to reproduce each of the results of the last row in Table 3 (in the Appendix of the paper).
249 | - EfficientQA dev
250 | - DATASET=effqa / TOP_K=80 / PASSAGE_W=null
251 | - INPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1.jsonl (from [this link](https://github.com/google-research-datasets/natural-questions/blob/master/nq_open/NQ-open.efficientqa.dev.1.1.jsonl))
252 | - While 34.33 is reported in the paper, the value changed to 34.55 after we rebuilt the TensorFlow graph w.r.t. refactoring. The model supported here is the latter one.
253 | - NQ dev
254 | - DATASET=nq / TOP_K=100 / PASSAGE_W=null
255 | - INPUT_FILE_NAME=nq-dev.jsonl
256 | - NQ test
257 | - DATASET=nq / TOP_K=90 / PASSAGE_W=null
258 | - INPUT_FILE_NAME=nq-test.jsonl
259 | - Trivia dev
260 | - DATASET=trivia / TOP_K=100 / PASSAGE_W=null
261 | - INPUT_FILE_NAME=trivia-dev.jsonl
262 | - Trivia test
263 | - DATASET=trivia / TOP_K=100 / PASSAGE_W=null
264 | - INPUT_FILE_NAME=trivia-test.jsonl
265 |
266 | ## F. License
267 |
268 | ```
269 | Copyright 2021-present NAVER Corp.
270 |
271 | Licensed under the Apache License, Version 2.0 (the "License");
272 | you may not use this file except in compliance with the License.
273 | You may obtain a copy of the License at
274 |
275 | http://www.apache.org/licenses/LICENSE-2.0
276 |
277 | Unless required by applicable law or agreed to in writing, software
278 | distributed under the License is distributed on an "AS IS" BASIS,
279 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
280 | See the License for the specific language governing permissions and
281 | limitations under the License.
282 | ```
283 |
--------------------------------------------------------------------------------
/competition/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/serving:2.3.0 as build_image
2 | FROM python:3.6.11-slim-buster
3 |
4 | ARG DATASET
5 |
6 | # Install TF Serving pkg.
7 | COPY --from=build_image /usr/bin/tensorflow_model_server /usr/bin/tensorflow_model_server
8 |
9 | # Reduce size
10 | RUN apt-get update && apt-get install -y --no-install-recommends lzma bzip2 && apt-get clean && rm -rf /var/lib/apt/lists/*
11 | RUN bzip2 /usr/bin/tensorflow_model_server
12 |
13 | # Install python packages.
14 | RUN pip install --no-cache-dir absl-py requests
15 | RUN pip uninstall pip -y
16 | RUN gzip -r /usr/local/lib/python3.6
17 |
18 | # Delete unnecessary files
19 | RUN rm -rf /root/* && rm /usr/bin/perl* && rm -rf /usr/lib/x86_64-linux-gnu/perl* && rm -rf /var/cache
20 |
21 | COPY models/$DATASET /models/minimal-rnr-qa/1
22 | COPY resources/$DATASET /resources
23 | COPY workspace /workspace
24 | COPY submission.sh /
25 |
26 | RUN chmod a+x /submission.sh
27 |
28 | WORKDIR /
29 |
--------------------------------------------------------------------------------
/competition/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # coding=utf-8
3 | # minimal-rnr-qa
4 | # Copyright 2021-present NAVER Corp.
5 | # Apache License v2.0
6 |
7 | if [ "$#" -ne 1 ]; then
8 | echo "Usage: ./build.sh DATASET"
9 | echo "DATASET: effqa | nq | trivia"
10 | exit 0
11 | fi
12 |
13 | DATASET=$1 # ["effqa", "nq", "trivia"]
14 |
15 | docker build -f Dockerfile \
16 | --build-arg DATASET=$DATASET \
17 | . -t minimal-rnr-qa:$DATASET-competition
18 |
19 | echo "minimal-rnr-qa:$DATASET"-competition
--------------------------------------------------------------------------------
/competition/submission.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # coding=utf-8
3 | # minimal-rnr-qa
4 | # Copyright 2021-present NAVER Corp.
5 | # Apache License v2.0
6 |
7 | URL="http://127.0.0.1:8501/v1/models/minimal-rnr-qa:predict"
8 |
9 | # decompress tensorflow_model_server bzip2
10 | bzip2 -d /usr/bin/tensorflow_model_server.bz2
11 | chmod a+x /usr/bin/tensorflow_model_server
12 |
13 | # decompress python
14 | gzip -dr /usr/local/lib/python3.6
15 |
16 | # decompress resources lzma
17 | for f in /resources/*.lzma; do lzma -dc $f > "${f%.*}"; done
18 | for f in /resources/*.lzma; do rm $f; done
19 |
20 | # decompress models bzip2
21 | bzip2 -d /models/minimal-rnr-qa/1/variables/variables.data-00000-of-00001.bz2
22 | bzip2 -d /models/minimal-rnr-qa/1/saved_model.pb.bz2
23 |
24 | # server
25 | /usr/bin/tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=minimal-rnr-qa --model_base_path=/models/minimal-rnr-qa &
26 |
27 | # client
28 | cd /workspace
29 | python -u infer.py --url $URL --resources_path /resources --input_path $1 --output_path $2 `if [[ -n "${TOP_K}" ]]; then echo --top_k $TOP_K; fi` `if [[ -n "${PASSAGE_W}" ]]; then echo --passage_score_weight $PASSAGE_W; fi`
30 |
--------------------------------------------------------------------------------
/competition/workspace/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/competition/workspace/__init__.py
--------------------------------------------------------------------------------
/competition/workspace/bert_tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .tokenization_bert import BertTokenizer
2 |
3 | __all__ = ['BertTokenizer']
--------------------------------------------------------------------------------
/competition/workspace/bert_tokenizer/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 |
7 | import logging
8 | import os
9 | from functools import wraps
10 | from typing import Optional
11 |
12 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
13 |
14 | USE_TF = True
15 | USE_TORCH = False
16 | _torch_available = False
17 | _torch_tpu_available = False
18 | _psutil_available = False
19 | _py3nvml_available = False
20 | _has_apex = False
21 |
22 | try:
23 | import tensorflow as tf
24 |
25 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
26 | _tf_available = True # pylint: disable=invalid-name
27 | logger.info("TensorFlow version {} available.".format(tf.__version__))
28 | except (ImportError, AssertionError):
29 | _tf_available = False # pylint: disable=invalid-name
30 |
31 | WEIGHTS_NAME = "pytorch_model.bin"
32 | TF2_WEIGHTS_NAME = "tf_model.h5"
33 | TF_WEIGHTS_NAME = "model.ckpt"
34 | CONFIG_NAME = "config.json"
35 | MODEL_CARD_NAME = "modelcard.json"
36 |
37 |
38 | def is_tf_available():
39 | return _tf_available
40 |
41 |
42 | def cached_path(
43 | url_or_filename,
44 | ) -> Optional[str]:
45 | if os.path.exists(url_or_filename):
46 | output_path = url_or_filename
47 | else:
48 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
49 |
50 | return output_path
51 |
52 |
53 | def tf_required(func):
54 | # Chose a different decorator name than in tests so it's clear they are not the same.
55 | @wraps(func)
56 | def wrapper(*args, **kwargs):
57 | if is_tf_available():
58 | return func(*args, **kwargs)
59 | else:
60 | raise ImportError(f"Method `{func.__name__}` requires TF.")
61 |
62 | return wrapper
--------------------------------------------------------------------------------
/competition/workspace/bert_tokenizer/tokenization_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 |
18 | import collections
19 | import logging
20 | import os
21 | import unicodedata
22 | from typing import List, Optional
23 |
24 | from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
25 |
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30 |
31 | PRETRAINED_VOCAB_FILES_MAP = {
32 | "vocab_file": {
33 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
34 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
35 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
36 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
37 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
38 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
39 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
40 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
41 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
42 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
43 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
44 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
45 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
46 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
47 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
50 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
51 | }
52 | }
53 |
54 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55 | "bert-base-uncased": 512,
56 | "bert-large-uncased": 512,
57 | "bert-base-cased": 512,
58 | "bert-large-cased": 512,
59 | "bert-base-multilingual-uncased": 512,
60 | "bert-base-multilingual-cased": 512,
61 | "bert-base-chinese": 512,
62 | "bert-base-german-cased": 512,
63 | "bert-large-uncased-whole-word-masking": 512,
64 | "bert-large-cased-whole-word-masking": 512,
65 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66 | "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67 | "bert-base-cased-finetuned-mrpc": 512,
68 | "bert-base-german-dbmdz-cased": 512,
69 | "bert-base-german-dbmdz-uncased": 512,
70 | "TurkuNLP/bert-base-finnish-cased-v1": 512,
71 | "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72 | "wietsedv/bert-base-dutch-cased": 512,
73 | }
74 |
75 | PRETRAINED_INIT_CONFIGURATION = {
76 | "bert-base-uncased": {"do_lower_case": True},
77 | "bert-large-uncased": {"do_lower_case": True},
78 | "bert-base-cased": {"do_lower_case": False},
79 | "bert-large-cased": {"do_lower_case": False},
80 | "bert-base-multilingual-uncased": {"do_lower_case": True},
81 | "bert-base-multilingual-cased": {"do_lower_case": False},
82 | "bert-base-chinese": {"do_lower_case": False},
83 | "bert-base-german-cased": {"do_lower_case": False},
84 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85 | "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89 | "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94 | }
95 |
96 |
97 | def load_vocab(vocab_file):
98 | """Loads a vocabulary file into a dictionary."""
99 | vocab = collections.OrderedDict()
100 | with open(vocab_file, "r", encoding="utf-8") as reader:
101 | tokens = reader.readlines()
102 | for index, token in enumerate(tokens):
103 | token = token.rstrip("\n")
104 | vocab[token] = index
105 | return vocab
106 |
107 |
108 | def whitespace_tokenize(text):
109 | """Runs basic whitespace cleaning and splitting on a piece of text."""
110 | text = text.strip()
111 | if not text:
112 | return []
113 | tokens = text.split()
114 | return tokens
115 |
116 |
117 | class BertTokenizer(PreTrainedTokenizer):
118 | r"""
119 | Constructs a BERT tokenizer. Based on WordPiece.
120 |
121 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
122 | should refer to the superclass for more information regarding methods.
123 |
124 | Args:
125 | vocab_file (:obj:`string`):
126 | File containing the vocabulary.
127 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
128 | Whether to lowercase the input when tokenizing.
129 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
130 | Whether to do basic tokenization before WordPiece.
131 | never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
132 | Collection of tokens which will never be split during tokenization. Only has an effect when
133 | :obj:`do_basic_tokenize=True`
134 | unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
135 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
136 | token instead.
137 | sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
138 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
139 | for sequence classification or for a text and a question for question answering.
140 | It is also used as the last token of a sequence built with special tokens.
141 | pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
142 | The token used for padding, for example when batching sequences of different lengths.
143 | cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
144 | The classifier token which is used when doing sequence classification (classification of the whole
145 | sequence instead of per-token classification). It is the first token of the sequence when built with
146 | special tokens.
147 | mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
148 | The token used for masking values. This is the token used when training this model with masked language
149 | modeling. This is the token which the model will try to predict.
150 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
151 | Whether to tokenize Chinese characters.
152 | This should likely be deactivated for Japanese:
153 | see: https://github.com/huggingface/transformers/issues/328
154 | """
155 |
156 | vocab_files_names = VOCAB_FILES_NAMES
157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160 |
161 | def __init__(
162 | self,
163 | vocab_file,
164 | do_lower_case=True,
165 | do_basic_tokenize=True,
166 | never_split=None,
167 | unk_token="[UNK]",
168 | sep_token="[SEP]",
169 | pad_token="[PAD]",
170 | cls_token="[CLS]",
171 | mask_token="[MASK]",
172 | tokenize_chinese_chars=True,
173 | **kwargs
174 | ):
175 | super().__init__(
176 | unk_token=unk_token,
177 | sep_token=sep_token,
178 | pad_token=pad_token,
179 | cls_token=cls_token,
180 | mask_token=mask_token,
181 | **kwargs,
182 | )
183 |
184 | if not os.path.isfile(vocab_file):
185 | raise ValueError(
186 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
187 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
188 | )
189 | self.vocab = load_vocab(vocab_file)
190 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
191 | self.do_basic_tokenize = do_basic_tokenize
192 | if do_basic_tokenize:
193 | self.basic_tokenizer = BasicTokenizer(
194 | do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
195 | )
196 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
197 |
198 | @property
199 | def vocab_size(self):
200 | return len(self.vocab)
201 |
202 | def get_vocab(self):
203 | return dict(self.vocab, **self.added_tokens_encoder)
204 |
205 | def _tokenize(self, text):
206 | split_tokens = []
207 | if self.do_basic_tokenize:
208 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
209 |
210 | # If the token is part of the never_split set
211 | if token in self.basic_tokenizer.never_split:
212 | split_tokens.append(token)
213 | else:
214 | split_tokens += self.wordpiece_tokenizer.tokenize(token)
215 | else:
216 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
217 | return split_tokens
218 |
219 | def _convert_token_to_id(self, token):
220 | """ Converts a token (str) in an id using the vocab. """
221 | return self.vocab.get(token, self.vocab.get(self.unk_token))
222 |
223 | def _convert_id_to_token(self, index):
224 | """Converts an index (integer) in a token (str) using the vocab."""
225 | return self.ids_to_tokens.get(index, self.unk_token)
226 |
227 | def convert_tokens_to_string(self, tokens):
228 | """ Converts a sequence of tokens (string) in a single string. """
229 | out_string = " ".join(tokens).replace(" ##", "").strip()
230 | return out_string
231 |
232 | def build_inputs_with_special_tokens(
233 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234 | ) -> List[int]:
235 | """
236 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks
237 | by concatenating and adding special tokens.
238 | A BERT sequence has the following format:
239 |
240 | - single sequence: ``[CLS] X [SEP]``
241 | - pair of sequences: ``[CLS] A [SEP] B [SEP]``
242 |
243 | Args:
244 | token_ids_0 (:obj:`List[int]`):
245 | List of IDs to which the special tokens will be added
246 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
247 | Optional second list of IDs for sequence pairs.
248 |
249 | Returns:
250 | :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
251 | """
252 | if token_ids_1 is None:
253 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
254 | cls = [self.cls_token_id]
255 | sep = [self.sep_token_id]
256 | return cls + token_ids_0 + sep + token_ids_1 + sep
257 |
258 | def get_special_tokens_mask(
259 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
260 | ) -> List[int]:
261 | """
262 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
263 | special tokens using the tokenizer ``prepare_for_model`` method.
264 |
265 | Args:
266 | token_ids_0 (:obj:`List[int]`):
267 | List of ids.
268 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
269 | Optional second list of IDs for sequence pairs.
270 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
271 | Set to True if the token list is already formatted with special tokens for the model
272 |
273 | Returns:
274 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
275 | """
276 |
277 | if already_has_special_tokens:
278 | if token_ids_1 is not None:
279 | raise ValueError(
280 | "You should not supply a second sequence if the provided sequence of "
281 | "ids is already formated with special tokens for the model."
282 | )
283 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
284 |
285 | if token_ids_1 is not None:
286 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
287 | return [1] + ([0] * len(token_ids_0)) + [1]
288 |
289 | def create_token_type_ids_from_sequences(
290 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
291 | ) -> List[int]:
292 | """
293 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
294 | A BERT sequence pair mask has the following format:
295 |
296 | ::
297 |
298 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
299 | | first sequence | second sequence |
300 |
301 | if token_ids_1 is None, only returns the first portion of the mask (0's).
302 |
303 | Args:
304 | token_ids_0 (:obj:`List[int]`):
305 | List of ids.
306 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
307 | Optional second list of IDs for sequence pairs.
308 |
309 | Returns:
310 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
311 | sequence(s).
312 | """
313 | sep = [self.sep_token_id]
314 | cls = [self.cls_token_id]
315 | if token_ids_1 is None:
316 | return len(cls + token_ids_0 + sep) * [0]
317 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
318 |
319 | def save_vocabulary(self, vocab_path):
320 | """
321 | Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
322 |
323 | Args:
324 | vocab_path (:obj:`str`):
325 | The directory in which to save the vocabulary.
326 |
327 | Returns:
328 | :obj:`Tuple(str)`: Paths to the files saved.
329 | """
330 | index = 0
331 | if os.path.isdir(vocab_path):
332 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
333 | else:
334 | vocab_file = vocab_path
335 | with open(vocab_file, "w", encoding="utf-8") as writer:
336 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
337 | if index != token_index:
338 | logger.warning(
339 | "Saving vocabulary to {}: vocabulary indices are not consecutive."
340 | " Please check that the vocabulary is not corrupted!".format(vocab_file)
341 | )
342 | index = token_index
343 | writer.write(token + "\n")
344 | index += 1
345 | return (vocab_file,)
346 |
347 |
348 | class BasicTokenizer(object):
349 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
350 |
351 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
352 | """ Constructs a BasicTokenizer.
353 |
354 | Args:
355 | **do_lower_case**: Whether to lower case the input.
356 | **never_split**: (`optional`) list of str
357 | Kept for backward compatibility purposes.
358 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
359 | List of token not to split.
360 | **tokenize_chinese_chars**: (`optional`) boolean (default True)
361 | Whether to tokenize Chinese characters.
362 | This should likely be deactivated for Japanese:
363 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
364 | """
365 | if never_split is None:
366 | never_split = []
367 | self.do_lower_case = do_lower_case
368 | self.never_split = set(never_split)
369 | self.tokenize_chinese_chars = tokenize_chinese_chars
370 |
371 | def tokenize(self, text, never_split=None):
372 | """ Basic Tokenization of a piece of text.
373 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
374 |
375 | Args:
376 | **never_split**: (`optional`) list of str
377 | Kept for backward compatibility purposes.
378 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
379 | List of token not to split.
380 | """
381 | # union() returns a new set by concatenating the two sets.
382 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
383 |
384 | # This was added on November 1st, 2018 for the multilingual and Chinese
385 | # models. This is also applied to the English models now, but it doesn't
386 | # matter since the English models were not trained on any Chinese data
387 | # and generally don't have any Chinese data in them (there are Chinese
388 | # characters in the vocabulary because Wikipedia does have some Chinese
389 | # words in the English Wikipedia.).
390 | if self.tokenize_chinese_chars:
391 | text = self._tokenize_chinese_chars(text)
392 | orig_tokens = whitespace_tokenize(text)
393 | split_tokens = []
394 | for token in orig_tokens:
395 | if self.do_lower_case and token not in never_split:
396 | token = token.lower()
397 | token = self._run_strip_accents(token)
398 | split_tokens.extend(self._run_split_on_punc(token, never_split))
399 |
400 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
401 | return output_tokens
402 |
403 | def _run_strip_accents(self, text):
404 | """Strips accents from a piece of text."""
405 | text = unicodedata.normalize("NFD", text)
406 | output = []
407 | for char in text:
408 | cat = unicodedata.category(char)
409 | if cat == "Mn":
410 | continue
411 | output.append(char)
412 | return "".join(output)
413 |
414 | def _run_split_on_punc(self, text, never_split=None):
415 | """Splits punctuation on a piece of text."""
416 | if never_split is not None and text in never_split:
417 | return [text]
418 | chars = list(text)
419 | i = 0
420 | start_new_word = True
421 | output = []
422 | while i < len(chars):
423 | char = chars[i]
424 | if _is_punctuation(char):
425 | output.append([char])
426 | start_new_word = True
427 | else:
428 | if start_new_word:
429 | output.append([])
430 | start_new_word = False
431 | output[-1].append(char)
432 | i += 1
433 |
434 | return ["".join(x) for x in output]
435 |
436 | def _tokenize_chinese_chars(self, text):
437 | """Adds whitespace around any CJK character."""
438 | output = []
439 | for char in text:
440 | cp = ord(char)
441 | if self._is_chinese_char(cp):
442 | output.append(" ")
443 | output.append(char)
444 | output.append(" ")
445 | else:
446 | output.append(char)
447 | return "".join(output)
448 |
449 | def _is_chinese_char(self, cp):
450 | """Checks whether CP is the codepoint of a CJK character."""
451 | # This defines a "chinese character" as anything in the CJK Unicode block:
452 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
453 | #
454 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
455 | # despite its name. The modern Korean Hangul alphabet is a different block,
456 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
457 | # space-separated words, so they are not treated specially and handled
458 | # like the all of the other languages.
459 | if (
460 | (cp >= 0x4E00 and cp <= 0x9FFF)
461 | or (cp >= 0x3400 and cp <= 0x4DBF) #
462 | or (cp >= 0x20000 and cp <= 0x2A6DF) #
463 | or (cp >= 0x2A700 and cp <= 0x2B73F) #
464 | or (cp >= 0x2B740 and cp <= 0x2B81F) #
465 | or (cp >= 0x2B820 and cp <= 0x2CEAF) #
466 | or (cp >= 0xF900 and cp <= 0xFAFF)
467 | or (cp >= 0x2F800 and cp <= 0x2FA1F) #
468 | ): #
469 | return True
470 |
471 | return False
472 |
473 | def _clean_text(self, text):
474 | """Performs invalid character removal and whitespace cleanup on text."""
475 | output = []
476 | for char in text:
477 | cp = ord(char)
478 | if cp == 0 or cp == 0xFFFD or _is_control(char):
479 | continue
480 | if _is_whitespace(char):
481 | output.append(" ")
482 | else:
483 | output.append(char)
484 | return "".join(output)
485 |
486 |
487 | class WordpieceTokenizer(object):
488 | """Runs WordPiece tokenization."""
489 |
490 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
491 | self.vocab = vocab
492 | self.unk_token = unk_token
493 | self.max_input_chars_per_word = max_input_chars_per_word
494 |
495 | def tokenize(self, text):
496 | """Tokenizes a piece of text into its word pieces.
497 |
498 | This uses a greedy longest-match-first algorithm to perform tokenization
499 | using the given vocabulary.
500 |
501 | For example:
502 | input = "unaffable"
503 | output = ["un", "##aff", "##able"]
504 |
505 | Args:
506 | text: A single token or whitespace separated tokens. This should have
507 | already been passed through `BasicTokenizer`.
508 |
509 | Returns:
510 | A list of wordpiece tokens.
511 | """
512 |
513 | output_tokens = []
514 | for token in whitespace_tokenize(text):
515 | chars = list(token)
516 | if len(chars) > self.max_input_chars_per_word:
517 | output_tokens.append(self.unk_token)
518 | continue
519 |
520 | is_bad = False
521 | start = 0
522 | sub_tokens = []
523 | while start < len(chars):
524 | end = len(chars)
525 | cur_substr = None
526 | while start < end:
527 | substr = "".join(chars[start:end])
528 | if start > 0:
529 | substr = "##" + substr
530 | if substr in self.vocab:
531 | cur_substr = substr
532 | break
533 | end -= 1
534 | if cur_substr is None:
535 | is_bad = True
536 | break
537 | sub_tokens.append(cur_substr)
538 | start = end
539 |
540 | if is_bad:
541 | output_tokens.append(self.unk_token)
542 | else:
543 | output_tokens.extend(sub_tokens)
544 | return output_tokens
545 |
546 |
547 |
--------------------------------------------------------------------------------
/competition/workspace/bert_tokenizer/tokenization_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Tokenization classes for python tokenizers.
16 | For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py
17 | """
18 |
19 | import itertools
20 | import logging
21 | import re
22 | import unicodedata
23 | from typing import Dict, List, Optional, Tuple, Union
24 |
25 | from .tokenization_utils_base import (
26 | BatchEncoding,
27 | EncodedInput,
28 | EncodedInputPair,
29 | PaddingStrategy,
30 | PreTokenizedInput,
31 | PreTokenizedInputPair,
32 | PreTrainedTokenizerBase,
33 | TensorType,
34 | TextInput,
35 | TextInputPair,
36 | TruncationStrategy,
37 | )
38 |
39 |
40 | logger = logging.getLogger(__name__)
41 |
42 |
43 | def _is_whitespace(char):
44 | """Checks whether `chars` is a whitespace character."""
45 | # \t, \n, and \r are technically contorl characters but we treat them
46 | # as whitespace since they are generally considered as such.
47 | if char == " " or char == "\t" or char == "\n" or char == "\r":
48 | return True
49 | cat = unicodedata.category(char)
50 | if cat == "Zs":
51 | return True
52 | return False
53 |
54 |
55 | def _is_control(char):
56 | """Checks whether `chars` is a control character."""
57 | # These are technically control characters but we count them as whitespace
58 | # characters.
59 | if char == "\t" or char == "\n" or char == "\r":
60 | return False
61 | cat = unicodedata.category(char)
62 | if cat.startswith("C"):
63 | return True
64 | return False
65 |
66 |
67 | def _is_punctuation(char):
68 | """Checks whether `chars` is a punctuation character."""
69 | cp = ord(char)
70 | # We treat all non-letter/number ASCII as punctuation.
71 | # Characters such as "^", "$", and "`" are not in the Unicode
72 | # Punctuation class but we treat them as punctuation anyways, for
73 | # consistency.
74 | if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
75 | return True
76 | cat = unicodedata.category(char)
77 | if cat.startswith("P"):
78 | return True
79 | return False
80 |
81 |
82 | def _is_end_of_word(text):
83 | """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
84 | last_char = text[-1]
85 | return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
86 |
87 |
88 | def _is_start_of_word(text):
89 | """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
90 | first_char = text[0]
91 | return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
92 |
93 |
94 | class PreTrainedTokenizer(PreTrainedTokenizerBase):
95 | """ Base class for all slow tokenizers.
96 |
97 | Handle all the shared methods for tokenization and special tokens as well as methods
98 | downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
99 |
100 | This class also contain the added tokens in a unified way on top of all tokenizers so we don't
101 | have to handle the specific vocabulary augmentation methods of the various underlying
102 | dictionary structures (BPE, sentencepiece...).
103 |
104 | Class attributes (overridden by derived classes):
105 |
106 | - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
107 | required by the model, and as associated values, the filename for saving the associated file (string).
108 | - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
109 | being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
110 | `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
111 | associated pretrained vocabulary file.
112 | - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
113 | models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
114 | model has no maximum input size.
115 | - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
116 | pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
117 | ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
118 | ``from_pretrained()`` method.
119 |
120 | Args:
121 | - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
122 | When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
123 | model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
124 | no associated max_length can be found in ``max_model_input_sizes``.
125 | - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
126 | Should be selected between ['right', 'left']
127 | - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
128 | model ("token_type_ids", "attention_mask"...).
129 | - ``bos_token``: (`Optional`) string: a beginning of sentence token.
130 | Will be associated to ``self.bos_token`` and ``self.bos_token_id``
131 | - ``eos_token``: (`Optional`) string: an end of sentence token.
132 | Will be associated to ``self.eos_token`` and ``self.eos_token_id``
133 | - ``unk_token``: (`Optional`) string: an unknown token.
134 | Will be associated to ``self.unk_token`` and ``self.unk_token_id``
135 | - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
136 | Will be associated to ``self.sep_token`` and ``self.sep_token_id``
137 | - ``pad_token``: (`Optional`) string: a padding token.
138 | Will be associated to ``self.pad_token`` and ``self.pad_token_id``
139 | - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
140 | leveraging self-attention along the full depth of the model).
141 | Will be associated to ``self.cls_token`` and ``self.cls_token_id``
142 | - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
143 | modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
144 | - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
145 | Adding all special tokens here ensure they won't be split by the tokenization process.
146 | Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
147 |
148 |
149 | .. automethod:: __call__
150 | """
151 |
152 | def __init__(self, **kwargs):
153 | super().__init__(**kwargs)
154 |
155 | # Added tokens - We store this for both slow and fast tokenizers
156 | # until the serialization of Fast tokenizers is updated
157 | self.added_tokens_encoder: Dict[str, int] = {}
158 | self.added_tokens_decoder: Dict[int, str] = {}
159 | self.unique_no_split_tokens: List[str] = []
160 |
161 | @property
162 | def is_fast(self) -> bool:
163 | return False
164 |
165 | @property
166 | def vocab_size(self) -> int:
167 | """ Size of the base vocabulary (without the added tokens) """
168 | raise NotImplementedError
169 |
170 | def get_vocab(self):
171 | """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
172 | raise NotImplementedError()
173 |
174 | def get_added_vocab(self) -> Dict[str, int]:
175 | return self.added_tokens_encoder
176 |
177 | def __len__(self):
178 | """ Size of the full vocabulary with the added tokens """
179 | return self.vocab_size + len(self.added_tokens_encoder)
180 |
181 | def _add_tokens(self, new_tokens, special_tokens=False) -> int:
182 | """
183 | Add a list of new tokens to the tokenizer class. If the new tokens are not in the
184 | vocabulary, they are added to it with indices starting from length of the current vocabulary.
185 |
186 | Args:
187 | new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not
188 | already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
189 |
190 | Returns:
191 | Number of tokens added to the vocabulary.
192 |
193 | Examples::
194 |
195 | # Let's see how to increase the vocabulary of Bert model and tokenizer
196 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
197 | model = BertModel.from_pretrained('bert-base-uncased')
198 |
199 | num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
200 | print('We have added', num_added_toks, 'tokens')
201 | model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
202 | """
203 | new_tokens = [str(tok) for tok in new_tokens]
204 |
205 | tokens_to_add = []
206 | for token in new_tokens:
207 | assert isinstance(token, str)
208 | if not special_tokens and self.init_kwargs.get("do_lower_case", False):
209 | token = token.lower()
210 | if (
211 | token != self.unk_token
212 | and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
213 | and token not in tokens_to_add
214 | ):
215 | tokens_to_add.append(token)
216 | if self.verbose:
217 | logger.info("Adding %s to the vocabulary", token)
218 |
219 | added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
220 | added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
221 | self.added_tokens_encoder.update(added_tok_encoder)
222 | self.added_tokens_decoder.update(added_tok_decoder)
223 |
224 | # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
225 | if special_tokens:
226 | self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
227 | else:
228 | # Or on the newly added tokens
229 | self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
230 |
231 | return len(tokens_to_add)
232 |
233 | def num_special_tokens_to_add(self, pair=False):
234 | """
235 | Returns the number of added tokens when encoding a sequence with special tokens.
236 |
237 | Note:
238 | This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
239 | inside your training loop.
240 |
241 | Args:
242 | pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
243 | number of added tokens in the case of a single sequence if set to False.
244 |
245 | Returns:
246 | Number of tokens added to sequences
247 | """
248 | token_ids_0 = []
249 | token_ids_1 = []
250 | return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
251 |
252 | def tokenize(self, text: TextInput, **kwargs):
253 | """ Converts a string in a sequence of tokens (string), using the tokenizer.
254 | Split in words for word-based vocabulary or sub-words for sub-word-based
255 | vocabularies (BPE/SentencePieces/WordPieces).
256 |
257 | Take care of added tokens.
258 |
259 | Args:
260 | text (:obj:`string`): The sequence to be encoded.
261 | **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
262 | """
263 | # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
264 | all_special_tokens_extended = dict()
265 |
266 | text, kwargs = self.prepare_for_tokenization(text, **kwargs)
267 |
268 | if kwargs:
269 | logger.warning(f"Keyword arguments {kwargs} not recognized.")
270 |
271 | # TODO: should this be in the base class?
272 | if self.init_kwargs.get("do_lower_case", False):
273 | # convert non-special tokens to lowercase
274 | escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
275 | pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
276 | text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
277 |
278 | def split_on_token(tok, text):
279 | result = []
280 | tok_extended = all_special_tokens_extended.get(tok, None)
281 | split_text = text.split(tok)
282 | full_word = ""
283 | for i, sub_text in enumerate(split_text):
284 | # AddedToken can control whitespace stripping around them.
285 | # We use them for GPT2 and Roberta to have different behavior depending on the special token
286 | # Cf. https://github.com/huggingface/transformers/pull/2778
287 | # and https://github.com/huggingface/transformers/issues/3788
288 | # We strip left and right by default
289 | if i < len(split_text) - 1:
290 | sub_text = sub_text.rstrip()
291 | if i > 0:
292 | sub_text = sub_text.lstrip()
293 |
294 | if i == 0 and not sub_text:
295 | result += [tok]
296 | elif i == len(split_text) - 1:
297 | if sub_text:
298 | result += [sub_text]
299 | else:
300 | pass
301 | else:
302 | if sub_text:
303 | result += [sub_text]
304 | result += [tok]
305 | return result
306 |
307 | def split_on_tokens(tok_list, text):
308 | if not text.strip():
309 | return []
310 | if not tok_list:
311 | return self._tokenize(text)
312 |
313 | tokenized_text = []
314 | text_list = [text]
315 | for tok in tok_list:
316 | tokenized_text = []
317 | for sub_text in text_list:
318 | if sub_text not in self.unique_no_split_tokens:
319 | tokenized_text += split_on_token(tok, sub_text)
320 | else:
321 | tokenized_text += [sub_text]
322 | text_list = tokenized_text
323 |
324 | return list(
325 | itertools.chain.from_iterable(
326 | (
327 | self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
328 | for token in tokenized_text
329 | )
330 | )
331 | )
332 |
333 | no_split_token = self.unique_no_split_tokens
334 | tokenized_text = split_on_tokens(no_split_token, text)
335 | return tokenized_text
336 |
337 | def _tokenize(self, text, **kwargs):
338 | """ Converts a string in a sequence of tokens (string), using the tokenizer.
339 | Split in words for word-based vocabulary or sub-words for sub-word-based
340 | vocabularies (BPE/SentencePieces/WordPieces).
341 |
342 | Do NOT take care of added tokens.
343 | """
344 | raise NotImplementedError
345 |
346 | def convert_tokens_to_ids(self, tokens):
347 | """ Converts a token string (or a sequence of tokens) in a single integer id
348 | (or a sequence of ids), using the vocabulary.
349 | """
350 | if tokens is None:
351 | return None
352 |
353 | if isinstance(tokens, str):
354 | return self._convert_token_to_id_with_added_voc(tokens)
355 |
356 | ids = []
357 | for token in tokens:
358 | ids.append(self._convert_token_to_id_with_added_voc(token))
359 | return ids
360 |
361 | def _convert_token_to_id_with_added_voc(self, token):
362 | if token is None:
363 | return None
364 |
365 | if token in self.added_tokens_encoder:
366 | return self.added_tokens_encoder[token]
367 | return self._convert_token_to_id(token)
368 |
369 | def _convert_token_to_id(self, token):
370 | raise NotImplementedError
371 |
372 | def _encode_plus(
373 | self,
374 | text: Union[TextInput, PreTokenizedInput, EncodedInput],
375 | text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
376 | add_special_tokens: bool = True,
377 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
378 | truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
379 | max_length: Optional[int] = None,
380 | stride: int = 0,
381 | is_pretokenized: bool = False,
382 | pad_to_multiple_of: Optional[int] = None,
383 | return_tensors: Optional[Union[str, TensorType]] = None,
384 | return_token_type_ids: Optional[bool] = None,
385 | return_attention_mask: Optional[bool] = None,
386 | return_overflowing_tokens: bool = False,
387 | return_special_tokens_mask: bool = False,
388 | return_offsets_mapping: bool = False,
389 | return_length: bool = False,
390 | verbose: bool = True,
391 | **kwargs
392 | ) -> BatchEncoding:
393 | def get_input_ids(text):
394 | if isinstance(text, str):
395 | tokens = self.tokenize(text, **kwargs)
396 | return self.convert_tokens_to_ids(tokens)
397 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
398 | if is_pretokenized:
399 | tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
400 | return self.convert_tokens_to_ids(tokens)
401 | else:
402 | return self.convert_tokens_to_ids(text)
403 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
404 | return text
405 | else:
406 | if is_pretokenized:
407 | raise ValueError(
408 | f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
409 | )
410 | else:
411 | raise ValueError(
412 | f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
413 | )
414 |
415 | if return_offsets_mapping:
416 | raise NotImplementedError(
417 | "return_offset_mapping is not available when using Python tokenizers."
418 | "To use this feature, change your tokenizer to one deriving from "
419 | "transformers.PreTrainedTokenizerFast."
420 | "More information on available tokenizers at "
421 | "https://github.com/huggingface/transformers/pull/2674"
422 | )
423 |
424 | first_ids = get_input_ids(text)
425 | second_ids = get_input_ids(text_pair) if text_pair is not None else None
426 |
427 | return self.prepare_for_model(
428 | first_ids,
429 | pair_ids=second_ids,
430 | add_special_tokens=add_special_tokens,
431 | padding=padding_strategy.value,
432 | truncation=truncation_strategy.value,
433 | max_length=max_length,
434 | stride=stride,
435 | pad_to_multiple_of=pad_to_multiple_of,
436 | return_tensors=return_tensors,
437 | prepend_batch_axis=True,
438 | return_attention_mask=return_attention_mask,
439 | return_token_type_ids=return_token_type_ids,
440 | return_overflowing_tokens=return_overflowing_tokens,
441 | return_special_tokens_mask=return_special_tokens_mask,
442 | return_length=return_length,
443 | verbose=verbose,
444 | )
445 |
446 | def _batch_encode_plus(
447 | self,
448 | batch_text_or_text_pairs: Union[
449 | List[TextInput],
450 | List[TextInputPair],
451 | List[PreTokenizedInput],
452 | List[PreTokenizedInputPair],
453 | List[EncodedInput],
454 | List[EncodedInputPair],
455 | ],
456 | add_special_tokens: bool = True,
457 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
458 | truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
459 | max_length: Optional[int] = None,
460 | stride: int = 0,
461 | is_pretokenized: bool = False,
462 | pad_to_multiple_of: Optional[int] = None,
463 | return_tensors: Optional[Union[str, TensorType]] = None,
464 | return_token_type_ids: Optional[bool] = None,
465 | return_attention_mask: Optional[bool] = None,
466 | return_overflowing_tokens: bool = False,
467 | return_special_tokens_mask: bool = False,
468 | return_offsets_mapping: bool = False,
469 | return_length: bool = False,
470 | verbose: bool = True,
471 | **kwargs
472 | ) -> BatchEncoding:
473 | def get_input_ids(text):
474 | if isinstance(text, str):
475 | tokens = self.tokenize(text, **kwargs)
476 | return self.convert_tokens_to_ids(tokens)
477 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
478 | if is_pretokenized:
479 | tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
480 | return self.convert_tokens_to_ids(tokens)
481 | else:
482 | return self.convert_tokens_to_ids(text)
483 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
484 | return text
485 | else:
486 | raise ValueError(
487 | "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
488 | )
489 |
490 | if return_offsets_mapping:
491 | raise NotImplementedError(
492 | "return_offset_mapping is not available when using Python tokenizers."
493 | "To use this feature, change your tokenizer to one deriving from "
494 | "transformers.PreTrainedTokenizerFast."
495 | )
496 |
497 | input_ids = []
498 | for ids_or_pair_ids in batch_text_or_text_pairs:
499 | if not isinstance(ids_or_pair_ids, (list, tuple)):
500 | ids, pair_ids = ids_or_pair_ids, None
501 | elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)):
502 | ids, pair_ids = ids_or_pair_ids, None
503 | else:
504 | ids, pair_ids = ids_or_pair_ids
505 |
506 | first_ids = get_input_ids(ids)
507 | second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
508 | input_ids.append((first_ids, second_ids))
509 |
510 | batch_outputs = self._batch_prepare_for_model(
511 | input_ids,
512 | add_special_tokens=add_special_tokens,
513 | padding_strategy=padding_strategy,
514 | truncation_strategy=truncation_strategy,
515 | max_length=max_length,
516 | stride=stride,
517 | pad_to_multiple_of=pad_to_multiple_of,
518 | return_attention_mask=return_attention_mask,
519 | return_token_type_ids=return_token_type_ids,
520 | return_overflowing_tokens=return_overflowing_tokens,
521 | return_special_tokens_mask=return_special_tokens_mask,
522 | return_length=return_length,
523 | return_tensors=return_tensors,
524 | verbose=verbose,
525 | )
526 |
527 | return BatchEncoding(batch_outputs)
528 |
529 | def _batch_prepare_for_model(
530 | self,
531 | batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
532 | add_special_tokens: bool = True,
533 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
534 | truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
535 | max_length: Optional[int] = None,
536 | stride: int = 0,
537 | pad_to_multiple_of: Optional[int] = None,
538 | return_tensors: Optional[str] = None,
539 | return_token_type_ids: Optional[bool] = None,
540 | return_attention_mask: Optional[bool] = None,
541 | return_overflowing_tokens: bool = False,
542 | return_special_tokens_mask: bool = False,
543 | return_length: bool = False,
544 | verbose: bool = True,
545 | ) -> BatchEncoding:
546 | """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
547 | It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
548 | manages a moving window (with user defined stride) for overflowing tokens
549 |
550 | Args:
551 | batch_ids_pairs: list of tokenized input ids or input ids pairs
552 | """
553 |
554 | batch_outputs = {}
555 | for first_ids, second_ids in batch_ids_pairs:
556 | outputs = self.prepare_for_model(
557 | first_ids,
558 | second_ids,
559 | add_special_tokens=add_special_tokens,
560 | padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
561 | truncation=truncation_strategy.value,
562 | max_length=max_length,
563 | stride=stride,
564 | pad_to_multiple_of=None, # we pad in batch afterward
565 | return_attention_mask=False, # we pad in batch afterward
566 | return_token_type_ids=return_token_type_ids,
567 | return_overflowing_tokens=return_overflowing_tokens,
568 | return_special_tokens_mask=return_special_tokens_mask,
569 | return_length=return_length,
570 | return_tensors=None, # We convert the whole batch to tensors at the end
571 | prepend_batch_axis=False,
572 | verbose=verbose,
573 | )
574 |
575 | for key, value in outputs.items():
576 | if key not in batch_outputs:
577 | batch_outputs[key] = []
578 | batch_outputs[key].append(value)
579 |
580 | batch_outputs = self.pad(
581 | batch_outputs,
582 | padding=padding_strategy.value,
583 | max_length=max_length,
584 | pad_to_multiple_of=pad_to_multiple_of,
585 | return_attention_mask=return_attention_mask,
586 | )
587 |
588 | batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
589 |
590 | return batch_outputs
591 |
592 | def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
593 | """ Performs any necessary transformations before tokenization.
594 |
595 | This method should pop the arguments from kwargs and return kwargs as well.
596 | We test kwargs at the end of the encoding process to be sure all the arguments have been used.
597 | """
598 | return (text, kwargs)
599 |
600 | def get_special_tokens_mask(
601 | self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
602 | ) -> List[int]:
603 | """
604 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
605 | special tokens using the tokenizer ``prepare_for_model`` method.
606 |
607 | Args:
608 | token_ids_0: list of ids (must not contain special tokens)
609 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
610 | for sequence pairs
611 | already_has_special_tokens: (default False) Set to True if the token list is already formated with
612 | special tokens for the model
613 |
614 | Returns:
615 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
616 | """
617 | return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
618 |
619 | def convert_ids_to_tokens(
620 | self, ids: Union[int, List[int]], skip_special_tokens: bool = False
621 | ) -> Union[str, List[str]]:
622 | """ Converts a single index or a sequence of indices (integers) in a token "
623 | (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
624 |
625 | Args:
626 | skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
627 | """
628 | if isinstance(ids, int):
629 | if ids in self.added_tokens_decoder:
630 | return self.added_tokens_decoder[ids]
631 | else:
632 | return self._convert_id_to_token(ids)
633 | tokens = []
634 | for index in ids:
635 | index = int(index)
636 | if skip_special_tokens and index in self.all_special_ids:
637 | continue
638 | if index in self.added_tokens_decoder:
639 | tokens.append(self.added_tokens_decoder[index])
640 | else:
641 | tokens.append(self._convert_id_to_token(index))
642 | return tokens
643 |
644 | def _convert_id_to_token(self, index: int) -> str:
645 | raise NotImplementedError
646 |
647 | def convert_tokens_to_string(self, tokens: List[str]) -> str:
648 | """ Converts a sequence of tokens (string) in a single string.
649 | The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
650 | but we often want to remove sub-word tokenization artifacts at the same time.
651 | """
652 | return " ".join(self.convert_ids_to_tokens(tokens))
653 |
654 | def decode(
655 | self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
656 | ) -> str:
657 | filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
658 |
659 | # To avoid mixing byte-level and unicode for byte-level BPT
660 | # we need to build string separatly for added tokens and byte-level tokens
661 | # cf. https://github.com/huggingface/transformers/issues/1133
662 | sub_texts = []
663 | current_sub_text = []
664 | for token in filtered_tokens:
665 | if skip_special_tokens and token in self.all_special_ids:
666 | continue
667 | if token in self.added_tokens_encoder:
668 | if current_sub_text:
669 | sub_texts.append(self.convert_tokens_to_string(current_sub_text))
670 | current_sub_text = []
671 | sub_texts.append(token)
672 | else:
673 | current_sub_text.append(token)
674 | if current_sub_text:
675 | sub_texts.append(self.convert_tokens_to_string(current_sub_text))
676 | text = " ".join(sub_texts)
677 |
678 | if clean_up_tokenization_spaces:
679 | clean_text = self.clean_up_tokenization(text)
680 | return clean_text
681 | else:
682 | return text
683 |
684 | def save_vocabulary(self, save_directory) -> Tuple[str]:
685 | """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
686 | and special token mappings.
687 |
688 | Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full
689 | Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained`
690 | class method.
691 | """
692 | raise NotImplementedError
--------------------------------------------------------------------------------
/competition/workspace/infer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # minimal-rnr-qa
3 | # Copyright 2021-present NAVER Corp.
4 | # Apache License v2.0
5 |
6 | import os
7 | import glob
8 | import json
9 | import requests
10 | from argparse import ArgumentParser
11 | from timeit import default_timer as timer
12 |
13 | from bert_tokenizer import BertTokenizer
14 |
15 |
16 | def read_txt(file_path):
17 | data = []
18 | with open(file_path) as f:
19 | for line in f:
20 | ids = [int(x) for x in line.rstrip().split(" ")]
21 | data.append(ids)
22 | return data
23 |
24 |
25 | def get_api_input(signature_name, input_):
26 | if "token_type_ids" in input_:
27 | del input_["token_type_ids"]
28 |
29 | if type(input_) != dict:
30 | input_ = dict(input_)
31 |
32 | return json.dumps({
33 | "signature_name": signature_name,
34 | "inputs": input_,
35 | })
36 |
37 |
38 | def get_output(url, data):
39 | return requests.post(url, data).json()["outputs"]
40 |
41 |
42 | def get_score_input(token_logits, relevace_logits, attn_mask, passage_score_weight):
43 | input_ = {
44 | "token_logits": token_logits,
45 | "relevance_logits": relevace_logits,
46 | "attn_mask": attn_mask,
47 | "passage_score_weight": passage_score_weight,
48 | }
49 | return json.dumps({
50 | "signature_name": "get_score",
51 | "inputs": input_,
52 | })
53 |
54 |
55 | def get_retriever_input(question_str, tokenizer, top_k):
56 | input_ = tokenizer([question_str], max_length=256, truncation=True)
57 | input_.update({"top_k": top_k})
58 | return get_api_input("retrieve", input_)
59 |
60 |
61 | def get_reader_input(question_str, tokenizer, retriever_output):
62 | input_ids = []
63 | attention_mask = []
64 |
65 | retrieved_titles = retriever_output["titles"]
66 | retrieved_docs = retriever_output["docs"]
67 |
68 | question = tokenizer.encode(question_str, max_length=256, truncation=True)
69 | for title, doc in zip(retrieved_titles, retrieved_docs):
70 | concat = question + title + [tokenizer.sep_token_id] + doc
71 | concat = concat[:350]
72 | input_ids.append(concat)
73 | max_len = max(len(ids) for ids in input_ids)
74 | for i in range(len(input_ids)):
75 | padding = [0] * (max_len - len(input_ids[i]))
76 | attention_mask.append([1] * len(input_ids[i]) + padding)
77 | input_ids[i] = input_ids[i] + padding
78 | return {
79 | "input_ids": input_ids,
80 | "attention_mask": attention_mask,
81 | }
82 |
83 |
84 | def is_sub_word_id(tokenizer, token_id):
85 | token = tokenizer.convert_ids_to_tokens([token_id])[0]
86 | return token.startswith("##") or token.startswith(" ##")
87 |
88 |
89 | def extend_span_to_full_words(tokenizer, tokens, span):
90 | start_index, end_index = span
91 | max_len = len(tokens)
92 | while start_index > 0 and is_sub_word_id(tokenizer, tokens[start_index]):
93 | start_index -= 1
94 |
95 | while end_index < max_len - 1 and is_sub_word_id(tokenizer, tokens[end_index + 1]):
96 | end_index += 1
97 |
98 | return start_index, end_index
99 |
100 |
101 | def get_answer_greedy(tokenizer, reader_input, reader_output):
102 | max_answer_length = 10
103 | input_ids = reader_input["input_ids"]
104 | relevance_logits = reader_output["relevance_logits"]
105 | if not isinstance(relevance_logits, (tuple, list)):
106 | relevance_logits = [relevance_logits]
107 |
108 | top_doc_idx = max(enumerate(relevance_logits), key=lambda x: x[1])[0]
109 |
110 | sequence_len = sum(id_ != 0 for id_ in input_ids[top_doc_idx])
111 | passage_offset = input_ids[top_doc_idx].index(tokenizer.sep_token_id) + 1
112 | ctx_ids = input_ids[top_doc_idx][passage_offset:sequence_len]
113 | p_start_logits = reader_output["start_logits"][top_doc_idx][passage_offset:sequence_len]
114 | p_end_logits = reader_output["end_logits"][top_doc_idx][passage_offset:sequence_len]
115 |
116 | scores = []
117 | for (i, s) in enumerate(p_start_logits):
118 | for (j, e) in enumerate(p_end_logits[i:i + max_answer_length]):
119 | scores.append(((i, i + j), s + e))
120 | scores = sorted(scores, key=lambda x: x[1], reverse=True)
121 |
122 | chosen_span_intervals = []
123 |
124 | answer = ""
125 | for (start_index, end_index), score in scores:
126 | assert start_index <= end_index
127 | length = end_index - start_index + 1
128 | assert length <= max_answer_length
129 |
130 | if any([start_index <= prev_start_index <= prev_end_index <= end_index or
131 | prev_start_index <= start_index <= end_index <= prev_end_index
132 | for (prev_start_index, prev_end_index) in chosen_span_intervals]):
133 | continue
134 |
135 | start_index, end_index = extend_span_to_full_words(tokenizer, ctx_ids, (start_index, end_index))
136 | answer = tokenizer.decode(ctx_ids[start_index:end_index + 1], skip_special_tokens=True)
137 | break
138 | return answer
139 |
140 |
141 | def get_answer_deep(tokenizer, reader_input, reader_output, url, passage_score_weight):
142 | max_answer_length = 10
143 | input_ids = reader_input["input_ids"]
144 | attn_mask = reader_input["attention_mask"]
145 | relevance_logits = reader_output["relevance_logits"]
146 | if not isinstance(relevance_logits, (tuple, list)):
147 | relevance_logits = [relevance_logits]
148 |
149 | start_logits = get_output(url, get_score_input(reader_output["start_logits"], relevance_logits, attn_mask, passage_score_weight))
150 | end_logits = get_output(url, get_score_input(reader_output["end_logits"], relevance_logits, attn_mask, passage_score_weight))
151 |
152 | nbest = []
153 | for passage_idx in range(len(input_ids)):
154 | sequence_len = sum(id_ != 0 for id_ in input_ids[passage_idx])
155 | passage_offset = input_ids[passage_idx].index(tokenizer.sep_token_id) + 1
156 | ctx_ids = input_ids[passage_idx][passage_offset:sequence_len]
157 |
158 | p_start_logits = start_logits[passage_idx][passage_offset:sequence_len]
159 | p_end_logits = end_logits[passage_idx][passage_offset:sequence_len]
160 |
161 | scores = []
162 | for (i, s) in enumerate(p_start_logits):
163 | for (j, e) in enumerate(p_end_logits[i:i + max_answer_length]):
164 | scores.append(((i, i + j), s + e))
165 |
166 | scores = sorted(scores, key=lambda x: x[1], reverse=True)
167 |
168 | chosen_span_intervals = []
169 | best_spans = []
170 |
171 | for (start_index, end_index), score in scores:
172 | assert start_index <= end_index
173 | length = end_index - start_index + 1
174 | assert length <= max_answer_length
175 |
176 | if any([start_index <= prev_start_index <= prev_end_index <= end_index or
177 | prev_start_index <= start_index <= end_index <= prev_end_index
178 | for (prev_start_index, prev_end_index) in chosen_span_intervals]):
179 | continue
180 |
181 | start_index, end_index = extend_span_to_full_words(tokenizer, ctx_ids, (start_index, end_index))
182 |
183 | title_offset = ctx_ids.index(tokenizer.sep_token_id) + 1
184 |
185 | context = ctx_ids[title_offset:]
186 | start_index -= title_offset
187 | end_index -= title_offset
188 |
189 | answer = tokenizer.decode(context[start_index:end_index + 1])
190 |
191 | best_spans.append((answer, score))
192 | if len(best_spans) == 5:
193 | break
194 | nbest.extend(best_spans)
195 |
196 | nbest = sorted(nbest, key=lambda x: x[1], reverse=True)
197 | return nbest[0][0]
198 |
199 |
200 | def get_retriever_output(retrieved_doc_ids, titles, docs):
201 | retrieved_titles = [titles[i] for i in retrieved_doc_ids]
202 | retrieved_docs = [docs[i] for i in retrieved_doc_ids]
203 | return {
204 | "titles": retrieved_titles,
205 | "docs": retrieved_docs,
206 | }
207 |
208 |
209 | def predict(url, tokenizer, titles, docs, question_str, top_k, passage_score_weight):
210 | question_str = question_str.lower()
211 | if question_str.endswith("?"):
212 | question_str = question_str[:-1]
213 |
214 | retrieved_doc_ids = get_output(url, get_retriever_input(question_str, tokenizer, top_k))
215 | retriever_output = get_retriever_output(retrieved_doc_ids, titles, docs)
216 | reader_input = get_reader_input(question_str, tokenizer, retriever_output)
217 |
218 | reader_output = get_output(url, get_api_input("read", reader_input))
219 | if passage_score_weight is not None:
220 | answer = get_answer_deep(tokenizer, reader_input, reader_output, url, passage_score_weight)
221 | else:
222 | answer = get_answer_greedy(tokenizer, reader_input, reader_output)
223 | return answer
224 |
225 |
226 | def main(args):
227 | print(args)
228 |
229 | url = args.url
230 | titles_path = sorted(glob.glob(os.path.join(args.resources_path, "*.titles.txt")))[0]
231 | docs_path = sorted(glob.glob(os.path.join(args.resources_path, "*.docs.txt")))[0]
232 |
233 | print(titles_path)
234 | print(docs_path)
235 |
236 | titles = read_txt(titles_path)
237 | docs = read_txt(docs_path)
238 |
239 | tokenizer = BertTokenizer.from_pretrained("bert_tokenizer/mobilebert-uncased")
240 |
241 | with open(args.input_path, "r", encoding="utf-8") as input_file, open(args.output_path, "w", encoding="utf-8") as output_file:
242 | for line in input_file:
243 | start = timer()
244 | question_str = json.loads(line.strip())["question"]
245 | answer = predict(url, tokenizer, titles, docs, question_str, args.top_k, args.passage_score_weight)
246 | out = {"question": question_str, "prediction": answer.strip()}
247 | output_file.write(json.dumps(out) + "\n")
248 | print(str(out) + " (%.4fs)" % (timer() - start))
249 |
250 |
251 | if __name__ == "__main__":
252 | parser = ArgumentParser()
253 |
254 | parser.add_argument(
255 | "--url",
256 | default=None,
257 | type=str,
258 | required=True,
259 | )
260 | parser.add_argument(
261 | "--resources_path",
262 | default=None,
263 | type=str,
264 | required=True,
265 | )
266 | parser.add_argument(
267 | "--input_path",
268 | default=None,
269 | type=str,
270 | required=True,
271 | )
272 | parser.add_argument(
273 | "--output_path",
274 | type=str,
275 | default="predictions.txt",
276 | required=True,
277 | )
278 | parser.add_argument(
279 | "--top_k",
280 | type=int,
281 | default=100,
282 | )
283 | parser.add_argument(
284 | "--passage_score_weight",
285 | type=float,
286 | default=None,
287 | )
288 |
289 | args = parser.parse_args()
290 | main(args)
291 |
--------------------------------------------------------------------------------
/playground/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # coding=utf-8
3 | # minimal-rnr-qa
4 | # Copyright 2021-present NAVER Corp.
5 | # Apache License v2.0
6 |
7 |
8 | if [ "$#" -ne 2 ]; then
9 | echo "Usage: ./build.sh DATASET MODEL_TYPE"
10 | echo "DATASET: effqa | nq | trivia"
11 | echo "MODEL_TYPE: tfserving | tfserving_faiss | pytorch"
12 | exit 0
13 | fi
14 |
15 |
16 | DATASET=$1 # ["effqa", "nq", "trivia"]
17 | MODEL_TYPE=$2 # ["tfserving", "tfserving_faiss", "pytorch"]
18 |
19 | if [ "$MODEL_TYPE" = "pytorch" ]; then
20 | DOCKERFILE="Dockerfile-pytorch"
21 | else
22 | DOCKERFILE="Dockerfile-tfserving"
23 | fi
24 |
25 |
26 | docker build -f configs/$DOCKERFILE \
27 | --build-arg DATASET=$DATASET \
28 | --build-arg MODEL_TYPE=$MODEL_TYPE \
29 | . -t minimal-rnr-qa:$DATASET-$MODEL_TYPE
30 |
31 | echo "minimal-rnr-qa:$DATASET-$MODEL_TYPE"
32 |
--------------------------------------------------------------------------------
/playground/configs/Dockerfile-pytorch:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.4-cuda10.1-cudnn7-runtime
2 |
3 | ARG DATASET
4 | ARG MODEL_TYPE
5 |
6 | # Install python packages
7 | RUN pip install --no-cache-dir absl-py requests regex
8 |
9 | COPY configs/requirements-$MODEL_TYPE.txt /requirements.txt
10 | RUN pip install --no-cache-dir -r /requirements.txt
11 |
12 | COPY models/$MODEL_TYPE/$DATASET /models/$DATASET
13 | COPY resources/$DATASET /resources/$DATASET
14 |
15 | ENV DATASET=$DATASET
16 | ENV MODEL_TYPE=$MODEL_TYPE
17 |
18 | COPY workspace /workspace
19 | COPY entrypoint.sh /workspace/entrypoint.sh
20 | RUN chmod a+x /workspace/entrypoint.sh
21 |
22 | WORKDIR /workspace
23 |
--------------------------------------------------------------------------------
/playground/configs/Dockerfile-tfserving:
--------------------------------------------------------------------------------
1 | FROM tensorflow/serving:2.3.0 as build_image
2 | FROM python:3.6.11-slim-buster
3 |
4 | ARG DATASET
5 | ARG MODEL_TYPE
6 |
7 | # Install TF Serving
8 | COPY --from=build_image /usr/bin/tensorflow_model_server /usr/bin/tensorflow_model_server
9 |
10 | # Install python packages
11 | RUN pip install --no-cache-dir absl-py requests regex
12 |
13 | COPY configs/requirements-$MODEL_TYPE.txt /requirements.txt
14 | RUN pip install --no-cache-dir -r /requirements.txt
15 |
16 | COPY models/$MODEL_TYPE/$DATASET/1 /models/minimal-rnr-qa/1
17 | COPY resources/$DATASET /resources/$DATASET
18 |
19 | ENV DATASET=$DATASET
20 | ENV MODEL_TYPE=$MODEL_TYPE
21 |
22 | COPY workspace /workspace
23 | COPY entrypoint.sh /workspace/entrypoint.sh
24 | RUN chmod a+x /workspace/entrypoint.sh
25 |
26 | WORKDIR /workspace
27 |
--------------------------------------------------------------------------------
/playground/configs/requirements-pytorch.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
2 | transformers==3.0.2
3 | faiss
4 | faiss-cpu
5 | numpy
6 |
--------------------------------------------------------------------------------
/playground/configs/requirements-tfserving.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/configs/requirements-tfserving.txt
--------------------------------------------------------------------------------
/playground/configs/requirements-tfserving_faiss.txt:
--------------------------------------------------------------------------------
1 | faiss
2 | faiss-cpu
3 | numpy
--------------------------------------------------------------------------------
/playground/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # coding=utf-8
3 | # minimal-rnr-qa
4 | # Copyright 2021-present NAVER Corp.
5 | # Apache License v2.0
6 |
7 |
8 | if [ -z "$MODE" ]; then
9 | echo "--env MODE must be defined: demo | file"
10 | exit 0
11 | fi
12 |
13 |
14 | # launch tfserving
15 | if [ "$MODEL_TYPE" = "tfserving" ] || [ "$MODEL_TYPE" = "tfserving_faiss" ]; then
16 | /usr/bin/tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=minimal-rnr-qa --model_base_path=/models/minimal-rnr-qa &
17 | fi
18 |
19 |
20 | # set USE_FAISS
21 | if [ "$MODEL_TYPE" = "tfserving" ]; then
22 | USE_FAISS=false
23 | else
24 | USE_FAISS=true
25 | fi
26 |
27 |
28 | if [ "$MODE" = "demo" ]; then
29 | pip install --no-cache-dir flask tornado
30 |
31 | if [ "$MODEL_TYPE" = "tfserving" ] || [ "$MODEL_TYPE" = "tfserving_faiss" ]; then
32 | python -u run_tf_demo.py \
33 | --dataset $DATASET \
34 | `if [[ -n "${USE_FAISS}" ]]; then echo --use_faiss_index $USE_FAISS; fi`
35 |
36 | else
37 | python -u run_pt_demo.py \
38 | --dataset $DATASET
39 |
40 | fi
41 |
42 | elif [ "$MODE" = "file" ]; then
43 | if [ "$MODEL_TYPE" = "tfserving" ] || [ "$MODEL_TYPE" = "tfserving_faiss" ]; then
44 | PYTHON_FILE="run_tf_inference.py"
45 |
46 | else
47 | PYTHON_FILE="run_pt_inference.py"
48 | fi
49 |
50 | python -u $PYTHON_FILE \
51 | --dataset $DATASET \
52 | `if [[ -n "${USE_FAISS}" ]]; then echo --use_faiss_index $USE_FAISS; fi` \
53 | `if [[ -n "${TOP_K}" ]]; then echo --top_k $TOP_K; fi` \
54 | `if [[ -n "${PASSAGE_W}" ]]; then echo --passage_score_weight $PASSAGE_W; fi` \
55 | --input_path $1 --output_path $2 \
56 |
57 | else
58 | echo "--env MODE must be: demo | file"
59 | exit 0
60 | fi
61 |
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/__init__.py
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/pytorch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/pytorch/__init__.py
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/pytorch/inferencer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # minimal-rnr-qa
3 | # Copyright 2021-present NAVER Corp.
4 | # Apache License v2.0
5 |
6 | from minimal_rnr.utils.inferencer import MinimalRnR
7 | from minimal_rnr.pytorch.model import get_model_tokenizer_device
8 |
9 |
10 | class TorchMinimalRnR(MinimalRnR):
11 | def __init__(self, args):
12 | super(TorchMinimalRnR, self).__init__(args)
13 |
14 | assert args.use_faiss_index, "PyTorch model cannot be run without a faiss index."
15 | import torch
16 | self.torch = torch
17 |
18 | self.model, self.tokenizer, self.device = get_model_tokenizer_device(args)
19 |
20 | def get_question_encoding(self, question):
21 | input_ = self.tokenizer([question], max_length=self.max_retriever_input_len,
22 | truncation=True, return_token_type_ids=True, return_tensors="pt")
23 | torch_input = self._add_prefix_and_to_device(input_, "retriever_")
24 |
25 | with self.torch.no_grad():
26 | question_encoding = self.model(**torch_input)
27 | np_question_encoding = question_encoding.cpu().numpy().astype("float32")
28 |
29 | return np_question_encoding
30 |
31 | def get_retriever_output(self, question, top_k):
32 | raise Exception("get_retriever_output is use for running without faiss index,",
33 | "which is not supported for torch models")
34 |
35 | def get_reader_output(self, reader_input):
36 | torch_input = self._add_prefix_and_to_device(reader_input, "reader_", convert_to_tensor=True)
37 |
38 | with self.torch.no_grad():
39 | start_logits, end_logits, relevance_logits = self.model(**torch_input)
40 |
41 | relevance_logits = relevance_logits.squeeze()
42 | if relevance_logits.dim() == 0: # for top_k=1
43 | relevance_logits = relevance_logits.unsqueeze(0)
44 |
45 | # returned as cuda tensors (on purpose)
46 | return {
47 | "start_logits": start_logits.squeeze(-1),
48 | "end_logits": end_logits.squeeze(-1),
49 | "relevance_logits": relevance_logits,
50 | }
51 |
52 | def get_passage_score_weighted_answer_token_logits(self, token_logits, relevance_logits, attn_mask, passage_score_weight):
53 | attn_mask = self.torch.tensor(attn_mask).float()
54 |
55 | relevance_logits = relevance_logits.unsqueeze(1) # [M, 1]
56 | masked_token_logits = token_logits - 1e10 * (1.0 - attn_mask)
57 | log_span_prob = token_logits - masked_token_logits.logsumexp(dim=1, keepdim=True) # [M, L] softmaxed over L
58 | log_passage_prob = relevance_logits - relevance_logits.logsumexp(dim=0, keepdim=True) # [M, 1] softmaxed over M
59 | weighted_logits = log_span_prob * (1 - passage_score_weight) + log_passage_prob * passage_score_weight
60 |
61 | return weighted_logits
62 |
63 | def _add_prefix_and_to_device(self, data, prefix, convert_to_tensor=False):
64 | wrap = lambda x: self.torch.tensor(x) if convert_to_tensor else x
65 | data = {prefix + k: wrap(v).to(device=self.device) for k, v in data.items()}
66 | return data
67 |
68 | def maybe_tensor_to_list(self, tensor):
69 | if not isinstance(tensor, (list, tuple)): # torch tensor
70 | return tensor.cpu().tolist()
71 | return tensor
72 |
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/pytorch/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # ******************************************************************
3 | # Copied and modified from https://github.com/facebookresearch/DPR *
4 | # ******************************************************************
5 | # Copyright (c) Facebook, Inc. and its affiliates.
6 | # All rights reserved.
7 | #
8 | # This source code is licensed under the license found in the
9 | # LICENSE file in the root directory of this source tree.
10 |
11 | from minimal_rnr.utils.inference import get_first_matched_file_path
12 | from minimal_rnr.utils.logger import get_logger
13 |
14 |
15 | def get_model_tokenizer_device(args):
16 | # not to import torch and transformers as default
17 | import torch
18 | from torch import Tensor as T
19 | from torch import nn
20 | from transformers import MobileBertModel, MobileBertConfig, AutoTokenizer
21 |
22 |
23 | def init_weights(modules):
24 | for module in modules:
25 | if isinstance(module, (nn.Linear, nn.Embedding)):
26 | module.weight.data.normal_(mean=0.0, std=0.02)
27 | elif isinstance(module, nn.LayerNorm):
28 | module.bias.data.zero_()
29 | module.weight.data.fill_(1.0)
30 | if isinstance(module, nn.Linear) and module.bias is not None:
31 | module.bias.data.zero_()
32 |
33 |
34 | class HFMobileBertEncoder(MobileBertModel):
35 | def __init__(self, config, project_dim: int = 0, ctx_bottleneck: bool = False):
36 | MobileBertModel.__init__(self, config)
37 | assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero'
38 | self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None
39 | self.decode_proj = nn.Sequential(
40 | nn.Tanh(),
41 | nn.Linear(project_dim, (config.hidden_size + project_dim) // 2),
42 | nn.Tanh(),
43 | nn.Linear((config.hidden_size + project_dim) // 2, config.hidden_size),
44 | ) if ctx_bottleneck else None
45 | self.init_weights()
46 |
47 | @classmethod
48 | def init_encoder(cls, cfg_name: str) -> MobileBertModel:
49 | cfg = MobileBertConfig.from_pretrained(cfg_name)
50 | return cls.from_pretrained(cfg_name, config=cfg)
51 |
52 | def forward(self, input_ids: T, token_type_ids: T, attention_mask: T):
53 | if self.config.output_hidden_states:
54 | sequence_output, pooled_output, hidden_states = super().forward(input_ids=input_ids,
55 | token_type_ids=token_type_ids,
56 | attention_mask=attention_mask)
57 | else:
58 | hidden_states = None
59 | sequence_output, pooled_output = super().forward(
60 | input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
61 |
62 | pooled_output = sequence_output[:, 0, :]
63 | return sequence_output, pooled_output, hidden_states
64 |
65 | def get_out_size(self):
66 | if self.encode_proj:
67 | return self.encode_proj.out_features
68 | return self.config.hidden_size
69 |
70 |
71 | class UnifiedRetrieverReader(nn.Module):
72 | def __init__(self, encoder: nn.Module):
73 | super(UnifiedRetrieverReader, self).__init__()
74 |
75 | self.emb_size = 128
76 |
77 | self.question_model = encoder
78 | hidden_size = encoder.config.hidden_size
79 |
80 | self.qa_outputs = nn.Linear(hidden_size, 2)
81 | self.qa_classifier = nn.Linear(hidden_size, 1)
82 |
83 | init_weights([self.qa_outputs, self.qa_classifier])
84 |
85 | @staticmethod
86 | def get_representation(sub_model: nn.Module, ids, segments, attn_mask):
87 | sequence_output = None
88 | pooled_output = None
89 | hidden_states = None
90 | if ids is not None:
91 | sequence_output, pooled_output, hidden_states = sub_model(ids, segments, attn_mask)
92 |
93 | return sequence_output, pooled_output, hidden_states
94 |
95 | def forward(
96 | self, retriever_input_ids=None, retriever_token_type_ids=None, retriever_attention_mask=None,
97 | reader_input_ids=None, reader_attention_mask=None, reader_token_type_ids=None):
98 |
99 | if retriever_input_ids is not None:
100 | _, encoding, _ = self.get_representation(
101 | self.question_model, retriever_input_ids, retriever_token_type_ids, retriever_attention_mask)
102 |
103 | if self.emb_size is not None:
104 | return encoding[:, :self.emb_size]
105 | return encoding
106 |
107 | if reader_input_ids is not None:
108 | start_logits, end_logits, relevance_logits = self._read(reader_input_ids, reader_token_type_ids, reader_attention_mask)
109 | return start_logits, end_logits, relevance_logits
110 |
111 | def _read(self, input_ids, token_type_ids, attention_mask):
112 | sequence_output, _pooled_output, _hidden_states = self.question_model(input_ids, token_type_ids, attention_mask)
113 | logits = self.qa_outputs(sequence_output)
114 | start_logits, end_logits = logits.split(1, dim=-1)
115 | start_logits = start_logits.squeeze(-1)
116 | end_logits = end_logits.squeeze(-1)
117 |
118 | qa_classifier_input = sequence_output[:, 0, :]
119 | relevance_logits = self.qa_classifier(qa_classifier_input)
120 | return start_logits, end_logits, relevance_logits
121 |
122 |
123 | cfg_name = "google/mobilebert-uncased"
124 | question_encoder = HFMobileBertEncoder.init_encoder(cfg_name)
125 | model = UnifiedRetrieverReader(question_encoder)
126 | tokenizer = AutoTokenizer.from_pretrained(cfg_name, do_lower_case=True)
127 |
128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129 |
130 | model_file = get_first_matched_file_path(args.model_path, args.dataset, "*.bin")
131 | logger = get_logger("minimal-rnr-qa")
132 | logger.info(f"Loading model from {model_file}...")
133 | model.load_state_dict(torch.load(model_file, map_location=device))
134 | model.to(device)
135 | model.eval()
136 |
137 | return model, tokenizer, device
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/tfserving/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/tfserving/__init__.py
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/tfserving/bert_tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .tokenization_bert import BertTokenizer
2 |
3 | __all__ = ['BertTokenizer']
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/tfserving/bert_tokenizer/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 |
7 | import logging
8 | import os
9 | from functools import wraps
10 | from typing import Optional
11 |
12 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
13 |
14 | USE_TF = True
15 | USE_TORCH = False
16 | _torch_available = False
17 | _torch_tpu_available = False
18 | _psutil_available = False
19 | _py3nvml_available = False
20 | _has_apex = False
21 |
22 | try:
23 | import tensorflow as tf
24 |
25 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
26 | _tf_available = True # pylint: disable=invalid-name
27 | logger.info("TensorFlow version {} available.".format(tf.__version__))
28 | except (ImportError, AssertionError):
29 | _tf_available = False # pylint: disable=invalid-name
30 |
31 | WEIGHTS_NAME = "pytorch_model.bin"
32 | TF2_WEIGHTS_NAME = "tf_model.h5"
33 | TF_WEIGHTS_NAME = "model.ckpt"
34 | CONFIG_NAME = "config.json"
35 | MODEL_CARD_NAME = "modelcard.json"
36 |
37 |
38 | def is_tf_available():
39 | return _tf_available
40 |
41 |
42 | def cached_path(
43 | url_or_filename,
44 | ) -> Optional[str]:
45 | if os.path.exists(url_or_filename):
46 | output_path = url_or_filename
47 | else:
48 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
49 |
50 | return output_path
51 |
52 |
53 | def tf_required(func):
54 | # Chose a different decorator name than in tests so it's clear they are not the same.
55 | @wraps(func)
56 | def wrapper(*args, **kwargs):
57 | if is_tf_available():
58 | return func(*args, **kwargs)
59 | else:
60 | raise ImportError(f"Method `{func.__name__}` requires TF.")
61 |
62 | return wrapper
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/tfserving/bert_tokenizer/tokenization_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 |
18 | import collections
19 | import logging
20 | import os
21 | import unicodedata
22 | from typing import List, Optional
23 |
24 | from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
25 |
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30 |
31 | PRETRAINED_VOCAB_FILES_MAP = {
32 | "vocab_file": {
33 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
34 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
35 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
36 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
37 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
38 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
39 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
40 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
41 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
42 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
43 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
44 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
45 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
46 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
47 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
50 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
51 | }
52 | }
53 |
54 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55 | "bert-base-uncased": 512,
56 | "bert-large-uncased": 512,
57 | "bert-base-cased": 512,
58 | "bert-large-cased": 512,
59 | "bert-base-multilingual-uncased": 512,
60 | "bert-base-multilingual-cased": 512,
61 | "bert-base-chinese": 512,
62 | "bert-base-german-cased": 512,
63 | "bert-large-uncased-whole-word-masking": 512,
64 | "bert-large-cased-whole-word-masking": 512,
65 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66 | "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67 | "bert-base-cased-finetuned-mrpc": 512,
68 | "bert-base-german-dbmdz-cased": 512,
69 | "bert-base-german-dbmdz-uncased": 512,
70 | "TurkuNLP/bert-base-finnish-cased-v1": 512,
71 | "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72 | "wietsedv/bert-base-dutch-cased": 512,
73 | }
74 |
75 | PRETRAINED_INIT_CONFIGURATION = {
76 | "bert-base-uncased": {"do_lower_case": True},
77 | "bert-large-uncased": {"do_lower_case": True},
78 | "bert-base-cased": {"do_lower_case": False},
79 | "bert-large-cased": {"do_lower_case": False},
80 | "bert-base-multilingual-uncased": {"do_lower_case": True},
81 | "bert-base-multilingual-cased": {"do_lower_case": False},
82 | "bert-base-chinese": {"do_lower_case": False},
83 | "bert-base-german-cased": {"do_lower_case": False},
84 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85 | "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89 | "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94 | }
95 |
96 |
97 | def load_vocab(vocab_file):
98 | """Loads a vocabulary file into a dictionary."""
99 | vocab = collections.OrderedDict()
100 | with open(vocab_file, "r", encoding="utf-8") as reader:
101 | tokens = reader.readlines()
102 | for index, token in enumerate(tokens):
103 | token = token.rstrip("\n")
104 | vocab[token] = index
105 | return vocab
106 |
107 |
108 | def whitespace_tokenize(text):
109 | """Runs basic whitespace cleaning and splitting on a piece of text."""
110 | text = text.strip()
111 | if not text:
112 | return []
113 | tokens = text.split()
114 | return tokens
115 |
116 |
117 | class BertTokenizer(PreTrainedTokenizer):
118 | r"""
119 | Constructs a BERT tokenizer. Based on WordPiece.
120 |
121 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
122 | should refer to the superclass for more information regarding methods.
123 |
124 | Args:
125 | vocab_file (:obj:`string`):
126 | File containing the vocabulary.
127 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
128 | Whether to lowercase the input when tokenizing.
129 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
130 | Whether to do basic tokenization before WordPiece.
131 | never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
132 | Collection of tokens which will never be split during tokenization. Only has an effect when
133 | :obj:`do_basic_tokenize=True`
134 | unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
135 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
136 | token instead.
137 | sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
138 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
139 | for sequence classification or for a text and a question for question answering.
140 | It is also used as the last token of a sequence built with special tokens.
141 | pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
142 | The token used for padding, for example when batching sequences of different lengths.
143 | cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
144 | The classifier token which is used when doing sequence classification (classification of the whole
145 | sequence instead of per-token classification). It is the first token of the sequence when built with
146 | special tokens.
147 | mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
148 | The token used for masking values. This is the token used when training this model with masked language
149 | modeling. This is the token which the model will try to predict.
150 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
151 | Whether to tokenize Chinese characters.
152 | This should likely be deactivated for Japanese:
153 | see: https://github.com/huggingface/transformers/issues/328
154 | """
155 |
156 | vocab_files_names = VOCAB_FILES_NAMES
157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160 |
161 | def __init__(
162 | self,
163 | vocab_file,
164 | do_lower_case=True,
165 | do_basic_tokenize=True,
166 | never_split=None,
167 | unk_token="[UNK]",
168 | sep_token="[SEP]",
169 | pad_token="[PAD]",
170 | cls_token="[CLS]",
171 | mask_token="[MASK]",
172 | tokenize_chinese_chars=True,
173 | **kwargs
174 | ):
175 | super().__init__(
176 | unk_token=unk_token,
177 | sep_token=sep_token,
178 | pad_token=pad_token,
179 | cls_token=cls_token,
180 | mask_token=mask_token,
181 | **kwargs,
182 | )
183 |
184 | if not os.path.isfile(vocab_file):
185 | raise ValueError(
186 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
187 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
188 | )
189 | self.vocab = load_vocab(vocab_file)
190 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
191 | self.do_basic_tokenize = do_basic_tokenize
192 | if do_basic_tokenize:
193 | self.basic_tokenizer = BasicTokenizer(
194 | do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
195 | )
196 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
197 |
198 | @property
199 | def vocab_size(self):
200 | return len(self.vocab)
201 |
202 | def get_vocab(self):
203 | return dict(self.vocab, **self.added_tokens_encoder)
204 |
205 | def _tokenize(self, text):
206 | split_tokens = []
207 | if self.do_basic_tokenize:
208 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
209 |
210 | # If the token is part of the never_split set
211 | if token in self.basic_tokenizer.never_split:
212 | split_tokens.append(token)
213 | else:
214 | split_tokens += self.wordpiece_tokenizer.tokenize(token)
215 | else:
216 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
217 | return split_tokens
218 |
219 | def _convert_token_to_id(self, token):
220 | """ Converts a token (str) in an id using the vocab. """
221 | return self.vocab.get(token, self.vocab.get(self.unk_token))
222 |
223 | def _convert_id_to_token(self, index):
224 | """Converts an index (integer) in a token (str) using the vocab."""
225 | return self.ids_to_tokens.get(index, self.unk_token)
226 |
227 | def convert_tokens_to_string(self, tokens):
228 | """ Converts a sequence of tokens (string) in a single string. """
229 | out_string = " ".join(tokens).replace(" ##", "").strip()
230 | return out_string
231 |
232 | def build_inputs_with_special_tokens(
233 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234 | ) -> List[int]:
235 | """
236 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks
237 | by concatenating and adding special tokens.
238 | A BERT sequence has the following format:
239 |
240 | - single sequence: ``[CLS] X [SEP]``
241 | - pair of sequences: ``[CLS] A [SEP] B [SEP]``
242 |
243 | Args:
244 | token_ids_0 (:obj:`List[int]`):
245 | List of IDs to which the special tokens will be added
246 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
247 | Optional second list of IDs for sequence pairs.
248 |
249 | Returns:
250 | :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
251 | """
252 | if token_ids_1 is None:
253 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
254 | cls = [self.cls_token_id]
255 | sep = [self.sep_token_id]
256 | return cls + token_ids_0 + sep + token_ids_1 + sep
257 |
258 | def get_special_tokens_mask(
259 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
260 | ) -> List[int]:
261 | """
262 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
263 | special tokens using the tokenizer ``prepare_for_model`` method.
264 |
265 | Args:
266 | token_ids_0 (:obj:`List[int]`):
267 | List of ids.
268 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
269 | Optional second list of IDs for sequence pairs.
270 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
271 | Set to True if the token list is already formatted with special tokens for the model
272 |
273 | Returns:
274 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
275 | """
276 |
277 | if already_has_special_tokens:
278 | if token_ids_1 is not None:
279 | raise ValueError(
280 | "You should not supply a second sequence if the provided sequence of "
281 | "ids is already formated with special tokens for the model."
282 | )
283 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
284 |
285 | if token_ids_1 is not None:
286 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
287 | return [1] + ([0] * len(token_ids_0)) + [1]
288 |
289 | def create_token_type_ids_from_sequences(
290 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
291 | ) -> List[int]:
292 | """
293 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
294 | A BERT sequence pair mask has the following format:
295 |
296 | ::
297 |
298 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
299 | | first sequence | second sequence |
300 |
301 | if token_ids_1 is None, only returns the first portion of the mask (0's).
302 |
303 | Args:
304 | token_ids_0 (:obj:`List[int]`):
305 | List of ids.
306 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
307 | Optional second list of IDs for sequence pairs.
308 |
309 | Returns:
310 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
311 | sequence(s).
312 | """
313 | sep = [self.sep_token_id]
314 | cls = [self.cls_token_id]
315 | if token_ids_1 is None:
316 | return len(cls + token_ids_0 + sep) * [0]
317 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
318 |
319 | def save_vocabulary(self, vocab_path):
320 | """
321 | Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
322 |
323 | Args:
324 | vocab_path (:obj:`str`):
325 | The directory in which to save the vocabulary.
326 |
327 | Returns:
328 | :obj:`Tuple(str)`: Paths to the files saved.
329 | """
330 | index = 0
331 | if os.path.isdir(vocab_path):
332 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
333 | else:
334 | vocab_file = vocab_path
335 | with open(vocab_file, "w", encoding="utf-8") as writer:
336 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
337 | if index != token_index:
338 | logger.warning(
339 | "Saving vocabulary to {}: vocabulary indices are not consecutive."
340 | " Please check that the vocabulary is not corrupted!".format(vocab_file)
341 | )
342 | index = token_index
343 | writer.write(token + "\n")
344 | index += 1
345 | return (vocab_file,)
346 |
347 |
348 | class BasicTokenizer(object):
349 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
350 |
351 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
352 | """ Constructs a BasicTokenizer.
353 |
354 | Args:
355 | **do_lower_case**: Whether to lower case the input.
356 | **never_split**: (`optional`) list of str
357 | Kept for backward compatibility purposes.
358 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
359 | List of token not to split.
360 | **tokenize_chinese_chars**: (`optional`) boolean (default True)
361 | Whether to tokenize Chinese characters.
362 | This should likely be deactivated for Japanese:
363 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
364 | """
365 | if never_split is None:
366 | never_split = []
367 | self.do_lower_case = do_lower_case
368 | self.never_split = set(never_split)
369 | self.tokenize_chinese_chars = tokenize_chinese_chars
370 |
371 | def tokenize(self, text, never_split=None):
372 | """ Basic Tokenization of a piece of text.
373 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
374 |
375 | Args:
376 | **never_split**: (`optional`) list of str
377 | Kept for backward compatibility purposes.
378 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
379 | List of token not to split.
380 | """
381 | # union() returns a new set by concatenating the two sets.
382 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
383 |
384 | # This was added on November 1st, 2018 for the multilingual and Chinese
385 | # models. This is also applied to the English models now, but it doesn't
386 | # matter since the English models were not trained on any Chinese data
387 | # and generally don't have any Chinese data in them (there are Chinese
388 | # characters in the vocabulary because Wikipedia does have some Chinese
389 | # words in the English Wikipedia.).
390 | if self.tokenize_chinese_chars:
391 | text = self._tokenize_chinese_chars(text)
392 | orig_tokens = whitespace_tokenize(text)
393 | split_tokens = []
394 | for token in orig_tokens:
395 | if self.do_lower_case and token not in never_split:
396 | token = token.lower()
397 | token = self._run_strip_accents(token)
398 | split_tokens.extend(self._run_split_on_punc(token, never_split))
399 |
400 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
401 | return output_tokens
402 |
403 | def _run_strip_accents(self, text):
404 | """Strips accents from a piece of text."""
405 | text = unicodedata.normalize("NFD", text)
406 | output = []
407 | for char in text:
408 | cat = unicodedata.category(char)
409 | if cat == "Mn":
410 | continue
411 | output.append(char)
412 | return "".join(output)
413 |
414 | def _run_split_on_punc(self, text, never_split=None):
415 | """Splits punctuation on a piece of text."""
416 | if never_split is not None and text in never_split:
417 | return [text]
418 | chars = list(text)
419 | i = 0
420 | start_new_word = True
421 | output = []
422 | while i < len(chars):
423 | char = chars[i]
424 | if _is_punctuation(char):
425 | output.append([char])
426 | start_new_word = True
427 | else:
428 | if start_new_word:
429 | output.append([])
430 | start_new_word = False
431 | output[-1].append(char)
432 | i += 1
433 |
434 | return ["".join(x) for x in output]
435 |
436 | def _tokenize_chinese_chars(self, text):
437 | """Adds whitespace around any CJK character."""
438 | output = []
439 | for char in text:
440 | cp = ord(char)
441 | if self._is_chinese_char(cp):
442 | output.append(" ")
443 | output.append(char)
444 | output.append(" ")
445 | else:
446 | output.append(char)
447 | return "".join(output)
448 |
449 | def _is_chinese_char(self, cp):
450 | """Checks whether CP is the codepoint of a CJK character."""
451 | # This defines a "chinese character" as anything in the CJK Unicode block:
452 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
453 | #
454 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
455 | # despite its name. The modern Korean Hangul alphabet is a different block,
456 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
457 | # space-separated words, so they are not treated specially and handled
458 | # like the all of the other languages.
459 | if (
460 | (cp >= 0x4E00 and cp <= 0x9FFF)
461 | or (cp >= 0x3400 and cp <= 0x4DBF) #
462 | or (cp >= 0x20000 and cp <= 0x2A6DF) #
463 | or (cp >= 0x2A700 and cp <= 0x2B73F) #
464 | or (cp >= 0x2B740 and cp <= 0x2B81F) #
465 | or (cp >= 0x2B820 and cp <= 0x2CEAF) #
466 | or (cp >= 0xF900 and cp <= 0xFAFF)
467 | or (cp >= 0x2F800 and cp <= 0x2FA1F) #
468 | ): #
469 | return True
470 |
471 | return False
472 |
473 | def _clean_text(self, text):
474 | """Performs invalid character removal and whitespace cleanup on text."""
475 | output = []
476 | for char in text:
477 | cp = ord(char)
478 | if cp == 0 or cp == 0xFFFD or _is_control(char):
479 | continue
480 | if _is_whitespace(char):
481 | output.append(" ")
482 | else:
483 | output.append(char)
484 | return "".join(output)
485 |
486 |
487 | class WordpieceTokenizer(object):
488 | """Runs WordPiece tokenization."""
489 |
490 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
491 | self.vocab = vocab
492 | self.unk_token = unk_token
493 | self.max_input_chars_per_word = max_input_chars_per_word
494 |
495 | def tokenize(self, text):
496 | """Tokenizes a piece of text into its word pieces.
497 |
498 | This uses a greedy longest-match-first algorithm to perform tokenization
499 | using the given vocabulary.
500 |
501 | For example:
502 | input = "unaffable"
503 | output = ["un", "##aff", "##able"]
504 |
505 | Args:
506 | text: A single token or whitespace separated tokens. This should have
507 | already been passed through `BasicTokenizer`.
508 |
509 | Returns:
510 | A list of wordpiece tokens.
511 | """
512 |
513 | output_tokens = []
514 | for token in whitespace_tokenize(text):
515 | chars = list(token)
516 | if len(chars) > self.max_input_chars_per_word:
517 | output_tokens.append(self.unk_token)
518 | continue
519 |
520 | is_bad = False
521 | start = 0
522 | sub_tokens = []
523 | while start < len(chars):
524 | end = len(chars)
525 | cur_substr = None
526 | while start < end:
527 | substr = "".join(chars[start:end])
528 | if start > 0:
529 | substr = "##" + substr
530 | if substr in self.vocab:
531 | cur_substr = substr
532 | break
533 | end -= 1
534 | if cur_substr is None:
535 | is_bad = True
536 | break
537 | sub_tokens.append(cur_substr)
538 | start = end
539 |
540 | if is_bad:
541 | output_tokens.append(self.unk_token)
542 | else:
543 | output_tokens.extend(sub_tokens)
544 | return output_tokens
545 |
546 |
547 |
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/tfserving/inferencer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # minimal-rnr-qa
3 | # Copyright 2021-present NAVER Corp.
4 | # Apache License v2.0
5 |
6 |
7 | import json
8 | import os
9 |
10 | import requests
11 |
12 | from minimal_rnr.tfserving.bert_tokenizer import BertTokenizer
13 | from minimal_rnr.utils.inferencer import MinimalRnR
14 |
15 |
16 | class TFServingMinimalRnR(MinimalRnR):
17 | def __init__(self, args):
18 | super(TFServingMinimalRnR, self).__init__(args)
19 |
20 | self.url = f"http://{args.tfserving_ip}:{args.tfserving_port}/v1/models/minimal-rnr-qa:predict"
21 | self.tokenizer = BertTokenizer.from_pretrained(
22 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "bert_tokenizer/mobilebert-uncased"))
23 |
24 | def get_question_encoding(self, question):
25 | return self._get_api_output(self._get_question_encoder_api_input(question))
26 |
27 | def get_retriever_output(self, question, top_k):
28 | retrieved_doc_ids = self._get_api_output(self._get_retrieve_api_input(question, top_k))
29 | return retrieved_doc_ids
30 |
31 | def get_reader_output(self, reader_input):
32 | reader_api_input = self._get_api_input("read", reader_input)
33 | reader_output = self._get_api_output(reader_api_input)
34 |
35 | if not isinstance(reader_output["relevance_logits"], (tuple, list)): # for top_k=1
36 | reader_output["relevance_logits"] = [reader_output["relevance_logits"]]
37 |
38 | return reader_output
39 |
40 | def get_passage_score_weighted_answer_token_logits(self, token_logits, relevance_logits, attn_mask, passage_score_weight):
41 | weighted_logits = self._get_api_output(self._get_score_api_input(token_logits, relevance_logits, attn_mask, passage_score_weight))
42 | return weighted_logits
43 |
44 | def _get_question_encoder_api_input(self, question):
45 | input_ = self.tokenizer([question], max_length=self.max_retriever_input_len, truncation=True, return_token_type_ids=True)
46 | return self._get_api_input("encode", input_, preserve_token_type_ids=True)
47 |
48 | def _get_retrieve_api_input(self, question, top_k):
49 | input_ = self.tokenizer([question], max_length=self.max_retriever_input_len, truncation=True)
50 | input_["top_k"] = top_k
51 |
52 | return self._get_api_input("retrieve", input_)
53 |
54 | def _get_score_api_input(self, token_logits, relevace_logits, attn_mask, passage_score_weight):
55 | input_ = {
56 | "token_logits": token_logits,
57 | "relevance_logits": relevace_logits,
58 | "attn_mask": attn_mask,
59 | "passage_score_weight": passage_score_weight,
60 | }
61 | return json.dumps({
62 | "signature_name": "get_score",
63 | "inputs": input_,
64 | })
65 |
66 | def _get_api_input(self, signature_name, input_, preserve_token_type_ids=False):
67 | if not preserve_token_type_ids and "token_type_ids" in input_:
68 | del input_["token_type_ids"]
69 |
70 | if type(input_) != dict:
71 | input_ = dict(input_)
72 |
73 | return json.dumps({
74 | "signature_name": signature_name,
75 | "inputs": input_,
76 | })
77 |
78 | def _get_api_output(self, payload):
79 | response = requests.post(self.url, payload)
80 | if response.status_code != 200:
81 | response.raise_for_status()
82 | return response.json()["outputs"]
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/utils/__init__.py
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/demo.py:
--------------------------------------------------------------------------------
1 | # Copied and modified from https://github.com/uwnlp/denspi
2 |
3 | from time import time
4 |
5 | from flask import Flask, request, jsonify
6 | from tornado.httpserver import HTTPServer
7 | from tornado.ioloop import IOLoop
8 | from tornado.wsgi import WSGIContainer
9 |
10 | from minimal_rnr.utils.logger import get_logger
11 |
12 |
13 | def run_app(args, minimal_rnr):
14 | logger = get_logger("minimal-rnr-qa")
15 |
16 | inference_api = minimal_rnr.get_inference_api()
17 |
18 | app = Flask(__name__, static_url_path='/static')
19 | app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False
20 |
21 | def _search(query, top_k, passage_score_weight):
22 | start = time()
23 | result = inference_api(query, top_k, passage_score_weight)
24 | return {"ret": result, "time": int((time() - start))}
25 |
26 | @app.route("/")
27 | def index():
28 | return app.send_static_file('index.html')
29 |
30 | @app.route("/files/")
31 | def static_files(path):
32 | return app.send_static_file('files/' + path)
33 |
34 | @app.route("/api", methods=["GET"])
35 | def api():
36 | logger.info(request.args)
37 |
38 | query = request.args["query"]
39 | top_k = int(request.args["top_k"])
40 |
41 | if request.args["passage_score_weight"] == "null":
42 | passage_score_weight = None
43 | else:
44 | passage_score_weight = float(request.args["passage_score_weight"])
45 |
46 | result = _search(query, top_k, passage_score_weight)
47 | logger.info(result)
48 | return jsonify(result)
49 |
50 | @app.route("/get_examples", methods=["GET"])
51 | def get_examples():
52 | with open(args.examples_path, "r") as fp:
53 | examples = [line.strip() for line in fp.readlines()]
54 | return jsonify(examples)
55 |
56 | @app.route("/quit")
57 | def quit():
58 | raise KeyboardInterrupt
59 |
60 | logger.info("Warming up...")
61 | minimal_rnr.predict_answer("warmup", top_k=5, passage_score_weight=0.8)
62 |
63 | logger.info(f"Starting server at {args.demo_port}")
64 | http_server = HTTPServer(WSGIContainer(app))
65 | http_server.listen(args.demo_port)
66 | IOLoop.instance().start()
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # ******************************************************************
3 | # Copied and modified from https://github.com/facebookresearch/FiD *
4 | # ******************************************************************
5 | # Copyright (c) Facebook, Inc. and its affiliates.
6 | # All rights reserved.
7 | #
8 | # This source code is licensed under the license found in the
9 | # LICENSE file in the root directory of this source tree.
10 |
11 | import string
12 | import unicodedata
13 |
14 | import regex
15 |
16 |
17 | def _normalize(text):
18 | # Copied from SQuAD evaluation script
19 | # https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
20 | return unicodedata.normalize('NFD', text)
21 |
22 |
23 | def normalize_answer(s):
24 | def remove_articles(text):
25 | return regex.sub(r'\b(a|an|the)\b', ' ', text)
26 |
27 | def white_space_fix(text):
28 | return ' '.join(text.split())
29 |
30 | def remove_punc(text):
31 | exclude = set(string.punctuation)
32 | return ''.join(ch for ch in text if ch not in exclude)
33 |
34 | def lower(text):
35 | return text.lower()
36 |
37 | return white_space_fix(remove_articles(remove_punc(lower(s))))
38 |
39 |
40 | def exact_match_score(prediction, ground_truth):
41 | return normalize_answer(prediction) == normalize_answer(ground_truth)
42 |
43 |
44 | def ems(prediction, ground_truths):
45 | return max([exact_match_score(prediction, gt) for gt in ground_truths])
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/inference.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import glob
3 | import json
4 | import os
5 | import regex
6 |
7 |
8 | class SimpleTokenizer(object):
9 | # Copied and modified from https://github.com/facebookresearch/FiD
10 |
11 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
12 | NON_WS = r'[^\p{Z}\p{C}]'
13 |
14 | def __init__(self):
15 | self._regexp = regex.compile(
16 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
17 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
18 | )
19 |
20 | def tokenize(self, text, uncased=False):
21 | matches = [m for m in self._regexp.finditer(text)]
22 | if uncased:
23 | tokens = [m.group().lower() for m in matches]
24 | else:
25 | tokens = [m.group() for m in matches]
26 | return tokens
27 |
28 |
29 | def extend_span_to_full_words(tokenizer, tokens, span):
30 | # Copied and modified from https://github.com/facebookresearch/DPR
31 |
32 | def is_sub_word_id(tokenizer, token_id):
33 | token = tokenizer.convert_ids_to_tokens([token_id])[0]
34 | return token.startswith("##") or token.startswith(" ##")
35 |
36 | start_index, end_index = span
37 | max_len = len(tokens)
38 | while start_index > 0 and is_sub_word_id(tokenizer, tokens[start_index]):
39 | start_index -= 1
40 |
41 | while end_index < max_len - 1 and is_sub_word_id(tokenizer, tokens[end_index + 1]):
42 | end_index += 1
43 |
44 | return start_index, end_index
45 |
46 |
47 | def read_txt(file_path):
48 | data = []
49 | with open(file_path) as f:
50 | for line in f:
51 | ids = [int(x) for x in line.rstrip().split(" ")]
52 | data.append(ids)
53 | return data
54 |
55 |
56 | def get_first_matched_file_path(directory, dataset, file_pattern):
57 | return sorted(glob.glob(os.path.join(directory, dataset, file_pattern)))[0]
58 |
59 |
60 | def normalize_question(question):
61 | question = question.lower()
62 | if question.endswith("?"):
63 | question = question[:-1]
64 | return question
65 |
66 |
67 | def _read_through_csv_qa_file(file_path):
68 | with open(file_path, encoding="utf-8") as f:
69 | reader = csv.reader(f, delimiter='\t')
70 | for row in reader:
71 | question = normalize_question(row[0])
72 | answers = None if len(row) < 2 else row[1]
73 |
74 | if answers is None:
75 | qa_dict = {"question": question}
76 | else:
77 | qa_dict = {"question": question, "answers": answers}
78 | yield qa_dict
79 |
80 |
81 | def _read_through_jsonl_qa_file(file_path):
82 | with open(file_path, encoding="utf-8") as f:
83 | for line in f:
84 | qa_dict = json.loads(line.strip())
85 | qa_dict["question"] = normalize_question(qa_dict["question"])
86 |
87 | if "answer" in qa_dict:
88 | qa_dict["answers"] = qa_dict["answer"]
89 | del qa_dict["answer"]
90 |
91 | yield qa_dict
92 |
93 |
94 | def read_through_qa_file(file_path):
95 | if file_path.endswith(".csv"):
96 | return _read_through_csv_qa_file(file_path)
97 | elif file_path.endswith(".jsonl") or file_path.endswith("jl"):
98 | return _read_through_jsonl_qa_file(file_path)
99 | else:
100 | raise ValueError(f"File {file_path} must be either csv or jsonlines file")
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/inferencer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # minimal-rnr-qa
3 | # Copyright 2021-present NAVER Corp.
4 | # Apache License v2.0
5 |
6 | import abc
7 | import json
8 | from timeit import default_timer as timer
9 |
10 | from minimal_rnr.utils.inference import read_txt, get_first_matched_file_path, read_through_qa_file, \
11 | extend_span_to_full_words
12 | from minimal_rnr.utils.logger import get_logger
13 |
14 |
15 | class MinimalRnR(object, metaclass=abc.ABCMeta):
16 | def __init__(self, args):
17 | self.logger = get_logger("minimal-rnr-qa")
18 |
19 | titles_file = get_first_matched_file_path(args.resources_path, args.dataset, "*.titles.txt")
20 | self.logger.info(f"Loading titles from {titles_file}...")
21 | self.all_titles = read_txt(titles_file)
22 |
23 | docs_file = get_first_matched_file_path(args.resources_path, args.dataset, "*.docs.txt")
24 | self.logger.info(f"Loading docs from {docs_file}...")
25 | self.all_docs = read_txt(docs_file)
26 |
27 | if args.use_faiss_index:
28 | import faiss
29 | index_file = get_first_matched_file_path(args.resources_path, args.dataset, "*.index")
30 | self.logger.info(f"Loading index from {index_file}...")
31 | self.index = faiss.read_index(index_file)
32 |
33 | import numpy as np
34 | self.np = np
35 | else:
36 | self.index = None
37 |
38 | self.tokenizer = None # must be overriden
39 |
40 | self.max_retriever_input_len = 256
41 | self.max_reader_input_len = 350
42 | self.max_answer_len = 10
43 | self.num_contexts = 10
44 | self.num_passage_answer_candidates = 5
45 |
46 | @abc.abstractmethod
47 | def get_question_encoding(self, question):
48 | pass
49 |
50 | @abc.abstractmethod
51 | def get_retriever_output(self, question, top_k):
52 | pass
53 |
54 | @abc.abstractmethod
55 | def get_reader_output(self, reader_input):
56 | pass
57 |
58 | @abc.abstractmethod
59 | def get_passage_score_weighted_answer_token_logits(self, token_logits, relevance_logits, attn_mask, passage_score_weight):
60 | pass
61 |
62 | def maybe_tensor_to_list(self, tensor):
63 | return tensor
64 |
65 | def get_inference_api(self):
66 | def api(question, top_k, passage_score_weight):
67 | return self.predict_answer(question, top_k, passage_score_weight, return_context=True)
68 |
69 | return api
70 |
71 | def inference_on_file(self, input_path, output_path, top_k=50, passage_score_weight=None):
72 | num_correct = 0
73 | total = 0
74 |
75 | evaluation_start = timer()
76 | with open(output_path, "w", encoding="utf-8") as f:
77 | for qa_dict in read_through_qa_file(input_path):
78 | start = timer()
79 | question = qa_dict["question"]
80 | prediction = self.predict_answer(question, top_k, passage_score_weight)
81 | qa_dict["prediction"] = prediction
82 |
83 | if "answers" in qa_dict:
84 | from minimal_rnr.utils.evaluation import ems
85 | correct = ems(prediction, qa_dict["answers"])
86 | qa_dict["correct"] = correct
87 |
88 | total += 1
89 | num_correct += int(correct)
90 |
91 | f.write(json.dumps(qa_dict, ensure_ascii=False) + "\n")
92 | self.logger.info(str(qa_dict) + " (%.4fs)" % (timer() - start))
93 |
94 | if total > 0:
95 | self.logger.info(f"EM: {100 * num_correct / total:.4f} ({num_correct} / {total})")
96 | self.logger.info(f"Evaluation took {timer() - evaluation_start:.4f}s.")
97 |
98 | def predict_answer(self, question, top_k=50, passage_score_weight=None, return_context=False):
99 | # retrieve
100 | start = timer()
101 | if self.index:
102 | retrieved_doc_ids = self._get_retriever_output_from_faiss_index(question, top_k)
103 | else:
104 | retrieved_doc_ids = self.get_retriever_output(question, top_k)
105 |
106 | self.logger.info(f" retrieve: {timer() - start:.4f}s")
107 |
108 | start = timer()
109 | title_doc_dict = self._get_title_doc_dict(retrieved_doc_ids)
110 | reader_input = self._get_reader_input(question, title_doc_dict)
111 | self.logger.info(f" convert: {timer() - start:.4f}s")
112 |
113 | # read
114 | start = timer()
115 | reader_output = self.get_reader_output(reader_input)
116 | self.logger.info(f" read: {timer() - start:.4f}s")
117 |
118 | start = timer()
119 | if passage_score_weight is not None:
120 | answer = self._get_answer_deep(reader_input, reader_output, title_doc_dict["titles"], passage_score_weight, return_context=return_context)
121 | else:
122 | answer = self._get_answer_greedy(reader_input, reader_output, title_doc_dict["titles"], return_context=return_context)
123 | self.logger.info(f" search: {timer() - start:.4f}s")
124 | return answer
125 |
126 | def _get_retriever_output_from_faiss_index(self, question, top_k):
127 | start = timer()
128 | question_encoding = self.get_question_encoding(question)
129 | self.logger.info(f" * encode: {timer() - start:.4f}s")
130 |
131 | start = timer()
132 | if not isinstance(question_encoding, self.np.ndarray): # tfserving_faiss
133 | question_encoding = self.np.asarray(question_encoding, dtype=self.np.float32)
134 | _, np_retrieved_doc_ids = self.index.search(question_encoding, top_k)
135 | self.logger.info(f" * faiss search: {timer() - start:.4f}s")
136 | retrieved_doc_ids = np_retrieved_doc_ids[0]
137 | return retrieved_doc_ids
138 |
139 | def _get_title_doc_dict(self, retrieved_doc_ids):
140 | retrieved_titles = []
141 | retrieved_docs = []
142 |
143 | for i in retrieved_doc_ids:
144 | retrieved_titles.append(self.all_titles[i])
145 | retrieved_docs.append(self.all_docs[i])
146 |
147 | return {
148 | "titles": retrieved_titles,
149 | "docs": retrieved_docs,
150 | }
151 |
152 | def _get_reader_input(self, question_str, title_doc_dict):
153 | input_ids = []
154 | attention_mask = []
155 |
156 | retrieved_titles = title_doc_dict["titles"]
157 | retrieved_docs = title_doc_dict["docs"]
158 |
159 | question = self.tokenizer.encode(question_str, max_length=self.max_retriever_input_len, truncation=True)
160 |
161 | # concat inputs
162 | for title, doc in zip(retrieved_titles, retrieved_docs):
163 | concat = question + title + [self.tokenizer.sep_token_id] + doc
164 | concat = concat[:self.max_reader_input_len]
165 | input_ids.append(concat)
166 | max_len = max(len(ids) for ids in input_ids)
167 |
168 | # pad inputs
169 | for i in range(len(input_ids)):
170 | padding = [self.tokenizer.pad_token_id] * (max_len - len(input_ids[i]))
171 | attention_mask.append([1] * len(input_ids[i]) + padding)
172 | input_ids[i] = input_ids[i] + padding
173 |
174 | return {
175 | "input_ids": input_ids,
176 | "attention_mask": attention_mask,
177 | }
178 |
179 | def _get_answer_deep(self, reader_input, reader_output, retrieved_titles, passage_score_weight=None, return_context=False):
180 | input_ids = reader_input["input_ids"]
181 | attn_mask = reader_input["attention_mask"]
182 |
183 | _start_logits = reader_output["start_logits"]
184 | _end_logits = reader_output["end_logits"]
185 | _relevance_logits = reader_output["relevance_logits"]
186 |
187 | # weighted scores
188 | start_logits = self.get_passage_score_weighted_answer_token_logits(_start_logits, _relevance_logits, attn_mask, passage_score_weight)
189 | end_logits = self.get_passage_score_weighted_answer_token_logits(_end_logits, _relevance_logits, attn_mask, passage_score_weight)
190 |
191 | start_logits = self.maybe_tensor_to_list(start_logits)
192 | end_logits = self.maybe_tensor_to_list(end_logits)
193 |
194 | candidate_answers = []
195 | candidate_contexts = []
196 |
197 | for passage_idx in range(len(input_ids)):
198 | sequence_len = sum(id_ != 0 for id_ in input_ids[passage_idx])
199 | passage_offset = input_ids[passage_idx].index(self.tokenizer.sep_token_id) + 1
200 | title_passage_ids = input_ids[passage_idx][passage_offset:sequence_len]
201 |
202 | p_start_logits = start_logits[passage_idx][passage_offset:sequence_len]
203 | p_end_logits = end_logits[passage_idx][passage_offset:sequence_len]
204 |
205 | scores = self._get_spans_sorted_with_scores(p_start_logits, p_end_logits)
206 |
207 | chosen_span_intervals = []
208 | p_candidate_answers = []
209 | p_candidate_contexts = []
210 |
211 | for (start_index, end_index), score in scores:
212 | ret = self._get_answer_and_passage(start_index, end_index, chosen_span_intervals, title_passage_ids)
213 | if not ret:
214 | continue
215 | else:
216 | answer, passage, start_index, ent_index = ret
217 | title = retrieved_titles[passage_idx]
218 |
219 | if not return_context:
220 | p_candidate_answers.append((answer, score))
221 | else:
222 | context = (self.tokenizer.decode(passage[:start_index])
223 | + " " + answer + " "
224 | + self.tokenizer.decode(passage[end_index + 1:]))
225 | p_candidate_contexts.append(({"title": self.tokenizer.decode(title),
226 | "context": context}, score))
227 |
228 | if max(len(p_candidate_answers), len(p_candidate_contexts)) == self.num_passage_answer_candidates:
229 | break
230 |
231 | if p_candidate_answers:
232 | candidate_answers.extend(p_candidate_answers)
233 | if p_candidate_contexts:
234 | candidate_contexts.extend(p_candidate_contexts)
235 |
236 | if not return_context:
237 | sorted_candidate_answers = sorted(candidate_answers, key=lambda x: x[1], reverse=True)
238 | return sorted_candidate_answers[0][0].strip()
239 | else:
240 | sorted_candidate_contexts = sorted(candidate_contexts, key=lambda x: x[1], reverse=True)
241 | return [context[0] for context in sorted_candidate_contexts][:self.num_contexts]
242 |
243 | def _get_answer_greedy(self, reader_input, reader_output, retrieved_titles, return_context=False):
244 | start_logits = reader_output["start_logits"]
245 | end_logits = reader_output["end_logits"]
246 | relevance_logits = reader_output["relevance_logits"]
247 |
248 | start_logits = self.maybe_tensor_to_list(start_logits)
249 | end_logits = self.maybe_tensor_to_list(end_logits)
250 | relevance_logits = self.maybe_tensor_to_list(relevance_logits)
251 |
252 | max_answer_length = 10
253 | input_ids = reader_input["input_ids"]
254 |
255 | top_passage_idx = max(enumerate(relevance_logits), key=lambda x: x[1])[0]
256 |
257 | sequence_len = sum(id_ != 0 for id_ in input_ids[top_passage_idx])
258 | passage_offset = input_ids[top_passage_idx].index(self.tokenizer.sep_token_id) + 1
259 | title_passage_ids = input_ids[top_passage_idx][passage_offset:sequence_len]
260 | p_start_logits = start_logits[top_passage_idx][passage_offset:sequence_len]
261 | p_end_logits = end_logits[top_passage_idx][passage_offset:sequence_len]
262 |
263 | scores = self._get_spans_sorted_with_scores(p_start_logits, p_end_logits)
264 | chosen_span_intervals = []
265 |
266 | for (start_index, end_index), score in scores:
267 | assert start_index <= end_index
268 | length = end_index - start_index + 1
269 | assert length <= max_answer_length
270 |
271 | ret = self._get_answer_and_passage(start_index, end_index, chosen_span_intervals, title_passage_ids)
272 | if not ret:
273 | continue
274 | answer, passage, start_index, end_index = ret
275 | title = retrieved_titles[top_passage_idx]
276 |
277 | if not return_context:
278 | return answer.strip()
279 | else:
280 | context = (self.tokenizer.decode(passage[:start_index])
281 | + " " + answer + " "
282 | + self.tokenizer.decode(passage[end_index + 1:]))
283 | return [{"title": self.tokenizer.decode(title), "context": context}]
284 |
285 | def _get_spans_sorted_with_scores(self, p_start_logits, p_end_logits):
286 | scores = []
287 | for (i, s) in enumerate(p_start_logits):
288 | for (j, e) in enumerate(p_end_logits[i:i + self.max_answer_len]):
289 | scores.append(((i, i + j), s + e))
290 | scores = sorted(scores, key=lambda x: x[1], reverse=True)
291 | return scores
292 |
293 | def _get_answer_and_passage(self, start_index, end_index, chosen_span_intervals, title_passage_ids):
294 | assert start_index <= end_index
295 | length = end_index - start_index + 1
296 | assert length <= self.max_answer_len
297 |
298 | if any([start_index <= prev_start_index <= prev_end_index <= end_index or
299 | prev_start_index <= start_index <= end_index <= prev_end_index
300 | for (prev_start_index, prev_end_index) in chosen_span_intervals]):
301 | return
302 |
303 | start_index, end_index = extend_span_to_full_words(self.tokenizer, title_passage_ids, (start_index, end_index))
304 |
305 | title_offset = title_passage_ids.index(self.tokenizer.sep_token_id) + 1
306 |
307 | passage = title_passage_ids[title_offset:]
308 | start_index -= title_offset
309 | end_index -= title_offset
310 |
311 | answer = self.tokenizer.decode(passage[start_index:end_index + 1])
312 |
313 | return answer, passage, start_index, end_index
314 |
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/logger.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # minimal-rnr-qa
3 | # Copyright 2021-present NAVER Corp.
4 | # Apache License v2.0
5 |
6 | import sys
7 | import logging
8 |
9 |
10 | def get_logger(name):
11 | if logging.getLogger(name).hasHandlers():
12 | return logging.getLogger(name)
13 |
14 | # initialization
15 | formatter = logging.Formatter(fmt="[MinR&R %(asctime)s] %(message)s",
16 | datefmt="%Y-%m-%d %H:%M:%S")
17 | handler = logging.StreamHandler(stream=sys.stdout)
18 | handler.setFormatter(formatter)
19 |
20 | logger = logging.getLogger(name)
21 | logger.setLevel(logging.DEBUG)
22 | logger.addHandler(handler)
23 | logger.propagate = False
24 | return logger
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/static/examples.txt:
--------------------------------------------------------------------------------
1 | What is the purpose of life?
2 | What is Naver known for?
3 | Famous musician in South Korea
4 | Where can you find water in desert?
5 | Name three famous writers
6 | Why is regular exercise important?
7 | Which city is famous for coffee?
8 | who is the villain in Harry Potter?
9 | What does a developer do?
10 | When is National Liberation Day of Korea?
11 | What is another term for x-ray imaging?
12 | When did the movie Avengers come out?
13 | What is water consisted of?
14 | Who is the director of the movie Interstellar?
15 | In what year did Nikola Tesla emigrate to the United States?
16 | Who coined the term Deep Learning?
17 | The most beautiful poem in the world
18 | What is Machine Learning?
19 | Who is the main character in Frozen?
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/static/files/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/utils/static/files/icon.png
--------------------------------------------------------------------------------
/playground/workspace/minimal_rnr/utils/static/files/popper.min.js:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (C) Federico Zivolo 2018
3 | Distributed under the MIT License (license terms are at http://opensource.org/licenses/MIT).
4 | */(function(e,t){'object'==typeof exports&&'undefined'!=typeof module?module.exports=t():'function'==typeof define&&define.amd?define(t):e.Popper=t()})(this,function(){'use strict';function e(e){return e&&'[object Function]'==={}.toString.call(e)}function t(e,t){if(1!==e.nodeType)return[];var o=getComputedStyle(e,null);return t?o[t]:o}function o(e){return'HTML'===e.nodeName?e:e.parentNode||e.host}function n(e){if(!e)return document.body;switch(e.nodeName){case'HTML':case'BODY':return e.ownerDocument.body;case'#document':return e.body;}var i=t(e),r=i.overflow,p=i.overflowX,s=i.overflowY;return /(auto|scroll|overlay)/.test(r+s+p)?e:n(o(e))}function r(e){return 11===e?re:10===e?pe:re||pe}function p(e){if(!e)return document.documentElement;for(var o=r(10)?document.body:null,n=e.offsetParent;n===o&&e.nextElementSibling;)n=(e=e.nextElementSibling).offsetParent;var i=n&&n.nodeName;return i&&'BODY'!==i&&'HTML'!==i?-1!==['TD','TABLE'].indexOf(n.nodeName)&&'static'===t(n,'position')?p(n):n:e?e.ownerDocument.documentElement:document.documentElement}function s(e){var t=e.nodeName;return'BODY'!==t&&('HTML'===t||p(e.firstElementChild)===e)}function d(e){return null===e.parentNode?e:d(e.parentNode)}function a(e,t){if(!e||!e.nodeType||!t||!t.nodeType)return document.documentElement;var o=e.compareDocumentPosition(t)&Node.DOCUMENT_POSITION_FOLLOWING,n=o?e:t,i=o?t:e,r=document.createRange();r.setStart(n,0),r.setEnd(i,0);var l=r.commonAncestorContainer;if(e!==l&&t!==l||n.contains(i))return s(l)?l:p(l);var f=d(e);return f.host?a(f.host,t):a(e,d(t).host)}function l(e){var t=1=o.clientWidth&&n>=o.clientHeight}),l=0a[e]&&!t.escapeWithReference&&(n=J(f[o],a[e]-('right'===e?f.width:f.height))),ae({},o,n)}};return l.forEach(function(e){var t=-1===['left','top'].indexOf(e)?'secondary':'primary';f=le({},f,m[t](e))}),e.offsets.popper=f,e},priority:['left','right','top','bottom'],padding:5,boundariesElement:'scrollParent'},keepTogether:{order:400,enabled:!0,fn:function(e){var t=e.offsets,o=t.popper,n=t.reference,i=e.placement.split('-')[0],r=Z,p=-1!==['top','bottom'].indexOf(i),s=p?'right':'bottom',d=p?'left':'top',a=p?'width':'height';return o[s]r(n[s])&&(e.offsets.popper[d]=r(n[s])),e}},arrow:{order:500,enabled:!0,fn:function(e,o){var n;if(!q(e.instance.modifiers,'arrow','keepTogether'))return e;var i=o.element;if('string'==typeof i){if(i=e.instance.popper.querySelector(i),!i)return e;}else if(!e.instance.popper.contains(i))return console.warn('WARNING: `arrow.element` must be child of its popper element!'),e;var r=e.placement.split('-')[0],p=e.offsets,s=p.popper,d=p.reference,a=-1!==['left','right'].indexOf(r),l=a?'height':'width',f=a?'Top':'Left',m=f.toLowerCase(),h=a?'left':'top',c=a?'bottom':'right',u=S(i)[l];d[c]-us[c]&&(e.offsets.popper[m]+=d[m]+u-s[c]),e.offsets.popper=g(e.offsets.popper);var b=d[m]+d[l]/2-u/2,y=t(e.instance.popper),w=parseFloat(y['margin'+f],10),E=parseFloat(y['border'+f+'Width'],10),v=b-e.offsets.popper[m]-w-E;return v=$(J(s[l]-u,v),0),e.arrowElement=i,e.offsets.arrow=(n={},ae(n,m,Q(v)),ae(n,h,''),n),e},element:'[x-arrow]'},flip:{order:600,enabled:!0,fn:function(e,t){if(W(e.instance.modifiers,'inner'))return e;if(e.flipped&&e.placement===e.originalPlacement)return e;var o=v(e.instance.popper,e.instance.reference,t.padding,t.boundariesElement,e.positionFixed),n=e.placement.split('-')[0],i=T(n),r=e.placement.split('-')[1]||'',p=[];switch(t.behavior){case he.FLIP:p=[n,i];break;case he.CLOCKWISE:p=z(n);break;case he.COUNTERCLOCKWISE:p=z(n,!0);break;default:p=t.behavior;}return p.forEach(function(s,d){if(n!==s||p.length===d+1)return e;n=e.placement.split('-')[0],i=T(n);var a=e.offsets.popper,l=e.offsets.reference,f=Z,m='left'===n&&f(a.right)>f(l.left)||'right'===n&&f(a.left)f(l.top)||'bottom'===n&&f(a.top)f(o.right),g=f(a.top)f(o.bottom),b='left'===n&&h||'right'===n&&c||'top'===n&&g||'bottom'===n&&u,y=-1!==['top','bottom'].indexOf(n),w=!!t.flipVariations&&(y&&'start'===r&&h||y&&'end'===r&&c||!y&&'start'===r&&g||!y&&'end'===r&&u);(m||b||w)&&(e.flipped=!0,(m||b)&&(n=p[d+1]),w&&(r=G(r)),e.placement=n+(r?'-'+r:''),e.offsets.popper=le({},e.offsets.popper,C(e.instance.popper,e.offsets.reference,e.placement)),e=P(e.instance.modifiers,e,'flip'))}),e},behavior:'flip',padding:5,boundariesElement:'viewport'},inner:{order:700,enabled:!1,fn:function(e){var t=e.placement,o=t.split('-')[0],n=e.offsets,i=n.popper,r=n.reference,p=-1!==['left','right'].indexOf(o),s=-1===['top','left'].indexOf(o);return i[p?'left':'top']=r[o]-(s?i[p?'width':'height']:0),e.placement=T(t),e.offsets.popper=g(i),e}},hide:{order:800,enabled:!0,fn:function(e){if(!q(e.instance.modifiers,'hide','preventOverflow'))return e;var t=e.offsets.reference,o=D(e.instance.modifiers,function(e){return'preventOverflow'===e.name}).boundaries;if(t.bottomo.right||t.top>o.bottom||t.right
2 |
3 |
4 | Minimal R&R QA
5 |
6 |
7 |
8 |
9 |
10 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
38 |
39 |
40 |
41 |
42 |
45 |
46 |
47 |
48 |
49 |
52 |
53 |
56 |
57 |
58 |
59 |
60 |
Latency:
61 |
62 | 5.8% of Wikipedia EN (Dec. 20, 2018 dump)
63 |