├── LICENSE ├── NOTICE ├── README.md ├── competition ├── Dockerfile ├── build.sh ├── submission.sh └── workspace │ ├── __init__.py │ ├── bert_tokenizer │ ├── __init__.py │ ├── file_utils.py │ ├── mobilebert-uncased │ │ ├── tokenizer.json │ │ └── vocab.txt │ ├── tokenization_bert.py │ ├── tokenization_utils.py │ └── tokenization_utils_base.py │ └── infer.py └── playground ├── build.sh ├── configs ├── Dockerfile-pytorch ├── Dockerfile-tfserving ├── requirements-pytorch.txt ├── requirements-tfserving.txt └── requirements-tfserving_faiss.txt ├── entrypoint.sh └── workspace ├── minimal_rnr ├── __init__.py ├── pytorch │ ├── __init__.py │ ├── inferencer.py │ └── model.py ├── tfserving │ ├── __init__.py │ ├── bert_tokenizer │ │ ├── __init__.py │ │ ├── file_utils.py │ │ ├── mobilebert-uncased │ │ │ ├── tokenizer.json │ │ │ └── vocab.txt │ │ ├── tokenization_bert.py │ │ ├── tokenization_utils.py │ │ └── tokenization_utils_base.py │ └── inferencer.py └── utils │ ├── __init__.py │ ├── demo.py │ ├── evaluation.py │ ├── inference.py │ ├── inferencer.py │ ├── logger.py │ └── static │ ├── examples.txt │ ├── files │ ├── all.js │ ├── bootstrap.min.js │ ├── icon.png │ ├── jquery-3.3.1.min.js │ ├── popper.min.js │ └── style.css │ └── index.html ├── run_pt_demo.py ├── run_pt_inference.py ├── run_tf_demo.py └── run_tf_inference.py /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "{}" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2021-present NAVER Corp. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Designing a Minimal Retrieve-and-Read System for Open-Domain Question Answering 2 | 3 | ## Abstract 4 | In open-domain question answering (QA), retrieve-and-read mechanism has the inherent benefit of interpretability and the easiness of adding, removing, or editing knowledge compared to the parametric approaches of closed-book QA models. However, it is also known to suffer from its large storage footprint due to its document corpus and index. Here, we discuss several orthogonal strategies to drastically reduce the footprint of a retrieve-and-read open-domain QA system by up to 160x. Our results indicate that retrieve-and-read can be a viable option even in a highly constrained serving environment such as edge devices, as we show that it can achieve better accuracy than a purely parametric model with comparable docker-level system size. 5 | 6 | - [Paper](https://arxiv.org/abs/2104.07242) (To appear in [NAACL 2021](https://2021.naacl.org/)) 7 | - Authors: [Sohee Yang](https://soheeyang.github.io/) and [Minjoon Seo](http://seominjoon.github.io/) 8 | - [Live Web Demo](http://52.156.155.214:8890) 9 | - The demo normally takes **1-3 seconds** for inference in the default setting, but ocassionally becomes slow according to the server condition. 10 | - BibTeX: 11 | 12 | ``` 13 | @inproceedings{mininalrnr, 14 | title={Designing a Minimal Retrieve-and-Read System for Open-Domain Question Answering}, 15 | author={Yang, Sohee and Seo, Minjoon}, 16 | booktitle={NAACL}, 17 | year={2021} 18 | } 19 | ``` 20 | 21 | This repository contains the code of the **Minimal Retrieve & Read QA System** that ranked the first place in the human (manual) evaluation and the second place in the automatic evaluation on "Systems Under 500Mb Track" of the [NeurIPS 2020 EfficientQA competition](https://efficientqa.github.io/). 22 | 23 | ## Web Demo 24 | ![image](https://user-images.githubusercontent.com/28291528/114962230-4a672c00-9ea5-11eb-8634-9b563c9d0d9a.png) 25 | 26 | You can play with the QA system in the [Live Web Demo](http://52.156.155.214:8890). You can also dynamically change the inference setting by controlling the values of `top_k` and `passage_score_weight`. (The demo normally takes **1-3 seconds** for inference in the default setting, but ocassionally becomes slow according to the server condition.) 27 | - `top_k`: sets the number of passages to retrieve and pass to the reader. The value must be a positive integer. Default value is set to 50. For the live web demo, the values are limited within the range [1, 100] to prevent freezing from reading too many passages. 28 | - `passage_score_weight`: 29 | - When the default value `null` is used, only the passage with the highest ranking score is used to extract the answers from. This is the setting used in DPR. 30 | - If a float value (λ ∈ [0, 1] highly recommended) is given, multiple passages are considered to select the answer. Specifically, answer spans from multiple passages are scored using the weighted sum of passage ranking scores and answer spans scores. The weighted sum is calculated as (1 - λ) (log Pstart + log Pend) + 2λ log Prank. Please refer to the paper for more details. 31 | 32 | If you have Docker installed, you can also run the web demo on your local machine in five minutes using [this command](#b-local-web-demo-quickstart). 33 | 34 | ## A. Introduction 35 | 36 | This repository contains the code for an interactive web demo, code for inference on the questions in a file (and evaluation on the answers), links to the model graphs/checkpoints (models), links to the index and preprocessed corpus files (resources), and links to built docker images. 37 | 38 | - `competition` directory contains the code to build and run the minimal-sized docker container used for the EfficientQA competition. Typing `du -h /` in the launched container reports 484.68MB as its size. Please see [**E. Competition Setting: Build & Run**](#e-competition-setting-build--run) for detail. 39 | - `playground` directory contains more practical, refactored code to play with that one can either run a web demo or run inference on a file using models built in different settings. Please see [**D. Playground: Build & Run**](#d-playground-build--run) for detail. 40 | 41 | ## B. Local Web Demo Quickstart 42 | 43 | To run the web demo on your local machine, run the following using docker: 44 | 45 | ```bash 46 | docker run \ 47 | -v /etc/localtime:/etc/localtime:ro \ 48 | --oom-kill-disable=true \ 49 | --env MODE=demo \ 50 | -p 10001:10001 \ 51 | soheeyang/minimal-rnr-qa:effqa-tfserving \ 52 | /workspace/entrypoint.sh 53 | ``` 54 | 55 | Then, access [http://localhost:10001](http://localhost:10001) to play with the demo! 56 | 57 | ## C. Pre-Built Docker Images 58 | 59 | Available in [https://hub.docker.com/r/soheeyang/minimal-rnr-qa](https://hub.docker.com/r/soheeyang/minimal-rnr-qa) 60 | - soheeyang/minimal-rnr-qa:$DATASET-$MODEL_TYPE 61 | - $DATASET: [ effqa | nq | trivia ] 62 | - $MODEL_TYPE: [ tfserving | tfserving_faiss | pytorch ] 63 | - soheeyang/minimal-rnr-qa:$DATASET-competition 64 | - $DATASET: [ effqa | nq | trivia ] 65 | - `soheeyang/minimal-rnr-qa:effqa-competition` is the docker container used for the EfficientQA challenge 66 | 67 | The follwoing are descriptions for each of the options for DATASET and MODEL_TYPE. 68 | - `$DATASET` is used to select the **model**; The model trained on this dataset is selected. The value must be one of the followings. 69 | - `effqa` trained on Natural Questions (NQ) train set, validation done on EfficientQA dev set 70 | - `nq` trained on NQ train set, validation done on NQ dev set 71 | - `trivia` trained on TriviaQA (Trivia) train set, validation done on Trivia dev set 72 | - `$MODEL_TYPE` is used to select the type of the chosen model. The value must be one of the followings. 73 | - `tfserving` TensorFlow (TF) graph for TF Serving. Index is fused into the graph to perform efficient passage retrieval without additional library dependency. This is the setting used in the EfficientQA competition. CPU serving. Smallest system footprint. 74 | - `tfserving_faiss` TF graph for TF Serving, but without index. It installs and makes use of FAISS to perform passage retrieval. CPU serving. 75 | - `pytorch` PyTorch checkpoint. It installs and makes use of FAISS to perform passage retrieval. The model code can be found at `playground/workspace/minimal_rnr/pytorch/model.py`. Supports serving on both CPU & GPU. Largest system footprint. 76 | 77 | ## D. Playground: Build & Run 78 | 79 | You can skip steps 1 and 2 if you use the [pre-built docker images](#c-pre-built-docker-images). 80 | 81 | ### 1. Download the code and necessary resources 82 | 83 | ```bash 84 | git clone https://github.com/clovaai/minimal-rnr-qa.git 85 | cd minimal-rnr-qa/playground 86 | 87 | wget https://dl.dropboxusercontent.com/s/l7034dttyp4bbf2/minrnr_playground_models.tar.gz 88 | wget https://dl.dropboxusercontent.com/s/51g36ytprbcl3mv/minrnr_playground_resources.tar.gz 89 | 90 | tar xvf minrnr_playground_models.tar.gz 91 | tar xvf minrnr_playground_resources.tar.gz 92 | ``` 93 | 94 | ### 2. Build docker image 95 | 96 | ```bash 97 | # inside minimal-rnr-qa/playground 98 | 99 | DATASET=effqa 100 | MODEL_TYPE=tfserving 101 | 102 | chmod a+x ./build.sh 103 | ./build.sh $DATASET $MODEL_TYPE 104 | ``` 105 | - This command builds a docker image tagged as `minimal-rnr-qa:$DATASET-$MODEL_TYPE`. 106 | 107 | ### 3-1. Run web demo 108 | 109 | ```jsx 110 | docker run \ 111 | -v /etc/localtime:/etc/localtime:ro \ 112 | --oom-kill-disable=true \ 113 | --env MODE=demo \ 114 | -p 10001:10001 \ 115 | minimal-rnr-qa:$DATASET-$MODEL_TYPE \ 116 | /workspace/entrypoint.sh 117 | ``` 118 | 119 | - `-v /etc/localtime:/etc/localtime:ro` sets the timezone of the container to be same with the host's. 120 | - `--oom-kill-disable=true` prevents kill by OOM. 121 | - `--env MODE=demo` [REQUIRED] runs a **web demo**. 122 | - `-p $HOST_PORT:10001` [REQUIRED] sets the port of the web page. connects the port 10001 of the container to a port of the host. 123 | - `minimal-rnr-qa:$DATASET-$MODEL_TYPE` [REQUIRED] Tag of the built image. 124 | - `/workspace/entrypoint.sh` [REQUIRED] Entrypoint of the container. 125 | 126 | ### 3-2. Run inference on a file 127 | 128 | #### Download input data 129 | The input files for EfficientQA dev set, NQ dev & test set, and Trivia dev & test set can be downloaded at once. 130 | ```bash 131 | INPUT_DIR=/tmp/minimal-rnr-qa 132 | OUTPUT_DIR=/tmp/minimal-rnr-qa 133 | 134 | mkdir -p $INPUT_DIR 135 | mkdir -p $OUTPUT_DIR 136 | 137 | wget -P $INPUT_DIR https://dl.dropboxusercontent.com/s/juh12j1z0ct3zeu/minrnr_datasets.tar.gz 138 | tar xvf $INPUT_DIR/minrnr_datasets.tar.gz -C $INPUT_DIR --strip-components=1 139 | ``` 140 | 141 | #### Run inference 142 | ```bash 143 | INPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1.jsonl 144 | OUTPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1-predictions.jsonl 145 | TOP_K=80 146 | PASSAGE_W=0.8 147 | 148 | docker run \ 149 | -v /etc/localtime:/etc/localtime:ro \ 150 | --oom-kill-disable=true \ 151 | -v $INPUT_DIR:/input \ 152 | -v $OUTPUT_DIR:/output \ 153 | --env MODE=file \ 154 | --env TOP_K=$TOP_K \ 155 | --env PASSAGE_W=$PASSAGE_W \ 156 | minimal-rnr-qa:$DATASET-$MODEL_TYPE \ 157 | /workspace/entrypoint.sh \ 158 | /input/$INPUT_FILE_NAME \ 159 | /output/$OUTPUT_FILE_NAME 160 | ``` 161 | 162 | - `-v /etc/localtime:/etc/localtime:ro` sets the timezone of the container to be same with the host's. 163 | - `--oom-kill-disable=true` prevents kill by OOM. 164 | - `-v $INPUT_DIR:/input` [REQUIRED] maps `$INPUT_DIR` of the host to `/input` in the container where the data is read from. This directory must have the file to run inference on. 165 | - `-v $OUTPUT_DIR:/output` [OPTIONAL] maps `$OUTPUT_DIR` of the host to `/output` in the container where the prediction result file is written. If not specified, the output prediction file is written only in the container. 166 | - `--env MODE=demo` [REQUIRED] runs **inference on the given input file** and outputs the predictions 167 | - `--env TOP_K=$INT_VALUE` [OPTIONAL] sets the number of passages to retrieve and pass to the reader. It must be an integer value. Default value is set to 50. 168 | - `--env PASSAGE_W=$FLOAT_VALUE` [OPTIONAL] 169 | - If the option is not used (as default) or `null` is given as the value, only the passage with the highest ranking score is used to extract the answers from. This is the setting used in DPR. 170 | - If the value is given, multiple passages are considered to select the answer. Specifically, answer spans from multiple passages are scored using the weighted sum of passage ranking scores and answer spans scores. The given value for this option must be λ ∈ [0, 1], and the weighted sum is calculated as (1 - λ) (log Pstart + log Pend) + 2λ log Prank. This value may be tuned on the validation set to slightly raise the end-to-end question answering accuracy. 171 | - `minimal-rnr-qa:$DATASET_$MODEL_TYPE` [REQUIRED] Tag of the built image. 172 | - `/workspace/entrypoint.sh` [REQUIRED] Entrypoint of the container. 173 | - `/input/$INPUT_FILE_NAME` [REQUIRED] Name of the file to run inference on. CSV or JSON Lines files are supported. 174 | - CSV files must consist of row of question strings or `question\t["answer_1", ..., "answer_n"]`. 175 | - JSON Lines files must consist of rows of `{"question": ...}`, `{"question": ..., "answers": ...}`, or `{"question": ..., "answer": ...}`. 176 | - If answers exist, Exact Match (EM) score is calculated and reported at the end of the inference. 177 | - `/output/$OUTPUT_FILE_NAME` [REQUIRED] Name of the output prediction result file. The file takes JSON Lines format. Please note that even if "answer" is given as the key for answers in the input file, it changes to "answers" in the prediction file for consistency and easier evaluation. 178 | 179 | ## E. Competition Setting: Build & Run 180 | 181 | You can skip steps 1 and 2 if you use the [pre-built docker images](#c-pre-built-docker-images). 182 | 183 | ### 1. Download the code and necessary resources 184 | 185 | ```bash 186 | git clone https://github.com/clovaai/minimal-rnr-qa.git 187 | cd minimal-rnr-qa/competition 188 | 189 | wget https://dl.dropboxusercontent.com/s/s5fa4rgf48bhhkb/minrnr_competition_resources.tar.gz 190 | wget https://dl.dropboxusercontent.com/s/utwzozvuret1sdo/minrnr_competition_models.tar.gz 191 | 192 | tar xvf minrnr_competition_models.tar.gz 193 | tar xvf minrnr_competition_resources.tar.gz 194 | ``` 195 | 196 | ### 2. Build docker image 197 | ```bash 198 | # inside minimal-rnr-qa/competition 199 | 200 | DATASET=effqa 201 | 202 | chmod a+x ./build.sh 203 | ./build.sh $DATASET 204 | ``` 205 | - Values for `$DATASET` 206 | - `effqa`: the model used in the challenge (Section 3 in the paper) 207 | - `nq`: trained on Natural Questions (Appendix A.5 in the paper) 208 | - `trivia`: trained on TriviaQA (Appendix A.5 in the paper) 209 | - This command builds a docker image tagged as `minimal-rnr-qa:$DATASET-competition`. 210 | 211 | ### 3. Prepare data (same as the above) 212 | 213 | The input files for EfficientQA dev set, NQ dev & test set, and Trivia dev & test set can be downloaded at once. 214 | ```bash 215 | INPUT_DIR=/tmp/minimal-rnr-qa 216 | OUTPUT_DIR=/tmp/minimal-rnr-qa 217 | 218 | mkdir -p $INPUT_DIR 219 | mkdir -p $OUTPUT_DIR 220 | 221 | wget -P $INPUT_DIR https://dl.dropboxusercontent.com/s/juh12j1z0ct3zeu/minrnr_datasets.tar.gz 222 | tar xvf $INPUT_DIR/minrnr_datasets.tar.gz -C $INPUT_DIR --strip-components=1 223 | ``` 224 | 225 | ### 4. Run 226 | 227 | ```bash 228 | # The setting used for EfficientQA submission 229 | 230 | INPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1.jsonl 231 | OUTPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1-predictions.jsonl 232 | TOP_K=80 233 | PASSAGE_W=0.8 234 | 235 | docker run \ 236 | -v ${INPUT_DIR}:/input \ 237 | -v ${OUTPUT_DIR}:/output \ 238 | --env TOP_K=$TOP_K \ 239 | --env PASSAGE_W=$PASSAGE_W \ 240 | --network="none" \ 241 | --oom-kill-disable=true \ 242 | minimal-rnr-qa:$DATASET-competition \ 243 | /submission.sh \ 244 | /input/$INPUT_FILE_NAME \ 245 | /output/$OUTPUT_FILE_NAME 246 | ``` 247 | 248 | - Below are the parameters to reproduce each of the results of the last row in Table 3 (in the Appendix of the paper). 249 | - EfficientQA dev 250 | - DATASET=effqa / TOP_K=80 / PASSAGE_W=null 251 | - INPUT_FILE_NAME=NQ-open.efficientqa.dev.1.1.jsonl (from [this link](https://github.com/google-research-datasets/natural-questions/blob/master/nq_open/NQ-open.efficientqa.dev.1.1.jsonl)) 252 | - While 34.33 is reported in the paper, the value changed to 34.55 after we rebuilt the TensorFlow graph w.r.t. refactoring. The model supported here is the latter one. 253 | - NQ dev 254 | - DATASET=nq / TOP_K=100 / PASSAGE_W=null 255 | - INPUT_FILE_NAME=nq-dev.jsonl 256 | - NQ test 257 | - DATASET=nq / TOP_K=90 / PASSAGE_W=null 258 | - INPUT_FILE_NAME=nq-test.jsonl 259 | - Trivia dev 260 | - DATASET=trivia / TOP_K=100 / PASSAGE_W=null 261 | - INPUT_FILE_NAME=trivia-dev.jsonl 262 | - Trivia test 263 | - DATASET=trivia / TOP_K=100 / PASSAGE_W=null 264 | - INPUT_FILE_NAME=trivia-test.jsonl 265 | 266 | ## F. License 267 | 268 | ``` 269 | Copyright 2021-present NAVER Corp. 270 | 271 | Licensed under the Apache License, Version 2.0 (the "License"); 272 | you may not use this file except in compliance with the License. 273 | You may obtain a copy of the License at 274 | 275 | http://www.apache.org/licenses/LICENSE-2.0 276 | 277 | Unless required by applicable law or agreed to in writing, software 278 | distributed under the License is distributed on an "AS IS" BASIS, 279 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 280 | See the License for the specific language governing permissions and 281 | limitations under the License. 282 | ``` 283 | -------------------------------------------------------------------------------- /competition/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/serving:2.3.0 as build_image 2 | FROM python:3.6.11-slim-buster 3 | 4 | ARG DATASET 5 | 6 | # Install TF Serving pkg. 7 | COPY --from=build_image /usr/bin/tensorflow_model_server /usr/bin/tensorflow_model_server 8 | 9 | # Reduce size 10 | RUN apt-get update && apt-get install -y --no-install-recommends lzma bzip2 && apt-get clean && rm -rf /var/lib/apt/lists/* 11 | RUN bzip2 /usr/bin/tensorflow_model_server 12 | 13 | # Install python packages. 14 | RUN pip install --no-cache-dir absl-py requests 15 | RUN pip uninstall pip -y 16 | RUN gzip -r /usr/local/lib/python3.6 17 | 18 | # Delete unnecessary files 19 | RUN rm -rf /root/* && rm /usr/bin/perl* && rm -rf /usr/lib/x86_64-linux-gnu/perl* && rm -rf /var/cache 20 | 21 | COPY models/$DATASET /models/minimal-rnr-qa/1 22 | COPY resources/$DATASET /resources 23 | COPY workspace /workspace 24 | COPY submission.sh / 25 | 26 | RUN chmod a+x /submission.sh 27 | 28 | WORKDIR / 29 | -------------------------------------------------------------------------------- /competition/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # coding=utf-8 3 | # minimal-rnr-qa 4 | # Copyright 2021-present NAVER Corp. 5 | # Apache License v2.0 6 | 7 | if [ "$#" -ne 1 ]; then 8 | echo "Usage: ./build.sh DATASET" 9 | echo "DATASET: effqa | nq | trivia" 10 | exit 0 11 | fi 12 | 13 | DATASET=$1 # ["effqa", "nq", "trivia"] 14 | 15 | docker build -f Dockerfile \ 16 | --build-arg DATASET=$DATASET \ 17 | . -t minimal-rnr-qa:$DATASET-competition 18 | 19 | echo "minimal-rnr-qa:$DATASET"-competition -------------------------------------------------------------------------------- /competition/submission.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # coding=utf-8 3 | # minimal-rnr-qa 4 | # Copyright 2021-present NAVER Corp. 5 | # Apache License v2.0 6 | 7 | URL="http://127.0.0.1:8501/v1/models/minimal-rnr-qa:predict" 8 | 9 | # decompress tensorflow_model_server bzip2 10 | bzip2 -d /usr/bin/tensorflow_model_server.bz2 11 | chmod a+x /usr/bin/tensorflow_model_server 12 | 13 | # decompress python 14 | gzip -dr /usr/local/lib/python3.6 15 | 16 | # decompress resources lzma 17 | for f in /resources/*.lzma; do lzma -dc $f > "${f%.*}"; done 18 | for f in /resources/*.lzma; do rm $f; done 19 | 20 | # decompress models bzip2 21 | bzip2 -d /models/minimal-rnr-qa/1/variables/variables.data-00000-of-00001.bz2 22 | bzip2 -d /models/minimal-rnr-qa/1/saved_model.pb.bz2 23 | 24 | # server 25 | /usr/bin/tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=minimal-rnr-qa --model_base_path=/models/minimal-rnr-qa & 26 | 27 | # client 28 | cd /workspace 29 | python -u infer.py --url $URL --resources_path /resources --input_path $1 --output_path $2 `if [[ -n "${TOP_K}" ]]; then echo --top_k $TOP_K; fi` `if [[ -n "${PASSAGE_W}" ]]; then echo --passage_score_weight $PASSAGE_W; fi` 30 | -------------------------------------------------------------------------------- /competition/workspace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/competition/workspace/__init__.py -------------------------------------------------------------------------------- /competition/workspace/bert_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenization_bert import BertTokenizer 2 | 3 | __all__ = ['BertTokenizer'] -------------------------------------------------------------------------------- /competition/workspace/bert_tokenizer/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import logging 8 | import os 9 | from functools import wraps 10 | from typing import Optional 11 | 12 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 13 | 14 | USE_TF = True 15 | USE_TORCH = False 16 | _torch_available = False 17 | _torch_tpu_available = False 18 | _psutil_available = False 19 | _py3nvml_available = False 20 | _has_apex = False 21 | 22 | try: 23 | import tensorflow as tf 24 | 25 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 26 | _tf_available = True # pylint: disable=invalid-name 27 | logger.info("TensorFlow version {} available.".format(tf.__version__)) 28 | except (ImportError, AssertionError): 29 | _tf_available = False # pylint: disable=invalid-name 30 | 31 | WEIGHTS_NAME = "pytorch_model.bin" 32 | TF2_WEIGHTS_NAME = "tf_model.h5" 33 | TF_WEIGHTS_NAME = "model.ckpt" 34 | CONFIG_NAME = "config.json" 35 | MODEL_CARD_NAME = "modelcard.json" 36 | 37 | 38 | def is_tf_available(): 39 | return _tf_available 40 | 41 | 42 | def cached_path( 43 | url_or_filename, 44 | ) -> Optional[str]: 45 | if os.path.exists(url_or_filename): 46 | output_path = url_or_filename 47 | else: 48 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 49 | 50 | return output_path 51 | 52 | 53 | def tf_required(func): 54 | # Chose a different decorator name than in tests so it's clear they are not the same. 55 | @wraps(func) 56 | def wrapper(*args, **kwargs): 57 | if is_tf_available(): 58 | return func(*args, **kwargs) 59 | else: 60 | raise ImportError(f"Method `{func.__name__}` requires TF.") 61 | 62 | return wrapper -------------------------------------------------------------------------------- /competition/workspace/bert_tokenizer/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | 18 | import collections 19 | import logging 20 | import os 21 | import unicodedata 22 | from typing import List, Optional 23 | 24 | from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | "vocab_file": { 33 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 34 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 35 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 36 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 37 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 38 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 39 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 40 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 41 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 42 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 43 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 44 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 45 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 46 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", 47 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", 48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", 49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt", 50 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt", 51 | } 52 | } 53 | 54 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 55 | "bert-base-uncased": 512, 56 | "bert-large-uncased": 512, 57 | "bert-base-cased": 512, 58 | "bert-large-cased": 512, 59 | "bert-base-multilingual-uncased": 512, 60 | "bert-base-multilingual-cased": 512, 61 | "bert-base-chinese": 512, 62 | "bert-base-german-cased": 512, 63 | "bert-large-uncased-whole-word-masking": 512, 64 | "bert-large-cased-whole-word-masking": 512, 65 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512, 66 | "bert-large-cased-whole-word-masking-finetuned-squad": 512, 67 | "bert-base-cased-finetuned-mrpc": 512, 68 | "bert-base-german-dbmdz-cased": 512, 69 | "bert-base-german-dbmdz-uncased": 512, 70 | "TurkuNLP/bert-base-finnish-cased-v1": 512, 71 | "TurkuNLP/bert-base-finnish-uncased-v1": 512, 72 | "wietsedv/bert-base-dutch-cased": 512, 73 | } 74 | 75 | PRETRAINED_INIT_CONFIGURATION = { 76 | "bert-base-uncased": {"do_lower_case": True}, 77 | "bert-large-uncased": {"do_lower_case": True}, 78 | "bert-base-cased": {"do_lower_case": False}, 79 | "bert-large-cased": {"do_lower_case": False}, 80 | "bert-base-multilingual-uncased": {"do_lower_case": True}, 81 | "bert-base-multilingual-cased": {"do_lower_case": False}, 82 | "bert-base-chinese": {"do_lower_case": False}, 83 | "bert-base-german-cased": {"do_lower_case": False}, 84 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, 85 | "bert-large-cased-whole-word-masking": {"do_lower_case": False}, 86 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, 87 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, 88 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, 89 | "bert-base-german-dbmdz-cased": {"do_lower_case": False}, 90 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, 91 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, 92 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, 93 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, 94 | } 95 | 96 | 97 | def load_vocab(vocab_file): 98 | """Loads a vocabulary file into a dictionary.""" 99 | vocab = collections.OrderedDict() 100 | with open(vocab_file, "r", encoding="utf-8") as reader: 101 | tokens = reader.readlines() 102 | for index, token in enumerate(tokens): 103 | token = token.rstrip("\n") 104 | vocab[token] = index 105 | return vocab 106 | 107 | 108 | def whitespace_tokenize(text): 109 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 110 | text = text.strip() 111 | if not text: 112 | return [] 113 | tokens = text.split() 114 | return tokens 115 | 116 | 117 | class BertTokenizer(PreTrainedTokenizer): 118 | r""" 119 | Constructs a BERT tokenizer. Based on WordPiece. 120 | 121 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users 122 | should refer to the superclass for more information regarding methods. 123 | 124 | Args: 125 | vocab_file (:obj:`string`): 126 | File containing the vocabulary. 127 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 128 | Whether to lowercase the input when tokenizing. 129 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 130 | Whether to do basic tokenization before WordPiece. 131 | never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`): 132 | Collection of tokens which will never be split during tokenization. Only has an effect when 133 | :obj:`do_basic_tokenize=True` 134 | unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): 135 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 136 | token instead. 137 | sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): 138 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences 139 | for sequence classification or for a text and a question for question answering. 140 | It is also used as the last token of a sequence built with special tokens. 141 | pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): 142 | The token used for padding, for example when batching sequences of different lengths. 143 | cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): 144 | The classifier token which is used when doing sequence classification (classification of the whole 145 | sequence instead of per-token classification). It is the first token of the sequence when built with 146 | special tokens. 147 | mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): 148 | The token used for masking values. This is the token used when training this model with masked language 149 | modeling. This is the token which the model will try to predict. 150 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 151 | Whether to tokenize Chinese characters. 152 | This should likely be deactivated for Japanese: 153 | see: https://github.com/huggingface/transformers/issues/328 154 | """ 155 | 156 | vocab_files_names = VOCAB_FILES_NAMES 157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 158 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 159 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 160 | 161 | def __init__( 162 | self, 163 | vocab_file, 164 | do_lower_case=True, 165 | do_basic_tokenize=True, 166 | never_split=None, 167 | unk_token="[UNK]", 168 | sep_token="[SEP]", 169 | pad_token="[PAD]", 170 | cls_token="[CLS]", 171 | mask_token="[MASK]", 172 | tokenize_chinese_chars=True, 173 | **kwargs 174 | ): 175 | super().__init__( 176 | unk_token=unk_token, 177 | sep_token=sep_token, 178 | pad_token=pad_token, 179 | cls_token=cls_token, 180 | mask_token=mask_token, 181 | **kwargs, 182 | ) 183 | 184 | if not os.path.isfile(vocab_file): 185 | raise ValueError( 186 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 187 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) 188 | ) 189 | self.vocab = load_vocab(vocab_file) 190 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 191 | self.do_basic_tokenize = do_basic_tokenize 192 | if do_basic_tokenize: 193 | self.basic_tokenizer = BasicTokenizer( 194 | do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars 195 | ) 196 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 197 | 198 | @property 199 | def vocab_size(self): 200 | return len(self.vocab) 201 | 202 | def get_vocab(self): 203 | return dict(self.vocab, **self.added_tokens_encoder) 204 | 205 | def _tokenize(self, text): 206 | split_tokens = [] 207 | if self.do_basic_tokenize: 208 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 209 | 210 | # If the token is part of the never_split set 211 | if token in self.basic_tokenizer.never_split: 212 | split_tokens.append(token) 213 | else: 214 | split_tokens += self.wordpiece_tokenizer.tokenize(token) 215 | else: 216 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 217 | return split_tokens 218 | 219 | def _convert_token_to_id(self, token): 220 | """ Converts a token (str) in an id using the vocab. """ 221 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 222 | 223 | def _convert_id_to_token(self, index): 224 | """Converts an index (integer) in a token (str) using the vocab.""" 225 | return self.ids_to_tokens.get(index, self.unk_token) 226 | 227 | def convert_tokens_to_string(self, tokens): 228 | """ Converts a sequence of tokens (string) in a single string. """ 229 | out_string = " ".join(tokens).replace(" ##", "").strip() 230 | return out_string 231 | 232 | def build_inputs_with_special_tokens( 233 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 234 | ) -> List[int]: 235 | """ 236 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 237 | by concatenating and adding special tokens. 238 | A BERT sequence has the following format: 239 | 240 | - single sequence: ``[CLS] X [SEP]`` 241 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 242 | 243 | Args: 244 | token_ids_0 (:obj:`List[int]`): 245 | List of IDs to which the special tokens will be added 246 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 247 | Optional second list of IDs for sequence pairs. 248 | 249 | Returns: 250 | :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 251 | """ 252 | if token_ids_1 is None: 253 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 254 | cls = [self.cls_token_id] 255 | sep = [self.sep_token_id] 256 | return cls + token_ids_0 + sep + token_ids_1 + sep 257 | 258 | def get_special_tokens_mask( 259 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 260 | ) -> List[int]: 261 | """ 262 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 263 | special tokens using the tokenizer ``prepare_for_model`` method. 264 | 265 | Args: 266 | token_ids_0 (:obj:`List[int]`): 267 | List of ids. 268 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 269 | Optional second list of IDs for sequence pairs. 270 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 271 | Set to True if the token list is already formatted with special tokens for the model 272 | 273 | Returns: 274 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 275 | """ 276 | 277 | if already_has_special_tokens: 278 | if token_ids_1 is not None: 279 | raise ValueError( 280 | "You should not supply a second sequence if the provided sequence of " 281 | "ids is already formated with special tokens for the model." 282 | ) 283 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 284 | 285 | if token_ids_1 is not None: 286 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 287 | return [1] + ([0] * len(token_ids_0)) + [1] 288 | 289 | def create_token_type_ids_from_sequences( 290 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 291 | ) -> List[int]: 292 | """ 293 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 294 | A BERT sequence pair mask has the following format: 295 | 296 | :: 297 | 298 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 299 | | first sequence | second sequence | 300 | 301 | if token_ids_1 is None, only returns the first portion of the mask (0's). 302 | 303 | Args: 304 | token_ids_0 (:obj:`List[int]`): 305 | List of ids. 306 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 307 | Optional second list of IDs for sequence pairs. 308 | 309 | Returns: 310 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 311 | sequence(s). 312 | """ 313 | sep = [self.sep_token_id] 314 | cls = [self.cls_token_id] 315 | if token_ids_1 is None: 316 | return len(cls + token_ids_0 + sep) * [0] 317 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 318 | 319 | def save_vocabulary(self, vocab_path): 320 | """ 321 | Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. 322 | 323 | Args: 324 | vocab_path (:obj:`str`): 325 | The directory in which to save the vocabulary. 326 | 327 | Returns: 328 | :obj:`Tuple(str)`: Paths to the files saved. 329 | """ 330 | index = 0 331 | if os.path.isdir(vocab_path): 332 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) 333 | else: 334 | vocab_file = vocab_path 335 | with open(vocab_file, "w", encoding="utf-8") as writer: 336 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 337 | if index != token_index: 338 | logger.warning( 339 | "Saving vocabulary to {}: vocabulary indices are not consecutive." 340 | " Please check that the vocabulary is not corrupted!".format(vocab_file) 341 | ) 342 | index = token_index 343 | writer.write(token + "\n") 344 | index += 1 345 | return (vocab_file,) 346 | 347 | 348 | class BasicTokenizer(object): 349 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 350 | 351 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 352 | """ Constructs a BasicTokenizer. 353 | 354 | Args: 355 | **do_lower_case**: Whether to lower case the input. 356 | **never_split**: (`optional`) list of str 357 | Kept for backward compatibility purposes. 358 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 359 | List of token not to split. 360 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 361 | Whether to tokenize Chinese characters. 362 | This should likely be deactivated for Japanese: 363 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 364 | """ 365 | if never_split is None: 366 | never_split = [] 367 | self.do_lower_case = do_lower_case 368 | self.never_split = set(never_split) 369 | self.tokenize_chinese_chars = tokenize_chinese_chars 370 | 371 | def tokenize(self, text, never_split=None): 372 | """ Basic Tokenization of a piece of text. 373 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 374 | 375 | Args: 376 | **never_split**: (`optional`) list of str 377 | Kept for backward compatibility purposes. 378 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 379 | List of token not to split. 380 | """ 381 | # union() returns a new set by concatenating the two sets. 382 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split 383 | 384 | # This was added on November 1st, 2018 for the multilingual and Chinese 385 | # models. This is also applied to the English models now, but it doesn't 386 | # matter since the English models were not trained on any Chinese data 387 | # and generally don't have any Chinese data in them (there are Chinese 388 | # characters in the vocabulary because Wikipedia does have some Chinese 389 | # words in the English Wikipedia.). 390 | if self.tokenize_chinese_chars: 391 | text = self._tokenize_chinese_chars(text) 392 | orig_tokens = whitespace_tokenize(text) 393 | split_tokens = [] 394 | for token in orig_tokens: 395 | if self.do_lower_case and token not in never_split: 396 | token = token.lower() 397 | token = self._run_strip_accents(token) 398 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 399 | 400 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 401 | return output_tokens 402 | 403 | def _run_strip_accents(self, text): 404 | """Strips accents from a piece of text.""" 405 | text = unicodedata.normalize("NFD", text) 406 | output = [] 407 | for char in text: 408 | cat = unicodedata.category(char) 409 | if cat == "Mn": 410 | continue 411 | output.append(char) 412 | return "".join(output) 413 | 414 | def _run_split_on_punc(self, text, never_split=None): 415 | """Splits punctuation on a piece of text.""" 416 | if never_split is not None and text in never_split: 417 | return [text] 418 | chars = list(text) 419 | i = 0 420 | start_new_word = True 421 | output = [] 422 | while i < len(chars): 423 | char = chars[i] 424 | if _is_punctuation(char): 425 | output.append([char]) 426 | start_new_word = True 427 | else: 428 | if start_new_word: 429 | output.append([]) 430 | start_new_word = False 431 | output[-1].append(char) 432 | i += 1 433 | 434 | return ["".join(x) for x in output] 435 | 436 | def _tokenize_chinese_chars(self, text): 437 | """Adds whitespace around any CJK character.""" 438 | output = [] 439 | for char in text: 440 | cp = ord(char) 441 | if self._is_chinese_char(cp): 442 | output.append(" ") 443 | output.append(char) 444 | output.append(" ") 445 | else: 446 | output.append(char) 447 | return "".join(output) 448 | 449 | def _is_chinese_char(self, cp): 450 | """Checks whether CP is the codepoint of a CJK character.""" 451 | # This defines a "chinese character" as anything in the CJK Unicode block: 452 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 453 | # 454 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 455 | # despite its name. The modern Korean Hangul alphabet is a different block, 456 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 457 | # space-separated words, so they are not treated specially and handled 458 | # like the all of the other languages. 459 | if ( 460 | (cp >= 0x4E00 and cp <= 0x9FFF) 461 | or (cp >= 0x3400 and cp <= 0x4DBF) # 462 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 463 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 464 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 465 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 466 | or (cp >= 0xF900 and cp <= 0xFAFF) 467 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 468 | ): # 469 | return True 470 | 471 | return False 472 | 473 | def _clean_text(self, text): 474 | """Performs invalid character removal and whitespace cleanup on text.""" 475 | output = [] 476 | for char in text: 477 | cp = ord(char) 478 | if cp == 0 or cp == 0xFFFD or _is_control(char): 479 | continue 480 | if _is_whitespace(char): 481 | output.append(" ") 482 | else: 483 | output.append(char) 484 | return "".join(output) 485 | 486 | 487 | class WordpieceTokenizer(object): 488 | """Runs WordPiece tokenization.""" 489 | 490 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 491 | self.vocab = vocab 492 | self.unk_token = unk_token 493 | self.max_input_chars_per_word = max_input_chars_per_word 494 | 495 | def tokenize(self, text): 496 | """Tokenizes a piece of text into its word pieces. 497 | 498 | This uses a greedy longest-match-first algorithm to perform tokenization 499 | using the given vocabulary. 500 | 501 | For example: 502 | input = "unaffable" 503 | output = ["un", "##aff", "##able"] 504 | 505 | Args: 506 | text: A single token or whitespace separated tokens. This should have 507 | already been passed through `BasicTokenizer`. 508 | 509 | Returns: 510 | A list of wordpiece tokens. 511 | """ 512 | 513 | output_tokens = [] 514 | for token in whitespace_tokenize(text): 515 | chars = list(token) 516 | if len(chars) > self.max_input_chars_per_word: 517 | output_tokens.append(self.unk_token) 518 | continue 519 | 520 | is_bad = False 521 | start = 0 522 | sub_tokens = [] 523 | while start < len(chars): 524 | end = len(chars) 525 | cur_substr = None 526 | while start < end: 527 | substr = "".join(chars[start:end]) 528 | if start > 0: 529 | substr = "##" + substr 530 | if substr in self.vocab: 531 | cur_substr = substr 532 | break 533 | end -= 1 534 | if cur_substr is None: 535 | is_bad = True 536 | break 537 | sub_tokens.append(cur_substr) 538 | start = end 539 | 540 | if is_bad: 541 | output_tokens.append(self.unk_token) 542 | else: 543 | output_tokens.extend(sub_tokens) 544 | return output_tokens 545 | 546 | 547 | -------------------------------------------------------------------------------- /competition/workspace/bert_tokenizer/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization classes for python tokenizers. 16 | For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py 17 | """ 18 | 19 | import itertools 20 | import logging 21 | import re 22 | import unicodedata 23 | from typing import Dict, List, Optional, Tuple, Union 24 | 25 | from .tokenization_utils_base import ( 26 | BatchEncoding, 27 | EncodedInput, 28 | EncodedInputPair, 29 | PaddingStrategy, 30 | PreTokenizedInput, 31 | PreTokenizedInputPair, 32 | PreTrainedTokenizerBase, 33 | TensorType, 34 | TextInput, 35 | TextInputPair, 36 | TruncationStrategy, 37 | ) 38 | 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | def _is_whitespace(char): 44 | """Checks whether `chars` is a whitespace character.""" 45 | # \t, \n, and \r are technically contorl characters but we treat them 46 | # as whitespace since they are generally considered as such. 47 | if char == " " or char == "\t" or char == "\n" or char == "\r": 48 | return True 49 | cat = unicodedata.category(char) 50 | if cat == "Zs": 51 | return True 52 | return False 53 | 54 | 55 | def _is_control(char): 56 | """Checks whether `chars` is a control character.""" 57 | # These are technically control characters but we count them as whitespace 58 | # characters. 59 | if char == "\t" or char == "\n" or char == "\r": 60 | return False 61 | cat = unicodedata.category(char) 62 | if cat.startswith("C"): 63 | return True 64 | return False 65 | 66 | 67 | def _is_punctuation(char): 68 | """Checks whether `chars` is a punctuation character.""" 69 | cp = ord(char) 70 | # We treat all non-letter/number ASCII as punctuation. 71 | # Characters such as "^", "$", and "`" are not in the Unicode 72 | # Punctuation class but we treat them as punctuation anyways, for 73 | # consistency. 74 | if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): 75 | return True 76 | cat = unicodedata.category(char) 77 | if cat.startswith("P"): 78 | return True 79 | return False 80 | 81 | 82 | def _is_end_of_word(text): 83 | """Checks whether the last character in text is one of a punctuation, control or whitespace character.""" 84 | last_char = text[-1] 85 | return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char)) 86 | 87 | 88 | def _is_start_of_word(text): 89 | """Checks whether the first character in text is one of a punctuation, control or whitespace character.""" 90 | first_char = text[0] 91 | return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char)) 92 | 93 | 94 | class PreTrainedTokenizer(PreTrainedTokenizerBase): 95 | """ Base class for all slow tokenizers. 96 | 97 | Handle all the shared methods for tokenization and special tokens as well as methods 98 | downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. 99 | 100 | This class also contain the added tokens in a unified way on top of all tokenizers so we don't 101 | have to handle the specific vocabulary augmentation methods of the various underlying 102 | dictionary structures (BPE, sentencepiece...). 103 | 104 | Class attributes (overridden by derived classes): 105 | 106 | - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file 107 | required by the model, and as associated values, the filename for saving the associated file (string). 108 | - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys 109 | being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the 110 | `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the 111 | associated pretrained vocabulary file. 112 | - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained 113 | models, and as associated values, the maximum length of the sequence inputs of this model, or None if the 114 | model has no maximum input size. 115 | - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the 116 | pretrained models, and as associated values, a dictionnary of specific arguments to pass to the 117 | ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the 118 | ``from_pretrained()`` method. 119 | 120 | Args: 121 | - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model. 122 | When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated 123 | model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`). 124 | no associated max_length can be found in ``max_model_input_sizes``. 125 | - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied. 126 | Should be selected between ['right', 'left'] 127 | - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the 128 | model ("token_type_ids", "attention_mask"...). 129 | - ``bos_token``: (`Optional`) string: a beginning of sentence token. 130 | Will be associated to ``self.bos_token`` and ``self.bos_token_id`` 131 | - ``eos_token``: (`Optional`) string: an end of sentence token. 132 | Will be associated to ``self.eos_token`` and ``self.eos_token_id`` 133 | - ``unk_token``: (`Optional`) string: an unknown token. 134 | Will be associated to ``self.unk_token`` and ``self.unk_token_id`` 135 | - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). 136 | Will be associated to ``self.sep_token`` and ``self.sep_token_id`` 137 | - ``pad_token``: (`Optional`) string: a padding token. 138 | Will be associated to ``self.pad_token`` and ``self.pad_token_id`` 139 | - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence 140 | leveraging self-attention along the full depth of the model). 141 | Will be associated to ``self.cls_token`` and ``self.cls_token_id`` 142 | - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language 143 | modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id`` 144 | - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. 145 | Adding all special tokens here ensure they won't be split by the tokenization process. 146 | Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids`` 147 | 148 | 149 | .. automethod:: __call__ 150 | """ 151 | 152 | def __init__(self, **kwargs): 153 | super().__init__(**kwargs) 154 | 155 | # Added tokens - We store this for both slow and fast tokenizers 156 | # until the serialization of Fast tokenizers is updated 157 | self.added_tokens_encoder: Dict[str, int] = {} 158 | self.added_tokens_decoder: Dict[int, str] = {} 159 | self.unique_no_split_tokens: List[str] = [] 160 | 161 | @property 162 | def is_fast(self) -> bool: 163 | return False 164 | 165 | @property 166 | def vocab_size(self) -> int: 167 | """ Size of the base vocabulary (without the added tokens) """ 168 | raise NotImplementedError 169 | 170 | def get_vocab(self): 171 | """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """ 172 | raise NotImplementedError() 173 | 174 | def get_added_vocab(self) -> Dict[str, int]: 175 | return self.added_tokens_encoder 176 | 177 | def __len__(self): 178 | """ Size of the full vocabulary with the added tokens """ 179 | return self.vocab_size + len(self.added_tokens_encoder) 180 | 181 | def _add_tokens(self, new_tokens, special_tokens=False) -> int: 182 | """ 183 | Add a list of new tokens to the tokenizer class. If the new tokens are not in the 184 | vocabulary, they are added to it with indices starting from length of the current vocabulary. 185 | 186 | Args: 187 | new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not 188 | already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). 189 | 190 | Returns: 191 | Number of tokens added to the vocabulary. 192 | 193 | Examples:: 194 | 195 | # Let's see how to increase the vocabulary of Bert model and tokenizer 196 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 197 | model = BertModel.from_pretrained('bert-base-uncased') 198 | 199 | num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) 200 | print('We have added', num_added_toks, 'tokens') 201 | model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. 202 | """ 203 | new_tokens = [str(tok) for tok in new_tokens] 204 | 205 | tokens_to_add = [] 206 | for token in new_tokens: 207 | assert isinstance(token, str) 208 | if not special_tokens and self.init_kwargs.get("do_lower_case", False): 209 | token = token.lower() 210 | if ( 211 | token != self.unk_token 212 | and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) 213 | and token not in tokens_to_add 214 | ): 215 | tokens_to_add.append(token) 216 | if self.verbose: 217 | logger.info("Adding %s to the vocabulary", token) 218 | 219 | added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add)) 220 | added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} 221 | self.added_tokens_encoder.update(added_tok_encoder) 222 | self.added_tokens_decoder.update(added_tok_decoder) 223 | 224 | # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert) 225 | if special_tokens: 226 | self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens))) 227 | else: 228 | # Or on the newly added tokens 229 | self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add))) 230 | 231 | return len(tokens_to_add) 232 | 233 | def num_special_tokens_to_add(self, pair=False): 234 | """ 235 | Returns the number of added tokens when encoding a sequence with special tokens. 236 | 237 | Note: 238 | This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this 239 | inside your training loop. 240 | 241 | Args: 242 | pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the 243 | number of added tokens in the case of a single sequence if set to False. 244 | 245 | Returns: 246 | Number of tokens added to sequences 247 | """ 248 | token_ids_0 = [] 249 | token_ids_1 = [] 250 | return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) 251 | 252 | def tokenize(self, text: TextInput, **kwargs): 253 | """ Converts a string in a sequence of tokens (string), using the tokenizer. 254 | Split in words for word-based vocabulary or sub-words for sub-word-based 255 | vocabularies (BPE/SentencePieces/WordPieces). 256 | 257 | Take care of added tokens. 258 | 259 | Args: 260 | text (:obj:`string`): The sequence to be encoded. 261 | **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method. 262 | """ 263 | # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors 264 | all_special_tokens_extended = dict() 265 | 266 | text, kwargs = self.prepare_for_tokenization(text, **kwargs) 267 | 268 | if kwargs: 269 | logger.warning(f"Keyword arguments {kwargs} not recognized.") 270 | 271 | # TODO: should this be in the base class? 272 | if self.init_kwargs.get("do_lower_case", False): 273 | # convert non-special tokens to lowercase 274 | escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens] 275 | pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" 276 | text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) 277 | 278 | def split_on_token(tok, text): 279 | result = [] 280 | tok_extended = all_special_tokens_extended.get(tok, None) 281 | split_text = text.split(tok) 282 | full_word = "" 283 | for i, sub_text in enumerate(split_text): 284 | # AddedToken can control whitespace stripping around them. 285 | # We use them for GPT2 and Roberta to have different behavior depending on the special token 286 | # Cf. https://github.com/huggingface/transformers/pull/2778 287 | # and https://github.com/huggingface/transformers/issues/3788 288 | # We strip left and right by default 289 | if i < len(split_text) - 1: 290 | sub_text = sub_text.rstrip() 291 | if i > 0: 292 | sub_text = sub_text.lstrip() 293 | 294 | if i == 0 and not sub_text: 295 | result += [tok] 296 | elif i == len(split_text) - 1: 297 | if sub_text: 298 | result += [sub_text] 299 | else: 300 | pass 301 | else: 302 | if sub_text: 303 | result += [sub_text] 304 | result += [tok] 305 | return result 306 | 307 | def split_on_tokens(tok_list, text): 308 | if not text.strip(): 309 | return [] 310 | if not tok_list: 311 | return self._tokenize(text) 312 | 313 | tokenized_text = [] 314 | text_list = [text] 315 | for tok in tok_list: 316 | tokenized_text = [] 317 | for sub_text in text_list: 318 | if sub_text not in self.unique_no_split_tokens: 319 | tokenized_text += split_on_token(tok, sub_text) 320 | else: 321 | tokenized_text += [sub_text] 322 | text_list = tokenized_text 323 | 324 | return list( 325 | itertools.chain.from_iterable( 326 | ( 327 | self._tokenize(token) if token not in self.unique_no_split_tokens else [token] 328 | for token in tokenized_text 329 | ) 330 | ) 331 | ) 332 | 333 | no_split_token = self.unique_no_split_tokens 334 | tokenized_text = split_on_tokens(no_split_token, text) 335 | return tokenized_text 336 | 337 | def _tokenize(self, text, **kwargs): 338 | """ Converts a string in a sequence of tokens (string), using the tokenizer. 339 | Split in words for word-based vocabulary or sub-words for sub-word-based 340 | vocabularies (BPE/SentencePieces/WordPieces). 341 | 342 | Do NOT take care of added tokens. 343 | """ 344 | raise NotImplementedError 345 | 346 | def convert_tokens_to_ids(self, tokens): 347 | """ Converts a token string (or a sequence of tokens) in a single integer id 348 | (or a sequence of ids), using the vocabulary. 349 | """ 350 | if tokens is None: 351 | return None 352 | 353 | if isinstance(tokens, str): 354 | return self._convert_token_to_id_with_added_voc(tokens) 355 | 356 | ids = [] 357 | for token in tokens: 358 | ids.append(self._convert_token_to_id_with_added_voc(token)) 359 | return ids 360 | 361 | def _convert_token_to_id_with_added_voc(self, token): 362 | if token is None: 363 | return None 364 | 365 | if token in self.added_tokens_encoder: 366 | return self.added_tokens_encoder[token] 367 | return self._convert_token_to_id(token) 368 | 369 | def _convert_token_to_id(self, token): 370 | raise NotImplementedError 371 | 372 | def _encode_plus( 373 | self, 374 | text: Union[TextInput, PreTokenizedInput, EncodedInput], 375 | text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, 376 | add_special_tokens: bool = True, 377 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, 378 | truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, 379 | max_length: Optional[int] = None, 380 | stride: int = 0, 381 | is_pretokenized: bool = False, 382 | pad_to_multiple_of: Optional[int] = None, 383 | return_tensors: Optional[Union[str, TensorType]] = None, 384 | return_token_type_ids: Optional[bool] = None, 385 | return_attention_mask: Optional[bool] = None, 386 | return_overflowing_tokens: bool = False, 387 | return_special_tokens_mask: bool = False, 388 | return_offsets_mapping: bool = False, 389 | return_length: bool = False, 390 | verbose: bool = True, 391 | **kwargs 392 | ) -> BatchEncoding: 393 | def get_input_ids(text): 394 | if isinstance(text, str): 395 | tokens = self.tokenize(text, **kwargs) 396 | return self.convert_tokens_to_ids(tokens) 397 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): 398 | if is_pretokenized: 399 | tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text))) 400 | return self.convert_tokens_to_ids(tokens) 401 | else: 402 | return self.convert_tokens_to_ids(text) 403 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): 404 | return text 405 | else: 406 | if is_pretokenized: 407 | raise ValueError( 408 | f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`." 409 | ) 410 | else: 411 | raise ValueError( 412 | f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." 413 | ) 414 | 415 | if return_offsets_mapping: 416 | raise NotImplementedError( 417 | "return_offset_mapping is not available when using Python tokenizers." 418 | "To use this feature, change your tokenizer to one deriving from " 419 | "transformers.PreTrainedTokenizerFast." 420 | "More information on available tokenizers at " 421 | "https://github.com/huggingface/transformers/pull/2674" 422 | ) 423 | 424 | first_ids = get_input_ids(text) 425 | second_ids = get_input_ids(text_pair) if text_pair is not None else None 426 | 427 | return self.prepare_for_model( 428 | first_ids, 429 | pair_ids=second_ids, 430 | add_special_tokens=add_special_tokens, 431 | padding=padding_strategy.value, 432 | truncation=truncation_strategy.value, 433 | max_length=max_length, 434 | stride=stride, 435 | pad_to_multiple_of=pad_to_multiple_of, 436 | return_tensors=return_tensors, 437 | prepend_batch_axis=True, 438 | return_attention_mask=return_attention_mask, 439 | return_token_type_ids=return_token_type_ids, 440 | return_overflowing_tokens=return_overflowing_tokens, 441 | return_special_tokens_mask=return_special_tokens_mask, 442 | return_length=return_length, 443 | verbose=verbose, 444 | ) 445 | 446 | def _batch_encode_plus( 447 | self, 448 | batch_text_or_text_pairs: Union[ 449 | List[TextInput], 450 | List[TextInputPair], 451 | List[PreTokenizedInput], 452 | List[PreTokenizedInputPair], 453 | List[EncodedInput], 454 | List[EncodedInputPair], 455 | ], 456 | add_special_tokens: bool = True, 457 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, 458 | truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, 459 | max_length: Optional[int] = None, 460 | stride: int = 0, 461 | is_pretokenized: bool = False, 462 | pad_to_multiple_of: Optional[int] = None, 463 | return_tensors: Optional[Union[str, TensorType]] = None, 464 | return_token_type_ids: Optional[bool] = None, 465 | return_attention_mask: Optional[bool] = None, 466 | return_overflowing_tokens: bool = False, 467 | return_special_tokens_mask: bool = False, 468 | return_offsets_mapping: bool = False, 469 | return_length: bool = False, 470 | verbose: bool = True, 471 | **kwargs 472 | ) -> BatchEncoding: 473 | def get_input_ids(text): 474 | if isinstance(text, str): 475 | tokens = self.tokenize(text, **kwargs) 476 | return self.convert_tokens_to_ids(tokens) 477 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): 478 | if is_pretokenized: 479 | tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text))) 480 | return self.convert_tokens_to_ids(tokens) 481 | else: 482 | return self.convert_tokens_to_ids(text) 483 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): 484 | return text 485 | else: 486 | raise ValueError( 487 | "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." 488 | ) 489 | 490 | if return_offsets_mapping: 491 | raise NotImplementedError( 492 | "return_offset_mapping is not available when using Python tokenizers." 493 | "To use this feature, change your tokenizer to one deriving from " 494 | "transformers.PreTrainedTokenizerFast." 495 | ) 496 | 497 | input_ids = [] 498 | for ids_or_pair_ids in batch_text_or_text_pairs: 499 | if not isinstance(ids_or_pair_ids, (list, tuple)): 500 | ids, pair_ids = ids_or_pair_ids, None 501 | elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)): 502 | ids, pair_ids = ids_or_pair_ids, None 503 | else: 504 | ids, pair_ids = ids_or_pair_ids 505 | 506 | first_ids = get_input_ids(ids) 507 | second_ids = get_input_ids(pair_ids) if pair_ids is not None else None 508 | input_ids.append((first_ids, second_ids)) 509 | 510 | batch_outputs = self._batch_prepare_for_model( 511 | input_ids, 512 | add_special_tokens=add_special_tokens, 513 | padding_strategy=padding_strategy, 514 | truncation_strategy=truncation_strategy, 515 | max_length=max_length, 516 | stride=stride, 517 | pad_to_multiple_of=pad_to_multiple_of, 518 | return_attention_mask=return_attention_mask, 519 | return_token_type_ids=return_token_type_ids, 520 | return_overflowing_tokens=return_overflowing_tokens, 521 | return_special_tokens_mask=return_special_tokens_mask, 522 | return_length=return_length, 523 | return_tensors=return_tensors, 524 | verbose=verbose, 525 | ) 526 | 527 | return BatchEncoding(batch_outputs) 528 | 529 | def _batch_prepare_for_model( 530 | self, 531 | batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]], 532 | add_special_tokens: bool = True, 533 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, 534 | truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, 535 | max_length: Optional[int] = None, 536 | stride: int = 0, 537 | pad_to_multiple_of: Optional[int] = None, 538 | return_tensors: Optional[str] = None, 539 | return_token_type_ids: Optional[bool] = None, 540 | return_attention_mask: Optional[bool] = None, 541 | return_overflowing_tokens: bool = False, 542 | return_special_tokens_mask: bool = False, 543 | return_length: bool = False, 544 | verbose: bool = True, 545 | ) -> BatchEncoding: 546 | """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. 547 | It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and 548 | manages a moving window (with user defined stride) for overflowing tokens 549 | 550 | Args: 551 | batch_ids_pairs: list of tokenized input ids or input ids pairs 552 | """ 553 | 554 | batch_outputs = {} 555 | for first_ids, second_ids in batch_ids_pairs: 556 | outputs = self.prepare_for_model( 557 | first_ids, 558 | second_ids, 559 | add_special_tokens=add_special_tokens, 560 | padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward 561 | truncation=truncation_strategy.value, 562 | max_length=max_length, 563 | stride=stride, 564 | pad_to_multiple_of=None, # we pad in batch afterward 565 | return_attention_mask=False, # we pad in batch afterward 566 | return_token_type_ids=return_token_type_ids, 567 | return_overflowing_tokens=return_overflowing_tokens, 568 | return_special_tokens_mask=return_special_tokens_mask, 569 | return_length=return_length, 570 | return_tensors=None, # We convert the whole batch to tensors at the end 571 | prepend_batch_axis=False, 572 | verbose=verbose, 573 | ) 574 | 575 | for key, value in outputs.items(): 576 | if key not in batch_outputs: 577 | batch_outputs[key] = [] 578 | batch_outputs[key].append(value) 579 | 580 | batch_outputs = self.pad( 581 | batch_outputs, 582 | padding=padding_strategy.value, 583 | max_length=max_length, 584 | pad_to_multiple_of=pad_to_multiple_of, 585 | return_attention_mask=return_attention_mask, 586 | ) 587 | 588 | batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) 589 | 590 | return batch_outputs 591 | 592 | def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict): 593 | """ Performs any necessary transformations before tokenization. 594 | 595 | This method should pop the arguments from kwargs and return kwargs as well. 596 | We test kwargs at the end of the encoding process to be sure all the arguments have been used. 597 | """ 598 | return (text, kwargs) 599 | 600 | def get_special_tokens_mask( 601 | self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False 602 | ) -> List[int]: 603 | """ 604 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 605 | special tokens using the tokenizer ``prepare_for_model`` method. 606 | 607 | Args: 608 | token_ids_0: list of ids (must not contain special tokens) 609 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids 610 | for sequence pairs 611 | already_has_special_tokens: (default False) Set to True if the token list is already formated with 612 | special tokens for the model 613 | 614 | Returns: 615 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 616 | """ 617 | return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) 618 | 619 | def convert_ids_to_tokens( 620 | self, ids: Union[int, List[int]], skip_special_tokens: bool = False 621 | ) -> Union[str, List[str]]: 622 | """ Converts a single index or a sequence of indices (integers) in a token " 623 | (resp.) a sequence of tokens (str), using the vocabulary and added tokens. 624 | 625 | Args: 626 | skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False 627 | """ 628 | if isinstance(ids, int): 629 | if ids in self.added_tokens_decoder: 630 | return self.added_tokens_decoder[ids] 631 | else: 632 | return self._convert_id_to_token(ids) 633 | tokens = [] 634 | for index in ids: 635 | index = int(index) 636 | if skip_special_tokens and index in self.all_special_ids: 637 | continue 638 | if index in self.added_tokens_decoder: 639 | tokens.append(self.added_tokens_decoder[index]) 640 | else: 641 | tokens.append(self._convert_id_to_token(index)) 642 | return tokens 643 | 644 | def _convert_id_to_token(self, index: int) -> str: 645 | raise NotImplementedError 646 | 647 | def convert_tokens_to_string(self, tokens: List[str]) -> str: 648 | """ Converts a sequence of tokens (string) in a single string. 649 | The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) 650 | but we often want to remove sub-word tokenization artifacts at the same time. 651 | """ 652 | return " ".join(self.convert_ids_to_tokens(tokens)) 653 | 654 | def decode( 655 | self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True 656 | ) -> str: 657 | filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) 658 | 659 | # To avoid mixing byte-level and unicode for byte-level BPT 660 | # we need to build string separatly for added tokens and byte-level tokens 661 | # cf. https://github.com/huggingface/transformers/issues/1133 662 | sub_texts = [] 663 | current_sub_text = [] 664 | for token in filtered_tokens: 665 | if skip_special_tokens and token in self.all_special_ids: 666 | continue 667 | if token in self.added_tokens_encoder: 668 | if current_sub_text: 669 | sub_texts.append(self.convert_tokens_to_string(current_sub_text)) 670 | current_sub_text = [] 671 | sub_texts.append(token) 672 | else: 673 | current_sub_text.append(token) 674 | if current_sub_text: 675 | sub_texts.append(self.convert_tokens_to_string(current_sub_text)) 676 | text = " ".join(sub_texts) 677 | 678 | if clean_up_tokenization_spaces: 679 | clean_text = self.clean_up_tokenization(text) 680 | return clean_text 681 | else: 682 | return text 683 | 684 | def save_vocabulary(self, save_directory) -> Tuple[str]: 685 | """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens 686 | and special token mappings. 687 | 688 | Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full 689 | Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` 690 | class method. 691 | """ 692 | raise NotImplementedError -------------------------------------------------------------------------------- /competition/workspace/infer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | import os 7 | import glob 8 | import json 9 | import requests 10 | from argparse import ArgumentParser 11 | from timeit import default_timer as timer 12 | 13 | from bert_tokenizer import BertTokenizer 14 | 15 | 16 | def read_txt(file_path): 17 | data = [] 18 | with open(file_path) as f: 19 | for line in f: 20 | ids = [int(x) for x in line.rstrip().split(" ")] 21 | data.append(ids) 22 | return data 23 | 24 | 25 | def get_api_input(signature_name, input_): 26 | if "token_type_ids" in input_: 27 | del input_["token_type_ids"] 28 | 29 | if type(input_) != dict: 30 | input_ = dict(input_) 31 | 32 | return json.dumps({ 33 | "signature_name": signature_name, 34 | "inputs": input_, 35 | }) 36 | 37 | 38 | def get_output(url, data): 39 | return requests.post(url, data).json()["outputs"] 40 | 41 | 42 | def get_score_input(token_logits, relevace_logits, attn_mask, passage_score_weight): 43 | input_ = { 44 | "token_logits": token_logits, 45 | "relevance_logits": relevace_logits, 46 | "attn_mask": attn_mask, 47 | "passage_score_weight": passage_score_weight, 48 | } 49 | return json.dumps({ 50 | "signature_name": "get_score", 51 | "inputs": input_, 52 | }) 53 | 54 | 55 | def get_retriever_input(question_str, tokenizer, top_k): 56 | input_ = tokenizer([question_str], max_length=256, truncation=True) 57 | input_.update({"top_k": top_k}) 58 | return get_api_input("retrieve", input_) 59 | 60 | 61 | def get_reader_input(question_str, tokenizer, retriever_output): 62 | input_ids = [] 63 | attention_mask = [] 64 | 65 | retrieved_titles = retriever_output["titles"] 66 | retrieved_docs = retriever_output["docs"] 67 | 68 | question = tokenizer.encode(question_str, max_length=256, truncation=True) 69 | for title, doc in zip(retrieved_titles, retrieved_docs): 70 | concat = question + title + [tokenizer.sep_token_id] + doc 71 | concat = concat[:350] 72 | input_ids.append(concat) 73 | max_len = max(len(ids) for ids in input_ids) 74 | for i in range(len(input_ids)): 75 | padding = [0] * (max_len - len(input_ids[i])) 76 | attention_mask.append([1] * len(input_ids[i]) + padding) 77 | input_ids[i] = input_ids[i] + padding 78 | return { 79 | "input_ids": input_ids, 80 | "attention_mask": attention_mask, 81 | } 82 | 83 | 84 | def is_sub_word_id(tokenizer, token_id): 85 | token = tokenizer.convert_ids_to_tokens([token_id])[0] 86 | return token.startswith("##") or token.startswith(" ##") 87 | 88 | 89 | def extend_span_to_full_words(tokenizer, tokens, span): 90 | start_index, end_index = span 91 | max_len = len(tokens) 92 | while start_index > 0 and is_sub_word_id(tokenizer, tokens[start_index]): 93 | start_index -= 1 94 | 95 | while end_index < max_len - 1 and is_sub_word_id(tokenizer, tokens[end_index + 1]): 96 | end_index += 1 97 | 98 | return start_index, end_index 99 | 100 | 101 | def get_answer_greedy(tokenizer, reader_input, reader_output): 102 | max_answer_length = 10 103 | input_ids = reader_input["input_ids"] 104 | relevance_logits = reader_output["relevance_logits"] 105 | if not isinstance(relevance_logits, (tuple, list)): 106 | relevance_logits = [relevance_logits] 107 | 108 | top_doc_idx = max(enumerate(relevance_logits), key=lambda x: x[1])[0] 109 | 110 | sequence_len = sum(id_ != 0 for id_ in input_ids[top_doc_idx]) 111 | passage_offset = input_ids[top_doc_idx].index(tokenizer.sep_token_id) + 1 112 | ctx_ids = input_ids[top_doc_idx][passage_offset:sequence_len] 113 | p_start_logits = reader_output["start_logits"][top_doc_idx][passage_offset:sequence_len] 114 | p_end_logits = reader_output["end_logits"][top_doc_idx][passage_offset:sequence_len] 115 | 116 | scores = [] 117 | for (i, s) in enumerate(p_start_logits): 118 | for (j, e) in enumerate(p_end_logits[i:i + max_answer_length]): 119 | scores.append(((i, i + j), s + e)) 120 | scores = sorted(scores, key=lambda x: x[1], reverse=True) 121 | 122 | chosen_span_intervals = [] 123 | 124 | answer = "" 125 | for (start_index, end_index), score in scores: 126 | assert start_index <= end_index 127 | length = end_index - start_index + 1 128 | assert length <= max_answer_length 129 | 130 | if any([start_index <= prev_start_index <= prev_end_index <= end_index or 131 | prev_start_index <= start_index <= end_index <= prev_end_index 132 | for (prev_start_index, prev_end_index) in chosen_span_intervals]): 133 | continue 134 | 135 | start_index, end_index = extend_span_to_full_words(tokenizer, ctx_ids, (start_index, end_index)) 136 | answer = tokenizer.decode(ctx_ids[start_index:end_index + 1], skip_special_tokens=True) 137 | break 138 | return answer 139 | 140 | 141 | def get_answer_deep(tokenizer, reader_input, reader_output, url, passage_score_weight): 142 | max_answer_length = 10 143 | input_ids = reader_input["input_ids"] 144 | attn_mask = reader_input["attention_mask"] 145 | relevance_logits = reader_output["relevance_logits"] 146 | if not isinstance(relevance_logits, (tuple, list)): 147 | relevance_logits = [relevance_logits] 148 | 149 | start_logits = get_output(url, get_score_input(reader_output["start_logits"], relevance_logits, attn_mask, passage_score_weight)) 150 | end_logits = get_output(url, get_score_input(reader_output["end_logits"], relevance_logits, attn_mask, passage_score_weight)) 151 | 152 | nbest = [] 153 | for passage_idx in range(len(input_ids)): 154 | sequence_len = sum(id_ != 0 for id_ in input_ids[passage_idx]) 155 | passage_offset = input_ids[passage_idx].index(tokenizer.sep_token_id) + 1 156 | ctx_ids = input_ids[passage_idx][passage_offset:sequence_len] 157 | 158 | p_start_logits = start_logits[passage_idx][passage_offset:sequence_len] 159 | p_end_logits = end_logits[passage_idx][passage_offset:sequence_len] 160 | 161 | scores = [] 162 | for (i, s) in enumerate(p_start_logits): 163 | for (j, e) in enumerate(p_end_logits[i:i + max_answer_length]): 164 | scores.append(((i, i + j), s + e)) 165 | 166 | scores = sorted(scores, key=lambda x: x[1], reverse=True) 167 | 168 | chosen_span_intervals = [] 169 | best_spans = [] 170 | 171 | for (start_index, end_index), score in scores: 172 | assert start_index <= end_index 173 | length = end_index - start_index + 1 174 | assert length <= max_answer_length 175 | 176 | if any([start_index <= prev_start_index <= prev_end_index <= end_index or 177 | prev_start_index <= start_index <= end_index <= prev_end_index 178 | for (prev_start_index, prev_end_index) in chosen_span_intervals]): 179 | continue 180 | 181 | start_index, end_index = extend_span_to_full_words(tokenizer, ctx_ids, (start_index, end_index)) 182 | 183 | title_offset = ctx_ids.index(tokenizer.sep_token_id) + 1 184 | 185 | context = ctx_ids[title_offset:] 186 | start_index -= title_offset 187 | end_index -= title_offset 188 | 189 | answer = tokenizer.decode(context[start_index:end_index + 1]) 190 | 191 | best_spans.append((answer, score)) 192 | if len(best_spans) == 5: 193 | break 194 | nbest.extend(best_spans) 195 | 196 | nbest = sorted(nbest, key=lambda x: x[1], reverse=True) 197 | return nbest[0][0] 198 | 199 | 200 | def get_retriever_output(retrieved_doc_ids, titles, docs): 201 | retrieved_titles = [titles[i] for i in retrieved_doc_ids] 202 | retrieved_docs = [docs[i] for i in retrieved_doc_ids] 203 | return { 204 | "titles": retrieved_titles, 205 | "docs": retrieved_docs, 206 | } 207 | 208 | 209 | def predict(url, tokenizer, titles, docs, question_str, top_k, passage_score_weight): 210 | question_str = question_str.lower() 211 | if question_str.endswith("?"): 212 | question_str = question_str[:-1] 213 | 214 | retrieved_doc_ids = get_output(url, get_retriever_input(question_str, tokenizer, top_k)) 215 | retriever_output = get_retriever_output(retrieved_doc_ids, titles, docs) 216 | reader_input = get_reader_input(question_str, tokenizer, retriever_output) 217 | 218 | reader_output = get_output(url, get_api_input("read", reader_input)) 219 | if passage_score_weight is not None: 220 | answer = get_answer_deep(tokenizer, reader_input, reader_output, url, passage_score_weight) 221 | else: 222 | answer = get_answer_greedy(tokenizer, reader_input, reader_output) 223 | return answer 224 | 225 | 226 | def main(args): 227 | print(args) 228 | 229 | url = args.url 230 | titles_path = sorted(glob.glob(os.path.join(args.resources_path, "*.titles.txt")))[0] 231 | docs_path = sorted(glob.glob(os.path.join(args.resources_path, "*.docs.txt")))[0] 232 | 233 | print(titles_path) 234 | print(docs_path) 235 | 236 | titles = read_txt(titles_path) 237 | docs = read_txt(docs_path) 238 | 239 | tokenizer = BertTokenizer.from_pretrained("bert_tokenizer/mobilebert-uncased") 240 | 241 | with open(args.input_path, "r", encoding="utf-8") as input_file, open(args.output_path, "w", encoding="utf-8") as output_file: 242 | for line in input_file: 243 | start = timer() 244 | question_str = json.loads(line.strip())["question"] 245 | answer = predict(url, tokenizer, titles, docs, question_str, args.top_k, args.passage_score_weight) 246 | out = {"question": question_str, "prediction": answer.strip()} 247 | output_file.write(json.dumps(out) + "\n") 248 | print(str(out) + " (%.4fs)" % (timer() - start)) 249 | 250 | 251 | if __name__ == "__main__": 252 | parser = ArgumentParser() 253 | 254 | parser.add_argument( 255 | "--url", 256 | default=None, 257 | type=str, 258 | required=True, 259 | ) 260 | parser.add_argument( 261 | "--resources_path", 262 | default=None, 263 | type=str, 264 | required=True, 265 | ) 266 | parser.add_argument( 267 | "--input_path", 268 | default=None, 269 | type=str, 270 | required=True, 271 | ) 272 | parser.add_argument( 273 | "--output_path", 274 | type=str, 275 | default="predictions.txt", 276 | required=True, 277 | ) 278 | parser.add_argument( 279 | "--top_k", 280 | type=int, 281 | default=100, 282 | ) 283 | parser.add_argument( 284 | "--passage_score_weight", 285 | type=float, 286 | default=None, 287 | ) 288 | 289 | args = parser.parse_args() 290 | main(args) 291 | -------------------------------------------------------------------------------- /playground/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # coding=utf-8 3 | # minimal-rnr-qa 4 | # Copyright 2021-present NAVER Corp. 5 | # Apache License v2.0 6 | 7 | 8 | if [ "$#" -ne 2 ]; then 9 | echo "Usage: ./build.sh DATASET MODEL_TYPE" 10 | echo "DATASET: effqa | nq | trivia" 11 | echo "MODEL_TYPE: tfserving | tfserving_faiss | pytorch" 12 | exit 0 13 | fi 14 | 15 | 16 | DATASET=$1 # ["effqa", "nq", "trivia"] 17 | MODEL_TYPE=$2 # ["tfserving", "tfserving_faiss", "pytorch"] 18 | 19 | if [ "$MODEL_TYPE" = "pytorch" ]; then 20 | DOCKERFILE="Dockerfile-pytorch" 21 | else 22 | DOCKERFILE="Dockerfile-tfserving" 23 | fi 24 | 25 | 26 | docker build -f configs/$DOCKERFILE \ 27 | --build-arg DATASET=$DATASET \ 28 | --build-arg MODEL_TYPE=$MODEL_TYPE \ 29 | . -t minimal-rnr-qa:$DATASET-$MODEL_TYPE 30 | 31 | echo "minimal-rnr-qa:$DATASET-$MODEL_TYPE" 32 | -------------------------------------------------------------------------------- /playground/configs/Dockerfile-pytorch: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.4-cuda10.1-cudnn7-runtime 2 | 3 | ARG DATASET 4 | ARG MODEL_TYPE 5 | 6 | # Install python packages 7 | RUN pip install --no-cache-dir absl-py requests regex 8 | 9 | COPY configs/requirements-$MODEL_TYPE.txt /requirements.txt 10 | RUN pip install --no-cache-dir -r /requirements.txt 11 | 12 | COPY models/$MODEL_TYPE/$DATASET /models/$DATASET 13 | COPY resources/$DATASET /resources/$DATASET 14 | 15 | ENV DATASET=$DATASET 16 | ENV MODEL_TYPE=$MODEL_TYPE 17 | 18 | COPY workspace /workspace 19 | COPY entrypoint.sh /workspace/entrypoint.sh 20 | RUN chmod a+x /workspace/entrypoint.sh 21 | 22 | WORKDIR /workspace 23 | -------------------------------------------------------------------------------- /playground/configs/Dockerfile-tfserving: -------------------------------------------------------------------------------- 1 | FROM tensorflow/serving:2.3.0 as build_image 2 | FROM python:3.6.11-slim-buster 3 | 4 | ARG DATASET 5 | ARG MODEL_TYPE 6 | 7 | # Install TF Serving 8 | COPY --from=build_image /usr/bin/tensorflow_model_server /usr/bin/tensorflow_model_server 9 | 10 | # Install python packages 11 | RUN pip install --no-cache-dir absl-py requests regex 12 | 13 | COPY configs/requirements-$MODEL_TYPE.txt /requirements.txt 14 | RUN pip install --no-cache-dir -r /requirements.txt 15 | 16 | COPY models/$MODEL_TYPE/$DATASET/1 /models/minimal-rnr-qa/1 17 | COPY resources/$DATASET /resources/$DATASET 18 | 19 | ENV DATASET=$DATASET 20 | ENV MODEL_TYPE=$MODEL_TYPE 21 | 22 | COPY workspace /workspace 23 | COPY entrypoint.sh /workspace/entrypoint.sh 24 | RUN chmod a+x /workspace/entrypoint.sh 25 | 26 | WORKDIR /workspace 27 | -------------------------------------------------------------------------------- /playground/configs/requirements-pytorch.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | transformers==3.0.2 3 | faiss 4 | faiss-cpu 5 | numpy 6 | -------------------------------------------------------------------------------- /playground/configs/requirements-tfserving.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/configs/requirements-tfserving.txt -------------------------------------------------------------------------------- /playground/configs/requirements-tfserving_faiss.txt: -------------------------------------------------------------------------------- 1 | faiss 2 | faiss-cpu 3 | numpy -------------------------------------------------------------------------------- /playground/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # coding=utf-8 3 | # minimal-rnr-qa 4 | # Copyright 2021-present NAVER Corp. 5 | # Apache License v2.0 6 | 7 | 8 | if [ -z "$MODE" ]; then 9 | echo "--env MODE must be defined: demo | file" 10 | exit 0 11 | fi 12 | 13 | 14 | # launch tfserving 15 | if [ "$MODEL_TYPE" = "tfserving" ] || [ "$MODEL_TYPE" = "tfserving_faiss" ]; then 16 | /usr/bin/tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=minimal-rnr-qa --model_base_path=/models/minimal-rnr-qa & 17 | fi 18 | 19 | 20 | # set USE_FAISS 21 | if [ "$MODEL_TYPE" = "tfserving" ]; then 22 | USE_FAISS=false 23 | else 24 | USE_FAISS=true 25 | fi 26 | 27 | 28 | if [ "$MODE" = "demo" ]; then 29 | pip install --no-cache-dir flask tornado 30 | 31 | if [ "$MODEL_TYPE" = "tfserving" ] || [ "$MODEL_TYPE" = "tfserving_faiss" ]; then 32 | python -u run_tf_demo.py \ 33 | --dataset $DATASET \ 34 | `if [[ -n "${USE_FAISS}" ]]; then echo --use_faiss_index $USE_FAISS; fi` 35 | 36 | else 37 | python -u run_pt_demo.py \ 38 | --dataset $DATASET 39 | 40 | fi 41 | 42 | elif [ "$MODE" = "file" ]; then 43 | if [ "$MODEL_TYPE" = "tfserving" ] || [ "$MODEL_TYPE" = "tfserving_faiss" ]; then 44 | PYTHON_FILE="run_tf_inference.py" 45 | 46 | else 47 | PYTHON_FILE="run_pt_inference.py" 48 | fi 49 | 50 | python -u $PYTHON_FILE \ 51 | --dataset $DATASET \ 52 | `if [[ -n "${USE_FAISS}" ]]; then echo --use_faiss_index $USE_FAISS; fi` \ 53 | `if [[ -n "${TOP_K}" ]]; then echo --top_k $TOP_K; fi` \ 54 | `if [[ -n "${PASSAGE_W}" ]]; then echo --passage_score_weight $PASSAGE_W; fi` \ 55 | --input_path $1 --output_path $2 \ 56 | 57 | else 58 | echo "--env MODE must be: demo | file" 59 | exit 0 60 | fi 61 | -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/__init__.py -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/pytorch/__init__.py -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/pytorch/inferencer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | from minimal_rnr.utils.inferencer import MinimalRnR 7 | from minimal_rnr.pytorch.model import get_model_tokenizer_device 8 | 9 | 10 | class TorchMinimalRnR(MinimalRnR): 11 | def __init__(self, args): 12 | super(TorchMinimalRnR, self).__init__(args) 13 | 14 | assert args.use_faiss_index, "PyTorch model cannot be run without a faiss index." 15 | import torch 16 | self.torch = torch 17 | 18 | self.model, self.tokenizer, self.device = get_model_tokenizer_device(args) 19 | 20 | def get_question_encoding(self, question): 21 | input_ = self.tokenizer([question], max_length=self.max_retriever_input_len, 22 | truncation=True, return_token_type_ids=True, return_tensors="pt") 23 | torch_input = self._add_prefix_and_to_device(input_, "retriever_") 24 | 25 | with self.torch.no_grad(): 26 | question_encoding = self.model(**torch_input) 27 | np_question_encoding = question_encoding.cpu().numpy().astype("float32") 28 | 29 | return np_question_encoding 30 | 31 | def get_retriever_output(self, question, top_k): 32 | raise Exception("get_retriever_output is use for running without faiss index,", 33 | "which is not supported for torch models") 34 | 35 | def get_reader_output(self, reader_input): 36 | torch_input = self._add_prefix_and_to_device(reader_input, "reader_", convert_to_tensor=True) 37 | 38 | with self.torch.no_grad(): 39 | start_logits, end_logits, relevance_logits = self.model(**torch_input) 40 | 41 | relevance_logits = relevance_logits.squeeze() 42 | if relevance_logits.dim() == 0: # for top_k=1 43 | relevance_logits = relevance_logits.unsqueeze(0) 44 | 45 | # returned as cuda tensors (on purpose) 46 | return { 47 | "start_logits": start_logits.squeeze(-1), 48 | "end_logits": end_logits.squeeze(-1), 49 | "relevance_logits": relevance_logits, 50 | } 51 | 52 | def get_passage_score_weighted_answer_token_logits(self, token_logits, relevance_logits, attn_mask, passage_score_weight): 53 | attn_mask = self.torch.tensor(attn_mask).float() 54 | 55 | relevance_logits = relevance_logits.unsqueeze(1) # [M, 1] 56 | masked_token_logits = token_logits - 1e10 * (1.0 - attn_mask) 57 | log_span_prob = token_logits - masked_token_logits.logsumexp(dim=1, keepdim=True) # [M, L] softmaxed over L 58 | log_passage_prob = relevance_logits - relevance_logits.logsumexp(dim=0, keepdim=True) # [M, 1] softmaxed over M 59 | weighted_logits = log_span_prob * (1 - passage_score_weight) + log_passage_prob * passage_score_weight 60 | 61 | return weighted_logits 62 | 63 | def _add_prefix_and_to_device(self, data, prefix, convert_to_tensor=False): 64 | wrap = lambda x: self.torch.tensor(x) if convert_to_tensor else x 65 | data = {prefix + k: wrap(v).to(device=self.device) for k, v in data.items()} 66 | return data 67 | 68 | def maybe_tensor_to_list(self, tensor): 69 | if not isinstance(tensor, (list, tuple)): # torch tensor 70 | return tensor.cpu().tolist() 71 | return tensor 72 | -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/pytorch/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # ****************************************************************** 3 | # Copied and modified from https://github.com/facebookresearch/DPR * 4 | # ****************************************************************** 5 | # Copyright (c) Facebook, Inc. and its affiliates. 6 | # All rights reserved. 7 | # 8 | # This source code is licensed under the license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | 11 | from minimal_rnr.utils.inference import get_first_matched_file_path 12 | from minimal_rnr.utils.logger import get_logger 13 | 14 | 15 | def get_model_tokenizer_device(args): 16 | # not to import torch and transformers as default 17 | import torch 18 | from torch import Tensor as T 19 | from torch import nn 20 | from transformers import MobileBertModel, MobileBertConfig, AutoTokenizer 21 | 22 | 23 | def init_weights(modules): 24 | for module in modules: 25 | if isinstance(module, (nn.Linear, nn.Embedding)): 26 | module.weight.data.normal_(mean=0.0, std=0.02) 27 | elif isinstance(module, nn.LayerNorm): 28 | module.bias.data.zero_() 29 | module.weight.data.fill_(1.0) 30 | if isinstance(module, nn.Linear) and module.bias is not None: 31 | module.bias.data.zero_() 32 | 33 | 34 | class HFMobileBertEncoder(MobileBertModel): 35 | def __init__(self, config, project_dim: int = 0, ctx_bottleneck: bool = False): 36 | MobileBertModel.__init__(self, config) 37 | assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' 38 | self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None 39 | self.decode_proj = nn.Sequential( 40 | nn.Tanh(), 41 | nn.Linear(project_dim, (config.hidden_size + project_dim) // 2), 42 | nn.Tanh(), 43 | nn.Linear((config.hidden_size + project_dim) // 2, config.hidden_size), 44 | ) if ctx_bottleneck else None 45 | self.init_weights() 46 | 47 | @classmethod 48 | def init_encoder(cls, cfg_name: str) -> MobileBertModel: 49 | cfg = MobileBertConfig.from_pretrained(cfg_name) 50 | return cls.from_pretrained(cfg_name, config=cfg) 51 | 52 | def forward(self, input_ids: T, token_type_ids: T, attention_mask: T): 53 | if self.config.output_hidden_states: 54 | sequence_output, pooled_output, hidden_states = super().forward(input_ids=input_ids, 55 | token_type_ids=token_type_ids, 56 | attention_mask=attention_mask) 57 | else: 58 | hidden_states = None 59 | sequence_output, pooled_output = super().forward( 60 | input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 61 | 62 | pooled_output = sequence_output[:, 0, :] 63 | return sequence_output, pooled_output, hidden_states 64 | 65 | def get_out_size(self): 66 | if self.encode_proj: 67 | return self.encode_proj.out_features 68 | return self.config.hidden_size 69 | 70 | 71 | class UnifiedRetrieverReader(nn.Module): 72 | def __init__(self, encoder: nn.Module): 73 | super(UnifiedRetrieverReader, self).__init__() 74 | 75 | self.emb_size = 128 76 | 77 | self.question_model = encoder 78 | hidden_size = encoder.config.hidden_size 79 | 80 | self.qa_outputs = nn.Linear(hidden_size, 2) 81 | self.qa_classifier = nn.Linear(hidden_size, 1) 82 | 83 | init_weights([self.qa_outputs, self.qa_classifier]) 84 | 85 | @staticmethod 86 | def get_representation(sub_model: nn.Module, ids, segments, attn_mask): 87 | sequence_output = None 88 | pooled_output = None 89 | hidden_states = None 90 | if ids is not None: 91 | sequence_output, pooled_output, hidden_states = sub_model(ids, segments, attn_mask) 92 | 93 | return sequence_output, pooled_output, hidden_states 94 | 95 | def forward( 96 | self, retriever_input_ids=None, retriever_token_type_ids=None, retriever_attention_mask=None, 97 | reader_input_ids=None, reader_attention_mask=None, reader_token_type_ids=None): 98 | 99 | if retriever_input_ids is not None: 100 | _, encoding, _ = self.get_representation( 101 | self.question_model, retriever_input_ids, retriever_token_type_ids, retriever_attention_mask) 102 | 103 | if self.emb_size is not None: 104 | return encoding[:, :self.emb_size] 105 | return encoding 106 | 107 | if reader_input_ids is not None: 108 | start_logits, end_logits, relevance_logits = self._read(reader_input_ids, reader_token_type_ids, reader_attention_mask) 109 | return start_logits, end_logits, relevance_logits 110 | 111 | def _read(self, input_ids, token_type_ids, attention_mask): 112 | sequence_output, _pooled_output, _hidden_states = self.question_model(input_ids, token_type_ids, attention_mask) 113 | logits = self.qa_outputs(sequence_output) 114 | start_logits, end_logits = logits.split(1, dim=-1) 115 | start_logits = start_logits.squeeze(-1) 116 | end_logits = end_logits.squeeze(-1) 117 | 118 | qa_classifier_input = sequence_output[:, 0, :] 119 | relevance_logits = self.qa_classifier(qa_classifier_input) 120 | return start_logits, end_logits, relevance_logits 121 | 122 | 123 | cfg_name = "google/mobilebert-uncased" 124 | question_encoder = HFMobileBertEncoder.init_encoder(cfg_name) 125 | model = UnifiedRetrieverReader(question_encoder) 126 | tokenizer = AutoTokenizer.from_pretrained(cfg_name, do_lower_case=True) 127 | 128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 | 130 | model_file = get_first_matched_file_path(args.model_path, args.dataset, "*.bin") 131 | logger = get_logger("minimal-rnr-qa") 132 | logger.info(f"Loading model from {model_file}...") 133 | model.load_state_dict(torch.load(model_file, map_location=device)) 134 | model.to(device) 135 | model.eval() 136 | 137 | return model, tokenizer, device -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/tfserving/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/tfserving/__init__.py -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/tfserving/bert_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenization_bert import BertTokenizer 2 | 3 | __all__ = ['BertTokenizer'] -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/tfserving/bert_tokenizer/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import logging 8 | import os 9 | from functools import wraps 10 | from typing import Optional 11 | 12 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 13 | 14 | USE_TF = True 15 | USE_TORCH = False 16 | _torch_available = False 17 | _torch_tpu_available = False 18 | _psutil_available = False 19 | _py3nvml_available = False 20 | _has_apex = False 21 | 22 | try: 23 | import tensorflow as tf 24 | 25 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 26 | _tf_available = True # pylint: disable=invalid-name 27 | logger.info("TensorFlow version {} available.".format(tf.__version__)) 28 | except (ImportError, AssertionError): 29 | _tf_available = False # pylint: disable=invalid-name 30 | 31 | WEIGHTS_NAME = "pytorch_model.bin" 32 | TF2_WEIGHTS_NAME = "tf_model.h5" 33 | TF_WEIGHTS_NAME = "model.ckpt" 34 | CONFIG_NAME = "config.json" 35 | MODEL_CARD_NAME = "modelcard.json" 36 | 37 | 38 | def is_tf_available(): 39 | return _tf_available 40 | 41 | 42 | def cached_path( 43 | url_or_filename, 44 | ) -> Optional[str]: 45 | if os.path.exists(url_or_filename): 46 | output_path = url_or_filename 47 | else: 48 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 49 | 50 | return output_path 51 | 52 | 53 | def tf_required(func): 54 | # Chose a different decorator name than in tests so it's clear they are not the same. 55 | @wraps(func) 56 | def wrapper(*args, **kwargs): 57 | if is_tf_available(): 58 | return func(*args, **kwargs) 59 | else: 60 | raise ImportError(f"Method `{func.__name__}` requires TF.") 61 | 62 | return wrapper -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/tfserving/bert_tokenizer/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | 18 | import collections 19 | import logging 20 | import os 21 | import unicodedata 22 | from typing import List, Optional 23 | 24 | from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | "vocab_file": { 33 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 34 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 35 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 36 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 37 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 38 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 39 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 40 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 41 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 42 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 43 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 44 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 45 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 46 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", 47 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", 48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", 49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt", 50 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt", 51 | } 52 | } 53 | 54 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 55 | "bert-base-uncased": 512, 56 | "bert-large-uncased": 512, 57 | "bert-base-cased": 512, 58 | "bert-large-cased": 512, 59 | "bert-base-multilingual-uncased": 512, 60 | "bert-base-multilingual-cased": 512, 61 | "bert-base-chinese": 512, 62 | "bert-base-german-cased": 512, 63 | "bert-large-uncased-whole-word-masking": 512, 64 | "bert-large-cased-whole-word-masking": 512, 65 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512, 66 | "bert-large-cased-whole-word-masking-finetuned-squad": 512, 67 | "bert-base-cased-finetuned-mrpc": 512, 68 | "bert-base-german-dbmdz-cased": 512, 69 | "bert-base-german-dbmdz-uncased": 512, 70 | "TurkuNLP/bert-base-finnish-cased-v1": 512, 71 | "TurkuNLP/bert-base-finnish-uncased-v1": 512, 72 | "wietsedv/bert-base-dutch-cased": 512, 73 | } 74 | 75 | PRETRAINED_INIT_CONFIGURATION = { 76 | "bert-base-uncased": {"do_lower_case": True}, 77 | "bert-large-uncased": {"do_lower_case": True}, 78 | "bert-base-cased": {"do_lower_case": False}, 79 | "bert-large-cased": {"do_lower_case": False}, 80 | "bert-base-multilingual-uncased": {"do_lower_case": True}, 81 | "bert-base-multilingual-cased": {"do_lower_case": False}, 82 | "bert-base-chinese": {"do_lower_case": False}, 83 | "bert-base-german-cased": {"do_lower_case": False}, 84 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, 85 | "bert-large-cased-whole-word-masking": {"do_lower_case": False}, 86 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, 87 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, 88 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, 89 | "bert-base-german-dbmdz-cased": {"do_lower_case": False}, 90 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, 91 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, 92 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, 93 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, 94 | } 95 | 96 | 97 | def load_vocab(vocab_file): 98 | """Loads a vocabulary file into a dictionary.""" 99 | vocab = collections.OrderedDict() 100 | with open(vocab_file, "r", encoding="utf-8") as reader: 101 | tokens = reader.readlines() 102 | for index, token in enumerate(tokens): 103 | token = token.rstrip("\n") 104 | vocab[token] = index 105 | return vocab 106 | 107 | 108 | def whitespace_tokenize(text): 109 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 110 | text = text.strip() 111 | if not text: 112 | return [] 113 | tokens = text.split() 114 | return tokens 115 | 116 | 117 | class BertTokenizer(PreTrainedTokenizer): 118 | r""" 119 | Constructs a BERT tokenizer. Based on WordPiece. 120 | 121 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users 122 | should refer to the superclass for more information regarding methods. 123 | 124 | Args: 125 | vocab_file (:obj:`string`): 126 | File containing the vocabulary. 127 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 128 | Whether to lowercase the input when tokenizing. 129 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 130 | Whether to do basic tokenization before WordPiece. 131 | never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`): 132 | Collection of tokens which will never be split during tokenization. Only has an effect when 133 | :obj:`do_basic_tokenize=True` 134 | unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): 135 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 136 | token instead. 137 | sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): 138 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences 139 | for sequence classification or for a text and a question for question answering. 140 | It is also used as the last token of a sequence built with special tokens. 141 | pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): 142 | The token used for padding, for example when batching sequences of different lengths. 143 | cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): 144 | The classifier token which is used when doing sequence classification (classification of the whole 145 | sequence instead of per-token classification). It is the first token of the sequence when built with 146 | special tokens. 147 | mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): 148 | The token used for masking values. This is the token used when training this model with masked language 149 | modeling. This is the token which the model will try to predict. 150 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 151 | Whether to tokenize Chinese characters. 152 | This should likely be deactivated for Japanese: 153 | see: https://github.com/huggingface/transformers/issues/328 154 | """ 155 | 156 | vocab_files_names = VOCAB_FILES_NAMES 157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 158 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 159 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 160 | 161 | def __init__( 162 | self, 163 | vocab_file, 164 | do_lower_case=True, 165 | do_basic_tokenize=True, 166 | never_split=None, 167 | unk_token="[UNK]", 168 | sep_token="[SEP]", 169 | pad_token="[PAD]", 170 | cls_token="[CLS]", 171 | mask_token="[MASK]", 172 | tokenize_chinese_chars=True, 173 | **kwargs 174 | ): 175 | super().__init__( 176 | unk_token=unk_token, 177 | sep_token=sep_token, 178 | pad_token=pad_token, 179 | cls_token=cls_token, 180 | mask_token=mask_token, 181 | **kwargs, 182 | ) 183 | 184 | if not os.path.isfile(vocab_file): 185 | raise ValueError( 186 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 187 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) 188 | ) 189 | self.vocab = load_vocab(vocab_file) 190 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 191 | self.do_basic_tokenize = do_basic_tokenize 192 | if do_basic_tokenize: 193 | self.basic_tokenizer = BasicTokenizer( 194 | do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars 195 | ) 196 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 197 | 198 | @property 199 | def vocab_size(self): 200 | return len(self.vocab) 201 | 202 | def get_vocab(self): 203 | return dict(self.vocab, **self.added_tokens_encoder) 204 | 205 | def _tokenize(self, text): 206 | split_tokens = [] 207 | if self.do_basic_tokenize: 208 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 209 | 210 | # If the token is part of the never_split set 211 | if token in self.basic_tokenizer.never_split: 212 | split_tokens.append(token) 213 | else: 214 | split_tokens += self.wordpiece_tokenizer.tokenize(token) 215 | else: 216 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 217 | return split_tokens 218 | 219 | def _convert_token_to_id(self, token): 220 | """ Converts a token (str) in an id using the vocab. """ 221 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 222 | 223 | def _convert_id_to_token(self, index): 224 | """Converts an index (integer) in a token (str) using the vocab.""" 225 | return self.ids_to_tokens.get(index, self.unk_token) 226 | 227 | def convert_tokens_to_string(self, tokens): 228 | """ Converts a sequence of tokens (string) in a single string. """ 229 | out_string = " ".join(tokens).replace(" ##", "").strip() 230 | return out_string 231 | 232 | def build_inputs_with_special_tokens( 233 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 234 | ) -> List[int]: 235 | """ 236 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 237 | by concatenating and adding special tokens. 238 | A BERT sequence has the following format: 239 | 240 | - single sequence: ``[CLS] X [SEP]`` 241 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 242 | 243 | Args: 244 | token_ids_0 (:obj:`List[int]`): 245 | List of IDs to which the special tokens will be added 246 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 247 | Optional second list of IDs for sequence pairs. 248 | 249 | Returns: 250 | :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 251 | """ 252 | if token_ids_1 is None: 253 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 254 | cls = [self.cls_token_id] 255 | sep = [self.sep_token_id] 256 | return cls + token_ids_0 + sep + token_ids_1 + sep 257 | 258 | def get_special_tokens_mask( 259 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 260 | ) -> List[int]: 261 | """ 262 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 263 | special tokens using the tokenizer ``prepare_for_model`` method. 264 | 265 | Args: 266 | token_ids_0 (:obj:`List[int]`): 267 | List of ids. 268 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 269 | Optional second list of IDs for sequence pairs. 270 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 271 | Set to True if the token list is already formatted with special tokens for the model 272 | 273 | Returns: 274 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 275 | """ 276 | 277 | if already_has_special_tokens: 278 | if token_ids_1 is not None: 279 | raise ValueError( 280 | "You should not supply a second sequence if the provided sequence of " 281 | "ids is already formated with special tokens for the model." 282 | ) 283 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 284 | 285 | if token_ids_1 is not None: 286 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 287 | return [1] + ([0] * len(token_ids_0)) + [1] 288 | 289 | def create_token_type_ids_from_sequences( 290 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 291 | ) -> List[int]: 292 | """ 293 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 294 | A BERT sequence pair mask has the following format: 295 | 296 | :: 297 | 298 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 299 | | first sequence | second sequence | 300 | 301 | if token_ids_1 is None, only returns the first portion of the mask (0's). 302 | 303 | Args: 304 | token_ids_0 (:obj:`List[int]`): 305 | List of ids. 306 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 307 | Optional second list of IDs for sequence pairs. 308 | 309 | Returns: 310 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 311 | sequence(s). 312 | """ 313 | sep = [self.sep_token_id] 314 | cls = [self.cls_token_id] 315 | if token_ids_1 is None: 316 | return len(cls + token_ids_0 + sep) * [0] 317 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 318 | 319 | def save_vocabulary(self, vocab_path): 320 | """ 321 | Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. 322 | 323 | Args: 324 | vocab_path (:obj:`str`): 325 | The directory in which to save the vocabulary. 326 | 327 | Returns: 328 | :obj:`Tuple(str)`: Paths to the files saved. 329 | """ 330 | index = 0 331 | if os.path.isdir(vocab_path): 332 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) 333 | else: 334 | vocab_file = vocab_path 335 | with open(vocab_file, "w", encoding="utf-8") as writer: 336 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 337 | if index != token_index: 338 | logger.warning( 339 | "Saving vocabulary to {}: vocabulary indices are not consecutive." 340 | " Please check that the vocabulary is not corrupted!".format(vocab_file) 341 | ) 342 | index = token_index 343 | writer.write(token + "\n") 344 | index += 1 345 | return (vocab_file,) 346 | 347 | 348 | class BasicTokenizer(object): 349 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 350 | 351 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 352 | """ Constructs a BasicTokenizer. 353 | 354 | Args: 355 | **do_lower_case**: Whether to lower case the input. 356 | **never_split**: (`optional`) list of str 357 | Kept for backward compatibility purposes. 358 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 359 | List of token not to split. 360 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 361 | Whether to tokenize Chinese characters. 362 | This should likely be deactivated for Japanese: 363 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 364 | """ 365 | if never_split is None: 366 | never_split = [] 367 | self.do_lower_case = do_lower_case 368 | self.never_split = set(never_split) 369 | self.tokenize_chinese_chars = tokenize_chinese_chars 370 | 371 | def tokenize(self, text, never_split=None): 372 | """ Basic Tokenization of a piece of text. 373 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 374 | 375 | Args: 376 | **never_split**: (`optional`) list of str 377 | Kept for backward compatibility purposes. 378 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 379 | List of token not to split. 380 | """ 381 | # union() returns a new set by concatenating the two sets. 382 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split 383 | 384 | # This was added on November 1st, 2018 for the multilingual and Chinese 385 | # models. This is also applied to the English models now, but it doesn't 386 | # matter since the English models were not trained on any Chinese data 387 | # and generally don't have any Chinese data in them (there are Chinese 388 | # characters in the vocabulary because Wikipedia does have some Chinese 389 | # words in the English Wikipedia.). 390 | if self.tokenize_chinese_chars: 391 | text = self._tokenize_chinese_chars(text) 392 | orig_tokens = whitespace_tokenize(text) 393 | split_tokens = [] 394 | for token in orig_tokens: 395 | if self.do_lower_case and token not in never_split: 396 | token = token.lower() 397 | token = self._run_strip_accents(token) 398 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 399 | 400 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 401 | return output_tokens 402 | 403 | def _run_strip_accents(self, text): 404 | """Strips accents from a piece of text.""" 405 | text = unicodedata.normalize("NFD", text) 406 | output = [] 407 | for char in text: 408 | cat = unicodedata.category(char) 409 | if cat == "Mn": 410 | continue 411 | output.append(char) 412 | return "".join(output) 413 | 414 | def _run_split_on_punc(self, text, never_split=None): 415 | """Splits punctuation on a piece of text.""" 416 | if never_split is not None and text in never_split: 417 | return [text] 418 | chars = list(text) 419 | i = 0 420 | start_new_word = True 421 | output = [] 422 | while i < len(chars): 423 | char = chars[i] 424 | if _is_punctuation(char): 425 | output.append([char]) 426 | start_new_word = True 427 | else: 428 | if start_new_word: 429 | output.append([]) 430 | start_new_word = False 431 | output[-1].append(char) 432 | i += 1 433 | 434 | return ["".join(x) for x in output] 435 | 436 | def _tokenize_chinese_chars(self, text): 437 | """Adds whitespace around any CJK character.""" 438 | output = [] 439 | for char in text: 440 | cp = ord(char) 441 | if self._is_chinese_char(cp): 442 | output.append(" ") 443 | output.append(char) 444 | output.append(" ") 445 | else: 446 | output.append(char) 447 | return "".join(output) 448 | 449 | def _is_chinese_char(self, cp): 450 | """Checks whether CP is the codepoint of a CJK character.""" 451 | # This defines a "chinese character" as anything in the CJK Unicode block: 452 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 453 | # 454 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 455 | # despite its name. The modern Korean Hangul alphabet is a different block, 456 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 457 | # space-separated words, so they are not treated specially and handled 458 | # like the all of the other languages. 459 | if ( 460 | (cp >= 0x4E00 and cp <= 0x9FFF) 461 | or (cp >= 0x3400 and cp <= 0x4DBF) # 462 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 463 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 464 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 465 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 466 | or (cp >= 0xF900 and cp <= 0xFAFF) 467 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 468 | ): # 469 | return True 470 | 471 | return False 472 | 473 | def _clean_text(self, text): 474 | """Performs invalid character removal and whitespace cleanup on text.""" 475 | output = [] 476 | for char in text: 477 | cp = ord(char) 478 | if cp == 0 or cp == 0xFFFD or _is_control(char): 479 | continue 480 | if _is_whitespace(char): 481 | output.append(" ") 482 | else: 483 | output.append(char) 484 | return "".join(output) 485 | 486 | 487 | class WordpieceTokenizer(object): 488 | """Runs WordPiece tokenization.""" 489 | 490 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 491 | self.vocab = vocab 492 | self.unk_token = unk_token 493 | self.max_input_chars_per_word = max_input_chars_per_word 494 | 495 | def tokenize(self, text): 496 | """Tokenizes a piece of text into its word pieces. 497 | 498 | This uses a greedy longest-match-first algorithm to perform tokenization 499 | using the given vocabulary. 500 | 501 | For example: 502 | input = "unaffable" 503 | output = ["un", "##aff", "##able"] 504 | 505 | Args: 506 | text: A single token or whitespace separated tokens. This should have 507 | already been passed through `BasicTokenizer`. 508 | 509 | Returns: 510 | A list of wordpiece tokens. 511 | """ 512 | 513 | output_tokens = [] 514 | for token in whitespace_tokenize(text): 515 | chars = list(token) 516 | if len(chars) > self.max_input_chars_per_word: 517 | output_tokens.append(self.unk_token) 518 | continue 519 | 520 | is_bad = False 521 | start = 0 522 | sub_tokens = [] 523 | while start < len(chars): 524 | end = len(chars) 525 | cur_substr = None 526 | while start < end: 527 | substr = "".join(chars[start:end]) 528 | if start > 0: 529 | substr = "##" + substr 530 | if substr in self.vocab: 531 | cur_substr = substr 532 | break 533 | end -= 1 534 | if cur_substr is None: 535 | is_bad = True 536 | break 537 | sub_tokens.append(cur_substr) 538 | start = end 539 | 540 | if is_bad: 541 | output_tokens.append(self.unk_token) 542 | else: 543 | output_tokens.extend(sub_tokens) 544 | return output_tokens 545 | 546 | 547 | -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/tfserving/inferencer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | 7 | import json 8 | import os 9 | 10 | import requests 11 | 12 | from minimal_rnr.tfserving.bert_tokenizer import BertTokenizer 13 | from minimal_rnr.utils.inferencer import MinimalRnR 14 | 15 | 16 | class TFServingMinimalRnR(MinimalRnR): 17 | def __init__(self, args): 18 | super(TFServingMinimalRnR, self).__init__(args) 19 | 20 | self.url = f"http://{args.tfserving_ip}:{args.tfserving_port}/v1/models/minimal-rnr-qa:predict" 21 | self.tokenizer = BertTokenizer.from_pretrained( 22 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "bert_tokenizer/mobilebert-uncased")) 23 | 24 | def get_question_encoding(self, question): 25 | return self._get_api_output(self._get_question_encoder_api_input(question)) 26 | 27 | def get_retriever_output(self, question, top_k): 28 | retrieved_doc_ids = self._get_api_output(self._get_retrieve_api_input(question, top_k)) 29 | return retrieved_doc_ids 30 | 31 | def get_reader_output(self, reader_input): 32 | reader_api_input = self._get_api_input("read", reader_input) 33 | reader_output = self._get_api_output(reader_api_input) 34 | 35 | if not isinstance(reader_output["relevance_logits"], (tuple, list)): # for top_k=1 36 | reader_output["relevance_logits"] = [reader_output["relevance_logits"]] 37 | 38 | return reader_output 39 | 40 | def get_passage_score_weighted_answer_token_logits(self, token_logits, relevance_logits, attn_mask, passage_score_weight): 41 | weighted_logits = self._get_api_output(self._get_score_api_input(token_logits, relevance_logits, attn_mask, passage_score_weight)) 42 | return weighted_logits 43 | 44 | def _get_question_encoder_api_input(self, question): 45 | input_ = self.tokenizer([question], max_length=self.max_retriever_input_len, truncation=True, return_token_type_ids=True) 46 | return self._get_api_input("encode", input_, preserve_token_type_ids=True) 47 | 48 | def _get_retrieve_api_input(self, question, top_k): 49 | input_ = self.tokenizer([question], max_length=self.max_retriever_input_len, truncation=True) 50 | input_["top_k"] = top_k 51 | 52 | return self._get_api_input("retrieve", input_) 53 | 54 | def _get_score_api_input(self, token_logits, relevace_logits, attn_mask, passage_score_weight): 55 | input_ = { 56 | "token_logits": token_logits, 57 | "relevance_logits": relevace_logits, 58 | "attn_mask": attn_mask, 59 | "passage_score_weight": passage_score_weight, 60 | } 61 | return json.dumps({ 62 | "signature_name": "get_score", 63 | "inputs": input_, 64 | }) 65 | 66 | def _get_api_input(self, signature_name, input_, preserve_token_type_ids=False): 67 | if not preserve_token_type_ids and "token_type_ids" in input_: 68 | del input_["token_type_ids"] 69 | 70 | if type(input_) != dict: 71 | input_ = dict(input_) 72 | 73 | return json.dumps({ 74 | "signature_name": signature_name, 75 | "inputs": input_, 76 | }) 77 | 78 | def _get_api_output(self, payload): 79 | response = requests.post(self.url, payload) 80 | if response.status_code != 200: 81 | response.raise_for_status() 82 | return response.json()["outputs"] -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/utils/__init__.py -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/demo.py: -------------------------------------------------------------------------------- 1 | # Copied and modified from https://github.com/uwnlp/denspi 2 | 3 | from time import time 4 | 5 | from flask import Flask, request, jsonify 6 | from tornado.httpserver import HTTPServer 7 | from tornado.ioloop import IOLoop 8 | from tornado.wsgi import WSGIContainer 9 | 10 | from minimal_rnr.utils.logger import get_logger 11 | 12 | 13 | def run_app(args, minimal_rnr): 14 | logger = get_logger("minimal-rnr-qa") 15 | 16 | inference_api = minimal_rnr.get_inference_api() 17 | 18 | app = Flask(__name__, static_url_path='/static') 19 | app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False 20 | 21 | def _search(query, top_k, passage_score_weight): 22 | start = time() 23 | result = inference_api(query, top_k, passage_score_weight) 24 | return {"ret": result, "time": int((time() - start))} 25 | 26 | @app.route("/") 27 | def index(): 28 | return app.send_static_file('index.html') 29 | 30 | @app.route("/files/") 31 | def static_files(path): 32 | return app.send_static_file('files/' + path) 33 | 34 | @app.route("/api", methods=["GET"]) 35 | def api(): 36 | logger.info(request.args) 37 | 38 | query = request.args["query"] 39 | top_k = int(request.args["top_k"]) 40 | 41 | if request.args["passage_score_weight"] == "null": 42 | passage_score_weight = None 43 | else: 44 | passage_score_weight = float(request.args["passage_score_weight"]) 45 | 46 | result = _search(query, top_k, passage_score_weight) 47 | logger.info(result) 48 | return jsonify(result) 49 | 50 | @app.route("/get_examples", methods=["GET"]) 51 | def get_examples(): 52 | with open(args.examples_path, "r") as fp: 53 | examples = [line.strip() for line in fp.readlines()] 54 | return jsonify(examples) 55 | 56 | @app.route("/quit") 57 | def quit(): 58 | raise KeyboardInterrupt 59 | 60 | logger.info("Warming up...") 61 | minimal_rnr.predict_answer("warmup", top_k=5, passage_score_weight=0.8) 62 | 63 | logger.info(f"Starting server at {args.demo_port}") 64 | http_server = HTTPServer(WSGIContainer(app)) 65 | http_server.listen(args.demo_port) 66 | IOLoop.instance().start() -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # ****************************************************************** 3 | # Copied and modified from https://github.com/facebookresearch/FiD * 4 | # ****************************************************************** 5 | # Copyright (c) Facebook, Inc. and its affiliates. 6 | # All rights reserved. 7 | # 8 | # This source code is licensed under the license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | 11 | import string 12 | import unicodedata 13 | 14 | import regex 15 | 16 | 17 | def _normalize(text): 18 | # Copied from SQuAD evaluation script 19 | # https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 20 | return unicodedata.normalize('NFD', text) 21 | 22 | 23 | def normalize_answer(s): 24 | def remove_articles(text): 25 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 26 | 27 | def white_space_fix(text): 28 | return ' '.join(text.split()) 29 | 30 | def remove_punc(text): 31 | exclude = set(string.punctuation) 32 | return ''.join(ch for ch in text if ch not in exclude) 33 | 34 | def lower(text): 35 | return text.lower() 36 | 37 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 38 | 39 | 40 | def exact_match_score(prediction, ground_truth): 41 | return normalize_answer(prediction) == normalize_answer(ground_truth) 42 | 43 | 44 | def ems(prediction, ground_truths): 45 | return max([exact_match_score(prediction, gt) for gt in ground_truths]) -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/inference.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import glob 3 | import json 4 | import os 5 | import regex 6 | 7 | 8 | class SimpleTokenizer(object): 9 | # Copied and modified from https://github.com/facebookresearch/FiD 10 | 11 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 12 | NON_WS = r'[^\p{Z}\p{C}]' 13 | 14 | def __init__(self): 15 | self._regexp = regex.compile( 16 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 17 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 18 | ) 19 | 20 | def tokenize(self, text, uncased=False): 21 | matches = [m for m in self._regexp.finditer(text)] 22 | if uncased: 23 | tokens = [m.group().lower() for m in matches] 24 | else: 25 | tokens = [m.group() for m in matches] 26 | return tokens 27 | 28 | 29 | def extend_span_to_full_words(tokenizer, tokens, span): 30 | # Copied and modified from https://github.com/facebookresearch/DPR 31 | 32 | def is_sub_word_id(tokenizer, token_id): 33 | token = tokenizer.convert_ids_to_tokens([token_id])[0] 34 | return token.startswith("##") or token.startswith(" ##") 35 | 36 | start_index, end_index = span 37 | max_len = len(tokens) 38 | while start_index > 0 and is_sub_word_id(tokenizer, tokens[start_index]): 39 | start_index -= 1 40 | 41 | while end_index < max_len - 1 and is_sub_word_id(tokenizer, tokens[end_index + 1]): 42 | end_index += 1 43 | 44 | return start_index, end_index 45 | 46 | 47 | def read_txt(file_path): 48 | data = [] 49 | with open(file_path) as f: 50 | for line in f: 51 | ids = [int(x) for x in line.rstrip().split(" ")] 52 | data.append(ids) 53 | return data 54 | 55 | 56 | def get_first_matched_file_path(directory, dataset, file_pattern): 57 | return sorted(glob.glob(os.path.join(directory, dataset, file_pattern)))[0] 58 | 59 | 60 | def normalize_question(question): 61 | question = question.lower() 62 | if question.endswith("?"): 63 | question = question[:-1] 64 | return question 65 | 66 | 67 | def _read_through_csv_qa_file(file_path): 68 | with open(file_path, encoding="utf-8") as f: 69 | reader = csv.reader(f, delimiter='\t') 70 | for row in reader: 71 | question = normalize_question(row[0]) 72 | answers = None if len(row) < 2 else row[1] 73 | 74 | if answers is None: 75 | qa_dict = {"question": question} 76 | else: 77 | qa_dict = {"question": question, "answers": answers} 78 | yield qa_dict 79 | 80 | 81 | def _read_through_jsonl_qa_file(file_path): 82 | with open(file_path, encoding="utf-8") as f: 83 | for line in f: 84 | qa_dict = json.loads(line.strip()) 85 | qa_dict["question"] = normalize_question(qa_dict["question"]) 86 | 87 | if "answer" in qa_dict: 88 | qa_dict["answers"] = qa_dict["answer"] 89 | del qa_dict["answer"] 90 | 91 | yield qa_dict 92 | 93 | 94 | def read_through_qa_file(file_path): 95 | if file_path.endswith(".csv"): 96 | return _read_through_csv_qa_file(file_path) 97 | elif file_path.endswith(".jsonl") or file_path.endswith("jl"): 98 | return _read_through_jsonl_qa_file(file_path) 99 | else: 100 | raise ValueError(f"File {file_path} must be either csv or jsonlines file") -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/inferencer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | import abc 7 | import json 8 | from timeit import default_timer as timer 9 | 10 | from minimal_rnr.utils.inference import read_txt, get_first_matched_file_path, read_through_qa_file, \ 11 | extend_span_to_full_words 12 | from minimal_rnr.utils.logger import get_logger 13 | 14 | 15 | class MinimalRnR(object, metaclass=abc.ABCMeta): 16 | def __init__(self, args): 17 | self.logger = get_logger("minimal-rnr-qa") 18 | 19 | titles_file = get_first_matched_file_path(args.resources_path, args.dataset, "*.titles.txt") 20 | self.logger.info(f"Loading titles from {titles_file}...") 21 | self.all_titles = read_txt(titles_file) 22 | 23 | docs_file = get_first_matched_file_path(args.resources_path, args.dataset, "*.docs.txt") 24 | self.logger.info(f"Loading docs from {docs_file}...") 25 | self.all_docs = read_txt(docs_file) 26 | 27 | if args.use_faiss_index: 28 | import faiss 29 | index_file = get_first_matched_file_path(args.resources_path, args.dataset, "*.index") 30 | self.logger.info(f"Loading index from {index_file}...") 31 | self.index = faiss.read_index(index_file) 32 | 33 | import numpy as np 34 | self.np = np 35 | else: 36 | self.index = None 37 | 38 | self.tokenizer = None # must be overriden 39 | 40 | self.max_retriever_input_len = 256 41 | self.max_reader_input_len = 350 42 | self.max_answer_len = 10 43 | self.num_contexts = 10 44 | self.num_passage_answer_candidates = 5 45 | 46 | @abc.abstractmethod 47 | def get_question_encoding(self, question): 48 | pass 49 | 50 | @abc.abstractmethod 51 | def get_retriever_output(self, question, top_k): 52 | pass 53 | 54 | @abc.abstractmethod 55 | def get_reader_output(self, reader_input): 56 | pass 57 | 58 | @abc.abstractmethod 59 | def get_passage_score_weighted_answer_token_logits(self, token_logits, relevance_logits, attn_mask, passage_score_weight): 60 | pass 61 | 62 | def maybe_tensor_to_list(self, tensor): 63 | return tensor 64 | 65 | def get_inference_api(self): 66 | def api(question, top_k, passage_score_weight): 67 | return self.predict_answer(question, top_k, passage_score_weight, return_context=True) 68 | 69 | return api 70 | 71 | def inference_on_file(self, input_path, output_path, top_k=50, passage_score_weight=None): 72 | num_correct = 0 73 | total = 0 74 | 75 | evaluation_start = timer() 76 | with open(output_path, "w", encoding="utf-8") as f: 77 | for qa_dict in read_through_qa_file(input_path): 78 | start = timer() 79 | question = qa_dict["question"] 80 | prediction = self.predict_answer(question, top_k, passage_score_weight) 81 | qa_dict["prediction"] = prediction 82 | 83 | if "answers" in qa_dict: 84 | from minimal_rnr.utils.evaluation import ems 85 | correct = ems(prediction, qa_dict["answers"]) 86 | qa_dict["correct"] = correct 87 | 88 | total += 1 89 | num_correct += int(correct) 90 | 91 | f.write(json.dumps(qa_dict, ensure_ascii=False) + "\n") 92 | self.logger.info(str(qa_dict) + " (%.4fs)" % (timer() - start)) 93 | 94 | if total > 0: 95 | self.logger.info(f"EM: {100 * num_correct / total:.4f} ({num_correct} / {total})") 96 | self.logger.info(f"Evaluation took {timer() - evaluation_start:.4f}s.") 97 | 98 | def predict_answer(self, question, top_k=50, passage_score_weight=None, return_context=False): 99 | # retrieve 100 | start = timer() 101 | if self.index: 102 | retrieved_doc_ids = self._get_retriever_output_from_faiss_index(question, top_k) 103 | else: 104 | retrieved_doc_ids = self.get_retriever_output(question, top_k) 105 | 106 | self.logger.info(f" retrieve: {timer() - start:.4f}s") 107 | 108 | start = timer() 109 | title_doc_dict = self._get_title_doc_dict(retrieved_doc_ids) 110 | reader_input = self._get_reader_input(question, title_doc_dict) 111 | self.logger.info(f" convert: {timer() - start:.4f}s") 112 | 113 | # read 114 | start = timer() 115 | reader_output = self.get_reader_output(reader_input) 116 | self.logger.info(f" read: {timer() - start:.4f}s") 117 | 118 | start = timer() 119 | if passage_score_weight is not None: 120 | answer = self._get_answer_deep(reader_input, reader_output, title_doc_dict["titles"], passage_score_weight, return_context=return_context) 121 | else: 122 | answer = self._get_answer_greedy(reader_input, reader_output, title_doc_dict["titles"], return_context=return_context) 123 | self.logger.info(f" search: {timer() - start:.4f}s") 124 | return answer 125 | 126 | def _get_retriever_output_from_faiss_index(self, question, top_k): 127 | start = timer() 128 | question_encoding = self.get_question_encoding(question) 129 | self.logger.info(f" * encode: {timer() - start:.4f}s") 130 | 131 | start = timer() 132 | if not isinstance(question_encoding, self.np.ndarray): # tfserving_faiss 133 | question_encoding = self.np.asarray(question_encoding, dtype=self.np.float32) 134 | _, np_retrieved_doc_ids = self.index.search(question_encoding, top_k) 135 | self.logger.info(f" * faiss search: {timer() - start:.4f}s") 136 | retrieved_doc_ids = np_retrieved_doc_ids[0] 137 | return retrieved_doc_ids 138 | 139 | def _get_title_doc_dict(self, retrieved_doc_ids): 140 | retrieved_titles = [] 141 | retrieved_docs = [] 142 | 143 | for i in retrieved_doc_ids: 144 | retrieved_titles.append(self.all_titles[i]) 145 | retrieved_docs.append(self.all_docs[i]) 146 | 147 | return { 148 | "titles": retrieved_titles, 149 | "docs": retrieved_docs, 150 | } 151 | 152 | def _get_reader_input(self, question_str, title_doc_dict): 153 | input_ids = [] 154 | attention_mask = [] 155 | 156 | retrieved_titles = title_doc_dict["titles"] 157 | retrieved_docs = title_doc_dict["docs"] 158 | 159 | question = self.tokenizer.encode(question_str, max_length=self.max_retriever_input_len, truncation=True) 160 | 161 | # concat inputs 162 | for title, doc in zip(retrieved_titles, retrieved_docs): 163 | concat = question + title + [self.tokenizer.sep_token_id] + doc 164 | concat = concat[:self.max_reader_input_len] 165 | input_ids.append(concat) 166 | max_len = max(len(ids) for ids in input_ids) 167 | 168 | # pad inputs 169 | for i in range(len(input_ids)): 170 | padding = [self.tokenizer.pad_token_id] * (max_len - len(input_ids[i])) 171 | attention_mask.append([1] * len(input_ids[i]) + padding) 172 | input_ids[i] = input_ids[i] + padding 173 | 174 | return { 175 | "input_ids": input_ids, 176 | "attention_mask": attention_mask, 177 | } 178 | 179 | def _get_answer_deep(self, reader_input, reader_output, retrieved_titles, passage_score_weight=None, return_context=False): 180 | input_ids = reader_input["input_ids"] 181 | attn_mask = reader_input["attention_mask"] 182 | 183 | _start_logits = reader_output["start_logits"] 184 | _end_logits = reader_output["end_logits"] 185 | _relevance_logits = reader_output["relevance_logits"] 186 | 187 | # weighted scores 188 | start_logits = self.get_passage_score_weighted_answer_token_logits(_start_logits, _relevance_logits, attn_mask, passage_score_weight) 189 | end_logits = self.get_passage_score_weighted_answer_token_logits(_end_logits, _relevance_logits, attn_mask, passage_score_weight) 190 | 191 | start_logits = self.maybe_tensor_to_list(start_logits) 192 | end_logits = self.maybe_tensor_to_list(end_logits) 193 | 194 | candidate_answers = [] 195 | candidate_contexts = [] 196 | 197 | for passage_idx in range(len(input_ids)): 198 | sequence_len = sum(id_ != 0 for id_ in input_ids[passage_idx]) 199 | passage_offset = input_ids[passage_idx].index(self.tokenizer.sep_token_id) + 1 200 | title_passage_ids = input_ids[passage_idx][passage_offset:sequence_len] 201 | 202 | p_start_logits = start_logits[passage_idx][passage_offset:sequence_len] 203 | p_end_logits = end_logits[passage_idx][passage_offset:sequence_len] 204 | 205 | scores = self._get_spans_sorted_with_scores(p_start_logits, p_end_logits) 206 | 207 | chosen_span_intervals = [] 208 | p_candidate_answers = [] 209 | p_candidate_contexts = [] 210 | 211 | for (start_index, end_index), score in scores: 212 | ret = self._get_answer_and_passage(start_index, end_index, chosen_span_intervals, title_passage_ids) 213 | if not ret: 214 | continue 215 | else: 216 | answer, passage, start_index, ent_index = ret 217 | title = retrieved_titles[passage_idx] 218 | 219 | if not return_context: 220 | p_candidate_answers.append((answer, score)) 221 | else: 222 | context = (self.tokenizer.decode(passage[:start_index]) 223 | + " " + answer + " " 224 | + self.tokenizer.decode(passage[end_index + 1:])) 225 | p_candidate_contexts.append(({"title": self.tokenizer.decode(title), 226 | "context": context}, score)) 227 | 228 | if max(len(p_candidate_answers), len(p_candidate_contexts)) == self.num_passage_answer_candidates: 229 | break 230 | 231 | if p_candidate_answers: 232 | candidate_answers.extend(p_candidate_answers) 233 | if p_candidate_contexts: 234 | candidate_contexts.extend(p_candidate_contexts) 235 | 236 | if not return_context: 237 | sorted_candidate_answers = sorted(candidate_answers, key=lambda x: x[1], reverse=True) 238 | return sorted_candidate_answers[0][0].strip() 239 | else: 240 | sorted_candidate_contexts = sorted(candidate_contexts, key=lambda x: x[1], reverse=True) 241 | return [context[0] for context in sorted_candidate_contexts][:self.num_contexts] 242 | 243 | def _get_answer_greedy(self, reader_input, reader_output, retrieved_titles, return_context=False): 244 | start_logits = reader_output["start_logits"] 245 | end_logits = reader_output["end_logits"] 246 | relevance_logits = reader_output["relevance_logits"] 247 | 248 | start_logits = self.maybe_tensor_to_list(start_logits) 249 | end_logits = self.maybe_tensor_to_list(end_logits) 250 | relevance_logits = self.maybe_tensor_to_list(relevance_logits) 251 | 252 | max_answer_length = 10 253 | input_ids = reader_input["input_ids"] 254 | 255 | top_passage_idx = max(enumerate(relevance_logits), key=lambda x: x[1])[0] 256 | 257 | sequence_len = sum(id_ != 0 for id_ in input_ids[top_passage_idx]) 258 | passage_offset = input_ids[top_passage_idx].index(self.tokenizer.sep_token_id) + 1 259 | title_passage_ids = input_ids[top_passage_idx][passage_offset:sequence_len] 260 | p_start_logits = start_logits[top_passage_idx][passage_offset:sequence_len] 261 | p_end_logits = end_logits[top_passage_idx][passage_offset:sequence_len] 262 | 263 | scores = self._get_spans_sorted_with_scores(p_start_logits, p_end_logits) 264 | chosen_span_intervals = [] 265 | 266 | for (start_index, end_index), score in scores: 267 | assert start_index <= end_index 268 | length = end_index - start_index + 1 269 | assert length <= max_answer_length 270 | 271 | ret = self._get_answer_and_passage(start_index, end_index, chosen_span_intervals, title_passage_ids) 272 | if not ret: 273 | continue 274 | answer, passage, start_index, end_index = ret 275 | title = retrieved_titles[top_passage_idx] 276 | 277 | if not return_context: 278 | return answer.strip() 279 | else: 280 | context = (self.tokenizer.decode(passage[:start_index]) 281 | + " " + answer + " " 282 | + self.tokenizer.decode(passage[end_index + 1:])) 283 | return [{"title": self.tokenizer.decode(title), "context": context}] 284 | 285 | def _get_spans_sorted_with_scores(self, p_start_logits, p_end_logits): 286 | scores = [] 287 | for (i, s) in enumerate(p_start_logits): 288 | for (j, e) in enumerate(p_end_logits[i:i + self.max_answer_len]): 289 | scores.append(((i, i + j), s + e)) 290 | scores = sorted(scores, key=lambda x: x[1], reverse=True) 291 | return scores 292 | 293 | def _get_answer_and_passage(self, start_index, end_index, chosen_span_intervals, title_passage_ids): 294 | assert start_index <= end_index 295 | length = end_index - start_index + 1 296 | assert length <= self.max_answer_len 297 | 298 | if any([start_index <= prev_start_index <= prev_end_index <= end_index or 299 | prev_start_index <= start_index <= end_index <= prev_end_index 300 | for (prev_start_index, prev_end_index) in chosen_span_intervals]): 301 | return 302 | 303 | start_index, end_index = extend_span_to_full_words(self.tokenizer, title_passage_ids, (start_index, end_index)) 304 | 305 | title_offset = title_passage_ids.index(self.tokenizer.sep_token_id) + 1 306 | 307 | passage = title_passage_ids[title_offset:] 308 | start_index -= title_offset 309 | end_index -= title_offset 310 | 311 | answer = self.tokenizer.decode(passage[start_index:end_index + 1]) 312 | 313 | return answer, passage, start_index, end_index 314 | -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/logger.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | import sys 7 | import logging 8 | 9 | 10 | def get_logger(name): 11 | if logging.getLogger(name).hasHandlers(): 12 | return logging.getLogger(name) 13 | 14 | # initialization 15 | formatter = logging.Formatter(fmt="[MinR&R %(asctime)s] %(message)s", 16 | datefmt="%Y-%m-%d %H:%M:%S") 17 | handler = logging.StreamHandler(stream=sys.stdout) 18 | handler.setFormatter(formatter) 19 | 20 | logger = logging.getLogger(name) 21 | logger.setLevel(logging.DEBUG) 22 | logger.addHandler(handler) 23 | logger.propagate = False 24 | return logger -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/static/examples.txt: -------------------------------------------------------------------------------- 1 | What is the purpose of life? 2 | What is Naver known for? 3 | Famous musician in South Korea 4 | Where can you find water in desert? 5 | Name three famous writers 6 | Why is regular exercise important? 7 | Which city is famous for coffee? 8 | who is the villain in Harry Potter? 9 | What does a developer do? 10 | When is National Liberation Day of Korea? 11 | What is another term for x-ray imaging? 12 | When did the movie Avengers come out? 13 | What is water consisted of? 14 | Who is the director of the movie Interstellar? 15 | In what year did Nikola Tesla emigrate to the United States? 16 | Who coined the term Deep Learning? 17 | The most beautiful poem in the world 18 | What is Machine Learning? 19 | Who is the main character in Frozen? -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/static/files/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/minimal-rnr-qa/9db881a031ec67a661b71f56598ae8720dc946eb/playground/workspace/minimal_rnr/utils/static/files/icon.png -------------------------------------------------------------------------------- /playground/workspace/minimal_rnr/utils/static/files/popper.min.js: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) Federico Zivolo 2018 3 | Distributed under the MIT License (license terms are at http://opensource.org/licenses/MIT). 4 | */(function(e,t){'object'==typeof exports&&'undefined'!=typeof module?module.exports=t():'function'==typeof define&&define.amd?define(t):e.Popper=t()})(this,function(){'use strict';function e(e){return e&&'[object Function]'==={}.toString.call(e)}function t(e,t){if(1!==e.nodeType)return[];var o=getComputedStyle(e,null);return t?o[t]:o}function o(e){return'HTML'===e.nodeName?e:e.parentNode||e.host}function n(e){if(!e)return document.body;switch(e.nodeName){case'HTML':case'BODY':return e.ownerDocument.body;case'#document':return e.body;}var i=t(e),r=i.overflow,p=i.overflowX,s=i.overflowY;return /(auto|scroll|overlay)/.test(r+s+p)?e:n(o(e))}function r(e){return 11===e?re:10===e?pe:re||pe}function p(e){if(!e)return document.documentElement;for(var o=r(10)?document.body:null,n=e.offsetParent;n===o&&e.nextElementSibling;)n=(e=e.nextElementSibling).offsetParent;var i=n&&n.nodeName;return i&&'BODY'!==i&&'HTML'!==i?-1!==['TD','TABLE'].indexOf(n.nodeName)&&'static'===t(n,'position')?p(n):n:e?e.ownerDocument.documentElement:document.documentElement}function s(e){var t=e.nodeName;return'BODY'!==t&&('HTML'===t||p(e.firstElementChild)===e)}function d(e){return null===e.parentNode?e:d(e.parentNode)}function a(e,t){if(!e||!e.nodeType||!t||!t.nodeType)return document.documentElement;var o=e.compareDocumentPosition(t)&Node.DOCUMENT_POSITION_FOLLOWING,n=o?e:t,i=o?t:e,r=document.createRange();r.setStart(n,0),r.setEnd(i,0);var l=r.commonAncestorContainer;if(e!==l&&t!==l||n.contains(i))return s(l)?l:p(l);var f=d(e);return f.host?a(f.host,t):a(e,d(t).host)}function l(e){var t=1=o.clientWidth&&n>=o.clientHeight}),l=0a[e]&&!t.escapeWithReference&&(n=J(f[o],a[e]-('right'===e?f.width:f.height))),ae({},o,n)}};return l.forEach(function(e){var t=-1===['left','top'].indexOf(e)?'secondary':'primary';f=le({},f,m[t](e))}),e.offsets.popper=f,e},priority:['left','right','top','bottom'],padding:5,boundariesElement:'scrollParent'},keepTogether:{order:400,enabled:!0,fn:function(e){var t=e.offsets,o=t.popper,n=t.reference,i=e.placement.split('-')[0],r=Z,p=-1!==['top','bottom'].indexOf(i),s=p?'right':'bottom',d=p?'left':'top',a=p?'width':'height';return o[s]r(n[s])&&(e.offsets.popper[d]=r(n[s])),e}},arrow:{order:500,enabled:!0,fn:function(e,o){var n;if(!q(e.instance.modifiers,'arrow','keepTogether'))return e;var i=o.element;if('string'==typeof i){if(i=e.instance.popper.querySelector(i),!i)return e;}else if(!e.instance.popper.contains(i))return console.warn('WARNING: `arrow.element` must be child of its popper element!'),e;var r=e.placement.split('-')[0],p=e.offsets,s=p.popper,d=p.reference,a=-1!==['left','right'].indexOf(r),l=a?'height':'width',f=a?'Top':'Left',m=f.toLowerCase(),h=a?'left':'top',c=a?'bottom':'right',u=S(i)[l];d[c]-us[c]&&(e.offsets.popper[m]+=d[m]+u-s[c]),e.offsets.popper=g(e.offsets.popper);var b=d[m]+d[l]/2-u/2,y=t(e.instance.popper),w=parseFloat(y['margin'+f],10),E=parseFloat(y['border'+f+'Width'],10),v=b-e.offsets.popper[m]-w-E;return v=$(J(s[l]-u,v),0),e.arrowElement=i,e.offsets.arrow=(n={},ae(n,m,Q(v)),ae(n,h,''),n),e},element:'[x-arrow]'},flip:{order:600,enabled:!0,fn:function(e,t){if(W(e.instance.modifiers,'inner'))return e;if(e.flipped&&e.placement===e.originalPlacement)return e;var o=v(e.instance.popper,e.instance.reference,t.padding,t.boundariesElement,e.positionFixed),n=e.placement.split('-')[0],i=T(n),r=e.placement.split('-')[1]||'',p=[];switch(t.behavior){case he.FLIP:p=[n,i];break;case he.CLOCKWISE:p=z(n);break;case he.COUNTERCLOCKWISE:p=z(n,!0);break;default:p=t.behavior;}return p.forEach(function(s,d){if(n!==s||p.length===d+1)return e;n=e.placement.split('-')[0],i=T(n);var a=e.offsets.popper,l=e.offsets.reference,f=Z,m='left'===n&&f(a.right)>f(l.left)||'right'===n&&f(a.left)f(l.top)||'bottom'===n&&f(a.top)f(o.right),g=f(a.top)f(o.bottom),b='left'===n&&h||'right'===n&&c||'top'===n&&g||'bottom'===n&&u,y=-1!==['top','bottom'].indexOf(n),w=!!t.flipVariations&&(y&&'start'===r&&h||y&&'end'===r&&c||!y&&'start'===r&&g||!y&&'end'===r&&u);(m||b||w)&&(e.flipped=!0,(m||b)&&(n=p[d+1]),w&&(r=G(r)),e.placement=n+(r?'-'+r:''),e.offsets.popper=le({},e.offsets.popper,C(e.instance.popper,e.offsets.reference,e.placement)),e=P(e.instance.modifiers,e,'flip'))}),e},behavior:'flip',padding:5,boundariesElement:'viewport'},inner:{order:700,enabled:!1,fn:function(e){var t=e.placement,o=t.split('-')[0],n=e.offsets,i=n.popper,r=n.reference,p=-1!==['left','right'].indexOf(o),s=-1===['top','left'].indexOf(o);return i[p?'left':'top']=r[o]-(s?i[p?'width':'height']:0),e.placement=T(t),e.offsets.popper=g(i),e}},hide:{order:800,enabled:!0,fn:function(e){if(!q(e.instance.modifiers,'hide','preventOverflow'))return e;var t=e.offsets.reference,o=D(e.instance.modifiers,function(e){return'preventOverflow'===e.name}).boundaries;if(t.bottomo.right||t.top>o.bottom||t.right 2 | 3 | 4 | Minimal R&R QA 5 | 6 | 7 | 8 | 9 | 10 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 38 |
39 | 40 |
41 |
42 | 45 | 47 |
48 | 49 | 52 |
53 | 56 |
57 |
58 | 59 |
60 |
Latency:
61 |
62 | 5.8% of Wikipedia EN (Dec. 20, 2018 dump) 63 |
64 |
65 | 66 |
67 |
68 |
69 |
70 |
top_k
71 |
72 | 73 |
74 |
75 |
76 |
77 |
78 |
passage_score_weight
79 |
80 | 81 |
82 |
83 |
84 | 85 |
86 | 87 |
88 |
    89 |
  • 90 |
