├── .gitignore ├── LICENSE ├── README.md ├── baseline ├── README.md ├── mDPR │ ├── README.md │ ├── create_adverarial_data.py │ ├── create_mine_train_data.py │ ├── dense_retriever.py │ ├── dpr │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── qa_validation.py │ │ │ └── reader_data.py │ │ ├── indexer │ │ │ └── faiss_indexers.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── biencoder.py │ │ │ ├── fairseq_models.py │ │ │ ├── hf_models.py │ │ │ ├── pytext_models.py │ │ │ └── reader.py │ │ ├── options.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── data_utils.py │ │ │ ├── dist_utils.py │ │ │ ├── model_utils.py │ │ │ └── tokenizers.py │ ├── generate_dense_embeddings.py │ └── train_dense_encoder.py ├── mGEN │ ├── README.md │ ├── __init__.py │ ├── align_wikidata.py │ ├── callbacks_rag.py │ ├── convert_dpr_retrieval_results_to_seq2seq.py │ ├── eval_mgen.py │ ├── finetune_mgen.py │ ├── lightning_base.py │ ├── requirements.txt │ ├── utils.py │ └── utils_rag.py ├── run_evaluation.sh ├── run_evaluation_cora.sh └── wikipedia_preprocess │ ├── README.md │ ├── build_db.py │ ├── build_dpr_w100_data.py │ ├── create_w100_data_japanese.py │ ├── create_w100_data_khmer.py │ ├── create_w100_data_thai.py │ ├── doc_db.py │ └── utils.py ├── data ├── eval │ ├── mia2022_test_surprise_tagalog_without_answers.jsonl │ ├── mia2022_test_surprise_tamil_without_answers.jsonl │ ├── mia_2022_dev_xorqa.jsonl │ ├── mia_2022_test_xorqa_without_answers.jsonl │ ├── mkqa_dev.zip │ └── mkqa_test_without_answers.zip └── train │ └── mia_2022_train_data.jsonl.zip ├── eval_scripts ├── eval_mkqa_all.py ├── eval_xor_full.py └── requirements.txt └── sample_predictions ├── submission.json └── submission_test.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | *.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 mia-workshop 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /baseline/README.md: -------------------------------------------------------------------------------- 1 | # Baselines for MIA 2022 Shared Task 2 | 3 | Our primary baseline model is the state-of-the-art [CORA](https://github.com/AkariAsai/CORA), which runs a multilingual DPR (mDPR) model to retrieve documents from many different languages and then generate the final answers in the target languages using a multilingual seq2seq generation models (mGEN). 4 | 5 | We have two versions: 6 | 1. **(Baseline 1) mDPR + mGEN (CORA w/o iterative training)**: We train mDPR and mGEN following the procedures in the CORA paper using the MIA2022 shared task official training and development data. We do not conduct iterative training. 7 | The experimental results can be reproduced by `run_evaluation.sh`. This is our primary baseline. 8 | 9 | 2. **(Baseline 2) CORA (trained models)**: 10 | We run the models available at the CORA library on our evaluation data. For the languages that are not originally covered by the CORA repository, we run mDPR to generate passage embeddings. 11 | THe experimental results can be reproduced by `run_evaluation_cora.sh`. 12 | 13 | ## Quick evaluation 14 | To reproduce the main results of the baseline 1, please run the following script. 15 | 16 | ``` 17 | bash run_evaluation.sh 18 | ``` 19 | ## Baseline predictions 20 | 21 | ### Intermediate results -- mDPR retrieval results 22 | 23 | To encourage those who are more interested in improving answer generation / reader components after retrieval, we release the retrieval results for the MKQA and XOR TyDi QA data. All of the retrieval results for the training and development set can be downloaded by running the command below. 24 | 25 | - Training data 26 | 27 | ``` 28 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia_shared_training_dpr_retrieval_results.json 29 | ``` 30 | 31 | - XOR QA Development data 32 | ``` 33 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia_shared_xorqa_development_dpr_retrieval_results.json 34 | ``` 35 | 36 | 37 | - MKQA development data 38 | ``` 39 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_non_iterative_baselines_mkqa_dev.zip 40 | unzip mia2022_non_iterative_baselines_mkqa_dev.zip 41 | ``` 42 | 43 | - Test data 44 | ``` 45 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_non_iterative_baselines_mkqa_dev.zip 46 | unzip mia2022_non_iterative_baselines_mkqa_dev.zip 47 | ``` 48 | 49 | Retrieval results for the test set will be available once the official test data is released. 50 | 51 | ### Final prediction results 52 | You can download final predictions from the following links. 53 | 54 | - Baseline 1:[MIA2022_Baseline 1 sample_predictions](https://drive.google.com/drive/folders/14Xv6enk7j4d3QKTNbB5jGjaColNffwW_?usp=sharing). 55 | 56 | - Baseline 2: [MIA2022_Baseline 2 sample_predictions](https://drive.google.com/drive/folders/1ePQjLOWUNiF5mr6leAhw8OG-o1h55i75?usp=sharing). 57 | -------------------------------------------------------------------------------- /baseline/mDPR/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## mDPR 3 | This code is mostly same as the original DPR repository with some minor modifications. The code is based on [Dense Passage Retriever](https://github.com/facebookresearch/DPR) and we modify the code to support more recent version of huggingface transformers. 4 | 5 | ### Installation 6 | Please install the dependencies by running the command below: 7 | 8 | ``` 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ### Download models 13 | 14 | - Baseline (1): mDPR trained on our training data (see details below) without iterative training process. 15 | 16 | 17 | ``` 18 | mkdir models 19 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_all_langs_w100.tsv 20 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mDPR_biencoder_best.cpt 21 | unzip mGEN_model.zip 22 | mkdir embeddings 23 | cd embeddings 24 | for i in 0 1 2 3; 25 | do 26 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/wiki_emb_en_$i 27 | done 28 | for i in 0 1 2 3; 29 | do 30 | wget https://nlp.cs.washington.edu/xorqa/cora/models/embeddings_baseline1/wiki_emb_others_$i 31 | done 32 | ``` 33 | 34 | - Baseline (2)): [CORA (Asai et al., 2021)](https://github.com/AkariAsai/CORA) public model 35 | 36 | ``` 37 | mkdir models 38 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_all_langs_w100.tsv 39 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mDPR_biencoder_best.cpt 40 | unzip mGEN_model.zip 41 | mkdir embeddings 42 | cd embeddings 43 | for i in 0 1 2 3 4 5 6 7; 44 | do 45 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/embeddings/wiki_emb_en_$i 46 | done 47 | for i in 0 1 2 3 4 5 6 7; 48 | do 49 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/embeddings/wiki_emb_xor_$i 50 | done 51 | for i in 0 1 2 3 4 5 6 7; 52 | do 53 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/embeddings/wiki_others_emb__$i 54 | done 55 | for i in 0 1 2 3 4 5 6 7; 56 | do 57 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/embeddings/wiki_others_emb_ms_tr_km_$i 58 | done 59 | ``` 60 | ### Data 61 | #### Training data using DPR's NQ training data + XOR-TyDi QA gold paragraph data 62 | 63 | ``` 64 | wget https://nlp.cs.washington.edu/xorqa/cora/data/base_mdpr_train_dev_data/mia2022_mdpr_train.json 65 | wget https://nlp.cs.washington.edu/xorqa/cora/data/base_mdpr_train_dev_data/mia2022_mdpr_xor_dev.json 66 | ``` 67 | 68 | The original data is from 69 | - [DPR's Natural Questions train data](https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz) 70 | - [XOR-TyDiQA's gold paragraph data](https://nlp.cs.washington.edu/xorqa/XORQA_site/data/trans_data_all_langs.zip) 71 | 72 | If you further augment your training data using them or the Natural Questions / TyDi QA, please make sure that you **do not** use any additional QA data from those datasets (i.e., questions whose question ids are not included in our official training data). 73 | 74 | #### Training data with adversarial paragraphs 75 | Recent work has shown that using a trained DPR model to mine harder negative passages can improve retrieval performance. See detailed discussions at [the original DPR repository](https://github.com/facebookresearch/DPR#new-march-2021-retrieval-model). 76 | 77 | We create additional training and dev data set by augmenting positive and negative passages from the top 50 retrieval results of our mDPR models. Please see the details of this process at [create_adverarial_data.py](create_adverarial_data.py). 78 | 79 | ``` 80 | wget https://nlp.cs.washington.edu/xorqa/cora/data/mia_adversarial_mdpr/mia_train_adversarial.json 81 | wget https://nlp.cs.washington.edu/xorqa/cora/data/mia_adversarial_mdpr/mia_xor_dev_adversarial.json 82 | ``` 83 | ### Training 84 | 1. Initial training 85 | 86 | We first train the DPR models using gold paragraph data from Natural Questions, XOR QA and TyDi QA. 87 | 88 | ``` 89 | python -m torch.distributed.launch \ 90 | -nproc_per_node=8 train_dense_encoder.py \ 91 | --max_grad_norm 2.0 \ 92 | --encoder_model_type hf_bert \ 93 | --pretrained_model_cfg bert-base-multilingual-uncased \ 94 | --seed 12345 --sequence_length 256 \ 95 | --warmup_steps 300 --batch_size 16 --do_lower_case \ 96 | --train_file /path/to/train/data \ 97 | --dev_file /path/to/eval/data \ 98 | --output_dir /path/to/output/dir \ 99 | --learning_rate 2e-05 --num_train_epochs 40 \ 100 | --dev_batch_size 6 --val_av_rank_start_epoch 30 101 | ``` 102 | 103 | 2. Generate Wikipedia embeddings 104 | After you train the DPR encoders, you need to generate Wikipedia passage embeddings. Please create a Wikipedia passage file following the instruction in the `wikipedia_preprocess` directory. The script to generate multilingual embeddings using 8 GPUs is as follows: 105 | 106 | ```sh 107 | for i in {0..7}; do 108 | export CUDA_VISIBLE_DEVICES=${i} 109 | nohup python generate_dense_embeddings.py --model_file /path/to/model/checkpoint --batch_size 64 --ctx_file /path/to/wikipedia/passage/file --shard_id ${i} --num_shards 8 --out_file ./embeddings_multilingual/wikipedia_split/wiki_emb > ./log/nohup.generate_wiki_emb.ser23_3_multi.${i} 2>&1 & 110 | done 111 | ``` 112 | Note that when you generate embeddings for the 13 target languages, you may experience out of memory issue when you load the Wikipedia passage tsv file (the total wikipedia passage size is 24GB * 8 GPU). 113 | We recommend you to generate English embeddings first, and then do the same for the remaining languages. 114 | 115 | 3. Retrieve Wikipedia passages for train data questions 116 | Following prior work, we retrieve top passages for the train data questions and use them to train our generator. Once you generate train data, you can retrieve top passages by running the command below. 117 | 118 | ``` 119 | python dense_retriever.py \ 120 | --model_file /path/to/model/checkpoint \ 121 | --ctx_file /path/to/wikipedia/passage/file --n-docs 100 \ 122 | --qa_file /path/to/input/qa/file \ 123 | --encoded_ctx_file "{glob expression for generated files}" \ 124 | --out_file /path/to/prediction/outputs \ 125 | --validation_workers 4 --batch_size 64 126 | ``` 127 | 128 | After run train your generator, please run the script to create new mDPR train data and repeat the steps from 1 using the new data. 129 | 130 | ### Evaluations 131 | You can run the evaluation using the same command as the step 3 in training. 132 | For example, to run the evaluation on the XOR QA dev data, you can run the command below. 133 | 134 | ``` 135 | python dense_retriever.py \ 136 | --model_file ../models/mDPR_biencoder_best.cpt \ 137 | --ctx_file ../models/all_w100.tsv \ 138 | --qa_file ../data/xor_dev_full_v1_1.jsonl \ 139 | --encoded_ctx_file "../models/embeddings/wiki_emb_*" \ 140 | --out_file xor_dev_dpr_retrieval_results.json \ 141 | --n-docs 20 --validation_workers 1 --batch_size 256 --add_lang 142 | ``` 143 | Due to the large number of the multilingual passages embeddings, retrieving passages takes more time than English only DPR. 144 | 145 | ### Retrieved results after mDPR initial training 146 | The top 50 passages retrieved by mDPR after initial training for our training, development sets of MKQA and XOR-TyDi QA are available at the following locations. 147 | 148 | - Training data 149 | ``` 150 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia_shared_training_dpr_retrieval_results.json 151 | ``` 152 | 153 | - XOR QA development data 154 | 155 | ``` 156 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia_shared_xorqa_development_dpr_retrieval_results.json 157 | ``` 158 | 159 | - MKQA development data 160 | The retrieval results for MKQA subsets are available here: 161 | 162 | ``` 163 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_non_iterative_baselines_mkqa_dev.zip 164 | unzip mia2022_non_iterative_baselines_mkqa_dev.zip 165 | ``` 166 | 167 | - Test data 168 | 169 | ``` 170 | wget https://nlp.cs.washington.edu/xorqa/cora/models/retriever_results_test_no_answers.zip 171 | unzip retriever_results_test_no_answers.zip 172 | ``` 173 | -------------------------------------------------------------------------------- /baseline/mDPR/create_adverarial_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import argparse 4 | from tqdm import tqdm 5 | import jsonlines 6 | import numpy as np 7 | from collections import Counter 8 | 9 | def read_jsonlines(eval_file_name): 10 | lines = [] 11 | print("loading examples from {0}".format(eval_file_name)) 12 | with jsonlines.open(eval_file_name) as reader: 13 | for obj in reader: 14 | lines.append(obj) 15 | return lines 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--qa_file", default=None, type=str) 20 | parser.add_argument("--orig_train_file", default=None, type=str) 21 | parser.add_argument("--train_retr_results", default=None, type=str) 22 | parser.add_argument("--output_fn", default=None, type=str) 23 | 24 | args = parser.parse_args() 25 | 26 | orig_open_data = read_jsonlines(args.qa_file) 27 | orig_train_data = json.load(open(args.orig_train_file)) 28 | train_retr_results = json.load(open(args.train_retr_results)) 29 | 30 | qid2orig_data = {item["q_id"]: item for item in orig_train_data} 31 | qid2retr_results = {item["q_id"]: item for item in train_retr_results} 32 | 33 | new_data = [] 34 | skip = 0 35 | p_count = [] 36 | n_count = [] 37 | for item in tqdm(orig_open_data): 38 | qid = item["id"] 39 | retr_results = qid2retr_results[qid] 40 | positives = [] 41 | negatives = [] 42 | for ctx in retr_results["ctxs"]: 43 | if ctx["has_answer"] is True: 44 | positives.append(ctx) 45 | else: 46 | negatives.append(ctx) 47 | new_train_sample = qid2orig_data[qid] if qid in qid2orig_data else {} 48 | 49 | if qid not in qid2orig_data: 50 | new_train_sample["question"] = item["question"] 51 | new_train_sample["answers"] = item["answers"] 52 | new_train_sample["q_id"] = item["id"] 53 | new_train_sample["negative_ctxs"] = [] 54 | new_train_sample["hard_negative_ctxs"] = [] 55 | new_train_sample["positive_ctxs"] = [] 56 | 57 | new_train_sample["positive_ctxs"] += positives 58 | hard_negatives_all = negatives + new_train_sample["hard_negative_ctxs"] 59 | sample_indices = random.sample(range(len(hard_negatives_all)), k=min(50, len(hard_negatives_all))) 60 | hard_negatives = [ctx for idx, ctx in enumerate(hard_negatives_all) if idx in sample_indices] 61 | new_train_sample["hard_negative_ctxs"] = hard_negatives 62 | 63 | if len(new_train_sample["positive_ctxs"]) == 0: 64 | skip += 1 65 | continue 66 | else: 67 | p_count.append(len(new_train_sample["positive_ctxs"])) 68 | n_count.append(len(new_train_sample["hard_negative_ctxs"])) 69 | assert "question" in new_train_sample 70 | assert "answers" in new_train_sample 71 | assert "q_id" in new_train_sample 72 | new_data.append(new_train_sample) 73 | with open(args.output_fn, 'w') as outfile: 74 | json.dump(new_data, outfile) 75 | print("processed {0} examples: {1} final examples.".format(len(orig_open_data),len(orig_open_data) - skip )) 76 | print("avg positive ctxs number: {0}, avg negative ctxs number:{1}".format(np.mean(p_count), np.mean(n_count))) 77 | 78 | if __name__=="__main__": 79 | main() -------------------------------------------------------------------------------- /baseline/mDPR/create_mine_train_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import argparse 4 | import csv 5 | import os 6 | from tqdm import tqdm 7 | import jsonlines 8 | import ast 9 | import re 10 | import numpy as np 11 | from collections import Counter 12 | 13 | def read_jsonlines(eval_file_name): 14 | lines = [] 15 | print("loading examples from {0}".format(eval_file_name)) 16 | with jsonlines.open(eval_file_name) as reader: 17 | for obj in reader: 18 | lines.append(obj) 19 | return lines 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--open_input_fn", default=None, type=str) 25 | parser.add_argument("--gold_data", default=None, type=str) 26 | parser.add_argument("--pred_result", default=None, type=str) 27 | parser.add_argument("--output_fn", default=None, type=str) 28 | 29 | args = parser.parse_args() 30 | 31 | orig_open_data = read_jsonlines(args.input_fn) 32 | qid2lang = {item["id"]:item["lang"] for item in orig_open_data} 33 | 34 | # load the original gold data; see the details of the input files in README.md. 35 | tsv_file = open(args.gold_para_data) 36 | read_tsv = csv.reader(tsv_file, delimiter="\t") 37 | orig_gold_data = [] 38 | for row in read_tsv: 39 | orig_id = row[2].split("_")[0] 40 | lang = qid2lang[orig_id] 41 | orig_gold_data.append( 42 | {"id": orig_id, "lang": lang, "answers": ast.literal_eval(row[1]), \ 43 | "title": row[3], "context":row[4], "question": row[5], "orig_id": orig_id}) 44 | 45 | # Load predictions 46 | if args.pred_result.endswith(".txt"): 47 | preds = open(args.pred_result).read().split("\n") 48 | elif args.pred_result.endswith(".json"): 49 | preds_orig = json.load(open(args.pred_result)) 50 | preds = [] 51 | for data in orig_gold_data: 52 | pred = preds_orig[data["id"]] 53 | pred.append(pred) 54 | else: 55 | raise NotImplementedError 56 | 57 | assert len(orig_gold_data) == len(preds) 58 | 59 | match_para_ids = {} 60 | for pred, data in tqdm(zip(preds, orig_gold_data)): 61 | orig_id = data["id"] 62 | match_para_ids.setdefault(orig_id, {"positive_ctxs": [], "negative_ctxs": [], "hard_negative_ctxs": [], 'matched_ctxs': [], "question": "{0} [{1}]".format(data["question"], data["lang"]), "lang": data["lang"], "answers": data["answers"] }) 63 | if pred in data["answers"]: 64 | ctx = data["context"] 65 | match_para_ids[orig_id]["positive_ctxs"].append( 66 | {"text": ctx, "title": data["title"]}) 67 | else: 68 | ctx = data["context"] 69 | if data["answers"][0] not in ctx: 70 | match_para_ids[orig_id]["hard_negative_ctxs"].append( 71 | {"text": ctx, "title": data["title"]}) 72 | else: 73 | match_para_ids[orig_id]["matched_ctxs"].append( 74 | {"text": ctx, "title": data["title"]}) 75 | 76 | dpr_data = [] 77 | new_positive_ids = [] 78 | for q_id, item in match_para_ids.items(): 79 | if len(item["positive_ctxs"]) > 0: 80 | item["q_id"] = q_id 81 | new_positive_ids.append(q_id) 82 | if len(item["positive_ctxs"]) == 0 and len(item["matched_ctxs"]) > 0: 83 | item["positive_ctxs"].append(item["matched_ctxs"][0]) 84 | 85 | elif len(item["positive_ctxs"]) == 0 and len(item["matched_ctxs"]) == 0: 86 | print("examples are skipped") 87 | continue 88 | dpr_data.append(item) 89 | 90 | print(dpr_data[-1]) 91 | print(len(dpr_data)) 92 | print(len(new_positive_ids)) 93 | print("positive para num:{}".format(np.mean([len(item["positive_ctxs"]) for item in dpr_data]))) 94 | print("negative para num:{}".format(np.mean([len(item["hard_negative_ctxs"]) for item in dpr_data]))) 95 | 96 | with open(args.output_fn, 'w') as outfile: 97 | json.dump(dpr_data, outfile) 98 | 99 | 100 | if __name__=="__main__": 101 | main() -------------------------------------------------------------------------------- /baseline/mDPR/dense_retriever.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line tool to get dense results and validate them 10 | """ 11 | 12 | import argparse 13 | import os 14 | import csv 15 | import glob 16 | import json 17 | import gzip 18 | import logging 19 | import pickle 20 | import time 21 | import jsonlines 22 | from typing import List, Tuple, Dict, Iterator 23 | 24 | import numpy as np 25 | import torch 26 | from torch import Tensor as T 27 | from torch import nn 28 | 29 | from dpr.data.qa_validation import calculate_matches 30 | from dpr.models import init_biencoder_components 31 | from dpr.options import add_encoder_params, setup_args_gpu, print_args, set_encoder_params_from_state, \ 32 | add_tokenizer_params, add_cuda_params 33 | from dpr.utils.data_utils import Tensorizer 34 | from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint 35 | from dpr.indexer.faiss_indexers import DenseIndexer, DenseHNSWFlatIndexer, DenseFlatIndexer 36 | 37 | logger = logging.getLogger() 38 | logger.setLevel(logging.INFO) 39 | if (logger.hasHandlers()): 40 | logger.handlers.clear() 41 | console = logging.StreamHandler() 42 | logger.addHandler(console) 43 | 44 | 45 | class DenseRetriever(object): 46 | """ 47 | Does passage retrieving over the provided index and question encoder 48 | """ 49 | def __init__(self, question_encoder: nn.Module, batch_size: int, tensorizer: Tensorizer, index: DenseIndexer): 50 | self.question_encoder = question_encoder 51 | self.batch_size = batch_size 52 | self.tensorizer = tensorizer 53 | self.index = index 54 | 55 | def generate_question_vectors(self, questions: List[str]) -> T: 56 | n = len(questions) 57 | bsz = self.batch_size 58 | query_vectors = [] 59 | 60 | self.question_encoder.eval() 61 | 62 | with torch.no_grad(): 63 | for j, batch_start in enumerate(range(0, n, bsz)): 64 | 65 | batch_token_tensors = [self.tensorizer.text_to_tensor(q) for q in 66 | questions[batch_start:batch_start + bsz]] 67 | 68 | q_ids_batch = torch.stack(batch_token_tensors, dim=0).cuda() 69 | q_seg_batch = torch.zeros_like(q_ids_batch).cuda() 70 | q_attn_mask = self.tensorizer.get_attn_mask(q_ids_batch) 71 | _, out, _ = self.question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) 72 | 73 | query_vectors.extend(out.cpu().split(1, dim=0)) 74 | 75 | if len(query_vectors) % 100 == 0: 76 | logger.info('Encoded queries %d', len(query_vectors)) 77 | 78 | query_tensor = torch.cat(query_vectors, dim=0) 79 | 80 | logger.info('Total encoded queries tensor %s', query_tensor.size()) 81 | 82 | assert query_tensor.size(0) == len(questions) 83 | return query_tensor 84 | 85 | def index_encoded_data(self, vector_files: List[str], buffer_size: int = 50000): 86 | """ 87 | Indexes encoded passages takes form a list of files 88 | :param vector_files: file names to get passages vectors from 89 | :param buffer_size: size of a buffer (amount of passages) to send for the indexing at once 90 | :return: 91 | """ 92 | buffer = [] 93 | for i, item in enumerate(iterate_encoded_files(vector_files)): 94 | db_id, doc_vector = item 95 | buffer.append((db_id, doc_vector)) 96 | if 0 < buffer_size == len(buffer): 97 | self.index.index_data(buffer) 98 | buffer = [] 99 | self.index.index_data(buffer) 100 | logger.info('Data indexing completed.') 101 | 102 | def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]: 103 | """ 104 | Does the retrieval of the best matching passages given the query vectors batch 105 | :param query_vectors: 106 | :param top_docs: 107 | :return: 108 | """ 109 | time0 = time.time() 110 | results = self.index.search_knn(query_vectors, top_docs) 111 | logger.info('index search time: %f sec.', time.time() - time0) 112 | return results 113 | 114 | 115 | def parse_qa_csv_file(location) -> Iterator[Tuple[str, List[str]]]: 116 | with open(location) as ifile: 117 | reader = csv.reader(ifile, delimiter='\t') 118 | for row in reader: 119 | question = row[0] 120 | answers = eval(row[1]) 121 | yield question, answers 122 | 123 | def read_jsonlines(eval_file_name): 124 | lines = [] 125 | print("loading examples from {0}".format(eval_file_name)) 126 | with jsonlines.open(eval_file_name) as reader: 127 | for obj in reader: 128 | lines.append(obj) 129 | return lines 130 | 131 | def parse_qa_jsonlines_file(location, add_lang=False) -> Iterator[Tuple[str, str, List[str], str]]: 132 | data = read_jsonlines(location) 133 | for row in data: 134 | question = row["question"] 135 | answers = row["answers"] 136 | q_id = row["id"] 137 | lang = row["lang"] 138 | if add_lang is True: 139 | question = "{0} [{1}]".format(question, lang) 140 | yield question, q_id, answers, lang 141 | 142 | 143 | def validate(passages: Dict[object, Tuple[str, str]], answers: List[List[str]], 144 | result_ctx_ids: List[Tuple[List[object], List[float]]], 145 | workers_num: int, match_type: str) -> List[List[bool]]: 146 | match_stats = calculate_matches(passages, answers, result_ctx_ids, workers_num, match_type) 147 | top_k_hits = match_stats.top_k_hits 148 | 149 | logger.info('Validation results: top k documents hits %s', top_k_hits) 150 | top_k_hits = [v / len(result_ctx_ids) for v in top_k_hits] 151 | logger.info('Validation results: top k documents hits accuracy %s', top_k_hits) 152 | return match_stats.questions_doc_hits 153 | 154 | 155 | def load_passages(ctx_file: str) -> Dict[object, Tuple[str, str]]: 156 | docs = {} 157 | logger.info('Reading data from: %s', ctx_file) 158 | if ctx_file.startswith(".gz"): 159 | with gzip.open(ctx_file) as tsvfile: 160 | reader = csv.reader(tsvfile, delimiter='\t', ) 161 | # file format: doc_id, doc_text, title 162 | for row in reader: 163 | if row[0] != 'id': 164 | docs[row[0]] = (row[1], row[2]) 165 | else: 166 | with open(ctx_file) as tsvfile: 167 | reader = csv.reader(tsvfile, delimiter='\t', ) 168 | # file format: doc_id, doc_text, title 169 | for row in reader: 170 | if row[0] != 'id': 171 | docs[row[0]] = (row[1], row[2]) 172 | return docs 173 | 174 | 175 | def save_results(passages: Dict[object, Tuple[str, str]], questions: List[str], q_ids: List[str], answers: List[List[str]], languages: List[str], 176 | top_passages_and_scores: List[Tuple[List[object], List[float]]], per_question_hits: List[List[bool]], 177 | out_file: str 178 | ): 179 | # join passages text with the result ids, their questions and assigning has|no answer labels 180 | merged_data = [] 181 | sqaud_style_data = {'data': [], 'version': 'v1.1'} 182 | assert len(per_question_hits) == len(questions) == len(answers) 183 | for i, q in enumerate(questions): 184 | q_answers = answers[i] 185 | q_id = q_ids[i] 186 | lang = languages[i] 187 | results_and_scores = top_passages_and_scores[i] 188 | hits = per_question_hits[i] 189 | docs = [passages[doc_id] for doc_id in results_and_scores[0]] 190 | scores = [str(score) for score in results_and_scores[1]] 191 | ctxs_num = len(hits) 192 | 193 | merged_data.append({ 194 | 'q_id': q_id, 195 | 'question': q, 196 | 'answers': q_answers, 197 | 'lang': lang, 198 | 'ctxs': [ 199 | { 200 | 'id': results_and_scores[0][c], 201 | 'title': docs[c][1], 202 | 'text': docs[c][0], 203 | 'score': scores[c], 204 | 'has_answer': hits[c], 205 | } for c in range(ctxs_num) 206 | ] 207 | }) 208 | 209 | with open(out_file, "w") as writer: 210 | writer.write(json.dumps(merged_data, indent=4) + "\n") 211 | 212 | # create XOR retrieve output format. 213 | xor_output_prediction_format = [] 214 | for example in merged_data: 215 | q_id = example["q_id"] 216 | ctxs = [ctx["text"] for ctx in example["ctxs"]] 217 | lang = example["lang"] 218 | xor_output_prediction_format.append({"id": q_id, "lang": lang, "ctxs" : ctxs}) 219 | 220 | with open("{}_xor_retrieve_results.json".format(out_file.split(".")[0]), 'w') as outfile: 221 | json.dump(xor_output_prediction_format, outfile) 222 | 223 | logger.info('Saved results * scores to %s', out_file) 224 | 225 | 226 | def iterate_encoded_files(vector_files: list) -> Iterator[Tuple[object, np.array]]: 227 | for i, file in enumerate(vector_files): 228 | logger.info('Reading file %s', file) 229 | with open(file, "rb") as reader: 230 | doc_vectors = pickle.load(reader) 231 | for doc in doc_vectors: 232 | db_id, doc_vector = doc 233 | yield db_id, doc_vector 234 | 235 | 236 | def main(args): 237 | saved_state = load_states_from_checkpoint(args.model_file) 238 | set_encoder_params_from_state(saved_state.encoder_params, args) 239 | 240 | tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) 241 | 242 | encoder = encoder.question_model 243 | 244 | encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, 245 | args.local_rank, 246 | args.fp16) 247 | encoder.eval() 248 | 249 | # load weights from the model file 250 | model_to_load = get_model_obj(encoder) 251 | logger.info('Loading saved model state ...') 252 | 253 | prefix_len = len('question_model.') 254 | question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if 255 | key.startswith('question_model.')} 256 | model_to_load.load_state_dict(question_encoder_state) 257 | vector_size = model_to_load.get_out_size() 258 | logger.info('Encoder vector_size=%d', vector_size) 259 | 260 | index_buffer_sz = args.index_buffer 261 | if args.hnsw_index: 262 | index = DenseHNSWFlatIndexer(vector_size) 263 | index_buffer_sz = -1 # encode all at once 264 | else: 265 | index = DenseFlatIndexer(vector_size) 266 | 267 | retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index) 268 | 269 | 270 | # index all passages 271 | ctx_files_pattern = args.encoded_ctx_file 272 | input_paths = glob.glob(ctx_files_pattern) 273 | if args.remove_lang is not None: 274 | final_fps = [] 275 | 276 | for path in input_paths: 277 | basename = os.path.basename(path) 278 | to_be_removed = False 279 | for lang in args.remove_lang: 280 | if lang in basename: 281 | to_be_removed = True 282 | if to_be_removed is False: 283 | final_fps.append(path) 284 | input_paths = final_fps 285 | print("lang {} are removed from retrieval target".format(input_paths)) 286 | index_path = "_".join(input_paths[0].split("_")[:-1]) 287 | if args.save_or_load_index and os.path.exists(index_path): 288 | retriever.index.deserialize(index_path) 289 | else: 290 | logger.info('Reading all passages data from files: %s', input_paths) 291 | retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz) 292 | if args.save_or_load_index: 293 | retriever.index.serialize(index_path) 294 | # get questions & answers 295 | questions = [] 296 | question_answers = [] 297 | question_languages = [] 298 | q_ids = [] 299 | 300 | for ds_item in parse_qa_jsonlines_file(args.qa_file, args.add_lang): 301 | question, q_id, answers, language = ds_item 302 | questions.append(question) 303 | q_ids.append(q_id) 304 | question_answers.append(answers) 305 | question_languages.append(language) 306 | 307 | questions_tensor = retriever.generate_question_vectors(questions) 308 | 309 | # get top k results 310 | top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs) 311 | 312 | all_passages = load_passages(args.ctx_file) 313 | 314 | if len(all_passages) == 0: 315 | raise RuntimeError('No passages data found. Please specify ctx_file param properly.') 316 | 317 | questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers, 318 | args.match) 319 | 320 | if args.out_file: 321 | save_results(all_passages, questions, q_ids, question_answers, question_languages, top_ids_and_scores, questions_doc_hits, args.out_file) 322 | 323 | 324 | if __name__ == '__main__': 325 | parser = argparse.ArgumentParser() 326 | 327 | add_encoder_params(parser) 328 | add_tokenizer_params(parser) 329 | add_cuda_params(parser) 330 | 331 | parser.add_argument('--qa_file', required=True, type=str, default=None, 332 | help="Question and answers file of the format: question \\t ['answer1','answer2', ...]") 333 | parser.add_argument('--ctx_file', required=True, type=str, default=None, 334 | help="All passages file in the tsv format: id \\t passage_text \\t title") 335 | parser.add_argument('--encoded_ctx_file', type=str, default=None, 336 | help='Glob path to encoded passages (from generate_dense_embeddings tool)') 337 | parser.add_argument('--remove_lang', type=str, default=None, nargs="*", 338 | help='languages to be removed') 339 | parser.add_argument('--add_lang', action='store_true') 340 | parser.add_argument('--out_file', type=str, default=None, 341 | help='output .tsv file path to write results to ') 342 | parser.add_argument('--match', type=str, default='string', choices=['regex', 'string'], 343 | help="Answer matching logic type") 344 | parser.add_argument('--n-docs', type=int, default=200, help="Amount of top docs to return") 345 | parser.add_argument('--validation_workers', type=int, default=16, 346 | help="Number of parallel processes to validate results") 347 | parser.add_argument('--batch_size', type=int, default=32, help="Batch size for question encoder forward pass") 348 | parser.add_argument('--index_buffer', type=int, default=50000, 349 | help="Temporal memory data buffer size (in samples) for indexer") 350 | parser.add_argument("--hnsw_index", action='store_true', help='If enabled, use inference time efficient HNSW index') 351 | parser.add_argument("--save_or_load_index", action='store_true', help='If enabled, save index') 352 | 353 | args = parser.parse_args() 354 | 355 | assert args.model_file, 'Please specify --model_file checkpoint to init model weights' 356 | 357 | setup_args_gpu(args) 358 | print_args(args) 359 | main(args) 360 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mia-workshop/MIA-Shared-Task-2022/ea2491a3e3a588e270b2dd30ce7e044584c3eeb5/baseline/mDPR/dpr/__init__.py -------------------------------------------------------------------------------- /baseline/mDPR/dpr/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mia-workshop/MIA-Shared-Task-2022/ea2491a3e3a588e270b2dd30ce7e044584c3eeb5/baseline/mDPR/dpr/data/__init__.py -------------------------------------------------------------------------------- /baseline/mDPR/dpr/data/qa_validation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Set of utilities for Q&A results validation tasks - Retriver passage validation and Reader predicted answer validation 10 | """ 11 | 12 | import collections 13 | import logging 14 | import string 15 | import unicodedata 16 | from functools import partial 17 | from multiprocessing import Pool as ProcessPool 18 | from typing import Tuple, List, Dict 19 | 20 | import regex as re 21 | 22 | from dpr.utils.tokenizers import SimpleTokenizer 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 27 | 28 | 29 | def calculate_matches(all_docs: Dict[object, Tuple[str, str]], answers: List[List[str]], 30 | closest_docs: List[Tuple[List[object], List[float]]], workers_num: int, 31 | match_type: str) -> QAMatchStats: 32 | """ 33 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 34 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 35 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 36 | :param answers: list of answers's list. One list per question 37 | :param closest_docs: document ids of the top results along with their scores 38 | :param workers_num: amount of parallel threads to process data 39 | :param match_type: type of answer matching. Refer to has_answer code for available options 40 | :return: matching information tuple. 41 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 42 | valid matches across an entire dataset. 43 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 44 | """ 45 | global dpr_all_documents 46 | dpr_all_documents = all_docs 47 | 48 | tok_opts = {} 49 | tokenizer = SimpleTokenizer(**tok_opts) 50 | 51 | processes = ProcessPool( 52 | processes=workers_num, 53 | ) 54 | 55 | logger.info('Matching answers in top docs...') 56 | 57 | get_score_partial = partial(check_answer, match_type=match_type, tokenizer=tokenizer) 58 | 59 | questions_answers_docs = zip(answers, closest_docs) 60 | 61 | scores = processes.map(get_score_partial, questions_answers_docs) 62 | 63 | logger.info('Per question validation results len=%d', len(scores)) 64 | 65 | n_docs = len(closest_docs[0][0]) 66 | top_k_hits = [0] * n_docs 67 | for question_hits in scores: 68 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 69 | if best_hit is not None: 70 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 71 | 72 | return QAMatchStats(top_k_hits, scores) 73 | 74 | 75 | def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: 76 | """Search through all the top docs to see if they have any of the answers.""" 77 | answers, (doc_ids, doc_scores) = questions_answers_docs 78 | 79 | global dpr_all_documents 80 | hits = [] 81 | 82 | for i, doc_id in enumerate(doc_ids): 83 | doc = dpr_all_documents[doc_id] 84 | text = doc[0] 85 | 86 | answer_found = False 87 | if text is None: # cannot find the document for some reason 88 | logger.warning("no doc in db") 89 | hits.append(False) 90 | continue 91 | 92 | if has_answer(answers, text, tokenizer, match_type): 93 | answer_found = True 94 | hits.append(answer_found) 95 | return hits 96 | 97 | 98 | def has_answer(answers, text, tokenizer, match_type) -> bool: 99 | """Check if a document contains an answer string. 100 | If `match_type` is string, token matching is done between the text and answer. 101 | If `match_type` is regex, we search the whole text with the regex. 102 | """ 103 | text = _normalize(text) 104 | 105 | if match_type == 'string': 106 | # Answer is a list of possible strings 107 | text = tokenizer.tokenize(text).words(uncased=True) 108 | 109 | for single_answer in answers: 110 | single_answer = _normalize(single_answer) 111 | single_answer = tokenizer.tokenize(single_answer) 112 | single_answer = single_answer.words(uncased=True) 113 | 114 | for i in range(0, len(text) - len(single_answer) + 1): 115 | if single_answer == text[i: i + len(single_answer)]: 116 | return True 117 | 118 | elif match_type == 'regex': 119 | # Answer is a regex 120 | for single_answer in answers: 121 | single_answer = _normalize(single_answer) 122 | if regex_match(text, single_answer): 123 | return True 124 | return False 125 | 126 | 127 | def regex_match(text, pattern): 128 | """Test if a regex pattern is contained within a text.""" 129 | try: 130 | pattern = re.compile( 131 | pattern, 132 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, 133 | ) 134 | except BaseException: 135 | return False 136 | return pattern.search(text) is not None 137 | 138 | 139 | # function for the reader model answer validation 140 | def exact_match_score(prediction, ground_truth): 141 | return _normalize_answer(prediction) == _normalize_answer(ground_truth) 142 | 143 | 144 | def _normalize_answer(s): 145 | def remove_articles(text): 146 | return re.sub(r'\b(a|an|the)\b', ' ', text) 147 | 148 | def white_space_fix(text): 149 | return ' '.join(text.split()) 150 | 151 | def remove_punc(text): 152 | exclude = set(string.punctuation) 153 | return ''.join(ch for ch in text if ch not in exclude) 154 | 155 | def lower(text): 156 | return text.lower() 157 | 158 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 159 | 160 | 161 | def _normalize(text): 162 | return unicodedata.normalize('NFD', text) 163 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/indexer/faiss_indexers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | FAISS-based index components for dense retriver 10 | """ 11 | 12 | import os 13 | import logging 14 | import pickle 15 | from typing import List, Tuple 16 | 17 | import faiss 18 | import numpy as np 19 | 20 | logger = logging.getLogger() 21 | 22 | 23 | class DenseIndexer(object): 24 | 25 | def __init__(self, buffer_size: int = 50000): 26 | self.buffer_size = buffer_size 27 | self.index_id_to_db_id = [] 28 | self.index = None 29 | 30 | def index_data(self, data: List[Tuple[object, np.array]]): 31 | raise NotImplementedError 32 | 33 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 34 | raise NotImplementedError 35 | 36 | def serialize(self, file: str): 37 | logger.info('Serializing index to %s', file) 38 | 39 | if os.path.isdir(file): 40 | index_file = os.path.join(file, "index.dpr") 41 | meta_file = os.path.join(file, "index_meta.dpr") 42 | else: 43 | index_file = file + '.index.dpr' 44 | meta_file = file + '.index_meta.dpr' 45 | 46 | faiss.write_index(self.index, index_file) 47 | with open(meta_file, mode='wb') as f: 48 | pickle.dump(self.index_id_to_db_id, f) 49 | 50 | def deserialize_from(self, file: str): 51 | logger.info('Loading index from %s', file) 52 | 53 | if os.path.isdir(file): 54 | index_file = os.path.join(file, "index.dpr") 55 | meta_file = os.path.join(file, "index_meta.dpr") 56 | else: 57 | index_file = file + '.index.dpr' 58 | meta_file = file + '.index_meta.dpr' 59 | 60 | self.index = faiss.read_index(index_file) 61 | logger.info('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) 62 | 63 | with open(meta_file, "rb") as reader: 64 | self.index_id_to_db_id = pickle.load(reader) 65 | assert len( 66 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' 67 | 68 | def _update_id_mapping(self, db_ids: List): 69 | self.index_id_to_db_id.extend(db_ids) 70 | 71 | 72 | class DenseFlatIndexer(DenseIndexer): 73 | 74 | def __init__(self, vector_sz: int, buffer_size: int = 50000): 75 | super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) 76 | self.index = faiss.IndexFlatIP(vector_sz) 77 | 78 | def index_data(self, data: List[Tuple[object, np.array]]): 79 | n = len(data) 80 | # indexing in batches is beneficial for many faiss index types 81 | for i in range(0, n, self.buffer_size): 82 | db_ids = [t[0] for t in data[i:i + self.buffer_size]] 83 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] 84 | vectors = np.concatenate(vectors, axis=0) 85 | self._update_id_mapping(db_ids) 86 | self.index.add(vectors) 87 | 88 | indexed_cnt = len(self.index_id_to_db_id) 89 | logger.info('Total data indexed %d', indexed_cnt) 90 | 91 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 92 | scores, indexes = self.index.search(query_vectors, top_docs) 93 | # convert to external ids 94 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 95 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 96 | return result 97 | 98 | 99 | class DenseHNSWFlatIndexer(DenseIndexer): 100 | """ 101 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 102 | """ 103 | 104 | def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512 105 | , ef_search: int = 128, ef_construction: int = 200): 106 | super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) 107 | 108 | # IndexHNSWFlat supports L2 similarity only 109 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 110 | index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) 111 | index.hnsw.efSearch = ef_search 112 | index.hnsw.efConstruction = ef_construction 113 | self.index = index 114 | self.phi = 0 115 | 116 | def index_data(self, data: List[Tuple[object, np.array]]): 117 | n = len(data) 118 | 119 | # max norm is required before putting all vectors in the index to convert inner product similarity to L2 120 | if self.phi > 0: 121 | raise RuntimeError('DPR HNSWF index needs to index all data at once,' 122 | 'results will be unpredictable otherwise.') 123 | phi = 0 124 | for i, item in enumerate(data): 125 | id, doc_vector = item 126 | norms = (doc_vector ** 2).sum() 127 | phi = max(phi, norms) 128 | logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi)) 129 | self.phi = 0 130 | 131 | # indexing in batches is beneficial for many faiss index types 132 | for i in range(0, n, self.buffer_size): 133 | db_ids = [t[0] for t in data[i:i + self.buffer_size]] 134 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] 135 | 136 | norms = [(doc_vector ** 2).sum() for doc_vector in vectors] 137 | aux_dims = [np.sqrt(phi - norm) for norm in norms] 138 | hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in 139 | enumerate(vectors)] 140 | hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) 141 | 142 | self._update_id_mapping(db_ids) 143 | self.index.add(hnsw_vectors) 144 | logger.info('data indexed %d', len(self.index_id_to_db_id)) 145 | 146 | indexed_cnt = len(self.index_id_to_db_id) 147 | logger.info('Total data indexed %d', indexed_cnt) 148 | 149 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 150 | 151 | aux_dim = np.zeros(len(query_vectors), dtype='float32') 152 | query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 153 | logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape) 154 | scores, indexes = self.index.search(query_nhsw_vectors, top_docs) 155 | # convert to external ids 156 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 157 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 158 | return result 159 | 160 | def deserialize_from(self, file: str): 161 | super(DenseHNSWFlatIndexer, self).deserialize_from(file) 162 | # to trigger warning on subsequent indexing 163 | self.phi = 1 164 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import importlib 9 | 10 | """ 11 | 'Router'-like set of methods for component initialization with lazy imports 12 | """ 13 | 14 | 15 | def init_hf_bert_biencoder(args, **kwargs): 16 | if importlib.util.find_spec("transformers") is None: 17 | raise RuntimeError('Please install transformers lib') 18 | from .hf_models import get_bert_biencoder_components 19 | return get_bert_biencoder_components(args, **kwargs) 20 | 21 | 22 | def init_hf_bert_reader(args, **kwargs): 23 | if importlib.util.find_spec("transformers") is None: 24 | raise RuntimeError('Please install transformers lib') 25 | from .hf_models import get_bert_reader_components 26 | return get_bert_reader_components(args, **kwargs) 27 | 28 | 29 | def init_pytext_bert_biencoder(args, **kwargs): 30 | if importlib.util.find_spec("pytext") is None: 31 | raise RuntimeError('Please install pytext lib') 32 | from .pytext_models import get_bert_biencoder_components 33 | return get_bert_biencoder_components(args, **kwargs) 34 | 35 | 36 | def init_fairseq_roberta_biencoder(args, **kwargs): 37 | if importlib.util.find_spec("fairseq") is None: 38 | raise RuntimeError('Please install fairseq lib') 39 | from .fairseq_models import get_roberta_biencoder_components 40 | return get_roberta_biencoder_components(args, **kwargs) 41 | 42 | 43 | def init_hf_bert_tenzorizer(args, **kwargs): 44 | if importlib.util.find_spec("transformers") is None: 45 | raise RuntimeError('Please install transformers lib') 46 | from .hf_models import get_bert_tensorizer 47 | return get_bert_tensorizer(args) 48 | 49 | 50 | def init_hf_roberta_tenzorizer(args, **kwargs): 51 | if importlib.util.find_spec("transformers") is None: 52 | raise RuntimeError('Please install transformers lib') 53 | from .hf_models import get_roberta_tensorizer 54 | return get_roberta_tensorizer(args) 55 | 56 | 57 | BIENCODER_INITIALIZERS = { 58 | 'hf_bert': init_hf_bert_biencoder, 59 | 'pytext_bert': init_pytext_bert_biencoder, 60 | 'fairseq_roberta': init_fairseq_roberta_biencoder, 61 | } 62 | 63 | READER_INITIALIZERS = { 64 | 'hf_bert': init_hf_bert_reader, 65 | } 66 | 67 | TENSORIZER_INITIALIZERS = { 68 | 'hf_bert': init_hf_bert_tenzorizer, 69 | 'hf_roberta': init_hf_roberta_tenzorizer, 70 | 'pytext_bert': init_hf_bert_tenzorizer, # using HF's code as of now 71 | 'fairseq_roberta': init_hf_roberta_tenzorizer, # using HF's code as of now 72 | } 73 | 74 | 75 | def init_comp(initializers_dict, type, args, **kwargs): 76 | if type in initializers_dict: 77 | return initializers_dict[type](args, **kwargs) 78 | else: 79 | raise RuntimeError('unsupported model type: {}'.format(type)) 80 | 81 | 82 | def init_biencoder_components(encoder_type: str, args, **kwargs): 83 | return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs) 84 | 85 | 86 | def init_reader_components(encoder_type: str, args, **kwargs): 87 | return init_comp(READER_INITIALIZERS, encoder_type, args, **kwargs) 88 | 89 | 90 | def init_tenzorizer(encoder_type: str, args, **kwargs): 91 | return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs) 92 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/models/biencoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | BiEncoder component + loss function for 'all-in-batch' training 10 | """ 11 | 12 | import collections 13 | import logging 14 | import random 15 | from typing import Tuple, List 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn.functional as F 20 | from torch import Tensor as T 21 | from torch import nn 22 | 23 | from dpr.utils.data_utils import Tensorizer 24 | from dpr.utils.data_utils import normalize_question 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | BiEncoderBatch = collections.namedtuple('BiENcoderInput', 29 | ['question_ids', 'question_segments', 'context_ids', 'ctx_segments', 30 | 'is_positive', 'hard_negatives']) 31 | 32 | 33 | def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T: 34 | """ 35 | calculates q->ctx scores for every row in ctx_vector 36 | :param q_vector: 37 | :param ctx_vector: 38 | :return: 39 | """ 40 | # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 41 | r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) 42 | return r 43 | 44 | 45 | def cosine_scores(q_vector: T, ctx_vectors: T): 46 | # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 47 | return F.cosine_similarity(q_vector, ctx_vectors, dim=1) 48 | 49 | 50 | class BiEncoder(nn.Module): 51 | """ Bi-Encoder model component. Encapsulates query/question and context/passage encoders. 52 | """ 53 | 54 | def __init__(self, question_model: nn.Module, ctx_model: nn.Module, fix_q_encoder: bool = False, 55 | fix_ctx_encoder: bool = False): 56 | super(BiEncoder, self).__init__() 57 | self.question_model = question_model 58 | self.ctx_model = ctx_model 59 | self.fix_q_encoder = fix_q_encoder 60 | self.fix_ctx_encoder = fix_ctx_encoder 61 | 62 | @staticmethod 63 | def get_representation(sub_model: nn.Module, ids: T, segments: T, attn_mask: T, fix_encoder: bool = False) -> ( 64 | T, T, T): 65 | sequence_output = None 66 | pooled_output = None 67 | hidden_states = None 68 | if ids is not None: 69 | if fix_encoder: 70 | with torch.no_grad(): 71 | sequence_output, pooled_output, hidden_states = sub_model(ids, segments, attn_mask) 72 | 73 | if sub_model.training: 74 | sequence_output.requires_grad_(requires_grad=True) 75 | pooled_output.requires_grad_(requires_grad=True) 76 | else: 77 | sequence_output, pooled_output, hidden_states = sub_model(ids, segments, attn_mask) 78 | 79 | return sequence_output, pooled_output, hidden_states 80 | 81 | def forward(self, question_ids: T, question_segments: T, question_attn_mask: T, context_ids: T, ctx_segments: T, 82 | ctx_attn_mask: T) -> Tuple[T, T]: 83 | 84 | _q_seq, q_pooled_out, _q_hidden = self.get_representation(self.question_model, question_ids, question_segments, 85 | question_attn_mask, self.fix_q_encoder) 86 | _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation(self.ctx_model, context_ids, ctx_segments, 87 | ctx_attn_mask, self.fix_ctx_encoder) 88 | 89 | return q_pooled_out, ctx_pooled_out 90 | 91 | @classmethod 92 | def create_biencoder_input(cls, 93 | samples: List, 94 | tensorizer: Tensorizer, 95 | insert_title: bool, 96 | num_hard_negatives: int = 0, 97 | num_other_negatives: int = 0, 98 | shuffle: bool = True, 99 | shuffle_positives: bool = False, 100 | ) -> BiEncoderBatch: 101 | """ 102 | Creates a batch of the biencoder training tuple. 103 | :param samples: list of data items (from json) to create the batch for 104 | :param tensorizer: components to create model input tensors from a text sequence 105 | :param insert_title: enables title insertion at the beginning of the context sequences 106 | :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) 107 | :param num_other_negatives: amount of other negatives per question (taken from samples' pools) 108 | :param shuffle: shuffles negative passages pools 109 | :param shuffle_positives: shuffles positive passages pools 110 | :return: BiEncoderBatch tuple 111 | """ 112 | question_tensors = [] 113 | ctx_tensors = [] 114 | positive_ctx_indices = [] 115 | hard_neg_ctx_indices = [] 116 | 117 | for sample in samples: 118 | # ctx+ & [ctx-] composition 119 | # as of now, take the first(gold) ctx+ only 120 | if shuffle and shuffle_positives: 121 | positive_ctxs = sample['positive_ctxs'] 122 | positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] 123 | else: 124 | positive_ctx = sample['positive_ctxs'][0] 125 | 126 | neg_ctxs = sample['negative_ctxs'] 127 | hard_neg_ctxs = sample['hard_negative_ctxs'] 128 | question = normalize_question(sample['question']) 129 | 130 | if shuffle: 131 | random.shuffle(neg_ctxs) 132 | random.shuffle(hard_neg_ctxs) 133 | 134 | neg_ctxs = neg_ctxs[0:num_other_negatives] 135 | hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] 136 | 137 | all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs 138 | hard_negatives_start_idx = 1 139 | hard_negatives_end_idx = 1 + len(hard_neg_ctxs) 140 | 141 | current_ctxs_len = len(ctx_tensors) 142 | 143 | sample_ctxs_tensors = [tensorizer.text_to_tensor(ctx['text'], title=ctx['title'] if insert_title else None) 144 | for 145 | ctx in all_ctxs] 146 | 147 | ctx_tensors.extend(sample_ctxs_tensors) 148 | positive_ctx_indices.append(current_ctxs_len) 149 | hard_neg_ctx_indices.append( 150 | [i for i in 151 | range(current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx)]) 152 | 153 | question_tensors.append(tensorizer.text_to_tensor(question)) 154 | 155 | ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) 156 | questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) 157 | 158 | ctx_segments = torch.zeros_like(ctxs_tensor) 159 | question_segments = torch.zeros_like(questions_tensor) 160 | 161 | return BiEncoderBatch(questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices, 162 | hard_neg_ctx_indices) 163 | 164 | 165 | class BiEncoderNllLoss(object): 166 | 167 | def calc(self, q_vectors: T, ctx_vectors: T, positive_idx_per_question: list, 168 | hard_negatice_idx_per_question: list = None) -> Tuple[T, int]: 169 | """ 170 | Computes nll loss for the given lists of question and ctx vectors. 171 | Note that although hard_negatice_idx_per_question in not currently in use, one can use it for the 172 | loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. 173 | :return: a tuple of loss value and amount of correct predictions per batch 174 | """ 175 | scores = self.get_scores(q_vectors, ctx_vectors) 176 | 177 | if len(q_vectors.size()) > 1: 178 | q_num = q_vectors.size(0) 179 | scores = scores.view(q_num, -1) 180 | 181 | softmax_scores = F.log_softmax(scores, dim=1) 182 | 183 | loss = F.nll_loss(softmax_scores, torch.tensor(positive_idx_per_question).to(softmax_scores.device), 184 | reduction='mean') 185 | 186 | max_score, max_idxs = torch.max(softmax_scores, 1) 187 | correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum() 188 | return loss, correct_predictions_count 189 | 190 | @staticmethod 191 | def get_scores(q_vector: T, ctx_vectors: T) -> T: 192 | f = BiEncoderNllLoss.get_similarity_function() 193 | return f(q_vector, ctx_vectors) 194 | 195 | @staticmethod 196 | def get_similarity_function(): 197 | return dot_product_scores 198 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/models/fairseq_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Encoder model wrappers based on Fairseq code 10 | """ 11 | 12 | import logging 13 | from typing import Tuple 14 | 15 | from fairseq.models.roberta.hub_interface import RobertaHubInterface 16 | from fairseq.models.roberta.model import RobertaModel as FaiseqRobertaModel 17 | from fairseq.optim.adam import FairseqAdam 18 | from torch import Tensor as T 19 | from torch import nn 20 | 21 | from dpr.models.hf_models import get_roberta_tensorizer 22 | from .biencoder import BiEncoder 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def get_roberta_biencoder_components(args, inference_only: bool = False, **kwargs): 28 | question_encoder = RobertaEncoder.from_pretrained(args.pretrained_file) 29 | ctx_encoder = RobertaEncoder.from_pretrained(args.pretrained_file) 30 | biencoder = BiEncoder(question_encoder, ctx_encoder) 31 | optimizer = get_fairseq_adamw_optimizer(biencoder, args) if not inference_only else None 32 | 33 | tensorizer = get_roberta_tensorizer(args) 34 | 35 | return tensorizer, biencoder, optimizer 36 | 37 | 38 | def get_fairseq_adamw_optimizer(model: nn.Module, args): 39 | setattr(args, 'lr', [args.learning_rate]) 40 | return FairseqAdam(args, model.parameters()).optimizer 41 | 42 | 43 | class RobertaEncoder(nn.Module): 44 | 45 | def __init__(self, fairseq_roberta_hub: RobertaHubInterface): 46 | super(RobertaEncoder, self).__init__() 47 | self.fairseq_roberta = fairseq_roberta_hub 48 | 49 | @classmethod 50 | def from_pretrained(cls, pretrained_dir_path: str): 51 | model = FaiseqRobertaModel.from_pretrained(pretrained_dir_path) 52 | return cls(model) 53 | 54 | def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: 55 | roberta_out = self.fairseq_roberta.extract_features(input_ids) 56 | cls_out = roberta_out[:, 0, :] 57 | return roberta_out, cls_out, None 58 | 59 | def get_out_size(self): 60 | raise NotImplementedError 61 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/models/hf_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Encoder model wrappers based on HuggingFace code 10 | """ 11 | 12 | import logging 13 | from typing import Tuple 14 | 15 | import torch 16 | from torch import Tensor as T 17 | from torch import nn 18 | from transformers.modeling_bert import BertModel 19 | from transformers import AdamW 20 | from transformers import AutoTokenizer, AutoConfig 21 | 22 | from dpr.utils.data_utils import Tensorizer 23 | from .biencoder import BiEncoder 24 | from .reader import Reader 25 | 26 | logger = logging.getLogger(__name__) 27 | logging.basicConfig(level=logging.ERROR) 28 | 29 | 30 | def get_bert_biencoder_components(args, inference_only: bool = False, **kwargs): 31 | dropout = args.dropout if hasattr(args, 'dropout') else 0.0 32 | question_encoder = HFBertEncoder.init_encoder(args.pretrained_model_cfg, 33 | projection_dim=args.projection_dim, dropout=dropout, **kwargs) 34 | ctx_encoder = HFBertEncoder.init_encoder(args.pretrained_model_cfg, 35 | projection_dim=args.projection_dim, dropout=dropout, **kwargs) 36 | 37 | fix_ctx_encoder = args.fix_ctx_encoder if hasattr(args, 'fix_ctx_encoder') else False 38 | print("fix context encoder: {}".format(fix_ctx_encoder)) 39 | biencoder = BiEncoder(question_encoder, ctx_encoder, fix_ctx_encoder=fix_ctx_encoder) 40 | 41 | optimizer = get_optimizer(biencoder, 42 | learning_rate=args.learning_rate, 43 | adam_eps=args.adam_eps, weight_decay=args.weight_decay, 44 | ) if not inference_only else None 45 | 46 | tensorizer = get_bert_tensorizer(args) 47 | 48 | return tensorizer, biencoder, optimizer 49 | 50 | 51 | def get_bert_reader_components(args, inference_only: bool = False, **kwargs): 52 | dropout = args.dropout if hasattr(args, 'dropout') else 0.0 53 | encoder = HFBertEncoder.init_encoder(args.pretrained_model_cfg, 54 | projection_dim=args.projection_dim, dropout=dropout) 55 | 56 | hidden_size = encoder.config.hidden_size 57 | reader = Reader(encoder, hidden_size) 58 | 59 | optimizer = get_optimizer(reader, 60 | learning_rate=args.learning_rate, 61 | adam_eps=args.adam_eps, weight_decay=args.weight_decay, 62 | ) if not inference_only else None 63 | 64 | tensorizer = get_bert_tensorizer(args) 65 | return tensorizer, reader, optimizer 66 | 67 | 68 | def get_bert_tensorizer(args, tokenizer=None): 69 | if not tokenizer: 70 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_cfg, do_lower_case=args.do_lower_case, use_fast=False) 71 | return BertTensorizer(tokenizer, args.sequence_length) 72 | 73 | 74 | def get_roberta_tensorizer(args, tokenizer=None): 75 | if not tokenizer: 76 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_cfg, do_lower_case=args.do_lower_case, use_fast=False) 77 | return RobertaTensorizer(tokenizer, args.sequence_length) 78 | 79 | 80 | def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, 81 | weight_decay: float = 0.0, ) -> torch.optim.Optimizer: 82 | no_decay = ['bias', 'LayerNorm.weight'] 83 | 84 | optimizer_grouped_parameters = [ 85 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 86 | 'weight_decay': weight_decay}, 87 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 88 | ] 89 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps) 90 | return optimizer 91 | 92 | class HFBertEncoder(BertModel): 93 | 94 | def __init__(self, config, project_dim: int = 0): 95 | BertModel.__init__(self, config) 96 | assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' 97 | self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None 98 | self.init_weights() 99 | 100 | @classmethod 101 | def init_encoder(cls, cfg_name: str, projection_dim: int = 0, dropout: float = 0.1, **kwargs) -> BertModel: 102 | cfg = AutoConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased') 103 | if dropout != 0: 104 | cfg.attention_probs_dropout_prob = dropout 105 | cfg.hidden_dropout_prob = dropout 106 | return cls.from_pretrained(cfg_name, config=cfg, project_dim=projection_dim, **kwargs) 107 | 108 | def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: 109 | if self.config.output_hidden_states: 110 | sequence_output, pooled_output, hidden_states = super().forward(input_ids=input_ids, 111 | token_type_ids=token_type_ids, 112 | attention_mask=attention_mask) 113 | else: 114 | hidden_states = None 115 | sequence_output, pooled_output = super().forward(input_ids=input_ids, token_type_ids=token_type_ids, 116 | attention_mask=attention_mask) 117 | 118 | pooled_output = sequence_output[:, 0, :] 119 | if self.encode_proj: 120 | pooled_output = self.encode_proj(pooled_output) 121 | return sequence_output, pooled_output, hidden_states 122 | 123 | def get_out_size(self): 124 | if self.encode_proj: 125 | return self.encode_proj.out_features 126 | return self.config.hidden_size 127 | 128 | 129 | class BertTensorizer(Tensorizer): 130 | def __init__(self, tokenizer: AutoTokenizer, max_length: int, pad_to_max: bool = True): 131 | self.tokenizer = tokenizer 132 | self.max_length = max_length 133 | self.pad_to_max = pad_to_max 134 | 135 | def text_to_tensor(self, text: str, title: str = None, add_special_tokens: bool = True): 136 | if isinstance(text, list) and len(text) == 1: 137 | text = text[0] 138 | text = text.strip() 139 | 140 | # tokenizer automatic padding is explicitly disabled since its inconsistent behavior 141 | # FIXME: temporary enabling the tokenizer's truncation. 142 | if title: 143 | token_ids = self.tokenizer.encode(title, text_pair=text, add_special_tokens=add_special_tokens, 144 | max_length=self.max_length, 145 | pad_to_max_length=False, truncation=True) 146 | else: 147 | token_ids = self.tokenizer.encode(text, add_special_tokens=add_special_tokens, max_length=self.max_length, 148 | pad_to_max_length=False, truncation=True) 149 | 150 | seq_len = self.max_length 151 | if self.pad_to_max and len(token_ids) < seq_len: 152 | token_ids = token_ids + [self.tokenizer.pad_token_id] * (seq_len - len(token_ids)) 153 | if len(token_ids) > seq_len: 154 | token_ids = token_ids[0:seq_len] 155 | token_ids[-1] = self.tokenizer.sep_token_id 156 | 157 | return torch.tensor(token_ids) 158 | 159 | def get_pair_separator_ids(self) -> T: 160 | return torch.tensor([self.tokenizer.sep_token_id]) 161 | 162 | def get_pad_id(self) -> int: 163 | return self.tokenizer.pad_token_type_id 164 | 165 | def get_attn_mask(self, tokens_tensor: T) -> T: 166 | return tokens_tensor != self.get_pad_id() 167 | 168 | def is_sub_word_id(self, token_id: int): 169 | token = self.tokenizer.convert_ids_to_tokens([token_id])[0] 170 | return token.startswith("##") or token.startswith(" ##") 171 | 172 | def to_string(self, token_ids, skip_special_tokens=True): 173 | return self.tokenizer.decode(token_ids, skip_special_tokens=True) 174 | 175 | def set_pad_to_max(self, do_pad: bool): 176 | self.pad_to_max = do_pad 177 | 178 | 179 | class RobertaTensorizer(BertTensorizer): 180 | def __init__(self, tokenizer, max_length: int, pad_to_max: bool = True): 181 | super(RobertaTensorizer, self).__init__(tokenizer, max_length, pad_to_max=pad_to_max) 182 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/models/pytext_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Encoder model wrappers based on HuggingFace code 10 | """ 11 | 12 | import logging 13 | from typing import Tuple 14 | 15 | import torch 16 | from pytext.models.representations.transformer_sentence_encoder import TransformerSentenceEncoder 17 | from pytext.optimizer.optimizers import AdamW 18 | from torch import Tensor as T 19 | from torch import nn 20 | 21 | from .biencoder import BiEncoder 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def get_bert_biencoder_components(args, inference_only: bool = False): 27 | # since bert tokenizer is the same in HF and pytext/fairseq, just use HF's implementation here for now 28 | from .hf_models import get_tokenizer, BertTensorizer 29 | 30 | tokenizer = get_tokenizer(args.pretrained_model_cfg, do_lower_case=args.do_lower_case) 31 | 32 | question_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, 33 | projection_dim=args.projection_dim, dropout=args.dropout, 34 | vocab_size=tokenizer.vocab_size, 35 | padding_idx=tokenizer.pad_token_type_id 36 | ) 37 | 38 | ctx_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, 39 | projection_dim=args.projection_dim, dropout=args.dropout, 40 | vocab_size=tokenizer.vocab_size, 41 | padding_idx=tokenizer.pad_token_type_id 42 | ) 43 | 44 | biencoder = BiEncoder(question_encoder, ctx_encoder) 45 | 46 | optimizer = get_optimizer(biencoder, 47 | learning_rate=args.learning_rate, 48 | adam_eps=args.adam_eps, weight_decay=args.weight_decay, 49 | ) if not inference_only else None 50 | 51 | tensorizer = BertTensorizer(tokenizer, args.sequence_length) 52 | return tensorizer, biencoder, optimizer 53 | 54 | 55 | def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, 56 | weight_decay: float = 0.0) -> torch.optim.Optimizer: 57 | cfg = AdamW.Config() 58 | cfg.lr = learning_rate 59 | cfg.weight_decay = weight_decay 60 | cfg.eps = adam_eps 61 | optimizer = AdamW.from_config(cfg, model) 62 | return optimizer 63 | 64 | 65 | def get_pytext_bert_base_cfg(): 66 | cfg = TransformerSentenceEncoder.Config() 67 | cfg.embedding_dim = 768 68 | cfg.ffn_embedding_dim = 3072 69 | cfg.num_encoder_layers = 12 70 | cfg.num_attention_heads = 12 71 | cfg.num_segments = 2 72 | cfg.use_position_embeddings = True 73 | cfg.offset_positions_by_padding = True 74 | cfg.apply_bert_init = True 75 | cfg.encoder_normalize_before = True 76 | cfg.activation_fn = "gelu" 77 | cfg.projection_dim = 0 78 | cfg.max_seq_len = 512 79 | cfg.multilingual = False 80 | cfg.freeze_embeddings = False 81 | cfg.n_trans_layers_to_freeze = 0 82 | cfg.use_torchscript = False 83 | return cfg 84 | 85 | 86 | class PytextBertEncoder(TransformerSentenceEncoder): 87 | 88 | def __init__(self, config: TransformerSentenceEncoder.Config, 89 | padding_idx: int, 90 | vocab_size: int, 91 | projection_dim: int = 0, 92 | *args, 93 | **kwarg 94 | ): 95 | 96 | TransformerSentenceEncoder.__init__(self, config, False, padding_idx, vocab_size, *args, **kwarg) 97 | 98 | assert config.embedding_dim > 0, 'Encoder hidden_size can\'t be zero' 99 | self.encode_proj = nn.Linear(config.embedding_dim, projection_dim) if projection_dim != 0 else None 100 | 101 | @classmethod 102 | def init_encoder(cls, pretrained_file: str = None, projection_dim: int = 0, dropout: float = 0.1, 103 | vocab_size: int = 0, 104 | padding_idx: int = 0, **kwargs): 105 | cfg = get_pytext_bert_base_cfg() 106 | 107 | if dropout != 0: 108 | cfg.dropout = dropout 109 | cfg.attention_dropout = dropout 110 | cfg.activation_dropout = dropout 111 | 112 | encoder = cls(cfg, padding_idx, vocab_size, projection_dim, **kwargs) 113 | 114 | if pretrained_file: 115 | logger.info('Loading pre-trained pytext encoder state from %s', pretrained_file) 116 | state = torch.load(pretrained_file) 117 | encoder.load_state_dict(state) 118 | return encoder 119 | 120 | def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: 121 | pooled_output = super().forward((input_ids, attention_mask, token_type_ids, None))[0] 122 | if self.encode_proj: 123 | pooled_output = self.encode_proj(pooled_output) 124 | 125 | return None, pooled_output, None 126 | 127 | def get_out_size(self): 128 | if self.encode_proj: 129 | return self.encode_proj.out_features 130 | return self.representation_dim 131 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line arguments utils 10 | """ 11 | 12 | import argparse 13 | import logging 14 | import os 15 | import random 16 | import socket 17 | 18 | import numpy as np 19 | import torch 20 | 21 | logger = logging.getLogger() 22 | 23 | 24 | def add_tokenizer_params(parser: argparse.ArgumentParser): 25 | parser.add_argument("--do_lower_case", action='store_true', 26 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 27 | 28 | 29 | def add_encoder_params(parser: argparse.ArgumentParser): 30 | """ 31 | Common parameters to initialize an encoder-based model 32 | """ 33 | parser.add_argument("--pretrained_model_cfg", default=None, type=str, help="config name for model initialization") 34 | parser.add_argument("--encoder_model_type", default=None, type=str, 35 | help="model type. One of [hf_bert, pytext_bert, fairseq_roberta]") 36 | parser.add_argument('--pretrained_file', type=str, help="Some encoders need to be initialized from a file") 37 | parser.add_argument("--model_file", default=None, type=str, 38 | help="Saved bi-encoder checkpoint file to initialize the model") 39 | parser.add_argument("--projection_dim", default=0, type=int, 40 | help="Extra linear layer on top of standard bert/roberta encoder") 41 | parser.add_argument("--sequence_length", type=int, default=512, help="Max length of the encoder input sequence") 42 | parser.add_argument("--fix_ctx_encoder", action="store_true", help="fix context encoder.") 43 | 44 | 45 | def add_training_params(parser: argparse.ArgumentParser): 46 | """ 47 | Common parameters for training 48 | """ 49 | add_cuda_params(parser) 50 | parser.add_argument("--train_file", default=None, type=str, help="File pattern for the train set") 51 | parser.add_argument("--dev_file", default=None, type=str, help="") 52 | 53 | parser.add_argument("--batch_size", default=2, type=int, help="Amount of questions per batch") 54 | parser.add_argument("--dev_batch_size", type=int, default=4, 55 | help="amount of questions per batch for dev set validation") 56 | parser.add_argument('--seed', type=int, default=0, help="random seed for initialization and dataset shuffling") 57 | 58 | parser.add_argument("--adam_eps", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 59 | parser.add_argument("--adam_betas", default='(0.9, 0.999)', type=str, help="Betas for Adam optimizer.") 60 | 61 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 62 | parser.add_argument("--log_batch_step", default=100, type=int, help="") 63 | parser.add_argument("--train_rolling_loss_step", default=100, type=int, help="") 64 | parser.add_argument("--weight_decay", default=0.0, type=float, help="") 65 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 66 | 67 | parser.add_argument("--warmup_steps", default=100, type=int, help="Linear warmup over warmup_steps.") 68 | parser.add_argument("--dropout", default=0.1, type=float, help="") 69 | 70 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 71 | help="Number of updates steps to accumulate before performing a backward/update pass.") 72 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 73 | help="Total number of training epochs to perform.") 74 | 75 | 76 | def add_cuda_params(parser: argparse.ArgumentParser): 77 | parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") 78 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") 79 | parser.add_argument('--fp16', action='store_true', 80 | help="Whether to use 16-bit float precision instead of 32-bit") 81 | 82 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 83 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 84 | "See details at https://nvidia.github.io/apex/amp.html") 85 | 86 | 87 | def add_reader_preprocessing_params(parser: argparse.ArgumentParser): 88 | parser.add_argument("--gold_passages_src", type=str, 89 | help="File with the original dataset passages (json format). Required for train set") 90 | parser.add_argument("--gold_passages_src_dev", type=str, 91 | help="File with the original dataset passages (json format). Required for dev set") 92 | parser.add_argument("--num_workers", type=int, default=16, 93 | help="number of parallel processes to binarize reader data") 94 | 95 | 96 | def get_encoder_checkpoint_params_names(): 97 | return ['do_lower_case', 'pretrained_model_cfg', 'encoder_model_type', 98 | 'pretrained_file', 99 | 'projection_dim', 'sequence_length'] 100 | 101 | 102 | def get_encoder_params_state(args): 103 | """ 104 | Selects the param values to be saved in a checkpoint, so that a trained model faile can be used for downstream 105 | tasks without the need to specify these parameter again 106 | :return: Dict of params to memorize in a checkpoint 107 | """ 108 | params_to_save = get_encoder_checkpoint_params_names() 109 | 110 | r = {} 111 | for param in params_to_save: 112 | r[param] = getattr(args, param) 113 | return r 114 | 115 | 116 | def set_encoder_params_from_state(state, args): 117 | if not state: 118 | return 119 | params_to_save = get_encoder_checkpoint_params_names() 120 | 121 | override_params = [(param, state[param]) for param in params_to_save if param in state and state[param]] 122 | for param, value in override_params: 123 | if hasattr(args, param): 124 | logger.warning('Overriding args parameter value from checkpoint state. Param = %s, value = %s', param, 125 | value) 126 | setattr(args, param, value) 127 | return args 128 | 129 | 130 | def set_seed(args): 131 | seed = args.seed 132 | random.seed(seed) 133 | np.random.seed(seed) 134 | torch.manual_seed(seed) 135 | if args.n_gpu > 0: 136 | torch.cuda.manual_seed_all(seed) 137 | 138 | 139 | def setup_args_gpu(args): 140 | """ 141 | Setup arguments CUDA, GPU & distributed training 142 | """ 143 | 144 | if args.local_rank == -1 or args.no_cuda: # single-node multi-gpu (or cpu) mode 145 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 146 | args.n_gpu = torch.cuda.device_count() 147 | else: # distributed mode 148 | torch.cuda.set_device(args.local_rank) 149 | device = torch.device("cuda", args.local_rank) 150 | torch.distributed.init_process_group(backend="nccl") 151 | args.n_gpu = 1 152 | args.device = device 153 | ws = os.environ.get('WORLD_SIZE') 154 | 155 | args.distributed_world_size = int(ws) if ws else 1 156 | 157 | logger.info( 158 | 'Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d', socket.gethostname(), 159 | args.local_rank, device, 160 | args.n_gpu, 161 | args.distributed_world_size) 162 | logger.info("16-bits training: %s ", args.fp16) 163 | 164 | 165 | def print_args(args): 166 | logger.info(" **************** CONFIGURATION **************** ") 167 | for key, val in sorted(vars(args).items()): 168 | keystr = "{}".format(key) + (" " * (30 - len(key))) 169 | logger.info("%s --> %s", keystr, val) 170 | logger.info(" **************** CONFIGURATION **************** ") 171 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mia-workshop/MIA-Shared-Task-2022/ea2491a3e3a588e270b2dd30ce7e044584c3eeb5/baseline/mDPR/dpr/utils/__init__.py -------------------------------------------------------------------------------- /baseline/mDPR/dpr/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for general purpose data processing 10 | """ 11 | 12 | import json 13 | import logging 14 | import math 15 | import pickle 16 | import random 17 | from typing import List, Iterator, Callable 18 | 19 | from torch import Tensor as T 20 | 21 | logger = logging.getLogger() 22 | 23 | 24 | def read_serialized_data_from_files(paths: List[str]) -> List: 25 | results = [] 26 | for i, path in enumerate(paths): 27 | with open(path, "rb") as reader: 28 | logger.info('Reading file %s', path) 29 | data = pickle.load(reader) 30 | results.extend(data) 31 | logger.info('Aggregated data size: {}'.format(len(results))) 32 | logger.info('Total data size: {}'.format(len(results))) 33 | return results 34 | 35 | 36 | def read_data_from_json_files(paths: List[str], upsample_rates: List = None) -> List: 37 | results = [] 38 | if upsample_rates is None: 39 | upsample_rates = [1] * len(paths) 40 | 41 | assert len(upsample_rates) == len(paths), 'up-sample rates parameter doesn\'t match input files amount' 42 | 43 | for i, path in enumerate(paths): 44 | with open(path, 'r', encoding="utf-8") as f: 45 | logger.info('Reading file %s' % path) 46 | data = json.load(f) 47 | upsample_factor = int(upsample_rates[i]) 48 | data = data * upsample_factor 49 | results.extend(data) 50 | logger.info('Aggregated data size: {}'.format(len(results))) 51 | return results 52 | 53 | 54 | class ShardedDataIterator(object): 55 | """ 56 | General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of 57 | the data. 58 | Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. 59 | It fills the extra sample by just taking first samples in a shard. 60 | It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). 61 | """ 62 | def __init__(self, data: list, shard_id: int = 0, num_shards: int = 1, batch_size: int = 1, shuffle=True, 63 | shuffle_seed: int = 0, offset: int = 0, 64 | strict_batch_size: bool = False 65 | ): 66 | 67 | self.data = data 68 | total_size = len(data) 69 | 70 | self.shards_num = max(num_shards, 1) 71 | self.shard_id = max(shard_id, 0) 72 | 73 | samples_per_shard = math.ceil(total_size / self.shards_num) 74 | 75 | self.shard_start_idx = self.shard_id * samples_per_shard 76 | 77 | self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size) 78 | 79 | if strict_batch_size: 80 | self.max_iterations = math.ceil(samples_per_shard / batch_size) 81 | else: 82 | self.max_iterations = int(samples_per_shard / batch_size) 83 | 84 | logger.debug( 85 | 'samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d', samples_per_shard, 86 | self.shard_start_idx, 87 | self.shard_end_idx, 88 | self.max_iterations) 89 | 90 | self.iteration = offset # to track in-shard iteration status 91 | self.shuffle = shuffle 92 | self.batch_size = batch_size 93 | self.shuffle_seed = shuffle_seed 94 | self.strict_batch_size = strict_batch_size 95 | 96 | def total_data_len(self) -> int: 97 | return len(self.data) 98 | 99 | def iterate_data(self, epoch: int = 0) -> Iterator[List]: 100 | if self.shuffle: 101 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 102 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 103 | epoch_rnd.shuffle(self.data) 104 | 105 | # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations 106 | 107 | max_iterations = self.max_iterations - self.iteration 108 | 109 | shard_samples = self.data[self.shard_start_idx:self.shard_end_idx] 110 | for i in range(self.iteration * self.batch_size, len(shard_samples), self.batch_size): 111 | items = shard_samples[i:i + self.batch_size] 112 | if self.strict_batch_size and len(items) < self.batch_size: 113 | logger.debug('Extending batch to max size') 114 | items.extend(shard_samples[0:self.batch_size - len(items)]) 115 | self.iteration += 1 116 | yield items 117 | 118 | # some shards may done iterating while the others are at the last batch. Just return the first batch 119 | while self.iteration < max_iterations: 120 | logger.debug('Fulfilling non complete shard='.format(self.shard_id)) 121 | self.iteration += 1 122 | batch = shard_samples[0:self.batch_size] 123 | yield batch 124 | 125 | logger.debug('Finished iterating, iteration={}, shard={}'.format(self.iteration, self.shard_id)) 126 | # reset the iteration status 127 | self.iteration = 0 128 | 129 | def get_iteration(self) -> int: 130 | return self.iteration 131 | 132 | def apply(self, visitor_func: Callable): 133 | for sample in self.data: 134 | visitor_func(sample) 135 | 136 | 137 | def normalize_question(question: str) -> str: 138 | if question[-1] == '?': 139 | question = question[:-1] 140 | return question 141 | 142 | 143 | class Tensorizer(object): 144 | """ 145 | Component for all text to model input data conversions and related utility methods 146 | """ 147 | 148 | # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) 149 | def text_to_tensor(self, text: str, title: str = None, add_special_tokens: bool = True): 150 | raise NotImplementedError 151 | 152 | def get_pair_separator_ids(self) -> T: 153 | raise NotImplementedError 154 | 155 | def get_pad_id(self) -> int: 156 | raise NotImplementedError 157 | 158 | def get_attn_mask(self, tokens_tensor: T): 159 | raise NotImplementedError 160 | 161 | def is_sub_word_id(self, token_id: int): 162 | raise NotImplementedError 163 | 164 | def to_string(self, token_ids, skip_special_tokens=True): 165 | raise NotImplementedError 166 | 167 | def set_pad_to_max(self, pad: bool): 168 | raise NotImplementedError 169 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for distributed model training 10 | """ 11 | 12 | import pickle 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | def get_rank(): 19 | return dist.get_rank() 20 | 21 | 22 | def get_world_size(): 23 | return dist.get_world_size() 24 | 25 | 26 | def get_default_group(): 27 | return dist.group.WORLD 28 | 29 | 30 | def all_reduce(tensor, group=None): 31 | if group is None: 32 | group = get_default_group() 33 | return dist.all_reduce(tensor, group=group) 34 | 35 | 36 | def all_gather_list(data, group=None, max_size=16384): 37 | """Gathers arbitrary data from all nodes into a list. 38 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 39 | data. Note that *data* must be picklable. 40 | Args: 41 | data (Any): data from the local worker to be gathered on other workers 42 | group (optional): group of the collective 43 | """ 44 | SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size 45 | 46 | enc = pickle.dumps(data) 47 | enc_size = len(enc) 48 | 49 | if enc_size + SIZE_STORAGE_BYTES > max_size: 50 | raise ValueError( 51 | 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) 52 | 53 | rank = get_rank() 54 | world_size = get_world_size() 55 | buffer_size = max_size * world_size 56 | 57 | if not hasattr(all_gather_list, '_buffer') or \ 58 | all_gather_list._buffer.numel() < buffer_size: 59 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 60 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 61 | 62 | buffer = all_gather_list._buffer 63 | buffer.zero_() 64 | cpu_buffer = all_gather_list._cpu_buffer 65 | 66 | assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( 67 | 256 ** SIZE_STORAGE_BYTES) 68 | 69 | size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') 70 | 71 | cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) 72 | cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) 73 | 74 | start = rank * max_size 75 | size = enc_size + SIZE_STORAGE_BYTES 76 | buffer[start: start + size].copy_(cpu_buffer[:size]) 77 | 78 | all_reduce(buffer, group=group) 79 | 80 | try: 81 | result = [] 82 | for i in range(world_size): 83 | out_buffer = buffer[i * max_size: (i + 1) * max_size] 84 | size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') 85 | if size > 0: 86 | result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) 87 | return result 88 | except pickle.UnpicklingError: 89 | raise Exception( 90 | 'Unable to unpickle data from other workers. all_gather_list requires all ' 91 | 'workers to enter the function together, so this error usually indicates ' 92 | 'that the workers have fallen out of sync somehow. Workers can fall out of ' 93 | 'sync if one of them runs out of memory, or if there are other conditions ' 94 | 'in your training script that can cause one worker to finish an epoch ' 95 | 'while other workers are still iterating over their portions of the data.' 96 | ) 97 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import glob 10 | import logging 11 | import os 12 | from typing import List 13 | 14 | import torch 15 | from torch import nn 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.serialization import default_restore_location 18 | 19 | logger = logging.getLogger() 20 | 21 | CheckpointState = collections.namedtuple("CheckpointState", 22 | ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', 23 | 'encoder_params']) 24 | 25 | 26 | def setup_for_distributed_mode(model: nn.Module, optimizer: torch.optim.Optimizer, device: object, n_gpu: int = 1, 27 | local_rank: int = -1, 28 | fp16: bool = False, 29 | fp16_opt_level: str = "O1") -> (nn.Module, torch.optim.Optimizer): 30 | model.to(device) 31 | if fp16: 32 | try: 33 | import apex 34 | from apex import amp 35 | apex.amp.register_half_function(torch, "einsum") 36 | except ImportError: 37 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 38 | 39 | model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) 40 | 41 | if n_gpu > 1: 42 | model = torch.nn.DataParallel(model) 43 | 44 | if local_rank != -1: 45 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], 46 | output_device=local_rank, 47 | find_unused_parameters=True) 48 | return model, optimizer 49 | 50 | 51 | def move_to_cuda(sample): 52 | if len(sample) == 0: 53 | return {} 54 | 55 | def _move_to_cuda(maybe_tensor): 56 | if torch.is_tensor(maybe_tensor): 57 | return maybe_tensor.cuda() 58 | elif isinstance(maybe_tensor, dict): 59 | return { 60 | key: _move_to_cuda(value) 61 | for key, value in maybe_tensor.items() 62 | } 63 | elif isinstance(maybe_tensor, list): 64 | return [_move_to_cuda(x) for x in maybe_tensor] 65 | elif isinstance(maybe_tensor, tuple): 66 | return [_move_to_cuda(x) for x in maybe_tensor] 67 | else: 68 | return maybe_tensor 69 | 70 | return _move_to_cuda(sample) 71 | 72 | 73 | def move_to_device(sample, device): 74 | if len(sample) == 0: 75 | return {} 76 | 77 | def _move_to_device(maybe_tensor, device): 78 | if torch.is_tensor(maybe_tensor): 79 | return maybe_tensor.to(device) 80 | elif isinstance(maybe_tensor, dict): 81 | return { 82 | key: _move_to_device(value, device) 83 | for key, value in maybe_tensor.items() 84 | } 85 | elif isinstance(maybe_tensor, list): 86 | return [_move_to_device(x, device) for x in maybe_tensor] 87 | elif isinstance(maybe_tensor, tuple): 88 | return [_move_to_device(x, device) for x in maybe_tensor] 89 | else: 90 | return maybe_tensor 91 | 92 | return _move_to_device(sample, device) 93 | 94 | 95 | def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1): 96 | """ Create a schedule with a learning rate that decreases linearly after 97 | linearly increasing during a warmup period. 98 | """ 99 | 100 | def lr_lambda(current_step): 101 | if current_step < warmup_steps: 102 | return float(current_step) / float(max(1, warmup_steps)) 103 | return max( 104 | 0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps)) 105 | ) 106 | 107 | return LambdaLR(optimizer, lr_lambda, last_epoch) 108 | 109 | 110 | def init_weights(modules: List): 111 | for module in modules: 112 | if isinstance(module, (nn.Linear, nn.Embedding)): 113 | module.weight.data.normal_(mean=0.0, std=0.02) 114 | elif isinstance(module, nn.LayerNorm): 115 | module.bias.data.zero_() 116 | module.weight.data.fill_(1.0) 117 | if isinstance(module, nn.Linear) and module.bias is not None: 118 | module.bias.data.zero_() 119 | 120 | 121 | def get_model_obj(model: nn.Module): 122 | return model.module if hasattr(model, 'module') else model 123 | 124 | 125 | def get_model_file(args, file_prefix) -> str: 126 | if args.model_file and os.path.exists(args.model_file): 127 | return args.model_file 128 | 129 | out_cp_files = glob.glob(os.path.join(args.output_dir, file_prefix + '*')) if args.output_dir else [] 130 | logger.info('Checkpoint files %s', out_cp_files) 131 | model_file = None 132 | 133 | if len(out_cp_files) > 0: 134 | model_file = max(out_cp_files, key=os.path.getctime) 135 | return model_file 136 | 137 | 138 | def load_states_from_checkpoint(model_file: str) -> CheckpointState: 139 | logger.info('Reading saved model from %s', model_file) 140 | state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu')) 141 | logger.info('model_state_dict keys %s', state_dict.keys()) 142 | return CheckpointState(**state_dict) 143 | -------------------------------------------------------------------------------- /baseline/mDPR/dpr/utils/tokenizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | """ 10 | Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency 11 | """ 12 | 13 | import copy 14 | import logging 15 | 16 | import regex 17 | import spacy 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Tokens(object): 23 | """A class to represent a list of tokenized text.""" 24 | TEXT = 0 25 | TEXT_WS = 1 26 | SPAN = 2 27 | POS = 3 28 | LEMMA = 4 29 | NER = 5 30 | 31 | def __init__(self, data, annotators, opts=None): 32 | self.data = data 33 | self.annotators = annotators 34 | self.opts = opts or {} 35 | 36 | def __len__(self): 37 | """The number of tokens.""" 38 | return len(self.data) 39 | 40 | def slice(self, i=None, j=None): 41 | """Return a view of the list of tokens from [i, j).""" 42 | new_tokens = copy.copy(self) 43 | new_tokens.data = self.data[i: j] 44 | return new_tokens 45 | 46 | def untokenize(self): 47 | """Returns the original text (with whitespace reinserted).""" 48 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 49 | 50 | def words(self, uncased=False): 51 | """Returns a list of the text of each token 52 | 53 | Args: 54 | uncased: lower cases text 55 | """ 56 | if uncased: 57 | return [t[self.TEXT].lower() for t in self.data] 58 | else: 59 | return [t[self.TEXT] for t in self.data] 60 | 61 | def offsets(self): 62 | """Returns a list of [start, end) character offsets of each token.""" 63 | return [t[self.SPAN] for t in self.data] 64 | 65 | def pos(self): 66 | """Returns a list of part-of-speech tags of each token. 67 | Returns None if this annotation was not included. 68 | """ 69 | if 'pos' not in self.annotators: 70 | return None 71 | return [t[self.POS] for t in self.data] 72 | 73 | def lemmas(self): 74 | """Returns a list of the lemmatized text of each token. 75 | Returns None if this annotation was not included. 76 | """ 77 | if 'lemma' not in self.annotators: 78 | return None 79 | return [t[self.LEMMA] for t in self.data] 80 | 81 | def entities(self): 82 | """Returns a list of named-entity-recognition tags of each token. 83 | Returns None if this annotation was not included. 84 | """ 85 | if 'ner' not in self.annotators: 86 | return None 87 | return [t[self.NER] for t in self.data] 88 | 89 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 90 | """Returns a list of all ngrams from length 1 to n. 91 | 92 | Args: 93 | n: upper limit of ngram length 94 | uncased: lower cases text 95 | filter_fn: user function that takes in an ngram list and returns 96 | True or False to keep or not keep the ngram 97 | as_string: return the ngram as a string vs list 98 | """ 99 | 100 | def _skip(gram): 101 | if not filter_fn: 102 | return False 103 | return filter_fn(gram) 104 | 105 | words = self.words(uncased) 106 | ngrams = [(s, e + 1) 107 | for s in range(len(words)) 108 | for e in range(s, min(s + n, len(words))) 109 | if not _skip(words[s:e + 1])] 110 | 111 | # Concatenate into strings 112 | if as_strings: 113 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 114 | 115 | return ngrams 116 | 117 | def entity_groups(self): 118 | """Group consecutive entity tokens with the same NER tag.""" 119 | entities = self.entities() 120 | if not entities: 121 | return None 122 | non_ent = self.opts.get('non_ent', 'O') 123 | groups = [] 124 | idx = 0 125 | while idx < len(entities): 126 | ner_tag = entities[idx] 127 | # Check for entity tag 128 | if ner_tag != non_ent: 129 | # Chomp the sequence 130 | start = idx 131 | while (idx < len(entities) and entities[idx] == ner_tag): 132 | idx += 1 133 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 134 | else: 135 | idx += 1 136 | return groups 137 | 138 | 139 | class Tokenizer(object): 140 | """Base tokenizer class. 141 | Tokenizers implement tokenize, which should return a Tokens class. 142 | """ 143 | 144 | def tokenize(self, text): 145 | raise NotImplementedError 146 | 147 | def shutdown(self): 148 | pass 149 | 150 | def __del__(self): 151 | self.shutdown() 152 | 153 | 154 | class SimpleTokenizer(Tokenizer): 155 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 156 | NON_WS = r'[^\p{Z}\p{C}]' 157 | 158 | def __init__(self, **kwargs): 159 | """ 160 | Args: 161 | annotators: None or empty set (only tokenizes). 162 | """ 163 | self._regexp = regex.compile( 164 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 165 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 166 | ) 167 | if len(kwargs.get('annotators', {})) > 0: 168 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 169 | (type(self).__name__, kwargs.get('annotators'))) 170 | self.annotators = set() 171 | 172 | def tokenize(self, text): 173 | data = [] 174 | matches = [m for m in self._regexp.finditer(text)] 175 | for i in range(len(matches)): 176 | # Get text 177 | token = matches[i].group() 178 | 179 | # Get whitespace 180 | span = matches[i].span() 181 | start_ws = span[0] 182 | if i + 1 < len(matches): 183 | end_ws = matches[i + 1].span()[0] 184 | else: 185 | end_ws = span[1] 186 | 187 | # Format data 188 | data.append(( 189 | token, 190 | text[start_ws: end_ws], 191 | span, 192 | )) 193 | return Tokens(data, self.annotators) 194 | 195 | 196 | class SpacyTokenizer(Tokenizer): 197 | 198 | def __init__(self, **kwargs): 199 | """ 200 | Args: 201 | annotators: set that can include pos, lemma, and ner. 202 | model: spaCy model to use (either path, or keyword like 'en'). 203 | """ 204 | model = kwargs.get('model', 'en') 205 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 206 | nlp_kwargs = {'parser': False} 207 | if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 208 | nlp_kwargs['tagger'] = False 209 | if 'ner' not in self.annotators: 210 | nlp_kwargs['entity'] = False 211 | self.nlp = spacy.load(model, **nlp_kwargs) 212 | 213 | def tokenize(self, text): 214 | # We don't treat new lines as tokens. 215 | clean_text = text.replace('\n', ' ') 216 | tokens = self.nlp.tokenizer(clean_text) 217 | if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 218 | self.nlp.tagger(tokens) 219 | if 'ner' in self.annotators: 220 | self.nlp.entity(tokens) 221 | 222 | data = [] 223 | for i in range(len(tokens)): 224 | # Get whitespace 225 | start_ws = tokens[i].idx 226 | if i + 1 < len(tokens): 227 | end_ws = tokens[i + 1].idx 228 | else: 229 | end_ws = tokens[i].idx + len(tokens[i].text) 230 | 231 | data.append(( 232 | tokens[i].text, 233 | text[start_ws: end_ws], 234 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 235 | tokens[i].tag_, 236 | tokens[i].lemma_, 237 | tokens[i].ent_type_, 238 | )) 239 | 240 | # Set special option for non-entity tag: '' vs 'O' in spaCy 241 | return Tokens(data, self.annotators, opts={'non_ent': ''}) 242 | -------------------------------------------------------------------------------- /baseline/mDPR/generate_dense_embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line tool that produces embeddings for a large documents base based on the pretrained ctx & question encoders 10 | Supposed to be used in a 'sharded' way to speed up the process. 11 | """ 12 | import os 13 | import pathlib 14 | 15 | import argparse 16 | import csv 17 | import logging 18 | import pickle 19 | from typing import List, Tuple 20 | 21 | import numpy as np 22 | import torch 23 | from torch import nn 24 | 25 | from dpr.models import init_biencoder_components 26 | from dpr.options import add_encoder_params, setup_args_gpu, print_args, set_encoder_params_from_state, \ 27 | add_tokenizer_params, add_cuda_params 28 | from dpr.utils.data_utils import Tensorizer 29 | from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint,move_to_device 30 | 31 | logger = logging.getLogger() 32 | logger.setLevel(logging.INFO) 33 | if (logger.hasHandlers()): 34 | logger.handlers.clear() 35 | console = logging.StreamHandler() 36 | logger.addHandler(console) 37 | 38 | def gen_ctx_vectors(ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer, 39 | insert_title: bool = True) -> List[Tuple[object, np.array]]: 40 | n = len(ctx_rows) 41 | bsz = args.batch_size 42 | total = 0 43 | results = [] 44 | for j, batch_start in enumerate(range(0, n, bsz)): 45 | 46 | batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in 47 | ctx_rows[batch_start:batch_start + bsz]] 48 | 49 | ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),args.device) 50 | ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),args.device) 51 | ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),args.device) 52 | with torch.no_grad(): 53 | _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) 54 | out = out.cpu() 55 | 56 | ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]] 57 | 58 | assert len(ctx_ids) == out.size(0) 59 | 60 | total += len(ctx_ids) 61 | 62 | results.extend([ 63 | (ctx_ids[i], out[i].view(-1).numpy()) 64 | for i in range(out.size(0)) 65 | ]) 66 | 67 | if total % 10 == 0: 68 | logger.info('Encoded passages %d', total) 69 | 70 | return results 71 | 72 | 73 | def main(args): 74 | saved_state = load_states_from_checkpoint(args.model_file) 75 | set_encoder_params_from_state(saved_state.encoder_params, args) 76 | print_args(args) 77 | 78 | tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) 79 | 80 | encoder = encoder.ctx_model 81 | 82 | encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, 83 | args.local_rank, 84 | args.fp16, 85 | args.fp16_opt_level) 86 | encoder.eval() 87 | 88 | # load weights from the model file 89 | model_to_load = get_model_obj(encoder) 90 | logger.info('Loading saved model state ...') 91 | logger.debug('saved model keys =%s', saved_state.model_dict.keys()) 92 | 93 | prefix_len = len('ctx_model.') 94 | ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if 95 | key.startswith('ctx_model.')} 96 | model_to_load.load_state_dict(ctx_state) 97 | 98 | logger.info('reading data from file=%s', args.ctx_file) 99 | 100 | rows = [] 101 | with open(args.ctx_file) as tsvfile: 102 | reader = csv.reader(tsvfile, delimiter='\t') 103 | # file format: doc_id, doc_text, title 104 | rows.extend([(row[0], row[1], row[2]) for row in reader if row[0] != 'id']) 105 | 106 | shard_size = int(len(rows) / args.num_shards) 107 | start_idx = args.shard_id * shard_size 108 | end_idx = start_idx + shard_size 109 | 110 | logger.info('Producing encodings for passages range: %d to %d (out of total %d)', start_idx, end_idx, len(rows)) 111 | rows = rows[start_idx:end_idx] 112 | 113 | data = gen_ctx_vectors(rows, encoder, tensorizer, True) 114 | 115 | file = args.out_file + '_' + str(args.shard_id) 116 | pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) 117 | logger.info('Writing results to %s' % file) 118 | with open(file, mode='wb') as f: 119 | pickle.dump(data, f) 120 | 121 | logger.info('Total passages processed %d. Written to %s', len(data), file) 122 | 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser() 126 | 127 | add_encoder_params(parser) 128 | add_tokenizer_params(parser) 129 | add_cuda_params(parser) 130 | 131 | parser.add_argument('--ctx_file', type=str, default=None, help='Path to passages set .tsv file') 132 | parser.add_argument('--out_file', required=True, type=str, default=None, 133 | help='output .tsv file path to write results to ') 134 | parser.add_argument('--shard_id', type=int, default=0, help="Number(0-based) of data shard to process") 135 | parser.add_argument('--num_shards', type=int, default=1, help="Total amount of data shards") 136 | parser.add_argument('--batch_size', type=int, default=32, help="Batch size for the passage encoder forward pass") 137 | args = parser.parse_args() 138 | 139 | assert args.model_file, 'Please specify --model_file checkpoint to init model weights' 140 | 141 | setup_args_gpu(args) 142 | 143 | 144 | main(args) 145 | -------------------------------------------------------------------------------- /baseline/mGEN/README.md: -------------------------------------------------------------------------------- 1 | ## mGEN 2 | This directory contains the code for the mGEN component. The code is originally based on [the transformers' implementation of RAG](https://github.com/huggingface/transformers/tree/v4.2.1/examples/research_projects/rag). 3 | 4 | ### Installation 5 | Please download the dependencies by running the command below: 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ### Data 12 | To lift the burden of training and running computationally expensive retrieval models, we release the retrieval results (top 50 passages) for the training, dev and test set (will be released when the official test data will be released). 13 | You can also use retrieval results of your own retriever(s). 14 | 15 | Here, we first download the mDPR retrieval results for the official training and dev sets and convert the data format. 16 | 17 | #### mDPR retrieval results 18 | 19 | - Training data 20 | ``` 21 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia_shared_training_dpr_retrieval_results.json 22 | ``` 23 | 24 | - XOR QA development data 25 | 26 | ``` 27 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia_shared_xorqa_development_dpr_retrieval_results.json 28 | ``` 29 | 30 | - MKQA development data 31 | The retrieval results for MKQA subsets are available here: 32 | 33 | ``` 34 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_non_iterative_baselines_mkqa_dev.zip 35 | unzip mia2022_non_iterative_baselines_mkqa_dev.zip 36 | ``` 37 | 38 | #### Data format 39 | 40 | Our fine-tuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files: 41 | 42 | ```bash 43 | train.source 44 | train.target 45 | val.source 46 | val.target 47 | test.source 48 | test.target 49 | ``` 50 | Each line contains each source/target sentence. 51 | 52 | #### Convert mDPR output to mGEN train data format 53 | This scripts convert the DPR output file into mGEN train data format. Please set the file names for train, dev and test data (`--train_fp`, `--dev_fp`, and `--test_fp`) and the output directory name (`--output_dir`). You can choose the number of the top DPR retrieved passages (`--top_n`). 54 | 55 | ``` 56 | python3 convert_dpr_retrieval_results_to_seq2seq.py \ 57 | --train_fp /path/to/dpr/output/train/data.json --dev_fp /path/to/dpr/output/dev/data.json \ 58 | --output_dir /path/to/mgen/data/dir \ 59 | --top_n 15 --add_lang 60 | ``` 61 | 62 | - Augment training data with WikiData 63 | CORA introduces a WikiData-based simple data augmentation approach for languages not covered in the human annotated training data. 64 | In particular, this approach retrieves Wikipedia entities in many languages corresponding to the original English answers in Natural Questions, and automatically generate cross-lingual mGEN training data by replacing the English answers with the target languages and appending language tags to the questions. 65 | 66 | Please follow the steps below if you want to try this data augmentation approach. 67 | 68 | 1. Retrieve corresponding Wikipedia entities 69 | ``` 70 | python align_wikidata.py --input_fp /path/to/input/qa/data.jsonl --output_fp /path/to/output/entity/file/name.json --sample_num 10000 71 | ``` 72 | The API can get slow, so for our baselines, we sample 10k NQ questions and retrieve corresponding entities. We obtained entities for about 6k questions. 73 | 74 | 2. Augment data with Wikipedia entity file 75 | 76 | ``` 77 | python convert_dpr_retrieval_results_to_seq2seq.py \ 78 | --train_fp /path/to/dpr/output/train/data.json \ 79 | --dev_fp /path/to/dpr/output/dev/data.json \ 80 | --output_dir /path/to/mgen/data/dir \ 81 | --top_n 15 --add_lang \ 82 | --ent_fp /path/to/output/entity/file/name.json 83 | ``` 84 | 85 | You can download the processed mGEN training data: 86 | ``` 87 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia_shared_task_mgen_nq_added.zip 88 | ``` 89 | ### Training 90 | Please specify the `model_type`, `model_name_or_path` and `gpus` (the number of GPUs to be used during fine-tuning). 91 | 92 | - Train `mt5-base` based model 93 | 94 | ```sh 95 | python finetune_mgen.py \ 96 | --data_dir /path/to/your/data/dir \ 97 | --output_dir /path/to/output/dir \ 98 | --model_name_or_path /path/to/previous_best_checkpoint \ 99 | --model_type mt5 --gpus 8 \ 100 | --do_train \ 101 | --do_predict \ 102 | --train_batch_size 4 \ 103 | --eval_batch_size 1 \ 104 | --max_source_length 1000 \ 105 | --max_target_length 20 \ 106 | --val_max_target_length 25 \ 107 | --test_max_target_length 25 \ 108 | --label_smoothing 0.1 \ 109 | --dropout 0.1 \ 110 | --num_train_epochs 50 \ 111 | --warmup_steps 500 112 | --learning_rate 3e-05 \ 113 | --weight_decay 0.001 \ 114 | --adam_epsilon 1e-08 \ 115 | --max_grad_norm 0.1 \ 116 | ``` 117 | 118 | - Train `mt5-large` based model. We train our mGEN on 8 GPUs with 24GB memory, and we found that we cannot train the model even with `train_batch_size==1` when we use adam optimizer. To fine-tune mt5-large based model, you have to set `--adafactor` option. 119 | 120 | ```sh 121 | python finetune_mgen.py \ 122 | --data_dir /path/to/your/data/dir \ 123 | --output_dir /path/to/model/output/dir \ 124 | --model_name_or_path /path/to/previous_best_checkpoint \ 125 | --model_type mt5 --gpus 8 \ 126 | --do_train \ 127 | --do_predict \ 128 | --train_batch_size 1 \ 129 | --eval_batch_size 1 \ 130 | --max_source_length 800 \ 131 | --max_target_length 20 \ 132 | --val_max_target_length 25 \ 133 | --test_max_target_length 25 \ 134 | --label_smoothing 0.1 \ 135 | --dropout 0.1 \ 136 | --num_train_epochs 50 \ 137 | --warmup_steps 500 138 | --learning_rate 3e-05 \ 139 | --weight_decay 0.001 \ 140 | --adam_epsilon 1e-08 \ 141 | --max_grad_norm 0.1 \ 142 | --adafactor 143 | ``` 144 | 145 | ### Evaluation 146 | 147 | 1. Run DPR 148 | TO evaluate your trained mGEN model, you first need to retrieve passages using mDPR. Please follow the instruction in [mDPR](../mDPR) directory. 149 | 150 | 2. Convert DPR output 151 | Please concert DPR output file as mentioned above. 152 | 153 | 3. Run mGEN 154 | Please run the mGEN evaluation by running [`eval_mgen.py`](eval_mgen.py). 155 | 156 | ``` 157 | CUDA_VISIBLE_DEVICES=0 python eval_mgen.py \ 158 | --model_name_or_path /path/to/model/output/dir \ 159 | --evaluation_set /path/to/your/data/dir/val.source \ 160 | --gold_data_path /path/to/your/data/dir/gold_para_qa_data_dev.tsv \ 161 | --predictions_path mgen_output.txt \ 162 | --gold_data_mode qa \ 163 | --model_type mt5 \ 164 | --max_length 20 \ 165 | --eval_batch_size 8 166 | ``` 167 | 168 | 169 | -------------------------------------------------------------------------------- /baseline/mGEN/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | sys.path.insert(1, os.path.dirname(os.path.realpath(__file__))) 6 | -------------------------------------------------------------------------------- /baseline/mGEN/align_wikidata.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | from collections import Counter 3 | from tqdm import tqdm 4 | import wptools 5 | import json 6 | import argparse 7 | import random 8 | 9 | 10 | def read_jsonlines(eval_file_name): 11 | lines = [] 12 | print("loading examples from {0}".format(eval_file_name)) 13 | with jsonlines.open(eval_file_name) as reader: 14 | for obj in reader: 15 | lines.append(obj) 16 | return lines 17 | 18 | 19 | def wikidata_alignment(answer): 20 | page = wptools.page(answer) 21 | answer_dict = {} 22 | try: 23 | page.get_more() 24 | for item in page.data["languages"]: 25 | answer_dict[item["lang"]] = item["title"] 26 | return answer_dict 27 | except: 28 | print("cannot find the answer") 29 | return None 30 | 31 | 32 | def postprocess(answer_string): 33 | if "(" in answer_string: 34 | return answer_string.split("(")[0] 35 | else: 36 | return answer_string 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--input_fp", default=None, type=str) 42 | parser.add_argument("--dpr_data", action="store_true") 43 | parser.add_argument("--output_fp", default=None, type=str) 44 | parser.add_argument("--sample_num", default=None, type=int) 45 | 46 | args = parser.parse_args() 47 | 48 | if args.dpr_data is True: 49 | # read input from DPR format file to align the English gold articles to the corresponding ones in the other languages. 50 | input_data = json.load(open(args.input_fp)) 51 | else: 52 | # read input data in the xor qa format. 53 | input_data = read_jsonlines(args.input_fp) 54 | if args.sample_num is not None: 55 | input_data = random.sample(input_data, k=args.sample_num) 56 | output_data = {} 57 | print("original input data num:{}".format(len(input_data))) 58 | 59 | for idx, item in tqdm(enumerate(input_data)): 60 | if args.dpr_data is True: 61 | answers = [item["positive_ctxs"][0]["title"]] 62 | q_id = idx 63 | else: 64 | answers = item["answers"] 65 | q_id = item["id"] 66 | 67 | # remove this all digit cases? 68 | for answer in list(set(answers)): 69 | if str(answer).isdigit() == False: 70 | translated_answers = wikidata_alignment(answer) 71 | if translated_answers is not None: 72 | output_data.setdefault(q_id, {}) 73 | for lang, answer in translated_answers.items(): 74 | translated_answer = postprocess(answer) 75 | output_data[q_id][lang] = translated_answer 76 | else: 77 | translated_answer = str(answer) 78 | output_data.setdefault(q_id, {}) 79 | output_data[q_id]["numeric"] = translated_answer 80 | print("found aligned answers for {} questions".format(len(output_data))) 81 | 82 | print("final data num: {}".format(len(output_data))) 83 | with open(args.output_fp, 'w') as outfile: 84 | json.dump(output_data, outfile) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /baseline/mGEN/callbacks_rag.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 9 | from pytorch_lightning.utilities import rank_zero_only 10 | 11 | from utils_rag import save_json 12 | 13 | def count_trainable_parameters(model): 14 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 15 | params = sum([np.prod(p.size()) for p in model_parameters]) 16 | return params 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def get_checkpoint_callback(output_dir, metric): 23 | """Saves the best model by validation EM score.""" 24 | if metric == "rouge2": 25 | exp = "{val_avg_rouge2:.4f}-{step_count}" 26 | elif metric == "bleu": 27 | exp = "{val_avg_bleu:.4f}-{step_count}" 28 | elif metric == "em": 29 | exp = "{val_avg_em:.4f}-{step_count}" 30 | else: 31 | raise NotImplementedError( 32 | f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." 33 | ) 34 | 35 | checkpoint_callback = ModelCheckpoint( 36 | filepath=os.path.join(output_dir, exp), 37 | monitor=f"val_{metric}", 38 | mode="max", 39 | save_top_k=3, 40 | # maybe save a checkpoint every time val is run, not just end of epoch. 41 | period=1, 42 | ) 43 | return checkpoint_callback 44 | 45 | 46 | def get_early_stopping_callback(metric, patience): 47 | return EarlyStopping( 48 | monitor=f"val_{metric}", # does this need avg? 49 | mode="min" if "loss" in metric else "max", 50 | patience=patience, 51 | verbose=True, 52 | ) 53 | 54 | 55 | class Seq2SeqLoggingCallback(pl.Callback): 56 | def on_batch_end(self, trainer, pl_module): 57 | lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate( 58 | pl_module.trainer.optimizers[0].param_groups)} 59 | pl_module.logger.log_metrics(lrs) 60 | 61 | @rank_zero_only 62 | def _write_logs( 63 | self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True 64 | ) -> None: 65 | logger.info( 66 | f"***** {type_path} results at step {trainer.global_step:05d} *****") 67 | metrics = trainer.callback_metrics 68 | trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in [ 69 | "log", "progress_bar", "preds"]}) 70 | # Log results 71 | od = Path(pl_module.hparams.output_dir) 72 | if type_path == "test": 73 | results_file = od / "test_results.txt" 74 | generations_file = od / "test_generations.txt" 75 | else: 76 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json 77 | # If people want this it will be easy enough to add back. 78 | results_file = od / \ 79 | f"{type_path}_results/{trainer.global_step:05d}.txt" 80 | generations_file = od / \ 81 | f"{type_path}_generations/{trainer.global_step:05d}.txt" 82 | results_file.parent.mkdir(exist_ok=True) 83 | generations_file.parent.mkdir(exist_ok=True) 84 | with open(results_file, "a+") as writer: 85 | for key in sorted(metrics): 86 | if key in ["log", "progress_bar", "preds"]: 87 | continue 88 | val = metrics[key] 89 | if isinstance(val, torch.Tensor): 90 | val = val.item() 91 | msg = f"{key}: {val:.6f}\n" 92 | writer.write(msg) 93 | 94 | if not save_generations: 95 | return 96 | 97 | if "preds" in metrics: 98 | content = "\n".join(metrics["preds"]) 99 | generations_file.open("w+").write(content) 100 | 101 | @rank_zero_only 102 | def on_train_start(self, trainer, pl_module): 103 | try: 104 | npars = pl_module.model.model.num_parameters() 105 | except AttributeError: 106 | npars = pl_module.model.num_parameters() 107 | 108 | n_trainable_pars = count_trainable_parameters(pl_module) 109 | # mp stands for million parameters 110 | trainer.logger.log_metrics( 111 | {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) 112 | 113 | @rank_zero_only 114 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 115 | save_json(pl_module.metrics, pl_module.metrics_save_path) 116 | return self._write_logs(trainer, pl_module, "test") 117 | 118 | @rank_zero_only 119 | def on_validation_end(self, trainer: pl.Trainer, pl_module): 120 | save_json(pl_module.metrics, pl_module.metrics_save_path) 121 | # Uncommenting this will save val generations 122 | # return self._write_logs(trainer, pl_module, "valid") 123 | -------------------------------------------------------------------------------- /baseline/mGEN/convert_dpr_retrieval_results_to_seq2seq.py: -------------------------------------------------------------------------------- 1 | from enum import auto 2 | import json 3 | import random 4 | import argparse 5 | import csv 6 | import os 7 | from tqdm import tqdm 8 | import jsonlines 9 | 10 | target_langs = ['ar', 'bn', 'fi', 'ja', 'ko', 'ru', 'te', 'en', 'es', 'km', 'ms', 'ru', 'sv', 'tr', 'zh_cn'] 11 | 12 | def read_jsonlines(eval_file_name): 13 | lines = [] 14 | print("loading examples from {0}".format(eval_file_name)) 15 | with jsonlines.open(eval_file_name) as reader: 16 | for obj in reader: 17 | lines.append(obj) 18 | return lines 19 | 20 | def load_dpr_results(pred_results, top_n=5, split="train", align_dict=None): 21 | q_c_a = [] 22 | has_answer = 0 23 | auto_nq_count = 0 24 | for item in tqdm(pred_results): 25 | question = item["question"] 26 | answers = item["answers"] 27 | ctxs = item["ctxs"] 28 | lang = item["lang"] 29 | qid = item["q_id"] 30 | for ctx in ctxs: 31 | if ctx["has_answer"] == True: 32 | has_answer += 1 33 | break 34 | if split == "train": 35 | has_answer_context = [] 36 | has_no_answer_context = [] 37 | for ctx in ctxs: 38 | if ctx["has_answer"] is True: 39 | has_answer_context.append(ctx) 40 | else: 41 | has_no_answer_context.append(ctx) 42 | if len(has_answer_context) > 3: 43 | has_answer_context = random.sample(has_answer_context, k=3) 44 | negative_context_num = top_n - len(has_answer_context) 45 | has_no_answer_context = has_no_answer_context[:negative_context_num] 46 | 47 | paragraphs = [item for item in has_answer_context] 48 | paragraphs += [item for item in has_no_answer_context] 49 | random.shuffle(paragraphs) 50 | else: 51 | paragraphs = [item for item in ctxs[:top_n]] 52 | 53 | context = "" 54 | for idx, para in enumerate(paragraphs): 55 | if len(context) > 0 and context[-1] != " ": 56 | context += " " 57 | context += "<{0}: {1}> ".format(idx, para["title"]) 58 | context += para["text"] 59 | 60 | 61 | q_c_a.append({"question": question, "answers": answers, 62 | "context": context, "lang": lang}) 63 | if split == 'train' and align_dict is not None and qid in align_dict: 64 | answer_entities = align_dict[qid] 65 | for tgt_lang in target_langs: 66 | if tgt_lang in answer_entities and random.random() > 0.5: 67 | q_c_a.append({"question": question, "answers": [answer_entities[tgt_lang]], 68 | "context": context, "lang": tgt_lang}) 69 | auto_nq_count += 1 70 | print("Generated {0} train data; {1} data includes answer string.".format( 71 | len(q_c_a), has_answer)) 72 | print("added automatically generated data {0}".format(auto_nq_count)) 73 | return q_c_a 74 | 75 | 76 | def main(): 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument("--train_fp", default=None, type=str) 79 | parser.add_argument("--dev_fp", default=None, type=str) 80 | parser.add_argument("--test_fp", default=None, type=str) 81 | parser.add_argument("--ent_fp", default=None, type=str) 82 | parser.add_argument("--output_dir", default=None, type=str) 83 | parser.add_argument("--top_n", default=5, type=int) 84 | parser.add_argument("--add_lang", action="store_true") 85 | 86 | args = parser.parse_args() 87 | 88 | if not os.path.exists(args.output_dir): 89 | os.makedirs(args.output_dir) 90 | 91 | if args.train_fp is not None: 92 | train_data = json.load(open(args.train_fp)) 93 | 94 | if args.dev_fp is not None: 95 | dev_data = json.load(open(args.dev_fp)) 96 | if args.test_fp is not None: 97 | test_data = json.load(open(args.test_fp)) 98 | 99 | if args.train_fp is not None: 100 | if args.ent_fp is not None: 101 | align_dict = json.load(open(args.ent_fp)) 102 | s2s_train = load_dpr_results(train_data, top_n=args.top_n, align_dict=align_dict) 103 | else: 104 | s2s_train = load_dpr_results(train_data, top_n=args.top_n) 105 | source_f_train = open(os.path.join( 106 | args.output_dir, "train.source"), "w") 107 | target_f_train = open(os.path.join( 108 | args.output_dir, "train.target"), "w") 109 | 110 | for item in s2s_train: 111 | if args.add_lang: 112 | source_f_train.write(": {0} [{1}]

