├── requirements.txt ├── .gitignore ├── data_prep ├── get_train_qids.sh ├── get_ance_ranking.sh ├── get_prf_data.sh ├── get_train_qids.py ├── data_utils.py ├── preprocess_data.sh ├── get_ance_embs.sh ├── get_ance_ranking.py ├── download_data.sh ├── get_train_query_embeds.py ├── get_prf_data.py ├── preprocess_data.py └── run_ann_data_gen.py ├── convert_output.sh ├── train_encoder.sh ├── eval.sh ├── get_marco_eval_output.py ├── get_eval_metrics.sh ├── README.md ├── model.py ├── lamb.py ├── get_eval_metrics.py ├── main.py ├── test_ance.py ├── msmarco_eval.py ├── data.py ├── runner.py └── utils └── util.py /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==2.3.0 2 | pytrec-eval 3 | faiss-cpu 4 | pandas 5 | sklearn -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # pycharm files 2 | .idea/ 3 | .venv/ 4 | 5 | # macos files 6 | .DS_Store 7 | 8 | # data 9 | outputs/** 10 | data/** 11 | model/** 12 | *__pycache__ -------------------------------------------------------------------------------- /data_prep/get_train_qids.sh: -------------------------------------------------------------------------------- 1 | # This script 2 | 3 | repo_dir=$(builtin cd ..; pwd) 4 | data_dir=${repo_dir}/data 5 | 6 | echo "Start generating train query ids..." 7 | python -u get_train_qids.py \ 8 | --raw_data_dir ${data_dir}/marco_raw_data 9 | echo "Done generating train query ids." -------------------------------------------------------------------------------- /convert_output.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source activate ance 4 | dataset=$1 5 | ckpt=$2 6 | mode="fullrank" 7 | num_feedbacks=$3 8 | python get_marco_eval_output.py \ 9 | --ance_checkpoint_path "/bos/tmp10/hongqiay/ance/${dataset}_output/" \ 10 | --processed_data_dir "/bos/tmp10/hongqiay/ance/${dataset}_preprocessed" \ 11 | --devI_path "/bos/tmp10/hongqiay/prf-query-encoder/models/k_${num_feedbacks}/checkpoint-${ckpt}/${dataset}_devI_${mode}.npy" 12 | -------------------------------------------------------------------------------- /train_encoder.sh: -------------------------------------------------------------------------------- 1 | gpu_no=2 2 | lr=1e-5 3 | num_feedbacks=3 4 | repo_dir=$(pwd) 5 | output_dir=${repo_dir}/outputs 6 | data_dir=${repo_dir}/data 7 | mkdir -p ${output_dir} 8 | 9 | dataset="marco" 10 | python -m torch.distributed.launch --nproc_per_node=${gpu_no} main.py \ 11 | --train \ 12 | --logging_steps 100 \ 13 | --save_steps 2000 \ 14 | --gradient_accumulation_steps 8 \ 15 | --warmup_steps=5000 \ 16 | --output_dir ${output_dir}/k_${num_feedbacks} \ 17 | --learning_rate ${lr} \ 18 | --num_feedbacks ${num_feedbacks} \ 19 | --per_gpu_train_batch_size 4 \ 20 | --load_optimizer_scheduler \ 21 | --ance_checkpoint_path ${data_dir}/${dataset}_output \ 22 | --preprocessed_dir ${data_dir}/${dataset}_preprocessed \ 23 | --train_data_dir ${data_dir}/${dataset}_output -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | # This 2 | lr=1e-5 3 | save_steps=2000 4 | gpu_id=${1:-2} 5 | end_eval_ckpt=1 6 | num_feedbacks=3 7 | repo_dir=$(pwd) 8 | output_dir=${repo_dir}/outputs 9 | data_dir=${repo_dir}/data 10 | dataset="marco" 11 | eval_mode="rerank" # switch to "full" if you would like to see the full retrieval results 12 | 13 | CUDA_VISIBLE_DEVICES=${gpu_id} python main.py \ 14 | --eval \ 15 | --first_stage_inn_path ${data_dir}/${dataset}_output/${dataset}_dev_I.npy \ 16 | --output_dir ${output_dir}/k_${num_feedbacks} \ 17 | --save_steps ${save_steps} \ 18 | --eval_mode ${eval_mode} \ 19 | --ance_checkpoint_path ${data_dir}/${dataset}_output/ \ 20 | --preprocessed_dir ${data_dir}/${dataset}_preprocessed \ 21 | --dev_data_dir ${data_dir}/${dataset}_output \ 22 | --dataset ${dataset} \ 23 | --per_gpu_eval_batch_size 50 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /get_marco_eval_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from util import load_embedding_prefix, load_embeddings 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--ance_checkpoint_path", default="/bos/tmp10/hongqiay/ance/marco_output/", 12 | help="location for dumpped query and passage/document embeddings which is output_dir") 13 | parser.add_argument("--processed_data_dir", default="/bos/tmp10/hongqiay/ance/marco_preprocessed") 14 | parser.add_argument("--devI_path", required=True) 15 | 16 | return parser.parse_args() 17 | 18 | 19 | 20 | 21 | def main(): 22 | args = parse_args() 23 | offset2qid = get_offset2qid(args) 24 | embedid2qid = get_embedid2qid(args, offset2qid) 25 | devI_to_tein(args, embedid2qid) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /data_prep/get_ance_ranking.sh: -------------------------------------------------------------------------------- 1 | # Run initial retrieval on all the datasets. 2 | 3 | repo_dir=$(builtin cd ..; pwd) 4 | data_dir=${repo_dir}/data 5 | model_dir=${repo_dir}/model 6 | 7 | pids="" 8 | 9 | dataset="marco" 10 | echo "Getting initial ANCE ranking on ${dataset} train set..." 11 | python -u get_ance_ranking.py \ 12 | --processed_data_dir ${data_dir}/${dataset}_preprocessed \ 13 | --ance_checkpoint_path ${data_dir}/${dataset}_output \ 14 | --dataset ${dataset} \ 15 | --mode "train" & 16 | pids="$pids $!" 17 | 18 | dataset_array=( 19 | marco 20 | trec19psg 21 | trec20psg 22 | dlhard 23 | ) 24 | 25 | for dataset in "${dataset_array[@]}"; do 26 | echo "Getting initial ANCE ranking on ${dataset} dev set..." 27 | python -u get_ance_ranking.py \ 28 | --processed_data_dir ${data_dir}/${dataset}_preprocessed \ 29 | --ance_checkpoint_path ${data_dir}/${dataset}_output \ 30 | --dataset ${dataset} \ 31 | --mode "dev" & 32 | pids="$pids $!" 33 | done 34 | 35 | for pid in $pids; do 36 | wait $pid 37 | done 38 | echo "Initial ranking on all datasets done." 39 | 40 | 41 | -------------------------------------------------------------------------------- /get_eval_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH -N 1 3 | #SBATCH -n 1 4 | #SBATCH --mem=60000 # Memory - Use up to 40G 5 | #SBATCH --time=0 # No time limit 6 | #SBATCH --mail-user=hongqiay@andrew.cmu.edu 7 | #SBATCH --mail-type=END 8 | #SBATCH -p gpu 9 | #SBATCH --gres=gpu:1 10 | #SBATCH --nodelist=boston-2-35 11 | 12 | source activate ance 13 | 14 | lr=1e-5 15 | save_steps=10000 16 | prev_ckpt=0 17 | end_eval_ckpt=450000 18 | num_feedbacks=3 19 | dataset="marco" 20 | eval_mode="rerank" 21 | python get_eval_metrics.py \ 22 | --eval \ 23 | --first_stage_inn_path "/bos/usr0/hongqiay/ANCE/results/ance_fullrank_top1K_${dataset}" \ 24 | --output_dir "/bos/tmp10/hongqiay/prf-query-encoder/models/k_${num_feedbacks}" \ 25 | --save_steps ${save_steps} \ 26 | --eval_mode ${eval_mode} \ 27 | --prev_evaluated_ckpt ${prev_ckpt} \ 28 | --ance_checkpoint_path "/bos/tmp10/hongqiay/ance/${dataset}_output/" \ 29 | --preprocessed_dir "/bos/tmp10/hongqiay/ance/${dataset}_preprocessed" \ 30 | --dev_data_dir "/bos/tmp10/hongqiay/prf-query-encoder/ance_format_data/prf_encoder_dev_${dataset}" \ 31 | --end_eval_ckpt ${end_eval_ckpt} \ 32 | --dataset ${dataset} \ 33 | --per_gpu_eval_batch_size 150 -------------------------------------------------------------------------------- /data_prep/get_prf_data.sh: -------------------------------------------------------------------------------- 1 | # Generates PRF data for all datasets. 2 | 3 | data_dir=$(builtin cd ../data; pwd) 4 | 5 | pids="" 6 | 7 | echo "Generating PRF training data from ANCE top ranking on MARCO training set..." 8 | dataset="marco" 9 | mode="train" 10 | python -u get_prf_data.py \ 11 | --processed_data_dir ${data_dir}/${dataset}_preprocessed \ 12 | --output_dir ${data_dir}/${dataset}_output \ 13 | --inn_path ${data_dir}/${dataset}_output/${dataset}_${mode}_I.npy \ 14 | --ance_checkpoint_path ${data_dir}/${dataset}_output \ 15 | --dataset ${dataset} \ 16 | --mode train & 17 | pids="$pids $!" 18 | 19 | 20 | dataset_array=( 21 | marco 22 | trec19psg 23 | trec20psg 24 | dlhard 25 | ) 26 | 27 | mode="dev" 28 | for dataset in "${dataset_array[@]}"; do 29 | echo "Generating PRF dev data from ANCE top ranking on ${dataset} dev set..." 30 | python -u get_prf_data.py \ 31 | --processed_data_dir ${data_dir}/${dataset}_preprocessed \ 32 | --output_dir ${data_dir}/${dataset}_output \ 33 | --inn_path ${data_dir}/${dataset}_output/${dataset}_${mode}_I.npy \ 34 | --ance_checkpoint_path ${data_dir}/${dataset}_output \ 35 | --dataset ${dataset} \ 36 | --mode dev & 37 | pids="$pids $!" 38 | done 39 | 40 | for pid in $pids; do 41 | wait $pid 42 | done 43 | echo "Generated PRF data for all datasets." -------------------------------------------------------------------------------- /data_prep/get_train_qids.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import csv 4 | import json 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--raw_data_dir", default=None, help="Path to MARCO raw data.") 10 | return parser.parse_args() 11 | 12 | 13 | def add_train_query_ids(args): 14 | in_path = os.path.join(f"{args.raw_data_dir}", "triples.train.small.tsv") 15 | out_path = os.path.join(f"{args.raw_data_dir}", "id.triples.train.small.tsv") 16 | 17 | query2qid = dict() 18 | with open(in_path, "r") as fin: 19 | reader = csv.reader(fin, delimiter="\t", quoting=csv.QUOTE_MINIMAL) 20 | print("Reading training queries...") 21 | for row in reader: 22 | query = row[0] 23 | if query not in query2qid: 24 | query2qid[query] = len(query2qid) 25 | fin.seek(0) 26 | print("Writing qids to CSV format...") 27 | with open(out_path, "w") as fout: 28 | writer = csv.writer(fout, delimiter="\t", quoting=csv.QUOTE_MINIMAL) 29 | for row in reader: 30 | query = row[0] 31 | writer.writerow([query2qid[query]] + row) 32 | print("Dumping to json...") 33 | with open(os.path.join(f"{args.raw_data_dir}", "train.query2qid.json"), "w") as f: 34 | json.dump(query2qid, f) 35 | 36 | if __name__ == '__main__': 37 | args = parse_args() 38 | add_train_query_ids(args) -------------------------------------------------------------------------------- /data_prep/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | # def load_embeddings(args, checkpoint=0): 6 | # split = args.split 7 | # passage_embedding = [] 8 | # passage_embedding2id = [] 9 | 10 | # with open(os.path.join(args.output_dir, f"query_embeds_split{split}.npy"), "rb") as f: 11 | # query_embedding = np.load(f, allow_pickle=True) 12 | # with open(os.path.join(args.output_dir, f"embedid2qid{split}.pkl"), "rb") as f: 13 | # query_embedid2qid = pickle.load(f) 14 | 15 | # for i in range(8): 16 | # try: 17 | # with open(os.path.join(args.output_dir, "passage_" + str(checkpoint) + "__emb_p__data_obj_" + str(i) + ".pb"), 'rb') as handle: 18 | # passage_embedding.append(np.load(handle)) 19 | # with open(os.path.join(args.output_dir, "passage_" + str(checkpoint) + "__embid_p__data_obj_" + str(i) + ".pb"), 'rb') as handle: 20 | # passage_embedding2id.append(np.load(handle)) 21 | # except: 22 | # print(f"Loaded {i} passage embedding splits.") 23 | # break 24 | 25 | # passage_embedding = np.concatenate(passage_embedding, axis=0) 26 | # passage_embedding2id = np.concatenate(passage_embedding2id, axis=0) 27 | # return query_embedding, query_embedid2qid, passage_embedding, passage_embedding2id 28 | 29 | 30 | def offset_to_orig_id(orig2offset): 31 | offset2orig = dict() 32 | for k, v in orig2offset.items(): 33 | offset2orig[v] = k 34 | return offset2orig -------------------------------------------------------------------------------- /data_prep/preprocess_data.sh: -------------------------------------------------------------------------------- 1 | # This script does two things: 2 | # 1. Multi-process tokenization of the passages & queries, and save the tokenized data in binary format: 3 | # (1) Binary passage tokens are saved in `passages` 4 | # (2) Binary query tokens are saved in `train-query` and `dev-query` 5 | # Note that files suffixed with `_split*` are simply unmerged segments of (1) & (2). 6 | # 2. Save pid2offset & qid2offset since multiprocessing messed up the passage & query order. 7 | # pid & qid are the ids from the original dataset. offsets are the orders of the passages/queries in the multiprocessed binary. 8 | 9 | 10 | dataset_array=( 11 | marco 12 | trec19psg 13 | trec20psg 14 | dlhard 15 | ) 16 | data_dir=$(builtin cd ../data; pwd) 17 | 18 | for dataset in "${dataset_array[@]}"; do 19 | echo "Preprocessing ${dataset} data..." 20 | python preprocess_data.py \ 21 | --dataset ${dataset} \ 22 | --data_dir ${data_dir}/${dataset}_raw_data \ 23 | --out_data_dir ${data_dir}/${dataset}_preprocessed/ \ 24 | --model_type rdot_nll \ 25 | --model_name_or_path roberta-base \ 26 | --max_seq_length 512 \ 27 | --data_type 1 28 | if [ "$dataset" != "marco" ]; then 29 | # create soft-link to use preprocessed passage data from marco 30 | ln -s ${data_dir}/marco_preprocessed/passage* ${data_dir}/${dataset}_preprocessed/ 31 | ln -s ${data_dir}/marco_preprocessed/pid2offset.pickle ${data_dir}/${dataset}_preprocessed/pid2offset.pickle 32 | fi 33 | done 34 | 35 | echo "Finished preprocessing data." -------------------------------------------------------------------------------- /data_prep/get_ance_embs.sh: -------------------------------------------------------------------------------- 1 | # This script uses ANCE FirstP model to generate 2 | # marco passage embeddings, marco training query embeddings, and dev query embeddings for all the datasets. 3 | 4 | gpu_no=${1:-4} 5 | 6 | repo_dir=$(builtin cd ..; pwd) 7 | model_dir=${repo_dir}/model 8 | data_dir=${repo_dir}/data 9 | mkdir -p ${model_dir} 10 | 11 | echo "Downloading ANCE model..." 12 | cd ${model_dir} 13 | wget https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Passage_ANCE_FirstP_Checkpoint.zip 14 | unzip Passage_ANCE_FirstP_Checkpoint.zip 15 | mv "Passage ANCE(FirstP) Checkpoint" ance_firstp 16 | rm Passage_ANCE_FirstP_Checkpoint.zip 17 | 18 | echo "Encoding passages using ANCE model..." 19 | cd ${repo_dir}/data_prep 20 | 21 | dataset_array=( 22 | marco 23 | trec19psg 24 | trec20psg 25 | dlhard 26 | ) 27 | 28 | for dataset in "${dataset_array[@]}"; do 29 | echo "Generating ANCE embeddings for dataset ${dataset}..." 30 | python -m torch.distributed.launch --nproc_per_node=$gpu_no run_ann_data_gen.py \ 31 | --dataset ${dataset} \ 32 | --init_model_dir ${model_dir}/ance_firstp \ 33 | --model_type rdot_nll \ 34 | --output_dir ${data_dir}/${dataset}_output \ 35 | --cache_dir ${data_dir}/${dataset}_cache \ 36 | --data_dir ${data_dir}/${dataset}_preprocessed \ 37 | --max_seq_length 512 \ 38 | --per_gpu_eval_batch_size 64 \ 39 | --topk_training 200 \ 40 | --negative_sample 20 \ 41 | --end_output_num 0 \ 42 | --inference 43 | if [ "$dataset" != "marco" ]; then 44 | ln -s ${data_dir}/marco_output/passage* ${data_dir}/${dataset}_output/ 45 | fi 46 | done -------------------------------------------------------------------------------- /data_prep/get_ance_ranking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) 4 | import argparse 5 | import numpy as np 6 | import csv 7 | import faiss 8 | import pickle 9 | from utils.util import * 10 | # from data_utils import * 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--processed_data_dir", default=None, help="Path to _preprocessed folder.") 15 | parser.add_argument("--ance_checkpoint_path", default=None, help="Path to _output.") 16 | parser.add_argument("--dataset", default=None, help="Name of the dataset.") 17 | parser.add_argument("--mode", choices=["train", "dev"], required=True) 18 | return parser.parse_args() 19 | 20 | def ance_full_rank(args, query_embedding, passage_embedding, topN=1000): 21 | dim = passage_embedding.shape[1] 22 | faiss.omp_set_num_threads(16) 23 | cpu_index = faiss.IndexFlatIP(dim) 24 | cpu_index.add(passage_embedding) 25 | print(f"Starting CPU search on {args.dataset} {args.mode} data...") 26 | _, I = cpu_index.search(query_embedding, topN) 27 | print(f"Finished CPU search on {args.dataset} {args.mode} data.") 28 | with open(os.path.join(args.ance_checkpoint_path, f"{args.dataset}_{args.mode}_I.npy"), "wb") as f: 29 | np.save(f, I) 30 | tein_path = os.path.join(args.ance_checkpoint_path, f"{args.dataset}_{args.mode}.tein") 31 | return I 32 | 33 | 34 | if __name__ == '__main__': 35 | args = parse_args() 36 | query_embedding, query_embedding2id, passage_embedding, passage_embedding2id = load_embeddings(args, args.mode) 37 | ance_full_rank(args, query_embedding, passage_embedding) 38 | query_embedding2qid = get_embedding2qid(args) 39 | if args.mode == "dev": 40 | query_embedding2qid = get_embedding2qid(args) 41 | devI_path = os.path.join(args.ance_checkpoint_path, f"{args.dataset}_{args.mode}_I.npy") 42 | devI_to_tein(args, query_embedding2qid, devI_path) 43 | 44 | print(f"Initial ranking on {args.dataset} {args.mode} data done.") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Query Representations for Dense Retrieval with Pseudo Relevance Feedback 2 | 3 | HongChien Yu, Chenyan Xiong, Jamie Callan 4 | 5 | This repository holds code that reproduces results reported in 6 | [Improving Query Representations for Dense Retrieval with Pseudo Relevance Feedback](https://arxiv.org/abs/2108.13454). 7 | 8 | 9 | Dense retrieval systems conduct first-stage retrieval using embedded representations and simple similarity metrics to 10 | match a query to documents. Its effectiveness depends on encoded embeddings to capture the semantics of queries and 11 | documents, a challenging task due to the shortness and ambiguity of search queries. This paper proposes ANCE-PRF, 12 | a new query encoder that uses pseudo relevance feedback (PRF) to improve query representations for dense retrieval. 13 | ANCE-PRF uses a BERT encoder that consumes the query and the top retrieved documents from a dense retrieval model, 14 | ANCE, and it learns to produce better query embeddings directly from relevance labels. 15 | It also keeps the document index unchanged to reduce overhead. ANCE-PRF significantly outperforms ANCE and other recent 16 | dense retrieval systems on several datasets. Analysis shows that the PRF encoder effectively captures the relevant and 17 | complementary information from PRF documents, while ignoring the noise with its learned attention mechanism. 18 | 19 | ## Requirements 20 | ``` 21 | pip install -r requirements.txt 22 | ``` 23 | ## Data Preparation 24 | 25 | ### Data Preprocessing 26 | Run the following script to preprocess data: 27 | ```angular2html 28 | cd data_prep 29 | bash download_data.sh 30 | bash preprocess_data.sh 31 | ``` 32 | 33 | 34 | ### Get ANCE Passage Embeddings 35 | ```angular2html 36 | cd data_prep 37 | bash get_ance_embs.sh 38 | ``` 39 | 40 | 41 | ### Get ANCE Ranking 42 | ``` 43 | bash get_ance_ranking.sh 44 | ``` 45 | 46 | ### Prepare PRF data 47 | Run the following command to create PRF data from ANCE top-retrieved documents: 48 | ```angular2html 49 | cd data_prep 50 | bash get_prf_data.sh 51 | ``` 52 | 53 | ## Training 54 | ```angular2html 55 | bash train_encoder.sh 56 | ``` 57 | While training is running, concurrently run 58 | ``` 59 | bash eval.sh 60 | ``` 61 | which keeps looking for the newest checkpoints and evaluate it on marco. 62 | This is sadly not a very effective use of GPU in terms of utilization percentage, but it makes the training faster by avoiding periodic switching from training to evaluation. 63 | 64 | In our work, we picked the model that performs best on marco dev as reported by `eval.sh` tensorboard. 65 | 66 | 67 | ## Trained Models and Ranking Files 68 | Trained models for k=3 can be downloaded [here](https://drive.google.com/file/d/1xbMgP0Z5tuoqymbWUhfuvRvUx6TvNuVw/view?usp=sharing). 69 | 70 | Ranking files for k=3 can be downloaded [here](https://drive.google.com/drive/folders/1FybKqWbE1Ap1xDd8MR01ZOqXn9W0Xy8b?usp=sharing). 71 | 72 | -------------------------------------------------------------------------------- /data_prep/download_data.sh: -------------------------------------------------------------------------------- 1 | # This script downloads all data used in the paper. 2 | # It also renames the data files using the same format, 3 | # so that the same preprocessing pipeline can be applied more easily. 4 | 5 | repo_dir=$(pwd)/.. 6 | data_dir=${repo_dir}/data 7 | mkdir -p ${data_dir} 8 | 9 | echo "Downloading MS MARCO passage data..." 10 | dataset="marco" 11 | raw_data_dir=${data_dir}/${dataset}_raw_data 12 | mkdir -p ${raw_data_dir} 13 | cd ${raw_data_dir} 14 | 15 | wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz 16 | tar -zxvf collectionandqueries.tar.gz 17 | rm collectionandqueries.tar.gz 18 | 19 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz 20 | gunzip msmarco-passagetest2019-top1000.tsv.gz 21 | 22 | wget https://msmarco.blob.core.windows.net/msmarcoranking/top1000.dev.tar.gz 23 | tar -zxvf top1000.dev.tar.gz 24 | mv top1000.dev top1000.dev.tsv 25 | rm top1000.dev.tar.gz 26 | 27 | wget https://msmarco.blob.core.windows.net/msmarcoranking/triples.train.small.tar.gz 28 | tar -zxvf triples.train.small.tar.gz 29 | rm triples.train.small.tar.gz 30 | 31 | echo "Downloading TREC DL 2019 passage data..." 32 | dataset="trec19psg" 33 | raw_data_dir=${data_dir}/${dataset}_raw_data 34 | mkdir -p ${raw_data_dir} 35 | cd ${raw_data_dir} 36 | 37 | # trec dl 19 shares marco corpus 38 | ln -s ${data_dir}/marco_raw_data/collection.tsv . 39 | 40 | wget --no-check-certificate https://trec.nist.gov/data/deep/2019qrels-pass.txt -O qrels.dev.small.tsv 41 | 42 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz 43 | gunzip msmarco-test2019-queries.tsv.gz 44 | mv msmarco-test2019-queries.tsv queries.dev.small.tsv 45 | 46 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz 47 | gunzip msmarco-passagetest2019-top1000.tsv.gz 48 | mv msmarco-passagetest2019-top1000.tsv top1000.dev.tsv 49 | 50 | 51 | echo "Downloading TREC DL 2020 passage data..." 52 | dataset="trec20psg" 53 | raw_data_dir=${data_dir}/${dataset}_raw_data 54 | mkdir -p ${raw_data_dir} 55 | cd ${raw_data_dir} 56 | 57 | # trec dl 20 shares marco corpus 58 | ln -s ${data_dir}/marco_raw_data/collection.tsv . 59 | 60 | wget --no-check-certificate https://trec.nist.gov/data/deep/2020qrels-pass.txt -O qrels.dev.small.tsv 61 | 62 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz 63 | gunzip msmarco-test2020-queries.tsv.gz 64 | mv msmarco-test2020-queries.tsv queries.dev.small.tsv 65 | 66 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-passagetest2020-top1000.tsv.gz 67 | gunzip msmarco-passagetest2020-top1000.tsv.gz 68 | mv msmarco-passagetest2020-top1000.tsv top1000.dev.tsv 69 | 70 | 71 | echo "Downloading DL-HARD passage data..." 72 | # Using the files from this commit: https://github.com/grill-lab/DL-Hard/commit/c58ce8d9e8932a7b560c0f2cb3435b6c2db578fe 73 | 74 | dataset="dlhard" 75 | raw_data_dir=${data_dir}/${dataset}_raw_data 76 | mkdir -p ${raw_data_dir} 77 | cd ${raw_data_dir} 78 | 79 | # dl-hard shares marco corpus 80 | ln -s ${data_dir}/marco_raw_data/collection.tsv . 81 | wget https://raw.githubusercontent.com/grill-lab/DL-Hard/main/dataset/dl_hard-passage.qrels -O qrels.dev.small.tsv 82 | wget https://raw.githubusercontent.com/grill-lab/DL-Hard/main/dataset/topics.tsv -O queries.dev.small.tsv 83 | wget https://raw.githubusercontent.com/grill-lab/DL-Hard/main/dataset/baselines/passage/bm25.run -O top1000.dev.tsv 84 | 85 | echo "Finished downloading data." 86 | -------------------------------------------------------------------------------- /data_prep/get_train_query_embeds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) 4 | import torch 5 | import argparse 6 | import numpy as np 7 | import csv 8 | import faiss 9 | import json 10 | import pickle 11 | from model import RobertaDot_NLL_LN 12 | from transformers import RobertaTokenizer, RobertaConfig 13 | from tqdm import tqdm 14 | from data_utils import * 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--raw_data_dir", default=None, help="Path to MARCO raw data.") 19 | parser.add_argument("--output_dir", default=None, help="Path to save the embeddings.") 20 | parser.add_argument("--model_name_or_path", default=None, help="Path to the ANCE first-p model.") 21 | parser.add_argument("--split", type=int) 22 | parser.add_argument("--chunk_size", type=int, default=10000) 23 | 24 | return parser.parse_args() 25 | 26 | 27 | def get_train_query_embeds(args): 28 | qid2query_path=os.path.join(f"{args.raw_data_dir}", "train.query2qid.json") 29 | split, chunk_size = args.split, args.chunk_size 30 | config = RobertaConfig.from_pretrained(args.model_name_or_path, num_labels=2, finetuning_task="MSMarco") 31 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path, do_lower_case=True) 32 | qry_encoder = RobertaDot_NLL_LN.from_pretrained(args.model_name_or_path, 33 | from_tf=bool(".ckpt" in args.model_name_or_path), 34 | config=config).eval() 35 | 36 | device = "cuda" if torch.cuda.is_available() else "cpu" 37 | qry_encoder.to(device) 38 | qry_encoder = qry_encoder.query_emb 39 | 40 | idx = 0 41 | with open(qid2query_path, "r") as f: 42 | queries = json.load(f) 43 | query_embeds = dict() 44 | for query, qid in tqdm(queries.items(), desc="Encode Query"): 45 | if idx < split * chunk_size or idx >= (split + 1) * chunk_size: 46 | idx += 1 47 | continue 48 | tokenized = tokenizer.encode_plus(query, return_tensors="pt", max_length=512) 49 | query_input_ids = tokenized["input_ids"].to(device) 50 | attention_mask = tokenized["attention_mask"].to(device) 51 | query_embed = qry_encoder(input_ids=query_input_ids, attention_mask=attention_mask).detach().cpu().numpy() 52 | query_embeds[qid] = query_embed 53 | idx += 1 54 | with open(os.path.join(args.output_dir, f"query_embeds_split{split}.pkl"), "wb") as fout: 55 | pickle.dump(query_embeds, fout) 56 | 57 | 58 | def get_np_query_embeds(args): 59 | embedid2qid = [] 60 | split = args.split 61 | with open(os.path.join(args.output_dir, f"query_embeds_split{split}.pkl"), "rb") as f: 62 | data = pickle.load(f) 63 | embeds = [] 64 | for k, v in data.items(): 65 | embeds.append(v) 66 | embedid2qid.append(k) 67 | with open(os.path.join(args.output_dir, f"query_embeds_split{split}.npy"), "wb") as fnp: 68 | embeds = np.concatenate(embeds) 69 | np.save(fnp, embeds) 70 | with open(os.path.join(args.output_dir, f"embedid2qid{split}.pkl"), "wb") as fout: 71 | pickle.dump(embedid2qid, fout) 72 | 73 | if __name__ == '__main__': 74 | args = parse_args() 75 | os.makedirs(args.output_dir, exist_ok=True) 76 | get_train_query_embeds(args) 77 | get_np_query_embeds(args) 78 | print(f"Done generating train query embeddings for split {args.split}.") -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath(os.path.dirname(__file__))) 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers import RobertaForSequenceClassification, RobertaTokenizer, RobertaConfig 8 | 9 | class EmbeddingMixin: 10 | """ 11 | Mixin for common functions in most embedding models. Each model should define its own bert-like backbone and forward. 12 | We inherit from RobertaModel to use from_pretrained 13 | """ 14 | def __init__(self, model_argobj): 15 | if model_argobj is None: 16 | self.use_mean = False 17 | else: 18 | self.use_mean = model_argobj.use_mean 19 | print("Using mean:", self.use_mean) 20 | 21 | def _init_weights(self, module): 22 | """ Initialize the weights """ 23 | if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)): 24 | # Slightly different from the TF version which uses truncated_normal for initialization 25 | # cf https://github.com/pytorch/pytorch/pull/5617 26 | module.weight.data.normal_(mean=0.0, std=0.02) 27 | 28 | def masked_mean(self, t, mask): 29 | s = torch.sum(t * mask.unsqueeze(-1).float(), axis=1) 30 | d = mask.sum(axis=1, keepdim=True).float() 31 | return s / d 32 | 33 | def masked_mean_or_first(self, emb_all, mask): 34 | # emb_all is a tuple from bert - sequence output, pooler (in old version only0 35 | # assert isinstance(emb_all, tuple) 36 | if self.use_mean: 37 | return self.masked_mean(emb_all[0], mask) 38 | else: 39 | return emb_all[0][:, 0] 40 | 41 | def query_emb(self, input_ids, attention_mask): 42 | raise NotImplementedError("Please Implement this method") 43 | 44 | def body_emb(self, input_ids, attention_mask): 45 | raise NotImplementedError("Please Implement this method") 46 | 47 | 48 | class NLL(EmbeddingMixin): 49 | def forward(self, input_ids, attention_mask, pos_emb, neg_emb): 50 | # pdb.set_trace() 51 | q_embs = self.query_emb(input_ids, attention_mask) 52 | 53 | logit_matrix = torch.cat([(q_embs * pos_emb).sum(-1).unsqueeze(1), 54 | (q_embs * neg_emb).sum(-1).unsqueeze(1)], dim=1) # [B, 2] 55 | lsm = F.log_softmax(logit_matrix, dim=1) 56 | loss = -1.0 * lsm[:, 0] 57 | return (loss.mean(),) 58 | 59 | 60 | class RobertaDot_NLL_LN(NLL, RobertaForSequenceClassification): 61 | """None 62 | Compress embedding to 200d, then computes NLL loss. 63 | """ 64 | 65 | def __init__(self, config, model_argobj=None): 66 | NLL.__init__(self, model_argobj) 67 | RobertaForSequenceClassification.__init__(self, config) 68 | self.embeddingHead = nn.Linear(config.hidden_size, 768) 69 | self.norm = nn.LayerNorm(768) 70 | self.apply(self._init_weights) 71 | 72 | def query_emb(self, input_ids, attention_mask): 73 | outputs1 = self.roberta(input_ids=input_ids, 74 | attention_mask=attention_mask) 75 | full_emb = self.masked_mean_or_first(outputs1, attention_mask) 76 | query1 = self.norm(self.embeddingHead(full_emb)) 77 | return query1 78 | 79 | def body_emb(self, input_ids, attention_mask): 80 | return self.query_emb(input_ids, attention_mask) 81 | 82 | class MSMarcoConfig: 83 | def __init__(self, name, model, process_fn, use_mean=True, tokenizer_class=RobertaTokenizer, config_class=RobertaConfig): 84 | self.name = name 85 | self.process_fn = process_fn 86 | self.model_class = model 87 | self.use_mean = use_mean 88 | self.tokenizer_class = tokenizer_class 89 | self.config_class = config_class 90 | 91 | -------------------------------------------------------------------------------- /lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | import collections 4 | import torch 5 | from tensorboardX import SummaryWriter 6 | from torch.optim import Optimizer 7 | 8 | 9 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 10 | """Log a histogram of trust ratio scalars in across layers.""" 11 | results = collections.defaultdict(list) 12 | for group in optimizer.param_groups: 13 | for p in group['params']: 14 | state = optimizer.state[p] 15 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 16 | if i in state: 17 | results[i].append(state[i]) 18 | 19 | for k, v in results.items(): 20 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 21 | 22 | 23 | class Lamb(Optimizer): 24 | r"""Implements Lamb algorithm. 25 | 26 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 27 | 28 | Arguments: 29 | params (iterable): iterable of parameters to optimize or dicts defining 30 | parameter groups 31 | lr (float, optional): learning rate (default: 1e-3) 32 | betas (Tuple[float, float], optional): coefficients used for computing 33 | running averages of gradient and its square (default: (0.9, 0.999)) 34 | eps (float, optional): term added to the denominator to improve 35 | numerical stability (default: 1e-8) 36 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 37 | adam (bool, optional): always use trust ratio = 1, which turns this into 38 | Adam. Useful for comparison purposes. 39 | 40 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 41 | https://arxiv.org/abs/1904.00962 42 | """ 43 | 44 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 45 | weight_decay=0, adam=False): 46 | if not 0.0 <= lr: 47 | raise ValueError("Invalid learning rate: {}".format(lr)) 48 | if not 0.0 <= eps: 49 | raise ValueError("Invalid epsilon value: {}".format(eps)) 50 | if not 0.0 <= betas[0] < 1.0: 51 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 52 | if not 0.0 <= betas[1] < 1.0: 53 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 54 | defaults = dict(lr=lr, betas=betas, eps=eps, 55 | weight_decay=weight_decay) 56 | self.adam = adam 57 | super(Lamb, self).__init__(params, defaults) 58 | 59 | def step(self, closure=None): 60 | """Performs a single optimization step. 61 | 62 | Arguments: 63 | closure (callable, optional): A closure that reevaluates the model 64 | and returns the loss. 65 | """ 66 | loss = None 67 | if closure is not None: 68 | loss = closure() 69 | 70 | for group in self.param_groups: 71 | for p in group['params']: 72 | if p.grad is None: 73 | continue 74 | grad = p.grad.data 75 | if grad.is_sparse: 76 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 77 | 78 | state = self.state[p] 79 | 80 | # State initialization 81 | if len(state) == 0: 82 | state['step'] = 0 83 | # Exponential moving average of gradient values 84 | state['exp_avg'] = torch.zeros_like(p.data) 85 | # Exponential moving average of squared gradient values 86 | state['exp_avg_sq'] = torch.zeros_like(p.data) 87 | 88 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 89 | beta1, beta2 = group['betas'] 90 | 91 | state['step'] += 1 92 | 93 | # Decay the first and second moment running average coefficient 94 | # m_t 95 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 96 | # v_t 97 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 98 | 99 | # Paper v3 does not use debiasing. 100 | # Apply bias to lr to avoid broadcast. 101 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 102 | 103 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 104 | 105 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 106 | if group['weight_decay'] != 0: 107 | adam_step.add_(group['weight_decay'], p.data) 108 | 109 | adam_norm = adam_step.pow(2).sum().sqrt() 110 | if weight_norm == 0 or adam_norm == 0: 111 | trust_ratio = 1 112 | else: 113 | trust_ratio = weight_norm / adam_norm 114 | state['weight_norm'] = weight_norm 115 | state['adam_norm'] = adam_norm 116 | state['trust_ratio'] = trust_ratio 117 | if self.adam: 118 | trust_ratio = 1 119 | 120 | p.data.add_(-step_size * trust_ratio, adam_step) 121 | 122 | return loss 123 | -------------------------------------------------------------------------------- /get_eval_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import argparse 5 | from data import * 6 | from util import * 7 | from torch.utils.data import DataLoader 8 | from runner import Trainer, Evaluator 9 | import numpy as np 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--first_stage_inn_path", default="/bos/usr0/hongqiay/ANCE/results/ance_fullrank_top1K_marco") 17 | parser.add_argument("--eval_mode", choices=["rerank", "full"], default="rerank") 18 | parser.add_argument("--ann_chunk_factor", type=int, default=100) 19 | parser.add_argument("--ance_checkpoint_path", default="/bos/tmp10/hongqiay/ance/marco_output_np/", 20 | help="location for dumpped query and passage/document embeddings which is output_dir") 21 | parser.add_argument("--train_data_dir", 22 | default="/bos/tmp10/hongqiay/prf-query-encoder/ance_format_data/prf_encoder_100splits") 23 | parser.add_argument("--dev_data_dir", 24 | default="/bos/tmp10/hongqiay/prf-query-encoder/ance_format_data/prf_encoder_dev") 25 | parser.add_argument("--max_seq_length", type=int, default=512) 26 | parser.add_argument("--optimizer", choices=["lamb", "adamw"], default="lamb") 27 | parser.add_argument("--num_queries", type=int, default=502939) 28 | parser.add_argument("--num_feedbacks", type=int, default=3) 29 | parser.add_argument("--preprocessed_dir", default="/bos/tmp10/hongqiay/ance/marco_preprocessed") 30 | parser.add_argument("--train", action="store_true") 31 | parser.add_argument("--eval", action="store_true") 32 | parser.add_argument("--local_rank", type=int, default=-1) 33 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, 34 | help="Batch size per GPU/CPU for training.") 35 | parser.add_argument("--per_gpu_eval_batch_size", default=50, type=int, 36 | help="Batch size per GPU/CPU for evaluation.") 37 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, 38 | help="Number of updates steps to accumulate before performing a backward/update pass.") 39 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 40 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 41 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 42 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 43 | parser.add_argument("--max_steps", default=1000000, type=int, 44 | help="If > 0: set total number of training steps to perform") 45 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 46 | parser.add_argument("--model_name_or_path", default="roberta-base", type=str) 47 | parser.add_argument("--load_optimizer_scheduler", default=False, action="store_true", 48 | help="load scheduler from checkpoint or not") 49 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 50 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 51 | parser.add_argument("--seed", type=int, default=42) 52 | parser.add_argument("--cache_dir", default="", type=str, 53 | help="Where do you want to store the pre-trained models downloaded from s3") 54 | parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.") 55 | parser.add_argument("--logging_steps", type=int, default=100, help="Log every X updates steps.") 56 | parser.add_argument("--save_steps", type=int, default=2000, help="Save checkpoint every X updates steps.") 57 | parser.add_argument("--output_dir") 58 | parser.add_argument("--prev_evaluated_ckpt", type=int, default=0, help="Start evaluating ckpt after this step.") 59 | parser.add_argument("--end_eval_ckpt", type=int, default=1000000, help="Start evaluating ckpt after this step.") 60 | parser.add_argument("--dataset", choices=["marco", "marco_eval", "trec19psg", "trec20psg"], default="marco") 61 | 62 | return parser.parse_args() 63 | 64 | 65 | def main(): 66 | args = parse_args() 67 | if is_first_worker() and args.train: 68 | os.makedirs(args.output_dir, exist_ok=True) 69 | with open(os.path.join(args.output_dir, "args"), "w") as f: 70 | f.write(str(args)) 71 | set_env(args) 72 | tokenizer, model = load_model(args) 73 | 74 | passage_collection_path = os.path.join(args.preprocessed_dir, "passages") 75 | passage_cache = EmbeddingCache(passage_collection_path) 76 | 77 | rerank_depths = None if args.eval_mode == "full" else [100] 78 | query_collection_path = os.path.join(args.preprocessed_dir, "dev-query") 79 | query_cache = EmbeddingCache(query_collection_path) 80 | with open(os.path.join(args.dev_data_dir, f"qry_encoder_dev_data_full")) as f: 81 | all_lines = f.readlines() 82 | 83 | results = [] 84 | with query_cache, passage_cache: 85 | dev_dataset = StreamingDataset(args, all_lines, 86 | GetTripletTrainingDataProcessingFn(args, query_cache, passage_cache, 87 | num_feedbacks=args.num_feedbacks)) 88 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 89 | dev_dataloader = DataLoader(dev_dataset, batch_size=args.eval_batch_size) 90 | with torch.no_grad(): 91 | evaluator = Evaluator(args, dev_dataloader) 92 | evaluator.eval(rerank_depths=rerank_depths, mode=args.eval_mode, results=results) 93 | np.save(os.path.join(args.output_dir, "metrics.npy"), results) 94 | 95 | 96 | 97 | if __name__ == '__main__': 98 | main() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import logging 5 | import argparse 6 | from data import * 7 | from utils.util import * 8 | from torch.utils.data import DataLoader 9 | from runner import Trainer, Evaluator 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--first_stage_inn_path", default=None) 18 | parser.add_argument("--eval_mode", choices=["rerank", "full"], default="rerank") 19 | parser.add_argument("--ann_chunk_factor", type=int, default=100) 20 | parser.add_argument("--ance_checkpoint_path", default=None, 21 | help="location for dumpped query and passage/document embeddings which is output_dir") 22 | parser.add_argument("--train_data_dir", 23 | default=None, help="Path to training data.") 24 | parser.add_argument("--dev_data_dir", 25 | default=None, help="Path to dev data.") 26 | parser.add_argument("--max_seq_length", type=int, default=512) 27 | parser.add_argument("--optimizer", choices=["lamb", "adamw"], default="lamb") 28 | parser.add_argument("--num_queries", type=int, default=502939) 29 | parser.add_argument("--num_feedbacks", type=int, default=3) 30 | parser.add_argument("--preprocessed_dir", default=None, help="Path to [dataset]_preprocessed folder.") 31 | parser.add_argument("--train", action="store_true") 32 | parser.add_argument("--eval", action="store_true") 33 | parser.add_argument("--local_rank", type=int, default=-1) 34 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, 35 | help="Batch size per GPU/CPU for training.") 36 | parser.add_argument("--per_gpu_eval_batch_size", default=50, type=int, 37 | help="Batch size per GPU/CPU for evaluation.") 38 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, 39 | help="Number of updates steps to accumulate before performing a backward/update pass.") 40 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 41 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 42 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 43 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 44 | parser.add_argument("--max_steps", default=1000000, type=int, 45 | help="If > 0: set total number of training steps to perform") 46 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 47 | parser.add_argument("--model_name_or_path", default="roberta-base", type=str) 48 | parser.add_argument("--load_optimizer_scheduler", default=False, action="store_true", 49 | help="load scheduler from checkpoint or not") 50 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 51 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 52 | parser.add_argument("--seed", type=int, default=42) 53 | parser.add_argument("--cache_dir", default="", type=str, 54 | help="Where do you want to store the pre-trained models downloaded from s3") 55 | parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.") 56 | parser.add_argument("--logging_steps", type=int, default=100, help="Log every X updates steps.") 57 | parser.add_argument("--save_steps", type=int, default=2000, help="Save checkpoint every X updates steps.") 58 | parser.add_argument("--output_dir", default=None, help="Directory to save output.") 59 | parser.add_argument("--prev_evaluated_ckpt", type=int, default=None , help="Start evaluating ckpt after this step.") 60 | parser.add_argument("--dataset", choices=["marco", "marco_eval", "trec19psg", "trec20psg", "dlhard"], default="marco") 61 | parser.add_argument("--end_eval_ckpt", type=int, default=1000000, help="Start evaluating ckpt after this step.") 62 | 63 | return parser.parse_args() 64 | 65 | 66 | def main(): 67 | args = parse_args() 68 | if is_first_worker() and args.train: 69 | os.makedirs(args.output_dir, exist_ok=True) 70 | with open(os.path.join(args.output_dir, "args"), "w") as f: 71 | f.write(str(args)) 72 | set_env(args) 73 | tokenizer, model = load_model(args) 74 | 75 | passage_collection_path = os.path.join(args.preprocessed_dir, "passages") 76 | passage_cache = EmbeddingCache(passage_collection_path) 77 | 78 | if args.train and args.eval: 79 | raise ValueError("train and eval are supposed to be initiated as separate tasks. ") 80 | 81 | if args.train: 82 | query_collection_path = os.path.join(args.preprocessed_dir, "train-query") 83 | query_cache = EmbeddingCache(query_collection_path) 84 | all_lines = [] 85 | with open(os.path.join(args.train_data_dir, "prf_train.tsv"), "r") as f: 86 | all_lines = f.readlines() 87 | with query_cache, passage_cache: 88 | train_dataset = StreamingDataset(args, all_lines, 89 | GetTripletTrainingDataProcessingFn(args, query_cache, passage_cache, 90 | num_feedbacks=args.num_feedbacks)) 91 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 92 | train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) 93 | trainer = Trainer(args, model, tokenizer, train_dataloader) 94 | trainer.train() 95 | 96 | if args.eval: 97 | rerank_depths = None if args.eval_mode == "full" else [20, 50, 100, 200, 500, 1000] 98 | query_collection_path = os.path.join(args.preprocessed_dir, "dev-query") 99 | query_cache = EmbeddingCache(query_collection_path) 100 | with open(os.path.join(args.dev_data_dir, f"prf_dev.tsv")) as f: 101 | all_lines = f.readlines() 102 | 103 | with query_cache, passage_cache: 104 | dev_dataset = StreamingDataset(args, all_lines, 105 | GetTripletTrainingDataProcessingFn(args, query_cache, passage_cache, 106 | num_feedbacks=args.num_feedbacks)) 107 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 108 | dev_dataloader = DataLoader(dev_dataset, batch_size=args.eval_batch_size) 109 | with torch.no_grad(): 110 | evaluator = Evaluator(args, dev_dataloader) 111 | evaluator.eval(rerank_depths=rerank_depths, mode=args.eval_mode) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() -------------------------------------------------------------------------------- /data_prep/get_prf_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | import csv 5 | import random 6 | import argparse 7 | import faiss 8 | import logging 9 | import pickle 10 | import numpy as np 11 | from utils.util import load_embeddings 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--processed_data_dir", required=True) 20 | parser.add_argument("--negative_sample", type=int, default=20, help="Number of negative documents to sample.") 21 | parser.add_argument("--num_feedbacks", type=int, default=20, help="Number of feedback documents to save.") 22 | parser.add_argument("--output_dir") 23 | parser.add_argument("--inn_path", help="Path to the num", required=True) 24 | parser.add_argument("--ance_checkpoint_path", help="The path to the _output directory", required=True) 25 | parser.add_argument("--mode", choices=["train", "dev"], required=True) 26 | parser.add_argument("--dataset", default=None, help="Name of the dataset.", required=True) 27 | parser.add_argument("--ann_chunk_factor", type=int, default=100, help="number of chunks to split the training data into.") 28 | 29 | return parser.parse_args() 30 | 31 | 32 | def add_psgs(res_psgs, passage_embedding2id, query_id, selected_ann_idx, num_examples): 33 | cnt = 0 34 | rank = 0 35 | for idx in selected_ann_idx: 36 | pid = passage_embedding2id[idx] 37 | rank += 1 38 | 39 | if pid in res_psgs[query_id]: 40 | continue 41 | 42 | if cnt >= num_examples: 43 | break 44 | 45 | res_psgs[query_id].append(pid) 46 | cnt += 1 47 | 48 | 49 | def generate_pids(args, query_embedding2id, passage_embedding2id, I_nearest_neighbor, effective_q_id): 50 | 51 | query_negative_passage = {} 52 | ance_top_passage = {} 53 | num_queries = 0 54 | 55 | for query_idx in range(I_nearest_neighbor.shape[0]): 56 | 57 | query_id = query_embedding2id[query_idx] 58 | 59 | if query_id not in effective_q_id: 60 | continue 61 | 62 | num_queries += 1 63 | 64 | top_ann_pid = I_nearest_neighbor[query_idx, :].copy() 65 | 66 | ance_top_passage[query_id] = [] 67 | add_psgs(ance_top_passage, passage_embedding2id, query_id, top_ann_pid, args.num_feedbacks) 68 | 69 | # Randomly sample negative 70 | negative_sample_I_idx = list(range(I_nearest_neighbor.shape[1])) 71 | random.shuffle(negative_sample_I_idx) 72 | selected_ann_idx = top_ann_pid[negative_sample_I_idx] 73 | query_negative_passage[query_id] = [] 74 | add_psgs(query_negative_passage, passage_embedding2id, query_id, selected_ann_idx, args.negative_sample) 75 | 76 | return query_negative_passage, ance_top_passage 77 | 78 | # TODO: this function is problematic! dev data sometimes have more than one rel docs 79 | def load_positive_ids(args): 80 | mode = args.mode 81 | logger.info("Loading query_2_pos_docid") 82 | query_positive_id = {} 83 | query_positive_id_path = os.path.join(args.processed_data_dir, f"{mode}-qrel.tsv") 84 | with open(query_positive_id_path, 'r', encoding='utf8') as f: 85 | tsvreader = csv.reader(f, delimiter="\t") 86 | for [topicid, docid, rel] in tsvreader: 87 | if int(rel) != 0: 88 | # assert rel == "1" 89 | topicid = int(topicid) 90 | docid = int(docid) 91 | query_positive_id[topicid] = docid 92 | 93 | return query_positive_id 94 | 95 | 96 | def generate_data(args, query_embedding2id, passage_embedding2id): 97 | query_positive_id = load_positive_ids(args) 98 | with open(args.inn_path, "rb") as f: 99 | I = np.load(f) 100 | effective_q_id = set(query_embedding2id.flatten()) 101 | query_negative_passage, ance_top_passage = generate_pids(args, query_embedding2id, passage_embedding2id, I, effective_q_id) 102 | with open(os.path.join(args.output_dir, f"prf_{args.mode}.tsv"), "w") as f: 103 | query_range = list(range(I.shape[0])) 104 | for query_idx in query_range: 105 | query_id = query_embedding2id[query_idx] 106 | if query_id not in effective_q_id or query_id not in query_positive_id: 107 | print(f"invalid qid {query_id}") 108 | pos_pid = query_positive_id[query_id] 109 | f.write( 110 | "{}\t{}\t{}\t{}\n".format( 111 | query_id, pos_pid, 112 | ','.join(str(feedback_pid) for feedback_pid in ance_top_passage[query_id]), 113 | ','.join(str(neg_pid) for neg_pid in query_negative_passage[query_id]))) 114 | 115 | 116 | def get_psg_embeds(args, passage_embedding, passage_embedding2id): 117 | id2embedding = dict() 118 | for embed_id, id in enumerate(passage_embedding2id): 119 | id2embedding[id] = embed_id 120 | embeddings = [] 121 | tsv_path = os.path.join(args.output_dir, f"prf_{args.mode}.tsv") 122 | 123 | with open(tsv_path, "r") as f: 124 | for l in f: 125 | line_arr = l.split("\t") 126 | pos_pid = id2embedding[int(line_arr[1])] 127 | neg_pids = line_arr[3].split(",") 128 | neg_pids = [id2embedding[int(neg_pid)] for neg_pid in neg_pids] 129 | all_pids = [pos_pid] + neg_pids 130 | embeddings.append(passage_embedding[all_pids]) 131 | embeddings = np.array(embeddings) 132 | output_path = os.path.join(args.output_dir, f"psg_embeds_{args.mode}") 133 | with open(output_path, "wb") as fout: 134 | pickle.dump(embeddings, fout, protocol=4) 135 | return embeddings 136 | 137 | 138 | def split_train_data(args, psg_embeds, num_queries): 139 | tsv_path = os.path.join(args.output_dir, f"prf_train.tsv") 140 | queries_per_chunk = num_queries // args.ann_chunk_factor 141 | with open(tsv_path, "r") as f: 142 | for i in range(args.ann_chunk_factor): 143 | output_tsv_path = os.path.join(args.output_dir, f"prf_train_{i}.tsv") 144 | output_embed_path = os.path.join(args.output_dir, f"psg_embeds_{args.mode}_{i}") 145 | q_start_idx = queries_per_chunk * i 146 | q_end_idx = num_queries if ( 147 | i == ( 148 | args.ann_chunk_factor - 149 | 1)) else ( 150 | q_start_idx + 151 | queries_per_chunk) 152 | with open(output_tsv_path, "w") as fout: 153 | for _ in q_start_idx, q_end_idx: 154 | l = f.readline() 155 | fout.write(l.strip() + "\n") 156 | with open(output_embed_path, "wb") as fout: 157 | pickle.dump(psg_embeds[q_start_idx:q_end_idx], fout, protocol=4) 158 | 159 | 160 | if __name__ == '__main__': 161 | args = parse_args() 162 | os.makedirs(args.output_dir, exist_ok=True) 163 | query_embedding, query_embedding2id, passage_embedding, passage_embedding2id = load_embeddings(args, args.mode) 164 | generate_data(args, query_embedding2id, passage_embedding2id) 165 | psg_embeds = get_psg_embeds(args, passage_embedding, passage_embedding2id) 166 | if args.mode == "train": 167 | num_queries = psg_embeds.shape[0] 168 | print(f"Splitting training data into {args.ann_chunk_factor} chunks. There are {num_queries} queries in total." ) 169 | split_train_data(args, psg_embeds, num_queries) 170 | print(f"Generated PRF {args.mode} data for {args.dataset}") 171 | -------------------------------------------------------------------------------- /test_ance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.functional as F 4 | from transformers import RobertaForSequenceClassification, RobertaTokenizer, RobertaConfig 5 | import pdb 6 | 7 | 8 | def pad_ids(input_ids, attention_mask, token_type_ids, max_length, pad_token, mask_padding_with_zero, 9 | pad_token_segment_id, pad_on_left=False): 10 | padding_length = max_length - len(input_ids) 11 | if pad_on_left: 12 | input_ids = ([pad_token] * padding_length) + input_ids 13 | attention_mask = ([0 if mask_padding_with_zero else 1] 14 | * padding_length) + attention_mask 15 | token_type_ids = ([pad_token_segment_id] * 16 | padding_length) + token_type_ids 17 | else: 18 | input_ids += [pad_token] * padding_length 19 | attention_mask += [0 if mask_padding_with_zero else 1] * padding_length 20 | token_type_ids += [pad_token_segment_id] * padding_length 21 | 22 | return input_ids, attention_mask, token_type_ids 23 | 24 | 25 | def triple_process_fn(line, i, tokenizer, args): 26 | features = [] 27 | cells = line.split("\t") 28 | if len(cells) == 3: 29 | # this is for training and validation 30 | # query, positive_passage, negative_passage = line 31 | mask_padding_with_zero = True 32 | pad_token_segment_id = 0 33 | pad_on_left = False 34 | 35 | for text in cells: 36 | input_id_a = tokenizer.encode( 37 | text.strip(), add_special_tokens=True, max_length=args.max_seq_length,) 38 | token_type_ids_a = [0] * len(input_id_a) 39 | attention_mask_a = [ 40 | 1 if mask_padding_with_zero else 0] * len(input_id_a) 41 | input_id_a, attention_mask_a, token_type_ids_a = pad_ids( 42 | input_id_a, attention_mask_a, token_type_ids_a, args.max_seq_length, tokenizer.pad_token_id, 43 | mask_padding_with_zero, pad_token_segment_id, pad_on_left) 44 | features += [torch.tensor(input_id_a, dtype=torch.int), 45 | torch.tensor(attention_mask_a, dtype=torch.bool)] 46 | else: 47 | raise Exception( 48 | "Line doesn't have correct length: {0}. Expected 3.".format(str(len(cells)))) 49 | return [features] 50 | 51 | 52 | class EmbeddingMixin: 53 | """ 54 | Mixin for common functions in most embedding models. Each model should define its own bert-like backbone and forward. 55 | We inherit from RobertaModel to use from_pretrained 56 | """ 57 | def __init__(self, model_argobj): 58 | if model_argobj is None: 59 | self.use_mean = False 60 | else: 61 | self.use_mean = model_argobj.use_mean 62 | print("Using mean:", self.use_mean) 63 | 64 | def _init_weights(self, module): 65 | """ Initialize the weights """ 66 | if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)): 67 | # Slightly different from the TF version which uses truncated_normal for initialization 68 | # cf https://github.com/pytorch/pytorch/pull/5617 69 | module.weight.data.normal_(mean=0.0, std=0.02) 70 | 71 | def masked_mean(self, t, mask): 72 | s = torch.sum(t * mask.unsqueeze(-1).float(), axis=1) 73 | d = mask.sum(axis=1, keepdim=True).float() 74 | return s / d 75 | 76 | def masked_mean_or_first(self, emb_all, mask): 77 | # emb_all is a tuple from bert - sequence output, pooler 78 | assert isinstance(emb_all, tuple) 79 | if self.use_mean: 80 | return self.masked_mean(emb_all[0], mask) 81 | else: 82 | return emb_all[0][:, 0] 83 | 84 | def query_emb(self, input_ids, attention_mask): 85 | raise NotImplementedError("Please Implement this method") 86 | 87 | def body_emb(self, input_ids, attention_mask): 88 | raise NotImplementedError("Please Implement this method") 89 | 90 | 91 | class NLL(EmbeddingMixin): 92 | def forward( 93 | self, 94 | query_ids, 95 | attention_mask_q, 96 | input_ids_a=None, 97 | attention_mask_a=None, 98 | input_ids_b=None, 99 | attention_mask_b=None, 100 | is_query=True): 101 | if input_ids_b is None and is_query: 102 | return self.query_emb(query_ids, attention_mask_q) 103 | elif input_ids_b is None: 104 | return self.body_emb(query_ids, attention_mask_q) 105 | # pdb.set_trace() 106 | q_embs = self.query_emb(query_ids, attention_mask_q) 107 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 108 | b_embs = self.body_emb(input_ids_b, attention_mask_b) 109 | 110 | logit_matrix = torch.cat([(q_embs * a_embs).sum(-1).unsqueeze(1), 111 | (q_embs * b_embs).sum(-1).unsqueeze(1)], dim=1) # [B, 2] 112 | lsm = F.log_softmax(logit_matrix, dim=1) 113 | loss = -1.0 * lsm[:, 0] 114 | return (loss.mean(),) 115 | 116 | 117 | class RobertaDot_NLL_LN(NLL, RobertaForSequenceClassification): 118 | """None 119 | Compress embedding to 200d, then computes NLL loss. 120 | """ 121 | 122 | def __init__(self, config, model_argobj=None): 123 | NLL.__init__(self, model_argobj) 124 | RobertaForSequenceClassification.__init__(self, config) 125 | self.embeddingHead = nn.Linear(config.hidden_size, 768) 126 | self.norm = nn.LayerNorm(768) 127 | self.apply(self._init_weights) 128 | 129 | def query_emb(self, input_ids, attention_mask): 130 | outputs1 = self.roberta(input_ids=input_ids, 131 | attention_mask=attention_mask) 132 | full_emb = self.masked_mean_or_first(outputs1, attention_mask) 133 | query1 = self.norm(self.embeddingHead(full_emb)) 134 | return query1 135 | 136 | def body_emb(self, input_ids, attention_mask): 137 | return self.query_emb(input_ids, attention_mask) 138 | 139 | 140 | default_process_fn = triple_process_fn 141 | 142 | 143 | class MSMarcoConfig: 144 | def __init__(self, name, model, process_fn=default_process_fn, use_mean=True, tokenizer_class=RobertaTokenizer, 145 | config_class=RobertaConfig): 146 | self.name = name 147 | self.process_fn = process_fn 148 | self.model_class = model 149 | self.use_mean = use_mean 150 | self.tokenizer_class = tokenizer_class 151 | self.config_class = config_class 152 | 153 | 154 | def test_loss(): 155 | config_obj = MSMarcoConfig(name="rdot_nll", model=RobertaDot_NLL_LN, use_mean=False) 156 | config = config_obj.config_class.from_pretrained("roberta-base", num_labels=2) 157 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base", do_lower_case=False) 158 | model = RobertaDot_NLL_LN.from_pretrained("roberta-base", config=config) 159 | query = "Test query" 160 | doc1 = "Positive doc" 161 | doc2 = "Negative doc" 162 | query_tokens = tokenizer(query, return_tensors="pt", max_length=128, truncation=True) 163 | doc1_tokens = tokenizer(doc1, return_tensors="pt", max_length=128, truncation=True) 164 | doc2_tokens = tokenizer(doc2, return_tensors="pt", max_length=128, truncation=True) 165 | _ = model(query_ids=query_tokens["input_ids"], attention_mask_q=query_tokens["attention_mask"], 166 | input_ids_a=doc1_tokens["input_ids"], attention_mask_a=doc1_tokens["attention_mask"], 167 | input_ids_b=doc2_tokens["input_ids"], attention_mask_b=doc2_tokens["attention_mask"]) 168 | 169 | 170 | def main(): 171 | test_loss() 172 | 173 | 174 | if __name__ == '__main__': 175 | main() -------------------------------------------------------------------------------- /msmarco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is official eval script opensourced on MSMarco site (not written or owned by us) 3 | 4 | This module computes evaluation metrics for MSMARCO dataset on the ranking task. 5 | Command line: 6 | python msmarco_eval_ranking.py 7 | 8 | Creation Date : 06/12/2018 9 | Last Modified : 1/21/2019 10 | Authors : Daniel Campos , Rutger van Haasteren 11 | """ 12 | import sys 13 | import statistics 14 | 15 | from collections import Counter 16 | import pdb 17 | 18 | MaxMRRRank = 10 19 | 20 | def load_reference_from_stream(f): 21 | """Load Reference reference relevant passages 22 | Args:f (stream): stream to load. 23 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 24 | """ 25 | qids_to_relevant_passageids = {} 26 | for l in f: 27 | try: 28 | l = l.strip().split('\t') 29 | qid = int(l[0]) 30 | if qid in qids_to_relevant_passageids: 31 | pass 32 | else: 33 | qids_to_relevant_passageids[qid] = [] 34 | qids_to_relevant_passageids[qid].append(int(l[2])) 35 | except: 36 | raise IOError('\"%s\" is not valid format' % l) 37 | return qids_to_relevant_passageids 38 | 39 | def load_reference(path_to_reference): 40 | """Load Reference reference relevant passages 41 | Args:path_to_reference (str): path to a file to load. 42 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 43 | """ 44 | with open(path_to_reference,'r') as f: 45 | qids_to_relevant_passageids = load_reference_from_stream(f) 46 | return qids_to_relevant_passageids 47 | 48 | def load_candidate_from_stream(f): 49 | """Load candidate data from a stream. 50 | Args:f (stream): stream to load. 51 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 52 | """ 53 | qid_to_ranked_candidate_passages = {} 54 | for l in f: 55 | try: 56 | l = l.strip().split('\t') 57 | if len(l) == 6: 58 | qid, pid, rank = int(l[0]), int(l[2]), int(l[3]) 59 | else: 60 | qid = int(l[0]) 61 | pid = int(l[1]) 62 | rank = int(l[2]) 63 | if qid in qid_to_ranked_candidate_passages: 64 | pass 65 | else: 66 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 67 | tmp = [0] * 1000 68 | qid_to_ranked_candidate_passages[qid] = tmp 69 | qid_to_ranked_candidate_passages[qid][rank-1]=pid 70 | except Exception as e: 71 | raise IOError('\"%s\" is not valid format' % l) 72 | return qid_to_ranked_candidate_passages 73 | 74 | def load_candidate(path_to_candidate): 75 | """Load candidate data from a file. 76 | Args:path_to_candidate (str): path to file to load. 77 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 78 | """ 79 | 80 | with open(path_to_candidate,'r') as f: 81 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f) 82 | return qid_to_ranked_candidate_passages 83 | 84 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 85 | """Perform quality checks on the dictionaries 86 | 87 | Args: 88 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 89 | Dict as read in with load_reference or load_reference_from_stream 90 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 91 | Returns: 92 | bool,str: Boolean whether allowed, message to be shown in case of a problem 93 | """ 94 | message = '' 95 | allowed = True 96 | 97 | # Create sets of the QIDs for the submitted and reference queries 98 | candidate_set = set(qids_to_ranked_candidate_passages.keys()) 99 | ref_set = set(qids_to_relevant_passageids.keys()) 100 | 101 | # Check that we do not have multiple passages per query 102 | for qid in qids_to_ranked_candidate_passages: 103 | # Remove all zeros from the candidates 104 | duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) 105 | 106 | if len(duplicate_pids-set([0])) > 0: 107 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( 108 | qid=qid, pid=list(duplicate_pids)[0]) 109 | allowed = False 110 | 111 | return allowed, message 112 | 113 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 114 | """Compute MRR metric 115 | Args: 116 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 117 | Dict as read in with load_reference or load_reference_from_stream 118 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 119 | Returns: 120 | dict: dictionary of metrics {'MRR': } 121 | """ 122 | all_scores = {} 123 | MRR = 0 124 | qids_with_relevant_passages = 0 125 | ranking = [] 126 | for qid in qids_to_ranked_candidate_passages: 127 | if qid in qids_to_relevant_passageids: 128 | ranking.append(0) 129 | target_pid = qids_to_relevant_passageids[qid] 130 | candidate_pid = qids_to_ranked_candidate_passages[qid] 131 | for i in range(0,MaxMRRRank): 132 | if candidate_pid[i] in target_pid: 133 | MRR += 1/(i + 1) 134 | ranking.pop() 135 | ranking.append(i+1) 136 | break 137 | if len(ranking) == 0: 138 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 139 | 140 | MRR = MRR/len(qids_to_relevant_passageids) 141 | all_scores['MRR @10'] = MRR 142 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) 143 | return all_scores 144 | 145 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): 146 | """Compute MRR metric 147 | Args: 148 | p_path_to_reference_file (str): path to reference file. 149 | Reference file should contain lines in the following format: 150 | QUERYID\tPASSAGEID 151 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs 152 | p_path_to_candidate_file (str): path to candidate file. 153 | Candidate file sould contain lines in the following format: 154 | QUERYID\tPASSAGEID1\tRank 155 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 156 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 157 | Where the values are separated by tabs and ranked in order of relevance 158 | Returns: 159 | dict: dictionary of metrics {'MRR': } 160 | """ 161 | 162 | qids_to_relevant_passageids = load_reference(path_to_reference) 163 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) 164 | if perform_checks: 165 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 166 | if message != '': print(message) 167 | 168 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 169 | 170 | 171 | def get_mrr(path_to_reference, path_to_candidate): 172 | return compute_metrics_from_files(path_to_reference, path_to_candidate)["MRR @10"] 173 | 174 | def main(): 175 | """Command line: 176 | python msmarco_eval_ranking.py 177 | """ 178 | print("Eval Started") 179 | if len(sys.argv) == 3: 180 | path_to_reference = sys.argv[1] 181 | path_to_candidate = sys.argv[2] 182 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 183 | print('#####################') 184 | for metric in sorted(metrics): 185 | print('{}: {}'.format(metric, metrics[metric])) 186 | print('#####################') 187 | 188 | else: 189 | print('Usage: msmarco_eval_ranking.py ') 190 | exit() 191 | 192 | if __name__ == '__main__': 193 | main() -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | import csv 6 | import torch 7 | from torch.utils.data import IterableDataset, TensorDataset 8 | import torch.distributed as dist 9 | import logging 10 | import pdb 11 | import random 12 | 13 | 14 | class EmbeddingCache: 15 | def __init__(self, base_path, seed=-1): 16 | self.base_path = base_path 17 | with open(base_path + '_meta', 'r') as f: 18 | meta = json.load(f) 19 | self.dtype = np.dtype(meta['type']) 20 | self.total_number = meta['total_number'] 21 | self.record_size = int( 22 | meta['embedding_size']) * self.dtype.itemsize + 4 23 | if seed >= 0: 24 | self.ix_array = np.random.RandomState( 25 | seed).permutation(self.total_number) 26 | else: 27 | self.ix_array = np.arange(self.total_number) 28 | self.f = None 29 | 30 | def open(self): 31 | self.f = open(self.base_path, 'rb') 32 | 33 | def close(self): 34 | self.f.close() 35 | 36 | def read_single_record(self): 37 | record_bytes = self.f.read(self.record_size) 38 | passage_len = int.from_bytes(record_bytes[:4], 'big') 39 | passage = np.frombuffer(record_bytes[4:], dtype=self.dtype) 40 | return passage_len, passage 41 | 42 | def __enter__(self): 43 | self.open() 44 | return self 45 | 46 | def __exit__(self, type, value, traceback): 47 | self.close() 48 | 49 | def __getitem__(self, key): 50 | if key < 0 or key > self.total_number: 51 | raise IndexError( 52 | "Index {} is out of bound for cached embeddings of size {}".format( 53 | key, self.total_number)) 54 | self.f.seek(key * self.record_size) 55 | return self.read_single_record() 56 | 57 | def __iter__(self): 58 | self.f.seek(0) 59 | for i in range(self.total_number): 60 | new_ix = self.ix_array[i] 61 | yield self.__getitem__(new_ix) 62 | 63 | def __len__(self): 64 | return self.total_number 65 | 66 | 67 | class StreamingDataset(IterableDataset): 68 | def __init__(self, args, elements, fn, distributed=True): 69 | super().__init__() 70 | self.elements = elements 71 | self.fn = fn 72 | self.num_replicas = -1 73 | self.distributed = distributed 74 | self.psg_embeds = None 75 | if args.train: 76 | with open(os.path.join(args.train_data_dir, "psg_embeds_train_0"), "rb") as f: 77 | self.psg_embeds = pickle.load(f) 78 | elif args.eval: 79 | with open(os.path.join(args.dev_data_dir, "psg_embeds_dev"), "rb") as f: 80 | self.psg_embeds = pickle.load(f) 81 | self.curr_split_idx = 0 82 | self.queries_per_chunk = args.num_queries // args.ann_chunk_factor 83 | self.args = args 84 | 85 | def load_psg_embeds(self, i): 86 | split_idx = (i // self.queries_per_chunk) % self.args.ann_chunk_factor 87 | if self.curr_split_idx != split_idx: 88 | logging.info(f"Training on split {split_idx}...") 89 | with open(os.path.join(self.args.train_data_dir, f"psg_embeds_train_{split_idx}"), "rb") as f: 90 | self.psg_embeds = pickle.load(f) 91 | self.curr_split_idx = split_idx 92 | 93 | def __iter__(self): 94 | if dist.is_initialized(): 95 | self.num_replicas = dist.get_world_size() 96 | self.rank = dist.get_rank() 97 | else: 98 | print("Not running in distributed mode") 99 | for i, element in enumerate(self.elements): 100 | if self.args.train: 101 | self.load_psg_embeds(i) 102 | if self.distributed and self.num_replicas != -1 and i % self.num_replicas != self.rank: 103 | continue 104 | records = self.fn(self.psg_embeds, element, i) 105 | # Each file line corresponds to several examples (1 + # neg samples) 106 | for rec in records: 107 | yield rec 108 | 109 | 110 | def GetProcessingFn(args): 111 | """ 112 | Modified from ANCE's GetProcessingFn 113 | :param args: 114 | :return: 115 | """ 116 | def fn(psg_emb, qry_len, qry, feedback_cache_items): 117 | # Get model input data 118 | sep_id = 2 119 | all_input_ids = [qry[:qry_len]] 120 | for feedback_len, feedback in feedback_cache_items: 121 | all_input_ids.append(feedback[1:feedback_len]) 122 | all_input_ids = np.concatenate(all_input_ids) 123 | if len(all_input_ids) > args.max_seq_length: 124 | all_input_ids = np.append(all_input_ids[:args.max_seq_length - 1], all_input_ids[-1]) # -1 is CLS 125 | content_len = args.max_seq_length 126 | else: 127 | all_sep_idxs = [idx for idx in range(len(all_input_ids)) if all_input_ids[idx] == sep_id] 128 | content_len = all_sep_idxs[-1] + 1 129 | pad_len = args.max_seq_length - content_len 130 | attention_mask = [1] * content_len + [0] * pad_len 131 | all_input_ids = torch.tensor([list(all_input_ids) + [0] * pad_len], dtype=torch.int) 132 | attention_mask = torch.tensor([attention_mask], dtype=torch.bool) 133 | if psg_emb is not None: 134 | psg_emb = torch.tensor([psg_emb], dtype=torch.float32) 135 | 136 | dataset = TensorDataset( 137 | all_input_ids, 138 | attention_mask, 139 | psg_emb 140 | ) 141 | return [ts for ts in dataset] 142 | 143 | return fn 144 | 145 | 146 | def GetTripletTrainingDataProcessingFn(args, query_cache, passage_cache, num_feedbacks=3): 147 | def fn(ordered_psg_embeds, line, i): 148 | queries_per_chunk = args.num_queries // args.ann_chunk_factor 149 | effective_idx = i % queries_per_chunk 150 | psg_embeds = ordered_psg_embeds[effective_idx] 151 | 152 | line_arr = line.split('\t') 153 | qid = int(line_arr[0]) 154 | pos_pid = int(line_arr[1]) 155 | feedback_pids = line_arr[2].split(",")[:num_feedbacks] 156 | feedback_pids = [int(feedback_pid) for feedback_pid in feedback_pids] 157 | neg_pids = line_arr[3].split(',') 158 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 159 | 160 | 161 | qry_len, qry = query_cache[qid] 162 | feedback_cache_items = [passage_cache[feedback_pid] for feedback_pid in feedback_pids] 163 | 164 | pos_data = GetProcessingFn(args)(psg_embeds[0], qry_len, qry, feedback_cache_items)[0] 165 | 166 | if args.eval: 167 | neg_pids = neg_pids[:1] 168 | for i in range(len(neg_pids)): 169 | neg_psg_emb = psg_embeds[1+i] 170 | neg_psg_emb = torch.tensor(neg_psg_emb, dtype=torch.float32) 171 | res = [data for data in pos_data] 172 | res.append(neg_psg_emb) 173 | yield res 174 | 175 | return fn 176 | 177 | 178 | def attentionVisProcessingFn(args): 179 | sep_id = 2 180 | def fn(qry_len, qry, feedback_cache_items): 181 | all_input_ids = [qry[:qry_len]] 182 | for feedback_len, feedback in feedback_cache_items: 183 | all_input_ids.append(feedback[1:feedback_len]) 184 | all_input_ids = np.concatenate(all_input_ids) 185 | if len(all_input_ids) > args.max_seq_length: 186 | all_input_ids = np.append(all_input_ids[:args.max_seq_length - 1], all_input_ids[-1]) # -1 is CLS 187 | all_sep_idxs = [idx for idx in range(len(all_input_ids)) if all_input_ids[idx] == sep_id] 188 | content_len = args.max_seq_length 189 | else: 190 | all_sep_idxs = [idx for idx in range(len(all_input_ids)) if all_input_ids[idx] == sep_id] 191 | content_len = all_sep_idxs[-1] + 1 192 | if len(all_sep_idxs) < (args.num_feedbacks + 1): 193 | all_sep_idxs += [all_sep_idxs[-1]] * (args.num_feedbacks + 1 - len(all_sep_idxs)) 194 | all_sep_idxs = torch.tensor([all_sep_idxs], dtype=torch.int) 195 | pad_len = args.max_seq_length - content_len 196 | attention_mask = [1] * content_len + [0] * pad_len 197 | all_input_ids = torch.tensor([list(all_input_ids) + [0] * pad_len], dtype=torch.int) 198 | attention_mask = torch.tensor([attention_mask], dtype=torch.bool) 199 | 200 | dataset = TensorDataset( 201 | all_input_ids, 202 | attention_mask, 203 | all_sep_idxs 204 | ) 205 | return [ts for ts in dataset] 206 | 207 | return fn 208 | 209 | 210 | def GetAttentionVisProcessingFn(args, query_cache, passage_cache, num_feedbacks=3): 211 | def fn(ordered_psg_embeds, line, i): 212 | line_arr = line.split('\t') 213 | qid = int(line_arr[0]) 214 | pos_pid = int(line_arr[1]) 215 | feedback_pids = line_arr[2].split(",")[:num_feedbacks] 216 | feedback_pids = [int(feedback_pid) for feedback_pid in feedback_pids] 217 | 218 | qry_len, qry = query_cache[qid] 219 | feedback_cache_items = [passage_cache[feedback_pid] for feedback_pid in feedback_pids] 220 | 221 | pos_data = attentionVisProcessingFn(args)(qry_len, qry, feedback_cache_items)[0] 222 | 223 | if pos_pid in feedback_pids: 224 | pos_idx = feedback_pids.index(pos_pid) 225 | else: 226 | pos_idx = -1 227 | 228 | yield [data for data in pos_data] + [pos_idx] + [qid] 229 | 230 | return fn 231 | 232 | 233 | def GetDotProductProcessingFn(args, query_cache, passage_cache, ance_query_embedding, ance_qid2embedid, 234 | ance_passage_embedding, ance_pid2embedid, 235 | num_feedbacks=3): 236 | def fn(ordered_psg_embeds, line, i): 237 | queries_per_chunk = args.num_queries // args.ann_chunk_factor 238 | effective_idx = i % queries_per_chunk 239 | psg_embeds = ordered_psg_embeds[effective_idx] 240 | line_arr = line.split('\t') 241 | qid = int(line_arr[0]) 242 | pos_pid = int(line_arr[1]) 243 | feedback_pids = line_arr[2].split(",")[:num_feedbacks] 244 | feedback_pids = [int(feedback_pid) for feedback_pid in feedback_pids] 245 | included = pos_pid in feedback_pids 246 | if included: 247 | random.shuffle(feedback_pids) 248 | pos_idx = feedback_pids.index(pos_pid) 249 | neg_pids = line_arr[3].split(',') 250 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 251 | 252 | ance_qry_embed = torch.tensor([ance_query_embedding[ance_qid2embedid[qid]]], dtype=torch.float32) 253 | pos_qry_embed = torch.tensor([psg_embeds[0]], dtype=torch.float32) 254 | neg_psg_embeds = [neg_psg_emb for neg_psg_emb in psg_embeds[1:]] 255 | neg_psg_embeds = torch.tensor([neg_psg_embeds], dtype=torch.float32) 256 | feedback_embeds = [ance_passage_embedding[ance_pid2embedid[feedback_pid]] for feedback_pid in feedback_pids] 257 | feedback_embeds = torch.tensor([feedback_embeds], dtype=torch.float32) 258 | 259 | qry_len, qry = query_cache[qid] 260 | if included: 261 | feedback_cache_items = [passage_cache[feedback_pid] for feedback_pid in feedback_pids] 262 | else: 263 | pos_idx = random.randrange(3) 264 | feedback_cache_items = [passage_cache[pos_pid] if i == pos_idx else passage_cache[feedback_pid] 265 | for i, feedback_pid in enumerate(feedback_pids)] 266 | pos_idx = torch.tensor([pos_idx], dtype=torch.int) 267 | 268 | pos_data = GetProcessingFn(args)(psg_embeds[0], qry_len, qry, feedback_cache_items)[0] 269 | res = [data for data in pos_data[:-1]] 270 | res += [ance_qry_embed, pos_qry_embed, neg_psg_embeds, feedback_embeds, pos_idx] 271 | 272 | yield res # (all_input_ids, attention_mask, ance_qry_emb, pos_emb, neg_embs, feedback_embeds) 273 | 274 | return fn 275 | 276 | 277 | 278 | def GetTsneDotProductrocessingFn(args, query_cache, passage_cache, ance_query_embedding, ance_qid2embedid, 279 | ance_passage_embedding, ance_pid2embedid, 280 | num_feedbacks=3): 281 | def fn(ordered_psg_embeds, line, i): 282 | queries_per_chunk = args.num_queries // args.ann_chunk_factor 283 | effective_idx = i % queries_per_chunk 284 | psg_embeds = ordered_psg_embeds[effective_idx] 285 | line_arr = line.split('\t') 286 | qid = int(line_arr[0]) 287 | pos_pid = int(line_arr[1]) 288 | feedback_pids = line_arr[2].split(",")[:num_feedbacks] 289 | feedback_pids = [int(feedback_pid) for feedback_pid in feedback_pids] 290 | neg_pids = line_arr[3].split(',') 291 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 292 | 293 | curr_qid = torch.tensor([qid], dtype=torch.int) 294 | ance_qry_embed = torch.tensor([ance_query_embedding[ance_qid2embedid[qid]]], dtype=torch.float32) 295 | pos_qry_embed = torch.tensor([psg_embeds[0]], dtype=torch.float32) 296 | neg_psg_embeds = [neg_psg_emb for neg_psg_emb in psg_embeds[1:]] 297 | neg_psg_embeds = torch.tensor([neg_psg_embeds], dtype=torch.float32) 298 | feedback_embeds = [ance_passage_embedding[ance_pid2embedid[feedback_pid]] for feedback_pid in feedback_pids] 299 | feedback_embeds = torch.tensor([feedback_embeds], dtype=torch.float32) 300 | 301 | qry_len, qry = query_cache[qid] 302 | feedback_cache_items = [passage_cache[feedback_pid] for feedback_pid in feedback_pids] 303 | 304 | pos_data = GetProcessingFn(args)(psg_embeds[0], qry_len, qry, feedback_cache_items)[0] 305 | res = [data for data in pos_data[:-1]] 306 | res += [ance_qry_embed, pos_qry_embed, neg_psg_embeds, feedback_embeds, curr_qid] 307 | 308 | yield res # (all_input_ids, attention_mask, ance_qry_emb, pos_emb, neg_embs, feedback_embeds) 309 | 310 | return fn 311 | 312 | 313 | def GetTripletDevDataProcessingFn(args, query_cache, passage_cache, num_feedbacks=3): 314 | def fn(ordered_psg_embeds, line, i): 315 | 316 | line_arr = line.split('\t') 317 | qid = int(line_arr[0]) 318 | feedback_pids = line_arr[1].split(",")[:num_feedbacks] 319 | feedback_pids = [int(feedback_pid) for feedback_pid in feedback_pids] 320 | qry_len, qry = query_cache[qid] 321 | feedback_cache_items = [passage_cache[feedback_pid] for feedback_pid in feedback_pids] 322 | 323 | data = GetProcessingFn(args)(None, qry_len, qry, feedback_cache_items)[0] 324 | yield data 325 | 326 | return fn -------------------------------------------------------------------------------- /data_prep/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | import gzip 6 | import pickle 7 | from utils.util import pad_input_ids, multi_file_process, numbered_byte_file_generator, UtilEmbeddingCache 8 | import csv 9 | from torch.utils.data import DataLoader, Dataset, TensorDataset, IterableDataset, get_worker_info 10 | import numpy as np 11 | import argparse 12 | import json 13 | 14 | 15 | def write_query_rel(args, pid2offset, query_file, positive_id_file, out_query_file, out_id_file): 16 | 17 | print( 18 | "Writing query files " + 19 | str(out_query_file) + 20 | " and " + 21 | str(out_id_file)) 22 | query_positive_id = set() 23 | 24 | query_positive_id_path = os.path.join( 25 | args.data_dir, 26 | positive_id_file, 27 | ) 28 | 29 | print("Loading query_2_pos_docid") 30 | with gzip.open(query_positive_id_path, 'rt', encoding='utf8') if positive_id_file[-2:] == "gz" else open(query_positive_id_path, 'r', encoding='utf8') as f: 31 | if args.data_type == 0 or args.dataset != "marco": 32 | tsvreader = csv.reader(f, delimiter=" ") 33 | elif args.data_type == 1: 34 | tsvreader = csv.reader(f, delimiter="\t") 35 | for [topicid, _, docid, rel] in tsvreader: 36 | query_positive_id.add(int(topicid)) 37 | 38 | query_collection_path = os.path.join( 39 | args.data_dir, 40 | query_file, 41 | ) 42 | 43 | out_query_path = os.path.join( 44 | args.out_data_dir, 45 | out_query_file, 46 | ) 47 | 48 | qid2offset = {} 49 | 50 | print('start query file split processing') 51 | multi_file_process( 52 | args, 53 | 32, 54 | query_collection_path, 55 | out_query_path, 56 | QueryPreprocessingFn) 57 | 58 | print('start merging splits') 59 | 60 | idx = 0 61 | with open(out_query_path, 'wb') as f: 62 | for record in numbered_byte_file_generator( 63 | out_query_path, 32, 8 + 4 + args.max_query_length * 4): 64 | q_id = int.from_bytes(record[:8], 'big') 65 | if q_id not in query_positive_id: 66 | # exclude the query as it is not in label set 67 | continue 68 | f.write(record[8:]) 69 | qid2offset[q_id] = idx 70 | idx += 1 71 | if idx < 3: 72 | print(str(idx) + " " + str(q_id)) 73 | 74 | print(f"query_file = {query_file}") 75 | if "dev" in query_file: 76 | prefix = "dev-" 77 | else: 78 | prefix = "train-" 79 | qid2offset_path = os.path.join( 80 | args.out_data_dir, 81 | f"{prefix}qid2offset.pickle" 82 | ) 83 | 84 | with open(qid2offset_path, 'wb') as handle: 85 | pickle.dump(qid2offset, handle, protocol=4) 86 | print("done saving qid2offset") 87 | 88 | print("Total lines written: " + str(idx)) 89 | meta = {'type': 'int32', 'total_number': idx, 90 | 'embedding_size': args.max_query_length} 91 | with open(out_query_path + "_meta", 'w') as f: 92 | json.dump(meta, f) 93 | 94 | embedding_cache = UtilEmbeddingCache(out_query_path) 95 | print("First line") 96 | with embedding_cache as emb: 97 | print(emb[0]) 98 | 99 | out_id_path = os.path.join( 100 | args.out_data_dir, 101 | out_id_file, 102 | ) 103 | 104 | print("Writing qrels") 105 | with gzip.open(query_positive_id_path, 'rt', encoding='utf8') if positive_id_file[-2:] == "gz" else open(query_positive_id_path, 'r', encoding='utf8') as f, \ 106 | open(out_id_path, "w", encoding='utf-8') as out_id: 107 | 108 | if args.data_type == 0 or args.dataset != "marco": 109 | tsvreader = csv.reader(f, delimiter=" ") 110 | else: 111 | tsvreader = csv.reader(f, delimiter="\t") 112 | out_line_count = 0 113 | for [topicid, _, docid, rel] in tsvreader: 114 | topicid = int(topicid) 115 | if args.data_type == 0: 116 | docid = int(docid[1:]) 117 | else: 118 | docid = int(docid) 119 | out_id.write(str(qid2offset[topicid]) + 120 | "\t" + 121 | str(pid2offset[docid]) + 122 | "\t" + 123 | rel + 124 | "\n") 125 | out_line_count += 1 126 | print("Total lines written: " + str(out_line_count)) 127 | 128 | 129 | def preprocess(args): 130 | 131 | pid2offset = {} 132 | if args.data_type == 0: 133 | in_passage_path = os.path.join( 134 | args.data_dir, 135 | "msmarco-docs.tsv", 136 | ) 137 | else: 138 | in_passage_path = os.path.join( 139 | args.data_dir, 140 | "collection.tsv", 141 | ) 142 | 143 | out_passage_path = os.path.join( 144 | args.out_data_dir, 145 | "passages", 146 | ) 147 | 148 | if os.path.exists(out_passage_path): 149 | print("preprocessed data already exist, exit preprocessing") 150 | return 151 | 152 | out_line_count = 0 153 | 154 | print('start passage file split processing') 155 | if args.dataset == "marco": 156 | multi_file_process( 157 | args, 158 | 32, 159 | in_passage_path, 160 | out_passage_path, 161 | PassagePreprocessingFn) 162 | 163 | print('start merging splits') 164 | with open(out_passage_path, 'wb') as f: 165 | for idx, record in enumerate(numbered_byte_file_generator( 166 | out_passage_path, 32, 8 + 4 + args.max_seq_length * 4)): 167 | p_id = int.from_bytes(record[:8], 'big') 168 | f.write(record[8:]) 169 | pid2offset[p_id] = idx 170 | if idx < 3: 171 | print(str(idx) + " " + str(p_id)) 172 | out_line_count += 1 173 | 174 | print("Total lines written: " + str(out_line_count)) 175 | meta = { 176 | 'type': 'int32', 177 | 'total_number': out_line_count, 178 | 'embedding_size': args.max_seq_length} 179 | with open(out_passage_path + "_meta", 'w') as f: 180 | json.dump(meta, f) 181 | embedding_cache = UtilEmbeddingCache(out_passage_path) 182 | print("First line") 183 | with embedding_cache as emb: 184 | print(emb[0]) 185 | 186 | pid2offset_path = os.path.join( 187 | args.out_data_dir, 188 | "pid2offset.pickle", 189 | ) 190 | with open(pid2offset_path, 'wb') as handle: 191 | pickle.dump(pid2offset, handle, protocol=4) 192 | print("done saving pid2offset") 193 | else: 194 | with open(os.path.join(args.data_dir, "..", "marco_preprocessed", "pid2offset.pickle"), "rb") as f: 195 | pid2offset = pickle.load(f) 196 | 197 | if args.data_type == 0: 198 | write_query_rel( 199 | args, 200 | pid2offset, 201 | "msmarco-doctrain-queries.tsv", 202 | "msmarco-doctrain-qrels.tsv", 203 | "train-query", 204 | "train-qrel.tsv") 205 | write_query_rel( 206 | args, 207 | pid2offset, 208 | "msmarco-test2019-queries.tsv", 209 | "2019qrels-docs.txt", 210 | "dev-query", 211 | "dev-qrel.tsv") 212 | else: 213 | if args.dataset == "marco": 214 | write_query_rel( 215 | args, 216 | pid2offset, 217 | "queries.train.tsv", 218 | "qrels.train.tsv", 219 | "train-query", 220 | "train-qrel.tsv") 221 | write_query_rel( 222 | args, 223 | pid2offset, 224 | "queries.dev.small.tsv", 225 | "qrels.dev.small.tsv", 226 | "dev-query", 227 | "dev-qrel.tsv") 228 | 229 | 230 | def PassagePreprocessingFn(args, line, tokenizer): 231 | if args.data_type == 0: 232 | line_arr = line.split('\t') 233 | p_id = int(line_arr[0][1:]) # remove "D" 234 | 235 | url = line_arr[1].rstrip() 236 | title = line_arr[2].rstrip() 237 | p_text = line_arr[3].rstrip() 238 | 239 | #full_text = url + "" + title + "" + p_text 240 | full_text = url + " "+tokenizer.sep_token+" " + title + " "+tokenizer.sep_token+" " + p_text 241 | # keep only first 10000 characters, should be sufficient for any 242 | # experiment that uses less than 500 - 1k tokens 243 | full_text = full_text[:args.max_doc_character] 244 | else: 245 | line = line.strip() 246 | line_arr = line.split('\t') 247 | p_id = int(line_arr[0]) 248 | p_text = line_arr[1].rstrip() 249 | 250 | # keep only first 10000 characters, should be sufficient for any 251 | # experiment that uses less than 500 - 1k tokens 252 | full_text = p_text[:args.max_doc_character] 253 | 254 | passage = tokenizer.encode( 255 | full_text, 256 | add_special_tokens=True, 257 | max_length=args.max_seq_length, 258 | ) 259 | passage_len = min(len(passage), args.max_seq_length) 260 | input_id_b = pad_input_ids(passage, args.max_seq_length,pad_token=tokenizer.pad_token_id) 261 | 262 | return p_id.to_bytes(8,'big') + passage_len.to_bytes(4,'big') + np.array(input_id_b,np.int32).tobytes() 263 | 264 | 265 | def QueryPreprocessingFn(args, line, tokenizer): 266 | line_arr = line.split('\t') 267 | q_id = int(line_arr[0]) 268 | 269 | passage = tokenizer.encode( 270 | line_arr[1].rstrip(), 271 | add_special_tokens=True, 272 | max_length=args.max_query_length) 273 | passage_len = min(len(passage), args.max_query_length) 274 | input_id_b = pad_input_ids(passage, args.max_query_length,pad_token=tokenizer.pad_token_id) 275 | 276 | return q_id.to_bytes(8,'big') + passage_len.to_bytes(4,'big') + np.array(input_id_b,np.int32).tobytes() 277 | 278 | 279 | def GetProcessingFn(args, query=False): 280 | def fn(vals, i): 281 | passage_len, passage = vals 282 | max_len = args.max_query_length if query else args.max_seq_length 283 | 284 | pad_len = max(0, max_len - passage_len) 285 | token_type_ids = ([0] if query else [1]) * passage_len + [0] * pad_len 286 | attention_mask = [1] * passage_len + [0] * pad_len 287 | 288 | passage_collection = [(i, passage, attention_mask, token_type_ids)] 289 | 290 | query2id_tensor = torch.tensor( 291 | [f[0] for f in passage_collection], dtype=torch.long) 292 | all_input_ids_a = torch.tensor( 293 | [f[1] for f in passage_collection], dtype=torch.int) 294 | all_attention_mask_a = torch.tensor( 295 | [f[2] for f in passage_collection], dtype=torch.bool) 296 | all_token_type_ids_a = torch.tensor( 297 | [f[3] for f in passage_collection], dtype=torch.uint8) 298 | 299 | dataset = TensorDataset( 300 | all_input_ids_a, 301 | all_attention_mask_a, 302 | all_token_type_ids_a, 303 | query2id_tensor) 304 | 305 | return [ts for ts in dataset] 306 | 307 | return fn 308 | 309 | 310 | def GetTrainingDataProcessingFn(args, query_cache, passage_cache): 311 | def fn(line, i): 312 | line_arr = line.split('\t') 313 | qid = int(line_arr[0]) 314 | pos_pid = int(line_arr[1]) 315 | neg_pids = line_arr[2].split(',') 316 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 317 | 318 | all_input_ids_a = [] 319 | all_attention_mask_a = [] 320 | 321 | query_data = GetProcessingFn( 322 | args, query=True)( 323 | query_cache[qid], qid)[0] 324 | pos_data = GetProcessingFn( 325 | args, query=False)( 326 | passage_cache[pos_pid], pos_pid)[0] 327 | 328 | pos_label = torch.tensor(1, dtype=torch.long) 329 | neg_label = torch.tensor(0, dtype=torch.long) 330 | 331 | for neg_pid in neg_pids: 332 | neg_data = GetProcessingFn( 333 | args, query=False)( 334 | passage_cache[neg_pid], neg_pid)[0] 335 | yield (query_data[0], query_data[1], query_data[2], pos_data[0], pos_data[1], pos_data[2], pos_label) 336 | yield (query_data[0], query_data[1], query_data[2], neg_data[0], neg_data[1], neg_data[2], neg_label) 337 | 338 | return fn 339 | 340 | 341 | def GetTripletTrainingDataProcessingFn(args, query_cache, passage_cache): 342 | def fn(line, i): 343 | line_arr = line.split('\t') 344 | qid = int(line_arr[0]) 345 | pos_pid = int(line_arr[1]) 346 | neg_pids = line_arr[2].split(',') 347 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 348 | 349 | all_input_ids_a = [] 350 | all_attention_mask_a = [] 351 | 352 | query_data = GetProcessingFn( 353 | args, query=True)( 354 | query_cache[qid], qid)[0] 355 | pos_data = GetProcessingFn( 356 | args, query=False)( 357 | passage_cache[pos_pid], pos_pid)[0] 358 | 359 | for neg_pid in neg_pids: 360 | neg_data = GetProcessingFn( 361 | args, query=False)( 362 | passage_cache[neg_pid], neg_pid)[0] 363 | yield (query_data[0], query_data[1], query_data[2], pos_data[0], pos_data[1], pos_data[2], 364 | neg_data[0], neg_data[1], neg_data[2]) 365 | 366 | return fn 367 | 368 | 369 | def get_arguments(): 370 | parser = argparse.ArgumentParser() 371 | 372 | parser.add_argument( 373 | "--data_dir", 374 | default=None, 375 | type=str, 376 | required=True, 377 | help="The input data dir", 378 | ) 379 | parser.add_argument( 380 | "--out_data_dir", 381 | default=None, 382 | type=str, 383 | required=True, 384 | help="The output data dir", 385 | ) 386 | parser.add_argument( 387 | "--model_type", 388 | default=None, 389 | type=str, 390 | required=True, 391 | help="We use rdot_nll in this work." 392 | ) 393 | parser.add_argument( 394 | "--model_name_or_path", 395 | default=None, 396 | type=str, 397 | required=True, 398 | help="We use roberta-base in this work." 399 | ) 400 | parser.add_argument( 401 | "--max_seq_length", 402 | default=128, 403 | type=int, 404 | help="The maximum total input sequence length after tokenization. Sequences longer " 405 | "than this will be truncated, sequences shorter will be padded.", 406 | ) 407 | parser.add_argument( 408 | "--max_query_length", 409 | default=64, 410 | type=int, 411 | help="The maximum total input sequence length after tokenization. Sequences longer " 412 | "than this will be truncated, sequences shorter will be padded.", 413 | ) 414 | parser.add_argument( 415 | "--max_doc_character", 416 | default=10000, 417 | type=int, 418 | help="used before tokenizer to save tokenizer latency", 419 | ) 420 | parser.add_argument( 421 | "--data_type", 422 | default=0, 423 | type=int, 424 | help="0 for doc, 1 for passage", 425 | ) 426 | parser.add_argument( 427 | "--dataset", 428 | choices=["marco", "trec19psg", "trec20psg", "dlhard"], 429 | help="Name of the dataset.", 430 | required=True 431 | ) 432 | 433 | args = parser.parse_args() 434 | 435 | return args 436 | 437 | 438 | def main(): 439 | args = get_arguments() 440 | 441 | if not os.path.exists(args.out_data_dir): 442 | os.makedirs(args.out_data_dir) 443 | preprocess(args) 444 | 445 | 446 | if __name__ == '__main__': 447 | main() -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from utils.util import * 4 | from data import * 5 | from lamb import Lamb 6 | import torch 7 | from tqdm import tqdm 8 | from tensorboardX import SummaryWriter 9 | from transformers import AdamW, get_linear_schedule_with_warmup 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class Trainer: 14 | def __init__(self, args, model, tokenizer, train_dataloader): 15 | self.args = args 16 | self.model = model 17 | self.tokenizer = tokenizer 18 | self.train_dataloader = train_dataloader 19 | 20 | def train(self): 21 | args, model, tokenizer, train_dataloader = self.args, self.model, self.tokenizer, self.train_dataloader 22 | """ Train the model """ 23 | logger.info("Training/evaluation parameters %s", args) 24 | tb_writer = None 25 | if is_first_worker(): 26 | tb_writer = SummaryWriter(log_dir=os.path.join(args.output_dir, "tb_logs")) 27 | 28 | real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ 29 | (torch.distributed.get_world_size() if args.local_rank != -1 else 1) 30 | 31 | optimizer_grouped_parameters = [] 32 | layer_optim_params = set() 33 | for layer_name in [ 34 | "roberta.embeddings", 35 | "score_out", 36 | "downsample1", 37 | "downsample2", 38 | "downsample3"]: 39 | layer = getattr_recursive(model, layer_name) 40 | if layer is not None: 41 | optimizer_grouped_parameters.append({"params": layer.parameters()}) 42 | for p in layer.parameters(): 43 | layer_optim_params.add(p) 44 | if getattr_recursive(model, "roberta.encoder.layer") is not None: 45 | for layer in model.roberta.encoder.layer: 46 | optimizer_grouped_parameters.append({"params": layer.parameters()}) 47 | for p in layer.parameters(): 48 | layer_optim_params.add(p) 49 | 50 | optimizer_grouped_parameters.append( 51 | {"params": [p for p in model.parameters() if p not in layer_optim_params]}) 52 | 53 | if args.optimizer.lower() == "lamb": 54 | optimizer = Lamb( 55 | optimizer_grouped_parameters, 56 | lr=args.learning_rate, 57 | eps=args.adam_epsilon) 58 | elif args.optimizer.lower() == "adamw": 59 | optimizer = AdamW( 60 | optimizer_grouped_parameters, 61 | lr=args.learning_rate, 62 | eps=args.adam_epsilon) 63 | else: 64 | raise Exception( 65 | "optimizer {0} not recognized! Can only be lamb or adamW".format( 66 | args.optimizer)) 67 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 68 | num_training_steps=args.max_steps) 69 | 70 | # Check if saved optimizer or scheduler states exist 71 | # Load in optimizer and scheduler states 72 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and args.load_optimizer_scheduler: 73 | optimizer.load_state_dict( 74 | torch.load( 75 | os.path.join( 76 | args.model_name_or_path, 77 | "optimizer.pt"))) 78 | if os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt")) and args.load_optimizer_scheduler: 79 | scheduler.load_state_dict( 80 | torch.load( 81 | os.path.join( 82 | args.model_name_or_path, 83 | "scheduler.pt"))) 84 | 85 | 86 | # multi-gpu training (should be after apex fp16 initialization) 87 | if args.n_gpu > 1: 88 | model = torch.nn.DataParallel(model) 89 | 90 | # Distributed training (should be after apex fp16 initialization) 91 | if args.local_rank != -1: 92 | model = torch.nn.parallel.DistributedDataParallel( 93 | model, 94 | device_ids=[ 95 | args.local_rank], 96 | output_device=args.local_rank, 97 | find_unused_parameters=True, 98 | ) 99 | 100 | # Train 101 | logger.info("***** Running training *****") 102 | logger.info(" Max steps = %d", args.max_steps) 103 | logger.info( 104 | " Instantaneous batch size per GPU = %d", 105 | args.per_gpu_train_batch_size) 106 | logger.info( 107 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 108 | args.train_batch_size 109 | * args.gradient_accumulation_steps 110 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 111 | ) 112 | logger.info( 113 | " Gradient Accumulation steps = %d", 114 | args.gradient_accumulation_steps) 115 | 116 | global_step = 0 117 | # Check if continuing training from a checkpoint 118 | if os.path.exists(args.model_name_or_path): 119 | # set global_step to gobal_step of last saved checkpoint from model 120 | # path 121 | if "-" in args.model_name_or_path: 122 | global_step = int( 123 | args.model_name_or_path.split("-")[-1].split("/")[0]) 124 | else: 125 | global_step = 0 126 | logger.info( 127 | " Continuing training from checkpoint, will skip to saved global_step") 128 | logger.info(" Continuing training from global step %d", global_step) 129 | 130 | tr_loss = 0.0 131 | model.zero_grad() 132 | model.train() 133 | set_seed(args) # Added here for reproducibility 134 | 135 | step = 0 136 | 137 | train_dataloader = self.train_dataloader 138 | train_dataloader_iter = iter(train_dataloader) 139 | 140 | 141 | while global_step < args.max_steps: 142 | # pdb.set_trace() 143 | try: 144 | batch = next(train_dataloader_iter) 145 | except StopIteration: 146 | logger.info("Finished iterating current dataset, begin reiterate") 147 | train_dataloader_iter = iter(train_dataloader) 148 | batch = next(train_dataloader_iter) 149 | 150 | batch = tuple(t.to(args.device) for t in batch) 151 | step += 1 152 | inputs = { 153 | "input_ids": batch[0].long(), 154 | "attention_mask": batch[1].long(), 155 | "pos_emb": batch[2].float(), 156 | "neg_emb": batch[3].float(), 157 | } 158 | # pdb.set_trace() 159 | # sync gradients only at gradient accumulation step 160 | if step % args.gradient_accumulation_steps == 0 or args.local_rank == -1: 161 | outputs = model(**inputs) 162 | else: 163 | with model.no_sync(): 164 | outputs = model(**inputs) 165 | # model outputs are always tuple in transformers (see doc) 166 | loss = outputs[0] 167 | 168 | if args.n_gpu > 1: 169 | loss = loss.mean() # mean() to average on multi-gpu parallel training 170 | if args.gradient_accumulation_steps > 1: 171 | loss = loss / args.gradient_accumulation_steps 172 | 173 | if step % args.gradient_accumulation_steps == 0 or args.local_rank == -1: 174 | loss.backward() 175 | else: 176 | with model.no_sync(): 177 | loss.backward() 178 | 179 | tr_loss += loss.item() 180 | if step % args.gradient_accumulation_steps == 0: 181 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 182 | optimizer.step() 183 | scheduler.step() # Update learning rate schedule 184 | model.zero_grad() 185 | global_step += 1 186 | 187 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 188 | logs = {} 189 | loss_scalar = tr_loss / args.logging_steps 190 | learning_rate_scalar = scheduler.get_lr()[0] 191 | logs["learning_rate"] = learning_rate_scalar 192 | logs["loss"] = loss_scalar 193 | tr_loss = 0 194 | 195 | if is_first_worker(): 196 | for key, value in logs.items(): 197 | tb_writer.add_scalar(key, value, global_step) 198 | logger.info(json.dumps({**logs, **{"step": global_step}})) 199 | 200 | if is_first_worker() and args.save_steps > 0 and global_step % args.save_steps == 0: 201 | # Save model checkpoint 202 | output_dir = os.path.join( 203 | args.output_dir, "checkpoint-{}".format(global_step)) 204 | if not os.path.exists(output_dir): 205 | os.makedirs(output_dir) 206 | model_to_save = ( 207 | model.module if hasattr(model, "module") else model 208 | ) # Take care of distributed/parallel training 209 | model_to_save.save_pretrained(output_dir) 210 | tokenizer.save_pretrained(output_dir) 211 | 212 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 213 | logger.info("Saving model checkpoint to %s", output_dir) 214 | 215 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 216 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 217 | # make sure logging_steps < save_steps 218 | torch.save( 219 | { 220 | "loss": loss_scalar, 221 | "learning_rate": learning_rate_scalar 222 | }, 223 | os.path.join(output_dir, "stats.pt") 224 | ) 225 | logger.info( 226 | "Saving optimizer and scheduler states to %s", 227 | output_dir) 228 | 229 | if is_first_worker(): 230 | tb_writer.close() 231 | 232 | return global_step 233 | 234 | 235 | class Evaluator: 236 | def __init__(self, args, dev_dataloader): 237 | self.args = args 238 | self.dev_dataloader = dev_dataloader 239 | self.prev_evaluated_ckpt = 0 if args.prev_evaluated_ckpt is None else args.prev_evaluated_ckpt 240 | 241 | def eval(self, rerank_depths=None, mode="rerank", results=None): 242 | dataset = self.args.dataset 243 | _, query_embedding2id, passage_embedding, passage_embedding2id = load_embeddings(self.args, mode="dev") 244 | # query_embs, query_embedding2id, passage_embedding, passage_embedding2id = load_embeddings(self.args, mode="dev") 245 | dev_query_positive_id = load_positve_query_id(self.args) 246 | binary_dev_query_positive_id = None if dataset in {"marco", "marco_eval"} else load_positve_query_id(self.args, binary=True) 247 | 248 | topN = 1000 249 | if is_first_worker(): 250 | # TODO: this is for current compatibility. change all names to f"eval_logs_{mode}_{dataset}" later 251 | name = f"eval_logs_{mode}" if dataset == "marco" else f"eval_logs_{mode}_{dataset}" 252 | tb_writer = SummaryWriter(log_dir=os.path.join(self.args.output_dir, name)) 253 | while self.prev_evaluated_ckpt < self.args.max_steps: 254 | curr_ckpt = self.prev_evaluated_ckpt + self.args.save_steps 255 | if curr_ckpt >= self.args.end_eval_ckpt: 256 | break 257 | output_dir = os.path.join(self.args.output_dir, f"checkpoint-{curr_ckpt}") 258 | while not os.path.exists(output_dir): 259 | logging.info(f"Waiting for step {curr_ckpt}") 260 | time.sleep(100) 261 | logging.info(f"Evaluating step {curr_ckpt}...") 262 | self.args.model_name_or_path = output_dir 263 | 264 | model = None 265 | while model is None: 266 | try: 267 | _, model = load_model(self.args) 268 | except: 269 | time.sleep(100) 270 | pass 271 | model.eval() 272 | 273 | # query embedding inference 274 | dev_query_embs_path = os.path.join(output_dir, f"dev_query_embs_{dataset}.npy") 275 | # TODO: distributed eval currently not supported (currently not necessary considering the size of marco dev) 276 | loss = 0 277 | cnt = 0 278 | if os.path.exists(dev_query_embs_path): 279 | query_embs = np.load(dev_query_embs_path) 280 | else: 281 | query_embs = [] 282 | for batch in tqdm(self.dev_dataloader, desc=f"Eval ckpt{curr_ckpt}"): 283 | cnt += 1 284 | batch = tuple(t.to(self.args.device) for t in batch) 285 | inputs = { 286 | "input_ids": batch[0].long(), 287 | "attention_mask": batch[1].long(), 288 | "pos_emb": batch[2].float(), 289 | "neg_emb": batch[3].float(), 290 | } 291 | query_emb = model.query_emb(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) 292 | loss += model(**inputs)[0].item() 293 | query_embs.append(query_emb.detach().cpu().numpy()) 294 | # 20 is the number of negative documents per positive document 295 | query_embs = np.concatenate(query_embs) 296 | np.save(os.path.join(output_dir, f"dev_query_embs_{dataset}.npy"), query_embs) 297 | logs = dict() 298 | if mode == "full": 299 | dev_I = full_rank(query_embs, passage_embedding, output_dir, dataset) 300 | result = EvalDevQuery(query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I, topN) 301 | final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, \ 302 | metrics, prediction = result 303 | logs = { 304 | "full_ndcg_10": final_ndcg, 305 | "full_map_10": final_Map, 306 | "full_pytrec_mrr": final_mrr, 307 | "full_ms_mrr": ms_mrr["MRR @10"], 308 | f"full_recall_{topN}": final_recall, 309 | f"hole_rate": hole_rate, 310 | f"Ahole_rate": Ahole_rate 311 | } 312 | if dataset not in {"marco", "marco_eval"}: 313 | binary_result = EvalDevQuery(query_embedding2id, passage_embedding2id, binary_dev_query_positive_id, 314 | dev_I, topN) 315 | _, _, _, _, binary_recall, _, binary_ms_mrr, _, _, _ = binary_result 316 | logs[f"full_binary_recall_{topN}"] = binary_recall 317 | logs["full_binary_ms_mrr"] = binary_ms_mrr["MRR @10"] 318 | 319 | elif mode == "rerank": 320 | first_stage_inn = np.load(self.args.first_stage_inn_path, allow_pickle=True) 321 | dev_I = rerank(first_stage_inn, query_embs, query_embedding2id, passage_embedding, passage_embedding2id, 322 | output_dir, dataset) 323 | reranked_w_scores = [] 324 | for inn in dev_I: 325 | reranked_w_scores.append({pid: rank for (rank, pid) in enumerate(inn)}) # rank, id 326 | if rerank_depths is not None: 327 | for depth in rerank_depths: 328 | reranked_I = get_inn_rerank_depth(reranked_w_scores, first_stage_inn, depth) 329 | result = EvalDevQuery(query_embedding2id, passage_embedding2id, dev_query_positive_id, 330 | reranked_I, topN) 331 | final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, \ 332 | metrics, prediction = result 333 | if loss > 0 and cnt > 0: 334 | logs["loss"] = loss / cnt 335 | logs[f"reranked_depth{depth}_ndcg_10"] = final_ndcg 336 | logs[f"reranked_depth{depth}_map_10"] = final_Map 337 | logs[f"reranked_depth{depth}_pytrec_mrr"] = final_mrr 338 | logs[f"reranked_depth{depth}_ms_mrr"] = ms_mrr["MRR @10"] 339 | logs["hole_rate"] = hole_rate 340 | logs["Ahole_rate"] = Ahole_rate 341 | if dataset not in {"marco", "marco_eval"}: 342 | binary_result = EvalDevQuery(query_embedding2id, passage_embedding2id, 343 | binary_dev_query_positive_id, 344 | reranked_I, topN) 345 | _, _, _, _, binary_recall, _, binary_ms_mrr, _, _, _ = binary_result 346 | logs[f"rerank_depth{depth}_binary_recall_{topN}"] = binary_recall 347 | logs[f"rerank_depth{depth}_binary_ms_mrr"] = binary_ms_mrr["MRR @10"] 348 | if results is not None: 349 | results.append(logs) 350 | if is_first_worker(): 351 | logger.info(json.dumps({**logs, **{"step": curr_ckpt}})) 352 | for key, value in logs.items(): 353 | tb_writer.add_scalar(key, value, curr_ckpt) 354 | self.prev_evaluated_ckpt = curr_ckpt 355 | 356 | 357 | -------------------------------------------------------------------------------- /data_prep/run_ann_data_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | import torch 5 | from utils.util import ( 6 | barrier_array_merge, 7 | convert_to_string_id, 8 | is_first_worker, 9 | UtilStreamingDataset, 10 | UtilEmbeddingCache, 11 | get_checkpoint_no, 12 | get_latest_ann_data, 13 | ) 14 | import csv 15 | from preprocess_data import GetProcessingFn 16 | from model import RobertaDot_NLL_LN 17 | from transformers import RobertaTokenizer, RobertaConfig 18 | import torch.distributed as dist 19 | from tqdm import tqdm, trange 20 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 21 | import numpy as np 22 | import argparse 23 | import logging 24 | import random 25 | import time 26 | import pytrec_eval 27 | 28 | torch.multiprocessing.set_sharing_strategy('file_system') 29 | 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | # ANN - active learning ------------------------------------------------------ 35 | 36 | try: 37 | from torch.utils.tensorboard import SummaryWriter 38 | except ImportError: 39 | from tensorboardX import SummaryWriter 40 | 41 | 42 | def get_latest_checkpoint(args): 43 | if not os.path.exists(args.training_dir): 44 | return args.init_model_dir, 0 45 | subdirectories = list(next(os.walk(args.training_dir))[1]) 46 | 47 | def valid_checkpoint(checkpoint): 48 | chk_path = os.path.join(args.training_dir, checkpoint) 49 | scheduler_path = os.path.join(chk_path, "scheduler.pt") 50 | return os.path.exists(scheduler_path) 51 | 52 | checkpoint_nums = [get_checkpoint_no( 53 | s) for s in subdirectories if valid_checkpoint(s)] 54 | 55 | if len(checkpoint_nums) > 0: 56 | return os.path.join(args.training_dir, "checkpoint-" + 57 | str(max(checkpoint_nums))) + "/", max(checkpoint_nums) 58 | return args.init_model_dir, 0 59 | 60 | 61 | def load_positive_ids(args): 62 | 63 | logger.info("Loading query_2_pos_docid") 64 | training_query_positive_id = {} 65 | # query_positive_id_path = os.path.join(args.data_dir, "train-qrel.tsv") 66 | # with open(query_positive_id_path, 'r', encoding='utf8') as f: 67 | # tsvreader = csv.reader(f, delimiter="\t") 68 | # for [topicid, docid, rel] in tsvreader: 69 | # assert rel == "1" 70 | # topicid = int(topicid) 71 | # docid = int(docid) 72 | # training_query_positive_id[topicid] = docid 73 | 74 | logger.info("Loading dev query_2_pos_docid") 75 | dev_query_positive_id = {} 76 | query_positive_id_path = os.path.join(args.data_dir, "dev-qrel.tsv") 77 | 78 | with open(query_positive_id_path, 'r', encoding='utf8') as f: 79 | tsvreader = csv.reader(f, delimiter="\t") 80 | for [topicid, docid, rel] in tsvreader: 81 | topicid = int(topicid) 82 | docid = int(docid) 83 | if topicid not in dev_query_positive_id: 84 | dev_query_positive_id[topicid] = {} 85 | dev_query_positive_id[topicid][docid] = int(rel) 86 | 87 | return training_query_positive_id, dev_query_positive_id 88 | 89 | 90 | def load_model(args, checkpoint_path): 91 | args.model_type = args.model_type.lower() 92 | args.model_name_or_path = checkpoint_path 93 | config = RobertaConfig.from_pretrained(args.model_name_or_path, num_labels=2, finetuning_task="MSMarco") 94 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path, do_lower_case=True) 95 | model = RobertaDot_NLL_LN.from_pretrained( 96 | args.model_name_or_path, 97 | from_tf=bool(".ckpt" in args.model_name_or_path), 98 | config=config, 99 | cache_dir=args.cache_dir if args.cache_dir else None, 100 | ) 101 | model.to(args.device) 102 | logger.info("Inference parameters %s", args) 103 | if args.local_rank != -1: 104 | model = torch.nn.parallel.DistributedDataParallel( 105 | model, 106 | device_ids=[ 107 | args.local_rank], 108 | output_device=args.local_rank, 109 | find_unused_parameters=True, 110 | ) 111 | return config, tokenizer, model 112 | 113 | 114 | def InferenceEmbeddingFromStreamDataLoader( 115 | args, 116 | model, 117 | train_dataloader, 118 | is_query_inference=True, 119 | prefix=""): 120 | # expect dataset from ReconstructTrainingSet 121 | results = {} 122 | eval_batch_size = args.per_gpu_eval_batch_size 123 | 124 | # Inference! 125 | logger.info("***** Running ANN Embedding Inference *****") 126 | logger.info(" Batch size = %d", eval_batch_size) 127 | 128 | embedding = [] 129 | embedding2id = [] 130 | 131 | if args.local_rank != -1: 132 | dist.barrier() 133 | model.eval() 134 | 135 | for batch in tqdm(train_dataloader, 136 | desc="Inferencing", 137 | disable=args.local_rank not in [-1, 138 | 0], 139 | position=0, 140 | leave=True): 141 | 142 | idxs = batch[3].detach().numpy() # [#B] 143 | 144 | batch = tuple(t.to(args.device) for t in batch) 145 | 146 | with torch.no_grad(): 147 | inputs = { 148 | "input_ids": batch[0].long(), 149 | "attention_mask": batch[1].long()} 150 | if is_query_inference: 151 | embs = model.module.query_emb(**inputs) 152 | else: 153 | embs = model.module.body_emb(**inputs) 154 | 155 | embs = embs.detach().cpu().numpy() 156 | 157 | # check for multi chunk output for long sequence 158 | if len(embs.shape) == 3: 159 | for chunk_no in range(embs.shape[1]): 160 | embedding2id.append(idxs) 161 | embedding.append(embs[:, chunk_no, :]) 162 | else: 163 | embedding2id.append(idxs) 164 | embedding.append(embs) 165 | 166 | embedding = np.concatenate(embedding, axis=0) 167 | embedding2id = np.concatenate(embedding2id, axis=0) 168 | return embedding, embedding2id 169 | 170 | 171 | # streaming inference 172 | def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference=True): 173 | inference_batch_size = args.per_gpu_eval_batch_size # * max(1, args.n_gpu) 174 | inference_dataset = UtilStreamingDataset(f, fn) 175 | inference_dataloader = DataLoader( 176 | inference_dataset, 177 | batch_size=inference_batch_size) 178 | 179 | if args.local_rank != -1: 180 | dist.barrier() # directory created 181 | 182 | _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader( 183 | args, model, inference_dataloader, is_query_inference=is_query_inference, prefix=prefix) 184 | 185 | logger.info(f"merging {prefix} embeddings") 186 | 187 | barrier_array_merge( 188 | args, 189 | _embedding, 190 | prefix=prefix + 191 | "_emb_p_", 192 | load_cache=False) 193 | 194 | logger.info(f"finished merging {prefix} embeddings") 195 | 196 | barrier_array_merge( 197 | args, 198 | _embedding2id, 199 | prefix=prefix + 200 | "_embid_p_", 201 | load_cache=False) 202 | 203 | logger.info(f"finished merging {prefix} embeddings ids") 204 | 205 | def generate_new_ann( 206 | args, 207 | checkpoint_path, 208 | latest_step_num): 209 | config, tokenizer, model = load_model(args, checkpoint_path) 210 | 211 | logger.info("***** inference of dev query *****") 212 | dev_query_collection_path = os.path.join(args.data_dir, "dev-query") 213 | dev_query_cache = UtilEmbeddingCache(dev_query_collection_path) 214 | with dev_query_cache as emb: 215 | StreamInferenceDoc(args, model, GetProcessingFn( 216 | args, query=True), "dev_query_" + str(latest_step_num) + "_", emb, is_query_inference=True) 217 | 218 | if args.dataset == "marco": 219 | logger.info("***** inference of train query *****") 220 | train_query_collection_path = os.path.join(args.data_dir, "train-query") 221 | train_query_cache = UtilEmbeddingCache(train_query_collection_path) 222 | with train_query_cache as emb: 223 | StreamInferenceDoc(args, model, GetProcessingFn( 224 | args, query=True), "train_query_" + str(latest_step_num) + "_", emb, is_query_inference=True) 225 | 226 | logger.info("***** inference of passages *****") 227 | passage_collection_path = os.path.join(args.data_dir, "passages") 228 | passage_cache = UtilEmbeddingCache(passage_collection_path) 229 | with passage_cache as emb: 230 | StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_" + str(latest_step_num) + "_", 231 | emb, is_query_inference=False) 232 | logger.info("***** Done passage inference *****") 233 | 234 | 235 | def GenerateNegativePassaageID( 236 | args, 237 | query_embedding2id, 238 | passage_embedding2id, 239 | training_query_positive_id, 240 | I_nearest_neighbor, 241 | effective_q_id): 242 | query_negative_passage = {} 243 | SelectTopK = args.ann_measure_topk_mrr 244 | mrr = 0 # only meaningful if it is SelectTopK = True 245 | num_queries = 0 246 | 247 | for query_idx in range(I_nearest_neighbor.shape[0]): 248 | 249 | query_id = query_embedding2id[query_idx] 250 | 251 | if query_id not in effective_q_id: 252 | continue 253 | 254 | num_queries += 1 255 | 256 | pos_pid = training_query_positive_id[query_id] 257 | top_ann_pid = I_nearest_neighbor[query_idx, :].copy() 258 | 259 | if SelectTopK: 260 | selected_ann_idx = top_ann_pid[:args.negative_sample + 1] 261 | else: 262 | negative_sample_I_idx = list(range(I_nearest_neighbor.shape[1])) 263 | random.shuffle(negative_sample_I_idx) 264 | selected_ann_idx = top_ann_pid[negative_sample_I_idx] 265 | 266 | query_negative_passage[query_id] = [] 267 | 268 | neg_cnt = 0 269 | rank = 0 270 | 271 | for idx in selected_ann_idx: 272 | neg_pid = passage_embedding2id[idx] 273 | rank += 1 274 | if neg_pid == pos_pid: 275 | if rank <= 10: 276 | mrr += 1 / rank 277 | continue 278 | 279 | if neg_pid in query_negative_passage[query_id]: 280 | continue 281 | 282 | if neg_cnt >= args.negative_sample: 283 | break 284 | 285 | query_negative_passage[query_id].append(neg_pid) 286 | neg_cnt += 1 287 | 288 | if SelectTopK: 289 | print("Rank:" + str(args.rank) + 290 | " --- ANN MRR:" + str(mrr / num_queries)) 291 | 292 | return query_negative_passage 293 | 294 | 295 | def EvalDevQuery( 296 | args, 297 | query_embedding2id, 298 | passage_embedding2id, 299 | dev_query_positive_id, 300 | I_nearest_neighbor): 301 | # [qid][docid] = docscore, here we use -rank as score, so the higher the rank (1 > 2), the higher the score (-1 > -2) 302 | prediction = {} 303 | 304 | for query_idx in range(I_nearest_neighbor.shape[0]): 305 | query_id = query_embedding2id[query_idx] 306 | prediction[query_id] = {} 307 | 308 | top_ann_pid = I_nearest_neighbor[query_idx, :].copy() 309 | selected_ann_idx = top_ann_pid[:50] 310 | rank = 0 311 | seen_pid = set() 312 | for idx in selected_ann_idx: 313 | pred_pid = passage_embedding2id[idx] 314 | 315 | if pred_pid not in seen_pid: 316 | # this check handles multiple vector per document 317 | rank += 1 318 | prediction[query_id][pred_pid] = -rank 319 | seen_pid.add(pred_pid) 320 | 321 | # use out of the box evaluation script 322 | evaluator = pytrec_eval.RelevanceEvaluator( 323 | convert_to_string_id(dev_query_positive_id), {'map', 'map_cut', 'ndcg_cut'}) 324 | 325 | eval_query_cnt = 0 326 | result = evaluator.evaluate(convert_to_string_id(prediction)) 327 | metric_results = { 328 | "map": 0, 329 | "ndcg_cut_3": 0, 330 | "ndcg_cut_5": 0, 331 | "ndcg_cut_10": 0 332 | } 333 | 334 | for k in result.keys(): 335 | eval_query_cnt += 1 336 | for metric in metric_results.keys(): 337 | metric_results[metric] += result[k][metric] 338 | # ndcg += result[k]["ndcg_cut_10"] 339 | for metric in metric_results.keys(): 340 | metric_results[metric] /= eval_query_cnt 341 | # final_ndcg = ndcg / eval_query_cnt 342 | final_ndcg = metric_results["ndcg_cut_3"] 343 | 344 | # print("Rank:" + str(args.rank) + " --- ANN NDCG@10:" + str(final_ndcg)) 345 | print("Rank:" + str(args.rank) + " --- ANN Results:" + str(metric_results)) 346 | 347 | return final_ndcg, eval_query_cnt 348 | 349 | 350 | def get_arguments(): 351 | parser = argparse.ArgumentParser() 352 | 353 | # Required parameters 354 | parser.add_argument( 355 | "--data_dir", 356 | default=None, 357 | type=str, 358 | required=True, 359 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.", 360 | ) 361 | 362 | parser.add_argument( 363 | "--training_dir", 364 | default="", 365 | type=str, 366 | help="Training dir, will look for latest checkpoint dir in here", 367 | ) 368 | 369 | parser.add_argument( 370 | "--init_model_dir", 371 | default=None, 372 | type=str, 373 | required=True, 374 | help="Initial model dir, will use this if no checkpoint is found in model_dir", 375 | ) 376 | 377 | parser.add_argument( 378 | "--last_checkpoint_dir", 379 | default="", 380 | type=str, 381 | help="Last checkpoint used, this is for rerunning this script when some ann data is already generated", 382 | ) 383 | 384 | parser.add_argument( 385 | "--model_type", 386 | default=None, 387 | type=str, 388 | required=True, 389 | help="Model type." 390 | ) 391 | 392 | parser.add_argument( 393 | "--output_dir", 394 | default=None, 395 | type=str, 396 | required=True, 397 | help="The output directory where the training data will be written", 398 | ) 399 | 400 | parser.add_argument( 401 | "--cache_dir", 402 | default=None, 403 | type=str, 404 | required=True, 405 | help="The directory where cached data will be written", 406 | ) 407 | 408 | parser.add_argument( 409 | "--end_output_num", 410 | default=- 411 | 1, 412 | type=int, 413 | help="Stop after this number of data versions has been generated, default run forever", 414 | ) 415 | 416 | parser.add_argument( 417 | "--max_seq_length", 418 | default=128, 419 | type=int, 420 | help="The maximum total input sequence length after tokenization. Sequences longer " 421 | "than this will be truncated, sequences shorter will be padded.", 422 | ) 423 | 424 | parser.add_argument( 425 | "--max_query_length", 426 | default=64, 427 | type=int, 428 | help="The maximum total input sequence length after tokenization. Sequences longer " 429 | "than this will be truncated, sequences shorter will be padded.", 430 | ) 431 | 432 | parser.add_argument( 433 | "--max_doc_character", 434 | default=10000, 435 | type=int, 436 | help="used before tokenizer to save tokenizer latency", 437 | ) 438 | 439 | parser.add_argument( 440 | "--per_gpu_eval_batch_size", 441 | default=128, 442 | type=int, 443 | help="The starting output file number", 444 | ) 445 | 446 | parser.add_argument( 447 | "--ann_chunk_factor", 448 | default=5, # for 500k queryes, divided into 100k chunks for each epoch 449 | type=int, 450 | help="devide training queries into chunks", 451 | ) 452 | 453 | parser.add_argument( 454 | "--topk_training", 455 | default=500, 456 | type=int, 457 | help="top k from which negative samples are collected", 458 | ) 459 | 460 | parser.add_argument( 461 | "--negative_sample", 462 | default=5, 463 | type=int, 464 | help="at each resample, how many negative samples per query do I use", 465 | ) 466 | 467 | parser.add_argument( 468 | "--ann_measure_topk_mrr", 469 | default=False, 470 | action="store_true", 471 | help="load scheduler from checkpoint or not", 472 | ) 473 | 474 | parser.add_argument( 475 | "--only_keep_latest_embedding_file", 476 | default=False, 477 | action="store_true", 478 | help="load scheduler from checkpoint or not", 479 | ) 480 | 481 | parser.add_argument( 482 | "--no_cuda", 483 | action="store_true", 484 | help="Avoid using CUDA when available", 485 | ) 486 | 487 | parser.add_argument( 488 | "--local_rank", 489 | type=int, 490 | default=-1, 491 | help="For distributed training: local_rank", 492 | ) 493 | 494 | parser.add_argument( 495 | "--server_ip", 496 | type=str, 497 | default="", 498 | help="For distant debugging.", 499 | ) 500 | 501 | parser.add_argument( 502 | "--server_port", 503 | type=str, 504 | default="", 505 | help="For distant debugging.", 506 | ) 507 | 508 | parser.add_argument( 509 | "--inference", 510 | default=False, 511 | action="store_true", 512 | help="only do inference if specify", 513 | ) 514 | 515 | parser.add_argument( 516 | "--encode_passages", 517 | default=False, 518 | action="store_true", 519 | help="only encode the passages if specify", 520 | ) 521 | 522 | parser.add_argument( 523 | "--dataset", 524 | default=None, 525 | help="Name of the dataset." 526 | ) 527 | 528 | args = parser.parse_args() 529 | 530 | return args 531 | 532 | 533 | def set_env(args): 534 | # Setup CUDA, GPU & distributed training 535 | if args.local_rank == -1 or args.no_cuda: 536 | device = torch.device( 537 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 538 | args.n_gpu = torch.cuda.device_count() 539 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 540 | torch.cuda.set_device(args.local_rank) 541 | device = torch.device("cuda", args.local_rank) 542 | torch.distributed.init_process_group(backend="nccl") 543 | args.n_gpu = 1 544 | args.device = device 545 | 546 | # store args 547 | if args.local_rank != -1: 548 | args.world_size = torch.distributed.get_world_size() 549 | args.rank = dist.get_rank() 550 | 551 | # Setup logging 552 | logging.basicConfig( 553 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 554 | datefmt="%m/%d/%Y %H:%M:%S", 555 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 556 | ) 557 | logger.warning( 558 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", 559 | args.local_rank, 560 | device, 561 | args.n_gpu, 562 | bool(args.local_rank != -1), 563 | ) 564 | 565 | 566 | def ann_data_gen(args): 567 | last_checkpoint = args.last_checkpoint_dir 568 | ann_no, ann_path, ndcg_json = get_latest_ann_data(args.output_dir) 569 | output_num = ann_no + 1 570 | 571 | logger.info("starting output number %d", output_num) 572 | 573 | if is_first_worker(): 574 | if not os.path.exists(args.output_dir): 575 | os.makedirs(args.output_dir) 576 | if not os.path.exists(args.cache_dir): 577 | os.makedirs(args.cache_dir) 578 | dist.barrier() 579 | 580 | while args.end_output_num == -1 or output_num <= args.end_output_num: 581 | next_checkpoint, latest_step_num = get_latest_checkpoint(args) 582 | 583 | if args.only_keep_latest_embedding_file: 584 | latest_step_num = 0 585 | 586 | if next_checkpoint == last_checkpoint: 587 | time.sleep(60) 588 | else: 589 | logger.info("start generate ann data number %d", output_num) 590 | logger.info("next checkpoint at " + next_checkpoint) 591 | generate_new_ann(args, next_checkpoint, latest_step_num) 592 | if args.inference: 593 | break 594 | logger.info("finished generating ann data number %d", output_num) 595 | output_num += 1 596 | last_checkpoint = next_checkpoint 597 | if args.local_rank != -1: 598 | dist.barrier() 599 | 600 | 601 | def main(): 602 | args = get_arguments() 603 | set_env(args) 604 | ann_data_gen(args) 605 | 606 | 607 | if __name__ == "__main__": 608 | main() 609 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os import listdir 4 | from os.path import isfile, join 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 6 | import pandas as pd 7 | from sklearn.metrics import roc_curve, auc 8 | import gzip 9 | import copy 10 | import torch 11 | from torch import nn 12 | import torch.distributed as dist 13 | from tqdm import tqdm, trange 14 | import json 15 | import logging 16 | import random 17 | import pytrec_eval 18 | import pickle 19 | import numpy as np 20 | import torch 21 | import csv 22 | import faiss 23 | from msmarco_eval import compute_metrics 24 | 25 | torch.multiprocessing.set_sharing_strategy('file_system') 26 | from multiprocessing import Process 27 | from torch.utils.data import DataLoader, Dataset, TensorDataset, IterableDataset 28 | import re 29 | from model import MSMarcoConfig, RobertaDot_NLL_LN 30 | from transformers import RobertaConfig, RobertaTokenizer 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def load_embedding_prefix(prefix): 36 | embedding = [] 37 | embedding2id = [] 38 | for i in range(8): 39 | try: 40 | with open(prefix + "__emb_p__data_obj_" + str(i) + ".pb", 41 | "rb") as handle: 42 | embedding.append(np.load(handle, allow_pickle=True)) 43 | with open(prefix + "__embid_p__data_obj_" + str(i) + ".pb", 44 | "rb") as handle: 45 | embedding2id.append(np.load(handle, allow_pickle=True)) 46 | except Exception as e: 47 | print(f"Loaded {i} chunks of embeddings.") 48 | break 49 | embedding = np.concatenate(embedding, axis=0) 50 | embedding2id = np.concatenate(embedding2id, axis=0) 51 | return embedding, embedding2id 52 | 53 | 54 | def load_embeddings(args, mode="train", checkpoint=0): 55 | 56 | query_prefix = f"{mode}_query_" 57 | 58 | query_embedding, query_embedding2id = load_embedding_prefix(os.path.join(args.ance_checkpoint_path, query_prefix + str(checkpoint))) 59 | passage_embedding, passage_embedding2id = load_embedding_prefix(os.path.join(args.ance_checkpoint_path, "passage_" + str(checkpoint))) 60 | 61 | return query_embedding, query_embedding2id, passage_embedding, passage_embedding2id 62 | 63 | 64 | class InputFeaturesPair(object): 65 | """ 66 | A single set of features of data. 67 | 68 | Args: 69 | input_ids: Indices of input sequence tokens in the vocabulary. 70 | attention_mask: Mask to avoid performing attention on padding token indices. 71 | Mask values selected in ``[0, 1]``: 72 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 73 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 74 | label: Label corresponding to the input 75 | """ 76 | 77 | def __init__( 78 | self, 79 | input_ids_a, 80 | attention_mask_a=None, 81 | token_type_ids_a=None, 82 | input_ids_b=None, 83 | attention_mask_b=None, 84 | token_type_ids_b=None, 85 | label=None): 86 | self.input_ids_a = input_ids_a 87 | self.attention_mask_a = attention_mask_a 88 | self.token_type_ids_a = token_type_ids_a 89 | 90 | self.input_ids_b = input_ids_b 91 | self.attention_mask_b = attention_mask_b 92 | self.token_type_ids_b = token_type_ids_b 93 | 94 | self.label = label 95 | 96 | def __repr__(self): 97 | return str(self.to_json_string()) 98 | 99 | def to_dict(self): 100 | """Serializes this instance to a Python dictionary.""" 101 | output = copy.deepcopy(self.__dict__) 102 | return output 103 | 104 | def to_json_string(self): 105 | """Serializes this instance to a JSON string.""" 106 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 107 | 108 | 109 | def getattr_recursive(obj, name): 110 | for layer in name.split("."): 111 | if hasattr(obj, layer): 112 | obj = getattr(obj, layer) 113 | else: 114 | return None 115 | return obj 116 | 117 | 118 | def barrier_array_merge( 119 | args, 120 | data_array, 121 | prefix="", 122 | load_cache=False): 123 | # data array: [B, any dimension] 124 | # merge alone one axis 125 | 126 | if args.local_rank == -1: 127 | return data_array 128 | 129 | if not load_cache: 130 | rank = args.rank 131 | 132 | pickle_path = os.path.join( 133 | args.output_dir, 134 | "{1}_data_obj_{0}.pb".format( 135 | str(rank), 136 | prefix)) 137 | 138 | with open(pickle_path, 'wb') as handle: 139 | np.save(handle, data_array) 140 | 141 | 142 | def pad_input_ids(input_ids, max_length, 143 | pad_on_left=False, 144 | pad_token=0): 145 | padding_length = max_length - len(input_ids) 146 | padding_id = [pad_token] * padding_length 147 | 148 | if padding_length <= 0: 149 | input_ids = input_ids[:max_length] 150 | else: 151 | if pad_on_left: 152 | input_ids = padding_id + input_ids 153 | else: 154 | input_ids = input_ids + padding_id 155 | 156 | return input_ids 157 | 158 | def pad_ids(input_ids, attention_mask, token_type_ids, max_length, 159 | pad_on_left=False, 160 | pad_token=0, 161 | pad_token_segment_id=0, 162 | mask_padding_with_zero=True): 163 | padding_length = max_length - len(input_ids) 164 | padding_id = [pad_token] * padding_length 165 | padding_type = [pad_token_segment_id] * padding_length 166 | padding_attention = [0 if mask_padding_with_zero else 1] * padding_length 167 | 168 | if padding_length <= 0: 169 | input_ids = input_ids[:max_length] 170 | attention_mask = attention_mask[:max_length] 171 | token_type_ids = token_type_ids[:max_length] 172 | else: 173 | if pad_on_left: 174 | input_ids = padding_id + input_ids 175 | attention_mask = padding_attention + attention_mask 176 | token_type_ids = padding_type + token_type_ids 177 | else: 178 | input_ids = input_ids + padding_id 179 | attention_mask = attention_mask + padding_attention 180 | token_type_ids = token_type_ids + padding_type 181 | 182 | return input_ids, attention_mask, token_type_ids 183 | 184 | 185 | def triple_process_fn(line, i, tokenizer, args): 186 | features = [] 187 | cells = line.split("\t") 188 | if len(cells) == 3: 189 | # this is for training and validation 190 | # query, positive_passage, negative_passage = line 191 | mask_padding_with_zero = True 192 | pad_token_segment_id = 0 193 | pad_on_left = False 194 | 195 | for text in cells: 196 | input_id_a = tokenizer.encode( 197 | text.strip(), add_special_tokens=True, max_length=args.max_seq_length, ) 198 | token_type_ids_a = [0] * len(input_id_a) 199 | attention_mask_a = [ 200 | 1 if mask_padding_with_zero else 0] * len(input_id_a) 201 | input_id_a, attention_mask_a, token_type_ids_a = pad_ids( 202 | input_id_a, attention_mask_a, token_type_ids_a, args.max_seq_length, tokenizer.pad_token_id, 203 | mask_padding_with_zero, pad_token_segment_id, pad_on_left) 204 | features += [torch.tensor(input_id_a, dtype=torch.int), 205 | torch.tensor(attention_mask_a, dtype=torch.bool)] 206 | else: 207 | raise Exception( 208 | "Line doesn't have correct length: {0}. Expected 3.".format(str(len(cells)))) 209 | return [features] 210 | 211 | 212 | # to reuse pytrec_eval, id must be string 213 | def convert_to_string_id(result_dict): 214 | string_id_dict = {} 215 | 216 | # format [string, dict[string, val]] 217 | for k, v in result_dict.items(): 218 | _temp_v = {} 219 | for inner_k, inner_v in v.items(): 220 | _temp_v[str(inner_k)] = inner_v 221 | 222 | string_id_dict[str(k)] = _temp_v 223 | 224 | return string_id_dict 225 | 226 | 227 | def set_seed(args): 228 | random.seed(args.seed) 229 | np.random.seed(args.seed) 230 | torch.manual_seed(args.seed) 231 | if args.n_gpu > 0: 232 | torch.cuda.manual_seed_all(args.seed) 233 | 234 | 235 | def set_env(args): 236 | # Setup CUDA, GPU & distributed training 237 | if args.local_rank == -1: 238 | device = torch.device( 239 | "cuda" if torch.cuda.is_available() else "cpu") 240 | args.n_gpu = torch.cuda.device_count() 241 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 242 | torch.cuda.set_device(args.local_rank) 243 | device = torch.device("cuda", args.local_rank) 244 | torch.distributed.init_process_group(backend="nccl") 245 | args.n_gpu = 1 246 | args.device = device 247 | 248 | # Setup logging 249 | logging.basicConfig( 250 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 251 | datefmt="%m/%d/%Y %H:%M:%S", 252 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 253 | ) 254 | logger.warning( 255 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", 256 | args.local_rank, 257 | device, 258 | args.n_gpu, 259 | bool(args.local_rank != -1), 260 | ) 261 | 262 | # Set seed 263 | set_seed(args) 264 | 265 | 266 | def load_model(args, output_attentions=False): 267 | # Prepare GLUE task 268 | args.output_mode = "classification" 269 | label_list = ["0", "1"] 270 | num_labels = len(label_list) 271 | 272 | # store args 273 | if args.local_rank != -1: 274 | args.world_size = torch.distributed.get_world_size() 275 | args.rank = dist.get_rank() 276 | 277 | # Load pretrained model and tokenizer 278 | if args.local_rank not in [-1, 0]: 279 | # Make sure only the first process in distributed training will 280 | # download model & vocab 281 | torch.distributed.barrier() 282 | 283 | config = RobertaConfig.from_pretrained( 284 | args.model_name_or_path, 285 | num_labels=num_labels, 286 | cache_dir=args.cache_dir if args.cache_dir else None, 287 | output_attentions=output_attentions 288 | ) 289 | tokenizer = RobertaTokenizer.from_pretrained( 290 | args.model_name_or_path, 291 | do_lower_case=args.do_lower_case, 292 | cache_dir=args.cache_dir if args.cache_dir else None, 293 | use_fast=True 294 | ) 295 | model = RobertaDot_NLL_LN.from_pretrained( 296 | args.model_name_or_path, 297 | from_tf=bool(".ckpt" in args.model_name_or_path), 298 | config=config, 299 | cache_dir=args.cache_dir if args.cache_dir else None, 300 | ) 301 | 302 | if args.local_rank == 0: 303 | # Make sure only the first process in distributed training will 304 | # download model & vocab 305 | torch.distributed.barrier() 306 | 307 | model.to(args.device) 308 | 309 | return tokenizer, model 310 | 311 | 312 | def is_first_worker(): 313 | return not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0 314 | 315 | 316 | def concat_key(all_list, key, axis=0): 317 | return np.concatenate([ele[key] for ele in all_list], axis=axis) 318 | 319 | 320 | def get_checkpoint_no(checkpoint_path): 321 | nums = re.findall(r'\d+', checkpoint_path) 322 | return int(nums[-1]) if len(nums) > 0 else 0 323 | 324 | 325 | def get_latest_ann_data(ann_data_path): 326 | ANN_PREFIX = "ann_ndcg_" 327 | if not os.path.exists(ann_data_path): 328 | return -1, None, None 329 | files = list(next(os.walk(ann_data_path))[2]) 330 | num_start_pos = len(ANN_PREFIX) 331 | data_no_list = [int(s[num_start_pos:]) 332 | for s in files if s[:num_start_pos] == ANN_PREFIX] 333 | if len(data_no_list) > 0: 334 | data_no = max(data_no_list) 335 | with open(os.path.join(ann_data_path, ANN_PREFIX + str(data_no)), 'r') as f: 336 | ndcg_json = json.load(f) 337 | return data_no, os.path.join( 338 | ann_data_path, "ann_training_data_" + str(data_no)), ndcg_json 339 | return -1, None, None 340 | 341 | 342 | def numbered_byte_file_generator(base_path, file_no, record_size): 343 | for i in range(file_no): 344 | with open('{}_split{}'.format(base_path, i), 'rb') as f: 345 | while True: 346 | b = f.read(record_size) 347 | if not b: 348 | # eof 349 | break 350 | yield b 351 | 352 | 353 | def tokenize_to_file(args, i, num_process, in_path, out_path, line_fn): 354 | configObj = MSMarcoConfig(name="rdot_nll", model=RobertaDot_NLL_LN, process_fn=triple_process_fn, use_mean=False) 355 | tokenizer = configObj.tokenizer_class.from_pretrained( 356 | args.model_name_or_path, 357 | do_lower_case=True, 358 | cache_dir=None, 359 | ) 360 | 361 | with open(in_path, 'r', encoding='utf-8') if in_path[-2:] != "gz" else gzip.open(in_path, 'rt', 362 | encoding='utf8') as in_f, \ 363 | open('{}_split{}'.format(out_path, i), 'wb') as out_f: 364 | for idx, line in enumerate(in_f): 365 | if idx % num_process != i: 366 | continue 367 | out_f.write(line_fn(args, line, tokenizer)) 368 | 369 | 370 | def multi_file_process(args, num_process, in_path, out_path, line_fn): 371 | processes = [] 372 | for i in range(num_process): 373 | p = Process( 374 | target=tokenize_to_file, 375 | args=( 376 | args, 377 | i, 378 | num_process, 379 | in_path, 380 | out_path, 381 | line_fn, 382 | )) 383 | processes.append(p) 384 | p.start() 385 | for p in processes: 386 | p.join() 387 | 388 | 389 | def all_gather(data): 390 | """ 391 | Run all_gather on arbitrary picklable data (not necessarily tensors) 392 | Args: 393 | data: any picklable object 394 | Returns: 395 | list[data]: list of data gathered from each rank 396 | """ 397 | if not dist.is_initialized() or dist.get_world_size() == 1: 398 | return [data] 399 | 400 | world_size = dist.get_world_size() 401 | # serialized to a Tensor 402 | buffer = pickle.dumps(data) 403 | storage = torch.ByteStorage.from_buffer(buffer) 404 | tensor = torch.ByteTensor(storage).to("cuda") 405 | 406 | # obtain Tensor size of each rank 407 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 408 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 409 | dist.all_gather(size_list, local_size) 410 | size_list = [int(size.item()) for size in size_list] 411 | max_size = max(size_list) 412 | 413 | # receiving Tensor from all ranks 414 | # we pad the tensor because torch all_gather does not support 415 | # gathering tensors of different shapes 416 | tensor_list = [] 417 | for _ in size_list: 418 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 419 | if local_size != max_size: 420 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 421 | tensor = torch.cat((tensor, padding), dim=0) 422 | dist.all_gather(tensor_list, tensor) 423 | 424 | data_list = [] 425 | for size, tensor in zip(size_list, tensor_list): 426 | buffer = tensor.cpu().numpy().tobytes()[:size] 427 | data_list.append(pickle.loads(buffer)) 428 | 429 | return data_list 430 | 431 | 432 | def get_offset2qid(args): 433 | path = os.path.join(args.processed_data_dir, f"{args.mode}-offset2qid.pickle") 434 | if os.path.exists(path): 435 | with open(path, "rb") as f: 436 | return pickle.load(f) 437 | with open(os.path.join(args.processed_data_dir, f"{args.mode}-qid2offset.pickle"), "rb") as f: 438 | qid2offset = pickle.load(f) 439 | offset2qid = {} 440 | for qid, offset in qid2offset.items(): 441 | offset2qid[offset] = qid 442 | with open(path, "wb") as f: 443 | pickle.dump(offset2qid, f) 444 | return offset2qid 445 | 446 | 447 | def get_embedding2qid(args): 448 | path = os.path.join(args.processed_data_dir, f"{args.mode}-embedding2qid.pickle") 449 | if os.path.exists(path): 450 | with open(path, "rb") as f: 451 | return pickle.load(f) 452 | 453 | query_prefix = f"{args.mode}_query_0" 454 | query_embedding, query_embedding2id = load_embedding_prefix(os.path.join(args.ance_checkpoint_path, query_prefix)) 455 | offset2qid = get_offset2qid(args) 456 | embedding2qid = {} 457 | for embeddingid, offset in enumerate(query_embedding2id): 458 | embedding2qid[embeddingid] = offset2qid[offset] 459 | with open(path, "wb") as f: 460 | pickle.dump(embedding2qid, f) 461 | return embedding2qid 462 | 463 | 464 | def offset_to_orig_id(orig2offset): 465 | offset2orig = dict() 466 | for k, v in orig2offset.items(): 467 | offset2orig[v] = k 468 | return offset2orig 469 | 470 | 471 | # qid, pid = marco qid, pid 472 | # offset = anceID (e.g. 0, 3, 6 ...) 473 | # embed_id = id into the passage_embedding / query_embedding matrices (as the order it appear when it's given to faiss) 474 | # In ANCE implementation, all evaluations are done on offset 475 | def ance_ranking_to_tein(args, dev_I, passage_embedding2id, query_embedid2qid, output_path, top=10): 476 | with open(os.path.join(args.processed_data_dir, "pid2offset.pickle"), "rb") as f: 477 | pid2offset = pickle.load(f) 478 | offset2pid = offset_to_orig_id(pid2offset) 479 | 480 | with open(output_path, "w") as f1, open(output_path+".marco", "w") as f2, \ 481 | open(f"{output_path}.marco.{top}", "w") as f3: 482 | writer1 = csv.writer(f1, delimiter="\t") 483 | writer2 = csv.writer(f2, delimiter="\t") 484 | writer3 = csv.writer(f3, delimiter="\t") 485 | for qry_embed_id, rel_psg_embed_ids in enumerate(dev_I): 486 | qid = query_embedid2qid[qry_embed_id] 487 | pids = [offset2pid[passage_embedding2id[p_embed_id]] for p_embed_id in rel_psg_embed_ids] 488 | rows1 = [[qid, "Q0", pid, rank+1, -rank, "full_ance"] for rank, pid in enumerate(pids)] 489 | rows2 = [[qid, pid, rank+1] for rank, pid in enumerate(pids)] 490 | rows3 = [[qid, pid, rank+1] for rank, pid in enumerate(pids[:top])] 491 | writer1.writerows(rows1) 492 | writer2.writerows(rows2) 493 | writer3.writerows(rows3) 494 | 495 | 496 | def devI_to_tein(args, query_embedid2qid, devI_path): 497 | _, passage_embedding2id = load_embedding_prefix(os.path.join(args.ance_checkpoint_path, "passage_0")) 498 | 499 | with open(devI_path, "rb") as f: 500 | dev_I = np.load(f) 501 | tein_path = devI_path + ".tein" 502 | ance_ranking_to_tein(args, dev_I, passage_embedding2id, query_embedid2qid, tein_path) 503 | 504 | 505 | def get_binary_qrel(qrel_tsv_path, bin_qrel_path): 506 | with open(qrel_tsv_path, "r") as fin, open(bin_qrel_path, "w") as fout: 507 | for l in fin: 508 | qid, pid, rel = l.split() 509 | rel = 0 if int(rel) < 2 else 1 510 | fout.write(f"{qid}\t{pid}\t{rel}\n") 511 | 512 | 513 | def load_positve_query_id(args, mode="dev", binary=False): 514 | positive_id = {} 515 | if not binary: 516 | query_positive_id_path = os.path.join(args.preprocessed_dir, f"{mode}-qrel.tsv") 517 | else: 518 | query_positive_id_path = os.path.join(args.preprocessed_dir, f"binary-{mode}-qrel.tsv") 519 | 520 | with open(query_positive_id_path, 'r', encoding='utf8') as f: 521 | tsvreader = csv.reader(f, delimiter="\t") 522 | for [topicid, docid, rel] in tsvreader: 523 | topicid = int(topicid) 524 | docid = int(docid) 525 | if topicid not in positive_id: 526 | positive_id[topicid] = {} 527 | positive_id[topicid][docid] = int(rel) 528 | return positive_id 529 | 530 | 531 | def rerank(first_stage_inn, query_embedding, query_embedding2id, passage_embedding, passage_embedding2id, output_dir, 532 | dataset, mode="dev"): 533 | # pidmap = collections.defaultdict(list) 534 | # for i in range(len(passage_embedding2id)): 535 | # pidmap[passage_embedding2id[i]].append(i) # abs pos(key) to rele pos(val) (old p_offset->old p_embid) 536 | 537 | all_dev_I = [] 538 | for i, qid in enumerate(query_embedding2id): 539 | p_set = [] 540 | p_set_map = {} 541 | 542 | count = 0 543 | for k, p_embid in enumerate(first_stage_inn[i]): # 544 | p_set.append(passage_embedding[p_embid]) # ids 545 | p_set_map[count] = p_embid # new p_embid->old p_embid 546 | count += 1 547 | dim = passage_embedding.shape[1] 548 | faiss.omp_set_num_threads(16) 549 | cpu_index = faiss.IndexFlatIP(dim) 550 | p_set = np.asarray(p_set) 551 | cpu_index.add(p_set) 552 | _, dev_I = cpu_index.search(query_embedding[i:i + 1], len(p_set)) 553 | for j in range(len(dev_I[0])): 554 | dev_I[0][j] = p_set_map[dev_I[0][j]] 555 | all_dev_I.append(dev_I[0]) 556 | with open(os.path.join(output_dir, f"{dataset}_{mode}I_rerank.npy"), "wb") as f: 557 | np.save(f, np.array(all_dev_I)) 558 | return all_dev_I 559 | 560 | 561 | def get_inn_rerank_depth(reranked_w_scores, first_stage_inn, depth): 562 | first_stage_inn = copy.deepcopy(first_stage_inn) 563 | for i, inn in enumerate(first_stage_inn): 564 | first_part = np.array(sorted(inn[:depth], key=lambda pid: reranked_w_scores[i][pid])) 565 | first_stage_inn[i][:depth] = first_part 566 | return first_stage_inn 567 | 568 | 569 | def EvalDevQuery(query_embedding2id, passage_embedding2id, dev_query_positive_id, I_nearest_neighbor, topN): 570 | prediction = {} # [qid][docid] = docscore, here we use -rank as score, so the higher the rank (1 > 2), the higher the score (-1 > -2) 571 | 572 | total = 0 573 | labeled = 0 574 | Atotal = 0 575 | Alabeled = 0 576 | qids_to_ranked_candidate_passages = {} 577 | for query_idx in range(len(I_nearest_neighbor)): 578 | seen_pid = set() 579 | query_id = query_embedding2id[query_idx] 580 | prediction[query_id] = {} 581 | 582 | top_ann_pid = I_nearest_neighbor[query_idx].copy() 583 | selected_ann_idx = top_ann_pid[:topN] 584 | rank = 0 585 | 586 | if query_id in qids_to_ranked_candidate_passages: 587 | pass 588 | else: 589 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 590 | tmp = [0] * 1000 591 | qids_to_ranked_candidate_passages[query_id] = tmp 592 | 593 | for idx in selected_ann_idx: 594 | pred_pid = passage_embedding2id[idx] 595 | 596 | if not pred_pid in seen_pid: 597 | # this check handles multiple vector per document 598 | qids_to_ranked_candidate_passages[query_id][rank] = pred_pid 599 | Atotal += 1 600 | if pred_pid not in dev_query_positive_id[query_id]: 601 | Alabeled += 1 602 | if rank < 10: 603 | total += 1 604 | if pred_pid not in dev_query_positive_id[query_id]: 605 | labeled += 1 606 | rank += 1 607 | prediction[query_id][pred_pid] = -rank 608 | seen_pid.add(pred_pid) 609 | 610 | # use out of the box evaluation script 611 | evaluator = pytrec_eval.RelevanceEvaluator( 612 | convert_to_string_id(dev_query_positive_id), {'map_cut', 'ndcg_cut', 'recip_rank', 'recall'}) 613 | 614 | eval_query_cnt = 0 615 | result = evaluator.evaluate(convert_to_string_id(prediction)) 616 | 617 | qids_to_relevant_passageids = {} 618 | for qid in dev_query_positive_id: 619 | qid = int(qid) 620 | if qid in qids_to_relevant_passageids: 621 | pass 622 | else: 623 | qids_to_relevant_passageids[qid] = [] 624 | for pid in dev_query_positive_id[qid]: 625 | if pid > 0: 626 | qids_to_relevant_passageids[qid].append(pid) 627 | 628 | ms_mrr = compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 629 | 630 | ndcg = 0 631 | Map = 0 632 | mrr = 0 633 | recall = 0 634 | recall_1000 = 0 635 | 636 | for k in result.keys(): 637 | eval_query_cnt += 1 638 | ndcg += result[k]["ndcg_cut_10"] 639 | Map += result[k]["map_cut_10"] 640 | mrr += result[k]["recip_rank"] 641 | recall += result[k]["recall_" + str(topN)] 642 | 643 | final_ndcg = ndcg / eval_query_cnt 644 | final_Map = Map / eval_query_cnt 645 | final_mrr = mrr / eval_query_cnt 646 | final_recall = recall / eval_query_cnt 647 | hole_rate = labeled / total 648 | Ahole_rate = Alabeled / Atotal 649 | 650 | return final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, result, prediction 651 | 652 | 653 | class UtilEmbeddingCache: 654 | def __init__(self, base_path, seed=-1): 655 | self.base_path = base_path 656 | with open(base_path + '_meta', 'r') as f: 657 | meta = json.load(f) 658 | self.dtype = np.dtype(meta['type']) 659 | self.total_number = meta['total_number'] 660 | self.record_size = int( 661 | meta['embedding_size']) * self.dtype.itemsize + 4 662 | if seed >= 0: 663 | self.ix_array = np.random.RandomState( 664 | seed).permutation(self.total_number) 665 | else: 666 | self.ix_array = np.arange(self.total_number) 667 | self.f = None 668 | 669 | def open(self): 670 | self.f = open(self.base_path, 'rb') 671 | 672 | def close(self): 673 | self.f.close() 674 | 675 | def read_single_record(self): 676 | record_bytes = self.f.read(self.record_size) 677 | passage_len = int.from_bytes(record_bytes[:4], 'big') 678 | passage = np.frombuffer(record_bytes[4:], dtype=self.dtype) 679 | return passage_len, passage 680 | 681 | def __enter__(self): 682 | self.open() 683 | return self 684 | 685 | def __exit__(self, type, value, traceback): 686 | self.close() 687 | 688 | def __getitem__(self, key): 689 | if key < 0 or key > self.total_number: 690 | raise IndexError( 691 | "Index {} is out of bound for cached embeddings of size {}".format( 692 | key, self.total_number)) 693 | self.f.seek(key * self.record_size) 694 | return self.read_single_record() 695 | 696 | def __iter__(self): 697 | self.f.seek(0) 698 | for i in range(self.total_number): 699 | new_ix = self.ix_array[i] 700 | yield self.__getitem__(new_ix) 701 | 702 | def __len__(self): 703 | return self.total_number 704 | 705 | 706 | class UtilStreamingDataset(IterableDataset): 707 | def __init__(self, elements, fn, distributed=True): 708 | super().__init__() 709 | self.elements = elements 710 | self.fn = fn 711 | self.num_replicas = -1 712 | self.distributed = distributed 713 | 714 | def __iter__(self): 715 | if dist.is_initialized(): 716 | self.num_replicas = dist.get_world_size() 717 | self.rank = dist.get_rank() 718 | else: 719 | print("Not running in distributed mode") 720 | for i, element in enumerate(self.elements): 721 | if self.distributed and self.num_replicas != -1 and i % self.num_replicas != self.rank: 722 | continue 723 | records = self.fn(element, i) 724 | for rec in records: 725 | yield rec --------------------------------------------------------------------------------