91 |
92 | 93 |
94 | 95 |
96 |
97 | 98 | Minimal R&R
99 | Sohee Yang and Minjoon Seo. Designing a Minimal Retrieve-and-Read Systemfor Open-Domain Question Answering. In NAACL 2021.
100 |
101 |
102 |
103 | 104 | 105 | 184 | 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /playground/workspace/run_pt_demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | import os 7 | import argparse 8 | import glob 9 | from distutils.util import strtobool 10 | 11 | from minimal_rnr.utils.demo import run_app 12 | from minimal_rnr.utils.logger import get_logger 13 | from minimal_rnr.pytorch.inferencer import TorchMinimalRnR 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--dataset", type=str, required=True) 19 | 20 | parser.add_argument("--model_path", default="/models") 21 | parser.add_argument("--resources_path", default="/resources") 22 | parser.add_argument("--use_faiss_index", default=True, type=strtobool) 23 | 24 | parser.add_argument("--demo_port", default=10001, type=int) 25 | parser.add_argument("--examples_path", default="/workspace/minimal_rnr/utils/static/examples.txt") 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def main(args): 31 | logger = get_logger("minimal-rnr-qa") 32 | logger.info(vars(args)) 33 | 34 | if args.use_faiss_index: 35 | assert glob.glob(os.path.join(args.resources_path, args.dataset, "*.index")), \ 36 | f"Index file does not exist in the path: {os.path.join(args.resources_path, args.dataset)}" 37 | 38 | minimal_rnr = TorchMinimalRnR(args) 39 | run_app(args, minimal_rnr) 40 | 41 | 42 | if __name__ == "__main__": 43 | args = get_args() 44 | main(args) 45 | -------------------------------------------------------------------------------- /playground/workspace/run_pt_inference.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | import argparse 7 | from distutils.util import strtobool 8 | 9 | from minimal_rnr.utils.logger import get_logger 10 | from minimal_rnr.pytorch.inferencer import TorchMinimalRnR 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--dataset", type=str, required=True) 16 | 17 | parser.add_argument("--model_path", default="/models") 18 | parser.add_argument("--resources_path", default="/resources") 19 | parser.add_argument("--use_faiss_index", default=True, type=strtobool) 20 | 21 | parser.add_argument("--top_k", type=int, default=50) 22 | parser.add_argument("--passage_score_weight", type=float, default=None) 23 | 24 | parser.add_argument("--input_path", type=str, default=None, required=True) 25 | parser.add_argument("--output_path", type=str, default="predictions.jsonl", required=True) 26 | 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | def main(args): 32 | logger = get_logger("minimal-rnr-qa") 33 | logger.info(vars(args)) 34 | 35 | minimal_rnr = TorchMinimalRnR(args) 36 | minimal_rnr.inference_on_file(args.input_path, args.output_path, args.top_k, args.passage_score_weight) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | main(args) 42 | -------------------------------------------------------------------------------- /playground/workspace/run_tf_demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | import os 7 | import argparse 8 | import glob 9 | from distutils.util import strtobool 10 | 11 | from minimal_rnr.utils.logger import get_logger 12 | from minimal_rnr.utils.demo import run_app 13 | from minimal_rnr.tfserving.inferencer import TFServingMinimalRnR 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--dataset", type=str, required=True) 19 | 20 | parser.add_argument("--tfserving_ip", default="127.0.0.1") 21 | parser.add_argument("--tfserving_port", default=8501) 22 | 23 | parser.add_argument("--resources_path", default="/resources") 24 | parser.add_argument("--use_faiss_index", default=True, type=strtobool) 25 | 26 | parser.add_argument("--demo_port", default=10001, type=int) 27 | parser.add_argument("--examples_path", default="/workspace/minimal_rnr/utils/static/examples.txt") 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def main(args): 33 | logger = get_logger("minimal-rnr-qa") 34 | logger.info(vars(args)) 35 | 36 | if args.use_faiss_index: 37 | assert glob.glob(os.path.join(args.resources_path, args.dataset, "*.index")), \ 38 | f"Index file does not exist in the path: {os.path.join(args.resources_path, args.dataset)}" 39 | 40 | minimal_rnr = TFServingMinimalRnR(args) 41 | run_app(args, minimal_rnr) 42 | 43 | 44 | if __name__ == "__main__": 45 | args = get_args() 46 | main(args) 47 | -------------------------------------------------------------------------------- /playground/workspace/run_tf_inference.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # minimal-rnr-qa 3 | # Copyright 2021-present NAVER Corp. 4 | # Apache License v2.0 5 | 6 | import argparse 7 | from distutils.util import strtobool 8 | 9 | from minimal_rnr.utils.logger import get_logger 10 | from minimal_rnr.tfserving.inferencer import TFServingMinimalRnR 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--dataset", type=str, required=True) 16 | 17 | parser.add_argument("--tfserving_ip", default="127.0.0.1") 18 | parser.add_argument("--tfserving_port", default=8501) 19 | 20 | parser.add_argument("--resources_path", default="/resources") 21 | parser.add_argument("--use_faiss_index", default=True, type=strtobool) 22 | 23 | parser.add_argument("--top_k", type=int, default=50) 24 | parser.add_argument("--passage_score_weight", type=float, default=None) 25 | 26 | parser.add_argument("--input_path", type=str, default=None, required=True) 27 | parser.add_argument("--output_path", type=str, default="predictions.jsonl", required=True) 28 | 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def main(args): 34 | logger = get_logger("minimal-rnr-qa") 35 | logger.info(vars(args)) 36 | 37 | minimal_rnr = TFServingMinimalRnR(args) 38 | minimal_rnr.inference_on_file(args.input_path, args.output_path, args.top_k, args.passage_score_weight) 39 | 40 | 41 | if __name__ == "__main__": 42 | args = get_args() 43 | main(args) 44 | --------------------------------------------------------------------------------