├── README.md ├── pic └── stat.png ├── requirements.txt ├── script ├── train_cross_encoder.sh └── train_dual_encoder.sh └── src ├── convert2trec.py ├── dataset_factory.py ├── modeling.py ├── msmarco_eval.py ├── train_cross_encoder.py ├── train_dual_encoder.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # T2Ranking 2 | 3 | ## Introduction 4 | T2Ranking is a large-scale Chinese benchmark for passage ranking. The details about T2Ranking are elaborated in [this paper](https://arxiv.org/abs/2304.03679#). 5 | 6 | 7 | Passage ranking are important and challenging topics for both academics and industries in the area of Information Retrieval (IR). The goal of passage ranking is to compile a search result list ordered in terms of relevance to the query from a large passage collection. Typically, Passage ranking involves two stages: passage retrieval and passage re-ranking. 8 | 9 | To support the passage ranking research, various benchmark datasets are constructed. However, the commonly-used datasets for passage ranking usually focus on the English language. For non-English scenarios, such as Chinese, the existing datasets are limited in terms of data scale, fine-grained relevance annotation and false negative issues. 10 | 11 | 12 | To address this problem, we introduce T2Ranking, a large-scale Chinese benchmark for passage ranking. T2Ranking comprises more than 300K queries and over 2M unique passages from real- world search engines. Specifically, we sample question-based search queries from user logs of the Sogou search engine, a popular search system in China. For each query, we extract the content of corresponding documents from different search engines. After model-based passage segmentation and clustering-based passage de-duplication, a large-scale passage corpus is obtained. For a given query and its corresponding passages, we hire expert annotators to provide 4-level relevance judgments of each query-passage pair. 13 | 14 | 15 |
16 |
Table 1: The data statistics of datasets commonly used in passage ranking. FR(SR): First (Second)- stage of passage ranking, i.e., passage Retrieval (Re-ranking).
17 | 18 | 19 | 20 | Compared with existing datasets, T2Ranking dataset has the following characteristics and advantages: 21 | * The proposed dataset focus on the Chinese search scenario, and has advantages in data scale compared with existing Chinese passage ranking datasets, which can better support the design of deep learning algorithms 22 | * The proposed dataset has a large number of fine-grained relevance annotations, which is helpful for mining fine-grained relationship between queries and passages and constructing more accurate ranking algorithms. 23 | * By retrieving passage results from multiple commercial search engines and providing complete annotation, we ease the false negative problem to some extent, which is beneficial to providing more accurate evaluation. 24 | * We design multiple strategies to ensure the high quality of our dataset, such as using a passage segment model and a passage clustering model to enhance the semantic integrity and diversity of passages and employing active learning for annotation method to improve the efficiency and quality of data annotation. 25 | 26 | ## Data Download 27 | The whole dataset is placed in [huggingface](https://huggingface.co/datasets/THUIR/T2Ranking), and the data formats are presented in the following table. 28 |
29 | 30 | | Description| Filename|Num Records|Format| 31 | |-------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------|----------:|-----------------------------------:| 32 | | Collection | collection.tsv | 2,303,643 | tsv: pid, passage | 33 | | Queries Train | queries.train.tsv | 258,042 | tsv: qid, query | 34 | | Queries Dev | queries.dev.tsv | 24,832 | tsv: qid, query | 35 | | Queries Test | queries.test.tsv | 24,832 | tsv: qid, query | 36 | | Qrels Train for re-ranking | qrels.train.tsv | 1,613,421 | TREC qrels format | 37 | | Qrels Dev for re-ranking | qrels.dev.tsv | 400,536 | TREC qrels format | 38 | | Qrels Retrieval Train | qrels.retrieval.train.tsv | 744,663 | tsv: qid, pid | 39 | | Qrels Retrieval Dev | qrels.retrieval.dev.tsv | 118,933 | tsv: qid, pid | 40 | | BM25 Negatives | train.bm25.tsv | 200,359,731 | tsv: qid, pid, index | 41 | | Hard Negatives | train.mined.tsv | 200,376,001 | tsv: qid, pid, index, score | 42 | 43 |
44 | 45 | You can download the dataset by running the following command: 46 | ```bash 47 | git lfs install 48 | git clone https://huggingface.co/datasets/THUIR/T2Ranking 49 | ``` 50 | After downloading, you can find the following files in the folder: 51 | ``` 52 | ├── data 53 | │ ├── collection.tsv 54 | │ ├── qrels.dev.tsv 55 | │ ├── qrels.retrieval.dev.tsv 56 | │ ├── qrels.retrieval.train.tsv 57 | │ ├── qrels.train.tsv 58 | │ ├── queries.dev.tsv 59 | │ ├── queries.test.tsv 60 | │ ├── queries.train.tsv 61 | │ ├── train.bm25.tsv 62 | │ └── train.mined.tsv 63 | ├── script 64 | │ ├── train_cross_encoder.sh 65 | │ └── train_dual_encoder.sh 66 | └── src 67 | ├── convert2trec.py 68 | ├── dataset_factory.py 69 | ├── modeling.py 70 | ├── msmarco_eval.py 71 | ├── train_cross_encoder.py 72 | ├── train_dual_encoder.py 73 | └── utils.py 74 | ``` 75 | 76 | 77 | 78 | ## Training and Evaluation 79 | The dual-encoder can be trained by running the following command: 80 | ```bash 81 | sh script/train_dual_encoder.sh 82 | ``` 83 | After training the model, you can evaluate the model by running the following command: 84 | ```bash 85 | python src/msmarco_eval.py data/qrels.retrieval.dev.tsv output/res.top1000.step20 86 | ``` 87 | 88 | 89 | The cross-encoder can be trained by running the following command: 90 | ```bash 91 | sh script/train_cross_encoder.sh 92 | ``` 93 | After training the model, you can evaluate the model by running the following command: 94 | ```bash 95 | python src/convert2trec.py output/res.step-20 && python src/msmarco_eval.py data/qrels.retrieval.dev.tsv output/res.step-20.trec && path_to/trec_eval -m ndcg_cut.5 data/qrels.dev.tsv res.step-20.trec 96 | ``` 97 | 98 | We have uploaded some checkpoints to Huggingface Hub. 99 | 100 | | Model | Description | Link | 101 | | ------------------ | --------------------------------------------------------- | ------------------------------------------------------------ | 102 | | dual-encoder 1 | dual-encoder trained with bm25 negatives | [DE1](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/dual-encoder-trained-with-bm25-negatives.p) | 103 | | dual-encoder 2 | dual-encoder trained with self-mined hard negatives | [DE2](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/dual-encoder-trained-with-hard-negatives.p) | 104 | | cross-encoder | cross-encoder trained with self-mined hard negatives | [CE](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/cross-encoder.p) | 105 | 106 | 107 | BM25 on DEV set 108 | ```bash 109 | ##################### 110 | MRR @10: 0.35894801237316354 111 | QueriesRanked: 24831 112 | recall@1: 0.05098711868967141 113 | recall@1000: 0.7464097131133757 114 | recall@50: 0.4942572226146033 115 | ##################### 116 | ``` 117 | 118 | DPR trained with BM25 negatives on DEV set 119 | 120 | ```bash 121 | ##################### 122 | MRR @10: 0.4856112079562753 123 | QueriesRanked: 24831 124 | recall@1: 0.07367235058688999 125 | recall@1000: 0.9082753169878586 126 | recall@50: 0.7099350889583964 127 | ##################### 128 | ``` 129 | 130 | DPR trained with self-mined hard negatives on DEV set 131 | 132 | ```bash 133 | ##################### 134 | MRR @10: 0.5166915171959451 135 | QueriesRanked: 24831 136 | recall@1: 0.08047455688965123 137 | recall@1000: 0.9135220125786163 138 | recall@50: 0.7327044025157232 139 | ##################### 140 | ``` 141 | 142 | 143 | BM25 retrieved+CE reranked on DEV set 144 | 145 | The reranked run file is placed in [here.](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/dev.bm25.tsv) 146 | ```bash 147 | ##################### 148 | MRR @10: 0.5188107959009376 149 | QueriesRanked: 24831 150 | recall@1: 0.08545219116806242 151 | recall@1000: 0.7464097131133757 152 | recall@50: 0.595298153566744 153 | ##################### 154 | ndcg_cut_20 all 0.4405 155 | ndcg_cut_100 all 0.4705 156 | ##################### 157 | ``` 158 | 159 | DPR retrieved+CE reranked on DEV set 160 | 161 | The reranked run file is placed in [here.](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/dev.dpr.tsv) 162 | ```bash 163 | ##################### 164 | MRR @10: 0.5508822816845231 165 | QueriesRanked: 24831 166 | recall@1: 0.08903406988867588 167 | recall@1000: 0.9135220125786163 168 | recall@50: 0.7393720781623112 169 | ##################### 170 | ndcg_cut_20 all 0.5131 171 | ndcg_cut_100 all 0.5564 172 | ##################### 173 | ``` 174 | 175 | ## License 176 | The dataset is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0.html). 177 | 178 | 179 | ## Citation 180 | If you use this dataset in your research, please cite our paper: 181 | ``` 182 | @misc{xie2023t2ranking, 183 | title={T2Ranking: A large-scale Chinese Benchmark for Passage Ranking}, 184 | author={Xiaohui Xie and Qian Dong and Bingning Wang and Feiyang Lv and Ting Yao and Weinan Gan and Zhijing Wu and Xiangsheng Li and Haitao Li and Yiqun Liu and Jin Ma}, 185 | year={2023}, 186 | eprint={2304.03679}, 187 | archivePrefix={arXiv}, 188 | primaryClass={cs.IR} 189 | } 190 | ``` 191 | -------------------------------------------------------------------------------- /pic/stat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUIR/T2Ranking/3ab0a0de72dd50bf84d852a985f6188334781403/pic/stat.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss_cpu==1.7.4 2 | faiss_gpu==1.7.2 3 | numpy==1.24.3 4 | pandas==2.0.1 5 | pytorch_pretrained_bert==0.6.2 6 | scikit_learn==1.2.2 7 | torch==2.0.0 8 | torch_optimizer==0.3.0 9 | tqdm==4.65.0 10 | transformers==4.28.1 11 | -------------------------------------------------------------------------------- /script/train_cross_encoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dataset=marco 3 | sample_num=64 4 | batch_size=16 5 | echo "batch size ${batch_size}" 6 | dev_batch_size=1024 7 | min_index=0 8 | max_index=256 9 | max_seq_len=332 10 | q_max_seq_len=32 11 | p_max_seq_len=128 12 | model_name_or_path=checkpoint/bert-base-chinese/ 13 | top1000=data/train.mined.tsv 14 | warm_start_from=data/cross-encoder.p 15 | dev_top1000=data/dev.dpr.tsv 16 | dev_query=data/queries.dev.tsv 17 | collection=data/collection.tsv 18 | qrels=data/qrels.retrieval.train.tsv 19 | dev_qrels=data/qrels.retrieval.dev.tsv 20 | query=data/queries.train.tsv 21 | learning_rate=3e-5 22 | ### 下面是永远不用改的 23 | warmup_proportion=0.1 24 | eval_step_proportion=0.01 25 | report_step=100 26 | epoch=20 27 | fp16=true 28 | output_dir=output 29 | log_dir=${output_dir}/log 30 | mkdir -p ${output_dir} 31 | mkdir -p ${log_dir} 32 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 33 | python -m torch.distributed.launch \ 34 | --log_dir ${log_dir} \ 35 | --nproc_per_node=8 \ 36 | src/train_cross_encoder.py \ 37 | --model_name_or_path=${model_name_or_path} \ 38 | --batch_size=${batch_size} \ 39 | --warmup_proportion=${warmup_proportion} \ 40 | --eval_step_proportion=${eval_step_proportion} \ 41 | --report=${report_step} \ 42 | --qrels=${qrels} \ 43 | --dev_qrels=${dev_qrels} \ 44 | --query=${query} \ 45 | --dev_query=${dev_query} \ 46 | --collection=${collection} \ 47 | --top1000=${top1000} \ 48 | --min_index=${min_index} \ 49 | --max_index=${max_index} \ 50 | --epoch=${epoch} \ 51 | --sample_num=${sample_num} \ 52 | --dev_batch_size=${dev_batch_size} \ 53 | --max_seq_len=${max_seq_len} \ 54 | --learning_rate=${learning_rate} \ 55 | --q_max_seq_len=${q_max_seq_len} \ 56 | --p_max_seq_len=${p_max_seq_len} \ 57 | --dev_top1000=${dev_top1000} \ 58 | --warm_start_from=${warm_start_from} \ 59 | | tee ${log_dir}/train.log 60 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" 61 | 62 | -------------------------------------------------------------------------------- /script/train_dual_encoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dataset=marco 3 | sample_num=2 4 | batch_size=64 5 | echo "batch size ${batch_size}" 6 | max_index=200 7 | retriever_model_name_or_path=checkpoint/bert-base-chinese/ 8 | top1000=data/train.mined.tsv 9 | warm_start_from=data/dual-encoder-trained-with-hard-negatives.p 10 | learning_rate=2e-5 11 | ### 下面是永远不用改的 12 | dev_batch_size=256 13 | min_index=0 14 | max_seq_len=332 15 | q_max_seq_len=32 16 | p_max_seq_len=300 17 | dev_query=data/queries.dev.tsv 18 | collection=data/collection.tsv 19 | qrels=data/qrels.retrieval.train.tsv 20 | dev_qrels=data/qrels.retrieval.dev.tsv 21 | query=data/queries.train.tsv 22 | warmup_proportion=0.1 23 | eval_step_proportion=0.01 24 | report_step=100 25 | epoch=200 26 | fp16=true 27 | output_dir=output 28 | log_dir=${output_dir}/log 29 | mkdir -p ${output_dir} 30 | mkdir -p ${log_dir} 31 | master_port=29500 32 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 33 | python -m torch.distributed.launch \ 34 | --log_dir ${log_dir} \ 35 | --nproc_per_node=8 \ 36 | --master_port=${master_port} \ 37 | src/train_dual_encoder.py \ 38 | --retriever_model_name_or_path=${retriever_model_name_or_path} \ 39 | --batch_size=${batch_size} \ 40 | --warmup_proportion=${warmup_proportion} \ 41 | --eval_step_proportion=${eval_step_proportion} \ 42 | --report=${report_step} \ 43 | --qrels=${qrels} \ 44 | --dev_qrels=${dev_qrels} \ 45 | --query=${query} \ 46 | --dev_query=${dev_query} \ 47 | --collection=${collection} \ 48 | --top1000=${top1000} \ 49 | --min_index=${min_index} \ 50 | --max_index=${max_index} \ 51 | --epoch=${epoch} \ 52 | --sample_num=${sample_num} \ 53 | --dev_batch_size=${dev_batch_size} \ 54 | --max_seq_len=${max_seq_len} \ 55 | --learning_rate=${learning_rate} \ 56 | --q_max_seq_len=${q_max_seq_len} \ 57 | --p_max_seq_len=${p_max_seq_len} \ 58 | --warm_start_from=${warm_start_from} \ 59 | | tee ${log_dir}/train.log 60 | 61 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" 62 | 63 | -------------------------------------------------------------------------------- /src/convert2trec.py: -------------------------------------------------------------------------------- 1 | import sys 2 | fin=sys.argv[1] 3 | def convert_to_trec(fin): 4 | with open(fin) as f: 5 | lines=f.readlines() 6 | saved=[] 7 | for line in lines: 8 | qid,pid,index,score=line.strip().split() 9 | saved.append(str(int(float(qid)))+'\t'+"vanilla_bert\t"+str(int(float(pid)))+'\t'+str(index)+'\t'+str(score)+'\tvanilla_bert\n') 10 | with open(fin+".trec","w") as f: 11 | f.writelines(saved) 12 | 13 | if __name__=="__main__": 14 | convert_to_trec(fin) -------------------------------------------------------------------------------- /src/dataset_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytorch_pretrained_bert 7 | import torch 8 | from sklearn.utils import shuffle 9 | from torch.utils.data import DataLoader, Dataset 10 | from tqdm import tqdm 11 | from transformers import AutoModel, AutoTokenizer, DataCollatorForWholeWordMask 12 | 13 | class PassageDataset(Dataset): 14 | def __init__(self, args): 15 | self.tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path) 16 | try: 17 | self.rank = torch.distributed.get_rank() 18 | self.n_procs = torch.distributed.get_world_size() 19 | except: 20 | self.rank = self.n_procs = 0 21 | self.args = args 22 | self.collection = pd.read_csv(args.collection,sep="\t", quoting=3) 23 | self.collection.columns=['pid', 'para'] 24 | self.collection = self.collection.fillna("NA") 25 | self.collection.index = self.collection.pid 26 | total_cnt = len(self.collection) 27 | shard_cnt = total_cnt//self.n_procs 28 | if self.rank!=self.n_procs-1: 29 | self.collection = self.collection[self.rank*shard_cnt:(self.rank+1)*shard_cnt] 30 | else: 31 | self.collection = self.collection[self.rank*shard_cnt:] 32 | self.num_samples = len(self.collection) 33 | print('rank:',self.rank,'samples:',self.num_samples) 34 | 35 | def _collate_fn(self, psgs): 36 | p_records = self.tokenizer(psgs, padding=True, truncation=True, return_tensors="pt", max_length=self.args.p_max_seq_len) 37 | return p_records 38 | 39 | def __getitem__(self, idx): 40 | cols = self.collection.iloc[idx] 41 | para = cols.para 42 | psg = para 43 | return psg 44 | 45 | def __len__(self): 46 | return self.num_samples 47 | 48 | 49 | class QueryDataset(Dataset): 50 | def __init__(self, args): 51 | self.tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path) 52 | self.args = args 53 | self.collection = pd.read_csv(args.dev_query, sep="\t", quoting=3) 54 | self.collection.columns = ['qid','qry'] 55 | self.collection = self.collection.fillna("NA") 56 | self.num_samples = len(self.collection) 57 | 58 | def _collate_fn(self, qrys): 59 | return self.tokenizer(qrys, padding=True, truncation=True, return_tensors="pt", max_length=self.args.q_max_seq_len) 60 | 61 | def __getitem__(self, idx): 62 | return self.collection.iloc[idx].qry 63 | 64 | def __len__(self): 65 | return self.num_samples 66 | 67 | 68 | class CrossEncoderTrainDataset(Dataset): 69 | def __init__(self, args): 70 | self.tokenizer = AutoTokenizer.from_pretrained(args.reranker_model_name_or_path) 71 | try: 72 | self.rank = torch.distributed.get_rank() 73 | self.n_procs = torch.distributed.get_world_size() 74 | except: 75 | self.rank = self.n_procs = 0 76 | self.args = args 77 | self.collection = pd.read_csv(args.collection,sep="\t", quoting=3) 78 | self.collection.columns=['pid', 'para'] 79 | self.collection = self.collection.fillna("NA") 80 | self.collection.index = self.collection.pid 81 | self.collection.pop('pid') 82 | self.query = pd.read_csv(args.query,sep="\t") 83 | self.query.columns = ['qid','text'] 84 | self.query.index = self.query.qid 85 | self.query.pop('qid') 86 | self.top1000 = pd.read_csv(args.top1000, sep="\t") 87 | self.top1000.columns=['qid','pid','index', 'score'] 88 | self.top1000 = list(self.top1000.groupby("qid")) 89 | self.len = len(self.top1000) 90 | self.min_index = args.min_index 91 | self.max_index = args.max_index 92 | qrels={} 93 | with f(args.qrels,'r') as f: 94 | lines = f.readlines() 95 | for line in lines[1:]: 96 | qid,pid = line.split() 97 | qid=int(qid) 98 | pid=int(pid) 99 | x=qrels.get(qid,[]) 100 | x.append(pid) 101 | qrels[qid]=x 102 | self.qrels = qrels 103 | self.sample_num = args.sample_num-1 104 | self.epoch = 0 105 | self.num_samples = len(self.top1000) 106 | 107 | def set_epoch(self, epoch): 108 | self.epoch = epoch 109 | print(self.epoch) 110 | 111 | def sample(self, qid, pids, sample_num): 112 | ''' 113 | qid:int 114 | pids:list 115 | sample_num:int 116 | ''' 117 | pids = [pid for pid in pids if pid not in self.qrels[qid]] 118 | pids = pids[self.args.min_index:self.args.max_index] 119 | interval = len(pids)//sample_num 120 | offset = self.epoch%interval 121 | sample_pids = pids[offset::interval][:sample_num] 122 | return sample_pids 123 | 124 | def __getitem__(self, idx): 125 | cols = self.top1000[idx] 126 | qid = cols[0] 127 | pids = list(cols[1]['pid']) 128 | sample_neg_pids = self.sample(qid, pids, self.sample_num) 129 | pos_id = random.choice(self.qrels.get(qid)) 130 | query = self.query.loc[qid]['text'] 131 | data = [(query, self.collection.loc[pos_id]['para'])] 132 | for neg_pid in sample_neg_pids: 133 | data.append((query, self.collection.loc[neg_pid]['para'])) 134 | return data 135 | 136 | def _collate_fn(self, sample_list): 137 | qrys = [] 138 | psgs = [] 139 | for qp_pairs in sample_list: 140 | for q,p in qp_pairs: 141 | qrys.append(q) 142 | psgs.append(p) 143 | features = self.tokenizer(qrys, psgs, padding=True, truncation=True, return_tensors="pt", max_length=self.args.max_seq_len) 144 | return features 145 | 146 | def __len__(self): 147 | return self.num_samples 148 | 149 | class CrossEncoderDevDataset(Dataset): 150 | def __init__(self, args): 151 | self.tokenizer = AutoTokenizer.from_pretrained(args.reranker_model_name_or_path) 152 | try: 153 | self.rank = torch.distributed.get_rank() 154 | self.n_procs = torch.distributed.get_world_size() 155 | except: 156 | self.rank = self.n_procs = 0 157 | self.args = args 158 | self.collection = pd.read_csv(args.collection,sep="\t", quoting=3) 159 | self.collection.columns=['pid', 'para'] 160 | self.collection = self.collection.fillna("NA") 161 | self.collection.index = self.collection.pid 162 | self.collection.pop('pid') 163 | self.query = pd.read_csv(args.dev_query,sep="\t") 164 | self.query.columns = ['qid','text'] 165 | self.query.index = self.query.qid 166 | self.query.pop('qid') 167 | self.top1000 = pd.read_csv(args.dev_top1000, sep="\t", header=None) 168 | self.num_samples = len(self.top1000) 169 | 170 | 171 | def __getitem__(self, idx): 172 | cols = self.top1000.iloc[idx] 173 | qid = cols[0] 174 | pid = cols[1] 175 | return self.query.loc[qid]['text'], self.collection.loc[pid]['para'], qid, pid 176 | 177 | def _collate_fn(self, sample_list): 178 | qrys = [] 179 | psgs = [] 180 | qids = [] 181 | pids = [] 182 | for q,p,qid,pid in sample_list: 183 | qrys.append(q) 184 | psgs.append(p) 185 | qids.append(qid) 186 | pids.append(pid) 187 | features = self.tokenizer(qrys, psgs, padding=True, truncation=True, return_tensors="pt", max_length=self.args.max_seq_len) 188 | return features, {"qids":np.array(qids),"pids":np.array(pids)} 189 | 190 | def __len__(self): 191 | return self.num_samples 192 | 193 | class DualEncoderTrainDataset(Dataset): 194 | def __init__(self, args): 195 | self.tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path) 196 | try: 197 | self.rank = torch.distributed.get_rank() 198 | self.n_procs = torch.distributed.get_world_size() 199 | except: 200 | self.rank = self.n_procs = 0 201 | self.args = args 202 | self.collection = pd.read_csv(args.collection,sep="\t", quoting=3) 203 | self.collection.columns=['pid','para'] 204 | self.collection = self.collection.fillna("NA") 205 | self.collection.index = self.collection.pid 206 | self.collection.pop('pid') 207 | self.query = pd.read_csv(args.query,sep="\t") 208 | self.query.columns = ['qid','text'] 209 | self.query.index = self.query.qid 210 | self.query.pop('qid') 211 | self.top1000 = pd.read_csv(args.top1000, sep="\t") 212 | if len(self.top1000.columns)==3: 213 | self.top1000.columns=['qid','pid','index'] 214 | else: 215 | self.top1000.columns=['qid','pid','index','score'] 216 | self.top1000 = list(self.top1000.groupby("qid")) 217 | self.len = len(self.top1000) 218 | self.min_index = args.min_index 219 | self.max_index = args.max_index 220 | qrels={} 221 | with open(args.qrels,'r') as f: 222 | lines = f.readlines() 223 | for line in lines[1:]: 224 | qid,pid = line.split() 225 | qid=int(qid) 226 | pid=int(pid) 227 | x=qrels.get(qid,[]) 228 | x.append(pid) 229 | qrels[qid]=x 230 | self.qrels = qrels 231 | self.sample_num = args.sample_num-1 232 | self.epoch = 0 233 | self.num_samples = len(self.top1000) 234 | 235 | def set_epoch(self, epoch): 236 | self.epoch = epoch 237 | print(self.epoch) 238 | 239 | def sample(self, qid, pids, sample_num): 240 | ''' 241 | qid:int 242 | pids:list 243 | sample_num:int 244 | ''' 245 | pids = [pid for pid in pids if pid not in self.qrels[qid]] 246 | pids = pids[self.args.min_index:self.args.max_index] 247 | if len(pids) 5 | Creation Date : 06/12/2018 6 | Last Modified : 4/6/2023 by Qian Dong and Haitao Li 7 | Authors : Daniel Campos , Rutger van Haasteren 8 | """ 9 | import itertools 10 | import sys 11 | from collections import Counter 12 | 13 | import numpy as np 14 | import pandas as pd 15 | 16 | MaxMRRRank = 10 17 | 18 | def load_reference_from_stream(f): 19 | """Load Reference reference relevant passages 20 | Args:f (stream): stream to load. 21 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 22 | """ 23 | qids_to_relevant_passageids = {} 24 | for l in f: 25 | try: 26 | l = l.strip().split('\t') 27 | qid = int(l[0]) 28 | if qid in qids_to_relevant_passageids: 29 | pass 30 | else: 31 | qids_to_relevant_passageids[qid] = [] 32 | qids_to_relevant_passageids[qid].append(int(l[1])) 33 | except: 34 | # raise IOError('\"%s\" is not valid format' % l) 35 | pass 36 | return qids_to_relevant_passageids 37 | 38 | 39 | def load_reference(path_to_reference): 40 | """Load Reference reference relevant passages 41 | Args:path_to_reference (str): path to a file to load. 42 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 43 | """ 44 | with open(path_to_reference, 'r') as f: 45 | qids_to_relevant_passageids = load_reference_from_stream(f) 46 | return qids_to_relevant_passageids 47 | 48 | 49 | def load_candidate_from_stream(f): 50 | """Load candidate data from a stream. 51 | Args:f (stream): stream to load. 52 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 53 | """ 54 | qid_to_ranked_candidate_passages = {} 55 | for l in f: 56 | try: 57 | l = l.strip().split() 58 | qid = int(float(l[0])) 59 | pid = int(float(l[1])) 60 | rank = int(float(l[2])) 61 | if qid in qid_to_ranked_candidate_passages: 62 | pass 63 | else: 64 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 65 | tmp = [0] * 1000 66 | qid_to_ranked_candidate_passages[qid] = tmp 67 | qid_to_ranked_candidate_passages[qid][rank - 1] = pid 68 | except: 69 | # raise IOError('\"%s\" is not valid format' % l) 70 | pass 71 | return qid_to_ranked_candidate_passages 72 | 73 | 74 | def load_candidate(path_to_candidate): 75 | """Load candidate data from a file. 76 | Args:path_to_candidate (str): path to file to load. 77 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 78 | """ 79 | 80 | with open(path_to_candidate, 'r') as f: 81 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f) 82 | return qid_to_ranked_candidate_passages 83 | 84 | 85 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 86 | """Perform quality checks on the dictionaries 87 | Args: 88 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 89 | Dict as read in with load_reference or load_reference_from_stream 90 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 91 | Returns: 92 | bool,str: Boolean whether allowed, message to be shown in case of a problem 93 | """ 94 | message = '' 95 | allowed = True 96 | 97 | # Create sets of the QIDs for the submitted and reference queries 98 | candidate_set = set(qids_to_ranked_candidate_passages.keys()) 99 | ref_set = set(qids_to_relevant_passageids.keys()) 100 | 101 | # Check that we do not have multiple passages per query 102 | for qid in qids_to_ranked_candidate_passages: 103 | # Remove all zeros from the candidates 104 | duplicate_pids = set( 105 | [item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) 106 | 107 | if len(duplicate_pids - set([0])) > 0: 108 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( 109 | qid=qid, pid=list(duplicate_pids)[0]) 110 | allowed = False 111 | 112 | return allowed, message 113 | 114 | 115 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 116 | """Compute MRR metric 117 | Args: 118 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 119 | Dict as read in with load_reference or load_reference_from_stream 120 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 121 | Returns: 122 | dict: dictionary of metrics {'MRR': } 123 | """ 124 | all_scores = {} 125 | MRR = 0 126 | qids_with_relevant_passages = 0 127 | ranking = [] 128 | recall_q_top1 = [] 129 | recall_q_top50 = [] 130 | recall_q_top1000 = [] 131 | recall_q_all = [] 132 | all_num = 0 133 | 134 | for qid in qids_to_ranked_candidate_passages: 135 | if qid in qids_to_relevant_passageids: 136 | ranking.append(0) 137 | target_pid = qids_to_relevant_passageids[qid] 138 | all_num = all_num + len(target_pid) 139 | candidate_pid = qids_to_ranked_candidate_passages[qid] 140 | for i in range(0, MaxMRRRank): 141 | if candidate_pid[i] in target_pid: 142 | MRR += 1.0 / (i + 1) 143 | ranking.pop() 144 | ranking.append(i + 1) 145 | break 146 | for i, pid in enumerate(candidate_pid): 147 | if pid in target_pid: 148 | recall_q_all.append(pid) 149 | if i < 50: 150 | recall_q_top50.append(pid) 151 | if i < 1000: 152 | recall_q_top1000.append(pid) 153 | if i == 0: 154 | recall_q_top1.append(pid) 155 | 156 | 157 | if len(ranking) == 0: 158 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 159 | 160 | 161 | MRR = MRR / len(qids_to_ranked_candidate_passages) 162 | recall_top1 = len(recall_q_top1) * 1.0 / all_num 163 | recall_top50 = len(recall_q_top50) * 1.0 / all_num 164 | recall_all = len(recall_q_top1000) * 1.0 / all_num 165 | all_scores['MRR @10'] = MRR 166 | all_scores["recall@1"] = recall_top1 167 | all_scores["recall@50"] = recall_top50 168 | all_scores["recall@1000"] = recall_all 169 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) 170 | return all_scores 171 | 172 | 173 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): 174 | """Compute MRR metric 175 | Args: 176 | p_path_to_reference_file (str): path to reference file. 177 | Reference file should contain lines in the following format: 178 | QUERYID\tPASSAGEID 179 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs 180 | p_path_to_candidate_file (str): path to candidate file. 181 | Candidate file sould contain lines in the following format: 182 | QUERYID\tPASSAGEID1\tRank 183 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 184 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 185 | Where the values are separated by tabs and ranked in order of relevance 186 | Returns: 187 | dict: dictionary of metrics {'MRR': } 188 | """ 189 | 190 | qids_to_relevant_passageids = load_reference(path_to_reference) 191 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) 192 | if perform_checks: 193 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 194 | if message != '': print(message) 195 | 196 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 197 | 198 | 199 | def main(): 200 | """Command line: 201 | python msmarco_eval_ranking.py 202 | """ 203 | 204 | if len(sys.argv) == 3: 205 | path_to_reference = sys.argv[1] 206 | path_to_candidate = sys.argv[2] 207 | 208 | else: 209 | print('Usage: msmarco_eval_ranking.py ') 210 | exit() 211 | 212 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 213 | print('#####################') 214 | for metric in sorted(metrics): 215 | print('{}: {}'.format(metric, metrics[metric])) 216 | print('#####################') 217 | 218 | 219 | def calc_mrr(path_to_reference, path_to_candidate): 220 | """Command line: 221 | python msmarco_eval_ranking.py 222 | """ 223 | 224 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 225 | print('#####################') 226 | for metric in sorted(metrics): 227 | print('{}: {}'.format(metric, metrics[metric])) 228 | print('#####################') 229 | return metrics 230 | 231 | 232 | def get_mrr(path_to_reference="/home/dongqian06/codes/NAACL2021-RocketQA/corpus/marco/qrels.dev.tsv", path_to_candidate="output/step_0_pred_dev_scores.txt"): 233 | all_data = pd.read_csv(path_to_candidate,sep="\t",header=None) 234 | all_data.columns = ["qid","pid","score"] 235 | all_data = all_data.groupby("qid").apply(lambda x: x.sort_values('score', ascending=False).reset_index(drop=True)) 236 | all_data.columns = ['query_id',"para_id","score"] 237 | all_data = all_data.reset_index() 238 | all_data.pop("qid") 239 | all_data.columns = ["index","qid","pid","score"] 240 | all_data = all_data.loc[:,["qid","pid","index","score"]] 241 | all_data['index']+=1 242 | path_to_candidate = path_to_candidate.replace("txt","qrels") 243 | all_data.to_csv(path_to_candidate, header=None,index=False,sep="\t") 244 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 245 | return metrics['MRR @10'] 246 | 247 | if __name__ == '__main__': 248 | main() 249 | -------------------------------------------------------------------------------- /src/train_cross_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" 4 | import argparse 5 | import random 6 | import subprocess 7 | import tempfile 8 | import time 9 | from collections import defaultdict 10 | 11 | import faiss 12 | import numpy as np 13 | import torch 14 | import torch.distributed as dist 15 | import torch.nn.functional as F 16 | from torch import distributed, optim 17 | from torch.cuda.amp import GradScaler, autocast 18 | from torch.nn.parallel import DistributedDataParallel 19 | from torch.utils.data import DataLoader, Dataset 20 | from tqdm import tqdm 21 | from transformers import AutoConfig, AutoTokenizer, BertModel 22 | 23 | import dataset_factory 24 | import utils 25 | from modeling import Reranker 26 | from msmarco_eval import calc_mrr 27 | from utils import add_prefix, build_engine, load_qid, read_embed, search 28 | 29 | SEED = 2023 30 | best_mrr=-1 31 | torch.manual_seed(SEED) 32 | torch.cuda.manual_seed_all(SEED) 33 | random.seed(SEED) 34 | def define_args(): 35 | import argparse 36 | parser = argparse.ArgumentParser('BERT-ranker model') 37 | parser.add_argument('--batch_size', type=int, default=2) 38 | parser.add_argument('--dev_batch_size', type=int, default=64) 39 | parser.add_argument('--max_seq_len', type=int, default=160) 40 | parser.add_argument('--q_max_seq_len', type=int, default=160) 41 | parser.add_argument('--p_max_seq_len', type=int, default=160) 42 | parser.add_argument('--model_name_or_path', type=str, default="../../data/bert-base-uncased/") 43 | parser.add_argument('--reranker_model_name_or_path', type=str, default="../../data/bert-base-uncased/") 44 | parser.add_argument('--warm_start_from', type=str, default="") 45 | parser.add_argument('--model_out_dir', type=str, default="output") 46 | parser.add_argument('--learning_rate', type=float, default=1e-5) 47 | parser.add_argument('--weight_decay', type=float, default=0.01) 48 | parser.add_argument('--warmup_proportion', type=float, default=0.1) 49 | parser.add_argument('--eval_step_proportion', type=float, default=1.0) 50 | parser.add_argument('--report', type=int, default=1) 51 | parser.add_argument('--epoch', type=int, default=3) 52 | parser.add_argument('--qrels', type=str, default="../../data/marco/qrels.train.debug.tsv") 53 | parser.add_argument('--dev_qrels', type=str, default="../../data/marco/qrels.train.debug.tsv") 54 | parser.add_argument('--top1000', type=str, default="../../data/marco/run.msmarco-passage.train.debug.tsv") 55 | parser.add_argument('--dev_top1000', type=str, default="../../data/marco/run.msmarco-passage.train.debug.tsv") 56 | parser.add_argument('--collection', type=str, default="../../data/marco/collection.debug.tsv") 57 | parser.add_argument('--query', type=str, default="../../data/marco/train.query.debug.txt") 58 | parser.add_argument('--dev_query', type=str, default="../../data/marco/train.query.debug.txt") 59 | parser.add_argument('--min_index', type=int, default=0) 60 | parser.add_argument('--max_index', type=int, default=256) 61 | parser.add_argument('--sample_num', type=int, default=128) 62 | parser.add_argument('--num_labels', type=int, default=1) 63 | parser.add_argument('--local-rank', type=int, default=0) 64 | parser.add_argument('--local_rank', type=int, default=0) 65 | parser.add_argument('--fp16', type=bool, default=True) 66 | parser.add_argument('--gradient_checkpoint', type=bool, default=True) 67 | parser.add_argument('--negatives_x_device', type=bool, default=True) 68 | parser.add_argument('--untie_encoder', type=bool, default=True) 69 | parser.add_argument('--add_pooler', type=bool, default=False) 70 | parser.add_argument('--Temperature', type=float, default=1.0) 71 | 72 | # args = parser.parse_args(args=[]) 73 | args = parser.parse_args() 74 | return args 75 | 76 | def merge(eval_cnts, file_pattern='output/res.step-%d.part-0%d'): 77 | f_list = [] 78 | total_part = torch.distributed.get_world_size() 79 | for part in range(total_part): 80 | f0 = open(file_pattern % (eval_cnts, part)) 81 | f_list+=f0.readlines() 82 | f_list = [l.strip().split("\t") for l in f_list] 83 | dedup = defaultdict(dict) 84 | for qid,pid,score in f_list: 85 | dedup[qid][pid] = float(score) 86 | mp = defaultdict(list) 87 | for qid in dedup: 88 | for pid in dedup[qid]: 89 | mp[qid].append((pid, dedup[qid][pid])) 90 | for qid in mp: 91 | mp[qid].sort(key=lambda x:x[1], reverse=True) 92 | with open(file_pattern.replace('.part-0%d','')%eval_cnts, 'w') as f: 93 | for qid in mp: 94 | for idx, (pid, score) in enumerate(mp[qid]): 95 | f.write(str(qid)+"\t"+str(pid)+'\t'+str(idx+1)+"\t"+str(score)+'\n') 96 | for part in range(total_part): 97 | os.remove(file_pattern % (eval_cnts, part)) 98 | 99 | def train_cross_encoder(args, model, optimizer): 100 | epoch = 0 101 | local_rank = torch.distributed.get_rank() 102 | if local_rank==0: 103 | print(f'Starting training, upto {args.epoch} epochs, LR={args.learning_rate}', flush=True) 104 | 105 | # 加载数据集 106 | dev_dataset = dataset_factory.CrossEncoderDevDataset(args) 107 | dev_sampler = torch.utils.data.distributed.DistributedSampler(dev_dataset) 108 | dev_loader = DataLoader(dev_dataset, batch_size=args.dev_batch_size, collate_fn=dev_dataset._collate_fn, sampler=dev_sampler, num_workers=4) 109 | validate_multi_gpu(model, dev_loader, epoch, args) 110 | 111 | train_dataset = dataset_factory.CrossEncoderTrainDataset(args) 112 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 113 | 114 | for epoch in range(1, args.epoch+1): 115 | train_dataset.set_epoch(epoch) # 选择的negative根据epoch后移 116 | train_sampler.set_epoch(epoch) # shuffle batch 117 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=train_dataset._collate_fn, sampler=train_sampler, num_workers=4, drop_last=True) 118 | train_iteration_multi_gpu(model, optimizer, train_loader, epoch, args) 119 | torch.distributed.barrier() 120 | del train_loader 121 | if epoch%1==0: 122 | validate_multi_gpu(model, dev_loader, epoch, args) 123 | torch.distributed.barrier() 124 | 125 | def validate_multi_gpu(model, dev_loader, epoch, args): 126 | global best_mrr 127 | local_start = time.time() 128 | local_rank = torch.distributed.get_rank() 129 | world_size = torch.distributed.get_world_size() 130 | with torch.no_grad(): 131 | model.eval() 132 | scores_lst = [] 133 | qids_lst = [] 134 | pids_lst = [] 135 | for record1, record2 in tqdm(dev_loader): 136 | with autocast(): 137 | scores = model(_prepare_inputs(record1)) 138 | qids = record2['qids'] 139 | pids = record2['pids'] 140 | scores_lst.append(scores.detach().cpu().numpy().copy()) 141 | qids_lst.append(qids.copy()) 142 | pids_lst.append(pids.copy()) 143 | qids_lst = np.concatenate(qids_lst).reshape(-1) 144 | pids_lst = np.concatenate(pids_lst).reshape(-1) 145 | scores_lst = np.concatenate(scores_lst).reshape(-1) 146 | with open("output/res.step-%d.part-0%d"%(epoch, local_rank), 'w') as f: 147 | for qid,pid,score in zip(qids_lst, pids_lst, scores_lst): 148 | f.write(str(qid)+'\t'+str(pid)+'\t'+str(score)+'\n') 149 | torch.distributed.barrier() 150 | if local_rank==0: 151 | merge(epoch) 152 | metrics = calc_mrr(args.dev_qrels, 'output/res.step-%d'%epoch) 153 | mrr = metrics['MRR @10'] 154 | if mrr>best_mrr: 155 | print("*"*50) 156 | print("new top") 157 | print("*"*50) 158 | best_mrr = mrr 159 | torch.save(model.module.lm.state_dict(), os.path.join(args.model_out_dir, "reranker.p")) 160 | 161 | 162 | 163 | def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: 164 | rt = tensor.clone() 165 | distributed.all_reduce(rt, op=distributed.ReduceOp.SUM) 166 | rt /= distributed.get_world_size()#进程数 167 | return rt 168 | 169 | def _prepare_inputs(record): 170 | prepared = {} 171 | local_rank = torch.distributed.get_rank() 172 | for key in record: 173 | x = record[key] 174 | if isinstance(x, torch.Tensor): 175 | prepared[key] = x.to(local_rank) 176 | else: 177 | prepared[key] = _prepare_inputs(x) 178 | return prepared 179 | 180 | def train_iteration_multi_gpu(model, optimizer, data_loader, epoch, args): 181 | total = 0 182 | model.train() 183 | total_loss = 0. 184 | local_rank = torch.distributed.get_rank() 185 | world_size = torch.distributed.get_world_size() 186 | start = time.time() 187 | local_start = time.time() 188 | all_steps_per_epoch = len(data_loader) 189 | step = 0 190 | scaler = GradScaler() 191 | for record in data_loader: 192 | record = _prepare_inputs(record) 193 | if args.fp16: 194 | with autocast(): 195 | loss = model(record) 196 | else: 197 | loss = model(record) 198 | torch.distributed.barrier() 199 | reduced_loss = reduce_tensor(loss.data) 200 | total_loss += reduced_loss.item() 201 | # optimize 202 | optimizer.zero_grad() 203 | scaler.scale(loss).backward() 204 | scaler.step(optimizer) 205 | scaler.update() 206 | step+=1 207 | if step%args.report==0 and local_rank==0: 208 | seconds = time.time()-local_start 209 | m, s = divmod(seconds, 60) 210 | h, m = divmod(m, 60) 211 | local_start = time.time() 212 | print("epoch:%d training step: %d/%d, mean loss: %.5f, current loss: %.5f,"%(epoch, step, all_steps_per_epoch, total_loss/step, loss.cpu().detach().numpy()),"report used time:%02d:%02d:%02d," % (h, m, s), end=' ') 213 | seconds = time.time()-start 214 | m, s = divmod(seconds, 60) 215 | h, m = divmod(m, 60) 216 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ') 217 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime())) 218 | if local_rank==0: 219 | # model.save(os.path.join(args.model_out_dir, "weights.epoch-%d.p"%(epoch))) 220 | torch.save(model.module.state_dict(), os.path.join(args.model_out_dir, "weights.epoch-%d.p"%(epoch))) 221 | seconds = time.time()-start 222 | m, s = divmod(seconds, 60) 223 | h, m = divmod(m, 60) 224 | print(f'train epoch={epoch} loss={total_loss}') 225 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ') 226 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime())) 227 | 228 | if __name__ == '__main__': 229 | args = define_args() 230 | args = vars(args) 231 | args = utils.HParams(**args) 232 | args.reranker_model_name_or_path = args.model_name_or_path 233 | # 加载到多卡 234 | torch.distributed.init_process_group(backend="nccl", init_method='env://') 235 | local_rank = torch.distributed.get_rank() 236 | if local_rank==0: 237 | args.print_config() 238 | torch.cuda.set_device(local_rank) 239 | device = torch.device("cuda", local_rank) 240 | 241 | model = Reranker(args) 242 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 243 | model.to(device) 244 | 245 | params = [(k, v) for k, v in model.named_parameters() if v.requires_grad] 246 | params = {'params': [v for k, v in params]} 247 | optimizer = torch.optim.Adam([params], lr=args.learning_rate, weight_decay=0.0) 248 | 249 | if args.warm_start_from: 250 | print('warm start from ', args.warm_start_from) 251 | state_dict = torch.load(args.warm_start_from, map_location=device) 252 | for k in list(state_dict.keys()): 253 | state_dict[k.replace('module.','')] = state_dict.pop(k) 254 | model.load_state_dict(state_dict) 255 | 256 | 257 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) 258 | os.makedirs(args.model_out_dir, exist_ok=True) 259 | 260 | train_cross_encoder(args, model, optimizer) 261 | -------------------------------------------------------------------------------- /src/train_dual_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" 4 | import argparse 5 | import random 6 | import subprocess 7 | import tempfile 8 | import time 9 | from collections import defaultdict 10 | 11 | import faiss 12 | import numpy as np 13 | import torch 14 | import torch.distributed as dist 15 | import torch.nn.functional as F 16 | from torch import distributed 17 | import torch_optimizer as optim 18 | from torch.cuda.amp import GradScaler, autocast 19 | from torch.nn.parallel import DistributedDataParallel 20 | from torch.utils.data import DataLoader, Dataset 21 | from tqdm import tqdm 22 | from transformers import AutoConfig, AutoTokenizer, BertModel 23 | 24 | import dataset_factory 25 | import utils 26 | from modeling import DualEncoder 27 | from msmarco_eval import calc_mrr 28 | from utils import add_prefix, build_engine, load_qid, merge, read_embed, search 29 | 30 | SEED = 2023 31 | best_mrr=-1 32 | torch.manual_seed(SEED) 33 | torch.cuda.manual_seed_all(SEED) 34 | random.seed(SEED) 35 | def define_args(): 36 | parser = argparse.ArgumentParser('BERT-retrieval model') 37 | parser.add_argument('--batch_size', type=int, default=2) 38 | parser.add_argument('--dev_batch_size', type=int, default=64) 39 | parser.add_argument('--max_seq_len', type=int, default=160) 40 | parser.add_argument('--q_max_seq_len', type=int, default=160) 41 | parser.add_argument('--p_max_seq_len', type=int, default=160) 42 | parser.add_argument('--retriever_model_name_or_path', type=str, default="") 43 | parser.add_argument('--model_out_dir', type=str, default="output") 44 | parser.add_argument('--learning_rate', type=float, default=1e-5) 45 | parser.add_argument('--weight_decay', type=float, default=0.01) 46 | parser.add_argument('--warmup_proportion', type=float, default=0.1) 47 | parser.add_argument('--eval_step_proportion', type=float, default=1.0) 48 | parser.add_argument('--report', type=int, default=1) 49 | parser.add_argument('--epoch', type=int, default=3) 50 | parser.add_argument('--qrels', type=str, default="/home/dongqian06/hdfs_data/data_train/qrels.train.debug.tsv") 51 | parser.add_argument('--dev_qrels', type=str, default="/home/dongqian06/hdfs_data/data_train/qrels.train.debug.tsv") 52 | parser.add_argument('--top1000', type=str, default="/home/dongqian06/codes/anserini/runs/run.msmarco-passage.train.debug.tsv") 53 | parser.add_argument('--collection', type=str, default="/home/dongqian06/hdfs_data/data_train/marco/collection.debug.tsv") 54 | parser.add_argument('--query', type=str, default="/home/dongqian06/hdfs_data/data_train/train.query.debug.txt") 55 | parser.add_argument('--dev_query', type=str, default="/home/dongqian06/hdfs_data/data_train/train.query.debug.txt") 56 | parser.add_argument('--min_index', type=int, default=0) 57 | parser.add_argument('--max_index', type=int, default=256) 58 | parser.add_argument('--sample_num', type=int, default=256) 59 | parser.add_argument('--num_labels', type=int, default=1) 60 | parser.add_argument('--local-rank', type=int, default=0) 61 | parser.add_argument('--local_rank', type=int, default=0) 62 | parser.add_argument('--fp16', type=bool, default=True) 63 | parser.add_argument('--gradient_checkpoint', type=bool, default=False) 64 | parser.add_argument('--negatives_x_device', type=bool, default=True) 65 | parser.add_argument('--negatives_in_device', type=bool, default=True) 66 | parser.add_argument('--untie_encoder', type=bool, default=True) 67 | parser.add_argument('--add_pooler', type=bool, default=False) 68 | parser.add_argument('--warm_start_from', type=str, default="") 69 | 70 | # args = parser.parse_args(args=[]) 71 | args = parser.parse_args() 72 | return args 73 | 74 | 75 | def main_multi(args, model, optimizer): 76 | epoch = 0 77 | local_rank = torch.distributed.get_rank() 78 | if local_rank==0: 79 | print(f'Starting training, upto {args.epoch} epochs, LR={args.learning_rate}', flush=True) 80 | 81 | # 加载数据集 82 | query_dataset = dataset_factory.QueryDataset(args) 83 | query_loader = DataLoader(query_dataset, batch_size=args.dev_batch_size, collate_fn=query_dataset._collate_fn, num_workers=3) 84 | passage_dataset = dataset_factory.PassageDataset(args) 85 | passage_loader = DataLoader(passage_dataset, batch_size=args.dev_batch_size, collate_fn=passage_dataset._collate_fn, num_workers=3) 86 | validate_multi_gpu(model, query_loader, passage_loader, epoch, args) 87 | 88 | train_dataset = dataset_factory.DualEncoderTrainDataset(args) 89 | 90 | for epoch in range(1, args.epoch+1): 91 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 92 | train_sampler.set_epoch(epoch) 93 | train_dataset.set_epoch(epoch) 94 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=train_dataset._collate_fn, sampler=train_sampler, num_workers=4) 95 | loss = train_iteration_multi_gpu(model, optimizer, train_loader, epoch, args) 96 | del train_loader 97 | torch.distributed.barrier() 98 | if epoch%10==0: 99 | validate_multi_gpu(model, query_loader, passage_loader, epoch, args) 100 | torch.distributed.barrier() 101 | 102 | def validate_multi_gpu(model, query_loader, passage_loader, epoch, args): 103 | global best_mrr 104 | local_start = time.time() 105 | local_rank = torch.distributed.get_rank() 106 | world_size = torch.distributed.get_world_size() 107 | _output_file_name = 'output/_para.index.part%d'%local_rank 108 | output_file_name = 'output/para.index.part%d'%local_rank 109 | top_k = 1000 110 | q_output_file_name = 'output/query.emb.step%d.npy'%epoch 111 | if local_rank==0: 112 | q_embs = [] 113 | with torch.no_grad(): 114 | model.eval() 115 | for records in query_loader: 116 | if args.fp16: 117 | with autocast(): 118 | q_reps = model(query_inputs=_prepare_inputs(records)) 119 | else: 120 | q_reps = model(query_inputs=_prepare_inputs(records)) 121 | q_embs.append(q_reps.cpu().detach().numpy()) 122 | emb_matrix = np.concatenate(q_embs, axis=0) 123 | np.save(q_output_file_name, emb_matrix) 124 | print("predict q_embs cnt: %s" % len(emb_matrix)) 125 | with torch.no_grad(): 126 | model.eval() 127 | para_embs = [] 128 | for records in tqdm(passage_loader, disable=args.local_rank>0): 129 | if args.fp16: 130 | with autocast(): 131 | p_reps = model(passage_inputs=_prepare_inputs(records)) 132 | else: 133 | p_reps = model(passage_inputs=_prepare_inputs(records)) 134 | para_embs.append(p_reps.cpu().detach().numpy()) 135 | torch.distributed.barrier() 136 | para_embs = np.concatenate(para_embs, axis=0) 137 | # para_embs = np.load('output/_para.emb.part%d.npy'%local_rank) 138 | print("predict embs cnt: %s" % len(para_embs)) 139 | # engine = build_engine(para_embs, 768) 140 | # faiss.write_index(engine, _output_file_name) 141 | engine = torch.from_numpy(para_embs).cuda() 142 | np.save('output/_para.emb.part%d.npy'%local_rank, para_embs) 143 | print('create index done!') 144 | qid_list = load_qid(args.dev_query) 145 | search(engine, q_output_file_name, qid_list, "output/res.top%d.part%d.step%d"%(top_k, local_rank, epoch), top_k=top_k) 146 | torch.distributed.barrier() 147 | if local_rank==0: 148 | f_list = [] 149 | for part in range(world_size): 150 | f_list.append('output/res.top%d.part%d.step%d' % (top_k, part, epoch)) 151 | shift = np.load("output/_para.emb.part0.npy").shape[0] 152 | merge(world_size, shift, top_k, epoch) 153 | metrics = calc_mrr(args.dev_qrels, 'output/res.top%d.step%d'%(top_k, epoch)) 154 | for run in f_list: 155 | os.remove(run) 156 | mrr = metrics['MRR @10'] 157 | if mrr>best_mrr: 158 | print("*"*50) 159 | print("new top") 160 | print("*"*50) 161 | best_mrr = mrr 162 | for part in range(world_size): 163 | os.rename('output/_para.emb.part%d.npy'%part, 'output/para.emb.part%d.npy'%part) 164 | torch.save(model.state_dict(), "output/best.p") 165 | seconds = time.time()-local_start 166 | m, s = divmod(seconds, 60) 167 | h, m = divmod(m, 60) 168 | print("******************eval, mrr@10: %.10f,"%(mrr),"report used time:%02d:%02d:%02d," % (h, m, s)) 169 | 170 | 171 | def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: 172 | rt = tensor.clone() 173 | distributed.all_reduce(rt, op=distributed.ReduceOp.SUM) 174 | rt /= distributed.get_world_size()#进程数 175 | return rt 176 | 177 | def _prepare_inputs(record): 178 | prepared = {} 179 | local_rank = torch.distributed.get_rank() 180 | for key in record: 181 | x = record[key] 182 | if isinstance(x, torch.Tensor): 183 | prepared[key] = x.to(local_rank) 184 | elif x is None: 185 | prepared[key] = x 186 | else: 187 | prepared[key] = _prepare_inputs(x) 188 | return prepared 189 | 190 | def train_iteration_multi_gpu(model, optimizer, data_loader, epoch, args): 191 | total = 0 192 | model.train() 193 | total_loss = 0. 194 | total_ce_loss = 0. 195 | local_rank = torch.distributed.get_rank() 196 | world_size = torch.distributed.get_world_size() 197 | start = time.time() 198 | local_start = time.time() 199 | all_steps_per_epoch = len(data_loader) 200 | step = 0 201 | scaler = GradScaler() 202 | for record in data_loader: 203 | record = _prepare_inputs(record) 204 | with autocast(): 205 | retriever_ce_loss = model(**record) 206 | loss = retriever_ce_loss 207 | torch.distributed.barrier() 208 | reduced_loss = reduce_tensor(loss.data) 209 | total_loss += reduced_loss.item() 210 | total_ce_loss += float(retriever_ce_loss.cpu().detach().numpy()) 211 | 212 | # optimize 213 | optimizer.zero_grad() 214 | scaler.scale(loss).backward() 215 | scaler.step(optimizer) 216 | scaler.update() 217 | step+=1 218 | if step%args.report==0 and local_rank==0: 219 | seconds = time.time()-local_start 220 | m, s = divmod(seconds, 60) 221 | h, m = divmod(m, 60) 222 | local_start = time.time() 223 | print(f"epoch:{epoch} training step: {step}/{all_steps_per_epoch}, mean loss: {total_loss/step}, ce loss: {total_ce_loss/step}, ", "report used time:%02d:%02d:%02d," % (h, m, s), end=' ') 224 | seconds = time.time()-start 225 | m, s = divmod(seconds, 60) 226 | h, m = divmod(m, 60) 227 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ') 228 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime())) 229 | if local_rank==0: 230 | # model.save(os.path.join(args.model_out_dir, "weights.epoch-%d.p"%(epoch))) 231 | # torch.save(model.state_dict(), os.path.join(args.model_out_dir, "weights.epoch-%d.p"%(epoch))) 232 | seconds = time.time()-start 233 | m, s = divmod(seconds, 60) 234 | h, m = divmod(m, 60) 235 | print(f'train epoch={epoch} loss={total_loss}') 236 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ') 237 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime())) 238 | return total_loss 239 | 240 | 241 | def main_cli(): 242 | args = define_args() 243 | args = vars(args) 244 | args = utils.HParams(**args) 245 | # 加载到多卡 246 | torch.distributed.init_process_group(backend="nccl", init_method='env://') 247 | local_rank = torch.distributed.get_rank() 248 | if local_rank==0: 249 | args.print_config() 250 | torch.cuda.set_device(local_rank) 251 | device = torch.device("cuda", local_rank) 252 | 253 | model = DualEncoder(args) 254 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 255 | model.to(device) 256 | 257 | params = [(k, v) for k, v in model.named_parameters() if v.requires_grad] 258 | params = {'params': [v for k, v in params]} 259 | # optimizer = torch.optim.Adam([params], lr=args.learning_rate, weight_decay=0.0) 260 | optimizer = optim.Lamb([params], lr=args.learning_rate, weight_decay=0.0) 261 | 262 | if args.warm_start_from: 263 | print('warm start from ', args.warm_start_from) 264 | state_dict = torch.load(args.warm_start_from, map_location=device) 265 | for k in list(state_dict.keys()): 266 | state_dict[k.replace('module.','')] = state_dict.pop(k) 267 | model.load_state_dict(state_dict, strict=True) 268 | 269 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) 270 | print("model loaded on GPU%d"%local_rank) 271 | print(args.model_out_dir) 272 | os.makedirs(args.model_out_dir, exist_ok=True) 273 | 274 | main_multi(args, model, optimizer) 275 | 276 | if __name__ == '__main__': 277 | main_cli() 278 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import six 3 | import os 4 | import faiss 5 | import numpy as np 6 | class HParams(object): 7 | """Hyper paramerter""" 8 | 9 | def __init__(self, **kwargs): 10 | for k, v in kwargs.items(): 11 | self.__dict__[k] = v 12 | 13 | def __contains__(self, key): 14 | return key in self.__dict__ 15 | 16 | def __getitem__(self, key): 17 | if key not in self.__dict__: 18 | raise ValueError('key(%s) not in HParams.' % key) 19 | return self.__dict__[key] 20 | 21 | def __repr__(self): 22 | return repr(self.to_dict()) 23 | 24 | def __setitem__(self, key, val): 25 | self.__dict__[key] = val 26 | 27 | @classmethod 28 | def from_json(cls, json_str): 29 | """doc""" 30 | d = json.loads(json_str) 31 | if type(d) != dict: 32 | raise ValueError('json object must be dict.') 33 | return HParams.from_dict(d) 34 | 35 | def get(self, key, default=None): 36 | """doc""" 37 | return self.__dict__.get(key, default) 38 | 39 | @classmethod 40 | def from_dict(cls, d): 41 | """doc""" 42 | if type(d) != dict: 43 | raise ValueError('input must be dict.') 44 | hp = HParams(**d) 45 | return hp 46 | 47 | def to_json(self): 48 | """doc""" 49 | return json.dumps(self.__dict__) 50 | 51 | def to_dict(self): 52 | """doc""" 53 | return self.__dict__ 54 | 55 | def print_config(self): 56 | for key,value in self.__dict__.items(): 57 | print(key+":",value) 58 | 59 | def join(self, other): 60 | """doc""" 61 | if not isinstance(other, HParams): 62 | raise ValueError('input must be HParams instance.') 63 | self.__dict__.update(**other.__dict__) 64 | return self 65 | 66 | 67 | def _get_dict_from_environ_or_json_or_file(args, env_name): 68 | if args == '': 69 | return None 70 | if args is None: 71 | s = os.environ.get(env_name) 72 | else: 73 | s = args 74 | if os.path.exists(s): 75 | s = open(s).read() 76 | if isinstance(s, six.string_types): 77 | try: 78 | r = eval(s) 79 | except SyntaxError as e: 80 | raise ValueError('json parse error: %s \n>Got json: %s' % 81 | (repr(e), s)) 82 | return r 83 | else: 84 | return s #None 85 | 86 | 87 | def parse_file(filename): 88 | """useless api""" 89 | d = _get_dict_from_environ_or_json_or_file(filename, None) 90 | if d is None: 91 | raise ValueError('file(%s) not found' % filename) 92 | return d 93 | 94 | def build_engine(p_emb_matrix, dim): 95 | index = faiss.IndexFlatIP(dim) 96 | index.add(p_emb_matrix.astype('float32')) 97 | return index 98 | from tqdm import tqdm 99 | def read_embed(file_name, dim=768, bs=100): 100 | if file_name.endswith('npy'): 101 | i = 0 102 | emb_np = np.load(file_name) 103 | with tqdm(total=len(emb_np)//bs+1) as pbar: 104 | while(i < len(emb_np)): 105 | vec_list = emb_np[i:i+bs] 106 | i += bs 107 | pbar.update(1) 108 | yield vec_list 109 | else: 110 | vec_list = [] 111 | with open(file_name) as inp: 112 | for line in tqdm(inp): 113 | data = line.strip() 114 | vector = [float(item) for item in data.split(' ')] 115 | assert len(vector) == dim 116 | vec_list.append(vector) 117 | if len(vec_list) == bs: 118 | yield vec_list 119 | vec_list = [] 120 | if vec_list: 121 | yield vec_list 122 | 123 | def load_qid(file_name): 124 | qid_list = [] 125 | with open(file_name) as inp: 126 | for line in inp: 127 | line = line.strip() 128 | qid = line.split('\t')[0] 129 | try: 130 | int(qid) 131 | qid_list.append(qid) 132 | except: 133 | pass 134 | return qid_list 135 | 136 | import torch 137 | def topk_query_passage(query_vector, passage_vector, k): 138 | """ 139 | 对query vector和passage vector进行内积计算,并返回top k的索引 140 | 141 | Args: 142 | query_vector (torch.Tensor): query向量,形状为 (batch_size, query_dim) 143 | passage_vector (torch.Tensor): passage向量,形状为 (batch_size, passage_dim) 144 | k (int): 返回的top k值 145 | 146 | Returns: 147 | torch.Tensor: top k值的索引,形状为 (batch_size, k) 148 | """ 149 | # 计算query向量和passage向量的内积 150 | scores = torch.matmul(query_vector, passage_vector.t()) # 形状为 (batch_size, batch_size) 151 | 152 | # 对每个batch进行排序,取top k值 153 | res_dist, res_p_id = torch.topk(scores, k=k, dim=1) # 形状为 (batch_size, k) 154 | 155 | return res_dist.cpu().numpy(), res_p_id.cpu().numpy() 156 | 157 | def search(index, emb_file, qid_list, outfile, top_k): 158 | q_idx = 0 159 | with open(outfile, 'w') as out: 160 | for batch_vec in read_embed(emb_file): 161 | q_emb_matrix = np.array(batch_vec) 162 | q_emb_matrix = torch.from_numpy(q_emb_matrix) 163 | q_emb_matrix = q_emb_matrix.cuda() 164 | res_dist, res_p_id = topk_query_passage(q_emb_matrix, index, top_k) 165 | for i in range(len(q_emb_matrix)): 166 | qid = qid_list[q_idx] 167 | for j in range(top_k): 168 | pid = res_p_id[i][j] 169 | score = res_dist[i][j] 170 | out.write('%s\t%s\t%s\t%s\n' % (qid, pid, j+1, score)) 171 | q_idx += 1 172 | 173 | def merge(total_part, shift, top, eval_cnts): 174 | f_list = [] 175 | for part in range(total_part): 176 | f0 = open('output/res.top%d.part%d.step%d' % (top, part, eval_cnts)) 177 | f_list.append(f0) 178 | 179 | line_list = [] 180 | for part in range(total_part): 181 | line = f_list[part].readline() 182 | line_list.append(line) 183 | 184 | out = open('output/res.top%d.step%d' % (top, eval_cnts), 'w') 185 | last_q = '' 186 | ans_list = {} 187 | while line_list[-1]: 188 | cur_list = [] 189 | for line in line_list: 190 | sub = line.strip().split('\t') 191 | cur_list.append(sub) 192 | 193 | if last_q == '': 194 | last_q = cur_list[0][0] 195 | if cur_list[0][0] != last_q: 196 | rank = sorted(ans_list.items(), key = lambda a:a[1], reverse=True) 197 | for i in range(top): 198 | out.write("%s\t%s\t%s\t%s\n" % (last_q, rank[i][0], i+1, rank[i][1])) 199 | ans_list = {} 200 | for i, sub in enumerate(cur_list): 201 | ans_list[int(sub[1]) + shift*i] = float(sub[-1]) 202 | last_q = cur_list[0][0] 203 | 204 | line_list = [] 205 | for f0 in f_list: 206 | line = f0.readline() 207 | line_list.append(line) 208 | 209 | rank = sorted(ans_list.items(), key = lambda a:a[1], reverse=True) 210 | for i in range(top): 211 | out.write("%s\t%s\t%s\t%s\n" % (last_q, rank[i][0], i+1, rank[i][1])) 212 | out.close() 213 | 214 | 215 | def add_prefix(state_dict, prefix='module.'): 216 | if all(key.startswith(prefix) for key in state_dict.keys()): 217 | return state_dict 218 | stripped_state_dict = {} 219 | for key in list(state_dict.keys()): 220 | key2 = prefix + key 221 | stripped_state_dict[key2] = state_dict.pop(key) 222 | return stripped_state_dict 223 | 224 | def filter_stop_words(txts): 225 | stop_words=['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 'won', 'wouldn'] 226 | txts = [t.split() for t in txts] 227 | txts = [list(set(list(filter(lambda x:x not in stop_words,t)))) for t in txts] 228 | rets = [] 229 | for t in txts: 230 | rets+=t 231 | return list(set(rets)) 232 | --------------------------------------------------------------------------------