├── README.md
├── pic
└── stat.png
├── requirements.txt
├── script
├── train_cross_encoder.sh
└── train_dual_encoder.sh
└── src
├── convert2trec.py
├── dataset_factory.py
├── modeling.py
├── msmarco_eval.py
├── train_cross_encoder.py
├── train_dual_encoder.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # T2Ranking
2 |
3 | ## Introduction
4 | T2Ranking is a large-scale Chinese benchmark for passage ranking. The details about T2Ranking are elaborated in [this paper](https://arxiv.org/abs/2304.03679#).
5 |
6 |
7 | Passage ranking are important and challenging topics for both academics and industries in the area of Information Retrieval (IR). The goal of passage ranking is to compile a search result list ordered in terms of relevance to the query from a large passage collection. Typically, Passage ranking involves two stages: passage retrieval and passage re-ranking.
8 |
9 | To support the passage ranking research, various benchmark datasets are constructed. However, the commonly-used datasets for passage ranking usually focus on the English language. For non-English scenarios, such as Chinese, the existing datasets are limited in terms of data scale, fine-grained relevance annotation and false negative issues.
10 |
11 |
12 | To address this problem, we introduce T2Ranking, a large-scale Chinese benchmark for passage ranking. T2Ranking comprises more than 300K queries and over 2M unique passages from real- world search engines. Specifically, we sample question-based search queries from user logs of the Sogou search engine, a popular search system in China. For each query, we extract the content of corresponding documents from different search engines. After model-based passage segmentation and clustering-based passage de-duplication, a large-scale passage corpus is obtained. For a given query and its corresponding passages, we hire expert annotators to provide 4-level relevance judgments of each query-passage pair.
13 |
14 |
15 |

16 | Table 1: The data statistics of datasets commonly used in passage ranking. FR(SR): First (Second)- stage of passage ranking, i.e., passage Retrieval (Re-ranking).
17 |
18 |
19 |
20 | Compared with existing datasets, T2Ranking dataset has the following characteristics and advantages:
21 | * The proposed dataset focus on the Chinese search scenario, and has advantages in data scale compared with existing Chinese passage ranking datasets, which can better support the design of deep learning algorithms
22 | * The proposed dataset has a large number of fine-grained relevance annotations, which is helpful for mining fine-grained relationship between queries and passages and constructing more accurate ranking algorithms.
23 | * By retrieving passage results from multiple commercial search engines and providing complete annotation, we ease the false negative problem to some extent, which is beneficial to providing more accurate evaluation.
24 | * We design multiple strategies to ensure the high quality of our dataset, such as using a passage segment model and a passage clustering model to enhance the semantic integrity and diversity of passages and employing active learning for annotation method to improve the efficiency and quality of data annotation.
25 |
26 | ## Data Download
27 | The whole dataset is placed in [huggingface](https://huggingface.co/datasets/THUIR/T2Ranking), and the data formats are presented in the following table.
28 |
29 |
30 | | Description| Filename|Num Records|Format|
31 | |-------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------|----------:|-----------------------------------:|
32 | | Collection | collection.tsv | 2,303,643 | tsv: pid, passage |
33 | | Queries Train | queries.train.tsv | 258,042 | tsv: qid, query |
34 | | Queries Dev | queries.dev.tsv | 24,832 | tsv: qid, query |
35 | | Queries Test | queries.test.tsv | 24,832 | tsv: qid, query |
36 | | Qrels Train for re-ranking | qrels.train.tsv | 1,613,421 | TREC qrels format |
37 | | Qrels Dev for re-ranking | qrels.dev.tsv | 400,536 | TREC qrels format |
38 | | Qrels Retrieval Train | qrels.retrieval.train.tsv | 744,663 | tsv: qid, pid |
39 | | Qrels Retrieval Dev | qrels.retrieval.dev.tsv | 118,933 | tsv: qid, pid |
40 | | BM25 Negatives | train.bm25.tsv | 200,359,731 | tsv: qid, pid, index |
41 | | Hard Negatives | train.mined.tsv | 200,376,001 | tsv: qid, pid, index, score |
42 |
43 |
44 |
45 | You can download the dataset by running the following command:
46 | ```bash
47 | git lfs install
48 | git clone https://huggingface.co/datasets/THUIR/T2Ranking
49 | ```
50 | After downloading, you can find the following files in the folder:
51 | ```
52 | ├── data
53 | │ ├── collection.tsv
54 | │ ├── qrels.dev.tsv
55 | │ ├── qrels.retrieval.dev.tsv
56 | │ ├── qrels.retrieval.train.tsv
57 | │ ├── qrels.train.tsv
58 | │ ├── queries.dev.tsv
59 | │ ├── queries.test.tsv
60 | │ ├── queries.train.tsv
61 | │ ├── train.bm25.tsv
62 | │ └── train.mined.tsv
63 | ├── script
64 | │ ├── train_cross_encoder.sh
65 | │ └── train_dual_encoder.sh
66 | └── src
67 | ├── convert2trec.py
68 | ├── dataset_factory.py
69 | ├── modeling.py
70 | ├── msmarco_eval.py
71 | ├── train_cross_encoder.py
72 | ├── train_dual_encoder.py
73 | └── utils.py
74 | ```
75 |
76 |
77 |
78 | ## Training and Evaluation
79 | The dual-encoder can be trained by running the following command:
80 | ```bash
81 | sh script/train_dual_encoder.sh
82 | ```
83 | After training the model, you can evaluate the model by running the following command:
84 | ```bash
85 | python src/msmarco_eval.py data/qrels.retrieval.dev.tsv output/res.top1000.step20
86 | ```
87 |
88 |
89 | The cross-encoder can be trained by running the following command:
90 | ```bash
91 | sh script/train_cross_encoder.sh
92 | ```
93 | After training the model, you can evaluate the model by running the following command:
94 | ```bash
95 | python src/convert2trec.py output/res.step-20 && python src/msmarco_eval.py data/qrels.retrieval.dev.tsv output/res.step-20.trec && path_to/trec_eval -m ndcg_cut.5 data/qrels.dev.tsv res.step-20.trec
96 | ```
97 |
98 | We have uploaded some checkpoints to Huggingface Hub.
99 |
100 | | Model | Description | Link |
101 | | ------------------ | --------------------------------------------------------- | ------------------------------------------------------------ |
102 | | dual-encoder 1 | dual-encoder trained with bm25 negatives | [DE1](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/dual-encoder-trained-with-bm25-negatives.p) |
103 | | dual-encoder 2 | dual-encoder trained with self-mined hard negatives | [DE2](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/dual-encoder-trained-with-hard-negatives.p) |
104 | | cross-encoder | cross-encoder trained with self-mined hard negatives | [CE](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/cross-encoder.p) |
105 |
106 |
107 | BM25 on DEV set
108 | ```bash
109 | #####################
110 | MRR @10: 0.35894801237316354
111 | QueriesRanked: 24831
112 | recall@1: 0.05098711868967141
113 | recall@1000: 0.7464097131133757
114 | recall@50: 0.4942572226146033
115 | #####################
116 | ```
117 |
118 | DPR trained with BM25 negatives on DEV set
119 |
120 | ```bash
121 | #####################
122 | MRR @10: 0.4856112079562753
123 | QueriesRanked: 24831
124 | recall@1: 0.07367235058688999
125 | recall@1000: 0.9082753169878586
126 | recall@50: 0.7099350889583964
127 | #####################
128 | ```
129 |
130 | DPR trained with self-mined hard negatives on DEV set
131 |
132 | ```bash
133 | #####################
134 | MRR @10: 0.5166915171959451
135 | QueriesRanked: 24831
136 | recall@1: 0.08047455688965123
137 | recall@1000: 0.9135220125786163
138 | recall@50: 0.7327044025157232
139 | #####################
140 | ```
141 |
142 |
143 | BM25 retrieved+CE reranked on DEV set
144 |
145 | The reranked run file is placed in [here.](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/dev.bm25.tsv)
146 | ```bash
147 | #####################
148 | MRR @10: 0.5188107959009376
149 | QueriesRanked: 24831
150 | recall@1: 0.08545219116806242
151 | recall@1000: 0.7464097131133757
152 | recall@50: 0.595298153566744
153 | #####################
154 | ndcg_cut_20 all 0.4405
155 | ndcg_cut_100 all 0.4705
156 | #####################
157 | ```
158 |
159 | DPR retrieved+CE reranked on DEV set
160 |
161 | The reranked run file is placed in [here.](https://huggingface.co/datasets/THUIR/T2Ranking/blob/main/data/dev.dpr.tsv)
162 | ```bash
163 | #####################
164 | MRR @10: 0.5508822816845231
165 | QueriesRanked: 24831
166 | recall@1: 0.08903406988867588
167 | recall@1000: 0.9135220125786163
168 | recall@50: 0.7393720781623112
169 | #####################
170 | ndcg_cut_20 all 0.5131
171 | ndcg_cut_100 all 0.5564
172 | #####################
173 | ```
174 |
175 | ## License
176 | The dataset is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0.html).
177 |
178 |
179 | ## Citation
180 | If you use this dataset in your research, please cite our paper:
181 | ```
182 | @misc{xie2023t2ranking,
183 | title={T2Ranking: A large-scale Chinese Benchmark for Passage Ranking},
184 | author={Xiaohui Xie and Qian Dong and Bingning Wang and Feiyang Lv and Ting Yao and Weinan Gan and Zhijing Wu and Xiangsheng Li and Haitao Li and Yiqun Liu and Jin Ma},
185 | year={2023},
186 | eprint={2304.03679},
187 | archivePrefix={arXiv},
188 | primaryClass={cs.IR}
189 | }
190 | ```
191 |
--------------------------------------------------------------------------------
/pic/stat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUIR/T2Ranking/3ab0a0de72dd50bf84d852a985f6188334781403/pic/stat.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | faiss_cpu==1.7.4
2 | faiss_gpu==1.7.2
3 | numpy==1.24.3
4 | pandas==2.0.1
5 | pytorch_pretrained_bert==0.6.2
6 | scikit_learn==1.2.2
7 | torch==2.0.0
8 | torch_optimizer==0.3.0
9 | tqdm==4.65.0
10 | transformers==4.28.1
11 |
--------------------------------------------------------------------------------
/script/train_cross_encoder.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | dataset=marco
3 | sample_num=64
4 | batch_size=16
5 | echo "batch size ${batch_size}"
6 | dev_batch_size=1024
7 | min_index=0
8 | max_index=256
9 | max_seq_len=332
10 | q_max_seq_len=32
11 | p_max_seq_len=128
12 | model_name_or_path=checkpoint/bert-base-chinese/
13 | top1000=data/train.mined.tsv
14 | warm_start_from=data/cross-encoder.p
15 | dev_top1000=data/dev.dpr.tsv
16 | dev_query=data/queries.dev.tsv
17 | collection=data/collection.tsv
18 | qrels=data/qrels.retrieval.train.tsv
19 | dev_qrels=data/qrels.retrieval.dev.tsv
20 | query=data/queries.train.tsv
21 | learning_rate=3e-5
22 | ### 下面是永远不用改的
23 | warmup_proportion=0.1
24 | eval_step_proportion=0.01
25 | report_step=100
26 | epoch=20
27 | fp16=true
28 | output_dir=output
29 | log_dir=${output_dir}/log
30 | mkdir -p ${output_dir}
31 | mkdir -p ${log_dir}
32 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}=================="
33 | python -m torch.distributed.launch \
34 | --log_dir ${log_dir} \
35 | --nproc_per_node=8 \
36 | src/train_cross_encoder.py \
37 | --model_name_or_path=${model_name_or_path} \
38 | --batch_size=${batch_size} \
39 | --warmup_proportion=${warmup_proportion} \
40 | --eval_step_proportion=${eval_step_proportion} \
41 | --report=${report_step} \
42 | --qrels=${qrels} \
43 | --dev_qrels=${dev_qrels} \
44 | --query=${query} \
45 | --dev_query=${dev_query} \
46 | --collection=${collection} \
47 | --top1000=${top1000} \
48 | --min_index=${min_index} \
49 | --max_index=${max_index} \
50 | --epoch=${epoch} \
51 | --sample_num=${sample_num} \
52 | --dev_batch_size=${dev_batch_size} \
53 | --max_seq_len=${max_seq_len} \
54 | --learning_rate=${learning_rate} \
55 | --q_max_seq_len=${q_max_seq_len} \
56 | --p_max_seq_len=${p_max_seq_len} \
57 | --dev_top1000=${dev_top1000} \
58 | --warm_start_from=${warm_start_from} \
59 | | tee ${log_dir}/train.log
60 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}=================="
61 |
62 |
--------------------------------------------------------------------------------
/script/train_dual_encoder.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | dataset=marco
3 | sample_num=2
4 | batch_size=64
5 | echo "batch size ${batch_size}"
6 | max_index=200
7 | retriever_model_name_or_path=checkpoint/bert-base-chinese/
8 | top1000=data/train.mined.tsv
9 | warm_start_from=data/dual-encoder-trained-with-hard-negatives.p
10 | learning_rate=2e-5
11 | ### 下面是永远不用改的
12 | dev_batch_size=256
13 | min_index=0
14 | max_seq_len=332
15 | q_max_seq_len=32
16 | p_max_seq_len=300
17 | dev_query=data/queries.dev.tsv
18 | collection=data/collection.tsv
19 | qrels=data/qrels.retrieval.train.tsv
20 | dev_qrels=data/qrels.retrieval.dev.tsv
21 | query=data/queries.train.tsv
22 | warmup_proportion=0.1
23 | eval_step_proportion=0.01
24 | report_step=100
25 | epoch=200
26 | fp16=true
27 | output_dir=output
28 | log_dir=${output_dir}/log
29 | mkdir -p ${output_dir}
30 | mkdir -p ${log_dir}
31 | master_port=29500
32 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}=================="
33 | python -m torch.distributed.launch \
34 | --log_dir ${log_dir} \
35 | --nproc_per_node=8 \
36 | --master_port=${master_port} \
37 | src/train_dual_encoder.py \
38 | --retriever_model_name_or_path=${retriever_model_name_or_path} \
39 | --batch_size=${batch_size} \
40 | --warmup_proportion=${warmup_proportion} \
41 | --eval_step_proportion=${eval_step_proportion} \
42 | --report=${report_step} \
43 | --qrels=${qrels} \
44 | --dev_qrels=${dev_qrels} \
45 | --query=${query} \
46 | --dev_query=${dev_query} \
47 | --collection=${collection} \
48 | --top1000=${top1000} \
49 | --min_index=${min_index} \
50 | --max_index=${max_index} \
51 | --epoch=${epoch} \
52 | --sample_num=${sample_num} \
53 | --dev_batch_size=${dev_batch_size} \
54 | --max_seq_len=${max_seq_len} \
55 | --learning_rate=${learning_rate} \
56 | --q_max_seq_len=${q_max_seq_len} \
57 | --p_max_seq_len=${p_max_seq_len} \
58 | --warm_start_from=${warm_start_from} \
59 | | tee ${log_dir}/train.log
60 |
61 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}=================="
62 |
63 |
--------------------------------------------------------------------------------
/src/convert2trec.py:
--------------------------------------------------------------------------------
1 | import sys
2 | fin=sys.argv[1]
3 | def convert_to_trec(fin):
4 | with open(fin) as f:
5 | lines=f.readlines()
6 | saved=[]
7 | for line in lines:
8 | qid,pid,index,score=line.strip().split()
9 | saved.append(str(int(float(qid)))+'\t'+"vanilla_bert\t"+str(int(float(pid)))+'\t'+str(index)+'\t'+str(score)+'\tvanilla_bert\n')
10 | with open(fin+".trec","w") as f:
11 | f.writelines(saved)
12 |
13 | if __name__=="__main__":
14 | convert_to_trec(fin)
--------------------------------------------------------------------------------
/src/dataset_factory.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import pytorch_pretrained_bert
7 | import torch
8 | from sklearn.utils import shuffle
9 | from torch.utils.data import DataLoader, Dataset
10 | from tqdm import tqdm
11 | from transformers import AutoModel, AutoTokenizer, DataCollatorForWholeWordMask
12 |
13 | class PassageDataset(Dataset):
14 | def __init__(self, args):
15 | self.tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path)
16 | try:
17 | self.rank = torch.distributed.get_rank()
18 | self.n_procs = torch.distributed.get_world_size()
19 | except:
20 | self.rank = self.n_procs = 0
21 | self.args = args
22 | self.collection = pd.read_csv(args.collection,sep="\t", quoting=3)
23 | self.collection.columns=['pid', 'para']
24 | self.collection = self.collection.fillna("NA")
25 | self.collection.index = self.collection.pid
26 | total_cnt = len(self.collection)
27 | shard_cnt = total_cnt//self.n_procs
28 | if self.rank!=self.n_procs-1:
29 | self.collection = self.collection[self.rank*shard_cnt:(self.rank+1)*shard_cnt]
30 | else:
31 | self.collection = self.collection[self.rank*shard_cnt:]
32 | self.num_samples = len(self.collection)
33 | print('rank:',self.rank,'samples:',self.num_samples)
34 |
35 | def _collate_fn(self, psgs):
36 | p_records = self.tokenizer(psgs, padding=True, truncation=True, return_tensors="pt", max_length=self.args.p_max_seq_len)
37 | return p_records
38 |
39 | def __getitem__(self, idx):
40 | cols = self.collection.iloc[idx]
41 | para = cols.para
42 | psg = para
43 | return psg
44 |
45 | def __len__(self):
46 | return self.num_samples
47 |
48 |
49 | class QueryDataset(Dataset):
50 | def __init__(self, args):
51 | self.tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path)
52 | self.args = args
53 | self.collection = pd.read_csv(args.dev_query, sep="\t", quoting=3)
54 | self.collection.columns = ['qid','qry']
55 | self.collection = self.collection.fillna("NA")
56 | self.num_samples = len(self.collection)
57 |
58 | def _collate_fn(self, qrys):
59 | return self.tokenizer(qrys, padding=True, truncation=True, return_tensors="pt", max_length=self.args.q_max_seq_len)
60 |
61 | def __getitem__(self, idx):
62 | return self.collection.iloc[idx].qry
63 |
64 | def __len__(self):
65 | return self.num_samples
66 |
67 |
68 | class CrossEncoderTrainDataset(Dataset):
69 | def __init__(self, args):
70 | self.tokenizer = AutoTokenizer.from_pretrained(args.reranker_model_name_or_path)
71 | try:
72 | self.rank = torch.distributed.get_rank()
73 | self.n_procs = torch.distributed.get_world_size()
74 | except:
75 | self.rank = self.n_procs = 0
76 | self.args = args
77 | self.collection = pd.read_csv(args.collection,sep="\t", quoting=3)
78 | self.collection.columns=['pid', 'para']
79 | self.collection = self.collection.fillna("NA")
80 | self.collection.index = self.collection.pid
81 | self.collection.pop('pid')
82 | self.query = pd.read_csv(args.query,sep="\t")
83 | self.query.columns = ['qid','text']
84 | self.query.index = self.query.qid
85 | self.query.pop('qid')
86 | self.top1000 = pd.read_csv(args.top1000, sep="\t")
87 | self.top1000.columns=['qid','pid','index', 'score']
88 | self.top1000 = list(self.top1000.groupby("qid"))
89 | self.len = len(self.top1000)
90 | self.min_index = args.min_index
91 | self.max_index = args.max_index
92 | qrels={}
93 | with f(args.qrels,'r') as f:
94 | lines = f.readlines()
95 | for line in lines[1:]:
96 | qid,pid = line.split()
97 | qid=int(qid)
98 | pid=int(pid)
99 | x=qrels.get(qid,[])
100 | x.append(pid)
101 | qrels[qid]=x
102 | self.qrels = qrels
103 | self.sample_num = args.sample_num-1
104 | self.epoch = 0
105 | self.num_samples = len(self.top1000)
106 |
107 | def set_epoch(self, epoch):
108 | self.epoch = epoch
109 | print(self.epoch)
110 |
111 | def sample(self, qid, pids, sample_num):
112 | '''
113 | qid:int
114 | pids:list
115 | sample_num:int
116 | '''
117 | pids = [pid for pid in pids if pid not in self.qrels[qid]]
118 | pids = pids[self.args.min_index:self.args.max_index]
119 | interval = len(pids)//sample_num
120 | offset = self.epoch%interval
121 | sample_pids = pids[offset::interval][:sample_num]
122 | return sample_pids
123 |
124 | def __getitem__(self, idx):
125 | cols = self.top1000[idx]
126 | qid = cols[0]
127 | pids = list(cols[1]['pid'])
128 | sample_neg_pids = self.sample(qid, pids, self.sample_num)
129 | pos_id = random.choice(self.qrels.get(qid))
130 | query = self.query.loc[qid]['text']
131 | data = [(query, self.collection.loc[pos_id]['para'])]
132 | for neg_pid in sample_neg_pids:
133 | data.append((query, self.collection.loc[neg_pid]['para']))
134 | return data
135 |
136 | def _collate_fn(self, sample_list):
137 | qrys = []
138 | psgs = []
139 | for qp_pairs in sample_list:
140 | for q,p in qp_pairs:
141 | qrys.append(q)
142 | psgs.append(p)
143 | features = self.tokenizer(qrys, psgs, padding=True, truncation=True, return_tensors="pt", max_length=self.args.max_seq_len)
144 | return features
145 |
146 | def __len__(self):
147 | return self.num_samples
148 |
149 | class CrossEncoderDevDataset(Dataset):
150 | def __init__(self, args):
151 | self.tokenizer = AutoTokenizer.from_pretrained(args.reranker_model_name_or_path)
152 | try:
153 | self.rank = torch.distributed.get_rank()
154 | self.n_procs = torch.distributed.get_world_size()
155 | except:
156 | self.rank = self.n_procs = 0
157 | self.args = args
158 | self.collection = pd.read_csv(args.collection,sep="\t", quoting=3)
159 | self.collection.columns=['pid', 'para']
160 | self.collection = self.collection.fillna("NA")
161 | self.collection.index = self.collection.pid
162 | self.collection.pop('pid')
163 | self.query = pd.read_csv(args.dev_query,sep="\t")
164 | self.query.columns = ['qid','text']
165 | self.query.index = self.query.qid
166 | self.query.pop('qid')
167 | self.top1000 = pd.read_csv(args.dev_top1000, sep="\t", header=None)
168 | self.num_samples = len(self.top1000)
169 |
170 |
171 | def __getitem__(self, idx):
172 | cols = self.top1000.iloc[idx]
173 | qid = cols[0]
174 | pid = cols[1]
175 | return self.query.loc[qid]['text'], self.collection.loc[pid]['para'], qid, pid
176 |
177 | def _collate_fn(self, sample_list):
178 | qrys = []
179 | psgs = []
180 | qids = []
181 | pids = []
182 | for q,p,qid,pid in sample_list:
183 | qrys.append(q)
184 | psgs.append(p)
185 | qids.append(qid)
186 | pids.append(pid)
187 | features = self.tokenizer(qrys, psgs, padding=True, truncation=True, return_tensors="pt", max_length=self.args.max_seq_len)
188 | return features, {"qids":np.array(qids),"pids":np.array(pids)}
189 |
190 | def __len__(self):
191 | return self.num_samples
192 |
193 | class DualEncoderTrainDataset(Dataset):
194 | def __init__(self, args):
195 | self.tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path)
196 | try:
197 | self.rank = torch.distributed.get_rank()
198 | self.n_procs = torch.distributed.get_world_size()
199 | except:
200 | self.rank = self.n_procs = 0
201 | self.args = args
202 | self.collection = pd.read_csv(args.collection,sep="\t", quoting=3)
203 | self.collection.columns=['pid','para']
204 | self.collection = self.collection.fillna("NA")
205 | self.collection.index = self.collection.pid
206 | self.collection.pop('pid')
207 | self.query = pd.read_csv(args.query,sep="\t")
208 | self.query.columns = ['qid','text']
209 | self.query.index = self.query.qid
210 | self.query.pop('qid')
211 | self.top1000 = pd.read_csv(args.top1000, sep="\t")
212 | if len(self.top1000.columns)==3:
213 | self.top1000.columns=['qid','pid','index']
214 | else:
215 | self.top1000.columns=['qid','pid','index','score']
216 | self.top1000 = list(self.top1000.groupby("qid"))
217 | self.len = len(self.top1000)
218 | self.min_index = args.min_index
219 | self.max_index = args.max_index
220 | qrels={}
221 | with open(args.qrels,'r') as f:
222 | lines = f.readlines()
223 | for line in lines[1:]:
224 | qid,pid = line.split()
225 | qid=int(qid)
226 | pid=int(pid)
227 | x=qrels.get(qid,[])
228 | x.append(pid)
229 | qrels[qid]=x
230 | self.qrels = qrels
231 | self.sample_num = args.sample_num-1
232 | self.epoch = 0
233 | self.num_samples = len(self.top1000)
234 |
235 | def set_epoch(self, epoch):
236 | self.epoch = epoch
237 | print(self.epoch)
238 |
239 | def sample(self, qid, pids, sample_num):
240 | '''
241 | qid:int
242 | pids:list
243 | sample_num:int
244 | '''
245 | pids = [pid for pid in pids if pid not in self.qrels[qid]]
246 | pids = pids[self.args.min_index:self.args.max_index]
247 | if len(pids)
5 | Creation Date : 06/12/2018
6 | Last Modified : 4/6/2023 by Qian Dong and Haitao Li
7 | Authors : Daniel Campos , Rutger van Haasteren
8 | """
9 | import itertools
10 | import sys
11 | from collections import Counter
12 |
13 | import numpy as np
14 | import pandas as pd
15 |
16 | MaxMRRRank = 10
17 |
18 | def load_reference_from_stream(f):
19 | """Load Reference reference relevant passages
20 | Args:f (stream): stream to load.
21 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
22 | """
23 | qids_to_relevant_passageids = {}
24 | for l in f:
25 | try:
26 | l = l.strip().split('\t')
27 | qid = int(l[0])
28 | if qid in qids_to_relevant_passageids:
29 | pass
30 | else:
31 | qids_to_relevant_passageids[qid] = []
32 | qids_to_relevant_passageids[qid].append(int(l[1]))
33 | except:
34 | # raise IOError('\"%s\" is not valid format' % l)
35 | pass
36 | return qids_to_relevant_passageids
37 |
38 |
39 | def load_reference(path_to_reference):
40 | """Load Reference reference relevant passages
41 | Args:path_to_reference (str): path to a file to load.
42 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
43 | """
44 | with open(path_to_reference, 'r') as f:
45 | qids_to_relevant_passageids = load_reference_from_stream(f)
46 | return qids_to_relevant_passageids
47 |
48 |
49 | def load_candidate_from_stream(f):
50 | """Load candidate data from a stream.
51 | Args:f (stream): stream to load.
52 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance
53 | """
54 | qid_to_ranked_candidate_passages = {}
55 | for l in f:
56 | try:
57 | l = l.strip().split()
58 | qid = int(float(l[0]))
59 | pid = int(float(l[1]))
60 | rank = int(float(l[2]))
61 | if qid in qid_to_ranked_candidate_passages:
62 | pass
63 | else:
64 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given
65 | tmp = [0] * 1000
66 | qid_to_ranked_candidate_passages[qid] = tmp
67 | qid_to_ranked_candidate_passages[qid][rank - 1] = pid
68 | except:
69 | # raise IOError('\"%s\" is not valid format' % l)
70 | pass
71 | return qid_to_ranked_candidate_passages
72 |
73 |
74 | def load_candidate(path_to_candidate):
75 | """Load candidate data from a file.
76 | Args:path_to_candidate (str): path to file to load.
77 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance
78 | """
79 |
80 | with open(path_to_candidate, 'r') as f:
81 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f)
82 | return qid_to_ranked_candidate_passages
83 |
84 |
85 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
86 | """Perform quality checks on the dictionaries
87 | Args:
88 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
89 | Dict as read in with load_reference or load_reference_from_stream
90 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
91 | Returns:
92 | bool,str: Boolean whether allowed, message to be shown in case of a problem
93 | """
94 | message = ''
95 | allowed = True
96 |
97 | # Create sets of the QIDs for the submitted and reference queries
98 | candidate_set = set(qids_to_ranked_candidate_passages.keys())
99 | ref_set = set(qids_to_relevant_passageids.keys())
100 |
101 | # Check that we do not have multiple passages per query
102 | for qid in qids_to_ranked_candidate_passages:
103 | # Remove all zeros from the candidates
104 | duplicate_pids = set(
105 | [item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1])
106 |
107 | if len(duplicate_pids - set([0])) > 0:
108 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format(
109 | qid=qid, pid=list(duplicate_pids)[0])
110 | allowed = False
111 |
112 | return allowed, message
113 |
114 |
115 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
116 | """Compute MRR metric
117 | Args:
118 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
119 | Dict as read in with load_reference or load_reference_from_stream
120 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
121 | Returns:
122 | dict: dictionary of metrics {'MRR': }
123 | """
124 | all_scores = {}
125 | MRR = 0
126 | qids_with_relevant_passages = 0
127 | ranking = []
128 | recall_q_top1 = []
129 | recall_q_top50 = []
130 | recall_q_top1000 = []
131 | recall_q_all = []
132 | all_num = 0
133 |
134 | for qid in qids_to_ranked_candidate_passages:
135 | if qid in qids_to_relevant_passageids:
136 | ranking.append(0)
137 | target_pid = qids_to_relevant_passageids[qid]
138 | all_num = all_num + len(target_pid)
139 | candidate_pid = qids_to_ranked_candidate_passages[qid]
140 | for i in range(0, MaxMRRRank):
141 | if candidate_pid[i] in target_pid:
142 | MRR += 1.0 / (i + 1)
143 | ranking.pop()
144 | ranking.append(i + 1)
145 | break
146 | for i, pid in enumerate(candidate_pid):
147 | if pid in target_pid:
148 | recall_q_all.append(pid)
149 | if i < 50:
150 | recall_q_top50.append(pid)
151 | if i < 1000:
152 | recall_q_top1000.append(pid)
153 | if i == 0:
154 | recall_q_top1.append(pid)
155 |
156 |
157 | if len(ranking) == 0:
158 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?")
159 |
160 |
161 | MRR = MRR / len(qids_to_ranked_candidate_passages)
162 | recall_top1 = len(recall_q_top1) * 1.0 / all_num
163 | recall_top50 = len(recall_q_top50) * 1.0 / all_num
164 | recall_all = len(recall_q_top1000) * 1.0 / all_num
165 | all_scores['MRR @10'] = MRR
166 | all_scores["recall@1"] = recall_top1
167 | all_scores["recall@50"] = recall_top50
168 | all_scores["recall@1000"] = recall_all
169 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages)
170 | return all_scores
171 |
172 |
173 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True):
174 | """Compute MRR metric
175 | Args:
176 | p_path_to_reference_file (str): path to reference file.
177 | Reference file should contain lines in the following format:
178 | QUERYID\tPASSAGEID
179 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs
180 | p_path_to_candidate_file (str): path to candidate file.
181 | Candidate file sould contain lines in the following format:
182 | QUERYID\tPASSAGEID1\tRank
183 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is
184 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID
185 | Where the values are separated by tabs and ranked in order of relevance
186 | Returns:
187 | dict: dictionary of metrics {'MRR': }
188 | """
189 |
190 | qids_to_relevant_passageids = load_reference(path_to_reference)
191 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate)
192 | if perform_checks:
193 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
194 | if message != '': print(message)
195 |
196 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
197 |
198 |
199 | def main():
200 | """Command line:
201 | python msmarco_eval_ranking.py
202 | """
203 |
204 | if len(sys.argv) == 3:
205 | path_to_reference = sys.argv[1]
206 | path_to_candidate = sys.argv[2]
207 |
208 | else:
209 | print('Usage: msmarco_eval_ranking.py ')
210 | exit()
211 |
212 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate)
213 | print('#####################')
214 | for metric in sorted(metrics):
215 | print('{}: {}'.format(metric, metrics[metric]))
216 | print('#####################')
217 |
218 |
219 | def calc_mrr(path_to_reference, path_to_candidate):
220 | """Command line:
221 | python msmarco_eval_ranking.py
222 | """
223 |
224 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate)
225 | print('#####################')
226 | for metric in sorted(metrics):
227 | print('{}: {}'.format(metric, metrics[metric]))
228 | print('#####################')
229 | return metrics
230 |
231 |
232 | def get_mrr(path_to_reference="/home/dongqian06/codes/NAACL2021-RocketQA/corpus/marco/qrels.dev.tsv", path_to_candidate="output/step_0_pred_dev_scores.txt"):
233 | all_data = pd.read_csv(path_to_candidate,sep="\t",header=None)
234 | all_data.columns = ["qid","pid","score"]
235 | all_data = all_data.groupby("qid").apply(lambda x: x.sort_values('score', ascending=False).reset_index(drop=True))
236 | all_data.columns = ['query_id',"para_id","score"]
237 | all_data = all_data.reset_index()
238 | all_data.pop("qid")
239 | all_data.columns = ["index","qid","pid","score"]
240 | all_data = all_data.loc[:,["qid","pid","index","score"]]
241 | all_data['index']+=1
242 | path_to_candidate = path_to_candidate.replace("txt","qrels")
243 | all_data.to_csv(path_to_candidate, header=None,index=False,sep="\t")
244 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate)
245 | return metrics['MRR @10']
246 |
247 | if __name__ == '__main__':
248 | main()
249 |
--------------------------------------------------------------------------------
/src/train_cross_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
4 | import argparse
5 | import random
6 | import subprocess
7 | import tempfile
8 | import time
9 | from collections import defaultdict
10 |
11 | import faiss
12 | import numpy as np
13 | import torch
14 | import torch.distributed as dist
15 | import torch.nn.functional as F
16 | from torch import distributed, optim
17 | from torch.cuda.amp import GradScaler, autocast
18 | from torch.nn.parallel import DistributedDataParallel
19 | from torch.utils.data import DataLoader, Dataset
20 | from tqdm import tqdm
21 | from transformers import AutoConfig, AutoTokenizer, BertModel
22 |
23 | import dataset_factory
24 | import utils
25 | from modeling import Reranker
26 | from msmarco_eval import calc_mrr
27 | from utils import add_prefix, build_engine, load_qid, read_embed, search
28 |
29 | SEED = 2023
30 | best_mrr=-1
31 | torch.manual_seed(SEED)
32 | torch.cuda.manual_seed_all(SEED)
33 | random.seed(SEED)
34 | def define_args():
35 | import argparse
36 | parser = argparse.ArgumentParser('BERT-ranker model')
37 | parser.add_argument('--batch_size', type=int, default=2)
38 | parser.add_argument('--dev_batch_size', type=int, default=64)
39 | parser.add_argument('--max_seq_len', type=int, default=160)
40 | parser.add_argument('--q_max_seq_len', type=int, default=160)
41 | parser.add_argument('--p_max_seq_len', type=int, default=160)
42 | parser.add_argument('--model_name_or_path', type=str, default="../../data/bert-base-uncased/")
43 | parser.add_argument('--reranker_model_name_or_path', type=str, default="../../data/bert-base-uncased/")
44 | parser.add_argument('--warm_start_from', type=str, default="")
45 | parser.add_argument('--model_out_dir', type=str, default="output")
46 | parser.add_argument('--learning_rate', type=float, default=1e-5)
47 | parser.add_argument('--weight_decay', type=float, default=0.01)
48 | parser.add_argument('--warmup_proportion', type=float, default=0.1)
49 | parser.add_argument('--eval_step_proportion', type=float, default=1.0)
50 | parser.add_argument('--report', type=int, default=1)
51 | parser.add_argument('--epoch', type=int, default=3)
52 | parser.add_argument('--qrels', type=str, default="../../data/marco/qrels.train.debug.tsv")
53 | parser.add_argument('--dev_qrels', type=str, default="../../data/marco/qrels.train.debug.tsv")
54 | parser.add_argument('--top1000', type=str, default="../../data/marco/run.msmarco-passage.train.debug.tsv")
55 | parser.add_argument('--dev_top1000', type=str, default="../../data/marco/run.msmarco-passage.train.debug.tsv")
56 | parser.add_argument('--collection', type=str, default="../../data/marco/collection.debug.tsv")
57 | parser.add_argument('--query', type=str, default="../../data/marco/train.query.debug.txt")
58 | parser.add_argument('--dev_query', type=str, default="../../data/marco/train.query.debug.txt")
59 | parser.add_argument('--min_index', type=int, default=0)
60 | parser.add_argument('--max_index', type=int, default=256)
61 | parser.add_argument('--sample_num', type=int, default=128)
62 | parser.add_argument('--num_labels', type=int, default=1)
63 | parser.add_argument('--local-rank', type=int, default=0)
64 | parser.add_argument('--local_rank', type=int, default=0)
65 | parser.add_argument('--fp16', type=bool, default=True)
66 | parser.add_argument('--gradient_checkpoint', type=bool, default=True)
67 | parser.add_argument('--negatives_x_device', type=bool, default=True)
68 | parser.add_argument('--untie_encoder', type=bool, default=True)
69 | parser.add_argument('--add_pooler', type=bool, default=False)
70 | parser.add_argument('--Temperature', type=float, default=1.0)
71 |
72 | # args = parser.parse_args(args=[])
73 | args = parser.parse_args()
74 | return args
75 |
76 | def merge(eval_cnts, file_pattern='output/res.step-%d.part-0%d'):
77 | f_list = []
78 | total_part = torch.distributed.get_world_size()
79 | for part in range(total_part):
80 | f0 = open(file_pattern % (eval_cnts, part))
81 | f_list+=f0.readlines()
82 | f_list = [l.strip().split("\t") for l in f_list]
83 | dedup = defaultdict(dict)
84 | for qid,pid,score in f_list:
85 | dedup[qid][pid] = float(score)
86 | mp = defaultdict(list)
87 | for qid in dedup:
88 | for pid in dedup[qid]:
89 | mp[qid].append((pid, dedup[qid][pid]))
90 | for qid in mp:
91 | mp[qid].sort(key=lambda x:x[1], reverse=True)
92 | with open(file_pattern.replace('.part-0%d','')%eval_cnts, 'w') as f:
93 | for qid in mp:
94 | for idx, (pid, score) in enumerate(mp[qid]):
95 | f.write(str(qid)+"\t"+str(pid)+'\t'+str(idx+1)+"\t"+str(score)+'\n')
96 | for part in range(total_part):
97 | os.remove(file_pattern % (eval_cnts, part))
98 |
99 | def train_cross_encoder(args, model, optimizer):
100 | epoch = 0
101 | local_rank = torch.distributed.get_rank()
102 | if local_rank==0:
103 | print(f'Starting training, upto {args.epoch} epochs, LR={args.learning_rate}', flush=True)
104 |
105 | # 加载数据集
106 | dev_dataset = dataset_factory.CrossEncoderDevDataset(args)
107 | dev_sampler = torch.utils.data.distributed.DistributedSampler(dev_dataset)
108 | dev_loader = DataLoader(dev_dataset, batch_size=args.dev_batch_size, collate_fn=dev_dataset._collate_fn, sampler=dev_sampler, num_workers=4)
109 | validate_multi_gpu(model, dev_loader, epoch, args)
110 |
111 | train_dataset = dataset_factory.CrossEncoderTrainDataset(args)
112 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
113 |
114 | for epoch in range(1, args.epoch+1):
115 | train_dataset.set_epoch(epoch) # 选择的negative根据epoch后移
116 | train_sampler.set_epoch(epoch) # shuffle batch
117 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=train_dataset._collate_fn, sampler=train_sampler, num_workers=4, drop_last=True)
118 | train_iteration_multi_gpu(model, optimizer, train_loader, epoch, args)
119 | torch.distributed.barrier()
120 | del train_loader
121 | if epoch%1==0:
122 | validate_multi_gpu(model, dev_loader, epoch, args)
123 | torch.distributed.barrier()
124 |
125 | def validate_multi_gpu(model, dev_loader, epoch, args):
126 | global best_mrr
127 | local_start = time.time()
128 | local_rank = torch.distributed.get_rank()
129 | world_size = torch.distributed.get_world_size()
130 | with torch.no_grad():
131 | model.eval()
132 | scores_lst = []
133 | qids_lst = []
134 | pids_lst = []
135 | for record1, record2 in tqdm(dev_loader):
136 | with autocast():
137 | scores = model(_prepare_inputs(record1))
138 | qids = record2['qids']
139 | pids = record2['pids']
140 | scores_lst.append(scores.detach().cpu().numpy().copy())
141 | qids_lst.append(qids.copy())
142 | pids_lst.append(pids.copy())
143 | qids_lst = np.concatenate(qids_lst).reshape(-1)
144 | pids_lst = np.concatenate(pids_lst).reshape(-1)
145 | scores_lst = np.concatenate(scores_lst).reshape(-1)
146 | with open("output/res.step-%d.part-0%d"%(epoch, local_rank), 'w') as f:
147 | for qid,pid,score in zip(qids_lst, pids_lst, scores_lst):
148 | f.write(str(qid)+'\t'+str(pid)+'\t'+str(score)+'\n')
149 | torch.distributed.barrier()
150 | if local_rank==0:
151 | merge(epoch)
152 | metrics = calc_mrr(args.dev_qrels, 'output/res.step-%d'%epoch)
153 | mrr = metrics['MRR @10']
154 | if mrr>best_mrr:
155 | print("*"*50)
156 | print("new top")
157 | print("*"*50)
158 | best_mrr = mrr
159 | torch.save(model.module.lm.state_dict(), os.path.join(args.model_out_dir, "reranker.p"))
160 |
161 |
162 |
163 | def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor:
164 | rt = tensor.clone()
165 | distributed.all_reduce(rt, op=distributed.ReduceOp.SUM)
166 | rt /= distributed.get_world_size()#进程数
167 | return rt
168 |
169 | def _prepare_inputs(record):
170 | prepared = {}
171 | local_rank = torch.distributed.get_rank()
172 | for key in record:
173 | x = record[key]
174 | if isinstance(x, torch.Tensor):
175 | prepared[key] = x.to(local_rank)
176 | else:
177 | prepared[key] = _prepare_inputs(x)
178 | return prepared
179 |
180 | def train_iteration_multi_gpu(model, optimizer, data_loader, epoch, args):
181 | total = 0
182 | model.train()
183 | total_loss = 0.
184 | local_rank = torch.distributed.get_rank()
185 | world_size = torch.distributed.get_world_size()
186 | start = time.time()
187 | local_start = time.time()
188 | all_steps_per_epoch = len(data_loader)
189 | step = 0
190 | scaler = GradScaler()
191 | for record in data_loader:
192 | record = _prepare_inputs(record)
193 | if args.fp16:
194 | with autocast():
195 | loss = model(record)
196 | else:
197 | loss = model(record)
198 | torch.distributed.barrier()
199 | reduced_loss = reduce_tensor(loss.data)
200 | total_loss += reduced_loss.item()
201 | # optimize
202 | optimizer.zero_grad()
203 | scaler.scale(loss).backward()
204 | scaler.step(optimizer)
205 | scaler.update()
206 | step+=1
207 | if step%args.report==0 and local_rank==0:
208 | seconds = time.time()-local_start
209 | m, s = divmod(seconds, 60)
210 | h, m = divmod(m, 60)
211 | local_start = time.time()
212 | print("epoch:%d training step: %d/%d, mean loss: %.5f, current loss: %.5f,"%(epoch, step, all_steps_per_epoch, total_loss/step, loss.cpu().detach().numpy()),"report used time:%02d:%02d:%02d," % (h, m, s), end=' ')
213 | seconds = time.time()-start
214 | m, s = divmod(seconds, 60)
215 | h, m = divmod(m, 60)
216 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ')
217 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime()))
218 | if local_rank==0:
219 | # model.save(os.path.join(args.model_out_dir, "weights.epoch-%d.p"%(epoch)))
220 | torch.save(model.module.state_dict(), os.path.join(args.model_out_dir, "weights.epoch-%d.p"%(epoch)))
221 | seconds = time.time()-start
222 | m, s = divmod(seconds, 60)
223 | h, m = divmod(m, 60)
224 | print(f'train epoch={epoch} loss={total_loss}')
225 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ')
226 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime()))
227 |
228 | if __name__ == '__main__':
229 | args = define_args()
230 | args = vars(args)
231 | args = utils.HParams(**args)
232 | args.reranker_model_name_or_path = args.model_name_or_path
233 | # 加载到多卡
234 | torch.distributed.init_process_group(backend="nccl", init_method='env://')
235 | local_rank = torch.distributed.get_rank()
236 | if local_rank==0:
237 | args.print_config()
238 | torch.cuda.set_device(local_rank)
239 | device = torch.device("cuda", local_rank)
240 |
241 | model = Reranker(args)
242 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
243 | model.to(device)
244 |
245 | params = [(k, v) for k, v in model.named_parameters() if v.requires_grad]
246 | params = {'params': [v for k, v in params]}
247 | optimizer = torch.optim.Adam([params], lr=args.learning_rate, weight_decay=0.0)
248 |
249 | if args.warm_start_from:
250 | print('warm start from ', args.warm_start_from)
251 | state_dict = torch.load(args.warm_start_from, map_location=device)
252 | for k in list(state_dict.keys()):
253 | state_dict[k.replace('module.','')] = state_dict.pop(k)
254 | model.load_state_dict(state_dict)
255 |
256 |
257 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
258 | os.makedirs(args.model_out_dir, exist_ok=True)
259 |
260 | train_cross_encoder(args, model, optimizer)
261 |
--------------------------------------------------------------------------------
/src/train_dual_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
4 | import argparse
5 | import random
6 | import subprocess
7 | import tempfile
8 | import time
9 | from collections import defaultdict
10 |
11 | import faiss
12 | import numpy as np
13 | import torch
14 | import torch.distributed as dist
15 | import torch.nn.functional as F
16 | from torch import distributed
17 | import torch_optimizer as optim
18 | from torch.cuda.amp import GradScaler, autocast
19 | from torch.nn.parallel import DistributedDataParallel
20 | from torch.utils.data import DataLoader, Dataset
21 | from tqdm import tqdm
22 | from transformers import AutoConfig, AutoTokenizer, BertModel
23 |
24 | import dataset_factory
25 | import utils
26 | from modeling import DualEncoder
27 | from msmarco_eval import calc_mrr
28 | from utils import add_prefix, build_engine, load_qid, merge, read_embed, search
29 |
30 | SEED = 2023
31 | best_mrr=-1
32 | torch.manual_seed(SEED)
33 | torch.cuda.manual_seed_all(SEED)
34 | random.seed(SEED)
35 | def define_args():
36 | parser = argparse.ArgumentParser('BERT-retrieval model')
37 | parser.add_argument('--batch_size', type=int, default=2)
38 | parser.add_argument('--dev_batch_size', type=int, default=64)
39 | parser.add_argument('--max_seq_len', type=int, default=160)
40 | parser.add_argument('--q_max_seq_len', type=int, default=160)
41 | parser.add_argument('--p_max_seq_len', type=int, default=160)
42 | parser.add_argument('--retriever_model_name_or_path', type=str, default="")
43 | parser.add_argument('--model_out_dir', type=str, default="output")
44 | parser.add_argument('--learning_rate', type=float, default=1e-5)
45 | parser.add_argument('--weight_decay', type=float, default=0.01)
46 | parser.add_argument('--warmup_proportion', type=float, default=0.1)
47 | parser.add_argument('--eval_step_proportion', type=float, default=1.0)
48 | parser.add_argument('--report', type=int, default=1)
49 | parser.add_argument('--epoch', type=int, default=3)
50 | parser.add_argument('--qrels', type=str, default="/home/dongqian06/hdfs_data/data_train/qrels.train.debug.tsv")
51 | parser.add_argument('--dev_qrels', type=str, default="/home/dongqian06/hdfs_data/data_train/qrels.train.debug.tsv")
52 | parser.add_argument('--top1000', type=str, default="/home/dongqian06/codes/anserini/runs/run.msmarco-passage.train.debug.tsv")
53 | parser.add_argument('--collection', type=str, default="/home/dongqian06/hdfs_data/data_train/marco/collection.debug.tsv")
54 | parser.add_argument('--query', type=str, default="/home/dongqian06/hdfs_data/data_train/train.query.debug.txt")
55 | parser.add_argument('--dev_query', type=str, default="/home/dongqian06/hdfs_data/data_train/train.query.debug.txt")
56 | parser.add_argument('--min_index', type=int, default=0)
57 | parser.add_argument('--max_index', type=int, default=256)
58 | parser.add_argument('--sample_num', type=int, default=256)
59 | parser.add_argument('--num_labels', type=int, default=1)
60 | parser.add_argument('--local-rank', type=int, default=0)
61 | parser.add_argument('--local_rank', type=int, default=0)
62 | parser.add_argument('--fp16', type=bool, default=True)
63 | parser.add_argument('--gradient_checkpoint', type=bool, default=False)
64 | parser.add_argument('--negatives_x_device', type=bool, default=True)
65 | parser.add_argument('--negatives_in_device', type=bool, default=True)
66 | parser.add_argument('--untie_encoder', type=bool, default=True)
67 | parser.add_argument('--add_pooler', type=bool, default=False)
68 | parser.add_argument('--warm_start_from', type=str, default="")
69 |
70 | # args = parser.parse_args(args=[])
71 | args = parser.parse_args()
72 | return args
73 |
74 |
75 | def main_multi(args, model, optimizer):
76 | epoch = 0
77 | local_rank = torch.distributed.get_rank()
78 | if local_rank==0:
79 | print(f'Starting training, upto {args.epoch} epochs, LR={args.learning_rate}', flush=True)
80 |
81 | # 加载数据集
82 | query_dataset = dataset_factory.QueryDataset(args)
83 | query_loader = DataLoader(query_dataset, batch_size=args.dev_batch_size, collate_fn=query_dataset._collate_fn, num_workers=3)
84 | passage_dataset = dataset_factory.PassageDataset(args)
85 | passage_loader = DataLoader(passage_dataset, batch_size=args.dev_batch_size, collate_fn=passage_dataset._collate_fn, num_workers=3)
86 | validate_multi_gpu(model, query_loader, passage_loader, epoch, args)
87 |
88 | train_dataset = dataset_factory.DualEncoderTrainDataset(args)
89 |
90 | for epoch in range(1, args.epoch+1):
91 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
92 | train_sampler.set_epoch(epoch)
93 | train_dataset.set_epoch(epoch)
94 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=train_dataset._collate_fn, sampler=train_sampler, num_workers=4)
95 | loss = train_iteration_multi_gpu(model, optimizer, train_loader, epoch, args)
96 | del train_loader
97 | torch.distributed.barrier()
98 | if epoch%10==0:
99 | validate_multi_gpu(model, query_loader, passage_loader, epoch, args)
100 | torch.distributed.barrier()
101 |
102 | def validate_multi_gpu(model, query_loader, passage_loader, epoch, args):
103 | global best_mrr
104 | local_start = time.time()
105 | local_rank = torch.distributed.get_rank()
106 | world_size = torch.distributed.get_world_size()
107 | _output_file_name = 'output/_para.index.part%d'%local_rank
108 | output_file_name = 'output/para.index.part%d'%local_rank
109 | top_k = 1000
110 | q_output_file_name = 'output/query.emb.step%d.npy'%epoch
111 | if local_rank==0:
112 | q_embs = []
113 | with torch.no_grad():
114 | model.eval()
115 | for records in query_loader:
116 | if args.fp16:
117 | with autocast():
118 | q_reps = model(query_inputs=_prepare_inputs(records))
119 | else:
120 | q_reps = model(query_inputs=_prepare_inputs(records))
121 | q_embs.append(q_reps.cpu().detach().numpy())
122 | emb_matrix = np.concatenate(q_embs, axis=0)
123 | np.save(q_output_file_name, emb_matrix)
124 | print("predict q_embs cnt: %s" % len(emb_matrix))
125 | with torch.no_grad():
126 | model.eval()
127 | para_embs = []
128 | for records in tqdm(passage_loader, disable=args.local_rank>0):
129 | if args.fp16:
130 | with autocast():
131 | p_reps = model(passage_inputs=_prepare_inputs(records))
132 | else:
133 | p_reps = model(passage_inputs=_prepare_inputs(records))
134 | para_embs.append(p_reps.cpu().detach().numpy())
135 | torch.distributed.barrier()
136 | para_embs = np.concatenate(para_embs, axis=0)
137 | # para_embs = np.load('output/_para.emb.part%d.npy'%local_rank)
138 | print("predict embs cnt: %s" % len(para_embs))
139 | # engine = build_engine(para_embs, 768)
140 | # faiss.write_index(engine, _output_file_name)
141 | engine = torch.from_numpy(para_embs).cuda()
142 | np.save('output/_para.emb.part%d.npy'%local_rank, para_embs)
143 | print('create index done!')
144 | qid_list = load_qid(args.dev_query)
145 | search(engine, q_output_file_name, qid_list, "output/res.top%d.part%d.step%d"%(top_k, local_rank, epoch), top_k=top_k)
146 | torch.distributed.barrier()
147 | if local_rank==0:
148 | f_list = []
149 | for part in range(world_size):
150 | f_list.append('output/res.top%d.part%d.step%d' % (top_k, part, epoch))
151 | shift = np.load("output/_para.emb.part0.npy").shape[0]
152 | merge(world_size, shift, top_k, epoch)
153 | metrics = calc_mrr(args.dev_qrels, 'output/res.top%d.step%d'%(top_k, epoch))
154 | for run in f_list:
155 | os.remove(run)
156 | mrr = metrics['MRR @10']
157 | if mrr>best_mrr:
158 | print("*"*50)
159 | print("new top")
160 | print("*"*50)
161 | best_mrr = mrr
162 | for part in range(world_size):
163 | os.rename('output/_para.emb.part%d.npy'%part, 'output/para.emb.part%d.npy'%part)
164 | torch.save(model.state_dict(), "output/best.p")
165 | seconds = time.time()-local_start
166 | m, s = divmod(seconds, 60)
167 | h, m = divmod(m, 60)
168 | print("******************eval, mrr@10: %.10f,"%(mrr),"report used time:%02d:%02d:%02d," % (h, m, s))
169 |
170 |
171 | def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor:
172 | rt = tensor.clone()
173 | distributed.all_reduce(rt, op=distributed.ReduceOp.SUM)
174 | rt /= distributed.get_world_size()#进程数
175 | return rt
176 |
177 | def _prepare_inputs(record):
178 | prepared = {}
179 | local_rank = torch.distributed.get_rank()
180 | for key in record:
181 | x = record[key]
182 | if isinstance(x, torch.Tensor):
183 | prepared[key] = x.to(local_rank)
184 | elif x is None:
185 | prepared[key] = x
186 | else:
187 | prepared[key] = _prepare_inputs(x)
188 | return prepared
189 |
190 | def train_iteration_multi_gpu(model, optimizer, data_loader, epoch, args):
191 | total = 0
192 | model.train()
193 | total_loss = 0.
194 | total_ce_loss = 0.
195 | local_rank = torch.distributed.get_rank()
196 | world_size = torch.distributed.get_world_size()
197 | start = time.time()
198 | local_start = time.time()
199 | all_steps_per_epoch = len(data_loader)
200 | step = 0
201 | scaler = GradScaler()
202 | for record in data_loader:
203 | record = _prepare_inputs(record)
204 | with autocast():
205 | retriever_ce_loss = model(**record)
206 | loss = retriever_ce_loss
207 | torch.distributed.barrier()
208 | reduced_loss = reduce_tensor(loss.data)
209 | total_loss += reduced_loss.item()
210 | total_ce_loss += float(retriever_ce_loss.cpu().detach().numpy())
211 |
212 | # optimize
213 | optimizer.zero_grad()
214 | scaler.scale(loss).backward()
215 | scaler.step(optimizer)
216 | scaler.update()
217 | step+=1
218 | if step%args.report==0 and local_rank==0:
219 | seconds = time.time()-local_start
220 | m, s = divmod(seconds, 60)
221 | h, m = divmod(m, 60)
222 | local_start = time.time()
223 | print(f"epoch:{epoch} training step: {step}/{all_steps_per_epoch}, mean loss: {total_loss/step}, ce loss: {total_ce_loss/step}, ", "report used time:%02d:%02d:%02d," % (h, m, s), end=' ')
224 | seconds = time.time()-start
225 | m, s = divmod(seconds, 60)
226 | h, m = divmod(m, 60)
227 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ')
228 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime()))
229 | if local_rank==0:
230 | # model.save(os.path.join(args.model_out_dir, "weights.epoch-%d.p"%(epoch)))
231 | # torch.save(model.state_dict(), os.path.join(args.model_out_dir, "weights.epoch-%d.p"%(epoch)))
232 | seconds = time.time()-start
233 | m, s = divmod(seconds, 60)
234 | h, m = divmod(m, 60)
235 | print(f'train epoch={epoch} loss={total_loss}')
236 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ')
237 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime()))
238 | return total_loss
239 |
240 |
241 | def main_cli():
242 | args = define_args()
243 | args = vars(args)
244 | args = utils.HParams(**args)
245 | # 加载到多卡
246 | torch.distributed.init_process_group(backend="nccl", init_method='env://')
247 | local_rank = torch.distributed.get_rank()
248 | if local_rank==0:
249 | args.print_config()
250 | torch.cuda.set_device(local_rank)
251 | device = torch.device("cuda", local_rank)
252 |
253 | model = DualEncoder(args)
254 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
255 | model.to(device)
256 |
257 | params = [(k, v) for k, v in model.named_parameters() if v.requires_grad]
258 | params = {'params': [v for k, v in params]}
259 | # optimizer = torch.optim.Adam([params], lr=args.learning_rate, weight_decay=0.0)
260 | optimizer = optim.Lamb([params], lr=args.learning_rate, weight_decay=0.0)
261 |
262 | if args.warm_start_from:
263 | print('warm start from ', args.warm_start_from)
264 | state_dict = torch.load(args.warm_start_from, map_location=device)
265 | for k in list(state_dict.keys()):
266 | state_dict[k.replace('module.','')] = state_dict.pop(k)
267 | model.load_state_dict(state_dict, strict=True)
268 |
269 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
270 | print("model loaded on GPU%d"%local_rank)
271 | print(args.model_out_dir)
272 | os.makedirs(args.model_out_dir, exist_ok=True)
273 |
274 | main_multi(args, model, optimizer)
275 |
276 | if __name__ == '__main__':
277 | main_cli()
278 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import six
3 | import os
4 | import faiss
5 | import numpy as np
6 | class HParams(object):
7 | """Hyper paramerter"""
8 |
9 | def __init__(self, **kwargs):
10 | for k, v in kwargs.items():
11 | self.__dict__[k] = v
12 |
13 | def __contains__(self, key):
14 | return key in self.__dict__
15 |
16 | def __getitem__(self, key):
17 | if key not in self.__dict__:
18 | raise ValueError('key(%s) not in HParams.' % key)
19 | return self.__dict__[key]
20 |
21 | def __repr__(self):
22 | return repr(self.to_dict())
23 |
24 | def __setitem__(self, key, val):
25 | self.__dict__[key] = val
26 |
27 | @classmethod
28 | def from_json(cls, json_str):
29 | """doc"""
30 | d = json.loads(json_str)
31 | if type(d) != dict:
32 | raise ValueError('json object must be dict.')
33 | return HParams.from_dict(d)
34 |
35 | def get(self, key, default=None):
36 | """doc"""
37 | return self.__dict__.get(key, default)
38 |
39 | @classmethod
40 | def from_dict(cls, d):
41 | """doc"""
42 | if type(d) != dict:
43 | raise ValueError('input must be dict.')
44 | hp = HParams(**d)
45 | return hp
46 |
47 | def to_json(self):
48 | """doc"""
49 | return json.dumps(self.__dict__)
50 |
51 | def to_dict(self):
52 | """doc"""
53 | return self.__dict__
54 |
55 | def print_config(self):
56 | for key,value in self.__dict__.items():
57 | print(key+":",value)
58 |
59 | def join(self, other):
60 | """doc"""
61 | if not isinstance(other, HParams):
62 | raise ValueError('input must be HParams instance.')
63 | self.__dict__.update(**other.__dict__)
64 | return self
65 |
66 |
67 | def _get_dict_from_environ_or_json_or_file(args, env_name):
68 | if args == '':
69 | return None
70 | if args is None:
71 | s = os.environ.get(env_name)
72 | else:
73 | s = args
74 | if os.path.exists(s):
75 | s = open(s).read()
76 | if isinstance(s, six.string_types):
77 | try:
78 | r = eval(s)
79 | except SyntaxError as e:
80 | raise ValueError('json parse error: %s \n>Got json: %s' %
81 | (repr(e), s))
82 | return r
83 | else:
84 | return s #None
85 |
86 |
87 | def parse_file(filename):
88 | """useless api"""
89 | d = _get_dict_from_environ_or_json_or_file(filename, None)
90 | if d is None:
91 | raise ValueError('file(%s) not found' % filename)
92 | return d
93 |
94 | def build_engine(p_emb_matrix, dim):
95 | index = faiss.IndexFlatIP(dim)
96 | index.add(p_emb_matrix.astype('float32'))
97 | return index
98 | from tqdm import tqdm
99 | def read_embed(file_name, dim=768, bs=100):
100 | if file_name.endswith('npy'):
101 | i = 0
102 | emb_np = np.load(file_name)
103 | with tqdm(total=len(emb_np)//bs+1) as pbar:
104 | while(i < len(emb_np)):
105 | vec_list = emb_np[i:i+bs]
106 | i += bs
107 | pbar.update(1)
108 | yield vec_list
109 | else:
110 | vec_list = []
111 | with open(file_name) as inp:
112 | for line in tqdm(inp):
113 | data = line.strip()
114 | vector = [float(item) for item in data.split(' ')]
115 | assert len(vector) == dim
116 | vec_list.append(vector)
117 | if len(vec_list) == bs:
118 | yield vec_list
119 | vec_list = []
120 | if vec_list:
121 | yield vec_list
122 |
123 | def load_qid(file_name):
124 | qid_list = []
125 | with open(file_name) as inp:
126 | for line in inp:
127 | line = line.strip()
128 | qid = line.split('\t')[0]
129 | try:
130 | int(qid)
131 | qid_list.append(qid)
132 | except:
133 | pass
134 | return qid_list
135 |
136 | import torch
137 | def topk_query_passage(query_vector, passage_vector, k):
138 | """
139 | 对query vector和passage vector进行内积计算,并返回top k的索引
140 |
141 | Args:
142 | query_vector (torch.Tensor): query向量,形状为 (batch_size, query_dim)
143 | passage_vector (torch.Tensor): passage向量,形状为 (batch_size, passage_dim)
144 | k (int): 返回的top k值
145 |
146 | Returns:
147 | torch.Tensor: top k值的索引,形状为 (batch_size, k)
148 | """
149 | # 计算query向量和passage向量的内积
150 | scores = torch.matmul(query_vector, passage_vector.t()) # 形状为 (batch_size, batch_size)
151 |
152 | # 对每个batch进行排序,取top k值
153 | res_dist, res_p_id = torch.topk(scores, k=k, dim=1) # 形状为 (batch_size, k)
154 |
155 | return res_dist.cpu().numpy(), res_p_id.cpu().numpy()
156 |
157 | def search(index, emb_file, qid_list, outfile, top_k):
158 | q_idx = 0
159 | with open(outfile, 'w') as out:
160 | for batch_vec in read_embed(emb_file):
161 | q_emb_matrix = np.array(batch_vec)
162 | q_emb_matrix = torch.from_numpy(q_emb_matrix)
163 | q_emb_matrix = q_emb_matrix.cuda()
164 | res_dist, res_p_id = topk_query_passage(q_emb_matrix, index, top_k)
165 | for i in range(len(q_emb_matrix)):
166 | qid = qid_list[q_idx]
167 | for j in range(top_k):
168 | pid = res_p_id[i][j]
169 | score = res_dist[i][j]
170 | out.write('%s\t%s\t%s\t%s\n' % (qid, pid, j+1, score))
171 | q_idx += 1
172 |
173 | def merge(total_part, shift, top, eval_cnts):
174 | f_list = []
175 | for part in range(total_part):
176 | f0 = open('output/res.top%d.part%d.step%d' % (top, part, eval_cnts))
177 | f_list.append(f0)
178 |
179 | line_list = []
180 | for part in range(total_part):
181 | line = f_list[part].readline()
182 | line_list.append(line)
183 |
184 | out = open('output/res.top%d.step%d' % (top, eval_cnts), 'w')
185 | last_q = ''
186 | ans_list = {}
187 | while line_list[-1]:
188 | cur_list = []
189 | for line in line_list:
190 | sub = line.strip().split('\t')
191 | cur_list.append(sub)
192 |
193 | if last_q == '':
194 | last_q = cur_list[0][0]
195 | if cur_list[0][0] != last_q:
196 | rank = sorted(ans_list.items(), key = lambda a:a[1], reverse=True)
197 | for i in range(top):
198 | out.write("%s\t%s\t%s\t%s\n" % (last_q, rank[i][0], i+1, rank[i][1]))
199 | ans_list = {}
200 | for i, sub in enumerate(cur_list):
201 | ans_list[int(sub[1]) + shift*i] = float(sub[-1])
202 | last_q = cur_list[0][0]
203 |
204 | line_list = []
205 | for f0 in f_list:
206 | line = f0.readline()
207 | line_list.append(line)
208 |
209 | rank = sorted(ans_list.items(), key = lambda a:a[1], reverse=True)
210 | for i in range(top):
211 | out.write("%s\t%s\t%s\t%s\n" % (last_q, rank[i][0], i+1, rank[i][1]))
212 | out.close()
213 |
214 |
215 | def add_prefix(state_dict, prefix='module.'):
216 | if all(key.startswith(prefix) for key in state_dict.keys()):
217 | return state_dict
218 | stripped_state_dict = {}
219 | for key in list(state_dict.keys()):
220 | key2 = prefix + key
221 | stripped_state_dict[key2] = state_dict.pop(key)
222 | return stripped_state_dict
223 |
224 | def filter_stop_words(txts):
225 | stop_words=['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 'won', 'wouldn']
226 | txts = [t.split() for t in txts]
227 | txts = [list(set(list(filter(lambda x:x not in stop_words,t)))) for t in txts]
228 | rets = []
229 | for t in txts:
230 | rets+=t
231 | return list(set(rets))
232 |
--------------------------------------------------------------------------------