├── model ├── SEED_Encoder │ ├── __init__.py │ ├── setup.py │ ├── config_decoder_1_attn_8.json │ ├── config_decoder_3_attn_2.json │ ├── SEED-Encoder.md │ ├── configuration_seed_encoder.py │ ├── modeling_seed_encoder.py │ └── tokenization_seed_encoder.py └── models.py ├── CODE_OF_CONDUCT.md ├── setup.py ├── commands ├── run_train_warmup.sh ├── run_train_dpr.sh ├── run_inference.sh ├── run_ann_data_gen_dpr.sh ├── run_ann_data_gen.sh ├── data_download.sh └── run_train.sh ├── LICENSE ├── .gitignore ├── SECURITY.md ├── utils ├── lamb.py ├── eval_mrr.py ├── msmarco_eval.py ├── dpr_utils.py └── util.py ├── data ├── process_fn.py ├── msmarco_data.py └── DPR_data.py ├── README.md ├── evaluation └── Calculate Metrics.ipynb └── drivers └── run_ann_data_gen_dpr.py /model/SEED_Encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenization_seed_encoder import * 2 | from .configuration_seed_encoder import * 3 | from .modeling_seed_encoder import * -------------------------------------------------------------------------------- /model/SEED_Encoder/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('SEED-Encoder.md') as f: 4 | readme = f.read() 5 | 6 | setup( 7 | name='SEED-Encoder', 8 | long_description=readme, 9 | install_requires=[ 10 | 'scikit-learn', 11 | 'pandas', 12 | 'tensorboardX', 13 | 'tqdm', 14 | 'tokenizers==0.9.2', 15 | 'six', 16 | ], 17 | ) -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('README.md') as f: 4 | readme = f.read() 5 | 6 | setup( 7 | name='ANCE', 8 | version='0.1.0', 9 | description='Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval', 10 | url='https://github.com/microsoft/ANCE', 11 | classifiers=[ 12 | 'Intended Audience :: Science/Research', 13 | 'License :: OSI Approved :: MIT License', 14 | 'Programming Language :: Python :: 3.6', 15 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 16 | ], 17 | license="MIT", 18 | long_description=readme, 19 | install_requires=[ 20 | 'transformers==2.3.0', 21 | 'pytrec-eval', 22 | 'faiss-cpu', 23 | 'wget', 24 | ], 25 | ) -------------------------------------------------------------------------------- /commands/run_train_warmup.sh: -------------------------------------------------------------------------------- 1 | # This script is for training the warmup checkpoint for ANCE 2 | data_dir="../data/raw_data/" 3 | output_dir="" 4 | cmd="python3 -m torch.distributed.launch --nproc_per_node=1 ../drivers/run_warmup.py --train_model_type rdot_nll \ 5 | --model_name_or_path roberta-base \ 6 | --task_name MSMarco --do_train --evaluate_during_training --data_dir ${data_dir} --max_seq_length 128 --per_gpu_eval_batch_size=256 \ 7 | --per_gpu_train_batch_size=32 --learning_rate 2e-4 --logging_steps 1000 --num_train_epochs 2.0 --output_dir ${output_dir} \ 8 | --warmup_steps 1000 --overwrite_output_dir --save_steps 30000 --gradient_accumulation_steps 1 --expected_train_size 35000000 --logging_steps_per_eval 20 \ 9 | --fp16 --optimizer lamb --log_dir ~/tensorboard/${DLWS_JOB_ID}/logs/OSpass " 10 | 11 | echo $cmd 12 | eval $cmd 13 | -------------------------------------------------------------------------------- /model/SEED_Encoder/config_decoder_1_attn_8.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "SEEDEncoderForMaskedLM" 4 | ], 5 | "pad_token_id" : 1, 6 | "vocab_size" : 32769, 7 | "encoder_layers" : 12, 8 | "encoder_embed_dim" : 768, 9 | "encoder_ffn_embed_dim" : 3072, 10 | "encoder_attention_heads" : 12, 11 | 12 | "dropout" : 0.1, 13 | "attention_dropout" : 0.1, 14 | "activation_dropout" : 0.0, 15 | "encoder_layerdrop" : 0.0, 16 | "max_positions" : 512, 17 | "activation_fn" : "gelu", 18 | "quant_noise_pq" : 0.0, 19 | "quant_noise_pq_block_size" : 8, 20 | 21 | 22 | "train_ratio" : "0.5:0.5", 23 | "decoder_atten_window" : 8, 24 | "pooler_activation_fn" : "tanh", 25 | "pooler_dropout" : 0.0, 26 | 27 | 28 | 29 | "decoder_layers" : 1, 30 | 31 | "decoder_embed_dim" : 768, 32 | "decoder_ffn_embed_dim" : 3072, 33 | "decoder_attention_heads" : 12, 34 | 35 | "attention_dropout" : 0.1, 36 | "activation_dropout" : 0.0, 37 | 38 | "adaptive_softmax_dropout" : 0 39 | } -------------------------------------------------------------------------------- /model/SEED_Encoder/config_decoder_3_attn_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "SEEDEncoderForMaskedLM" 4 | ], 5 | "pad_token_id" : 1, 6 | "vocab_size" : 32769, 7 | "encoder_layers" : 12, 8 | "encoder_embed_dim" : 768, 9 | "encoder_ffn_embed_dim" : 3072, 10 | "encoder_attention_heads" : 12, 11 | 12 | "dropout" : 0.1, 13 | "attention_dropout" : 0.1, 14 | "activation_dropout" : 0.0, 15 | "encoder_layerdrop" : 0.0, 16 | "max_positions" : 512, 17 | "activation_fn" : "gelu", 18 | "quant_noise_pq" : 0.0, 19 | "quant_noise_pq_block_size" : 8, 20 | 21 | 22 | "train_ratio" : "0.5:0.5", 23 | "decoder_atten_window" : 2, 24 | "pooler_activation_fn" : "tanh", 25 | "pooler_dropout" : 0.0, 26 | 27 | 28 | 29 | "decoder_layers" : 3, 30 | 31 | "decoder_embed_dim" : 768, 32 | "decoder_ffn_embed_dim" : 3072, 33 | "decoder_attention_heads" : 12, 34 | 35 | "attention_dropout" : 0.1, 36 | "activation_dropout" : 0.0, 37 | 38 | "adaptive_softmax_dropout" : 0 39 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /commands/run_train_dpr.sh: -------------------------------------------------------------------------------- 1 | gpu_no=8 2 | 3 | # model type 4 | model_type="dpr" 5 | seq_length=256 6 | triplet="--triplet --optimizer lamb" # set this to empty for non triplet model 7 | 8 | # hyper parameters 9 | batch_size=16 10 | gradient_accumulation_steps=1 11 | learning_rate=1e-5 12 | warmup_steps=1000 13 | 14 | # input/output directories 15 | base_data_dir="../data/QA_NQ_data/" 16 | job_name="ann_NQ_test" 17 | model_dir="${base_data_dir}${job_name}/" 18 | model_ann_data_dir="${model_dir}ann_data/" 19 | pretrained_checkpoint_dir="../../../DPR/checkpoint/retriever/multiset/bert-base-encoder.cp" 20 | 21 | train_cmd="\ 22 | sudo python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_dpr.py --model_type $model_type \ 23 | --model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco $triplet --data_dir $base_data_dir \ 24 | --ann_dir $model_ann_data_dir --max_seq_length $seq_length --per_gpu_train_batch_size=$batch_size \ 25 | --gradient_accumulation_steps $gradient_accumulation_steps --learning_rate $learning_rate --output_dir $model_dir \ 26 | --warmup_steps $warmup_steps --logging_steps 100 --save_steps 1000 --log_dir "~/tensorboard/${DLWS_JOB_ID}/logs/${job_name}" \ 27 | " 28 | 29 | echo $train_cmd 30 | eval $train_cmd 31 | 32 | echo "copy current script to model directory" 33 | sudo cp $0 $model_dir -------------------------------------------------------------------------------- /commands/run_inference.sh: -------------------------------------------------------------------------------- 1 | # # Passage ANCE(FirstP) 2 | gpu_no=4 3 | seq_length=512 4 | model_type=rdot_nll 5 | tokenizer_type="roberta-base" 6 | base_data_dir="../data/raw_data/" 7 | preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}_dev/" 8 | job_name="OSPass512" 9 | pretrained_checkpoint_dir="" 10 | 11 | # # Document ANCE(FirstP) 12 | # gpu_no=4 13 | # seq_length=512 14 | # model_type=rdot_nll 15 | # tokenizer_type="roberta-base" 16 | # base_data_dir="../data/raw_data/" 17 | # preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 18 | # job_name="OSDoc512" 19 | # pretrained_checkpoint_dir="" 20 | 21 | # # Document ANCE(MaxP) 22 | # gpu_no=4 23 | # seq_length=2048 24 | # model_type=rdot_nll_multi_chunk 25 | # tokenizer_type="roberta-base" 26 | # base_data_dir="../data/raw_data/" 27 | # preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 28 | # job_name="OSDoc2048" 29 | # pretrained_checkpoint_dir="" 30 | 31 | ##################################### Inference ################################ 32 | model_dir="${base_data_dir}${job_name}/" 33 | model_ann_data_dir="${model_dir}ann_data_inf/" 34 | 35 | initial_data_gen_cmd="\ 36 | python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $pretrained_checkpoint_dir \ 37 | --init_model_dir $pretrained_checkpoint_dir --model_type $model_type --output_dir $model_ann_data_dir \ 38 | --cache_dir "${model_ann_data_dir}cache/" --data_dir $preprocessed_data_dir --max_seq_length $seq_length \ 39 | --per_gpu_eval_batch_size 16 --topk_training 200 --negative_sample 20 --end_output_num 0 --inference \ 40 | " 41 | 42 | echo $initial_data_gen_cmd 43 | eval $initial_data_gen_cmd 44 | -------------------------------------------------------------------------------- /commands/run_ann_data_gen_dpr.sh: -------------------------------------------------------------------------------- 1 | # tokenization 2 | wiki_dir="../../../DPR/data/wikipedia_split/" # path for psgs_w100.tsv downloaded with DPR code 3 | ans_dir="../../../DPR/data/retriever/qas/" # path for DPR question&answer csv files 4 | question_dir="../../../DPR/data/retriever/" # path for DPR training data 5 | data_type=0 #0 is nq, 1 is trivia, 2 is both 6 | out_data_dir="../data/QA_NQ_data/" # change this for different data_type 7 | 8 | tokenization_cmd="\ 9 | python ../data/DPR_data.py --wiki_dir $wiki_dir --question_dir $question_dir --data_type $data_type --answer_dir $ans_dir \ 10 | --out_data_dir $out_data_dir \ 11 | " 12 | 13 | echo $tokenization_cmd 14 | eval $tokenization_cmd 15 | 16 | 17 | gpu_no=8 18 | 19 | # model type 20 | model_type="dpr" 21 | seq_length=256 22 | 23 | # ann parameters 24 | batch_size=16 25 | ann_topk=200 26 | ann_negative_sample=100 27 | 28 | # input/output directories 29 | base_data_dir="${out_data_dir}" 30 | job_name="ann_NQ_test" 31 | model_dir="${base_data_dir}${job_name}/" 32 | model_ann_data_dir="${model_dir}ann_data/" 33 | pretrained_checkpoint_dir="../../../DPR/checkpoint/retriever/multiset/bert-base-encoder.cp" 34 | passage_path="../../../DPR/data/wikipedia_split/" 35 | test_qa_path="../../../DPR/data/retriever/qas/" 36 | trivia_test_qa_path="../../../DPR/data/retriever/qas/" 37 | 38 | 39 | data_gen_cmd="\ 40 | sudo python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen_dpr.py --training_dir $model_dir \ 41 | --init_model_dir $pretrained_checkpoint_dir --model_type $model_type --output_dir $model_ann_data_dir \ 42 | --cache_dir "${model_ann_data_dir}cache/" --data_dir $base_data_dir --max_seq_length $seq_length \ 43 | --per_gpu_eval_batch_size $batch_size --topk_training $ann_topk --negative_sample $ann_negative_sample \ 44 | --passage_path $passage_path --test_qa_path $test_qa_path --trivia_test_qa_path $trivia_test_qa_path \ 45 | " 46 | 47 | echo $data_gen_cmd 48 | eval $data_gen_cmd -------------------------------------------------------------------------------- /commands/run_ann_data_gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script is for generate ann data for a model in training 4 | # 5 | # For the overall design of the ann driver, check run_train.sh 6 | # 7 | # This script continuously generate ann data using latest model from model_dir 8 | # For training, run this script after initial ann data is created from run_train.sh 9 | # Make sure parameter used here is consistent with the training script 10 | 11 | # # Passage ANCE(FirstP) 12 | # gpu_no=4 13 | # seq_length=512 14 | # model_type=rdot_nll 15 | # tokenizer_type="roberta-base" 16 | # base_data_dir="../data/raw_data/" 17 | # preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 18 | # job_name="OSPass512" 19 | 20 | 21 | # # Document ANCE(FirstP) 22 | # gpu_no=4 23 | # seq_length=512 24 | # model_type=rdot_nll 25 | # tokenizer_type="roberta-base" 26 | # base_data_dir="../data/raw_data/" 27 | # preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 28 | # job_name="OSDoc512" 29 | 30 | # # Document ANCE(MaxP) 31 | gpu_no=4 32 | seq_length=2048 33 | model_type=rdot_nll_multi_chunk 34 | tokenizer_type="roberta-base" 35 | base_data_dir="../data/raw_data/" 36 | preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 37 | job_name="OSDoc2048" 38 | 39 | ##################################### Inital ANN Data generation ################################ 40 | model_dir="${base_data_dir}${job_name}/" 41 | model_ann_data_dir="${model_dir}ann_data/" 42 | pretrained_checkpoint_dir="warmup checkpoint path" 43 | 44 | initial_data_gen_cmd="\ 45 | python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \ 46 | --init_model_dir $pretrained_checkpoint_dir --model_type $model_type --output_dir $model_ann_data_dir \ 47 | --cache_dir "${model_ann_data_dir}cache/" --data_dir $preprocessed_data_dir --max_seq_length $seq_length \ 48 | --per_gpu_eval_batch_size 16 --topk_training 200 --negative_sample 20 \ 49 | " 50 | 51 | echo $initial_data_gen_cmd 52 | eval $initial_data_gen_cmd 53 | -------------------------------------------------------------------------------- /commands/data_download.sh: -------------------------------------------------------------------------------- 1 | mkdir ../data/raw_data/ 2 | cd ../data/raw_data/ 3 | 4 | # download MSMARCO passage data 5 | wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz 6 | tar -zxvf collectionandqueries.tar.gz 7 | rm collectionandqueries.tar.gz 8 | 9 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz 10 | gunzip msmarco-passagetest2019-top1000.tsv.gz 11 | 12 | wget https://msmarco.blob.core.windows.net/msmarcoranking/top1000.dev.tar.gz 13 | tar -zxvf top1000.dev.tar.gz 14 | rm top1000.dev.tar.gz 15 | 16 | wget https://msmarco.blob.core.windows.net/msmarcoranking/triples.train.small.tar.gz 17 | tar -zxvf triples.train.small.tar.gz 18 | rm triples.train.small.tar.gz 19 | 20 | # download MSMARCO doc data 21 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz 22 | gunzip msmarco-docs.tsv.gz 23 | 24 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz 25 | gunzip msmarco-doctrain-queries.tsv.gz 26 | 27 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz 28 | gunzip msmarco-doctrain-qrels.tsv.gz 29 | 30 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz 31 | gunzip msmarco-test2019-queries.tsv.gz 32 | 33 | wget https://trec.nist.gov/data/deep/2019qrels-docs.txt 34 | 35 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctest2019-top100.gz 36 | gunzip msmarco-doctest2019-top100.gz 37 | 38 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-top100.gz 39 | gunzip msmarco-docdev-top100.gz 40 | 41 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz 42 | gunzip msmarco-docdev-queries.tsv.gz 43 | 44 | 45 | # clone DPR repo and download NQ and TriviaQA datasets 46 | cd ../../../ 47 | git clone https://github.com/facebookresearch/DPR 48 | cd DPR 49 | python data/download_data.py --resource data.wikipedia_split.psgs_w100 50 | python data/download_data.py --resource data.retriever.nq 51 | python data/download_data.py --resource data.retriever.trivia 52 | python data/download_data.py --resource data.retriever.qas.nq 53 | python data/download_data.py --resource data.retriever.qas.trivia 54 | python data/download_data.py --resource checkpoint.retriever.multiset.bert-base-encoder 55 | 56 | 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | commands/runs/ 4 | runs/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /commands/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script is for training with updated ann driver 4 | # 5 | # The design for this ann driver is to have 2 separate processes for training: one for passage/query 6 | # inference using trained checkpoint to generate ann data and calcuate ndcg, another for training the model 7 | # using the ann data generated. Data between processes is shared on common directory, model_dir for checkpoints 8 | # and model_ann_data_dir for ann data. 9 | # 10 | # This script initialize the training and start the model training process 11 | # It first preprocess the msmarco data into indexable cache, then generate a single initial ann data 12 | # version to train on, after which it start training on the generated ann data, continously looking for 13 | # newest ann data generated in model_ann_data_dir 14 | # 15 | # To start training, you'll need to run this script first 16 | # after intial ann data is created (you can tell by either finding "successfully created 17 | # initial ann training data" in console output or if you start seeing new model on tensorboard), 18 | # start run_ann_data_gen.sh in another dlts job (or same dlts job using split GPU) 19 | # 20 | # Note if preprocess directory or ann data directory already exist, those steps will be skipped 21 | # and training will start immediately 22 | 23 | # # Passage ANCE(FirstP) 24 | # gpu_no=4 25 | # seq_length=512 26 | # model_type=rdot_nll 27 | # tokenizer_type="roberta-base" 28 | # base_data_dir="../data/raw_data/" 29 | # preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 30 | # job_name="OSPass512" 31 | # pretrained_checkpoint_dir="warmup or trained checkpoint path" 32 | # data_type=1 33 | # warmup_steps=5000 34 | # per_gpu_train_batch_size=8 35 | # gradient_accumulation_steps=2 36 | # learning_rate=1e-6 37 | 38 | # # Document ANCE(FirstP) 39 | # gpu_no=4 40 | # seq_length=512 41 | # tokenizer_type="roberta-base" 42 | # model_type=rdot_nll 43 | # base_data_dir="../data/raw_data/" 44 | # preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 45 | # job_name="OSDoc512" 46 | # pretrained_checkpoint_dir="warmup or trained checkpoint path" 47 | # data_type=0 48 | # warmup_steps=3000 49 | # per_gpu_train_batch_size=8 50 | # gradient_accumulation_steps=2 51 | # learning_rate=5e-6 52 | 53 | # # Document ANCE(MaxP) 54 | gpu_no=8 55 | seq_length=2048 56 | tokenizer_type="roberta-base" 57 | model_type=rdot_nll_multi_chunk 58 | base_data_dir="../data/raw_data/" 59 | preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 60 | job_name="OSDoc2048" 61 | pretrained_checkpoint_dir="warmup or trained checkpoint path" 62 | data_type=0 63 | warmup_steps=500 64 | per_gpu_train_batch_size=2 65 | gradient_accumulation_steps=8 66 | learning_rate=1e-5 67 | 68 | ##################################### Data Preprocessing ################################ 69 | model_dir="${base_data_dir}${job_name}/" 70 | model_ann_data_dir="${model_dir}ann_data/" 71 | 72 | preprocess_cmd="\ 73 | python ../data/msmarco_data.py --data_dir $base_data_dir --out_data_dir $preprocessed_data_dir --model_type $model_type \ 74 | --model_name_or_path roberta-base --max_seq_length $seq_length --data_type $data_type\ 75 | " 76 | 77 | echo $preprocess_cmd 78 | eval $preprocess_cmd 79 | 80 | if [[ $? = 0 ]]; then 81 | echo "successfully created preprocessed data" 82 | else 83 | echo "preprocessing failed" 84 | echo "failure: $?" 85 | exit 1 86 | fi 87 | 88 | ##################################### Inital ANN Data generation ################################ 89 | initial_data_gen_cmd="\ 90 | python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \ 91 | --init_model_dir $pretrained_checkpoint_dir --model_type $model_type --output_dir $model_ann_data_dir \ 92 | --cache_dir "${model_ann_data_dir}cache/" --data_dir $preprocessed_data_dir --max_seq_length $seq_length \ 93 | --per_gpu_eval_batch_size 16 --topk_training 200 --negative_sample 20 --end_output_num 0 \ 94 | " 95 | 96 | echo $initial_data_gen_cmd 97 | eval $initial_data_gen_cmd 98 | 99 | if [[ $? = 0 ]]; then 100 | echo "successfully created initial ann training data" 101 | else 102 | echo "initial data generation failed" 103 | echo "failure: $?" 104 | exit 1 105 | fi 106 | 107 | ############################################# Training ######################################## 108 | train_cmd="\ 109 | python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann.py --model_type $model_type \ 110 | --model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco --triplet --data_dir $preprocessed_data_dir \ 111 | --ann_dir $model_ann_data_dir --max_seq_length $seq_length --per_gpu_train_batch_size=$per_gpu_train_batch_size \ 112 | --gradient_accumulation_steps $gradient_accumulation_steps --learning_rate $learning_rate --output_dir $model_dir \ 113 | --warmup_steps $warmup_steps --logging_steps 100 --save_steps 10000 --optimizer lamb --single_warmup \ 114 | " 115 | 116 | echo $train_cmd 117 | eval $train_cmd -------------------------------------------------------------------------------- /utils/lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | import collections 4 | import math 5 | 6 | import torch 7 | from tensorboardX import SummaryWriter 8 | from torch.optim import Optimizer 9 | 10 | 11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 12 | """Log a histogram of trust ratio scalars in across layers.""" 13 | results = collections.defaultdict(list) 14 | for group in optimizer.param_groups: 15 | for p in group['params']: 16 | state = optimizer.state[p] 17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 18 | if i in state: 19 | results[i].append(state[i]) 20 | 21 | for k, v in results.items(): 22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 23 | 24 | class Lamb(Optimizer): 25 | r"""Implements Lamb algorithm. 26 | 27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 28 | 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | eps (float, optional): term added to the denominator to improve 36 | numerical stability (default: 1e-8) 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | adam (bool, optional): always use trust ratio = 1, which turns this into 39 | Adam. Useful for comparison purposes. 40 | 41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 42 | https://arxiv.org/abs/1904.00962 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 46 | weight_decay=0, adam=False): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | defaults = dict(lr=lr, betas=betas, eps=eps, 56 | weight_decay=weight_decay) 57 | self.adam = adam 58 | super(Lamb, self).__init__(params, defaults) 59 | 60 | def step(self, closure=None): 61 | """Performs a single optimization step. 62 | 63 | Arguments: 64 | closure (callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 78 | 79 | state = self.state[p] 80 | 81 | # State initialization 82 | if len(state) == 0: 83 | state['step'] = 0 84 | # Exponential moving average of gradient values 85 | state['exp_avg'] = torch.zeros_like(p.data) 86 | # Exponential moving average of squared gradient values 87 | state['exp_avg_sq'] = torch.zeros_like(p.data) 88 | 89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | # Decay the first and second moment running average coefficient 95 | # m_t 96 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 97 | # v_t 98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 99 | 100 | # Paper v3 does not use debiasing. 101 | # Apply bias to lr to avoid broadcast. 102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 103 | 104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 105 | 106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 107 | if group['weight_decay'] != 0: 108 | adam_step.add_(group['weight_decay'], p.data) 109 | 110 | adam_norm = adam_step.pow(2).sum().sqrt() 111 | if weight_norm == 0 or adam_norm == 0: 112 | trust_ratio = 1 113 | else: 114 | trust_ratio = weight_norm / adam_norm 115 | state['weight_norm'] = weight_norm 116 | state['adam_norm'] = adam_norm 117 | state['trust_ratio'] = trust_ratio 118 | if self.adam: 119 | trust_ratio = 1 120 | 121 | p.data.add_(-step_size * trust_ratio, adam_step) 122 | 123 | return loss 124 | -------------------------------------------------------------------------------- /data/process_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pad_ids(input_ids, attention_mask, token_type_ids, max_length, pad_token, mask_padding_with_zero, pad_token_segment_id, pad_on_left=False): 5 | padding_length = max_length - len(input_ids) 6 | if pad_on_left: 7 | input_ids = ([pad_token] * padding_length) + input_ids 8 | attention_mask = ([0 if mask_padding_with_zero else 1] 9 | * padding_length) + attention_mask 10 | token_type_ids = ([pad_token_segment_id] * 11 | padding_length) + token_type_ids 12 | else: 13 | input_ids += [pad_token] * padding_length 14 | attention_mask += [0 if mask_padding_with_zero else 1] * padding_length 15 | token_type_ids += [pad_token_segment_id] * padding_length 16 | 17 | return input_ids, attention_mask, token_type_ids 18 | 19 | 20 | def dual_process_fn(line, i, tokenizer, args): 21 | features = [] 22 | cells = line.split("\t") 23 | if len(cells) == 2: 24 | # this is for training and validation 25 | # id, passage = line 26 | mask_padding_with_zero = True 27 | pad_token_segment_id = 0 28 | pad_on_left = False 29 | 30 | text = cells[1].strip() 31 | input_id_a = tokenizer.encode( 32 | text, add_special_tokens=True, max_length=args.max_seq_length,) 33 | token_type_ids_a = [0] * len(input_id_a) 34 | attention_mask_a = [ 35 | 1 if mask_padding_with_zero else 0] * len(input_id_a) 36 | input_id_a, attention_mask_a, token_type_ids_a = pad_ids( 37 | input_id_a, attention_mask_a, token_type_ids_a, args.max_seq_length, tokenizer.pad_token_id, mask_padding_with_zero, pad_token_segment_id, pad_on_left) 38 | features += [torch.tensor(input_id_a, dtype=torch.int), torch.tensor( 39 | attention_mask_a, dtype=torch.bool), torch.tensor(token_type_ids_a, dtype=torch.uint8)] 40 | qid = int(cells[0]) 41 | features.append(qid) 42 | else: 43 | raise Exception( 44 | "Line doesn't have correct length: {0}. Expected 2.".format(str(len(cells)))) 45 | return [features] 46 | 47 | 48 | def triple_process_fn(line, i, tokenizer, args): 49 | features = [] 50 | cells = line.split("\t") 51 | if len(cells) == 3: 52 | # this is for training and validation 53 | # query, positive_passage, negative_passage = line 54 | mask_padding_with_zero = True 55 | pad_token_segment_id = 0 56 | pad_on_left = False 57 | 58 | for text in cells: 59 | input_id_a = tokenizer.encode( 60 | text.strip(), add_special_tokens=True, max_length=args.max_seq_length,) 61 | token_type_ids_a = [0] * len(input_id_a) 62 | attention_mask_a = [ 63 | 1 if mask_padding_with_zero else 0] * len(input_id_a) 64 | input_id_a, attention_mask_a, token_type_ids_a = pad_ids( 65 | input_id_a, attention_mask_a, token_type_ids_a, args.max_seq_length, tokenizer.pad_token_id, mask_padding_with_zero, pad_token_segment_id, pad_on_left) 66 | features += [torch.tensor(input_id_a, dtype=torch.int), 67 | torch.tensor(attention_mask_a, dtype=torch.bool)] 68 | else: 69 | raise Exception( 70 | "Line doesn't have correct length: {0}. Expected 3.".format(str(len(cells)))) 71 | return [features] 72 | 73 | 74 | def triple2dual_process_fn(line, i, tokenizer, args): 75 | ret = [] 76 | cells = line.split("\t") 77 | if len(cells) == 3: 78 | # this is for training and validation 79 | # query, positive_passage, negative_passage = line 80 | # return 2 entries per line, 1 pos + 1 neg 81 | mask_padding_with_zero = True 82 | pad_token_segment_id = 0 83 | pad_on_left = False 84 | pos_feats = [] 85 | neg_feats = [] 86 | 87 | for i, text in enumerate(cells): 88 | input_id_a = tokenizer.encode( 89 | text.strip(), add_special_tokens=True, max_length=args.max_seq_length,) 90 | token_type_ids_a = [0] * len(input_id_a) 91 | attention_mask_a = [ 92 | 1 if mask_padding_with_zero else 0] * len(input_id_a) 93 | input_id_a, attention_mask_a, token_type_ids_a = pad_ids( 94 | input_id_a, attention_mask_a, token_type_ids_a, args.max_seq_length, tokenizer.pad_token_id, mask_padding_with_zero, pad_token_segment_id, pad_on_left) 95 | if i == 0: 96 | pos_feats += [torch.tensor(input_id_a, dtype=torch.int), 97 | torch.tensor(attention_mask_a, dtype=torch.bool)] 98 | neg_feats += [torch.tensor(input_id_a, dtype=torch.int), 99 | torch.tensor(attention_mask_a, dtype=torch.bool)] 100 | elif i == 1: 101 | pos_feats += [torch.tensor(input_id_a, dtype=torch.int), 102 | torch.tensor(attention_mask_a, dtype=torch.bool), 1] 103 | else: 104 | neg_feats += [torch.tensor(input_id_a, dtype=torch.int), 105 | torch.tensor(attention_mask_a, dtype=torch.bool), 0] 106 | ret = [pos_feats, neg_feats] 107 | else: 108 | raise Exception( 109 | "Line doesn't have correct length: {0}. Expected 3.".format(str(len(cells)))) 110 | return ret 111 | 112 | -------------------------------------------------------------------------------- /model/SEED_Encoder/SEED-Encoder.md: -------------------------------------------------------------------------------- 1 | This repository provides the fine-tuning stage on Marco ranking task for [SEED-Encoder](https://arxiv.org/abs/2102.09206) and is based on ANCE (https://github.com/microsoft/ANCE). 2 | 3 | # Requirements and Installation 4 | 5 | * [PyTorch](http://pytorch.org/) version >= 1.4.0 6 | * Python version >= 3.6 7 | 8 | ## Requirements 9 | 10 | To install requirements, run the following commands: 11 | 12 | ```setup 13 | cd SEED_Encoder 14 | python setup.py install 15 | ``` 16 | 17 | 18 | 19 | # Fine-tuning for SEED-Encoder 20 | * We follow the ranking experiments in ANCE ([Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval](https://arxiv.org/pdf/2007.00808.pdf) ) as our downstream tasks. 21 | 22 | 23 | 24 | 25 | ## Our Checkpoints 26 | [Pretrained SEED-Encoder with 3-layer decoder, attention span = 2 ](https://fastbertjp.blob.core.windows.net/release-model/SEED-Encoder-3-decoder-layers.tar) 27 | 28 | [Pretrained SEED-Encoder with 1-layer decoder, attention span = 8 ](https://fastbertjp.blob.core.windows.net/release-model/SEED-Encoder-1-decoder-layer.tar) 29 | 30 | [SEED-Encoder warmup checkpoint](https://fastbertjp.blob.core.windows.net/release-model/SEED-Encoder-warmup-90000.tar) 31 | 32 | [ANCE finetuned SEED-Encoder checkpoint on passage ranking task](https://fastbertjp.blob.core.windows.net/release-model/SEED-Encoder-pass-440000.tar) 33 | 34 | [ANCE finetuned SEED-Encoder checkpoint on document ranking task](https://fastbertjp.blob.core.windows.net/release-model/SEED-Encoder-doc-800000.tar) 35 | 36 | 37 | 38 | 39 | 40 | ## Data Preprocessing 41 | 42 | seq_length=512 43 | tokenizer_type="seed-encoder" 44 | base_data_dir={} 45 | data_type={} 46 | model_path={} 47 | preprocessed_data_dir=${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/ 48 | 49 | preprocess_cmd="\ 50 | python ../data/msmarco_data.py --data_dir $base_data_dir --out_data_dir $preprocessed_data_dir --model_name_or_path $model_path --model_type seeddot_nll --max_seq_length $seq_length --data_type $data_type " 51 | 52 | echo $preprocess_cmd 53 | eval $preprocess_cmd 54 | 55 | 56 | ## Warmup for Training 57 | 58 | 59 | DATA_DIR=../../data/raw_data 60 | SAVE_DIR=../../temp/ 61 | LOAD_DIR=$your_path/SEED-Encoder-1-decoder-layer/ 62 | 63 | python3 -m torch.distributed.launch --nproc_per_node=8 ../drivers/run_warmup.py \ 64 | --train_model_type seeddot_nll --model_name_or_path $LOAD_DIR --task_name MSMarco --do_train \ 65 | --evaluate_during_training --data_dir $DATA_DIR \ 66 | --max_seq_length 128 --per_gpu_eval_batch_size=256 --per_gpu_train_batch_size=32 --learning_rate 2e-4 --logging_steps 100 --num_train_epochs 2.0 \ 67 | --output_dir $SAVE_DIR --warmup_steps 1000 --overwrite_output_dir --save_steps 20000 --gradient_accumulation_steps 1 --expected_train_size 35000000 \ 68 | --logging_steps_per_eval 100 --fp16 --optimizer lamb --log_dir $SAVE_DIR/log --do_lower_case --fp16 69 | 70 | 71 | 72 | DATA_DIR=../../data/raw_data 73 | SAVE_DIR=../../temp/ 74 | LOAD_DIR=$your_path/SEED-Encoder-3-decoder-layers/ 75 | 76 | python3 -m torch.distributed.launch --nproc_per_node=8 ../drivers/run_warmup.py \ 77 | --train_model_type seeddot_nll --model_name_or_path $LOAD_DIR --task_name MSMarco --do_train \ 78 | --evaluate_during_training --data_dir $DATA_DIR \ 79 | --max_seq_length 128 --per_gpu_eval_batch_size=256 --per_gpu_train_batch_size=32 --learning_rate 2e-4 --logging_steps 100 --num_train_epochs 2.0 \ 80 | --output_dir $SAVE_DIR --warmup_steps 1000 --overwrite_output_dir --save_steps 20000 --gradient_accumulation_steps 1 --expected_train_size 35000000 \ 81 | --logging_steps_per_eval 100 --fp16 --optimizer lamb --log_dir $SAVE_DIR/log --do_lower_case --fp16 82 | 83 | 84 | 85 | 86 | ## ANCE Training (passage, you may first use the second command to generate the initial data) 87 | 88 | gpu_no=4 89 | seq_length=512 90 | tokenizer_type={} 91 | model_type=seeddot_nll 92 | base_data_dir={} 93 | preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 94 | job_name=21_09_04_try 95 | pretrained_checkpoint_dir=${you_model_dir}/SEED-Encoder-pass-440000/ 96 | data_type=1 97 | warmup_steps=5000 98 | per_gpu_train_batch_size=16 99 | gradient_accumulation_steps=1 100 | learning_rate=1e-6 101 | 102 | model_dir="${base_data_dir}${job_name}/" 103 | model_ann_data_dir="${model_dir}ann_data/" 104 | 105 | 106 | CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \ 107 | --init_model_dir $pretrained_checkpoint_dir --model_type $model_type --output_dir $model_ann_data_dir \ 108 | --cache_dir "${model_dir}cache/" --data_dir $preprocessed_data_dir --max_seq_length $seq_length \ 109 | --per_gpu_eval_batch_size 64 --topk_training 200 --negative_sample 20 110 | 111 | 112 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=$gpu_no --master_addr 127.0.0.2 --master_port 35000 ../drivers/run_ann.py --model_type $model_type \ 113 | --model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco --triplet --data_dir $preprocessed_data_dir \ 114 | --ann_dir $model_ann_data_dir --max_seq_length $seq_length --per_gpu_train_batch_size=$per_gpu_train_batch_size \ 115 | --gradient_accumulation_steps $gradient_accumulation_steps --learning_rate $learning_rate --output_dir $model_dir \ 116 | --warmup_steps $warmup_steps --logging_steps 100 --save_steps 100000000 --optimizer lamb --single_warmup --cache_dir "${model_dir}cache/" --do_lower_case 117 | 118 | 119 | 120 | ## ANCE Training (document) 121 | 122 | gpu_no=4 123 | seq_length=512 124 | tokenizer_type={} 125 | model_type=seeddot_nll 126 | base_data_dir={} 127 | preprocessed_data_dir="${base_data_dir}ann_data_${tokenizer_type}_${seq_length}/" 128 | job_name=21_09_04_try2 129 | pretrained_checkpoint_dir=${you_model_dir}/SEED-Encoder-doc-800000/ 130 | data_type=0 131 | warmup_steps=3000 132 | per_gpu_train_batch_size=4 133 | gradient_accumulation_steps=4 134 | learning_rate=5e-6 135 | 136 | model_dir="${base_data_dir}${job_name}/" 137 | model_ann_data_dir="${model_dir}ann_data/" 138 | 139 | 140 | 141 | CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \ 142 | --init_model_dir $pretrained_checkpoint_dir --model_type $model_type --output_dir $model_ann_data_dir \ 143 | --cache_dir "${model_dir}cache/" --data_dir $preprocessed_data_dir --max_seq_length $seq_length \ 144 | --per_gpu_eval_batch_size 16 --topk_training 200 --negative_sample 20 145 | 146 | 147 | 148 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=$gpu_no --master_addr 127.0.0.2 --master_port 35000 ../drivers/run_ann.py --model_type $model_type \ 149 | --model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco --triplet --data_dir $preprocessed_data_dir \ 150 | --ann_dir $model_ann_data_dir --max_seq_length $seq_length --per_gpu_train_batch_size=$per_gpu_train_batch_size \ 151 | --gradient_accumulation_steps $gradient_accumulation_steps --learning_rate $learning_rate --output_dir $model_dir \ 152 | --warmup_steps $warmup_steps --logging_steps 100 --save_steps 100000000 --optimizer lamb --single_warmup --cache_dir "${model_dir}cache/" --do_lower_case 153 | 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /utils/eval_mrr.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path += ["../"] 3 | from utils.msmarco_eval import quality_checks_qids, compute_metrics, load_reference 4 | import torch.distributed as dist 5 | import gzip 6 | import faiss 7 | import numpy as np 8 | from data.process_fn import dual_process_fn 9 | from tqdm import tqdm 10 | import torch 11 | import os 12 | from utils.util import concat_key, is_first_worker, all_gather, StreamingDataset 13 | from torch.utils.data import DataLoader 14 | 15 | 16 | def embedding_inference(args, path, model, fn, bz, num_workers=2, is_query=True): 17 | f = open(path, encoding="utf-8") 18 | model = model.module if hasattr(model, "module") else model 19 | sds = StreamingDataset(f, fn) 20 | loader = DataLoader(sds, batch_size=bz, num_workers=1) 21 | emb_list, id_list = [], [] 22 | model.eval() 23 | for i, batch in tqdm(enumerate(loader), desc="Eval", disable=args.local_rank not in [-1, 0]): 24 | batch = tuple(t.to(args.device) for t in batch) 25 | with torch.no_grad(): 26 | inputs = {"input_ids": batch[0].long( 27 | ), "attention_mask": batch[1].long()} 28 | idx = batch[3].long() 29 | if is_query: 30 | embs = model.query_emb(**inputs) 31 | else: 32 | embs = model.body_emb(**inputs) 33 | if len(embs.shape) == 3: 34 | B, C, E = embs.shape 35 | # [b1c1, b1c2, b1c3, b1c4, b2c1 ....] 36 | embs = embs.view(B*C, -1) 37 | idx = idx.repeat_interleave(C) 38 | 39 | assert embs.shape[0] == idx.shape[0] 40 | emb_list.append(embs.detach().cpu().numpy()) 41 | id_list.append(idx.detach().cpu().numpy()) 42 | f.close() 43 | emb_arr = np.concatenate(emb_list, axis=0) 44 | id_arr = np.concatenate(id_list, axis=0) 45 | 46 | return emb_arr, id_arr 47 | 48 | 49 | def parse_top_dev(input_path, qid_col, pid_col): 50 | ret = {} 51 | with open(input_path, encoding="utf-8") as f: 52 | for line in f: 53 | cells = line.strip().split("\t") 54 | qid = int(cells[qid_col]) 55 | pid = int(cells[pid_col]) 56 | if qid not in ret: 57 | ret[qid] = [] 58 | ret[qid].append(pid) 59 | return ret 60 | 61 | 62 | def search_knn(xq, xb, k, distance_type=faiss.METRIC_L2): 63 | """ wrapper around the faiss knn functions without index """ 64 | nq, d = xq.shape 65 | nb, d2 = xb.shape 66 | assert d == d2 67 | 68 | I = np.empty((nq, k), dtype='int64') 69 | D = np.empty((nq, k), dtype='float32') 70 | 71 | if distance_type == faiss.METRIC_L2: 72 | heaps = faiss.float_maxheap_array_t() 73 | heaps.k = k 74 | heaps.nh = nq 75 | heaps.val = faiss.swig_ptr(D) 76 | heaps.ids = faiss.swig_ptr(I) 77 | faiss.knn_L2sqr( 78 | faiss.swig_ptr(xq), faiss.swig_ptr(xb), 79 | d, nq, nb, heaps 80 | ) 81 | elif distance_type == faiss.METRIC_INNER_PRODUCT: 82 | heaps = faiss.float_minheap_array_t() 83 | heaps.k = k 84 | heaps.nh = nq 85 | heaps.val = faiss.swig_ptr(D) 86 | heaps.ids = faiss.swig_ptr(I) 87 | faiss.knn_inner_product( 88 | faiss.swig_ptr(xq), faiss.swig_ptr(xb), 89 | d, nq, nb, heaps 90 | ) 91 | return D, I 92 | 93 | 94 | def get_topk_restricted(q_emb, psg_emb_arr, pid_dict, psg_ids, pid_subset, top_k): 95 | subset_ix = np.array([pid_dict[x] 96 | for x in pid_subset if x != -1 and x in pid_dict]) 97 | if len(subset_ix) == 0: 98 | _D = np.ones((top_k,))*-128 99 | _I = (np.ones((top_k,))*-1).astype(int) 100 | return _D, _I 101 | else: 102 | sub_emb = psg_emb_arr[subset_ix] 103 | _D, _I = search_knn(q_emb, sub_emb, top_k, 104 | distance_type=faiss.METRIC_INNER_PRODUCT) 105 | return _D.squeeze(), psg_ids[subset_ix[_I]].squeeze() # (top_k,) 106 | 107 | 108 | def passage_dist_eval(args, model, tokenizer): 109 | base_path = args.data_dir 110 | passage_path = os.path.join(base_path, "collection.tsv") 111 | queries_path = os.path.join(base_path, "queries.dev.small.tsv") 112 | 113 | def fn(line, i): 114 | return dual_process_fn(line, i, tokenizer, args) 115 | 116 | top1000_path = os.path.join(base_path, "top1000.dev") 117 | top1k_qid_pid = parse_top_dev(top1000_path, qid_col=0, pid_col=1) 118 | 119 | mrr_ref_path = os.path.join(base_path, "qrels.dev.small.tsv") 120 | ref_dict = load_reference(mrr_ref_path) 121 | 122 | reranking_mrr, full_ranking_mrr = combined_dist_eval( 123 | args, model, queries_path, passage_path, fn, fn, top1k_qid_pid, ref_dict) 124 | return reranking_mrr, full_ranking_mrr 125 | 126 | 127 | def combined_dist_eval(args, model, queries_path, passage_path, query_fn, psg_fn, topk_dev_qid_pid, ref_dict): 128 | # get query/psg embeddings here 129 | eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 130 | query_embs, query_ids = embedding_inference( 131 | args, queries_path, model, query_fn, eval_batch_size, 1, True) 132 | query_pkl = {"emb": query_embs, "id": query_ids} 133 | all_query_list = all_gather(query_pkl) 134 | query_embs = concat_key(all_query_list, "emb") 135 | query_ids = concat_key(all_query_list, "id") 136 | print(query_embs.shape, query_ids.shape) 137 | psg_embs, psg_ids = embedding_inference( 138 | args, passage_path, model, psg_fn, eval_batch_size, 2, False) 139 | print(psg_embs.shape) 140 | 141 | top_k = 100 142 | D, I = search_knn(query_embs, psg_embs, top_k, 143 | distance_type=faiss.METRIC_INNER_PRODUCT) 144 | I = psg_ids[I] 145 | 146 | # compute reranking and full ranking mrr here 147 | # topk_dev_qid_pid is used for computing reranking mrr 148 | pid_dict = dict([(p, i) for i, p in enumerate(psg_ids)]) 149 | arr_data = [] 150 | d_data = [] 151 | for i, qid in enumerate(query_ids): 152 | q_emb = query_embs[i:i+1] 153 | pid_subset = topk_dev_qid_pid[qid] 154 | ds, top_pids = get_topk_restricted( 155 | q_emb, psg_embs, pid_dict, psg_ids, pid_subset, 10) 156 | arr_data.append(top_pids) 157 | d_data.append(ds) 158 | _D = np.array(d_data) 159 | _I = np.array(arr_data) 160 | 161 | # reranking mrr 162 | reranking_mrr = compute_mrr(_D, _I, query_ids, ref_dict) 163 | D2 = D[:, :100] 164 | I2 = I[:, :100] 165 | # full mrr 166 | full_ranking_mrr = compute_mrr(D2, I2, query_ids, ref_dict) 167 | del psg_embs 168 | torch.cuda.empty_cache() 169 | dist.barrier() 170 | return reranking_mrr, full_ranking_mrr 171 | 172 | 173 | def compute_mrr(D, I, qids, ref_dict): 174 | knn_pkl = {"D": D, "I": I} 175 | all_knn_list = all_gather(knn_pkl) 176 | mrr = 0.0 177 | if is_first_worker(): 178 | D_merged = concat_key(all_knn_list, "D", axis=1) 179 | I_merged = concat_key(all_knn_list, "I", axis=1) 180 | print(D_merged.shape, I_merged.shape) 181 | # we pad with negative pids and distance -128 - if they make it to the top we have a problem 182 | idx = np.argsort(D_merged, axis=1)[:, ::-1][:, :10] 183 | sorted_I = np.take_along_axis(I_merged, idx, axis=1) 184 | candidate_dict = {} 185 | for i, qid in enumerate(qids): 186 | seen_pids = set() 187 | if qid not in candidate_dict: 188 | candidate_dict[qid] = [0]*1000 189 | j = 0 190 | for pid in sorted_I[i]: 191 | if pid >= 0 and pid not in seen_pids: 192 | candidate_dict[qid][j] = pid 193 | j += 1 194 | seen_pids.add(pid) 195 | 196 | allowed, message = quality_checks_qids(ref_dict, candidate_dict) 197 | if message != '': 198 | print(message) 199 | 200 | mrr_metrics = compute_metrics(ref_dict, candidate_dict) 201 | mrr = mrr_metrics["MRR @10"] 202 | print(mrr) 203 | return mrr 204 | -------------------------------------------------------------------------------- /utils/msmarco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is official eval script opensourced on MSMarco site (not written or owned by us) 3 | 4 | This module computes evaluation metrics for MSMARCO dataset on the ranking task. 5 | Command line: 6 | python msmarco_eval_ranking.py 7 | 8 | Creation Date : 06/12/2018 9 | Last Modified : 1/21/2019 10 | Authors : Daniel Campos , Rutger van Haasteren 11 | """ 12 | import sys 13 | import statistics 14 | 15 | from collections import Counter 16 | 17 | MaxMRRRank = 10 18 | 19 | def load_reference_from_stream(f): 20 | """Load Reference reference relevant passages 21 | Args:f (stream): stream to load. 22 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 23 | """ 24 | qids_to_relevant_passageids = {} 25 | for l in f: 26 | try: 27 | l = l.strip().split('\t') 28 | qid = int(l[0]) 29 | if qid in qids_to_relevant_passageids: 30 | pass 31 | else: 32 | qids_to_relevant_passageids[qid] = [] 33 | qids_to_relevant_passageids[qid].append(int(l[2])) 34 | except: 35 | raise IOError('\"%s\" is not valid format' % l) 36 | return qids_to_relevant_passageids 37 | 38 | def load_reference(path_to_reference): 39 | """Load Reference reference relevant passages 40 | Args:path_to_reference (str): path to a file to load. 41 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 42 | """ 43 | with open(path_to_reference,'r') as f: 44 | qids_to_relevant_passageids = load_reference_from_stream(f) 45 | return qids_to_relevant_passageids 46 | 47 | def load_candidate_from_stream(f): 48 | """Load candidate data from a stream. 49 | Args:f (stream): stream to load. 50 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 51 | """ 52 | qid_to_ranked_candidate_passages = {} 53 | for l in f: 54 | try: 55 | l = l.strip().split('\t') 56 | qid = int(l[0]) 57 | pid = int(l[1]) 58 | rank = int(l[2]) 59 | if qid in qid_to_ranked_candidate_passages: 60 | pass 61 | else: 62 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 63 | tmp = [0] * 1000 64 | qid_to_ranked_candidate_passages[qid] = tmp 65 | qid_to_ranked_candidate_passages[qid][rank-1]=pid 66 | except: 67 | raise IOError('\"%s\" is not valid format' % l) 68 | return qid_to_ranked_candidate_passages 69 | 70 | def load_candidate(path_to_candidate): 71 | """Load candidate data from a file. 72 | Args:path_to_candidate (str): path to file to load. 73 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 74 | """ 75 | 76 | with open(path_to_candidate,'r') as f: 77 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f) 78 | return qid_to_ranked_candidate_passages 79 | 80 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 81 | """Perform quality checks on the dictionaries 82 | 83 | Args: 84 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 85 | Dict as read in with load_reference or load_reference_from_stream 86 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 87 | Returns: 88 | bool,str: Boolean whether allowed, message to be shown in case of a problem 89 | """ 90 | message = '' 91 | allowed = True 92 | 93 | # Create sets of the QIDs for the submitted and reference queries 94 | candidate_set = set(qids_to_ranked_candidate_passages.keys()) 95 | ref_set = set(qids_to_relevant_passageids.keys()) 96 | 97 | # Check that we do not have multiple passages per query 98 | for qid in qids_to_ranked_candidate_passages: 99 | # Remove all zeros from the candidates 100 | duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) 101 | 102 | if len(duplicate_pids-set([0])) > 0: 103 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( 104 | qid=qid, pid=list(duplicate_pids)[0]) 105 | allowed = False 106 | 107 | return allowed, message 108 | 109 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 110 | """Compute MRR metric 111 | Args: 112 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 113 | Dict as read in with load_reference or load_reference_from_stream 114 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 115 | Returns: 116 | dict: dictionary of metrics {'MRR': } 117 | """ 118 | all_scores = {} 119 | MRR = 0 120 | qids_with_relevant_passages = 0 121 | ranking = [] 122 | for qid in qids_to_ranked_candidate_passages: 123 | if qid in qids_to_relevant_passageids: 124 | ranking.append(0) 125 | target_pid = qids_to_relevant_passageids[qid] 126 | candidate_pid = qids_to_ranked_candidate_passages[qid] 127 | for i in range(0,MaxMRRRank): 128 | if candidate_pid[i] in target_pid: 129 | MRR += 1/(i + 1) 130 | ranking.pop() 131 | ranking.append(i+1) 132 | break 133 | if len(ranking) == 0: 134 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 135 | 136 | MRR = MRR/len(qids_to_relevant_passageids) 137 | all_scores['MRR @10'] = MRR 138 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) 139 | return all_scores 140 | 141 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): 142 | """Compute MRR metric 143 | Args: 144 | p_path_to_reference_file (str): path to reference file. 145 | Reference file should contain lines in the following format: 146 | QUERYID\tPASSAGEID 147 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs 148 | p_path_to_candidate_file (str): path to candidate file. 149 | Candidate file sould contain lines in the following format: 150 | QUERYID\tPASSAGEID1\tRank 151 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 152 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 153 | Where the values are separated by tabs and ranked in order of relevance 154 | Returns: 155 | dict: dictionary of metrics {'MRR': } 156 | """ 157 | 158 | qids_to_relevant_passageids = load_reference(path_to_reference) 159 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) 160 | if perform_checks: 161 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 162 | if message != '': print(message) 163 | 164 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 165 | 166 | def main(): 167 | """Command line: 168 | python msmarco_eval_ranking.py 169 | """ 170 | print("Eval Started") 171 | if len(sys.argv) == 3: 172 | path_to_reference = sys.argv[1] 173 | path_to_candidate = sys.argv[2] 174 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 175 | print('#####################') 176 | for metric in sorted(metrics): 177 | print('{}: {}'.format(metric, metrics[metric])) 178 | print('#####################') 179 | 180 | else: 181 | print('Usage: msmarco_eval_ranking.py ') 182 | exit() 183 | 184 | if __name__ == '__main__': 185 | main() -------------------------------------------------------------------------------- /model/SEED_Encoder/configuration_seed_encoder.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | #from transformers.utils import logging 3 | #logger = logging.get_logger(__name__) 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | # DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP = { 9 | # "microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/config.json", 10 | # "microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/config.json", 11 | # "microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/config.json", 12 | # "microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/config.json", 13 | # } 14 | 15 | 16 | class SEEDEncoderConfig(PretrainedConfig): 17 | r""" 18 | This is the configuration class to store the configuration of a :class:`~transformers.DebertaV2Model`. It is used 19 | to instantiate a DeBERTa-v2 model according to the specified arguments, defining the model architecture. 20 | Instantiating a configuration with the defaults will yield a similar configuration to that of the DeBERTa 21 | `microsoft/deberta-v2-xlarge `__ architecture. 22 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 23 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 24 | Arguments: 25 | vocab_size (:obj:`int`, `optional`, defaults to 128100): 26 | Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by 27 | the :obj:`inputs_ids` passed when calling :class:`~transformers.DebertaV2Model`. 28 | hidden_size (:obj:`int`, `optional`, defaults to 1536): 29 | Dimensionality of the encoder layers and the pooler layer. 30 | num_hidden_layers (:obj:`int`, `optional`, defaults to 24): 31 | Number of hidden layers in the Transformer encoder. 32 | num_attention_heads (:obj:`int`, `optional`, defaults to 24): 33 | Number of attention heads for each attention layer in the Transformer encoder. 34 | intermediate_size (:obj:`int`, `optional`, defaults to 6144): 35 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 36 | hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): 37 | The non-linear activation function (function or string) in the encoder and pooler. If string, 38 | :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"`, :obj:`"gelu"`, :obj:`"tanh"`, :obj:`"gelu_fast"`, 39 | :obj:`"mish"`, :obj:`"linear"`, :obj:`"sigmoid"` and :obj:`"gelu_new"` are supported. 40 | hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 41 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 42 | attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 43 | The dropout ratio for the attention probabilities. 44 | max_position_embeddings (:obj:`int`, `optional`, defaults to 512): 45 | The maximum sequence length that this model might ever be used with. Typically set this to something large 46 | just in case (e.g., 512 or 1024 or 2048). 47 | type_vocab_size (:obj:`int`, `optional`, defaults to 0): 48 | The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.DebertaModel` or 49 | :class:`~transformers.TFDebertaModel`. 50 | initializer_range (:obj:`float`, `optional`, defaults to 0.02): 51 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 52 | layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-7): 53 | The epsilon used by the layer normalization layers. 54 | relative_attention (:obj:`bool`, `optional`, defaults to :obj:`True`): 55 | Whether use relative position encoding. 56 | max_relative_positions (:obj:`int`, `optional`, defaults to -1): 57 | The range of relative positions :obj:`[-max_position_embeddings, max_position_embeddings]`. Use the same 58 | value as :obj:`max_position_embeddings`. 59 | pad_token_id (:obj:`int`, `optional`, defaults to 0): 60 | The value used to pad input_ids. 61 | position_biased_input (:obj:`bool`, `optional`, defaults to :obj:`False`): 62 | Whether add absolute position embedding to content embedding. 63 | pos_att_type (:obj:`List[str]`, `optional`): 64 | The type of relative position attention, it can be a combination of :obj:`["p2c", "c2p", "p2p"]`, e.g. 65 | :obj:`["p2c"]`, :obj:`["p2c", "c2p"]`, :obj:`["p2c", "c2p", 'p2p"]`. 66 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 67 | The epsilon used by the layer normalization layers. 68 | """ 69 | model_type = "seed_encoder" 70 | 71 | def __init__( 72 | self, 73 | pad_token_id=1, 74 | vocab_size=32769, 75 | encoder_layers=12, 76 | encoder_embed_dim=768, 77 | encoder_ffn_embed_dim=3072, 78 | encoder_attention_heads=12, 79 | dropout=0.1, 80 | attention_dropout=0.1, 81 | activation_dropout=0.0, 82 | encoder_layerdrop=0.0, 83 | max_positions=512, 84 | activation_fn='gelu', 85 | quant_noise_pq=0.0, 86 | quant_noise_pq_block_size=8, 87 | train_ratio='0.5:0.5', 88 | decoder_atten_window=2, 89 | pooler_activation_fn='tanh', 90 | pooler_dropout=0.0, 91 | encoder_layers_to_keep=None, 92 | decoder_layers=3, 93 | decoder_embed_path=None, 94 | decoder_embed_dim=768, 95 | decoder_ffn_embed_dim=3072, 96 | decoder_attention_heads=12, 97 | decoder_normalize_before=True, 98 | decoder_learned_pos=True, 99 | adaptive_softmax_cutoff=None, 100 | adaptive_softmax_dropout=0, 101 | share_decoder_input_output_embed=True, 102 | share_all_embeddings=True, 103 | no_token_positional_embeddings=False, 104 | adaptive_input=False, 105 | no_cross_attention=False, 106 | cross_self_attention=False, 107 | no_scale_embedding=True, 108 | layernorm_embedding=True, 109 | tie_adaptive_weights=True, 110 | decoder_layers_to_keep=None, 111 | initializer_range=0.02, 112 | **kwargs 113 | ): 114 | super().__init__(**kwargs) 115 | 116 | self.pad_token_id=pad_token_id 117 | self.vocab_size=vocab_size 118 | self.encoder_layers=encoder_layers 119 | self.encoder_embed_dim=encoder_embed_dim 120 | self.encoder_ffn_embed_dim=encoder_ffn_embed_dim 121 | self.encoder_attention_heads=encoder_attention_heads 122 | 123 | self.dropout=dropout 124 | self.attention_dropout=attention_dropout 125 | self.activation_dropout=activation_dropout 126 | self.encoder_layerdrop=encoder_layerdrop 127 | self.max_positions=max_positions 128 | self.activation_fn=activation_fn 129 | self.quant_noise_pq=quant_noise_pq 130 | self.quant_noise_pq_block_size=quant_noise_pq_block_size 131 | 132 | 133 | self.train_ratio=train_ratio 134 | self.decoder_atten_window=decoder_atten_window 135 | self.pooler_activation_fn=pooler_activation_fn 136 | self.pooler_dropout=pooler_dropout 137 | 138 | 139 | self.encoder_layers_to_keep=encoder_layers_to_keep 140 | self.decoder_layers=decoder_layers 141 | self.decoder_embed_path=decoder_embed_path 142 | self.decoder_embed_dim=decoder_embed_dim 143 | self.decoder_ffn_embed_dim=decoder_ffn_embed_dim 144 | self.decoder_attention_heads=decoder_attention_heads 145 | self.decoder_normalize_before=decoder_normalize_before 146 | self.decoder_learned_pos=decoder_learned_pos 147 | self.adaptive_softmax_cutoff=adaptive_softmax_cutoff 148 | self.adaptive_softmax_dropout=adaptive_softmax_dropout 149 | self.share_decoder_input_output_embed=share_decoder_input_output_embed 150 | self.share_all_embeddings=share_all_embeddings 151 | self.no_token_positional_embeddings=no_token_positional_embeddings 152 | 153 | self.adaptive_input=adaptive_input 154 | self.no_cross_attention=no_cross_attention 155 | self.cross_self_attention=cross_self_attention 156 | 157 | self.decoder_output_dim=decoder_embed_dim 158 | self.decoder_input_dim=decoder_embed_dim 159 | 160 | self.no_scale_embedding=no_scale_embedding 161 | self.layernorm_embedding=layernorm_embedding 162 | self.tie_adaptive_weights=tie_adaptive_weights 163 | self.decoder_layers_to_keep=decoder_layers_to_keep 164 | 165 | 166 | self.decoder_layerdrop=0 167 | 168 | self.max_source_positions=max_positions 169 | self.max_target_positions=max_positions 170 | 171 | self.initializer_range = initializer_range 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /model/SEED_Encoder/modeling_seed_encoder.py: -------------------------------------------------------------------------------- 1 | #from transformers.utils import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | 10 | from .modules import ( 11 | LayerNorm, 12 | get_activation_fn, 13 | MultiheadAttention, 14 | 15 | ) 16 | from .modules import quant_noise as apply_quant_noise_ 17 | 18 | 19 | from .transformer_sentence_encoder import TransformerSentenceEncoder,TransformerDecoder,EncoderOut 20 | 21 | 22 | import os 23 | from transformers.modeling_utils import PreTrainedModel 24 | 25 | 26 | #logger = logging.get_logger(__name__) 27 | 28 | import logging 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | from model.SEED_Encoder import SEEDEncoderConfig 33 | 34 | 35 | 36 | 37 | 38 | class SEEDEncoderPretrainedModel(PreTrainedModel): 39 | 40 | config_class = SEEDEncoderConfig 41 | base_model_prefix = "seed_encoder" 42 | 43 | def _init_weights(self, module): 44 | """Initialize the weights.""" 45 | if isinstance(module, nn.Linear): 46 | # Slightly different from the TF version which uses truncated_normal for initialization 47 | # cf https://github.com/pytorch/pytorch/pull/5617 48 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 49 | if module.bias is not None: 50 | module.bias.data.zero_() 51 | elif isinstance(module, nn.Embedding): 52 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 53 | if module.padding_idx is not None: 54 | module.weight.data[module.padding_idx].zero_() 55 | elif isinstance(module, MultiheadAttention): 56 | module.q_proj.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 57 | module.k_proj.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 58 | module.v_proj.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 59 | elif isinstance(module, nn.LayerNorm): 60 | module.bias.data.zero_() 61 | module.weight.data.fill_(1.0) 62 | 63 | 64 | 65 | 66 | class RobertaEncoder(nn.Module): 67 | """RoBERTa encoder.""" 68 | 69 | def __init__(self, args): 70 | super().__init__() 71 | self.args = args 72 | 73 | if args.encoder_layers_to_keep: 74 | args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) 75 | 76 | self.sentence_encoder = TransformerSentenceEncoder( 77 | padding_idx=args.pad_token_id, 78 | vocab_size=args.vocab_size, 79 | num_encoder_layers=args.encoder_layers, 80 | embedding_dim=args.encoder_embed_dim, 81 | ffn_embedding_dim=args.encoder_ffn_embed_dim, 82 | num_attention_heads=args.encoder_attention_heads, 83 | dropout=args.dropout, 84 | attention_dropout=args.attention_dropout, 85 | activation_dropout=args.activation_dropout, 86 | layerdrop=args.encoder_layerdrop, 87 | max_seq_len=args.max_positions, 88 | num_segments=0, 89 | encoder_normalize_before=True, 90 | apply_bert_init=True, 91 | activation_fn=args.activation_fn, 92 | q_noise=args.quant_noise_pq, 93 | qn_block_size=args.quant_noise_pq_block_size, 94 | ) 95 | #args.untie_weights_roberta = getattr(args, 'untie_weights_roberta', False) 96 | 97 | 98 | 99 | def forward(self, src_tokens, return_all_hiddens=False, **unused): 100 | 101 | inner_states, _ = self.sentence_encoder( 102 | src_tokens, 103 | last_state_only=not return_all_hiddens, 104 | ) 105 | x = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C 106 | 107 | # x_origin=x 108 | # if not features_only: 109 | # x = self.output_layer(x, masked_tokens=masked_tokens) 110 | return x, {'inner_states': inner_states if return_all_hiddens else None} 111 | 112 | 113 | 114 | 115 | class SEEDEncoderModel(SEEDEncoderPretrainedModel): 116 | 117 | def __init__(self, config): 118 | super().__init__(config) 119 | self.encoder=RobertaEncoder(config) 120 | self.init_weights() 121 | 122 | 123 | def forward(self, src_tokens, prev_tokens, return_all_hiddens=False, **kwargs): 124 | 125 | x_encoder, extra = self.encoder(src_tokens, return_all_hiddens, **kwargs) 126 | 127 | return x_encoder ,extra 128 | 129 | def get_input_embeddings(self): 130 | 131 | return self.encoder.sentence_encoder.embed_tokens 132 | 133 | def set_input_embeddings(self, value): 134 | 135 | self.encoder.sentence_encoder.embed_tokens = value 136 | 137 | 138 | class SEEDEncoderForMaskedLM(SEEDEncoderPretrainedModel): 139 | """docstring for ClassName""" 140 | # _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] 141 | # _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] 142 | 143 | 144 | def __init__(self, config): 145 | super().__init__(config) 146 | self.seed_encoder = SEEDEncoderModel(config) 147 | self.decoder=TransformerDecoder(config,self.encoder.sentence_encoder.embed_tokens,no_encoder_attn=config.no_cross_attention) 148 | self.lm_head = RobertaLMHead( 149 | embed_dim=config.encoder_embed_dim, 150 | output_dim=config.vocab_size, 151 | activation_fn=config.activation_fn, 152 | weight=self.seed_encoder.encoder.sentence_encoder.embed_tokens.weight ) 153 | self.train_ratio=config.train_ratio 154 | self.decoder_atten_window=config.decoder_atten_window 155 | 156 | self.init_weights() 157 | 158 | 159 | def forward( src_tokens,prev_tokens, masked_tokens=None,**kwargs): 160 | x_encoder,_=self.seed_encoder(src_tokens) 161 | 162 | h=x_encoder[:,0:1,:] 163 | h=h.transpose(0,1) 164 | h=EncoderOut( 165 | encoder_out=h, # T x B x C 166 | encoder_padding_mask=None, # B x T 167 | encoder_embedding=None, # B x T x C 168 | encoder_states=None, # List[T x B x C] 169 | src_tokens=None, 170 | src_lengths=None, 171 | ) 172 | decoder_output=self.decoder(prev_tokens, encoder_out=h,local_attn_mask=self.decoder_atten_window)[0] 173 | 174 | 175 | features=self.lm_head(x_encoder, masked_tokens) 176 | return features, decoder_output 177 | 178 | def get_output_embeddings(self): 179 | return self.lm_head.weight 180 | 181 | def set_output_embeddings(self, new_embeddings): 182 | self.lm_head.weight = new_embeddings 183 | 184 | 185 | 186 | class SEEDEncoderForSequenceClassification(SEEDEncoderPretrainedModel): 187 | """docstring for ClassName""" 188 | def __init__(self, config): 189 | super().__init__(config) 190 | self.seed_encoder = SEEDEncoderModel(config) 191 | self.classification_heads=RobertaClassificationHead( 192 | config.encoder_embed_dim, 193 | config.encoder_embed_dim, 194 | config.num_labels, 195 | config.pooler_activation_fn, 196 | config.pooler_dropout, 197 | config.quant_noise_pq, 198 | config.quant_noise_pq_block_size,) 199 | 200 | self.init_weights() 201 | 202 | def forward(src_tokens,return_all_hiddens=False,**kwargs): 203 | 204 | x_encoder, extra = self.seed_encoder.encoder(src_tokens, return_all_hiddens, **kwargs) 205 | x = self.classification_heads(x_encoder,**kwargs) 206 | 207 | return x 208 | 209 | 210 | 211 | 212 | 213 | 214 | class RobertaLMHead(nn.Module): 215 | """Head for masked language modeling.""" 216 | 217 | def __init__(self, embed_dim, output_dim, activation_fn, weight=None): 218 | super().__init__() 219 | self.dense = nn.Linear(embed_dim, embed_dim) 220 | self.activation_fn = get_activation_fn(activation_fn) 221 | self.layer_norm = LayerNorm(embed_dim) 222 | 223 | if weight is None: 224 | weight = nn.Linear(embed_dim, output_dim, bias=False).weight 225 | self.weight = weight 226 | self.bias = nn.Parameter(torch.zeros(output_dim)) 227 | 228 | def forward(self, features, masked_tokens=None, **kwargs): 229 | # Only project the masked tokens while training, 230 | # saves both memory and computation 231 | if masked_tokens is not None: 232 | features = features[masked_tokens, :] 233 | 234 | x = self.dense(features) 235 | x = self.activation_fn(x) 236 | x = self.layer_norm(x) 237 | # project back to size of vocabulary with bias 238 | x = F.linear(x, self.weight) + self.bias 239 | return x 240 | 241 | 242 | class RobertaClassificationHead(nn.Module): 243 | """Head for sentence-level classification tasks.""" 244 | 245 | def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout, q_noise=0, qn_block_size=8): 246 | super().__init__() 247 | self.dense = nn.Linear(input_dim, inner_dim) 248 | self.activation_fn = get_activation_fn(activation_fn) 249 | self.dropout = nn.Dropout(p=pooler_dropout) 250 | self.out_proj = apply_quant_noise_( 251 | nn.Linear(inner_dim, num_classes), q_noise, qn_block_size 252 | ) 253 | 254 | def forward(self, features, **kwargs): 255 | x = features[:, 0, :] # take token (equiv. to [CLS]) 256 | x = self.dropout(x) 257 | x = self.dense(x) 258 | x = self.activation_fn(x) 259 | x = self.dropout(x) 260 | x = self.out_proj(x) 261 | return x 262 | 263 | 264 | 265 | 266 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path += ['../'] 3 | import torch 4 | from torch import nn 5 | from transformers import ( 6 | RobertaConfig, 7 | RobertaModel, 8 | RobertaForSequenceClassification, 9 | RobertaTokenizer, 10 | BertModel, 11 | BertTokenizer, 12 | BertConfig 13 | ) 14 | import torch.nn.functional as F 15 | from data.process_fn import triple_process_fn, triple2dual_process_fn 16 | from model.SEED_Encoder import SEEDEncoderConfig, SEEDTokenizer, SEEDEncoderForSequenceClassification,SEEDEncoderForMaskedLM 17 | 18 | 19 | class EmbeddingMixin: 20 | """ 21 | Mixin for common functions in most embedding models. Each model should define its own bert-like backbone and forward. 22 | We inherit from RobertaModel to use from_pretrained 23 | """ 24 | def __init__(self, model_argobj): 25 | if model_argobj is None: 26 | self.use_mean = False 27 | else: 28 | self.use_mean = model_argobj.use_mean 29 | print("Using mean:", self.use_mean) 30 | 31 | def _init_weights(self, module): 32 | """ Initialize the weights """ 33 | if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)): 34 | # Slightly different from the TF version which uses truncated_normal for initialization 35 | # cf https://github.com/pytorch/pytorch/pull/5617 36 | module.weight.data.normal_(mean=0.0, std=0.02) 37 | 38 | def masked_mean(self, t, mask): 39 | s = torch.sum(t * mask.unsqueeze(-1).float(), axis=1) 40 | d = mask.sum(axis=1, keepdim=True).float() 41 | return s / d 42 | 43 | def masked_mean_or_first(self, emb_all, mask): 44 | # emb_all is a tuple from bert - sequence output, pooler 45 | assert isinstance(emb_all, tuple) 46 | if self.use_mean: 47 | return self.masked_mean(emb_all[0], mask) 48 | else: 49 | return emb_all[0][:, 0] 50 | 51 | def query_emb(self, input_ids, attention_mask): 52 | raise NotImplementedError("Please Implement this method") 53 | 54 | def body_emb(self, input_ids, attention_mask): 55 | raise NotImplementedError("Please Implement this method") 56 | 57 | 58 | class NLL(EmbeddingMixin): 59 | def forward( 60 | self, 61 | query_ids, 62 | attention_mask_q, 63 | input_ids_a=None, 64 | attention_mask_a=None, 65 | input_ids_b=None, 66 | attention_mask_b=None, 67 | is_query=True): 68 | if input_ids_b is None and is_query: 69 | return self.query_emb(query_ids, attention_mask_q) 70 | elif input_ids_b is None: 71 | return self.body_emb(query_ids, attention_mask_q) 72 | 73 | q_embs = self.query_emb(query_ids, attention_mask_q) 74 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 75 | b_embs = self.body_emb(input_ids_b, attention_mask_b) 76 | 77 | logit_matrix = torch.cat([(q_embs * a_embs).sum(-1).unsqueeze(1), 78 | (q_embs * b_embs).sum(-1).unsqueeze(1)], dim=1) # [B, 2] 79 | lsm = F.log_softmax(logit_matrix, dim=1) 80 | loss = -1.0 * lsm[:, 0] 81 | return (loss.mean(),) 82 | 83 | 84 | class NLL_MultiChunk(EmbeddingMixin): 85 | def forward( 86 | self, 87 | query_ids, 88 | attention_mask_q, 89 | input_ids_a=None, 90 | attention_mask_a=None, 91 | input_ids_b=None, 92 | attention_mask_b=None, 93 | is_query=True): 94 | if input_ids_b is None and is_query: 95 | return self.query_emb(query_ids, attention_mask_q) 96 | elif input_ids_b is None: 97 | return self.body_emb(query_ids, attention_mask_q) 98 | 99 | q_embs = self.query_emb(query_ids, attention_mask_q) 100 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 101 | b_embs = self.body_emb(input_ids_b, attention_mask_b) 102 | 103 | [batchS, full_length] = input_ids_a.size() 104 | chunk_factor = full_length // self.base_len 105 | 106 | # special handle of attention mask ----- 107 | attention_mask_body = attention_mask_a.reshape( 108 | batchS, chunk_factor, -1)[:, :, 0] # [batchS, chunk_factor] 109 | inverted_bias = ((1 - attention_mask_body) * (-9999)).float() 110 | 111 | a12 = torch.matmul( 112 | q_embs.unsqueeze(1), a_embs.transpose( 113 | 1, 2)) # [batch, 1, chunk_factor] 114 | logits_a = (a12[:, 0, :] + inverted_bias).max(dim=- 115 | 1, keepdim=False).values # [batch] 116 | # ------------------------------------- 117 | 118 | # special handle of attention mask ----- 119 | attention_mask_body = attention_mask_b.reshape( 120 | batchS, chunk_factor, -1)[:, :, 0] # [batchS, chunk_factor] 121 | inverted_bias = ((1 - attention_mask_body) * (-9999)).float() 122 | 123 | a12 = torch.matmul( 124 | q_embs.unsqueeze(1), b_embs.transpose( 125 | 1, 2)) # [batch, 1, chunk_factor] 126 | logits_b = (a12[:, 0, :] + inverted_bias).max(dim=- 127 | 1, keepdim=False).values # [batch] 128 | # ------------------------------------- 129 | 130 | logit_matrix = torch.cat( 131 | [logits_a.unsqueeze(1), logits_b.unsqueeze(1)], dim=1) # [B, 2] 132 | lsm = F.log_softmax(logit_matrix, dim=1) 133 | loss = -1.0 * lsm[:, 0] 134 | return (loss.mean(),) 135 | 136 | 137 | class RobertaDot_NLL_LN(NLL, RobertaForSequenceClassification): 138 | """None 139 | Compress embedding to 200d, then computes NLL loss. 140 | """ 141 | 142 | def __init__(self, config, model_argobj=None): 143 | NLL.__init__(self, model_argobj) 144 | RobertaForSequenceClassification.__init__(self, config) 145 | self.embeddingHead = nn.Linear(config.hidden_size, 768) 146 | self.norm = nn.LayerNorm(768) 147 | self.apply(self._init_weights) 148 | 149 | def query_emb(self, input_ids, attention_mask): 150 | outputs1 = self.roberta(input_ids=input_ids, 151 | attention_mask=attention_mask) 152 | full_emb = self.masked_mean_or_first(outputs1, attention_mask) 153 | query1 = self.norm(self.embeddingHead(full_emb)) 154 | return query1 155 | 156 | def body_emb(self, input_ids, attention_mask): 157 | return self.query_emb(input_ids, attention_mask) 158 | 159 | 160 | class RobertaDot_CLF_ANN_NLL_MultiChunk(NLL_MultiChunk, RobertaDot_NLL_LN): 161 | def __init__(self, config): 162 | RobertaDot_NLL_LN.__init__(self, config) 163 | self.base_len = 512 164 | 165 | def body_emb(self, input_ids, attention_mask): 166 | [batchS, full_length] = input_ids.size() 167 | chunk_factor = full_length // self.base_len 168 | 169 | input_seq = input_ids.reshape( 170 | batchS, 171 | chunk_factor, 172 | full_length // 173 | chunk_factor).reshape( 174 | batchS * 175 | chunk_factor, 176 | full_length // 177 | chunk_factor) 178 | attention_mask_seq = attention_mask.reshape( 179 | batchS, 180 | chunk_factor, 181 | full_length // 182 | chunk_factor).reshape( 183 | batchS * 184 | chunk_factor, 185 | full_length // 186 | chunk_factor) 187 | 188 | outputs_k = self.roberta(input_ids=input_seq, 189 | attention_mask=attention_mask_seq) 190 | 191 | compressed_output_k = self.embeddingHead( 192 | outputs_k[0]) # [batch, len, dim] 193 | compressed_output_k = self.norm(compressed_output_k[:, 0, :]) 194 | 195 | [batch_expand, embeddingS] = compressed_output_k.size() 196 | complex_emb_k = compressed_output_k.reshape( 197 | batchS, chunk_factor, embeddingS) 198 | 199 | return complex_emb_k # size [batchS, chunk_factor, embeddingS] 200 | 201 | class SEEDEncoderDot_NLL_LN(NLL, SEEDEncoderForSequenceClassification): 202 | """None 203 | Compress embedding to 200d, then computes NLL loss. 204 | """ 205 | def __init__(self, config, model_argobj=None): 206 | NLL.__init__(self, model_argobj) 207 | SEEDEncoderForSequenceClassification.__init__(self, config) 208 | self.embeddingHead = nn.Linear(config.encoder_embed_dim, 768) 209 | self.norm = nn.LayerNorm(768) 210 | self.apply(self._init_weights) 211 | 212 | def query_emb(self, input_ids, attention_mask=None): 213 | outputs1 = self.seed_encoder.encoder(input_ids) 214 | 215 | full_emb = self.masked_mean_or_first(outputs1, attention_mask) 216 | query1 = self.norm(self.embeddingHead(full_emb)) 217 | 218 | return query1 219 | 220 | def body_emb(self, input_ids, attention_mask=None): 221 | return self.query_emb(input_ids, attention_mask) 222 | 223 | class HFBertEncoder(BertModel): 224 | def __init__(self, config): 225 | BertModel.__init__(self, config) 226 | assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' 227 | self.init_weights() 228 | @classmethod 229 | def init_encoder(cls, args, dropout: float = 0.1): 230 | cfg = BertConfig.from_pretrained("bert-base-uncased") 231 | if dropout != 0: 232 | cfg.attention_probs_dropout_prob = dropout 233 | cfg.hidden_dropout_prob = dropout 234 | return cls.from_pretrained("bert-base-uncased", config=cfg) 235 | def forward(self, input_ids, attention_mask): 236 | hidden_states = None 237 | sequence_output, pooled_output = super().forward(input_ids=input_ids, 238 | attention_mask=attention_mask) 239 | pooled_output = sequence_output[:, 0, :] 240 | return sequence_output, pooled_output, hidden_states 241 | def get_out_size(self): 242 | if self.encode_proj: 243 | return self.encode_proj.out_features 244 | return self.config.hidden_size 245 | 246 | 247 | class BiEncoder(nn.Module): 248 | """ Bi-Encoder model component. Encapsulates query/question and context/passage encoders. 249 | """ 250 | def __init__(self, args): 251 | super(BiEncoder, self).__init__() 252 | self.question_model = HFBertEncoder.init_encoder(args) 253 | self.ctx_model = HFBertEncoder.init_encoder(args) 254 | def query_emb(self, input_ids, attention_mask): 255 | sequence_output, pooled_output, hidden_states = self.question_model(input_ids, attention_mask) 256 | return pooled_output 257 | def body_emb(self, input_ids, attention_mask): 258 | sequence_output, pooled_output, hidden_states = self.ctx_model(input_ids, attention_mask) 259 | return pooled_output 260 | def forward(self, query_ids, attention_mask_q, input_ids_a = None, attention_mask_a = None, input_ids_b = None, attention_mask_b = None): 261 | if input_ids_b is None: 262 | q_embs = self.query_emb(query_ids, attention_mask_q) 263 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 264 | return (q_embs, a_embs) 265 | q_embs = self.query_emb(query_ids, attention_mask_q) 266 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 267 | b_embs = self.body_emb(input_ids_b, attention_mask_b) 268 | logit_matrix = torch.cat([(q_embs*a_embs).sum(-1).unsqueeze(1), (q_embs*b_embs).sum(-1).unsqueeze(1)], dim=1) #[B, 2] 269 | lsm = F.log_softmax(logit_matrix, dim=1) 270 | loss = -1.0*lsm[:,0] 271 | return (loss.mean(),) 272 | 273 | 274 | # -------------------------------------------------- 275 | ALL_MODELS = sum( 276 | ( 277 | tuple(conf.pretrained_config_archive_map.keys()) 278 | for conf in ( 279 | RobertaConfig, 280 | ) if hasattr(conf,'pretrained_config_archive_map') 281 | ), 282 | (), 283 | ) 284 | 285 | 286 | default_process_fn = triple_process_fn 287 | 288 | 289 | class MSMarcoConfig: 290 | def __init__(self, name, model, process_fn=default_process_fn, use_mean=True, tokenizer_class=RobertaTokenizer, config_class=RobertaConfig): 291 | self.name = name 292 | self.process_fn = process_fn 293 | self.model_class = model 294 | self.use_mean = use_mean 295 | self.tokenizer_class = tokenizer_class 296 | self.config_class = config_class 297 | 298 | 299 | configs = [ 300 | MSMarcoConfig(name="rdot_nll", 301 | model=RobertaDot_NLL_LN, 302 | use_mean=False, 303 | ), 304 | MSMarcoConfig(name="rdot_nll_multi_chunk", 305 | model=RobertaDot_CLF_ANN_NLL_MultiChunk, 306 | use_mean=False, 307 | ), 308 | MSMarcoConfig(name="dpr", 309 | model=BiEncoder, 310 | tokenizer_class=BertTokenizer, 311 | config_class=BertConfig, 312 | use_mean=False, 313 | ), 314 | MSMarcoConfig(name="seeddot_nll", 315 | model=SEEDEncoderDot_NLL_LN, 316 | use_mean=False, 317 | tokenizer_class=SEEDTokenizer, 318 | config_class=SEEDEncoderConfig, 319 | ), 320 | ] 321 | 322 | MSMarcoConfigDict = {cfg.name: cfg for cfg in configs} 323 | -------------------------------------------------------------------------------- /utils/dpr_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import sys 3 | sys.path += ['../'] 4 | import glob 5 | import logging 6 | import os 7 | from typing import List, Tuple, Dict 8 | import faiss 9 | import pickle 10 | import numpy as np 11 | import unicodedata 12 | import torch 13 | import torch.distributed as dist 14 | from torch import nn 15 | from torch.serialization import default_restore_location 16 | import regex 17 | from transformers import AdamW 18 | from utils.lamb import Lamb 19 | 20 | 21 | logger = logging.getLogger() 22 | 23 | CheckpointState = collections.namedtuple("CheckpointState", 24 | ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', 25 | 'encoder_params']) 26 | 27 | def get_encoder_checkpoint_params_names(): 28 | return ['do_lower_case', 'pretrained_model_cfg', 'encoder_model_type', 29 | 'pretrained_file', 30 | 'projection_dim', 'sequence_length'] 31 | 32 | def get_encoder_params_state(args): 33 | """ 34 | Selects the param values to be saved in a checkpoint, so that a trained model faile can be used for downstream 35 | tasks without the need to specify these parameter again 36 | :return: Dict of params to memorize in a checkpoint 37 | """ 38 | params_to_save = get_encoder_checkpoint_params_names() 39 | 40 | r = {} 41 | for param in params_to_save: 42 | r[param] = getattr(args, param) 43 | return r 44 | 45 | def set_encoder_params_from_state(state, args): 46 | if not state: 47 | return 48 | params_to_save = get_encoder_checkpoint_params_names() 49 | 50 | override_params = [(param, state[param]) for param in params_to_save if param in state and state[param]] 51 | for param, value in override_params: 52 | if hasattr(args, param): 53 | logger.warning('Overriding args parameter value from checkpoint state. Param = %s, value = %s', param, 54 | value) 55 | setattr(args, param, value) 56 | return args 57 | 58 | def get_model_obj(model: nn.Module): 59 | return model.module if hasattr(model, 'module') else model 60 | 61 | 62 | def get_model_file(args, file_prefix) -> str: 63 | out_cp_files = glob.glob(os.path.join(args.output_dir, file_prefix + '*')) if args.output_dir else [] 64 | logger.info('Checkpoint files %s', out_cp_files) 65 | model_file = None 66 | 67 | if args.model_file and os.path.exists(args.model_file): 68 | model_file = args.model_file 69 | elif len(out_cp_files) > 0: 70 | model_file = max(out_cp_files, key=os.path.getctime) 71 | return model_file 72 | 73 | 74 | def load_states_from_checkpoint(model_file: str) -> CheckpointState: 75 | logger.info('Reading saved model from %s', model_file) 76 | state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu')) 77 | logger.info('model_state_dict keys %s', state_dict.keys()) 78 | return CheckpointState(**state_dict) 79 | 80 | def get_optimizer(args, model: nn.Module, weight_decay: float = 0.0, ) -> torch.optim.Optimizer: 81 | no_decay = ['bias', 'LayerNorm.weight'] 82 | optimizer_grouped_parameters = [ 83 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 84 | 'weight_decay': weight_decay}, 85 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 86 | ] 87 | if args.optimizer == "adamW": 88 | return AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 89 | elif args.optimizer == "lamb": 90 | return Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 91 | else: 92 | raise Exception("optimizer {0} not recognized! Can only be lamb or adamW".format(args.optimizer)) 93 | 94 | 95 | def all_gather_list(data, group=None, max_size=16384): 96 | """Gathers arbitrary data from all nodes into a list. 97 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 98 | data. Note that *data* must be picklable. 99 | Args: 100 | data (Any): data from the local worker to be gathered on other workers 101 | group (optional): group of the collective 102 | """ 103 | SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size 104 | 105 | enc = pickle.dumps(data) 106 | enc_size = len(enc) 107 | 108 | if enc_size + SIZE_STORAGE_BYTES > max_size: 109 | raise ValueError( 110 | 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) 111 | 112 | rank = dist.get_rank() 113 | world_size = dist.get_world_size() 114 | buffer_size = max_size * world_size 115 | 116 | if not hasattr(all_gather_list, '_buffer') or \ 117 | all_gather_list._buffer.numel() < buffer_size: 118 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 119 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 120 | 121 | buffer = all_gather_list._buffer 122 | buffer.zero_() 123 | cpu_buffer = all_gather_list._cpu_buffer 124 | 125 | assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( 126 | 256 ** SIZE_STORAGE_BYTES) 127 | 128 | size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') 129 | 130 | cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) 131 | cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) 132 | 133 | start = rank * max_size 134 | size = enc_size + SIZE_STORAGE_BYTES 135 | buffer[start: start + size].copy_(cpu_buffer[:size]) 136 | 137 | if group is None: 138 | group = dist.group.WORLD 139 | dist.all_reduce(buffer, group=group) 140 | 141 | try: 142 | result = [] 143 | for i in range(world_size): 144 | out_buffer = buffer[i * max_size: (i + 1) * max_size] 145 | size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') 146 | if size > 0: 147 | result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) 148 | return result 149 | except pickle.UnpicklingError: 150 | raise Exception( 151 | 'Unable to unpickle data from other workers. all_gather_list requires all ' 152 | 'workers to enter the function together, so this error usually indicates ' 153 | 'that the workers have fallen out of sync somehow. Workers can fall out of ' 154 | 'sync if one of them runs out of memory, or if there are other conditions ' 155 | 'in your training script that can cause one worker to finish an epoch ' 156 | 'while other workers are still iterating over their portions of the data.' 157 | ) 158 | 159 | 160 | 161 | class DenseHNSWFlatIndexer(object): 162 | """ 163 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 164 | """ 165 | 166 | def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512 167 | , ef_search: int = 128, ef_construction: int = 200): 168 | self.buffer_size = buffer_size 169 | self.index_id_to_db_id = [] 170 | self.index = None 171 | 172 | # IndexHNSWFlat supports L2 similarity only 173 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 174 | index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) 175 | index.hnsw.efSearch = ef_search 176 | index.hnsw.efConstruction = ef_construction 177 | self.index = index 178 | self.phi = 0 179 | 180 | def index_data(self, data: List[Tuple[object, np.array]]): 181 | n = len(data) 182 | 183 | # max norm is required before putting all vectors in the index to convert inner product similarity to L2 184 | if self.phi > 0: 185 | raise RuntimeError('DPR HNSWF index needs to index all data at once,' 186 | 'results will be unpredictable otherwise.') 187 | phi = 0 188 | for i, item in enumerate(data): 189 | id, doc_vector = item 190 | norms = (doc_vector ** 2).sum() 191 | phi = max(phi, norms) 192 | logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi)) 193 | self.phi = 0 194 | 195 | # indexing in batches is beneficial for many faiss index types 196 | for i in range(0, n, self.buffer_size): 197 | db_ids = [t[0] for t in data[i:i + self.buffer_size]] 198 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] 199 | 200 | norms = [(doc_vector ** 2).sum() for doc_vector in vectors] 201 | aux_dims = [np.sqrt(phi - norm) for norm in norms] 202 | hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in 203 | enumerate(vectors)] 204 | hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) 205 | 206 | self._update_id_mapping(db_ids) 207 | self.index.add(hnsw_vectors) 208 | logger.info('data indexed %d', len(self.index_id_to_db_id)) 209 | 210 | indexed_cnt = len(self.index_id_to_db_id) 211 | logger.info('Total data indexed %d', indexed_cnt) 212 | 213 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 214 | 215 | aux_dim = np.zeros(len(query_vectors), dtype='float32') 216 | query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 217 | logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape) 218 | scores, indexes = self.index.search(query_nhsw_vectors, top_docs) 219 | # convert to external ids 220 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 221 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 222 | return result 223 | 224 | def _update_id_mapping(self, db_ids: List): 225 | self.index_id_to_db_id.extend(db_ids) 226 | 227 | 228 | 229 | def check_answer(passages, answers, doc_ids, tokenizer): 230 | """Search through all the top docs to see if they have any of the answers.""" 231 | hits = [] 232 | for i, doc_id in enumerate(doc_ids): 233 | text = passages[doc_id][0] 234 | hits.append(has_answer(answers, text, tokenizer)) 235 | return hits 236 | 237 | 238 | def has_answer(answers, text, tokenizer): 239 | """Check if a document contains an answer string. 240 | If `match_type` is string, token matching is done between the text and answer. 241 | If `match_type` is regex, we search the whole text with the regex. 242 | """ 243 | 244 | if text is None: 245 | logger.warning("no doc in db") 246 | return False 247 | 248 | text = _normalize(text) 249 | 250 | # Answer is a list of possible strings 251 | text = tokenizer.tokenize(text).words(uncased=True) 252 | 253 | for single_answer in answers: 254 | single_answer = _normalize(single_answer) 255 | single_answer = tokenizer.tokenize(single_answer) 256 | single_answer = single_answer.words(uncased=True) 257 | 258 | for i in range(0, len(text) - len(single_answer) + 1): 259 | if single_answer == text[i: i + len(single_answer)]: 260 | return True 261 | return False 262 | 263 | 264 | class SimpleTokenizer: 265 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 266 | NON_WS = r'[^\p{Z}\p{C}]' 267 | 268 | def __init__(self, **kwargs): 269 | """ 270 | Args: 271 | annotators: None or empty set (only tokenizes). 272 | """ 273 | self._regexp = regex.compile( 274 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 275 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 276 | ) 277 | if len(kwargs.get('annotators', {})) > 0: 278 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 279 | (type(self).__name__, kwargs.get('annotators'))) 280 | self.annotators = set() 281 | 282 | def tokenize(self, text): 283 | data = [] 284 | matches = [m for m in self._regexp.finditer(text)] 285 | for i in range(len(matches)): 286 | # Get text 287 | token = matches[i].group() 288 | 289 | # Get whitespace 290 | span = matches[i].span() 291 | start_ws = span[0] 292 | if i + 1 < len(matches): 293 | end_ws = matches[i + 1].span()[0] 294 | else: 295 | end_ws = span[1] 296 | 297 | # Format data 298 | data.append(( 299 | token, 300 | text[start_ws: end_ws], 301 | span, 302 | )) 303 | return Tokens(data, self.annotators) 304 | 305 | 306 | def _normalize(text): 307 | return unicodedata.normalize('NFD', text) 308 | 309 | 310 | class Tokens(object): 311 | """A class to represent a list of tokenized text.""" 312 | TEXT = 0 313 | TEXT_WS = 1 314 | SPAN = 2 315 | POS = 3 316 | LEMMA = 4 317 | NER = 5 318 | 319 | def __init__(self, data, annotators, opts=None): 320 | self.data = data 321 | self.annotators = annotators 322 | self.opts = opts or {} 323 | 324 | def __len__(self): 325 | """The number of tokens.""" 326 | return len(self.data) 327 | 328 | def words(self, uncased=False): 329 | """Returns a list of the text of each token 330 | 331 | Args: 332 | uncased: lower cases text 333 | """ 334 | if uncased: 335 | return [t[self.TEXT].lower() for t in self.data] 336 | else: 337 | return [t[self.TEXT] for t in self.data] 338 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path += ['../'] 3 | import pandas as pd 4 | from sklearn.metrics import roc_curve, auc 5 | import gzip 6 | import copy 7 | import torch 8 | from torch import nn 9 | import torch.distributed as dist 10 | from tqdm import tqdm, trange 11 | import os 12 | from os import listdir 13 | from os.path import isfile, join 14 | import json 15 | import logging 16 | import random 17 | import pytrec_eval 18 | import pickle 19 | import numpy as np 20 | import torch 21 | torch.multiprocessing.set_sharing_strategy('file_system') 22 | from multiprocessing import Process 23 | from torch.utils.data import DataLoader, Dataset, TensorDataset, IterableDataset 24 | import re 25 | from model.models import MSMarcoConfigDict, ALL_MODELS 26 | from typing import List, Set, Dict, Tuple, Callable, Iterable, Any 27 | 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class InputFeaturesPair(object): 33 | """ 34 | A single set of features of data. 35 | 36 | Args: 37 | input_ids: Indices of input sequence tokens in the vocabulary. 38 | attention_mask: Mask to avoid performing attention on padding token indices. 39 | Mask values selected in ``[0, 1]``: 40 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 41 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 42 | label: Label corresponding to the input 43 | """ 44 | 45 | def __init__( 46 | self, 47 | input_ids_a, 48 | attention_mask_a=None, 49 | token_type_ids_a=None, 50 | input_ids_b=None, 51 | attention_mask_b=None, 52 | token_type_ids_b=None, 53 | label=None): 54 | 55 | self.input_ids_a = input_ids_a 56 | self.attention_mask_a = attention_mask_a 57 | self.token_type_ids_a = token_type_ids_a 58 | 59 | self.input_ids_b = input_ids_b 60 | self.attention_mask_b = attention_mask_b 61 | self.token_type_ids_b = token_type_ids_b 62 | 63 | self.label = label 64 | 65 | def __repr__(self): 66 | return str(self.to_json_string()) 67 | 68 | def to_dict(self): 69 | """Serializes this instance to a Python dictionary.""" 70 | output = copy.deepcopy(self.__dict__) 71 | return output 72 | 73 | def to_json_string(self): 74 | """Serializes this instance to a JSON string.""" 75 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 76 | 77 | 78 | def getattr_recursive(obj, name): 79 | for layer in name.split("."): 80 | if hasattr(obj, layer): 81 | obj = getattr(obj, layer) 82 | else: 83 | return None 84 | return obj 85 | 86 | 87 | def barrier_array_merge( 88 | args, 89 | data_array, 90 | merge_axis=0, 91 | prefix="", 92 | load_cache=False, 93 | only_load_in_master=False): 94 | # data array: [B, any dimension] 95 | # merge alone one axis 96 | 97 | if args.local_rank == -1: 98 | return data_array 99 | 100 | if not load_cache: 101 | rank = args.rank 102 | if is_first_worker(): 103 | if not os.path.exists(args.output_dir): 104 | os.makedirs(args.output_dir) 105 | 106 | dist.barrier() # directory created 107 | pickle_path = os.path.join( 108 | args.output_dir, 109 | "{1}_data_obj_{0}.pb".format( 110 | str(rank), 111 | prefix)) 112 | with open(pickle_path, 'wb') as handle: 113 | pickle.dump(data_array, handle, protocol=4) 114 | 115 | # make sure all processes wrote their data before first process 116 | # collects it 117 | dist.barrier() 118 | 119 | data_array = None 120 | 121 | data_list = [] 122 | 123 | # return empty data 124 | if only_load_in_master: 125 | if not is_first_worker(): 126 | dist.barrier() 127 | return None 128 | 129 | for i in range( 130 | args.world_size): # TODO: dynamically find the max instead of HardCode 131 | pickle_path = os.path.join( 132 | args.output_dir, 133 | "{1}_data_obj_{0}.pb".format( 134 | str(i), 135 | prefix)) 136 | try: 137 | with open(pickle_path, 'rb') as handle: 138 | b = pickle.load(handle) 139 | data_list.append(b) 140 | except BaseException: 141 | continue 142 | 143 | data_array_agg = np.concatenate(data_list, axis=merge_axis) 144 | dist.barrier() 145 | return data_array_agg 146 | 147 | 148 | def pad_input_ids(input_ids, max_length, 149 | pad_on_left=False, 150 | pad_token=0): 151 | padding_length = max_length - len(input_ids) 152 | padding_id = [pad_token] * padding_length 153 | 154 | if padding_length <= 0: 155 | input_ids = input_ids[:max_length] 156 | else: 157 | if pad_on_left: 158 | input_ids = padding_id + input_ids 159 | else: 160 | input_ids = input_ids + padding_id 161 | 162 | return input_ids 163 | 164 | 165 | def pad_ids(input_ids, attention_mask, token_type_ids, max_length, 166 | pad_on_left=False, 167 | pad_token=0, 168 | pad_token_segment_id=0, 169 | mask_padding_with_zero=True): 170 | padding_length = max_length - len(input_ids) 171 | padding_id = [pad_token] * padding_length 172 | padding_type = [pad_token_segment_id] * padding_length 173 | padding_attention = [0 if mask_padding_with_zero else 1] * padding_length 174 | 175 | if padding_length <= 0: 176 | input_ids = input_ids[:max_length] 177 | attention_mask = attention_mask[:max_length] 178 | token_type_ids = token_type_ids[:max_length] 179 | else: 180 | if pad_on_left: 181 | input_ids = padding_id + input_ids 182 | attention_mask = padding_attention + attention_mask 183 | token_type_ids = padding_type + token_type_ids 184 | else: 185 | input_ids = input_ids + padding_id 186 | attention_mask = attention_mask + padding_attention 187 | token_type_ids = token_type_ids + padding_type 188 | 189 | return input_ids, attention_mask, token_type_ids 190 | 191 | 192 | # to reuse pytrec_eval, id must be string 193 | def convert_to_string_id(result_dict): 194 | string_id_dict = {} 195 | 196 | # format [string, dict[string, val]] 197 | for k, v in result_dict.items(): 198 | _temp_v = {} 199 | for inner_k, inner_v in v.items(): 200 | _temp_v[str(inner_k)] = inner_v 201 | 202 | string_id_dict[str(k)] = _temp_v 203 | 204 | return string_id_dict 205 | 206 | 207 | def set_seed(args): 208 | random.seed(args.seed) 209 | np.random.seed(args.seed) 210 | torch.manual_seed(args.seed) 211 | if args.n_gpu > 0: 212 | torch.cuda.manual_seed_all(args.seed) 213 | 214 | 215 | def is_first_worker(): 216 | return not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0 217 | 218 | 219 | def concat_key(all_list, key, axis=0): 220 | return np.concatenate([ele[key] for ele in all_list], axis=axis) 221 | 222 | 223 | def get_checkpoint_no(checkpoint_path): 224 | nums = re.findall(r'\d+', checkpoint_path) 225 | return int(nums[-1]) if len(nums) > 0 else 0 226 | 227 | 228 | def get_latest_ann_data(ann_data_path): 229 | ANN_PREFIX = "ann_ndcg_" 230 | if not os.path.exists(ann_data_path): 231 | return -1, None, None 232 | files = list(next(os.walk(ann_data_path))[2]) 233 | num_start_pos = len(ANN_PREFIX) 234 | data_no_list = [int(s[num_start_pos:]) 235 | for s in files if s[:num_start_pos] == ANN_PREFIX] 236 | if len(data_no_list) > 0: 237 | data_no = max(data_no_list) 238 | with open(os.path.join(ann_data_path, ANN_PREFIX + str(data_no)), 'r') as f: 239 | ndcg_json = json.load(f) 240 | return data_no, os.path.join( 241 | ann_data_path, "ann_training_data_" + str(data_no)), ndcg_json 242 | return -1, None, None 243 | 244 | 245 | def numbered_byte_file_generator(base_path, file_no, record_size): 246 | for i in range(file_no): 247 | with open('{}_split{}'.format(base_path, i), 'rb') as f: 248 | while True: 249 | b = f.read(record_size) 250 | if not b: 251 | # eof 252 | break 253 | yield b 254 | 255 | 256 | class EmbeddingCache: 257 | def __init__(self, base_path, seed=-1): 258 | self.base_path = base_path 259 | with open(base_path + '_meta', 'r') as f: 260 | meta = json.load(f) 261 | self.dtype = np.dtype(meta['type']) 262 | self.total_number = meta['total_number'] 263 | self.record_size = int( 264 | meta['embedding_size']) * self.dtype.itemsize + 4 265 | if seed >= 0: 266 | self.ix_array = np.random.RandomState( 267 | seed).permutation(self.total_number) 268 | else: 269 | self.ix_array = np.arange(self.total_number) 270 | self.f = None 271 | 272 | def open(self): 273 | self.f = open(self.base_path, 'rb') 274 | 275 | def close(self): 276 | self.f.close() 277 | 278 | def read_single_record(self): 279 | record_bytes = self.f.read(self.record_size) 280 | passage_len = int.from_bytes(record_bytes[:4], 'big') 281 | passage = np.frombuffer(record_bytes[4:], dtype=self.dtype) 282 | return passage_len, passage 283 | 284 | def __enter__(self): 285 | self.open() 286 | return self 287 | 288 | def __exit__(self, type, value, traceback): 289 | self.close() 290 | 291 | def __getitem__(self, key): 292 | if key < 0 or key > self.total_number: 293 | raise IndexError( 294 | "Index {} is out of bound for cached embeddings of size {}".format( 295 | key, self.total_number)) 296 | self.f.seek(key * self.record_size) 297 | return self.read_single_record() 298 | 299 | def __iter__(self): 300 | self.f.seek(0) 301 | for i in range(self.total_number): 302 | new_ix = self.ix_array[i] 303 | yield self.__getitem__(new_ix) 304 | 305 | def __len__(self): 306 | return self.total_number 307 | 308 | 309 | class StreamingDataset(IterableDataset): 310 | def __init__(self, elements, fn, distributed=True): 311 | super().__init__() 312 | self.elements = elements 313 | self.fn = fn 314 | self.num_replicas=-1 315 | self.distributed = distributed 316 | 317 | def __iter__(self): 318 | if dist.is_initialized(): 319 | self.num_replicas = dist.get_world_size() 320 | self.rank = dist.get_rank() 321 | else: 322 | print("Not running in distributed mode") 323 | for i, element in enumerate(self.elements): 324 | if self.distributed and self.num_replicas != -1 and i % self.num_replicas != self.rank: 325 | continue 326 | records = self.fn(element, i) 327 | for rec in records: 328 | yield rec 329 | 330 | 331 | def tokenize_to_file(args, i, num_process, in_path, out_path, line_fn): 332 | 333 | configObj = MSMarcoConfigDict[args.model_type] 334 | tokenizer = configObj.tokenizer_class.from_pretrained( 335 | args.model_name_or_path, 336 | do_lower_case=True, 337 | cache_dir=None, 338 | ) 339 | 340 | with open(in_path, 'r', encoding='utf-8') if in_path[-2:] != "gz" else gzip.open(in_path, 'rt', encoding='utf8') as in_f,\ 341 | open('{}_split{}'.format(out_path, i), 'wb') as out_f: 342 | for idx, line in enumerate(in_f): 343 | if idx % num_process != i: 344 | continue 345 | out_f.write(line_fn(args, line, tokenizer)) 346 | 347 | 348 | def multi_file_process(args, num_process, in_path, out_path, line_fn): 349 | processes = [] 350 | for i in range(num_process): 351 | p = Process( 352 | target=tokenize_to_file, 353 | args=( 354 | args, 355 | i, 356 | num_process, 357 | in_path, 358 | out_path, 359 | line_fn, 360 | )) 361 | processes.append(p) 362 | p.start() 363 | for p in processes: 364 | p.join() 365 | 366 | 367 | def all_gather(data): 368 | """ 369 | Run all_gather on arbitrary picklable data (not necessarily tensors) 370 | Args: 371 | data: any picklable object 372 | Returns: 373 | list[data]: list of data gathered from each rank 374 | """ 375 | if not dist.is_initialized() or dist.get_world_size() == 1: 376 | return [data] 377 | 378 | world_size = dist.get_world_size() 379 | # serialized to a Tensor 380 | buffer = pickle.dumps(data) 381 | storage = torch.ByteStorage.from_buffer(buffer) 382 | tensor = torch.ByteTensor(storage).to("cuda") 383 | 384 | # obtain Tensor size of each rank 385 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 386 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 387 | dist.all_gather(size_list, local_size) 388 | size_list = [int(size.item()) for size in size_list] 389 | max_size = max(size_list) 390 | 391 | # receiving Tensor from all ranks 392 | # we pad the tensor because torch all_gather does not support 393 | # gathering tensors of different shapes 394 | tensor_list = [] 395 | for _ in size_list: 396 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 397 | if local_size != max_size: 398 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 399 | tensor = torch.cat((tensor, padding), dim=0) 400 | dist.all_gather(tensor_list, tensor) 401 | 402 | data_list = [] 403 | for size, tensor in zip(size_list, tensor_list): 404 | buffer = tensor.cpu().numpy().tobytes()[:size] 405 | data_list.append(pickle.loads(buffer)) 406 | 407 | return data_list 408 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval 2 | Lee Xiong*, Chenyan Xiong*, Ye Li, Kwok-Fung Tang, Jialin Liu, Paul Bennett, Junaid Ahmed, Arnold Overwijk 3 | 4 | This repo provides the code for reproducing the experiments in [Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval](https://arxiv.org/pdf/2007.00808.pdf) 5 | 6 | Conducting text retrieval in a dense learned representation space has many intriguing advantages over sparse retrieval. Yet the effectiveness of dense retrieval (DR) 7 | often requires combination with sparse retrieval. In this paper, we identify that 8 | the main bottleneck is in the training mechanisms, where the negative instances 9 | used in training are not representative of the irrelevant documents in testing. This 10 | paper presents Approximate nearest neighbor Negative Contrastive Estimation 11 | (ANCE), a training mechanism that constructs negatives from an Approximate 12 | Nearest Neighbor (ANN) index of the corpus, which is parallelly updated with the 13 | learning process to select more realistic negative training instances. This fundamentally resolves the discrepancy between the data distribution used in the training 14 | and testing of DR. In our experiments, ANCE boosts the BERT-Siamese DR 15 | model to outperform all competitive dense and sparse retrieval baselines. It nearly 16 | matches the accuracy of sparse-retrieval-and-BERT-reranking using dot-product in 17 | the ANCE-learned representation space and provides almost 100x speed-up. 18 | 19 | Our analyses further confirm that the negatives from sparse retrieval or other sampling methods differ 20 | drastically from the actual negatives in DR, and that ANCE fundamentally resolves this mismatch. 21 | We also show the influence of the asynchronous ANN refreshing on learning convergence and 22 | demonstrate that the efficiency bottleneck is in the encoding update, not in the ANN part during 23 | ANCE training. These qualifications demonstrate the advantages, perhaps also the necessity, of our 24 | asynchronous ANCE learning in dense retrieval. 25 | 26 | ## What's new 27 | * [September 2021 Released SEED-Encoder fine-tuning code.](https://github.com/microsoft/ANCE/tree/master/model/SEED_Encoder/SEED-Encoder.md) 28 | 29 | 30 | ## Requirements 31 | 32 | To install requirements, run the following commands: 33 | 34 | ```setup 35 | git clone https://github.com/microsoft/ANCE 36 | cd ANCE 37 | python setup.py install 38 | ``` 39 | 40 | ## Data Download 41 | To download all the needed data, run: 42 | ``` 43 | bash commands/data_download.sh 44 | ``` 45 | 46 | ## Data Preprocessing 47 | The command to preprocess passage and document data is listed below: 48 | 49 | ``` 50 | python data/msmarco_data.py  51 | --data_dir $raw_data_dir \ 52 | --out_data_dir $preprocessed_data_dir \ 53 | --model_type {use rdot_nll for ANCE FirstP, rdot_nll_multi_chunk for ANCE MaxP} \ 54 | --model_name_or_path roberta-base \  55 | --max_seq_length {use 512 for ANCE FirstP, 2048 for ANCE MaxP} \ 56 | --data_type {use 1 for passage, 0 for document} 57 | ``` 58 | 59 | The data preprocessing command is included as the first step in the training command file commands/run_train.sh 60 | 61 | ## Warmup for Training 62 | ANCE training starts from a pretrained BM25 warmup checkpoint. The command with our used parameters to train this warmup checkpoint is in commands/run_train_warmup.py and is shown below: 63 | 64 | python3 -m torch.distributed.launch --nproc_per_node=1 ../drivers/run_warmup.py \ 65 | --train_model_type rdot_nll \ 66 | --model_name_or_path roberta-base \ 67 | --task_name MSMarco \ 68 | --do_train \ 69 | --evaluate_during_training \ 70 | --data_dir ${location of your raw data} 71 | --max_seq_length 128 72 | --per_gpu_eval_batch_size=256 \ 73 | --per_gpu_train_batch_size=32 \ 74 | --learning_rate 2e-4 \ 75 | --logging_steps 100 \ 76 | --num_train_epochs 2.0 \ 77 | --output_dir ${location for checkpoint saving} \ 78 | --warmup_steps 1000 \ 79 | --overwrite_output_dir \ 80 | --save_steps 30000 \ 81 | --gradient_accumulation_steps 1 \ 82 | --expected_train_size 35000000 \ 83 | --logging_steps_per_eval 1 \ 84 | --fp16 \ 85 | --optimizer lamb \ 86 | --log_dir ~/tensorboard/${DLWS_JOB_ID}/logs/OSpass 87 | 88 | ## Training 89 | 90 | To train the model(s) in the paper, you need to start two commands in the following order: 91 | 92 | 1. run commands/run_train.sh which does three things in a sequence: 93 | 94 | a. Data preprocessing: this is explained in the previous data preprocessing section. This step will check if the preprocess data folder exists, and will be skipped if the checking is positive. 95 | 96 | b. Initial ANN data generation: this step will use the pretrained BM25 warmup checkpoint to generate the initial training data. The command is as follow: 97 | 98 | python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py  99 | --training_dir {# checkpoint location, not used for initial data generation} \ 100 | --init_model_dir {pretrained BM25 warmup checkpoint location} \  101 | --model_type rdot_nll \ 102 | --output_dir $model_ann_data_dir \ 103 | --cache_dir $model_ann_data_dir_cache \ 104 | --data_dir $preprocessed_data_dir \ 105 | --max_seq_length 512 \ 106 | --per_gpu_eval_batch_size 16 \ 107 | --topk_training {top k candidates for ANN search(ie:200)} \ 108 | --negative_sample {negative samples per query(20)} \ 109 | --end_output_num 0 # only set as 0 for initial data generation, do not set this otherwise 110 | 111 | c. Training: ANCE training with the most recently generated ANN data, the command is as follow: 112 | 113 | python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann.py  114 | --model_type rdot_nll \ 115 | --model_name_or_path $pretrained_checkpoint_dir \ 116 | --task_name MSMarco \ 117 | --triplet {# default = False, action="store_true", help="Whether to run training}\ 118 | --data_dir $preprocessed_data_dir \ 119 | --ann_dir {location of the ANN generated training data} \ 120 | --max_seq_length 512 \ 121 | --per_gpu_train_batch_size=8 \ 122 | --gradient_accumulation_steps 2 \ 123 | --learning_rate 1e-6 \ 124 | --output_dir $model_dir \ 125 | --warmup_steps 5000 \ 126 | --logging_steps 100 \ 127 | --save_steps 10000 \ 128 | --optimizer lamb 129 | 130 | 2. Once training starts, start another job in parallel to fetch the latest checkpoint from the ongoing training and update the training data. To do that, run 131 | 132 | bash commands/run_ann_data_gen.sh 133 | 134 | The command is similar to the initial ANN data generation command explained previously 135 | 136 | ## Inference 137 | The command for inferencing query and passage/doc embeddings is the same as that for Initial ANN data generation described above as the first step in ANN data generation is inference. However you need to add --inference to the command to have the program to stop after the initial inference step. commands/run_inference.sh provides a sample command. 138 | 139 | ## Evaluation 140 | 141 | The evaluation is done through "Calculate Metrics.ipynb". This notebook calculates full ranking and reranking metrics used in the paper including NDCG, MRR, hole rate, recall for passage/document, dev/eval set specified by user. In order to run it, you need to define the following parameters at the beginning of the Jupyter notebook. 142 | 143 | checkpoint_path = {location for dumpped query and passage/document embeddings which is output_dir from run_ann_data_gen.py} 144 | checkpoint =  {embedding from which checkpoint(ie: 200000)} 145 | data_type =  {0 for document, 1 for passage} 146 | test_set =  {0 for MSMARCO dev_set, 1 for TREC eval_set} 147 | raw_data_dir =  148 | processed_data_dir =  149 | 150 | ## ANCE VS DPR on OpenQA Benchmarks 151 | We also evaluate ANCE on the OpenQA benchmark used in a parallel work ([DPR](https://github.com/facebookresearch/DPR)). At the time of our experiment, only the pre-processed NQ and TriviaQA data are released. 152 | Our experiments use the two released tasks and inherit DPR retriever evaluation. The evaluation uses the Coverage@20/100 which is whether the Top-20/100 retrieved passages include the answer. We explain the steps to 153 | reproduce our results on OpenQA Benchmarks in this section. 154 | 155 | ### Download data 156 | commands/data_download.sh takes care of this step. 157 | 158 | ### ANN data generation & ANCE training 159 | Following the same training philosophy discussed before, the ann data generation and ANCE training for OpenQA require two parallel jobs. 160 | 1. We need to preprocess data and generate an initial training set for ANCE to start training. The command for that is provided in: 161 | ``` 162 | commands/run_ann_data_gen_dpr.sh 163 | ``` 164 | We keep this data generation job running after it creates an initial training set as it will later keep generating training data with newest checkpoints from the training process. 165 | 166 | 2. After an initial training set is generated, we start an ANCE training job with commands provided in: 167 | ``` 168 | commands/run_train_dpr.sh 169 | ``` 170 | During training, the evaluation metrics will be printed to tensorboards each time it receives new training data. Alternatively, you could check the metrics in the dumped file "ann_ndcg_#" in the directory specified by "model_ann_data_dir" in commands/run_ann_data_gen_dpr.sh each time new training data is generated. 171 | 172 | ## Results 173 | The run_train.sh and run_ann_data_gen.sh files contain the command with the parameters we used for passage ANCE(FirstP), document ANCE(FirstP) and document ANCE(MaxP) 174 | Our model achieves the following performance on MSMARCO dev set and TREC eval set : 175 | 176 | 177 | | MSMARCO Dev Passage Retrieval | MRR@10 | Recall@1k | Steps | 178 | |---------------- | -------------- |-------------- | -------------- | 179 | | ANCE(FirstP) | 0.330 | 0.959 | [600K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Passage_ANCE_FirstP_Checkpoint.zip) | 180 | | ANCE(MaxP) | - | - | - | 181 | 182 | | TREC DL Passage NDCG@10 | Rerank | Retrieval | Steps | 183 | |---------------- | -------------- |-------------- | -------------- | 184 | | ANCE(FirstP) | 0.677 | 0.648 | [600K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Passage_ANCE_FirstP_Checkpoint.zip) | 185 | | ANCE(MaxP) | - | - | - | 186 | 187 | | TREC DL Document NDCG@10 | Rerank | Retrieval | Steps | 188 | |---------------- | -------------- |-------------- | -------------- | 189 | | ANCE(FirstP) | 0.641 | 0.615 | [210K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Document_ANCE_FirstP_Checkpoint.zip) | 190 | | ANCE(MaxP) | 0.671 | 0.628 | [139K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Document_ANCE_MaxP_Checkpoint.zip) | 191 | 192 | | MSMARCO Dev Passage Retrieval | MRR@10 | Steps | 193 | |---------------- | -------------- | -------------- | 194 | | pretrained BM25 warmup checkpoint | 0.311 | [60K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/warmup_checpoint.zip) | 195 | 196 | | ANCE Single-task Training | Top-20 | Top-100 | Steps | 197 | |---------------- | -------------- | -------------- |-------------- | 198 | | NQ | 81.9 | 87.5 | [136K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/nq.cp) | 199 | | TriviaQA | 80.3 | 85.3 | [100K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/trivia.cp) | 200 | 201 | | ANCE Multi-task Training | Top-20 | Top-100 | Steps | 202 | |---------------- | -------------- | -------------- |-------------- | 203 | | NQ | 82.1 | 87.9 | [300K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/multi.cp) | 204 | | TriviaQA | 80.3 | 85.2 | [300K](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/multi.cp) | 205 | 206 | 207 | Click the steps in the table to download the corresponding checkpoints. 208 | 209 | Our result for document ANCE(FirstP) TREC eval set top 100 retrieved document per query could be downloaded [here](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Results/ance_512_eval_top100.txt). 210 | Our result for document ANCE(MaxP) TREC eval set top 100 retrieved document per query could be downloaded [here](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Results/ance_2048_eval_top100.txt). 211 | 212 | The TREC eval set query embedding and their ids for our passage ANCE(FirstP) experiment could be downloaded [here](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Passage_ANCE_FirstP_Embedding.zip). 213 | The TREC eval set query embedding and their ids for our document ANCE(FirstP) experiment could be downloaded [here](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Document_ANCE_FirstP_Embedding.zip). 214 | The TREC eval set query embedding and their ids for our document 2048 ANCE(MaxP) experiment could be downloaded [here](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Document_ANCE_MaxP_Embedding.zip). 215 | 216 | The t-SNE plots for all the queries in the TREC document eval set for ANCE(FirstP) could be viewed [here](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/t-SNE.zip). 217 | 218 | run_train.sh and run_ann_data_gen.sh files contain the commands with the parameters we used for passage ANCE(FirstP), document ANCE(FirstP) and document 2048 ANCE(MaxP) to reproduce the results in this section. 219 | run_train_warmup.sh contains the commands to reproduce the results for the pretrained BM25 warmup checkpoint in this section 220 | 221 | Note the steps to reproduce similar results as shown in the table might be a little different due to different synchronizing between training and ann data generation processes and other possible environment differences of the user experiments. 222 | 223 | 224 | -------------------------------------------------------------------------------- /data/msmarco_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | sys.path += ['../'] 5 | import gzip 6 | import pickle 7 | from utils.util import pad_input_ids, multi_file_process, numbered_byte_file_generator, EmbeddingCache 8 | import csv 9 | from model.models import MSMarcoConfigDict, ALL_MODELS 10 | from torch.utils.data import DataLoader, Dataset, TensorDataset, IterableDataset, get_worker_info 11 | import numpy as np 12 | from os import listdir 13 | from os.path import isfile, join 14 | import argparse 15 | import json 16 | 17 | 18 | def write_query_rel(args, pid2offset, query_file, positive_id_file, out_query_file, out_id_file): 19 | 20 | print( 21 | "Writing query files " + 22 | str(out_query_file) + 23 | " and " + 24 | str(out_id_file)) 25 | query_positive_id = set() 26 | 27 | query_positive_id_path = os.path.join( 28 | args.data_dir, 29 | positive_id_file, 30 | ) 31 | 32 | print("Loading query_2_pos_docid") 33 | with gzip.open(query_positive_id_path, 'rt', encoding='utf8') if positive_id_file[-2:] == "gz" else open(query_positive_id_path, 'r', encoding='utf8') as f: 34 | if args.data_type == 0: 35 | tsvreader = csv.reader(f, delimiter=" ") 36 | else: 37 | tsvreader = csv.reader(f, delimiter="\t") 38 | for [topicid, _, docid, rel] in tsvreader: 39 | query_positive_id.add(int(topicid)) 40 | 41 | query_collection_path = os.path.join( 42 | args.data_dir, 43 | query_file, 44 | ) 45 | 46 | out_query_path = os.path.join( 47 | args.out_data_dir, 48 | out_query_file, 49 | ) 50 | 51 | qid2offset = {} 52 | 53 | print('start query file split processing') 54 | multi_file_process( 55 | args, 56 | 32, 57 | query_collection_path, 58 | out_query_path, 59 | QueryPreprocessingFn) 60 | 61 | print('start merging splits') 62 | 63 | idx = 0 64 | with open(out_query_path, 'wb') as f: 65 | for record in numbered_byte_file_generator( 66 | out_query_path, 32, 8 + 4 + args.max_query_length * 4): 67 | q_id = int.from_bytes(record[:8], 'big') 68 | if q_id not in query_positive_id: 69 | # exclude the query as it is not in label set 70 | continue 71 | f.write(record[8:]) 72 | qid2offset[q_id] = idx 73 | idx += 1 74 | if idx < 3: 75 | print(str(idx) + " " + str(q_id)) 76 | 77 | qid2offset_path = os.path.join( 78 | args.out_data_dir, 79 | "qid2offset.pickle", 80 | ) 81 | with open(qid2offset_path, 'wb') as handle: 82 | pickle.dump(qid2offset, handle, protocol=4) 83 | print("done saving qid2offset") 84 | 85 | print("Total lines written: " + str(idx)) 86 | meta = {'type': 'int32', 'total_number': idx, 87 | 'embedding_size': args.max_query_length} 88 | with open(out_query_path + "_meta", 'w') as f: 89 | json.dump(meta, f) 90 | 91 | embedding_cache = EmbeddingCache(out_query_path) 92 | print("First line") 93 | with embedding_cache as emb: 94 | print(emb[0]) 95 | 96 | out_id_path = os.path.join( 97 | args.out_data_dir, 98 | out_id_file, 99 | ) 100 | 101 | print("Writing qrels") 102 | with gzip.open(query_positive_id_path, 'rt', encoding='utf8') if positive_id_file[-2:] == "gz" else open(query_positive_id_path, 'r', encoding='utf8') as f, \ 103 | open(out_id_path, "w", encoding='utf-8') as out_id: 104 | 105 | if args.data_type == 0: 106 | tsvreader = csv.reader(f, delimiter=" ") 107 | else: 108 | tsvreader = csv.reader(f, delimiter="\t") 109 | out_line_count = 0 110 | for [topicid, _, docid, rel] in tsvreader: 111 | topicid = int(topicid) 112 | if args.data_type == 0: 113 | docid = int(docid[1:]) 114 | else: 115 | docid = int(docid) 116 | out_id.write(str(qid2offset[topicid]) + 117 | "\t" + 118 | str(pid2offset[docid]) + 119 | "\t" + 120 | rel + 121 | "\n") 122 | out_line_count += 1 123 | print("Total lines written: " + str(out_line_count)) 124 | 125 | 126 | def preprocess(args): 127 | 128 | pid2offset = {} 129 | if args.data_type == 0: 130 | in_passage_path = os.path.join( 131 | args.data_dir, 132 | "msmarco-docs.tsv", 133 | ) 134 | else: 135 | in_passage_path = os.path.join( 136 | args.data_dir, 137 | "collection.tsv", 138 | ) 139 | 140 | out_passage_path = os.path.join( 141 | args.out_data_dir, 142 | "passages", 143 | ) 144 | 145 | if os.path.exists(out_passage_path): 146 | print("preprocessed data already exist, exit preprocessing") 147 | return 148 | 149 | out_line_count = 0 150 | 151 | print('start passage file split processing') 152 | multi_file_process( 153 | args, 154 | 32, 155 | in_passage_path, 156 | out_passage_path, 157 | PassagePreprocessingFn) 158 | 159 | print('start merging splits') 160 | with open(out_passage_path, 'wb') as f: 161 | for idx, record in enumerate(numbered_byte_file_generator( 162 | out_passage_path, 32, 8 + 4 + args.max_seq_length * 4)): 163 | p_id = int.from_bytes(record[:8], 'big') 164 | f.write(record[8:]) 165 | pid2offset[p_id] = idx 166 | if idx < 3: 167 | print(str(idx) + " " + str(p_id)) 168 | out_line_count += 1 169 | 170 | print("Total lines written: " + str(out_line_count)) 171 | meta = { 172 | 'type': 'int32', 173 | 'total_number': out_line_count, 174 | 'embedding_size': args.max_seq_length} 175 | with open(out_passage_path + "_meta", 'w') as f: 176 | json.dump(meta, f) 177 | embedding_cache = EmbeddingCache(out_passage_path) 178 | print("First line") 179 | with embedding_cache as emb: 180 | print(emb[0]) 181 | 182 | pid2offset_path = os.path.join( 183 | args.out_data_dir, 184 | "pid2offset.pickle", 185 | ) 186 | with open(pid2offset_path, 'wb') as handle: 187 | pickle.dump(pid2offset, handle, protocol=4) 188 | print("done saving pid2offset") 189 | 190 | if args.data_type == 0: 191 | write_query_rel( 192 | args, 193 | pid2offset, 194 | "msmarco-doctrain-queries.tsv", 195 | "msmarco-doctrain-qrels.tsv", 196 | "train-query", 197 | "train-qrel.tsv") 198 | write_query_rel( 199 | args, 200 | pid2offset, 201 | "msmarco-test2019-queries.tsv", 202 | "2019qrels-docs.txt", 203 | "dev-query", 204 | "dev-qrel.tsv") 205 | else: 206 | write_query_rel( 207 | args, 208 | pid2offset, 209 | "queries.train.tsv", 210 | "qrels.train.tsv", 211 | "train-query", 212 | "train-qrel.tsv") 213 | write_query_rel( 214 | args, 215 | pid2offset, 216 | "queries.dev.small.tsv", 217 | "qrels.dev.small.tsv", 218 | "dev-query", 219 | "dev-qrel.tsv") 220 | 221 | 222 | def PassagePreprocessingFn(args, line, tokenizer): 223 | if args.data_type == 0: 224 | line_arr = line.split('\t') 225 | p_id = int(line_arr[0][1:]) # remove "D" 226 | 227 | url = line_arr[1].rstrip() 228 | title = line_arr[2].rstrip() 229 | p_text = line_arr[3].rstrip() 230 | 231 | #full_text = url + "" + title + "" + p_text 232 | full_text = url + " "+tokenizer.sep_token+" " + title + " "+tokenizer.sep_token+" " + p_text 233 | # keep only first 10000 characters, should be sufficient for any 234 | # experiment that uses less than 500 - 1k tokens 235 | full_text = full_text[:args.max_doc_character] 236 | else: 237 | line = line.strip() 238 | line_arr = line.split('\t') 239 | p_id = int(line_arr[0]) 240 | 241 | p_text = line_arr[1].rstrip() 242 | 243 | # keep only first 10000 characters, should be sufficient for any 244 | # experiment that uses less than 500 - 1k tokens 245 | full_text = p_text[:args.max_doc_character] 246 | 247 | passage = tokenizer.encode( 248 | full_text, 249 | add_special_tokens=True, 250 | max_length=args.max_seq_length, 251 | ) 252 | passage_len = min(len(passage), args.max_seq_length) 253 | input_id_b = pad_input_ids(passage, args.max_seq_length,pad_token=tokenizer.pad_token_id) 254 | 255 | 256 | 257 | 258 | return p_id.to_bytes(8,'big') + passage_len.to_bytes(4,'big') + np.array(input_id_b,np.int32).tobytes() 259 | 260 | 261 | def QueryPreprocessingFn(args, line, tokenizer): 262 | line_arr = line.split('\t') 263 | q_id = int(line_arr[0]) 264 | 265 | passage = tokenizer.encode( 266 | line_arr[1].rstrip(), 267 | add_special_tokens=True, 268 | max_length=args.max_query_length) 269 | passage_len = min(len(passage), args.max_query_length) 270 | input_id_b = pad_input_ids(passage, args.max_query_length,pad_token=tokenizer.pad_token_id) 271 | 272 | return q_id.to_bytes(8,'big') + passage_len.to_bytes(4,'big') + np.array(input_id_b,np.int32).tobytes() 273 | 274 | 275 | def GetProcessingFn(args, query=False): 276 | def fn(vals, i): 277 | passage_len, passage = vals 278 | max_len = args.max_query_length if query else args.max_seq_length 279 | 280 | pad_len = max(0, max_len - passage_len) 281 | token_type_ids = ([0] if query else [1]) * passage_len + [0] * pad_len 282 | attention_mask = [1] * passage_len + [0] * pad_len 283 | 284 | passage_collection = [(i, passage, attention_mask, token_type_ids)] 285 | 286 | query2id_tensor = torch.tensor( 287 | [f[0] for f in passage_collection], dtype=torch.long) 288 | all_input_ids_a = torch.tensor( 289 | [f[1] for f in passage_collection], dtype=torch.int) 290 | all_attention_mask_a = torch.tensor( 291 | [f[2] for f in passage_collection], dtype=torch.bool) 292 | all_token_type_ids_a = torch.tensor( 293 | [f[3] for f in passage_collection], dtype=torch.uint8) 294 | 295 | dataset = TensorDataset( 296 | all_input_ids_a, 297 | all_attention_mask_a, 298 | all_token_type_ids_a, 299 | query2id_tensor) 300 | 301 | return [ts for ts in dataset] 302 | 303 | return fn 304 | 305 | 306 | def GetTrainingDataProcessingFn(args, query_cache, passage_cache): 307 | def fn(line, i): 308 | line_arr = line.split('\t') 309 | qid = int(line_arr[0]) 310 | pos_pid = int(line_arr[1]) 311 | neg_pids = line_arr[2].split(',') 312 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 313 | 314 | all_input_ids_a = [] 315 | all_attention_mask_a = [] 316 | 317 | query_data = GetProcessingFn( 318 | args, query=True)( 319 | query_cache[qid], qid)[0] 320 | pos_data = GetProcessingFn( 321 | args, query=False)( 322 | passage_cache[pos_pid], pos_pid)[0] 323 | 324 | pos_label = torch.tensor(1, dtype=torch.long) 325 | neg_label = torch.tensor(0, dtype=torch.long) 326 | 327 | for neg_pid in neg_pids: 328 | neg_data = GetProcessingFn( 329 | args, query=False)( 330 | passage_cache[neg_pid], neg_pid)[0] 331 | yield (query_data[0], query_data[1], query_data[2], pos_data[0], pos_data[1], pos_data[2], pos_label) 332 | yield (query_data[0], query_data[1], query_data[2], neg_data[0], neg_data[1], neg_data[2], neg_label) 333 | 334 | return fn 335 | 336 | 337 | def GetTripletTrainingDataProcessingFn(args, query_cache, passage_cache): 338 | def fn(line, i): 339 | line_arr = line.split('\t') 340 | qid = int(line_arr[0]) 341 | pos_pid = int(line_arr[1]) 342 | neg_pids = line_arr[2].split(',') 343 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 344 | 345 | all_input_ids_a = [] 346 | all_attention_mask_a = [] 347 | 348 | query_data = GetProcessingFn( 349 | args, query=True)( 350 | query_cache[qid], qid)[0] 351 | pos_data = GetProcessingFn( 352 | args, query=False)( 353 | passage_cache[pos_pid], pos_pid)[0] 354 | 355 | for neg_pid in neg_pids: 356 | neg_data = GetProcessingFn( 357 | args, query=False)( 358 | passage_cache[neg_pid], neg_pid)[0] 359 | yield (query_data[0], query_data[1], query_data[2], pos_data[0], pos_data[1], pos_data[2], 360 | neg_data[0], neg_data[1], neg_data[2]) 361 | 362 | return fn 363 | 364 | 365 | def get_arguments(): 366 | parser = argparse.ArgumentParser() 367 | 368 | parser.add_argument( 369 | "--data_dir", 370 | default=None, 371 | type=str, 372 | required=True, 373 | help="The input data dir", 374 | ) 375 | parser.add_argument( 376 | "--out_data_dir", 377 | default=None, 378 | type=str, 379 | required=True, 380 | help="The output data dir", 381 | ) 382 | parser.add_argument( 383 | "--model_type", 384 | default=None, 385 | type=str, 386 | required=True, 387 | help="Model type selected in the list: " + 388 | ", ".join( 389 | MSMarcoConfigDict.keys()), 390 | ) 391 | parser.add_argument( 392 | "--model_name_or_path", 393 | default=None, 394 | type=str, 395 | required=True, 396 | help="Path to pre-trained model or shortcut name selected in the list: " + 397 | ", ".join(ALL_MODELS), 398 | ) 399 | parser.add_argument( 400 | "--max_seq_length", 401 | default=128, 402 | type=int, 403 | help="The maximum total input sequence length after tokenization. Sequences longer " 404 | "than this will be truncated, sequences shorter will be padded.", 405 | ) 406 | parser.add_argument( 407 | "--max_query_length", 408 | default=64, 409 | type=int, 410 | help="The maximum total input sequence length after tokenization. Sequences longer " 411 | "than this will be truncated, sequences shorter will be padded.", 412 | ) 413 | parser.add_argument( 414 | "--max_doc_character", 415 | default=10000, 416 | type=int, 417 | help="used before tokenizer to save tokenizer latency", 418 | ) 419 | parser.add_argument( 420 | "--data_type", 421 | default=0, 422 | type=int, 423 | help="0 for doc, 1 for passage", 424 | ) 425 | 426 | args = parser.parse_args() 427 | 428 | return args 429 | 430 | 431 | def main(): 432 | args = get_arguments() 433 | 434 | if not os.path.exists(args.out_data_dir): 435 | os.makedirs(args.out_data_dir) 436 | preprocess(args) 437 | 438 | 439 | if __name__ == '__main__': 440 | main() 441 | -------------------------------------------------------------------------------- /evaluation/Calculate Metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "source": [ 7 | "import sys\n", 8 | "sys.path += ['../utils']\n", 9 | "import csv\n", 10 | "from tqdm import tqdm \n", 11 | "import collections\n", 12 | "import gzip\n", 13 | "import pickle\n", 14 | "import numpy as np\n", 15 | "import faiss\n", 16 | "import os\n", 17 | "import pytrec_eval\n", 18 | "import json\n", 19 | "from msmarco_eval import quality_checks_qids, compute_metrics, load_reference" 20 | ], 21 | "outputs": [], 22 | "metadata": {} 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "source": [ 27 | "# Define params below" 28 | ], 29 | "metadata": {} 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "source": [ 35 | "checkpoint_path = # location for dumpped query and passage/document embeddings which is output_dir \n", 36 | "checkpoint = 0 # embedding from which checkpoint(ie: 200000)\n", 37 | "data_type = 0 # 0 for document, 1 for passage\n", 38 | "test_set = 1 # 0 for dev_set, 1 for eval_set\n", 39 | "raw_data_dir = \n", 40 | "processed_data_dir = " 41 | ], 42 | "outputs": [], 43 | "metadata": {} 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "source": [ 48 | "# Load Qrel" 49 | ], 50 | "metadata": {} 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "source": [ 56 | "if data_type == 0:\n", 57 | " topN = 100\n", 58 | "else:\n", 59 | " topN = 1000\n", 60 | "dev_query_positive_id = {}\n", 61 | "query_positive_id_path = os.path.join(processed_data_dir, \"dev-qrel.tsv\")\n", 62 | "\n", 63 | "with open(query_positive_id_path, 'r', encoding='utf8') as f:\n", 64 | " tsvreader = csv.reader(f, delimiter=\"\\t\")\n", 65 | " for [topicid, docid, rel] in tsvreader:\n", 66 | " topicid = int(topicid)\n", 67 | " docid = int(docid)\n", 68 | " if topicid not in dev_query_positive_id:\n", 69 | " dev_query_positive_id[topicid] = {}\n", 70 | " dev_query_positive_id[topicid][docid] = int(rel)" 71 | ], 72 | "outputs": [], 73 | "metadata": {} 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "source": [ 78 | "# Prepare rerank data" 79 | ], 80 | "metadata": {} 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "source": [ 86 | "qidmap_path = processed_data_dir+\"/qid2offset.pickle\"\n", 87 | "pidmap_path = processed_data_dir+\"/pid2offset.pickle\"\n", 88 | "if data_type == 0:\n", 89 | " if test_set == 1:\n", 90 | " query_path = raw_data_dir+\"/msmarco-test2019-queries.tsv\"\n", 91 | " passage_path = raw_data_dir+\"/msmarco-doctest2019-top100\"\n", 92 | " else:\n", 93 | " query_path = raw_data_dir+\"/msmarco-docdev-queries.tsv\"\n", 94 | " passage_path = raw_data_dir+\"/msmarco-docdev-top100\"\n", 95 | "else:\n", 96 | " if test_set == 1:\n", 97 | " query_path = raw_data_dir+\"/msmarco-test2019-queries.tsv\"\n", 98 | " passage_path = raw_data_dir+\"/msmarco-passagetest2019-top1000.tsv\"\n", 99 | " else:\n", 100 | " query_path = raw_data_dir+\"/queries.dev.small.tsv\"\n", 101 | " passage_path = raw_data_dir+\"/top1000.dev\"\n", 102 | " \n", 103 | "with open(qidmap_path, 'rb') as handle:\n", 104 | " qidmap = pickle.load(handle)\n", 105 | "\n", 106 | "with open(pidmap_path, 'rb') as handle:\n", 107 | " pidmap = pickle.load(handle)\n", 108 | "\n", 109 | "qset = set()\n", 110 | "with gzip.open(query_path, 'rt', encoding='utf-8') if query_path[-2:] == \"gz\" else open(query_path, 'rt', encoding='utf-8') as f:\n", 111 | " tsvreader = csv.reader(f, delimiter=\"\\t\")\n", 112 | " for [qid, query] in tsvreader:\n", 113 | " qset.add(qid)\n", 114 | "\n", 115 | "bm25 = collections.defaultdict(set)\n", 116 | "with gzip.open(passage_path, 'rt', encoding='utf-8') if passage_path[-2:] == \"gz\" else open(passage_path, 'rt', encoding='utf-8') as f:\n", 117 | " for line in tqdm(f):\n", 118 | " if data_type == 0:\n", 119 | " [qid, Q0, pid, rank, score, runstring] = line.split(' ')\n", 120 | " pid = pid[1:]\n", 121 | " else:\n", 122 | " [qid, pid, query, passage] = line.split(\"\\t\")\n", 123 | " if qid in qset and int(qid) in qidmap:\n", 124 | " bm25[qidmap[int(qid)]].add(pidmap[int(pid)]) \n", 125 | "\n", 126 | "print(\"number of queries with \" +str(topN) + \" BM25 passages:\", len(bm25))" 127 | ], 128 | "outputs": [], 129 | "metadata": {} 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "source": [ 134 | "# Calculate Metrics" 135 | ], 136 | "metadata": {} 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "source": [ 142 | "def convert_to_string_id(result_dict):\n", 143 | " string_id_dict = {}\n", 144 | "\n", 145 | " # format [string, dict[string, val]]\n", 146 | " for k, v in result_dict.items():\n", 147 | " _temp_v = {}\n", 148 | " for inner_k, inner_v in v.items():\n", 149 | " _temp_v[str(inner_k)] = inner_v\n", 150 | "\n", 151 | " string_id_dict[str(k)] = _temp_v\n", 152 | "\n", 153 | " return string_id_dict\n", 154 | "\n", 155 | "def EvalDevQuery(query_embedding2id, passage_embedding2id, dev_query_positive_id, I_nearest_neighbor,topN):\n", 156 | " prediction = {} #[qid][docid] = docscore, here we use -rank as score, so the higher the rank (1 > 2), the higher the score (-1 > -2)\n", 157 | "\n", 158 | " total = 0\n", 159 | " labeled = 0\n", 160 | " Atotal = 0\n", 161 | " Alabeled = 0\n", 162 | " qids_to_ranked_candidate_passages = {} \n", 163 | " for query_idx in range(len(I_nearest_neighbor)): \n", 164 | " seen_pid = set()\n", 165 | " query_id = query_embedding2id[query_idx]\n", 166 | " prediction[query_id] = {}\n", 167 | "\n", 168 | " top_ann_pid = I_nearest_neighbor[query_idx].copy()\n", 169 | " selected_ann_idx = top_ann_pid[:topN]\n", 170 | " rank = 0\n", 171 | " \n", 172 | " if query_id in qids_to_ranked_candidate_passages:\n", 173 | " pass \n", 174 | " else:\n", 175 | " # By default, all PIDs in the list of 1000 are 0. Only override those that are given\n", 176 | " tmp = [0] * 1000\n", 177 | " qids_to_ranked_candidate_passages[query_id] = tmp\n", 178 | " \n", 179 | " for idx in selected_ann_idx:\n", 180 | " pred_pid = passage_embedding2id[idx]\n", 181 | " \n", 182 | " if not pred_pid in seen_pid:\n", 183 | " # this check handles multiple vector per document\n", 184 | " qids_to_ranked_candidate_passages[query_id][rank]=pred_pid\n", 185 | " Atotal += 1\n", 186 | " if pred_pid not in dev_query_positive_id[query_id]:\n", 187 | " Alabeled += 1\n", 188 | " if rank < 10:\n", 189 | " total += 1\n", 190 | " if pred_pid not in dev_query_positive_id[query_id]:\n", 191 | " labeled += 1\n", 192 | " rank += 1\n", 193 | " prediction[query_id][pred_pid] = -rank\n", 194 | " seen_pid.add(pred_pid)\n", 195 | "\n", 196 | " # use out of the box evaluation script\n", 197 | " evaluator = pytrec_eval.RelevanceEvaluator(\n", 198 | " convert_to_string_id(dev_query_positive_id), {'map_cut', 'ndcg_cut', 'recip_rank','recall'})\n", 199 | "\n", 200 | " eval_query_cnt = 0\n", 201 | " result = evaluator.evaluate(convert_to_string_id(prediction))\n", 202 | " \n", 203 | " qids_to_relevant_passageids = {}\n", 204 | " for qid in dev_query_positive_id:\n", 205 | " qid = int(qid)\n", 206 | " if qid in qids_to_relevant_passageids:\n", 207 | " pass\n", 208 | " else:\n", 209 | " qids_to_relevant_passageids[qid] = []\n", 210 | " for pid in dev_query_positive_id[qid]:\n", 211 | " if pid>0:\n", 212 | " qids_to_relevant_passageids[qid].append(pid)\n", 213 | " \n", 214 | " ms_mrr = compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)\n", 215 | "\n", 216 | " ndcg = 0\n", 217 | " Map = 0\n", 218 | " mrr = 0\n", 219 | " recall = 0\n", 220 | " recall_1000 = 0\n", 221 | "\n", 222 | " for k in result.keys():\n", 223 | " eval_query_cnt += 1\n", 224 | " ndcg += result[k][\"ndcg_cut_10\"]\n", 225 | " Map += result[k][\"map_cut_10\"]\n", 226 | " mrr += result[k][\"recip_rank\"]\n", 227 | " recall += result[k][\"recall_\"+str(topN)]\n", 228 | "\n", 229 | " final_ndcg = ndcg / eval_query_cnt\n", 230 | " final_Map = Map / eval_query_cnt\n", 231 | " final_mrr = mrr / eval_query_cnt\n", 232 | " final_recall = recall / eval_query_cnt\n", 233 | " hole_rate = labeled/total\n", 234 | " Ahole_rate = Alabeled/Atotal\n", 235 | "\n", 236 | " return final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, result, prediction" 237 | ], 238 | "outputs": [], 239 | "metadata": {} 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "source": [ 245 | "dev_query_embedding = []\n", 246 | "dev_query_embedding2id = []\n", 247 | "passage_embedding = []\n", 248 | "passage_embedding2id = []\n", 249 | "for i in range(8):\n", 250 | " try:\n", 251 | " with open(checkpoint_path + \"dev_query_\"+str(checkpoint)+\"__emb_p__data_obj_\"+str(i)+\".pb\", 'rb') as handle:\n", 252 | " dev_query_embedding.append(pickle.load(handle))\n", 253 | " with open(checkpoint_path + \"dev_query_\"+str(checkpoint)+\"__embid_p__data_obj_\"+str(i)+\".pb\", 'rb') as handle:\n", 254 | " dev_query_embedding2id.append(pickle.load(handle))\n", 255 | " with open(checkpoint_path + \"passage_\"+str(checkpoint)+\"__emb_p__data_obj_\"+str(i)+\".pb\", 'rb') as handle:\n", 256 | " passage_embedding.append(pickle.load(handle))\n", 257 | " with open(checkpoint_path + \"passage_\"+str(checkpoint)+\"__embid_p__data_obj_\"+str(i)+\".pb\", 'rb') as handle:\n", 258 | " passage_embedding2id.append(pickle.load(handle))\n", 259 | " except:\n", 260 | " break\n", 261 | "if (not dev_query_embedding) or (not dev_query_embedding2id) or (not passage_embedding) or not (passage_embedding2id):\n", 262 | " print(\"No data found for checkpoint: \",checkpoint)\n", 263 | "\n", 264 | "dev_query_embedding = np.concatenate(dev_query_embedding, axis=0)\n", 265 | "dev_query_embedding2id = np.concatenate(dev_query_embedding2id, axis=0)\n", 266 | "passage_embedding = np.concatenate(passage_embedding, axis=0)\n", 267 | "passage_embedding2id = np.concatenate(passage_embedding2id, axis=0)" 268 | ], 269 | "outputs": [], 270 | "metadata": {} 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "source": [ 275 | "# reranking metrics" 276 | ], 277 | "metadata": {} 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "source": [ 283 | "pidmap = collections.defaultdict(list)\n", 284 | "for i in range(len(passage_embedding2id)):\n", 285 | " pidmap[passage_embedding2id[i]].append(i) # abs pos(key) to rele pos(val)\n", 286 | "\n", 287 | "if len(bm25) == 0:\n", 288 | " print(\"Rerank data set is empty. Check if your data prepration is done on the same data set. Rerank metrics is skipped.\")\n", 289 | "else:\n", 290 | " rerank_data = {}\n", 291 | " all_dev_I = []\n", 292 | " for i,qid in enumerate(dev_query_embedding2id):\n", 293 | " p_set = []\n", 294 | " p_set_map = {}\n", 295 | " if qid not in bm25:\n", 296 | " print(qid,\"not in bm25\")\n", 297 | " else:\n", 298 | " count = 0\n", 299 | " for k,pid in enumerate(bm25[qid]):\n", 300 | " if pid in pidmap:\n", 301 | " for val in pidmap[pid]:\n", 302 | " p_set.append(passage_embedding[val])\n", 303 | " p_set_map[count] = val # new rele pos(key) to old rele pos(val)\n", 304 | " count += 1\n", 305 | " else:\n", 306 | " print(pid,\"not in passages\")\n", 307 | " dim = passage_embedding.shape[1]\n", 308 | " faiss.omp_set_num_threads(16)\n", 309 | " cpu_index = faiss.IndexFlatIP(dim)\n", 310 | " p_set = np.asarray(p_set)\n", 311 | " cpu_index.add(p_set) \n", 312 | " _, dev_I = cpu_index.search(dev_query_embedding[i:i+1], len(p_set))\n", 313 | " for j in range(len(dev_I[0])):\n", 314 | " dev_I[0][j] = p_set_map[dev_I[0][j]]\n", 315 | " all_dev_I.append(dev_I[0])\n", 316 | " result = EvalDevQuery(dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, all_dev_I, topN)\n", 317 | " final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, metrics, prediction = result\n", 318 | " print(\"Reranking Results for checkpoint \"+str(checkpoint))\n", 319 | " print(\"Reranking NDCG@10:\" + str(final_ndcg))\n", 320 | " print(\"Reranking map@10:\" + str(final_Map))\n", 321 | " print(\"Reranking pytrec_mrr:\" + str(final_mrr))\n", 322 | " print(\"Reranking recall@\"+str(topN)+\":\" + str(final_recall))\n", 323 | " print(\"Reranking hole rate@10:\" + str(hole_rate))\n", 324 | " print(\"Reranking hole rate:\" + str(Ahole_rate))\n", 325 | " print(\"Reranking ms_mrr:\" + str(ms_mrr))" 326 | ], 327 | "outputs": [], 328 | "metadata": {} 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "source": [ 333 | "# full ranking metrics" 334 | ], 335 | "metadata": {} 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "source": [ 341 | "dim = passage_embedding.shape[1]\n", 342 | "faiss.omp_set_num_threads(16)\n", 343 | "cpu_index = faiss.IndexFlatIP(dim)\n", 344 | "cpu_index.add(passage_embedding) \n", 345 | "_, dev_I = cpu_index.search(dev_query_embedding, topN)\n", 346 | "result = EvalDevQuery(dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I, topN)\n", 347 | "final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, metrics, prediction = result\n", 348 | "print(\"Results for checkpoint \"+str(checkpoint))\n", 349 | "print(\"NDCG@10:\" + str(final_ndcg))\n", 350 | "print(\"map@10:\" + str(final_Map))\n", 351 | "print(\"pytrec_mrr:\" + str(final_mrr))\n", 352 | "print(\"recall@\"+str(topN)+\":\" + str(final_recall))\n", 353 | "print(\"hole rate@10:\" + str(hole_rate))\n", 354 | "print(\"hole rate:\" + str(Ahole_rate))\n", 355 | "print(\"ms_mrr:\" + str(ms_mrr))" 356 | ], 357 | "outputs": [], 358 | "metadata": {} 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "source": [], 364 | "outputs": [], 365 | "metadata": {} 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": "Python 3", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.6.9" 385 | } 386 | }, 387 | "nbformat": 4, 388 | "nbformat_minor": 4 389 | } 390 | -------------------------------------------------------------------------------- /data/DPR_data.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import sys 3 | sys.path += ['../'] 4 | import argparse 5 | import json 6 | import os 7 | import random 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset, TensorDataset 11 | from model.models import MSMarcoConfigDict, ALL_MODELS 12 | import csv 13 | from utils.util import multi_file_process, numbered_byte_file_generator, EmbeddingCache 14 | import pickle 15 | 16 | 17 | def normalize_question(question: str) -> str: 18 | if question[-1] == '?': 19 | question = question[:-1] 20 | return question 21 | 22 | 23 | def write_qas_query(args, qas_file, out_query_file): 24 | print("Writing qas query files " + str(out_query_file)) 25 | print("print",args.answer_dir,qas_file) 26 | qas_path = os.path.join( 27 | args.answer_dir, 28 | qas_file, 29 | ) 30 | out_query_path = os.path.join( 31 | args.out_data_dir, 32 | out_query_file , 33 | ) 34 | 35 | configObj = MSMarcoConfigDict[args.model_type] 36 | tokenizer = configObj.tokenizer_class.from_pretrained( 37 | args.model_name_or_path, 38 | do_lower_case=True, 39 | cache_dir=None, 40 | ) 41 | 42 | qid = 0 43 | with open(qas_path, "r", encoding="utf-8") as f, open(out_query_path, "wb") as out_query: 44 | reader = csv.reader(f, delimiter='\t') 45 | for row in reader: 46 | question = normalize_question(row[0]) 47 | out_query.write(QueryPreprocessingFn(args, qid, question, tokenizer)) 48 | qid += 1 49 | 50 | meta = {'type': 'int32', 'total_number': qid, 'embedding_size': args.max_seq_length} 51 | with open(out_query_path + "_meta", 'w') as f: 52 | json.dump(meta, f) 53 | 54 | 55 | def write_query_rel(args, pid2offset, query_file, out_query_file, out_ann_file, out_train_file, passage_id_name="passage_id"): 56 | 57 | print("Writing query files " + str(out_query_file) + " and " + str(out_ann_file)) 58 | 59 | query_path = os.path.join( 60 | args.question_dir, 61 | query_file, 62 | ) 63 | 64 | with open(query_path, 'r', encoding="utf-8") as f: 65 | data = json.load(f) 66 | print('Aggregated data size: {}'.format(len(data))) 67 | 68 | data = [r for r in data if len(r['positive_ctxs']) > 0] 69 | print('Total cleaned data size: {}'.format(len(data))) 70 | data = [r for r in data if len(r['hard_negative_ctxs']) > 0] 71 | print('Total cleaned data size: {}'.format(len(data))) 72 | 73 | out_query_path = os.path.join( 74 | args.out_data_dir, 75 | out_query_file , 76 | ) 77 | 78 | out_ann_file = os.path.join( 79 | args.out_data_dir, 80 | out_ann_file , 81 | ) 82 | 83 | out_training_path = os.path.join( 84 | args.out_data_dir, 85 | out_train_file , 86 | ) 87 | 88 | qid = 0 89 | 90 | configObj = MSMarcoConfigDict[args.model_type] 91 | tokenizer = configObj.tokenizer_class.from_pretrained( 92 | args.model_name_or_path, 93 | do_lower_case=True, 94 | cache_dir=None, 95 | ) 96 | 97 | with open(out_query_path, "wb") as out_query, \ 98 | open(out_ann_file, "w", encoding='utf-8') as out_ann, \ 99 | open(out_training_path, "w", encoding='utf-8') as out_training: 100 | for sample in data: 101 | positive_ctxs = sample['positive_ctxs'] 102 | neg_ctxs = sample['hard_negative_ctxs'] 103 | question = normalize_question(sample['question']) 104 | first_pos_pid = pid2offset[int(positive_ctxs[0][passage_id_name])] 105 | neg_pids = [str(pid2offset[int(neg_ctx[passage_id_name])]) for neg_ctx in neg_ctxs] 106 | out_ann.write("{}\t{}\t{}\n".format(qid, first_pos_pid, sample["answers"])) 107 | out_training.write("{}\t{}\t{}\n".format(qid, first_pos_pid, ','.join(neg_pids))) 108 | out_query.write(QueryPreprocessingFn(args, qid, question, tokenizer)) 109 | qid += 1 110 | 111 | print("Total lines written: " + str(qid)) 112 | meta = {'type': 'int32', 'total_number': qid, 'embedding_size': args.max_seq_length} 113 | with open(out_query_path + "_meta", 'w') as f: 114 | json.dump(meta, f) 115 | 116 | embedding_cache = EmbeddingCache(out_query_path) 117 | print("First line") 118 | with embedding_cache as emb: 119 | print(emb[0]) 120 | 121 | 122 | def write_mapping(args, id2offset, out_name): 123 | out_path = os.path.join( 124 | args.out_data_dir, 125 | out_name , 126 | ) 127 | with open(out_path, 'w') as f: 128 | for item in id2offset.items(): 129 | f.write("{}\t{}\n".format(item[0], item[1])) 130 | 131 | 132 | def load_mapping(data_dir, out_name): 133 | out_path = os.path.join( 134 | data_dir, 135 | out_name , 136 | ) 137 | pid2offset = {} 138 | offset2pid = {} 139 | with open(out_path, 'r') as f: 140 | for line in f.readlines(): 141 | line_arr = line.split('\t') 142 | pid2offset[int(line_arr[0])] = int(line_arr[1]) 143 | offset2pid[int(line_arr[1])] = int(line_arr[0]) 144 | return pid2offset, offset2pid 145 | 146 | 147 | def preprocess(args): 148 | 149 | pid2offset = {} 150 | in_passage_path = os.path.join( 151 | args.wiki_dir, 152 | "psgs_w100.tsv" , 153 | ) 154 | out_passage_path = os.path.join( 155 | args.out_data_dir, 156 | "passages" , 157 | ) 158 | 159 | if os.path.exists(out_passage_path): 160 | print("preprocessed data already exist, exit preprocessing") 161 | return 162 | else: 163 | out_line_count = 0 164 | 165 | print('start passage file split processing') 166 | multi_file_process(args, 32, in_passage_path, out_passage_path, PassagePreprocessingFn) 167 | 168 | print('start merging splits') 169 | with open(out_passage_path, 'wb') as f: 170 | for idx, record in enumerate(numbered_byte_file_generator(out_passage_path, 32, 8 + 4 + args.max_seq_length * 4)): 171 | p_id = int.from_bytes(record[:8], 'big') 172 | f.write(record[8:]) 173 | pid2offset[p_id] = idx 174 | if idx < 3: 175 | print(str(idx) + " " + str(p_id)) 176 | out_line_count += 1 177 | 178 | print("Total lines written: " + str(out_line_count)) 179 | meta = {'type': 'int32', 'total_number': out_line_count, 'embedding_size': args.max_seq_length} 180 | with open(out_passage_path + "_meta", 'w') as f: 181 | json.dump(meta, f) 182 | write_mapping(args, pid2offset, "pid2offset") 183 | 184 | embedding_cache = EmbeddingCache(out_passage_path) 185 | print("First line") 186 | with embedding_cache as emb: 187 | print(emb[pid2offset[1]]) 188 | 189 | if args.data_type == 0: 190 | write_query_rel(args, pid2offset, "nq-train.json", "train-query", "train-ann", "train-data") 191 | elif args.data_type == 1: 192 | write_query_rel(args, pid2offset, "trivia-train.json", "train-query", "train-ann", "train-data", "psg_id") 193 | else: 194 | # use both training dataset and merge them 195 | write_query_rel(args, pid2offset, "nq-train.json", "train-query-nq", "train-ann-nq", "train-data-nq") 196 | write_query_rel(args, pid2offset, "trivia-train.json", "train-query-trivia", "train-ann-trivia", "train-data-trivia", "psg_id") 197 | 198 | with open(args.out_data_dir + "train-query-nq", "rb") as nq_query, \ 199 | open(args.out_data_dir + "train-query-trivia", "rb") as trivia_query, \ 200 | open(args.out_data_dir + "train-query", "wb") as out_query: 201 | out_query.write(nq_query.read()) 202 | out_query.write(trivia_query.read()) 203 | 204 | with open(args.out_data_dir + "train-query-nq_meta", "r", encoding='utf-8') as nq_query, \ 205 | open(args.out_data_dir + "train-query-trivia_meta", "r", encoding='utf-8') as trivia_query, \ 206 | open(args.out_data_dir + "train-query_meta", "w", encoding='utf-8') as out_query: 207 | a = json.load(nq_query) 208 | b = json.load(trivia_query) 209 | meta = {'type': 'int32', 'total_number': a['total_number'] + b['total_number'], 'embedding_size': args.max_seq_length} 210 | json.dump(meta, out_query) 211 | 212 | embedding_cache = EmbeddingCache(args.out_data_dir + "train-query") 213 | print("First line after merge") 214 | with embedding_cache as emb: 215 | print(emb[58812]) 216 | 217 | with open(args.out_data_dir + "train-ann-nq", "r", encoding='utf-8') as nq_ann, \ 218 | open(args.out_data_dir + "train-ann-trivia", "r", encoding='utf-8') as trivia_ann, \ 219 | open(args.out_data_dir + "train-ann", "w", encoding='utf-8') as out_ann: 220 | out_ann.writelines(nq_ann.readlines()) 221 | out_ann.writelines(trivia_ann.readlines()) 222 | 223 | 224 | write_query_rel(args, pid2offset, "nq-dev.json", "dev-query", "dev-ann", "dev-data") 225 | write_query_rel(args, pid2offset, "trivia-dev.json", "dev-query-trivia", "dev-ann-trivia", "dev-data-trivia", "psg_id") 226 | write_qas_query(args, "nq-test.csv", "test-query") 227 | write_qas_query(args, "trivia-test.csv", "trivia-test-query") 228 | 229 | def PassagePreprocessingFn(args, line, tokenizer): 230 | line_arr = list(csv.reader([line], delimiter='\t'))[0] 231 | if line_arr[0] == 'id': 232 | return bytearray() 233 | 234 | p_id = int(line_arr[0]) 235 | text = line_arr[1] 236 | title = line_arr[2] 237 | 238 | token_ids = tokenizer.encode(title, text_pair=text, add_special_tokens=True, 239 | max_length=args.max_seq_length, 240 | pad_to_max_length=False) 241 | 242 | seq_len = args.max_seq_length 243 | passage_len = len(token_ids) 244 | if len(token_ids) < seq_len: 245 | token_ids = token_ids + [tokenizer.pad_token_id] * (seq_len - len(token_ids)) 246 | if len(token_ids) > seq_len: 247 | token_ids = token_ids[0:seq_len] 248 | token_ids[-1] = tokenizer.sep_token_id 249 | 250 | if p_id < 5: 251 | a = np.array(token_ids, np.int32) 252 | print("pid {}, passagelen {}, shape {}".format(p_id, passage_len, a.shape)) 253 | 254 | return p_id.to_bytes(8, 'big') + passage_len.to_bytes(4, 'big') + np.array(token_ids, np.int32).tobytes() 255 | 256 | 257 | def QueryPreprocessingFn(args, qid, text, tokenizer): 258 | token_ids = tokenizer.encode(text, add_special_tokens=True, max_length=args.max_seq_length, 259 | pad_to_max_length=False) 260 | 261 | seq_len = args.max_seq_length 262 | passage_len = len(token_ids) 263 | if len(token_ids) < seq_len: 264 | token_ids = token_ids + [tokenizer.pad_token_id] * (seq_len - len(token_ids)) 265 | if len(token_ids) > seq_len: 266 | token_ids = token_ids[0:seq_len] 267 | token_ids[-1] = tokenizer.sep_token_id 268 | 269 | if qid < 5: 270 | a = np.array(token_ids, np.int32) 271 | print("qid {}, passagelen {}, shape {}".format(qid, passage_len, a.shape)) 272 | 273 | return passage_len.to_bytes(4, 'big') + np.array(token_ids, np.int32).tobytes() 274 | 275 | 276 | def GetProcessingFn(args, query=False): 277 | def fn(vals, i): 278 | passage_len, passage = vals 279 | max_len = args.max_seq_length 280 | 281 | pad_len = max(0, max_len - passage_len) 282 | token_type_ids = [0] * passage_len + [0] * pad_len 283 | attention_mask = passage != 0 284 | 285 | passage_collection = [(i, passage, attention_mask, token_type_ids)] 286 | 287 | query2id_tensor = torch.tensor([f[0] for f in passage_collection], dtype=torch.long) 288 | all_input_ids_a = torch.tensor([f[1] for f in passage_collection], dtype=torch.int) 289 | all_attention_mask_a = torch.tensor([f[2] for f in passage_collection], dtype=torch.bool) 290 | all_token_type_ids_a = torch.tensor([f[3] for f in passage_collection], dtype=torch.uint8) 291 | 292 | dataset = TensorDataset(all_input_ids_a, all_attention_mask_a, all_token_type_ids_a, query2id_tensor) 293 | 294 | return [ts for ts in dataset] 295 | 296 | return fn 297 | 298 | 299 | def GetTrainingDataProcessingFn(args, query_cache, passage_cache, shuffle=True): 300 | def fn(line, i): 301 | line_arr = line.split('\t') 302 | qid = int(line_arr[0]) 303 | pos_pid = int(line_arr[1]) 304 | neg_pids = line_arr[2].split(',') 305 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 306 | 307 | all_input_ids_a = [] 308 | all_attention_mask_a = [] 309 | 310 | query_data = GetProcessingFn(args, query=True)(query_cache[qid], qid)[0] 311 | pos_data = GetProcessingFn(args, query=False)(passage_cache[pos_pid], pos_pid)[0] 312 | 313 | if shuffle: 314 | random.shuffle(neg_pids) 315 | 316 | neg_data = GetProcessingFn(args, query=False)(passage_cache[neg_pids[0]], neg_pids[0])[0] 317 | yield (query_data[0], query_data[1], query_data[2], pos_data[0], pos_data[1], pos_data[2]) 318 | yield (query_data[0], query_data[1], query_data[2], neg_data[0], neg_data[1], neg_data[2]) 319 | 320 | return fn 321 | 322 | 323 | def GetTripletTrainingDataProcessingFn(args, query_cache, passage_cache, shuffle=True): 324 | def fn(line, i): 325 | line_arr = line.split('\t') 326 | qid = int(line_arr[0]) 327 | pos_pid = int(line_arr[1]) 328 | neg_pids = line_arr[2].split(',') 329 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 330 | 331 | all_input_ids_a = [] 332 | all_attention_mask_a = [] 333 | 334 | query_data = GetProcessingFn(args, query=True)(query_cache[qid], qid)[0] 335 | pos_data = GetProcessingFn(args, query=False)(passage_cache[pos_pid], pos_pid)[0] 336 | 337 | if shuffle: 338 | random.shuffle(neg_pids) 339 | 340 | neg_data = GetProcessingFn(args, query=False)(passage_cache[neg_pids[0]], neg_pids[0])[0] 341 | yield (query_data[0], query_data[1], query_data[2], pos_data[0], pos_data[1], pos_data[2], 342 | neg_data[0], neg_data[1], neg_data[2]) 343 | 344 | return fn 345 | 346 | 347 | def main(): 348 | parser = argparse.ArgumentParser() 349 | parser.add_argument( 350 | "--out_data_dir", 351 | default="/webdata-nfs/jialliu/dpr/ann/ann_multi_data_256/", 352 | type=str, 353 | help="The output data dir", 354 | ) 355 | parser.add_argument( 356 | "--model_type", 357 | default="dpr", 358 | type=str, 359 | help="Model type selected in the list: " + ", ".join(MSMarcoConfigDict.keys()), 360 | ) 361 | parser.add_argument( 362 | "--model_name_or_path", 363 | default="bert-base-uncased", 364 | type=str, 365 | help="Path to pre-trained model or shortcut name selected in the list: " + 366 | ", ".join(ALL_MODELS), 367 | ) 368 | parser.add_argument( 369 | "--max_seq_length", 370 | default=256, 371 | type=int, 372 | help="The maximum total input sequence length after tokenization. Sequences longer " 373 | "than this will be truncated, sequences shorter will be padded.", 374 | ) 375 | parser.add_argument( 376 | "--data_type", 377 | default=0, 378 | type=int, 379 | help="0 is nq, 1 is trivia, 2 is both", 380 | ) 381 | parser.add_argument( 382 | "--question_dir", 383 | type=str, 384 | help="location of the raw QnA question data", 385 | ) 386 | parser.add_argument( 387 | "--wiki_dir", 388 | type=str, 389 | help="location of the wiki corpus", 390 | ) 391 | parser.add_argument( 392 | "--answer_dir", 393 | type=str, 394 | help="location of the QnA answers for evaluation", 395 | ) 396 | args = parser.parse_args() 397 | if not os.path.exists(args.out_data_dir): 398 | os.makedirs(args.out_data_dir) 399 | preprocess(args) 400 | 401 | 402 | if __name__ == '__main__': 403 | main() 404 | -------------------------------------------------------------------------------- /model/SEED_Encoder/tokenization_seed_encoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Microsoft and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization class for model DeBERTa.""" 16 | 17 | import os 18 | import unicodedata 19 | from typing import Any, Dict, List, Optional, Tuple 20 | 21 | import sentencepiece as sp 22 | import six 23 | 24 | from transformers.tokenization_utils import PreTrainedTokenizer 25 | from tokenizers import BertWordPieceTokenizer, normalizers, pre_tokenizers 26 | import re 27 | 28 | PRETRAINED_VOCAB_FILES_MAP = { 29 | "vocab_file": { 30 | "microsoft/seed-encoder-3-layer-decoder": "./vocab.txt", 31 | "microsoft/seed-encoder-1-layer-decoder": "./vocab.txt" 32 | } 33 | } 34 | 35 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 36 | "microsoft/seed-encoder-3-layer-decoder": 512, 37 | "microsoft/seed-encoder-1-layer-decoder": 512, 38 | } 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "microsoft/seed-encoder-3-layer-decoder": {"do_lower_case": False}, 42 | "microsoft/seed-encoder-1-layer-decoder": {"do_lower_case": False}, 43 | } 44 | 45 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 46 | 47 | 48 | 49 | class SEEDTokenizer(PreTrainedTokenizer): 50 | r""" 51 | Constructs a DeBERTa-v2 tokenizer. Based on `SentencePiece `__. 52 | Args: 53 | vocab_file (:obj:`str`): 54 | `SentencePiece `__ file (generally has a `.spm` extension) that 55 | contains the vocabulary necessary to instantiate a tokenizer. 56 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`): 57 | Whether or not to lowercase the input when tokenizing. 58 | bos_token (:obj:`string`, `optional`, defaults to "[CLS]"): 59 | The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. 60 | When building a sequence using special tokens, this is not the token that is used for the beginning of 61 | sequence. The token used is the :obj:`cls_token`. 62 | eos_token (:obj:`string`, `optional`, defaults to "[SEP]"): 63 | The end of sequence token. When building a sequence using special tokens, this is not the token that is 64 | used for the end of sequence. The token used is the :obj:`sep_token`. 65 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 66 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 67 | token instead. 68 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 69 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 70 | sequence classification or for a text and a question for question answering. It is also used as the last 71 | token of a sequence built with special tokens. 72 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 73 | The token used for padding, for example when batching sequences of different lengths. 74 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 75 | The classifier token which is used when doing sequence classification (classification of the whole sequence 76 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 77 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 78 | The token used for masking values. This is the token used when training this model with masked language 79 | modeling. This is the token which the model will try to predict. 80 | sp_model_kwargs (:obj:`dict`, `optional`): 81 | Will be passed to the ``SentencePieceProcessor.__init__()`` method. The `Python wrapper for SentencePiece 82 | `__ can be used, among other things, to set: 83 | - ``enable_sampling``: Enable subword regularization. 84 | - ``nbest_size``: Sampling parameters for unigram. Invalid for BPE-Dropout. 85 | - ``nbest_size = {0,1}``: No sampling is performed. 86 | - ``nbest_size > 1``: samples from the nbest_size results. 87 | - ``nbest_size < 0``: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) 88 | using forward-filtering-and-backward-sampling algorithm. 89 | - ``alpha``: Smoothing parameter for unigram sampling, and dropout probability of merge operations for 90 | BPE-dropout. 91 | """ 92 | 93 | vocab_files_names = VOCAB_FILES_NAMES 94 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 95 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 96 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 97 | 98 | def __init__( 99 | self, 100 | vocab_file, 101 | do_lower_case=False, 102 | bos_token="[CLS]", 103 | eos_token="[SEP]", 104 | unk_token="[UNK]", 105 | sep_token="[SEP]", 106 | pad_token="[PAD]", 107 | cls_token="[CLS]", 108 | mask_token="", 109 | fb_model_kwargs:Optional[Dict[str, Any]] = None, 110 | **kwargs 111 | ) -> None: 112 | self.fb_model_kwargs = {} if fb_model_kwargs is None else fb_model_kwargs 113 | 114 | super().__init__( 115 | do_lower_case=do_lower_case, 116 | bos_token=bos_token, 117 | eos_token=eos_token, 118 | unk_token=unk_token, 119 | sep_token=sep_token, 120 | pad_token=pad_token, 121 | cls_token=cls_token, 122 | mask_token=mask_token, 123 | fb_model_kwargs=self.fb_model_kwargs, 124 | **kwargs, 125 | ) 126 | 127 | if not os.path.isfile(vocab_file): 128 | raise ValueError( 129 | f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " 130 | "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" 131 | ) 132 | self.do_lower_case = do_lower_case 133 | 134 | #self._tokenizer = BertWordPieceTokenizer(vocab_file, clean_text=False, strip_accents=False, lowercase=False) 135 | self._tokenizer = FastBERTTokenizer(vocab_file, fb_model_kwargs=self.fb_model_kwargs) 136 | 137 | #print('???',self.cls_token_id,self.sep_token_id,self.pad_token_id) 138 | 139 | @property 140 | def vocab_size(self): 141 | return len(self.vocab) 142 | 143 | @property 144 | def vocab(self): 145 | return self._tokenizer.vocab 146 | 147 | def get_vocab(self): 148 | vocab = self.vocab.copy() 149 | vocab.update(self.get_added_vocab()) 150 | return vocab 151 | 152 | def _tokenize(self, text: str) -> List[str]: 153 | """Take as input a string and return a list of strings (tokens) for words/sub-words""" 154 | if self.do_lower_case: 155 | escaped_special_toks = [re.escape(s_tok) for s_tok in ['[CLS]','[PAD]','[UNK]','[SEP]']] 156 | pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" 157 | text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) 158 | return self._tokenizer.tokenize(text) 159 | 160 | 161 | def _convert_token_to_id(self, token): 162 | """Converts a token (str) in an id using the vocab.""" 163 | #return self._tokenizer.spm.PieceToId(token) 164 | return self._tokenizer._convert_token_to_id(token) 165 | 166 | def _convert_id_to_token(self, index): 167 | """Converts an index (integer) in a token (str) using the vocab.""" 168 | #return self._tokenizer.spm.IdToPiece(index) if index < self.vocab_size else self.unk_token 169 | return self._tokenizer._convert_id_to_token(index) 170 | 171 | def convert_tokens_to_string(self, tokens): 172 | """Converts a sequence of tokens (string) in a single string.""" 173 | return self._tokenizer.decode(tokens) 174 | 175 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 176 | """ 177 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 178 | adding special tokens. A DeBERTa sequence has the following format: 179 | - single sequence: [CLS] X [SEP] 180 | - pair of sequences: [CLS] A [SEP] B [SEP] 181 | Args: 182 | token_ids_0 (:obj:`List[int]`): 183 | List of IDs to which the special tokens will be added. 184 | token_ids_1 (:obj:`List[int]`, `optional`): 185 | Optional second list of IDs for sequence pairs. 186 | Returns: 187 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 188 | """ 189 | 190 | if token_ids_1 is None: 191 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 192 | cls = [self.cls_token_id] 193 | sep = [self.sep_token_id] 194 | return cls + token_ids_0 + sep + token_ids_1 + sep 195 | 196 | def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): 197 | """ 198 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 199 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 200 | Args: 201 | token_ids_0 (:obj:`List[int]`): 202 | List of IDs. 203 | token_ids_1 (:obj:`List[int]`, `optional`): 204 | Optional second list of IDs for sequence pairs. 205 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 206 | Whether or not the token list is already formatted with special tokens for the model. 207 | Returns: 208 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 209 | """ 210 | 211 | if already_has_special_tokens: 212 | return super().get_special_tokens_mask( 213 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 214 | ) 215 | 216 | if token_ids_1 is not None: 217 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 218 | return [1] + ([0] * len(token_ids_0)) + [1] 219 | 220 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 221 | """ 222 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa 223 | sequence pair mask has the following format: 224 | :: 225 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 226 | | first sequence | second sequence | 227 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). 228 | Args: 229 | token_ids_0 (:obj:`List[int]`): 230 | List of IDs. 231 | token_ids_1 (:obj:`List[int]`, `optional`): 232 | Optional second list of IDs for sequence pairs. 233 | Returns: 234 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 235 | sequence(s). 236 | """ 237 | sep = [self.sep_token_id] 238 | cls = [self.cls_token_id] 239 | if token_ids_1 is None: 240 | return len(cls + token_ids_0 + sep) * [0] 241 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 242 | 243 | def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): 244 | add_prefix_space = kwargs.pop("add_prefix_space", False) 245 | if is_split_into_words or add_prefix_space: 246 | text = " " + text 247 | return (text, kwargs) 248 | 249 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 250 | return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix) 251 | 252 | def encode(self,full_text, add_special_tokens,max_length): 253 | if self.do_lower_case: 254 | escaped_special_toks = [re.escape(s_tok) for s_tok in ['[CLS]','[PAD]','[UNK]','[SEP]']] 255 | pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" 256 | full_text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), full_text) 257 | return self._tokenizer.bert_tokenizer.encode(full_text, add_special_tokens=add_special_tokens).ids[:max_length] 258 | 259 | 260 | 261 | class FastBERTTokenizer: 262 | r""" 263 | Constructs a tokenizer based on `SentencePiece `__. 264 | Args: 265 | vocab_file (:obj:`str`): 266 | `SentencePiece `__ file (generally has a `.spm` extension) that 267 | contains the vocabulary necessary to instantiate a tokenizer. 268 | sp_model_kwargs (:obj:`dict`, `optional`): 269 | Will be passed to the ``SentencePieceProcessor.__init__()`` method. The `Python wrapper for SentencePiece 270 | `__ can be used, among other things, to set: 271 | - ``enable_sampling``: Enable subword regularization. 272 | - ``nbest_size``: Sampling parameters for unigram. Invalid for BPE-Dropout. 273 | - ``nbest_size = {0,1}``: No sampling is performed. 274 | - ``nbest_size > 1``: samples from the nbest_size results. 275 | - ``nbest_size < 0``: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) 276 | using forward-filtering-and-backward-sampling algorithm. 277 | - ``alpha``: Smoothing parameter for unigram sampling, and dropout probability of merge operations for 278 | BPE-dropout. 279 | """ 280 | 281 | def __init__(self, vocab_file, fb_model_kwargs: Optional[Dict[str, Any]] = None): 282 | self.vocab_file = vocab_file 283 | self.fb_model_kwargs = {} if fb_model_kwargs is None else fb_model_kwargs 284 | 285 | assert os.path.exists(vocab_file), "no existing vocab file." 286 | 287 | #spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) 288 | #spm.load(vocab_file) 289 | #bpe_vocab_size = spm.GetPieceSize() 290 | #self.spm = spm 291 | 292 | self.bert_tokenizer = BertWordPieceTokenizer(vocab_file, clean_text=False, strip_accents=False, lowercase=False) 293 | 294 | self.vocab={} 295 | self.ids_to_tokens=[] 296 | self.add_from_file(open(vocab_file,'r')) 297 | self.add_symbol('') 298 | self.vocab_size=len(self.vocab) 299 | 300 | 301 | def add_from_file(self, f): 302 | """ 303 | Loads a pre-existing dictionary from a text file and adds its symbols 304 | to this instance. 305 | """ 306 | 307 | lines = f.readlines() 308 | 309 | for line in lines: 310 | line = line.rstrip() 311 | word = line 312 | self.add_symbol(word, overwrite=False) 313 | 314 | def add_symbol(self, word, overwrite=False): 315 | """Adds a word to the dictionary""" 316 | if word in self.vocab and not overwrite: 317 | idx = self.vocab[word] 318 | return idx 319 | else: 320 | idx = len(self.ids_to_tokens) 321 | self.vocab[word] = idx 322 | self.ids_to_tokens.append(word) 323 | return idx 324 | 325 | def tokenize(self, text): 326 | return self.bert_tokenizer.encode(text, add_special_tokens=False).tokens 327 | 328 | 329 | def convert_ids_to_tokens(self, index): 330 | return self.ids_to_tokens[index] if index < self.vocab_size else self.unk 331 | 332 | def _convert_token_to_id(self,token): 333 | return self.vocab[token] 334 | 335 | 336 | def decode(self, x:str) ->str: 337 | return self.bert_tokenizer.decode([ 338 | int(tok) for tok in x.split() 339 | ]) 340 | 341 | def pad(self): 342 | return "[PAD]" 343 | 344 | def bos(self): 345 | return "[CLS]" 346 | 347 | def eos(self): 348 | return "[SEP]" 349 | 350 | def unk(self): 351 | return "[UNK]" 352 | 353 | def mask(self): 354 | return "" 355 | 356 | def sym(self, id): 357 | return self.ids_to_tokens[id] 358 | 359 | def id(self, sym): 360 | return self.vocab[sym] if sym in self.vocab else 1 361 | 362 | 363 | def save_pretrained(self, path: str, filename_prefix: str = None): 364 | filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] 365 | if filename_prefix is not None: 366 | filename = filename_prefix + "-" + filename 367 | full_path = os.path.join(path, filename) 368 | with open(full_path, "w") as fs: 369 | #fs.write(self.spm.serialized_model_proto()) 370 | for item in self.ids_to_tokens: 371 | fs.write(str(item)+'\n') 372 | return (full_path,) 373 | #pass 374 | 375 | 376 | def _run_strip_accents(self, text): 377 | """Strips accents from a piece of text.""" 378 | text = unicodedata.normalize("NFD", text) 379 | output = [] 380 | for char in text: 381 | cat = unicodedata.category(char) 382 | if cat == "Mn": 383 | continue 384 | output.append(char) 385 | return "".join(output) 386 | 387 | def _run_split_on_punc(self, text): 388 | """Splits punctuation on a piece of text.""" 389 | chars = list(text) 390 | i = 0 391 | start_new_word = True 392 | output = [] 393 | while i < len(chars): 394 | char = chars[i] 395 | if _is_punctuation(char): 396 | output.append([char]) 397 | start_new_word = True 398 | else: 399 | if start_new_word: 400 | output.append([]) 401 | start_new_word = False 402 | output[-1].append(char) 403 | i += 1 404 | 405 | return ["".join(x) for x in output] 406 | 407 | 408 | 409 | 410 | def _is_whitespace(char): 411 | """Checks whether `chars` is a whitespace character.""" 412 | # \t, \n, and \r are technically control characters but we treat them 413 | # as whitespace since they are generally considered as such. 414 | if char == " " or char == "\t" or char == "\n" or char == "\r": 415 | return True 416 | cat = unicodedata.category(char) 417 | if cat == "Zs": 418 | return True 419 | return False 420 | 421 | 422 | def _is_control(char): 423 | """Checks whether `chars` is a control character.""" 424 | # These are technically control characters but we count them as whitespace 425 | # characters. 426 | if char == "\t" or char == "\n" or char == "\r": 427 | return False 428 | cat = unicodedata.category(char) 429 | if cat.startswith("C"): 430 | return True 431 | return False 432 | 433 | 434 | def _is_punctuation(char): 435 | """Checks whether `chars` is a punctuation character.""" 436 | cp = ord(char) 437 | # We treat all non-letter/number ASCII as punctuation. 438 | # Characters such as "^", "$", and "`" are not in the Unicode 439 | # Punctuation class but we treat them as punctuation anyways, for 440 | # consistency. 441 | if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): 442 | return True 443 | cat = unicodedata.category(char) 444 | if cat.startswith("P"): 445 | return True 446 | return False 447 | 448 | 449 | def convert_to_unicode(text): 450 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 451 | if six.PY3: 452 | if isinstance(text, str): 453 | return text 454 | elif isinstance(text, bytes): 455 | return text.decode("utf-8", "ignore") 456 | else: 457 | raise ValueError(f"Unsupported string type: {type(text)}") 458 | elif six.PY2: 459 | if isinstance(text, str): 460 | return text.decode("utf-8", "ignore") 461 | else: 462 | raise ValueError(f"Unsupported string type: {type(text)}") 463 | else: 464 | raise ValueError("Not running on Python2 or Python 3?") -------------------------------------------------------------------------------- /drivers/run_ann_data_gen_dpr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | sys.path += ['../'] 4 | import json 5 | import logging 6 | import os 7 | from os.path import isfile, join 8 | import random 9 | import time 10 | import csv 11 | import numpy as np 12 | import torch 13 | torch.multiprocessing.set_sharing_strategy('file_system') 14 | from torch.utils.data import DataLoader 15 | from tqdm import tqdm 16 | import torch.distributed as dist 17 | from torch import nn 18 | from model.models import MSMarcoConfigDict 19 | from utils.util import ( 20 | StreamingDataset, 21 | EmbeddingCache, 22 | get_checkpoint_no, 23 | get_latest_ann_data, 24 | barrier_array_merge, 25 | is_first_worker, 26 | ) 27 | from data.DPR_data import GetProcessingFn, load_mapping 28 | from utils.dpr_utils import load_states_from_checkpoint, get_model_obj, SimpleTokenizer, has_answer 29 | import random 30 | import transformers 31 | from transformers import ( 32 | AdamW, 33 | RobertaConfig, 34 | RobertaForSequenceClassification, 35 | RobertaTokenizer 36 | ) 37 | from torch import nn 38 | logger = logging.getLogger(__name__) 39 | import faiss 40 | try: 41 | from torch.utils.tensorboard import SummaryWriter 42 | except ImportError: 43 | from tensorboardX import SummaryWriter 44 | 45 | 46 | def get_latest_checkpoint(args): 47 | if not os.path.exists(args.training_dir): 48 | return args.init_model_dir, 0 49 | files = list(next(os.walk(args.training_dir))[2]) 50 | 51 | def valid_checkpoint(checkpoint): 52 | return checkpoint.startswith("checkpoint-") 53 | 54 | logger.info("checkpoint files") 55 | logger.info(files) 56 | checkpoint_nums = [get_checkpoint_no(s) for s in files if valid_checkpoint(s)] 57 | 58 | if len(checkpoint_nums) > 0: 59 | return os.path.join(args.training_dir, "checkpoint-" + str(max(checkpoint_nums))), max(checkpoint_nums) 60 | return args.init_model_dir, 0 61 | 62 | 63 | def load_data(args): 64 | passage_path = os.path.join(args.passage_path, "psgs_w100.tsv") 65 | test_qa_path = os.path.join(args.test_qa_path, "nq-test.csv") 66 | trivia_test_qa_path = os.path.join(args.trivia_test_qa_path, "trivia-test.csv") 67 | train_ann_path = os.path.join(args.data_dir, "train-ann") 68 | 69 | pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") 70 | 71 | passage_text = {} 72 | train_pos_id = [] 73 | train_answers = [] 74 | test_answers = [] 75 | test_answers_trivia = [] 76 | 77 | logger.info("Loading train ann") 78 | with open(train_ann_path, 'r', encoding='utf8') as f: 79 | # file format: q_id, positive_pid, answers 80 | tsvreader = csv.reader(f, delimiter="\t") 81 | for row in tsvreader: 82 | train_pos_id.append(int(row[1])) 83 | train_answers.append(eval(row[2])) 84 | 85 | logger.info("Loading test answers") 86 | with open(test_qa_path, "r", encoding="utf-8") as ifile: 87 | # file format: question, answers 88 | reader = csv.reader(ifile, delimiter='\t') 89 | for row in reader: 90 | test_answers.append(eval(row[1])) 91 | 92 | logger.info("Loading trivia test answers") 93 | with open(trivia_test_qa_path, "r", encoding="utf-8") as ifile: 94 | # file format: question, answers 95 | reader = csv.reader(ifile, delimiter='\t') 96 | for row in reader: 97 | test_answers_trivia.append(eval(row[1])) 98 | 99 | logger.info("Loading passages") 100 | with open(passage_path, "r", encoding="utf-8") as tsvfile: 101 | reader = csv.reader(tsvfile, delimiter='\t', ) 102 | # file format: doc_id, doc_text, title 103 | for row in reader: 104 | if row[0] != 'id': 105 | passage_text[pid2offset[int(row[0])]] = (row[1], row[2]) 106 | 107 | logger.info("Finished loading data, pos_id length %d, train answers length %d, test answers length %d", len(train_pos_id), len(train_answers), len(test_answers)) 108 | 109 | return (passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia) 110 | 111 | 112 | def load_model(args, checkpoint_path): 113 | label_list = ["0", "1"] 114 | num_labels = len(label_list) 115 | args.model_type = args.model_type.lower() 116 | configObj = MSMarcoConfigDict[args.model_type] 117 | args.model_name_or_path = checkpoint_path 118 | 119 | model = configObj.model_class(args) 120 | 121 | saved_state = load_states_from_checkpoint(checkpoint_path) 122 | model_to_load = get_model_obj(model) 123 | logger.info('Loading saved model state ...') 124 | model_to_load.load_state_dict(saved_state.model_dict) 125 | 126 | model.to(args.device) 127 | logger.info("Inference parameters %s", args) 128 | if args.local_rank != -1: 129 | model = torch.nn.parallel.DistributedDataParallel( 130 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, 131 | ) 132 | return model 133 | 134 | 135 | def InferenceEmbeddingFromStreamDataLoader(args, model, train_dataloader, is_query_inference = True, prefix =""): 136 | # expect dataset from ReconstructTrainingSet 137 | results = {} 138 | eval_batch_size = args.per_gpu_eval_batch_size 139 | 140 | # Inference! 141 | logger.info("***** Running ANN Embedding Inference *****") 142 | logger.info(" Batch size = %d", eval_batch_size) 143 | 144 | embedding = [] 145 | embedding2id = [] 146 | 147 | if args.local_rank != -1: 148 | dist.barrier() 149 | model.eval() 150 | 151 | for batch in tqdm(train_dataloader, desc="Inferencing", disable=args.local_rank not in [-1, 0], position=0, leave=True): 152 | 153 | idxs = batch[3].detach().numpy() #[#B] 154 | 155 | batch = tuple(t.to(args.device) for t in batch) 156 | 157 | with torch.no_grad(): 158 | inputs = {"input_ids": batch[0].long(), "attention_mask": batch[1].long()} 159 | if is_query_inference: 160 | embs = model.module.query_emb(**inputs) 161 | else: 162 | embs = model.module.body_emb(**inputs) 163 | 164 | embs = embs.detach().cpu().numpy() 165 | 166 | # check for multi chunk output for long sequence 167 | if len(embs.shape) == 3: 168 | for chunk_no in range(embs.shape[1]): 169 | embedding2id.append(idxs) 170 | embedding.append(embs[:,chunk_no,:]) 171 | else: 172 | embedding2id.append(idxs) 173 | embedding.append(embs) 174 | 175 | 176 | embedding = np.concatenate(embedding, axis=0) 177 | embedding2id = np.concatenate(embedding2id, axis=0) 178 | return embedding, embedding2id 179 | 180 | 181 | # streaming inference 182 | def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference = True, load_cache=False): 183 | inference_batch_size = args.per_gpu_eval_batch_size #* max(1, args.n_gpu) 184 | #inference_dataloader = StreamingDataLoader(f, fn, batch_size=inference_batch_size, num_workers=1) 185 | inference_dataset = StreamingDataset(f, fn) 186 | inference_dataloader = DataLoader(inference_dataset, batch_size=inference_batch_size) 187 | 188 | if args.local_rank != -1: 189 | dist.barrier() # directory created 190 | 191 | if load_cache: 192 | _embedding = None 193 | _embedding2id = None 194 | else: 195 | _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader(args, model, inference_dataloader, is_query_inference = is_query_inference, prefix = prefix) 196 | 197 | # preserve to memory 198 | full_embedding = barrier_array_merge(args, _embedding, prefix = prefix + "_emb_p_", load_cache = load_cache, only_load_in_master = True) 199 | full_embedding2id = barrier_array_merge(args, _embedding2id, prefix = prefix + "_embid_p_", load_cache = load_cache, only_load_in_master = True) 200 | 201 | return full_embedding, full_embedding2id 202 | 203 | 204 | def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num): 205 | 206 | model = load_model(args, checkpoint_path) 207 | pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") 208 | 209 | logger.info("***** inference of train query *****") 210 | train_query_collection_path = os.path.join(args.data_dir, "train-query") 211 | train_query_cache = EmbeddingCache(train_query_collection_path) 212 | with train_query_cache as emb: 213 | query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "query_" + str(latest_step_num)+"_", emb, is_query_inference = True) 214 | 215 | logger.info("***** inference of dev query *****") 216 | dev_query_collection_path = os.path.join(args.data_dir, "test-query") 217 | dev_query_cache = EmbeddingCache(dev_query_collection_path) 218 | with dev_query_cache as emb: 219 | dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True) 220 | 221 | dev_query_collection_path_trivia = os.path.join(args.data_dir, "trivia-test-query") 222 | dev_query_cache_trivia = EmbeddingCache(dev_query_collection_path_trivia) 223 | with dev_query_cache_trivia as emb: 224 | dev_query_embedding_trivia, dev_query_embedding2id_trivia = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True) 225 | 226 | logger.info("***** inference of passages *****") 227 | passage_collection_path = os.path.join(args.data_dir, "passages") 228 | passage_cache = EmbeddingCache(passage_collection_path) 229 | with passage_cache as emb: 230 | passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_"+ str(latest_step_num)+"_", emb, is_query_inference = False, load_cache = False) 231 | logger.info("***** Done passage inference *****") 232 | 233 | if is_first_worker(): 234 | passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data 235 | dim = passage_embedding.shape[1] 236 | print('passage embedding shape: ' + str(passage_embedding.shape)) 237 | top_k = args.topk_training 238 | faiss.omp_set_num_threads(16) 239 | cpu_index = faiss.IndexFlatIP(dim) 240 | cpu_index.add(passage_embedding) 241 | logger.info("***** Done ANN Index *****") 242 | 243 | # measure ANN mrr 244 | _, dev_I = cpu_index.search(dev_query_embedding, 100) #I: [number of queries, topk] 245 | top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id) 246 | 247 | # measure ANN mrr 248 | _, dev_I = cpu_index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk] 249 | top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I, dev_query_embedding2id_trivia, passage_embedding2id) 250 | 251 | logger.info("Start searching for query embedding with length %d", len(query_embedding)) 252 | _, I = cpu_index.search(query_embedding, top_k) #I: [number of queries, topk] 253 | 254 | logger.info("***** GenerateNegativePassaageID *****") 255 | effective_q_id = set(query_embedding2id.flatten()) 256 | 257 | logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0]) 258 | query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id) 259 | 260 | logger.info("Done generating negative passages, output length %d", len(query_negative_passage)) 261 | 262 | logger.info("***** Construct ANN Triplet *****") 263 | train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num)) 264 | 265 | with open(train_data_output_path, 'w') as f: 266 | query_range = list(range(I.shape[0])) 267 | random.shuffle(query_range) 268 | for query_idx in query_range: 269 | query_id = query_embedding2id[query_idx] 270 | # if not query_id in train_pos_id: 271 | # continue 272 | pos_pid = train_pos_id[query_id] 273 | f.write("{}\t{}\t{}\n".format(query_id, pos_pid, ','.join(str(neg_pid) for neg_pid in query_negative_passage[query_id]))) 274 | 275 | ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num)) 276 | with open(ndcg_output_path, 'w') as f: 277 | json.dump({'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 278 | 'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path}, f) 279 | 280 | 281 | def GenerateNegativePassaageID(args, passages, answers, query_embedding2id, passage_embedding2id, closest_docs, training_query_positive_id): 282 | query_negative_passage = {} 283 | 284 | tok_opts = {} 285 | tokenizer = SimpleTokenizer(**tok_opts) 286 | 287 | for query_idx in range(closest_docs.shape[0]): 288 | query_id = query_embedding2id[query_idx] 289 | 290 | pos_pid = training_query_positive_id[query_id] 291 | doc_ids = [passage_embedding2id[pidx] for pidx in closest_docs[query_idx]] 292 | 293 | query_negative_passage[query_id] = [] 294 | neg_cnt = 0 295 | 296 | for doc_id in doc_ids: 297 | if doc_id == pos_pid: 298 | continue 299 | if doc_id in query_negative_passage[query_id]: 300 | continue 301 | if neg_cnt >= args.negative_sample: 302 | break 303 | 304 | text = passages[doc_id][0] 305 | if not has_answer(answers[query_id], text, tokenizer): 306 | query_negative_passage[query_id].append(doc_id) 307 | neg_cnt+=1 308 | 309 | return query_negative_passage 310 | 311 | 312 | def validate(passages, answers, closest_docs, query_embedding2id, passage_embedding2id): 313 | 314 | tok_opts = {} 315 | tokenizer = SimpleTokenizer(**tok_opts) 316 | 317 | logger.info('Matching answers in top docs...') 318 | scores = [] 319 | for query_idx in range(closest_docs.shape[0]): 320 | query_id = query_embedding2id[query_idx] 321 | doc_ids = [passage_embedding2id[pidx] for pidx in closest_docs[query_idx]] 322 | hits = [] 323 | for i, doc_id in enumerate(doc_ids): 324 | text = passages[doc_id][0] 325 | hits.append(has_answer(answers[query_id], text, tokenizer)) 326 | scores.append(hits) 327 | 328 | logger.info('Per question validation results len=%d', len(scores)) 329 | 330 | n_docs = len(closest_docs[0]) 331 | top_k_hits = [0] * n_docs 332 | for question_hits in scores: 333 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 334 | if best_hit is not None: 335 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 336 | 337 | logger.info('Validation results: top k documents hits %s', top_k_hits) 338 | top_k_hits = [v / len(closest_docs) for v in top_k_hits] 339 | logger.info('Validation results: top k documents hits accuracy %s', top_k_hits) 340 | return top_k_hits 341 | 342 | 343 | def get_arguments(): 344 | parser = argparse.ArgumentParser() 345 | 346 | # Required parameters 347 | parser.add_argument( 348 | "--data_dir", 349 | default=None, 350 | type=str, 351 | required=True, 352 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.", 353 | ) 354 | parser.add_argument( 355 | "--training_dir", 356 | default=None, 357 | type=str, 358 | required=True, 359 | help="Training dir, will look for latest checkpoint dir in here", 360 | ) 361 | parser.add_argument( 362 | "--init_model_dir", 363 | default=None, 364 | type=str, 365 | required=True, 366 | help="Initial model dir, will use this if no checkpoint is found in model_dir", 367 | ) 368 | parser.add_argument( 369 | "--last_checkpoint_dir", 370 | default="", 371 | type=str, 372 | help="Last checkpoint used, this is for rerunning this script when some ann data is already generated", 373 | ) 374 | parser.add_argument( 375 | "--model_type", 376 | default=None, 377 | type=str, 378 | required=True, 379 | help="Model type selected in the list: " + ", ".join(MSMarcoConfigDict.keys()), 380 | ) 381 | parser.add_argument( 382 | "--output_dir", 383 | default=None, 384 | type=str, 385 | required=True, 386 | help="The output directory where the training data will be written", 387 | ) 388 | parser.add_argument( 389 | "--cache_dir", 390 | default=None, 391 | type=str, 392 | required=True, 393 | help="The directory where cached data will be written", 394 | ) 395 | parser.add_argument( 396 | "--end_output_num", 397 | default=-1, 398 | type=int, 399 | help="Stop after this number of data versions has been generated, default run forever", 400 | ) 401 | parser.add_argument( 402 | "--max_seq_length", 403 | default=128, 404 | type=int, 405 | help="The maximum total input sequence length after tokenization. Sequences longer " 406 | "than this will be truncated, sequences shorter will be padded.", 407 | ) 408 | 409 | parser.add_argument( 410 | "--max_query_length", 411 | default=64, 412 | type=int, 413 | help="The maximum total input sequence length after tokenization. Sequences longer " 414 | "than this will be truncated, sequences shorter will be padded.", 415 | ) 416 | 417 | parser.add_argument( 418 | "--max_doc_character", 419 | default= 10000, 420 | type=int, 421 | help="used before tokenizer to save tokenizer latency", 422 | ) 423 | 424 | parser.add_argument( 425 | "--per_gpu_eval_batch_size", 426 | default=128, 427 | type=int, 428 | help="The starting output file number", 429 | ) 430 | 431 | parser.add_argument( 432 | "--ann_chunk_factor", 433 | default= 5, # for 500k queryes, divided into 100k chunks for each epoch 434 | type=int, 435 | help="devide training queries into chunks", 436 | ) 437 | 438 | parser.add_argument( 439 | "--topk_training", 440 | default= 500, 441 | type=int, 442 | help="top k from which negative samples are collected", 443 | ) 444 | 445 | parser.add_argument( 446 | "--negative_sample", 447 | default= 5, 448 | type=int, 449 | help="at each resample, how many negative samples per query do I use", 450 | ) 451 | 452 | parser.add_argument( 453 | "--ann_measure_topk_mrr", 454 | default = False, 455 | action="store_true", 456 | help="load scheduler from checkpoint or not", 457 | ) 458 | 459 | parser.add_argument( 460 | "--only_keep_latest_embedding_file", 461 | default = False, 462 | action="store_true", 463 | help="load scheduler from checkpoint or not", 464 | ) 465 | 466 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 467 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 468 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 469 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 470 | 471 | parser.add_argument( 472 | "--passage_path", 473 | default=None, 474 | type=str, 475 | required=True, 476 | help="passage_path", 477 | ) 478 | 479 | parser.add_argument( 480 | "--test_qa_path", 481 | default=None, 482 | type=str, 483 | required=True, 484 | help="test_qa_path", 485 | ) 486 | 487 | parser.add_argument( 488 | "--trivia_test_qa_path", 489 | default=None, 490 | type=str, 491 | required=True, 492 | help="trivia_test_qa_path", 493 | ) 494 | 495 | args = parser.parse_args() 496 | 497 | return args 498 | 499 | 500 | def set_env(args): 501 | # Setup CUDA, GPU & distributed training 502 | if args.local_rank == -1 or args.no_cuda: 503 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 504 | args.n_gpu = torch.cuda.device_count() 505 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 506 | torch.cuda.set_device(args.local_rank) 507 | device = torch.device("cuda", args.local_rank) 508 | torch.distributed.init_process_group(backend="nccl") 509 | args.n_gpu = 1 510 | args.device = device 511 | 512 | # store args 513 | if args.local_rank != -1: 514 | args.world_size = torch.distributed.get_world_size() 515 | args.rank = dist.get_rank() 516 | 517 | # Setup logging 518 | logging.basicConfig( 519 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 520 | datefmt="%m/%d/%Y %H:%M:%S", 521 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 522 | ) 523 | logger.warning( 524 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", 525 | args.local_rank, 526 | device, 527 | args.n_gpu, 528 | bool(args.local_rank != -1), 529 | ) 530 | 531 | def ann_data_gen(args): 532 | last_checkpoint = args.last_checkpoint_dir 533 | ann_no, ann_path, ndcg_json = get_latest_ann_data(args.output_dir) 534 | output_num = ann_no + 1 535 | 536 | logger.info("starting output number %d", output_num) 537 | preloaded_data = None 538 | 539 | if is_first_worker(): 540 | if not os.path.exists(args.output_dir): 541 | os.makedirs(args.output_dir) 542 | if not os.path.exists(args.cache_dir): 543 | os.makedirs(args.cache_dir) 544 | preloaded_data = load_data(args) 545 | 546 | while args.end_output_num == -1 or output_num <= args.end_output_num: 547 | next_checkpoint, latest_step_num = get_latest_checkpoint(args) 548 | 549 | if args.only_keep_latest_embedding_file: 550 | latest_step_num = 0 551 | 552 | if next_checkpoint == last_checkpoint: 553 | time.sleep(60) 554 | else: 555 | logger.info("start generate ann data number %d", output_num) 556 | logger.info("next checkpoint at " + next_checkpoint) 557 | generate_new_ann(args, output_num, next_checkpoint, preloaded_data, latest_step_num) 558 | logger.info("finished generating ann data number %d", output_num) 559 | output_num += 1 560 | last_checkpoint = next_checkpoint 561 | if args.local_rank != -1: 562 | dist.barrier() 563 | 564 | 565 | def main(): 566 | args = get_arguments() 567 | set_env(args) 568 | ann_data_gen(args) 569 | 570 | 571 | if __name__ == "__main__": 572 | main() --------------------------------------------------------------------------------