├── kilt ├── __init__.py ├── configs │ ├── __init__.py │ ├── retriever │ │ ├── __init__.py │ │ ├── default_drqa.json │ │ ├── blink_biencoder.json │ │ ├── default_dpr.json │ │ └── default_blink.json │ ├── mapping │ │ └── dev_natural_questions.json │ ├── dev_data.json │ ├── train_data.json │ ├── test_data.json │ └── all_data.json ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── hotpotqa_ks.py │ ├── zero_shot_re.py │ ├── triviaqa.py │ ├── hotpotqa.py │ ├── entity_linking.py │ ├── fact_verification.py │ └── natural_questions.py ├── readers │ ├── t5 │ │ ├── __init__.py │ │ ├── README.md │ │ ├── evaluate_kilt_task.py │ │ ├── data.py │ │ ├── finetune.py │ │ └── base_transformer.py │ └── fid │ │ ├── README.md │ │ ├── preprocess.py │ │ └── postprocess.py ├── retrievers │ ├── __init__.py │ ├── README.md │ ├── DrQA_tfidf.py │ ├── base_retriever.py │ ├── BM25_connector.py │ ├── DPR_connector.py │ ├── DPR_distr_connector.py │ └── BLINK_connector.py ├── dataset_mapper.py ├── retrieval.py ├── knowledge_source.py └── eval_downstream.py ├── tests ├── __init__.py ├── test_data │ ├── __init__.py │ ├── guess2_2.jsonl │ ├── gold2.jsonl │ ├── guess2_1.jsonl │ ├── guess3_1.jsonl │ ├── gold3.jsonl │ ├── guess3_1.json │ ├── guess1_1.jsonl │ └── gold1.jsonl ├── README.md ├── test_eval_retrieval.py └── test_eval_downstream.py ├── img ├── KILT_logo.png └── infographic_e.jpg ├── scripts ├── map_datasets.py ├── README.md ├── download_all_kilt_data.py ├── get_triviaqa_input.py ├── execute_retrieval.py ├── map_TAC-KBP2010_to_KILT.py └── create_kilt_data_paragraphs.py ├── setup.py ├── LICENSE ├── .gitignore ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md └── README.md /kilt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kilt/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kilt/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kilt/readers/t5/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kilt/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kilt/configs/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | ```bash 2 | python -m unittest 3 | ``` -------------------------------------------------------------------------------- /img/KILT_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/KILT/HEAD/img/KILT_logo.png -------------------------------------------------------------------------------- /img/infographic_e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/KILT/HEAD/img/infographic_e.jpg -------------------------------------------------------------------------------- /kilt/configs/retriever/default_drqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "retriever_model": "models/kilt_db_simple.npz", 3 | "num_threads": 4 4 | } -------------------------------------------------------------------------------- /tests/test_data/guess2_2.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 1, "output": [{"provenance": [{"wikipedia_id":5,"section":5,"start_paragraph_id":5,"end_paragraph_id":5}]}]} -------------------------------------------------------------------------------- /kilt/configs/mapping/dev_natural_questions.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_file": "original_data/v1.0-simplified_simplified-nq-dev.jsonl", 3 | "output_file": "output/nq-dev-kilt.jsonl", 4 | "log_file": "output/nq-dev-kilt.log" 5 | } -------------------------------------------------------------------------------- /tests/test_data/gold2.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 1, "output": [{"answer": "test1", "provenance": [{"wikipedia_id":1,"section":1,"start_paragraph_id":"1","end_paragraph_id":1},{"wikipedia_id":"2","section":2,"start_paragraph_id":2,"end_paragraph_id":2}]},{"answer": "test2", "provenance": [{"wikipedia_id":5,"section":5,"start_paragraph_id":"5","end_paragraph_id":5}]},{"answer": "test3", "provenance": [{"wikipedia_id":7,"section":7,"start_paragraph_id":7,"end_paragraph_id":7},{"wikipedia_id":9,"section":9,"start_paragraph_id":"9","end_paragraph_id":9},{"wikipedia_id":11,"section":11,"start_paragraph_id":11,"end_paragraph_id":11}]}]} -------------------------------------------------------------------------------- /kilt/configs/retriever/blink_biencoder.json: -------------------------------------------------------------------------------- 1 | { 2 | "test_entities": null, 3 | "test_mentions": null, 4 | "interactive": false, 5 | "biencoder_model": "models/biencoder_wiki_large.bin", 6 | "biencoder_config": "models/biencoder_wiki_large.json", 7 | "entity_catalogue": "models/entity.jsonl", 8 | "entity_encoding": "models/all_entities_large.t7", 9 | "crossencoder_model": "models/crossencoder_wiki_large.bin", 10 | "crossencoder_config": "models/crossencoder_wiki_large.json", 11 | "wikipedia_title2id": "models/Wikipedia_title2id.p", 12 | "fast": true, 13 | "output_path": "logs/" 14 | } -------------------------------------------------------------------------------- /kilt/configs/retriever/default_dpr.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_file": "models/dpr_multi_set_hf_bert.0", 3 | "encoded_ctx_file": "models/kilt_passages_2048_0.pkl", 4 | "encoder_model_type": null, 5 | "pretrained_model_cfg": null, 6 | "projection_dim": 0, 7 | "device": "cuda", 8 | "n_gpu": 2, 9 | "local_rank": -1, 10 | "fp16": false, 11 | "index_buffer": 50000, 12 | "hnsw_index": false, 13 | "batch_size": 512, 14 | "qa_file": "", 15 | "out_file": "output/info.log", 16 | "n_docs": 50, 17 | "ctx_file": "models/kilt_w100_title.tsv", 18 | "KILT_mapping": "models/mapping_KILT_title.p", 19 | "hnsw_index_path": null 20 | } -------------------------------------------------------------------------------- /tests/test_data/guess2_1.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 1, "output": [{"provenance": [{"wikipedia_id":2,"section":2,"start_paragraph_id":2,"end_paragraph_id":2},{"wikipedia_id":5,"section":5,"start_paragraph_id":5,"end_paragraph_id":5},{"wikipedia_id":"7","section":7,"start_paragraph_id":"7","end_paragraph_id":7},{"wikipedia_id":33,"section":33,"start_paragraph_id":33,"end_paragraph_id":"33"},{"wikipedia_id":1,"section":"1","start_paragraph_id":1,"end_paragraph_id":1},{"wikipedia_id":9,"section":9,"start_paragraph_id":9,"end_paragraph_id":"9"},{"wikipedia_id":44,"section":44,"start_paragraph_id":44,"end_paragraph_id":44},{"wikipedia_id":"88","section":88,"start_paragraph_id":88,"end_paragraph_id":88}]}]} -------------------------------------------------------------------------------- /tests/test_data/guess3_1.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 1, "output": [{"answer": "the #### transcript is a written version of each day 's cnn student news program use this transcript to he lp students with reading comprehension and vocabulary use the weekly newsquiz to test your knowledge of storie s you saw on cnn student news", "provenance": [{"wikipedia_id": 1}]}]} 2 | {"id": 2, "output": [{"answer": "the #### transcript is a written version of each day 's cnn student news program use this transcript to he lp students with reading comprehension and vocabulary use the weekly newsquiz to test your knowledge of storie s you saw on cnn student news", "provenance": [{"wikipedia_id": "2"}]}]} 3 | -------------------------------------------------------------------------------- /kilt/configs/retriever/default_blink.json: -------------------------------------------------------------------------------- 1 | { 2 | "test_entities": null, 3 | "test_mentions": null, 4 | "interactive": false, 5 | "biencoder_model": "models/biencoder_wiki_large.bin", 6 | "biencoder_config": "models/biencoder_wiki_large.json", 7 | "entity_catalogue": "models/entity.jsonl", 8 | "entity_encoding": "models/all_entities_large.t7", 9 | "crossencoder_model": "models/crossencoder_wiki_large.bin", 10 | "crossencoder_config": "models/crossencoder_wiki_large.json", 11 | "wikipedia_title2id": "models/Wikipedia_title2id.p", 12 | "fast": false, 13 | "output_path": "logs/", 14 | "top_k": 100, 15 | "faiss_index": null, 16 | "index_path": null 17 | } -------------------------------------------------------------------------------- /kilt/configs/dev_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "Fact Checking": { 3 | "FEVER": "data/fever-dev-kilt.jsonl" 4 | }, 5 | "Entity Linking": { 6 | "AIDA-YAGO2": "data/aidayago2-dev-kilt.jsonl", 7 | "WNED": "data/wned-dev-kilt.jsonl", 8 | "CWEB": "data/cweb-dev-kilt.jsonl" 9 | }, 10 | "Slot Filling": { 11 | "Zero Shot RE": "data/structured_zeroshot-dev-kilt.jsonl", 12 | "T-REx": "data/trex-dev-kilt.jsonl" 13 | }, 14 | "Open Domain QA": { 15 | "Natural Questions": "data/nq-dev-kilt.jsonl", 16 | "HotpotQA": "data/hotpotqa-dev-kilt.jsonl", 17 | "TriviaQA": "data/triviaqa-dev-kilt.jsonl", 18 | "ELI5": "data/eli5-dev-kilt.jsonl" 19 | }, 20 | "Dialogue": { 21 | "Wizard of Wikipedia": "data/wow-dev-kilt.jsonl" 22 | } 23 | } -------------------------------------------------------------------------------- /kilt/configs/train_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "Entity Linking": { 3 | "AIDA-YAGO2": "data/aidayago2-train-kilt.jsonl", 4 | "WNED": "data/wned-train-kilt.jsonl", 5 | "CWEB": "data/cweb-train-kilt.jsonl" 6 | }, 7 | "Fact Checking": { 8 | "FEVER": "data/fever-train-kilt.jsonl" 9 | }, 10 | "Slot Filling": { 11 | "Zero Shot RE": "data/structured_zeroshot-train-kilt.jsonl", 12 | "T-REx": "data/trex-train-kilt.jsonl" 13 | }, 14 | "Open Domain QA": { 15 | "Natural Questions": "data/nq-train-kilt.jsonl", 16 | "HotpotQA": "data/hotpotqa-train-kilt.jsonl", 17 | "TriviaQA": "data/triviaqa-train-kilt.jsonl", 18 | "ELI5": "data/eli5-train-kilt.jsonl" 19 | }, 20 | "Dialogue": { 21 | "Wizard of Wikipedia": "data/wow-train-kilt.jsonl" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/map_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from kilt import dataset_mapper 8 | from kilt.datasets import ( 9 | base_dataset, 10 | entity_linking, 11 | fact_verification, 12 | natural_questions, 13 | zero_shot_re, 14 | hotpotqa, 15 | wizard, 16 | ) 17 | 18 | 19 | if __name__ == "__main__": 20 | datasets = [] 21 | 22 | # NQ dev set 23 | datasets.append( 24 | natural_questions.NaturalQuestionsDataset.from_config_file( 25 | "dev_natural_questions", "kilt/configs/mapping/dev_natural_questions.json" 26 | ) 27 | ) 28 | 29 | for dataset in datasets: 30 | dataset_mapper.map_dataset(dataset=dataset) 31 | -------------------------------------------------------------------------------- /tests/test_data/gold3.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 1, "output": [{"answer": "this page includes the show transcript use the transcript to help students with reading comprehension and vocabulary at the bottom of the page , comment for a chance to be mentioned on cnn student news . you must be a teac her or a student age # # or older to request a mention on the cnn student news roll call . the weekly newsquiz tests students ' knowledge of even ts in the news", "provenance": [{"wikipedia_id": "1"}]}]} 2 | {"id": 2, "output": [{"answer": "this page includes the show transcript use the transcript to help students with reading comprehension and vocabulary at the bottom of the page , comment for a chance to be mentioned on cnn student news . you must be a teac her or a student age # # or older to request a mention on the cnn student news roll call . the weekly newsquiz tests students ' knowledge of even ts in the news", "provenance": [{"wikipedia_id": "1"}]}]} 3 | -------------------------------------------------------------------------------- /kilt/configs/test_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "Fact Checking": { 3 | "FEVER": "data/fever-test_without_answers-kilt.jsonl" 4 | }, 5 | "Entity Linking": { 6 | "AIDA-YAGO2": "data/aidayago2-test_without_answers-kilt.jsonl", 7 | "WNED": "data/wned-test_without_answers-kilt.jsonl", 8 | "CWEB": "data/cweb-test_without_answers-kilt.jsonl" 9 | }, 10 | "Slot Filling": { 11 | "Zero Shot RE": "data/structured_zeroshot-test_without_answers-kilt.jsonl", 12 | "T-REx": "data/trex-test_without_answers-kilt.jsonl" 13 | }, 14 | "Open Domain QA": { 15 | "Natural Questions": "data/nq-test_without_answers-kilt.jsonl", 16 | "HotpotQA": "data/hotpotqa-test_without_answers-kilt.jsonl", 17 | "TriviaQA": "data/triviaqa-test_without_answers-kilt.jsonl", 18 | "ELI5": "data/eli5-test_without_answers-kilt.jsonl" 19 | }, 20 | "Dialogue": { 21 | "Wizard of Wikipedia": "data/wow-test_without_answers-kilt.jsonl" 22 | } 23 | } -------------------------------------------------------------------------------- /tests/test_data/guess3_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": 1, 3 | "output": [ 4 | { 5 | "answer": "the #### transcript is a written version of each day 's cnn student news program use this transcript to he lp students with reading comprehension and vocabulary use the weekly newsquiz to test your knowledge of storie s you saw on cnn student news", 6 | "provenance": [ 7 | { 8 | "wikipedia_id": 1 9 | } 10 | ] 11 | } 12 | ] 13 | } 14 | { 15 | "id": 2, 16 | "output": [ 17 | { 18 | "answer": "the #### transcript is a written version of each day 's cnn student news program use this transcript to he lp students with reading comprehension and vocabulary use the weekly newsquiz to test your knowledge of storie s you saw on cnn student news", 19 | "provenance": [ 20 | { 21 | "wikipedia_id": "2" 22 | } 23 | ] 24 | } 25 | ] 26 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import setuptools 9 | 10 | with open("README.md", "r") as fh: 11 | long_description = fh.read() 12 | 13 | setuptools.setup( 14 | name="kilt", 15 | version="0.1.0", 16 | description="Knowledge Intensive Language Tasks", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | packages=setuptools.find_packages(), 20 | classifiers=[ 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | ], 25 | python_requires=">=3.7", 26 | install_requires=[ 27 | "bs4", 28 | "flair", 29 | "jsonlines", 30 | "nltk", 31 | "prettytable", 32 | "pymongo", 33 | "pytest", 34 | "rouge", 35 | "spacy>=2.1.8", 36 | "torch", 37 | "tqdm", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_data/guess1_1.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 1, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]} 2 | {"id": 2, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "a suicide", "provenance": [{"wikipedia_id": "999", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]} 3 | {"id": 3, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "999", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]} 4 | -------------------------------------------------------------------------------- /kilt/readers/fid/README.md: -------------------------------------------------------------------------------- 1 | # Fusion-in-Decoder 2 | 3 | ### Install Fusion-in-Decoder 4 | 5 | `pip install -e git+https://github.com/facebookresearch/FiD#egg=FiD` 6 | 7 | ### Convert KILT data format to FiD format 8 | 9 | ```shell 10 | python preprocess.py input_data.jsonl outputpath 11 | ``` 12 | 13 | ### Train FiD 14 | 15 | ```shell 16 | python src/fid/train_reader.py \ 17 | --use_checkpoint \ 18 | --train_data train_data.json \ 19 | --eval_data eval_data.json \ 20 | --model_size base \ 21 | --per_gpu_batch_size 1 \ 22 | --n_context 100 \ 23 | --name my_experiment \ 24 | --checkpoint_dir checkpoint \ 25 | ``` 26 | 27 | ### Eval FiD 28 | 29 | ```shell 30 | python src/fid/test_reader.py \ 31 | --model_path checkpoint/my_experiment/checkpoint/best_dev \ 32 | --eval_data eval_data.json \ 33 | --per_gpu_batch_size 1 \ 34 | --n_context 100 \ 35 | --name my_test \ 36 | --checkpoint_dir checkpoint \ 37 | ``` 38 | 39 | ### Convert to KILT format for eval 40 | 41 | ```shell 42 | python postprocess.py checkpoint/my_test/final_output.json my_kilt_output.jsonl initial_input_data.jsonl 43 | ``` 44 | 45 | -------------------------------------------------------------------------------- /tests/test_data/gold1.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 1, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]} 2 | {"id": 2, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]} 3 | {"id": 3, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "it was suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | original_data 3 | logs 4 | output 5 | src 6 | predictions 7 | models 8 | NHKB/ 9 | kilt_internal/ 10 | checkpoint/ 11 | fid_data/ 12 | 13 | text_utils.py 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | .vscode/ 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *,cover 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # IPython checkpoints 75 | .ipynb_checkpoints 76 | notebook/.ipynb_checkpoints 77 | 78 | # Mac os x stuff 79 | .DS_Store -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to KILT 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to KILT, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /kilt/readers/fid/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import sys 10 | from tqdm.auto import tqdm 11 | 12 | 13 | def convert_kilt(inputpath, outputpath): 14 | data = [] 15 | inputdata = open(inputpath, "r") 16 | for example in tqdm(inputdata): 17 | d = {} 18 | ex = json.loads(example) 19 | d["question"] = ex["input"] 20 | answers = set() 21 | for a in ex["output"]: 22 | if "answer" in a: 23 | answers.add(a["answer"]) 24 | d["answers"] = list(answers) 25 | d["id"] = ex["id"] 26 | passages = [] 27 | for c in ex["output"][0]["provenance"]: 28 | p = {"text": c["text"], "title": ""} 29 | if "wikipedia_title" in c: 30 | p["title"] = c["wikipedia_title"] 31 | if "wikipedia_id" in c: 32 | p["wikipedia_id"] = c["wikipedia_id"] 33 | passages.append(p) 34 | d["ctxs"] = passages 35 | data.append(d) 36 | with open(outputpath, "w") as fout: 37 | json.dump(data, fout) 38 | 39 | 40 | if __name__ == "__main__": 41 | inputpath = sys.argv[1] 42 | outputpath = sys.argv[2] 43 | convert_kilt(inputpath, outputpath) 44 | -------------------------------------------------------------------------------- /kilt/readers/fid/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | 9 | import sys 10 | import json 11 | 12 | def convert_to_kilt(inputpath, outputpath, datapath): 13 | data = [] 14 | with open(datapath, 'r') as fin: 15 | for k, example in enumerate(fin): 16 | example = json.loads(example) 17 | data.append(example) 18 | datadict = {ex['id']:ex for ex in data} 19 | outfile = open(outputpath, 'w') 20 | with open(inputpath, 'r') as f: 21 | lines = f.readlines() 22 | for line in lines: 23 | d = {} 24 | try: 25 | id, answer = line.split('\t') 26 | except ValueError: 27 | print('error') 28 | id = int(id) 29 | answer = answer.split('\n')[0] 30 | if id in d: 31 | print('key already in dict', d[id], answer) 32 | d['id'] = id 33 | 34 | d['output'] = [{'answer': answer}] 35 | 36 | json.dump(d, outfile) 37 | outfile.write('\n') 38 | 39 | 40 | if __name__ == '__main__': 41 | inputpath = sys.argv[1] 42 | outputpath = sys.argv[2] 43 | datapath = sys.argv[3] 44 | convert_to_kilt(inputpath, outputpath, datapath) 45 | -------------------------------------------------------------------------------- /kilt/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib.resources 8 | import json 9 | 10 | from abc import ABC, abstractmethod 11 | 12 | from kilt.configs import mapping 13 | 14 | 15 | class Dataset(ABC): 16 | def __init__(self, name): 17 | self.name = name 18 | self.output_file = None 19 | self.max_chunks = None 20 | 21 | @classmethod 22 | def from_default_config(cls, name): 23 | config = json.loads( 24 | importlib.resources.read_text( 25 | mapping, "default_{name}.json".format(name=name) 26 | ) 27 | ) 28 | return cls(name, **config) 29 | 30 | @classmethod 31 | def from_config_file(cls, name, config_file): 32 | with open(config_file, "r") as cf: 33 | config = json.load(cf) 34 | return cls(name, **config) 35 | 36 | @classmethod 37 | def from_config_string(cls, name, config_string): 38 | config = json.loads(config_string) 39 | return cls(name, **config) 40 | 41 | def get_chunks(self, num_chunks): 42 | """ 43 | Retruns a list of chunks of the dataset. 44 | """ 45 | pass 46 | 47 | @abstractmethod 48 | def process_chunk(self, chunk, ks, chunk_id): 49 | """ 50 | Processes a single chunk of the dataset. Maps each line in the 51 | chunk into the kilt format. Returns a list of mapped entries and 52 | optionally metadata. 53 | """ 54 | pass 55 | 56 | @abstractmethod 57 | def postprocess_metadata(self, metadata): 58 | pass 59 | -------------------------------------------------------------------------------- /kilt/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import sys 10 | import multiprocessing 11 | 12 | from multiprocessing.pool import ThreadPool 13 | 14 | from kilt.knowledge_source import KnowledgeSource 15 | 16 | 17 | def run_thread(args): 18 | dataset = args["dataset"] 19 | return dataset.process_chunk(args["chunk"], args["ks"], args["id"]) 20 | 21 | 22 | def map_dataset(dataset): 23 | print("Processing {} dataset.".format(dataset.name)) 24 | ks = KnowledgeSource() 25 | 26 | num_threads = ( 27 | min(dataset.max_chunks, int(multiprocessing.cpu_count())) 28 | if dataset.max_chunks and dataset.max_chunks > 0 29 | else int(multiprocessing.cpu_count()) 30 | ) 31 | print("num_threads", num_threads) 32 | pool = ThreadPool(num_threads) 33 | chunks = dataset.get_chunks(num_threads) 34 | results = pool.map( 35 | run_thread, 36 | [ 37 | {"id": id, "chunk": chunk, "ks": ks, "dataset": dataset} 38 | for id, chunk in enumerate(chunks) 39 | ], 40 | ) 41 | 42 | kilt_data = [] 43 | metadata = [] 44 | for x in results: 45 | kd, meta = x 46 | kilt_data.extend(kd) 47 | metadata.append(meta) 48 | 49 | pool.terminate() 50 | pool.join() 51 | 52 | dataset.postprocess_metadata(metadata) 53 | 54 | with open(dataset.output_file, "w+") as outfile: 55 | for idx, data in enumerate(kilt_data): 56 | print(round(idx * 100 / len(kilt_data), 2), "%", end="\r") 57 | sys.stdout.flush() 58 | json.dump(data, outfile) 59 | outfile.write("\n") 60 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Scripts 2 | 3 | ## execute retrieval 4 | 5 | This script executes retrieval from a given model and store retrieved augmented KILT files 6 | 7 | options: 8 | * __--model_name/-m__ : retriever model name in {drqa,solr,dpr,blink,bm25} 9 | * __--model_configuration/-c__ : model configuration 10 | * __--output_folder/-o__ : output folder 11 | 12 | 13 | example 14 | ``` 15 | python scripts/execute_retrieval.py -m dpr -c kilt/configs/retriever/default_dpr.json -o predictions/dpr/ 16 | ``` 17 | 18 | To setup the retrievers see [kilt/readers/README.md](this README). 19 | 20 | 21 | ## create kilt data paragraphs 22 | The script `scripts/create_kilt_data_paragraphs.py` create chunks of all wikipedia in the format: 23 | ``` 24 | {'_id': str, 25 | 'wikipedia_id': str, 26 | 'wikipedia_title': str, 27 | 'text': str, 28 | 'anchors': [{'text': str, 29 | 'href': str, 30 | 'source': {'paragraph_id': int, 'start': int, 'end': int}, 31 | 'start': int, 32 | 'end': int}, ...], 33 | 'categories': str, 34 | 'history': {'revid': int, 35 | 'timestamp': str, 36 | 'parentid': int, 37 | 'pre_dump': bool, 38 | 'pageid': int, 39 | 'url': str}, 40 | 'sources': [{'paragraph_id': int, 'start': 0, 'end': int}, ...], 41 | 'section': str} 42 | ``` 43 | It creates a `jsonl` file(s) where for each line there is a consecutive number (ID) and a `json` dictionary. 44 | 45 | The script can launch 3 invididual steps that has to be run in order. Here an example. First preprocess uses `threads` to split the Knowledge Bases in even parts and it saves them into `folder`. 46 | ```bash 47 | python create_kilt_data_paragraphs \ 48 | --step preprocess \ 49 | --folder "./kilt_data" \ 50 | --threads 32 51 | ``` 52 | 53 | Then, the following creates chunks of size `chunk_size`. `rank`is the id of the portion of the dataset to compute. 54 | ```bash 55 | python create_kilt_data_paragraphs \ 56 | --step main \ 57 | --chunk_size 100 58 | --folder "./kilt_data" \ 59 | --rank 60 | ``` 61 | 62 | Finally, we can merge all files with 63 | ```bash 64 | python create_kilt_data_paragraphs \ 65 | --step merge \ 66 | --folder "./kilt_data" \ 67 | --threads 32 68 | ``` 69 | -------------------------------------------------------------------------------- /kilt/retrievers/README.md: -------------------------------------------------------------------------------- 1 | # DrQA tf-idf 2 | 3 | ## install 4 | ```bash 5 | pip install -e git+https://github.com/facebookresearch/DrQA#egg=DrQA 6 | pip install pexpect==4.8 7 | ``` 8 | 9 | ## download models 10 | 11 | Download the following files in the `models` folder. 12 | 13 | - [kilt_db_simple.npz](http://dl.fbaipublicfiles.com/KILT/kilt_db_simple.npz) 14 | 15 | ## run 16 | ```bash 17 | python scripts/execute_retrieval.py -m drqa -o predictions/drqa 18 | ``` 19 | 20 | # DPR 21 | 22 | ## install 23 | ```bash 24 | pip install -e git+https://github.com/facebookresearch/DPR.git#egg=DPR 25 | ``` 26 | 27 | ## download models 28 | 29 | Download the following files in the `models` folder. 30 | 31 | - [dpr_multi_set_hf_bert.0](http://dl.fbaipublicfiles.com/KILT/dpr_multi_set_hf_bert.0) 32 | - [kilt_passages_2048_0.pkl](http://dl.fbaipublicfiles.com/KILT/kilt_passages_2048_0.pkl) 33 | - [kilt_w100_title.tsv](http://dl.fbaipublicfiles.com/KILT/kilt_w100_title.tsv) 34 | - [mapping_KILT_title.p](http://dl.fbaipublicfiles.com/KILT/mapping_KILT_title.p) 35 | 36 | ## run 37 | ```bash 38 | python scripts/execute_retrieval.py -m dpr -o predictions/dpr 39 | ``` 40 | 41 | # DPR distributed 42 | 43 | Please follow instructions in the [Sphere](https://github.com/facebookresearch/Sphere) repository. 44 | 45 | # BLINK 46 | 47 | ## install 48 | ```bash 49 | pip install -e git+https://github.com/facebookresearch/BLINK.git#egg=BLINK 50 | pip install flair 51 | ``` 52 | 53 | ## download models 54 | 55 | Download files in the `models` folder using the following script: [download_models.sh](https://github.com/facebookresearch/BLINK/blob/master/download_blink_models.sh) 56 | 57 | And this file: 58 | - [Wikipedia_title2id.p](http://dl.fbaipublicfiles.com/KILT/Wikipedia_title2id.p) 59 | 60 | ## run 61 | ```bash 62 | python scripts/execute_retrieval.py -m blink -o predictions/blink 63 | ``` 64 | 65 | # BM25 66 | Follow instructions in [`pyserini`](https://github.com/castorini/pyserini#installation) to download JAVA. 67 | ## install 68 | ```bash 69 | pip install jnius 70 | pip install pyserini==0.9.4.0 71 | ``` 72 | 73 | ## run 74 | ```bash 75 | python scripts/execute_retrieval.py -m bm25 -o predictions/bm25 76 | ``` 77 | -------------------------------------------------------------------------------- /tests/test_eval_retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import unittest 9 | import importlib.resources 10 | 11 | import kilt.eval_downstream 12 | import kilt.eval_retrieval 13 | import tests.test_data as test_data 14 | 15 | 16 | class TestEvalRetrieval(unittest.TestCase): 17 | def test_calculate_metrics(self): 18 | 19 | with importlib.resources.open_text(test_data, "gold2.jsonl") as gold_file: 20 | with importlib.resources.open_text( 21 | test_data, "guess2_1.jsonl" 22 | ) as guess_file: 23 | 24 | for rank_keys in [ 25 | ["wikipedia_id"], 26 | ["wikipedia_id", "section"], 27 | ["wikipedia_id", "start_paragraph_id", "end_paragraph_id"], 28 | ]: 29 | 30 | result = kilt.eval_retrieval.evaluate( 31 | gold_file.name, guess_file.name, ks=[1, 5], rank_keys=rank_keys, 32 | ) 33 | self.assertEqual(result["Rprec"], 1 / 2) 34 | self.assertEqual(result["precision@1"], 1) 35 | self.assertEqual(result["precision@5"], 2 / 5) 36 | self.assertEqual(result["recall@5"], 2 / 3) 37 | self.assertEqual(result["success_rate@5"], 1) 38 | 39 | with importlib.resources.open_text( 40 | test_data, "guess2_2.jsonl" 41 | ) as guess_file: 42 | 43 | for rank_keys in [ 44 | ["wikipedia_id"], 45 | ["wikipedia_id", "section"], 46 | ["wikipedia_id", "start_paragraph_id", "end_paragraph_id"], 47 | ]: 48 | 49 | result = kilt.eval_retrieval.evaluate( 50 | gold_file.name, guess_file.name, ks=[1, 5], rank_keys=rank_keys, 51 | ) 52 | self.assertEqual(result["Rprec"], 1) 53 | self.assertEqual(result["precision@1"], 1) 54 | self.assertEqual(result["precision@5"], 1 / 5) 55 | self.assertEqual(result["recall@5"], 1 / 3) 56 | self.assertEqual(result["success_rate@5"], 1) 57 | 58 | -------------------------------------------------------------------------------- /kilt/configs/all_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "Fact Checking Test": { 3 | "FEVER": "data/fever-test_without_answers-kilt.jsonl" 4 | }, 5 | "Entity Linking Test": { 6 | "AIDA-YAGO2": "data/aidayago2-test_without_answers-kilt.jsonl", 7 | "WNED": "data/wned-test_without_answers-kilt.jsonl", 8 | "CWEB": "data/cweb-test_without_answers-kilt.jsonl" 9 | }, 10 | "Slot Filling Test": { 11 | "Zero Shot RE": "data/structured_zeroshot-test_without_answers-kilt.jsonl", 12 | "T-REx": "data/trex-test_without_answers-kilt.jsonl" 13 | }, 14 | "Open Domain QA Test": { 15 | "Natural Questions": "data/nq-test_without_answers-kilt.jsonl", 16 | "HotpotQA": "data/hotpotqa-test_without_answers-kilt.jsonl", 17 | "TriviaQA": "data/triviaqa-test_without_answers-kilt.jsonl", 18 | "ELI5": "data/eli5-test_without_answers-kilt.jsonl" 19 | }, 20 | "Dialogue Test Dev": { 21 | "Wizard of Wikipedia": "data/wow-test_without_answers-kilt.jsonl" 22 | }, 23 | "Fact Checking Dev": { 24 | "FEVER": "data/fever-dev-kilt.jsonl" 25 | }, 26 | "Entity Linking Dev": { 27 | "AIDA-YAGO2": "data/aidayago2-dev-kilt.jsonl", 28 | "WNED": "data/wned-dev-kilt.jsonl", 29 | "CWEB": "data/cweb-dev-kilt.jsonl" 30 | }, 31 | "Slot Filling Dev": { 32 | "Zero Shot RE": "data/structured_zeroshot-dev-kilt.jsonl", 33 | "T-REx": "data/trex-dev-kilt.jsonl" 34 | }, 35 | "Open Domain QA Dev": { 36 | "Natural Questions": "data/nq-dev-kilt.jsonl", 37 | "HotpotQA": "data/hotpotqa-dev-kilt.jsonl", 38 | "TriviaQA": "data/triviaqa-dev-kilt.jsonl", 39 | "ELI5": "data/eli5-dev-kilt.jsonl" 40 | }, 41 | "Dialogue Dev Train": { 42 | "Wizard of Wikipedia": "data/wow-dev-kilt.jsonl" 43 | }, 44 | "Entity Linking Train": { 45 | "AIDA-YAGO2": "data/aidayago2-train-kilt.jsonl", 46 | "WNED": "data/wned-train-kilt.jsonl", 47 | "CWEB": "data/cweb-train-kilt.jsonl" 48 | }, 49 | "Fact Checking Train": { 50 | "FEVER": "data/fever-train-kilt.jsonl" 51 | }, 52 | "Slot Filling Train": { 53 | "Zero Shot RE": "data/structured_zeroshot-train-kilt.jsonl", 54 | "T-REx": "data/trex-train-kilt.jsonl" 55 | }, 56 | "Open Domain QA Train": { 57 | "Natural Questions": "data/nq-train-kilt.jsonl", 58 | "HotpotQA": "data/hotpotqa-train-kilt.jsonl", 59 | "TriviaQA": "data/triviaqa-train-kilt.jsonl", 60 | "ELI5": "data/eli5-train-kilt.jsonl" 61 | }, 62 | "Dialogue Train": { 63 | "Wizard of Wikipedia": "data/wow-train-kilt.jsonl" 64 | } 65 | } -------------------------------------------------------------------------------- /kilt/datasets/hotpotqa_ks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import multiprocessing 8 | from multiprocessing.pool import ThreadPool 9 | import os 10 | from kilt.kilt_utils import chunk_it 11 | import bz2 12 | import json 13 | 14 | STEPS = 10 15 | 16 | 17 | def run_thread(arguments): 18 | thread_id = arguments["id"] 19 | filenames = arguments["filenames"] 20 | verbose = arguments["verbose"] 21 | 22 | output_dict = {} 23 | 24 | steps = int(len(filenames) / STEPS) 25 | 26 | for file_id, filename in enumerate(filenames): 27 | 28 | if verbose: 29 | try: 30 | if file_id % steps == 0: 31 | percentage = file_id * 100 / len(filenames) 32 | print( 33 | "t{} [{}/{}] {:.2f}%".format( 34 | thread_id, file_id, len(filenames), percentage 35 | ), 36 | flush=True, 37 | ) 38 | except: 39 | pass 40 | with bz2.open( 41 | filename, 42 | mode="r", 43 | compresslevel=9, 44 | encoding=None, 45 | errors=None, 46 | newline=None, 47 | ) as f: 48 | for line in f: 49 | data = json.loads(line) 50 | output_dict[data["title"]] = data 51 | 52 | return output_dict 53 | 54 | 55 | def load_ks(ks_directory, verbose=False): 56 | NUM_TREADS = int(multiprocessing.cpu_count()) 57 | 58 | if verbose: 59 | print(f"loading hotpotqa knowledge source with {NUM_TREADS} threads") 60 | pool = ThreadPool(NUM_TREADS) 61 | 62 | filenames = [] 63 | directories = [ 64 | os.path.join(ks_directory, o) 65 | for o in os.listdir(ks_directory) 66 | if os.path.isdir(os.path.join(ks_directory, o)) 67 | ] 68 | for directory in directories: 69 | onlyfiles = [ 70 | f 71 | for f in os.listdir(directory) 72 | if os.path.isfile(os.path.join(directory, f)) 73 | ] 74 | for filetto in onlyfiles: 75 | filename = "{}/{}".format(directory, filetto) 76 | filenames.append(filename) 77 | 78 | arguments = [ 79 | {"id": i, "filenames": chunk, "verbose": verbose} 80 | for i, chunk in enumerate(chunk_it(filenames, NUM_TREADS)) 81 | ] 82 | 83 | results = pool.map(run_thread, arguments) 84 | output_dict = {} 85 | for x in results: 86 | output_dict.update(x) 87 | pool.terminate() 88 | pool.join() 89 | 90 | return output_dict 91 | -------------------------------------------------------------------------------- /kilt/readers/t5/README.md: -------------------------------------------------------------------------------- 1 | # T5 2 | 3 | ### aidayago2 4 | ``` 5 | python ../finetune.py \ 6 | --data_dir=${DATA_DIR} \ 7 | --dataset=${DATASET} \ 8 | --model_name_or_path=t5-base \ 9 | --learning_rate=1e-3 \ 10 | --num_train_epoch=1000 \ 11 | --output_dir=$OUTPUT_DIR \ 12 | --n_gpu=8 \ 13 | --do_train 14 | ``` 15 | 16 | ### eli5 17 | ``` 18 | python ../finetune.py \ 19 | --data_dir=${DATA_DIR} \ 20 | --dataset=${DATASET} \ 21 | --model_name_or_path=t5-base \ 22 | --learning_rate=1e-3 \ 23 | --num_train_epoch=500 \ 24 | --output_dir=$OUTPUT_DIR \ 25 | --n_gpu=8 \ 26 | --do_train 27 | ``` 28 | 29 | ### FEVER 30 | ``` 31 | python ../finetune.py \ 32 | --data_dir=${DATA_DIR} \ 33 | --dataset=${DATASET} \ 34 | --model_name_or_path=t5-base \ 35 | --learning_rate=1e-3 \ 36 | --num_train_epoch=1000 \ 37 | --output_dir=$OUTPUT_DIR \ 38 | --n_gpu=4 \ 39 | --do_train 40 | ``` 41 | 42 | ### HotpotQA 43 | ``` 44 | python ../finetune.py \ 45 | --data_dir=${DATA_DIR} \ 46 | --dataset=${DATASET} \ 47 | --model_name_or_path=t5-base \ 48 | --learning_rate=1e-3 \ 49 | --num_train_epoch=1000 \ 50 | --output_dir=$OUTPUT_DIR \ 51 | --n_gpu=4 \ 52 | --do_train 53 | ``` 54 | 55 | ### Natural Questions 56 | ``` 57 | python ../finetune.py \ 58 | --data_dir=${DATA_DIR} \ 59 | --dataset=${DATASET} \ 60 | --model_name_or_path=t5-base \ 61 | --learning_rate=1e-3 \ 62 | --num_train_epoch=500 \ 63 | --output_dir=$OUTPUT_DIR \ 64 | --n_gpu=4 \ 65 | --do_train 66 | ``` 67 | 68 | ### T-REx 69 | ``` 70 | python ../finetune.py \ 71 | --data_dir=${DATA_DIR} \ 72 | --dataset=${DATASET} \ 73 | --model_name_or_path=t5-base \ 74 | --learning_rate=1e-3 \ 75 | --num_train_epoch=500 \ 76 | --output_dir=$OUTPUT_DIR \ 77 | --n_gpu=4 \ 78 | --do_train 79 | ``` 80 | 81 | ### TriviaQA 82 | ``` 83 | python ../finetune.py \ 84 | --data_dir=${DATA_DIR} \ 85 | --dataset=${DATASET} \ 86 | --model_name_or_path=t5-base \ 87 | --learning_rate=1e-3 \ 88 | --num_train_epoch=2100 \ 89 | --output_dir=$OUTPUT_DIR \ 90 | --n_gpu=4 \ 91 | --do_train 92 | ``` 93 | 94 | ### Wizard of Wikipedia 95 | ``` 96 | python ../finetune.py \ 97 | --data_dir=${DATA_DIR} \ 98 | --dataset=${DATASET} \ 99 | --model_name_or_path=t5-base \ 100 | --learning_rate=1e-3 \ 101 | --num_train_epoch=1000 \ 102 | --output_dir=$OUTPUT_DIR \ 103 | --n_gpu=8 \ 104 | --do_train 105 | ``` 106 | 107 | ### Zeroshot RE 108 | ``` 109 | python ../finetune.py \ 110 | --data_dir=${DATA_DIR} \ 111 | --dataset=${DATASET} \ 112 | --model_name_or_path=t5-base \ 113 | --learning_rate=1e-3 \ 114 | --num_train_epoch=1000 \ 115 | --n_gpu=4 \ 116 | --output_dir=$OUTPUT_DIR \ 117 | --do_train 118 | ``` -------------------------------------------------------------------------------- /scripts/download_all_kilt_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import requests 9 | from tqdm.auto import tqdm 10 | 11 | urls = [ 12 | "http://dl.fbaipublicfiles.com/KILT/fever-train-kilt.jsonl", 13 | "http://dl.fbaipublicfiles.com/KILT/fever-dev-kilt.jsonl", 14 | "http://dl.fbaipublicfiles.com/KILT/fever-test_without_answers-kilt.jsonl", 15 | "http://dl.fbaipublicfiles.com/KILT/nq-train-kilt.jsonl", 16 | "http://dl.fbaipublicfiles.com/KILT/nq-dev-kilt.jsonl", 17 | "http://dl.fbaipublicfiles.com/KILT/nq-test_without_answers-kilt.jsonl", 18 | "http://dl.fbaipublicfiles.com/KILT/hotpotqa-train-kilt.jsonl", 19 | "http://dl.fbaipublicfiles.com/KILT/hotpotqa-dev-kilt.jsonl", 20 | "http://dl.fbaipublicfiles.com/KILT/hotpotqa-test_without_answers-kilt.jsonl", 21 | "http://dl.fbaipublicfiles.com/KILT/triviaqa-train_id-kilt.jsonl", 22 | "http://dl.fbaipublicfiles.com/KILT/triviaqa-dev_id-kilt.jsonl", 23 | "http://dl.fbaipublicfiles.com/KILT/triviaqa-test_id_without_answers-kilt.jsonl", 24 | "http://dl.fbaipublicfiles.com/KILT/eli5-train-kilt.jsonl", 25 | "http://dl.fbaipublicfiles.com/KILT/eli5-dev-kilt.jsonl", 26 | "http://dl.fbaipublicfiles.com/KILT/eli5-test_without_answers-kilt.jsonl", 27 | "http://dl.fbaipublicfiles.com/KILT/trex-train-kilt.jsonl", 28 | "http://dl.fbaipublicfiles.com/KILT/trex-dev-kilt.jsonl", 29 | "http://dl.fbaipublicfiles.com/KILT/trex-test_without_answers-kilt.jsonl", 30 | "http://dl.fbaipublicfiles.com/KILT/structured_zeroshot-train-kilt.jsonl", 31 | "http://dl.fbaipublicfiles.com/KILT/structured_zeroshot-dev-kilt.jsonl", 32 | "http://dl.fbaipublicfiles.com/KILT/structured_zeroshot-test_without_answers-kilt.jsonl", 33 | "http://dl.fbaipublicfiles.com/KILT/aidayago2-train-kilt.jsonl", 34 | "http://dl.fbaipublicfiles.com/KILT/aidayago2-dev-kilt.jsonl", 35 | "http://dl.fbaipublicfiles.com/KILT/aidayago2-test_without_answers-kilt.jsonl", 36 | "http://dl.fbaipublicfiles.com/KILT/wned-dev-kilt.jsonl", 37 | "http://dl.fbaipublicfiles.com/KILT/wned-test_without_answers-kilt.jsonl", 38 | "http://dl.fbaipublicfiles.com/KILT/cweb-dev-kilt.jsonl", 39 | "http://dl.fbaipublicfiles.com/KILT/cweb-test_without_answers-kilt.jsonl", 40 | "http://dl.fbaipublicfiles.com/KILT/wow-train-kilt.jsonl", 41 | "http://dl.fbaipublicfiles.com/KILT/wow-dev-kilt.jsonl", 42 | "http://dl.fbaipublicfiles.com/KILT/wow-test_without_answers-kilt.jsonl", 43 | ] 44 | 45 | 46 | for url in urls: 47 | base = url.split("/")[-1] 48 | filename = f"data/{base}" 49 | r = requests.get(url, stream=True) 50 | # Total size in bytes. 51 | total_size = int(r.headers.get("content-length", 0)) 52 | block_size = 1024 # 1 Kibibyte 53 | t = tqdm(total=total_size, unit="iB", unit_scale=True, desc=base) 54 | with open(filename, "wb") as f: 55 | for data in r.iter_content(block_size): 56 | t.update(len(data)) 57 | f.write(data) 58 | t.close() 59 | -------------------------------------------------------------------------------- /kilt/retrievers/DrQA_tfidf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import multiprocessing 9 | from multiprocessing.pool import ThreadPool 10 | 11 | from tqdm import tqdm 12 | from drqa import retriever 13 | 14 | import kilt.kilt_utils as utils 15 | from kilt.retrievers.base_retriever import Retriever 16 | 17 | 18 | def _get_predictions_thread(arguments): 19 | 20 | id = arguments["id"] 21 | queries_data = arguments["queries_data"] 22 | topk = arguments["topk"] 23 | ranker = arguments["ranker"] 24 | logger = arguments["logger"] 25 | 26 | if id == 0: 27 | iter_ = tqdm(queries_data) 28 | else: 29 | iter_ = queries_data 30 | 31 | result_doc_ids = [] 32 | result_doc_scores = [] 33 | result_query_id = [] 34 | 35 | for query_element in iter_: 36 | 37 | query = ( 38 | query_element["query"] 39 | .replace(utils.ENT_END, "") 40 | .replace(utils.ENT_START, "") 41 | .strip() 42 | ) 43 | result_query_id.append(query_element["id"]) 44 | 45 | doc_ids = [] 46 | doc_scores = [] 47 | try: 48 | doc_ids, doc_scores = ranker.closest_docs(query, topk) 49 | except RuntimeError as e: 50 | if logger: 51 | logger.warning("RuntimeError: {}".format(e)) 52 | 53 | result_doc_ids.append(doc_ids) 54 | result_doc_scores.append(doc_scores) 55 | 56 | return result_doc_ids, result_doc_scores, result_query_id 57 | 58 | 59 | class DrQA(Retriever): 60 | def __init__(self, name, retriever_model, num_threads): 61 | super().__init__(name) 62 | 63 | self.num_threads = min(num_threads, int(multiprocessing.cpu_count())) 64 | 65 | # initialize a ranker per thread 66 | self.arguments = [] 67 | for id in tqdm(range(self.num_threads)): 68 | self.arguments.append( 69 | { 70 | "id": id, 71 | "ranker": retriever.get_class("tfidf")(tfidf_path=retriever_model), 72 | } 73 | ) 74 | 75 | def feed_data(self, queries_data, logger=None): 76 | 77 | chunked_queries = utils.chunk_it(queries_data, self.num_threads) 78 | 79 | for idx, arg in enumerate(self.arguments): 80 | arg["queries_data"] = chunked_queries[idx] 81 | arg["logger"] = logger 82 | 83 | def run(self): 84 | pool = ThreadPool(self.num_threads) 85 | results = pool.map(_get_predictions_thread, self.arguments) 86 | 87 | provenance = {} 88 | 89 | for x in results: 90 | i, s, q = x 91 | for query_id, doc_ids in zip(q, i): 92 | provenance[query_id] = [] 93 | for d_id in doc_ids: 94 | provenance[query_id].append({"wikipedia_id": str(d_id).strip()}) 95 | 96 | pool.terminate() 97 | pool.join() 98 | 99 | return provenance 100 | -------------------------------------------------------------------------------- /tests/test_eval_downstream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import unittest 9 | import importlib.resources 10 | 11 | import kilt.eval_downstream 12 | import kilt.eval_retrieval 13 | import tests.test_data as test_data 14 | 15 | 16 | class TestEvalDownstream(unittest.TestCase): 17 | def test_calculate_metrics(self): 18 | 19 | with importlib.resources.open_text(test_data, "gold1.jsonl") as gold_file: 20 | with importlib.resources.open_text( 21 | test_data, "guess1_1.jsonl" 22 | ) as guess_file: 23 | result = kilt.eval_downstream.evaluate(gold_file.name, guess_file.name) 24 | 25 | # kilt 26 | self.assertEqual(result["kilt"]["KILT-em"], 1 / 3) 27 | self.assertEqual(result["kilt"]["KILT-f1"], 1 / 3) 28 | self.assertEqual(result["kilt"]["KILT-rougel"], 0.3333333316666667) 29 | 30 | # downsream 31 | self.assertEqual(result["downstream"]["em"], 2 / 3) 32 | self.assertEqual(result["downstream"]["f1"], 0.8333333333333334) 33 | self.assertEqual(result["downstream"]["rougel"], 0.7222222178240741) 34 | 35 | # retrieval page level 36 | self.assertEqual(result["retrieval"]["Rprec"], 1 / 3) 37 | self.assertEqual(result["retrieval"]["recall@5"], 1 / 3) 38 | 39 | with importlib.resources.open_text(test_data, "gold1.jsonl") as guess_file: 40 | result = kilt.eval_downstream.evaluate(gold_file.name, guess_file.name) 41 | 42 | # kilt 43 | self.assertEqual(result["kilt"]["KILT-em"], 1) 44 | self.assertEqual(result["kilt"]["KILT-f1"], 1) 45 | self.assertEqual(result["kilt"]["KILT-rougel"], 0.999999995) 46 | 47 | # downsream 48 | self.assertEqual(result["downstream"]["em"], 1) 49 | self.assertEqual(result["downstream"]["f1"], 1) 50 | self.assertEqual(result["downstream"]["rougel"], 0.999999995) 51 | 52 | # retrieval page level 53 | self.assertEqual(result["retrieval"]["Rprec"], 1) 54 | self.assertEqual(result["retrieval"]["recall@5"], 1) 55 | 56 | with importlib.resources.open_text(test_data, "gold3.jsonl") as gold_file: 57 | with importlib.resources.open_text( 58 | test_data, "guess3_1.jsonl" 59 | ) as guess_file: 60 | 61 | result = kilt.eval_downstream.evaluate(gold_file.name, guess_file.name) 62 | 63 | # kilt 64 | self.assertEqual(result["kilt"]["KILT-em"], 0) 65 | self.assertEqual(result["kilt"]["KILT-f1"], 0.25510204081632654) 66 | self.assertEqual(result["kilt"]["KILT-rougel"], 0.22352940932318338) 67 | 68 | # downsream 69 | self.assertEqual(result["downstream"]["em"], 0) 70 | self.assertEqual(result["downstream"]["f1"], 0.5102040816326531) 71 | self.assertEqual(result["downstream"]["rougel"], 0.44705881864636676) 72 | -------------------------------------------------------------------------------- /kilt/retrievers/base_retriever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | from abc import ABC, abstractmethod 10 | from kilt.configs import retriever 11 | 12 | 13 | class Retriever(ABC): 14 | def __init__(self, name): 15 | self.name = name 16 | 17 | @classmethod 18 | def from_default_config(cls, name): 19 | import importlib.resources 20 | 21 | config = json.loads( 22 | importlib.resources.read_text( 23 | retriever, "default_{name}.json".format(name=name) 24 | ) 25 | ) 26 | return cls(name, **config) 27 | 28 | @classmethod 29 | def from_config_file(cls, name, config_file): 30 | with open(config_file, "r") as cf: 31 | config = json.load(cf) 32 | return cls(name, **config) 33 | 34 | @classmethod 35 | def from_config_string(cls, name, config_string): 36 | config = json.loads(config_string) 37 | return cls(name, **config) 38 | 39 | @abstractmethod 40 | def feed_data(self, queries_data, logger=None): 41 | """ 42 | fed all data to the retriever, that will take care of batchify it 43 | each element in queries_data has an id and a query 44 | 45 | Args: 46 | queries_data (list): list of dicts with two fields: (1) 'id' -> id of the query: (2) 'query' -> text of the query 47 | 48 | Example: 49 | queries_data = [ 50 | {'id': '-4203908294749842710', 'query': 'what is the definition of bcc in email'}, 51 | ... 52 | ] 53 | """ 54 | raise NotImplementedError 55 | 56 | @abstractmethod 57 | def run(self): 58 | """ 59 | get the retrieved documents for all the fed data 60 | return all_doc_id, all_scores, all_query_id, provenance 61 | 62 | Returns 63 | ------- 64 | provenance: dictionary with retrieval result, the keys should match the query id in input 65 | 66 | Example: 67 | provenance: { 68 | '-4203908294749842710': [ 69 | {"score": "179.01215", "text": "permit the use of a program-external editor. The email clients will perform formatting according to RFC 5322 for headers and body, and MIME for non-textual content and attachments. Headers include the destination fields, \"To\", \"Cc\" (short for \"Carbon copy\"), and \"Bcc\" (\"Blind carbon copy\"), and the originator fields \"From\" which is the message's author(s), \"Sender\" in case there are more authors, and \"Reply-To\"", "wikipedia_title": "Email client", "wikipedia_id": "43478"}, 70 | {"score": "184.6643", "text": "this example, the conversation parts are prefixed with \"S:\" and \"C:\", for \"server\" and \"client\", respectively; these labels are not part of the exchange.) After the message sender (SMTP client) establishes a reliable communications channel to the message receiver (SMTP server), the session is opened with a greeting by the server, usually containing its fully qualified domain name (FQDN), in this case \"smtp.example.com\". The client initiates its dialog by responding with a", "wikipedia_title": "Simple Mail Transfer Protocol", "wikipedia_id": "27675"}, 71 | ... 72 | ], 73 | ... 74 | } 75 | """ 76 | raise NotImplementedError 77 | -------------------------------------------------------------------------------- /kilt/retrievers/BM25_connector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import multiprocessing 9 | from multiprocessing.pool import ThreadPool 10 | import json 11 | 12 | from tqdm import tqdm 13 | import jnius_config 14 | 15 | import kilt.kilt_utils as utils 16 | from kilt.retrievers.base_retriever import Retriever 17 | 18 | 19 | def _run_thread(arguments): 20 | idz = arguments["id"] 21 | index = arguments["index"] 22 | k = arguments["k"] 23 | data = arguments["data"] 24 | 25 | # BM25 parameters #TODO 26 | # bm25_a = arguments["bm25_a"] 27 | # bm25_b = arguments["bm25_b"] 28 | # searcher.set_bm25(bm25_a, bm25_b) 29 | 30 | from pyserini.search import SimpleSearcher 31 | 32 | searcher = SimpleSearcher(index) 33 | 34 | _iter = data 35 | if idz == 0: 36 | _iter = tqdm(data) 37 | 38 | provenance = {} 39 | for x in _iter: 40 | query_id = x["id"] 41 | query = ( 42 | x["query"].replace(utils.ENT_END, "").replace(utils.ENT_START, "").strip() 43 | ) 44 | 45 | hits = searcher.search(query, k) 46 | 47 | element = [] 48 | for y in hits: 49 | try: 50 | doc_data = json.loads(str(y.docid).strip()) 51 | doc_data["score"] = y.score 52 | doc_data["text"] = str(y.raw).strip() 53 | element.append(doc_data) 54 | except Exception as e: 55 | print(e) 56 | element.append( 57 | { 58 | "score": y.score, 59 | "text": str(y.raw).strip(), 60 | "title": y.docid, 61 | } 62 | ) 63 | provenance[query_id] = element 64 | 65 | return provenance 66 | 67 | 68 | class BM25(Retriever): 69 | def __init__(self, name, index, k, num_threads, Xms=None, Xmx=None): 70 | super().__init__(name) 71 | 72 | if Xms and Xmx: 73 | # to solve Insufficient memory for the Java Runtime Environment 74 | jnius_config.add_options( 75 | "-Xms{}".format(Xms), "-Xmx{}".format(Xmx), "-XX:-UseGCOverheadLimit" 76 | ) 77 | print("Configured options:", jnius_config.get_options()) 78 | 79 | self.num_threads = min(num_threads, int(multiprocessing.cpu_count())) 80 | 81 | # initialize a ranker per thread 82 | self.arguments = [] 83 | for id in tqdm(range(self.num_threads)): 84 | self.arguments.append( 85 | { 86 | "id": id, 87 | "index": index, 88 | "k": k, 89 | } 90 | ) 91 | 92 | def feed_data(self, queries_data, logger=None): 93 | 94 | chunked_queries = utils.chunk_it(queries_data, self.num_threads) 95 | 96 | for idx, arg in enumerate(self.arguments): 97 | arg["data"] = chunked_queries[idx] 98 | 99 | def run(self): 100 | pool = ThreadPool(self.num_threads) 101 | results = pool.map(_run_thread, self.arguments) 102 | 103 | provenance = {} 104 | for x in results: 105 | provenance.update(x) 106 | pool.terminate() 107 | pool.join() 108 | 109 | return provenance 110 | -------------------------------------------------------------------------------- /scripts/get_triviaqa_input.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import sys 9 | import requests 10 | import tarfile 11 | import os 12 | import json 13 | 14 | from tqdm.auto import tqdm 15 | 16 | from kilt import kilt_utils 17 | 18 | url = "http://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz" 19 | tar_filename = "triviaqa-rc.tar.gz" 20 | trivia_path = "triviaqa-rc/" 21 | members = [ 22 | "qa/wikipedia-train.json", 23 | "qa/wikipedia-dev.json", 24 | "qa/wikipedia-test-without-answers.json", 25 | ] 26 | base = "data/" 27 | input_files = [ 28 | base + "triviaqa-train_id-kilt.jsonl", 29 | base + "triviaqa-dev_id-kilt.jsonl", 30 | base + "triviaqa-test_id_without_answers-kilt.jsonl", 31 | ] 32 | output_files = [ 33 | base + "triviaqa-train-kilt.jsonl", 34 | base + "triviaqa-dev-kilt.jsonl", 35 | base + "triviaqa-test_without_answers-kilt.jsonl", 36 | ] 37 | 38 | 39 | def decompress(tar_file, path, members=None): 40 | """ 41 | Extracts `tar_file` and puts the `members` to `path`. 42 | If members is None, all members on `tar_file` will be extracted. 43 | """ 44 | tar = tarfile.open(tar_file, mode="r:gz") 45 | if members is None: 46 | members = tar.getmembers() 47 | # with progress bar 48 | # set the progress bar 49 | progress = tqdm(members) 50 | for member in progress: 51 | tar.extract(member, path=path) 52 | # set the progress description of the progress bar 53 | progress.set_description(f"Extracting {str(member)}") 54 | # or use this 55 | # tar.extractall(members=members, path=path) 56 | # close the file 57 | tar.close() 58 | 59 | 60 | print("1. download TriviaQA original tar.gz file") 61 | # Streaming, so we can iterate over the response. 62 | r = requests.get(url, stream=True) 63 | # Total size in bytes. 64 | total_size = int(r.headers.get("content-length", 0)) 65 | block_size = 1024 # 1 Kibibyte 66 | t = tqdm(total=total_size, unit="iB", unit_scale=True) 67 | with open(tar_filename, "wb") as f: 68 | for data in r.iter_content(block_size): 69 | t.update(len(data)) 70 | f.write(data) 71 | t.close() 72 | if total_size != 0 and t.n != total_size: 73 | print("ERROR, something went wrong") 74 | 75 | 76 | print("2. extract tar.gz file") 77 | decompress(tar_filename, trivia_path, members=members) 78 | 79 | print("3. remove tar.gz file") 80 | os.remove(tar_filename) 81 | 82 | print("4. getting original questions") 83 | id2input = {} 84 | for member in members: 85 | print(member) 86 | filename = trivia_path + member 87 | with open(filename, "r") as fin: 88 | data = json.load(fin) 89 | for element in data["Data"]: 90 | e_id = element["QuestionId"] 91 | e_input = element["Question"] 92 | assert e_id not in id2input 93 | id2input[e_id] = e_input 94 | os.remove(filename) 95 | 96 | print("5. remove original TriviaQA data") 97 | os.rmdir(trivia_path + "qa/") 98 | os.rmdir(trivia_path) 99 | 100 | print("6. update kilt files") 101 | for in_file, out_file in zip(input_files, output_files): 102 | data = kilt_utils.load_data(in_file) 103 | for element in data: 104 | element["input"] = id2input[element["id"]] 105 | kilt_utils.store_data(out_file, data) 106 | os.remove(in_file) 107 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /kilt/retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import os 10 | import os.path 11 | from os import path 12 | 13 | from kilt import kilt_utils as utils 14 | 15 | 16 | def generate_output_file(output_folder, dataset_file): 17 | basename = os.path.basename(dataset_file) 18 | output_file = os.path.join(output_folder, basename) 19 | if not os.path.exists(os.path.dirname(output_file)): 20 | os.makedirs(os.path.dirname(output_file)) 21 | return output_file 22 | 23 | 24 | def run( 25 | test_config_json, 26 | ranker, 27 | model_name, 28 | logger, 29 | topk=100, 30 | debug=False, 31 | output_folder="", 32 | ): 33 | 34 | for task_family, datasets in test_config_json.items(): 35 | logger.info("TASK: {}".format(task_family)) 36 | 37 | for dataset_name, dataset_file in datasets.items(): 38 | logger.info("DATASET: {}".format(dataset_name)) 39 | 40 | if dataset_file: 41 | 42 | output_file = generate_output_file(output_folder, dataset_file) 43 | if path.exists(output_file): 44 | logger.info( 45 | "Skip output file {} that already exists.".format(output_file) 46 | ) 47 | continue 48 | 49 | raw_data = utils.load_data(dataset_file) 50 | 51 | # consider only valid data - filter out invalid 52 | validated_data = {} 53 | query_data = [] 54 | for element in raw_data: 55 | #if utils.validate_datapoint(element, logger=None): 56 | if element["id"] in validated_data: 57 | raise ValueError("ids are not unique in input data!") 58 | validated_data[element["id"]] = element 59 | query_data.append( 60 | {"query": element["input"], "id": element["id"]} 61 | ) 62 | 63 | if debug: 64 | # just consider the top10 datapoints 65 | query_data = query_data[:10] 66 | print("query_data: {}", format(query_data)) 67 | 68 | # get predictions 69 | ranker.feed_data(query_data) 70 | provenance = ranker.run() 71 | 72 | if len(provenance) != len(query_data): 73 | logger.warning( 74 | "different numbers of queries: {} and predicions: {}".format( 75 | len(query_data), len(provenance) 76 | ) 77 | ) 78 | 79 | # write prediction files 80 | if provenance: 81 | logger.info("writing prediction file to {}".format(output_file)) 82 | 83 | predictions = [] 84 | for query_id in provenance.keys(): 85 | element = validated_data[query_id] 86 | new_output = [{"provenance": provenance[query_id]}] 87 | # append the answers 88 | if "output" in element: 89 | for o in element["output"]: 90 | if "answer" in o: 91 | new_output.append({"answer": o["answer"]}) 92 | element["output"] = new_output 93 | predictions.append(element) 94 | 95 | with open(output_file, "w+") as outfile: 96 | for p in predictions: 97 | json.dump(p, outfile) 98 | outfile.write("\n") 99 | -------------------------------------------------------------------------------- /kilt/knowledge_source.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pymongo import MongoClient 8 | import requests 9 | from urllib.parse import unquote 10 | import urllib.request 11 | from bs4 import BeautifulSoup 12 | import urllib.parse as urlparse 13 | from urllib.parse import parse_qs 14 | 15 | DEFAULT_MONGO_CONNECTION_STRING = "mongodb://127.0.0.1:27017/admin" 16 | 17 | 18 | def _get_pageid_from_api(title, client=None): 19 | pageid = None 20 | 21 | title_html = title.strip().replace(" ", "%20") 22 | url = ( 23 | "https://en.wikipedia.org/w/api.php?action=query&titles={}&format=json".format( 24 | title_html 25 | ) 26 | ) 27 | 28 | try: 29 | # Package the request, send the request and catch the response: r 30 | r = requests.get(url) 31 | 32 | # Decode the JSON data into a dictionary: json_data 33 | json_data = r.json() 34 | 35 | if len(json_data["query"]["pages"]) > 1: 36 | print("WARNING: more than one result returned from wikipedia api") 37 | 38 | for _, v in json_data["query"]["pages"].items(): 39 | pageid = v["pageid"] 40 | 41 | except Exception as e: 42 | # print("Exception: {}".format(e)) 43 | pass 44 | 45 | return pageid 46 | 47 | 48 | def _read_url(url): 49 | with urllib.request.urlopen(url) as response: 50 | html = response.read() 51 | soup = BeautifulSoup(html, features="html.parser") 52 | title = soup.title.string.replace(" - Wikipedia", "").strip() 53 | return title 54 | 55 | 56 | def _get_title_from_wikipedia_url(url, client=None): 57 | title = None 58 | try: 59 | title = _read_url(url) 60 | except Exception: 61 | try: 62 | # try adding https 63 | title = _read_url("https://" + url) 64 | except Exception: 65 | # print("Exception: {}".format(e)) 66 | pass 67 | return title 68 | 69 | 70 | class KnowledgeSource: 71 | def __init__( 72 | self, 73 | mongo_connection_string=None, 74 | database="kilt", 75 | collection="knowledgesource", 76 | ): 77 | if not mongo_connection_string: 78 | mongo_connection_string = DEFAULT_MONGO_CONNECTION_STRING 79 | self.client = MongoClient(mongo_connection_string) 80 | self.db = self.client[database][collection] 81 | 82 | def get_all_pages_cursor(self): 83 | cursor = self.db.find({}) 84 | return cursor 85 | 86 | def get_num_pages(self): 87 | return self.db.estimated_document_count() 88 | 89 | def get_page_by_id(self, wikipedia_id): 90 | page = self.db.find_one({"_id": str(wikipedia_id)}) 91 | return page 92 | 93 | def get_page_by_title(self, wikipedia_title, attempt=0): 94 | page = self.db.find_one({"wikipedia_title": str(wikipedia_title)}) 95 | return page 96 | 97 | def get_page_from_url(self, url): 98 | page = None 99 | 100 | # 1. try to look for title in the url 101 | parsed = urlparse.urlparse(url) 102 | record = parse_qs(parsed.query) 103 | if "title" in record: 104 | title = record["title"][0].replace("_", " ") 105 | page = self.get_page_by_title(title) 106 | 107 | # 2. try another way to look for title in the url 108 | if page == None: 109 | title = url.split("/")[-1].replace("_", " ") 110 | page = self.get_page_by_title(title) 111 | 112 | # 3. try to retrieve the current wikipedia_id from the url 113 | if page == None: 114 | title = _get_title_from_wikipedia_url(url, client=self.client) 115 | if title: 116 | pageid = _get_pageid_from_api(title, client=self.client) 117 | if pageid: 118 | page = self.get_page_by_id(pageid) 119 | 120 | return page 121 | -------------------------------------------------------------------------------- /scripts/execute_retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import argparse 9 | 10 | from kilt import retrieval 11 | from kilt import kilt_utils as utils 12 | 13 | 14 | def execute( 15 | logger, test_config_json, retriever, log_directory, model_name, output_folder 16 | ): 17 | 18 | # run evaluation 19 | retrieval.run( 20 | test_config_json, retriever, model_name, logger, output_folder=output_folder 21 | ) 22 | 23 | 24 | def main(args): 25 | 26 | # load configs 27 | with open(args.test_config, "r") as fin: 28 | test_config_json = json.load(fin) 29 | 30 | # create a new directory to log and store results 31 | log_directory = utils.create_logdir_with_timestamp(args.logdir) 32 | logger = None 33 | 34 | logger = utils.init_logging(log_directory, args.model_name, logger) 35 | logger.info("loading {} ...".format(args.model_name)) 36 | 37 | if args.model_name == "drqa": 38 | # DrQA tf-idf 39 | from kilt.retrievers import DrQA_tfidf 40 | 41 | if args.model_configuration: 42 | retriever = DrQA_tfidf.DrQA.from_config_file( 43 | args.model_name, args.model_configuration 44 | ) 45 | else: 46 | retriever = DrQA_tfidf.DrQA.from_default_config(args.model_name) 47 | elif args.model_name == "dpr": 48 | # DPR 49 | from kilt.retrievers import DPR_connector 50 | 51 | if args.model_configuration: 52 | retriever = DPR_connector.DPR.from_config_file( 53 | args.model_name, args.model_configuration 54 | ) 55 | else: 56 | retriever = DPR_connector.DPR.from_default_config(args.model_name) 57 | elif args.model_name == "dpr_distr": 58 | # DPR distributed 59 | from kilt.retrievers import DPR_distr_connector 60 | 61 | if args.model_configuration: 62 | retriever = DPR_distr_connector.DPR.from_config_file( 63 | args.model_name, args.model_configuration 64 | ) 65 | else: 66 | raise "No default configuration for DPR distributed!" 67 | elif args.model_name == "blink": 68 | # BLINK 69 | from kilt.retrievers import BLINK_connector 70 | 71 | if args.model_configuration: 72 | retriever = BLINK_connector.BLINK.from_config_file( 73 | args.model_name, args.model_configuration 74 | ) 75 | else: 76 | retriever = BLINK_connector.BLINK.from_default_config(args.model_name) 77 | elif args.model_name == "bm25": 78 | # BM25 79 | from kilt.retrievers import BM25_connector 80 | 81 | if args.model_configuration: 82 | retriever = BM25_connector.BM25.from_config_file( 83 | args.model_name, args.model_configuration 84 | ) 85 | else: 86 | retriever = BM25_connector.BM25.from_default_config(args.model_name) 87 | else: 88 | raise ValueError("unknown retriever model") 89 | 90 | execute( 91 | logger, 92 | test_config_json, 93 | retriever, 94 | log_directory, 95 | args.model_name, 96 | args.output_folder, 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | 102 | parser = argparse.ArgumentParser() 103 | 104 | parser.add_argument( 105 | "--test_config", 106 | dest="test_config", 107 | type=str, 108 | default="kilt/configs/test_data.json", 109 | help="Test Configuration.", 110 | ) 111 | 112 | parser.add_argument( 113 | "--logdir", 114 | dest="logdir", 115 | type=str, 116 | default="logs/ranking/", 117 | help="logdir", 118 | ) 119 | 120 | parser.add_argument( 121 | "--model_name", 122 | "-m", 123 | dest="model_name", 124 | type=str, 125 | required=True, 126 | help="retriever model name in {drqa,solr,dpr,blink,bm25}", 127 | ) 128 | 129 | parser.add_argument( 130 | "--model_configuration", 131 | "-c", 132 | dest="model_configuration", 133 | type=str, 134 | default=None, 135 | help="model configuration", 136 | ) 137 | 138 | parser.add_argument( 139 | "--output_folder", 140 | "-o", 141 | dest="output_folder", 142 | type=str, 143 | required=True, 144 | help="output folder", 145 | ) 146 | 147 | args = parser.parse_args() 148 | 149 | main(args) 150 | -------------------------------------------------------------------------------- /kilt/datasets/zero_shot_re.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import spacy 8 | import uuid 9 | 10 | import kilt.kilt_utils as utils 11 | from kilt.datasets.base_dataset import Dataset 12 | 13 | 14 | class ZeroShotREDataset(Dataset): 15 | def __init__(self, name, input_file, output_file, max_chunks): 16 | super().__init__(name) 17 | self.input_file = input_file 18 | self.output_file = output_file 19 | self.max_chunks = max_chunks 20 | self.nlp = spacy.load("en_core_web_sm") 21 | 22 | def get_uuid(self): 23 | return str(uuid.uuid4()) 24 | 25 | def map_datapoint( 26 | self, 27 | wikidata_relation, 28 | question_template, 29 | wikipedia_title, 30 | sentence, 31 | answer_spans, 32 | ks, 33 | entry_id, 34 | ): 35 | kilt_entry = {} 36 | kilt_entry["id"] = entry_id 37 | kilt_entry["input"] = question_template.replace("XXX", wikipedia_title).replace( 38 | " auther", " author" # to fix typo in templates 39 | ) 40 | kilt_entry["output"] = [] 41 | kilt_entry["meta"] = { 42 | "wikidata_relation": wikidata_relation, 43 | "question_template": question_template, 44 | } 45 | print("Getting wiki page for", wikipedia_title) 46 | pages = ks.get_pages_by_title(wikipedia_title) 47 | 48 | if len(pages) <= 0: 49 | kilt_entry["output"] = [ 50 | {"answer": answer_span, "provenance": []} 51 | for answer_span in answer_spans 52 | ] 53 | return kilt_entry 54 | print("matching answer") 55 | # We take the first returned page from the list. 56 | paragraph_id, start_character, end_character, bleu = utils.match_answer( 57 | sentence, pages[0], nlp=self.nlp, debug=False 58 | ) 59 | print("done matching answer") 60 | 61 | for answer_span in answer_spans: 62 | output = {"answer": answer_span, "provenance": []} 63 | output["provenance"].append( 64 | { 65 | "wikipedia_id": pages[0]["wikipedia_id"], 66 | "title": pages[0]["wikipedia_title"], 67 | "start_paragraph_id": paragraph_id, 68 | "start_character": start_character, 69 | "end_paragraph_id": paragraph_id, 70 | "end_character": end_character, 71 | "bleu_score": bleu, 72 | "meta": {}, 73 | } 74 | ) 75 | kilt_entry["output"].append(output) 76 | return kilt_entry 77 | 78 | def get_chunks(self, num_chunks): 79 | data = [] 80 | with open(self.input_file, "r") as fin: 81 | data = fin.readlines() 82 | return utils.chunk_it(data, num_chunks) 83 | 84 | def process_chunk(self, chunk, ks, chunk_id): 85 | kilt_data = [] 86 | missing_pages = 0 87 | negative_samples = 0 88 | for i, line in enumerate(chunk): 89 | print("Processed {} lines for chunk {}".format(i, chunk_id)) 90 | print("Processing:", line) 91 | fields = line.strip().split("\t") 92 | # Leave out negative samples (samples where one can't infer the 93 | # answer from the provided sentence). 94 | if len(fields) <= 4: 95 | negative_samples += 1 96 | continue 97 | wikidata_relation, question_template, wikipedia_title, sentence = fields[ 98 | 0:4 99 | ] 100 | answer_spans = fields[4:] 101 | kilt_entry = self.map_datapoint( 102 | wikidata_relation, 103 | question_template, 104 | wikipedia_title, 105 | sentence, 106 | answer_spans, 107 | ks, 108 | self.get_uuid(), 109 | ) 110 | if kilt_entry is None: 111 | missing_pages += 1 112 | continue 113 | kilt_data.append(kilt_entry) 114 | return kilt_data, [missing_pages, negative_samples] 115 | 116 | def postprocess_metadata(self, metadata): 117 | missing_pages = 0 118 | negative_samples = 0 119 | for m, n in metadata: 120 | missing_pages += m 121 | negative_samples += n 122 | print( 123 | "{} samples with missing pages, {} samples with no answer spans.".format( 124 | missing_pages, negative_samples 125 | ) 126 | ) 127 | -------------------------------------------------------------------------------- /kilt/readers/t5/evaluate_kilt_task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import glob 10 | import os 11 | from pathlib import Path 12 | 13 | import torch 14 | from rouge_score import rouge_scorer, scoring 15 | from tqdm import tqdm 16 | 17 | from finetune import Seq2seqTransformer 18 | 19 | 20 | def chunks(lst, n): 21 | """Yield successive n-sized chunks from lst.""" 22 | for i in range(0, len(lst), n): 23 | yield lst[i : i + n] 24 | 25 | 26 | def generate_answers(lns, output_file_path, model, tokenizer, batch_size, device): 27 | output_file = Path(output_file_path).open("w") 28 | 29 | model.to(device) 30 | 31 | # update config with specific params 32 | task_specific_params = model.config.task_specific_params 33 | if task_specific_params is not None: 34 | model.config.update(task_specific_params.get("nq", {})) 35 | 36 | for batch in tqdm(list(chunks(lns, batch_size))): 37 | batch = [model.config.prefix + text for text in batch] 38 | 39 | dct = tokenizer.batch_encode_plus( 40 | batch, max_length=64, return_tensors="pt", pad_to_max_length=True 41 | ) 42 | input_ids = dct["input_ids"].to(device) 43 | attention_mask = dct["attention_mask"].to(device) 44 | 45 | answers = model.generate(input_ids=input_ids, attention_mask=attention_mask) 46 | dec = [ 47 | tokenizer.decode( 48 | g, skip_special_tokens=True, clean_up_tokenization_spaces=False 49 | ) 50 | for g in answers 51 | ] 52 | 53 | for hypothesis in dec: 54 | output_file.write(hypothesis + "\n") 55 | output_file.flush() 56 | 57 | 58 | def calculate_rouge(output_lns, reference_lns, score_path): 59 | score_file = Path(score_path).open("w") 60 | scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) 61 | aggregator = scoring.BootstrapAggregator() 62 | 63 | for reference_ln, output_ln in zip(reference_lns, output_lns): 64 | scores = scorer.score(reference_ln, output_ln) 65 | aggregator.add_scores(scores) 66 | 67 | result = aggregator.aggregate() 68 | score_file.write( 69 | "ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format( 70 | result["rouge1"], result["rouge2"], result["rougeL"] 71 | ) 72 | ) 73 | 74 | 75 | def run_generate(): 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument( 78 | "model_size", 79 | type=str, 80 | help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.", 81 | default="t5-base", 82 | ) 83 | parser.add_argument( 84 | "input_path", 85 | type=str, 86 | help="like nqa/test_articles_questions.txt", 87 | ) 88 | parser.add_argument( 89 | "output_path", 90 | type=str, 91 | help="where to save summaries", 92 | ) 93 | parser.add_argument( 94 | "reference_path", type=str, help="like nqa/test_reference_answers.txt" 95 | ) 96 | parser.add_argument( 97 | "score_path", 98 | type=str, 99 | help="where to save the rouge score", 100 | ) 101 | parser.add_argument( 102 | "--batch_size", 103 | type=int, 104 | default=8, 105 | required=False, 106 | help="batch size: how many to summarize at a time", 107 | ) 108 | parser.add_argument( 109 | "--no_cuda", 110 | default=False, 111 | type=bool, 112 | help="Whether to force the execution on CPU.", 113 | ) 114 | 115 | args = parser.parse_args() 116 | args.device = torch.device( 117 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 118 | ) 119 | 120 | source_lns = [x.rstrip() for x in open(args.input_path).readlines()] 121 | sq2sq = Seq2seqTransformer(args) 122 | checkpoints = list( 123 | sorted( 124 | glob.glob( 125 | os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True 126 | ) 127 | ) 128 | ) 129 | 130 | model = sq2sq.load_from_checkpoint(checkpoints[-1]).model 131 | tokenizer = sq2sq.tokenizer 132 | generate_answers( 133 | source_lns, args.output_path, model, tokenizer, args.batch_size, args.device 134 | ) 135 | output_lns = [x.rstrip() for x in open(args.output_path).readlines()] 136 | reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] 137 | 138 | calculate_rouge(output_lns, reference_lns, args.score_path) 139 | 140 | 141 | if __name__ == "__main__": 142 | run_generate() 143 | -------------------------------------------------------------------------------- /kilt/retrievers/DPR_connector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import argparse 10 | import glob 11 | import pickle 12 | 13 | from dpr.utils.model_utils import ( 14 | load_states_from_checkpoint, 15 | setup_for_distributed_mode, 16 | get_model_obj, 17 | ) 18 | from dpr.options import set_encoder_params_from_state 19 | from dpr.models import init_biencoder_components 20 | from dense_retriever import ( 21 | DenseRetriever, 22 | parse_qa_csv_file, 23 | load_passages, 24 | iterate_encoded_files, 25 | ) 26 | from dpr.indexer.faiss_indexers import ( 27 | DenseIndexer, 28 | DenseHNSWFlatIndexer, 29 | DenseFlatIndexer, 30 | ) 31 | 32 | from kilt.configs import retriever 33 | import kilt.kilt_utils as utils 34 | from kilt.retrievers.base_retriever import Retriever 35 | 36 | 37 | class DPR(Retriever): 38 | def __init__(self, name, **config): 39 | super().__init__(name) 40 | 41 | self.args = argparse.Namespace(**config) 42 | saved_state = load_states_from_checkpoint(self.args.model_file) 43 | set_encoder_params_from_state(saved_state.encoder_params, self.args) 44 | tensorizer, encoder, _ = init_biencoder_components( 45 | self.args.encoder_model_type, self.args, inference_only=True 46 | ) 47 | encoder = encoder.question_model 48 | encoder, _ = setup_for_distributed_mode( 49 | encoder, 50 | None, 51 | self.args.device, 52 | self.args.n_gpu, 53 | self.args.local_rank, 54 | self.args.fp16, 55 | ) 56 | encoder.eval() 57 | 58 | # load weights from the model file 59 | model_to_load = get_model_obj(encoder) 60 | 61 | prefix_len = len("question_model.") 62 | question_encoder_state = { 63 | key[prefix_len:]: value 64 | for (key, value) in saved_state.model_dict.items() 65 | if key.startswith("question_model.") 66 | } 67 | model_to_load.load_state_dict(question_encoder_state, strict=False) 68 | vector_size = model_to_load.get_out_size() 69 | 70 | # index all passages 71 | ctx_files_pattern = self.args.encoded_ctx_file 72 | input_paths = glob.glob(ctx_files_pattern) 73 | 74 | index_buffer_sz = self.args.index_buffer 75 | if self.args.hnsw_index: 76 | index = DenseHNSWFlatIndexer(vector_size) 77 | index.deserialize_from(self.args.hnsw_index_path) 78 | else: 79 | index = DenseFlatIndexer(vector_size) 80 | index.index_data(input_paths) 81 | 82 | self.retriever = DenseRetriever( 83 | encoder, self.args.batch_size, tensorizer, index 84 | ) 85 | 86 | # not needed for now 87 | self.all_passages = load_passages(self.args.ctx_file) 88 | 89 | self.KILT_mapping = None 90 | if self.args.KILT_mapping: 91 | self.KILT_mapping = pickle.load(open(self.args.KILT_mapping, "rb")) 92 | 93 | def feed_data( 94 | self, 95 | queries_data, 96 | ent_start_token=utils.ENT_START, 97 | ent_end_token=utils.ENT_START, 98 | logger=None, 99 | ): 100 | 101 | # get questions & answers 102 | self.questions = [ 103 | x["query"].replace(ent_start_token, "").replace(ent_end_token, "").strip() 104 | for x in queries_data 105 | ] 106 | self.query_ids = [x["id"] for x in queries_data] 107 | 108 | def run(self): 109 | 110 | questions_tensor = self.retriever.generate_question_vectors(self.questions) 111 | top_ids_and_scores = self.retriever.get_top_docs( 112 | questions_tensor.numpy(), self.args.n_docs 113 | ) 114 | 115 | provenance = {} 116 | 117 | for record, query_id in zip(top_ids_and_scores, self.query_ids): 118 | top_ids, scores = record 119 | element = [] 120 | 121 | # sort by score in descending order 122 | for score, id in sorted(zip(scores, top_ids), reverse=True): 123 | 124 | text = self.all_passages[id][0] 125 | index = self.all_passages[id][1] 126 | 127 | wikipedia_id = None 128 | if self.KILT_mapping: 129 | # passages indexed by wikipedia title - mapping needed 130 | title = index 131 | if title in self.KILT_mapping: 132 | wikipedia_id = self.KILT_mapping[title] 133 | else: 134 | # passages indexed by wikipedia id 135 | wikipedia_id = index 136 | 137 | element.append( 138 | { 139 | "score": str(score), 140 | "text": str(text), 141 | "wikipedia_title": str(index), 142 | "wikipedia_id": str(wikipedia_id), 143 | } 144 | ) 145 | 146 | assert query_id not in provenance 147 | provenance[query_id] = element 148 | 149 | return provenance 150 | -------------------------------------------------------------------------------- /kilt/retrievers/DPR_distr_connector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import logging 9 | import pickle 10 | import zlib 11 | from omegaconf import OmegaConf 12 | from tqdm import tqdm 13 | 14 | from dpr.models import init_biencoder_components 15 | from dpr.options import setup_cfg_gpu, set_cfg_params_from_state 16 | from dpr.utils.model_utils import ( 17 | setup_for_distributed_mode, 18 | get_model_obj, 19 | load_states_from_checkpoint, 20 | ) 21 | from dense_retriever import DenseRPCRetriever 22 | 23 | import kilt.kilt_utils as utils 24 | from kilt.retrievers.base_retriever import Retriever 25 | 26 | 27 | logger = logging.getLogger() 28 | logger.setLevel(logging.INFO) 29 | 30 | 31 | class DPR(Retriever): 32 | def __init__(self, name, cfg): 33 | super().__init__(name) 34 | 35 | cfg = setup_cfg_gpu(cfg) 36 | 37 | logger.info("CFG (after gpu configuration):") 38 | logger.info("%s", OmegaConf.to_yaml(cfg)) 39 | 40 | saved_state = load_states_from_checkpoint(cfg.model_file) 41 | set_cfg_params_from_state(saved_state.encoder_params, cfg) 42 | 43 | tensorizer, encoder, _ = init_biencoder_components( 44 | cfg.encoder.encoder_model_type, cfg, inference_only=True 45 | ) 46 | 47 | encoder = encoder.question_model 48 | 49 | encoder, _ = setup_for_distributed_mode( 50 | encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16 51 | ) 52 | encoder.eval() 53 | 54 | # load weights from the model file 55 | model_to_load = get_model_obj(encoder) 56 | logger.info("Loading saved model state ...") 57 | 58 | encoder_prefix = "question_model." 59 | prefix_len = len(encoder_prefix) 60 | 61 | logger.info("Encoder state prefix %s", encoder_prefix) 62 | question_encoder_state = { 63 | key[prefix_len:]: value 64 | for (key, value) in saved_state.model_dict.items() 65 | if key.startswith(encoder_prefix) 66 | and key != "question_model.embeddings.position_ids" 67 | } 68 | model_to_load.load_state_dict(question_encoder_state, strict=False) 69 | vector_size = model_to_load.get_out_size() 70 | logger.info("Encoder vector_size=%d", vector_size) 71 | 72 | self.retriever = DenseRPCRetriever( 73 | encoder, 74 | cfg.batch_size, 75 | tensorizer, 76 | cfg.rpc_retriever_cfg_file, 77 | vector_size, 78 | use_l2_conversion=cfg.use_l2_conversion, 79 | ) 80 | self.retriever.load_index(cfg.rpc_index_id) 81 | 82 | self.KILT_mapping = None 83 | if cfg.KILT_mapping: 84 | self.KILT_mapping = dict(pickle.load(open(cfg.KILT_mapping, "rb"))) 85 | 86 | self.rpc_meta_compressed = cfg.rpc_meta_compressed 87 | self.cfg = cfg 88 | 89 | @classmethod 90 | def from_config_file(cls, name, config_file): 91 | cfg = OmegaConf.load(config_file) 92 | return cls(name, cfg) 93 | 94 | @classmethod 95 | def process_query(cls, x, ent_start_token, ent_end_token): 96 | return x["query"].replace(ent_start_token, "").replace( 97 | ent_end_token, "" 98 | ).strip() + ("?" if not x["query"].endswith("?") else "") 99 | 100 | def feed_data( 101 | self, 102 | queries_data, 103 | ent_start_token=utils.ENT_START, 104 | ent_end_token=utils.ENT_START, 105 | logger=None, 106 | ): 107 | 108 | # get questions & answers 109 | self.questions = [ 110 | DPR.process_query(x, ent_start_token, ent_end_token) for x in queries_data 111 | ] 112 | self.query_ids = [x["id"] for x in queries_data] 113 | 114 | def run(self): 115 | 116 | dup_multiplier = 1 117 | questions_tensor = self.retriever.generate_question_vectors(self.questions) 118 | top_ids_and_scores = self.retriever.get_top_docs( 119 | questions_tensor.numpy(), dup_multiplier * self.cfg.n_docs, search_batch=256 120 | ) 121 | 122 | provenance = {} 123 | 124 | for record, query_id in tqdm(zip(top_ids_and_scores, self.query_ids)): 125 | element = [] 126 | docs_meta, scores = record 127 | 128 | cnt = 0 129 | for score, meta in zip(scores, docs_meta): 130 | if cnt >= self.cfg.n_docs: 131 | break 132 | doc_id, text, title = meta[:3] 133 | wikipedia_id = ( 134 | self.KILT_mapping[int(doc_id)] 135 | if self.KILT_mapping and (int(doc_id) in self.KILT_mapping) 136 | else None 137 | ) 138 | 139 | element.append( 140 | { 141 | "score": str(score), 142 | "text": str(zlib.decompress(text).decode()) 143 | if self.rpc_meta_compressed 144 | else text, 145 | "wikipedia_title": str(zlib.decompress(title).decode()) 146 | if self.rpc_meta_compressed 147 | else title, 148 | "wikipedia_id": str(wikipedia_id), 149 | "doc_id": str(doc_id), 150 | } 151 | ) 152 | cnt += 1 153 | 154 | assert query_id not in provenance 155 | provenance[query_id] = element 156 | 157 | return provenance 158 | -------------------------------------------------------------------------------- /kilt/datasets/triviaqa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | import json 11 | import spacy 12 | import sys 13 | import re 14 | import kilt.kilt_utils as utils 15 | from kilt.datasets.base_dataset import Dataset 16 | from kilt import knowledge_source # remove later 17 | 18 | 19 | class TriviaQADataset(Dataset): 20 | def __init__(self, name, input_file, output_file, log_file): 21 | super().__init__(name) 22 | self.input_file = "/private/home/angelafan/robocheckers/KIB/wikipedia-test.json" 23 | self.output_file = output_file 24 | self.log_file = log_file 25 | self.nlp = spacy.load("en_core_web_sm") 26 | 27 | def get_chunks(self, num_chunks): 28 | with open(self.input_file, "r", encoding='utf-8') as infile: 29 | all_data = json.load(infile) 30 | 31 | all_data = all_data['Data'] 32 | n = len(all_data) 33 | print("{} examples in the dataset".format(n)) 34 | return utils.chunk_it(all_data, num_chunks) 35 | 36 | def process_chunk(self, chunk, ks, chunk_id=-1): 37 | missing_pages = 0.0 38 | short_exact_match = 0.0 39 | short_fuzzy_match = 0.0 40 | n = len(chunk) 41 | kilt_data = [] 42 | 43 | for idx, datapoint in enumerate(chunk): 44 | 45 | print( 46 | "t: {}, p: {:.2f} %, mp: {:.1f}, exact: {:.1f}, fuzzy: {:.1f}".format( 47 | chunk_id, 48 | round(idx * 100 / n, 2), 49 | missing_pages, 50 | short_exact_match, 51 | short_fuzzy_match, 52 | ), 53 | end="\r", 54 | ) 55 | sys.stdout.flush() 56 | 57 | # answer 58 | answers = datapoint["Answer"]["Aliases"] 59 | normalized_answers = datapoint["Answer"]["NormalizedAliases"] 60 | question = datapoint["Question"] 61 | wikipedia_pages = datapoint["EntityPages"] 62 | wiki_titles = [i["Title"] for i in wikipedia_pages] 63 | dataset_id = datapoint["QuestionId"] 64 | 65 | # group by question, 66 | for answer_index, answer in enumerate(answers): 67 | for title in wiki_titles: 68 | page = ks.get_pages_by_title(title) 69 | if not page: 70 | missing_pages += 1 # metric will be inflated since its on each unfetchable page 71 | else: 72 | page = page[0] 73 | kilt_record = { 74 | # original data point id if available otherwise unique id 75 | "id": dataset_id, 76 | # question / claim / sentence 77 | # dialogue history goes here 78 | "input": question, 79 | } 80 | 81 | local_sem = 0.0 82 | local_sfm = 0.0 83 | 84 | answer_span = answer 85 | 86 | ( 87 | paragraph_id, 88 | start_character, 89 | end_character, 90 | bleu, 91 | ) = utils.match_answer( 92 | answer_span, page, nlp=self.nlp, debug=False 93 | ) 94 | 95 | kilt_record_output = { 96 | # answer in textual form 97 | "answer": answer_span, 98 | "provenance": [ 99 | # list of relevant WikipediaPages / Spans as provenance for the answer from the ks 100 | { 101 | "wikipedia_id": page[ 102 | "wikipedia_id" 103 | ], # *mandatory* - ID Wikipedia Page 104 | "title": page[ 105 | "wikipedia_title" 106 | ], # *mandatory* - Title Wikipedia Page 107 | "start_paragraph_id": paragraph_id, # start paragraph id with relevant info 108 | "start_character": start_character, 109 | "end_paragraph_id": paragraph_id, # end paragraph id 110 | "end_character": end_character, 111 | "bleu_score": bleu, # 1.0 when gold data is exactly matched, lower for fuzzy matches 112 | "normalized_aliases": normalized_answers 113 | } 114 | ], 115 | } 116 | 117 | 118 | if bleu == 1: 119 | local_sem += 1 120 | elif bleu < 1 and bleu >= 0: 121 | local_sfm += 1 122 | else: 123 | print("ERROR: invalid bleu: {}".format(bleu)) 124 | sys.exit(-1) 125 | 126 | # update kilt data 127 | kilt_record["output"] = kilt_record_output 128 | kilt_data.append(kilt_record) 129 | 130 | metadata = [missing_pages] 131 | return kilt_data, metadata 132 | 133 | def postprocess_metadata(self, metadata): 134 | missing_pages = 0.0 135 | short_exact_match = 0.0 136 | short_fuzzy_match = 0.0 137 | for met in metadata: 138 | if met == []: 139 | continue 140 | mp, sem, sfm = met 141 | missing_pages += mp 142 | short_exact_match += sem 143 | short_fuzzy_match += sfm 144 | 145 | print("Print stats") 146 | msg = "\n n: {:.1f}, missing pages: {:.1f}, short exact match: {:.1f}, short fuzzy match: {:.1f}".format( 147 | 0, missing_pages, short_exact_match, short_fuzzy_match 148 | ) 149 | print(msg) 150 | 151 | f = open(self.log_file, "w+") 152 | f.write(msg) 153 | f.close() 154 | -------------------------------------------------------------------------------- /kilt/retrievers/BLINK_connector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import logging 10 | import pickle 11 | 12 | import blink.main_dense as main_dense 13 | from flair.models import SequenceTagger 14 | from flair.data import Sentence 15 | 16 | from kilt.retrievers.base_retriever import Retriever 17 | import kilt.kilt_utils as utils 18 | 19 | class BLINK(Retriever): 20 | def __init__(self, name, **config): 21 | super().__init__(name) 22 | 23 | 24 | self.args = argparse.Namespace(**config) 25 | 26 | self.logger = logging.getLogger("KILT") 27 | 28 | self.models = main_dense.load_models(self.args, logger=self.logger) 29 | 30 | self.ner_model = SequenceTagger.load("ner") 31 | 32 | self.cache_pages = {} 33 | 34 | self.Wikipedia_title2id = pickle.load(open(self.args.wikipedia_title2id, "rb")) 35 | 36 | def feed_data( 37 | self, 38 | queries_data, 39 | ent_start_token=utils.ENT_START, 40 | ent_end_token=utils.ENT_END, 41 | logger=None, 42 | ): 43 | if logger: 44 | self.logger = logger 45 | 46 | wikipedia_id2local_id = self.models[8] 47 | 48 | self.test_data = [] 49 | for element in queries_data: 50 | 51 | query = element["query"] 52 | 53 | if ent_start_token in query and ent_end_token in query: 54 | split1 = query.split(ent_start_token) 55 | assert len(split1) == 2 56 | left = split1[0] 57 | split2 = split1[1].split(ent_end_token) 58 | assert len(split2) == 2 59 | mention = split2[0] 60 | right = split2[1] 61 | 62 | record = { 63 | "id": element["id"], 64 | "label": "unknown", 65 | "label_id": -1, 66 | "context_left": left.strip().lower(), 67 | "mention": mention.strip().lower(), 68 | "context_right": right.strip().lower(), 69 | } 70 | self.test_data.append(record) 71 | else: 72 | 73 | # Apply a NER system 74 | sent = Sentence(query, use_tokenizer=True) 75 | self.ner_model.predict(sent) 76 | sent_mentions = sent.to_dict(tag_type="ner")["entities"] 77 | 78 | if len(sent_mentions) == 0: 79 | # no mention 80 | record = { 81 | "id": element["id"], 82 | "label": "unknown", 83 | "label_id": -1, 84 | "context_left": query.strip().lower(), 85 | "mention": "", 86 | "context_right": "", 87 | } 88 | self.test_data.append(record) 89 | 90 | else: 91 | # create a record for each mention detected 92 | for hit in sent_mentions: 93 | left = query[: int(hit["start_pos"])].strip() 94 | mention = hit["text"] 95 | right = query[int(hit["end_pos"]) :].strip() 96 | 97 | record = { 98 | "id": element["id"], 99 | "label": "unknown", 100 | "label_id": -1, 101 | "context_left": left.strip().lower(), 102 | "mention": mention.strip().lower(), 103 | "context_right": right.strip().lower(), 104 | } 105 | self.test_data.append(record) 106 | 107 | def run(self): 108 | ( 109 | biencoder_accuracy, 110 | recall_at, 111 | crossencoder_normalized_accuracy, 112 | overall_unormalized_accuracy, 113 | num_datapoints, 114 | predictions, 115 | scores, 116 | ) = main_dense.run( 117 | self.args, self.logger, *self.models, test_data=self.test_data 118 | ) 119 | 120 | # aggregate multiple records for the same datapoint 121 | print("aggregate multiple records for the same datapoint", flush=True) 122 | id_2_results = {} 123 | for r, p, s in zip(self.test_data, predictions, scores): 124 | 125 | if r["id"] not in id_2_results: 126 | id_2_results[r["id"]] = {"predictions": [], "scores": []} 127 | id_2_results[r["id"]]["predictions"].extend(p) 128 | id_2_results[r["id"]]["scores"].extend(s) 129 | 130 | provenance = {} 131 | 132 | for id, results in id_2_results.items(): 133 | 134 | element = [] 135 | 136 | # merge predictions when multiple entities are found 137 | sorted_titles = [] 138 | sorted_scores = [] 139 | for y, x in sorted( 140 | zip(results["scores"], results["predictions"]), reverse=True 141 | ): 142 | if x not in sorted_titles: 143 | sorted_titles.append(x) 144 | sorted_scores.append(y) 145 | 146 | local_doc_id = [] 147 | for e_title, score in zip(sorted_titles, sorted_scores): 148 | 149 | if e_title not in self.Wikipedia_title2id: 150 | print( 151 | "WARNING: title: {} not recognized".format(e_title), flush=True 152 | ) 153 | else: 154 | 155 | """ 156 | if e_title in self.cache_pages: 157 | page = self.cache_pages[e_title] 158 | else: 159 | page = self.ks.get_page_by_title(e_title) 160 | self.cache_pages[e_title] = page 161 | 162 | wikipedia_id = page["wikipedia_id"] 163 | """ 164 | 165 | wikipedia_id = self.Wikipedia_title2id[e_title] 166 | 167 | element.append( 168 | { 169 | "score": str(score), 170 | # "text": page["text"], 171 | "wikipedia_title": str(e_title), 172 | "wikipedia_id": str(wikipedia_id), 173 | } 174 | ) 175 | provenance[id] = element 176 | 177 | return provenance 178 | -------------------------------------------------------------------------------- /scripts/map_TAC-KBP2010_to_KILT.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | from tqdm.auto import tqdm 10 | import pickle 11 | import argparse 12 | 13 | from kilt.knowledge_source import KnowledgeSource 14 | 15 | 16 | def write_output(filename, data): 17 | with open(filename, "w+") as outfile: 18 | for idx, element in enumerate(data): 19 | # print(round(idx * 100 / len(data), 2), "%", end="\r") 20 | # sys.stdout.flush() 21 | json.dump(element, outfile) 22 | outfile.write("\n") 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument( 28 | "--train", 29 | dest="train_mentions_filename", 30 | type=str, 31 | default="data/tac_kbp_2010/train.jsonl", 32 | help="train file TAC-KBP2010", 33 | ) 34 | 35 | parser.add_argument( 36 | "--test", 37 | dest="test_mentions_filename", 38 | type=str, 39 | default="data/tac_kbp_2010/test.jsonl", 40 | help="train file TAC-KBP2010", 41 | ) 42 | 43 | parser.add_argument( 44 | "--entities", 45 | dest="test_entities_filename", 46 | type=str, 47 | default="data/tac_kbp_2010/tac_kbp_ref_know_base/entity.jsonl", 48 | help="knowledge source file TAC-KBP2010", 49 | ) 50 | 51 | parser.add_argument( 52 | "--out_test", 53 | dest="out_test", 54 | type=str, 55 | default="data/tac_kbp_2010/tackbp2010-test-kilt.jsonl", 56 | help="output file for TAC-KBP2010 test in KILT format", 57 | ) 58 | 59 | parser.add_argument( 60 | "--out_train", 61 | dest="out_train", 62 | type=str, 63 | default="data/tac_kbp_2010/tackbp2010-train-kilt.jsonl", 64 | help="output file for TAC-KBP2010 train in KILT format", 65 | ) 66 | 67 | args = parser.parse_args() 68 | 69 | ent_start_token = "[START_ENT]" 70 | ent_end_token = "[END_ENT]" 71 | ks = KnowledgeSource() 72 | kb2id = {} 73 | 74 | manual_labels_correspondance = { 75 | "E0431500": 19457, # Myanmar 76 | "E0633385": 109495, # Key West 77 | "E0277953": 8725021, # Aarti Agarwal 78 | "E0526355": 30875653, # Bob Casey Jr. 79 | "E0508649": 41709552, # American Eagle (airline brand) 80 | "E0504008": 504790, # New York Daily News 81 | "E0398776": 99689, # National Express 82 | "E0343020": 402982, # Reliance Industries Limited 83 | "E0131583": 77825, # TNT (American TV network) 84 | "E0586856": 12710981, # List of Dirty Sexy Money characters 85 | "E0439840": 1114732, # Palestine (region) 86 | "E0655951": 607797, # Miami Herald 87 | "E0681609": 7761399, # Chad Johnson 88 | "E0233160": 27169389, # Ronald Reagan UCLA Medical Center 89 | "E0465278": 7554772, # Randalls 90 | "E0435757": 2118244, # Bago, Myanmar 91 | "E0194326": 14141082, # Belmond Limited 92 | "E0029703": 30858216, # Aaj News 93 | "E0071026": 27885464, # Public Security Police Force of Macau 94 | "E0513036": 14331070, # Senvion 95 | "E03912200": None, # Nepal Cable Television Association 96 | "E0436955": None, # PAS 97 | } 98 | 99 | labels = {} 100 | with open(args.train_mentions_filename, "r") as fin: 101 | lines = fin.readlines() 102 | for line in lines: 103 | data = json.loads(line) 104 | label_id = str(data["label_id"]).strip() 105 | if label_id not in labels: 106 | labels[label_id] = False 107 | 108 | with open(args.test_mentions_filename, "r") as fin: 109 | lines = fin.readlines() 110 | for line in lines: 111 | data = json.loads(line) 112 | label_id = str(data["label_id"]).strip() 113 | if label_id not in labels: 114 | labels[label_id] = False 115 | 116 | print("labels:", len(labels)) 117 | missing_pages = 0 118 | with open(args.test_entities_filename, "r") as fin: 119 | lines = fin.readlines() 120 | for line in tqdm(lines): 121 | entity = json.loads(line) 122 | title = entity["title"] 123 | kb_idx = str(entity["kb_idx"]).strip() 124 | 125 | if kb_idx in labels: 126 | labels[kb_idx] = True 127 | title = title.replace("&", "&") 128 | page = ks.get_page_by_title(title) 129 | if page: 130 | kb2id[kb_idx] = page["wikipedia_id"] 131 | else: 132 | missing_pages += 1 133 | 134 | c = 0 135 | for label, found in labels.items(): 136 | if not found: 137 | if ( 138 | label in manual_labels_correspondance 139 | and manual_labels_correspondance[label] 140 | ): 141 | kb2id[label] = manual_labels_correspondance[label] 142 | else: 143 | c += 1 144 | print(f"missing {c}/{len(labels)} labels in ks") 145 | 146 | for idx, filename in enumerate( 147 | [args.test_mentions_filename, args.train_mentions_filename] 148 | ): 149 | kilt_records = [] 150 | missing = 0 151 | with open(filename, "r") as fin: 152 | lines = fin.readlines() 153 | for line in lines: 154 | data = json.loads(line) 155 | label_id = str(data["label_id"]).strip() 156 | if label_id in kb2id: 157 | wikipedia_id = kb2id[label_id] 158 | page = ks.get_page_by_id(wikipedia_id) 159 | 160 | input_text = ( 161 | str(data["context_left"]).strip() 162 | + " " 163 | + ent_start_token 164 | + " " 165 | + str(data["mention"]).strip() 166 | + " " 167 | + ent_end_token 168 | + " " 169 | + str(data["context_right"]).strip() 170 | ) 171 | 172 | # rename 173 | data["left_context"] = data.pop("context_left") 174 | data["right_context"] = data.pop("context_right") 175 | 176 | kilt_records.append( 177 | { 178 | "id": data["query_id"], 179 | "input": input_text, 180 | "output": [ 181 | { 182 | "answer": page["wikipedia_title"], 183 | "provenance": [ 184 | { 185 | "wikipedia_id": wikipedia_id, 186 | "title": page["wikipedia_title"], 187 | } 188 | ], 189 | } 190 | ], 191 | "meta": data, 192 | } 193 | ) 194 | else: 195 | missing += 1 196 | 197 | if idx == 1: 198 | print("missing {}/{} points in train".format(missing, len(lines))) 199 | write_output(args.out_train, kilt_records) 200 | elif idx == 0: 201 | print("missing {}/{} points in test".format(missing, len(lines))) 202 | write_output(args.out_test, kilt_records) 203 | else: 204 | print("ERROR") 205 | -------------------------------------------------------------------------------- /kilt/datasets/hotpotqa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import sys 9 | import os 10 | 11 | import spacy 12 | import pprint 13 | 14 | import kilt.kilt_utils as utils 15 | from kilt.datasets.base_dataset import Dataset 16 | from kilt.datasets.hotpotqa_ks import load_ks 17 | 18 | 19 | class HotpotQADataset(Dataset): 20 | def __init__( 21 | self, 22 | name, 23 | input_file, 24 | output_file, 25 | log_file, 26 | ks_directory, 27 | get_only_original_evidence, 28 | max_chunks=None, 29 | debug=False, 30 | ): 31 | super().__init__(name) 32 | self.input_file = input_file 33 | self.output_file = output_file 34 | self.log_file = log_file 35 | self.hotpotqa_ks = load_ks(ks_directory, verbose=True) 36 | self.nlp = spacy.load("en_core_web_sm") 37 | self.max_chunks = max_chunks 38 | self.debug = debug 39 | self.get_only_original_evidence = get_only_original_evidence 40 | 41 | def get_chunks(self, num_chunks): 42 | all_data = [] 43 | with open(self.input_file, "r") as fin: 44 | lines = fin.readlines() 45 | assert len(lines) == 1 46 | line = lines[0] 47 | all_data = json.loads(line) 48 | 49 | n = len(all_data) 50 | print("{} examples in the dataset".format(n)) 51 | return utils.chunk_it(all_data, num_chunks) 52 | 53 | def process_chunk(self, chunk, ks, chunk_id=-1): 54 | 55 | missing_pages = 0.0 56 | exact_match = 0.0 57 | fuzzy_match = 0.0 58 | n = len(chunk) 59 | kilt_data = [] 60 | metadata = [] 61 | for idx, datapoint in enumerate(chunk): 62 | print( 63 | "t: {}, p: {:.2f} %, mp: {:.1f}, exact: {:.1f}, fuzzy: {:.1f}".format( 64 | chunk_id, 65 | round(idx * 100 / n, 2), 66 | missing_pages, 67 | exact_match, 68 | fuzzy_match, 69 | ), 70 | end="\r", 71 | ) 72 | sys.stdout.flush() 73 | 74 | kilt_record = { 75 | # original data point id if available otherwise unique id 76 | "id": datapoint["_id"], 77 | # question / claim / sentence 78 | "input": datapoint["question"], 79 | # dataset/task specific 80 | "meta": {"level": datapoint["level"], "type": datapoint["type"],}, 81 | } 82 | kilt_record_provenance = [] 83 | 84 | local_missing_page = False 85 | local_exact_match = True 86 | for evidence in datapoint["supporting_facts"]: 87 | title = evidence[0] 88 | sent_id = evidence[1] 89 | text = "" 90 | try: 91 | text = self.hotpotqa_ks[title]["text"][sent_id] 92 | except IndexError as e: 93 | print( 94 | "\nIndexError: {}\ntitle:{}\nsent_id:{}\n".format( 95 | e, title, sent_id 96 | ) 97 | ) 98 | 99 | if self.get_only_original_evidence: 100 | kilt_record_provenance.append( 101 | {"text": text, "title": title, "sent_id": sent_id} 102 | ) 103 | 104 | else: 105 | pages = ks.get_pages_by_title(title) 106 | if len(pages) == 0: 107 | local_missing_page = True 108 | break 109 | 110 | bleu = -1 111 | paragraph_id = -1 112 | start_character = -1 113 | end_character = -1 114 | for page in pages: 115 | # it is unlikely, but there could be multiple pages for a title (e.g., disambiguation) 116 | if text and len(text) > 0: 117 | ( 118 | local_paragraph_id, 119 | local_start_character, 120 | local_end_character, 121 | local_bleu, 122 | ) = utils.match_answer( 123 | text, page, nlp=self.nlp, debug=False 124 | ) 125 | 126 | if local_bleu > bleu: 127 | paragraph_id = local_paragraph_id 128 | start_character = local_start_character 129 | end_character = local_end_character 130 | bleu = local_bleu 131 | 132 | if bleu != 1.0: 133 | local_exact_match = False 134 | 135 | kilt_record_provenance.append( 136 | # list of relevant WikipediaPages / Spans as provenance for the answer from the ks 137 | { 138 | "wikipedia_id": page[ 139 | "wikipedia_id" 140 | ], # *mandatory* - ID Wikipedia Page 141 | "title": page[ 142 | "wikipedia_title" 143 | ], # *mandatory* - Title Wikipedia Page 144 | "start_paragraph_id": paragraph_id, # start paragraph id with relevant info 145 | "start_character": start_character, 146 | "end_paragraph_id": paragraph_id, # end paragraph id 147 | "end_character": end_character, 148 | "bleu_score": bleu, # 1.0 when gold data is exactly matched, lower for fuzzy matches 149 | } 150 | ) 151 | 152 | if local_missing_page: 153 | missing_pages += 1 154 | continue 155 | if local_exact_match: 156 | exact_match += 1 157 | else: 158 | fuzzy_match += 1 159 | 160 | kilt_record["output"] = [ 161 | {"answer": datapoint["answer"], "provenance": kilt_record_provenance} 162 | ] 163 | kilt_data.append(kilt_record) 164 | 165 | if self.debug: 166 | pp = pprint.PrettyPrinter(indent=4) 167 | print("original datapoint:") 168 | pp.pprint(datapoint) 169 | input("...") 170 | print("kilt record:") 171 | pp.pprint(kilt_record) 172 | input("...") 173 | 174 | metadata = [missing_pages, exact_match, fuzzy_match] 175 | return kilt_data, metadata 176 | 177 | def postprocess_metadata(self, metadata): 178 | missing_pages = 0.0 179 | short_exact_match = 0.0 180 | short_fuzzy_match = 0.0 181 | for met in metadata: 182 | if met == []: 183 | continue 184 | mp, sem, sfm = met 185 | missing_pages += mp 186 | short_exact_match += sem 187 | short_fuzzy_match += sfm 188 | 189 | print("Print stats") 190 | msg = "\n n: {:.1f}, missing pages: {:.1f}, short exact match: {:.1f}, short fuzzy match: {:.1f}".format( 191 | 0, missing_pages, short_exact_match, short_fuzzy_match 192 | ) 193 | print(msg) 194 | 195 | f = open(self.log_file, "w+") 196 | f.write(msg) 197 | f.close() 198 | -------------------------------------------------------------------------------- /kilt/datasets/entity_linking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import random 9 | import sys 10 | import uuid 11 | 12 | import uuid 13 | from tqdm import tqdm 14 | 15 | import kilt.kilt_utils as utils 16 | from kilt.knowledge_source import KnowledgeSource 17 | from kilt.datasets.base_dataset import Dataset 18 | 19 | 20 | def convert_to_KILT_format( 21 | questions, 22 | ks, 23 | id_filter_positive, 24 | id_filter_negative, 25 | max_input_lenght=256, 26 | ent_start_token="[START_ENT]", 27 | ent_end_token="[END_ENT]", 28 | ): 29 | data = [] 30 | for q in questions: 31 | 32 | if id_filter_positive: 33 | if id_filter_positive not in q["id"]: 34 | continue 35 | 36 | if id_filter_negative: 37 | if id_filter_negative in q["id"]: 38 | continue 39 | 40 | page = ks.get_page_from_url(q["Wikipedia_URL"]) 41 | 42 | if page: 43 | left_context = q["left_context"].copy() 44 | right_context = q["right_context"].copy() 45 | 46 | left = " ".join(left_context).strip() 47 | text_mention = q["mention"].strip() 48 | right = " ".join(right_context).strip() 49 | 50 | # create input text 51 | # balance left and right context 52 | input_text = ( 53 | left 54 | + " " 55 | + ent_start_token 56 | + " " 57 | + text_mention 58 | + " " 59 | + ent_end_token 60 | + " " 61 | + right 62 | ) 63 | tokens = input_text.split() 64 | while ( 65 | len(tokens) >= max_input_lenght - 2 66 | ): # 2 = ent_start_token + ent_end_token 67 | offset = max(1, int((len(tokens) - max_input_lenght) / 2)) 68 | len_left = len(left.split()) 69 | len_right = len(right.split()) 70 | if len_left > len_right: 71 | left_context = left_context[offset:] 72 | left = " ".join(left_context).strip() 73 | else: 74 | right_context = right_context[:offset] 75 | right = " ".join(right_context).strip() 76 | # udpate tokens 77 | input_text = ( 78 | left 79 | + " " 80 | + ent_start_token 81 | + " " 82 | + text_mention 83 | + " " 84 | + ent_end_token 85 | + " " 86 | + right 87 | ) 88 | tokens = input_text.split() 89 | 90 | datapoint = { 91 | "id": str(uuid.uuid4()) + "_" + str(q["id"]), 92 | "input": input_text, 93 | "output": [ 94 | { 95 | "answer": page["wikipedia_title"], 96 | "provenance": [ 97 | # list of relevant WikipediaPages / Spans as provenance for the answer from the ks 98 | { 99 | "wikipedia_id": page["wikipedia_id"], 100 | "title": page["wikipedia_title"], 101 | } 102 | ], 103 | } 104 | ], 105 | "meta": { 106 | "left_context": " ".join(q["left_context"]).strip(), 107 | "mention": text_mention, 108 | "right_context": " ".join(q["right_context"]).strip(), 109 | }, # dataset/task specific 110 | } 111 | data.append(datapoint) 112 | return data 113 | 114 | 115 | class EntityLinkingDataset(Dataset): 116 | def __init__( 117 | self, 118 | name, 119 | input_file, 120 | output_file, 121 | id_filter_positive, 122 | id_filter_negative, 123 | max_chunks, 124 | ): 125 | super().__init__(name) 126 | self.input_file = input_file 127 | self.output_file = output_file 128 | self.ks = KnowledgeSource() 129 | self.id_filter_positive = id_filter_positive 130 | self.id_filter_negative = id_filter_negative 131 | self.max_chunks = max_chunks 132 | 133 | def get_chunks(self, num_chunks): 134 | 135 | data = [] 136 | with open(self.input_file, "r") as fin: 137 | data = fin.readlines() 138 | 139 | # a single chunk for entity linking 140 | return [data] 141 | 142 | def process_chunk(self, lines, ks, chunk_id=-1): 143 | 144 | kilt_records = [] 145 | 146 | # left context so far in the document 147 | left_context = [] 148 | 149 | # working datapoints for the document 150 | document_questions = [] 151 | 152 | # is the entity open 153 | open_entity = False 154 | 155 | # question id in the document 156 | question_i = 0 157 | 158 | for line in tqdm(lines): 159 | 160 | if "-DOCSTART-" in line: 161 | # new document is starting 162 | 163 | doc_id = line.split("(")[-1][:-2] 164 | 165 | # END DOCUMENT 166 | 167 | # check end of entity 168 | if open_entity: 169 | open_entity = False 170 | 171 | """ 172 | #DEBUG 173 | for q in document_questions: 174 | pp.pprint(q) 175 | input("...") 176 | """ 177 | 178 | # add sentence_questions to kilt_records 179 | kilt_records.extend( 180 | convert_to_KILT_format( 181 | document_questions, 182 | self.ks, 183 | self.id_filter_positive, 184 | self.id_filter_negative, 185 | ) 186 | ) 187 | 188 | # reset 189 | left_context = [] 190 | document_questions = [] 191 | question_i = 0 192 | 193 | else: 194 | split = line.split("\t") 195 | token = split[0].strip() 196 | 197 | if len(split) >= 5: 198 | B_I = split[1] 199 | mention = split[2] 200 | #  YAGO2_entity = split[3] 201 | Wikipedia_URL = split[4] 202 | Wikipedia_ID = split[5] 203 | # Freee_base_id = split[6] 204 | 205 | if B_I == "I": 206 | pass 207 | 208 | elif B_I == "B": 209 | 210 | q = { 211 | "id": "{}:{}".format(doc_id, question_i), 212 | "mention": mention, 213 | "Wikipedia_URL": Wikipedia_URL, 214 | "Wikipedia_ID": Wikipedia_ID, 215 | "left_context": left_context.copy(), 216 | "right_context": [], 217 | } 218 | document_questions.append(q) 219 | open_entity = True 220 | question_i += 1 221 | 222 | else: 223 | print("Invalid B_I {}", format(B_I)) 224 | sys.exit(-1) 225 | 226 | # print(token,B_I,mention,Wikipedia_URL,Wikipedia_ID) 227 | else: 228 | if open_entity: 229 | open_entity = False 230 | 231 | left_context.append(token) 232 | 233 | for q in document_questions[:-1]: 234 | q["right_context"].append(token) 235 | 236 | if len(document_questions) > 0 and not open_entity: 237 | document_questions[-1]["right_context"].append(token) 238 | 239 | # FINAL SENTENCE 240 | if open_entity: 241 | open_entity = False 242 | 243 | # add sentence_questions to kilt_records 244 | kilt_records.extend( 245 | convert_to_KILT_format( 246 | document_questions, 247 | self.ks, 248 | self.id_filter_positive, 249 | self.id_filter_negative, 250 | ) 251 | ) 252 | 253 | return kilt_records, [] # no metadata 254 | 255 | def postprocess_metadata(self, metadata): 256 | pass 257 | -------------------------------------------------------------------------------- /kilt/eval_downstream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import pprint 10 | import re 11 | import string 12 | from rouge import Rouge 13 | 14 | from collections import Counter 15 | 16 | import kilt.eval_retrieval as retrieval_metrics 17 | from kilt import kilt_utils 18 | 19 | # utility to get gold answers 20 | def get_gold_answers(gold): 21 | ground_truths = set() 22 | for item in gold["output"]: 23 | if "answer" in item and item["answer"] and len(item["answer"].strip()) > 0: 24 | ground_truths.add(item["answer"].strip()) 25 | return ground_truths 26 | 27 | 28 | # utility to get gold titles 29 | def get_gold_titles(gold): 30 | titles = set() 31 | for item in gold["output"]: 32 | if "provenance" in item: 33 | for provenance in item["provenance"]: 34 | if ( 35 | "title" in provenance 36 | and provenance["title"] 37 | and len(provenance["title"].strip()) > 0 38 | ): 39 | titles.add(provenance["title"].strip()) 40 | return titles 41 | 42 | 43 | # utility to get max 44 | def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 45 | scores_for_ground_truths = [] 46 | for ground_truth in ground_truths: 47 | score = metric_fn(prediction, ground_truth) 48 | scores_for_ground_truths.append(score) 49 | return max(scores_for_ground_truths) 50 | 51 | 52 | # answer nomalization 53 | def normalize_answer(s): 54 | """Lower text and remove punctuation, articles and extra whitespace.""" 55 | 56 | def remove_articles(text): 57 | return re.sub(r"\b(a|an|the)\b", " ", text) 58 | 59 | def white_space_fix(text): 60 | return " ".join(text.split()) 61 | 62 | def remove_punc(text): 63 | exclude = set(string.punctuation) 64 | return "".join(ch for ch in text if ch not in exclude) 65 | 66 | def lower(text): 67 | return text.lower() 68 | 69 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 70 | 71 | 72 | # F1 score definition 73 | def _f1_score(prediction, ground_truth): 74 | prediction_tokens = normalize_answer(prediction).split() 75 | ground_truth_tokens = normalize_answer(ground_truth).split() 76 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 77 | num_same = sum(common.values()) 78 | if num_same == 0: 79 | return 0 80 | precision = 1.0 * num_same / len(prediction_tokens) 81 | recall = 1.0 * num_same / len(ground_truth_tokens) 82 | f1 = (2 * precision * recall) / (precision + recall) 83 | return f1 84 | 85 | 86 | # EM score definition 87 | def _exact_match_score(prediction, ground_truth): 88 | return normalize_answer(prediction) == normalize_answer(ground_truth) 89 | 90 | 91 | # ROUGEL score definition 92 | def _rougel_score(prediction, ground_truth): 93 | rouge = Rouge() 94 | # no normalization 95 | try: 96 | scores = rouge.get_scores(prediction, ground_truth, avg=True) 97 | except ValueError: # "Hypothesis is empty." 98 | return 0.0 99 | return scores["rouge-l"]["f"] 100 | 101 | 102 | def _calculate_metrics(gold_records, guess_records): 103 | 104 | assert len(gold_records) == len( 105 | guess_records 106 | ), "different size gold: {} guess: {}".format(len(gold_records), len(guess_records)) 107 | 108 | total_count = 0 109 | 110 | # downstream metrics 111 | accuracy = 0 112 | normalized_em = 0 113 | normalized_f1 = 0 114 | rougel = 0 115 | 116 | # kilt metrics 117 | kilt_accuracy = 0 118 | kilt_em = 0 119 | kilt_f1 = 0 120 | kilt_rougel = 0 121 | 122 | for guess_item, gold_item in zip(guess_records, gold_records): 123 | 124 | # check ids 125 | assert ( 126 | str(gold_item["id"]).strip() == str(guess_item["id"]).strip() 127 | ), "Items must have same order with same IDs" 128 | 129 | total_count += 1 130 | # check if each output of guess file exist in set of candidate answers 131 | gold_candidate_answers = get_gold_answers(gold_item) 132 | 133 | conditions = (len(guess_item["output"]) == 1) and ( 134 | "answer" in guess_item["output"][0] 135 | ) 136 | assert ( 137 | conditions 138 | ), f"you should provide exactly one valid answer for {guess_item['id']}" 139 | guess_answer = str(guess_item["output"][0]["answer"]).strip() 140 | 141 | if len(guess_answer) == 0: 142 | # empty answer 143 | continue 144 | 145 | # 0. accuracy = strict exact match 146 | local_accuracy = 0 147 | if guess_answer in gold_candidate_answers: 148 | local_accuracy = 1 149 | accuracy += local_accuracy 150 | 151 | # 1. normalized exact match 152 | local_em = _metric_max_over_ground_truths( 153 | _exact_match_score, guess_answer, gold_candidate_answers 154 | ) 155 | normalized_em += local_em 156 | 157 | # 2. normalized f1 158 | local_f1 = _metric_max_over_ground_truths( 159 | _f1_score, guess_answer, gold_candidate_answers 160 | ) 161 | normalized_f1 += local_f1 162 | 163 | # 3. rougel 164 | local_rougel = _metric_max_over_ground_truths( 165 | _rougel_score, guess_answer, gold_candidate_answers 166 | ) 167 | rougel += local_rougel 168 | 169 | # KILT-metrics 170 | Rprec = retrieval_metrics.rprecision( 171 | guess_item, gold_item, rank_keys=["wikipedia_id"] 172 | ) 173 | if Rprec == 1: 174 | # 1. KILT-AC 175 | kilt_accuracy += local_accuracy 176 | 177 | # 2. KILT-EM 178 | kilt_em += local_em 179 | 180 | # 3. KILT-F1 181 | kilt_f1 += local_f1 182 | 183 | # 4. KILT-RL 184 | kilt_rougel += local_rougel 185 | 186 | if total_count > 0: 187 | accuracy /= total_count 188 | normalized_em /= total_count 189 | normalized_f1 /= total_count 190 | rougel /= total_count 191 | kilt_accuracy /= total_count 192 | kilt_em /= total_count 193 | kilt_f1 /= total_count 194 | kilt_rougel /= total_count 195 | 196 | return { 197 | "kilt": { 198 | "KILT-accuracy": kilt_accuracy, 199 | "KILT-em": kilt_em, 200 | "KILT-f1": kilt_f1, 201 | "KILT-rougel": kilt_rougel, 202 | }, 203 | "downstream": { 204 | "accuracy": accuracy, 205 | "em": normalized_em, 206 | "f1": normalized_f1, 207 | "rougel": rougel, 208 | }, 209 | } 210 | 211 | 212 | def validate_input(gold_records, guess_records): 213 | 214 | if len(gold_records) != len(guess_records): 215 | print( 216 | "WARNING: DIFFERENT SIZE gold: {} guess: {}".format( 217 | len(gold_records), len(guess_records) 218 | ) 219 | ) 220 | 221 | # align order 222 | gold_ids = [] 223 | for gold in gold_records: 224 | assert str(gold["id"]).strip() not in gold_ids, "Gold IDs should be unique" 225 | gold_ids.append(str(gold["id"]).strip()) 226 | 227 | id2guess_record = {} 228 | for guess in guess_records: 229 | assert ( 230 | str(guess["id"]).strip() not in id2guess_record 231 | ), "Prediction IDs should be unique" 232 | id2guess_record[str(guess["id"]).strip()] = guess 233 | 234 | guess_records = [] 235 | for id in gold_ids: 236 | if id in id2guess_record: 237 | guess_records.append(id2guess_record[id]) 238 | else: 239 | raise ValueError("ERROR: no prediction provided for id: {}".format(id)) 240 | 241 | return gold_records, guess_records 242 | 243 | 244 | def evaluate(gold, guess): 245 | pp = pprint.PrettyPrinter(indent=4) 246 | 247 | gold_records = kilt_utils.load_data(gold) 248 | guess_records = kilt_utils.load_data(guess) 249 | 250 | # 0. validate input 251 | gold_records, guess_records = validate_input(gold_records, guess_records) 252 | 253 | # 1. downstream + kilt 254 | result = _calculate_metrics(gold_records, guess_records) 255 | 256 | # 2. retrieval performance 257 | retrieval_results = retrieval_metrics.compute( 258 | gold_records, guess_records, ks=[1, 5], rank_keys=["wikipedia_id"] 259 | ) 260 | result["retrieval"] = { 261 | "Rprec": retrieval_results["Rprec"], 262 | "recall@5": retrieval_results["recall@5"], 263 | } 264 | 265 | pp.pprint(result) 266 | return result 267 | 268 | 269 | if __name__ == "__main__": 270 | parser = argparse.ArgumentParser() 271 | parser.add_argument("guess", help="Guess KILT file") 272 | parser.add_argument("gold", help="Gold KILT file") 273 | 274 | args = parser.parse_args() 275 | evaluate(args.gold, args.guess) 276 | -------------------------------------------------------------------------------- /scripts/create_kilt_data_paragraphs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import multiprocessing 9 | from multiprocessing.pool import ThreadPool 10 | import sys 11 | import argparse 12 | import pickle 13 | import json 14 | import os 15 | import spacy 16 | from tqdm import tqdm, trange 17 | 18 | import kilt.kilt_utils as utils 19 | from kilt.knowledge_source import KnowledgeSource 20 | 21 | 22 | def create_chunk(document, buffer, paragraph_id, paragraph, section): 23 | start = buffer[0].idx 24 | end = buffer[-1].idx + len(buffer[-1]) 25 | 26 | anchors = [ 27 | { 28 | "text": anchor["text"], 29 | "href": anchor["href"], 30 | "source": { 31 | "paragraph_id": anchor["paragraph_id"], 32 | "start": anchor["start"], 33 | "end": anchor["end"], 34 | }, 35 | "start": anchor["start"] - start, 36 | "end": anchor["end"] - start, 37 | } 38 | for anchor in document["anchors"] 39 | if anchor["paragraph_id"] == paragraph_id 40 | and anchor["start"] >= start 41 | and anchor["end"] <= end 42 | ] 43 | 44 | return { 45 | "_id": document["_id"], 46 | "wikipedia_id": document["wikipedia_id"], 47 | "wikipedia_title": document["wikipedia_title"], 48 | "text": paragraph.text[start : end + 1].strip(), 49 | "tmp_len": len(buffer), 50 | "anchors": anchors, 51 | "categories": document["categories"], 52 | "history": document["history"], 53 | "sources": [{"paragraph_id": paragraph_id, "start": start, "end": end,}], 54 | "section": section, 55 | } 56 | 57 | 58 | def run_thread(args): 59 | documents = args["documents"] 60 | nlp = args["nlp"] 61 | id = args["id"] 62 | rank = args["rank"] 63 | chunk_size = args["chunk_size"] 64 | 65 | if id == 0 and rank == 0: 66 | iter_ = tqdm(documents) 67 | else: 68 | iter_ = documents 69 | 70 | # initialization 71 | output = [] 72 | 73 | for document in iter_: 74 | 75 | # initialization 76 | buffer = [] 77 | section = "Section::::Abstract" 78 | 79 | # loop paragrpahs removing first (title) 80 | for paragraph_id, paragraph in enumerate(nlp.pipe(document["text"][1:]), 1): 81 | 82 | # if section then save name and move on 83 | if "Section::::" in paragraph.text: 84 | section = paragraph.text.strip() 85 | continue 86 | 87 | for sentence in paragraph.sents: 88 | if buffer and len(buffer) + len(sentence) >= chunk_size: 89 | # create new chunk 90 | new_chunk = create_chunk( 91 | document, buffer, paragraph_id, paragraph, section 92 | ) 93 | output.append(new_chunk) 94 | buffer = [] 95 | 96 | for token in sentence: 97 | word = token.text.strip() 98 | if word and len(word) > 0: 99 | buffer.append(token) 100 | 101 | if buffer: 102 | # create new chunk 103 | new_chunk = create_chunk( 104 | document, buffer, paragraph_id, paragraph, section 105 | ) 106 | 107 | # conditions on merging with previous chunk 108 | if ( 109 | output 110 | and document["wikipedia_id"] == output[-1]["wikipedia_id"] 111 | and section == output[-1]["section"] 112 | and len(buffer) + output[-1]["tmp_len"] < chunk_size 113 | ): 114 | 115 | # adjusting anchors offsets 116 | for anchor in new_chunk["anchors"]: 117 | anchor["start"] += len(output[-1]["text"]) + 1 118 | anchor["end"] += len(output[-1]["text"]) + 1 119 | 120 | # appending new data 121 | output[-1]["text"] += " " + new_chunk["text"] 122 | output[-1]["anchors"] += new_chunk["anchors"] 123 | output[-1]["sources"] += new_chunk["sources"] 124 | output[-1]["tmp_len"] += new_chunk["tmp_len"] + 1 125 | else: 126 | output.append(new_chunk) 127 | buffer = [] 128 | 129 | for out in output: 130 | del out["tmp_len"] 131 | 132 | return output 133 | 134 | 135 | def store_chunks(documents, num_threads, folder): 136 | for id, chunk in enumerate(utils.chunk_it(documents, num_threads)): 137 | out_filename = os.path.join(folder, "documents_{}.p".format(id)) 138 | pickle.dump(chunk, open(out_filename, "wb")) 139 | 140 | 141 | def load_chunk(id, folder): 142 | in_filename = os.path.join(folder, "documents_{}.p".format(id)) 143 | return pickle.load(open(in_filename, "rb")) 144 | 145 | 146 | def load_all_documents_from_ks(cursor, steps, n): 147 | documents = [] 148 | j = 0 149 | for document in cursor: 150 | if j % steps == 0: 151 | sys.stdout.write("{}/{} \r".format(j, n)) 152 | sys.stdout.flush() 153 | documents.append(document) 154 | j += 1 155 | return documents 156 | 157 | 158 | def preprocess_data(num_threads, folder): 159 | 160 | ks = KnowledgeSource() 161 | n = ks.get_num_pages() 162 | steps = int(n / 100) 163 | 164 | cursor = ks.get_all_pages_cursor() 165 | 166 | print("LOADING ALL DOCUMENTS", flush=True) 167 | ducuments = load_all_documents_from_ks(cursor, steps, n) 168 | store_chunks(ducuments, num_threads, folder) 169 | 170 | 171 | def main(rank, num_threads, folder, chunk_size): 172 | 173 | print("loading chunk {}".format(rank), flush=True) 174 | documents = load_chunk(rank, folder) 175 | 176 | arguments = [ 177 | { 178 | "rank": rank, 179 | "id": id, 180 | "documents": chunk, 181 | "nlp": spacy.load("en_core_web_sm"), 182 | "chunk_size": chunk_size, 183 | } 184 | for id, chunk in enumerate(utils.chunk_it(documents, num_threads)) 185 | ] 186 | 187 | print("starting {} threads in {}".format(num_threads, rank)) 188 | pool = ThreadPool(num_threads) 189 | results = pool.map(run_thread, arguments) 190 | 191 | f = open(os.path.join(folder, "kilt_{}.jsonl".format(rank)), "w+",) 192 | 193 | i = 1 194 | for output in results: 195 | for msg in output: 196 | f.write("{}\t{}\n".format(i, json.dumps(msg))) 197 | i += 1 198 | f.close() 199 | pool.terminate() 200 | pool.join() 201 | print("done {}".format(rank)) 202 | 203 | 204 | def merge_files(num_threads, folder): 205 | 206 | f = open(os.path.join(folder, "kilt.jsonl"), "w+") 207 | i = 1 208 | for rank in trange(num_threads): 209 | filename = os.path.join(folder, "kilt_{}.jsonl".format(rank)) 210 | print("reading {}".format(filename), flush=True) 211 | with open(filename, "r") as fin: 212 | lines = fin.readlines() 213 | for line in tqdm(lines): 214 | elements = line.split("\t") 215 | if len(elements) != 2: 216 | print( 217 | "ERROR: len(elements)!=2 -> {}".format(len(elements)), 218 | flush=True, 219 | ) 220 | else: 221 | f.write("{}\t{}\n".format(i, elements[1].strip())) 222 | i += 1 223 | f.close() 224 | print("done") 225 | 226 | 227 | if __name__ == "__main__": 228 | parser = argparse.ArgumentParser() 229 | 230 | parser.add_argument( 231 | "--step", 232 | type=str, 233 | choices=["preprocess", "main", "merge"], 234 | help="step to exectue", 235 | ) 236 | 237 | parser.add_argument( 238 | "--chunk_size", default=100, type=int, help="chunk max token size", 239 | ) 240 | 241 | parser.add_argument( 242 | "--folder", type=str, help="path where to save and load files", 243 | ) 244 | 245 | parser.add_argument( 246 | "--rank", default=None, type=int, help="rank in a distributed execution", 247 | ) 248 | 249 | parser.add_argument( 250 | "--threads", default=None, type=int, help="number of threads", 251 | ) 252 | 253 | args = parser.parse_args() 254 | 255 | if args.threads == None: 256 | args.threads = int(multiprocessing.cpu_count()) 257 | 258 | # step 1 259 | if args.step == "preprocess": 260 | preprocess_data(num_threads=args.threads, folder=args.folder) 261 | # step 2 262 | elif args.step == "main": 263 | main( 264 | rank=args.rank, 265 | num_threads=args.threads, 266 | folder=args.folder, 267 | chunk_size=args.chunk_size, 268 | ) 269 | # step 3 270 | elif args.step == "merge": 271 | merge_files(num_threads=args.threads, folder=args.folder) 272 | -------------------------------------------------------------------------------- /kilt/readers/t5/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import configparser 9 | import fcntl 10 | import gzip 11 | import json 12 | import os 13 | import pathlib 14 | 15 | import torch.utils.data 16 | from transformers.tokenization_utils import trim_batch 17 | 18 | dataset_task_map = {'nq': "Question Answering", "aidayago2": "Entity Linking", "cweb": "Entity Linking", 19 | "fever": "Fact Checking", "hotpotqa": "Question Answering", 20 | "triviaqa": "Question Answering", "wned": "Entity Linking", "wow": "Dialogue", 21 | "zeroshot": "Relation Extraction", "trex":"Slot Filling", "eli5":"Question Answering"} 22 | 23 | dataset_config = configparser.ConfigParser() 24 | location = os.path.join(pathlib.Path(__file__).parent, 'config_file') 25 | dataset_config.read(location) 26 | 27 | 28 | def encode_seq(tokenizer, seqs, max_length, out_dir, dataset, side='source', type_path='train', pad_to_max_length=True, 29 | return_tensors="pt"): 30 | examples = [] 31 | lengths = [] 32 | 33 | output_file = os.path.join(out_dir, dataset + "-" + type_path + "-" + side + ".encoded") 34 | with open(output_file, "w") as f_out: 35 | texts = [] 36 | for text in seqs: 37 | 38 | if dataset_task_map[dataset] == 'Entity Linking' and side == 'source': 39 | length = int(int(dataset_config[dataset]['source_length']) / 2) 40 | mention_start = text.find('[START_ENT]') 41 | mention_end = text.find('[END_ENT]') 42 | left = text[0:mention_start] 43 | right = text[mention_end + len('[END_ENT]'):] 44 | 45 | left_ids = tokenizer.encode(left) 46 | right_ids = tokenizer.encode(right) 47 | left = tokenizer.decode(left_ids[max(0, len(left_ids) - length):len(left_ids)]) 48 | right = tokenizer.decode(right_ids[0:min(len(right_ids), length)]) 49 | text = left + ' ' + text[mention_start:mention_end] + '[END_ENT] ' + right 50 | 51 | if dataset == 'wow' and side == 'source': 52 | text = text.replace('\n', '[SEP]') 53 | 54 | if dataset == 'fever' and side == 'target': 55 | if text == "REFUTES": 56 | text = "" 57 | if text == "SUPPORTS": 58 | text = "" 59 | 60 | txt = text if side == 'target' else \ 61 | dataset_task_map[dataset] + ": " + text 62 | txt = txt + tokenizer.eos_token 63 | texts.append(txt) 64 | 65 | if dataset == 'wow' and side == 'source': 66 | tokenized = tokenizer.batch_encode_plus( 67 | texts, add_special_tokens=True, max_length=max_length, pad_to_max_length='left', 68 | return_tensors=return_tensors, 69 | ) 70 | else: 71 | tokenized = tokenizer.batch_encode_plus( 72 | texts, add_special_tokens=True, max_length=max_length, pad_to_max_length=pad_to_max_length, 73 | return_tensors=return_tensors, 74 | ) 75 | 76 | #lengths.append(tokenized["input_ids"].size()[1]) 77 | 78 | for input in tokenized["input_ids"]: 79 | tokens = tokenizer.convert_ids_to_tokens(input) 80 | f_out.write(' | '.join(tokens) + "\n") 81 | 82 | 83 | 84 | return tokenized 85 | 86 | 87 | class KiltDataset(torch.utils.data.Dataset): 88 | def __init__( 89 | self, 90 | tokenizer, 91 | data_dir, 92 | dataset, 93 | type_path, 94 | max_source_length, 95 | max_target_length, 96 | output_dir 97 | ): 98 | super().__init__() 99 | self.tokenizer = tokenizer 100 | 101 | self.source = [] 102 | self.target = [] 103 | 104 | # self.ids, raw_sources, raw_targets, self.id_targets = nq_jsonl_to_tsv(data_dir, type_path) 105 | 106 | self.ids, raw_sources, raw_targets, self.id_targets = kilt_to_seq2seq(data_dir, dataset, type_path) 107 | 108 | self.source = encode_seq(tokenizer, raw_sources, max_source_length, output_dir, dataset, 'source', type_path) 109 | self.target = encode_seq(tokenizer, raw_targets, max_target_length, output_dir, dataset, 'target', type_path) 110 | 111 | def __len__(self): 112 | return len(self.source["input_ids"]) 113 | 114 | def __getitem__(self, index): 115 | 116 | source_ids = self.source["input_ids"][index].squeeze() 117 | target_ids = self.target["input_ids"][index].squeeze() 118 | src_mask = self.source["attention_mask"][index].squeeze() 119 | q_id = self.ids[index] 120 | return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "id": q_id} 121 | 122 | @staticmethod 123 | def trim_seq2seq_batch(batch, pad_token_id): 124 | target_ids = trim_batch(batch["target_ids"], pad_token_id) 125 | source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) 126 | return source_ids, source_mask, target_ids 127 | 128 | 129 | def kilt_to_seq2seq(data_dir, dataset, type_path): 130 | data_file = pathlib.Path(os.path.join(data_dir, dataset + '-' + type_path + "-kilt.jsonl")) 131 | sources = [] 132 | targets = [] 133 | ids = [] 134 | id_targets = {} 135 | if not data_file.exists(): 136 | return ids, sources, targets 137 | 138 | with open(data_file, "r") as f: 139 | 140 | for line in f.readlines(): 141 | qa = json.loads(line) 142 | q_id = qa['id'] 143 | question = qa['input'] 144 | output = qa['output'] 145 | if len(output) == 0: 146 | continue 147 | answers = set() 148 | id_targets[q_id] = [] 149 | 150 | for out in output: 151 | if 'answer' not in out.keys(): 152 | continue 153 | 154 | answer = out['answer'] 155 | answers.add(answer) 156 | id_targets[q_id].append(answer) 157 | if type_path == 'test': 158 | sources.append(question) 159 | targets.append(answers.pop()) 160 | ids.append(q_id) 161 | else: 162 | for answer in answers: 163 | sources.append(question) 164 | targets.append(answer) 165 | ids.append(q_id) 166 | return ids, sources, targets, id_targets 167 | 168 | 169 | def seq2seq_to_kilt(ids, sources, targets, output_dir, dataset, type_path): 170 | data_file = os.path.join(output_dir, dataset + '-' + type_path + "-kilt.jsonl") 171 | 172 | with open(data_file, "a+") as output_file: 173 | data = [] 174 | for q_id, s, t in zip(ids, sources, targets): 175 | qa = {"id": q_id, 'input': s, 'output': []} 176 | a = {'answer': t, 'provenance': []} 177 | qa['output'].append(a) 178 | data.append(json.dumps(qa)) 179 | fcntl.flock(output_file, fcntl.LOCK_EX) 180 | if os.stat(data_file).st_size > 0: 181 | output_file.write('\n') 182 | output_file.write('\n'.join(data)) 183 | fcntl.flock(output_file, fcntl.LOCK_UN) 184 | 185 | 186 | def nq_jsonl_to_tsv(data_dir, type_path): 187 | def extract_answer(answer_tokens, span): 188 | """Reconstruct answer from token span and remove extra spaces.""" 189 | start, end = span["start_token"], span["end_token"] 190 | ans = " ".join(answer_tokens[start:end]) 191 | # Remove incorrect spacing around punctuation. 192 | ans = ans.replace(" ,", ",").replace(" .", ".").replace(" %", "%") 193 | ans = ans.replace(" - ", "-").replace(" : ", ":").replace(" / ", "/") 194 | ans = ans.replace("( ", "(").replace(" )", ")") 195 | ans = ans.replace("`` ", "\"").replace(" ''", "\"") 196 | ans = ans.replace(" 's", "'s").replace("s ' ", "s' ") 197 | return ans 198 | 199 | count = 0 200 | ids = [] 201 | sources = [] 202 | targets = [] 203 | id_targets = {} 204 | in_fname = data_dir + '/' + type_path + '.jsonl.gz' 205 | 206 | for line in gzip.open(in_fname, "rb"): 207 | ex = json.loads(line) 208 | 209 | # Remove any examples with more than one answer. 210 | 211 | # Questions in NQ do not include a question mark. 212 | q_id = ex['annotations'][0]['annotation_id'] 213 | question = ex["question_text"] + "?" 214 | answers = [] 215 | for answer_span in ex['annotations'][0]['short_answers']: 216 | tokens = [] 217 | # Handle the two document formats in NQ (tokens or text). 218 | if "document_tokens" in ex: 219 | tokens = [t["token"] for t in ex["document_tokens"]] 220 | elif "document_text" in ex: 221 | tokens = ex["document_text"].split(" ") 222 | answer = extract_answer(tokens, answer_span) 223 | # Write this line as \t 224 | sources.append(question) 225 | targets.append(answer) 226 | answers.append(answer) 227 | ids.append(q_id) 228 | id_targets[q_id] = answers 229 | count += 1 230 | 231 | return ids, sources, targets, id_targets 232 | 233 | 234 | if __name__ == "__main__": 235 | pass 236 | -------------------------------------------------------------------------------- /kilt/datasets/fact_verification.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import spacy 9 | import sys 10 | import unicodedata 11 | 12 | 13 | import pprint 14 | 15 | pp = pprint.PrettyPrinter(indent=4) 16 | 17 | import kilt.kilt_utils as utils 18 | from kilt.datasets.base_dataset import Dataset 19 | 20 | 21 | class FactVerificationDataset(Dataset): 22 | def __init__( 23 | self, name, claims_input_file, evidence_directory_path, output_file, log_file 24 | ): 25 | super().__init__(name) 26 | self.claims_input_file = claims_input_file 27 | self.evidence_directory_path = evidence_directory_path 28 | self.output_file = output_file 29 | self.log_file = log_file 30 | self.nlp = spacy.load("en_core_web_sm") 31 | 32 | def _normalize(self, text): 33 | replacements = { 34 | "-LRB-": "(", 35 | "-RRB-": ")", 36 | "-LSB-": "[", 37 | "-RSB-": "]", 38 | "-LCB-": "{", 39 | "-RCB-": "}", 40 | "-COLON-": ":", 41 | } 42 | for key, val in replacements.items(): 43 | text = text.replace(key, val) 44 | return text 45 | 46 | def get_chunks(self, num_chunks): 47 | 48 | # Read claims, create a set of wiki pages to 49 | # find the evidence sentences in 50 | page_to_evidence_sents = {} 51 | 52 | with open(self.claims_input_file, "r") as infile: 53 | for line in infile: 54 | claim = json.loads(line) 55 | 56 | if "verifiable" in claim and claim["verifiable"] == "NOT VERIFIABLE": 57 | continue 58 | 59 | evidence_sets = claim["evidence"] 60 | for evidence_set in evidence_sets: 61 | 62 | for evidence in evidence_set: 63 | if evidence[2]: 64 | page_id = unicodedata.normalize("NFKD", evidence[2]) 65 | else: 66 | # those can be filtered out/ignored. They’re an artefact of merging some of the duplicates where annotators disagreed over the label. 67 | break 68 | 69 | sent_id = int(evidence[3]) 70 | 71 | if page_id not in page_to_evidence_sents: 72 | page_to_evidence_sents[page_id] = {} 73 | 74 | page_to_evidence_sents[page_id][sent_id] = None 75 | 76 | for idx in range(1, 110): 77 | filename = self.evidence_directory_path + f"/wiki-{idx:03}.jsonl" 78 | print(f"processing filename {filename}") 79 | with open(filename, "r") as fin: 80 | for line in fin: 81 | wiki_page = json.loads(line.strip()) 82 | page_id = wiki_page["id"] 83 | if page_id not in page_to_evidence_sents: 84 | continue 85 | lines = wiki_page["lines"].split("\n") 86 | sentences = [] 87 | for l in lines: 88 | line_fields = l.split("\t") 89 | # skip empty sentences 90 | if len(line_fields) < 2 or line_fields[1] == "": 91 | continue 92 | # skip sentences where first element is not number 93 | if not line_fields[0].isdigit(): 94 | continue 95 | 96 | sent_text = line_fields[1] 97 | 98 | # there is no id, so the new line character is 99 | # likely a formatting error, will ignore and 100 | # append the normalized text to the previous 101 | # sentence. 102 | if line_fields[0] == "": 103 | sentences[-1]["text"] += " " + sent_text 104 | else: 105 | sentences.append( 106 | { 107 | "id": line_fields[0], 108 | "text": sent_text, 109 | } 110 | ) 111 | 112 | for sentence in sentences: 113 | sent_id = int(sentence["id"]) 114 | sent_text = sentence["text"] 115 | if sent_id in page_to_evidence_sents[page_id]: 116 | page_to_evidence_sents[page_id][sent_id] = sent_text 117 | 118 | data = [] 119 | for page_id in page_to_evidence_sents: 120 | for sent_id in page_to_evidence_sents[page_id]: 121 | sent_text = page_to_evidence_sents[page_id][sent_id] 122 | data.append( 123 | { 124 | "page_id": page_id, 125 | "sent_id": sent_id, 126 | "text": sent_text, 127 | } 128 | ) 129 | 130 | n = len(data) 131 | print("{} examples in the dataset".format(n)) 132 | return utils.chunk_it(data, num_chunks) 133 | 134 | def process_chunk(self, chunk, ks, chunk_id=-1): 135 | missing_pages = 0.0 136 | exact_match = 0.0 137 | fuzzy_match = 0.0 138 | n = len(chunk) 139 | kilt_data = [] 140 | metadata = [] 141 | 142 | for idx, datapoint in enumerate(chunk): 143 | print( 144 | "t: {}, p: {:.2f} %, mp: {:.1f}, exact: {:.1f}, fuzzy: {:.1f}".format( 145 | chunk_id, 146 | round(idx * 100 / n, 2), 147 | missing_pages, 148 | exact_match, 149 | fuzzy_match, 150 | ), 151 | end="\r", 152 | ) 153 | sys.stdout.flush() 154 | 155 | page_id = datapoint["page_id"] 156 | sent_id = datapoint["sent_id"] 157 | text = datapoint["text"] 158 | 159 | if not text or text == None or len(text) == 0: 160 | continue 161 | 162 | url = "https://en.wikipedia.org/wiki/" + self._normalize( 163 | datapoint["page_id"] 164 | ) 165 | page = ks.get_page_from_url(url) 166 | if not page: 167 | missing_pages += 1 168 | else: 169 | # get and validate evidence sentence 170 | 171 | local_sem = 0.0 172 | local_sfm = 0.0 173 | 174 | kilt_record = { 175 | # original data point id if available otherwise unique id 176 | "page_id": page_id, 177 | "sentence_id": sent_id, 178 | "evidence_text": text, 179 | } 180 | 181 | kilt_record_output = [] 182 | 183 | paragraph_id, start_character, end_character, bleu = utils.match_answer( 184 | text, page, nlp=self.nlp, debug=False 185 | ) 186 | 187 | kilt_record_output.append( 188 | { 189 | # answer in textual form 190 | "answer": text, 191 | "provenance": [ 192 | # list of relevant WikipediaPages / Spans as provenance for the answer from the ks 193 | { 194 | "wikipedia_id": page[ 195 | "wikipedia_id" 196 | ], # *mandatory* - ID Wikipedia Page 197 | "title": page[ 198 | "wikipedia_title" 199 | ], # *mandatory* - Title Wikipedia Page 200 | "start_paragraph_id": paragraph_id, # start paragraph id with relevant info 201 | "start_character": start_character, 202 | "end_paragraph_id": paragraph_id, # end paragraph id 203 | "end_character": end_character, 204 | "bleu_score": bleu, # 1.0 when gold data is exactly matched, lower for fuzzy matches 205 | "meta": { # dataset/task specific 206 | "fever_page_id": page_id, 207 | "fever_sentence_id": sent_id, 208 | }, 209 | } 210 | ], 211 | } 212 | ) 213 | 214 | if bleu == 1: 215 | local_sem += 1 216 | elif bleu < 1 and bleu >= 0: 217 | local_sfm += 1 218 | else: 219 | print("ERROR: invalid bleu: {}".format(bleu)) 220 | sys.exit(-1) 221 | 222 | # update kilt data 223 | kilt_record["output"] = kilt_record_output 224 | kilt_data.append(kilt_record) 225 | 226 | exact_match += local_sem # / len(short_answers) 227 | fuzzy_match += local_sfm # / len(short_answers) 228 | 229 | metadata = [missing_pages, exact_match, fuzzy_match] 230 | 231 | return kilt_data, metadata 232 | 233 | def postprocess_metadata(self, metadata): 234 | missing_pages = 0.0 235 | exact_match = 0.0 236 | fuzzy_match = 0.0 237 | for met in metadata: 238 | if met == []: 239 | continue 240 | mp, sem, sfm = met 241 | missing_pages += mp 242 | exact_match += sem 243 | fuzzy_match += sfm 244 | 245 | print("Print stats") 246 | msg = "\n n: {:.1f}, missing pages: {:.1f}, exact match: {:.1f}, fuzzy match: {:.1f}".format( 247 | 0, missing_pages, exact_match, fuzzy_match 248 | ) 249 | print(msg) 250 | 251 | f = open(self.log_file, "w+") 252 | f.write(msg) 253 | f.close() 254 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![KILT logo](./img/KILT_logo.png) 2 | -------------------------------------------------------------------------------- 3 | 4 | # A Benchmark for Knowledge Intensive Language Tasks 5 | 6 | [http://kiltbenchmark.com/](http://kiltbenchmark.com) 7 | 8 | 9 | The KILT benchmark is described in the following paper: 10 | 11 | ```bibtex 12 | @inproceedings{petroni-etal-2021-kilt, 13 | title = "{KILT}: a Benchmark for Knowledge Intensive Language Tasks", 14 | author = {Petroni, Fabio and Piktus, Aleksandra and 15 | Fan, Angela and Lewis, Patrick and 16 | Yazdani, Majid and De Cao, Nicola and 17 | Thorne, James and Jernite, Yacine and 18 | Karpukhin, Vladimir and Maillard, Jean and 19 | Plachouras, Vassilis and Rockt{\"a}schel, Tim and 20 | Riedel, Sebastian}, 21 | booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association 22 | for Computational Linguistics: Human Language Technologies", 23 | month = jun, 24 | year = "2021", 25 | address = "Online", 26 | publisher = "Association for Computational Linguistics", 27 | url = "https://aclanthology.org/2021.naacl-main.200", 28 | doi = "10.18653/v1/2021.naacl-main.200", 29 | pages = "2523--2544", 30 | } 31 | ``` 32 | 33 | [https://arxiv.org/abs/2009.02252](https://arxiv.org/abs/2009.02252) 34 | 35 | 36 | ## Setup the env 37 | 38 | ```bash 39 | conda create -n kilt37 -y python=3.7 && conda activate kilt37 40 | pip install -e . 41 | ``` 42 | 43 | ## KILT knowledge source 44 | 45 | The KILT knowledge source can be downloaded here: [kilt_knowledgesource.json](http://dl.fbaipublicfiles.com/KILT/kilt_knowledgesource.json) (34.76GiB).
46 | It is based on the [2019/08/01 Wikipedia dump](http://dl.fbaipublicfiles.com/BLINK/enwiki-pages-articles.xml.bz2).
47 | We use [mongoDB](https://www.mongodb.com) to index the knowledge base (but you can use any json-based db).
48 | To import the knowledge source in mongoDB run: 49 | 50 | ```bash 51 | wget http://dl.fbaipublicfiles.com/KILT/kilt_knowledgesource.json 52 | mongoimport --db kilt --collection knowledgesource --file kilt_knowledgesource.json 53 | ``` 54 | 55 | 56 | ### Structure of each record 57 | 58 | ```python 59 | { 60 | 'wikipedia_title': 'Email marketing', 61 | 'wikipedia_id': 1101759, 62 | 'text': ['p1', 'p2',...., 'pn'], # list of paragraph text 63 | 'anchors': [{"text":,"href":,"paragraph_id":,"start":,"end":} ] , 64 | 'categories': 'comma separated list of categories' 65 | 'history': # some info from wikipedia, including original url 66 | 'wikidata_info': # wikidata info 67 | } 68 | ``` 69 | 70 | ### Query the knowledge source 71 | 72 | ```python 73 | from kilt.knowledge_source import KnowledgeSource 74 | 75 | # get the knowledge souce 76 | ks = KnowledgeSource() 77 | 78 | # count entries - 5903530 79 | ks.get_num_pages() 80 | 81 | # get page by id 82 | page = ks.get_page_by_id(27097632) 83 | 84 | # get pages by title 85 | page = ks.get_page_by_title("Michael Jordan") 86 | ``` 87 | 88 | 89 | ## KILT data 90 | 91 | Examples: 92 | ![KILT example](./img/infographic_e.jpg) 93 | 94 | ### download the data 95 | 96 | ```bash 97 | mkdir data 98 | python scripts/download_all_kilt_data.py 99 | python scripts/get_triviaqa_input.py 100 | ``` 101 | 102 | You can also download and use the KILT data through [the HuggingFace's nlp library](https://huggingface.co/datasets?search=kilt). 103 | 104 | Note that we release only the input for the test sets, without answers. 105 | Test answers are used for [the KILT challenge on EvalAI](https://evalai.cloudcv.org/web/challenges/challenge-page/689/overview) where participants can upload their models’ predictions and be listed on the public leaderboard (there are strict submission limits to discourage overfitting on test data). 106 | 107 | ### KILT data format 108 | 109 | ```python 110 | {'id': # original data point id if available otherwise unique id 111 | 'input': # question / claim / sentence / etc 112 | 'output': [ # each element might contain an answer, a provenance or both 113 | { 114 | 'answer': # answer in textual form 115 | 'provenance': [ 116 | # evidence set for the answer from the KILT ks 117 | { 118 | 'wikipedia_id': # *mandatory* 119 | 'title': 120 | 'section': 121 | 'start_paragraph_id': 122 | 'start_character': 123 | 'end_paragraph_id': 124 | 'end_character': 125 | 'bleu_score': # wrt original evidence 126 | 'meta': # dataset/task specific 127 | } 128 | ] 129 | } 130 | ] 131 | 'meta': # dataset/task specific 132 | } 133 | ``` 134 | 135 | ### KILT data catalogue 136 | 137 | | dataset | task | train | dev | test | 138 | | ------------- | ------------- | ------------- | ------------- | ------------- | 139 | | [FEVER](https://fever.ai) | Fact Checking | [fever-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/fever-train-kilt.jsonl)
(104,966 lines, 38.9MiB) | [fever-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/fever-dev-kilt.jsonl)
(10,444 lines, 6.17MiB) | [fever-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/fever-test_without_answers-kilt.jsonl)
(10,100 lines, 839kiB) | 140 | | [AIDA CoNLL-YAGO](https://www.mpi-inf.mpg.de/departments/databases-and-information-systems/research/ambiverse-nlu/aida/downloads) | Entity Linking | [aidayago2-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/aidayago2-train-kilt.jsonl)
(18,395 lines, 70.1MiB) | [aidayago2-dev-kilt.jsonl]( http://dl.fbaipublicfiles.com/KILT/aidayago2-dev-kilt.jsonl)
(4,784 lines, 21.1MiB) | [aidayago2-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/aidayago2-test_without_answers-kilt.jsonl)
(4,463 lines, 14.4MiB) | 141 | | [WNED-WIKI](https://github.com/U-Alberta/wned) | Entity Linking | - | [wned-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/wned-dev-kilt.jsonl)
(3,396 lines, 12.9MiB) | [wned-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/wned-test_without_answers-kilt.jsonl)
(3,376 lines, 13.3MiB) | 142 | | [WNED-CWEB](https://github.com/U-Alberta/wned) | Entity Linking | - | [cweb-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/cweb-dev-kilt.jsonl)
(5,599 lines, 90.2MiB) | [cweb-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/cweb-test_without_answers-kilt.jsonl)
(5,543 lines, 100MiB) | 143 | | [T-REx](https://hadyelsahar.github.io/t-rex) | Slot Filling | [trex-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/trex-train-kilt.jsonl)
(2,284,168 lines, 1.75GiB) | [trex-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/trex-dev-kilt.jsonl)
(5,000 lines, 3.80MiB) | [trex-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/trex-test_without_answers-kilt.jsonl)
(5,000 lines, 896kiB) | 144 | | [Zero-Shot RE](http://nlp.cs.washington.edu/zeroshot) | Slot Filling | [structured_zeroshot-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/structured_zeroshot-train-kilt.jsonl)
(147,909 lines, 71.4MiB) | [structured_zeroshot-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/structured_zeroshot-dev-kilt.jsonl)
(3,724 lines, 2.27MiB) | [structured_zeroshot-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/structured_zeroshot-test_without_answers-kilt.jsonl)
(4,966 lines, 1.22MiB) | 145 | | [Natural Questions](https://ai.google.com/research/NaturalQuestions) | Open Domain QA | [nq-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/nq-train-kilt.jsonl)
(87,372 lines, 51.9MiB) | [nq-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/nq-dev-kilt.jsonl)
(2,837 lines, 7.94MiB) | [nq-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/nq-test_without_answers-kilt.jsonl)
(1,444 lines, 334kiB) | 146 | | [HotpotQA](https://hotpotqa.github.io) | Open Domain QA | [hotpotqa-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/hotpotqa-train-kilt.jsonl)
(88,869 lines, 52.8MiB) | [hotpotqa-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/hotpotqa-dev-kilt.jsonl)
(5,600 lines, 3.97MiB) | [hotpotqa-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/hotpotqa-test_without_answers-kilt.jsonl)
(5,569 lines, 778kiB) | 147 | | [TriviaQA](http://nlp.cs.washington.edu/triviaqa) | Open Domain QA | [triviaqa-train_id-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/triviaqa-train_id-kilt.jsonl)*
(61,844 lines, 102MiB) | [triviaqa-dev_id-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/triviaqa-dev_id-kilt.jsonl)*
(5,359 lines, 9.81MiB) | [triviaqa-test_id_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/triviaqa-test_id_without_answers-kilt.jsonl)*
(6,586 lines, 123kiB) | 148 | | [ELI5](https://facebookresearch.github.io/ELI5/explore.html) | Open Domain QA | [eli5-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/eli5-train-kilt.jsonl)
(272,634 lines, 548MiB) | [eli5-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/eli5-dev-kilt.jsonl)
(1,507 lines, 14.1MiB) | [eli5-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/eli5-test_without_answers-kilt.jsonl)
(600 lines, 99kiB) | 149 | | [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia) | Dialogue | [wow-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/wow-train-kilt.jsonl)
(63,734 lines, 48.9MiB) | [wow-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/wow-dev-kilt.jsonl)
(3,054 lines, 2.42MiB) | [wow-test_without_answers-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/wow-test_without_answers-kilt.jsonl)
(2,944 lines, 1.29MiB)| 150 | 151 | * run `python scripts/get_triviaqa_input.py` to get the question associated with each id 152 | 153 | ### Additional data 154 | 155 | For Entity Linking, in addition to the AIDA CoNLL-YAGO train set, the whole knowledge source can be used as training data by exploiting hyperlinks. To facilitate experimentation, we release such data in KILT format following the splits of [BLINK](https://github.com/facebookresearch/BLINK): 156 | - [blink-train-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/blink-train-kilt.jsonl) (9M lines) 157 | - [blink-dev-kilt.jsonl](http://dl.fbaipublicfiles.com/KILT/blink-dev-kilt.jsonl) (10,000 lines) 158 | 159 | We also provide a [script](scripts/map_TAC-KBP2010_to_KILT.py) to map the TAC-KBP 2010 dataset to the knowledge source and format of KILT. 160 | 161 | ## Run the retrieval evaluation 162 | 163 | Please follow [this README](kilt/retrievers/README.md). 164 | 165 | 166 | ## Mapping scripts 167 | 168 | Mapping scripts are located in `kilt/datasets/`. 169 | See `scripts/map_datasets.py` for an example. 170 | 171 | 172 | ## Troubleshooting 173 | 174 | If the module cannot be found, preface the python command with `PYTHONPATH=.` 175 | 176 | If the experiments fail on GPU memory allocation, try reducing batch size. 177 | 178 | 179 | ## License 180 | KILT is MIT licensed. See the [LICENSE](LICENSE) file for details. 181 | -------------------------------------------------------------------------------- /kilt/datasets/natural_questions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | import json 11 | import spacy 12 | import sys 13 | import re 14 | import kilt.kilt_utils as utils 15 | from kilt.datasets.base_dataset import Dataset 16 | 17 | 18 | class NaturalQuestionsDataset(Dataset): 19 | def __init__(self, name, input_file, output_file, log_file): 20 | super().__init__(name) 21 | self.input_file = input_file 22 | self.output_file = output_file 23 | self.log_file = log_file 24 | self.nlp = spacy.load("en_core_web_sm") 25 | 26 | def get_chunks(self, num_chunks): 27 | all_data = [] 28 | with open(self.input_file, "r") as infile: 29 | for line in infile: 30 | data = json.loads(line) 31 | all_data.append(data) 32 | 33 | n = len(all_data) 34 | print("{} examples in the dataset".format(n)) 35 | return utils.chunk_it(all_data, num_chunks) 36 | 37 | def process_chunk(self, chunk, ks, chunk_id=-1): 38 | missing_pages = 0.0 39 | short_exact_match = 0.0 40 | short_fuzzy_match = 0.0 41 | n = len(chunk) 42 | kilt_data = [] 43 | metadata = [] 44 | 45 | for idx, datapoint in enumerate(chunk): 46 | 47 | # from standard to simplified format 48 | if "document_text" not in datapoint: 49 | # wget https://raw.githubusercontent.com/google-research-datasets/natural-questions/master/text_utils.py 50 | from text_utils import simplify_nq_example 51 | datapoint = simplify_nq_example(datapoint) 52 | 53 | print( 54 | "t: {}, p: {:.2f} %, mp: {:.1f}, exact: {:.1f}, fuzzy: {:.1f}".format( 55 | chunk_id, 56 | round(idx * 100 / n, 2), 57 | missing_pages, 58 | short_exact_match, 59 | short_fuzzy_match, 60 | ), 61 | end="\r", 62 | ) 63 | sys.stdout.flush() 64 | 65 | url = datapoint["document_url"] 66 | page = ks.get_page_from_url(url) 67 | 68 | if not page: 69 | print("ERROR, not page!") 70 | missing_pages += 1 71 | else: 72 | # get and validate annotations 73 | annotations = datapoint["annotations"] 74 | 75 | kilt_record = { 76 | # original data point id if available otherwise unique id 77 | "id": datapoint["example_id"], 78 | # question / claim / sentence 79 | "input": datapoint["question_text"], 80 | } 81 | 82 | kilt_record_output = [] 83 | local_sem = 0.0 84 | local_sfm = 0.0 85 | 86 | for annotation in annotations: 87 | 88 | if "short_answers" in annotation: 89 | short_answers = annotation["short_answers"] 90 | 91 | # scan all possible short answers 92 | for answer_index in range(len(short_answers)): 93 | s = short_answers[answer_index]["start_token"] 94 | e = short_answers[answer_index]["end_token"] 95 | short_answer = datapoint["document_text"].split()[s:e] 96 | answer_span = " ".join(short_answer).strip() 97 | 98 | ( 99 | paragraph_id, 100 | start_character, 101 | end_character, 102 | bleu, 103 | ) = utils.match_answer( 104 | answer_span, page, nlp=self.nlp, debug=False 105 | ) 106 | 107 | kilt_record_output.append( 108 | { 109 | # answer in textual form 110 | "answer": answer_span, 111 | "provenance": [ 112 | # list of relevant WikipediaPages / Spans as provenance for the answer from the ks 113 | { 114 | "wikipedia_id": page[ 115 | "wikipedia_id" 116 | ], # *mandatory* - ID Wikipedia Page 117 | "title": page[ 118 | "wikipedia_title" 119 | ], # *mandatory* - Title Wikipedia Page 120 | "start_paragraph_id": paragraph_id, # start paragraph id with relevant info 121 | "start_character": start_character, 122 | "end_paragraph_id": paragraph_id, # end paragraph id 123 | "end_character": end_character, 124 | "bleu_score": bleu, # 1.0 when gold data is exactly matched, lower for fuzzy matches 125 | "meta": { # dataset/task specific 126 | "yes_no_answer": annotations[0][ 127 | "yes_no_answer" 128 | ], 129 | "annotation_id": annotations[0][ 130 | "annotation_id" 131 | ], 132 | }, 133 | } 134 | ], 135 | } 136 | ) 137 | 138 | if bleu == 1: 139 | local_sem += 1 140 | elif bleu < 1 and bleu >= 0: 141 | local_sfm += 1 142 | else: 143 | print("ERROR: invalid bleu: {}".format(bleu)) 144 | sys.exit(-1) 145 | 146 | if "long_answer" in annotation: 147 | 148 | long_answer = annotation["long_answer"] 149 | 150 | s = long_answer["start_token"] 151 | e = long_answer["end_token"] 152 | long_answer = datapoint["document_text"].split()[s:e] 153 | answer_span = " ".join(long_answer).strip() 154 | 155 | ( 156 | paragraph_id, 157 | start_character, 158 | end_character, 159 | bleu, 160 | ) = utils.match_answer( 161 | answer_span, page, nlp=self.nlp, debug=False 162 | ) 163 | 164 | kilt_record_output.append( 165 | { 166 | # answer in textual form 167 | "answer": answer_span, 168 | "provenance": [ 169 | # list of relevant WikipediaPages / Spans as provenance for the answer from the ks 170 | { 171 | "wikipedia_id": page[ 172 | "wikipedia_id" 173 | ], # *mandatory* - ID Wikipedia Page 174 | "title": page[ 175 | "wikipedia_title" 176 | ], # *mandatory* - Title Wikipedia Page 177 | "start_paragraph_id": paragraph_id, # start paragraph id with relevant info 178 | "start_character": start_character, 179 | "end_paragraph_id": paragraph_id, # end paragraph id 180 | "end_character": end_character, 181 | "bleu_score": bleu, # 1.0 when gold data is exactly matched, lower for fuzzy matches 182 | "meta": { # dataset/task specific 183 | "yes_no_answer": annotations[0][ 184 | "yes_no_answer" 185 | ], 186 | "annotation_id": annotations[0][ 187 | "annotation_id" 188 | ], 189 | }, 190 | } 191 | ], 192 | } 193 | ) 194 | 195 | if bleu == 1: 196 | local_sem += 1 197 | elif bleu < 1 and bleu >= 0: 198 | local_sfm += 1 199 | else: 200 | print("ERROR: invalid bleu: {}".format(bleu)) 201 | sys.exit(-1) 202 | 203 | # update kilt data 204 | kilt_record["output"] = kilt_record_output 205 | kilt_data.append(kilt_record) 206 | 207 | # average by answers per single question 208 | # if len(short_answers) > 0: 209 | # short_exact_match += local_sem / len(short_answers) 210 | # short_fuzzy_match += local_sfm / len(short_answers) 211 | 212 | metadata = [missing_pages, short_exact_match, short_fuzzy_match] 213 | return kilt_data, metadata 214 | 215 | def postprocess_metadata(self, metadata): 216 | missing_pages = 0.0 217 | short_exact_match = 0.0 218 | short_fuzzy_match = 0.0 219 | for met in metadata: 220 | if met == []: 221 | continue 222 | mp, sem, sfm = met 223 | missing_pages += mp 224 | short_exact_match += sem 225 | short_fuzzy_match += sfm 226 | 227 | print("Print stats") 228 | msg = "\n n: {:.1f}, missing pages: {:.1f}, short exact match: {:.1f}, short fuzzy match: {:.1f}".format( 229 | 0, missing_pages, short_exact_match, short_fuzzy_match 230 | ) 231 | print(msg) 232 | 233 | f = open(self.log_file, "w+") 234 | f.write(msg) 235 | f.close() -------------------------------------------------------------------------------- /kilt/readers/t5/finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import glob 10 | import logging 11 | import os 12 | import time 13 | 14 | import torch 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data.dataset import ConcatDataset 17 | from transformers import get_linear_schedule_with_warmup 18 | from transformers.tokenization_utils import trim_batch 19 | 20 | from base_transformer import BaseTransformer, add_generic_args, generic_train 21 | from data import KiltDataset, seq2seq_to_kilt, dataset_config 22 | from eval_downstream import normalize_answer 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class Seq2seqTransformer(BaseTransformer): 28 | 29 | def __init__(self, hparams): 30 | super().__init__(hparams, num_labels=None) 31 | self.lr_scheduler = None 32 | self.devsets = {} 33 | self.em = -1 34 | self.dataset_list = self.hparams.dataset.split(',') 35 | self.eval_batch_size = 100000 36 | self.train_batch_size = 100000 37 | self.source_length = -1 38 | self.target_length = -1 39 | 40 | special_tokens = [] 41 | 42 | for i in range(0, 101): 43 | special_tokens.append('') 44 | 45 | special_tokens.extend(['[START_ENT]', '[END_ENT]', 'Question Answering:', 'Entity Linking:', 46 | 'Fact Checking:', 'Dialogue:', 'Relation Extraction:', '[SEP]']) # 47 | self.tokenizer.add_special_tokens( 48 | {'additional_special_tokens': special_tokens}) 49 | 50 | fevers_classes = ["", ""] 51 | 52 | self.tokenizer.add_tokens(fevers_classes) 53 | 54 | self.model.resize_token_embeddings(len(self.tokenizer)) 55 | 56 | self.bad_words = [[self.tokenizer.convert_tokens_to_ids(bad_word)] for bad_word in 57 | self.tokenizer.additional_special_tokens] 58 | 59 | for d in self.dataset_list: 60 | train_batch = int(dataset_config[d]['train_batch']) 61 | eval_batch = int(dataset_config[d]['eval_batch']) 62 | source_length = int(dataset_config[d]['source_length']) 63 | target_length = int(dataset_config[d]['target_length']) 64 | if train_batch < self.train_batch_size: 65 | self.train_batch_size = train_batch 66 | if eval_batch < self.eval_batch_size: 67 | self.eval_batch_size = eval_batch 68 | if source_length > self.source_length: 69 | self.source_length = source_length 70 | if target_length > self.target_length: 71 | self.target_length = target_length 72 | 73 | self.data_dir = self.hparams.data_dir 74 | self.output_dir = self.hparams.output_dir 75 | 76 | def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None): 77 | return self.model( 78 | input_ids, attention_mask=attention_mask, lm_labels=lm_labels, 79 | ) 80 | 81 | def _step(self, batch): 82 | pad_token_id = self.tokenizer.pad_token_id 83 | source_ids, source_mask, y, ids = batch["source_ids"], batch["source_mask"], batch["target_ids"], batch["ids"] 84 | 85 | lm_labels = y.clone() 86 | lm_labels[y == pad_token_id] = -100 87 | 88 | outputs = self(source_ids, attention_mask=source_mask, lm_labels=lm_labels, ) 89 | 90 | loss = outputs[0] 91 | 92 | return loss 93 | 94 | def training_step(self, batch, batch_idx): 95 | loss = self._step(batch) 96 | 97 | tensorboard_logs = {"train_loss": loss} 98 | return {"loss": loss, "log": tensorboard_logs} 99 | 100 | def validation_step(self, batch, batch_idx): 101 | pad_token_id = self.tokenizer.pad_token_id 102 | 103 | source_ids, source_mask, y = KiltDataset.trim_seq2seq_batch(batch, pad_token_id) 104 | generated_ids = self.model.generate( 105 | input_ids=source_ids, 106 | attention_mask=source_mask, 107 | num_beams=1, 108 | max_length=self.target_length, 109 | repetition_penalty=1, 110 | length_penalty=1.0, 111 | early_stopping=True, 112 | use_cache=True, 113 | do_sample=False, 114 | top_p=0.95, 115 | top_k=50, 116 | bad_words_ids=self.bad_words 117 | ) 118 | 119 | preds = [self.tokenizer.decode(g) for g in generated_ids] 120 | target = [self.tokenizer.decode(t) for t in y] 121 | loss = self._step(batch) 122 | sources = [self.tokenizer.decode(s) for s in source_ids] 123 | 124 | return {"val_loss": loss, 'sources': sources, "preds": preds, "target": target, "ids": batch["ids"]} 125 | 126 | def validation_end(self, outputs): 127 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 128 | tensorboard_logs = {"val_loss": avg_loss} 129 | 130 | preds = [] 131 | ids = [] 132 | sources = [] 133 | for batch in outputs: 134 | sources.extend(batch['sources']) 135 | preds.extend(batch['preds']) 136 | ids.extend(batch["ids"]) 137 | em = 0 138 | for q_id, pred in set(zip(ids, preds)): 139 | targets = [normalize_answer(x) for x in self.devsets[q_id]] 140 | 141 | if normalize_answer(pred) in targets: 142 | em = em + 1 143 | if em > self.em: 144 | self.em = em 145 | self.trainer.save_checkpoint(self.output_dir + '/' + "best_em.ckpt") 146 | seq2seq_to_kilt(set(ids), set(sources), set(preds), self.hparams.output_dir, 147 | self.hparams.dataset, 'dev') 148 | return {"avg_val_loss": avg_loss, "log": tensorboard_logs, "EM": em} 149 | 150 | def test_step(self, batch, batch_idx): 151 | pad_token_id = self.tokenizer.pad_token_id 152 | 153 | source_ids, source_mask, y = KiltDataset.trim_seq2seq_batch(batch, pad_token_id) 154 | # NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_kilt_task.py 155 | 156 | generated_ids = self.model.generate( 157 | input_ids=source_ids, 158 | attention_mask=source_mask, 159 | num_beams=1, 160 | max_length=self.target_length, 161 | repetition_penalty=1, 162 | length_penalty=1.0, 163 | early_stopping=True, 164 | use_cache=True, 165 | do_sample=False, 166 | top_p=0.95, 167 | top_k=50, 168 | bad_words_ids=self.bad_words 169 | ) 170 | preds = [self.tokenizer.decode(g) for g in generated_ids] 171 | target = [self.tokenizer.decode(t) for t in y] 172 | loss = self._step(batch) 173 | sources = [self.tokenizer.decode(s) for s in source_ids] 174 | return {"val_loss": loss, 'sources': sources, "preds": preds, "target": target, "ids": batch["ids"]} 175 | 176 | def test_end(self, outputs): 177 | sources = [] 178 | preds = [] 179 | ids = [] 180 | for batch in outputs: 181 | sources.extend(batch['sources']) 182 | preds.extend(batch['preds']) 183 | ids.extend(batch["ids"]) 184 | 185 | seq2seq_to_kilt(ids, sources, preds, self.hparams.output_dir, self.hparams.dataset, 'test') 186 | 187 | return self.test_epoch_end(outputs) 188 | 189 | def test_epoch_end(self, outputs): 190 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 191 | tensorboard_logs = {"val_loss": avg_loss} 192 | return {"avg_val_loss": avg_loss, "log": tensorboard_logs} 193 | 194 | def collate_fn(self, batch): 195 | input_ids = torch.stack([x["source_ids"] for x in batch]) 196 | masks = torch.stack([x["source_mask"] for x in batch]) 197 | target_ids = torch.stack([x["target_ids"] for x in batch]) 198 | ids = [x["id"] for x in batch] 199 | pad_token_id = self.tokenizer.pad_token_id 200 | y = trim_batch(target_ids, pad_token_id) 201 | source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 202 | return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y, "ids": ids} 203 | 204 | def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: 205 | datasets = [] 206 | for d in self.dataset_list: 207 | datasets.append( 208 | KiltDataset(self.tokenizer, self.data_dir, d, type_path, self.source_length, self.target_length, 209 | self.output_dir)) 210 | if type_path == 'dev': 211 | for x in datasets: 212 | self.devsets.update(x.id_targets) 213 | concat_dataset = ConcatDataset(datasets) 214 | dataloader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate_fn) 215 | 216 | print(type_path, dataloader.batch_size, concat_dataset.__len__()) 217 | return dataloader 218 | 219 | def train_dataloader(self) -> DataLoader: 220 | dataloader = self.get_dataloader("train", batch_size=self.train_batch_size, shuffle=True) 221 | t_total = ( 222 | (len(dataloader.dataset) // (self.train_batch_size * max(1, self.hparams.n_gpu))) 223 | // self.hparams.gradient_accumulation_steps 224 | * float(self.hparams.num_train_epochs) 225 | ) 226 | scheduler = get_linear_schedule_with_warmup( 227 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total 228 | ) 229 | self.lr_scheduler = scheduler 230 | 231 | return dataloader 232 | 233 | def val_dataloader(self) -> DataLoader: 234 | return self.get_dataloader("dev", batch_size=self.eval_batch_size) 235 | 236 | def test_dataloader(self) -> DataLoader: 237 | return self.get_dataloader("test", batch_size=self.eval_batch_size) 238 | 239 | @staticmethod 240 | def add_model_specific_args(arg_parser, root_dir): 241 | BaseTransformer.add_model_specific_args(arg_parser, root_dir) 242 | 243 | arg_parser.add_argument( 244 | "--data_dir", 245 | default=None, 246 | type=str, 247 | required=True, 248 | help="The input data dir. Should contain the dataset files for the task.", 249 | ) 250 | arg_parser.add_argument("--dataset", required=True, type=str) 251 | 252 | return arg_parser 253 | 254 | 255 | def main(arguments): 256 | # If output_dir not provided, a folder will be generated in pwd 257 | if not arguments.output_dir: 258 | arguments.output_dir = os.path.join("./results", f"{arguments.task}_{time.strftime('%Y%m%d_%H%M%S')}", ) 259 | os.makedirs(arguments.output_dir) 260 | model = Seq2seqTransformer(arguments) 261 | trainer = generic_train(model, arguments) 262 | 263 | if arguments.do_predict: 264 | checkpoints = list( 265 | sorted(glob.glob(os.path.join(arguments.output_dir, "*.ckpt"), recursive=True))) 266 | model = model.load_from_checkpoint(checkpoints[-1]) 267 | model.hparams.dataset = arguments.dataset 268 | model.dataset_list = arguments.dataset.split(',') 269 | 270 | trainer.test(model) 271 | 272 | 273 | if __name__ == "__main__": 274 | 275 | parser = argparse.ArgumentParser() 276 | add_generic_args(parser) 277 | parser = Seq2seqTransformer.add_model_specific_args(parser, os.getcwd()) 278 | args = parser.parse_args() 279 | 280 | main(args) 281 | -------------------------------------------------------------------------------- /kilt/readers/t5/base_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import logging 10 | import os 11 | import random 12 | 13 | import numpy as np 14 | import pytorch_lightning as pl 15 | import torch 16 | from transformers import ( 17 | AdamW, 18 | AutoConfig, 19 | AutoModelWithLMHead, 20 | AutoTokenizer, 21 | get_linear_schedule_with_warmup, 22 | ) 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def set_seed(args: argparse.Namespace): 28 | random.seed(args.seed) 29 | np.random.seed(args.seed) 30 | torch.manual_seed(args.seed) 31 | if args.n_gpu > 0: 32 | torch.cuda.manual_seed_all(args.seed) 33 | 34 | 35 | class BaseTransformer(pl.LightningModule): 36 | def __init__(self, hparams: argparse.Namespace, num_labels=None, **config_kwargs): 37 | "Initialize a model." 38 | 39 | super().__init__() 40 | self.hparams = hparams 41 | cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None 42 | self.config = AutoConfig.from_pretrained( 43 | self.hparams.config_name 44 | if self.hparams.config_name 45 | else self.hparams.model_name_or_path, 46 | **({"num_labels": num_labels} if num_labels is not None else {}), 47 | cache_dir=cache_dir, 48 | **config_kwargs, 49 | ) 50 | self.tokenizer = AutoTokenizer.from_pretrained( 51 | self.hparams.tokenizer_name 52 | if self.hparams.tokenizer_name 53 | else self.hparams.model_name_or_path, 54 | cache_dir=cache_dir, 55 | ) 56 | self.model = AutoModelWithLMHead.from_pretrained( 57 | self.hparams.model_name_or_path, 58 | config=self.config, 59 | cache_dir=cache_dir, 60 | ) 61 | 62 | def is_logger(self): 63 | return self.trainer.proc_rank <= 0 64 | 65 | def configure_optimizers(self): 66 | "Prepare optimizer and schedule (linear warmup and decay)" 67 | 68 | model = self.model 69 | no_decay = ["bias", "LayerNorm.weight"] 70 | optimizer_grouped_parameters = [ 71 | dict( 72 | params=[ 73 | p 74 | for n, p in model.named_parameters() 75 | if not any(nd in n for nd in no_decay) 76 | ], 77 | weight_decay=self.hparams.weight_decay, 78 | ), 79 | { 80 | "params": [ 81 | p 82 | for n, p in model.named_parameters() 83 | if any(nd in n for nd in no_decay) 84 | ], 85 | "weight_decay": 0.0, 86 | }, 87 | ] 88 | optimizer = AdamW( 89 | optimizer_grouped_parameters, 90 | lr=self.hparams.learning_rate, 91 | eps=self.hparams.adam_epsilon, 92 | ) 93 | self.opt = optimizer 94 | return [optimizer] 95 | 96 | def optimizer_step( 97 | self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None 98 | ): 99 | if self.trainer.use_tpu: 100 | xm.optimizer_step(optimizer) 101 | else: 102 | optimizer.step() 103 | optimizer.zero_grad() 104 | self.lr_scheduler.step() 105 | 106 | def get_tqdm_dict(self): 107 | avg_loss = getattr(self.trainer, "avg_loss", 0.0) 108 | tqdm_dict = { 109 | "loss": "{:.3f}".format(avg_loss), 110 | "lr": self.lr_scheduler.get_last_lr()[-1], 111 | } 112 | return tqdm_dict 113 | 114 | def test_step(self, batch, batch_nb): 115 | return self.validation_step(batch, batch_nb) 116 | 117 | def test_end(self, outputs): 118 | return self.validation_end(outputs) 119 | 120 | def train_dataloader(self): 121 | train_batch_size = self.hparams.train_batch_size 122 | dataloader = self.load_dataset("train", train_batch_size) 123 | 124 | t_total = ( 125 | (len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu))) 126 | // self.hparams.gradient_accumulation_steps 127 | * float(self.hparams.num_train_epochs) 128 | ) 129 | scheduler = get_linear_schedule_with_warmup( 130 | self.opt, 131 | num_warmup_steps=self.hparams.warmup_steps, 132 | num_training_steps=t_total, 133 | ) 134 | self.lr_scheduler = scheduler 135 | return dataloader 136 | 137 | def val_dataloader(self): 138 | return self.load_dataset("dev", self.hparams.eval_batch_size) 139 | 140 | def test_dataloader(self): 141 | return self.load_dataset("test", self.hparams.eval_batch_size) 142 | 143 | def _feature_file(self, mode): 144 | return os.path.join( 145 | self.hparams.data_dir, 146 | "cached_{}_{}_{}".format( 147 | mode, 148 | list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), 149 | str(self.hparams.max_seq_length), 150 | ), 151 | ) 152 | 153 | @staticmethod 154 | def add_model_specific_args(parser, root_dir): 155 | parser.add_argument( 156 | "--model_name_or_path", 157 | default=None, 158 | type=str, 159 | required=True, 160 | help="Path to pretrained model or model identifier from huggingface.co/models", 161 | ) 162 | parser.add_argument( 163 | "--config_name", 164 | default="", 165 | type=str, 166 | help="Pretrained config name or path if not the same as model_name", 167 | ) 168 | parser.add_argument( 169 | "--tokenizer_name", 170 | default="", 171 | type=str, 172 | help="Pretrained tokenizer name or path if not the same as model_name", 173 | ) 174 | parser.add_argument( 175 | "--cache_dir", 176 | default="", 177 | type=str, 178 | help="Where do you want to store the pre-trained models downloaded from s3", 179 | ) 180 | parser.add_argument( 181 | "--learning_rate", 182 | default=5e-5, 183 | type=float, 184 | help="The initial learning rate for Adam.", 185 | ) 186 | parser.add_argument( 187 | "--weight_decay", 188 | default=0.0, 189 | type=float, 190 | help="Weight decay if we apply some.", 191 | ) 192 | parser.add_argument( 193 | "--adam_epsilon", 194 | default=1e-8, 195 | type=float, 196 | help="Epsilon for Adam optimizer.", 197 | ) 198 | parser.add_argument( 199 | "--warmup_steps", 200 | default=0, 201 | type=int, 202 | help="Linear warmup over warmup_steps.", 203 | ) 204 | parser.add_argument( 205 | "--num_train_epochs", 206 | default=3, 207 | type=int, 208 | help="Total number of training epochs to perform.", 209 | ) 210 | 211 | 212 | class LoggingCallback(pl.Callback): 213 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 214 | logger.info("***** Validation results *****") 215 | if pl_module.is_logger(): 216 | metrics = trainer.callback_metrics 217 | # Log results 218 | for key in sorted(metrics): 219 | if key not in ["log", "progress_bar"]: 220 | logger.info("{} = {}\n".format(key, str(metrics[key]))) 221 | 222 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 223 | logger.info("***** Test results *****") 224 | 225 | if pl_module.is_logger(): 226 | metrics = trainer.callback_metrics 227 | 228 | # Log and save results to file 229 | output_test_results_file = os.path.join( 230 | pl_module.hparams.output_dir, "test_results.txt" 231 | ) 232 | with open(output_test_results_file, "w") as writer: 233 | for key in sorted(metrics): 234 | if key not in ["log", "progress_bar"]: 235 | logger.info("{} = {}\n".format(key, str(metrics[key]))) 236 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 237 | 238 | 239 | def add_generic_args(parser): 240 | parser.add_argument( 241 | "--output_dir", 242 | default=None, 243 | type=str, 244 | required=True, 245 | help="The output directory where the model predictions and checkpoints will be written.", 246 | ) 247 | 248 | parser.add_argument( 249 | "--fp16", 250 | action="store_true", 251 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 252 | ) 253 | 254 | parser.add_argument( 255 | "--fp16_opt_level", 256 | type=str, 257 | default="O1", 258 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 259 | "See details at https://nvidia.github.io/apex/amp.html", 260 | ) 261 | 262 | parser.add_argument("--n_gpu", type=int, default=1) 263 | parser.add_argument("--n_tpu_cores", type=int, default=0) 264 | parser.add_argument( 265 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 266 | ) 267 | parser.add_argument( 268 | "--do_train", action="store_true", help="Whether to run training." 269 | ) 270 | parser.add_argument( 271 | "--do_predict", 272 | action="store_true", 273 | help="Whether to run predictions on the test set.", 274 | ) 275 | parser.add_argument( 276 | "--gradient_accumulation_steps", 277 | type=int, 278 | default=1, 279 | help="Number of updates steps to accumulate before performing a backward/update pass.", 280 | ) 281 | 282 | parser.add_argument( 283 | "--seed", type=int, default=42, help="random seed for initialization" 284 | ) 285 | 286 | 287 | def generic_train(model: BaseTransformer, args: argparse.Namespace): 288 | # init model 289 | set_seed(args) 290 | 291 | if ( 292 | os.path.exists(args.output_dir) 293 | and os.listdir(args.output_dir) 294 | and args.do_train 295 | ): 296 | raise ValueError( 297 | "Output directory ({}) already exists and is not empty.".format( 298 | args.output_dir 299 | ) 300 | ) 301 | 302 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 303 | filepath=args.output_dir, 304 | prefix="checkpoint", 305 | monitor="val_loss", 306 | mode="min", 307 | save_top_k=5, 308 | ) 309 | 310 | train_params = dict( 311 | accumulate_grad_batches=args.gradient_accumulation_steps, 312 | gpus=args.n_gpu, 313 | max_epochs=args.num_train_epochs, 314 | early_stop_callback=False, 315 | gradient_clip_val=args.max_grad_norm, 316 | checkpoint_callback=checkpoint_callback, 317 | callbacks=[LoggingCallback()], 318 | ) 319 | 320 | if args.fp16: 321 | train_params["use_amp"] = args.fp16 322 | train_params["amp_level"] = args.fp16_opt_level 323 | 324 | if args.n_tpu_cores > 0: 325 | global xm 326 | 327 | train_params["num_tpu_cores"] = args.n_tpu_cores 328 | train_params["gpus"] = 0 329 | 330 | if args.n_gpu > 1: 331 | train_params["distributed_backend"] = "ddp" 332 | 333 | trainer = pl.Trainer(**train_params) 334 | 335 | if args.do_train: 336 | trainer.fit(model) 337 | 338 | return trainer 339 | --------------------------------------------------------------------------------