├── .gitignore ├── src ├── utils │ ├── __init__.py │ ├── utils.py │ ├── tokenizers.py │ └── download_data.py ├── model.py ├── data.py ├── msmarco_utils │ └── search.py └── train_dpr.py ├── scripts └── tools │ ├── embed.sh │ └── test.sh ├── requirements.txt ├── config ├── nq │ ├── vram11 │ │ ├── train_dpr_nq_bsz8.yaml │ │ ├── train_dpr_nq_gradAccum_16.yaml │ │ └── train_dpr_nq_contAccum_cache8_accum16.yaml │ ├── vram24 │ │ ├── train_dpr_nq_bsz32.yaml │ │ ├── train_dpr_nq_gradAccum_4.yaml │ │ └── train_dpr_nq_contAccum_cache8_accum4.yaml │ └── train_dpr_nq_bsz128.yaml ├── webq │ ├── vram11 │ │ ├── train_dpr_webq_bsz8.yaml │ │ ├── train_dpr_webq_gradAccum_16.yaml │ │ └── train_dpr_webq_contAccum_cache1_accum16.yaml │ ├── vram24 │ │ ├── train_dpr_webq_bsz32.yaml │ │ ├── train_dpr_webq_gradAccum_4.yaml │ │ └── train_dpr_webq_contAccum_cache1_accum4.yaml │ └── train_dpr_webq_bsz128.yaml ├── trivia │ ├── vram11 │ │ ├── train_dpr_trivia_bsz8.yaml │ │ ├── train_dpr_trivia_gradAccum_16.yaml │ │ └── train_dpr_trivia_contAccum_cache4_accum16.yaml │ ├── vram24 │ │ ├── train_dpr_trivia_bsz32.yaml │ │ ├── train_dpr_trivia_gradAccum_4.yaml │ │ └── train_dpr_trivia_contAccum_cache4_accum4.yaml │ └── train_dpr_trivia_bsz128.yaml ├── trec │ ├── train_dpr_trec_bsz128.yaml │ ├── vram11 │ │ ├── train_dpr_trec_bsz8.yaml │ │ ├── train_dpr_trec_gradAccum_16.yaml │ │ └── train_dpr_trec_contAccum_cache1_accum16.yaml │ └── vram24 │ │ ├── train_dpr_trec_bsz32.yaml │ │ ├── train_dpr_trec_gradAccum_4.yaml │ │ └── train_dpr_trec_contAccum_cache1_accum4.yaml └── msmarco │ ├── train_dpr_msmarco_bsz128.yaml │ ├── vram11 │ ├── train_dpr_msmarco_bsz8.yaml │ ├── train_dpr_msmarco_gradAccum_16.yaml │ └── train_dpr_msmarco_contAccum_cache8_accum16.yaml │ └── vram24 │ ├── train_dpr_msmarco_bsz32.yaml │ ├── train_dpr_msmarco_gradAccum_4.yaml │ └── train_dpr_msmarco_contAccum_cache8_accum4.yaml ├── doc2embedding_msmarco.py ├── data ├── download_dpr_datasets.sh └── msmarco_download_and_preprocess.ipynb ├── test_msmarco.py ├── doc2embedding.py ├── README.md └── test_dpr.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /scripts/tools/embed.sh: -------------------------------------------------------------------------------- 1 | PRETRAINED_MODEL_PATH=$1 2 | OUTPUT_DIR=$2 3 | accelerate launch --num_processes=1 doc2embedding.py \ 4 | --model_save_dir $PRETRAINED_MODEL_PATH/doc_encoder \ 5 | --embed_dir $OUTPUT_DIR/embeddings -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.20.3 2 | beir==2.0.0 3 | elasticsearch==7.9.1 4 | faiss-cpu==1.7.4 5 | jsonlines==4.0.0 6 | omegaconf==2.3.0 7 | regex==2022.6.2 8 | sentence-transformers==2.2.2 9 | sentencepiece==0.1.99 10 | torch==1.13.0a0+08820cb 11 | tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1649051611147/work 12 | transformers==4.18.0 13 | wget==3.2 -------------------------------------------------------------------------------- /scripts/tools/test.sh: -------------------------------------------------------------------------------- 1 | DEVICE=$1 2 | SAVED_MODEL_PATH=$2 3 | EMBED_DIR=$3 4 | DATA_SPLIT=$4 5 | RESULT_FILE_PATH=$5 6 | 7 | export CUDA_VISIBLE_DEVICES=$DEVICE 8 | python test_dpr.py \ 9 | --embedding_dir $EMBED_DIR/embeddings \ 10 | --pretrained_model_path $SAVED_MODEL_PATH/query_encoder \ 11 | --data_split $DATA_SPLIT \ 12 | --result_file_path $RESULT_FILE_PATH 13 | -------------------------------------------------------------------------------- /config/nq/vram11/train_dpr_nq_bsz8.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 8 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 1 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.1375 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : nq_dpr_bsz8 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_hard_neg : False 31 | cache_query : False 32 | cache_size : 8 -------------------------------------------------------------------------------- /config/nq/vram24/train_dpr_nq_bsz32.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 32 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 1 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.3 # 24GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : nq_dpr_bsz32 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_hard_neg : False 31 | cache_query : False 32 | cache_size : 8 -------------------------------------------------------------------------------- /config/webq/vram11/train_dpr_webq_bsz8.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 100 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : webq_dpr_bsz8 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_hard_neg : False 30 | cache_size : 8 31 | cache_query : False -------------------------------------------------------------------------------- /config/webq/vram24/train_dpr_webq_bsz32.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 32 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 100 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.3 # 24GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : webq_dpr_bsz32 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_hard_neg : False 30 | cache_size : 8 31 | cache_query : False -------------------------------------------------------------------------------- /config/trivia/vram11/train_dpr_trivia_bsz8.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 40 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : trivia_dpr_bsz8 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_hard_neg : False 30 | cache_size : 8 31 | cache_query : False -------------------------------------------------------------------------------- /config/trivia/vram24/train_dpr_trivia_bsz32.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 28 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 40 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 5 18 | num_other_negative_ctx: 5 19 | vram_fraction: 0.3 # 24GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : trivia_dpr_bsz28 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_hard_neg : False 30 | cache_size : 8 31 | cache_query : False -------------------------------------------------------------------------------- /config/webq/train_dpr_webq_bsz128.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 128 9 | per_device_eval_batch_size: 64 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 100 16 | seed: 19980406 17 | gradient_accumulation_steps: 1 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 1 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : webq_dpr_bsz128 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : False 31 | cache_hard_neg : False 32 | cache_size : 16 33 | -------------------------------------------------------------------------------- /config/trec/train_dpr_trec_bsz128.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 128 9 | per_device_eval_batch_size: 64 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 100 16 | seed: 19980406 17 | gradient_accumulation_steps: 1 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 1 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : trec_dpr_bsz128 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : False 31 | cache_hard_neg : False 32 | cache_size : 16 -------------------------------------------------------------------------------- /config/trec/vram11/train_dpr_trec_bsz8.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 100 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : trec_dpr_bsz8 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_hard_neg : False 30 | cache_size : 8 31 | cache_query : False -------------------------------------------------------------------------------- /config/trec/vram24/train_dpr_trec_bsz32.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 32 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 100 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.3 # 24GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : trec_dpr_bsz32 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_hard_neg : False 30 | cache_size : 8 31 | cache_query : False -------------------------------------------------------------------------------- /config/nq/train_dpr_nq_bsz128.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 128 9 | per_device_eval_batch_size: 16 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 1 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 1 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : MDF_nq_dpr_bsz128 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : False 31 | cache_hard_neg : False 32 | cache_size : 16 33 | use_hard_neg : True -------------------------------------------------------------------------------- /config/nq/vram24/train_dpr_nq_gradAccum_4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 32 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 4 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.3 # 24GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : nq_dpr_gradAccum_bsz32_accum4 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : True 31 | cache_hard_neg : False 32 | cache_size : 16 -------------------------------------------------------------------------------- /config/nq/vram11/train_dpr_nq_gradAccum_16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 8 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 16 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.1375 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : nq_dpr_gradAccum_bsz8_accum16 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : True 31 | cache_hard_neg : False 32 | cache_size : 16 -------------------------------------------------------------------------------- /config/trivia/train_dpr_trivia_bsz128.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 64 #oom issue로 변경 9 | per_device_eval_batch_size: 64 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 1 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 1 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : trivia_dpr_bsz128 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : False 31 | cache_hard_neg : False 32 | cache_size : 16 -------------------------------------------------------------------------------- /config/webq/vram24/train_dpr_webq_gradAccum_4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 32 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 100 15 | seed: 19980406 16 | gradient_accumulation_steps: 4 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.3 # 24GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : webq_dpr_gradAccum_bsz32_accum4 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_query : True 30 | cache_hard_neg : False 31 | cache_size : 16 -------------------------------------------------------------------------------- /config/webq/vram11/train_dpr_webq_gradAccum_16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 100 15 | seed: 19980406 16 | gradient_accumulation_steps: 16 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : webq_dpr_gradAccum_bsz8_accum16 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_query : True 30 | cache_hard_neg : False 31 | cache_size : 16 -------------------------------------------------------------------------------- /config/trivia/vram24/train_dpr_trivia_gradAccum_4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 28 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 4 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.3 # 24GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : trivia_dpr_gradAccum_bsz28_accum4 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : True 31 | cache_hard_neg : False 32 | cache_size : 16 -------------------------------------------------------------------------------- /config/msmarco/train_dpr_msmarco_bsz128.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json 3 | dev_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 128 8 | per_device_eval_batch_size: 16 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 10 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 1 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : msmarco_dpr_bsz128_filtered 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_query : False 30 | cache_hard_neg : False 31 | cache_size : 16 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/trec/vram24/train_dpr_trec_gradAccum_4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 32 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 100 16 | seed: 19980406 17 | gradient_accumulation_steps: 4 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.3 # 24GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : trec_dpr_gradAccum_bsz32_accum4 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : True 31 | cache_hard_neg : False 32 | cache_size : 16 -------------------------------------------------------------------------------- /config/trivia/vram11/train_dpr_trivia_gradAccum_16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 8 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 16 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.1375 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : trivia_dpr_gradAccum_bsz8_accum16 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : True 31 | cache_hard_neg : False 32 | cache_size : 16 -------------------------------------------------------------------------------- /config/msmarco/vram11/train_dpr_msmarco_bsz8.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json 3 | dev_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 10 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : msmarco_dpr_bsz8_filtered 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_query : False 30 | cache_hard_neg : False 31 | cache_size : 16 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/msmarco/vram24/train_dpr_msmarco_bsz32.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json 3 | dev_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 32 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 10 15 | seed: 19980406 16 | gradient_accumulation_steps: 1 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.3 # 24GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : msmarco_dpr_bsz32_filtered 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_query : False 30 | cache_hard_neg : False 31 | cache_size : 16 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/trec/vram11/train_dpr_trec_gradAccum_16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 8 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 100 16 | seed: 19980406 17 | gradient_accumulation_steps: 16 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.1375 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : trec_dpr_gradAccum_bsz8_accum16 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : False 29 | cont_cache : False 30 | cache_query : True 31 | cache_hard_neg : False 32 | cache_size : 16 -------------------------------------------------------------------------------- /config/nq/vram24/train_dpr_nq_contAccum_cache8_accum4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 32 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 4 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.3 # 24GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : MDF_nq_contAccum_bsz32_cache8_accum4 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : True 29 | cont_cache : True 30 | cache_query : True 31 | cache_hard_neg : True 32 | cache_size : 8 33 | use_hard_neg : True -------------------------------------------------------------------------------- /config/msmarco/vram24/train_dpr_msmarco_gradAccum_4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json 3 | dev_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 32 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 10 15 | seed: 19980406 16 | gradient_accumulation_steps: 4 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.3 # 24GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : msmarco_dpr_bsz32_accum4_filtered 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_query : False 30 | cache_hard_neg : False 31 | cache_size : 16 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/nq/vram11/train_dpr_nq_contAccum_cache8_accum16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/nq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 8 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 16 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.1375 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : MDF_nq_contAccum_bsz8_cache8_accum16 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : True 29 | cont_cache : True 30 | cache_query : True 31 | cache_hard_neg : True 32 | cache_size : 8 33 | use_hard_neg : True -------------------------------------------------------------------------------- /config/msmarco/vram11/train_dpr_msmarco_gradAccum_16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json 3 | dev_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 10 15 | seed: 19980406 16 | gradient_accumulation_steps: 16 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : msmarco_dpr_bsz8_accum16_filtered 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : False 28 | cont_cache : False 29 | cache_query : False 30 | cache_hard_neg : False 31 | cache_size : 16 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/webq/vram11/train_dpr_webq_contAccum_cache1_accum16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 100 15 | seed: 19980406 16 | gradient_accumulation_steps: 16 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : MDF_webq_contAccum_bsz8_cache1_accum16 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : True 28 | cont_cache : True 29 | cache_query : True 30 | cache_hard_neg : True 31 | cache_size : 1 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/webq/vram24/train_dpr_webq_contAccum_cache1_accum4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/webq-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 32 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 100 16 | seed: 19980406 17 | gradient_accumulation_steps: 4 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.3 # 24GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : MDF_webq_contAccum_bsz32_cache1_accum4 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : True 29 | cont_cache : True 30 | cache_query : True 31 | cache_hard_neg : True 32 | cache_size : 1 33 | use_hard_neg : True -------------------------------------------------------------------------------- /config/trivia/vram11/train_dpr_trivia_contAccum_cache4_accum16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 40 15 | seed: 19980406 16 | gradient_accumulation_steps: 16 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : MDF_trivia_contAccum_bsz8_cache4_accum16 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : True 28 | cont_cache : True 29 | cache_query : True 30 | cache_hard_neg : True 31 | cache_size : 4 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/msmarco/vram24/train_dpr_msmarco_contAccum_cache8_accum4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json 3 | dev_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 32 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 10 15 | seed: 19980406 16 | gradient_accumulation_steps: 4 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.3 # 24GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : MDF_msmarco_dpr_bsz32_cache8_accum4_filtered 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : True 28 | cont_cache : True 29 | cache_query : True 30 | cache_hard_neg : True 31 | cache_size : 8 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/trivia/vram24/train_dpr_trivia_contAccum_cache4_accum4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/trivia-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 28 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 40 16 | seed: 19980406 17 | gradient_accumulation_steps: 4 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.3 # 24GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : MDF_trivia_contAccum_bsz28_cache4_accum4 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : True 29 | cont_cache : True 30 | cache_query : True 31 | cache_hard_neg : True 32 | cache_size : 4 33 | use_hard_neg : True -------------------------------------------------------------------------------- /config/msmarco/vram11/train_dpr_msmarco_contAccum_cache8_accum16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json 3 | dev_file: /workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_dev.json 4 | 5 | ## training 6 | base_model: bert-base-uncased 7 | per_device_train_batch_size: 8 8 | per_device_eval_batch_size: 8 9 | adam_eps: 1.0e-8 10 | weight_decay: 0.0 11 | max_grad_norm: 2.0 12 | lr: 2.0e-5 13 | warmup_steps: 1237 14 | max_train_epochs: 10 15 | seed: 19980406 16 | gradient_accumulation_steps: 16 17 | num_hard_negative_ctx: 30 18 | num_other_negative_ctx: 30 19 | vram_fraction: 0.1375 # 11GB / 80GB 20 | 21 | ## logs 22 | log_dir : /workspace/mnt2/dpr_logs 23 | embed_dir : /workspace/mnt2/dpr_output/ 24 | run_name : MDF_msmarco_dpr_bsz8_cache8_accum16_filtered 25 | 26 | ## loss (Contrastive Accumulation) 27 | prev_cache : True 28 | cont_cache : True 29 | cache_query : True 30 | cache_hard_neg : True 31 | cache_size : 8 32 | use_hard_neg : True -------------------------------------------------------------------------------- /config/trec/vram24/train_dpr_trec_contAccum_cache1_accum4.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 32 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 100 16 | seed: 19980406 17 | gradient_accumulation_steps: 4 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.3 # 24GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : MDF_trec_contAccum_bsz32_cache1_accum4 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : True 29 | cont_cache : True 30 | cache_query : True 31 | cache_hard_neg : True 32 | cache_size : 1 33 | use_hard_neg : True -------------------------------------------------------------------------------- /config/trec/vram11/train_dpr_trec_contAccum_cache1_accum16.yaml: -------------------------------------------------------------------------------- 1 | ## data 2 | train_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-train.json 3 | dev_file: /workspace/mnt2/dpr_datasets/downloads/data/retriever/curatedtrec-dev.json 4 | 5 | 6 | ## training 7 | base_model: bert-base-uncased 8 | per_device_train_batch_size: 8 9 | per_device_eval_batch_size: 8 10 | adam_eps: 1.0e-8 11 | weight_decay: 0.0 12 | max_grad_norm: 2.0 13 | lr: 2.0e-5 14 | warmup_steps: 1237 15 | max_train_epochs: 100 16 | seed: 19980406 17 | gradient_accumulation_steps: 16 18 | num_hard_negative_ctx: 30 19 | num_other_negative_ctx: 30 20 | vram_fraction: 0.1375 # 11GB / 80GB 21 | 22 | ## logs 23 | log_dir : /workspace/mnt2/dpr_logs 24 | embed_dir : /workspace/mnt2/dpr_output/ 25 | run_name : MDF_trec_contAccum_bsz8_cache1_accum16 26 | 27 | ## loss (Contrastive Accumulation) 28 | prev_cache : True 29 | cont_cache : True 30 | cache_query : True 31 | cache_hard_neg : True 32 | cache_size : 1 33 | use_hard_neg : True -------------------------------------------------------------------------------- /doc2embedding_msmarco.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from src.msmarco_utils.search import FlatIPFaissSearch, CustomBiEncoder 4 | 5 | from datasets import load_dataset 6 | 7 | if __name__ == "__main__" : 8 | from tqdm import tqdm 9 | import argparse 10 | import os 11 | 12 | MSMARCO_CORPUS = load_dataset('BeIR/msmarco', 'corpus', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original')["corpus"] 13 | print("MSMARCO_CORPUS loading") 14 | MSMARCO_CORPUS = {sample['_id'] : {"title" : sample['title'], "text" : sample['text']} for sample in tqdm(MSMARCO_CORPUS)} 15 | BATCH_SIZE=1024 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--embed_dir",type=str) 19 | parser.add_argument("--model_save_dir",required=True) 20 | args = parser.parse_args() 21 | 22 | print(f""" 23 | Arguments: 24 | embed_dir: {args.embed_dir} 25 | model_save_dir: {args.model_save_dir} 26 | """ 27 | ) 28 | 29 | os.makedirs(args.embed_dir, exist_ok=True) 30 | 31 | model = CustomBiEncoder(model_save_dir=args.model_save_dir) 32 | index_model = FlatIPFaissSearch(model, batch_size=BATCH_SIZE, output_dir=args.embed_dir) 33 | index_model.embed_and_save(MSMARCO_CORPUS, score_function='dot') -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | def set_seed(seed: int = 19980406): 2 | """ 3 | Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if 4 | installed). 5 | 6 | Args: 7 | seed (:obj:`int`): The seed to set. 8 | """ 9 | import random 10 | import numpy as np 11 | import torch 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | def normalize_document(document: str): 18 | document = document.replace("\n", " ").replace("’", "'") 19 | if document.startswith('"'): 20 | document = document[1:] 21 | if document.endswith('"'): 22 | document = document[:-1] 23 | return document 24 | 25 | def normalize_query(question: str) -> str: 26 | question = question.replace("’", "'") 27 | return question 28 | 29 | def get_yaml_file(file_path): 30 | import yaml 31 | with open(file_path, "r") as file: 32 | config = yaml.safe_load(file) 33 | return config 34 | 35 | 36 | def get_linear_scheduler( 37 | optimizer, 38 | warmup_steps, 39 | total_training_steps, 40 | steps_shift=0, 41 | last_epoch=-1, 42 | ): 43 | from torch.optim.lr_scheduler import LambdaLR 44 | """Create a schedule with a learning rate that decreases linearly after 45 | linearly increasing during a warmup period. 46 | """ 47 | 48 | def lr_lambda(current_step): 49 | current_step += steps_shift 50 | if current_step < warmup_steps: 51 | return float(current_step) / float(max(1, warmup_steps)) 52 | return max( 53 | 1e-7, 54 | float(total_training_steps - current_step) / float(max(1, total_training_steps - warmup_steps)), 55 | ) 56 | 57 | return LambdaLR(optimizer, lr_lambda, last_epoch) -------------------------------------------------------------------------------- /data/download_dpr_datasets.sh: -------------------------------------------------------------------------------- 1 | # wikipedia corpus 2 | python src/utils/download_data.py --resource data.wikipedia_split.psgs_w100 --output_dir /workspace/mnt2/dpr_datasets 3 | 4 | # natural questions 5 | python src/utils/download_data.py --resource data.retriever.nq-train --output_dir /workspace/mnt2/dpr_datasets 6 | python src/utils/download_data.py --resource data.retriever.nq-dev --output_dir /workspace/mnt2/dpr_datasets 7 | python src/utils/download_data.py --resource data.retriever.qas.nq-dev --output_dir /workspace/mnt2/dpr_datasets 8 | python src/utils/download_data.py --resource data.retriever.qas.nq-test --output_dir /workspace/mnt2/dpr_datasets 9 | 10 | # trivia 11 | python src/utils/download_data.py --resource data.retriever.trivia-train --output_dir /workspace/mnt2/dpr_datasets 12 | python src/utils/download_data.py --resource data.retriever.qas.trivia-test --output_dir /workspace/mnt2/dpr_datasets 13 | python src/utils/download_data.py --resource data.retriever.qas.trivia-dev --output_dir /workspace/mnt2/dpr_datasets 14 | python src/utils/download_data.py --resource data.retriever.trivia-dev --output_dir /workspace/mnt2/dpr_datasets 15 | 16 | # web questions 17 | python src/utils/download_data.py --resource data.retriever.webq-train --output_dir /workspace/mnt2/dpr_datasets 18 | python src/utils/download_data.py --resource data.retriever.qas.webq-test --output_dir /workspace/mnt2/dpr_datasets 19 | python src/utils/download_data.py --resource data.retriever.webq-dev --output_dir /workspace/mnt2/dpr_datasets 20 | 21 | # trec 22 | python src/utils/download_data.py --resource data.retriever.curatedtrec-train --output_dir /workspace/mnt2/dpr_datasets 23 | python src/utils/download_data.py --resource data.retriever.qas.curatedtrec-test --output_dir /workspace/mnt2/dpr_datasets 24 | python src/utils/download_data.py --resource data.retriever.curatedtrec-dev --output_dir /workspace/mnt2/dpr_datasets 25 | -------------------------------------------------------------------------------- /test_msmarco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | from tqdm import tqdm 5 | from collections import defaultdict 6 | 7 | import transformers 8 | from beir.retrieval.evaluation import EvaluateRetrieval 9 | from datasets import load_dataset 10 | 11 | from src.msmarco_utils.search import FlatIPFaissSearch, CustomBiEncoder 12 | 13 | def convert_qrels_beir(qrels): 14 | new_qrels = defaultdict(dict) 15 | for qrel in qrels: 16 | if qrel['score'] > 0: 17 | new_qrels[str(qrel['query-id'])][str(qrel['corpus-id'])] = qrel['score'] 18 | return new_qrels 19 | 20 | transformers.logging.set_verbosity_error() 21 | NUM_SHARDS = 4 22 | 23 | if __name__ == '__main__': 24 | import faiss 25 | faiss.omp_set_num_threads(16) 26 | import argparse 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--embedding_dir",required=True) 29 | parser.add_argument("--model_save_dir",required=True) 30 | parser.add_argument("--data_split",required=True, choices=["validation", "test"]) 31 | parser.add_argument("--result_file_path",required=True) 32 | args = parser.parse_args() 33 | print(f""" 34 | Arguments: 35 | embedding_dir: {args.embedding_dir} 36 | model_save_dir: {args.model_save_dir} 37 | data_split: {args.data_split} 38 | result_file_path: {args.result_file_path} 39 | """) 40 | faiss.omp_set_num_threads(16) 41 | 42 | 43 | ## load QA dataset 44 | with open(f'/workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_{args.data_split}.json', 'r') as f: 45 | queries = json.load(f) 46 | queries = {query['question']['_id'] : query['question']['text'] for query in queries} 47 | 48 | corpus = load_dataset('BeIR/msmarco', 'corpus', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original')["corpus"] 49 | qrels = load_dataset('BeIR/msmarco-qrels', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original')[args.data_split] 50 | qrels = convert_qrels_beir(qrels) 51 | 52 | # make faiss index 53 | model = CustomBiEncoder(model_save_dir=args.model_save_dir) 54 | index_model = FlatIPFaissSearch(model, batch_size=1024, output_dir=args.embedding_dir) 55 | # from here 56 | index_model.load_and_index(embed_dir=args.embedding_dir, mapping_dict_dir=os.path.join(args.embedding_dir, 'mapping_dic.tsv')) 57 | retriever = EvaluateRetrieval(index_model, score_function="dot") 58 | results = retriever.retrieve(corpus, queries, ) 59 | 60 | k_values = [5, 10, 20, 100, 1000] 61 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, k_values) 62 | mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr") 63 | print(f"ndcg: {ndcg}, map: {_map}, recall: {recall}, precision: {precision}, mrr: {mrr}") 64 | 65 | 66 | index_name = "/".join(args.embedding_dir.split('/')[-2:]) 67 | print(f""" 68 | Results saved at {args.result_file_path} 69 | Row Name : {index_name} 70 | """) 71 | result_df = pd.DataFrame(dict(**ndcg, **_map, **recall, **precision, **mrr), index=[index_name]) 72 | if os.path.exists(args.result_file_path): 73 | result_df.to_csv(args.result_file_path, mode='a', header=False) 74 | else: 75 | result_df.to_csv(args.result_file_path) -------------------------------------------------------------------------------- /doc2embedding.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from tqdm import tqdm 3 | import os 4 | import transformers 5 | transformers.logging.set_verbosity_error() 6 | from transformers import ( 7 | DPRContextEncoder, 8 | DPRContextEncoderTokenizer, 9 | BertTokenizerFast, 10 | BertModel, 11 | ) 12 | import torch 13 | import numpy as np 14 | from accelerate import PartialState 15 | 16 | if __name__ == "__main__": 17 | NUM_DOCS = 21015324 18 | WIKIPEDIA_PATH = "/workspace/mnt2/dpr_datasets/downloads/data/wikipedia_split/psgs_w100.tsv" 19 | ENCODING_BATCH_SIZE = 1024 20 | 21 | import argparse 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--embed_dir",type=str) 24 | parser.add_argument("--model_save_dir",required=True) 25 | parser.add_argument("--log_dir",type=str) 26 | args = parser.parse_args() 27 | 28 | distributed_state = PartialState() 29 | device = distributed_state.device 30 | 31 | ## load encoder 32 | if args.model_save_dir == 'facebook/dpr-ctx_encoder-single-nq-base': 33 | doc_encoder = DPRContextEncoder.from_pretrained(args.model_save_dir) 34 | tokenizer = DPRContextEncoderTokenizer.from_pretrained(args.model_save_dir) 35 | else: 36 | doc_encoder = BertModel.from_pretrained(args.model_save_dir,add_pooling_layer=False) 37 | tokenizer = BertTokenizerFast.from_pretrained(args.model_save_dir) 38 | doc_encoder.eval() 39 | doc_encoder.to(device) 40 | 41 | 42 | ## load wikipedia passages 43 | progress_bar = tqdm(total=NUM_DOCS, disable=not distributed_state.is_main_process,ncols=100,desc='loading wikipedia...') 44 | id_col,text_col,title_col=0,1,2 45 | wikipedia = [] 46 | with open(WIKIPEDIA_PATH) as f: 47 | reader = csv.reader(f, delimiter="\t") 48 | for row in reader: 49 | if row[id_col] == "id":continue 50 | wikipedia.append( 51 | [row[title_col],row[text_col].strip('"')] 52 | ) 53 | progress_bar.update(1) 54 | 55 | with distributed_state.split_between_processes(wikipedia) as sharded_wikipedia: 56 | 57 | sharded_wikipedia = [sharded_wikipedia[idx:idx+ENCODING_BATCH_SIZE] for idx in range(0,len(sharded_wikipedia),ENCODING_BATCH_SIZE)] 58 | encoding_progress_bar = tqdm(total=len(sharded_wikipedia), disable=not distributed_state.is_main_process,ncols=100,desc='encoding wikipedia...') 59 | doc_embeddings = [] 60 | for data in sharded_wikipedia: 61 | title = [x[0] for x in data] 62 | passage = [x[1] for x in data] 63 | model_input = tokenizer(title,passage,max_length=256,padding='max_length',return_tensors='pt',truncation=True).to(device) 64 | with torch.no_grad(): 65 | if isinstance(doc_encoder,BertModel): 66 | CLS_POS = 0 67 | output = doc_encoder(**model_input).last_hidden_state[:,CLS_POS,:].cpu().numpy() 68 | else: 69 | output = doc_encoder(**model_input).pooler_output.cpu().numpy() 70 | doc_embeddings.append(output) 71 | encoding_progress_bar.update(1) 72 | doc_embeddings = np.concatenate(doc_embeddings,axis=0) 73 | os.makedirs(args.embed_dir,exist_ok=True) 74 | np.save(f'{args.embed_dir}/wikipedia_shard_{distributed_state.process_index}.npy',doc_embeddings) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Gradient Accumulation Method for Dense Retriever under Memory Constraint 2 | 3 | This repository is the official implementation of [A Gradient Accumulation Method for Dense Retriever under Memory Constraint](https://arxiv.org/abs/2406.12356). It is adapted from the repository [nanoDPR](https://github.com/Hannibal046/nanoDPR/tree/master), which offers a simplified replication of the DPR model. 4 | 5 | ## 1. Requirements 6 | --- 7 | To install the required packages: 8 | ```setup 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## 2. Preparing the data 13 | 14 | ### 2-1. Download DPR data 15 | --- 16 | DPR provides preprocessed datasets in their official repository. Download the datasets with the following command: 17 | ```bash 18 | bash data/download_dpr_datasets.sh 19 | ``` 20 | 21 | ### 2-2. Download and preprocess MS Marco data 22 | --- 23 | You can download and preprocess the MS Marco data using the provided scripts. The BEIR repository and Huggingface offer preprocessed MS Marco data. Additionally, you can filter hard negatives by cross-encoder scores. 24 | 25 | Find the download and preprocessing code in `data/msmarco_download_and_preprocess.ipynb`. 26 | 27 | ## 3. Training 28 | --- 29 | You can train the DPR model under various settings: 30 | ### 3-1. DPR with ContAccum in low-resource 31 | ```bash 32 | python src/train_dpr.py --config_file config/{data_name}/train_dpr_{data_name}_contAccum_cache1_accum4.yaml 33 | ``` 34 | 35 | ### 3-2. DPR in high-resource 36 | ```bash 37 | python src/train_dpr.py --config_file config/{data_name}/train_dpr_{data_name}_bsz128.yaml 38 | ``` 39 | 40 | ### 3-3. DPR in low-resource 41 | ```bash 42 | python src/train_dpr.py --config_file config/{data_name}/vram11/train_dpr_{data_name}_bsz8.yaml 43 | ``` 44 | 45 | ### 3-4. DPR with gradient accumulation in low-resource 46 | ```bash 47 | python src/train_dpr.py --config_file config/{data_name}/vram11/train_dpr_{data_name}_gradAccum_4.yaml 48 | ``` 49 | 50 | ## 4. 4. Extracting Embeddings of all passages 51 | ### For MS Marco 52 | ```bash 53 | accelerate launch --num_processes=4 doc2embedding_msmarco.py \ 54 | --embed_dir /workspace/mnt2/dpr_output/{embed_dir} \ 55 | --model_save_dir /workspace/mnt2/dpr_logs/{model_dir} 56 | ``` 57 | ### For DPR datasets 58 | ```bash 59 | bash scripts/tools/embed.sh {model_dir} {embed_dir} 60 | ``` 61 | 62 | ## 5. Evaluation 63 | ### For MS Marco 64 | ```bash 65 | python test_msmarco.py \ 66 | --embedding_dir {embed_dir} \ 67 | --model_save_dir {model_dir} \ 68 | --data_split test \ 69 | --result_file_path result.csv 70 | ``` 71 | 72 | ### For DPR datasets 73 | ```bash 74 | bash scripts/tools/test.sh 6 {model_dir}/query_encoder {embed_dir}/embeddings 75 | ``` 76 | 77 | ## 6. Implementation details 78 | ### 6-1. Traditaional InfoNCE Loss 79 | ```python 80 | # q_local: query representations in the same batch 81 | # p_local: passage representations in the same batch 82 | # labels: n x n matrix that has diagonal 1-hot element 83 | for batch in dataloader: 84 | q_local, p_local = model(batch) 85 | sim_matrix = torch.matmul(q_local, p_local.permute(1,0)) 86 | labels = torch.cat([torch.arange(single_device_query_num) + gpu_index * single_device_doc_num for gpu_index in range(accelerator.num_processes)],dim=0).to(matching_score.device) 87 | loss = F.nll_loss(input=F.log_softmax(sim_matrix,dim=1),target=labels) 88 | loss.backward() 89 | ... 90 | ``` 91 | 92 | ### 6-2. ContAccum Implementation 93 | ```python 94 | # q_local: query representations in the same batch 95 | # p_local: passage representations in the same batch 96 | # labels: n x n matrix with diagonal 1-hot elements 97 | loss_calculator = LossCalculator(args,hard_neg=args.use_hard_neg) 98 | for batch in dataloader: 99 | q_local, p_local = model(batch) 100 | loss = loss_calculator(q_local, p_local) 101 | loss.backward() 102 | if step % gradient_accumulations_step == 0: 103 | optimizer.step() 104 | optimizer.zero_grad() 105 | ... 106 | ``` 107 | ### 6-3. Hyperparameters for ContAccum 108 | All hyperparameters for ContAccum are contained in the `args` variable: 109 | 110 | • prev_cache (boolean): Whether to cache the representations generated by the previous model. If not, the memory bank is cleared out after every model update. 111 | • cache_query (boolean): Whether to cache the query representations. If not, only the passage representations are cached. 112 | • cache_hard_neg (boolean): Whether to cache the hard negative passage representations. This makes the size of the passage memory bank twice as large as the query memory bank. 113 | • cache_size (int): The memory bank size. It should be the same as the local batch size. 114 | • use_hard_neg (boolean): Whether hard negatives are used for training. This is different from the cache_hard_neg parameter. 115 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class DualEncoder(nn.Module): 8 | def __init__(self,query_encoder,doc_encoder): 9 | super().__init__() 10 | self.query_encoder = query_encoder 11 | self.doc_encoder = doc_encoder 12 | 13 | def forward( 14 | self, 15 | query_input_ids, # [bs,seq_len] 16 | query_attention_mask, # [bs,seq_len] 17 | query_token_type_ids, # [bs,seq_len], 18 | doc_input_ids, # [bs*n_doc,seq_len] 19 | doc_attention_mask, # [bs*n_doc,seq_len] 20 | doc_token_type_ids, # [bs*n_doc,seq_len] 21 | ): 22 | CLS_POS = 0 23 | ## [bs,n_dim] 24 | query_embedding = self.query_encoder( 25 | input_ids=query_input_ids, 26 | attention_mask = query_attention_mask, 27 | token_type_ids = query_token_type_ids, 28 | ).last_hidden_state[:,CLS_POS,:] 29 | 30 | ## [bs * n_doc,n_dim] 31 | doc_embedding = self.doc_encoder( 32 | input_ids = doc_input_ids, 33 | attention_mask = doc_attention_mask, 34 | token_type_ids = doc_token_type_ids, 35 | ).last_hidden_state[:,CLS_POS,:] 36 | 37 | return query_embedding,doc_embedding 38 | 39 | def calculate_dpr_loss(matching_score,labels): 40 | return F.nll_loss(input=F.log_softmax(matching_score,dim=1),target=labels) 41 | 42 | class LossCalculator(nn.Module): 43 | def __init__(self, args, hard_neg=True): 44 | super().__init__() 45 | self.prev_cache = args.prev_cache 46 | self.cache_hard_neg = args.cache_hard_neg 47 | self.gradient_accumulation_steps = args.gradient_accumulation_steps 48 | self.cache_query = args.cache_query 49 | self.hard_neg = hard_neg 50 | 51 | if self.cache_query: 52 | self.query_cache = deque(maxlen=args.cache_size*self.gradient_accumulation_steps) 53 | else : 54 | self.query_cache = None 55 | self.neg_doc_cache = deque(maxlen=args.cache_size*self.gradient_accumulation_steps) 56 | if self.cache_hard_neg: 57 | self.hard_neg_doc_cache = deque(maxlen=args.cache_size*self.gradient_accumulation_steps) 58 | else : 59 | self.hard_neg_doc_cache = None 60 | 61 | def forward(self,query_embedding, doc_embedding): 62 | """ 63 | Args: 64 | query_embedding : [bs,n_dim] 65 | doc_embedding : [bs*2,n_dim] (positive + negative) 66 | """ 67 | # concat with cache 68 | query_embedding = self.concat_with_cache(query_embedding,self.query_cache) 69 | if self.hard_neg: 70 | neg_doc_embedding, hard_neg_doc_embedding = self.split_hard_neg_doc(doc_embedding) 71 | neg_doc_embedding = self.concat_with_cache(neg_doc_embedding,self.neg_doc_cache) 72 | hard_neg_doc_embedding = self.concat_with_cache(hard_neg_doc_embedding,self.hard_neg_doc_cache) 73 | doc_embedding = torch.cat([neg_doc_embedding,hard_neg_doc_embedding],dim=0) 74 | else: 75 | neg_doc_embedding = self.concat_with_cache(doc_embedding,self.neg_doc_cache) 76 | doc_embedding = neg_doc_embedding 77 | 78 | len_query = query_embedding.shape[0] 79 | labels = torch.arange(len_query).to(query_embedding.device) 80 | matching_score = torch.matmul(query_embedding,doc_embedding.permute(1,0)) 81 | loss = calculate_dpr_loss(matching_score,labels=labels) 82 | return loss 83 | 84 | def concat_with_cache(self,embedding,cache): 85 | if type(cache) != deque : 86 | return embedding 87 | if len(cache) == 0: 88 | self.enque(embedding,cache) 89 | return embedding 90 | else: 91 | embedding_with_cache = torch.cat([embedding,torch.cat(list(cache)).to(embedding.device)],dim=0) 92 | self.enque(embedding,cache) 93 | return embedding_with_cache 94 | 95 | def enque(self, embedding, cache): 96 | if cache is None: 97 | return embedding 98 | else: 99 | cache.append(embedding.detach().clone()) 100 | 101 | def split_hard_neg_doc(self,doc_embedding): 102 | neg_doc_embedding = doc_embedding[:len(doc_embedding)//2] 103 | hard_neg_doc_embedding = doc_embedding[len(doc_embedding)//2:] 104 | return neg_doc_embedding,hard_neg_doc_embedding 105 | 106 | def empty_cache(self): 107 | if not self.prev_cache and self.cont_cache: 108 | if self.query_cache : 109 | self.query_cache.clear() 110 | self.neg_doc_cache.clear() 111 | if self.cache_hard_neg: 112 | self.hard_neg_doc_cache.clear() 113 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random, json 3 | 4 | from utils import normalize_query, normalize_document 5 | class QADataset(torch.utils.data.Dataset): 6 | def __init__(self,file_path): 7 | data = json.load(open(file_path)) 8 | self.data = [r for r in data if len(r["positive_ctxs"]) > 0 and (len(r["hard_negative_ctxs"]) > 0 or len(r["negative_ctxs"]) > 0)] 9 | 10 | print(f""" 11 | >>> DATASET INFO 12 | - file_path: {file_path} 13 | - number of samples (before cleaning): {len(data)} 14 | - number of samples (after cleaning): {len(self.data)} 15 | """) 16 | 17 | def __len__(self): 18 | return len(self.data) 19 | 20 | def __getitem__(self,idx): 21 | return self.data[idx] 22 | 23 | @staticmethod 24 | def collate_fn(samples,tokenizer,args,stage,hard_neg=True): 25 | 26 | # prepare query input 27 | queries = [normalize_query(x['question']) for x in samples] 28 | query_inputs = tokenizer(queries,max_length=256,padding=True,truncation=True,return_tensors='pt') 29 | 30 | # prepare document input 31 | ## select the first positive document 32 | positive_passages = [x['positive_ctxs'][0] for x in samples] 33 | positive_titles = [x['title'] for x in positive_passages] 34 | positive_docs = [x['text'] for x in positive_passages] 35 | 36 | if stage == 'train': 37 | ## random choose one negative document 38 | negative_passages = [random.choice(x['hard_negative_ctxs']) 39 | if len(x['hard_negative_ctxs']) != 0 40 | else random.choice(x['negative_ctxs']) 41 | for x in samples ] 42 | elif stage == 'dev': 43 | negative_passages = [x['hard_negative_ctxs'][:min(args.num_hard_negative_ctx,len(x['hard_negative_ctxs']))] 44 | + x['negative_ctxs'][:min(args.num_other_negative_ctx,len(x['negative_ctxs']))] 45 | for x in samples] 46 | negative_passages = [x for y in negative_passages for x in y] 47 | 48 | negative_titles = [x["title"] for x in negative_passages] 49 | negative_docs = [x["text"] for x in negative_passages] 50 | if hard_neg : 51 | titles = positive_titles + negative_titles 52 | docs = positive_docs + negative_docs 53 | else: 54 | titles = positive_titles 55 | docs = positive_docs 56 | doc_inputs = tokenizer(titles,docs,max_length=256,padding=True,truncation=True,return_tensors='pt') 57 | 58 | return { 59 | 'query_input_ids':query_inputs.input_ids, 60 | 'query_attention_mask':query_inputs.attention_mask, 61 | 'query_token_type_ids':query_inputs.token_type_ids, 62 | 63 | "doc_input_ids":doc_inputs.input_ids, 64 | "doc_attention_mask":doc_inputs.attention_mask, 65 | "doc_token_type_ids":doc_inputs.token_type_ids, 66 | } 67 | 68 | 69 | class BEIRDataset(torch.utils.data.Dataset): 70 | def __init__(self,file_path): 71 | self.data = json.load(open(file_path)) 72 | # self.data = [r for r in data if len(r["positive_ctxs"]) > 0 and (len(r["hard_negative_ctxs"]) > 0 or len(r["negative_ctxs"]) > 0)] 73 | 74 | print(f""" 75 | >>> DATASET INFO 76 | - file_path: {file_path} 77 | - number of samples : {len(self.data)} 78 | """) 79 | 80 | def __len__(self): 81 | return len(self.data) 82 | 83 | def __getitem__(self,idx): 84 | return self.data[idx] 85 | 86 | @staticmethod 87 | def collate_fn(samples,tokenizer,args,stage,hard_neg=True): 88 | 89 | # prepare query input 90 | queries = [normalize_query(x['question']['text']) for x in samples] 91 | query_inputs = tokenizer(queries,max_length=256,padding=True,truncation=True,return_tensors='pt') 92 | 93 | # prepare document input 94 | ## select the first positive document 95 | ## passage = title + document 96 | positive_passages = [x['positive_ctxs'][0] for x in samples] 97 | positive_titles = [x['title'] for x in positive_passages] 98 | positive_docs = [x['text'] for x in positive_passages] 99 | 100 | if stage == 'train': 101 | ## random choose one negative document 102 | negative_passages = [random.choice(x['negative_ctxs']) 103 | for x in samples ] 104 | negative_titles = [x["title"] for x in negative_passages] 105 | negative_docs = [x["text"] for x in negative_passages] 106 | 107 | if hard_neg: 108 | titles = positive_titles + negative_titles if stage == 'train' else positive_titles 109 | docs = positive_docs + negative_docs if stage == 'train' else positive_docs 110 | else: 111 | titles = positive_titles 112 | docs = positive_docs 113 | 114 | doc_inputs = tokenizer(titles,docs,max_length=256,padding=True,truncation=True,return_tensors='pt') 115 | 116 | return { 117 | 'query_input_ids':query_inputs.input_ids, 118 | 'query_attention_mask':query_inputs.attention_mask, 119 | 'query_token_type_ids':query_inputs.token_type_ids, 120 | 121 | "doc_input_ids":doc_inputs.input_ids, 122 | "doc_attention_mask":doc_inputs.attention_mask, 123 | "doc_token_type_ids":doc_inputs.token_type_ids, 124 | } -------------------------------------------------------------------------------- /test_dpr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import unicodedata 4 | import time 5 | import pickle 6 | import regex as re 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import torch 12 | import transformers 13 | from transformers import DPRQuestionEncoder,DPRQuestionEncoderTokenizer,BertModel,BertTokenizerFast 14 | 15 | from src.utils.tokenizers import SimpleTokenizer 16 | from src.utils import normalize_query 17 | 18 | WIKIEPEDIA_PATH = "/workspace/mnt2/dpr_datasets/downloads/data/wikipedia_split/psgs_w100.tsv" 19 | TEST_FILE_DIR="/workspace/mnt2/dpr_datasets/downloads/data/retriever/qas/" 20 | ENCODING_BATCH_SIZE=32 21 | NUM_DOCS=21015324 22 | 23 | transformers.logging.set_verbosity_error() 24 | 25 | def normalize(text): 26 | return unicodedata.normalize("NFD", text) 27 | 28 | def regex_match(text, pattern): 29 | """Test if a regex pattern is contained within a text.""" 30 | try: 31 | pattern = re.compile(pattern, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE) 32 | except BaseException: 33 | return False 34 | return pattern.search(text) is not None 35 | 36 | 37 | def has_answer(answers,doc,is_trec=False): 38 | if not is_trec: 39 | tokenizer = SimpleTokenizer() 40 | doc = tokenizer.tokenize(normalize(doc)).words(uncased=True) 41 | for answer in answers: 42 | answer = tokenizer.tokenize(normalize(answer)).words(uncased=True) 43 | for i in range(0, len(doc) - len(answer) + 1): 44 | if answer == doc[i : i + len(answer)]: 45 | return True 46 | else : 47 | for answer in answers : 48 | answer = normalize(answer) 49 | if regex_match(doc, answer) : 50 | return True 51 | return False 52 | 53 | if __name__ == '__main__': 54 | import faiss 55 | faiss.omp_set_num_threads(16) 56 | import argparse 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("--num_shards",type=int,default=4) 59 | parser.add_argument("--embedding_dir",required=True) 60 | parser.add_argument("--pretrained_model_path",required=True) 61 | parser.add_argument("--data_split",required=True) 62 | parser.add_argument("--result_file_path",required=True) 63 | args = parser.parse_args() 64 | 65 | ## load QA dataset 66 | query_col,answers_col=0,1 67 | queries,answers = [],[] 68 | TEST_FILE = os.path.join(TEST_FILE_DIR,args.data_split+".csv") 69 | with open(TEST_FILE) as f: 70 | reader = csv.reader(f, delimiter="\t") 71 | for row in reader: 72 | queries.append(normalize_query(row[query_col])) 73 | answers.append(eval(row[answers_col])) 74 | queries = [queries[idx:idx+ENCODING_BATCH_SIZE] for idx in range(0,len(queries),ENCODING_BATCH_SIZE)] 75 | 76 | # make faiss index 77 | embedding_dimension = 768 78 | index = faiss.IndexFlatIP(embedding_dimension) 79 | for idx in tqdm(range(args.num_shards),desc='building index from embedding...'): 80 | data = np.load(f"{args.embedding_dir}/wikipedia_shard_{idx}.npy") 81 | index.add(data) 82 | 83 | ## load wikipedia passages 84 | id_col,text_col,title_col=0,1,2 85 | wiki_passages = [] 86 | with open(WIKIEPEDIA_PATH) as f: 87 | reader = csv.reader(f, delimiter="\t") 88 | for row in tqdm(reader,total=NUM_DOCS,desc="loading wikipedia passages..."): 89 | if row[id_col] == "id":continue 90 | wiki_passages.append(row[text_col].strip('"')) 91 | 92 | ## load query encoder 93 | if args.pretrained_model_path == 'facebook/dpr-question_encoder-single-nq-base': 94 | query_encoder = DPRQuestionEncoder.from_pretrained(args.pretrained_model_path) 95 | tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(args.pretrained_model_path) 96 | else: 97 | query_encoder = BertModel.from_pretrained(args.pretrained_model_path,add_pooling_layer=False) 98 | tokenizer = BertTokenizerFast.from_pretrained(args.pretrained_model_path) 99 | device = "cuda" if torch.cuda.is_available() else "cpu" 100 | query_encoder.to(device).eval() 101 | 102 | ## embed queries 103 | query_embeddings = [] 104 | for query in tqdm(queries,desc='encoding queries...'): 105 | with torch.no_grad(): 106 | query_embedding = query_encoder(**tokenizer(query,max_length=256,truncation=True,padding='max_length',return_tensors='pt').to(device)) 107 | if isinstance(query_encoder,DPRQuestionEncoder): 108 | query_embedding = query_embedding.pooler_output 109 | else: 110 | query_embedding = query_embedding.last_hidden_state[:,0,:] 111 | query_embeddings.append(query_embedding.cpu().detach().numpy()) 112 | query_embeddings = np.concatenate(query_embeddings,axis=0) 113 | 114 | ## retrieve top-k documents 115 | print("searching index ") 116 | start_time = time.time() 117 | top_k = 100 118 | faiss.omp_set_num_threads(16) 119 | _,I = index.search(query_embeddings,top_k) 120 | print(f"takes {time.time()-start_time} s") 121 | 122 | hit_lists = [] 123 | if_trec = "trec" in args.data_split 124 | for answer_list,id_list in tqdm(zip(answers,I),total=len(answers),desc='calculating metrics...'): 125 | ## process single query 126 | hit_list = [] 127 | for doc_id in id_list: 128 | doc = wiki_passages[doc_id] 129 | hit_list.append(has_answer(answer_list,doc,if_trec)) 130 | hit_lists.append(hit_list) 131 | 132 | top_k_hits = [0]*top_k 133 | best_hits = [] 134 | for hit_list in hit_lists: 135 | best_hit = next((i for i, x in enumerate(hit_list) if x), None) 136 | if best_hit is not None: 137 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 138 | 139 | top_k_ratio = [x/len(answers) for x in top_k_hits] 140 | 141 | test_topk = [4,19,99] 142 | 143 | step_and_epoch = args.pretrained_model_path.split('/')[-2] 144 | exp_name = args.pretrained_model_path.split('/')[-3] 145 | epoch = step_and_epoch.split('-')[-1] 146 | 147 | result_dict = { 148 | "epoch" : epoch, 149 | "exp_name" : exp_name, 150 | "top_5" : top_k_ratio[4], 151 | "top_20" : top_k_ratio[19], 152 | "top_100" : top_k_ratio[99] 153 | } 154 | 155 | result_df = pd.DataFrame(result_dict, index=[0]) 156 | 157 | if os.path.exists(args.result_file_path) : 158 | result_df.to_csv(args.result_file_path, mode='a', header=False) 159 | else : 160 | result_df.to_csv(args.result_file_path) 161 | 162 | print(f"EXP: {exp_name} EPOCH: {epoch}") 163 | for idx in [4,19,99]: 164 | print(f"top-{idx+1} accuracy",top_k_ratio[idx]) 165 | -------------------------------------------------------------------------------- /data/msmarco_download_and_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 0. load libraries" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 3, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from beir import util\n", 17 | "from beir.datasets.data_loader import GenericDataLoader\n", 18 | "from tqdm.autonotebook import tqdm\n", 19 | "import os, gzip, json\n", 20 | "from datasets import load_dataset\n", 21 | "from tqdm import tqdm\n", 22 | "import numpy as np\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "# 1. download hard negative passages of msmarco mined by sentence-transformers" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 4, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "triplets_url = \"https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz\"\n", 39 | "data_path = \"/workspace/mnt2/dpr_datasets/msmarco/sbert\"\n", 40 | "msmarco_triplets_filepath = os.path.join(data_path, \"msmarco-hard-negatives.jsonl.gz\")\n", 41 | "if not os.path.isfile(msmarco_triplets_filepath):\n", 42 | " util.download_url(triplets_url, msmarco_triplets_filepath)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "dataset = \"msmarco\"\n", 52 | "url = \"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip\".format(dataset)\n", 53 | "out_dir = \"/workspace/mnt2/dpr_datasets/msmarco/beir/msmarco\"\n", 54 | "data_path = util.download_and_unzip(url, out_dir)\n", 55 | "corpus, queries, _ = GenericDataLoader(data_path).load(split=\"train\")" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# 2. select the hard negative passages which has cross encoder score lower than positive passages - 3 " 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "ce_score_margin = 3\n", 72 | "num_negs_per_system = 10\n", 73 | "train_queries = {}\n", 74 | "not_selected_samples = []\n", 75 | "cnt=0\n", 76 | "with gzip.open(msmarco_triplets_filepath, 'rt', encoding='utf8') as fIn:\n", 77 | " for line in tqdm(fIn, total=502939):\n", 78 | " not_selected_samples.append(cnt)\n", 79 | " cnt = 0\n", 80 | " data = json.loads(line)\n", 81 | " \n", 82 | " #Get the positive passage ids\n", 83 | " pos_pids = [item['pid'] for item in data['pos']]\n", 84 | " pos_min_ce_score = min([item['ce-score'] for item in data['pos']])\n", 85 | " ce_score_threshold = pos_min_ce_score - ce_score_margin\n", 86 | " \n", 87 | " #Get the hard negatives\n", 88 | " neg_pids = set()\n", 89 | "\n", 90 | " if 'bm25' not in data['neg']:\n", 91 | " continue\n", 92 | " system_negs = data['neg']['bm25']\n", 93 | " negs_added = 0\n", 94 | " for item in system_negs:\n", 95 | " if item['ce-score'] > ce_score_threshold:\n", 96 | " cnt += 1\n", 97 | " continue\n", 98 | "\n", 99 | " pid = item['pid']\n", 100 | " if pid not in neg_pids:\n", 101 | " neg_pids.add(pid)\n", 102 | " negs_added += 1\n", 103 | " if negs_added >= num_negs_per_system:\n", 104 | " break\n", 105 | " \n", 106 | " if len(pos_pids) > 0 and len(neg_pids) > 0:\n", 107 | " train_queries[data['qid']] = {\n", 108 | " 'query': queries[data['qid']], \n", 109 | " 'pos': pos_pids, \n", 110 | " 'hard_neg': list(neg_pids)}\n", 111 | " " 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## 3. Preprocess the hard negative passages with the original msmarco data" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "# it took more than 40 minutes to download the dataset\n", 128 | "corpus = load_dataset('BeIR/msmarco', 'corpus', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original')\n", 129 | "query = load_dataset('BeIR/msmarco', 'queries', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original')\n", 130 | "qrels = load_dataset('BeIR/msmarco-qrels', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original') # train/validation/test" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "queries = {}\n", 140 | "\n", 141 | "for line in tqdm(query['queries']):\n", 142 | " queries[line['_id']] = line\n", 143 | "\n", 144 | "corpus_ = {}\n", 145 | "\n", 146 | "for line in tqdm(corpus['corpus']):\n", 147 | " corpus_[line['_id']] = line" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "msmarco = []\n", 157 | "for qid, qrel in tqdm(train_queries.items()):\n", 158 | " data = {}\n", 159 | " data['dataset'] = 'msmarco'\n", 160 | " data['question'] = {'text' : qrel['query']}\n", 161 | " data['positive_ctxs'] = [corpus_[pid] for pid in qrel['pos']]\n", 162 | " data['negative_ctxs'] = [corpus_[pid] for pid in qrel['hard_neg']]\n", 163 | " msmarco.append(data)\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 10, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "with open('/workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json', 'w') as f:\n", 173 | " json.dump(msmarco, f, indent=4)" 174 | ] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "base", 180 | "language": "python", 181 | "name": "python3" 182 | }, 183 | "language_info": { 184 | "codemirror_mode": { 185 | "name": "ipython", 186 | "version": 3 187 | }, 188 | "file_extension": ".py", 189 | "mimetype": "text/x-python", 190 | "name": "python", 191 | "nbconvert_exporter": "python", 192 | "pygments_lexer": "ipython3", 193 | "version": "3.0.0" 194 | } 195 | }, 196 | "nbformat": 4, 197 | "nbformat_minor": 2 198 | } 199 | -------------------------------------------------------------------------------- /src/msmarco_utils/search.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict 3 | 4 | import torch 5 | import faiss 6 | import numpy as np 7 | from tqdm import trange, tqdm 8 | 9 | from beir.retrieval.search import BaseSearch 10 | from beir.retrieval.search.dense.util import save_dict_to_tsv, load_tsv_to_dict 11 | from beir.retrieval.search.dense.faiss_index import FaissIndex 12 | 13 | from accelerate import PartialState 14 | from transformers import BertTokenizer, BertModel 15 | 16 | class FlatIPFaissSearch(BaseSearch): 17 | def __init__( 18 | self, 19 | model, 20 | batch_size: int = 128, 21 | output_dir: str = None, 22 | **kwargs 23 | ): 24 | self.model = model 25 | self.batch_size = batch_size 26 | self.score_functions = ['cos_sim','dot'] 27 | self.mapping_tsv_keys = ["beir-docid", "faiss-docid"] 28 | self.dim_size = 768 29 | self.output_dir = output_dir 30 | 31 | def _create_mapping_ids(self, corpus_ids): 32 | self.mapping = {} 33 | self.rev_mapping = {} 34 | if not all(isinstance(doc_id, int) for doc_id in corpus_ids): 35 | for idx in range(len(corpus_ids)): 36 | self.mapping[corpus_ids[idx]] = idx 37 | self.rev_mapping[idx] = corpus_ids[idx] 38 | 39 | def save(self, output_dir: str): 40 | """ 41 | save embedding and mapping to disk 42 | """ 43 | # get current process id and concat to prefix 44 | distributed_state = PartialState() 45 | pid = distributed_state.process_index 46 | # Save Faiss Index to disk 47 | save_embed_path = os.path.join(output_dir, "{}_{}".format('embed', pid)) 48 | print("Saving Embeddings to path: {}".format(save_embed_path)) 49 | np.save(save_embed_path, self.corpus_embeddings) 50 | if pid == 0: 51 | print("Index size: {:.2f}MB".format(os.path.getsize(save_embed_path + ".npy")*0.000001*4)) 52 | 53 | # Save Mapping to disk 54 | if pid == 0: 55 | save_dict_to_tsv(self.mapping, os.path.join(output_dir, "mapping_dic.tsv"), self.mapping_tsv_keys) 56 | print("Mapping saved to path: {}".format(os.path.join(output_dir, "mapping_dic.tsv"))) 57 | 58 | def _embed(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs): 59 | corpus_ids = [key for key, value in corpus.items()] 60 | self._create_mapping_ids(corpus_ids) 61 | corpus = [corpus[cid] for cid in corpus_ids] 62 | 63 | print("Encoding Corpus in batches... Warning: This might take a while!") 64 | 65 | distributed_state = PartialState() 66 | device = distributed_state.device 67 | self.model.ctx_encoder.to(device) 68 | with distributed_state.split_between_processes(corpus) as sharded_corpus: 69 | shared_corpus = [sharded_corpus[idx:idx+self.batch_size] for idx in range(0, len(sharded_corpus), self.batch_size)] 70 | encoding_progress = tqdm(total=len(shared_corpus), desc="Encoding Passages: ", ncols=100, disable=not distributed_state.is_main_process) 71 | corpus_encoding = [] 72 | for data in shared_corpus: 73 | sub_encoding = self.model.encode_corpus(corpus_batch=data, device=device) 74 | corpus_encoding.append(sub_encoding) 75 | encoding_progress.update(1) 76 | self.corpus_embeddings = np.concatenate(corpus_encoding, axis=0) 77 | 78 | 79 | def embed_and_save(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs): 80 | self._embed(corpus, score_function, **kwargs) 81 | self.save(self.output_dir) 82 | 83 | def load_and_index(self, embed_dir: str = "embed/", mapping_dict_dir:str = "mapping_dic.tsv"): 84 | # load and get the faiss ids from mapping(msmarco-docid to faiss-docid) 85 | sub_corpus_embeddings = [np.load(os.path.join(embed_dir, f)) for f in os.listdir(embed_dir) if f.endswith(".npy")] # it should be 4 np files which are generated by 4 shards 86 | self.corpus_embeddings = np.concatenate(sub_corpus_embeddings, axis=0) 87 | self.mapping = load_tsv_to_dict(mapping_dict_dir, header=True) 88 | self.faiss_ids = list(self.mapping.values()) 89 | self.rev_mapping = {v: k for k, v in self.mapping.items()} 90 | 91 | # index the corpus 92 | base_index = faiss.IndexFlatIP(self.dim_size) 93 | self.faiss_index = FaissIndex.build(self.faiss_ids, self.corpus_embeddings, base_index) 94 | 95 | def search(self, 96 | corpus: Dict[str, Dict[str, str]], 97 | queries: Dict[str, str], 98 | top_k: int, 99 | score_function = str, **kwargs) -> Dict[str, Dict[str, float]]: 100 | self.results = {} 101 | 102 | assert score_function in self.score_functions 103 | normalize_embeddings = True if score_function == "cos_sim" else False 104 | 105 | query_ids = list(queries.keys()) 106 | queries = [queries[qid] for qid in queries] 107 | print("Computing Query Embeddings. Normalize: {}...".format(normalize_embeddings)) 108 | query_embeddings = self.model.encode_queries( 109 | queries, show_progress_bar=True, 110 | batch_size=self.batch_size, 111 | normalize_embeddings=normalize_embeddings) 112 | 113 | faiss_scores, faiss_doc_ids = self.faiss_index.search(query_embeddings, top_k, **kwargs) 114 | 115 | for idx in range(len(query_ids)): 116 | scores = [float(score) for score in faiss_scores[idx]] 117 | if len(self.rev_mapping) != 0: 118 | doc_ids = [self.rev_mapping[doc_id] for doc_id in faiss_doc_ids[idx]] 119 | else: 120 | doc_ids = [str(doc_id) for doc_id in faiss_doc_ids[idx]] 121 | self.results[query_ids[idx]] = dict(zip(doc_ids, scores)) 122 | 123 | return self.results 124 | 125 | def get_index_name(self): 126 | return "flat_faiss_index" 127 | 128 | 129 | class CustomBiEncoder : 130 | def __init__(self, model_save_dir=None, ) : 131 | self.query_encoder = BertModel.from_pretrained(os.path.join(model_save_dir, 'query_encoder')) 132 | self.ctx_encoder = BertModel.from_pretrained(os.path.join(model_save_dir, 'doc_encoder')) 133 | 134 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 135 | 136 | def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray: 137 | CLS_POS=0 138 | query_embeddings = [] 139 | self.query_encoder.cuda() 140 | with torch.no_grad(): 141 | for start_idx in trange(0, len(queries), batch_size, desc = '>>> Encoding Queries : '): 142 | encoded = self.tokenizer(queries[start_idx:start_idx+batch_size], max_length=256,padding='max_length',return_tensors='pt',truncation=True) 143 | query_repr = self.query_encoder( 144 | input_ids = encoded['input_ids'].cuda(), 145 | attention_mask = encoded['attention_mask'].cuda(), 146 | token_type_ids = encoded['token_type_ids'].cuda(), 147 | ).last_hidden_state[:,CLS_POS,:].detach().cpu() 148 | query_embeddings += query_repr.detach().cpu() 149 | 150 | return torch.stack(query_embeddings) 151 | 152 | def encode_corpus(self, corpus_batch: List[Dict[str, str]], device, **kwargs) -> np.ndarray: 153 | CLS_POS=0 154 | with torch.no_grad(): 155 | titles = [row['title'] for row in corpus_batch] 156 | texts = [row['text'] for row in corpus_batch] 157 | encoded = self.tokenizer(titles, texts, truncation='longest_first', padding=True, return_tensors='pt') 158 | corpus_repr = self.ctx_encoder( 159 | input_ids = encoded['input_ids'].to(device), 160 | attention_mask = encoded['attention_mask'].to(device), 161 | token_type_ids = encoded['token_type_ids'].to(device) 162 | ).last_hidden_state[:,CLS_POS,:].detach().cpu() 163 | return corpus_repr -------------------------------------------------------------------------------- /src/utils/tokenizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | """ 10 | copied from DPR codebase 11 | """ 12 | 13 | import copy 14 | import logging 15 | 16 | import regex 17 | import spacy 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Tokens(object): 23 | """A class to represent a list of tokenized text.""" 24 | 25 | TEXT = 0 26 | TEXT_WS = 1 27 | SPAN = 2 28 | POS = 3 29 | LEMMA = 4 30 | NER = 5 31 | 32 | def __init__(self, data, annotators, opts=None): 33 | self.data = data 34 | self.annotators = annotators 35 | self.opts = opts or {} 36 | 37 | def __len__(self): 38 | """The number of tokens.""" 39 | return len(self.data) 40 | 41 | def slice(self, i=None, j=None): 42 | """Return a view of the list of tokens from [i, j).""" 43 | new_tokens = copy.copy(self) 44 | new_tokens.data = self.data[i:j] 45 | return new_tokens 46 | 47 | def untokenize(self): 48 | """Returns the original text (with whitespace reinserted).""" 49 | return "".join([t[self.TEXT_WS] for t in self.data]).strip() 50 | 51 | def words(self, uncased=False): 52 | """Returns a list of the text of each token 53 | 54 | Args: 55 | uncased: lower cases text 56 | """ 57 | if uncased: 58 | return [t[self.TEXT].lower() for t in self.data] 59 | else: 60 | return [t[self.TEXT] for t in self.data] 61 | 62 | def offsets(self): 63 | """Returns a list of [start, end) character offsets of each token.""" 64 | return [t[self.SPAN] for t in self.data] 65 | 66 | def pos(self): 67 | """Returns a list of part-of-speech tags of each token. 68 | Returns None if this annotation was not included. 69 | """ 70 | if "pos" not in self.annotators: 71 | return None 72 | return [t[self.POS] for t in self.data] 73 | 74 | def lemmas(self): 75 | """Returns a list of the lemmatized text of each token. 76 | Returns None if this annotation was not included. 77 | """ 78 | if "lemma" not in self.annotators: 79 | return None 80 | return [t[self.LEMMA] for t in self.data] 81 | 82 | def entities(self): 83 | """Returns a list of named-entity-recognition tags of each token. 84 | Returns None if this annotation was not included. 85 | """ 86 | if "ner" not in self.annotators: 87 | return None 88 | return [t[self.NER] for t in self.data] 89 | 90 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 91 | """Returns a list of all ngrams from length 1 to n. 92 | 93 | Args: 94 | n: upper limit of ngram length 95 | uncased: lower cases text 96 | filter_fn: user function that takes in an ngram list and returns 97 | True or False to keep or not keep the ngram 98 | as_string: return the ngram as a string vs list 99 | """ 100 | 101 | def _skip(gram): 102 | if not filter_fn: 103 | return False 104 | return filter_fn(gram) 105 | 106 | words = self.words(uncased) 107 | ngrams = [ 108 | (s, e + 1) 109 | for s in range(len(words)) 110 | for e in range(s, min(s + n, len(words))) 111 | if not _skip(words[s : e + 1]) 112 | ] 113 | 114 | # Concatenate into strings 115 | if as_strings: 116 | ngrams = ["{}".format(" ".join(words[s:e])) for (s, e) in ngrams] 117 | 118 | return ngrams 119 | 120 | def entity_groups(self): 121 | """Group consecutive entity tokens with the same NER tag.""" 122 | entities = self.entities() 123 | if not entities: 124 | return None 125 | non_ent = self.opts.get("non_ent", "O") 126 | groups = [] 127 | idx = 0 128 | while idx < len(entities): 129 | ner_tag = entities[idx] 130 | # Check for entity tag 131 | if ner_tag != non_ent: 132 | # Chomp the sequence 133 | start = idx 134 | while idx < len(entities) and entities[idx] == ner_tag: 135 | idx += 1 136 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 137 | else: 138 | idx += 1 139 | return groups 140 | 141 | 142 | class Tokenizer(object): 143 | """Base tokenizer class. 144 | Tokenizers implement tokenize, which should return a Tokens class. 145 | """ 146 | 147 | def tokenize(self, text): 148 | raise NotImplementedError 149 | 150 | def shutdown(self): 151 | pass 152 | 153 | def __del__(self): 154 | self.shutdown() 155 | 156 | 157 | class SimpleTokenizer(Tokenizer): 158 | ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+" 159 | NON_WS = r"[^\p{Z}\p{C}]" 160 | 161 | def __init__(self, **kwargs): 162 | """ 163 | Args: 164 | annotators: None or empty set (only tokenizes). 165 | """ 166 | self._regexp = regex.compile( 167 | "(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS), 168 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, 169 | ) 170 | if len(kwargs.get("annotators", {})) > 0: 171 | logger.warning( 172 | "%s only tokenizes! Skipping annotators: %s" % (type(self).__name__, kwargs.get("annotators")) 173 | ) 174 | self.annotators = set() 175 | 176 | def tokenize(self, text): 177 | data = [] 178 | matches = [m for m in self._regexp.finditer(text)] 179 | for i in range(len(matches)): 180 | # Get text 181 | token = matches[i].group() 182 | 183 | # Get whitespace 184 | span = matches[i].span() 185 | start_ws = span[0] 186 | if i + 1 < len(matches): 187 | end_ws = matches[i + 1].span()[0] 188 | else: 189 | end_ws = span[1] 190 | 191 | # Format data 192 | data.append( 193 | ( 194 | token, 195 | text[start_ws:end_ws], 196 | span, 197 | ) 198 | ) 199 | return Tokens(data, self.annotators) 200 | 201 | 202 | class SpacyTokenizer(Tokenizer): 203 | def __init__(self, **kwargs): 204 | """ 205 | Args: 206 | annotators: set that can include pos, lemma, and ner. 207 | model: spaCy model to use (either path, or keyword like 'en'). 208 | """ 209 | model = kwargs.get("model", "en_core_web_sm") # TODO: replace with en ? 210 | self.annotators = copy.deepcopy(kwargs.get("annotators", set())) 211 | nlp_kwargs = {"parser": False} 212 | if not any([p in self.annotators for p in ["lemma", "pos", "ner"]]): 213 | nlp_kwargs["tagger"] = False 214 | if "ner" not in self.annotators: 215 | nlp_kwargs["entity"] = False 216 | self.nlp = spacy.load(model, **nlp_kwargs) 217 | 218 | def tokenize(self, text): 219 | # We don't treat new lines as tokens. 220 | clean_text = text.replace("\n", " ") 221 | tokens = self.nlp.tokenizer(clean_text) 222 | if any([p in self.annotators for p in ["lemma", "pos", "ner"]]): 223 | self.nlp.tagger(tokens) 224 | if "ner" in self.annotators: 225 | self.nlp.entity(tokens) 226 | 227 | data = [] 228 | for i in range(len(tokens)): 229 | # Get whitespace 230 | start_ws = tokens[i].idx 231 | if i + 1 < len(tokens): 232 | end_ws = tokens[i + 1].idx 233 | else: 234 | end_ws = tokens[i].idx + len(tokens[i].text) 235 | 236 | data.append( 237 | ( 238 | tokens[i].text, 239 | text[start_ws:end_ws], 240 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 241 | tokens[i].tag_, 242 | tokens[i].lemma_, 243 | tokens[i].ent_type_, 244 | ) 245 | ) 246 | 247 | # Set special option for non-entity tag: '' vs 'O' in spaCy 248 | return Tokens(data, self.annotators, opts={"non_ent": ""}) 249 | -------------------------------------------------------------------------------- /src/train_dpr.py: -------------------------------------------------------------------------------- 1 | ## built-in 2 | import math,logging,functools,os 3 | import types 4 | os.environ["TOKENIZERS_PARALLELISM"]='true' 5 | 6 | ## third-party 7 | from accelerate import Accelerator 8 | from accelerate.logging import get_logger 9 | from accelerate.utils import DistributedDataParallelKwargs 10 | import transformers 11 | from transformers import ( 12 | BertTokenizer, 13 | BertModel, 14 | ) 15 | transformers.logging.set_verbosity_error() 16 | logging.basicConfig( 17 | format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | datefmt = '%m/%d/%Y %H:%M:%S', 19 | level = logging.INFO 20 | ) 21 | 22 | import torch 23 | import torch.distributed as dist 24 | from tqdm import tqdm 25 | 26 | ## own 27 | from utils import ( 28 | get_yaml_file, 29 | set_seed, 30 | get_linear_scheduler, 31 | ) 32 | from model import DualEncoder, calculate_dpr_loss, LossCalculator 33 | from data import QADataset, BEIRDataset 34 | 35 | logging.basicConfig(level=logging.INFO) 36 | logger = get_logger(__name__) 37 | 38 | def parse_args(): 39 | import argparse 40 | parser = argparse.ArgumentParser() 41 | ## adding args here for more control from CLI is possible 42 | parser.add_argument("--config_file",default='config/train_dpr_nq.yaml') 43 | args = parser.parse_args() 44 | 45 | yaml_config = get_yaml_file(args.config_file) 46 | args_dict = {k:v for k,v in vars(args).items() if v is not None} 47 | yaml_config.update(args_dict) 48 | args = types.SimpleNamespace(**yaml_config) 49 | return args 50 | 51 | def calculate_hit_cnt(matching_score,labels): 52 | _, max_ids = torch.max(matching_score,1) 53 | return (max_ids == labels).sum() 54 | 55 | def calculate_average_rank(matching_score,labels): 56 | _,indices = torch.sort(matching_score,dim=1,descending=True) 57 | ranks = [] 58 | for idx,label in enumerate(labels): 59 | rank = ((indices[idx] == label).nonzero()).item() + 1 ## rank starts from 1 60 | ranks.append(rank) 61 | return ranks 62 | 63 | def validate(model,dataloader,accelerator): 64 | model.eval() 65 | query_embeddings = [] 66 | positive_doc_embeddings = [] 67 | negative_doc_embeddings = [] 68 | for batch in dataloader: 69 | with torch.no_grad(): 70 | query_embedding,doc_embedding = model(**batch) 71 | query_num,_ = query_embedding.shape 72 | query_embeddings.append(query_embedding.cpu()) 73 | positive_doc_embeddings.append(doc_embedding[:query_num,:].cpu()) 74 | negative_doc_embeddings.append(doc_embedding[query_num:,:].cpu()) 75 | 76 | query_embeddings = torch.cat(query_embeddings,dim=0) 77 | doc_embeddings = torch.cat(positive_doc_embeddings+negative_doc_embeddings,dim=0) 78 | matching_score = torch.matmul(query_embeddings,doc_embeddings.permute(1,0)) # bs, num_pos+num_neg 79 | labels = torch.arange(query_embeddings.shape[0],dtype=torch.int64).to(matching_score.device) 80 | loss = calculate_dpr_loss(matching_score,labels=labels).item() 81 | ranks = calculate_average_rank(matching_score,labels=labels) 82 | if accelerator.use_distributed and accelerator.num_processes>1: 83 | ranks_from_all_gpus = [None for _ in range(accelerator.num_processes)] 84 | dist.all_gather_object(ranks_from_all_gpus,ranks) 85 | ranks = [x for y in ranks_from_all_gpus for x in y] 86 | 87 | loss_from_all_gpus = [None for _ in range(accelerator.num_processes)] 88 | dist.all_gather_object(loss_from_all_gpus,loss) 89 | loss = sum(loss_from_all_gpus)/len(loss_from_all_gpus) 90 | 91 | return sum(ranks)/len(ranks),loss 92 | 93 | def main(): 94 | args = parse_args() 95 | set_seed(args.seed) 96 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) 97 | if args.vram_fraction != 1: 98 | torch.cuda.set_per_process_memory_fraction(args.vram_fraction) 99 | print(f">>> set vram fraction to {args.vram_fraction*torch.cuda.get_device_properties(0).total_memory/1024**3}GB") 100 | accelerator = Accelerator( 101 | gradient_accumulation_steps=args.gradient_accumulation_steps, 102 | mixed_precision='no', 103 | kwargs_handlers=[kwargs] 104 | ) 105 | if accelerator.is_local_main_process: 106 | LOG_DIR = os.path.join(args.log_dir, args.run_name) 107 | 108 | tokenizer = BertTokenizer.from_pretrained(args.base_model) 109 | query_encoder = BertModel.from_pretrained(args.base_model,add_pooling_layer=False) 110 | doc_encoder = BertModel.from_pretrained(args.base_model,add_pooling_layer=False) 111 | dual_encoder = DualEncoder(query_encoder,doc_encoder) 112 | dual_encoder.train() 113 | 114 | collate_fn = QADataset.collate_fn if "msmarco" not in args.run_name else BEIRDataset.collate_fn 115 | 116 | train_dataset = QADataset(args.train_file) if "msmarco" not in args.run_name else BEIRDataset(args.train_file) 117 | train_collate_fn = functools.partial(collate_fn,tokenizer=tokenizer,stage='train',args=args,hard_neg=args.use_hard_neg) 118 | train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=args.per_device_train_batch_size,shuffle=True,collate_fn=train_collate_fn,num_workers=4,pin_memory=True) 119 | 120 | dev_dataset = QADataset(args.dev_file) if "msmarco" not in args.run_name else BEIRDataset(args.dev_file) 121 | dev_collate_fn = functools.partial(collate_fn,tokenizer=tokenizer,stage='dev',args=args,hard_neg=args.use_hard_neg) 122 | dev_dataloader = torch.utils.data.DataLoader(dev_dataset,batch_size=args.per_device_eval_batch_size,shuffle=False,collate_fn=dev_collate_fn,num_workers=4,pin_memory=True) 123 | 124 | 125 | no_decay = ["bias", "LayerNorm.weight"] 126 | optimizer_grouped_parameters = [ 127 | { 128 | "params": [p for n, p in dual_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 129 | "weight_decay": args.weight_decay, 130 | }, 131 | { 132 | "params": [p for n, p in dual_encoder.named_parameters() if any(nd in n for nd in no_decay)], 133 | "weight_decay": 0.0, 134 | }, 135 | ] 136 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters,lr=args.lr, eps=args.adam_eps) 137 | 138 | NUM_UPDATES_PER_EPOCH = math.ceil(len(train_dataloader) / (args.gradient_accumulation_steps*accelerator.num_processes)) 139 | MAX_TRAIN_STEPS = NUM_UPDATES_PER_EPOCH * args.max_train_epochs 140 | MAX_TRAIN_EPOCHS = math.ceil(MAX_TRAIN_STEPS / NUM_UPDATES_PER_EPOCH) 141 | TOTAL_TRAIN_BATCH_SIZE = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 142 | SAVE_EPOCHS = 20 if MAX_TRAIN_EPOCHS > 40 else 40 143 | lr_scheduler = get_linear_scheduler(optimizer,warmup_steps=args.warmup_steps,total_training_steps=MAX_TRAIN_STEPS) 144 | loss_calculator = LossCalculator(args,hard_neg=args.use_hard_neg) 145 | 146 | dual_encoder, optimizer, train_dataloader, dev_dataloader, loss_calculator = accelerator.prepare( 147 | dual_encoder, optimizer, train_dataloader, dev_dataloader, loss_calculator 148 | ) 149 | 150 | logger.info("***** Running training *****") 151 | logger.info(f" Num train examples = {len(train_dataset)}") 152 | logger.info(f" Num dev examples = {len(dev_dataset)}") 153 | logger.info(f" Num Epochs = {MAX_TRAIN_EPOCHS}") 154 | logger.info(f" Per device train batch size = {args.per_device_train_batch_size}") 155 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {TOTAL_TRAIN_BATCH_SIZE}") 156 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 157 | logger.info(f" Total optimization steps = {MAX_TRAIN_STEPS}") 158 | logger.info(f" Per device eval batch size = {args.per_device_eval_batch_size}") 159 | completed_steps = 0 160 | progress_bar = tqdm(range(MAX_TRAIN_STEPS), disable=not accelerator.is_local_main_process,ncols=100) 161 | 162 | for epoch in range(MAX_TRAIN_EPOCHS): 163 | set_seed(args.seed+epoch) 164 | progress_bar.set_description(f"epoch: {epoch+1}/{MAX_TRAIN_EPOCHS}") 165 | for step,batch in enumerate(train_dataloader): 166 | with accelerator.accumulate(dual_encoder): 167 | with accelerator.autocast(): 168 | query_embedding,doc_embedding = dual_encoder(**batch) 169 | single_device_query_num,_ = query_embedding.shape 170 | single_device_doc_num,_ = doc_embedding.shape 171 | if accelerator.use_distributed: 172 | doc_list = [torch.zeros_like(doc_embedding) for _ in range(accelerator.num_processes)] 173 | dist.all_gather(tensor_list=doc_list, tensor=doc_embedding.contiguous()) 174 | doc_list[dist.get_rank()] = doc_embedding 175 | doc_embedding = torch.cat(doc_list, dim=0) 176 | 177 | query_list = [torch.zeros_like(query_embedding) for _ in range(accelerator.num_processes)] 178 | dist.all_gather(tensor_list=query_list, tensor=query_embedding.contiguous()) 179 | query_list[dist.get_rank()] = query_embedding 180 | query_embedding = torch.cat(query_list, dim=0) 181 | 182 | if args.cont_cache : 183 | loss = loss_calculator(query_embedding,doc_embedding) 184 | else : 185 | matching_score = torch.matmul(query_embedding,doc_embedding.permute(1,0)) 186 | labels = torch.cat([torch.arange(single_device_query_num) + gpu_index * single_device_doc_num for gpu_index in range(accelerator.num_processes)],dim=0).to(matching_score.device) 187 | loss = calculate_dpr_loss(matching_score,labels=labels) 188 | 189 | accelerator.backward(loss) 190 | 191 | ## one optimization step 192 | if accelerator.sync_gradients: 193 | progress_bar.update(1) 194 | progress_bar.set_postfix(loss=f"{loss:.4f}",lr=f"{lr_scheduler.get_last_lr()[0]:6f}") 195 | completed_steps += 1 196 | accelerator.clip_grad_norm_(dual_encoder.parameters(), args.max_grad_norm) 197 | if not accelerator.optimizer_step_was_skipped: 198 | lr_scheduler.step() 199 | accelerator.log({"training_loss": loss}, step=completed_steps) 200 | accelerator.log({"lr": lr_scheduler.get_last_lr()[0], "epoch" : epoch}, step=completed_steps) 201 | optimizer.step() 202 | optimizer.zero_grad() 203 | loss_calculator.empty_cache() # empty cache after one optimization step, if prev_cache is False 204 | 205 | if (epoch+1) % 10 == 0 : 206 | print(f"evaluating on dev set...") 207 | avg_rank,loss = validate(dual_encoder,dev_dataloader,accelerator) 208 | dual_encoder.train() 209 | accelerator.log({"val_avg_rank": avg_rank, "val_loss":loss}, step=completed_steps) 210 | accelerator.wait_for_everyone() 211 | 212 | 213 | if accelerator.is_local_main_process: 214 | unwrapped_model = accelerator.unwrap_model(dual_encoder) 215 | unwrapped_model.query_encoder.save_pretrained(os.path.join(LOG_DIR,f"step-{completed_steps}_epoch-{epoch}/query_encoder")) 216 | tokenizer.save_pretrained(os.path.join(LOG_DIR,f"step-{completed_steps}_epoch-{epoch}/query_encoder")) 217 | 218 | unwrapped_model.doc_encoder.save_pretrained(os.path.join(LOG_DIR,f"step-{completed_steps}_epoch-{epoch}/doc_encoder")) 219 | tokenizer.save_pretrained(os.path.join(LOG_DIR,f"step-{completed_steps}_epoch-{epoch}/doc_encoder")) 220 | 221 | accelerator.wait_for_everyone() 222 | 223 | 224 | 225 | accelerator.end_training() 226 | 227 | if __name__ == '__main__': 228 | main() -------------------------------------------------------------------------------- /src/utils/download_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Copied from DPR 10 | """ 11 | 12 | import argparse 13 | import gzip 14 | import logging 15 | import os 16 | import pathlib 17 | import wget 18 | 19 | from typing import Tuple 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | # TODO: move to hydra config group 24 | 25 | NQ_LICENSE_FILES = [ 26 | "https://dl.fbaipublicfiles.com/dpr/nq_license/LICENSE", 27 | "https://dl.fbaipublicfiles.com/dpr/nq_license/README", 28 | ] 29 | 30 | RESOURCES_MAP = { 31 | "data.wikipedia_split.psgs_w100": { 32 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz", 33 | "original_ext": ".tsv", 34 | "compressed": True, 35 | "desc": "Entire wikipedia passages set obtain by splitting all pages into 100-word segments (no overlap)", 36 | }, 37 | "data.retriever.nq-dev": { 38 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz", 39 | "original_ext": ".json", 40 | "compressed": True, 41 | "desc": "NQ dev subset with passages pools for the Retriever train time validation", 42 | "license_files": NQ_LICENSE_FILES, 43 | }, 44 | "data.retriever.nq-train": { 45 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz", 46 | "original_ext": ".json", 47 | "compressed": True, 48 | "desc": "NQ train subset with passages pools for the Retriever training", 49 | "license_files": NQ_LICENSE_FILES, 50 | }, 51 | "data.retriever.nq-adv-hn-train": { 52 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-adv-hn-train.json.gz", 53 | "original_ext": ".json", 54 | "compressed": True, 55 | "desc": "NQ train subset with hard negative passages mined using the baseline DPR NQ encoders & wikipedia index", 56 | "license_files": NQ_LICENSE_FILES, 57 | }, 58 | "data.retriever.trivia-dev": { 59 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-trivia-dev.json.gz", 60 | "original_ext": ".json", 61 | "compressed": True, 62 | "desc": "TriviaQA dev subset with passages pools for the Retriever train time validation", 63 | }, 64 | "data.retriever.trivia-train": { 65 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-trivia-train.json.gz", 66 | "original_ext": ".json", 67 | "compressed": True, 68 | "desc": "TriviaQA train subset with passages pools for the Retriever training", 69 | }, 70 | "data.retriever.squad1-train": { 71 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-squad1-train.json.gz", 72 | "original_ext": ".json", 73 | "compressed": True, 74 | "desc": "SQUAD 1.1 train subset with passages pools for the Retriever training", 75 | }, 76 | "data.retriever.squad1-dev": { 77 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-squad1-dev.json.gz", 78 | "original_ext": ".json", 79 | "compressed": True, 80 | "desc": "SQUAD 1.1 dev subset with passages pools for the Retriever train time validation", 81 | }, 82 | "data.retriever.webq-train": { 83 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-webquestions-train.json.gz", 84 | "original_ext": ".json", 85 | "compressed": True, 86 | "desc": "WebQuestions dev subset with passages pools for the Retriever train time validation", 87 | }, 88 | "data.retriever.webq-dev": { 89 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-webquestions-dev.json.gz", 90 | "original_ext": ".json", 91 | "compressed": True, 92 | "desc": "WebQuestions dev subset with passages pools for the Retriever train time validation", 93 | }, 94 | "data.retriever.curatedtrec-train": { 95 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-curatedtrec-train.json.gz", 96 | "original_ext": ".json", 97 | "compressed": True, 98 | "desc": "CuratedTrec dev subset with passages pools for the Retriever train time validation", 99 | }, 100 | "data.retriever.curatedtrec-dev": { 101 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-curatedtrec-dev.json.gz", 102 | "original_ext": ".json", 103 | "compressed": True, 104 | "desc": "CuratedTrec dev subset with passages pools for the Retriever train time validation", 105 | }, 106 | "data.retriever.qas.nq-dev": { 107 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-dev.qa.csv", 108 | "original_ext": ".csv", 109 | "compressed": False, 110 | "desc": "NQ dev subset for Retriever validation and IR results generation", 111 | "license_files": NQ_LICENSE_FILES, 112 | }, 113 | "data.retriever.qas.nq-test": { 114 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-test.qa.csv", 115 | "original_ext": ".csv", 116 | "compressed": False, 117 | "desc": "NQ test subset for Retriever validation and IR results generation", 118 | "license_files": NQ_LICENSE_FILES, 119 | }, 120 | "data.retriever.qas.nq-train": { 121 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-train.qa.csv", 122 | "original_ext": ".csv", 123 | "compressed": False, 124 | "desc": "NQ train subset for Retriever validation and IR results generation", 125 | "license_files": NQ_LICENSE_FILES, 126 | }, 127 | # 128 | "data.retriever.qas.trivia-dev": { 129 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-dev.qa.csv.gz", 130 | "original_ext": ".csv", 131 | "compressed": True, 132 | "desc": "Trivia dev subset for Retriever validation and IR results generation", 133 | }, 134 | "data.retriever.qas.trivia-test": { 135 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-test.qa.csv.gz", 136 | "original_ext": ".csv", 137 | "compressed": True, 138 | "desc": "Trivia test subset for Retriever validation and IR results generation", 139 | }, 140 | "data.retriever.qas.trivia-train": { 141 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-train.qa.csv.gz", 142 | "original_ext": ".csv", 143 | "compressed": True, 144 | "desc": "Trivia train subset for Retriever validation and IR results generation", 145 | }, 146 | "data.retriever.qas.squad1-test": { 147 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/squad1-test.qa.csv", 148 | "original_ext": ".csv", 149 | "compressed": False, 150 | "desc": "Trivia test subset for Retriever validation and IR results generation", 151 | }, 152 | "data.retriever.qas.webq-test": { 153 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/webquestions-test.qa.csv", 154 | "original_ext": ".csv", 155 | "compressed": False, 156 | "desc": "WebQuestions test subset for Retriever validation and IR results generation", 157 | }, 158 | "data.retriever.qas.curatedtrec-test": { 159 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/curatedtrec-test.qa.csv", 160 | "original_ext": ".csv", 161 | "compressed": False, 162 | "desc": "CuratedTrec test subset for Retriever validation and IR results generation", 163 | }, 164 | "data.gold_passages_info.nq_train": { 165 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-train_gold_info.json.gz", 166 | "original_ext": ".json", 167 | "compressed": True, 168 | "desc": "Original NQ (our train subset) gold positive passages and alternative question tokenization", 169 | "license_files": NQ_LICENSE_FILES, 170 | }, 171 | "data.gold_passages_info.nq_dev": { 172 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-dev_gold_info.json.gz", 173 | "original_ext": ".json", 174 | "compressed": True, 175 | "desc": "Original NQ (our dev subset) gold positive passages and alternative question tokenization", 176 | "license_files": NQ_LICENSE_FILES, 177 | }, 178 | "data.gold_passages_info.nq_test": { 179 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-test_gold_info.json.gz", 180 | "original_ext": ".json", 181 | "compressed": True, 182 | "desc": "Original NQ (our test, original dev subset) gold positive passages and alternative question " 183 | "tokenization", 184 | "license_files": NQ_LICENSE_FILES, 185 | }, 186 | "pretrained.fairseq.roberta-base.dict": { 187 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/fairseq/roberta/dict.txt", 188 | "original_ext": ".txt", 189 | "compressed": False, 190 | "desc": "Dictionary for pretrained fairseq roberta model", 191 | }, 192 | "pretrained.fairseq.roberta-base.model": { 193 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/fairseq/roberta/model.pt", 194 | "original_ext": ".pt", 195 | "compressed": False, 196 | "desc": "Weights for pretrained fairseq roberta base model", 197 | }, 198 | "pretrained.pytext.bert-base.model": { 199 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/pretrained/pytext/bert/bert-base-uncased.pt", 200 | "original_ext": ".pt", 201 | "compressed": False, 202 | "desc": "Weights for pretrained pytext bert base model", 203 | }, 204 | "data.retriever_results.nq.single.wikipedia_passages": { 205 | "s3_url": [ 206 | "https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single/nq/wiki_passages_{}".format(i) 207 | for i in range(50) 208 | ], 209 | "original_ext": ".pkl", 210 | "compressed": False, 211 | "desc": "Encoded wikipedia files using a biencoder checkpoint(" 212 | "checkpoint.retriever.single.nq.bert-base-encoder) trained on NQ dataset ", 213 | }, 214 | "data.retriever_results.nq.single-adv-hn.wikipedia_passages": { 215 | "s3_url": [ 216 | "https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single-adv-hn/nq/wiki_passages_{}".format(i) 217 | for i in range(50) 218 | ], 219 | "original_ext": ".pkl", 220 | "compressed": False, 221 | "desc": "Encoded wikipedia files using a biencoder checkpoint(" 222 | "checkpoint.retriever.single-adv-hn.nq.bert-base-encoder) trained on NQ dataset + adversarial hard negatives", 223 | }, 224 | "data.retriever_results.nq.single.test": { 225 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-test.json.gz", 226 | "original_ext": ".json", 227 | "compressed": True, 228 | "desc": "Retrieval results of NQ test dataset for the encoder trained on NQ", 229 | "license_files": NQ_LICENSE_FILES, 230 | }, 231 | "data.retriever_results.nq.single.dev": { 232 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-dev.json.gz", 233 | "original_ext": ".json", 234 | "compressed": True, 235 | "desc": "Retrieval results of NQ dev dataset for the encoder trained on NQ", 236 | "license_files": NQ_LICENSE_FILES, 237 | }, 238 | "data.retriever_results.nq.single.train": { 239 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-train.json.gz", 240 | "original_ext": ".json", 241 | "compressed": True, 242 | "desc": "Retrieval results of NQ train dataset for the encoder trained on NQ", 243 | "license_files": NQ_LICENSE_FILES, 244 | }, 245 | "data.retriever_results.nq.single-adv-hn.test": { 246 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single-adv-hn/nq-test.json.gz", 247 | "original_ext": ".json", 248 | "compressed": True, 249 | "desc": "Retrieval results of NQ test dataset for the encoder trained on NQ + adversarial hard negatives", 250 | "license_files": NQ_LICENSE_FILES, 251 | }, 252 | "checkpoint.retriever.single.nq.bert-base-encoder": { 253 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriever/single/nq/hf_bert_base.cp", 254 | "original_ext": ".cp", 255 | "compressed": False, 256 | "desc": "Biencoder weights trained on NQ data and HF bert-base-uncased model", 257 | }, 258 | "checkpoint.retriever.multiset.bert-base-encoder": { 259 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriver/multiset/hf_bert_base.cp", 260 | "original_ext": ".cp", 261 | "compressed": False, 262 | "desc": "Biencoder weights trained on multi set data and HF bert-base-uncased model", 263 | }, 264 | "checkpoint.retriever.single-adv-hn.nq.bert-base-encoder": { 265 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriver/single-adv-hn/nq/hf_bert_base.cp", 266 | "original_ext": ".cp", 267 | "compressed": False, 268 | "desc": "Biencoder weights trained on the original DPR NQ data combined with adversarial hard negatives (See data.retriever.nq-adv-hn-train resource). " 269 | "The model is HF bert-base-uncased", 270 | }, 271 | "data.reader.nq.single.train": { 272 | "s3_url": ["https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/train.{}.pkl".format(i) for i in range(8)], 273 | "original_ext": ".pkl", 274 | "compressed": False, 275 | "desc": "Reader model NQ train dataset input data preprocessed from retriever results (also trained on NQ)", 276 | "license_files": NQ_LICENSE_FILES, 277 | }, 278 | "data.reader.nq.single.dev": { 279 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/dev.0.pkl", 280 | "original_ext": ".pkl", 281 | "compressed": False, 282 | "desc": "Reader model NQ dev dataset input data preprocessed from retriever results (also trained on NQ)", 283 | "license_files": NQ_LICENSE_FILES, 284 | }, 285 | "data.reader.nq.single.test": { 286 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/test.0.pkl", 287 | "original_ext": ".pkl", 288 | "compressed": False, 289 | "desc": "Reader model NQ test dataset input data preprocessed from retriever results (also trained on NQ)", 290 | "license_files": NQ_LICENSE_FILES, 291 | }, 292 | "data.reader.trivia.multi-hybrid.train": { 293 | "s3_url": [ 294 | "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/train.{}.pkl".format(i) 295 | for i in range(8) 296 | ], 297 | "original_ext": ".pkl", 298 | "compressed": False, 299 | "desc": "Reader model Trivia train dataset input data preprocessed from hybrid retriever results " 300 | "(where dense part is trained on multiset)", 301 | }, 302 | "data.reader.trivia.multi-hybrid.dev": { 303 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/dev.0.pkl", 304 | "original_ext": ".pkl", 305 | "compressed": False, 306 | "desc": "Reader model Trivia dev dataset input data preprocessed from hybrid retriever results " 307 | "(where dense part is trained on multiset)", 308 | }, 309 | "data.reader.trivia.multi-hybrid.test": { 310 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/test.0.pkl", 311 | "original_ext": ".pkl", 312 | "compressed": False, 313 | "desc": "Reader model Trivia test dataset input data preprocessed from hybrid retriever results " 314 | "(where dense part is trained on multiset)", 315 | }, 316 | "checkpoint.reader.nq-single.hf-bert-base": { 317 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-single/hf_bert_base.cp", 318 | "original_ext": ".cp", 319 | "compressed": False, 320 | "desc": "Reader weights trained on NQ-single retriever results and HF bert-base-uncased model", 321 | }, 322 | "checkpoint.reader.nq-trivia-hybrid.hf-bert-base": { 323 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-trivia-hybrid/hf_bert_base.cp", 324 | "original_ext": ".cp", 325 | "compressed": False, 326 | "desc": "Reader weights trained on Trivia multi hybrid retriever results and HF bert-base-uncased model", 327 | }, 328 | # extra checkpoints for EfficientQA competition 329 | "checkpoint.reader.nq-single-subset.hf-bert-base": { 330 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-single-seen_only/hf_bert_base.cp", 331 | "original_ext": ".cp", 332 | "compressed": False, 333 | "desc": "Reader weights trained on NQ-single retriever results and HF bert-base-uncased model, when only Wikipedia pages seen during training are considered", 334 | }, 335 | "checkpoint.reader.nq-tfidf.hf-bert-base": { 336 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-drqa/hf_bert_base.cp", 337 | "original_ext": ".cp", 338 | "compressed": False, 339 | "desc": "Reader weights trained on TFIDF results and HF bert-base-uncased model", 340 | }, 341 | "checkpoint.reader.nq-tfidf-subset.hf-bert-base": { 342 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-drqa-seen_only/hf_bert_base.cp", 343 | "original_ext": ".cp", 344 | "compressed": False, 345 | "desc": "Reader weights trained on TFIDF results and HF bert-base-uncased model, when only Wikipedia pages seen during training are considered", 346 | }, 347 | # retrieval indexes 348 | "indexes.single.nq.full.index": { 349 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/full.index.dpr", 350 | "original_ext": ".dpr", 351 | "compressed": False, 352 | "desc": "DPR index on NQ-single retriever", 353 | }, 354 | "indexes.single.nq.full.index_meta": { 355 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/full.index_meta.dpr", 356 | "original_ext": ".dpr", 357 | "compressed": False, 358 | "desc": "DPR index on NQ-single retriever (metadata)", 359 | }, 360 | "indexes.single.nq.subset.index": { 361 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/seen_only.index.dpr", 362 | "original_ext": ".dpr", 363 | "compressed": False, 364 | "desc": "DPR index on NQ-single retriever when only Wikipedia pages seen during training are considered", 365 | }, 366 | "indexes.single.nq.subset.index_meta": { 367 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/single/nq/seen_only.index_meta.dpr", 368 | "original_ext": ".dpr", 369 | "compressed": False, 370 | "desc": "DPR index on NQ-single retriever when only Wikipedia pages seen during training are considered (metadata)", 371 | }, 372 | "indexes.tfidf.nq.full": { 373 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/drqa/nq/full-tfidf.npz", 374 | "original_ext": ".npz", 375 | "compressed": False, 376 | "desc": "TFIDF index", 377 | }, 378 | "indexes.tfidf.nq.subset": { 379 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/indexes/drqa/nq/seen_only-tfidf.npz", 380 | "original_ext": ".npz", 381 | "compressed": False, 382 | "desc": "TFIDF index when only Wikipedia pages seen during training are considered", 383 | }, 384 | # Universal retriever project data 385 | "data.wikipedia_split.psgs_w100": { 386 | "s3_url": "https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz", 387 | "original_ext": ".tsv", 388 | "compressed": True, 389 | "desc": "Entire wikipedia passages set obtain by splitting all pages into 100-word segments (no overlap)", 390 | }, 391 | } 392 | 393 | 394 | def unpack(gzip_file: str, out_file: str): 395 | logger.info("Uncompressing %s", gzip_file) 396 | input = gzip.GzipFile(gzip_file, "rb") 397 | s = input.read() 398 | input.close() 399 | output = open(out_file, "wb") 400 | output.write(s) 401 | output.close() 402 | logger.info(" Saved to %s", out_file) 403 | 404 | 405 | def download_resource( 406 | s3_url: str, original_ext: str, compressed: bool, resource_key: str, out_dir: str 407 | ) -> Tuple[str, str]: 408 | logger.info("Requested resource from %s", s3_url) 409 | path_names = resource_key.split(".") 410 | 411 | if out_dir: 412 | root_dir = out_dir 413 | else: 414 | # since hydra overrides the location for the 'current dir' for every run and we don't want to duplicate 415 | # resources multiple times, remove the current folder's volatile part 416 | root_dir = os.path.abspath("./") 417 | if "/outputs/" in root_dir: 418 | root_dir = root_dir[: root_dir.index("/outputs/")] 419 | 420 | logger.info("Download root_dir %s", root_dir) 421 | 422 | save_root = os.path.join(root_dir, "downloads", *path_names[:-1]) # last segment is for file name 423 | 424 | pathlib.Path(save_root).mkdir(parents=True, exist_ok=True) 425 | 426 | local_file_uncompressed = os.path.abspath(os.path.join(save_root, path_names[-1] + original_ext)) 427 | logger.info("File to be downloaded as %s", local_file_uncompressed) 428 | 429 | if os.path.exists(local_file_uncompressed): 430 | logger.info("File already exist %s", local_file_uncompressed) 431 | return save_root, local_file_uncompressed 432 | 433 | local_file = os.path.abspath(os.path.join(save_root, path_names[-1] + (".tmp" if compressed else original_ext))) 434 | 435 | wget.download(s3_url, out=local_file) 436 | 437 | logger.info("Downloaded to %s", local_file) 438 | 439 | if compressed: 440 | uncompressed_file = os.path.join(save_root, path_names[-1] + original_ext) 441 | unpack(local_file, uncompressed_file) 442 | os.remove(local_file) 443 | local_file = uncompressed_file 444 | return save_root, local_file 445 | 446 | 447 | def download_file(s3_url: str, out_dir: str, file_name: str): 448 | logger.info("Loading from %s", s3_url) 449 | local_file = os.path.join(out_dir, file_name) 450 | 451 | if os.path.exists(local_file): 452 | logger.info("File already exist %s", local_file) 453 | return 454 | 455 | wget.download(s3_url, out=local_file) 456 | logger.info("Downloaded to %s", local_file) 457 | 458 | 459 | def download(resource_key: str, out_dir: str = None): 460 | if resource_key not in RESOURCES_MAP: 461 | # match by prefix 462 | resources = [k for k in RESOURCES_MAP.keys() if k.startswith(resource_key)] 463 | logger.info("matched by prefix resources: %s", resources) 464 | if resources: 465 | for key in resources: 466 | download(key, out_dir) 467 | else: 468 | logger.info("no resources found for specified key") 469 | return [] 470 | download_info = RESOURCES_MAP[resource_key] 471 | 472 | s3_url = download_info["s3_url"] 473 | 474 | save_root_dir = None 475 | data_files = [] 476 | if isinstance(s3_url, list): 477 | for i, url in enumerate(s3_url): 478 | save_root_dir, local_file = download_resource( 479 | url, 480 | download_info["original_ext"], 481 | download_info["compressed"], 482 | "{}_{}".format(resource_key, i), 483 | out_dir, 484 | ) 485 | data_files.append(local_file) 486 | else: 487 | save_root_dir, local_file = download_resource( 488 | s3_url, 489 | download_info["original_ext"], 490 | download_info["compressed"], 491 | resource_key, 492 | out_dir, 493 | ) 494 | data_files.append(local_file) 495 | 496 | license_files = download_info.get("license_files", None) 497 | if license_files: 498 | download_file(license_files[0], save_root_dir, "LICENSE") 499 | download_file(license_files[1], save_root_dir, "README") 500 | return data_files 501 | 502 | 503 | def main(): 504 | parser = argparse.ArgumentParser() 505 | 506 | parser.add_argument( 507 | "--output_dir", 508 | default="./", 509 | type=str, 510 | help="The output directory to download file", 511 | ) 512 | parser.add_argument( 513 | "--resource", 514 | type=str, 515 | help="Resource name. See RESOURCES_MAP for all possible values", 516 | ) 517 | args = parser.parse_args() 518 | if args.resource: 519 | download(args.resource, args.output_dir) 520 | else: 521 | logger.warning("Please specify resource value. Possible options are:") 522 | for k, v in RESOURCES_MAP.items(): 523 | logger.warning("Resource key=%s : %s", k, v["desc"]) 524 | 525 | 526 | if __name__ == "__main__": 527 | main() 528 | --------------------------------------------------------------------------------