:{2}".format( 113 | item["question"], item["lang"], item["context"]).replace("\n", "") + "\n") 114 | else: 115 | source_f_train.write(": {0}

:{1}".format( 116 | item["question"], item["context"]).replace("\n", "") + "\n") 117 | target_f_train.write(item["answers"][0].replace("\n", "") + "\n") 118 | 119 | source_f_train.close() 120 | target_f_train.close() 121 | 122 | if args.dev_fp is not None: 123 | s2s_dev = load_dpr_results(dev_data, top_n=args.top_n, split="dev") 124 | source_f_val = open(os.path.join(args.output_dir, "val.source"), "w") 125 | target_f_val = open(os.path.join(args.output_dir, "val.target"), "w") 126 | 127 | for item in s2s_dev: 128 | if args.add_lang: 129 | if args.top_n == 0: 130 | source_f_val.write(": {0} [{1}]".format( 131 | item["question"], item["lang"]).replace("\n", "") + "\n") 132 | 133 | else: 134 | source_f_val.write(": {0} [{1}]

:{2}".format( 135 | item["question"], item["lang"], item["context"]).replace("\n", "") + "\n") 136 | else: 137 | if args.top_n == 0: 138 | source_f_val.write(": {0}".format( 139 | item["question"]).replace("\n", "") + "\n") 140 | else: 141 | source_f_val.write(": {0}

