├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── datasets ├── WAFER.md └── download_data.py ├── img └── side_logo.png └── projects └── verify_wikipedia ├── .gitignore ├── README.md ├── conf ├── base_model │ ├── hf_bert.yaml │ ├── hf_bert_large.yaml │ ├── hf_roberta.yaml │ └── hf_roberta_large.yaml ├── datasets │ ├── evaluation.yaml │ └── retrieval.yaml ├── input_transform │ └── biencoder.yaml ├── kilt │ ├── bm25_ccnet.json │ └── wafer.json ├── retrieval.yaml └── retriever │ ├── dpr_remote.yaml │ └── reranker.yaml ├── dpr ├── __init__.py ├── dataset │ ├── __init__.py │ ├── input_transform.py │ ├── retrieval.py │ ├── training.py │ └── utils.py ├── dense_retriever.py ├── indexer │ ├── client.py │ └── faiss_indexers.py ├── models │ ├── __init__.py │ ├── biencoder.py │ ├── hf_models.py │ ├── ranker_loss.py │ └── reranker.py ├── options.py └── utils │ ├── __init__.py │ ├── conf_utils.py │ ├── data_utils.py │ ├── dist_utils.py │ └── model_utils.py ├── evaluation ├── __init__.py ├── dataset.py └── retrievers │ ├── __init__.py │ ├── base_retriever.py │ ├── dpr.py │ ├── dpr_remote.py │ └── reranker.py ├── img └── architecture.jpg ├── internal └── eval │ └── retrieval │ └── retrieve │ ├── configs │ ├── best_dpr_remote.yaml │ └── best_reranker.yaml │ ├── retrieve.sh │ ├── retrieve_gpu_sharded.sh │ ├── slurm │ ├── sbatch_retrieve_gpu_sharded.sh │ └── wrapped_retrieve.sh │ └── start_distributed_faiss.sh ├── misc ├── __init__.py └── utils.py └── scripts ├── GAR_urltitle_generator.py ├── __init__.py └── retrieve_rerank.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | 4 | # Developer tools 5 | .vscode/ 6 | /node_modules/ 7 | 8 | # Installation files 9 | *.egg-info/ 10 | eggs/ 11 | .eggs/ 12 | *.egg 13 | 14 | .DS_Store 15 | 16 | # web 17 | web/side/api/__pycache__/api.cpython-37.pyc 18 | web/side2/api/__pycache__/api.cpython-37.pyc 19 | web/side/next.config.js 20 | web/side/prisma/dev.db 21 | web/side/prisma/dev.db-journal 22 | 23 | # data 24 | data/ 25 | 26 | # src 27 | src/ 28 | 29 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "projects/distributed-faiss"] 2 | path = projects/distributed-faiss 3 | url = git@github.com:facebookresearch/distributed-faiss.git 4 | [submodule "projects/KILT"] 5 | path = projects/KILT 6 | url = git@github.com:facebookresearch/KILT.git 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to side 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to side, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![side logo](img/side_logo.png) 2 | -------------------------------------------------------------------------------- 3 | 4 | Welcome! This is an effort to improve citations on Wikipedia. 5 | To read more about the overarching goals of this effort, please see our article on [ML-Assisted Wikipedia Editing](https://meta.wikimedia.org/wiki/Research:Machine_Learning_Assisted_Wikipedia_Editing). 6 | 7 | The Side system is described in the following paper: 8 | 9 | ```bibtex 10 | @inproceedings{petroni-etal-2022-improving, 11 | title = "Improving Wikipedia Verifiability with AI", 12 | author = {Petroni, Fabio and Broscheit, Samuel and Piktus, Aleksandra and Lewis, Patrick and Izacard, Gautier and Hosseini, Lucas and Dwivedi-Yu, Jane and Lomeli, Maria and Schick, Timo and Mazaré, Pierre-Emmanuel and Joulin, Armand and Grave, Edouard e Riedel, Sebastian}, 13 | url = "https://openreview.net/forum?id=qfTqRtkDbWZ", 14 | year = "2022" 15 | } 16 | ``` 17 | 18 | [https://openreview.net/forum?id=qfTqRtkDbWZ](https://openreview.net/forum?id=qfTqRtkDbWZ) 19 | 20 | Read our blog post [https://tech.fb.com/artificial-intelligence/2022/07/how-ai-could-help-make-wikipedia-entries-more-accurate/](https://tech.fb.com/artificial-intelligence/2022/07/how-ai-could-help-make-wikipedia-entries-more-accurate/). 21 | 22 | **Check out the Side demo at [https://verifier.sideeditor.com](https://verifier.sideeditor.com)!** 23 | 24 | Code and models in [projects/verify_wikipedia](projects/verify_wikipedia). 25 | 26 | ## License 27 | Side is MIT licensed. See the [LICENSE](LICENSE) file for details. 28 | -------------------------------------------------------------------------------- /datasets/WAFER.md: -------------------------------------------------------------------------------- 1 | # WAFER data 2 | 3 | | split | size | file | 4 | | ------------- | ------------- | ------------- | 5 | | [train](http://dl.fbaipublicfiles.com/side/wafer-train.jsonl.tar.gz) | 3805958 | [wafer-train.jsonl.tar.gz](http://dl.fbaipublicfiles.com/side/wafer-train.jsonl.tar.gz) | 6 | | [dev](http://dl.fbaipublicfiles.com/side/wafer-dev.jsonl.tar.gz) | 4545 | [wafer-dev.jsonl.tar.gz](http://dl.fbaipublicfiles.com/side/wafer-dev.jsonl.tar.gz) | 7 | | [test](http://dl.fbaipublicfiles.com/side/wafer-test.jsonl.tar.gz) | 4568 | [wafer-test.jsonl.tar.gz](http://dl.fbaipublicfiles.com/side/wafer-test.jsonl.tar.gz) | 8 | | [fail-dev](http://dl.fbaipublicfiles.com/side/wafer-fail-dev.jsonl.tar.gz) | 725 | [wafer-fail-dev.jsonl.tar.gz](http://dl.fbaipublicfiles.com/side/wafer-fail-dev.jsonl.tar.gz) | 9 | | [fail-test](http://dl.fbaipublicfiles.com/side/wafer-fail-test.jsonl.tar.gz) | 730 | [wafer-fail-test.jsonl.tar.gz](http://dl.fbaipublicfiles.com/side/wafer-fail-test.jsonl.tar.gz) | 10 | 11 | 12 | ### Download script 13 | 14 | ```bash 15 | python data/download_data.py --dest_dir data --dataset wafer 16 | ``` 17 | 18 | ### Structure of each record 19 | 20 | ```python 21 | { 22 | 'id': # unique id 23 | 'input': # in-context claim with [CIT] tag 24 | 'output': [ # list of valid urls for the citation 25 | { 26 | 'answer': # a url or {Failed verification} 27 | 'provenance': [ 28 | { 29 | 'url': # *mandatory* 30 | 'chunk_id': # from the Sphere retrieval engine 31 | 'title': # not provided, use the script to get this 32 | 'text': # not provided, use the script to get this 33 | } 34 | ] 35 | } 36 | ] 37 | 'meta': 38 | { 39 | 'wikipedia_id': # KILT wikipedia_id 40 | 'wikipedia_title': # KILT wikipedia_title 41 | 'wikipedia_section': # KILT wikipedia_section 42 | 'cit_paragraph_id': # KILT cit_paragraph_id 43 | 'cit_offset': # KILT cit_offset 44 | 'sentences': [] # sentences before claim - sentences[-1] contains [CIT] 45 | 'featured': # fatured flag, assume false if not present 46 | } 47 | } 48 | ``` -------------------------------------------------------------------------------- /datasets/download_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os.path 9 | import requests 10 | 11 | from tqdm import tqdm 12 | from pathlib import Path 13 | import tarfile 14 | 15 | SIDE_URL = "http://dl.fbaipublicfiles.com/side/" 16 | 17 | # wafer constants 18 | WAFER_FILENAMES = [ 19 | "wafer-train.jsonl.tar.gz", 20 | "wafer-dev.jsonl.tar.gz", 21 | "wafer-test.jsonl.tar.gz", 22 | "wafer-fail-dev.jsonl.tar.gz", 23 | "wafer-fail-test.jsonl.tar.gz", 24 | ] 25 | 26 | 27 | def download_file(url, file_path, overwrite): 28 | file_name = url.split("/")[-1] 29 | r = requests.get(url, stream=True) 30 | 31 | # Total size in bytes. 32 | total_size = int(r.headers.get("content-length", 0)) 33 | 34 | if not overwrite and os.path.isfile(file_path): 35 | current_size = os.path.getsize(file_path) 36 | if total_size == current_size: 37 | print(" - Skipping " + file_name + " - already exists.") 38 | return 39 | 40 | block_size = 1024 # 1 Kibibyte 41 | t = tqdm( 42 | total=total_size, 43 | unit="iB", 44 | unit_scale=True, 45 | desc=" - Downloading " + file_name + ": ", 46 | ) 47 | with open(file_path, "wb") as f: 48 | for data in r.iter_content(block_size): 49 | t.update(len(data)) 50 | f.write(data) 51 | t.close() 52 | 53 | 54 | def download_wafer(dest_dir, overwrite): 55 | Path(dest_dir).mkdir(parents=True, exist_ok=True) 56 | print("Downloading compressed sparse index:") 57 | 58 | for filename in WAFER_FILENAMES: 59 | print("Downloading", filename) 60 | download_file( 61 | SIDE_URL + filename, 62 | dest_dir + "/" + filename, 63 | overwrite, 64 | ) 65 | 66 | print("Extracting", filename) 67 | my_tar = tarfile.open(dest_dir + "/" + filename) 68 | my_tar.extractall(dest_dir + "/") 69 | my_tar.close() 70 | 71 | print("Removing", filename) 72 | os.remove(dest_dir + filename) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument( 78 | "--dest_dir", 79 | required=True, 80 | type=str, 81 | help="The path to a directory where index files should be stored", 82 | ) 83 | parser.add_argument( 84 | "--dataset", 85 | required=True, 86 | choices=["wafer"], 87 | type=str, 88 | help="The dataset to download.", 89 | ) 90 | parser.add_argument( 91 | "--overwrite", 92 | action="store_true", 93 | help="If flag set, existing files will be overwritter, otherwise skipping download.", 94 | ) 95 | args = parser.parse_args() 96 | 97 | if args.dataset == "wafer": 98 | download_wafer(args.dest_dir, args.overwrite) 99 | else: 100 | raise ValueError 101 | -------------------------------------------------------------------------------- /img/side_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/img/side_logo.png -------------------------------------------------------------------------------- /projects/verify_wikipedia/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # Developer tools 6 | .vscode/ 7 | 8 | # Installation files 9 | *.egg-info/ 10 | eggs/ 11 | .eggs/ 12 | *.egg 13 | *.mmap 14 | .idea/ 15 | 16 | # files 17 | outputs/ 18 | logs/ 19 | logs/* 20 | 21 | # models 22 | models 23 | models/ 24 | 25 | # src 26 | src/ -------------------------------------------------------------------------------- /projects/verify_wikipedia/README.md: -------------------------------------------------------------------------------- 1 | # Quick start 2 | 3 | This is a quick start tutorial to run inference on pre-trained models. 4 | 5 | ## Setup 6 | 7 | Install the project with 8 | 9 | **NOTE: The `side` conda environment is essential for some slurm scripts!** 10 | 11 | ```bash 12 | conda create -n side -y python=3.8 && conda activate side 13 | cd projects 14 | #git submodule add git@github.com:facebookresearch/distributed-faiss.git 15 | cd distributed-faiss 16 | pip install -e . 17 | cd .. 18 | #git submodule add git@github.com:facebookresearch/KILT.git 19 | cd KILT 20 | pip install -e . 21 | cd .. 22 | cd verify_wikipedia 23 | pip install -r requirements.txt 24 | pip install -e git+https://github.com/fairinternal/fairseq-py#egg=fairseq 25 | pip install hydra-core==1.1.1 26 | ``` 27 | 28 | ## Download WAFER data 29 | 30 | Follow [this README](../../datasets/WAFER.md). 31 | 32 | 33 | ## Architecture 34 | ![side architecture](./img/architecture.jpg) 35 | 36 | ## Downloading Index and Models 37 | 38 | Create a `models` folder under `verify_wikipedia` containing the following three archives uncompressed: 39 | 40 | - [dense](https://dl.fbaipublicfiles.com/side/dense.tar.gz) [1.2 TB] 41 | - [sparse](https://dl.fbaipublicfiles.com/side/sparse.tar.gz) [780 GB] 42 | - [verifier](https://dl.fbaipublicfiles.com/side/verifier.tar.gz) [8 GB] 43 | 44 | 45 | ## Retrieval Engine 46 | 47 | ### Dense Retrieval 48 | 49 | #### Setting the discovery-config path 50 | 51 | For communicating with the distributed server, a "discovery-config" file containing information about the nodes and ports used by the distributed server is required. That file needs to be updated each time the server is started, 52 | potentially causing a ``PermissionError`` if you use the script as is. To fix this, you need to: 53 | 54 | 1. Replace the default ``retriever.rpc_retriever_cfg_file`` in ``best_dpr_remote.yaml`` with some file that you have 55 | write access to. 56 | 2. Replace the default ``--discovery-config`` in ``start_distributed_faiss.sh`` with the very same file. 57 | 58 | #### Running retrieval with DPR 59 | 60 | 1. First start the distributed server (16 nodes) in a seperate screen/tmux session (interupting the job does not cancel the job, 61 | this has to be done manually with scancel) with distributed_faiss: 62 | 63 | ```bash 64 | screen -S distributed_faiss 65 | conda activate side 66 | bash internal/eval/retrieval/retrieve/start_distributed_faiss.sh models/dense/index/ 67 | ``` 68 | 69 | Tips: `Ctrl + a d` to detach for the screen session; `screen -r distributed_faiss` to reconnect. 70 | 71 | 2. As soon as the distributed server runs (16 nodes are shown in squeue), then retrieval for *evaluation data* can be 72 | done locally (on a devfair machine) with `retrieve.sh` but some devfairs might not have free ressources. So the 73 | safest way is to do the inference with the sharded mode which is also faster: 74 | 75 | ```bash 76 | bash internal/eval/retrieval/retrieve/retrieve_gpu_sharded.sh internal/eval/retrieval/retrieve/configs/best_dpr_remote.yaml 77 | ``` 78 | 79 | Alernatively it can be executed locally (for debug purposes) with: 80 | 81 | ```bash 82 | bash internal/eval/retrieval/retrieve/retrieve.sh internal/eval/retrieval/retrieve/configs/best_dpr_remote.yaml 83 | ``` 84 | 85 | If you want to use a custom dataset (let's say `custom_test`), you should add a configuration for it in [conf/datasets/evaluation.yaml](conf/datasets/evaluation.yaml), and then specify the new `evaluation_dataset` in the command line (or in the configuration script): 86 | ```bash 87 | bash internal/eval/retrieval/retrieve/retrieve.sh internal/eval/retrieval/retrieve/configs/best_dpr_remote.yaml evaluation_datasets=[custom_test] 88 | ``` 89 | 90 | #### Evaluate dense retrieval 91 | 92 | ```bash 93 | python ../KILT/kilt/eval_retrieval.py outputs/predictions/evaluation.retrievers.dpr_remote.DPR//wafer-dev.jsonl ../../data/wafer-dev.jsonl --rank_keys url --ks 1,100 94 | ``` 95 | 96 | ### Sparse Retrieval 97 | 98 | #### Running retrieval with GAR 99 | 100 | 1. First augment each query with generated expansions. Note, for this to work you should have the ["meta"]["sentences"] field populated as described in [this README](../../datasets/WAFER.md). 101 | 102 | ```bash 103 | mkdir -p outputs/predictions/expanded_inputs/ 104 | python scripts/GAR_urltitle_generator.py --test_filename ../../data/wafer-dev.jsonl --out_filename outputs/predictions/expanded_inputs/wafer-dev.jsonl 105 | ``` 106 | 107 | 2. Run BM25 retrieval with expanded queries with the KILT library. You should modify (conf/kilt/wafer.json)[conf/kilt/wafer.json] to point on the desired prediction file. 108 | 109 | ```bash 110 | module unload java 111 | export JAVA_HOME=/private/home/fabiopetroni/jdk-11.0.9 112 | PATH=${PATH}:${JAVA_HOME}/bin 113 | mkdir -p outputs/predictions/GAR/ 114 | python ../KILT/scripts/execute_retrieval.py --test_config conf/kilt/wafer.json --model_name bm25 --model_configuration conf/kilt/bm25_ccnet.json --output_folder outputs/predictions/GAR/ 115 | ``` 116 | 117 | #### Evaluate sparse retrieval 118 | 119 | ```bash 120 | python ../KILT/kilt/eval_retrieval.py outputs/predictions/GAR/wafer-dev.jsonl ../../data/wafer-dev.jsonl --rank_keys url --ks 1,100 121 | ``` 122 | 123 | ## Verification Engine 124 | 125 | First edit the `internal/eval/retrieval/retrieve/configs/best_reranker.yaml` to point to the dense (i.e., dpr) and sparse (i.e., gar) prediction files. 126 | 127 | Running the verification engine can be quite slow and it's better done sharded. To do so, call for example 128 | 129 | ```bash 130 | bash internal/eval/retrieval/retrieve/retrieve_gpu_sharded.sh internal/eval/retrieval/retrieve/configs/best_reranker.yaml 131 | ``` 132 | 133 | #### Evaluate the verification egine 134 | 135 | ```bash 136 | python ../KILT/kilt/eval_retrieval.py outputs/predictions/evaluation.retrievers.reranker.Reranker/wafer-dev.jsonl ../../data/wafer-dev.jsonl --rank_keys url --ks 1,100 137 | ``` 138 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/base_model/hf_bert.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # model type. One of [hf_bert, pytext_bert, fairseq_roberta] 8 | encoder_model_type: hf_bert 9 | 10 | # HuggingFace's config name for model initialization 11 | pretrained_model_cfg: bert-base-uncased 12 | 13 | # Some encoders need to be initialized from a file 14 | pretrained_file: 15 | 16 | # Extra linear layer on top of standard bert/roberta encoder 17 | projection_dim: 0 18 | 19 | # Max length of the encoder input sequence 20 | sequence_length: 256 21 | 22 | dropout: 0.1 23 | 24 | # if False, the model won't load pre-trained BERT weights 25 | pretrained: True 26 | 27 | share_parameters: False -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/base_model/hf_bert_large.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # model type. One of [hf_bert, pytext_bert, fairseq_roberta] 8 | encoder_model_type: hf_bert 9 | 10 | # HuggingFace's config name for model initialization 11 | pretrained_model_cfg: bert-large-uncased 12 | 13 | # Some encoders need to be initialized from a file 14 | pretrained_file: 15 | 16 | # Extra linear layer on top of standard bert/roberta encoder 17 | projection_dim: 0 18 | 19 | # Max length of the encoder input sequence 20 | sequence_length: 256 21 | 22 | dropout: 0.1 23 | 24 | # if False, the model won't load pre-trained BERT weights 25 | pretrained: True 26 | 27 | share_parameters: False -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/base_model/hf_roberta.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # model type. One of [hf_bert, pytext_bert, fairseq_roberta] 8 | encoder_model_type: hf_roberta 9 | 10 | # HuggingFace's config name for model initialization 11 | pretrained_model_cfg: roberta-base 12 | 13 | # Some encoders need to be initialized from a file 14 | pretrained_file: 15 | 16 | # Extra linear layer on top of standard bert/roberta encoder 17 | projection_dim: 0 18 | 19 | # Max length of the encoder input sequence 20 | sequence_length: 256 21 | 22 | dropout: 0.1 23 | 24 | # if False, the model won't load pre-trained BERT weights 25 | pretrained: True 26 | 27 | share_parameters: False -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/base_model/hf_roberta_large.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # model type. One of [hf_bert, pytext_bert, fairseq_roberta] 8 | encoder_model_type: hf_roberta 9 | 10 | # HuggingFace's config name for model initialization 11 | pretrained_model_cfg: roberta-large 12 | 13 | # Some encoders need to be initialized from a file 14 | pretrained_file: 15 | 16 | # Extra linear layer on top of standard bert/roberta encoder 17 | projection_dim: 0 18 | 19 | # Max length of the encoder input sequence 20 | sequence_length: 256 21 | 22 | dropout: 0.1 23 | 24 | # if False, the model won't load pre-trained BERT weights 25 | pretrained: True 26 | 27 | share_parameters: False -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/datasets/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | wafer_dev: 9 | _target_: evaluation.dataset.WaferCCNetTaskDataset 10 | file: ../../../../../../data/wafer-dev.jsonl 11 | task_family: Citations 12 | name: wafer-dev-kiltweb 13 | question_transform_type: 14 | 15 | wafer_test: 16 | _target_: evaluation.dataset.WaferCCNetTaskDataset 17 | file: ../../../../../../data/wafer-dev.jsonl 18 | task_family: Citations 19 | name: wafer-test-kiltweb 20 | question_transform_type: 21 | 22 | 23 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/datasets/retrieval.yaml: -------------------------------------------------------------------------------- 1 | wafer_ccnet: 2 | files: 3 | _target_: dpr.dataset.retrieval.CCNETMultiCsvCtxSrc 4 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/input_transform/biencoder.yaml: -------------------------------------------------------------------------------- 1 | do_lower_case: True 2 | 3 | question: 4 | type: only_before_cit+title 5 | dropout: 0.0 6 | 7 | passage: 8 | type: text+title -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/kilt/bm25_ccnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "index": "models/sparse/cc_net_bm25", 3 | "num_threads": 20, 4 | "k": 100, 5 | "Xms": "128g", 6 | "Xmx": "512g" 7 | } 8 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/kilt/wafer.json: -------------------------------------------------------------------------------- 1 | { 2 | "WAFER": { 3 | "test": "outputs/predictions/expanded_inputs/wafer-test.jsonl", 4 | "dev": "outputs/predictions/expanded_inputs/wafer-dev.jsonl" 5 | } 6 | } -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/retrieval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - datasets: evaluation 3 | - retriever: default 4 | 5 | retriever_config: 6 | 7 | n_docs: 100 8 | 9 | evaluation_datasets: 10 | - wafer_dev 11 | - wafer_test 12 | 13 | output_folder: ${retriever.output_folder} 14 | output_folder_in_checkpoint_dir: false 15 | output_suffix: 16 | task_id: ${now:%Y-%m-%d_%H-%M-%S} 17 | 18 | override_output_dir: outputs/predictions/${retriever._target_}/${task_id}/ 19 | 20 | hydra: 21 | run: 22 | dir: ${override_output_dir} 23 | 24 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/retriever/dpr_remote.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datasets: retrieval 3 | 4 | ctx_src: wafer_ccnet 5 | 6 | # Batch size to generate query embeddings 7 | batch_size: 1024 8 | 9 | # DPR index 10 | rpc_hnsw: False 11 | hnsw_index: False 12 | 13 | checkpoint_load_suffix: best_validation_acc # e.g. "0", "1", "best_validation_acc" 14 | checkpoint_dir: 15 | 16 | out_file: ${.checkpoint_dir}/ctxt_embeddings__${.checkpoint_load_suffix}__${.ctx_src}/shard_ 17 | model_file: ${.checkpoint_dir}/outputs/checkpoint.${.checkpoint_load_suffix} 18 | output_folder: 19 | #output_folder: ${.checkpoint_dir}/predictions/${.checkpoint_load_suffix}__${.ctx_src} 20 | 21 | # Remote DPR 22 | 23 | # Discovery config path of the remote distributed-faiss server 24 | rpc_retriever_cfg_file: /checkpoint/fabiopetroni/WAI/Samuel/distributed-faiss-data/45680605/discovery-config.txt 25 | 26 | # ID of the index to load 27 | rpc_index_id: hnswsq_default_index 28 | 29 | _target_: evaluation.retrievers.dpr_remote.DPR -------------------------------------------------------------------------------- /projects/verify_wikipedia/conf/retriever/reranker.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /datasets: retrieval 3 | 4 | ctx_src: 5 | 6 | # Batch size to generate query embeddings 7 | batch_size: 128 8 | 9 | checkpoint_load_suffix: best_validation_acc # e.g. "0", "1", "best_validation_acc" 10 | checkpoint_dir: 11 | 12 | retrieved_data_nr_ctxs: 100 13 | retrieved_data_files: [] 14 | 15 | model_file: ${.checkpoint_dir}/outputs/checkpoint.${.checkpoint_load_suffix} 16 | #output_folder: ${.checkpoint_dir}/predictions/${.checkpoint_load_suffix}__${.ctx_src} 17 | output_folder: 18 | 19 | _target_: evaluation.retrievers.reranker.Reranker -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/projects/verify_wikipedia/dpr/__init__.py -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/projects/verify_wikipedia/dpr/dataset/__init__.py -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/dataset/input_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | import json 9 | import logging 10 | from typing import List 11 | 12 | import torch 13 | 14 | from dpr.dataset.utils import normalize_passage 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | Passage = collections.namedtuple("Passage", ["text", "title"]) 20 | 21 | 22 | class QueryContextsRaw(object): 23 | query: str 24 | positive_passages: List[Passage] 25 | negative_passages: List[Passage] 26 | hard_negative_passages: List[Passage] 27 | 28 | 29 | WaferPassage = collections.namedtuple( 30 | "WaferPassage", 31 | [ 32 | "text", 33 | "title", 34 | "url", 35 | "passage_id", 36 | "bm25_hit", 37 | "rank", 38 | ], 39 | ) 40 | 41 | 42 | class WaferQueryContextsRawText(QueryContextsRaw): 43 | query: str 44 | positive_passages: List[WaferPassage] 45 | negative_passages: List[WaferPassage] 46 | negative_doc_passages: List[WaferPassage] 47 | hard_negative_passages: List[WaferPassage] 48 | hard_negative_doc_passages: List[WaferPassage] 49 | 50 | 51 | class SchemaPreprocessor: 52 | def __init__(self, cfg, tensorizer): 53 | self.cfg = cfg 54 | self.tensorizer = tensorizer 55 | 56 | def preprocess_query(self, question: str, training: bool = False) -> torch.Tensor: 57 | raise NotImplementedError 58 | 59 | def preprocess_passage(self, passage: Passage, training: bool = False) -> torch.Tensor: 60 | raise NotImplementedError 61 | 62 | def passage_struct_from_dict(self, ctx: dict): 63 | raise NotImplementedError 64 | 65 | 66 | class WaferPreprocessor(SchemaPreprocessor): 67 | def preprocess_query(self, query, training: bool = False): 68 | 69 | # extract title, section, text from the query string 70 | SEP_split = query.split("[SEP]") 71 | if len(SEP_split) == 1: 72 | # logger.warning(f"Text did not have three SEP sections: {len(SEP_split)}") 73 | title, section, text_cit = "", "", SEP_split[0] 74 | elif len(SEP_split) == 2: 75 | # logger.warning(f"Text did not have three SEP sections: {len(SEP_split)}") 76 | title, section, text_cit = SEP_split[0], "", SEP_split[1] 77 | elif len(SEP_split) == 3: 78 | title, section, text_cit = SEP_split 79 | CIT_split = text_cit.split("[CIT]") 80 | 81 | if len(CIT_split) == 1: 82 | logger.warning(f"Text did not have CIT section: {len(CIT_split)}") 83 | before_cit, after_cit = CIT_split[0], "" 84 | else: 85 | before_cit, after_cit = text_cit.split("[CIT]") 86 | 87 | # do some normalization and some hacks 88 | title_text = " ".join(title.split()) 89 | section_text = " ".join(section.split()) 90 | # double cit because BertTokenizer replaces last token with [SEP] 91 | # now only get tokens in front of citation 92 | before_cit_text = " ".join(before_cit.split()) 93 | 94 | # get question transform type; default: only use context left of citation 95 | question_transform_type = None 96 | if not hasattr(self.cfg, "input_transform") or ( 97 | self.cfg.input_transform.question.type is None or len(self.cfg.input_transform.question.type) == 0 98 | ): 99 | question_transform_type = "only_before_cit" 100 | else: 101 | question_transform_type = self.cfg.input_transform.question.type 102 | 103 | fields = { 104 | "title": None, 105 | "section": None, 106 | "bc_text": None, 107 | "ac_text": None, 108 | } 109 | 110 | if question_transform_type == "only_before_cit" or question_transform_type == "only_before_cit+title": 111 | fields["title"] = title_text 112 | fields["bc_text"] = before_cit_text 113 | elif question_transform_type == "only_before_cit+dropout_title+section": 114 | if training: 115 | dropout_field = torch.rand(1) 116 | if dropout_field[0] < self.cfg.input_transform.question.dropout: 117 | fields["title"] = self.tensorizer.get_mask_token() 118 | else: 119 | fields["title"] = title_text 120 | else: 121 | fields["title"] = title_text 122 | fields["section"] = section_text 123 | fields["bc_text"] = before_cit_text 124 | elif question_transform_type == "only_before_cit+dropout_title": 125 | if training: 126 | dropout_field = torch.rand(1) 127 | if dropout_field[0] < self.cfg.input_transform.question.dropout: 128 | fields["title"] = self.tensorizer.get_mask_token() 129 | else: 130 | fields["title"] = title_text 131 | else: 132 | fields["title"] = title_text 133 | fields["bc_text"] = before_cit_text 134 | elif question_transform_type == "only_before_cit+dropout_title+dropout_section": 135 | if training: 136 | dropout_field = torch.rand(1) 137 | if dropout_field[0] < self.cfg.input_transform.question.dropout: 138 | fields["section"] = self.tensorizer.get_mask_token() 139 | else: 140 | fields["section"] = section_text 141 | else: 142 | fields["section"] = section_text 143 | fields["title"] = title_text 144 | fields["bc_text"] = before_cit_text 145 | elif ( 146 | question_transform_type == "only_before_cit+section" 147 | or question_transform_type == "only_before_cit+title+section" 148 | or question_transform_type == "only_before_cit+section+title" 149 | ): 150 | fields["title"] = title_text 151 | fields["section"] = section_text 152 | fields["bc_text"] = before_cit_text 153 | elif question_transform_type == "only_before_cit+no_title": 154 | fields["bc_text"] = before_cit_text 155 | elif question_transform_type == "only_before_cit_30w" or question_transform_type == "only_before_cit_30w+title": 156 | fields["title"] = title_text 157 | fields["bc_text"] = " ".join(list(before_cit_text.split())[-30:]) 158 | elif question_transform_type == "only_before_cit_30w+no_title": 159 | fields["bc_text"] = " ".join(list(before_cit_text.split())[-30:]) 160 | else: 161 | raise Exception(f"Unknown question_transform_type: {question_transform_type}") 162 | 163 | field_tensors = { 164 | "title": None, 165 | "section": None, 166 | "bc_text": None, 167 | "ac_text": None, 168 | } 169 | 170 | self.tensorizer.set_pad_to_max(False) 171 | 172 | for k in field_tensors.keys(): 173 | if fields[k] is not None: 174 | field_tensors[k] = self.tensorizer.text_to_tensor( 175 | fields[k], apply_max_len=False, add_special_tokens=False 176 | ) 177 | 178 | fields_max_perc_size = { 179 | "title": 0.2, 180 | "section": 0.2, 181 | "bc_text": None, 182 | "ac_text": 0.2, 183 | } 184 | for k in field_tensors.keys(): 185 | if fields_max_perc_size[k]: 186 | fields_max_perc_size[k] = int(fields_max_perc_size[k] * self.tensorizer.max_length) 187 | 188 | SEP_LEN = 1 189 | CLS_LEN = 1 190 | 191 | for k in ["ac_text", "section", "title"]: 192 | if ( 193 | sum([len(t) + SEP_LEN for t in field_tensors.values() if t is not None]) + CLS_LEN 194 | > self.tensorizer.max_length 195 | ): 196 | if field_tensors[k] is not None and len(field_tensors[k]) + 1 > fields_max_perc_size[k]: 197 | field_tensors[k] = field_tensors[k][: (fields_max_perc_size[k] - SEP_LEN)] 198 | 199 | if ( 200 | sum([len(t) + SEP_LEN for t in field_tensors.values() if t is not None]) + CLS_LEN 201 | > self.tensorizer.max_length 202 | ): 203 | blocked_size = sum( 204 | [ 205 | len(field_tensors[k]) + SEP_LEN 206 | for k in ["ac_text", "section", "title"] 207 | if field_tensors[k] is not None 208 | ] 209 | ) 210 | free_size = self.tensorizer.max_length - blocked_size - SEP_LEN - CLS_LEN 211 | field_tensors["bc_text"] = field_tensors["bc_text"][-free_size:] 212 | 213 | tags = { 214 | "[CLS]": torch.tensor([self.tensorizer.get_cls_id()]), 215 | "[SEP]": torch.tensor([self.tensorizer.get_sep_id()]), 216 | "[CIT]": self.tensorizer.text_to_tensor("[CIT]", apply_max_len=False, add_special_tokens=False)[0:1], 217 | } 218 | 219 | ordered_fields_and_tags = [ 220 | ("title", tags["[SEP]"]), 221 | ("section", tags["[SEP]"]), 222 | ("bc_text", tags["[CIT]"]), 223 | ("ac_text", tags["[SEP]"]), 224 | ] 225 | 226 | tensor_list = list() 227 | tensor_list.append(tags["[CLS]"]) 228 | for field_name, tag in ordered_fields_and_tags: 229 | if field_tensors[field_name] is not None: 230 | tensor_list.append(field_tensors[field_name]) 231 | tensor_list.append(tag) 232 | 233 | pad_size = self.tensorizer.max_length - sum([len(t) for t in tensor_list]) 234 | if pad_size > 0: 235 | tensor_list.append((torch.ones(pad_size, dtype=torch.int) * self.tensorizer.get_pad_id()).long()) 236 | 237 | self.tensorizer.set_pad_to_max(True) 238 | assert ( 239 | sum([len(t) for t in tensor_list]) == self.tensorizer.max_length 240 | ), f"{[len(t) for t in tensor_list]} {sum([len(t) for t in tensor_list])}!={self.tensorizer.max_length}" 241 | return torch.cat(tensor_list, dim=0).long() 242 | 243 | def preprocess_passage(self, ctx: WaferPassage, training: bool = False) -> torch.Tensor: 244 | 245 | try: 246 | text = json.loads(ctx.text)["contents"] # for BM25 pyserini which can have json packed into the text 247 | except: 248 | text = ctx.text 249 | 250 | if hasattr(self.cfg, "normalize_passage"): 251 | if self.cfg.normalize: 252 | text = normalize_passage(text) 253 | else: 254 | text = normalize_passage(text) 255 | 256 | # handle incomplete or missing configs 257 | if hasattr(self.cfg, "passage_transform_type"): 258 | if self.cfg.input_transform.passage.type is None or len(self.cfg.input_transform.passage.type) == 0: 259 | passage_transform_type = "title+text" 260 | else: 261 | passage_transform_type = self.cfg.input_transform.passage.type 262 | else: 263 | passage_transform_type = "title+text" 264 | 265 | fields = { 266 | "title": None, 267 | "text": None, 268 | "url": None, 269 | "rank": None, 270 | } 271 | 272 | field_tensors = { 273 | "title": None, 274 | "text": None, 275 | "url": None, 276 | "rank": None, 277 | } 278 | 279 | if passage_transform_type == "text": 280 | fields["text"] = text 281 | elif passage_transform_type == "title+text": 282 | fields["title"] = ctx.title 283 | fields["text"] = text 284 | 285 | self.tensorizer.set_pad_to_max(False) 286 | 287 | # create token tensors 288 | for k in field_tensors.keys(): 289 | if fields[k] is not None: 290 | field_tensors[k] = self.tensorizer.text_to_tensor( 291 | fields[k], apply_max_len=False, add_special_tokens=False 292 | ) 293 | 294 | tags = { 295 | "[CLS]": torch.tensor([self.tensorizer.get_cls_id()]), 296 | "[SEP]": torch.tensor([self.tensorizer.get_sep_id()]), 297 | } 298 | 299 | ordered_fields_and_tags = [ 300 | ("title", tags["[SEP]"]), 301 | ("text", tags["[SEP]"]), 302 | ("url", tags["[SEP]"]), 303 | ("rank", tags["[SEP]"]), 304 | ] 305 | 306 | tensor_list = list() 307 | tensor_list.append(tags["[CLS]"]) 308 | for field_name, tag in ordered_fields_and_tags: 309 | if field_tensors[field_name] is not None: 310 | tensor_list.append(field_tensors[field_name]) 311 | tensor_list.append(tag) 312 | 313 | pad_size = self.tensorizer.max_length - sum([len(t) for t in tensor_list]) 314 | if pad_size > 0: 315 | tensor_list.append((torch.ones(pad_size, dtype=torch.int) * self.tensorizer.get_pad_id()).long()) 316 | 317 | self.tensorizer.set_pad_to_max(True) 318 | return torch.cat(tensor_list, dim=0)[: self.tensorizer.max_length].long() 319 | 320 | def passage_struct_from_dict(self, ctx: dict): 321 | return WaferPassage( 322 | ctx["text"], 323 | ctx["title"], 324 | ctx["url"], 325 | ctx["passage_id"] if "passage_id" in ctx else None, 326 | ctx["bm25_hit"] 327 | if "bm25_hit" in ctx 328 | else ctx["system"] 329 | if "system" in ctx 330 | else ctx["system_hit"] 331 | if "system_hit" in ctx 332 | else None, 333 | ctx["rank"] if "rank" in ctx else "[RANK:0]", 334 | ) 335 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/dataset/retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | import csv 9 | import logging 10 | import os 11 | import pickle 12 | import sys 13 | import zlib 14 | from pathlib import Path 15 | from typing import Dict, List 16 | 17 | import filelock 18 | import hydra 19 | import torch 20 | from omegaconf import DictConfig 21 | 22 | from dpr.dataset.input_transform import Passage, WaferPassage, WaferPreprocessor 23 | from dpr.dataset.utils import get_file_len, find_or_download_files 24 | from dpr.utils.data_utils import Tensorizer, DEFAULT_SELECTOR 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class ContextSource(torch.utils.data.Dataset): 30 | def __init__( 31 | self, 32 | file: str, 33 | selector: DictConfig = None, 34 | encoder_type: str = None, 35 | ): 36 | if selector: 37 | self.selector = hydra.utils.instantiate(selector) 38 | else: 39 | self.selector = DEFAULT_SELECTOR 40 | self.file = file 41 | self.data_files = [] 42 | 43 | def get_data_paths(self): 44 | self.data_files = find_or_download_files(self.file) 45 | assert len(self.data_files) > 0, f"self.file={self.file} ||| self.data_files={self.data_files}" 46 | self.file = self.data_files[0] 47 | 48 | def get_meta(self, ctxt): 49 | raise NotImplementedError 50 | 51 | def load_data_to(self, ctxs: Dict[object, Passage]): 52 | raise NotImplementedError 53 | 54 | 55 | class WaferCsvCtxSrc(ContextSource, WaferPreprocessor): 56 | def __init__( 57 | self, 58 | cfg: DictConfig, 59 | tensorizer: Tensorizer, 60 | file: str, 61 | id_prefix: str = None, 62 | selector: DictConfig = None, 63 | ): 64 | ContextSource.__init__( 65 | self, 66 | file=file, 67 | selector=selector, 68 | ) 69 | WaferPreprocessor.__init__( 70 | self, 71 | cfg=cfg, 72 | tensorizer=tensorizer, 73 | ) 74 | self.id_prefix = id_prefix 75 | 76 | def __len__(self): 77 | os.makedirs(f"{Path.home()}/.cache/files/", exist_ok=True) 78 | lock = filelock.FileLock(f"{Path.home()}/.cache/files/{self.file.replace('/', '_')}.lock") 79 | with lock: 80 | if not os.path.exists(f"{Path.home()}/.cache/files/{self.file.replace('/', '_')}.size"): 81 | size = get_file_len(self.file) 82 | with open(f"{Path.home()}/.cache/files/{self.file.replace('/', '_')}.size", "w") as f: 83 | f.writelines(f"{size}") 84 | else: 85 | with open(f"{Path.home()}/.cache/files/{self.file.replace('/', '_')}.size") as f: 86 | size = int(next(f)) 87 | return size 88 | 89 | def load_data_to(self, ctxs: Dict[object, WaferPassage], start=0, end=sys.maxsize): 90 | super().get_data_paths() 91 | logger.info("Reading file %s", self.file) 92 | with open(self.file) as fd: 93 | if start > 0: 94 | for i, line in enumerate(fd): 95 | if i + 1 == start: 96 | break 97 | reader = csv.reader(fd, delimiter="\t") 98 | for i, row in enumerate(reader): 99 | if start + i == end: 100 | break 101 | if row[0] == "id": 102 | continue 103 | if self.id_prefix: 104 | sample_id = self.id_prefix + str(row[0]) 105 | else: 106 | sample_id = row[0] 107 | 108 | ctxs[sample_id] = WaferPassage( 109 | text=row[1].strip('"'), 110 | title=row[2], 111 | url=row[3], 112 | passage_id=row[4], 113 | bm25_hit=row[5], 114 | rank=None, 115 | ) 116 | 117 | def get_meta(self, ctxt: WaferPassage): 118 | return ( 119 | zlib.compress(ctxt.text.encode()), 120 | zlib.compress(ctxt.title.encode()), 121 | zlib.compress(ctxt.url.encode()), 122 | ctxt.passage_id, 123 | ctxt.bm25_hit, 124 | ) 125 | 126 | 127 | CCNETCsvCtxSrcPassage = collections.namedtuple("CCNETCsvCtxSrcPassage", ["text", "title", "url", "is_wiki"]) 128 | 129 | 130 | class CCNETCsvCtxSrc(WaferCsvCtxSrc): 131 | def __init__( 132 | self, 133 | chunk_id__to__url__map, 134 | cfg, 135 | tensorizer, 136 | file, 137 | id_prefix, 138 | selector, 139 | ): 140 | WaferCsvCtxSrc.__init__( 141 | self, 142 | cfg=cfg, 143 | tensorizer=tensorizer, 144 | file=file, 145 | id_prefix=id_prefix, 146 | selector=selector, 147 | ) 148 | self.chunk_id__to__url__map = chunk_id__to__url__map 149 | 150 | def load_data_to(self, ctxs: Dict[object, CCNETCsvCtxSrcPassage], start=0, end=sys.maxsize): 151 | super().get_data_paths() 152 | logger.info("Reading file %s", self.file) 153 | with open(self.file) as fd: 154 | if start > 0: 155 | for i, line in enumerate(fd): 156 | if i + 1 == start: 157 | break 158 | reader = csv.reader(fd, delimiter="\t") 159 | for i, row in enumerate(reader): 160 | if start + i == end: 161 | break 162 | if row[0] == "id": 163 | continue 164 | if self.id_prefix: 165 | sample_id = self.id_prefix + str(row[0]) 166 | else: 167 | sample_id = row[0] 168 | ctxs[sample_id] = CCNETCsvCtxSrcPassage( 169 | text=row[1].strip('"'), 170 | title=row[2], 171 | url=self.chunk_id__to__url__map(sample_id), 172 | is_wiki=row[4], 173 | ) 174 | 175 | def get_meta(self, ctxt: CCNETCsvCtxSrcPassage): 176 | return ( 177 | zlib.compress(ctxt.text.encode()), 178 | zlib.compress(ctxt.title.encode()), 179 | zlib.compress(ctxt.url.encode()), 180 | ctxt.is_wiki, 181 | ) 182 | 183 | 184 | # TODO: create multiple files retriever data? 185 | class CCNETMultiCsvCtxSrc(ContextSource, WaferPreprocessor): 186 | def __init__( 187 | self, 188 | cfg: DictConfig, 189 | tensorizer: Tensorizer, 190 | files: str, 191 | id_prefix: str = None, 192 | selector: DictConfig = None, 193 | ): 194 | WaferPreprocessor.__init__( 195 | self, 196 | cfg=cfg, 197 | tensorizer=tensorizer, 198 | ) 199 | if selector: 200 | self.selector = hydra.utils.instantiate(selector) 201 | else: 202 | self.selector = DEFAULT_SELECTOR 203 | 204 | # TODO remove or make configurable 205 | self.chunk_id__to__url__map = None 206 | 207 | def lazy_pickle_dict(key): 208 | if self.chunk_id__to__url__map is None: 209 | with open("/checkpoint/piktus/sphere/chunk_id_url.pkl", "rb") as f: 210 | self.chunk_id__to__url__map = pickle.load(f) 211 | return self.chunk_id__to__url__map[key] 212 | 213 | self.ctx_srcs = list() 214 | if files: 215 | for file in files: 216 | self.ctx_srcs.append( 217 | CCNETCsvCtxSrc( 218 | chunk_id__to__url__map=lazy_pickle_dict, 219 | cfg=cfg, 220 | tensorizer=tensorizer, 221 | file=file, 222 | id_prefix=id_prefix, 223 | selector=selector, 224 | ) 225 | ) 226 | 227 | def __len__(self): 228 | size = 0 229 | for ctx_src in self.ctx_srcs: 230 | size += len(ctx_src) 231 | return size 232 | 233 | def load_data_to(self, ctxs: Dict[object, WaferPassage], start=0, end=sys.maxsize): 234 | dataset_start = 0 235 | dataset_end = 0 236 | for ctx_src in self.ctx_srcs: 237 | dataset_end += len(ctx_src) 238 | if start < dataset_end: 239 | if end <= dataset_end: 240 | ctx_src.load_data_to(ctxs, start - dataset_start, end - dataset_start) 241 | break 242 | else: 243 | ctx_src.load_data_to(ctxs, start - dataset_start, dataset_end - dataset_start) 244 | start = dataset_end 245 | dataset_start = dataset_end 246 | 247 | def get_meta(self, ctxt: CCNETCsvCtxSrcPassage): 248 | return self.ctx_srcs[0].get_meta(ctxt) 249 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/dataset/training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import random 10 | import socket 11 | import tempfile 12 | from typing import List, Union 13 | 14 | import filelock 15 | import torch 16 | from omegaconf import DictConfig 17 | 18 | from dpr.models.biencoder import QueryContextsBatch 19 | from dpr.dataset.input_transform import SchemaPreprocessor, WaferQueryContextsRawText, WaferPreprocessor 20 | from dpr.dataset.utils import find_or_download_files 21 | from dpr.utils.data_utils import read_data_from_json_files, Dataset, Tensorizer 22 | from dpr.utils.dist_utils import infer_slurm_init 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class JsonDataset(Dataset): 28 | def __init__( 29 | self, 30 | file: str, 31 | ): 32 | super().__init__() 33 | self.file = file 34 | self.data_files = [] 35 | 36 | def load_data(self, start_pos: int = -1, end_pos: int = -1, sharding_fn=None): 37 | # lock and wait for each process to load its shard to avoid swapping in DDP training on single nodes 38 | hostname = socket.gethostbyaddr(socket.gethostname())[0] 39 | max_concurrency = 4 40 | _, local_rank, world_size, _ = infer_slurm_init() 41 | lock = filelock.FileLock( 42 | os.path.join(f"{tempfile.gettempdir()}", f".{hostname}.lock.{int(local_rank % max_concurrency)}") 43 | ) 44 | with lock: 45 | if not self.data: 46 | self._load_all_data() 47 | start_pos, end_pos = sharding_fn(len(self.data)) 48 | if start_pos >= 0 and end_pos >= 0: 49 | logger.info("Selecting subset range from %d to %d", start_pos, end_pos) 50 | self.data = self.data[start_pos:end_pos] 51 | 52 | def _load_all_data(self): 53 | self.data_files = find_or_download_files(self.file) 54 | logger.info("Data files: %s", self.data_files) 55 | data = read_data_from_json_files(self.data_files) 56 | # filter those without positive ctx 57 | self.data = [r for r in data if len(r["positive_ctxs"]) > 0] 58 | logger.info("Total cleaned data size: %d", len(self.data)) 59 | 60 | def create_model_input( 61 | self, 62 | samples: List[WaferQueryContextsRawText], 63 | dataset: Union[Dataset, SchemaPreprocessor], 64 | cfg, 65 | is_training: bool, 66 | ): 67 | raise NotImplementedError 68 | 69 | def prepare_model_inputs(self, batch: dict): 70 | raise NotImplementedError 71 | 72 | 73 | class JsonWaferDataset(JsonDataset, WaferPreprocessor): 74 | def __init__( 75 | self, 76 | cfg: DictConfig, 77 | tensorizer: Tensorizer, 78 | file: str, 79 | shuffle_negatives: bool = False, 80 | ): 81 | JsonDataset.__init__( 82 | self, 83 | file=file, 84 | ) 85 | WaferPreprocessor.__init__( 86 | self, 87 | cfg=cfg, 88 | tensorizer=tensorizer, 89 | ) 90 | self.shuffle_negatives = shuffle_negatives 91 | 92 | def __getitem__(self, index) -> WaferQueryContextsRawText: 93 | json_sample: dict = self.data[index] 94 | r = WaferQueryContextsRawText() 95 | r.query = json_sample["question"] 96 | 97 | positive_ctxs = json_sample["positive_ctxs"] 98 | 99 | negative_ctxs = json_sample["negative_ctxs"] if "negative_ctxs" in json_sample else [] 100 | negative_doc_ctxs = json_sample["negative_doc_ctxs"] if "negative_doc_ctxs" in json_sample else [] 101 | hard_negative_ctxs = json_sample["hard_negative_ctxs"] if "hard_negative_ctxs" in json_sample else [] 102 | hard_negative_doc_ctxs = ( 103 | json_sample["hard_negative_doc_ctxs"] if "hard_negative_doc_ctxs" in json_sample else [] 104 | ) 105 | 106 | for ctx in positive_ctxs + negative_ctxs + negative_doc_ctxs + hard_negative_ctxs + hard_negative_doc_ctxs: 107 | if "title" not in ctx: 108 | ctx["title"] = None 109 | 110 | r.positive_passages = [self.passage_struct_from_dict(ctx) for ctx in positive_ctxs] 111 | r.negative_passages = [self.passage_struct_from_dict(ctx) for ctx in negative_ctxs] 112 | r.negative_doc_passages = [self.passage_struct_from_dict(ctx) for ctx in negative_doc_ctxs] 113 | r.hard_negative_passages = [self.passage_struct_from_dict(ctx) for ctx in hard_negative_ctxs] 114 | r.hard_negative_doc_passages = [self.passage_struct_from_dict(ctx) for ctx in hard_negative_doc_ctxs] 115 | return r 116 | 117 | def create_model_input( 118 | self, 119 | samples: List[WaferQueryContextsRawText], 120 | dataset: Union[Dataset, SchemaPreprocessor], 121 | cfg, 122 | is_training: bool, 123 | ) -> QueryContextsBatch: 124 | 125 | question_tensors = [] 126 | ctx_tensors = [] 127 | positive_and_negative_ctx_lens = [] 128 | neg_ctx_indices = [] 129 | 130 | for sample in samples: 131 | 132 | question = sample.query 133 | 134 | positive_ctx = sample.positive_passages 135 | neg_ctxs = sample.negative_passages 136 | neg_doc_ctxs = sample.negative_doc_passages 137 | hard_neg_ctxs = sample.hard_negative_passages 138 | hard_neg_doc_ctxs = sample.hard_negative_doc_passages 139 | 140 | # 141 | # Set the number of positive and negative contexts 142 | # 143 | 144 | if is_training: 145 | 146 | num_positives = cfg.train_iterator.num_positives 147 | other_negatives = cfg.train_iterator.other_negatives 148 | other_doc_negatives = cfg.train_iterator.other_doc_negatives 149 | hard_negatives = cfg.train_iterator.hard_negatives 150 | 151 | else: 152 | 153 | # if validation then reuse the train config if nothing is defined 154 | 155 | if cfg.valid_iterator.num_positives is not None: 156 | num_positives = cfg.valid_iterator.num_positives 157 | else: 158 | num_positives = cfg.train_iterator.num_positives 159 | 160 | if cfg.valid_iterator.other_negatives is not None: 161 | other_negatives = cfg.valid_iterator.other_negatives 162 | else: 163 | other_negatives = cfg.train_iterator.other_negatives 164 | 165 | if cfg.valid_iterator.other_doc_negatives is not None: 166 | other_doc_negatives = cfg.valid_iterator.other_doc_negatives 167 | else: 168 | other_doc_negatives = cfg.train_iterator.other_doc_negatives 169 | 170 | if cfg.valid_iterator.hard_negatives is not None: 171 | hard_negatives = cfg.valid_iterator.hard_negatives 172 | else: 173 | hard_negatives = cfg.train_iterator.hard_negatives 174 | 175 | # 176 | # Trim the available contexts down to the defined number of contexts 177 | # 178 | 179 | # positive contexts 180 | 181 | positive_ctx = positive_ctx[0:num_positives] 182 | 183 | # randomly sampled negative single contexts (one contexts from multiple documents) 184 | 185 | if is_training and cfg.train_iterator.shuffle_negatives and len(neg_ctxs) > 0: 186 | random.shuffle(neg_ctxs) 187 | if cfg.model_class in { 188 | "BERTRerankerCrossEncoder", 189 | "ColbertRerankerBiEncoder", 190 | }: 191 | neg_ctxs = ( 192 | neg_ctxs * 5 193 | ) # TODO: hack for WAFER reranker dev to make up for instances with too few negatives 194 | neg_ctxs = neg_ctxs[0:other_negatives] 195 | else: 196 | neg_ctxs = neg_ctxs[0:other_negatives] 197 | 198 | # randomly sampled negative single contexts (multiple contexts from one document) 199 | 200 | if is_training and cfg.train_iterator.shuffle_negatives and len(neg_doc_ctxs) > 0: 201 | random.shuffle(neg_doc_ctxs) 202 | neg_doc_ctxs = neg_doc_ctxs[0:other_doc_negatives] 203 | 204 | # hard sampled negative single contexts (one contexts from multiple documents) 205 | 206 | if is_training and cfg.train_iterator.shuffle_negatives and len(hard_neg_ctxs) > 0: 207 | random.shuffle(hard_neg_ctxs) 208 | hard_neg_ctxs = hard_neg_ctxs[0:hard_negatives] 209 | 210 | # hard sampled negative document contexts (multiple contexts from one document) 211 | 212 | if is_training and cfg.train_iterator.shuffle_negatives and len(hard_neg_doc_ctxs) > 0: 213 | random.shuffle(hard_neg_doc_ctxs) 214 | hard_neg_doc_ctxs = hard_neg_doc_ctxs[0 : cfg.train_iterator.hard_doc_negatives] 215 | 216 | # 217 | # Trim all negatives (other, hard, ...) by the number of positives 218 | # Thus the total number of contexts per instance is defined by the total number of negative contexts 219 | # 220 | ns = [neg_ctxs, neg_doc_ctxs, hard_neg_ctxs, hard_neg_doc_ctxs] 221 | p = len(positive_ctx) 222 | while True: 223 | for i in range(len(ns)): 224 | if p > 0 and len(ns[i]) > 0: 225 | ns[i] = ns[i][:-1] 226 | p -= 1 227 | if p == 0 or sum([len(n) for n in ns]) == 0: 228 | break 229 | if p == 0 or sum([len(n) for n in ns]) == 0: 230 | break 231 | 232 | # 233 | # Combine all contexts 234 | # 235 | 236 | neg_ctxs, neg_doc_ctxs, hard_neg_ctxs, hard_neg_doc_ctxs = ns 237 | all_ctxs = positive_ctx + neg_ctxs + neg_doc_ctxs + hard_neg_ctxs + hard_neg_doc_ctxs 238 | 239 | # 240 | # Compute offsets 241 | # 242 | 243 | current_ctxs_len = len(ctx_tensors) 244 | positive_ctx_len = len(positive_ctx) 245 | neg_ctxs_len = len(neg_ctxs) + len(neg_doc_ctxs) + len(hard_neg_ctxs) + len(hard_neg_doc_ctxs) 246 | 247 | negatives_start_idx = positive_ctx_len 248 | negatives_end_idx = positive_ctx_len + neg_ctxs_len 249 | 250 | # 251 | # Preprocess contexts 252 | # 253 | 254 | sample_ctxs_tensors = [dataset.preprocess_passage(ctx, is_training) for ctx in all_ctxs] 255 | 256 | # 257 | # Finalize 258 | # 259 | 260 | ctx_tensors.extend(sample_ctxs_tensors) 261 | 262 | positive_and_negative_ctx_lens.append([positive_ctx_len, neg_ctxs_len]) 263 | neg_ctx_indices.append( 264 | [ 265 | i 266 | for i in range( 267 | current_ctxs_len + negatives_start_idx, 268 | current_ctxs_len + negatives_end_idx, 269 | ) 270 | ] 271 | ) 272 | 273 | question_tensors.append(dataset.preprocess_query(question, is_training)) 274 | 275 | # 276 | # Create batch tensor 277 | # 278 | 279 | ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) 280 | questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) 281 | 282 | # 283 | # Create batch tensor 284 | # 285 | 286 | if cfg.n_gpu > 0: 287 | divider = cfg.n_gpu 288 | else: 289 | divider = 1 290 | truncate_by = ctxs_tensor.size(0) % divider 291 | if truncate_by > 0 and cfg.model_class in { 292 | "BERTRerankerCrossEncoder", 293 | "ColbertRerankerBiEncoder", 294 | }: 295 | ctxs_tensor = ctxs_tensor[:-truncate_by] 296 | 297 | ctx_segments = torch.zeros_like(ctxs_tensor) 298 | question_segments = torch.zeros_like(questions_tensor) 299 | 300 | ctxt_total_len = sum([sum(l) for l in positive_and_negative_ctx_lens]) 301 | positive_ctx_indices = [] 302 | offset = 0 303 | 304 | for positive_and_negative_ctx_len in positive_and_negative_ctx_lens: 305 | pos_len, neg_len = positive_and_negative_ctx_len 306 | next_offset = offset + sum(positive_and_negative_ctx_len) 307 | if ( 308 | cfg.model_class 309 | in { 310 | "BERTRerankerCrossEncoder", 311 | "ColbertRerankerBiEncoder", 312 | } 313 | or cfg.model_class == "DocumentBiEncoder" 314 | and self.cfg.distributed_world_size > 1 315 | ): 316 | positive_ctx_indices.append([True] * pos_len + [False] * neg_len) 317 | else: 318 | positive_ctx_indices.append( 319 | [False] * offset + [True] * pos_len + [False] * neg_len + [False] * (ctxt_total_len - next_offset) 320 | ) 321 | offset = next_offset 322 | 323 | positive_ctx_indices = torch.tensor(positive_ctx_indices, dtype=bool) 324 | 325 | if truncate_by > 0 and cfg.model_class in { 326 | "BERTRerankerCrossEncoder", 327 | "ColbertRerankerBiEncoder", 328 | }: 329 | positive_ctx_indices = positive_ctx_indices[:, -truncate_by] 330 | 331 | return QueryContextsBatch( 332 | questions_tensor, 333 | question_segments, 334 | ctxs_tensor, 335 | ctx_segments, 336 | positive_ctx_indices, 337 | neg_ctx_indices, 338 | "question", 339 | ) 340 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/dataset/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import glob 8 | import os 9 | from typing import List 10 | 11 | 12 | def normalize_passage(ctx_text: str): 13 | ctx_text = ctx_text.replace("\n", " ").replace("’", "'") 14 | if ctx_text.startswith('"'): 15 | ctx_text = ctx_text[1:] 16 | if ctx_text.endswith('"'): 17 | ctx_text = ctx_text[:-1] 18 | return ctx_text 19 | 20 | 21 | def find_or_download_files(source_name) -> List[str]: 22 | if os.path.exists(source_name) or glob.glob(source_name): 23 | return glob.glob(source_name) 24 | # else: 25 | # # try to use data downloader 26 | # from dpr.data.download_data import download 27 | # return download(source_name) 28 | 29 | 30 | def get_file_len(filename): 31 | """ 32 | From https://stackoverflow.com/questions/845058/how-to-get-line-count-of-a-large-file-cheaply-in-python/68385697#68385697 33 | 34 | :param filename: 35 | :return: 36 | """ 37 | 38 | def _make_gen(reader): 39 | b = reader(2**16) 40 | while b: 41 | yield b 42 | b = reader(2**16) 43 | 44 | with open(filename, "rb") as f: 45 | count = sum(buf.count(b"\n") for buf in _make_gen(f.raw.read)) 46 | return count 47 | 48 | 49 | def resolve_file(cfg, file_name, error_not_exists=True): 50 | if file_name is None: 51 | if error_not_exists: 52 | print(f"File was 'None'.") 53 | exit() 54 | else: 55 | return None 56 | if not os.path.exists(file_name): 57 | if file_name in cfg.create_data.datasets: 58 | return cfg.create_data.datasets[file_name]["file"] 59 | else: 60 | if error_not_exists: 61 | print(f"File {file_name} does not exist and is not an entry in datasets.") 62 | exit() 63 | return file_name 64 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/dense_retriever.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line tool to get dense results and validate them 10 | """ 11 | 12 | import logging 13 | import pickle 14 | import time 15 | from typing import List, Tuple, Union, Iterator 16 | 17 | import numpy as np 18 | import torch 19 | import tqdm 20 | from torch import Tensor as T 21 | from torch import nn 22 | 23 | from dpr.indexer.faiss_indexers import ( 24 | DenseIndexer, 25 | ) 26 | from dpr.models.biencoder import ( 27 | BiEncoder, 28 | ) 29 | from dpr.options import setup_logger 30 | from dpr.dataset.input_transform import SchemaPreprocessor 31 | from dpr.dataset.retrieval import ContextSource 32 | 33 | logger = logging.getLogger() 34 | setup_logger(logger) 35 | 36 | 37 | def generate_question_vectors( 38 | question_encoder: torch.nn.Module, 39 | qa_src: Union[ContextSource, SchemaPreprocessor], 40 | questions: List[str], 41 | bsz: int, 42 | ) -> T: 43 | n = len(questions) 44 | query_vectors = [] 45 | selector = qa_src.selector 46 | with torch.no_grad(): 47 | for j, batch_start in enumerate(range(0, n, bsz)): 48 | batch_questions = questions[batch_start : batch_start + bsz] 49 | batch_tensors = [qa_src.preprocess_query(q) for q in batch_questions] 50 | 51 | q_ids_batch = torch.stack(batch_tensors, dim=0).cuda() 52 | q_seg_batch = torch.zeros_like(q_ids_batch).cuda() 53 | q_attn_mask = qa_src.tensorizer.get_attn_mask(q_ids_batch) 54 | 55 | if selector: 56 | rep_positions = selector.get_positions(q_ids_batch, qa_src.tensorizer) 57 | 58 | _, out, _ = BiEncoder.get_representation( 59 | question_encoder, 60 | q_ids_batch, 61 | q_seg_batch, 62 | q_attn_mask, 63 | representation_token_pos=rep_positions, 64 | ) 65 | else: 66 | _, out, _ = question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) 67 | 68 | query_vectors.extend(out.cpu().split(1, dim=0)) 69 | 70 | if len(query_vectors) % 100 == 0: 71 | logger.info("Encoded queries %d", len(query_vectors)) 72 | 73 | query_tensor = torch.cat(query_vectors, dim=0) 74 | logger.info("Total encoded queries tensor %s", query_tensor.size()) 75 | assert query_tensor.size(0) == len(questions) 76 | return query_tensor 77 | 78 | 79 | class DenseRetriever(object): 80 | def __init__( 81 | self, 82 | question_encoder: nn.Module, 83 | batch_size: int, 84 | qa_src: Union[ContextSource, SchemaPreprocessor], 85 | ): 86 | self.question_encoder = question_encoder 87 | self.batch_size = batch_size 88 | self.qa_src = qa_src 89 | 90 | def generate_question_vectors(self, questions: List[str]) -> T: 91 | 92 | bsz = self.batch_size 93 | self.question_encoder.eval() 94 | return generate_question_vectors( 95 | self.question_encoder, 96 | self.qa_src, 97 | questions, 98 | bsz, 99 | ) 100 | 101 | 102 | class LocalFaissRetriever(DenseRetriever): 103 | """ 104 | Does passage retrieving over the provided index and question encoder 105 | """ 106 | 107 | def __init__( 108 | self, 109 | question_encoder: nn.Module, 110 | batch_size: int, 111 | qa_src: SchemaPreprocessor, 112 | index: DenseIndexer, 113 | ): 114 | super().__init__(question_encoder, batch_size, qa_src) 115 | self.index = index 116 | 117 | def index_encoded_data( 118 | self, 119 | vector_files: List[str], 120 | buffer_size: int, 121 | path_id_prefixes: List = None, 122 | ): 123 | """ 124 | Indexes encoded passages takes form a list of files 125 | :param vector_files: file names to get passages vectors from 126 | :param buffer_size: size of a buffer (amount of passages) to send for the indexing at once 127 | :return: 128 | """ 129 | buffer = [] 130 | for i, item in enumerate(iterate_encoded_files(vector_files, path_id_prefixes=path_id_prefixes)): 131 | buffer.append(item) 132 | if 0 < buffer_size == len(buffer): 133 | self.index.index_data(buffer) 134 | buffer = [] 135 | self.index.index_data(buffer) 136 | logger.info("Data indexing completed.") 137 | 138 | def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]: 139 | """ 140 | Does the retrieval of the best matching passages given the query vectors batch 141 | :param query_vectors: 142 | :param top_docs: 143 | :return: 144 | """ 145 | time0 = time.time() 146 | results = self.index.search_knn(query_vectors, top_docs) 147 | logger.info("index search time: %f sec.", time.time() - time0) 148 | # self.index = None 149 | return results 150 | 151 | 152 | class DenseRPCRetriever(DenseRetriever): 153 | def __init__( 154 | self, 155 | question_encoder: nn.Module, 156 | qa_src: SchemaPreprocessor, 157 | batch_size: int, 158 | index_cfg_path: str, 159 | dim: int, 160 | use_l2_conversion: bool = False, 161 | nprobe: int = 256, 162 | ): 163 | # from distributed_faiss.client import IndexClient // using a copy because of a small bugfix TODO make PR to df 164 | from dpr.indexer.client import IndexClient 165 | from distributed_faiss.index_cfg import IndexCfg 166 | 167 | # super().__init__(question_encoder, batch_size, tensorizer) 168 | super().__init__(question_encoder, batch_size, qa_src) 169 | self.dim = dim 170 | self.index_id = "dr" 171 | self.nprobe = nprobe 172 | logger.info("Connecting to index server ...") 173 | self.index_client = IndexClient(index_cfg_path) 174 | self.use_l2_conversion = use_l2_conversion 175 | logger.info("Connected") 176 | 177 | def load_index(self, index_id): 178 | from distributed_faiss.index_cfg import IndexCfg 179 | 180 | self.index_id = index_id 181 | 182 | logger.info("Loading remote index %s", index_id) 183 | 184 | idx_cfg = IndexCfg() 185 | idx_cfg.nprobe = self.nprobe 186 | if self.use_l2_conversion: 187 | idx_cfg.metric = "l2" 188 | 189 | self.index_client.load_index(self.index_id, cfg=idx_cfg, force_reload=False) 190 | logger.info("Index loaded") 191 | self._wait_index_ready(index_id) 192 | 193 | def index_encoded_data( 194 | self, 195 | vector_files: List[str], 196 | buffer_size: int = 1000, 197 | path_id_prefixes: List = None, 198 | ): 199 | """ 200 | Indexes encoded passages takes form a list of files 201 | :param vector_files: file names to get passages vectors from 202 | :param buffer_size: size of a buffer (amount of passages) to send for the indexing at once 203 | :return: 204 | """ 205 | from distributed_faiss.index_cfg import IndexCfg 206 | 207 | buffer = [] 208 | idx_cfg = IndexCfg() 209 | 210 | idx_cfg.dim = self.dim 211 | logger.info("Index train num=%d", idx_cfg.train_num) 212 | idx_cfg.faiss_factory = "flat" 213 | index_id = self.index_id 214 | self.index_client.create_index(index_id, idx_cfg) 215 | 216 | def send_buf_data(buffer, index_client): 217 | buffer_vectors = [np.reshape(encoded_item[1], (1, -1)) for encoded_item in buffer] 218 | buffer_vectors = np.concatenate(buffer_vectors, axis=0) 219 | meta = [encoded_item[0] for encoded_item in buffer] 220 | index_client.add_index_data(index_id, buffer_vectors, meta) 221 | 222 | for i, item in enumerate(iterate_encoded_files(vector_files, path_id_prefixes=path_id_prefixes)): 223 | buffer.append(item) 224 | if 0 < buffer_size == len(buffer): 225 | send_buf_data(buffer, self.index_client) 226 | buffer = [] 227 | if buffer: 228 | send_buf_data(buffer, self.index_client) 229 | logger.info("Embeddings sent.") 230 | self._wait_index_ready(index_id) 231 | 232 | def get_top_docs( 233 | self, query_vectors: np.array, top_docs: int = 100, search_batch: int = 512, filter_pos=3, filter_value=True 234 | ) -> List[Tuple[List[object], List[float]]]: 235 | """ 236 | Does the retrieval of the best matching passages given the query vectors batch 237 | :param query_vectors: 238 | :param top_docs: 239 | :return: 240 | """ 241 | if self.use_l2_conversion: 242 | aux_dim = np.zeros(len(query_vectors), dtype="float32") 243 | query_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 244 | logger.info("query_hnsw_vectors %s", query_vectors.shape) 245 | self.index_client.cfg.metric = "l2" 246 | 247 | results = [] 248 | for i in tqdm.tqdm(range(0, query_vectors.shape[0], search_batch)): 249 | time0 = time.time() 250 | query_batch = query_vectors[i : i + search_batch] 251 | logger.info("query_batch: %s", query_batch.shape) 252 | # scores, meta = self.index_client.search(query_batch, top_docs, self.index_id) 253 | 254 | scores, meta = self.index_client.search_with_filter( 255 | query_batch, top_docs, self.index_id, filter_pos=filter_pos, filter_value=filter_value 256 | ) 257 | 258 | logger.info("index search time: %f sec.", time.time() - time0) 259 | results.extend([(meta[q], scores[q]) for q in range(len(scores))]) 260 | return results 261 | 262 | def _wait_index_ready(self, index_id: str): 263 | from distributed_faiss.index_state import IndexState 264 | 265 | # TODO: move this method into IndexClient class 266 | while self.index_client.get_state(index_id) != IndexState.TRAINED: 267 | time.sleep(10) 268 | logger.info( 269 | "Remote Index is ready. Index data size %d", 270 | self.index_client.get_ntotal(index_id), 271 | ) 272 | 273 | 274 | def get_all_passages(ctx_sources): 275 | try: 276 | 277 | all_passages = {} 278 | for ctx_src in ctx_sources: 279 | ctx_src.load_data_to(all_passages) 280 | logger.info("Loaded ctx data: %d", len(all_passages)) 281 | 282 | if len(all_passages) == 0: 283 | raise RuntimeError("No passages data found. Please specify ctx_file param properly.") 284 | 285 | return all_passages 286 | 287 | except NotImplementedError: 288 | 289 | for ctx_src in ctx_sources: 290 | ctx_src.load_data() 291 | logger.info(f"Loaded ctx data: {ctx_src}") 292 | 293 | 294 | def iterate_encoded_files(vector_files: list, path_id_prefixes: List = None) -> Iterator[Tuple]: 295 | for i, file in enumerate(vector_files): 296 | logger.info("Reading file %s", file) 297 | id_prefix = None 298 | if path_id_prefixes: 299 | id_prefix = path_id_prefixes[i] 300 | with open(file, "rb") as reader: 301 | doc_vectors = pickle.load(reader) 302 | for doc in doc_vectors: 303 | doc = list(doc) 304 | if id_prefix and not str(doc[0]).startswith(id_prefix): 305 | doc[0] = id_prefix + str(doc[0]) 306 | yield doc 307 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/indexer/client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import itertools 4 | 5 | # Copyright (c) Facebook, Inc. and its affiliates. 6 | # 7 | # This source code is licensed under the MIT license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | import logging 10 | import os 11 | import random 12 | import time 13 | from multiprocessing.dummy import Pool as ThreadPool 14 | from typing import List, Tuple 15 | 16 | import faiss 17 | import numpy as np 18 | 19 | from distributed_faiss.index_cfg import IndexCfg 20 | from distributed_faiss import rpc 21 | from distributed_faiss.index_state import IndexState 22 | 23 | logger = logging.getLogger() 24 | 25 | from typing import Optional 26 | 27 | 28 | class ResultHeap: 29 | """Accumulate query results from a sliced dataset. The final result will 30 | be in self.D, self.I. Using faiss.float_maxheap_array_t()""" 31 | 32 | def __init__(self, nq, k): 33 | "nq: number of query vectors, k: number of results per query" 34 | self.I = np.zeros((nq, k), dtype="int64") 35 | self.D = np.zeros((nq, k), dtype="float32") 36 | self.nq, self.k = nq, k 37 | heaps = faiss.float_maxheap_array_t() 38 | 39 | heaps.k = k 40 | heaps.nh = nq 41 | heaps.val = faiss.swig_ptr(self.D) 42 | heaps.ids = faiss.swig_ptr(self.I) 43 | heaps.heapify() 44 | self.heaps = heaps 45 | 46 | def add_result(self, D, I): 47 | """D, I do not need to be in a particular order (heap or sorted)""" 48 | assert D.shape == (self.nq, self.k) 49 | assert I.shape == (self.nq, self.k) 50 | self.heaps.addn_with_ids(self.k, faiss.swig_ptr(D), faiss.swig_ptr(I), self.k) 51 | 52 | def finalize(self): 53 | self.heaps.reorder() 54 | 55 | 56 | class IndexClient: 57 | """ 58 | Manages a set of distance sub-indexes. The sub_indexes search a 59 | subset of the inverted lists. Searches are merged afterwards 60 | """ 61 | 62 | def __init__(self, server_list_path: str, cfg_path: Optional[str] = None): 63 | """connect to a series of (host, port) pairs""" 64 | machine_ports = IndexClient.read_server_list(server_list_path) 65 | self.sub_indexes = IndexClient.setup_connection(machine_ports) 66 | self.num_indexes = len(self.sub_indexes) 67 | 68 | # index_rank_to_id is a map between logical node id and sub-indexes list id 69 | # it is useful for debuggind 70 | # might be useful for a better load balancing upon index restart with failed/lagged behind nodes 71 | index_ranks = [idx.get_rank() for idx in self.sub_indexes] 72 | logger.info("index_ranks %s", index_ranks) 73 | self.index_rank_to_id = {index_rank: index_id for index_id, index_rank in enumerate(index_ranks)} 74 | logger.info("index_rank_to_id %s", self.index_rank_to_id) 75 | 76 | # pool of threads. Each thread manages one sub-index. 77 | self.pool = ThreadPool(self.num_indexes) 78 | self.verbose = False 79 | self.cur_server_ids = {} 80 | 81 | random.seed(time.time()) 82 | self.cfg = IndexCfg.from_json(cfg_path) if cfg_path is not None else None 83 | 84 | @staticmethod 85 | def read_server_list( 86 | server_list_path, 87 | initial_timeout=0.1, 88 | backoff_factor=1.5, 89 | total_max_timeout=7200, 90 | ) -> List[Tuple[str, int]]: 91 | time_waited = 0 92 | while True: 93 | with open(server_list_path) as f: 94 | res = [] # list of [(hostname, port), ...] 95 | for idx, line in enumerate(f): 96 | if idx == 0: 97 | num_servers = int(line) 98 | continue 99 | res.append((str(line.split(",")[0]), int(line.split(",")[1]))) 100 | 101 | msg = f"{num_servers} != {len(res)} in server list {server_list_path}." 102 | if num_servers != len(res): 103 | print(msg + f" Waiting {round(initial_timeout * 100) / 100} seconds for servers to load...") 104 | time.sleep(initial_timeout) 105 | 106 | if time_waited + initial_timeout >= total_max_timeout: 107 | break 108 | time_waited += initial_timeout 109 | initial_timeout *= backoff_factor 110 | 111 | assert num_servers == len(res), msg + f" Timed out after waiting {round(time_waited * 100) / 100} seconds" 112 | return res 113 | 114 | @staticmethod 115 | def setup_connection(machine_ports) -> List[rpc.Client]: 116 | sub_indexes = [] 117 | for idx, machine_port in enumerate(machine_ports): 118 | sub_indexes.append(rpc.Client(idx, machine_port[0], machine_port[1], False)) 119 | return sub_indexes 120 | 121 | def drop_index(self, index_id: str): 122 | self.pool.map(lambda idx: idx.drop_index(index_id), self.sub_indexes) 123 | 124 | def save_index(self, index_id: str): 125 | self.pool.map(lambda idx: idx.save_index(index_id), self.sub_indexes) 126 | 127 | def load_index( 128 | self, 129 | index_id: str, 130 | cfg: Optional[IndexCfg] = None, 131 | force_reload: bool = True, 132 | ) -> bool: 133 | def setup_cfg(cfg: Optional[IndexCfg]): 134 | if cfg is None: 135 | config_paths = self.pool.map(lambda idx: idx.get_config_path(index_id), self.sub_indexes) 136 | if len(config_paths) > 0 and os.path.isfile(config_paths[0]): 137 | cfg = IndexCfg.from_json(config_paths[0]) 138 | else: 139 | cfg = IndexCfg() 140 | return cfg 141 | 142 | if force_reload: 143 | logger.info("Forced index reload") 144 | self.pool.map(lambda idx: idx.drop_index(index_id), self.sub_indexes) 145 | all_loaded = self.pool.map(lambda idx: idx.load_index(index_id, cfg), self.sub_indexes) 146 | # TODO: remove cfg as client instance attribute, it should be specific index only cfg 147 | self.cfg = setup_cfg(cfg) 148 | logger.info(f"Index cfg {self.cfg}") 149 | 150 | if all(all_loaded): 151 | logger.info(f"The index {index_id} is loaded.") 152 | return True 153 | if any(all_loaded): 154 | logger.warning(f"Some server nodes can't load index: {all_loaded}") 155 | return False 156 | 157 | def create_index(self, index_id: str, cfg: Optional[IndexCfg] = None): 158 | if cfg is not None: 159 | self.cfg = cfg 160 | if self.cfg is None: 161 | self.cfg = IndexCfg() 162 | return self.pool.map(lambda idx: idx.create_index(index_id, self.cfg), self.sub_indexes) 163 | 164 | def add_index_data( 165 | self, 166 | index_id: str, 167 | embeddings: np.array, 168 | metadata: Optional[List[object]] = None, 169 | train_async_if_triggered: bool = True, 170 | ) -> None: 171 | """ 172 | Randomly select index of server to write data to first time 173 | writing to it. later use round-robin to select index server 174 | to balance the load. 175 | """ 176 | if index_id not in self.cur_server_ids: 177 | self.cur_server_ids[index_id] = random.randint(0, self.num_indexes - 1) 178 | cur_server_id = self.cur_server_ids[index_id] 179 | self.sub_indexes[cur_server_id].add_index_data(index_id, embeddings, metadata, train_async_if_triggered) 180 | self.cur_server_ids[index_id] = (self.cur_server_ids[index_id] + 1) % self.num_indexes 181 | 182 | def sync_train(self, index_id: str) -> None: 183 | self.pool.map(lambda idx: idx.sync_train(index_id), self.sub_indexes) 184 | 185 | def async_train(self, index_id: str): 186 | self.pool.map(lambda idx: idx.sync_train(index_id), self.sub_indexes) 187 | 188 | def search(self, query, topk: int, index_id: str, return_embeddings: bool = False) -> Tuple[np.ndarray, List]: 189 | """Call idx.search on each rpc client and then aggregates results using heap.""" 190 | 191 | q_size = query.shape[0] 192 | maximize_metric: bool = self.cfg.metric == "dot" 193 | results = self.pool.imap(lambda idx: idx.search(index_id, query, topk, return_embeddings), self.sub_indexes) 194 | return self._aggregate_results(results, topk, q_size, maximize_metric, return_embeddings) 195 | 196 | # TODO: make filter a generic custom function that takes meta and sample's score and return bool value 197 | def search_with_filter( 198 | self, 199 | query: np.array, 200 | top_k: int, 201 | index_id: str, 202 | filter_pos: int = -1, 203 | filter_value=None, 204 | ) -> Tuple[np.array, List[List[object]]]: 205 | 206 | # TODO: get from cfg ? 207 | filter_top_factor = 3 # search for x times more results 208 | actual_top_k = filter_top_factor * top_k if filter_pos >= 0 else top_k 209 | (scores, meta) = self.search(query, actual_top_k, index_id) 210 | 211 | if filter_pos < 0: 212 | return scores, meta 213 | 214 | def _do_filter(scores: np.array, results_meta: List[List[Tuple]]): 215 | re_query_ids = [] 216 | new_results = [] 217 | new_scores = [] 218 | 219 | for i, meta_list in enumerate(results_meta): 220 | sample_filtered_meta = [] 221 | sample_filtered_scores = [] 222 | for j, meta in enumerate(meta_list): 223 | if not meta: 224 | logger.warning("No meta for j=%d, score=%s", j, scores[i, j]) 225 | continue 226 | if len(meta) > filter_pos and meta[filter_pos] != filter_value: 227 | sample_filtered_meta.append(meta) 228 | sample_filtered_scores.append(scores[i, j]) 229 | if len(sample_filtered_meta) >= top_k: 230 | break 231 | 232 | if len(sample_filtered_meta) < top_k: 233 | re_query_ids.append(i) 234 | 235 | new_results.append(sample_filtered_meta) 236 | if len(sample_filtered_scores) > 0: 237 | new_scores.append(np.concatenate([s.reshape(-1, 1) for s in sample_filtered_scores], axis=0)) 238 | # TODO: filtered scores list may be of different lengths so we can't concatenate them here generally 239 | # new_scores = np.concatenate(new_scores, axis=0) 240 | return new_scores, new_results, re_query_ids 241 | 242 | new_scores, new_results_meta, re_query_ids = _do_filter(scores, meta) 243 | logger.info(f"{len(re_query_ids)} samples return less than {top_k} results after filtering") 244 | # TODO: search again with larger top_k for queries in re_query_ids 245 | 246 | return new_scores, new_results_meta 247 | 248 | @staticmethod 249 | def _aggregate_results( 250 | results: List[Tuple], 251 | topk: int, 252 | q_size: int, 253 | maximize_metric: bool, 254 | return_embeddings: bool, 255 | ): 256 | meta = [] 257 | embs = [] 258 | cur_idx = 0 259 | 260 | def to_matrix(l, n): 261 | return [l[i : i + n] for i in range(0, len(l), n)] 262 | 263 | res_heap = ResultHeap(q_size, topk) 264 | for DI, MetaI, e in results: 265 | # Two hacks: 266 | # 1) for DOT, search for -D because we want to find max dot product, not min 267 | # 2) create new indexed to later map metadata to the best indexes 268 | merged_meta = list(itertools.chain(*MetaI)) 269 | meta.extend(merged_meta) 270 | if return_embeddings: 271 | merged_embs = list(itertools.chain(*e)) 272 | embs.extend(merged_embs) 273 | Ii = np.reshape(np.arange(cur_idx, cur_idx + q_size * topk), (q_size, topk)) 274 | if maximize_metric: 275 | res_heap.add_result(-DI, Ii) 276 | else: 277 | res_heap.add_result(DI, Ii) 278 | cur_idx += q_size * topk 279 | res_heap.finalize() 280 | ids = np.reshape(res_heap.I, (-1,)).tolist() 281 | selected_meta = [meta[i] for i in ids] 282 | if return_embeddings: 283 | selected_embs = [embs[i] for i in ids] 284 | 285 | return ( 286 | ( 287 | res_heap.D, 288 | to_matrix(selected_meta, res_heap.D.shape[1]), 289 | to_matrix(selected_embs, res_heap.D.shape[1]), 290 | ) 291 | if return_embeddings 292 | else (res_heap.D, to_matrix(selected_meta, res_heap.D.shape[1])) 293 | ) 294 | 295 | def get_centroids(self, index_id: str): 296 | return self.pool.map(lambda idx: idx.get_centroids(index_id), self.sub_indexes) 297 | 298 | def set_nprobe(self, index_id: str, nprobe: int): 299 | return self.pool.map(lambda idx: idx.set_nprobe(index_id, nprobe), self.sub_indexes) 300 | 301 | def get_state(self, index_id: str) -> IndexState: 302 | states = self.pool.map(lambda idx: idx.get_state(index_id), self.sub_indexes) 303 | logging.info("Index nodes states %s", states) 304 | return IndexState.get_aggregated_states(states) 305 | 306 | def add_buffer_to_index( 307 | self, 308 | index_id: str, 309 | ): 310 | self.pool.map(lambda idx: idx.add_buffer_to_index(index_id), self.sub_indexes) 311 | 312 | def get_ntotal(self, index_id: str) -> None: 313 | return sum(self.pool.map(lambda idx: idx.get_ntotal(index_id), self.sub_indexes)) 314 | 315 | def get_ids(self, index_id: str) -> set: 316 | id_set_list = self.pool.map(lambda idx: idx.get_ids(index_id), self.sub_indexes) 317 | for i in range(len(id_set_list)): 318 | logging.info("ids i=%s len=%s", i, len(id_set_list[i])) 319 | return set().union(*id_set_list) 320 | 321 | def set_omp_num_threads(self, num_threads: int) -> None: 322 | self.pool.map(lambda idx: idx.set_omp_num_threads(num_threads), self.sub_indexes) 323 | 324 | def close(self): 325 | [index_conn.close() for index_conn in self.sub_indexes] 326 | 327 | def get_num_servers(self): 328 | return self.num_indexes 329 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/indexer/faiss_indexers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | FAISS-based index components for dense retriever 10 | """ 11 | 12 | import faiss 13 | import logging 14 | import numpy as np 15 | import os 16 | import pickle 17 | 18 | from typing import List, Tuple 19 | 20 | logger = logging.getLogger() 21 | 22 | 23 | class DenseIndexer(object): 24 | def __init__(self, buffer_size: int = 50000): 25 | self.buffer_size = buffer_size 26 | self.index_id_to_db_id = [] 27 | self.index = None 28 | 29 | def init_index(self, vector_sz: int): 30 | raise NotImplementedError 31 | 32 | def index_data(self, data: List[Tuple[object, np.array]]): 33 | raise NotImplementedError 34 | 35 | def get_index_name(self): 36 | raise NotImplementedError 37 | 38 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 39 | raise NotImplementedError 40 | 41 | def serialize(self, file: str): 42 | logger.info("Serializing index to %s", file) 43 | 44 | if os.path.isdir(file): 45 | index_file = os.path.join(file, "index.dpr") 46 | meta_file = os.path.join(file, "index_meta.dpr") 47 | else: 48 | index_file = file + ".index.dpr" 49 | meta_file = file + ".index_meta.dpr" 50 | 51 | faiss.write_index(self.index, index_file) 52 | with open(meta_file, mode="wb") as f: 53 | pickle.dump(self.index_id_to_db_id, f) 54 | 55 | def get_files(self, path: str): 56 | if os.path.isdir(path): 57 | index_file = os.path.join(path, "index.dpr") 58 | meta_file = os.path.join(path, "index_meta.dpr") 59 | else: 60 | index_file = path + ".{}.dpr".format(self.get_index_name()) 61 | meta_file = path + ".{}_meta.dpr".format(self.get_index_name()) 62 | return index_file, meta_file 63 | 64 | def index_exists(self, path: str): 65 | index_file, meta_file = self.get_files(path) 66 | return os.path.isfile(index_file) and os.path.isfile(meta_file) 67 | 68 | def deserialize(self, path: str): 69 | logger.info("Loading index from %s", path) 70 | index_file, meta_file = self.get_files(path) 71 | 72 | self.index = faiss.read_index(index_file) 73 | logger.info("Loaded index of type %s and size %d", type(self.index), self.index.ntotal) 74 | 75 | with open(meta_file, "rb") as reader: 76 | self.index_id_to_db_id = pickle.load(reader) 77 | assert ( 78 | len(self.index_id_to_db_id) == self.index.ntotal 79 | ), "Deserialized index_id_to_db_id should match faiss index size" 80 | 81 | def _update_id_mapping(self, db_ids: List) -> int: 82 | self.index_id_to_db_id.extend(db_ids) 83 | return len(self.index_id_to_db_id) 84 | 85 | 86 | class DenseFlatIndexer(DenseIndexer): 87 | def __init__(self, buffer_size: int = 50000): 88 | super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) 89 | 90 | def init_index(self, vector_sz: int): 91 | self.index = faiss.IndexFlatIP(vector_sz) 92 | 93 | def index_data(self, data: List[Tuple[object, np.array]]): 94 | n = len(data) 95 | # indexing in batches is beneficial for many faiss index types 96 | for i in range(0, n, self.buffer_size): 97 | db_ids = [t[0] for t in data[i : i + self.buffer_size]] 98 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size]] 99 | vectors = np.concatenate(vectors, axis=0) 100 | total_data = self._update_id_mapping(db_ids) 101 | self.index.add(vectors) 102 | logger.info("data indexed %d", total_data) 103 | 104 | indexed_cnt = len(self.index_id_to_db_id) 105 | logger.info("Total data indexed %d", indexed_cnt) 106 | 107 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 108 | scores, indexes = self.index.search(query_vectors, top_docs) 109 | # convert to external ids 110 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 111 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 112 | return result 113 | 114 | def get_index_name(self): 115 | return "flat_index" 116 | 117 | 118 | class DenseHNSWFlatIndexer(DenseIndexer): 119 | """ 120 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 121 | """ 122 | 123 | def __init__( 124 | self, 125 | buffer_size: int = 1e9, 126 | store_n: int = 512, 127 | ef_search: int = 128, 128 | ef_construction: int = 200, 129 | ): 130 | super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) 131 | self.store_n = store_n 132 | self.ef_search = ef_search 133 | self.ef_construction = ef_construction 134 | self.phi = 0 135 | 136 | def init_index(self, vector_sz: int): 137 | # IndexHNSWFlat supports L2 similarity only 138 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 139 | index = faiss.IndexHNSWFlat(vector_sz + 1, self.store_n) 140 | index.hnsw.efSearch = self.ef_search 141 | index.hnsw.efConstruction = self.ef_construction 142 | self.index = index 143 | 144 | def index_data(self, data: List[Tuple[object, np.array]]): 145 | n = len(data) 146 | 147 | # max norm is required before putting all vectors in the index to convert inner product similarity to L2 148 | if self.phi > 0: 149 | raise RuntimeError( 150 | "DPR HNSWF index needs to index all data at once," "results will be unpredictable otherwise." 151 | ) 152 | phi = 0 153 | for i, item in enumerate(data): 154 | id, doc_vector = item[0:2] 155 | norms = (doc_vector**2).sum() 156 | phi = max(phi, norms) 157 | logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi)) 158 | self.phi = phi 159 | 160 | # indexing in batches is beneficial for many faiss index types 161 | bs = int(self.buffer_size) 162 | for i in range(0, n, bs): 163 | db_ids = [t[0] for t in data[i : i + bs]] 164 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + bs]] 165 | 166 | norms = [(doc_vector**2).sum() for doc_vector in vectors] 167 | aux_dims = [np.sqrt(phi - norm) for norm in norms] 168 | hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in enumerate(vectors)] 169 | hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) 170 | self.train(hnsw_vectors) 171 | 172 | self._update_id_mapping(db_ids) 173 | self.index.add(hnsw_vectors) 174 | logger.info("data indexed %d", len(self.index_id_to_db_id)) 175 | indexed_cnt = len(self.index_id_to_db_id) 176 | logger.info("Total data indexed %d", indexed_cnt) 177 | 178 | def train(self, vectors: np.array): 179 | pass 180 | 181 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 182 | 183 | aux_dim = np.zeros(len(query_vectors), dtype="float32") 184 | query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 185 | logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape) 186 | scores, indexes = self.index.search(query_nhsw_vectors, top_docs) 187 | # convert to external ids 188 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 189 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 190 | return result 191 | 192 | def deserialize(self, file: str): 193 | super(DenseHNSWFlatIndexer, self).deserialize(file) 194 | # to trigger exception on subsequent indexing 195 | self.phi = 1 196 | 197 | def get_index_name(self): 198 | return "hnsw_index" 199 | 200 | 201 | class DenseHNSWSQIndexer(DenseHNSWFlatIndexer): 202 | """ 203 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 204 | """ 205 | 206 | def __init__( 207 | self, 208 | buffer_size: int = 1e10, 209 | store_n: int = 128, 210 | ef_search: int = 128, 211 | ef_construction: int = 200, 212 | ): 213 | super(DenseHNSWSQIndexer, self).__init__( 214 | buffer_size=buffer_size, 215 | store_n=store_n, 216 | ef_search=ef_search, 217 | ef_construction=ef_construction, 218 | ) 219 | 220 | def init_index(self, vector_sz: int): 221 | # IndexHNSWFlat supports L2 similarity only 222 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 223 | index = faiss.IndexHNSWSQ(vector_sz + 1, faiss.ScalarQuantizer.QT_8bit, self.store_n) 224 | index.hnsw.efSearch = self.ef_search 225 | index.hnsw.efConstruction = self.ef_construction 226 | self.index = index 227 | 228 | def train(self, vectors: np.array): 229 | self.index.train(vectors) 230 | 231 | def get_index_name(self): 232 | return "hnswsq_index" 233 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import importlib 9 | 10 | """ 11 | 'Router'-like set of methods for component initialization with lazy imports 12 | """ 13 | 14 | 15 | def init_hf_bert_biencoder(args, **kwargs): 16 | # if importlib.util.find_spec("transformers") is None: 17 | # raise RuntimeError("Please install transformers lib") 18 | from .hf_models import get_hf_biencoder_components 19 | 20 | return get_hf_biencoder_components("bert", args, **kwargs) 21 | 22 | 23 | def init_hf_roberta_biencoder(args, **kwargs): 24 | # if importlib.util.find_spec("transformers") is None: 25 | # raise RuntimeError("Please install transformers lib") 26 | from .hf_models import get_hf_biencoder_components 27 | 28 | return get_hf_biencoder_components("roberta", args, **kwargs) 29 | 30 | 31 | def init_hf_deberta_biencoder(args, **kwargs): 32 | # if importlib.util.find_spec("transformers") is None: 33 | # raise RuntimeError("Please install transformers lib") 34 | from .hf_models import get_hf_biencoder_components 35 | 36 | return get_hf_biencoder_components("deberta", args, **kwargs) 37 | 38 | 39 | def init_hf_bert_tenzorizer(args, **kwargs): 40 | if importlib.util.find_spec("transformers") is None: 41 | raise RuntimeError("Please install transformers lib") 42 | from .hf_models import get_bert_tensorizer 43 | 44 | return get_bert_tensorizer(args) 45 | 46 | 47 | def init_hf_roberta_tenzorizer(args, **kwargs): 48 | if importlib.util.find_spec("transformers") is None: 49 | raise RuntimeError("Please install transformers lib") 50 | from .hf_models import get_roberta_tensorizer 51 | 52 | return get_roberta_tensorizer(args.encoder.pretrained_model_cfg, args.do_lower_case, args.encoder.sequence_length) 53 | 54 | 55 | BIENCODER_INITIALIZERS = { 56 | "hf_bert": init_hf_bert_biencoder, 57 | "hf_roberta": init_hf_roberta_biencoder, 58 | "hf_deberta": init_hf_deberta_biencoder, 59 | } 60 | 61 | TENSORIZER_INITIALIZERS = { 62 | "hf_bert": init_hf_bert_tenzorizer, 63 | "hf_roberta": init_hf_roberta_tenzorizer, 64 | "pytext_bert": init_hf_bert_tenzorizer, # using HF's code as of now 65 | } 66 | 67 | 68 | def init_comp(initializers_dict, type, args, **kwargs): 69 | if type in initializers_dict: 70 | return initializers_dict[type](args, **kwargs) 71 | else: 72 | raise RuntimeError("unsupported model type: {}".format(type)) 73 | 74 | 75 | def init_biencoder_components(encoder_type: str, args, **kwargs): 76 | return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs) 77 | 78 | 79 | def init_tenzorizer(encoder_type: str, args, **kwargs): 80 | return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs) 81 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/models/biencoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | BiEncoder component + loss function for 'all-in-batch' training 10 | """ 11 | import collections 12 | import logging 13 | from typing import Tuple 14 | 15 | import torch 16 | from torch import Tensor as T 17 | from torch import nn 18 | 19 | from dpr.models.ranker_loss import PassageLevelRankerNllLoss, RankerLoss, DocumentLevelRankerNllLoss 20 | from dpr.utils.model_utils import CheckpointState 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | QueryContextsBatch = collections.namedtuple( 26 | "QueryContextsBatch", 27 | [ 28 | "question_ids", 29 | "question_segments", 30 | "context_ids", 31 | "ctx_segments", 32 | "is_positive", 33 | "hard_negatives", 34 | "encoder_type", 35 | ], 36 | ) 37 | 38 | 39 | class BiEncoder(nn.Module): 40 | """Bi-Encoder model component. Encapsulates query/question and context/passage encoders.""" 41 | 42 | def __init__( 43 | self, 44 | cfg, 45 | question_model: nn.Module, 46 | ctx_model: nn.Module, 47 | fix_q_encoder: bool = False, 48 | fix_ctx_encoder: bool = False, 49 | ): 50 | super(BiEncoder, self).__init__() 51 | self.cfg = cfg 52 | self.question_model = question_model 53 | self.ctx_model = ctx_model 54 | self.fix_q_encoder = fix_q_encoder 55 | self.fix_ctx_encoder = fix_ctx_encoder 56 | logger.info("!!! fix_ctx_encoder=%s", fix_ctx_encoder) 57 | 58 | @staticmethod 59 | def get_representation( 60 | sub_model: nn.Module, 61 | ids: T, 62 | segments: T, 63 | attn_mask: T, 64 | fix_encoder: bool = False, 65 | representation_token_pos=0, 66 | ) -> (T, T, T): 67 | sequence_output = None 68 | pooled_output = None 69 | hidden_states = None 70 | if ids is not None: 71 | if fix_encoder: 72 | with torch.no_grad(): 73 | sequence_output, pooled_output, hidden_states = sub_model( 74 | ids, 75 | segments, 76 | attn_mask, 77 | representation_token_pos=representation_token_pos, 78 | ) 79 | 80 | if sub_model.training: 81 | sequence_output.requires_grad_(requires_grad=True) 82 | pooled_output.requires_grad_(requires_grad=True) 83 | else: 84 | sequence_output, pooled_output, hidden_states = sub_model( 85 | ids, 86 | segments, 87 | attn_mask, 88 | representation_token_pos=representation_token_pos, 89 | ) 90 | 91 | return sequence_output, pooled_output, hidden_states 92 | 93 | def forward( 94 | self, 95 | question_ids: T, 96 | question_segments: T, 97 | question_attn_mask: T, 98 | context_ids: T, 99 | ctx_segments: T, 100 | ctx_attn_mask: T, 101 | encoder_type: str = None, 102 | representation_token_pos=0, 103 | ) -> Tuple[T, T]: 104 | 105 | q_encoder = self.question_model if encoder_type is None or encoder_type == "question" else self.ctx_model 106 | _q_seq, q_pooled_out, _q_hidden = self.get_representation( 107 | q_encoder, 108 | question_ids, 109 | question_segments, 110 | question_attn_mask, 111 | self.fix_q_encoder, 112 | representation_token_pos=representation_token_pos, 113 | ) 114 | 115 | ctx_encoder = self.ctx_model if encoder_type is None or encoder_type == "ctx" else self.question_model 116 | _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( 117 | ctx_encoder, context_ids, ctx_segments, ctx_attn_mask, self.fix_ctx_encoder 118 | ) 119 | 120 | return q_pooled_out, ctx_pooled_out 121 | 122 | def load_state(self, saved_state: CheckpointState, strict: bool = True): 123 | self.load_state_dict(saved_state.model_dict, strict=strict) 124 | 125 | def get_state_dict(self): 126 | return self.state_dict() 127 | 128 | def get_loss_function(self) -> RankerLoss: 129 | return PassageLevelRankerNllLoss(self.cfg) 130 | 131 | def prepare_model_inputs(self, batch: dict, dataset): 132 | return { 133 | "question_ids": batch["question_ids"], 134 | "question_segments": batch["question_segments"], 135 | "question_attn_mask": dataset.tensorizer.get_attn_mask(batch["question_ids"]), 136 | "context_ids": batch["context_ids"], 137 | "ctx_segments": batch["ctx_segments"], 138 | "ctx_attn_mask": dataset.tensorizer.get_attn_mask(batch["context_ids"]), 139 | } 140 | 141 | 142 | class DocumentBiEncoder(BiEncoder): 143 | def __init__( 144 | self, 145 | cfg, 146 | question_model: nn.Module, 147 | ctx_model: nn.Module, 148 | fix_q_encoder: bool = False, 149 | fix_ctx_encoder: bool = False, 150 | ): 151 | super().__init__( 152 | cfg=cfg, 153 | question_model=question_model, 154 | ctx_model=ctx_model, 155 | fix_q_encoder=fix_q_encoder, 156 | fix_ctx_encoder=fix_ctx_encoder, 157 | ) 158 | 159 | def get_loss_function(self) -> RankerLoss: 160 | return DocumentLevelRankerNllLoss(self.cfg) 161 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/models/ranker_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import Tensor as T 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T: 19 | """ 20 | calculates q->ctx scores for every row in ctx_vector 21 | :param q_vector: 22 | :param ctx_vector: 23 | :return: 24 | """ 25 | # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 26 | r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) 27 | return r 28 | 29 | 30 | def cosine_scores(q_vector: T, ctx_vectors: T): 31 | # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 32 | return F.cosine_similarity(q_vector, ctx_vectors, dim=1) 33 | 34 | 35 | class RankerLoss(object): 36 | def __init__(self, cfg): 37 | self.cfg = cfg 38 | 39 | def calc( 40 | self, 41 | q_vectors: T, 42 | ctx_vectors: T, 43 | positive_idx_per_question: list, 44 | global_step: int, 45 | hard_negative_idx_per_question: list = None, 46 | loss_scale: float = None, 47 | ) -> Tuple[T, int]: 48 | raise NotImplementedError 49 | 50 | @staticmethod 51 | def get_scores(q_vector: T, ctx_vectors: T) -> T: 52 | raise NotImplementedError 53 | 54 | @staticmethod 55 | def get_similarity_function(): 56 | raise NotImplementedError 57 | 58 | 59 | class PassageLevelRankerNllLoss(RankerLoss): 60 | def __init__(self, cfg): 61 | super().__init__(cfg) 62 | self.cfg = cfg 63 | 64 | def calc( 65 | self, 66 | q_vectors: T, 67 | ctx_vectors: T, 68 | positive_idx_per_question: list, 69 | global_step: int, 70 | hard_negative_idx_per_question: list = None, 71 | loss_scale: float = None, 72 | ) -> Tuple[T, int]: 73 | """ 74 | Computes nll loss for the given lists of question and ctx vectors. 75 | :return: a tuple of loss value and amount of correct predictions per batch 76 | """ 77 | scores = self.get_scores(q_vectors, ctx_vectors) 78 | if isinstance(positive_idx_per_question, list): 79 | positive_idx_per_question = torch.tensor(positive_idx_per_question) 80 | 81 | if len(q_vectors.size()) > 1: 82 | q_num = q_vectors.size(0) 83 | scores = scores.view(q_num, -1) 84 | 85 | softmax_scores = F.log_softmax(scores, dim=1) 86 | 87 | loss = F.nll_loss( 88 | softmax_scores, 89 | positive_idx_per_question.to(softmax_scores.device), 90 | reduction="mean", 91 | ) 92 | 93 | max_score, max_idxs = torch.max(softmax_scores, 1) 94 | correct_predictions_count = (max_idxs == positive_idx_per_question.to(max_idxs.device)).sum() 95 | 96 | if loss_scale: 97 | loss.mul_(loss_scale) 98 | 99 | return loss, correct_predictions_count 100 | 101 | @staticmethod 102 | def get_scores(q_vector: T, ctx_vectors: T) -> T: 103 | f = PassageLevelRankerNllLoss.get_similarity_function() 104 | return f(q_vector, ctx_vectors) 105 | 106 | @staticmethod 107 | def get_similarity_function(): 108 | return dot_product_scores 109 | 110 | 111 | class DocumentLevelRankerNllLoss(PassageLevelRankerNllLoss): 112 | def calc( 113 | self, 114 | q_vectors: T, 115 | ctx_vectors: T, 116 | positive_mask: list, 117 | global_step: int, 118 | hard_negative_idx_per_question: list = None, 119 | loss_scale: float = None, 120 | ) -> Tuple[T, int]: 121 | """ 122 | Computes nll loss for the given lists of question and ctx vectors. Difference to the PassageLevelRankerNllLoss 123 | is that there are multiple positive passages per training instance and we do not have supervision on which 124 | of the positive passage is the correct one that contains the relevant text. This class contains a few different 125 | strategies to handle this kind of case. They mostly revolve around using the score of the question and passage 126 | encoder to weight the softmax loss according to the model's current belief, i.e. the expectation maximisation 127 | algorithm (EM). 128 | 129 | "hard_em": pick the loss with the positive passage that receives the highest scores amongst all positive passages 130 | "first_passage": pick the first positive passage (in document order) 131 | "random_passage": pick a random passage 132 | "soft_em": weight the losses of the positive passage according to their scores amongst all positive passages 133 | "random": randomly weight all the passages 134 | "uniform": uniformly weight all the passages 135 | 136 | :return: a tuple of loss value and amount of correct predictions per batch 137 | """ 138 | scores = self.get_scores(q_vectors, ctx_vectors) 139 | 140 | if len(q_vectors.size()) > 1: 141 | q_num = q_vectors.size(0) 142 | scores = scores.view(q_num, -1) 143 | 144 | loss, correct_predictions_count = None, None 145 | 146 | if self.cfg.loss_strategy in {"hard_em", "first_passage", "random_passage"}: 147 | 148 | positive_idx_per_question = None 149 | 150 | if self.cfg.loss_strategy == "hard_em": 151 | e_step_scores = scores.clone().detach() 152 | e_step_scores[~positive_mask] = -1.0e6 153 | positive_idx_per_question = e_step_scores.max(-1).indices 154 | 155 | elif self.cfg.loss_strategy == "first_passage": 156 | true_coords = positive_mask.nonzero() 157 | first_pass_idx = torch.ones_like(positive_mask, dtype=torch.long) * 10_000 158 | first_pass_idx[true_coords[:, 0], true_coords[:, 1]] = true_coords[:, 1] 159 | positive_idx_per_question = first_pass_idx.min(1).indices 160 | 161 | elif self.cfg.loss_strategy == "random_passage": 162 | e_step_scores = torch.randn(scores.size()) 163 | e_step_scores[~positive_mask] = -1.0e6 164 | positive_idx_per_question = e_step_scores.max(-1).indices 165 | 166 | positive_mask[torch.arange(len(positive_idx_per_question)), positive_idx_per_question] = False 167 | scores_mask = torch.ones_like(scores, requires_grad=False) 168 | scores_mask[positive_mask] = -1.0e6 169 | scores_with_mask = scores + scores_mask 170 | softmax_scores = torch.log_softmax(scores_with_mask, dim=1) 171 | 172 | loss = -(softmax_scores[torch.arange(len(positive_idx_per_question)), positive_idx_per_question]).sum() 173 | 174 | elif self.cfg.loss_strategy in { 175 | "soft_em", 176 | "random", 177 | "uniform", 178 | }: 179 | 180 | # positive_mask = torch.tensor([ 181 | # [True, True, False, False, False, False, False, ], 182 | # [False, False, False, True, True, True, False, ], 183 | # ]) 184 | 185 | true_coords = positive_mask.nonzero() 186 | # true_coords = 187 | # tensor([[0, 0], 188 | # [0, 1], 189 | # [1, 3], 190 | # [1, 4], 191 | # [1, 5]]) 192 | 193 | positive_mask_scores = positive_mask[true_coords[:, 0]] 194 | positive_mask_scores[torch.arange(positive_mask_scores.size(0)), true_coords[:, 1]] = False 195 | # Mask out the positive scores for each instance 196 | # positive_mask_scores = 197 | # tensor([[False, True, False, False, False, False, False], 198 | # [True, False, False, False, False, False, False], 199 | # [False, False, False, False, True, True, False], 200 | # [False, False, False, True, False, True, False], 201 | # [False, False, False, True, True, False, False]]) 202 | 203 | scores_softmax_mask = torch.ones_like(positive_mask_scores, dtype=torch.float, device=scores.device) 204 | scores_softmax_mask[positive_mask_scores] = 0 205 | 206 | scores_with_mask = scores[true_coords[:, 0]] * scores_softmax_mask 207 | log_softmax_scores = torch.log_softmax(scores_with_mask, dim=-1) 208 | # log_softmax_scores = torch.log( 209 | # tensor([[0.2626, 0.0000, 0.0705, 0.2840, 0.0632, 0.0606, 0.2590], 210 | # [0.0000, 0.2288, 0.0738, 0.2971, 0.0661, 0.0634, 0.2709], 211 | # [0.0506, 0.3419, 0.5314, 0.0520, 0.0000, 0.0000, 0.0241], 212 | # [0.0469, 0.3170, 0.4926, 0.0000, 0.1211, 0.0000, 0.0224], 213 | # [0.0438, 0.2960, 0.4600, 0.0000, 0.0000, 0.1793, 0.0209]])) 214 | log_softmax_scores = log_softmax_scores[ 215 | torch.arange(positive_mask_scores.size(0), device=positive_mask.device), true_coords[:, 1] 216 | ] 217 | 218 | if self.cfg.loss_strategy == "soft_em": 219 | e_step_scores = scores.detach().clone() 220 | # apply temperature 221 | if global_step and self.cfg.loss_strategy == "soft_em" and self.cfg.soft_em.warmup_steps > 0: 222 | global_step_clamped = min(global_step, self.cfg.soft_em.warmup_steps) 223 | temperature = (self.cfg.soft_em.end_temperature - self.cfg.soft_em.start_temperature) * ( 224 | global_step_clamped / self.cfg.soft_em.warmup_steps 225 | ) + self.cfg.soft_em.start_temperature 226 | e_step_scores = e_step_scores * temperature 227 | softmax_mask = torch.zeros_like(positive_mask, dtype=torch.float, device=e_step_scores.device) 228 | softmax_mask[~positive_mask] = -1.0e6 229 | e_step_probs = torch.softmax(e_step_scores + softmax_mask, dim=-1) 230 | # e_step_probs = 231 | # tensor([[0.6502, 0.3498, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], 232 | # [0.0000, 0.0000, 0.0000, 0.1621, 0.7200, 0.1179, 0.0000]]) 233 | weighted_logsoftmax = log_softmax_scores * e_step_probs[true_coords[:, 0], true_coords[:, 1]] 234 | loss = -(weighted_logsoftmax).sum() 235 | elif self.cfg.loss_strategy == "random": 236 | random_weights = torch.randn_like(scores[true_coords[:, 0], true_coords[:, 1]].view(-1, 1)) 237 | random_weights = torch.softmax(random_weights, dim=-1) 238 | weighted_logsoftmax = log_softmax_scores * random_weights 239 | loss = -(weighted_logsoftmax).sum() 240 | elif self.cfg.loss_strategy == "uniform": 241 | uniform = torch.ones_like(scores) 242 | softmax_mask = torch.zeros_like(positive_mask, dtype=torch.float, device=uniform.device) 243 | softmax_mask[~positive_mask] = -1.0e6 244 | uniform = torch.softmax(uniform + softmax_mask, dim=-1) 245 | # e_step_probs = 246 | # tensor([[0.6502, 0.3498, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], 247 | # [0.0000, 0.0000, 0.0000, 0.1621, 0.7200, 0.1179, 0.0000]]) 248 | weighted_logsoftmax = log_softmax_scores * uniform[true_coords[:, 0], true_coords[:, 1]] 249 | loss = -(weighted_logsoftmax).sum() 250 | 251 | # weighted_logsoftmax[torch.arange(positive_mask_all_scores.size(0)), true_coords[:, 1]] = 252 | # tensor([-65024.1797, -34975.6328, -16210.2422, -71997.6328, -11791.9492]) 253 | 254 | else: 255 | raise Exception(f"Unkownn loss_strategy {self.cfg.loss_strategy}") 256 | 257 | # if hasattr(self.cfg, "regularize") and self.cfg.regularize.passages == "spread": 258 | # ctx_vectors * ctx_vectors.t() 259 | 260 | # is the highest scored true passage also the highest scored overall passage? 261 | max_score, max_idxs = torch.max(scores, 1) 262 | e_step_scores = scores.clone().detach() 263 | e_step_scores[~positive_mask] = -1.0e6 264 | positive_idx_per_question = e_step_scores.max(-1).indices 265 | correct_predictions_count = (max_idxs == positive_idx_per_question.to(max_idxs.device)).sum() 266 | 267 | if loss_scale: 268 | loss.mul_(loss_scale) 269 | 270 | return loss, correct_predictions_count 271 | 272 | 273 | class RerankerNllLoss(DocumentLevelRankerNllLoss): 274 | """ 275 | The RerankerNllLoss expects that the forward function returns a fake output for the q_vector s and a tensor with 276 | scores for each query x contexts. The q_vector is mainly used to keep track of how many questions the batch had. 277 | TODO: It would be better to create a RankerOutput class that keeps track of this 278 | """ 279 | 280 | @staticmethod 281 | def get_scores(q_vector: T, ctx_vectors: T) -> T: 282 | f = RerankerNllLoss.get_similarity_function() 283 | return f(ctx_vectors) 284 | 285 | @staticmethod 286 | def get_similarity_function(): 287 | def similarity_function(ctx_vectors): 288 | """ 289 | calculates q->ctx scores for every row in ctx_vector 290 | :param q_vector: 291 | :param ctx_vector: 292 | :return: 293 | """ 294 | # q_pooled_out, None 295 | 296 | return ctx_vectors 297 | 298 | return similarity_function 299 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/models/reranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | BiEncoder component + loss function for 'all-in-batch' training 10 | """ 11 | 12 | import logging 13 | from typing import Tuple 14 | 15 | import torch 16 | from torch import Tensor as T 17 | from torch import nn 18 | 19 | from dpr.models.biencoder import RankerLoss, DocumentBiEncoder 20 | from dpr.models.ranker_loss import RerankerNllLoss 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class ColbertRerankerBiEncoder(DocumentBiEncoder): 26 | def __init__( 27 | self, 28 | cfg, 29 | question_model: nn.Module, 30 | ctx_model: nn.Module, 31 | fix_q_encoder: bool = False, 32 | fix_ctx_encoder: bool = False, 33 | ): 34 | super().__init__( 35 | cfg=cfg, 36 | question_model=question_model, 37 | ctx_model=ctx_model, 38 | fix_q_encoder=fix_q_encoder, 39 | fix_ctx_encoder=fix_ctx_encoder, 40 | ) 41 | 42 | def get_interactions( 43 | self, 44 | question_ids: T, 45 | question_segments: T, 46 | question_attn_mask: T, 47 | context_ids: T, 48 | ctx_segments: T, 49 | ctx_attn_mask: T, 50 | encoder_type: str = None, 51 | representation_token_pos=0, 52 | ): 53 | q_encoder = self.question_model if encoder_type is None or encoder_type == "question" else self.ctx_model 54 | _q_seq, q_pooled_out, _q_hidden = self.get_representation( 55 | q_encoder, 56 | question_ids, 57 | question_segments, 58 | question_attn_mask, 59 | self.fix_q_encoder, 60 | representation_token_pos=representation_token_pos, 61 | ) 62 | 63 | ctx_encoder = self.ctx_model if encoder_type is None or encoder_type == "ctx" else self.question_model 64 | _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( 65 | ctx_encoder, context_ids, ctx_segments, ctx_attn_mask, self.fix_ctx_encoder 66 | ) 67 | 68 | return torch.einsum("qih,cjh->qcij", _q_seq, _ctx_seq) 69 | 70 | def forward( 71 | self, 72 | question_ids: T, 73 | question_segments: T, 74 | question_attn_mask: T, 75 | context_ids: T, 76 | ctx_segments: T, 77 | ctx_attn_mask: T, 78 | encoder_type: str = None, 79 | representation_token_pos=0, 80 | ) -> Tuple[T, T]: 81 | 82 | q_b_size, q_len = question_ids.size() 83 | ctx_b_size, c_len = context_ids.size() 84 | one_ctx_size = ctx_b_size // q_b_size 85 | 86 | r = self.get_interactions( 87 | question_ids, 88 | question_segments, 89 | question_attn_mask, 90 | context_ids, 91 | ctx_segments, 92 | ctx_attn_mask, 93 | encoder_type, 94 | representation_token_pos, 95 | ) 96 | r = r.max(-1).values 97 | r = r.mean(-1) 98 | r = r.view(q_b_size, -1) 99 | 100 | one_ctx_size = ctx_b_size // q_b_size 101 | rows = torch.arange(q_b_size).view(-1, 1).repeat(1, one_ctx_size).view(-1) 102 | cols = torch.arange(q_b_size * one_ctx_size).view(-1) 103 | 104 | return torch.zeros(q_b_size, device=r.device), r[rows, cols].view(q_b_size, -1) 105 | 106 | def get_loss_function(self) -> RankerLoss: 107 | return RerankerNllLoss(self.cfg) 108 | 109 | 110 | class BERTRerankerCrossEncoder(DocumentBiEncoder): 111 | def __init__( 112 | self, 113 | cfg, 114 | question_model: nn.Module, 115 | ctx_model: nn.Module, 116 | fix_q_encoder: bool = False, 117 | fix_ctx_encoder: bool = False, 118 | ): 119 | super().__init__( 120 | cfg=cfg, 121 | question_model=question_model, 122 | ctx_model=ctx_model, 123 | fix_q_encoder=fix_q_encoder, 124 | fix_ctx_encoder=fix_ctx_encoder, 125 | ) 126 | if question_model.encode_proj is None: 127 | raise Exception('Set "encoder.projection_dim: 1" for BERTRerankerCrossEncoder') 128 | 129 | def forward( 130 | self, 131 | question_ids: T, 132 | question_segments: T, 133 | question_attn_mask: T, 134 | context_ids: T, 135 | ctx_segments: T, 136 | ctx_attn_mask: T, 137 | encoder_type: str = None, 138 | representation_token_pos=0, 139 | ) -> Tuple[T, T]: 140 | 141 | q_encoder = self.question_model if encoder_type is None or encoder_type == "question" else self.ctx_model 142 | q_b_size, q_len = question_ids.size() 143 | ctx_b_size, c_len = context_ids.size() 144 | 145 | _, all_cross_pooled_out, _ = self.get_representation( 146 | q_encoder, 147 | context_ids, 148 | ctx_segments, 149 | ctx_attn_mask, 150 | self.fix_q_encoder, 151 | representation_token_pos=representation_token_pos, 152 | ) 153 | 154 | return torch.zeros(q_b_size, device=all_cross_pooled_out.device), all_cross_pooled_out.view(q_b_size, -1) 155 | 156 | def get_loss_function(self) -> RankerLoss: 157 | return RerankerNllLoss(self.cfg) 158 | 159 | def prepare_model_inputs(self, batch: dict, dataset): 160 | 161 | question_ids = batch["question_ids"] 162 | question_segments = batch["question_segments"] 163 | question_attn_mask = dataset.tensorizer.get_attn_mask(batch["question_ids"]) 164 | context_ids = batch["context_ids"] 165 | ctx_attn_mask = dataset.tensorizer.get_attn_mask(batch["context_ids"]) 166 | 167 | q_b_size, q_len = question_ids.size() 168 | ctx_b_size, c_len = context_ids.size() 169 | 170 | all_cross_ids = question_ids.repeat(1, ctx_b_size // q_b_size).view(-1, q_len).repeat(1, 2) 171 | all_cross_ids[:, q_len:] = dataset.tensorizer.get_pad_id() 172 | 173 | cols = torch.arange(1, c_len, device=question_ids.device).view(1, -1).repeat(ctx_b_size, 1) + ( 174 | (all_cross_ids != dataset.tensorizer.get_pad_id()).sum(-1).view(-1, 1) - 1 175 | ) 176 | rows = torch.arange(ctx_b_size, device=question_ids.device).view(-1, 1).repeat(1, c_len - 1) 177 | 178 | all_cross_ids[rows, cols] = context_ids[:, 1:] 179 | 180 | all_cross_segments = question_segments.repeat(1, ctx_b_size // q_b_size).view(-1, q_len).repeat(1, 2) 181 | all_cross_segments[:, q_len:] = 0 182 | all_cross_segments[rows, cols] = 1 183 | 184 | all_cross_attn_mask = question_attn_mask.repeat(1, ctx_b_size // q_b_size).view(-1, q_len).repeat(1, 2) 185 | all_cross_attn_mask[:, q_len:] = 0 186 | all_cross_attn_mask[rows, cols] = ctx_attn_mask[:, 1:] 187 | 188 | return { 189 | "question_ids": question_ids, 190 | "question_segments": question_segments, 191 | "question_attn_mask": question_attn_mask, 192 | "context_ids": all_cross_ids, 193 | "ctx_segments": all_cross_segments, 194 | "ctx_attn_mask": all_cross_attn_mask, 195 | } 196 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line arguments utils 10 | """ 11 | import json 12 | import logging 13 | import random 14 | 15 | import numpy as np 16 | import torch 17 | from omegaconf import DictConfig 18 | 19 | logger = logging.getLogger() 20 | 21 | 22 | def set_cfg_params_from_state(state: dict, cfg: DictConfig): 23 | """ 24 | Overrides some of the encoder config parameters from a give state object 25 | """ 26 | if not state: 27 | return 28 | 29 | cfg.input_transform.do_lower_case = state["do_lower_case"] 30 | 31 | if "encoder" in state: 32 | saved_encoder_params = state["encoder"] 33 | # TODO: try to understand why cfg.base_model = state["encoder"] doesn't work 34 | 35 | # HACK 36 | conf_str = ( 37 | str(saved_encoder_params) 38 | .strip() 39 | .replace("'", '"') 40 | .replace("None", "null") 41 | .replace("True", "true") 42 | .replace("False", "false") 43 | ) 44 | print(conf_str) 45 | saved_encoder_params_json = json.loads(conf_str) 46 | 47 | for k, v in saved_encoder_params_json.items(): 48 | if hasattr(cfg.base_model, k): 49 | setattr(cfg.base_model, k, v) 50 | else: # 'old' checkpoints backward compatibility support 51 | return 52 | 53 | 54 | def get_encoder_params_state_from_cfg(cfg: DictConfig): 55 | """ 56 | Selects the param values to be saved in a checkpoint, so that a trained model can be used for downstream 57 | tasks without the need to specify these parameter again 58 | :return: Dict of params to memorize in a checkpoint 59 | """ 60 | return { 61 | "do_lower_case": cfg.input_transform.do_lower_case, 62 | "encoder": cfg.base_model, 63 | } 64 | 65 | 66 | def set_seed(args): 67 | seed = args.seed 68 | random.seed(seed) 69 | np.random.seed(seed) 70 | torch.manual_seed(seed) 71 | if args.n_gpu > 0: 72 | torch.cuda.manual_seed_all(seed) 73 | 74 | 75 | def setup_logger(logger): 76 | logger.setLevel(logging.INFO) 77 | if logger.hasHandlers(): 78 | logger.handlers.clear() 79 | log_formatter = logging.Formatter("[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s") 80 | console = logging.StreamHandler() 81 | console.setFormatter(log_formatter) 82 | logger.addHandler(console) 83 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/projects/verify_wikipedia/dpr/utils/__init__.py -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/utils/conf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import glob 8 | import logging 9 | import os 10 | 11 | import hydra 12 | from omegaconf import DictConfig 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class DatasetsCfg(object): 18 | def __init__(self, cfg: DictConfig, tensorizer): 19 | ds_cfg = cfg.datasets 20 | self.train_datasets_names = cfg.train_datasets 21 | logger.info("train_datasets: %s", self.train_datasets_names) 22 | self.train_datasets = _init_datasets(self.train_datasets_names, ds_cfg, cfg=cfg, tensorizer=tensorizer) 23 | 24 | self.dev_datasets_names = cfg.dev_datasets 25 | logger.info("dev_datasets: %s", self.dev_datasets_names) 26 | self.dev_datasets = _init_datasets(self.dev_datasets_names, ds_cfg, cfg=cfg, tensorizer=tensorizer) 27 | self.sampling_rates = cfg.multi_dataset_train_sampling_rates 28 | 29 | 30 | def _init_datasets(datasets_names, ds_cfg: DictConfig, cfg: DictConfig, tensorizer): 31 | if isinstance(datasets_names, str): 32 | return [_init_dataset(datasets_names, ds_cfg, cfg=cfg, tensorizer=tensorizer)] 33 | elif datasets_names: 34 | return [_init_dataset(ds_name, ds_cfg, cfg=cfg, tensorizer=tensorizer) for ds_name in datasets_names] 35 | else: 36 | return [] 37 | 38 | 39 | def _init_dataset(name: str, ds_cfg: DictConfig, cfg: DictConfig, tensorizer): 40 | if glob.glob(name): 41 | files = glob.glob(name) 42 | return [_init_dataset(f, ds_cfg, cfg=cfg, tensorizer=tensorizer) for f in files] 43 | # try to find in cfg 44 | if name not in ds_cfg: 45 | raise RuntimeError("Can't find dataset location/config for: {}".format(name)) 46 | return hydra.utils.instantiate(ds_cfg[name], cfg, tensorizer) 47 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for general purpose data processing 10 | """ 11 | import itertools 12 | import json 13 | import logging 14 | import math 15 | import random 16 | from typing import List, Iterator, Callable, Tuple 17 | 18 | import hydra 19 | import jsonlines 20 | import torch 21 | from omegaconf import DictConfig 22 | from tokenizers import Tokenizer 23 | from torch import Tensor as T 24 | 25 | logger = logging.getLogger() 26 | 27 | 28 | def read_data_from_json_files(paths: List[str], return_size=False) -> List: 29 | results = [] 30 | for i, path in enumerate(paths): 31 | if path.endswith(".jsonl"): 32 | logger.info("Reading jsonl file %s" % path) 33 | with open(path, "r", encoding="utf-8") as f: 34 | for line in f: 35 | if return_size: 36 | results.append(1) 37 | else: 38 | results.append(json.loads(line)) 39 | else: 40 | with open(path, "r", encoding="utf-8") as f: 41 | logger.info("Reading file %s" % path) 42 | data = json.load(f) 43 | if return_size: 44 | results.extend(len(data)) 45 | else: 46 | results.extend(data) 47 | logger.info("Aggregated data size: {}".format(len(results))) 48 | if return_size: 49 | return len(results) 50 | else: 51 | return results 52 | 53 | 54 | def read_data_from_jsonl_files(paths: List[str]) -> List: 55 | results = [] 56 | for i, path in enumerate(paths): 57 | logger.info("Reading file %s" % path) 58 | with jsonlines.open(path, mode="r") as jsonl_reader: 59 | data = [r for r in jsonl_reader] 60 | results.extend(data) 61 | logger.info("Aggregated data size: {}".format(len(results))) 62 | return results 63 | 64 | 65 | class Tensorizer(object): 66 | """ 67 | Component for all text to model input data conversions and related utility methods 68 | """ 69 | 70 | def __init__(self, max_length): 71 | self.max_length = max_length 72 | self.tokenizer: Tokenizer = None 73 | 74 | # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) 75 | def text_to_tensor( 76 | self, 77 | text: str, 78 | title: str = None, 79 | add_special_tokens: bool = True, 80 | apply_max_len: bool = True, 81 | ): 82 | raise NotImplementedError 83 | 84 | def get_pair_separator_ids(self) -> T: 85 | raise NotImplementedError 86 | 87 | def get_sep_id(self) -> int: 88 | raise NotImplementedError 89 | 90 | def get_sep_token(self) -> str: 91 | raise NotImplementedError 92 | 93 | def get_pad_id(self) -> int: 94 | raise NotImplementedError 95 | 96 | def get_pad_token(self) -> str: 97 | raise NotImplementedError 98 | 99 | def get_mask_id(self) -> int: 100 | raise NotImplementedError 101 | 102 | def get_mask_token(self) -> str: 103 | raise NotImplementedError 104 | 105 | def get_cls_id(self) -> int: 106 | raise NotImplementedError 107 | 108 | def get_cls_token(self) -> str: 109 | raise NotImplementedError 110 | 111 | def get_attn_mask(self, tokens_tensor: T): 112 | raise NotImplementedError 113 | 114 | def is_sub_word_id(self, token_id: int): 115 | raise NotImplementedError 116 | 117 | def to_string(self, token_ids, skip_special_tokens=True): 118 | raise NotImplementedError 119 | 120 | def set_pad_to_max(self, pad: bool): 121 | raise NotImplementedError 122 | 123 | def get_token_id(self, token: str) -> int: 124 | raise NotImplementedError 125 | 126 | 127 | class RepTokenSelector(object): 128 | def get_positions(self, input_ids: T, tenzorizer: Tensorizer): 129 | raise NotImplementedError 130 | 131 | 132 | class RepStaticPosTokenSelector(RepTokenSelector): 133 | def __init__(self, static_position: int = 0): 134 | self.static_position = static_position 135 | 136 | def get_positions(self, input_ids: T, tenzorizer: Tensorizer): 137 | return self.static_position 138 | 139 | 140 | DEFAULT_SELECTOR = RepStaticPosTokenSelector() 141 | 142 | 143 | class Dataset(torch.utils.data.Dataset): 144 | def __init__( 145 | self, 146 | selector: DictConfig = None, 147 | encoder_type: str = None, 148 | ): 149 | if selector: 150 | self.selector = hydra.utils.instantiate(selector) 151 | else: 152 | self.selector = DEFAULT_SELECTOR 153 | self.encoder_type = encoder_type 154 | self.data = [] 155 | 156 | def load_data(self, start_pos: int = -1, end_pos: int = -1, sharding_fn=None): 157 | raise NotImplementedError 158 | 159 | def calc_total_data_len(self): 160 | raise NotImplementedError 161 | 162 | def __len__(self): 163 | return len(self.data) 164 | 165 | def __getitem__(self, index): 166 | raise NotImplementedError 167 | 168 | 169 | # TODO: to be fully replaced with LocalSharded{...}. Keeping it only for old results reproduction compatibility 170 | class ShardedDataIterator(object): 171 | """ 172 | General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of 173 | the data. 174 | Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. 175 | It fills the extra sample by just taking first samples in a shard. 176 | It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). 177 | """ 178 | 179 | def __init__( 180 | self, 181 | dataset: Dataset, 182 | shard_id: int = 0, 183 | num_shards: int = 1, 184 | batch_size: int = 1, 185 | shuffle=True, 186 | shuffle_seed: int = 0, 187 | offset: int = 0, 188 | strict_batch_size: bool = False, 189 | ): 190 | 191 | self.dataset = dataset 192 | self.shard_id = shard_id 193 | self.num_shards = num_shards 194 | self.iteration = offset # to track in-shard iteration status 195 | self.shuffle = shuffle 196 | self.batch_size = batch_size 197 | self.shuffle_seed = shuffle_seed 198 | self.strict_batch_size = strict_batch_size 199 | self.shard_start_idx = -1 200 | self.shard_end_idx = -1 201 | self.samples_per_shard = -1 202 | self.max_iterations = 0 203 | 204 | def calculate_shards_fn(self, total_size): 205 | 206 | logger.info("Calculating shard positions") 207 | shards_num = max(self.num_shards, 1) 208 | shard_id = max(self.shard_id, 0) 209 | 210 | if self.strict_batch_size: 211 | self.samples_per_shard = int(total_size / shards_num) 212 | else: 213 | self.samples_per_shard = math.ceil(total_size / shards_num) 214 | 215 | self.shard_start_idx = shard_id * self.samples_per_shard 216 | self.shard_end_idx = min(self.shard_start_idx + self.samples_per_shard, total_size) 217 | 218 | return self.shard_start_idx, self.shard_end_idx 219 | 220 | def set_max_iteration(self): 221 | if self.strict_batch_size: 222 | self.max_iterations = math.ceil(self.samples_per_shard / self.batch_size) 223 | else: 224 | self.max_iterations = int(self.samples_per_shard / self.batch_size) 225 | logger.info( 226 | "samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d", 227 | self.samples_per_shard, 228 | self.shard_start_idx, 229 | self.shard_end_idx, 230 | self.max_iterations, 231 | ) 232 | 233 | def load_data(self): 234 | self.dataset.load_data(sharding_fn=self.calculate_shards_fn) 235 | self.set_max_iteration() 236 | logger.info("Sharded dataset data %d", len(self.dataset)) 237 | 238 | def total_data_len(self) -> int: 239 | return len(self.dataset) 240 | 241 | def iterations_num(self) -> int: 242 | return self.max_iterations - self.iteration 243 | 244 | def max_iterations_num(self) -> int: 245 | return self.max_iterations 246 | 247 | def get_iteration(self) -> int: 248 | return self.iteration 249 | 250 | def apply(self, visitor_func: Callable): 251 | for sample in self.dataset: 252 | visitor_func(sample) 253 | 254 | def get_shard_indices(self, epoch: int): 255 | indices = list(range(len(self.dataset))) 256 | if self.shuffle: 257 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 258 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 259 | epoch_rnd.shuffle(indices) 260 | shard_indices = indices[self.shard_start_idx : self.shard_end_idx] 261 | return shard_indices 262 | 263 | # TODO: merge with iterate_ds_sampled_data 264 | def iterate_ds_data(self, epoch: int = 0) -> Iterator[List]: 265 | # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations 266 | max_iterations = self.max_iterations - self.iteration 267 | shard_indices = self.get_shard_indices(epoch) 268 | 269 | for i in range(self.iteration * self.batch_size, len(shard_indices), self.batch_size): 270 | items_idxs = shard_indices[i : i + self.batch_size] 271 | if self.strict_batch_size and len(items_idxs) < self.batch_size: 272 | logger.debug("Extending batch to max size") 273 | items_idxs.extend(shard_indices[0 : self.batch_size - len(items)]) 274 | self.iteration += 1 275 | items = [self.dataset[idx] for idx in items_idxs] 276 | yield items 277 | 278 | # some shards may done iterating while the others are at the last batch. Just return the first batch 279 | while self.iteration < max_iterations: 280 | logger.debug("Fulfilling non complete shard=".format(self.shard_id)) 281 | self.iteration += 1 282 | items_idxs = shard_indices[0 : self.batch_size] 283 | items = [self.dataset[idx] for idx in items_idxs] 284 | yield items 285 | 286 | logger.info("Finished iterating, iteration={}, shard={}".format(self.iteration, self.shard_id)) 287 | # reset the iteration status 288 | self.iteration = 0 289 | 290 | def iterate_ds_sampled_data(self, num_iterations: int, epoch: int = 0) -> Iterator[List]: 291 | self.iteration = 0 292 | shard_indices = self.get_shard_indices(epoch) 293 | cycle_it = itertools.cycle(shard_indices) 294 | for i in range(num_iterations): 295 | items_idxs = [next(cycle_it) for _ in range(self.batch_size)] 296 | self.iteration += 1 297 | items = [self.dataset[idx] for idx in items_idxs] 298 | yield items 299 | 300 | logger.info("Finished iterating, iteration={}, shard={}".format(self.iteration, self.shard_id)) 301 | # TODO: reset the iteration status? 302 | self.iteration = 0 303 | 304 | def get_dataset(self) -> Dataset: 305 | return self.dataset 306 | 307 | 308 | class LocalShardedDataIterator(ShardedDataIterator): 309 | # uses only one shard after the initial dataset load to reduce memory footprint 310 | def load_data(self): 311 | self.dataset.load_data(sharding_fn=self.calculate_shards_fn) 312 | self.set_max_iteration() 313 | logger.info("Sharded dataset data %d", len(self.dataset)) 314 | 315 | def get_shard_indices(self, epoch: int): 316 | indices = list(range(len(self.dataset))) 317 | if self.shuffle: 318 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 319 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 320 | epoch_rnd.shuffle(indices) 321 | shard_indices = indices 322 | return shard_indices 323 | 324 | 325 | class MultiSetDataIterator(object): 326 | """ 327 | Iterator over multiple data sources. Useful when all samples form a single batch should be from the same dataset. 328 | """ 329 | 330 | def __init__( 331 | self, 332 | datasets: List[ShardedDataIterator], 333 | shuffle_seed: int = 0, 334 | shuffle=True, 335 | sampling_rates: List = [], 336 | rank: int = 0, 337 | ): 338 | # randomized data loading to avoid file system congestion 339 | ds_list_copy = [ds for ds in datasets] 340 | rnd = random.Random(rank) 341 | rnd.shuffle(ds_list_copy) 342 | [ds.load_data() for ds in ds_list_copy] 343 | 344 | self.iterables = datasets 345 | data_lengths = [it.total_data_len() for it in datasets] 346 | self.total_data = sum(data_lengths) 347 | logger.info("rank=%d; Multi set data sizes %s", rank, data_lengths) 348 | logger.info("rank=%d; Multi set total data %s", rank, self.total_data) 349 | logger.info("rank=%d; Multi set sampling_rates %s", rank, sampling_rates) 350 | self.shuffle_seed = shuffle_seed 351 | self.shuffle = shuffle 352 | self.iteration = 0 353 | self.rank = rank 354 | 355 | if sampling_rates: 356 | self.max_its_pr_ds = [int(ds.max_iterations_num() * sampling_rates[i]) for i, ds in enumerate(datasets)] 357 | else: 358 | self.max_its_pr_ds = [ds.max_iterations_num() for ds in datasets] 359 | 360 | self.max_iterations = sum(self.max_its_pr_ds) 361 | logger.info("rank=%d; Multi set max_iterations per dataset %s", rank, self.max_its_pr_ds) 362 | logger.info("rank=%d; Multi set max_iterations %d", rank, self.max_iterations) 363 | 364 | def total_data_len(self) -> int: 365 | return self.total_data 366 | 367 | def get_max_iterations(self): 368 | return self.max_iterations 369 | 370 | def iterate_ds_data(self, epoch: int = 0) -> Iterator[Tuple[List, int]]: 371 | 372 | logger.info("rank=%d; Iteration start", self.rank) 373 | logger.info( 374 | "rank=%d; Multi set iteration: iteration ptr per set: %s", 375 | self.rank, 376 | [it.get_iteration() for it in self.iterables], 377 | ) 378 | 379 | data_src_indices = [] 380 | iterators = [] 381 | for source, src_its in enumerate(self.max_its_pr_ds): 382 | logger.info( 383 | "rank=%d; Multi set iteration: source %d, batches to be taken: %s", 384 | self.rank, 385 | source, 386 | src_its, 387 | ) 388 | data_src_indices.extend([source] * src_its) 389 | 390 | iterators.append(self.iterables[source].iterate_ds_sampled_data(src_its, epoch=epoch)) 391 | 392 | if self.shuffle: 393 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 394 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 395 | epoch_rnd.shuffle(data_src_indices) 396 | 397 | logger.info("rank=%d; data_src_indices len=%d", self.rank, len(data_src_indices)) 398 | for i, source_idx in enumerate(data_src_indices): 399 | it = iterators[source_idx] 400 | next_item = next(it, None) 401 | if next_item is not None: 402 | self.iteration += 1 403 | yield (next_item, source_idx) 404 | else: 405 | logger.warning("rank=%d; Next item in the source %s is None", self.rank, source_idx) 406 | 407 | logger.info("rank=%d; last iteration %d", self.rank, self.iteration) 408 | 409 | logger.info( 410 | "rank=%d; Multi set iteration finished: iteration per set: %s", 411 | self.rank, 412 | [it.iteration for it in self.iterables], 413 | ) 414 | [next(it, None) for it in iterators] 415 | 416 | # TODO: clear iterators in some non-hacky way 417 | for it in self.iterables: 418 | it.iteration = 0 419 | logger.info( 420 | "rank=%d; Multi set iteration finished after next: iteration per set: %s", 421 | self.rank, 422 | [it.iteration for it in self.iterables], 423 | ) 424 | # reset the iteration status 425 | self.iteration = 0 426 | 427 | def get_iteration(self) -> int: 428 | return self.iteration 429 | 430 | def get_dataset(self, ds_id: int) -> Dataset: 431 | return self.iterables[ds_id].get_dataset() 432 | 433 | def get_datasets(self) -> List[Dataset]: 434 | return [it.get_dataset() for it in self.iterables] 435 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for distributed model training 10 | """ 11 | import logging 12 | import os 13 | import pickle 14 | import socket 15 | import subprocess 16 | from typing import Tuple 17 | import torch 18 | import torch.distributed as dist 19 | 20 | logger = logging.getLogger() 21 | 22 | 23 | def get_rank(): 24 | return dist.get_rank() 25 | 26 | 27 | def get_world_size(): 28 | return dist.get_world_size() 29 | 30 | 31 | def get_default_group(): 32 | return dist.group.WORLD 33 | 34 | 35 | def all_reduce(tensor, group=None): 36 | if group is None: 37 | group = get_default_group() 38 | return dist.all_reduce(tensor, group=group) 39 | 40 | 41 | def all_gather_list(data, group=None, max_size=16384): 42 | """Gathers arbitrary data from all nodes into a list. 43 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 44 | data. Note that *data* must be picklable. 45 | Args: 46 | data (Any): data from the local worker to be gathered on other workers 47 | group (optional): group of the collective 48 | """ 49 | SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size 50 | 51 | enc = pickle.dumps(data) 52 | enc_size = len(enc) 53 | 54 | if enc_size + SIZE_STORAGE_BYTES > max_size: 55 | raise ValueError( 56 | "encoded data exceeds max_size, this can be fixed by increasing buffer size: {}".format(enc_size) 57 | ) 58 | 59 | rank = get_rank() 60 | world_size = get_world_size() 61 | buffer_size = max_size * world_size 62 | 63 | if not hasattr(all_gather_list, "_buffer") or all_gather_list._buffer.numel() < buffer_size: 64 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 65 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 66 | 67 | buffer = all_gather_list._buffer 68 | buffer.zero_() 69 | cpu_buffer = all_gather_list._cpu_buffer 70 | 71 | assert enc_size < 256**SIZE_STORAGE_BYTES, "Encoded object size should be less than {} bytes".format( 72 | 256**SIZE_STORAGE_BYTES 73 | ) 74 | 75 | size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder="big") 76 | 77 | cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) 78 | cpu_buffer[SIZE_STORAGE_BYTES : enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) 79 | 80 | start = rank * max_size 81 | size = enc_size + SIZE_STORAGE_BYTES 82 | buffer[start : start + size].copy_(cpu_buffer[:size]) 83 | 84 | all_reduce(buffer, group=group) 85 | 86 | try: 87 | result = [] 88 | for i in range(world_size): 89 | out_buffer = buffer[i * max_size : (i + 1) * max_size] 90 | size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder="big") 91 | if size > 0: 92 | result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES : size + SIZE_STORAGE_BYTES].tolist()))) 93 | return result 94 | except pickle.UnpicklingError: 95 | raise Exception( 96 | "Unable to unpickle data from other workers. all_gather_list requires all " 97 | "workers to enter the function together, so this error usually indicates " 98 | "that the workers have fallen out of sync somehow. Workers can fall out of " 99 | "sync if one of them runs out of memory, or if there are other conditions " 100 | "in your training script that can cause one worker to finish an epoch " 101 | "while other workers are still iterating over their portions of the data." 102 | ) 103 | 104 | 105 | def setup_cfg_gpu(cfg): 106 | """ 107 | Setup params for CUDA, GPU & distributed training 108 | """ 109 | 110 | logger.info("CFG's local_rank=%s", cfg.local_rank) 111 | ws = os.environ.get("WORLD_SIZE") 112 | cfg.distributed_world_size = int(ws) if ws else 1 113 | logger.info("Env WORLD_SIZE=%s", ws) 114 | 115 | nnodes = int(os.environ.get("SLURM_NNODES")) if os.environ.get("SLURM_NNODES") else 1 116 | ntasks = int(os.environ.get("SLURM_NTASKS")) if os.environ.get("SLURM_NTASKS") else 1 117 | 118 | if nnodes * ntasks > 1 and not cfg.distributed_port: 119 | logger.info("distributed_port not specified setting to '29500' ...") 120 | cfg.distributed_port = 29500 121 | 122 | # if cfg.distributed_port and cfg.distributed_port > 0: 123 | if nnodes > 1: 124 | # logger.info("distributed_port is specified, trying to init distributed mode from SLURM params ...") 125 | 126 | init_method, local_rank, world_size, device = infer_slurm_init(cfg) 127 | 128 | logger.info( 129 | "Inferred params from SLURM: init_method=%s | local_rank=%s | world_size=%s", 130 | init_method, 131 | local_rank, 132 | world_size, 133 | ) 134 | 135 | cfg.local_rank = local_rank 136 | cfg.distributed_world_size = world_size 137 | cfg.n_gpu = 1 138 | 139 | torch.cuda.set_device(device) 140 | device = str(torch.device("cuda", device)) 141 | 142 | torch.distributed.init_process_group( 143 | backend="nccl", 144 | init_method=init_method, 145 | world_size=world_size, 146 | rank=local_rank, 147 | ) 148 | 149 | elif cfg.local_rank == -1 or cfg.no_cuda: # single-node multi-gpu (or cpu) mode 150 | device = str(torch.device("cuda" if torch.cuda.is_available() and not cfg.no_cuda else "cpu")) 151 | cfg.n_gpu = torch.cuda.device_count() 152 | else: # distributed mode 153 | torch.cuda.set_device(cfg.local_rank) 154 | device = str(torch.device("cuda", cfg.local_rank)) 155 | torch.distributed.init_process_group(backend="nccl") 156 | cfg.n_gpu = 1 157 | 158 | cfg.device = device 159 | 160 | logger.info( 161 | "Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d", 162 | socket.gethostname(), 163 | cfg.local_rank, 164 | cfg.device, 165 | cfg.n_gpu, 166 | cfg.distributed_world_size, 167 | ) 168 | logger.info("16-bits training: %s ", cfg.fp16) 169 | return cfg 170 | 171 | 172 | def infer_slurm_init(cfg=None) -> Tuple[str, int, int, int]: 173 | 174 | # if cfg.distributed_port 175 | 176 | node_list = os.environ.get("SLURM_STEP_NODELIST") 177 | if node_list is None: 178 | node_list = os.environ.get("SLURM_JOB_NODELIST") 179 | logger.info("SLURM_JOB_NODELIST: %s", node_list) 180 | 181 | if node_list is None: 182 | logger.warning("Can't find SLURM node_list from env parameters") 183 | return "", 0, 1, -1 184 | 185 | local_rank = None 186 | world_size = None 187 | distributed_init_method = None 188 | device_id = None 189 | try: 190 | hostnames = subprocess.check_output(["scontrol", "show", "hostnames", node_list]) 191 | if cfg: 192 | distributed_init_method = "tcp://{host}:{port}".format( 193 | host=hostnames.split()[0].decode("utf-8"), 194 | port=cfg.distributed_port, 195 | ) 196 | nnodes = int(os.environ.get("SLURM_NNODES")) 197 | logger.info("SLURM_NNODES: %s", nnodes) 198 | ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") 199 | if ntasks_per_node is not None: 200 | ntasks_per_node = int(ntasks_per_node) 201 | logger.info("SLURM_NTASKS_PER_NODE: %s", ntasks_per_node) 202 | else: 203 | ntasks = int(os.environ.get("SLURM_NTASKS")) 204 | logger.info("SLURM_NTASKS: %s", ntasks) 205 | assert ntasks % nnodes == 0 206 | ntasks_per_node = int(ntasks / nnodes) 207 | 208 | if ntasks_per_node == 1: 209 | gpus_per_node = torch.cuda.device_count() 210 | node_id = int(os.environ.get("SLURM_NODEID")) 211 | local_rank = node_id * gpus_per_node 212 | world_size = nnodes * gpus_per_node 213 | logger.info("node_id: %s", node_id) 214 | else: 215 | world_size = ntasks_per_node * nnodes 216 | proc_id = os.environ.get("SLURM_PROCID") 217 | local_id = os.environ.get("SLURM_LOCALID") 218 | logger.info("SLURM_PROCID %s", proc_id) 219 | logger.info("SLURM_LOCALID %s", local_id) 220 | local_rank = int(proc_id) 221 | device_id = int(local_id) 222 | 223 | except subprocess.CalledProcessError as e: # scontrol failed 224 | raise e 225 | except FileNotFoundError: # Slurm is not installed 226 | pass 227 | return distributed_init_method, local_rank, world_size, device_id 228 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/dpr/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import glob 10 | import logging 11 | import os 12 | from typing import List 13 | 14 | import torch 15 | from torch import nn 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.serialization import default_restore_location 18 | 19 | logger = logging.getLogger() 20 | 21 | CheckpointState = collections.namedtuple( 22 | "CheckpointState", 23 | [ 24 | "model_dict", 25 | "optimizer_dict", 26 | "scheduler_dict", 27 | "offset", 28 | "epoch", 29 | "best_validation_loss_result", 30 | "best_validation_acc_result", 31 | "encoder_params", 32 | ], 33 | ) 34 | 35 | 36 | def setup_fp16_and_distributed_mode( 37 | model: nn.Module, 38 | optimizer: torch.optim.Optimizer, 39 | device: object, 40 | n_gpu: int = 1, 41 | local_rank: int = -1, 42 | fp16: bool = False, 43 | fp16_opt_level: str = "O1", 44 | ) -> (nn.Module, torch.optim.Optimizer): 45 | 46 | model.to(device) 47 | 48 | if n_gpu > 1: 49 | model = torch.nn.DataParallel(model) 50 | 51 | if local_rank != -1: 52 | model = torch.nn.parallel.DistributedDataParallel( 53 | model, 54 | device_ids=[device if device else local_rank], 55 | output_device=local_rank, 56 | find_unused_parameters=True, 57 | gradient_as_bucket_view=True, 58 | ) 59 | return model, optimizer 60 | 61 | 62 | def move_to_cuda(sample): 63 | if len(sample) == 0: 64 | return {} 65 | 66 | def _move_to_cuda(maybe_tensor): 67 | if torch.is_tensor(maybe_tensor): 68 | return maybe_tensor.cuda() 69 | elif isinstance(maybe_tensor, dict): 70 | return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} 71 | elif isinstance(maybe_tensor, list): 72 | return [_move_to_cuda(x) for x in maybe_tensor] 73 | elif isinstance(maybe_tensor, tuple): 74 | return [_move_to_cuda(x) for x in maybe_tensor] 75 | else: 76 | return maybe_tensor 77 | 78 | return _move_to_cuda(sample) 79 | 80 | 81 | def move_to_device(sample, device): 82 | if len(sample) == 0: 83 | return {} 84 | 85 | def _move_to_device(maybe_tensor, device): 86 | if torch.is_tensor(maybe_tensor): 87 | return maybe_tensor.to(device) 88 | elif isinstance(maybe_tensor, dict): 89 | return {key: _move_to_device(value, device) for key, value in maybe_tensor.items()} 90 | elif isinstance(maybe_tensor, list): 91 | return [_move_to_device(x, device) for x in maybe_tensor] 92 | elif isinstance(maybe_tensor, tuple): 93 | return [_move_to_device(x, device) for x in maybe_tensor] 94 | else: 95 | return maybe_tensor 96 | 97 | return _move_to_device(sample, device) 98 | 99 | 100 | def get_schedule_linear( 101 | optimizer, 102 | warmup_steps, 103 | total_training_steps, 104 | steps_shift=0, 105 | last_epoch=-1, 106 | ): 107 | 108 | """Create a schedule with a learning rate that decreases linearly after 109 | linearly increasing during a warmup period. 110 | """ 111 | 112 | def lr_lambda(current_step): 113 | current_step += steps_shift 114 | if current_step < warmup_steps: 115 | return float(current_step) / float(max(1, warmup_steps)) 116 | return max( 117 | 1e-7, 118 | float(total_training_steps - current_step) / float(max(1, total_training_steps - warmup_steps)), 119 | ) 120 | 121 | return LambdaLR(optimizer, lr_lambda, last_epoch) 122 | 123 | 124 | def init_weights(modules: List): 125 | for module in modules: 126 | if isinstance(module, (nn.Linear, nn.Embedding)): 127 | module.weight.data.normal_(mean=0.0, std=0.02) 128 | elif isinstance(module, nn.LayerNorm): 129 | module.bias.data.zero_() 130 | module.weight.data.fill_(1.0) 131 | if isinstance(module, nn.Linear) and module.bias is not None: 132 | module.bias.data.zero_() 133 | 134 | 135 | def get_model_obj(model: nn.Module): 136 | return model.module if hasattr(model, "module") else model 137 | 138 | 139 | def get_model_file(args, file_prefix) -> str: 140 | if args.resume.model_file and os.path.exists(args.resume.model_file): 141 | return args.resume.model_file 142 | 143 | out_cp_files = glob.glob(os.path.join(args.output_dir, file_prefix + "*[0-9]")) if args.output_dir else [] 144 | logger.info("Checkpoint files %s", out_cp_files) 145 | model_file = None 146 | 147 | if len(out_cp_files) > 0: 148 | model_file = max(out_cp_files, key=os.path.getctime) 149 | return model_file 150 | 151 | 152 | def load_states_from_checkpoint(model_file: str) -> CheckpointState: 153 | logger.info("Reading saved model from %s", model_file) 154 | state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, "cpu")) 155 | logger.info("model_state_dict keys %s", state_dict.keys()) 156 | return CheckpointState(**state_dict) 157 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/projects/verify_wikipedia/evaluation/__init__.py -------------------------------------------------------------------------------- /projects/verify_wikipedia/evaluation/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | class TaskDataset: 8 | def __init__( 9 | self, 10 | file, 11 | task_family, 12 | name, 13 | question_transform_type, 14 | ): 15 | self.file = file 16 | self.task_family = task_family 17 | self.name = name 18 | self.question_transform_type = question_transform_type 19 | 20 | def validate_datapoint(self, datapoint, logger): 21 | raise NotImplementedError 22 | 23 | def transform_query(self, datapoint, question_transform_type): 24 | return datapoint 25 | 26 | 27 | class WaferCCNetTaskDataset(TaskDataset): 28 | def validate_datapoint(self, datapoint, logger): 29 | 30 | # input is a string 31 | if not isinstance(datapoint["input"], str): 32 | if logger: 33 | logger.warning("[{}] input is not a string {}".format(datapoint["id"], datapoint["input"])) 34 | return False 35 | 36 | # output is not empty 37 | if "output" in datapoint: 38 | if len(datapoint["output"]) == 0: 39 | if logger: 40 | logger.warning("[{}] empty output".format(datapoint["id"])) 41 | return False 42 | 43 | for output in datapoint["output"]: 44 | # answer is a string 45 | if "answer" in output: 46 | if not isinstance(output["answer"], str): 47 | if logger: 48 | logger.warning("[{}] answer is not a string {}".format(datapoint["id"], output["answer"])) 49 | return False 50 | 51 | # provenance is not empty 52 | if len(output["provenance"]) == 0: 53 | if logger: 54 | logger.warning("[{}] empty provenance".format(datapoint["id"])) 55 | return False 56 | 57 | if "provenance" in output: 58 | for provenance in output["provenance"]: 59 | # title is provided 60 | if "url" not in provenance or not isinstance(provenance["url"], str): 61 | if logger: 62 | logger.warning("datapoint with malformed or missing url") 63 | return False 64 | 65 | return True 66 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/evaluation/retrievers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/projects/verify_wikipedia/evaluation/retrievers/__init__.py -------------------------------------------------------------------------------- /projects/verify_wikipedia/evaluation/retrievers/base_retriever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | from abc import ABC, abstractmethod 10 | 11 | 12 | class Retriever(ABC): 13 | def __init__(self, name): 14 | self.name = name 15 | 16 | @classmethod 17 | def from_default_config(cls, name): 18 | raise NotImplementedError 19 | 20 | def get_queries_data(self, queries_data): 21 | return queries_data 22 | 23 | def preprocess_instance(self, instance): 24 | return instance 25 | 26 | @abstractmethod 27 | def set_queries_data(self, queries_data): 28 | """ 29 | fed all data to the retriever, that will take care of batchify it 30 | each element in queries_data has an id and a query 31 | 32 | Args: 33 | queries_data (list): list of dicts with two fields: (1) 'id' -> id of the query: (2) 'query' -> text of the query 34 | 35 | Example: 36 | queries_data = [ 37 | {'id': '-4203908294749842710', 'query': 'what is the definition of bcc in email'}, 38 | ... 39 | ] 40 | """ 41 | raise NotImplementedError 42 | 43 | @abstractmethod 44 | def run(self): 45 | """ 46 | get the retrieved documents for all the fed data 47 | return all_doc_id, all_scores, all_query_id, provenance 48 | 49 | Returns 50 | ------- 51 | provenance: dictionary with retrieval result, the keys should match the query id in input 52 | 53 | Example: 54 | provenance: { 55 | '-4203908294749842710': [ 56 | {"score": "179.01215", "text": "permit the use of a program-external editor. The email clients will perform formatting according to RFC 5322 for headers and body, and MIME for non-textual content and attachments. Headers include the destination fields, \"To\", \"Cc\" (short for \"Carbon copy\"), and \"Bcc\" (\"Blind carbon copy\"), and the originator fields \"From\" which is the message's author(s), \"Sender\" in case there are more authors, and \"Reply-To\"", "wikipedia_title": "Email client", "wikipedia_id": "43478"}, 57 | {"score": "184.6643", "text": "this example, the conversation parts are prefixed with \"S:\" and \"C:\", for \"server\" and \"client\", respectively; these labels are not part of the exchange.) After the message sender (SMTP client) establishes a reliable communications channel to the message receiver (SMTP server), the session is opened with a greeting by the server, usually containing its fully qualified domain name (FQDN), in this case \"smtp.example.com\". The client initiates its dialog by responding with a", "wikipedia_title": "Simple Mail Transfer Protocol", "wikipedia_id": "27675"}, 58 | ... 59 | ], 60 | ... 61 | } 62 | """ 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/evaluation/retrievers/dpr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import glob 7 | import logging 8 | 9 | import hydra 10 | from omegaconf import OmegaConf 11 | 12 | from dpr.dense_retriever import LocalFaissRetriever, get_all_passages 13 | from dpr.indexer.faiss_indexers import DenseFlatIndexer, DenseHNSWFlatIndexer 14 | from dpr.models import init_biencoder_components 15 | from dpr.options import set_cfg_params_from_state 16 | from dpr.utils.dist_utils import setup_cfg_gpu 17 | from dpr.utils.model_utils import ( 18 | setup_fp16_and_distributed_mode, 19 | get_model_obj, 20 | load_states_from_checkpoint, 21 | ) 22 | from evaluation.retrievers.base_retriever import Retriever 23 | 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.INFO) 26 | 27 | 28 | class DPR(Retriever): 29 | def __init__(self, name, cfg): 30 | super().__init__(name) 31 | print(OmegaConf.to_yaml(cfg)) 32 | 33 | _ckpt_cfg = OmegaConf.load(f"{cfg.retriever.checkpoint_dir}/.hydra/config.yaml") 34 | self.checkpoint_cfg = setup_cfg_gpu(_ckpt_cfg) 35 | self.cfg = cfg 36 | 37 | assert int(self.cfg.n_docs) > 0 38 | 39 | logger.info("CFG (after gpu configuration):") 40 | logger.info("%s", OmegaConf.to_yaml(cfg)) 41 | 42 | saved_state = load_states_from_checkpoint(f"{cfg.retriever.model_file}.generate_embeddings") 43 | set_cfg_params_from_state(saved_state.encoder_params, self.checkpoint_cfg) 44 | 45 | tensorizer, encoder, _ = init_biencoder_components( 46 | self.checkpoint_cfg.base_model.encoder_model_type, self.checkpoint_cfg, inference_only=True 47 | ) 48 | 49 | encoder = encoder.question_model 50 | encoder, _ = setup_fp16_and_distributed_mode( 51 | encoder, 52 | None, 53 | self.checkpoint_cfg.device, 54 | self.checkpoint_cfg.n_gpu, 55 | self.checkpoint_cfg.local_rank, 56 | self.checkpoint_cfg.fp16, 57 | ) 58 | encoder.eval() 59 | 60 | # load weights from the model file 61 | model_to_load = get_model_obj(encoder) 62 | logger.info("Loading saved model state ...") 63 | 64 | encoder_prefix = "question_model." 65 | prefix_len = len(encoder_prefix) 66 | 67 | logger.info("Encoder state prefix %s", encoder_prefix) 68 | question_encoder_state = { 69 | key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith(encoder_prefix) 70 | } 71 | model_to_load.load_state_dict(question_encoder_state) 72 | vector_size = model_to_load.get_out_size() 73 | logger.info("Encoder vector_size=%d", vector_size) 74 | 75 | ctx_sources = [] 76 | ctx_src = hydra.utils.instantiate( 77 | cfg.retriever.datasets[cfg.retriever.ctx_src], self.checkpoint_cfg, tensorizer=tensorizer 78 | ) 79 | ctx_sources.append(ctx_src) 80 | self.all_passages = get_all_passages(ctx_sources) 81 | 82 | if self.cfg.retriever.hnsw_index: 83 | index = DenseHNSWFlatIndexer(vector_size) 84 | index.deserialize(cfg.retriever.hnsw_index_path) 85 | self.retriever = LocalFaissRetriever(encoder, cfg.retriever.batch_size, ctx_src, index) 86 | else: 87 | logger.info("Reading from encoded vectors=%s", f"{cfg.retriever.out_file}_[0-9]*") 88 | # index all passages 89 | index = DenseFlatIndexer() 90 | index.init_index(vector_sz=vector_size) 91 | self.retriever = LocalFaissRetriever(encoder, cfg.retriever.batch_size, ctx_src, index) 92 | self.retriever.index_encoded_data( 93 | vector_files=glob.glob(f"{cfg.retriever.out_file}_[0-9]*"), 94 | buffer_size=index.buffer_size, 95 | ) 96 | 97 | def set_queries_data(self, queries_data): 98 | # get questions & answers 99 | self.questions = [x["query"].strip() for x in queries_data] 100 | self.query_ids = [x["id"] for x in queries_data] 101 | 102 | def run(self): 103 | 104 | questions_tensor = self.retriever.generate_question_vectors(self.questions) 105 | 106 | top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), int(self.cfg.n_docs)) 107 | 108 | provenance = {} 109 | 110 | for record, query_id in zip(top_ids_and_scores, self.query_ids): 111 | top_ids, scores = record 112 | element = [] 113 | 114 | # sort by score in descending order 115 | for score, id in sorted(zip(scores, top_ids), reverse=True): 116 | 117 | text = self.all_passages[id][0] 118 | index = self.all_passages[id][1] 119 | url = self.all_passages[id][2] 120 | 121 | wikipedia_id = None 122 | 123 | element.append( 124 | { 125 | "score": str(score), 126 | "text": str(text), 127 | "wikipedia_title": str(index), 128 | "wikipedia_id": str(wikipedia_id), 129 | "chunk_id": id, 130 | "url": url, 131 | } 132 | ) 133 | 134 | assert query_id not in provenance 135 | provenance[query_id] = element 136 | 137 | return provenance 138 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/evaluation/retrievers/dpr_remote.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import os 8 | import pickle 9 | import zlib 10 | from pathlib import Path 11 | 12 | import hydra 13 | from omegaconf import OmegaConf 14 | from tqdm import tqdm 15 | 16 | from dpr.dense_retriever import DenseRPCRetriever 17 | from dpr.models import init_biencoder_components 18 | from dpr.options import set_cfg_params_from_state 19 | from dpr.utils.dist_utils import setup_cfg_gpu 20 | from dpr.utils.model_utils import ( 21 | setup_fp16_and_distributed_mode, 22 | get_model_obj, 23 | load_states_from_checkpoint, 24 | ) 25 | from evaluation.retrievers.base_retriever import Retriever 26 | 27 | logger = logging.getLogger() 28 | logger.setLevel(logging.INFO) 29 | 30 | 31 | class DPR(Retriever): 32 | def __init__(self, name, cfg): 33 | super().__init__(name) 34 | 35 | self.cfg = cfg 36 | _ckpt_cfg = OmegaConf.load(f"{cfg.retriever.checkpoint_dir}/.hydra/config.yaml") 37 | self.checkpoint_cfg = setup_cfg_gpu(_ckpt_cfg) 38 | 39 | logger.info("CFG (after gpu configuration):") 40 | logger.info("%s", OmegaConf.to_yaml(cfg)) 41 | 42 | if os.path.exists(f"{cfg.retriever.model_file}.generate_embeddings"): 43 | saved_state = load_states_from_checkpoint(f"{cfg.retriever.model_file}.generate_embeddings") 44 | elif os.path.exists(f"{cfg.retriever.model_file}"): 45 | saved_state = load_states_from_checkpoint(f"{cfg.retriever.model_file}") 46 | else: 47 | raise Exception( 48 | f"Checkpoint {cfg.retriever.model_file}.generate_embeddings or {cfg.retriever.model_file} doesn't exist." 49 | ) 50 | 51 | set_cfg_params_from_state(saved_state.encoder_params, self.checkpoint_cfg) 52 | 53 | tensorizer, encoder, _ = init_biencoder_components( 54 | self.checkpoint_cfg.base_model.encoder_model_type, self.checkpoint_cfg, inference_only=True 55 | ) 56 | 57 | encoder = encoder.question_model 58 | encoder, _ = setup_fp16_and_distributed_mode( 59 | encoder, 60 | None, 61 | self.checkpoint_cfg.device, 62 | self.checkpoint_cfg.n_gpu, 63 | self.checkpoint_cfg.local_rank, 64 | self.checkpoint_cfg.fp16, 65 | ) 66 | encoder.eval() 67 | 68 | # load weights from the model file 69 | model_to_load = get_model_obj(encoder) 70 | logger.info("Loading saved model state ...") 71 | 72 | encoder_prefix = "question_model." 73 | prefix_len = len(encoder_prefix) 74 | 75 | logger.info("Encoder state prefix %s", encoder_prefix) 76 | question_encoder_state = { 77 | key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith(encoder_prefix) 78 | } 79 | model_to_load.load_state_dict(question_encoder_state) 80 | vector_size = model_to_load.get_out_size() 81 | logger.info("Encoder vector_size=%d", vector_size) 82 | 83 | ctx_src = hydra.utils.instantiate( 84 | cfg.retriever.datasets[cfg.retriever.ctx_src], self.checkpoint_cfg, tensorizer=tensorizer 85 | ) 86 | 87 | self.retriever = DenseRPCRetriever( 88 | question_encoder=encoder, 89 | qa_src=ctx_src, 90 | batch_size=cfg.retriever.batch_size, 91 | index_cfg_path=cfg.retriever.rpc_retriever_cfg_file, 92 | dim=vector_size, 93 | use_l2_conversion=True, 94 | # nprobe=128, 95 | ) 96 | self.retriever.load_index(cfg.retriever.rpc_index_id) 97 | 98 | def set_queries_data(self, queries_data): 99 | # get questions & answers 100 | self.questions = [x["query"].strip() for x in queries_data] 101 | self.query_ids = [x["id"] for x in queries_data] 102 | 103 | def run(self): 104 | 105 | questions_tensor = self.retriever.generate_question_vectors(self.questions) 106 | top_ids_and_scores = self.retriever.get_top_docs( 107 | questions_tensor.numpy(), 108 | self.cfg.n_docs, 109 | filter_pos=4, 110 | filter_value="1", 111 | ) 112 | 113 | provenance = {} 114 | 115 | for record, query_id in tqdm(zip(top_ids_and_scores, self.query_ids)): 116 | docs_meta, scores = record 117 | element = [] 118 | 119 | for score, meta in zip(scores, docs_meta): 120 | chunk_id, text, title, url = meta[:4] 121 | element.append( 122 | { 123 | "score": str(score), 124 | "chunk_id": str(chunk_id), 125 | "text": str(zlib.decompress(text).decode()), 126 | "wikipedia_title": str(zlib.decompress(title).decode()), 127 | "url": str(zlib.decompress(url).decode()), 128 | } 129 | ) 130 | 131 | assert query_id not in provenance 132 | provenance[query_id] = element 133 | 134 | return provenance 135 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/evaluation/retrievers/reranker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import os.path 8 | from collections import OrderedDict 9 | from pathlib import Path 10 | import torch 11 | import json 12 | 13 | import hydra 14 | from omegaconf import OmegaConf 15 | import tqdm 16 | 17 | from dpr.models import init_biencoder_components 18 | from dpr.models.biencoder import RankerLoss 19 | from dpr.options import set_cfg_params_from_state 20 | from dpr.dataset.input_transform import SchemaPreprocessor 21 | from dpr.utils.dist_utils import setup_cfg_gpu 22 | from dpr.utils.model_utils import ( 23 | setup_fp16_and_distributed_mode, 24 | get_model_obj, 25 | load_states_from_checkpoint, 26 | move_to_device, 27 | ) 28 | from evaluation.retrievers.base_retriever import Retriever 29 | from scripts.train import gather_to_main_device 30 | 31 | logger = logging.getLogger() 32 | logger.setLevel(logging.INFO) 33 | 34 | 35 | def get_ctxt_encoder_input_in_dpr_schema(VALID_SPLIT_DATA_FILES__RETRIEVERS, gold_queries_data, nr_ctxs, strict=False): 36 | 37 | out = OrderedDict() 38 | 39 | for i, instance in enumerate(gold_queries_data): 40 | if instance["id"] not in out: 41 | out[instance["id"]] = { 42 | "id": instance["id"], 43 | "question": instance["query"], 44 | "ctxs": list(), 45 | "provenances": list(), 46 | "evaluation": set(), 47 | } 48 | 49 | for retriever_sys, valid_split_data_file in VALID_SPLIT_DATA_FILES__RETRIEVERS: 50 | with open(valid_split_data_file) as f: 51 | 52 | for i, line in enumerate(tqdm.tqdm(f)): 53 | 54 | jline_sys = json.loads(line) 55 | 56 | if jline_sys["id"] not in out: 57 | # logger.info(f"Skipping instance '{jline_sys['id']}' which is not in gold") 58 | continue 59 | 60 | for output in jline_sys["output"]: 61 | if "provenance" not in output: 62 | continue 63 | if len(output["provenance"]) == 0: 64 | continue 65 | if len(output["provenance"]) == 0 or "text" not in output["provenance"][0]: 66 | continue 67 | 68 | # rand_nr = random.randint(5, nr_ctxs) 69 | 70 | for i, provenance in enumerate(output["provenance"][:nr_ctxs]): 71 | text, url = ( 72 | provenance["text"], 73 | provenance["url"], 74 | ) 75 | 76 | if "wikipedia_title" in provenance: 77 | title = provenance["wikipedia_title"] 78 | else: 79 | title = provenance["title"] 80 | 81 | if "chunk_id" in provenance: 82 | chunk_id = provenance["chunk_id"] 83 | elif "sha" in provenance: 84 | chunk_id = provenance["sha"] 85 | else: 86 | chunk_id = None 87 | 88 | out[jline_sys["id"]]["ctxs"].append( 89 | { 90 | "id": chunk_id, 91 | "doc_id": 0, 92 | "passage_id": 0, 93 | "url": url, 94 | "title": title, 95 | "text": text, 96 | "system_hit": retriever_sys, 97 | "rank": f"[RANK:{i}]", 98 | } 99 | ) 100 | provenance["retriever_sys"] = retriever_sys 101 | out[jline_sys["id"]]["provenances"].append(provenance) 102 | 103 | out[jline_sys["id"]]["evaluation"].add(retriever_sys) 104 | 105 | before_filter = len(out) 106 | logger.info(f"query_ctx size before filter: {before_filter}") 107 | out = list(filter(lambda x: len(x["evaluation"]) == len(VALID_SPLIT_DATA_FILES__RETRIEVERS), list(out.values()))) 108 | after_filter = len(out) 109 | logger.info(f"query_ctxy size after filter: {after_filter}") 110 | 111 | if strict: 112 | assert after_filter == before_filter 113 | 114 | return out 115 | 116 | 117 | class Reranker(Retriever): 118 | def __init__(self, name, cfg): 119 | super().__init__(name) 120 | 121 | self.cfg = cfg 122 | _ckpt_cfg = OmegaConf.load(f"{cfg.retriever.checkpoint_dir}/.hydra/config.yaml") 123 | self.checkpoint_cfg = setup_cfg_gpu(_ckpt_cfg) 124 | 125 | logger.info("CFG (after gpu configuration):") 126 | logger.info("%s", OmegaConf.to_yaml(cfg)) 127 | 128 | saved_state = load_states_from_checkpoint(f"{cfg.retriever.model_file}") 129 | set_cfg_params_from_state(saved_state.encoder_params, self.checkpoint_cfg) 130 | 131 | tensorizer, reranker, _ = init_biencoder_components( 132 | self.checkpoint_cfg.base_model.encoder_model_type, self.checkpoint_cfg, inference_only=True 133 | ) 134 | 135 | self.tensorizer = tensorizer 136 | self.ctx_src: SchemaPreprocessor = hydra.utils.instantiate( 137 | cfg.retriever.datasets[cfg.retriever.ctx_src], self.checkpoint_cfg, tensorizer=tensorizer 138 | ) 139 | self.loss_func: RankerLoss = reranker.get_loss_function() 140 | self.biencoder_prepare_model_inputs = reranker.prepare_model_inputs 141 | 142 | self.reranker, _ = setup_fp16_and_distributed_mode( 143 | reranker, 144 | None, 145 | self.checkpoint_cfg.device, 146 | self.checkpoint_cfg.n_gpu, 147 | self.checkpoint_cfg.local_rank, 148 | self.checkpoint_cfg.fp16, 149 | ) 150 | self.reranker.eval() 151 | 152 | # load weights from the model file 153 | logger.info("Loading saved model state ...") 154 | reranker = get_model_obj(self.reranker) 155 | reranker.load_state(saved_state, strict=True) 156 | 157 | self.query_ctxts = None 158 | self.annotate_interactions = self.cfg.retriever.retrieved_data_nr_ctxs == 1 159 | 160 | self.cfg.retriever.retrieved_data_files = [ 161 | s.strip().split(":") for s in self.cfg.retriever.retrieved_data_files 162 | ] 163 | 164 | # self.annotate_interactions = False 165 | 166 | def get_queries_data(self, gold_queries_data): 167 | return get_ctxt_encoder_input_in_dpr_schema( 168 | self.cfg.retriever.retrieved_data_files, 169 | gold_queries_data=gold_queries_data, 170 | nr_ctxs=self.cfg.retriever.retrieved_data_nr_ctxs, 171 | ) 172 | 173 | def set_queries_data(self, queries_data): 174 | self.query_ctxts = queries_data 175 | 176 | def run(self): 177 | 178 | result = {} 179 | 180 | def collate(batch): 181 | """ 182 | Collates a list of batch items into processed tensors for a bienencoder model to consume. 183 | :param batch: 184 | :return: 185 | """ 186 | batch_inputs = list() 187 | batch_query_ctxs = list() 188 | orig_lens = list() 189 | 190 | for query_ctxs in batch: 191 | question_ids = self.ctx_src.preprocess_query(query_ctxs["question"], training=False) 192 | ctx_ids = torch.stack( 193 | [ 194 | self.ctx_src.preprocess_passage(self.ctx_src.passage_struct_from_dict(p), training=False) 195 | for p in query_ctxs["ctxs"] 196 | ], 197 | dim=0, 198 | ) 199 | ctx_ids_size = ctx_ids.size(0) 200 | orig_lens.append(ctx_ids_size) 201 | should_ctx_ids_size = ( 202 | len(self.cfg.retriever.retrieved_data_files) * self.cfg.retriever.retrieved_data_nr_ctxs 203 | ) 204 | if should_ctx_ids_size > ctx_ids_size: 205 | padding = torch.zeros_like(ctx_ids[0 : (should_ctx_ids_size - ctx_ids_size), :]) 206 | ctx_ids = torch.cat([ctx_ids, padding], dim=0) 207 | inputs = { 208 | "question_ids": question_ids, 209 | "question_segments": torch.zeros_like(question_ids), 210 | "context_ids": ctx_ids, 211 | "ctx_segments": torch.zeros_like(ctx_ids), 212 | } 213 | batch_inputs.append(inputs) 214 | batch_query_ctxs.append(query_ctxs) 215 | 216 | # pad the the last batch to play well with gather() from multi-GPU 217 | while len(batch_inputs) < self.checkpoint_cfg.n_gpu * multiplier: 218 | batch_inputs.append( 219 | { 220 | "question_ids": torch.zeros_like(batch_inputs[-1]["question_ids"]), 221 | "question_segments": torch.zeros_like(batch_inputs[-1]["question_segments"]), 222 | "context_ids": torch.zeros_like(batch_inputs[-1]["context_ids"]), 223 | "ctx_segments": torch.zeros_like(batch_inputs[-1]["ctx_segments"]), 224 | } 225 | ) 226 | 227 | splits = [len(bi["context_ids"]) for bi in batch_inputs] 228 | batch_inputs_merged = { 229 | k: torch.cat([bi[k].view(-1, bi[k].size(-1)) for bi in batch_inputs], dim=0) 230 | for k in batch_inputs[0].keys() 231 | } 232 | batch_inputs_merged = self.biencoder_prepare_model_inputs(batch_inputs_merged, self.ctx_src) 233 | 234 | return orig_lens, splits, batch_inputs_merged, batch_query_ctxs 235 | 236 | multiplier = 1 # TODO: hardcoded nr of batch items per GPU (with nr_of_ctxs passages, atm hardcoded 100) 237 | loader = torch.utils.data.DataLoader( 238 | self.query_ctxts, batch_size=max(1, self.checkpoint_cfg.n_gpu * multiplier), collate_fn=collate 239 | ) 240 | 241 | for orig_lens, splits, batch_inputs_merged, batch_query_ctxs in tqdm.tqdm(loader): 242 | 243 | batch_inputs_merged = move_to_device(batch_inputs_merged, self.checkpoint_cfg.device) 244 | 245 | interactions = None 246 | 247 | with torch.no_grad(): 248 | q_out, ctx_out = self.reranker( 249 | **batch_inputs_merged, 250 | encoder_type="question", 251 | representation_token_pos=0, 252 | ) 253 | if self.annotate_interactions and hasattr(get_model_obj(self.reranker), "get_interactions"): 254 | interactions = get_model_obj(self.reranker).get_interactions( 255 | **batch_inputs_merged, 256 | encoder_type="question", 257 | representation_token_pos=0, 258 | ) 259 | interactions = interactions.detach().cpu().tolist() 260 | 261 | q_out, ctx_out, _, _ = gather_to_main_device( 262 | self.checkpoint_cfg, 263 | q_out, 264 | ctx_out, 265 | ) 266 | 267 | if len(q_out.size()) > 1: 268 | q_out_splits = torch.split(q_out, 1) 269 | ctx_out_splits = torch.split(ctx_out, splits) 270 | else: 271 | # crossencoder sends fake output with size len(q_out.size()) == 1 272 | q_out_splits = torch.split(q_out, 1) 273 | if ctx_out.size(0) * ctx_out.size(1) == sum(splits): 274 | ctx_out_splits = ( 275 | torch.split(ctx_out.view(ctx_out.size(0) * ctx_out.size(1), -1), splits) 276 | if len(ctx_out.size()) > 1 277 | else [None] * len(splits) 278 | ) 279 | else: 280 | logger.warning( 281 | f"{ctx_out.size(0)}*{ctx_out.size(1)} = {ctx_out.size(0) * ctx_out.size(1)} != {sum(splits)} {splits}" 282 | ) 283 | continue 284 | 285 | assert len(q_out_splits) == len(ctx_out_splits) 286 | scores_chunks = list() 287 | 288 | for q_out, ctx_out in zip(q_out_splits, ctx_out_splits): 289 | loss_func_get_scores = self.loss_func.get_scores(q_out, ctx_out).view(-1) 290 | scores_chunks.append(loss_func_get_scores) 291 | 292 | # loss_func_get_scores = self.loss_func.get_scores(q_out, ctx_out).view(-1).cpu() 293 | # print(loss_func_get_scores.size()) 294 | # scores_chunks = [sc.tolist() for sc in torch.split(loss_func_get_scores, 1)] 295 | 296 | for scores_chunk, query_ctxs, orig_len in zip(scores_chunks, batch_query_ctxs, orig_lens): 297 | outputs = list() 298 | assert orig_len == len( 299 | query_ctxs["provenances"] 300 | ), f"scores_chunk={scores_chunk.size()} orig_len={orig_len} len(provenances)={len(query_ctxs['provenances'])}" 301 | scores_chunk = scores_chunk[:orig_len] 302 | # sort by score in descending order 303 | for score, provenance_id in sorted(zip(scores_chunk, list(range(len(scores_chunk)))), reverse=True): 304 | provenance = query_ctxs["provenances"][provenance_id] 305 | provenance["score"] = score.item() 306 | if interactions: 307 | provenance["question_ids"] = batch_inputs_merged["question_ids"].detach().cpu().tolist() 308 | provenance["context_ids"] = batch_inputs_merged["context_ids"].detach().cpu().tolist() 309 | provenance["interactions"] = interactions 310 | outputs.append(provenance) 311 | result[query_ctxs["id"]] = outputs 312 | 313 | return result 314 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/img/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/projects/verify_wikipedia/img/architecture.jpg -------------------------------------------------------------------------------- /projects/verify_wikipedia/internal/eval/retrieval/retrieve/configs/best_dpr_remote.yaml: -------------------------------------------------------------------------------- 1 | retriever: dpr_remote 2 | 3 | retriever.ctx_src: wafer_ccnet 4 | retriever.batch_size: 1024 5 | retriever.checkpoint_dir: ../../../../models/dense/biencoder/ 6 | retriever.rpc_retriever_cfg_file: ../../../../models/dense/index/discovery-config.txt 7 | 8 | n_docs: 100 9 | evaluation_datasets: 10 | - wafer_dev 11 | - wafer_test 12 | 13 | 14 | # Uncomment to write the prediction files into the corresponding checkpoint directory 15 | #output_folder: ${retriever.checkpoint_dir}/predictions/${retriever.checkpoint_load_suffix}__${retriever.ctx_src} 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/internal/eval/retrieval/retrieve/configs/best_reranker.yaml: -------------------------------------------------------------------------------- 1 | retriever: reranker 2 | 3 | retriever.ctx_src: wafer_ccnet 4 | retriever.batch_size: 2 5 | retriever.checkpoint_dir: ../../../../models/verifier/ 6 | retriever.retrieved_data_nr_ctxs: 100 7 | retriever.retrieved_data_files: 8 | - dpr:../../../../outputs/predictions/evaluation.retrievers.dpr_remote.DPR/wafer-test.jsonl 9 | - gar:../../../../outputs/predictions/GAR/wafer-test.jsonl 10 | 11 | evaluation_datasets: 12 | - wafer_test 13 | 14 | # Uncomment to write the prediction files into the corresponding checkpoint directory 15 | #output_folder: ${retriever.checkpoint_dir}/predictions/${retriever.checkpoint_load_suffix}__${retriever.ctx_src} 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/internal/eval/retrieval/retrieve/retrieve.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | BASEDIR=$( pwd ) 8 | 9 | module purge 10 | module unload cuda 11 | module load cuda 12 | module load NCCL 13 | module load cudnn 14 | module unload java 15 | export JAVA_HOME=/private/home/fabiopetroni/jdk-11.0.9 16 | PATH=${PATH}:${JAVA_HOME}/bin 17 | 18 | export PYTHONPATH=$BASEDIR:$BASEDIR/../distributed-faiss:$BASEDIR/../KILT 19 | 20 | # execute retrieval 21 | CUDA_VISIBLE_DEVICES=0 python $BASEDIR/scripts/retrieve_rerank.py $@ 22 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/internal/eval/retrieval/retrieve/retrieve_gpu_sharded.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | EXPERIMENT_YAML=$(realpath $1) 8 | BASEDIR=$( pwd ) 9 | 10 | # generate embeddings 11 | sbatch $BASEDIR/internal/eval/retrieval/retrieve/slurm/sbatch_retrieve_gpu_sharded.sh \ 12 | $BASEDIR \ 13 | $EXPERIMENT_YAML 14 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/internal/eval/retrieval/retrieve/slurm/sbatch_retrieve_gpu_sharded.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # Usage: sbatch launch_distributed.sh EXPERIMENT.yaml 8 | #SBATCH --job-name=kilt 9 | #SBATCH --partition=learnlab 10 | #SBATCH --array=0-3 11 | #SBATCH --nodes=1 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gpus-per-node=2 14 | #SBATCH --constraint=volta32gb 15 | #SBATCH --cpus-per-task=4 16 | #SBATCH --mem=240G 17 | #SBATCH --time=72:00:00 18 | #SBATCH --open-mode=append 19 | #SBATCH --comment="kilt job" 20 | #SBATCH --output=logs/retrieve_%A_%a.log 21 | 22 | BASEDIR=$1 23 | echo "BASEDIR = $BASEDIR" 24 | WRAPPED_SCRIPT=$BASEDIR/internal/eval/retrieval/retrieve/slurm/wrapped_retrieve.sh 25 | echo "WRAPPED_SCRIPT = $WRAPPED_SCRIPT" 26 | shift; 27 | echo "Calling command $WRAPPER" 28 | srun --label $WRAPPED_SCRIPT "$@" 29 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/internal/eval/retrieval/retrieve/slurm/wrapped_retrieve.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | . /usr/share/modules/init/sh 9 | eval "$(conda shell.bash hook)" 10 | conda activate side 11 | 12 | BASEDIR=$( pwd ) 13 | echo "BASEDIR=$BASEDIR" 14 | 15 | module purge 16 | module unload cuda 17 | module load cuda 18 | module load NCCL 19 | module load cudnn 20 | module unload java 21 | export JAVA_HOME=/private/home/fabiopetroni/jdk-11.0.9 22 | PATH=${PATH}:${JAVA_HOME}/bin 23 | 24 | export PYTHONPATH=$BASEDIR:$BASEDIR/../distributed-faiss:$BASEDIR/../KILT 25 | 26 | # execute retrieval 27 | python $BASEDIR/scripts/retrieve_rerank.py $@ task_id=$SLURM_ARRAY_JOB_ID 28 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/internal/eval/retrieval/retrieve/start_distributed_faiss.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | DIST_FAISS_DATA_DIR=$1 8 | 9 | mkdir -p $DIST_FAISS_DATA_DIR 10 | 11 | echo "$DIST_FAISS_DATA_DIR" 12 | 13 | BASEDIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )/../../../../../ 14 | 15 | echo "$BASEDIR" 16 | 17 | export PYTHONPATH=$BASEDIR:$BASEDIR/distributed-faiss 18 | 19 | python $BASEDIR/distributed-faiss/scripts/server_launcher.py \ 20 | --discovery-config $DIST_FAISS_DATA_DIR/discovery-config.txt \ 21 | --log-dir $DIST_FAISS_DATA_DIR/logs \ 22 | --save-dir $DIST_FAISS_DATA_DIR/index \ 23 | --partition learnlab \ 24 | --num-servers 32 \ 25 | --num-servers-per-node 2 \ 26 | --timeout-min 4320 \ 27 | --mem-gb 500 \ 28 | --base-port 12034 29 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/projects/verify_wikipedia/misc/__init__.py -------------------------------------------------------------------------------- /projects/verify_wikipedia/misc/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import sys 10 | 11 | from omegaconf import OmegaConf, DictConfig 12 | 13 | 14 | def load_data(filename): 15 | data = [] 16 | with open(filename, "r") as fin: 17 | lines = fin.readlines() 18 | for line in lines: 19 | data.append(json.loads(line)) 20 | return data 21 | 22 | 23 | def load_options_from_argv_yaml(constrain_to=None): 24 | # convert the cli params into Hydra format 25 | config_file = None 26 | hydra_formatted_args = list() 27 | if constrain_to: 28 | constrain_to = flatten_yaml_dict(constrain_to, return_dict=True) 29 | for arg in sys.argv: 30 | if len(arg.split("=")) == 1 and arg.endswith(".yaml"): 31 | config_file = OmegaConf.load(arg) 32 | for k, v in flatten_yaml_dict(yd=config_file, return_dict=True).items(): 33 | if constrain_to: 34 | if k in constrain_to: 35 | hydra_formatted_args.append(k + "=" + str(v).replace("'", "").replace('"', "")) 36 | else: 37 | hydra_formatted_args.append(k + "=" + str(v)) 38 | else: 39 | hydra_formatted_args.append(arg) 40 | sys.argv = hydra_formatted_args 41 | return config_file 42 | 43 | 44 | def flatten_yaml_dict(yd=None, prefix: str = "", level: int = 0, return_dict: bool = True): 45 | result_args = list() 46 | if len(prefix) == 0: 47 | for k, v in yd.items(): 48 | if isinstance(v, dict) or isinstance(v, DictConfig): 49 | result_args.extend(flatten_yaml_dict(v, k, level=0, return_dict=False)) 50 | # elif isinstance(v, str): 51 | else: 52 | result_args.append(k + "=" + str(v).replace("'", "").replace('"', "")) 53 | # else: 54 | # print(type(v)) 55 | else: 56 | for k, v in yd.items(): 57 | if isinstance(v, dict) or isinstance(v, DictConfig): 58 | result_args.extend(flatten_yaml_dict(v, prefix + "." + k, level=level + 1)) 59 | # elif isinstance(v, str): 60 | else: 61 | result_args.append(prefix + "." + k + "=" + str(v).replace("'", "").replace('"', "")) 62 | # else: 63 | # print(type(v)) 64 | if level > 0: 65 | return result_args 66 | elif return_dict: 67 | return dict([(kv.split("=")[0], kv.split("=")[1]) for kv in result_args]) 68 | else: 69 | return result_args 70 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/scripts/GAR_urltitle_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import json 9 | from tqdm.auto import tqdm 10 | import pickle 11 | import json 12 | import argparse 13 | 14 | from genre.fairseq_model import GENRE 15 | 16 | 17 | def load_data(filename): 18 | data = [] 19 | with open(filename, "r") as fin: 20 | lines = fin.readlines() 21 | for line in lines: 22 | data.append(json.loads(line)) 23 | return data 24 | 25 | 26 | def store_data(filename, data): 27 | with open(filename, "w+") as outfile: 28 | for idx, element in enumerate(data): 29 | # print(round(idx * 100 / len(data), 2), "%", end="\r") 30 | # sys.stdout.flush() 31 | json.dump(element, outfile) 32 | outfile.write("\n") 33 | 34 | 35 | def main(args, test_data): 36 | print( 37 | "loading model {} from {}".format( 38 | args.checkpoint_file, 39 | args.model_path, 40 | ) 41 | ) 42 | model = ( 43 | GENRE.from_pretrained( 44 | args.model_path, checkpoint_file=args.checkpoint_file, arg_overrides=True 45 | ) 46 | .eval() 47 | .to("cuda:1") 48 | ) 49 | 50 | genre_output = [] 51 | buffer = [] 52 | batch_size = 100 53 | beam_size = 15 54 | for record in tqdm(test_data): 55 | s = record["input"] 56 | 57 | buffer.append(s) 58 | 59 | if len(buffer) == batch_size: 60 | a = model.sample( 61 | buffer, 62 | beam=beam_size, 63 | max_len_b=15, 64 | # prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 65 | ) 66 | genre_output.extend(a) 67 | buffer = [] 68 | 69 | if len(buffer) > 0: 70 | a = model.sample( 71 | buffer, 72 | beam=beam_size, 73 | max_len_b=15, 74 | # prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 75 | ) 76 | genre_output.extend(a) 77 | buffer = [] 78 | 79 | new_data = [] 80 | print("storing output file in {}".format(args.out_filename)) 81 | for record, output in zip(test_data, genre_output): 82 | for x in output: 83 | if type(x["score"]) != float: 84 | x["score"] = x["score"].item() 85 | 86 | # append title prediction to input 87 | record["input"] = ( 88 | record["meta"]["sentences"][-1] 89 | + " " 90 | + output[0]["text"].strip() 91 | + " " 92 | + record["meta"]["wikipedia_title"] 93 | ) 94 | record["output"] = output 95 | new_data.append(record) 96 | 97 | store_data(args.out_filename, new_data) 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser() 102 | 103 | parser.add_argument( 104 | "--model_path", 105 | dest="model_path", 106 | type=str, 107 | default="models/sparse/title_generator/", 108 | help="model path.", 109 | ) 110 | 111 | parser.add_argument( 112 | "--checkpoint_file", 113 | dest="checkpoint_file", 114 | type=str, 115 | default="checkpoint_best.pt", 116 | help="checkpoint file.", 117 | ) 118 | 119 | parser.add_argument( 120 | "--test_filename", 121 | dest="test_filename", 122 | type=str, 123 | required=True, 124 | help="test filename.", 125 | ) 126 | 127 | parser.add_argument( 128 | "--out_filename", 129 | dest="out_filename", 130 | type=str, 131 | required=True, 132 | help="output filename.", 133 | ) 134 | 135 | args = parser.parse_args() 136 | 137 | print("lading test data from {}".format(args.test_filename)) 138 | test_data = load_data(args.test_filename) 139 | 140 | main(args, test_data) 141 | -------------------------------------------------------------------------------- /projects/verify_wikipedia/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/side/6dc4ee909655872dc0c822d9cdd0fd62fffa21d4/projects/verify_wikipedia/scripts/__init__.py -------------------------------------------------------------------------------- /projects/verify_wikipedia/scripts/retrieve_rerank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import glob 8 | import json 9 | import logging 10 | import math 11 | import os 12 | import sys 13 | 14 | import filelock 15 | import hydra 16 | from omegaconf import OmegaConf 17 | 18 | from evaluation.retrievers.base_retriever import Retriever 19 | from misc import utils 20 | from misc.utils import load_options_from_argv_yaml 21 | 22 | 23 | def setup_logger(logger): 24 | logger.setLevel(logging.INFO) 25 | if logger.hasHandlers(): 26 | logger.handlers.clear() 27 | log_formatter = logging.Formatter("[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s") 28 | console = logging.StreamHandler() 29 | console.setFormatter(log_formatter) 30 | logger.addHandler(console) 31 | 32 | 33 | logger = logging.getLogger() 34 | setup_logger(logger) 35 | 36 | 37 | def output_file_name(output_folder, dataset_file, output_suffix=""): 38 | if not output_suffix: 39 | output_suffix = "" 40 | basename = os.path.basename(dataset_file) 41 | output_file = os.path.join(output_folder, basename) + output_suffix 42 | if not os.path.exists(os.path.dirname(output_file)): 43 | os.makedirs(os.path.dirname(output_file)) 44 | return output_file 45 | 46 | 47 | def apply_ranker( 48 | test_config, 49 | ranker: Retriever, 50 | logger, 51 | debug=False, 52 | output_folder="", 53 | num_shards=1, 54 | shard_id=0, 55 | ): 56 | 57 | for dataset in test_config.evaluation_datasets: 58 | dataset = hydra.utils.instantiate(test_config.datasets[dataset]) 59 | 60 | logger.info("TASK: {}".format(dataset.task_family)) 61 | logger.info("DATASET: {}".format(dataset.name)) 62 | 63 | output_file = output_file_name(output_folder, dataset.file, test_config.output_suffix) 64 | if os.path.exists(output_file): 65 | logger.info("Skip output file {} that already exists.".format(output_file)) 66 | continue 67 | 68 | raw_data = utils.load_data(dataset.file) 69 | validated_data = {} 70 | queries_data = list() 71 | for instance in raw_data: 72 | if dataset.validate_datapoint(instance, logger=logger): 73 | instance = ranker.preprocess_instance(instance) 74 | if instance["id"] in validated_data: 75 | raise ValueError("ids are not unique in input data!") 76 | validated_data[instance["id"]] = instance 77 | queries_data.append({"query": instance["input"], "id": instance["id"]}) 78 | 79 | queries_data = ranker.get_queries_data(queries_data) 80 | 81 | if debug: 82 | # just consider the top10 datapoints 83 | queries_data = queries_data[:10] 84 | print("query_data: {}", format(queries_data)) 85 | 86 | if num_shards > 1: 87 | len_all_query_ctxts = len(queries_data) 88 | shard_size = math.ceil(len_all_query_ctxts / num_shards) 89 | start_idx = shard_id * shard_size 90 | end_idx = start_idx + shard_size 91 | queries_data = queries_data[start_idx:end_idx] 92 | logger.info(f"sharded query_ctxy size: {len(queries_data)}") 93 | else: 94 | logger.info(f"query_ctxy size: {len(queries_data)}") 95 | 96 | ranker.set_queries_data(queries_data) 97 | 98 | # get predictions 99 | provenance = ranker.run() 100 | 101 | if len(provenance) != len(queries_data): 102 | logger.warning( 103 | "different numbers of queries: {} and predictions: {}".format(len(queries_data), len(provenance)) 104 | ) 105 | 106 | # write prediction files 107 | if provenance: 108 | logger.info("writing prediction file to {}".format(output_file)) 109 | 110 | predictions = [] 111 | for query_id in provenance.keys(): 112 | if query_id in validated_data: 113 | instance = validated_data[query_id] 114 | new_output = [{"provenance": provenance[query_id]}] 115 | # append the answers 116 | if "output" in instance: 117 | for o in instance["output"]: 118 | if "answer" in o: 119 | new_output.append({"answer": o["answer"]}) 120 | instance["output"] = new_output 121 | predictions.append(instance) 122 | 123 | if not os.path.exists(output_folder): 124 | os.makedirs(output_folder, exist_ok=True) 125 | 126 | if num_shards > 1: 127 | 128 | lock = filelock.FileLock(output_file_name(output_folder, dataset.file, ".lock")) 129 | with lock: 130 | with open(output_file, "w+") as outfile: 131 | for p in predictions: 132 | json.dump(p, outfile) 133 | outfile.write("\n") 134 | output_files = glob.glob(output_file_name(output_folder, dataset.file, ".[0-9]*-[0-9]*")) 135 | if len(output_files) == num_shards: 136 | sorted(output_files) 137 | with open(output_file_name(output_folder, dataset.file, ""), "w") as cat: 138 | for output_file in output_files: 139 | cat.writelines(open(output_file)) 140 | 141 | else: 142 | 143 | with open(output_file, "w+") as outfile: 144 | for p in predictions: 145 | json.dump(p, outfile) 146 | outfile.write("\n") 147 | 148 | 149 | @hydra.main(config_path="../conf", config_name="retrieval") 150 | def main(cfg): 151 | 152 | logger.info("loading {} ...".format(OmegaConf.to_yaml(cfg))) 153 | 154 | shard_id = int(os.environ.get("SLURM_ARRAY_TASK_ID")) if os.environ.get("SLURM_ARRAY_TASK_ID") else 1 155 | num_shards = int(os.environ.get("SLURM_ARRAY_TASK_COUNT")) if os.environ.get("SLURM_ARRAY_TASK_COUNT") else 1 156 | task_id = int(os.environ.get("SLURM_ARRAY_JOB_ID")) if os.environ.get("SLURM_ARRAY_JOB_ID") else cfg.task_id 157 | 158 | if num_shards > 1: 159 | cfg.output_suffix = f".{shard_id:02}-{num_shards}" 160 | 161 | module_name = ".".join(cfg.retriever._target_.split(".")[:-1]) 162 | clazz_name = cfg.retriever._target_.split(".")[-1] 163 | retriever_module = __import__(name=module_name, fromlist=clazz_name) 164 | retriever_clazz = getattr(retriever_module, clazz_name) 165 | retriever = retriever_clazz(clazz_name, cfg) 166 | 167 | if cfg.output_folder is None or len(cfg.output_folder) == 0: 168 | cfg.output_folder = "./" 169 | 170 | if cfg.output_folder_in_checkpoint_dir: 171 | output_folder = f"{cfg.output_folder}__{task_id}" 172 | else: 173 | output_folder = cfg.output_folder 174 | 175 | apply_ranker( 176 | test_config=cfg, 177 | ranker=retriever, 178 | logger=logger, 179 | output_folder=output_folder, 180 | num_shards=num_shards, 181 | shard_id=shard_id, 182 | ) 183 | 184 | OmegaConf.save(cfg, f"{output_folder}/retriever_cfg.yaml") 185 | if hasattr(retriever, "checkpoint_cfg"): 186 | OmegaConf.save(retriever.checkpoint_cfg, f"{output_folder}/checkpoint_cfg.yaml") 187 | 188 | 189 | if __name__ == "__main__": 190 | load_options_from_argv_yaml() 191 | logger.info("Sys.argv: %s", sys.argv) 192 | main() 193 | --------------------------------------------------------------------------------