├── .gitignore ├── LICENSE.txt ├── README.md ├── config ├── bertserini-ds │ └── bertserini-ds-new-ranker.yaml ├── bertserini │ └── bertserini.yaml ├── multipassage-bert │ └── multipassage-bert.yaml └── sparta │ ├── cmrc2018-sparta-rbt.yaml │ ├── drcd-dev-sparta-rbt.yaml │ ├── nq-dev-sparta-spanbert.yaml │ ├── nq-test-sparta-spanbert.yaml │ ├── online-eval-zh-cmrc2018-sparta-rbt.yaml │ └── squad-sparta-spanbert.yaml ├── demo.py ├── example.py ├── imgs ├── demo.gif └── logo.png ├── reports ├── bertserini-ds-new-ranker.txt ├── bertserini.txt ├── cmrc2018-sparta-rbt.txt ├── drcd-dev-sparta-rbt.txt ├── multipassage-bert.txt ├── nq-dev-sparta-spanbert.txt ├── nq-test-sparta-spanbert.txt ├── online-eval-zh-cmrc2018-sparta-rbt.txt └── squad-sparta-spanbert.txt ├── requirements.txt └── soco_openqa ├── __init__.py ├── cloud_bucket.py ├── config.py ├── demo ├── __init__.py ├── helper.py ├── qa.py ├── ranker.py └── reader.py ├── evaluation.py ├── helper.py ├── pipeline.py ├── ranker.py ├── reader.py └── soco_mrc ├── models └── bert_model.py ├── mrc_model.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | https://wit.ai/snakeztc_WQ3S3/MyFirstApp/inbox 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | data/ 106 | data-local 107 | data 108 | .idea/ 109 | prod_models/ 110 | logs/ 111 | temp/ 112 | mlruns/ 113 | cache/ 114 | resources-local/ 115 | resources/ 116 | resources 117 | 118 | 119 | .vscode/ 120 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
A Simple and Fair Evaluation Library for Open-domain Question Answering 20 | 21 | 22 | # Introduction 23 | Open-domain QA Evaluation usually means days of tedious work on indexing large scale of data and building the pipeline. That's why we create SF-QA [[EACL 2021 Demo paper]](https://www.aclweb.org/anthology/2021.eacl-demos.2/). 24 | 25 | We make the evaluation process simple and fast by establishing a standard evalution system with the same source so researchers can focus on their proposed method instead of spending significant time to 26 | reproduce the existing approaches in research 27 | 28 | SF-QA helps you to evaluate your open-domain QA without building the entire open-domain QA pipeline. It supports: 29 | - Efficient Reader Comparison 30 | - Reproducible Research 31 | - Knowledge Source for Applications 32 | 33 | # Features 34 | ✨ **Easy evaluation framework:** Build especially for open-domain QA 35 | ✨ **Pre-trained Wiki dataset:** No need to train it yourself 36 | ✨ **Scaleable:** set your own configurations and evaluate on the open domain scale. 37 | ✨ **Open source:** Everyone can contribute 38 | 39 | 40 | 41 | 42 | # Installation 43 | SF-QA requires Python 3, Pytorch 1.6.0. Transformer 2.11.0, elasticsearch 7.5.0 44 | 45 | ``` 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | # How to use SFQA 50 | 51 | 52 | ## Usage 1: Ask questions and interact with open-QA demo 53 | 54 | ### 1. In `demo.py`, set reader model and ranker index you want to try from the following list: 55 | 56 | * reader model available: 57 | * English: 58 | - squad-ds-context-global-norm-2016sparta-from-pt 59 | - squad-chunk-global-norm-2016bm25-bert-base-uncased 60 | * Chinese: 61 | - cmrc2018-ds-context-global-norm-2018sparta-from-pt-v1 62 | - drcd-ds-context-global-norm-2017sparta-from-pt-v1 63 | 64 | * ranker index available: 65 | * English: 66 | - sparta-en-wiki-2016 67 | - bm25-en-wiki-2016 68 | * Chinese: 69 | - sparta-zh-wiki-2020 70 | 71 | ### 2. Run demo file 72 | 73 | ``` 74 | python demo.py 75 | ``` 76 |  77 | 78 | 79 | ## Usage 2: Reproduce previous research result 80 | Run the script to evaluate open QA with config.yaml file under `./config` folder. 81 | ``` 82 | python example.py --config ./config/sparta/squad-sparta-spanbert.yaml 83 | ``` 84 | 85 | ```yaml 86 | data: 87 | lang: en 88 | name: squad 89 | split: dev-v1.1.json 90 | ranker: 91 | use_cached: True 92 | cached_ranker_file: squad_dev_wiki_sparta_2016sparta.json 93 | reader: 94 | model_id: squad-ds-context-global-norm-2016sparta-from-pt 95 | param: 96 | n_gpu: 1 97 | score_weight: 0.8 98 | top_k: 50 99 | 100 | ``` 101 | 102 | 103 | 104 | --- 105 | 106 | ## SF-QA APIs 107 | - POST 108 | - https://api.soco.ai/v1/sfqa/query 109 | - Header 110 | - Authorization: soco_research 111 | - Content-Type: application/json 112 | 113 | - Body 114 | ``` 115 | { 116 | "lang": "en", 117 | "index": "wiki-frame-2016", 118 | "model_id": "spanbert-large-squad2", 119 | "query": "What Swiss city was the center of the Calvinist movement?", 120 | "params": { 121 | "top_k": 10, 122 | "n_best": 2, 123 | "ranker_only":true 124 | } 125 | } 126 | ``` 127 | 128 | - Response 129 | ``` 130 | { 131 | "result": [ 132 | { 133 | "value": "Geneva", 134 | "score": 8.959972752954101, 135 | "prob": 0.981730729341507, 136 | "source": { 137 | "context": "Huguenot. The nickname may have been a combined reference to the Swiss politician Besançon Hugues (died 1532) and the religiously conflicted nature of Swiss republicanism in his time, using a clever derogatory pun on the name \"Hugues\" by way of the Dutch word \"Huisgenoten\" (literally \"housemates\"), referring to the connotations of a somewhat related word in German \"Eidgenosse\" (\"Confederates\" as in \"a citizen of one of the states of the Swiss Confederacy\"). Geneva was John Calvin's adopted home and the centre of the Calvinist movement. In Geneva, Hugues, though Catholic, was a leader of the \"Confederate Party\", so called because it favoured independence from the Duke of Savoy through an alliance between the city-state of Geneva and the Swiss Confederation.", 138 | "url": null, 139 | "title": null 140 | } 141 | }, 142 | ... 143 | ] 144 | } 145 | ``` 146 | When ranker_only is true, it will just return the sentence level answer in the value with the context. 147 | ``` 148 | { 149 | "result": [ 150 | { 151 | "score": 17.94549, 152 | "answer": { 153 | "answer_start": 317, 154 | "context": "Canton of Geneva. As is the case in several other Swiss cantons (e.g. Ticino, Neuchâtel, and Jura), this canton is referred to as a republic within the Swiss Confederation. The canton of Geneva is located in the southwestern corner of Switzerland; and is considered one of the most cosmopolitan areas of the country. As a center of the Calvinist Reformation, the city of Geneva has had a great influence on the canton, which essentially consists of the city and its hinterlands. Geneva was a Prince-Bishopric of the Holy Roman Empire from 1154, but from 1290, secular authority over the citizens was divided from the bishop's authority, at first only lower jurisdiction, the office of vidame given to François de Candie in 1314, but from 1387 the bishops granted the citizens of Geneva full communal self-government.", 155 | "displayable": false, 156 | "value": "As a center of the Calvinist Reformation, the city of Geneva has had a great influence on the canton, which essentially consists of the city and its hinterlands.", 157 | "add_to_qa_index": true 158 | }, 159 | "meta": { 160 | "2016": true, 161 | "doc_id": "Canton of Geneva", 162 | "chunk_id": "91a26e05-0c9c-470b-a9e1-9e8635af265c", 163 | "chunk_type": "content" 164 | } 165 | }, 166 | ``` 167 | 168 | # Download Data 169 | 170 | ## Preprocessed wikipedia dump 171 | The processed Wikipedia can be downloaded in the following links: 172 | - Paragraph Level: [wiki_2016](https://sfqa.s3.us-east-2.amazonaws.com/wikidump/2016/wiki_2016_paragraphs.jsonl.bz2) 173 | - Sentence Level: [wiki_2016](https://sfqa.s3.us-east-2.amazonaws.com/wikidump/2016/wiki_2016_frames.jsonl.bz2) 174 | - Phrase Level: You can use either paragraph level wiki or sentence level wiki to get n-best passages, then get the 1-best phrase using a machine reader. 175 | 176 | ## Cached Retrieval Results 177 | - [Wiki_2016_BM25](https://sfqa.s3.us-east-2.amazonaws.com/data/wiki-frame-2016_sent_bm25_context.json) 178 | - [Wiki_2016_SPARTA](https://sfqa.s3.us-east-2.amazonaws.com/data/wiki-frame-2016_sparta_context.json) 179 | 180 | Please contact to contact@soco.ai to contribute your cached retrieval results. 181 | 182 | 183 | # How to Contribute 184 | Contribution is welcomed to this project! 185 | To contribute via pull request, follow these steps: 186 | 187 | - Create an issue describing the feature you want to work on 188 | - Write your code, tests and documentation, and format them with black 189 | - Create a pull request describing your changes 190 | 191 | 192 | # License 193 | This project is licensed under the Apache License, Version 2.0. 194 | -------------------------------------------------------------------------------- /config/bertserini-ds/bertserini-ds-new-ranker.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: en 3 | name: squad 4 | split: dev-v1.1.json 5 | ranker: 6 | use_cached: True 7 | cached_ranker_file: squad_dev_bm25_wiki_2016_en_ctx_bm25_serini.json 8 | reader: 9 | model_id: squad-ds-para-serini-global-norm-2016bm25-bert-base-cased-from-pt 10 | param: 11 | n_gpu: 1 12 | score_weight: 0.9 13 | top_k: 10 14 | -------------------------------------------------------------------------------- /config/bertserini/bertserini.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: en 3 | name: squad 4 | split: dev-v1.1.json 5 | ranker: 6 | use_cached: True 7 | cached_ranker_file: squad_dev_2016_wiki_chunk_ctx_bm25.json 8 | reader: 9 | model_id: squad-chunk-global-norm-2016bm25-bert-base-uncased 10 | param: 11 | score_weight: 0.8 12 | top_k: 50 13 | -------------------------------------------------------------------------------- /config/multipassage-bert/multipassage-bert.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: en 3 | name: squad 4 | split: dev-v1.1.json 5 | ranker: 6 | use_cached: True 7 | cached_ranker_file: squad_dev_bm25_wiki_2016_en_ctx_bm25_serini.json 8 | reader: 9 | model_id: squad-chunk-global-norm-2016bm25-bert-large-reranker 10 | use_reranker: True 11 | rerank_size: 30 12 | param: 13 | n_gpu: 1 14 | score_weight: 0.9 15 | top_k: 100 16 | -------------------------------------------------------------------------------- /config/sparta/cmrc2018-sparta-rbt.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: zh 3 | name: cmrc2018 4 | split: dev-v1.1.json 5 | ranker: 6 | use_cached: True 7 | cached_ranker_file: cmrc2018_dev_zh-wiki-frame-2018_sparta.json 8 | reader: 9 | model_id: cmrc2018-ds-context-global-norm-2018sparta-from-pt-v1 10 | param: 11 | score_weight: 0.8 12 | top_k: 50 -------------------------------------------------------------------------------- /config/sparta/drcd-dev-sparta-rbt.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: zh 3 | name: drcd 4 | split: dev-v1.1.json 5 | ranker: 6 | use_cached: True 7 | cached_ranker_file: drcd_dev_zh-wiki-frame-2017_sparta.json 8 | reader: 9 | model_id: drcd-ds-context-global-norm-2017sparta-from-pt-v1 10 | param: 11 | score_weight: 0.9 12 | top_k: 50 -------------------------------------------------------------------------------- /config/sparta/nq-dev-sparta-spanbert.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: en 3 | name: nq-open 4 | split: dev-v1.1.json 5 | ranker: 6 | use_cached: True 7 | cached_ranker_file: nq-open_dev_wiki-frame-2018_sparta.json 8 | reader: 9 | model_id: nq-ds-context-global-norm-2018sparta-from-pt-v1 10 | param: 11 | n_gpu: 2 12 | score_weight: 0.8 13 | top_k: 100 14 | -------------------------------------------------------------------------------- /config/sparta/nq-test-sparta-spanbert.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: en 3 | name: nq-open 4 | split: test-v1.1.json 5 | ranker: 6 | use_cached: True 7 | cached_ranker_file: nq-open_test_wiki-frame-2018_sparta.json 8 | reader: 9 | model_id: nq-ds-context-global-norm-2018sparta-from-pt-v1 10 | param: 11 | n_gpu: 1 12 | score_weight: 0.8 13 | top_k: 100 14 | -------------------------------------------------------------------------------- /config/sparta/online-eval-zh-cmrc2018-sparta-rbt.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: zh 3 | name: cmrc2018 4 | split: dev-v1.1.json 5 | ranker: 6 | use_cached: False 7 | model: 8 | name: sparta 9 | es_index_name: zh-wiki-frame-2020 10 | reader: 11 | model_id: cmrc2018-ds-context-global-norm-2018sparta-from-pt-v1 12 | param: 13 | score_weight: 0.8 14 | top_k: 10 15 | -------------------------------------------------------------------------------- /config/sparta/squad-sparta-spanbert.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | lang: en 3 | name: squad 4 | split: dev-v1.1.json 5 | ranker: 6 | use_cached: True 7 | cached_ranker_file: squad_dev_wiki_sparta_2016sparta.json 8 | reader: 9 | model_id: squad-ds-context-global-norm-2016sparta-from-pt 10 | param: 11 | n_gpu: 2 12 | score_weight: 0.8 13 | top_k: 10 14 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from soco_openqa.demo import Reader, Ranker, QA, display 2 | 3 | reader = Reader(model='squad-ds-context-global-norm-2016sparta-from-pt') 4 | 5 | ranker = Ranker(index='sparta-en-wiki-2016') 6 | 7 | qa = QA(reader, ranker) 8 | 9 | while True: 10 | q = input('Enter a query: ') 11 | results = qa.query(q, num_results=3) 12 | display(results) 13 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | 5 | import soco_openqa.helper as helper 6 | from soco_openqa.pipeline import OpenQA 7 | from soco_openqa.config import get_config 8 | from soco_openqa.evaluation import evaluate 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--config", required=True, type=str, help='yaml file name') 17 | args = parser.parse_args() 18 | 19 | config = get_config(args.config) 20 | data = helper.load_json(file_dir=config.data.name, file_name=config.data.split) 21 | 22 | # Initiate evaluation pipeline 23 | qa = OpenQA(config) 24 | predictions = qa.predict(data) 25 | 26 | # Evaluate predictions 27 | results = evaluate(config.data.lang, data, predictions, config) 28 | logger.info(results) 29 | 30 | # Save locally 31 | helper.save_logs(config.dump(), results, save_name=config.config_name) 32 | -------------------------------------------------------------------------------- /imgs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soco-ai/SF-QA/7996a6f6fbf4370eb7913658d6f8d61466a5ccac/imgs/demo.gif -------------------------------------------------------------------------------- /imgs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soco-ai/SF-QA/7996a6f6fbf4370eb7913658d6f8d61466a5ccac/imgs/logo.png -------------------------------------------------------------------------------- /reports/bertserini-ds-new-ranker.txt: -------------------------------------------------------------------------------- 1 | ********************{ 2 | "data": { 3 | "lang": "en", 4 | "name": "squad/dev-v1.1" 5 | }, 6 | "ranker": { 7 | "use_cached": true, 8 | "cached_ranker_id": "squad_dev_bm25_wiki_2016_en_ctx_bm25" 9 | }, 10 | "reader": { 11 | "model_id": "squad-ds-para-v3-global-norm-2016bm25-bert-base-uncased" 12 | }, 13 | "param": { 14 | "score_weight": 1, 15 | "top_k": 100 16 | }, 17 | "config_name": "bertserini-ds-new-ranker" 18 | } 19 | exact_match=39.480 f1=46.051 R@1=29.678 R@5=49.716 R@10=57.275 R@50=72.914 R@100=78.013 20 | 21 | ********************{ 22 | "data": { 23 | "lang": "en", 24 | "name": "squad/dev-v1.1" 25 | }, 26 | "ranker": { 27 | "use_cached": true, 28 | "cached_ranker_id": "squad_dev_bm25_wiki_2016_en_ctx_bm25_v3" 29 | }, 30 | "reader": { 31 | "model_id": "squad-ds-para-global-norm-2016bm25-bert-base-cased" 32 | }, 33 | "param": { 34 | "score_weight": 1, 35 | "top_k": 100 36 | }, 37 | "config_name": "bertserini-ds-new-ranker" 38 | } 39 | exact_match=39.470 f1=46.092 R@1=32.942 R@5=53.179 R@10=60.653 R@50=75.459 R@100=80.331 40 | 41 | ********************{ 42 | "data": { 43 | "lang": "en", 44 | "name": "squad/dev-v1.1" 45 | }, 46 | "ranker": { 47 | "use_cached": true, 48 | "cached_ranker_id": "squad_dev_bm25_wiki_2016_en_ctx_bm25_v3" 49 | }, 50 | "reader": { 51 | "model_id": "squad-ds-para-v3-global-norm-2016bm25-bert-base-cased-from-pt" 52 | }, 53 | "param": { 54 | "score_weight": 0.8, 55 | "top_k": 100 56 | }, 57 | "config_name": "bertserini-ds-new-ranker" 58 | } 59 | exact_match=40.937 f1=48.452 R@1=32.942 R@5=53.179 R@10=60.653 R@50=75.459 R@100=80.331 60 | 61 | ********************{ 62 | "data": { 63 | "lang": "en", 64 | "name": "squad/dev-v1.1" 65 | }, 66 | "ranker": { 67 | "use_cached": true, 68 | "cached_ranker_id": "squad_dev_bm25_wiki_2016_en_ctx_bm25_serini" 69 | }, 70 | "reader": { 71 | "model_id": "squad-ds-para-serini-global-norm-2016bm25-bert-base-cased-from-pt" 72 | }, 73 | "param": { 74 | "score_weight": 0.9, 75 | "top_k": 100 76 | }, 77 | "config_name": "bertserini-ds-new-ranker" 78 | } 79 | exact_match=51.599 f1=59.154 R@1=43.198 R@5=64.598 R@10=71.523 R@50=82.904 R@100=86.604 80 | 81 | ******************** 82 | 2021-01-04T18:12:11 83 | config_name: bertserini-ds-new-ranker 84 | data: 85 | lang: en 86 | name: squad 87 | split: dev-v1.1.json 88 | param: 89 | n_gpu: 1 90 | score_weight: 0.9 91 | top_k: 10 92 | ranker: 93 | cached_ranker_file: squad_dev_bm25_wiki_2016_en_ctx_bm25_serini.json 94 | model: 95 | es_index_name: null 96 | name: null 97 | use_cached: true 98 | reader: 99 | model_id: squad-ds-para-serini-global-norm-2016bm25-bert-base-cased-from-pt 100 | rerank_size: null 101 | use_reranker: false 102 | 103 | exact_match=45.629 f1=52.962 R@1=43.198 R@5=64.598 R@10=71.523 104 | 105 | -------------------------------------------------------------------------------- /reports/bertserini.txt: -------------------------------------------------------------------------------- 1 | ********************{ 2 | "data": { 3 | "lang": "en", 4 | "name": "squad/dev-v1.1" 5 | }, 6 | "ranker": { 7 | "use_cached": true, 8 | "cached_ranker_id": "squad_dev_2016_wiki_chunk_ctx_bm25" 9 | }, 10 | "reader": { 11 | "model_id": "squad-chunk-global-norm-2016bm25-bert-base-uncased" 12 | }, 13 | "param": { 14 | "score_weight": 0.8, 15 | "top_k": 100 16 | }, 17 | "config_name": "bertserini" 18 | } 19 | exact_match=41.173 f1=48.598 R@1=41.854 R@5=62.923 R@10=70.350 R@50=83.103 R@100=83.103 20 | 21 | -------------------------------------------------------------------------------- /reports/cmrc2018-sparta-rbt.txt: -------------------------------------------------------------------------------- 1 | ********************{ 2 | "data": { 3 | "lang": "zh", 4 | "name": "cmrc2018/dev-v1.1" 5 | }, 6 | "ranker": { 7 | "use_cached": true, 8 | "cached_ranker_id": "cmrc2018_dev_zh-wiki-frame-2018_sparta" 9 | }, 10 | "reader": { 11 | "model_id": "cmrc2018-ds-context-global-norm-2018sparta-from-pt-v1" 12 | }, 13 | "param": { 14 | "score_weight": 0.8, 15 | "top_k": 50 16 | }, 17 | "config_name": "cmrc2018-sparta-rbt" 18 | } 19 | exact_match=63.063 f1=80.202 R@1=70.053 R@5=83.318 R@10=87.232 R@50=93.476 R@100=93.476 20 | 21 | -------------------------------------------------------------------------------- /reports/drcd-dev-sparta-rbt.txt: -------------------------------------------------------------------------------- 1 | ********************{ 2 | "data": { 3 | "lang": "zh", 4 | "name": "drcd/dev-v1.1" 5 | }, 6 | "ranker": { 7 | "use_cached": true, 8 | "cached_ranker_id": "drcd_dev_zh-wiki-frame-2017_sparta" 9 | }, 10 | "reader": { 11 | "model_id": "drcd-ds-context-global-norm-2017sparta-from-pt-v1" 12 | }, 13 | "param": { 14 | "score_weight": 0.8, 15 | "top_k": 50 16 | }, 17 | "config_name": "drcd-dev-sparta-rbt" 18 | } 19 | exact_match=61.975 f1=73.576 R@1=60.471 R@5=74.461 R@10=78.547 R@50=86.237 R@100=86.237 20 | 21 | ********************{ 22 | "data": { 23 | "lang": "zh", 24 | "name": "drcd/dev-v1.1" 25 | }, 26 | "ranker": { 27 | "use_cached": true, 28 | "cached_ranker_id": "drcd_dev_zh-wiki-frame-2017_sparta" 29 | }, 30 | "reader": { 31 | "model_id": "drcd-ds-context-global-norm-2017sparta-from-pt-v1" 32 | }, 33 | "param": { 34 | "score_weight": 0.9, 35 | "top_k": 50 36 | }, 37 | "config_name": "drcd-dev-sparta-rbt" 38 | } 39 | exact_match=63.025 f1=74.535 R@1=60.471 R@5=74.461 R@10=78.547 R@50=86.237 R@100=86.237 40 | 41 | ******************** 42 | config_name: drcd-dev-sparta-rbt 43 | data: 44 | lang: zh 45 | name: drcd/dev-v1.1 46 | param: 47 | n_gpu: 2 48 | score_weight: 0.9 49 | top_k: 1 50 | ranker: 51 | cached_ranker_id: drcd_dev_zh-wiki-frame-2017_sparta 52 | use_cached: true 53 | reader: 54 | model_id: drcd-ds-context-global-norm-2017sparta-from-pt-v1 55 | rerank_size: null 56 | use_reranker: false 57 | 58 | exact_match=50.341 f1=61.127 R@1=60.471 R@5=60.471 R@10=60.471 R@50=60.471 R@100=60.471 59 | 60 | ******************** 61 | config_name: drcd-dev-sparta-rbt 62 | data: 63 | lang: zh 64 | name: drcd/dev-v1.1 65 | param: 66 | n_gpu: 2 67 | score_weight: 0.9 68 | top_k: 1 69 | ranker: 70 | cached_ranker_id: drcd_dev_zh-wiki-frame-2017_sparta 71 | model: 72 | es_index_name: null 73 | name: null 74 | use_cached: true 75 | reader: 76 | model_id: drcd-ds-context-global-norm-2017sparta-from-pt-v1 77 | rerank_size: null 78 | use_reranker: false 79 | 80 | exact_match=50.341 f1=61.127 R@1=60.471 R@5=60.471 R@10=60.471 R@50=60.471 R@100=60.471 81 | 82 | ******************** 83 | config_name: drcd-dev-sparta-rbt 84 | data: 85 | lang: zh 86 | name: drcd/dev-v1.1 87 | param: 88 | n_gpu: 2 89 | score_weight: 0.9 90 | top_k: 1 91 | ranker: 92 | cached_ranker_id: drcd_dev_zh-wiki-frame-2017_sparta 93 | model: 94 | es_index_name: null 95 | name: null 96 | use_cached: true 97 | reader: 98 | model_id: drcd-ds-context-global-norm-2017sparta-from-pt-v1 99 | rerank_size: null 100 | use_reranker: false 101 | 102 | exact_match=50.341 f1=61.127 R@1=60.471 103 | 104 | ******************** 105 | 2021-01-03T02:18:17 106 | config_name: drcd-dev-sparta-rbt 107 | data: 108 | lang: zh 109 | name: drcd/dev-v1.1 110 | param: 111 | n_gpu: 2 112 | score_weight: 0.9 113 | top_k: 1 114 | ranker: 115 | cached_ranker_id: drcd_dev_zh-wiki-frame-2017_sparta 116 | model: 117 | es_index_name: null 118 | name: null 119 | use_cached: true 120 | reader: 121 | model_id: drcd-ds-context-global-norm-2017sparta-from-pt-v1 122 | rerank_size: null 123 | use_reranker: false 124 | 125 | exact_match=50.341 f1=61.127 R@1=60.471 126 | 127 | ******************** 128 | 2021-01-03T23:23:15 129 | config_name: drcd-dev-sparta-rbt 130 | data: 131 | lang: zh 132 | name: drcd 133 | split: dev-v1.1.json 134 | param: 135 | n_gpu: 2 136 | score_weight: 0.9 137 | top_k: 1 138 | ranker: 139 | cached_ranker_file: drcd_dev_zh-wiki-frame-2017_sparta.json 140 | model: 141 | es_index_name: null 142 | name: null 143 | use_cached: true 144 | reader: 145 | model_id: drcd-ds-context-global-norm-2017sparta-from-pt-v1 146 | rerank_size: null 147 | use_reranker: false 148 | 149 | exact_match=50.341 f1=61.127 R@1=60.471 150 | 151 | -------------------------------------------------------------------------------- /reports/multipassage-bert.txt: -------------------------------------------------------------------------------- 1 | 2 | ********************{ 3 | "data": { 4 | "lang": "en", 5 | "name": "squad/dev-v1.1" 6 | }, 7 | "ranker": { 8 | "use_cached": true, 9 | "cached_ranker_id": "squad_dev_bm25_wiki_2016_en_ctx_bm25_serini" 10 | }, 11 | "reader": { 12 | "model_id": "squad-chunk-global-norm-2016bm25-bert-large-reranker", 13 | "use_reranker": true, 14 | "rerank_size": 30 15 | }, 16 | "param": { 17 | "score_weight": 0.9, 18 | "top_k": 100 19 | }, 20 | "config_name": "multipassage-bert" 21 | } 22 | exact_match=53.179 f1=60.712 R@1=43.198 R@5=64.598 R@10=71.523 R@50=82.904 R@100=86.604 23 | 24 | ******************** 25 | 2021-01-04T23:03:33 26 | config_name: multipassage-bert 27 | data: 28 | lang: en 29 | name: squad 30 | split: dev-v1.1.json 31 | param: 32 | n_gpu: 1 33 | score_weight: 0.9 34 | top_k: 100 35 | ranker: 36 | cached_ranker_file: squad_dev_bm25_wiki_2016_en_ctx_bm25_serini.json 37 | model: 38 | es_index_name: null 39 | name: null 40 | use_cached: true 41 | reader: 42 | model_id: squad-chunk-global-norm-2016bm25-bert-large-reranker 43 | rerank_size: 30 44 | use_reranker: true 45 | 46 | exact_match=53.377 f1=60.907 R@1=43.198 R@5=64.598 R@10=71.523 R@50=82.904 R@100=86.604 47 | 48 | -------------------------------------------------------------------------------- /reports/nq-dev-sparta-spanbert.txt: -------------------------------------------------------------------------------- 1 | ********************{ 2 | "data": { 3 | "lang": "en", 4 | "name": "nq-open/dev-v1.1" 5 | }, 6 | "ranker": { 7 | "use_cached": true, 8 | "cached_ranker_id": "nq-open_dev_wiki-frame-2018_sparta" 9 | }, 10 | "reader": { 11 | "model_id": "nq-ds-context-global-norm-2018sparta-from-pt-v1" 12 | }, 13 | "param": { 14 | "n_gpu": 2, 15 | "score_weight": 0.8, 16 | "top_k": 50 17 | }, 18 | "config_name": "nq-dev-sparta-spanbert" 19 | } 20 | exact_match=36.782 f1=47.441 R@1=27.761 R@5=49.572 R@10=57.086 R@50=69.190 R@100=69.190 21 | 22 | -------------------------------------------------------------------------------- /reports/nq-test-sparta-spanbert.txt: -------------------------------------------------------------------------------- 1 | ********************{ 2 | "data": { 3 | "lang": "en", 4 | "name": "nq-open/test-v1.1" 5 | }, 6 | "ranker": { 7 | "use_cached": true, 8 | "cached_ranker_id": "nq-open_test_wiki-frame-2018_sparta" 9 | }, 10 | "reader": { 11 | "model_id": "nq-ds-context-global-norm-2018sparta-from-pt-v1" 12 | }, 13 | "param": { 14 | "n_gpu": 2, 15 | "score_weight": 0.8, 16 | "top_k": 50 17 | }, 18 | "config_name": "nq-test-sparta-spanbert" 19 | } 20 | exact_match=37.452 f1=46.237 R@1=28.643 R@5=51.828 R@10=59.778 R@50=73.296 R@100=73.296 21 | 22 | ******************** 23 | 2021-01-03T01:31:40 24 | config_name: nq-test-sparta-spanbert 25 | data: 26 | lang: en 27 | name: nq-open/test-v1.1 28 | param: 29 | n_gpu: 1 30 | score_weight: 0.8 31 | top_k: 10 32 | ranker: 33 | cached_ranker_id: nq-open_test_wiki-frame-2018_sparta 34 | model: 35 | es_index_name: null 36 | name: null 37 | use_cached: true 38 | reader: 39 | model_id: nq-ds-context-global-norm-2018sparta-from-pt-v1 40 | rerank_size: null 41 | use_reranker: false 42 | 43 | exact_match=35.263 f1=44.140 R@1=28.643 R@5=51.828 R@10=59.778 44 | 45 | -------------------------------------------------------------------------------- /reports/online-eval-zh-cmrc2018-sparta-rbt.txt: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "lang": "zh", 4 | "name": "cmrc2018/dev-v1.1" 5 | }, 6 | "ranker": { 7 | "use_cached": false, 8 | "model": { 9 | "name": "sparta", 10 | "es_index_name": "zh-wiki-frame-2020" 11 | } 12 | }, 13 | "reader": { 14 | "model_id": "cmrc2018-ds-context-global-norm-2018sparta-from-pt-v1" 15 | }, 16 | "param": { 17 | "score_weight": 0.8, 18 | "top_k": 10 19 | }, 20 | "config_name": "online-eval-zh-cmrc2018-sparta-rbt" 21 | } 22 | exact_match=55.483 f1=72.940 R@1=64.368 R@5=78.565 R@10=82.355 23 | -------------------------------------------------------------------------------- /reports/squad-sparta-spanbert.txt: -------------------------------------------------------------------------------- 1 | ******************** 2 | 2021-01-04T00:57:17 3 | config_name: squad-sparta-spanbert 4 | data: 5 | lang: en 6 | name: squad 7 | split: dev-v1.1.json 8 | param: 9 | n_gpu: 1 10 | score_weight: 0.8 11 | top_k: 1 12 | ranker: 13 | cached_ranker_file: squad_dev_wiki_sparta_2016sparta.json 14 | model: 15 | es_index_name: null 16 | name: null 17 | use_cached: true 18 | reader: 19 | model_id: squad-ds-context-global-norm-2016sparta-from-pt 20 | rerank_size: null 21 | use_reranker: false 22 | 23 | exact_match=41.845 f1=47.967 R@1=50.757 24 | 25 | ******************** 26 | 2021-01-04T14:58:04 27 | config_name: squad-sparta-spanbert 28 | data: 29 | lang: en 30 | name: squad 31 | split: dev-v1.1.json 32 | param: 33 | n_gpu: 2 34 | score_weight: 0.8 35 | top_k: 1 36 | ranker: 37 | cached_ranker_file: squad_dev_wiki_sparta_2016sparta.json 38 | model: 39 | es_index_name: null 40 | name: null 41 | use_cached: true 42 | reader: 43 | model_id: squad-ds-context-global-norm-2016sparta-from-pt 44 | rerank_size: null 45 | use_reranker: false 46 | 47 | exact_match=41.845 f1=47.967 R@1=50.757 48 | 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | numpy>=1.16.3 3 | elasticsearch >= 7.0.3 4 | soco-encoders>=0.2.1 5 | soco-mrc==0.0.9 6 | soco-device==0.0.5.2 7 | pyyaml==5.3.1 8 | yacs==0.1.8 9 | transformers==2.11.0 10 | rich==3.7.4.3 11 | Cython==0.29.21 12 | -------------------------------------------------------------------------------- /soco_openqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soco-ai/SF-QA/7996a6f6fbf4370eb7913658d6f8d61466a5ccac/soco_openqa/__init__.py -------------------------------------------------------------------------------- /soco_openqa/cloud_bucket.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import oss2 3 | import os 4 | import time 5 | from tqdm import tqdm 6 | 7 | 8 | class CloudBucket(object): 9 | roles = {'reader-s3': ('AKIAIXQ4BV6ORFZ3JKPA', 'D+9W2Zii2Hvpo6G0IJ77tofKUM59ZasC/HE3Kf/w'), 10 | 'reader-oss': ('LTAINYtHd5knNlAg', 'HIg89gNwNYpbEKoC38SdnYCn3Db8cI')} 11 | 12 | def __init__(self, region='us', permission='reader'): 13 | self.region = region.lower() 14 | self.permission = permission.lower() 15 | if permission == 'reader': 16 | s3_key = self.roles['reader-s3'] 17 | oss_key = self.roles['reader-oss'] 18 | else: 19 | raise Exception("Unknown permission {}".format(permission)) 20 | 21 | self._s3 = boto3.resource('s3', aws_access_key_id=s3_key[0], aws_secret_access_key=s3_key[1]) 22 | self._s3_bucket = self._s3.Bucket('sfqa') 23 | self._oss_bucket = oss2.Bucket(oss2.Auth(oss_key[0], oss_key[1]), 24 | str('http://oss-cn-hongkong.aliyuncs.com'), 25 | str('convmind-models')) 26 | 27 | def safe_mkdir(self, path): 28 | if not os.path.exists(path): 29 | os.makedirs(path, exist_ok=True) 30 | 31 | def download(self, folder_dir, files, local_dir): 32 | print('start downloading {} to {} ...'.format(files, local_dir)) 33 | start = time.time() 34 | for f in tqdm(files, desc='downloading files'): 35 | try: 36 | if self.region.lower() == 'cn': 37 | self._oss_bucket.get_object_to_file('{}/{}'.format(folder_dir, f), os.path.join(local_dir, f)) 38 | else: 39 | self._s3_bucket.download_file('{}/{}'.format(folder_dir, f), os.path.join(local_dir, f)) 40 | except: 41 | print("Failed to download {}".format(f)) 42 | 43 | end = time.time() 44 | print(("download models using {:.4f} sec".format(end - start))) 45 | 46 | def download_dir(self, folder_dir, dirs, local_dir): 47 | start = time.time() 48 | for d in dirs: 49 | cur_remote_dir = os.path.join(folder_dir, d) 50 | cur_local_dir = os.path.join(local_dir, d) 51 | 52 | self.safe_mkdir(cur_local_dir) 53 | 54 | files = [] 55 | if self.region.lower() == 'cn': 56 | for o in self._oss_bucket.list_objects(prefix=cur_remote_dir).object_list: 57 | f_name = o.key.split('/')[-1] 58 | if f_name != '': 59 | files.append(o.key.split('/')[-1]) 60 | else: 61 | for o in self._s3_bucket.objects.filter(Prefix=cur_remote_dir): 62 | f_name = o.key.split('/')[-1] 63 | if f_name != '': 64 | files.append(o.key.split('/')[-1]) 65 | 66 | self.download(cur_remote_dir, files, cur_local_dir) 67 | 68 | end = time.time() 69 | print(("download models using {:.4f} sec".format(end - start))) 70 | 71 | 72 | def download_model(self, asset_dir, asset_id, local_dir='resources'): 73 | self.safe_mkdir(local_dir) 74 | 75 | if not os.path.exists(os.path.join(local_dir, asset_id, 'config.json')): 76 | self.download_dir(asset_dir, [asset_id], local_dir) 77 | else: 78 | pass 79 | 80 | def download_file(self, file_dir, file_name, cloud_base_dir='data', local_base_dir='data'): 81 | """ 82 | download a single file given file dir and file name 83 | 84 | :param file_dir: [description] 85 | :type file_dir: [type] 86 | :param file_name: [description] 87 | :type file_name: [type] 88 | :param cloud_dir: [description], defaults to 'data' 89 | :type cloud_dir: str, optional 90 | :param local_dir: [description], defaults to 'data' 91 | :type local_dir: str, optional 92 | """ 93 | local_dir = os.path.join(local_base_dir, file_dir) if local_base_dir else file_dir 94 | cloud_dir = os.path.join(cloud_base_dir, file_dir) if cloud_base_dir else file_dir 95 | 96 | self.safe_mkdir(local_dir) 97 | if not os.path.exists(os.path.join(local_dir, file_name)): 98 | self.download(cloud_dir, [file_name], local_dir) 99 | else: 100 | pass 101 | 102 | 103 | if __name__ == '__main__': 104 | b = CloudBucket('cn') 105 | for d in b._oss_bucket.list_objects(prefix='mrc-models').object_list: 106 | print(d.key) 107 | -------------------------------------------------------------------------------- /soco_openqa/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from yacs.config import CfgNode as CN 3 | import soco_openqa.helper as helper 4 | 5 | _C = CN() 6 | 7 | _C.config_name = None 8 | 9 | # dataset 10 | _C.data = CN() 11 | _C.data.lang = 'en' 12 | _C.data.name = None 13 | _C.data.split = None 14 | 15 | # ranker 16 | _C.ranker = CN() 17 | _C.ranker.use_cached = True 18 | _C.ranker.cached_ranker_file = None 19 | 20 | 21 | ## online ranker if use_cached = False 22 | _C.ranker.model = CN() 23 | _C.ranker.model.name = None 24 | _C.ranker.model.es_index_name = None 25 | 26 | 27 | # reader 28 | _C.reader = CN() 29 | _C.reader.model_id = None 30 | _C.reader.use_reranker = False 31 | _C.reader.rerank_size = None 32 | 33 | # experiment setting 34 | _C.param = CN() 35 | _C.param.score_weight = 0.8 36 | _C.param.top_k = 50 37 | _C.param.n_gpu = 2 38 | 39 | 40 | def get_config(config_file): 41 | 42 | config = _C.clone() 43 | config.merge_from_file(config_file) 44 | 45 | # add config file name for saving log 46 | config_name = helper.get_name_from_path(config_file) 47 | config.merge_from_list(['config_name', config_name]) 48 | config.freeze() 49 | 50 | return config 51 | -------------------------------------------------------------------------------- /soco_openqa/demo/__init__.py: -------------------------------------------------------------------------------- 1 | from soco_openqa.demo.ranker import Ranker 2 | from soco_openqa.demo.reader import Reader 3 | from soco_openqa.demo.qa import QA 4 | from soco_openqa.demo.helper import display 5 | -------------------------------------------------------------------------------- /soco_openqa/demo/helper.py: -------------------------------------------------------------------------------- 1 | import re 2 | from rich.console import Console 3 | from rich.table import Table 4 | 5 | 6 | def _normalize_text(text): 7 | return re.sub('\s+', ' ', text) 8 | 9 | def display(results): 10 | 11 | console = Console() 12 | table = Table(show_header=True, header_style="bold magenta", show_lines=True) 13 | table.add_column("doc_id",style="dim",width=12) 14 | table.add_column("passage") 15 | table.add_column("answer",width=30) 16 | table.add_column("score",style="dim",width=15) 17 | 18 | for r in results: 19 | doc_id = str(r['source']['doc_id']) 20 | passage = str(r['source']['context']) 21 | passage = _normalize_text(passage) 22 | answer_span = r['answer_span'] 23 | answer = str(r['value']) 24 | score = '{:.4f}'.format(r['score']) 25 | passage_with_ans = '{}[red]{}[/red]{}'.format(passage[:answer_span[0]], passage[answer_span[0]:answer_span[1]], passage[answer_span[1]:]) 26 | 27 | table.add_row( 28 | doc_id, passage_with_ans, answer, score, 29 | ) 30 | 31 | console.print(table) 32 | -------------------------------------------------------------------------------- /soco_openqa/demo/qa.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(level=logging.INFO) 3 | logger = logging.getLogger(__name__) 4 | 5 | class QA: 6 | def __init__(self, Reader, Ranker): 7 | self.Reader = Reader 8 | self.Ranker = Ranker 9 | 10 | def query(self, query, num_results=10): 11 | logger.info('Start ranking...') 12 | top_passages = self.Ranker.query(query) 13 | logger.info('Start reading...') 14 | results = self.Reader.predict(query, top_passages) 15 | 16 | return results[:num_results] 17 | -------------------------------------------------------------------------------- /soco_openqa/demo/ranker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from requests.auth import HTTPBasicAuth 4 | 5 | host_url = "https://api.soco.ai/v1/sfqa/query" 6 | 7 | INDEX_MAP = { 8 | 'sparta-en-wiki-2016': {'lang': 'en', 'index': 'wiki-frame-2016'}, 9 | 'sparta-zh-wiki-2020': {'lang': 'zh', 'index': 'zh-wiki-frame-2020'}, 10 | 'bm25-en-wiki-2016': {'lang': 'en', 'index': 'bm25_wiki_2016_en'}, 11 | } 12 | 13 | class Ranker: 14 | def __init__(self, index): 15 | if index not in INDEX_MAP: 16 | raise ValueError('{} not existed, try one from {}'.format(index, INDEX_MAP.keys())) 17 | self.index = INDEX_MAP[index] 18 | 19 | 20 | def query(self, query): 21 | """ 22 | query api and get ranker results 23 | 24 | :param query: a string of natural language question 25 | :type query: str 26 | :return: a list of dictionaries containing topn retrieved answers 27 | :rtype: list 28 | """ 29 | headers = {"Accept": "application/json", "Authorization": 'soco_research'} 30 | json_body = { 31 | "lang": self.index['lang'], 32 | "index": self.index['index'], 33 | "model_id": "", 34 | "query": query, 35 | "params": { 36 | "top_k": 50, 37 | "n_best": 50, 38 | "ranker_only":True 39 | } 40 | } 41 | 42 | r = requests.post(url=host_url, json=json_body, headers=headers) 43 | r.raise_for_status() 44 | res = r.json()['result'] 45 | # clean res 46 | for ans in res: 47 | ans['answer'] = ans['answer']['context'] 48 | 49 | return res 50 | 51 | 52 | if __name__ == '__main__': 53 | ranker = Ranker('sparta-en-wiki-2016') 54 | res = ranker.query('when was microsoft founded?') 55 | print(res) -------------------------------------------------------------------------------- /soco_openqa/demo/reader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from soco_openqa.soco_mrc.mrc_model import MrcModel 3 | from collections import defaultdict 4 | 5 | 6 | 7 | class Reader: 8 | def __init__(self, model): 9 | self.model_id = model 10 | self.reader = MrcModel('us', n_gpu=1) 11 | self.thresh = 0.8 12 | 13 | def predict(self, query, top_passages): 14 | batch = [{'q': query, 'doc': p['answer']} for p in top_passages] 15 | preds = self.reader.batch_predict( 16 | self.model_id, 17 | batch, 18 | merge_pred=True, 19 | stride=128, 20 | batch_size=50 21 | ) 22 | 23 | candidates = defaultdict(list) 24 | for a_id, a in enumerate(preds): 25 | if a.get('missing_warning'): 26 | continue 27 | score = self.thresh * (a['score']) + (1 - self.thresh) * (top_passages[a_id]['score']) 28 | candidates[a['value']].append({'combined_score': score, 29 | 'reader_score':a['score'], 30 | 'ranker_score':top_passages[a_id]['score'], 31 | 'idx': a_id, 32 | 'prob': a['prob'], 33 | 'answer_span': a['answer_span']}) 34 | 35 | # get best passages with best answer 36 | answers = [] 37 | for k, v in candidates.items(): 38 | combined_scores = [x['combined_score'] for x in v] 39 | reader_scores = [x['reader_score'] for x in v] 40 | ranker_scores = [x['ranker_score'] for x in v] 41 | idxes = [x['idx'] for x in v] 42 | best_idx = int(np.argmax(combined_scores)) 43 | best_a_id = idxes[best_idx] 44 | answers.append({'value': k, 45 | 'score': combined_scores[best_idx], 46 | 'reader_score': reader_scores[best_idx], 47 | 'ranker_score': ranker_scores[best_idx], 48 | 'prob': v[best_idx]['prob'], 49 | 'answer_span': v[best_idx]['answer_span'], 50 | "source": { 51 | 'context': top_passages[best_a_id]['answer'], 52 | 'url': top_passages[best_a_id].get('meta', {}).get('url'), 53 | 'doc_id': top_passages[best_a_id].get('meta', {}).get('doc_id') 54 | } 55 | }) 56 | 57 | answers = sorted(answers, key=lambda x: x['score'], reverse=True) 58 | return answers 59 | 60 | -------------------------------------------------------------------------------- /soco_openqa/evaluation.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, OrderedDict, defaultdict 2 | import string 3 | import re 4 | import argparse 5 | import json 6 | import random 7 | import sys 8 | import nltk 9 | import logging 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class EvalMrcEn: 16 | def evaluate(self, dataset, predictions, config): 17 | f1 = exact_match = total = 0 18 | recall_grid = defaultdict(int) 19 | 20 | for article in dataset: 21 | for paragraph in article['paragraphs']: 22 | for qa in paragraph['qas']: 23 | f1, exact_match, total, recall_grid = self._single_evaluate(qa, predictions, config, f1, exact_match, total, recall_grid) 24 | 25 | exact_match = 100.0 * exact_match / total 26 | f1 = 100.0 * f1 / total 27 | for r, v in recall_grid.items(): 28 | recall_grid[r] = v * 100.0 / total 29 | 30 | res = {'exact_match': exact_match, 'f1': f1} 31 | res.update(recall_grid) 32 | 33 | return res 34 | 35 | 36 | def _normalize_answer(self, s): 37 | """Lower text and remove punctuation, articles and extra whitespace.""" 38 | def remove_articles(text): 39 | return re.sub(r'\b(a|an|the)\b', ' ', text) 40 | 41 | def white_space_fix(text): 42 | return ' '.join(text.split()) 43 | 44 | def remove_punc(text): 45 | exclude = set(string.punctuation) 46 | return ''.join(ch for ch in text if ch not in exclude) 47 | 48 | def lower(text): 49 | return text.lower() 50 | 51 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 52 | 53 | 54 | def _f1_score(self, prediction, ground_truth): 55 | prediction_tokens = self._normalize_answer(prediction).split() 56 | ground_truth_tokens = self._normalize_answer(ground_truth).split() 57 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 58 | num_same = sum(common.values()) 59 | if num_same == 0: 60 | return 0 61 | precision = 1.0 * num_same / len(prediction_tokens) 62 | recall = 1.0 * num_same / len(ground_truth_tokens) 63 | f1 = (2 * precision * recall) / (precision + recall) 64 | return f1 65 | 66 | 67 | def _exact_match_score(self, prediction, ground_truth): 68 | return (self._normalize_answer(prediction) == self._normalize_answer(ground_truth)) 69 | 70 | 71 | def _metric_max_over_ground_truths(self, metric_fn, prediction, ground_truths): 72 | scores_for_ground_truths = [] 73 | for ground_truth in ground_truths: 74 | score = metric_fn(prediction, ground_truth) 75 | scores_for_ground_truths.append(score) 76 | return max(scores_for_ground_truths) 77 | 78 | def _meteric_recall(self, passages, ground_truth, at_n): 79 | context_hit = False 80 | 81 | passages = passages[0:at_n] 82 | for p in passages: 83 | for a in ground_truth: 84 | if a.lower() in p['answer'].lower(): 85 | context_hit = True 86 | 87 | if context_hit: 88 | break 89 | return context_hit 90 | 91 | 92 | def _single_evaluate(self, qa, predictions, config, f1, exact_match, total, recall_grid): 93 | total += 1 94 | qa['id'] = str(qa['id']) 95 | if qa['id'] not in predictions: 96 | message = 'Unanswered question ' + qa['id'] + \ 97 | ' will receive score 0.' 98 | print(message, file=sys.stderr) 99 | return 100 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 101 | 102 | best_prediction = predictions[qa['id']]['answer'] 103 | best_exact_match = self._metric_max_over_ground_truths( 104 | self._exact_match_score, best_prediction, ground_truths) 105 | exact_match += best_exact_match 106 | f1 += self._metric_max_over_ground_truths( 107 | self._f1_score, best_prediction, ground_truths) 108 | 109 | for n in [1, 5, 10, 50, 100]: 110 | if n > config.param.top_k: 111 | break 112 | c_h = self._meteric_recall(predictions[qa['id']]['passages'], ground_truths, n) 113 | recall_grid['R@{}'.format(n)] += int(c_h) 114 | 115 | return f1, exact_match, total, recall_grid 116 | 117 | 118 | class EvalMrcZh: 119 | # -*- coding: utf-8 -*- 120 | ''' 121 | Evaluation script for CMRC 2018 122 | version: v5 - special 123 | Note: 124 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets 125 | v5: formatted output, add usage description 126 | v4: fixed segmentation issues 127 | ''' 128 | def evaluate(self, ground_truth_file, prediction_file, config): 129 | f1 = 0 130 | em = 0 131 | total_count = 0 132 | recall_grid = defaultdict(int) 133 | for instance in ground_truth_file: 134 | for para in instance["paragraphs"]: 135 | for qas in para['qas']: 136 | f1, em, total_count, recall_grid = self._single_evaluate(qas, prediction_file, config, f1, em, total_count, recall_grid) 137 | 138 | f1_score = 100.0 * f1 / total_count 139 | em_score = 100.0 * em / total_count 140 | 141 | for r, v in recall_grid.items(): 142 | recall_grid[r] = v * 100.0 / total_count 143 | res = {'exact_match': em_score, 'f1': f1_score} 144 | res.update(recall_grid) 145 | 146 | return res 147 | 148 | 149 | def _mixed_segmentation(self, in_str, rm_punc=False): 150 | in_str = str(in_str).lower().strip() 151 | segs_out = [] 152 | temp_str = "" 153 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 154 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 155 | '「','」','(',')','-','~','『','』'] 156 | for char in in_str: 157 | if rm_punc and char in sp_char: 158 | continue 159 | if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char: 160 | if temp_str != "": 161 | ss = nltk.word_tokenize(temp_str) 162 | segs_out.extend(ss) 163 | temp_str = "" 164 | segs_out.append(char) 165 | else: 166 | temp_str += char 167 | 168 | #handling last part 169 | if temp_str != "": 170 | ss = nltk.word_tokenize(temp_str) 171 | segs_out.extend(ss) 172 | 173 | return segs_out 174 | 175 | 176 | # remove punctuation 177 | def _remove_punctuation(self, in_str): 178 | in_str = str(in_str).lower().strip() 179 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 180 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 181 | '「','」','(',')','-','~','『','』'] 182 | out_segs = [] 183 | for char in in_str: 184 | if char in sp_char: 185 | continue 186 | else: 187 | out_segs.append(char) 188 | return ''.join(out_segs) 189 | 190 | 191 | # find longest common string 192 | def _find_lcs(self, s1, s2): 193 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)] 194 | mmax = 0 195 | p = 0 196 | for i in range(len(s1)): 197 | for j in range(len(s2)): 198 | if s1[i] == s2[j]: 199 | m[i+1][j+1] = m[i][j]+1 200 | if m[i+1][j+1] > mmax: 201 | mmax=m[i+1][j+1] 202 | p=i+1 203 | return s1[p-mmax:p], mmax 204 | 205 | 206 | def _meteric_recall(self, passages, ground_truth, at_n): 207 | context_hit = False 208 | 209 | passages = passages[0:at_n] 210 | for p in passages: 211 | for a in ground_truth: 212 | if a.lower() in p['answer'].lower(): 213 | context_hit = True 214 | 215 | if context_hit: 216 | break 217 | return context_hit 218 | 219 | 220 | def _single_evaluate(self, qas, prediction_file, config, f1, em, total_count, recall_grid): 221 | total_count += 1 222 | query_id = qas['id'].strip() 223 | query_text = qas['question'].strip() 224 | answers = [x["text"] for x in qas['answers']] 225 | 226 | if query_id not in prediction_file: 227 | sys.stderr.write('Unanswered question: {}\n'.format(query_id)) 228 | return 229 | 230 | prediction = str(prediction_file[query_id]['answer']) 231 | f1 += self._calc_f1_score(answers, prediction) 232 | em += self._calc_em_score(answers, prediction) 233 | 234 | 235 | for n in [1, 5, 10, 50, 100]: 236 | if n > config.param.top_k: 237 | break 238 | c_h = self._meteric_recall(prediction_file[query_id]['passages'], answers, n) 239 | recall_grid['R@{}'.format(n)] += int(c_h) 240 | 241 | return f1, em, total_count, recall_grid 242 | 243 | 244 | def _calc_f1_score(self, answers, prediction): 245 | f1_scores = [] 246 | for ans in answers: 247 | ans_segs = self._mixed_segmentation(ans, rm_punc=True) 248 | prediction_segs = self._mixed_segmentation(prediction, rm_punc=True) 249 | lcs, lcs_len = self._find_lcs(ans_segs, prediction_segs) 250 | if lcs_len == 0: 251 | f1_scores.append(0) 252 | continue 253 | precision = 1.0*lcs_len/len(prediction_segs) 254 | recall = 1.0*lcs_len/len(ans_segs) 255 | f1 = (2*precision*recall)/(precision+recall) 256 | f1_scores.append(f1) 257 | return max(f1_scores) 258 | 259 | 260 | def _calc_em_score(self, answers, prediction): 261 | em = 0 262 | for ans in answers: 263 | ans_ = self._remove_punctuation(ans) 264 | prediction_ = self._remove_punctuation(prediction) 265 | if ans_ == prediction_: 266 | em = 1 267 | break 268 | return em 269 | 270 | 271 | def evaluate(lang, data, predictions, config): 272 | if lang == 'en': 273 | eval_func = EvalMrcEn() 274 | elif lang == 'zh': 275 | eval_func = EvalMrcZh() 276 | else: 277 | raise ValueError('lang {} not recognized'.format(lang)) 278 | 279 | results = eval_func.evaluate(data['data'], predictions, config) 280 | return results -------------------------------------------------------------------------------- /soco_openqa/helper.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | import os 5 | from soco_device import DeviceCheck 6 | from soco_openqa.cloud_bucket import CloudBucket 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def get_name_from_path(path): 13 | def strip_path(path): 14 | return path.rsplit('/', 1)[-1] 15 | 16 | def remove_ext(name): 17 | return name.rsplit('.', 1)[0] 18 | 19 | return remove_ext(strip_path(path)) 20 | 21 | 22 | def find_device(n_gpu): 23 | ''' 24 | return available device number given requested number 25 | ''' 26 | device_check = DeviceCheck() 27 | device_name, device_ids = device_check.get_device(n_gpu=n_gpu) 28 | n_gpu = len(device_ids) 29 | 30 | return n_gpu 31 | 32 | 33 | def load_json(file_dir, file_name, region='us'): 34 | cloud_bucket = CloudBucket(region) 35 | cloud_bucket.download_file(file_dir=file_dir, file_name=file_name) 36 | res = json.load(open(os.path.join('data', file_dir, file_name))) 37 | return res 38 | 39 | 40 | def save_logs(config, results, save_path='reports', save_name='eval_results'): 41 | os.makedirs(save_path, exist_ok=True) 42 | save_name = '{}/{}.txt'.format(save_path, save_name) 43 | results = '\n' + ' '.join(['{}={:.3f}'.format(k, v) for k, v in results.items()]) + '\n\n' 44 | 45 | curr_time = datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S%z') 46 | with open(save_name, 'a') as f: 47 | f.write('{}\n'.format('*'*20)) 48 | f.write('{}\n'.format(curr_time)) 49 | f.write(config) 50 | f.write(results) 51 | logger.info('Saved results to {}'.format(save_name)) 52 | 53 | 54 | def load_jsonl(file_name): 55 | data = [] 56 | with open(file_name, 'r', encoding='utf-8') as f: 57 | for line in f: 58 | data.append(json.loads(line.rstrip('\n|\r'))) 59 | 60 | logger.info('Loaded {} lines from {}'.format(len(data), file_name)) 61 | return data 62 | 63 | 64 | def sparta_postprocess(ranker_id): 65 | data = load_jsonl('./data/{}.jsonl'.format(ranker_id)) 66 | 67 | # convert a list of dict to {id1: [{‘answer’:1, ‘score’:1}, {‘answer’:2, ‘score’:2}, …], id2: [], …} 68 | id_to_data_map = dict() 69 | for d in data: 70 | if d['id'] not in id_to_data_map: 71 | id_to_data_map[d['id']] = [] 72 | for response in d['responses']: 73 | res_map = dict() 74 | res_map['score'] = response['_score'] 75 | res_map['answer'] = response['_source']['context'].get('context', '') 76 | id_to_data_map[d['id']].append(res_map) 77 | else: 78 | raise ValueError('{} already in map'.format(d['id'])) 79 | 80 | # save processed version to json 81 | json.dump(id_to_data_map, open('./data/{}.json'.format(ranker_id), 'w'), ensure_ascii=False) 82 | 83 | return id_to_data_map 84 | 85 | 86 | class QueryGenerator(object): 87 | @classmethod 88 | def sent_bm25(cls, query): 89 | es_query = {"query": {'match': {'q': {'query': query}}}} 90 | return es_query 91 | 92 | @classmethod 93 | def context_bm25(cls, query): 94 | es_query = {"query": {'match': {'context.context': {'query': query}}}} 95 | return es_query 96 | 97 | @classmethod 98 | def tscore_search(cls, query, query_embedded, alpha_bm25=0.0, max_l2r=-1): 99 | # convert to string only 100 | query_embedded = [t if type(t) is str else '__'.join(t) for t in query_embedded if t] 101 | 102 | main_query = [{'rank_feature': {'field': 'term_scores.{}'.format(t), 103 | "log": {"scaling_factor": 1.0} 104 | }} for t in query_embedded] 105 | es_query = {"query": {"bool": {"should": main_query}}} 106 | 107 | if alpha_bm25 > 0: 108 | window_size = max(100, max_l2r) 109 | es_query['rescore'] = { 110 | "window_size": window_size, 111 | "query": { 112 | "score_mode": "total", 113 | "rescore_query": { 114 | "match": {"q": {"query": query}}, 115 | }, 116 | "query_weight": 1.0, 117 | "rescore_query_weight": alpha_bm25 118 | }} 119 | 120 | return es_query 121 | 122 | -------------------------------------------------------------------------------- /soco_openqa/pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from os.path import isfile 4 | 5 | from tqdm import tqdm 6 | 7 | import soco_openqa.helper as helper 8 | from soco_openqa.ranker import BM25Ranker, SpartaRanker 9 | from soco_openqa.reader import Reader 10 | from soco_openqa.cloud_bucket import CloudBucket 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | ranker_map = {'sparta': SpartaRanker, 'bm25': BM25Ranker} 16 | 17 | class OpenQA(object): 18 | def __init__(self, config): 19 | self.use_cached_rank = True if config.ranker.use_cached else False 20 | self.top_k = config.param.top_k 21 | self.ranker = self._load_ranker(config) 22 | self.reader = Reader(config) 23 | self.cloud_bucket = CloudBucket('us') 24 | 25 | def _load_ranker(self, config): 26 | if self.use_cached_rank: 27 | logger.info('Use cached ranker file') 28 | data_name = config.data.name 29 | ranker_file = config.ranker.cached_ranker_file 30 | ranker = helper.load_json(file_dir=data_name, file_name=ranker_file) 31 | 32 | else: 33 | ranker_name = config.ranker.model.name 34 | ranker = ranker_map[ranker_name](config) 35 | 36 | return ranker 37 | 38 | 39 | def predict(self, data): 40 | predictions = dict() 41 | no_ans_cnt = 0 42 | logger.info("Execution started") 43 | for d in tqdm(data['data']): 44 | for p in tqdm(d['paragraphs']): 45 | for qa in p['qas']: 46 | if len(predictions) > 0 and len(predictions) % 1000 == 0: 47 | logger.info("{}: {} no answer".format(len(predictions), no_ans_cnt)) 48 | _id = str(qa['id']) 49 | query = qa['question'] 50 | if self.use_cached_rank: 51 | top_passages = self.ranker[_id][:self.top_k] 52 | else: 53 | top_passages = self.ranker.rank(query) 54 | 55 | try: 56 | answers = self.reader.predict(query, top_passages) 57 | if len(answers) > 0: 58 | predictions[_id] = {'answer': answers[0]['value'], 'passages': top_passages} 59 | else: 60 | raise ValueError("No Answer") 61 | except ValueError as e: 62 | no_ans_cnt += 1 63 | predictions[_id] = {'answer': 'NO_ANSWER', 'passages': top_passages} 64 | 65 | return predictions -------------------------------------------------------------------------------- /soco_openqa/ranker.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import urllib3 5 | from elasticsearch import Elasticsearch, RequestsHttpConnection 6 | from soco_encoders.model_loaders import EncoderLoader 7 | 8 | from soco_openqa.helper import QueryGenerator as QG 9 | 10 | urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) 11 | 12 | logger = logging.getLogger('elasticsearch') 13 | logger.setLevel(logging.WARNING) 14 | 15 | class RankerBases(object): 16 | def rank(self, query, size=10): 17 | raise NotImplementedError 18 | 19 | 20 | class BM25Ranker(RankerBases): 21 | def __init__(self, config): 22 | pass 23 | 24 | 25 | class SpartaRanker(RankerBases): 26 | def __init__(self, config): 27 | self.index = config.ranker.model.es_index_name 28 | self.lang = config.data.lang 29 | self.top_k = config.param.top_k 30 | self.tokenizers = dict() 31 | 32 | es_url = 'https://elastic:13-socoES@search-new-wiki-qsd6ejfqfwyva7sok6ag32s72u.us-east-2.es.amazonaws.com' 33 | 34 | es = Elasticsearch( 35 | hosts=[es_url], 36 | ca_certs=False, 37 | verify_certs=False, 38 | connection_class=RequestsHttpConnection 39 | ) 40 | # print("Create ES client {}".format(es_url)) 41 | self.es = es 42 | 43 | 44 | def _get_query(self, query, query_embedded): 45 | es_query = QG.tscore_search(query, query_embedded) 46 | es_query['size'] = self.top_k 47 | es_query['_source'] = {'excludes': ['embedding_vector*', 'term_scores']} 48 | return es_query 49 | 50 | def _load_tokenizer(self): 51 | if self.lang not in self.tokenizers: 52 | if self.lang == 'zh': 53 | self.tokenizers[self.lang] = EncoderLoader.load_tokenizer('bert-base-chinese-zh_v4-10K') 54 | elif self.lang == 'en': 55 | self.tokenizers[self.lang] = EncoderLoader.load_tokenizer('bert-base-uncased') 56 | else: 57 | raise NotImplementedError 58 | 59 | return self.tokenizers[self.lang] 60 | 61 | 62 | def _postprocess(self, res): 63 | res = [{'score': p['_score'], 64 | 'answer': p['_source']['answer']['context']} 65 | for p in res['hits']['hits']] 66 | 67 | return res 68 | 69 | 70 | def rank(self, query): 71 | """ 72 | search inside 73 | """ 74 | tokenizer = self._load_tokenizer() 75 | tokens = tokenizer.tokenize(query, mode='all') 76 | es_query = self._get_query(query, tokens) 77 | res = self.es.search(index=self.index, body=es_query, request_timeout=500) 78 | 79 | return self._postprocess(res) 80 | -------------------------------------------------------------------------------- /soco_openqa/reader.py: -------------------------------------------------------------------------------- 1 | import soco_openqa.helper as helper 2 | from soco_openqa.soco_mrc.mrc_model import MrcModel, MrcRerankerModel 3 | from collections import defaultdict 4 | import numpy as np 5 | 6 | 7 | class Reader(object): 8 | def __init__(self, config): 9 | gpu_request = config.param.n_gpu 10 | n_gpu = helper.find_device(gpu_request) 11 | print('number of gpus: {}'.format(n_gpu)) 12 | 13 | if config.reader.use_reranker: 14 | self.reader = MrcRerankerModel('us', n_gpu=n_gpu) 15 | self.use_reranker = True 16 | self.rerank_size = config.reader.rerank_size 17 | else: 18 | self.reader = MrcModel('us', n_gpu=n_gpu) 19 | self.use_reranker = False 20 | self.model_id = config.reader.model_id 21 | self.thresh = config.param.score_weight 22 | 23 | def predict(self, query, top_passages): 24 | batch = [{'q': query, 'doc': p['answer']} for p in top_passages] 25 | preds = self.reader.batch_predict( 26 | self.model_id, 27 | batch, 28 | merge_pred=True, 29 | stride=128, 30 | batch_size=50 31 | ) 32 | 33 | # combine with ranking score 34 | if not self.use_reranker: 35 | candidates = defaultdict(list) 36 | for a_id, a in enumerate(preds): 37 | if a.get('missing_warning'): 38 | continue 39 | score = self.thresh * (a['score']) + (1 - self.thresh) * (top_passages[a_id]['score']) 40 | candidates[a['value']].append(score) 41 | 42 | candidates = [{'value': k, 'score': np.max(v)} for k, v in candidates.items()] 43 | 44 | answers = sorted(candidates, key=lambda x: x['score'], reverse=True) 45 | else: 46 | # first sort by cls_score, then get topn and rank by combined score 47 | candidates = defaultdict(lambda: defaultdict(list)) 48 | for a_id, a in enumerate(preds): 49 | if a.get('missing_warning'): 50 | continue 51 | score = self.thresh * (a['score']*a['cls_prob']) + (1 - self.thresh) * (top_passages[a_id]['score']) 52 | 53 | candidates[a['value']]['score'].append(score) 54 | candidates[a['value']]['cls_prob'].append(a['cls_prob']) 55 | 56 | candidates = [{'value': k, 'score': np.max(v['score']), 'cls_prob':np.max(v['cls_prob'])} for k, v in candidates.items()] 57 | 58 | rerank_candidates = sorted(candidates, key=lambda x: x['cls_prob'], reverse=True)[:self.rerank_size] 59 | answers = sorted(rerank_candidates, key=lambda x: x['score'], reverse=True) 60 | 61 | return answers 62 | -------------------------------------------------------------------------------- /soco_openqa/soco_mrc/models/bert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BertPreTrainedModel 3 | from transformers import BertModel 4 | from torch import nn 5 | from torch.nn import CrossEntropyLoss, MSELoss, NLLLoss 6 | 7 | 8 | class BertForQuestionAnsweringWithReranker(BertPreTrainedModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | self.bert = BertModel(config) 12 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 13 | 14 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 15 | self.classifier = nn.Linear(config.hidden_size, 2) 16 | 17 | self.init_weights() 18 | 19 | def forward( 20 | self, 21 | input_ids=None, 22 | token_type_ids=None, 23 | attention_mask=None, 24 | 25 | ): 26 | sequence_output, pooled_output = self.bert( 27 | input_ids, 28 | attention_mask=attention_mask, 29 | token_type_ids=token_type_ids, 30 | ) 31 | 32 | logits = self.qa_outputs(sequence_output) 33 | start_logits, end_logits = logits.split(1, dim=-1) 34 | start_logits = start_logits.squeeze(-1) 35 | end_logits = end_logits.squeeze(-1) 36 | cls_logits = self.classifier(pooled_output) 37 | 38 | return start_logits, end_logits, cls_logits 39 | -------------------------------------------------------------------------------- /soco_openqa/soco_mrc/mrc_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoModelForQuestionAnswering, AutoTokenizer 3 | import os 4 | import torch 5 | import re 6 | from collections import namedtuple 7 | import logging 8 | from soco_device import DeviceCheck 9 | from soco_openqa.soco_mrc import util 10 | from soco_openqa.soco_mrc.models.bert_model import BertForQuestionAnsweringWithReranker 11 | from soco_openqa.cloud_bucket import CloudBucket 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | logging.getLogger("transformers").setLevel(logging.ERROR) 16 | logging.getLogger("tokenization_utils").setLevel(logging.ERROR) 17 | 18 | MODEL_MAP = { 19 | 'MrcModel': AutoModelForQuestionAnswering, 20 | 'MrcRerankerModel': BertForQuestionAnsweringWithReranker, 21 | } 22 | 23 | 24 | 25 | class ModelBase(object): 26 | def __init__(self, 27 | region, 28 | n_gpu=0, 29 | fp16=False, 30 | quantize=False, 31 | multiprocess=False 32 | ): 33 | logger.info("Op in {} region".format(region)) 34 | self.n_gpu_request = n_gpu 35 | self.region = region 36 | self.fp16 = fp16 37 | self.quantize = quantize 38 | self.multiprocess = multiprocess 39 | self.cloud_bucket = CloudBucket(region) 40 | self._models = dict() 41 | self.max_input_length = 512 42 | self.device_check = DeviceCheck() 43 | 44 | 45 | def _load_model(self, model_id): 46 | # a naive check. if too big, just reset 47 | if len(self._models) > 20: 48 | self._models = dict() 49 | 50 | if model_id not in self._models: 51 | path = os.path.join('resources', model_id) 52 | self.cloud_bucket.download_model('mrc-models', model_id) 53 | model_class = self.__class__.__name__ 54 | model = MODEL_MAP[model_class].from_pretrained(path) 55 | tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) 56 | 57 | device_name, device_ids = self.device_check.get_device_by_model(model_id, n_gpu=self.n_gpu_request) 58 | self.n_gpu_allocate = len(device_ids) 59 | device = '{}:{}'.format(device_name, device_ids[0]) if self.n_gpu_allocate == 1 else device_name 60 | 61 | if self.fp16 and 'cuda' in device: 62 | logger.info('Use fp16') 63 | model.half() 64 | if self.quantize and device == 'cpu': 65 | logger.info('Use quantization') 66 | model = torch.quantization.quantize_dynamic( 67 | model, {torch.nn.Linear}, dtype=torch.qint8) 68 | model.to(device) 69 | 70 | # multi gpu inference 71 | if self.n_gpu_allocate > 1 and not isinstance(model, torch.nn.DataParallel): 72 | model = torch.nn.DataParallel(model, device_ids=device_ids) 73 | 74 | self._models[model_id] = (tokenizer, model, device) 75 | 76 | else: 77 | # if loaded as cpu, check if gpu is available 78 | _, _, device = self._models[model_id] 79 | if self.n_gpu_request > 0 and device == 'cpu': 80 | device_name, device_ids = self.device_check.get_device_by_model(model_id, n_gpu=1) 81 | new_device = '{}:{}'.format(device_name, device_ids[0]) if len(device_ids) == 1 else device_name 82 | if new_device != device: 83 | logger.info('Reloading') 84 | self._models.pop(model_id) 85 | self._load_model(model_id) 86 | 87 | return self._models[model_id] 88 | 89 | 90 | def batch_predict(self, model_id, data, **kwargs): 91 | raise NotImplementedError 92 | 93 | def _get_param(self, kwargs): 94 | batch_size = kwargs.pop('batch_size', 10) 95 | merge_pred = kwargs.pop('merge_pred', False) 96 | stride = kwargs.pop('stride', 0) 97 | 98 | return batch_size, merge_pred, stride 99 | 100 | 101 | 102 | class MrcModel(ModelBase): 103 | def batch_predict(self, model_id, data, **kwargs): 104 | batch_size, merge_pred, stride = self._get_param(kwargs) 105 | 106 | tokenizer, model, device = self._load_model(model_id) 107 | 108 | features = util.convert_examples_to_features( 109 | tokenizer, 110 | data, 111 | self.max_input_length, 112 | merge_pred, 113 | stride) 114 | 115 | results = [] 116 | for batch in util.chunks(features, batch_size): 117 | padded = util.pad_batch(batch) 118 | input_ids, token_type_ids, attn_masks = padded 119 | 120 | with torch.no_grad(): 121 | start_scores, end_scores = model(torch.tensor(input_ids).to(device), 122 | token_type_ids=torch.tensor(token_type_ids).to(device), 123 | attention_mask=torch.tensor(attn_masks).to(device)) 124 | start_probs = torch.softmax(start_scores, dim=1) 125 | end_probs = torch.softmax(end_scores, dim=1) 126 | 127 | for b_id in range(len(batch)): 128 | all_tokens = tokenizer.convert_ids_to_tokens(input_ids[b_id]) 129 | legal_length = batch[b_id]['length'] 130 | b_start_score = start_scores[b_id][0:legal_length] 131 | b_end_score = end_scores[b_id][0:legal_length] 132 | token2char = batch[b_id]['offset_mapping'] 133 | for t_id in range(legal_length): 134 | if token2char[t_id] is None or token2char[t_id] == (0, 0): 135 | b_start_score[t_id] = -10000 136 | b_end_score[t_id] = -10000 137 | 138 | _, top_start_id = torch.topk(b_start_score, 2, dim=0) 139 | _, top_end_id = torch.topk(b_end_score, 2, dim=0) 140 | 141 | s_prob = start_probs[b_id, top_start_id[0]].item() 142 | e_prob = end_probs[b_id, top_end_id[0]].item() 143 | s_logit = start_scores[b_id, top_start_id[0]].item() 144 | e_logit = end_scores[b_id, top_end_id[0]].item() 145 | 146 | prob = (s_prob + e_prob) / 2 147 | score = (s_logit + e_logit) / 2 148 | 149 | doc = batch[b_id]['doc'] 150 | doc_offset = input_ids[b_id].index(102) 151 | 152 | res = all_tokens[top_start_id[0]:top_end_id[0] + 1] 153 | char_offset = token2char[doc_offset + 1][0] 154 | example_idx = batch[b_id]['example_idx'] 155 | 156 | if not res or res[0] == "[CLS]" or res[0] == '[SEP]' or top_start_id[0].item() <= doc_offset: 157 | prediction = {'missing_warning': True, 158 | 'prob': prob, 159 | 'start_end_prob': [s_prob, e_prob], 160 | 'score': score, 161 | 'start_end_score': [s_logit, e_logit], 162 | 'value': "", 163 | 'answer_start': -1, 164 | 'example_idx': example_idx} 165 | else: 166 | if not merge_pred: 167 | start_map = token2char[top_start_id[0].item()] 168 | end_map = token2char[top_end_id[0].item()] 169 | span = [start_map[0] - char_offset, end_map[1] - char_offset] 170 | ans = doc[span[0]: span[1]] 171 | else: 172 | base_idx = batch[b_id]['base_idx'] 173 | orig_doc = batch[b_id]['orig_doc'] 174 | orig_token2char = batch[b_id]['orig_offset_mapping'] 175 | 176 | # map token index, then use offset mapping to map to original position 177 | orig_start_map = orig_token2char[top_start_id[0].item() + base_idx - doc_offset - 1] 178 | orig_end_map = orig_token2char[top_end_id[0].item() + base_idx - doc_offset - 1] 179 | span = [orig_start_map[0], orig_end_map[1]] 180 | ans = orig_doc[span[0]: span[1]] 181 | try: 182 | start_map = token2char[top_start_id[0].item()] 183 | end_map = token2char[top_end_id[0].item()] 184 | debug_span = [start_map[0] - char_offset, end_map[1] - char_offset] 185 | debug_ans = doc[debug_span[0]: debug_span[1]] 186 | assert debug_ans == ans 187 | except Exception as e: 188 | print(e) 189 | print('chunk ans: {} '.format(debug_ans)) 190 | print('doc ans: {} '.format(ans)) 191 | print('chunk span: {} vs doc span: {}'.format(debug_span, span)) 192 | 193 | prediction = {'value': ans, 194 | 'answer_start': span[0], 195 | 'answer_span': span, 196 | 'prob': prob, 197 | 'start_end_prob': [s_prob, e_prob], 198 | 'score': score, 199 | 'start_end_score': [s_logit, e_logit], 200 | 'tokens': res, 201 | 'example_idx': example_idx} 202 | 203 | results.append(prediction) 204 | 205 | 206 | # merge predictions 207 | if merge_pred: 208 | results = util.merge_predictions(results) 209 | 210 | return results 211 | 212 | 213 | class MrcRerankerModel(ModelBase): 214 | 215 | def batch_predict(self, model_id, data, **kwargs): 216 | batch_size, merge_pred, stride = self._get_param(kwargs) 217 | 218 | tokenizer, model, device = self._load_model(model_id) 219 | 220 | features = util.convert_examples_to_features( 221 | tokenizer, 222 | data, 223 | self.max_input_length, 224 | merge_pred, 225 | stride) 226 | 227 | results = [] 228 | for batch in util.chunks(features, batch_size): 229 | padded = util.pad_batch(batch) 230 | input_ids, token_type_ids, attn_masks = padded 231 | 232 | with torch.no_grad(): 233 | start_scores, end_scores, cls_scores = model(torch.tensor(input_ids).to(device), 234 | token_type_ids=torch.tensor(token_type_ids).to(device), 235 | attention_mask=torch.tensor(attn_masks).to(device)) 236 | 237 | start_probs = torch.softmax(start_scores, dim=1) 238 | end_probs = torch.softmax(end_scores, dim=1) 239 | cls_probs = torch.softmax(cls_scores, dim=1) 240 | 241 | for b_id in range(len(batch)): 242 | all_tokens = tokenizer.convert_ids_to_tokens(input_ids[b_id]) 243 | legal_length = batch[b_id]['length'] 244 | b_start_score = start_scores[b_id][0:legal_length] 245 | b_end_score = end_scores[b_id][0:legal_length] 246 | token2char = batch[b_id]['offset_mapping'] 247 | for t_id in range(legal_length): 248 | if token2char[t_id] is None or token2char[t_id] == (0, 0): 249 | b_start_score[t_id] = -10000 250 | b_end_score[t_id] = -10000 251 | 252 | _, top_start_id = torch.topk(b_start_score, 2, dim=0) 253 | _, top_end_id = torch.topk(b_end_score, 2, dim=0) 254 | 255 | s_prob = start_probs[b_id, top_start_id[0]].item() 256 | e_prob = end_probs[b_id, top_end_id[0]].item() 257 | s_logit = start_scores[b_id, top_start_id[0]].item() 258 | e_logit = end_scores[b_id, top_end_id[0]].item() 259 | # get has answer confidence 260 | cls_score = cls_scores[b_id][1].item() 261 | cls_prob = cls_probs[b_id][1].item() 262 | 263 | prob = (s_prob + e_prob) / 2 264 | score = (s_logit + e_logit) / 2 265 | 266 | doc = batch[b_id]['doc'] 267 | doc_offset = input_ids[b_id].index(102) 268 | 269 | res = all_tokens[top_start_id[0]:top_end_id[0] + 1] 270 | char_offset = token2char[doc_offset + 1][0] 271 | example_idx = batch[b_id]['example_idx'] 272 | 273 | if not res or res[0] == "[CLS]" or res[0] == '[SEP]' or top_start_id[0].item() <= doc_offset: 274 | prediction = {'missing_warning': True, 275 | 'prob': prob, 276 | 'start_end_prob': [s_prob, e_prob], 277 | 'score': score, 278 | 'cls_score': cls_score, 279 | 'cls_prob': cls_prob, 280 | 'start_end_score': [s_logit, e_logit], 281 | 'value': "", 282 | 'answer_start': -1, 283 | 'example_idx': example_idx} 284 | else: 285 | if not merge_pred: 286 | start_map = token2char[top_start_id[0].item()] 287 | end_map = token2char[top_end_id[0].item()] 288 | span = [start_map[0] - char_offset, end_map[1] - char_offset] 289 | ans = doc[span[0]: span[1]] 290 | else: 291 | base_idx = batch[b_id]['base_idx'] 292 | orig_doc = batch[b_id]['orig_doc'] 293 | orig_token2char = batch[b_id]['orig_offset_mapping'] 294 | 295 | # map token index, then use offset mapping to map to original position 296 | orig_start_map = orig_token2char[top_start_id[0].item() + base_idx - doc_offset - 1] 297 | orig_end_map = orig_token2char[top_end_id[0].item() + base_idx - doc_offset - 1] 298 | span = [orig_start_map[0], orig_end_map[1]] 299 | ans = orig_doc[span[0]: span[1]] 300 | try: 301 | start_map = token2char[top_start_id[0].item()] 302 | end_map = token2char[top_end_id[0].item()] 303 | debug_span = [start_map[0] - char_offset, end_map[1] - char_offset] 304 | debug_ans = doc[debug_span[0]: debug_span[1]] 305 | assert debug_ans == ans 306 | except Exception as e: 307 | print(e) 308 | print('chunk ans: {} '.format(debug_ans)) 309 | print('doc ans: {} '.format(ans)) 310 | print('chunk span: {} vs doc span: {}'.format(debug_span, span)) 311 | 312 | prediction = {'value': ans, 313 | 'answer_start': span[0], 314 | 'answer_span': span, 315 | 'prob': prob, 316 | 'start_end_prob': [s_prob, e_prob], 317 | 'score': score, 318 | 'cls_score': cls_score, 319 | 'cls_prob': cls_prob, 320 | 'start_end_score': [s_logit, e_logit], 321 | 'tokens': res, 322 | 'example_idx': example_idx} 323 | 324 | results.append(prediction) 325 | 326 | # merge predictions 327 | if merge_pred: 328 | results = util.merge_predictions(results) 329 | 330 | return results 331 | -------------------------------------------------------------------------------- /soco_openqa/soco_mrc/util.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import re 4 | import string 5 | from collections import Counter, OrderedDict 6 | import nltk 7 | from typing import Any, Callable, Dict, Generator, Sequence 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | def chunks(l: Sequence, n: int = 5) -> Generator[Sequence, None, None]: 12 | """Yield successive n-sized chunks from l.""" 13 | for i in range(0, len(l), n): 14 | yield l[i:i + n] 15 | 16 | def stride_chunks(l: Sequence, win_len: int, stride_len: int): 17 | s_id = 0 18 | e_id = min(len(l), win_len) 19 | 20 | while True: 21 | yield s_id, l[s_id:e_id] 22 | 23 | if e_id == len(l): 24 | break 25 | 26 | s_id = s_id + stride_len 27 | e_id = min(s_id + win_len, len(l)) 28 | 29 | 30 | class DefaultListOrderedDict(OrderedDict): 31 | def __missing__(self,k): 32 | self[k] = [] 33 | return self[k] 34 | 35 | 36 | def convert_examples_to_features(tokenizer, data, max_input_length, merge_pred=False, stride=0): 37 | features = [] 38 | for example_idx, d in enumerate(data): 39 | doc = _normalize_text(d['doc']) 40 | q = _normalize_text(d['q']) 41 | if not merge_pred: 42 | temp = tokenizer.encode_plus(q, doc, return_offsets_mapping=True, truncation=False) 43 | # cut by max_input_length 44 | input_ids = temp.data['input_ids'] 45 | if len(input_ids) > max_input_length: 46 | logger.info("Input length {} is too big. Cap to {}".format(len(input_ids), max_input_length)) 47 | for k, v in temp.data.items(): 48 | temp.data[k] = cap_to(v, max_input_length) 49 | 50 | temp['doc'] = doc 51 | temp['q'] = q 52 | temp['example_idx'] = example_idx 53 | features.append(temp) 54 | else: 55 | seq_pair_added_toks = tokenizer.model_max_length - tokenizer.max_len_sentences_pair 56 | q_toks = tokenizer.tokenize(q) 57 | q_enc = tokenizer.encode_plus(q, return_offsets_mapping=True) 58 | window_len = max_input_length - len(q_toks) - seq_pair_added_toks 59 | doc_enc = tokenizer.encode_plus(doc, add_special_tokens=False, return_offsets_mapping=True, truncation=False) 60 | for base_idx, chunk_mapping in stride_chunks(doc_enc['offset_mapping'], window_len, stride): 61 | chunk_st = chunk_mapping[0][0] 62 | chunk_ed = chunk_mapping[-1][1] 63 | chunk = doc[chunk_st: chunk_ed] 64 | 65 | new_dict = {} 66 | # add last [SEP] 67 | chunk_input_ids = doc_enc['input_ids'][base_idx:base_idx+len(chunk_mapping)] + [102] 68 | chunk_token_type_ids = [1]*len(chunk_mapping) + [1] 69 | chunk_attention_mask = [1]*len(chunk_mapping) + [1] 70 | tmp_chunk_offset_mapping = doc_enc['offset_mapping'][base_idx:base_idx+len(chunk_mapping)] 71 | base_offset = tmp_chunk_offset_mapping[0][0] 72 | chunk_offset_mapping = [] 73 | for offset in tmp_chunk_offset_mapping: 74 | chunk_offset_mapping.append((offset[0]-base_offset, offset[1]-base_offset)) 75 | chunk_offset_mapping.append((0, 0)) 76 | 77 | new_dict['input_ids'] = q_enc['input_ids'] + chunk_input_ids 78 | new_dict['token_type_ids'] = q_enc['token_type_ids'] + chunk_token_type_ids 79 | new_dict['attention_mask'] = q_enc['attention_mask'] + chunk_attention_mask 80 | new_dict['offset_mapping'] = q_enc['offset_mapping'] + chunk_offset_mapping 81 | 82 | new_dict['doc'] = chunk 83 | new_dict['orig_doc'] = doc 84 | new_dict['q'] = q 85 | new_dict['example_idx'] = example_idx 86 | new_dict['base_idx'] = base_idx 87 | new_dict['orig_offset_mapping'] = doc_enc['offset_mapping'] 88 | 89 | features.append(new_dict) 90 | 91 | return features 92 | 93 | def merge_predictions(results, strategy='max'): 94 | """ 95 | merge chunk predictions indicated by example_idx 96 | :param results: batched results 97 | :type results: dict 98 | :param strategy: 'max' or 'merge' indicating how to merge results, defaults to 'max' 99 | 'max': only keep results with highest probability 100 | 'merge': keep all predicted results 101 | :type strategy: str, optional 102 | :return: a list of dictionary containing batch number of results 103 | :rtype: list 104 | """ 105 | 106 | idx_res_map = OrderedDict() 107 | for r in results: 108 | example_idx = r.get('example_idx') 109 | if example_idx not in idx_res_map: 110 | idx_res_map[example_idx] = r 111 | else: 112 | if strategy == 'max': # use max prob answer 113 | if r['prob'] > idx_res_map[example_idx]['prob'] and not r.get('missing_warning'): 114 | idx_res_map[example_idx] = r 115 | elif strategy == 'merge': # merge all chunk answer 116 | if not r.get('missing_warning'): 117 | if idx_res_map[example_idx].get('missing_warning'): 118 | idx_res_map[example_idx] = r 119 | continue 120 | idx_res_map[example_idx]['value_type'] = r.pop('value_type') 121 | idx_res_map[example_idx]['example_idx'] = r.pop('example_idx') 122 | # only keep value that is different 123 | keep_idx = [i for i, v in enumerate(r.get('value')) if v not in idx_res_map[example_idx].get('value')] 124 | for k in r.keys(): 125 | idx_res_map[example_idx][k].extend([v for i, v in enumerate(r[k]) if i in keep_idx]) 126 | 127 | results = [v for v in idx_res_map.values()] 128 | return results 129 | 130 | 131 | 132 | def pad_batch(batch): 133 | max_len = max([len(f['input_ids']) for f in batch]) 134 | for f in batch: 135 | f_len = len(f['input_ids']) 136 | f['length'] = f_len 137 | f['input_ids'] = f['input_ids'] + [0] * (max_len - f_len) 138 | f['token_type_ids'] = f['token_type_ids'] + [0] * (max_len - f_len) 139 | f['attention_mask'] = f['attention_mask'] + [0] * (max_len - f_len) 140 | 141 | input_ids = [f['input_ids'] for f in batch] 142 | token_type_ids = [f['token_type_ids'] for f in batch] 143 | attn_masks = [f['attention_mask'] for f in batch] 144 | 145 | return input_ids, token_type_ids, attn_masks 146 | 147 | 148 | def _normalize_text(text): 149 | return re.sub('\s+', ' ', text) 150 | 151 | def cap_to(seq, max_len): 152 | prefix = seq[0:-1][0:max_len - 1] 153 | return prefix + [seq[-1]] 154 | 155 | def get_span_from_ohe(bio_labels): 156 | left = 0 157 | right = 1 158 | found_st = False 159 | found_ed = False 160 | span_indexes = [] 161 | 162 | while right < len(bio_labels): 163 | if not found_st and not found_ed: 164 | if bio_labels[right] == 0: 165 | right += 1 166 | continue 167 | else: 168 | found_st = True 169 | left = right 170 | if found_st: 171 | if bio_labels[right] == 1: 172 | right += 1 173 | continue 174 | else: 175 | span_indexes.append((left, right-1)) 176 | left = right 177 | right = left + 1 178 | found_st = False 179 | found_ed = False 180 | 181 | if set(bio_labels[left:right]) == {1}: 182 | span_indexes.append((left, right-1)) 183 | 184 | return span_indexes 185 | 186 | def get_ans_span(res): 187 | if not res: 188 | return "" 189 | 190 | for i, t in enumerate(res): 191 | if t.startswith("##"): 192 | res[i - 1] += t[2:] 193 | res[i] = "" 194 | 195 | value = " ".join([x for x in res if x != ""]) 196 | return value 197 | 198 | def is_whitespace(c): 199 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 200 | return True 201 | return False 202 | 203 | def token2char(orig_str, tokens): 204 | norm_tokens = [t.replace('##', '') for t in tokens] 205 | 206 | token_id = 0 207 | token_char_id = 0 208 | token2char_map = {} # token_id -> [start, end] 209 | 210 | token2char_map[token_id] = [0, None] 211 | for c_id, c in enumerate(orig_str): 212 | if is_whitespace(c): 213 | token2char_map[token_id][1] = c_id 214 | token_id += 1 215 | token_char_id = 0 216 | token2char_map[token_id] = [c_id+1, None] 217 | continue 218 | 219 | if token_char_id < len(norm_tokens[token_id]) and c == norm_tokens[token_id][token_char_id]: 220 | token_char_id += 1 221 | else: 222 | token2char_map[token_id][1] = c_id 223 | token_id += 1 224 | token_char_id = 0 225 | token2char_map[token_id] = [c_id, None] 226 | 227 | if c == norm_tokens[token_id][token_char_id]: 228 | token_char_id += 1 229 | 230 | token2char_map[token_id][1] = c_id+1 231 | # print(token2char_map) 232 | return token2char_map 233 | --------------------------------------------------------------------------------