:{1}".format( 142 | item["question"], item["context"]).replace("\n", "") + "\n") 143 | target_f_val.write(item["answers"][0].replace("\n", "") + "\n") 144 | 145 | source_f_val.close() 146 | target_f_val.close() 147 | 148 | with open(os.path.join(args.output_dir, "gold_para_qa_data_dev.tsv"), "w") as out_file: 149 | tsv_writer = csv.writer(out_file, delimiter='\t') 150 | for item in s2s_dev: 151 | if args.add_lang: 152 | tsv_writer.writerow([": {0} [{1}]

:{2}".format( 153 | item["question"], item["lang"], item["context"]), item["answers"]]) 154 | else: 155 | tsv_writer.writerow([": {0}

:{1}".format( 156 | item["question"], item["context"]), item["answers"]]) 157 | 158 | if args.test_fp is not None: 159 | s2s_test = load_dpr_results(test_data, top_n=args.top_n, split="test") 160 | source_f_test = open(os.path.join(args.output_dir, "test.source"), "w") 161 | 162 | for item in s2s_test: 163 | source_f_test.write(": {0}

:{1}".format( 164 | item["question"], item["context"]).replace("\n", "") + "\n") 165 | 166 | source_f_test.close() 167 | 168 | with open(os.path.join(args.output_dir, "gold_para_qa_data_test.tsv"), "w") as out_file: 169 | tsv_writer = csv.writer(out_file, delimiter='\t') 170 | for item in s2s_test: 171 | tsv_writer.writerow([": {0}

:{1}".format( 172 | item["question"], item["context"])]) 173 | 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /baseline/mGEN/eval_mgen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import logging 4 | import os 5 | import sys 6 | 7 | import pandas as pd 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from transformers import BartForConditionalGeneration, MT5ForConditionalGeneration, AutoTokenizer 12 | from transformers import logging as transformers_logging 13 | 14 | 15 | sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip 16 | from utils import exact_match_score, f1_score # noqa: E402 # isort:skip 17 | from utils import metric_max_over_ground_truths, get_scores, get_precision_at_k # noqa: E402 # isort:skip 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | logging.basicConfig(level=logging.INFO) 22 | 23 | transformers_logging.set_verbosity_info() 24 | 25 | 26 | def evaluate_batch_e2e(args, model, tokenizer, questions): 27 | with torch.no_grad(): 28 | inputs_dict = tokenizer.batch_encode_plus( 29 | questions, return_tensors="pt", padding=True, truncation=True 30 | ) 31 | 32 | input_ids = inputs_dict.input_ids.to(args.device) 33 | attention_mask = inputs_dict.attention_mask.to(args.device) 34 | outputs = model.generate( 35 | input_ids, 36 | attention_mask=attention_mask, 37 | num_beams=args.num_beams, 38 | min_length=args.min_length, 39 | max_length=args.max_length, 40 | early_stopping=False, 41 | num_return_sequences=1, 42 | # BART likes to repeat BOS tokens, dont allow it to generate more than one, 43 | bad_words_ids=[[0, 0]], 44 | output_scores=args.output_scores, 45 | return_dict_in_generate=args.output_scores 46 | ) 47 | if args.output_scores is True: 48 | sequences_scores = outputs["sequences_scores"] 49 | answers = tokenizer.batch_decode( 50 | outputs["sequences"], skip_special_tokens=True) 51 | 52 | if args.print_predictions: 53 | for q, a in zip(questions, answers): 54 | logger.info("Q: {} - A: {}".format(q, a)) 55 | 56 | return answers, sequences_scores 57 | else: 58 | answers = tokenizer.batch_decode(outputs, skip_special_tokens=True) 59 | if args.print_predictions: 60 | for q, a in zip(questions, answers): 61 | logger.info("Q: {} - A: {}".format(q, a)) 62 | return answers 63 | 64 | 65 | def get_args(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument( 68 | "--model_type", 69 | choices=["mt5", "bart"], 70 | type=str, 71 | help="model type", 72 | ) 73 | parser.add_argument( 74 | "--model_name_or_path", 75 | default=None, 76 | type=str, 77 | required=True, 78 | help="Path to pretrained checkpoints or model identifier from huggingface.co/models", 79 | ) 80 | parser.add_argument("--k", default=1, type=int, 81 | help="k for the precision@k calculation") 82 | parser.add_argument( 83 | "--evaluation_set", 84 | default=None, 85 | type=str, 86 | required=True, 87 | help="Path to a file containing evaluation samples", 88 | ) 89 | parser.add_argument( 90 | "--gold_data_path", 91 | default=None, 92 | type=str, 93 | required=True, 94 | help="Path to a tab-separated file with gold samples", 95 | ) 96 | parser.add_argument( 97 | "--gold_data_mode", 98 | default="qa", 99 | type=str, 100 | choices=["qa", "ans"], 101 | help="Format of the gold data file" 102 | "qa - a single line in the following format: question [tab] answer_list" 103 | "ans - a single line of the gold file contains the expected answer string", 104 | ) 105 | parser.add_argument( 106 | "--predictions_path", 107 | type=str, 108 | default="predictions.txt", 109 | help="Name of the predictions file, to be stored in the checkpoints directory", 110 | ) 111 | parser.add_argument( 112 | "--eval_all_checkpoints", 113 | action="store_true", 114 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", 115 | ) 116 | parser.add_argument( 117 | "--eval_batch_size", 118 | default=8, 119 | type=int, 120 | help="Batch size per GPU/CPU for evaluation.", 121 | ) 122 | parser.add_argument( 123 | "--recalculate", 124 | help="Recalculate predictions even if the prediction file exists", 125 | action="store_true", 126 | ) 127 | parser.add_argument( 128 | "--num_beams", 129 | default=4, 130 | type=int, 131 | help="Number of beams to be used when generating answers", 132 | ) 133 | parser.add_argument("--min_length", default=1, type=int, 134 | help="Min length of the generated answers") 135 | parser.add_argument("--max_length", default=50, type=int, 136 | help="Max length of the generated answers") 137 | 138 | parser.add_argument( 139 | "--print_predictions", 140 | action="store_true", 141 | help="If True, prints predictions while evaluating.", 142 | ) 143 | parser.add_argument( 144 | "--print_docs", 145 | action="store_true", 146 | help="If True, prints docs retried while generating.", 147 | ) 148 | parser.add_argument( 149 | "--output_scores", 150 | action="store_true", 151 | help="If True, output the prediction scores", 152 | ) 153 | args = parser.parse_args() 154 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 155 | return args 156 | 157 | 158 | def main(args): 159 | model_kwargs = {} 160 | if args.model_type == "bart": 161 | model_class = BartForConditionalGeneration 162 | elif args.model_type == "mt5": 163 | model_class = MT5ForConditionalGeneration 164 | else: 165 | raise NotImplementedError 166 | 167 | checkpoints = ( 168 | [f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()] 169 | if args.eval_all_checkpoints 170 | else [args.model_name_or_path] 171 | ) 172 | 173 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 174 | 175 | score_fn = get_scores 176 | evaluate_batch_fn = evaluate_batch_e2e 177 | 178 | for checkpoint in checkpoints: 179 | if os.path.exists(args.predictions_path) and (not args.recalculate): 180 | logger.info("Calculating metrics based on an existing predictions file: {}".format( 181 | args.predictions_path)) 182 | score_fn(args, args.predictions_path, args.gold_data_path) 183 | continue 184 | 185 | logger.info("***** Running evaluation for {} *****".format(checkpoint)) 186 | logger.info(" Batch size = %d", args.eval_batch_size) 187 | logger.info(" Predictions will be stored under {}".format( 188 | args.predictions_path)) 189 | 190 | model = model_class.from_pretrained(checkpoint, **model_kwargs) 191 | tokenizer = AutoTokenizer.from_pretrained( 192 | checkpoint, local_files_only=True) 193 | model.to(args.device) 194 | 195 | with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file, open("{}_score".format(args.predictions_path), "w") as preds_file_score: 196 | questions = [] 197 | for line in tqdm(eval_file): 198 | questions.append(line.strip()) 199 | if len(questions) == args.eval_batch_size: 200 | if args.output_scores is True: 201 | answers, scores = evaluate_batch_fn( 202 | args, model, tokenizer, questions) 203 | print(scores) 204 | for score in list(scores): 205 | preds_file_score.write(str(float(score))) 206 | preds_file_score.write("\n") 207 | preds_file_score.flush() 208 | else: 209 | answers = evaluate_batch_fn( 210 | args, model, tokenizer, questions) 211 | 212 | preds_file.write("\n".join(answers) + "\n") 213 | preds_file.flush() 214 | 215 | questions = [] 216 | if len(questions) > 0: 217 | if args.output_scores is True: 218 | answers, scores = evaluate_batch_fn( 219 | args, model, tokenizer, questions) 220 | for score in list(scores): 221 | preds_file_score.write(str(float(score))) 222 | preds_file_score.write("\n") 223 | preds_file_score.flush() 224 | else: 225 | answers = evaluate_batch_fn( 226 | args, model, tokenizer, questions) 227 | preds_file.write("\n".join(answers)) 228 | preds_file.flush() 229 | 230 | score_fn(args, args.predictions_path, args.gold_data_path) 231 | 232 | 233 | if __name__ == "__main__": 234 | args = get_args() 235 | main(args) 236 | -------------------------------------------------------------------------------- /baseline/mGEN/requirements.txt: -------------------------------------------------------------------------------- 1 | faiss-cpu >= 1.6.3 2 | datasets >= 1.0.1 3 | psutil >= 5.7.0 4 | torch >= 1.4.0 5 | transformers==4.2.1 6 | pytorch-lightning==1.0.4 7 | -------------------------------------------------------------------------------- /baseline/mGEN/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import linecache 4 | import os 5 | import pickle 6 | import re 7 | import socket 8 | import string 9 | from collections import Counter 10 | from logging import getLogger 11 | from pathlib import Path 12 | from typing import Callable, Dict, Iterable, List 13 | import pandas as pd 14 | import ast 15 | 16 | import git 17 | import torch 18 | from torch.utils.data import Dataset 19 | 20 | from transformers import BartTokenizer, T5Tokenizer 21 | 22 | 23 | def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"): 24 | extra_kw = {"add_prefix_space": True} if isinstance( 25 | tokenizer, BartTokenizer) and not line.startswith(" ") else {} 26 | tokenizer.padding_side = padding_side 27 | return tokenizer( 28 | [line], 29 | max_length=max_length, 30 | padding="max_length" if pad_to_max_length else None, 31 | truncation=True, 32 | return_tensors=return_tensors, 33 | add_special_tokens=True, 34 | **extra_kw, 35 | ) 36 | 37 | 38 | def trim_batch( 39 | input_ids, 40 | pad_token_id, 41 | attention_mask=None, 42 | ): 43 | """Remove columns that are populated exclusively by pad_token_id""" 44 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 45 | if attention_mask is None: 46 | return input_ids[:, keep_column_mask] 47 | else: 48 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 49 | 50 | 51 | class Seq2SeqDataset(Dataset): 52 | def __init__( 53 | self, 54 | tokenizer, 55 | data_dir, 56 | max_source_length, 57 | max_target_length, 58 | type_path="train", 59 | n_obs=None, 60 | src_lang=None, 61 | tgt_lang=None, 62 | prefix="", 63 | ): 64 | super().__init__() 65 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 66 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 67 | self.src_lens = self.get_char_lens(self.src_file) 68 | self.max_source_length = max_source_length 69 | self.max_target_length = max_target_length 70 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 71 | self.tokenizer = tokenizer 72 | self.prefix = prefix 73 | if n_obs is not None: 74 | self.src_lens = self.src_lens[:n_obs] 75 | self.src_lang = src_lang 76 | self.tgt_lang = tgt_lang 77 | 78 | def __len__(self): 79 | return len(self.src_lens) 80 | 81 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 82 | index = index + 1 # linecache starts at 1 83 | source_line = self.prefix + \ 84 | linecache.getline(str(self.src_file), index).rstrip("\n") 85 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 86 | assert source_line, f"empty source line for index {index}" 87 | assert tgt_line, f"empty tgt line for index {index}" 88 | 89 | # Need to add eos token manually for T5 90 | if isinstance(self.tokenizer, T5Tokenizer): 91 | source_line += self.tokenizer.eos_token 92 | tgt_line += self.tokenizer.eos_token 93 | 94 | # Pad source and target to the right 95 | source_tokenizer = self.tokenizer 96 | target_tokenizer = self.tokenizer 97 | 98 | if self.src_lang is not None: 99 | source_tokenizer.set_src_lang_special_tokens(self.src_lang) 100 | source_inputs = encode_line( 101 | source_tokenizer, source_line, self.max_source_length, "right") 102 | if self.tgt_lang is not None: 103 | target_tokenizer.set_tgt_lang_special_tokens(self.tgt_lang) 104 | target_inputs = encode_line( 105 | target_tokenizer, tgt_line, self.max_target_length, "right") 106 | 107 | source_ids = source_inputs["input_ids"].squeeze() 108 | target_ids = target_inputs["input_ids"].squeeze() 109 | src_mask = source_inputs["attention_mask"].squeeze() 110 | 111 | if self.src_lang is not None: 112 | source_tokenizer.set_src_lang_special_tokens(self.src_lang) 113 | 114 | return { 115 | "input_ids": source_ids, 116 | "attention_mask": src_mask, 117 | "decoder_input_ids": target_ids, 118 | } 119 | 120 | @staticmethod 121 | def get_char_lens(data_file): 122 | return [len(x) for x in Path(data_file).open().readlines()] 123 | 124 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 125 | input_ids = torch.stack([x["input_ids"] for x in batch]) 126 | masks = torch.stack([x["attention_mask"] for x in batch]) 127 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 128 | tgt_pad_token_id = self.tokenizer.pad_token_id 129 | src_pad_token_id = self.tokenizer.pad_token_id 130 | y = trim_batch(target_ids, tgt_pad_token_id) 131 | source_ids, source_mask = trim_batch( 132 | input_ids, src_pad_token_id, attention_mask=masks) 133 | batch = { 134 | "input_ids": source_ids, 135 | "attention_mask": source_mask, 136 | "decoder_input_ids": y, 137 | } 138 | return batch 139 | 140 | 141 | logger = getLogger(__name__) 142 | 143 | 144 | def flatten_list(summary_ids: List[List]): 145 | return [x for x in itertools.chain.from_iterable(summary_ids)] 146 | 147 | 148 | def save_git_info(folder_path: str) -> None: 149 | """Save git information to output_dir/git_log.json""" 150 | repo_infos = get_git_info() 151 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 152 | 153 | 154 | def save_json(content, path, indent=4, **json_dump_kwargs): 155 | with open(path, "w") as f: 156 | json.dump(content, f, indent=indent, **json_dump_kwargs) 157 | 158 | 159 | def load_json(path): 160 | with open(path) as f: 161 | return json.load(f) 162 | 163 | 164 | def get_git_info(): 165 | repo = git.Repo(search_parent_directories=True) 166 | repo_infos = { 167 | "repo_id": str(repo), 168 | "repo_sha": str(repo.head.object.hexsha), 169 | "repo_branch": str(repo.active_branch), 170 | "hostname": str(socket.gethostname()), 171 | } 172 | return repo_infos 173 | 174 | 175 | def lmap(f: Callable, x: Iterable) -> List: 176 | """list(map(f, x))""" 177 | return list(map(f, x)) 178 | 179 | 180 | def pickle_save(obj, path): 181 | """pickle.dump(obj, path)""" 182 | with open(path, "wb") as f: 183 | return pickle.dump(obj, f) 184 | 185 | 186 | def normalize_answer(s): 187 | """Lower text and remove punctuation, articles and extra whitespace.""" 188 | 189 | def remove_articles(text): 190 | return re.sub(r"\b(a|an|the)\b", " ", text) 191 | 192 | def white_space_fix(text): 193 | return " ".join(text.split()) 194 | 195 | def remove_punc(text): 196 | exclude = set(string.punctuation) 197 | return "".join(ch for ch in text if ch not in exclude) 198 | 199 | def lower(text): 200 | return text.lower() 201 | 202 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 203 | 204 | 205 | def f1_score(prediction, ground_truth): 206 | prediction_tokens = normalize_answer(prediction).split() 207 | ground_truth_tokens = normalize_answer(ground_truth).split() 208 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 209 | num_same = sum(common.values()) 210 | if num_same == 0: 211 | return 0 212 | precision = 1.0 * num_same / len(prediction_tokens) 213 | recall = 1.0 * num_same / len(ground_truth_tokens) 214 | f1 = (2 * precision * recall) / (precision + recall) 215 | return f1 216 | 217 | 218 | def exact_match_score(prediction, ground_truth): 219 | return normalize_answer(prediction) == normalize_answer(ground_truth) 220 | 221 | 222 | def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict: 223 | assert len(output_lns) == len(reference_lns) 224 | em = 0 225 | for hypo, pred in zip(output_lns, reference_lns): 226 | em += exact_match_score(hypo, pred) 227 | if len(output_lns) > 0: 228 | em /= len(output_lns) 229 | return {"em": em} 230 | 231 | 232 | def is_rag_model(model_prefix): 233 | return model_prefix.startswith("rag") 234 | 235 | 236 | def set_extra_model_params(extra_params, hparams, config): 237 | equivalent_param = {p: p for p in extra_params} 238 | # T5 models don't have `dropout` param, they have `dropout_rate` instead 239 | equivalent_param["dropout"] = "dropout_rate" 240 | for p in extra_params: 241 | if getattr(hparams, p, None): 242 | if not hasattr(config, p) and not hasattr(config, equivalent_param[p]): 243 | logger.info("config doesn't have a `{}` attribute".format(p)) 244 | delattr(hparams, p) 245 | continue 246 | set_p = p if hasattr(config, p) else equivalent_param[p] 247 | setattr(config, set_p, getattr(hparams, p)) 248 | delattr(hparams, p) 249 | return hparams, config 250 | 251 | 252 | def infer_model_type(model_name_or_path): 253 | if "token" in model_name_or_path: 254 | return "rag_token" 255 | if "sequence" in model_name_or_path: 256 | return "rag_sequence" 257 | if "bart" in model_name_or_path: 258 | return "bart" 259 | return None 260 | 261 | 262 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 263 | return max(metric_fn(prediction, gt) for gt in ground_truths) 264 | 265 | 266 | def get_scores(args, preds_path, gold_data_path): 267 | hypos = [line.strip() for line in open(preds_path, "r").readlines()] 268 | answers = [] 269 | 270 | if args.gold_data_mode == "qa": 271 | data = pd.read_csv(gold_data_path, sep="\t", header=None) 272 | for answer_list in data[1]: 273 | ground_truths = ast.literal_eval(answer_list) 274 | answers.append(ground_truths) 275 | else: 276 | references = [line.strip() 277 | for line in open(gold_data_path, "r").readlines()] 278 | answers = [[reference] for reference in references] 279 | 280 | f1 = em = total = 0 281 | for prediction, ground_truths in zip(hypos, answers): 282 | total += 1 283 | em += metric_max_over_ground_truths(exact_match_score, 284 | prediction, ground_truths) 285 | f1 += metric_max_over_ground_truths(f1_score, 286 | prediction, ground_truths) 287 | 288 | em = 100.0 * em / total 289 | f1 = 100.0 * f1 / total 290 | 291 | logger.info(f"F1: {f1:.2f}") 292 | logger.info(f"EM: {em:.2f}") 293 | 294 | 295 | def get_precision_at_k(args, preds_path, gold_data_path): 296 | k = args.k 297 | hypos = [line.strip() for line in open(preds_path, "r").readlines()] 298 | references = [line.strip() 299 | for line in open(gold_data_path, "r").readlines()] 300 | 301 | em = total = 0 302 | for hypo, reference in zip(hypos, references): 303 | hypo_provenance = set(hypo.split("\t")[:k]) 304 | ref_provenance = set(reference.split("\t")) 305 | total += 1 306 | em += len(hypo_provenance & ref_provenance) / k 307 | 308 | em = 100.0 * em / total 309 | logger.info(f"Precision@{k}: {em: .2f}") 310 | -------------------------------------------------------------------------------- /baseline/mGEN/utils_rag.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import linecache 4 | import os 5 | import pickle 6 | import re 7 | import socket 8 | import string 9 | from collections import Counter 10 | from logging import getLogger 11 | from pathlib import Path 12 | from typing import Callable, Dict, Iterable, List 13 | 14 | import git 15 | import torch 16 | from torch.utils.data import Dataset 17 | 18 | from transformers import BartTokenizer, RagTokenizer, T5Tokenizer 19 | 20 | 21 | def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"): 22 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {} 23 | tokenizer.padding_side = padding_side 24 | return tokenizer( 25 | [line], 26 | max_length=max_length, 27 | padding="max_length" if pad_to_max_length else None, 28 | truncation=True, 29 | return_tensors=return_tensors, 30 | add_special_tokens=True, 31 | **extra_kw, 32 | ) 33 | 34 | 35 | def trim_batch( 36 | input_ids, 37 | pad_token_id, 38 | attention_mask=None, 39 | ): 40 | """Remove columns that are populated exclusively by pad_token_id""" 41 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 42 | if attention_mask is None: 43 | return input_ids[:, keep_column_mask] 44 | else: 45 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 46 | 47 | 48 | class Seq2SeqDataset(Dataset): 49 | def __init__( 50 | self, 51 | tokenizer, 52 | data_dir, 53 | max_source_length, 54 | max_target_length, 55 | type_path="train", 56 | n_obs=None, 57 | src_lang=None, 58 | tgt_lang=None, 59 | prefix="", 60 | ): 61 | super().__init__() 62 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 63 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 64 | self.src_lens = self.get_char_lens(self.src_file) 65 | self.max_source_length = max_source_length 66 | self.max_target_length = max_target_length 67 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 68 | self.tokenizer = tokenizer 69 | self.prefix = prefix 70 | if n_obs is not None: 71 | self.src_lens = self.src_lens[:n_obs] 72 | self.src_lang = src_lang 73 | self.tgt_lang = tgt_lang 74 | 75 | def __len__(self): 76 | return len(self.src_lens) 77 | 78 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 79 | index = index + 1 # linecache starts at 1 80 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 81 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 82 | assert source_line, f"empty source line for index {index}" 83 | assert tgt_line, f"empty tgt line for index {index}" 84 | 85 | # Need to add eos token manually for T5 86 | if isinstance(self.tokenizer, T5Tokenizer): 87 | source_line += self.tokenizer.eos_token 88 | tgt_line += self.tokenizer.eos_token 89 | 90 | # Pad source and target to the right 91 | source_tokenizer = ( 92 | self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer 93 | ) 94 | target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer 95 | 96 | source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right") 97 | target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right") 98 | 99 | source_ids = source_inputs["input_ids"].squeeze() 100 | target_ids = target_inputs["input_ids"].squeeze() 101 | src_mask = source_inputs["attention_mask"].squeeze() 102 | return { 103 | "input_ids": source_ids, 104 | "attention_mask": src_mask, 105 | "decoder_input_ids": target_ids, 106 | } 107 | 108 | @staticmethod 109 | def get_char_lens(data_file): 110 | return [len(x) for x in Path(data_file).open().readlines()] 111 | 112 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 113 | input_ids = torch.stack([x["input_ids"] for x in batch]) 114 | masks = torch.stack([x["attention_mask"] for x in batch]) 115 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 116 | tgt_pad_token_id = ( 117 | self.tokenizer.generator.pad_token_id 118 | if isinstance(self.tokenizer, RagTokenizer) 119 | else self.tokenizer.pad_token_id 120 | ) 121 | src_pad_token_id = ( 122 | self.tokenizer.question_encoder.pad_token_id 123 | if isinstance(self.tokenizer, RagTokenizer) 124 | else self.tokenizer.pad_token_id 125 | ) 126 | y = trim_batch(target_ids, tgt_pad_token_id) 127 | source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks) 128 | batch = { 129 | "input_ids": source_ids, 130 | "attention_mask": source_mask, 131 | "decoder_input_ids": y, 132 | } 133 | return batch 134 | 135 | 136 | logger = getLogger(__name__) 137 | 138 | 139 | def flatten_list(summary_ids: List[List]): 140 | return [x for x in itertools.chain.from_iterable(summary_ids)] 141 | 142 | 143 | def save_git_info(folder_path: str) -> None: 144 | """Save git information to output_dir/git_log.json""" 145 | repo_infos = get_git_info() 146 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 147 | 148 | 149 | def save_json(content, path, indent=4, **json_dump_kwargs): 150 | with open(path, "w") as f: 151 | json.dump(content, f, indent=indent, **json_dump_kwargs) 152 | 153 | 154 | def load_json(path): 155 | with open(path) as f: 156 | return json.load(f) 157 | 158 | 159 | def get_git_info(): 160 | repo = git.Repo(search_parent_directories=True) 161 | repo_infos = { 162 | "repo_id": str(repo), 163 | "repo_sha": str(repo.head.object.hexsha), 164 | "repo_branch": str(repo.active_branch), 165 | "hostname": str(socket.gethostname()), 166 | } 167 | return repo_infos 168 | 169 | 170 | def lmap(f: Callable, x: Iterable) -> List: 171 | """list(map(f, x))""" 172 | return list(map(f, x)) 173 | 174 | 175 | def pickle_save(obj, path): 176 | """pickle.dump(obj, path)""" 177 | with open(path, "wb") as f: 178 | return pickle.dump(obj, f) 179 | 180 | 181 | def normalize_answer(s): 182 | """Lower text and remove punctuation, articles and extra whitespace.""" 183 | 184 | def remove_articles(text): 185 | return re.sub(r"\b(a|an|the)\b", " ", text) 186 | 187 | def white_space_fix(text): 188 | return " ".join(text.split()) 189 | 190 | def remove_punc(text): 191 | exclude = set(string.punctuation) 192 | return "".join(ch for ch in text if ch not in exclude) 193 | 194 | def lower(text): 195 | return text.lower() 196 | 197 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 198 | 199 | 200 | def f1_score(prediction, ground_truth): 201 | prediction_tokens = normalize_answer(prediction).split() 202 | ground_truth_tokens = normalize_answer(ground_truth).split() 203 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 204 | num_same = sum(common.values()) 205 | if num_same == 0: 206 | return 0 207 | precision = 1.0 * num_same / len(prediction_tokens) 208 | recall = 1.0 * num_same / len(ground_truth_tokens) 209 | f1 = (2 * precision * recall) / (precision + recall) 210 | return f1 211 | 212 | 213 | def exact_match_score(prediction, ground_truth): 214 | return normalize_answer(prediction) == normalize_answer(ground_truth) 215 | 216 | 217 | def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict: 218 | assert len(output_lns) == len(reference_lns) 219 | em = 0 220 | for hypo, pred in zip(output_lns, reference_lns): 221 | em += exact_match_score(hypo, pred) 222 | if len(output_lns) > 0: 223 | em /= len(output_lns) 224 | return {"em": em} 225 | 226 | 227 | def is_rag_model(model_prefix): 228 | return model_prefix.startswith("rag") 229 | 230 | 231 | def set_extra_model_params(extra_params, hparams, config): 232 | equivalent_param = {p: p for p in extra_params} 233 | # T5 models don't have `dropout` param, they have `dropout_rate` instead 234 | equivalent_param["dropout"] = "dropout_rate" 235 | for p in extra_params: 236 | if getattr(hparams, p, None): 237 | if not hasattr(config, p) and not hasattr(config, equivalent_param[p]): 238 | logger.info("config doesn't have a `{}` attribute".format(p)) 239 | delattr(hparams, p) 240 | continue 241 | set_p = p if hasattr(config, p) else equivalent_param[p] 242 | setattr(config, set_p, getattr(hparams, p)) 243 | delattr(hparams, p) 244 | return hparams, config -------------------------------------------------------------------------------- /baseline/run_evaluation.sh: -------------------------------------------------------------------------------- 1 | # download models 2 | mkdir models 3 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_all_langs_w100.tsv 4 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mgen_mia_train_data_non_iterative_augmented.zip 5 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mDPR_mia_train_data_non_iterative_biencoder_best.cpt 6 | unzip mgen_mia_train_data_non_iterative_augmented.zip 7 | mkdir embeddings 8 | cd embeddings 9 | for i in 0 1 2 3; 10 | do 11 | wget https://nlp.cs.washington.edu/xorqa/cora/models/embeddings_baseline1/wiki_emb_$i 12 | done 13 | for i in 0 1 2 3; 14 | do 15 | wget https://nlp.cs.washington.edu/xorqa/cora/models/embeddings_baseline1/wiki_emb_others_$i 16 | done 17 | cd ../.. 18 | 19 | # Run mDPR 20 | pip install transformers==3.0.2 21 | cd mDPR 22 | python dense_retriever.py \ 23 | --model_file ../models/mDPR_mia_train_data_non_iterative_biencoder_best.cpt \ 24 | --ctx_file ../models/mia2022_shared_task_all_langs_w100.tsv \ 25 | --qa_file ../data/eval/mia_2022_dev_xorqa.jsonl \ 26 | --encoded_ctx_file "../models/embeddings_baseline1/wiki_*" \ 27 | --out_file xor_dev_dpr_retrieval_results.json \ 28 | --n-docs 20 --validation_workers 1 --batch_size 256 29 | cd .. 30 | 31 | # Convert data 32 | cd mGEN 33 | python3 convert_dpr_retrieval_results_to_seq2seq.py \ 34 | --dev_fp ../mDPR/xor_dev_dpr_retrieval_results.json \ 35 | --output_dir xorqa_dev_final_retriever_results \ 36 | --top_n 15 --add_lang 37 | 38 | # Run mGEN 39 | pip install transformers==4.2.1 40 | CUDA_VISIBLE_DEVICES=0 python eval_mgen.py \ 41 | --model_name_or_path mgen_mia_train_data_non_iterative_augmented \ 42 | --evaluation_set xorqa_dev_final_retriever_results/val.source \ 43 | --gold_data_path xorqa_dev_final_retriever_results/gold_para_qa_data_dev.tsv \ 44 | --predictions_path xor_dev_final_results.txt \ 45 | --gold_data_mode qa \ 46 | --model_type mt5 \ 47 | --max_length 20 \ 48 | --eval_batch_size 4 49 | cd .. 50 | 51 | # Run evaluation 52 | cd eval_scripts 53 | python eval_xor_full.py --data_file ../data/eval/mia_2022_dev_xorqa.jsonl --pred_file ../mGEN/xor_dev_final_results.txt --txt_file 54 | -------------------------------------------------------------------------------- /baseline/run_evaluation_cora.sh: -------------------------------------------------------------------------------- 1 | # download models 2 | mkdir models 3 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_all_langs_w100.tsv 4 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mGEN_model.zip 5 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mDPR_biencoder_best.cpt 6 | unzip mGEN_model.zip 7 | mkdir embeddings 8 | cd embeddings 9 | for i in 0 1 2 3 4 5 6 7; 10 | do 11 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/embeddings/wiki_emb_en_$i 12 | done 13 | for i in 0 1 2 3 4 5 6 7; 14 | do 15 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/embeddings/wiki_emb_xor_$i 16 | done 17 | for i in 0 1 2 3 4 5 6 7; 18 | do 19 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/embeddings/wiki_others_emb__$i 20 | done 21 | for i in 0 1 2 3 4 5 6 7; 22 | do 23 | wget https://nlp.cs.washington.edu/xorqa/cora/models/mia2022_shared_task_embeddings/embeddings/wiki_others_emb_ms_tr_km_$i 24 | done 25 | cd ../.. 26 | 27 | # Run mDPR 28 | pip install transformers==3.0.2 29 | cd mDPR 30 | python dense_retriever.py \ 31 | --model_file ../models/mDPR_biencoder_best.cpt \ 32 | --ctx_file ../models/mia2022_shared_task_all_langs_w100.tsv \ 33 | --qa_file ../data/eval/mia_2022_dev_xorqa.jsonl \ 34 | --encoded_ctx_file "../models/embeddings/wiki_*" \ 35 | --out_file xor_dev_dpr_retrieval_results.json \ 36 | --n-docs 20 --validation_workers 1 --batch_size 256 37 | cd .. 38 | 39 | # Convert data 40 | cd mGEN 41 | python3 convert_dpr_retrieval_results_to_seq2seq.py \ 42 | --dev_fp ../mDPR/xor_dev_dpr_retrieval_results.json \ 43 | --output_dir xorqa_dev_final_retriever_results \ 44 | --top_n 15 \ 45 | 46 | # Run mGEN 47 | pip install transformers==4.2.1 48 | CUDA_VISIBLE_DEVICES=0 python eval_mgen.py \ 49 | --model_name_or_path mGEN_model \ 50 | --evaluation_set xorqa_dev_final_retriever_results/val.source \ 51 | --gold_data_path xorqa_dev_final_retriever_results/gold_para_qa_data_dev.tsv \ 52 | --predictions_path xor_dev_final_results.txt \ 53 | --gold_data_mode qa \ 54 | --model_type mt5 \ 55 | --max_length 20 \ 56 | --eval_batch_size 4 57 | cd .. 58 | 59 | # Run evaluation 60 | cd eval_scripts 61 | python eval_xor_full.py --data_file ../data/eval/mia_2022_dev_xorqa.jsonl --pred_file ../mGEN/xor_dev_final_results.txt --txt_file -------------------------------------------------------------------------------- /baseline/wikipedia_preprocess/README.md: -------------------------------------------------------------------------------- 1 | ## Wikipedia preprocessing code 2 | This directory contains the code to preprocess Wikipedias. 3 | First you need to download the Wikipedia dumps following [1. Download Wikipedia dumps](#1-download-wikipedia-dumps), preprocess and store the data into a sqlite DB file ([2. Store data into database](#2-store-data-into-database)), and then create a context file by splitting each article into 100 token long and write to a tsv file ([3. Create a DPR context file](#3-create-a-dpr-context-file)). 4 | 5 | ### 1. Download Wikipedia dumps 6 | First, you need to download Wikipedia dump from [the Wikimedia website](https://dumps.wikimedia.org/). They only keep the most recent dumps, so if you are looking for dumps from certain timestamps, you have to check [the archive](https://archive.org/details/wikimediadownloads). 7 | 8 | e.g., all of the related dump for Japanese Wikipedia 20190201 can be seen and downloaded [here](https://archive.org/download/jawiki-20190201). `jawiki-20190201-pages-articles-multistream.xml.bz2` includes the article text. 9 | 10 | ### Run Wikiextractor to extract plain text 11 | We usually run [Wikiextractor](https://github.com/attardi/wikiextractor) to preprocess and extract plain text data from the Wikipedia dump. 12 | 13 | ``` 14 | git clone https://github.com/attardi/wikiextractor.git 15 | cd wikiextractor 16 | python WikiExtractor.py /path/to/your/xxwiki-20190201-pages-articles-multistream.xml.bz2 --filter_disambig_pages --json -o /path/to/output/directory -s 17 | ``` 18 | 19 | you can add `-c` (`--compress`) option to compress the output files using bzip. 20 | 21 | ### 2. Store data into database 22 | You can store the processed text data into sqlite database. 23 | 24 | ``` 25 | python build_db.py /path/to/preprocessed/data/dir /path/to/db/file.db 26 | ``` 27 | 28 | ### 3. Create a DPR context file 29 | DPR first splits each article into 100-token length instead of using the original paragraphs or articles as is. Run the command below to generate a tsv file where each line contains 100-token length Wikipedia paragraphs. 30 | 31 | ``` 32 | python build_dpr_w100_data.py --db_path /path/to/db/file.db --tsv_path /path/to/output/file.tsv 33 | ``` 34 | 35 | Japanese and Thai does not use white spaces for segmentation. For those language, you need to run the special scripts below, which tokenize the input sequences and generate 100-token document chunks as in other languages. 36 | 37 | - For Japanese: [create_w100_data_japanese.py](create_w100_data_japanese.py) 38 | - For Thai: [create_w100_data_thai.py](create_w100_data_thai.py]) 39 | 40 | ## References 41 | - [List of Wikipedias](https://en.wikipedia.org/wiki/List_of_Wikipedias): you can check the statistics of each Wikipedia from the **Details table** section. 42 | - [Wikimedia Archive](https://archive.org/details/wikimediadownloads?and[]=year%3A%222019%22) 43 | -------------------------------------------------------------------------------- /baseline/wikipedia_preprocess/build_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # The codes are started from DrQA (https://github.com/facebookresearch/DrQA) library. 3 | """A script to read in and store documents in a sqlite database.""" 4 | 5 | import argparse 6 | import sqlite3 7 | import json 8 | import os 9 | import logging 10 | import importlib.util 11 | import glob 12 | import csv 13 | from utils import process_jsonlines 14 | 15 | from multiprocessing import Pool as ProcessPool 16 | from tqdm import tqdm 17 | 18 | logger = logging.getLogger() 19 | logger.setLevel(logging.INFO) 20 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 21 | console = logging.StreamHandler() 22 | console.setFormatter(fmt) 23 | logger.addHandler(console) 24 | 25 | 26 | # ------------------------------------------------------------------------------ 27 | # Import helper 28 | # ------------------------------------------------------------------------------ 29 | 30 | 31 | PREPROCESS_FN = None 32 | 33 | 34 | def init(filename): 35 | global PREPROCESS_FN 36 | if filename: 37 | PREPROCESS_FN = import_module(filename).preprocess 38 | 39 | 40 | def import_module(filename): 41 | """Import a module given a full path to the file.""" 42 | spec = importlib.util.spec_from_file_location('doc_filter', filename) 43 | module = importlib.util.module_from_spec(spec) 44 | spec.loader.exec_module(module) 45 | return module 46 | 47 | 48 | # ------------------------------------------------------------------------------ 49 | # Store corpus. 50 | # ------------------------------------------------------------------------------ 51 | 52 | 53 | def iter_files(path): 54 | """Walk through all files located under a root path.""" 55 | if os.path.isfile(path): 56 | yield path 57 | elif os.path.isdir(path): 58 | for dirpath, _, filenames in os.walk(path): 59 | for f in filenames: 60 | yield os.path.join(dirpath, f) 61 | else: 62 | raise RuntimeError('Path %s is invalid' % path) 63 | 64 | 65 | def get_contents(filename): 66 | """Parse the contents of a file. Each line is a JSON encoded document.""" 67 | global PREPROCESS_FN 68 | documents = [] 69 | extracted_items = process_jsonlines(filename) 70 | for extracted_item in extracted_items: 71 | wiki_id = extracted_item["wiki_id"] 72 | title = extracted_item["title"] 73 | text = extracted_item["text"] 74 | 75 | documents.append((title, text, wiki_id)) 76 | return documents 77 | 78 | 79 | def store_contents(wiki_dir, save_path, preprocess, num_workers=None, lang=None): 80 | """Preprocess and store a corpus of documents in sqlite. 81 | Args: 82 | data_path: Root path to directory (or directory of directories) of files 83 | containing json encoded documents (must have `id` and `text` fields). 84 | save_path: Path to output sqlite db. 85 | preprocess: Path to file defining a custom `preprocess` function. Takes 86 | in and outputs a structured doc. 87 | num_workers: Number of parallel processes to use when reading docs. 88 | """ 89 | filenames = [f for f in glob.glob( 90 | wiki_dir + "/*/wiki_*", recursive=True) if ".bz2" not in f] 91 | if os.path.isfile(save_path): 92 | raise RuntimeError('%s already exists! Not overwriting.' % save_path) 93 | 94 | logger.info('Reading into database...') 95 | conn = sqlite3.connect(save_path) 96 | c = conn.cursor() 97 | c.execute( 98 | "CREATE TABLE documents (id PRIMARY KEY, text, wiki_id);") 99 | 100 | workers = ProcessPool(num_workers, initializer=init, 101 | initargs=(preprocess,)) 102 | count = 0 103 | content_processing_method = get_contents 104 | 105 | with tqdm(total=len(filenames)) as pbar: 106 | for pairs in tqdm(workers.imap_unordered(content_processing_method, filenames)): 107 | count += len(pairs) 108 | c.executemany( 109 | "INSERT OR REPLACE INTO documents VALUES (?,?,?)", pairs) 110 | pbar.update() 111 | 112 | logger.info('Read %d docs.' % count) 113 | logger.info('Committing...') 114 | conn.commit() 115 | conn.close() 116 | 117 | # ------------------------------------------------------------------------------ 118 | # Main. 119 | # ------------------------------------------------------------------------------ 120 | 121 | 122 | if __name__ == '__main__': 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('wiki_dir', type=str, help='/path/to/data') 125 | parser.add_argument('save_path', type=str, help='/path/to/saved/db.db') 126 | parser.add_argument('--preprocess', type=str, default=None, 127 | help=('File path to a python module that defines ' 128 | 'a `preprocess` function')) 129 | parser.add_argument('--num-workers', type=int, default=None, 130 | help='Number of CPU processes (for tokenizing, etc)') 131 | parser.add_argument('--lang', type=str, default=None, 132 | help='language_code') 133 | args = parser.parse_args() 134 | 135 | store_contents( 136 | args.wiki_dir, args.save_path, args.preprocess, args.num_workers) 137 | -------------------------------------------------------------------------------- /baseline/wikipedia_preprocess/build_dpr_w100_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from doc_db import DocDB 4 | import re 5 | import json 6 | from tqdm import tqdm 7 | import csv 8 | 9 | 10 | def collect_text_data(db_path, start_idx, separate=False): 11 | count = start_idx 12 | if separate is True: 13 | lang_doc_data = {} 14 | for path in db_path: 15 | db = DocDB(path) 16 | doc_ids = db.get_doc_ids() 17 | if separate is True: 18 | lang_doc_data[path] = [] 19 | else: 20 | doc_data = [] 21 | for doc_id in tqdm(doc_ids): 22 | sections_paras = db.get_doc_text_section_separations(doc_id) 23 | title = doc_id 24 | if "_0" in doc_id: 25 | title = doc_id.split("_0")[0] 26 | para_text = "" 27 | for section in sections_paras: 28 | paragraphs = section["paragraphs"] 29 | for para_idx, para in enumerate(paragraphs): 30 | para_text += para 31 | para_text += " " 32 | if len(para_text) > 0 and para_text[-1] == " ": 33 | para_text = para_text[:-1] 34 | para_tokens = para_text.split() 35 | if len(para_tokens) < 20: 36 | continue 37 | for i in range(len(para_tokens) // 100): 38 | w100_para_text = " ".join(para_tokens[100*i:100*(i+1)]) 39 | if separate is True: 40 | lang_doc_data[path].append( 41 | {"title": title, "id": count, "text": w100_para_text}) 42 | else: 43 | doc_data.append( 44 | {"title": title, "id": count, "text": w100_para_text}) 45 | count += 1 46 | # store the last part if the remaining part is longer than 20. 47 | if len(para_tokens) % 100 > 20: 48 | w100_para_text = " ".join( 49 | para_tokens[100*(len(para_tokens) // 100):]) 50 | if separate is True: 51 | lang_doc_data[path].append( 52 | {"title": title, "id": count, "text": w100_para_text}) 53 | else: 54 | doc_data.append( 55 | {"title": title, "id": count, "text": w100_para_text}) 56 | count += 1 57 | if separate is True: 58 | print("collected {0} data from {1}".format( 59 | len(lang_doc_data[path]), path)) 60 | print("collected {} data".format(count)) 61 | 62 | if separate is True: 63 | return lang_doc_data 64 | else: 65 | return doc_data 66 | 67 | 68 | def write_para_data_to_tsv(input_data, output_fn): 69 | with open(output_fn, 'wt') as tsv_file: 70 | tsv_writer = csv.writer(tsv_file, delimiter='\t', lineterminator='\n') 71 | for item in tqdm(input_data): 72 | tsv_writer.writerow([item["id"], item["text"], item["title"]]) 73 | print("wrote full data to {}".format(output_fn)) 74 | 75 | 76 | def write_para_data_to_tsvs(input_data, output_dir): 77 | for k, para_data in input_data.items(): 78 | db_path_name = os.path.basename(k) 79 | lang_code = db_path_name[:2] 80 | 81 | with open(os.path.join(output_dir, "{}_wiki_w100.tsv".format(lang_code)), 'wt') as tsv_file: 82 | tsv_writer = csv.writer( 83 | tsv_file, delimiter='\t', lineterminator='\n') 84 | for item in tqdm(para_data): 85 | tsv_writer.writerow([item["id"], item["text"], item["title"]]) 86 | print("wrote {0} {1}'s data to {2}".format(len(para_data), lang_code, os.path.join( 87 | output_dir, "{}_wiki_w100.tsv".format(lang_code)))) 88 | print(para_data[-1]) 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('--db_path', type=str, required=True, nargs='+', 94 | help='Path to sqlite db holding document texts') 95 | parser.add_argument('--tsv_path', type=str, required=True, 96 | help='output tsv file name') 97 | parser.add_argument('--separate', action="store_true") 98 | parser.add_argument('--tsv_dir', type=str) 99 | parser.add_argument('--start_idx', type=int, default=0) 100 | 101 | args = parser.parse_args() 102 | input_data = collect_text_data( 103 | args.db_path, args.start_idx, separate=args.separate) 104 | if args.separate is True: 105 | if not os.path.exists(args.tsv_dir): 106 | os.makedirs(args.tsv_dir) 107 | write_para_data_to_tsvs(input_data, args.tsv_dir) 108 | else: 109 | write_para_data_to_tsv(input_data, args.tsv_path) 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /baseline/wikipedia_preprocess/create_w100_data_japanese.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from doc_db import DocDB 4 | import re 5 | import json 6 | from tqdm import tqdm 7 | import csv 8 | import spacy 9 | 10 | nlp = spacy.load("ja_core_news_sm") 11 | 12 | 13 | def tokenize_japanese_text(text): 14 | tokens = [] 15 | doc = nlp(text) 16 | for token in doc: 17 | tokens.append(token.text) 18 | return tokens 19 | 20 | 21 | def collect_text_data(db_path, start_idx): 22 | db = DocDB(db_path) 23 | doc_ids = db.get_doc_ids() 24 | count = start_idx 25 | doc_data = [] 26 | for doc_id in tqdm(doc_ids): 27 | sections_paras = db.get_doc_text_section_separations(doc_id) 28 | title = doc_id.split("_0")[0] 29 | para_tokens = [] 30 | for section in sections_paras: 31 | paragraphs = section["paragraphs"] 32 | for para_idx, para in enumerate(paragraphs): 33 | para_tokens += tokenize_japanese_text(para) 34 | # skip articles whose para token is less than 20. 35 | if len(para_tokens) < 20: 36 | continue 37 | for i in range(len(para_tokens) // 100): 38 | w100_para_text = "".join(para_tokens[100*i:100*(i+1)]) 39 | doc_data.append( 40 | {"title": title, "id": count, "text": w100_para_text}) 41 | count += 1 42 | if len(para_tokens) % 100 > 20: 43 | w100_para_text = " ".join( 44 | para_tokens[100*(len(para_tokens) // 100):]) 45 | doc_data.append( 46 | {"title": title, "id": count, "text": w100_para_text}) 47 | count += 1 48 | 49 | print("collected {} data".format(count)) 50 | 51 | return doc_data, count 52 | 53 | 54 | def write_para_data_to_tsv(input_data, output_fn): 55 | with open(output_fn, 'wt') as tsv_file: 56 | tsv_writer = csv.writer(tsv_file, delimiter='\t', lineterminator='\n') 57 | for item in tqdm(input_data): 58 | tsv_writer.writerow([item["id"], item["text"], item["title"]]) 59 | print("wrote full data to {}".format(output_fn)) 60 | 61 | 62 | def main(): 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--db_path', type=str, required=True, 65 | help='Path to sqlite db holding document texts') 66 | parser.add_argument('--tsv_path', type=str, required=True, 67 | help='output tsv file name') 68 | parser.add_argument('--start_idx', type=int, default=0) 69 | 70 | args = parser.parse_args() 71 | input_data = collect_text_data(args.db_path, args.start_idx) 72 | write_para_data_to_tsv(input_data, args.tsv_path) 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /baseline/wikipedia_preprocess/create_w100_data_khmer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from doc_db import DocDB 4 | import re 5 | import json 6 | from tqdm import tqdm 7 | import csv 8 | 9 | from khmernltk import word_tokenize as km_tokenizer 10 | 11 | 12 | def tokenize_km_text(text): 13 | tokens = km_tokenizer(text) 14 | tokens = [token for token in tokens if token != " "] 15 | return tokens 16 | 17 | 18 | def collect_text_data(db_path, start_idx): 19 | db = DocDB(db_path) 20 | doc_ids = db.get_doc_ids() 21 | count = start_idx 22 | doc_data = [] 23 | for doc_id in tqdm(doc_ids): 24 | sections_paras = db.get_doc_text_section_separations(doc_id) 25 | title = doc_id.split("_0")[0] 26 | para_tokens = [] 27 | for section in sections_paras: 28 | paragraphs = section["paragraphs"] 29 | for para_idx, para in enumerate(paragraphs): 30 | para_tokens += tokenize_km_text(para) 31 | # skip articles whose para token is less than 20. 32 | if len(para_tokens) < 20: 33 | continue 34 | for i in range(len(para_tokens) // 100): 35 | w100_para_text = " ".join(para_tokens[100*i:100*(i+1)]) 36 | doc_data.append( 37 | {"title": title, "id": count, "text": w100_para_text}) 38 | count += 1 39 | if len(para_tokens) % 100 > 20: 40 | w100_para_text = " ".join( 41 | para_tokens[100*(len(para_tokens) // 100):]) 42 | doc_data.append( 43 | {"title": title, "id": count, "text": w100_para_text}) 44 | count += 1 45 | 46 | print("collected {} data".format(count)) 47 | 48 | return doc_data, count 49 | 50 | 51 | def write_para_data_to_tsv(input_data, output_fn): 52 | with open(output_fn, 'wt') as tsv_file: 53 | tsv_writer = csv.writer(tsv_file, delimiter='\t', lineterminator='\n') 54 | for item in tqdm(input_data): 55 | tsv_writer.writerow([item["id"], item["text"], item["title"]]) 56 | print("wrote full data to {}".format(output_fn)) 57 | 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--db_path', type=str, required=True, 62 | help='Path to sqlite db holding document texts') 63 | parser.add_argument('--tsv_path', type=str, required=True, 64 | help='output tsv file name') 65 | parser.add_argument('--start_idx', type=int, default=27019719) 66 | 67 | args = parser.parse_args() 68 | input_data, count = collect_text_data(args.db_path, args.start_idx) 69 | write_para_data_to_tsv(input_data, args.tsv_path) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /baseline/wikipedia_preprocess/create_w100_data_thai.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from doc_db import DocDB 4 | import re 5 | import json 6 | from tqdm import tqdm 7 | import csv 8 | 9 | from pythainlp.tokenize import word_tokenize 10 | 11 | 12 | def tokenize_thai_text(text): 13 | tokens = [] 14 | tokens = word_tokenize(text, engine="newmm") 15 | return tokens 16 | 17 | 18 | def collect_text_data(db_path, start_idx): 19 | db = DocDB(db_path) 20 | doc_ids = db.get_doc_ids() 21 | count = start_idx 22 | doc_data = [] 23 | for doc_id in tqdm(doc_ids): 24 | sections_paras = db.get_doc_text_section_separations(doc_id) 25 | title = doc_id.split("_0")[0] 26 | para_tokens = [] 27 | for section in sections_paras: 28 | paragraphs = section["paragraphs"] 29 | for para_idx, para in enumerate(paragraphs): 30 | para_tokens += tokenize_thai_text(para) 31 | # skip articles whose para token is less than 20. 32 | if len(para_tokens) < 20: 33 | continue 34 | for i in range(len(para_tokens) // 100): 35 | w100_para_text = "".join(para_tokens[100*i:100*(i+1)]) 36 | doc_data.append( 37 | {"title": title, "id": count, "text": w100_para_text}) 38 | count += 1 39 | if len(para_tokens) % 100 > 20: 40 | w100_para_text = " ".join( 41 | para_tokens[100*(len(para_tokens) // 100):]) 42 | doc_data.append( 43 | {"title": title, "id": count, "text": w100_para_text}) 44 | count += 1 45 | 46 | print("collected {} data".format(count)) 47 | 48 | return doc_data, count 49 | 50 | 51 | def write_para_data_to_tsv(input_data, output_fn): 52 | with open(output_fn, 'wt') as tsv_file: 53 | tsv_writer = csv.writer(tsv_file, delimiter='\t', lineterminator='\n') 54 | for item in tqdm(input_data): 55 | tsv_writer.writerow([item["id"], item["text"], item["title"]]) 56 | print("wrote full data to {}".format(output_fn)) 57 | 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--db_path', type=str, required=True, 62 | help='Path to sqlite db holding document texts') 63 | parser.add_argument('--tsv_path', type=str, required=True, 64 | help='output tsv file name') 65 | parser.add_argument('--start_idx', type=int, default=27019719) 66 | 67 | args = parser.parse_args() 68 | input_data, count = collect_text_data(args.db_path, args.start_idx) 69 | write_para_data_to_tsv(input_data, args.tsv_path) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /baseline/wikipedia_preprocess/doc_db.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import argparse 3 | import time 4 | import re 5 | from utils import remove_tags, normalize 6 | 7 | 8 | class DocDB(object): 9 | """Sqlite backed document storage. 10 | 11 | Implements get_doc_text(doc_id). 12 | """ 13 | 14 | def __init__(self, db_path=None): 15 | self.path = db_path 16 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 17 | 18 | def __enter__(self): 19 | return self 20 | 21 | def __exit__(self, *args): 22 | self.close() 23 | 24 | def close(self): 25 | """Close the connection to the database.""" 26 | self.connection.close() 27 | 28 | def get_doc_ids(self): 29 | """Fetch all ids of docs stored in the db.""" 30 | cursor = self.connection.cursor() 31 | cursor.execute("SELECT id FROM documents") 32 | results = [r[0] for r in cursor.fetchall()] 33 | cursor.close() 34 | return results 35 | 36 | def get_doc_ids_lang(self, lang): 37 | """Fetch all ids of docs stored in the db.""" 38 | cursor = self.connection.cursor() 39 | cursor.execute("SELECT id FROM documents where lang = ?", (lang,)) 40 | results = [r[0] for r in cursor.fetchall()] 41 | cursor.close() 42 | return results 43 | 44 | def get_doc_text(self, doc_id): 45 | """Fetch the raw text of the doc for 'doc_id'.""" 46 | cursor = self.connection.cursor() 47 | cursor.execute( 48 | "SELECT text FROM documents WHERE id = ?", 49 | (doc_id,) 50 | ) 51 | result = cursor.fetchone() 52 | cursor.close() 53 | return result if result is None else result[0] 54 | 55 | def get_doc_text_section_separations(self, doc_id): 56 | # WIP: we might have better formats to keep the information. 57 | """ 58 | fetch all of the paragraphs with section level separations 59 | e.g., 60 | >>> sectioned_paragraphs = db.get_doc_text_hyper_linked_titles_for_articles("Tokyo Imperial Palace_0") 61 | >>> sectioned_paragraphs[0] 62 | {"section_name":"Early life and sumo background.", 63 | "parent_section_name": None:, 64 | "paragraphs": ["Tatsu Ryōya was born in Kanazawa, Ishikawa and is the youngest of three children. 65 | His father was a truck driver. He was a normal-sized baby but grew quickly so that when 66 | attending kindergarten he had difficulty fitting into the uniform. He first began 67 | practicing sumo whilst in the first grade of elementary school.", 68 | "By the age of thirteen, when he ended his 69 | first year at junior high school he stood , and weighed . 70 | After competing successfully in junior high school sumo he gave up formal education 71 | at the age of fifteen and entered the Takadagawa stable to pursue a professional career." 72 | "type": "child"} 73 | """ 74 | cursor = self.connection.cursor() 75 | cursor.execute( 76 | "SELECT text FROM documents WHERE id = ?", 77 | (doc_id,) 78 | ) 79 | result = cursor.fetchone() 80 | cursor.close() 81 | if result is None: 82 | return [] 83 | else: 84 | output_data = [] 85 | section_separated_context = result[0].split("Section::::") 86 | 87 | parent_section = "" 88 | for s_idx, section in enumerate(section_separated_context): 89 | # the first sections are likely to be introductory paragraphs. 90 | if s_idx == 0 and len(section.split("\n\n")) > 1 and len(section.split("\n\n")[1]) > 0: 91 | section_name = "Introduction" 92 | parent_section = "Introduction" 93 | output_data.append( 94 | {"section_name": section_name, "paragraphs": section.split("\n\n")[1:], 95 | "type": "intro", "parent_section_name": parent_section}) 96 | else: 97 | section_name = re.compile( 98 | "(.*)\n").search(section).group(1) 99 | section_text = re.sub("(.*)\n", "", section, 1) 100 | if len(section_text) == 0: 101 | # this is section header 102 | parent_section = section_name 103 | output_data.append({"section_name": section_name, "paragraphs": [], 104 | "type": "parent", "parent_section_name": None}) 105 | else: 106 | output_data.append({"section_name": section_name, "paragraphs": [para for para in section_text.split("\n\n") if len(para) > 10], 107 | "type": "child", "parent_section_name": parent_section}) 108 | return output_data 109 | -------------------------------------------------------------------------------- /baseline/wikipedia_preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | import jsonlines 3 | import re 4 | from urllib.parse import unquote 5 | import regex 6 | import numpy as np 7 | import scipy.sparse as sp 8 | from sklearn.utils import murmurhash3_32 9 | 10 | import logging 11 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 12 | datefmt='%m/%d/%Y %H:%M:%S', 13 | level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def normalize(text): 18 | """Resolve different type of unicode encodings / capitarization in HotpotQA data.""" 19 | text = unicodedata.normalize('NFD', text) 20 | return text[0].capitalize() + text[1:] 21 | 22 | 23 | def make_wiki_id(title, para_index): 24 | title_id = "{0}_{1}".format(normalize(title), para_index) 25 | return title_id 26 | 27 | 28 | def find_hyper_linked_titles(text_w_links): 29 | titles = re.findall(r'href=[\'"]?([^\'" >]+)', text_w_links) 30 | titles = [unquote(title) for title in titles] 31 | titles = [title[0].capitalize() + title[1:] for title in titles] 32 | return titles 33 | 34 | 35 | TAG_RE = re.compile(r'<[^>]+>') 36 | 37 | 38 | def remove_tags(text): 39 | return TAG_RE.sub('', text) 40 | 41 | 42 | def process_jsonlines(filename): 43 | # item should be nested list 44 | extracted_items = [] 45 | with jsonlines.open(filename) as reader: 46 | for obj in reader: 47 | wiki_id = obj["id"] 48 | title = obj["title"] 49 | text = obj["text"] 50 | 51 | extracted_items.append({"wiki_id": wiki_id, "title": title, 52 | "text": text}) 53 | 54 | return extracted_items 55 | 56 | # ------------------------------------------------------------------------------ 57 | # Sparse matrix saving/loading helpers. 58 | # ------------------------------------------------------------------------------ 59 | 60 | 61 | def save_sparse_csr(filename, matrix, metadata=None): 62 | data = { 63 | 'data': matrix.data, 64 | 'indices': matrix.indices, 65 | 'indptr': matrix.indptr, 66 | 'shape': matrix.shape, 67 | 'metadata': metadata, 68 | } 69 | np.savez(filename, **data) 70 | 71 | 72 | def load_sparse_csr(filename): 73 | loader = np.load(filename, allow_pickle=True) 74 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 75 | loader['indptr']), shape=loader['shape']) 76 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 77 | 78 | # ------------------------------------------------------------------------------ 79 | # Token hashing. 80 | # ------------------------------------------------------------------------------ 81 | 82 | 83 | def hash(token, num_buckets): 84 | """Unsigned 32 bit murmurhash for feature hashing.""" 85 | return murmurhash3_32(token, positive=True) % num_buckets 86 | 87 | # ------------------------------------------------------------------------------ 88 | # Text cleaning. 89 | # ------------------------------------------------------------------------------ 90 | 91 | 92 | STOPWORDS = { 93 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 94 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 95 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 96 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 97 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 98 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 99 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 100 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 101 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 102 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 103 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 104 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 105 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 106 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 107 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 108 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 109 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 110 | } 111 | 112 | 113 | def filter_word(text): 114 | """Take out english stopwords, punctuation, and compound endings.""" 115 | text = normalize(text) 116 | if regex.match(r'^\p{P}+$', text): 117 | return True 118 | if text.lower() in STOPWORDS: 119 | return True 120 | return False 121 | 122 | 123 | def filter_ngram(gram, mode='any'): 124 | """Decide whether to keep or discard an n-gram. 125 | Args: 126 | gram: list of tokens (length N) 127 | mode: Option to throw out ngram if 128 | 'any': any single token passes filter_word 129 | 'all': all tokens pass filter_word 130 | 'ends': book-ended by filterable tokens 131 | """ 132 | filtered = [filter_word(w) for w in gram] 133 | if mode == 'any': 134 | return any(filtered) 135 | elif mode == 'all': 136 | return all(filtered) 137 | elif mode == 'ends': 138 | return filtered[0] or filtered[-1] 139 | else: 140 | raise ValueError('Invalid mode: %s' % mode) 141 | 142 | 143 | def get_field(d, field_list): 144 | """get the subfield associated to a list of elastic fields 145 | E.g. ['file', 'filename'] to d['file']['filename'] 146 | """ 147 | if isinstance(field_list, str): 148 | return d[field_list] 149 | else: 150 | idx = d.copy() 151 | for field in field_list: 152 | idx = idx[field] 153 | return idx 154 | 155 | 156 | def load_para_collections_from_tfidf_id_intro_only(tfidf_id, db): 157 | if "_0" not in tfidf_id: 158 | tfidf_id = "{0}_0".format(tfidf_id) 159 | if db.get_doc_text(tfidf_id) is None: 160 | logger.warning("{0} is missing".format(tfidf_id)) 161 | return [] 162 | return [[tfidf_id, db.get_doc_text(tfidf_id).split("\t")]] 163 | 164 | 165 | def load_linked_titles_from_tfidf_id(tfidf_id, db): 166 | para_titles = db.get_paras_with_article(tfidf_id) 167 | linked_titles_all = [] 168 | for para_title in para_titles: 169 | linked_title_per_para = db.get_hyper_linked(para_title) 170 | if len(linked_title_per_para) > 0: 171 | linked_titles_all += linked_title_per_para.split("\t") 172 | return linked_titles_all 173 | 174 | 175 | def load_para_and_linked_titles_dict_from_tfidf_id(tfidf_id, db): 176 | """ 177 | load paragraphs and hyperlinked titles from DB. 178 | This method is mainly for Natural Questions Open benchmark. 179 | """ 180 | # will be fixed in the later version; current tfidf weights use indexed titles as keys. 181 | if "_0" not in tfidf_id: 182 | tfidf_id = "{0}_0".format(tfidf_id) 183 | paras, linked_titles = db.get_doc_text_hyper_linked_titles_for_articles( 184 | tfidf_id) 185 | if len(paras) == 0: 186 | logger.warning("{0} is missing".format(tfidf_id)) 187 | return [], [] 188 | 189 | paras_dict = {} 190 | linked_titles_dict = {} 191 | article_name = tfidf_id.split("_0")[0] 192 | # store the para_dict and linked_titles_dict; skip the first para (title) 193 | for para_idx, (para, linked_title_list) in enumerate(zip(paras[1:], linked_titles[1:])): 194 | paras_dict["{0}_{1}".format(article_name, para_idx)] = para 195 | linked_titles_dict["{0}_{1}".format( 196 | article_name, para_idx)] = linked_title_list 197 | 198 | return paras_dict, linked_titles_dict 199 | 200 | 201 | def prune_top_k_paragraphs(question_text, paragraphs, tfidf_vectorizer, pruning_l=10): 202 | para_titles, para_text = list(paragraphs.keys()), list(paragraphs.values()) 203 | # prune top l paragraphs using the question as query to reduce the search space. 204 | top_tfidf_para_indices = tfidf_vectorizer.prune( 205 | question_text, para_text)[:pruning_l] 206 | para_title_text_pairs_pruned = {} 207 | 208 | # store the selected paras into dictionary. 209 | for idx in top_tfidf_para_indices: 210 | para_title_text_pairs_pruned[para_titles[idx]] = para_text[idx] 211 | 212 | return para_title_text_pairs_pruned 213 | -------------------------------------------------------------------------------- /data/eval/mkqa_dev.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mia-workshop/MIA-Shared-Task-2022/ea2491a3e3a588e270b2dd30ce7e044584c3eeb5/data/eval/mkqa_dev.zip -------------------------------------------------------------------------------- /data/eval/mkqa_test_without_answers.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mia-workshop/MIA-Shared-Task-2022/ea2491a3e3a588e270b2dd30ce7e044584c3eeb5/data/eval/mkqa_test_without_answers.zip -------------------------------------------------------------------------------- /data/train/mia_2022_train_data.jsonl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mia-workshop/MIA-Shared-Task-2022/ea2491a3e3a588e270b2dd30ce7e044584c3eeb5/data/train/mia_2022_train_data.jsonl.zip -------------------------------------------------------------------------------- /eval_scripts/eval_mkqa_all.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | from statistics import mean 4 | import os 5 | from tqdm import tqdm 6 | from nltk.translate import bleu 7 | import MeCab 8 | from collections import Counter 9 | import string 10 | import re 11 | import argparse 12 | import sys 13 | from pythainlp.tokenize import word_tokenize as th_tokenizer 14 | from khmernltk import word_tokenize as km_tokenizer 15 | import jieba.posseg as pseg 16 | 17 | 18 | 19 | wakati = MeCab.Tagger("-Owakati") 20 | 21 | lang_dic = {'telugu': 'te', 'swahili': 'sw', 'thai': 'th', 'finnish': 'fi', 'indonesian': 'id', 22 | 'japanese': 'ja', 'russian': 'ru', 'arabic': 'ar', 'english': 'en', 'bengali': 'bn', 23 | "korean": "ko", "spanish": "es", "hebrew": "he", "swedish": "sv", "danish": "da", "german": "de", 24 | "hungarian": "hu", "italian": "it", "khmer": "km", "malay": "ms", "dutch": "nl", 25 | "norwegian": "no", "portuguese": "pt", "turkish": "tr", "vietnamese": "vi", "french": "fr", "polish": "pl", 26 | "chinese (simplified)": "zh_cn", "chinese (hong kong)": 'zh_hk', "chinese (traditional)": "zh_tw", "tamil": "ta", "tagalog": "tl"} 27 | 28 | langs = ['tr', 'hu', 'zh_hk', 'nl', 'ms', 'zh_cn', 'ja', 'de', 'ru', 'pl', 'fi', 'pt', 'km', 29 | 'it', 'fr', 'he', 'vi', 'zh_tw', 'no', 'da', 'th', 'sv', 'es', 'ar', 'en', 'ko', 'en', "ta", "tl"] 30 | 31 | def tokenize_th_text(text): 32 | tokens = th_tokenizer(text, engine="newmm") 33 | tokens = [token for token in tokens if token != " "] 34 | return " ".join(tokens) 35 | 36 | def tokenize_zh_text(text): 37 | tokens = pseg.cut(text) 38 | tokens = [w.word for w in tokens] 39 | tokens = [token for token in tokens if token != " "] 40 | return " ".join(tokens) 41 | 42 | def tokenize_km_text(text): 43 | tokens = km_tokenizer(text) 44 | tokens = [token for token in tokens if token != " "] 45 | return " ".join(tokens) 46 | 47 | def read_jsonlines(eval_file_name): 48 | lines = [] 49 | print("loading examples from {0}".format(eval_file_name)) 50 | with jsonlines.open(eval_file_name) as reader: 51 | for obj in reader: 52 | lines.append(obj) 53 | return lines 54 | 55 | 56 | def load_tydi_answer(tydi_eval_open_domain_data): 57 | answer_dict = {} 58 | eval_data = read_jsonlines(tydi_eval_open_domain_data) 59 | for item in eval_data: 60 | answer_dict[item["id"]] = item["answers"] 61 | return answer_dict 62 | 63 | 64 | def normalize_answer(s): 65 | # TODO: should we keep those counter removal? 66 | def remove_counter(text): 67 | return text.replace("年", "").replace("歳", "").replace("人", "").replace("년", "") 68 | 69 | def white_space_fix(text): 70 | return ' '.join(text.split()) 71 | 72 | def remove_punc(text): 73 | exclude = set(string.punctuation) 74 | return ''.join(ch for ch in text if ch not in exclude) 75 | 76 | def lower(text): 77 | return text.lower() 78 | 79 | return white_space_fix(remove_counter(remove_punc(lower(s)))) 80 | 81 | 82 | def exact_match_score(prediction, ground_truth): 83 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 84 | 85 | 86 | def f1_score(prediction, ground_truth): 87 | prediction_tokens = normalize_answer(prediction).split() 88 | ground_truth_tokens = normalize_answer(ground_truth).split() 89 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 90 | num_same = sum(common.values()) 91 | 92 | if num_same == 0: 93 | return 0 94 | precision = 1.0 * num_same / len(prediction_tokens) 95 | recall = 1.0 * num_same / len(ground_truth_tokens) 96 | f1 = (2 * precision * recall) / (precision + recall) 97 | return f1 98 | 99 | 100 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 101 | scores_for_ground_truths = [] 102 | for ground_truth in ground_truths: 103 | score = metric_fn(prediction, ground_truth) 104 | scores_for_ground_truths.append(score) 105 | return max(scores_for_ground_truths) 106 | 107 | 108 | # 3. XOR-Full Evaluation 109 | def calculate_f1_em_bleu(dataset, predictions): 110 | lang_dict = {lang: {"count": 0, "f1": 0, "bleu": 0, "em": 0} 111 | for lang in lang_dic.values()} 112 | 113 | for qa in dataset: 114 | lang = qa["lang"] 115 | gts = qa["answers"] 116 | if gts[0] == "No Answer": 117 | continue 118 | q_id = qa["id"] 119 | 120 | lang_dict[lang]["count"] += 1 121 | if q_id not in predictions: 122 | print("no answers") 123 | continue 124 | pred = predictions[q_id] 125 | if isinstance(gts, str): 126 | gts = [gts] 127 | 128 | final_gts = [] 129 | # for the languages where white spaces are not widely used for word tokenization, we use the same word tokenizers on both targets and predictions and calculate word-level F1. 130 | if lang == "ja": 131 | for gt in gts: 132 | gt = wakati.parse(gt) 133 | final_gts.append(gt) 134 | final_pred = wakati.parse(pred.replace("・", " ").replace("、", ",")) 135 | elif lang == "zh_cn" or lang == "zh_hk" or lang == "zh_tw": 136 | for gt in gts: 137 | gt = tokenize_zh_text(gt) 138 | final_gts.append(gt) 139 | final_pred = tokenize_zh_text(pred) 140 | elif lang == "th": 141 | for gt in gts: 142 | gt = tokenize_th_text(gt) 143 | final_gts.append(gt) 144 | final_pred = tokenize_th_text(pred) 145 | elif lang == "km": 146 | for gt in gts: 147 | gt = tokenize_km_text(gt) 148 | final_gts.append(gt) 149 | final_pred = tokenize_km_text(pred) 150 | else: 151 | final_gts = gts 152 | final_pred = pred 153 | lang_dict[lang]["f1"] += metric_max_over_ground_truths( 154 | f1_score, final_pred, final_gts) 155 | lang_dict[lang]["bleu"] += bleu(final_gts, pred) 156 | lang_dict[lang]["em"] += metric_max_over_ground_truths( 157 | exact_match_score, final_pred, final_gts) 158 | # finalize scores 159 | for lang, scores in lang_dict.items(): 160 | if scores["count"] == 0: 161 | continue 162 | for score_key in scores: 163 | if "count" != score_key: 164 | lang_dict[lang][score_key] = scores[score_key]/scores["count"] 165 | return lang_dict 166 | 167 | 168 | def main(): 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("--data_dir", 171 | default=None, type=str) 172 | parser.add_argument("--pred_dir", 173 | default=None, type=str) 174 | parser.add_argument("--txt_file", action="store_true") 175 | parser.add_argument("--target", type=str, nargs='+') 176 | 177 | args = parser.parse_args() 178 | # load dpr results 179 | results_all = {} 180 | for lang in tqdm(langs): 181 | if lang not in args.target: 182 | continue 183 | dataset = read_jsonlines(os.path.join(args.data_dir, "mkqa-{}.jsonl".format(lang))) 184 | # fix file path 185 | # need to fix the file path 186 | predictions = json.load(open(os.path.join(args.pred_dir, "mkqa_pred_{}.json".format(lang)))) 187 | 188 | results = calculate_f1_em_bleu(dataset, predictions) 189 | results_all[lang] = results[lang] 190 | 191 | f1_total, em_total, bleu_total = 0.0, 0.0, 0.0 192 | total_num = 0 193 | lang_count = 0 194 | 195 | for lang in results_all: 196 | if results_all[lang]["count"] == 0: 197 | continue 198 | lang_count += 1 199 | f1_total += results_all[lang]["f1"] 200 | em_total += results_all[lang]["em"] 201 | bleu_total += results_all[lang]["bleu"] 202 | total_num += results_all[lang]["count"] 203 | print("Evaluating the performance on {0} for {1} examples".format( 204 | lang, results_all[lang]["count"])) 205 | print("F1: {0}, EM:{1}, BLEU:{2}".format( 206 | results_all[lang]["f1"] * 100, results_all[lang]["em"] * 100, results_all[lang]["bleu"] * 100)) 207 | print("avg f1: {}".format(f1_total / lang_count * 100)) 208 | print("avg em: {}".format(em_total / lang_count * 100)) 209 | 210 | 211 | if __name__ == "__main__": 212 | main() 213 | -------------------------------------------------------------------------------- /eval_scripts/eval_xor_full.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import jsonlines 3 | import json 4 | from statistics import mean 5 | import requests 6 | import csv 7 | import urllib 8 | import glob 9 | import os 10 | from tqdm import tqdm 11 | from nltk.translate import bleu 12 | import MeCab 13 | from collections import Counter 14 | import string 15 | import re 16 | import argparse 17 | 18 | wakati = MeCab.Tagger("-Owakati") 19 | 20 | lang_dic = {'telugu': 'te', 'swahili': 'sw', 'thai': 'th', 'finnish': 'fi', 'indonesian': 'id', 21 | 'japanese': 'ja', 'russian': 'ru', 'arabic': 'ar', 'english': 'en', 'bengali': 'bn', 22 | "korean": "ko", "spanish": "es", "hebrew": "he", "swedish": "sv", "danish": "da", "german": "de", 23 | "hungarian": "hu", "italian": "it", "khmer": "km", "malay": "ms", "dutch": "nl", 24 | "norwegian": "no", "portuguese": "pt", "turkish": "tr", "vietnamese": "vi", "french": "fr", "polish": "pl", 25 | "chinese (simplified)": "zh_cn", "chinese (hong kong)": 'zh_hk', "chinese (traditional)": "zh_tw", "tamil": "ta", "tagalog": "tl"} 26 | 27 | 28 | def read_jsonlines(eval_file_name): 29 | lines = [] 30 | print("loading examples from {0}".format(eval_file_name)) 31 | with jsonlines.open(eval_file_name) as reader: 32 | for obj in reader: 33 | lines.append(obj) 34 | return lines 35 | 36 | 37 | def load_tydi_answer(tydi_eval_open_domain_data): 38 | answer_dict = {} 39 | eval_data = read_jsonlines(tydi_eval_open_domain_data) 40 | for item in eval_data: 41 | answer_dict[item["id"]] = item["answers"] 42 | return answer_dict 43 | 44 | 45 | def normalize_answer(s): 46 | # TODO: should we keep those counter removal? 47 | def remove_counter(text): 48 | return text.replace("年", "").replace("歳", "").replace("人", "").replace("년", "") 49 | 50 | def white_space_fix(text): 51 | return ' '.join(text.split()) 52 | 53 | def remove_punc(text): 54 | exclude = set(string.punctuation) 55 | return ''.join(ch for ch in text if ch not in exclude) 56 | 57 | def lower(text): 58 | return text.lower() 59 | 60 | return white_space_fix(remove_counter(remove_punc(lower(s)))) 61 | 62 | 63 | def exact_match_score(prediction, ground_truth): 64 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 65 | 66 | 67 | def f1_score(prediction, ground_truth): 68 | prediction_tokens = normalize_answer(prediction).split() 69 | ground_truth_tokens = normalize_answer(ground_truth).split() 70 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 71 | num_same = sum(common.values()) 72 | 73 | if num_same == 0: 74 | return 0 75 | precision = 1.0 * num_same / len(prediction_tokens) 76 | recall = 1.0 * num_same / len(ground_truth_tokens) 77 | f1 = (2 * precision * recall) / (precision + recall) 78 | return f1 79 | 80 | 81 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 82 | scores_for_ground_truths = [] 83 | for ground_truth in ground_truths: 84 | score = metric_fn(prediction, ground_truth) 85 | scores_for_ground_truths.append(score) 86 | return max(scores_for_ground_truths) 87 | 88 | 89 | # 3. XOR-Full Evaluation 90 | def calculate_f1_em_bleu(dataset, predictions): 91 | lang_dict = {lang: {"count": 0, "f1": 0, "bleu": 0, "em": 0} 92 | for lang in lang_dic.values()} 93 | 94 | for qa in dataset: 95 | lang = qa["lang"] 96 | q_id = qa["id"] 97 | gts = qa["answers"] 98 | if gts[0] == "No Answer": 99 | continue 100 | lang_dict[lang]["count"] += 1 101 | if q_id not in predictions: 102 | print(q_id) 103 | print("no answers") 104 | continue 105 | pred = predictions[q_id] 106 | if isinstance(gts, str): 107 | gts = [gts] 108 | 109 | final_gts = [] 110 | # for japanese, we need to tokenize the input as there are no white spaces. 111 | if lang == "ja": 112 | for gt in gts: 113 | gt = wakati.parse(gt) 114 | final_gts.append(gt) 115 | final_pred = wakati.parse(pred.replace("・", " ").replace("、", ",")) 116 | else: 117 | final_gts = gts 118 | final_pred = pred 119 | lang_dict[lang]["f1"] += metric_max_over_ground_truths( 120 | f1_score, final_pred, final_gts) 121 | lang_dict[lang]["bleu"] += bleu(final_gts, pred) 122 | lang_dict[lang]["em"] += metric_max_over_ground_truths( 123 | exact_match_score, final_pred, final_gts) 124 | # finalize scores 125 | for lang, scores in lang_dict.items(): 126 | if scores["count"] == 0: 127 | continue 128 | for score_key in scores: 129 | if "count" != score_key: 130 | lang_dict[lang][score_key] = scores[score_key]/scores["count"] 131 | return lang_dict 132 | 133 | 134 | def main(): 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument("--data_file", 137 | default=None, type=str) 138 | parser.add_argument("--pred_file", 139 | default=None, type=str) 140 | 141 | args = parser.parse_args() 142 | 143 | dataset = read_jsonlines(args.data_file) 144 | with open(args.pred_file) as prediction_file: 145 | predictions = json.load(prediction_file) 146 | 147 | results = calculate_f1_em_bleu(dataset, predictions) 148 | 149 | f1_total, em_total, bleu_total = 0.0, 0.0, 0.0 150 | total_num = 0 151 | lang_count = 0 152 | for lang in results: 153 | if results[lang]["count"] == 0: 154 | continue 155 | lang_count += 1 156 | f1_total += results[lang]["f1"] 157 | em_total += results[lang]["em"] 158 | bleu_total += results[lang]["bleu"] 159 | total_num += results[lang]["count"] 160 | print("Evaluating the performance on {0} for {1} examples".format( 161 | lang, results[lang]["count"])) 162 | print("F1: {0}, EM:{1}, BLEU:{2}".format( 163 | results[lang]["f1"] * 100, results[lang]["em"] * 100, results[lang]["bleu"] * 100)) 164 | print("avg f1: {}".format(f1_total / lang_count * 100)) 165 | print("avg em: {}".format(em_total / lang_count * 100)) 166 | print("avg bleu: {}".format(bleu_total / lang_count * 100)) 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /eval_scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | jieba 2 | khmer-nltk 3 | pythainlp 4 | mecab-python3 5 | --------------------------------------------------------------------------------