├── InstructDistill ├── Instruction-Distillation.pdf ├── README.md ├── bm25_retrieval.py ├── instruction_distill.py ├── pairwise_ranking.py ├── rank_loss.py ├── trec_eval.py └── zero2_bf16_config.yaml ├── LICENSE.txt ├── NovelEval ├── README.md ├── corpus.tsv ├── data.json ├── qrels.txt └── queries.tsv ├── README.md ├── assets ├── benchmark-results.png └── specialization-results.png ├── pointwise.py ├── rank_gpt.py ├── rank_loss.py ├── requirements.txt ├── run_evaluation.py ├── specialization.py └── trec_eval.py /InstructDistill/Instruction-Distillation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunnweiwei/RankGPT/0d62bc3855c7c118048a7c47c18e719b938e291a/InstructDistill/Instruction-Distillation.pdf -------------------------------------------------------------------------------- /InstructDistill/README.md: -------------------------------------------------------------------------------- 1 | # Instruction Distillation 2 | 3 | Code for paper [Instruction Distillation Makes Large Language Models Efficient Zero-shot Rankers](https://arxiv.org/abs/2311.01555). 4 | 5 | *Instruction Distillation* is an unsupervised approach to specialize LLMs on ranking tasks by distilling instructions. 6 | 7 | This work is presented at *The 1st Workshop on "Recommendation with Generative Models"* at CIKM 2023. 8 | 9 | ## Pre-trained Models 10 | 11 | The following code show how to predict the relevance of a paired (query, passage). 12 | 13 | ```python 14 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM 15 | import torch 16 | 17 | query = "How much impact do masks have on preventing the spread of the COVID-19?" 18 | passage = "Title: Universal Masking is Urgent in the COVID-19 Pandemic: SEIR and Agent Based Models, Empirical Validation, Policy Recommendations Content: We present two models for the COVID-19 pandemic predicting the impact of universal face mask wearing upon the spread of the SARS-CoV-2 virus--one employing a stochastic dynamic network based compartmental SEIR (susceptible-exposed-infectious-recovered) approach, and the other employing individual ABM (agent-based modelling) Monte Carlo simulation--indicating (1) significant impact under (near) universal masking when at least 80% of a population is wearing masks, versus minimal impact when only 50% or less of the population is wearing masks, and (2) significant impact when universal masking is adopted early, by Day 50 of a regional outbreak, versus minimal impact when universal masking is adopted late. These effects hold even at the lower filtering rates of homemade masks. To validate these theoretical models, we compare their predictions against a new empirical data set we have collected" 19 | instrcution = "Predict whether the given passage answer the question.\n\nQuestion: {0}\n\nPassage: {1}\n\nDoes the passage answer the question?" 20 | instrcution = instrcution.format(query, passage) 21 | ``` 22 | Use case of flan-t5 models 23 | ```python 24 | tokenizer = AutoTokenizer.from_pretrained("fireballoon/rank-flan-t5-xl") 25 | model = AutoModelForSeq2SeqLM.from_pretrained("fireballoon/rank-flan-t5-xl", torch_dtype=torch.float16) 26 | token_of_Yes = 2163 27 | features = tokenizer([instrcution,], padding=True, truncation=True, return_tensors="pt", max_length=1024) 28 | features['decoder_input_ids'] = torch.zeros(len(batch), 1).long() 29 | scores = model(**features).logits[:, -1, token_of_Yes] 30 | ``` 31 | Use case of llama models 32 | ```python 33 | tokenizer = AutoTokenizer.from_pretrained("fireballoon/rank-llama-2-7b", use_fast=False, padding_side="left") 34 | model = AutoModelForCausalLM.from_pretrained("fireballoon/rank-llama-2-7b", torch_dtype=torch.float16) 35 | token_of_Yes = 3869 36 | features = tokenizer([instrcution,], padding=True, truncation=True, return_tensors="pt", max_length=1024) 37 | scores = model(**features).logits[:, -1, token_of_Yes] 38 | ``` 39 | 40 | ## Training 41 | Retrieve passage using BM25 42 | ``` 43 | python bm25_retrieval.py 44 | ``` 45 | (optional) Evaluating Pairwise Ranking Prompting (PRP) on benchmarks. 46 | ``` 47 | python pairwise_ranking.py --model google/flan-t5-xl --eval true --generate false 48 | ``` 49 | Getting predictions of PRP on MS MARCO (`data/marco-train-10k.jsonl`, can be downloaded from [RankGPT](https://github.com/sunnweiwei/RankGPT/tree/main#download-data-and-model)). The ranking results will be saved at `out/marco-train-10k-flan-xl.json`. 50 | ``` 51 | python pairwise_ranking.py \ 52 | --model google/flan-t5-xl \ 53 | --eval false \ 54 | --generate true \ 55 | --data data/marco-train-10k.jsonl \ 56 | --save_path out/marco-train-10k-flan-xl.json 57 | ``` 58 | Training the pointwise ranker using PRP's predictions. The model checkpoints well be saved at `out/rank-flan-t5-xl`. 59 | ``` 60 | python instruction_distill.py \ 61 | --model google/flan-t5-xl \ 62 | --loss rank_net \ 63 | --data data/marco-train-10k.jsonl \ 64 | --save_path out/rank-flan-t5-xl \ 65 | --permutation out/marco-train-10k-flan-xl.json \ 66 | --do_train true \ 67 | --do_eval false 68 | ``` 69 | Converting deepspeed checkpoint. 70 | ``` 71 | python zero_to_fp32.py . pytorch_model.bin 72 | ``` 73 | 74 | ## Cite 75 | ``` 76 | @inproceedings{Sun2023InstructionDM, 77 | title={Instruction Distillation Makes Large Language Models Efficient Zero-shot Rankers}, 78 | author={Weiwei Sun and Zheng Chen and Xinyu Ma and Lingyong Yan and Shuaiqiang Wang and Pengjie Ren and Zhumin Chen and Dawei Yin and Zhaochun Ren}, 79 | booktitle={GenRec workshop at CIKM}, 80 | year={2023}, 81 | } 82 | ``` 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /InstructDistill/bm25_retrieval.py: -------------------------------------------------------------------------------- 1 | THE_RESULTS = { 2 | 'dl19': 'data/rank_results/dl19.json', 3 | 'dl20': 'data/rank_results/dl20.json', 4 | 'covid': 'data/rank_results/beir-trec-covid.json', 5 | 'arguana': 'data/rank_results/beir-arguana.json', 6 | 'touche': 'data/rank_results/beir-touche.json', 7 | 'news': 'data/rank_results/beir-news.json', 8 | 'scifact': 'data/rank_results/beir-scifact.json', 9 | 'fiqa': 'data/rank_results/beir-fiqa.json', 10 | 'scidocs': 'data/rank_results/beir-scidocs.json', 11 | 'nfc': 'data/rank_results/beir-nfc.json', 12 | 'quora': 'data/rank_results/beir-quora.json', 13 | 'dbpedia': 'data/rank_results/beir-dbpedia.json', 14 | 'fever': 'data/rank_results/beir-fever.json', 15 | 'robust04': 'data/rank_results/beir-robust04.json', 16 | 'signal': 'data/rank_results/beir-signal.json', 17 | } 18 | 19 | THE_INDEX = { 20 | 'dl19': 'msmarco-v1-passage', 21 | 'dl20': 'msmarco-v1-passage', 22 | 'covid': 'beir-v1.0.0-trec-covid.flat', 23 | 'arguana': 'beir-v1.0.0-arguana.flat', 24 | 'touche': 'beir-v1.0.0-webis-touche2020.flat', 25 | 'news': 'beir-v1.0.0-trec-news.flat', 26 | 'scifact': 'beir-v1.0.0-scifact.flat', 27 | 'fiqa': 'beir-v1.0.0-fiqa.flat', 28 | 'scidocs': 'beir-v1.0.0-scidocs.flat', 29 | 'nfc': 'beir-v1.0.0-nfcorpus.flat', 30 | 'quora': 'beir-v1.0.0-quora.flat', 31 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity.flat', 32 | 'fever': 'beir-v1.0.0-fever-flat', 33 | 'robust04': 'beir-v1.0.0-robust04.flat', 34 | 'signal': 'beir-v1.0.0-signal1m.flat', 35 | 36 | 'mrtydi-ar': 'mrtydi-v1.1-arabic', 37 | 'mrtydi-bn': 'mrtydi-v1.1-bengali', 38 | 'mrtydi-fi': 'mrtydi-v1.1-finnish', 39 | 'mrtydi-id': 'mrtydi-v1.1-indonesian', 40 | 'mrtydi-ja': 'mrtydi-v1.1-japanese', 41 | 'mrtydi-ko': 'mrtydi-v1.1-korean', 42 | 'mrtydi-ru': 'mrtydi-v1.1-russian', 43 | 'mrtydi-sw': 'mrtydi-v1.1-swahili', 44 | 'mrtydi-te': 'mrtydi-v1.1-telugu', 45 | 'mrtydi-th': 'mrtydi-v1.1-thai', 46 | } 47 | 48 | THE_TOPICS = { 49 | 'dl19': 'dl19-passage', 50 | 'dl20': 'dl20-passage', 51 | 'covid': 'beir-v1.0.0-trec-covid-test', 52 | 'arguana': 'beir-v1.0.0-arguana-test', 53 | 'touche': 'beir-v1.0.0-webis-touche2020-test', 54 | 'news': 'beir-v1.0.0-trec-news-test', 55 | 'scifact': 'beir-v1.0.0-scifact-test', 56 | 'fiqa': 'beir-v1.0.0-fiqa-test', 57 | 'scidocs': 'beir-v1.0.0-scidocs-test', 58 | 'nfc': 'beir-v1.0.0-nfcorpus-test', 59 | 'quora': 'beir-v1.0.0-quora-test', 60 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity-test', 61 | 'fever': 'beir-v1.0.0-fever-test', 62 | 'robust04': 'beir-v1.0.0-robust04-test', 63 | 'signal': 'beir-v1.0.0-signal1m-test', 64 | 65 | 'mrtydi-ar': 'mrtydi-v1.1-arabic-test', 66 | 'mrtydi-bn': 'mrtydi-v1.1-bengali-test', 67 | 'mrtydi-fi': 'mrtydi-v1.1-finnish-test', 68 | 'mrtydi-id': 'mrtydi-v1.1-indonesian-test', 69 | 'mrtydi-ja': 'mrtydi-v1.1-japanese-test', 70 | 'mrtydi-ko': 'mrtydi-v1.1-korean-test', 71 | 'mrtydi-ru': 'mrtydi-v1.1-russian-test', 72 | 'mrtydi-sw': 'mrtydi-v1.1-swahili-test', 73 | 'mrtydi-te': 'mrtydi-v1.1-telugu-test', 74 | 'mrtydi-th': 'mrtydi-v1.1-thai-test', 75 | 76 | } 77 | 78 | from pyserini.search import LuceneSearcher, get_topics, get_qrels 79 | import json 80 | from tqdm import tqdm 81 | 82 | 83 | def run_retriever(topics, searcher, qrels=None, k=100, qid=None): 84 | ranks = [] 85 | if isinstance(topics, str): 86 | hits = searcher.search(topics, k=k) 87 | ranks.append({'query': topics, 'hits': []}) 88 | rank = 0 89 | for hit in hits: 90 | rank += 1 91 | content = json.loads(searcher.doc(hit.docid).raw()) 92 | if 'title' in content: 93 | content = 'Title: ' + content['title'] + ' ' + 'Content: ' + content['text'] 94 | else: 95 | content = content['contents'] 96 | content = ' '.join(content.split()) 97 | ranks[-1]['hits'].append({ 98 | 'content': content, 99 | 'qid': qid, 'docid': hit.docid, 'rank': rank, 'score': hit.score}) 100 | return ranks[-1] 101 | 102 | for qid in tqdm(topics): 103 | if qid in qrels: 104 | query = topics[qid]['title'] 105 | ranks.append({'query': query, 'hits': []}) 106 | hits = searcher.search(query, k=k) 107 | rank = 0 108 | for hit in hits: 109 | rank += 1 110 | content = json.loads(searcher.doc(hit.docid).raw()) 111 | if 'title' in content: 112 | content = 'Title: ' + content['title'] + ' ' + 'Content: ' + content['text'] 113 | else: 114 | content = content['contents'] 115 | content = ' '.join(content.split()) 116 | ranks[-1]['hits'].append({ 117 | 'content': content, 118 | 'qid': qid, 'docid': hit.docid, 'rank': rank, 'score': hit.score}) 119 | return ranks 120 | 121 | def do_retrieval(): 122 | for data in ['dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04']: 123 | print('#' * 20) 124 | print(f'Evaluation on {data}') 125 | print('#' * 20) 126 | 127 | # Retrieve passages using pyserini BM25. 128 | # Get a specific doc: 129 | # * searcher.num_docs 130 | # * json.loads(searcher.object.reader.document(4).fields[1].fieldsData) -> {"id": "1", "contents": ""} 131 | try: 132 | searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data]) 133 | topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20') 134 | qrels = get_qrels(THE_TOPICS[data]) 135 | rank_results = run_retriever(topics, searcher, qrels, k=100) 136 | 137 | # Store JSON in rank_results to a file 138 | with open(f'rank_results_{data}.json', 'w') as f: 139 | json.dump(rank_results, f, indent=2) 140 | # Store the QRELS of the dataset 141 | with open(f'qrels_{data}.json', 'w') as f: 142 | json.dump(qrels, f, indent=2) 143 | except: 144 | print(f'Failed to retrieve passages for {data}') 145 | 146 | for data in ['mrtydi-ar', 'mrtydi-bn', 'mrtydi-fi', 'mrtydi-id', 'mrtydi-ja', 'mrtydi-ko', 'mrtydi-ru', 'mrtydi-sw', 147 | 'mrtydi-te', 'mrtydi-th']: 148 | print('#' * 20) 149 | print(f'Evaluation on {data}') 150 | print('#' * 20) 151 | 152 | # Retrieve passages using pyserini BM25. 153 | try: 154 | searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data]) 155 | topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20') 156 | qrels = get_qrels(THE_TOPICS[data]) 157 | rank_results = run_retriever(topics, searcher, qrels, k=100) 158 | rank_results = rank_results[:100] 159 | 160 | # Store JSON in rank_results to a file 161 | with open(f'data/rank_results/{data}.json', 'w') as f: 162 | json.dump(rank_results, f, indent=2) 163 | # Store the QRELS of the dataset 164 | with open(f'data/qrels/{data}.json', 'w') as f: 165 | json.dump(qrels, f, indent=2) 166 | except: 167 | print(f'Failed to retrieve passages for {data}') 168 | -------------------------------------------------------------------------------- /InstructDistill/instruction_distill.py: -------------------------------------------------------------------------------- 1 | try: 2 | from fastchat.train.llama_flash_attn_monkey_patch import ( 3 | replace_llama_attn_with_flash_attn, 4 | ) 5 | 6 | replace_llama_attn_with_flash_attn() 7 | except: 8 | print('Install fastchat to use flash attention. Refer to https://github.com/lm-sys/FastChat') 9 | 10 | import json 11 | from torch.utils.data import Dataset 12 | from accelerate import Accelerator 13 | from transformers import AutoTokenizer, AdamW, AutoModelForSeq2SeqLM, AutoConfig, AutoModelForCausalLM 14 | import torch 15 | import torch.distributed as dist 16 | from torch.utils.data import DistributedSampler 17 | from tqdm import tqdm 18 | from rank_loss import RankLoss 19 | import argparse 20 | import numpy as np 21 | import os 22 | 23 | 24 | class RerankData(Dataset): 25 | def __init__(self, data, tokenizer, psg_num=20, label=True): 26 | self.data = data 27 | self.tokenizer = tokenizer 28 | self.psg_num = psg_num 29 | self.label = label 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | @staticmethod 35 | def prompt(query, psg, max_len=200): 36 | psg = ' '.join(psg.split()[:max_len]) 37 | return f"Predict whether the given passage answer the question.\n\nQuestion: {query}\n\nPassage: {psg}\n\nDoes the passage answer the question?" 38 | 39 | def __getitem__(self, item): 40 | item = self.data[item] 41 | query = item['query'] 42 | 43 | if self.label: 44 | pos = [str(item['positive_passages'][0]['text'])] 45 | pos_id = [psg['docid'] for psg in item['positive_passages']] 46 | neg = [str(psg['text']) for psg in item['retrieved_passages'] if psg['docid'] not in pos_id][:self.psg_num] 47 | else: 48 | pos = [] 49 | neg = [str(psg['text']) for psg in item['retrieved_passages']][:self.psg_num] 50 | 51 | passages = pos + neg 52 | passages = passages[:self.psg_num] 53 | passages = passages + [''] * (self.psg_num - len(passages)) 54 | data = [self.prompt(query, psg) for psg in passages] 55 | return data 56 | 57 | def collate_fn(self, data): 58 | data = sum(data, []) 59 | batch_size = len(data) 60 | features = self.tokenizer(data, padding=True, truncation=True, return_tensors="pt", 61 | max_length=2048) 62 | 63 | features['decoder_input_ids'] = torch.zeros(batch_size, 1).long() 64 | return features 65 | 66 | 67 | def receive_response(data, responses): 68 | def clean_response(response: str): 69 | new_response = '' 70 | for c in response: 71 | if not c.isdigit(): 72 | new_response += ' ' 73 | else: 74 | new_response += c 75 | new_response = new_response.strip() 76 | return new_response 77 | 78 | def remove_duplicate(response): 79 | new_response = [] 80 | for c in response: 81 | if c not in new_response: 82 | new_response.append(c) 83 | return new_response 84 | 85 | new_data = [] 86 | for item, response in zip(data, responses): 87 | response = clean_response(response) 88 | response = [int(x) - 1 for x in response.split()] 89 | response = remove_duplicate(response) 90 | passages = item['retrieved_passages'] 91 | original_rank = [tt for tt in range(len(passages))] 92 | response = [ss for ss in response if ss in original_rank] 93 | response = response + [tt for tt in original_rank if tt not in response] 94 | new_passages = [passages[ii] for ii in response] 95 | new_data.append({'query': item['query'], 96 | 'positive_passages': item['positive_passages'], 97 | 'retrieved_passages': new_passages}) 98 | return new_data 99 | 100 | 101 | def split_data(data, process_idx, num_processes): 102 | if isinstance(data, torch.Tensor): 103 | sublist_length, remainder = divmod(data.size(0), num_processes) 104 | return data[process_idx * sublist_length + min(process_idx, remainder):(process_idx + 1) * sublist_length + min( 105 | process_idx + 1, remainder)] 106 | else: 107 | return data 108 | 109 | 110 | def gather_tensors(local_tensor, pad=False): 111 | if not dist.is_initialized(): 112 | return local_tensor 113 | 114 | if pad: 115 | local_size = torch.tensor([local_tensor.size(0)], device=local_tensor.device) 116 | sizes = [torch.zeros_like(local_size) for _ in range(dist.get_world_size())] 117 | dist.all_gather(sizes, local_size) 118 | 119 | max_size = max(torch.stack(sizes)).item() 120 | 121 | padded_tensor = torch.zeros(max_size, *local_tensor.size()[1:], device=local_tensor.device) 122 | padded_tensor[:local_tensor.size(0)] = local_tensor 123 | 124 | gathered_tensors = [torch.zeros_like(padded_tensor) for _ in range(dist.get_world_size())] 125 | dist.all_gather(gathered_tensors, padded_tensor) 126 | 127 | mask = [torch.arange(padded_tensor.size(0), device=padded_tensor.device) < size_tensor.item() 128 | for size_tensor in sizes] 129 | 130 | gathered_tensors = [gathered_tensor[mask_tensor] for gathered_tensor, mask_tensor in 131 | zip(gathered_tensors, mask)] 132 | 133 | else: 134 | gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(dist.get_world_size())] 135 | dist.all_gather(gathered_tensors, local_tensor) 136 | 137 | gathered_tensors[dist.get_rank()] = local_tensor 138 | 139 | return torch.cat(gathered_tensors, dim=0) 140 | 141 | 142 | def train(args): 143 | model_name = args.model 144 | loss_type = args.loss 145 | data_path = args.data 146 | save_path = args.save_path 147 | permutation = args.permutation 148 | 149 | accelerator = Accelerator(gradient_accumulation_steps=2) 150 | batch_size = 16 151 | psg_num = 8 152 | 153 | # Load model and tokenizer 154 | if 't5' in model_name: # flan-t5 155 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 156 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 157 | token_Yes = 2163 158 | else: # llama 159 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="left", model_max_length=4096) 160 | tokenizer.pad_token_id = 0 161 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 162 | token_Yes = 3869 163 | 164 | model.config.use_cache = False 165 | model.gradient_checkpointing_enable() 166 | 167 | # Load data and permutation 168 | data = [json.loads(line) for line in open(data_path)] 169 | response = json.load(open(permutation)) 170 | data = receive_response(data, response) 171 | dataset = RerankData(data, tokenizer, psg_num=psg_num, label=False) 172 | 173 | # Distributed training 174 | train_sampler = DistributedSampler(dataset, num_replicas=1, rank=0) 175 | data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, 176 | batch_size=batch_size, num_workers=0, 177 | sampler=train_sampler) 178 | 179 | optimizer = AdamW(model.parameters(), 2e-5) 180 | model, optimizer, _ = accelerator.prepare(model, optimizer, data_loader) 181 | 182 | loss_function = getattr(RankLoss, loss_type) 183 | 184 | for epoch in range(3): 185 | accelerator.print(f'Training {save_path} {epoch}') 186 | accelerator.wait_for_everyone() 187 | model.train() 188 | tk0 = tqdm(data_loader, total=len(data_loader)) 189 | loss_report = [] 190 | for batch in tk0: 191 | with accelerator.accumulate(model): 192 | # Split the tensor based on the GPU id 193 | batch = {k: split_data(v, accelerator.process_index, accelerator.num_processes) for k, v in 194 | batch.items()} 195 | batch = {k: v.cuda() for k, v in batch.items()} 196 | 197 | out = model(**batch) 198 | logits = gather_tensors(out.logits[:, -1, token_Yes].contiguous()) # Gather all predictions across GPUs 199 | logits = logits.view(-1, psg_num) 200 | 201 | y_true = torch.tensor([[1 / (i + 1) for i in range(logits.size(1))]] * logits.size(0)).cuda() 202 | 203 | loss = loss_function(logits, y_true) 204 | 205 | accelerator.backward(loss) 206 | accelerator.clip_grad_norm_(model.parameters(), 1.) 207 | optimizer.step() 208 | optimizer.zero_grad() 209 | 210 | loss_report.append(accelerator.gather(loss).mean().item()) 211 | tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:])) 212 | 213 | accelerator.wait_for_everyone() 214 | model.save_checkpoint(f'{save_path}/{epoch}') 215 | return model, tokenizer 216 | 217 | 218 | def eval_on_benchmark(args, model, tokenizer): 219 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 220 | from bm25_retrieval import THE_RESULTS 221 | from trec_eval import EvalFunction 222 | 223 | # save_path = 'out/new-flan-t5-large-from-large/2.pt' 224 | save_path = 'out/new-flan-t5-xl-from-xl/1/pytorch_model.bin' 225 | 226 | model_name = 'models/flan-t5-xl' 227 | 228 | print(save_path) 229 | print(model_name) 230 | 231 | if model is not None and tokenizer is not None: 232 | pass 233 | elif 't5' in model_name: # flan-t5 234 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 235 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 236 | else: # llama 237 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="left", 238 | model_max_length=4096) 239 | tokenizer.pad_token_id = 0 240 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 241 | 242 | token_Yes = 2163 if 't5' in model_name else 3869 243 | 244 | model.load_state_dict(torch.load(f'{save_path}')) 245 | model = model.cuda() 246 | model.eval() 247 | 248 | # data_list = ['dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04'] 249 | data_list = ['dl19', 'dl20'] 250 | for data_name in data_list: 251 | print() 252 | print('#' * 20) 253 | print(save_path) 254 | print(f'Now eval [{data_name}]') 255 | print('#' * 20) 256 | 257 | rank_results = json.load(open(THE_RESULTS[data_name])) 258 | saved = [] 259 | for item in tqdm(rank_results): 260 | q = item['query'] 261 | passages = [psg['content'] for i, psg in enumerate(item['hits'])][:100] 262 | if len(passages) == 0: 263 | saved.append('') 264 | continue 265 | 266 | i = 0 267 | normalized_scores = [] 268 | while i < len(passages): 269 | batch = passages[i: i + 10] 270 | i += 10 271 | 272 | features = tokenizer([RerankData.prompt(q, psg) for psg in batch], padding=True, truncation=True, 273 | return_tensors="pt", max_length=1024) 274 | if 't5' in model_name: 275 | features['decoder_input_ids'] = torch.zeros(len(batch), 1).long() 276 | 277 | features = {k: v.cuda() for k, v in features.items()} 278 | with torch.no_grad(): 279 | scores = model(**features).logits[:, -1, token_Yes] 280 | normalized_scores.extend([float(score) for score in scores]) 281 | 282 | ranked = np.argsort(normalized_scores)[::-1] 283 | response = ' > '.join([str(ss + 1) for ss in ranked]) 284 | saved.append(response) 285 | 286 | rank_results = EvalFunction.receive_responses(rank_results, saved, cut_start=0, cut_end=100) 287 | tmp_path = save_path.replace('/', '-') 288 | tmp_path = 'tmp/' + tmp_path 289 | EvalFunction.write_file(rank_results, tmp_path) 290 | EvalFunction.main(data_name, tmp_path) 291 | 292 | 293 | def parse_args(): 294 | parser = argparse.ArgumentParser() 295 | parser.add_argument('--model', type=str, default='google/flan-t5-xl') 296 | parser.add_argument('--loss', type=str, default='rank_net') 297 | parser.add_argument('--data', type=str, default='data/marco-train-10k.jsonl') 298 | parser.add_argument('--save_path', type=str, default='out/flan-t5-xl-id') 299 | parser.add_argument('--permutation', type=str, default='marco-train-10k-gpt3.5.json') 300 | parser.add_argument('--do_train', type=bool, default=True) 301 | parser.add_argument('--do_eval', type=bool, default=True) 302 | args = parser.parse_args() 303 | 304 | print('====Input Arguments====') 305 | print(json.dumps(vars(args), indent=2, sort_keys=False)) 306 | return args 307 | 308 | 309 | if __name__ == '__main__': 310 | args = parse_args() 311 | model, tokenizer = None, None 312 | if args.do_train: 313 | model, tokenizer = train(args) 314 | if args.de_eval: 315 | eval_on_benchmark(args, model, tokenizer) 316 | -------------------------------------------------------------------------------- /InstructDistill/pairwise_ranking.py: -------------------------------------------------------------------------------- 1 | # Pariwise Ranking Prompting 2 | # 3 | 4 | import json 5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM 6 | import torch 7 | from tqdm import tqdm 8 | import numpy as np 9 | import os 10 | import argparse 11 | 12 | 13 | FLAN_PRP_PROMPT = '''Question: Given a query "{0}", which of the following two passages is more relevant to the query? 14 | 15 | passage A: {1} 16 | 17 | passage B: {2} 18 | 19 | Output the identifier of the more relevant passage. The answer must be passage A or passage B. 20 | Answer:''' 21 | 22 | GPT_PRP_PROMPT = '''### System: 23 | You are a pairwise passage ranker that can judge which passages is more relevant to the query. 24 | 25 | ### User: 26 | Given a query "{0}", which of the following two passages is more relevant to the query? 27 | 28 | Passage A: {1} 29 | 30 | Passage B: {2} 31 | 32 | Output the identifier of the more relevant passage. The answer must be Passage A or Passage B. 33 | 34 | ### Assistant: 35 | The more relevant passage is Passage''' 36 | 37 | 38 | def eval_prp(model_name): 39 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 40 | from trec_eval import EvalFunction 41 | from bm25_retrieval import THE_RESULTS 42 | 43 | print(model_name) 44 | 45 | if 't5' in model_name: 46 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 47 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16) 48 | token_passage = 5454 49 | token_A = 71 50 | token_B = 272 51 | PROMPT = FLAN_PRP_PROMPT 52 | else: 53 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="left", 54 | model_max_length=4096) 55 | tokenizer.pad_token_id = 0 56 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) 57 | token_passage = None 58 | token_A = 319 59 | token_B = 350 60 | PROMPT = GPT_PRP_PROMPT 61 | 62 | model = model.cuda() 63 | model.eval() 64 | 65 | # data_list = ['dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04'] 66 | data_list = ['dl19', 'dl20'] 67 | for data_name in data_list: 68 | print() 69 | print('#' * 20) 70 | print(f'Now eval [{data_name}]') 71 | print('#' * 20) 72 | 73 | rank_results = json.load(open(THE_RESULTS[data_name])) 74 | saved = [] 75 | for item in tqdm(rank_results): 76 | q = item['query'] 77 | passages = [psg['content'] for i, psg in enumerate(item['hits'])][:100] 78 | passages = [' '.join(psg.split()[:100]) for psg in passages] 79 | if len(passages) == 0: 80 | saved.append('') 81 | continue 82 | 83 | all_score = [0 for _ in range(len(passages))] 84 | 85 | new_passages = [] 86 | for i in range(len(passages)): 87 | for j in range(len(passages)): 88 | if i == j: 89 | continue 90 | prompt = PROMPT.format(q, passages[i], passages[j]) 91 | new_passages.append([prompt, i, j]) 92 | passages = new_passages 93 | 94 | i = 0 95 | while i < len(passages): 96 | batch = passages[i: i + 10] 97 | i += 10 98 | features = tokenizer([psg[0] for psg in batch], padding=True, truncation=True, return_tensors="pt") 99 | if 't5' in model_name: 100 | features['decoder_input_ids'] = torch.tensor([[0, token_passage]] * len(batch)).long() 101 | features = {k: v.cuda() for k, v in features.items()} 102 | with torch.no_grad(): 103 | scores = model(**features).logits[:, -1] 104 | for score, psg in zip(scores, batch): 105 | if score[token_A] > score[token_B]: 106 | all_score[psg[1]] += 1 107 | elif score[token_B] > score[token_A]: 108 | all_score[psg[2]] += 1 109 | else: 110 | all_score[psg[1]] += 0.5 111 | all_score[psg[2]] += 0.5 112 | all_score = [s + 1 / (10 + r) for r, s in enumerate(all_score)] 113 | ranked = np.argsort(all_score)[::-1] 114 | response = ' > '.join([str(ss + 1) for ss in ranked]) 115 | saved.append(response) 116 | 117 | rank_results = EvalFunction.receive_responses(rank_results, saved, cut_start=0, cut_end=100) 118 | tmp_path = 'tmp_rank_results' 119 | EvalFunction.write_file(rank_results, tmp_path) 120 | 121 | EvalFunction.main(data_name, tmp_path) 122 | 123 | 124 | def generate_data(model_name, data_path, save_path): 125 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 126 | 127 | if 't5' in model_name: 128 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 129 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16) 130 | token_passage = 5454 131 | token_A = 71 132 | token_B = 272 133 | PROMPT = FLAN_PRP_PROMPT 134 | else: 135 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="left", 136 | model_max_length=4096) 137 | tokenizer.pad_token_id = 0 138 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) 139 | token_passage = None 140 | token_A = 319 141 | token_B = 350 142 | PROMPT = GPT_PRP_PROMPT 143 | 144 | rank_results = [json.loads(line) for line in open(data_path)][:10000] 145 | saved = [] 146 | for item in tqdm(rank_results): 147 | q = item['query'] 148 | passages = [psg['text'] for i, psg in enumerate(item['retrieved_passages'])][:20] 149 | passages = [' '.join(psg.split()[:100]) for psg in passages] 150 | if len(passages) == 0: 151 | saved.append('') 152 | continue 153 | 154 | all_score = [0 for _ in range(len(passages))] 155 | 156 | new_passages = [] 157 | for i in range(len(passages)): 158 | for j in range(len(passages)): 159 | if i == j: 160 | continue 161 | prompt = PROMPT.format(q, passages[i], passages[j]) 162 | new_passages.append([prompt, i, j]) 163 | passages = new_passages 164 | 165 | i = 0 166 | while i < len(passages): 167 | batch = passages[i: i + 10] 168 | i += 10 169 | features = tokenizer([psg[0] for psg in batch], padding=True, truncation=True, return_tensors="pt", 170 | max_length=1024) 171 | if 't5' in model_name: 172 | features['decoder_input_ids'] = torch.tensor([[0, token_passage]] * len(batch)).long() 173 | features = {k: v.cuda() for k, v in features.items()} 174 | with torch.no_grad(): 175 | scores = model(**features).logits[:, -1] 176 | for score, psg in zip(scores, batch): 177 | if score[token_A] > score[token_B]: 178 | all_score[psg[1]] += 1 179 | elif score[token_B] > score[token_A]: 180 | all_score[psg[2]] += 1 181 | else: 182 | all_score[psg[1]] += 0.5 183 | all_score[psg[2]] += 0.5 184 | all_score = [s + 1 / (10 + r) for r, s in enumerate(all_score)] 185 | ranked = np.argsort(all_score)[::-1] 186 | response = ' > '.join([str(ss + 1) for ss in ranked]) 187 | saved.append(response) 188 | json.dump(saved, open(save_path, 'w')) 189 | 190 | 191 | def parse_args(): 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument('--model', type=str, default='google/flan-t5-xl') 194 | parser.add_argument('--eval', type=bool, default=True) 195 | parser.add_argument('--generate', type=bool, default=True) 196 | parser.add_argument('--data', type=str, default='data/marco-train-10k.jsonl') 197 | parser.add_argument('--save_path', type=str, default='out/rpr-flan-t5-xl.json') 198 | args = parser.parse_args() 199 | 200 | print('====Input Arguments====') 201 | print(json.dumps(vars(args), indent=2, sort_keys=False)) 202 | return args 203 | 204 | 205 | if __name__ == '__main__': 206 | args = parse_args() 207 | 208 | # Eval pairwise ranking on benchmarks 209 | if args.eval: 210 | eval_prp(args.model) 211 | 212 | # Get predictions on MS MARCO 213 | if args.generate: 214 | generate_data(args.model, args.data, args.save_path) 215 | -------------------------------------------------------------------------------- /InstructDistill/rank_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from torch.nn import BCELoss, BCEWithLogitsLoss 4 | from itertools import product 5 | 6 | 7 | class RankLoss: 8 | 9 | @staticmethod 10 | def softmax_ce_loss(y_pred, *args, **kwargs): 11 | return F.cross_entropy(y_pred, torch.zeros((y_pred.size(0),)).long().cuda()) 12 | 13 | @staticmethod 14 | def pointwise_rmse(y_pred, y_true=None): 15 | if y_true is None: 16 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 17 | y_true[:, 0] = 1 18 | errors = (y_true - y_pred) 19 | squared_errors = errors ** 2 20 | valid_mask = (y_true != -100).float() 21 | mean_squared_errors = torch.sum(squared_errors, dim=1) / torch.sum(valid_mask, dim=1) 22 | rmses = torch.sqrt(mean_squared_errors) 23 | return torch.mean(rmses) 24 | 25 | @staticmethod 26 | def pointwise_bce(y_pred, y_true=None): 27 | if y_true is None: 28 | y_true = torch.zeros_like(y_pred).float().to(y_pred.device) 29 | y_true[:, 0] = 1 30 | loss = F.binary_cross_entropy(torch.sigmoid(y_pred), y_true) 31 | return loss 32 | 33 | @staticmethod 34 | def list_net(y_pred, y_true=None, padded_value_indicator=-100, eps=1e-10): 35 | """ 36 | ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach". 37 | :param y_pred: predictions from the model, shape [batch_size, slate_length] 38 | :param y_true: ground truth labels, shape [batch_size, slate_length] 39 | :param eps: epsilon value, used for numerical stability 40 | :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 41 | :return: loss value, a torch.Tensor 42 | """ 43 | if y_true is None: 44 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 45 | y_true[:, 0] = 1 46 | 47 | preds_smax = F.softmax(y_pred, dim=1) 48 | true_smax = F.softmax(y_true, dim=1) 49 | 50 | preds_smax = preds_smax + eps 51 | preds_log = torch.log(preds_smax) 52 | 53 | return torch.mean(-torch.sum(true_smax * preds_log, dim=1)) 54 | 55 | @staticmethod 56 | def rank_net(y_pred, y_true=None, padded_value_indicator=-100, weight_by_diff=False, 57 | weight_by_diff_powed=False): 58 | """ 59 | RankNet loss introduced in "Learning to Rank using Gradient Descent". 60 | :param y_pred: predictions from the model, shape [batch_size, slate_length] 61 | :param y_true: ground truth labels, shape [batch_size, slate_length] 62 | :param weight_by_diff: flag indicating whether to weight the score differences by ground truth differences. 63 | :param weight_by_diff_powed: flag indicating whether to weight the score differences by the squared ground truth differences. 64 | :return: loss value, a torch.Tensor 65 | """ 66 | if y_true is None: 67 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 68 | y_true[:, 0] = 1 69 | 70 | # here we generate every pair of indices from the range of document length in the batch 71 | document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2)) 72 | 73 | pairs_true = y_true[:, document_pairs_candidates] 74 | selected_pred = y_pred[:, document_pairs_candidates] 75 | 76 | # here we calculate the relative true relevance of every candidate pair 77 | true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1] 78 | pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1] 79 | 80 | # here we filter just the pairs that are 'positive' and did not involve a padded instance 81 | # we can do that since in the candidate pairs we had symetric pairs so we can stick with 82 | # positive ones for a simpler loss function formulation 83 | the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs)) 84 | 85 | pred_diffs = pred_diffs[the_mask] 86 | 87 | weight = None 88 | if weight_by_diff: 89 | abs_diff = torch.abs(true_diffs) 90 | weight = abs_diff[the_mask] 91 | elif weight_by_diff_powed: 92 | true_pow_diffs = torch.pow(pairs_true[:, :, 0], 2) - torch.pow(pairs_true[:, :, 1], 2) 93 | abs_diff = torch.abs(true_pow_diffs) 94 | weight = abs_diff[the_mask] 95 | 96 | # here we 'binarize' true relevancy diffs since for a pairwise loss we just need to know 97 | # whether one document is better than the other and not about the actual difference in 98 | # their relevancy levels 99 | true_diffs = (true_diffs > 0).type(torch.float32) 100 | true_diffs = true_diffs[the_mask] 101 | 102 | return BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs) 103 | 104 | @staticmethod 105 | def lambda_loss(y_pred, y_true=None, eps=1e-10, padded_value_indicator=-100, weighing_scheme=None, k=None, 106 | sigma=1., mu=10., reduction="mean", reduction_log="binary"): 107 | """ 108 | LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization". 109 | Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet. 110 | :param y_pred: predictions from the model, shape [batch_size, slate_length] 111 | :param y_true: ground truth labels, shape [batch_size, slate_length] 112 | :param eps: epsilon value, used for numerical stability 113 | :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 114 | :param weighing_scheme: a string corresponding to a name of one of the weighing schemes 115 | :param k: rank at which the loss is truncated 116 | :param sigma: score difference weight used in the sigmoid function 117 | :param mu: optional weight used in NDCGLoss2++ weighing scheme 118 | :param reduction: losses reduction method, could be either a sum or a mean 119 | :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural 120 | :return: loss value, a torch.Tensor 121 | """ 122 | if y_true is None: 123 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 124 | y_true[:, 0] = 1 125 | 126 | device = y_pred.device 127 | 128 | # Here we sort the true and predicted relevancy scores. 129 | y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1) 130 | y_true_sorted, _ = y_true.sort(descending=True, dim=-1) 131 | 132 | # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element. 133 | true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred) 134 | true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :] 135 | padded_pairs_mask = torch.isfinite(true_diffs) 136 | 137 | if weighing_scheme != "ndcgLoss1_scheme": 138 | padded_pairs_mask = padded_pairs_mask & (true_diffs > 0) 139 | 140 | ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device) 141 | ndcg_at_k_mask[:k, :k] = 1 142 | 143 | # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs) 144 | true_sorted_by_preds.clamp_(min=0.) 145 | y_true_sorted.clamp_(min=0.) 146 | 147 | # Here we find the gains, discounts and ideal DCGs per slate. 148 | pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device) 149 | D = torch.log2(1. + pos_idxs.float())[None, :] 150 | maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps) 151 | G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None] 152 | 153 | # Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0) 154 | if weighing_scheme is None: 155 | weights = 1. 156 | else: 157 | weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds) # type: ignore 158 | 159 | # We are clamping the array entries to maintain correct backprop (log(0) and division by 0) 160 | scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8) 161 | scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.) 162 | weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps) 163 | if reduction_log == "natural": 164 | losses = torch.log(weighted_probas) 165 | elif reduction_log == "binary": 166 | losses = torch.log2(weighted_probas) 167 | else: 168 | raise ValueError("Reduction logarithm base can be either natural or binary") 169 | 170 | if reduction == "sum": 171 | loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask]) 172 | elif reduction == "mean": 173 | loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask]) 174 | else: 175 | raise ValueError("Reduction method can be either sum or mean") 176 | 177 | return loss 178 | 179 | 180 | def ndcgLoss1_scheme(G, D, *args): 181 | return (G / D)[:, :, None] 182 | 183 | 184 | def ndcgLoss2_scheme(G, D, *args): 185 | pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device) 186 | delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :]) 187 | deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.)) 188 | deltas.diagonal().zero_() 189 | 190 | return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :]) 191 | 192 | 193 | def lambdaRank_scheme(G, D, *args): 194 | return torch.abs(torch.pow(D[:, :, None], -1.) - torch.pow(D[:, None, :], -1.)) * torch.abs( 195 | G[:, :, None] - G[:, None, :]) 196 | 197 | 198 | def ndcgLoss2PP_scheme(G, D, *args): 199 | return args[0] * ndcgLoss2_scheme(G, D) + lambdaRank_scheme(G, D) 200 | 201 | 202 | def rankNet_scheme(G, D, *args): 203 | return 1. 204 | 205 | 206 | def rankNetWeightedByGTDiff_scheme(G, D, *args): 207 | return torch.abs(args[1][:, :, None] - args[1][:, None, :]) 208 | 209 | 210 | def rankNetWeightedByGTDiffPowed_scheme(G, D, *args): 211 | return torch.abs(torch.pow(args[1][:, :, None], 2) - torch.pow(args[1][:, None, :], 2)) -------------------------------------------------------------------------------- /InstructDistill/trec_eval.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import tempfile 3 | import os 4 | import copy 5 | from typing import Dict, Tuple 6 | import pytrec_eval 7 | 8 | 9 | def trec_eval(qrels: Dict[str, Dict[str, int]], 10 | results: Dict[str, Dict[str, float]], 11 | k_values: Tuple[int] = (10, 50, 100, 200, 1000)) -> Dict[str, float]: 12 | ndcg, _map, recall = {}, {}, {} 13 | 14 | for k in k_values: 15 | ndcg[f"NDCG@{k}"] = 0.0 16 | _map[f"MAP@{k}"] = 0.0 17 | recall[f"Recall@{k}"] = 0.0 18 | 19 | map_string = "map_cut." + ",".join([str(k) for k in k_values]) 20 | ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) 21 | recall_string = "recall." + ",".join([str(k) for k in k_values]) 22 | 23 | evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string}) 24 | scores = evaluator.evaluate(results) 25 | 26 | for query_id in scores: 27 | for k in k_values: 28 | ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)] 29 | _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)] 30 | recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)] 31 | 32 | def _normalize(m: dict) -> dict: 33 | return {k: round(v / len(scores), 5) for k, v in m.items()} 34 | 35 | ndcg = _normalize(ndcg) 36 | _map = _normalize(_map) 37 | recall = _normalize(recall) 38 | 39 | all_metrics = {} 40 | for mt in [ndcg, _map, recall]: 41 | all_metrics.update(mt) 42 | 43 | return all_metrics 44 | 45 | 46 | def get_qrels_file(name): 47 | THE_TOPICS = { 48 | 'dl19': 'dl19-passage', 49 | 'dl20': 'dl20-passage', 50 | 'covid': 'beir-v1.0.0-trec-covid-test', 51 | 'arguana': 'beir-v1.0.0-arguana-test', 52 | 'touche': 'beir-v1.0.0-webis-touche2020-test', 53 | 'news': 'beir-v1.0.0-trec-news-test', 54 | 'scifact': 'beir-v1.0.0-scifact-test', 55 | 'fiqa': 'beir-v1.0.0-fiqa-test', 56 | 'scidocs': 'beir-v1.0.0-scidocs-test', 57 | 'nfc': 'beir-v1.0.0-nfcorpus-test', 58 | 'quora': 'beir-v1.0.0-quora-test', 59 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity-test', 60 | 'fever': 'beir-v1.0.0-fever-test', 61 | 'robust04': 'beir-v1.0.0-robust04-test', 62 | 'signal': 'beir-v1.0.0-signal1m-test', 63 | } 64 | name = THE_TOPICS[name] 65 | name = name.replace('-test', '.test') 66 | name = 'data/label_file/qrels.' + name + '.txt' 67 | return name 68 | 69 | 70 | def remove_duplicate(response): 71 | new_response = [] 72 | for c in response: 73 | if c not in new_response: 74 | new_response.append(c) 75 | else: 76 | print('duplicate') 77 | return new_response 78 | 79 | 80 | def clean_response(response: str): 81 | new_response = '' 82 | for c in response: 83 | if not c.isdigit(): 84 | new_response += ' ' 85 | else: 86 | try: 87 | new_response += str(int(c)) 88 | except: 89 | new_response += ' ' 90 | new_response = new_response.strip() 91 | return new_response 92 | 93 | 94 | class EvalFunction: 95 | @staticmethod 96 | def receive_responses(rank_results, responses, cut_start=0, cut_end=100): 97 | print('receive_responses', len(responses), len(rank_results)) 98 | for i in range(len(responses)): 99 | response = responses[i] 100 | response = clean_response(response) 101 | response = [int(x) - 1 for x in response.split()] 102 | response = remove_duplicate(response) 103 | cut_range = copy.deepcopy(rank_results[i]['hits'][cut_start: cut_end]) 104 | original_rank = [tt for tt in range(len(cut_range))] 105 | response = [ss for ss in response if ss in original_rank] 106 | response = response + [tt for tt in original_rank if tt not in response] 107 | for j, x in enumerate(response): 108 | rank_results[i]['hits'][j + cut_start] = { 109 | 'content': cut_range[x]['content'], 'qid': cut_range[x]['qid'], 'docid': cut_range[x]['docid'], 110 | 'rank': cut_range[j]['rank'], 'score': cut_range[j]['score']} 111 | return rank_results 112 | 113 | @staticmethod 114 | def write_file(rank_results, file): 115 | print('write_file') 116 | with open(file, 'w') as f: 117 | for i in range(len(rank_results)): 118 | rank = 1 119 | hits = rank_results[i]['hits'] 120 | for hit in hits: 121 | f.write(f"{hit['qid']} Q0 {hit['docid']} {rank} {hit['score']} rank\n") 122 | rank += 1 123 | return True 124 | 125 | @staticmethod 126 | def trunc(qrels, run): 127 | qrels = get_qrels_file(qrels) 128 | # print(qrels) 129 | run = pd.read_csv(run, delim_whitespace=True, header=None) 130 | qrels = pd.read_csv(qrels, delim_whitespace=True, header=None) 131 | run[0] = run[0].astype(str) 132 | qrels[0] = qrels[0].astype(str) 133 | 134 | qrels = qrels[qrels[0].isin(run[0])] 135 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 136 | qrels.to_csv(temp_file, sep='\t', header=None, index=None) 137 | return temp_file 138 | 139 | @staticmethod 140 | def main(args_qrel, args_run): 141 | 142 | args_qrel = EvalFunction.trunc(args_qrel, args_run) 143 | 144 | assert os.path.exists(args_qrel) 145 | assert os.path.exists(args_run) 146 | 147 | with open(args_qrel, 'r') as f_qrel: 148 | qrel = pytrec_eval.parse_qrel(f_qrel) 149 | 150 | with open(args_run, 'r') as f_run: 151 | run = pytrec_eval.parse_run(f_run) 152 | 153 | all_metrics = trec_eval(qrel, run, k_values=(1, 5, 10)) 154 | print(all_metrics) 155 | return all_metrics 156 | 157 | 158 | if __name__ == '__main__': 159 | EvalFunction.main('dl19', 'ranking_results_file') 160 | -------------------------------------------------------------------------------- /InstructDistill/zero2_bf16_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 4 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2019-2021 Pyserini authors 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /NovelEval/README.md: -------------------------------------------------------------------------------- 1 | # NovelEval 2 | *A new test set with the novel queries and passages that have not been contaminated by the latest LLMs* 3 | 4 | The questions in the current benchmark dataset are typically gathered years ago, which raises the issue that existing LLMs already possess knowledge of these questions. 5 | Furthermore, since many LLMs do not disclose information about their training data, there is a potential risk of contamination of the existing benchmark test set. 6 | However, re-ranking models are expected to possess the capability to comprehend, deduce, and rank knowledge that is inherently unknown to them. 7 | 8 | Therefore, we suggest constructing **continuously updated IR test sets** to ensure that the questions, passages to be ranked, and relevance annotations have not been learned by the latest LLMs for a fair evaluation. 9 | 10 | ## Data Collection 11 | As an initial effort, we built **NovelEval-2306**, a novel test set with 21 novel questions collected during 2023-06. 12 | This test set is constructed by gathering questions and passages fromfrom 4 domains that were published after the release of GPT-4. 13 | To ensure that GPT-4 did not possess prior knowledge of these questions, we presented them to both gpt-4-0314 and gpt-4-0613. 14 | For instance, question *"Which film was the 2023 Palme d'Or winner?"* pertains to the Cannes Film Festival that took place on May 27, 2023, rendering its answer inaccessible to most existing LLMs. 15 | Next, we searched 20 candidate passages for each question using Google search. 16 | The relevance of these passages was manually labeled as: 0 for not relevant, 1 for partially relevant, and 2 for relevant. 17 | 18 | 19 | ## Files 20 | | Type | Filename | Format| 21 | | ---- | ---- | ---- | 22 | | Corpus | [corpus.tsv](https://github.com/sunnweiwei/RankGPT/blob/main/NovelEval/corpus.tsv) | tsv: docid, content | 23 | | Queries | [queries.tsv](https://github.com/sunnweiwei/RankGPT/blob/main/NovelEval/queries.tsv) | tsv: qid, query | 24 | | Qrels | [qrels.txt](https://github.com/sunnweiwei/RankGPT/blob/main/NovelEval/qrels.txt) | TREC qrels format: qid, Q0, docid, relevance-score | 25 | 26 | ## Results 27 | 28 | | Method | nDCG@1 | nDCG@5 | nDCG@10 | 29 | | ---- | ----- | ----- | ----- | 30 | | BM25 | 33.33 | 45.96 | 55.77 | 31 | | monoBERT (340M) | 78.57 | 70.65 | 77.27 | 32 | | monoT5 (220M) | 83.33 | 77.46 | 81.27 | 33 | | monoT5 (3B) | 83.33 | 78.38 | 84.62 | 34 | | gpt-3.5-turbo | 76.19 | 74.15 | 75.71 | 35 | | **gpt-4** | **85.71** | **87.49** | **90.45** | 36 | -------------------------------------------------------------------------------- /NovelEval/qrels.txt: -------------------------------------------------------------------------------- 1 | 0 Q0 0-0 0 2 | 0 Q0 0-1 0 3 | 0 Q0 0-2 0 4 | 0 Q0 0-3 2 5 | 0 Q0 0-4 2 6 | 0 Q0 0-5 0 7 | 0 Q0 0-6 2 8 | 0 Q0 0-7 0 9 | 0 Q0 0-8 0 10 | 0 Q0 0-9 0 11 | 0 Q0 0-10 0 12 | 0 Q0 0-11 0 13 | 0 Q0 0-12 0 14 | 0 Q0 0-13 0 15 | 0 Q0 0-14 0 16 | 0 Q0 0-15 0 17 | 0 Q0 0-16 0 18 | 0 Q0 0-17 0 19 | 0 Q0 0-18 0 20 | 0 Q0 0-19 0 21 | 1 Q0 1-0 2 22 | 1 Q0 1-1 1 23 | 1 Q0 1-2 0 24 | 1 Q0 1-3 2 25 | 1 Q0 1-4 0 26 | 1 Q0 1-5 0 27 | 1 Q0 1-6 2 28 | 1 Q0 1-7 2 29 | 1 Q0 1-8 1 30 | 1 Q0 1-9 1 31 | 1 Q0 1-10 0 32 | 1 Q0 1-11 0 33 | 1 Q0 1-12 0 34 | 1 Q0 1-13 2 35 | 1 Q0 1-14 0 36 | 1 Q0 1-15 0 37 | 1 Q0 1-16 0 38 | 1 Q0 1-17 0 39 | 1 Q0 1-18 0 40 | 1 Q0 1-19 0 41 | 2 Q0 2-0 2 42 | 2 Q0 2-1 2 43 | 2 Q0 2-2 0 44 | 2 Q0 2-3 2 45 | 2 Q0 2-4 0 46 | 2 Q0 2-5 0 47 | 2 Q0 2-6 0 48 | 2 Q0 2-7 2 49 | 2 Q0 2-8 0 50 | 2 Q0 2-9 2 51 | 2 Q0 2-10 0 52 | 2 Q0 2-11 0 53 | 2 Q0 2-12 1 54 | 2 Q0 2-13 0 55 | 2 Q0 2-14 0 56 | 2 Q0 2-15 0 57 | 2 Q0 2-16 0 58 | 2 Q0 2-17 0 59 | 2 Q0 2-18 0 60 | 2 Q0 2-19 0 61 | 3 Q0 3-0 2 62 | 3 Q0 3-1 1 63 | 3 Q0 3-2 2 64 | 3 Q0 3-3 0 65 | 3 Q0 3-4 0 66 | 3 Q0 3-5 0 67 | 3 Q0 3-6 0 68 | 3 Q0 3-7 0 69 | 3 Q0 3-8 1 70 | 3 Q0 3-9 0 71 | 3 Q0 3-10 0 72 | 3 Q0 3-11 0 73 | 3 Q0 3-12 0 74 | 3 Q0 3-13 0 75 | 3 Q0 3-14 0 76 | 3 Q0 3-15 0 77 | 3 Q0 3-16 0 78 | 3 Q0 3-17 0 79 | 3 Q0 3-18 0 80 | 3 Q0 3-19 0 81 | 4 Q0 4-0 0 82 | 4 Q0 4-1 0 83 | 4 Q0 4-2 0 84 | 4 Q0 4-3 0 85 | 4 Q0 4-4 2 86 | 4 Q0 4-5 0 87 | 4 Q0 4-6 0 88 | 4 Q0 4-7 1 89 | 4 Q0 4-8 2 90 | 4 Q0 4-9 2 91 | 4 Q0 4-10 0 92 | 4 Q0 4-11 0 93 | 4 Q0 4-12 2 94 | 4 Q0 4-13 1 95 | 4 Q0 4-14 0 96 | 4 Q0 4-15 0 97 | 4 Q0 4-16 0 98 | 4 Q0 4-17 0 99 | 4 Q0 4-18 2 100 | 4 Q0 4-19 2 101 | 5 Q0 5-0 0 102 | 5 Q0 5-1 2 103 | 5 Q0 5-2 1 104 | 5 Q0 5-3 0 105 | 5 Q0 5-4 2 106 | 5 Q0 5-5 2 107 | 5 Q0 5-6 0 108 | 5 Q0 5-7 0 109 | 5 Q0 5-8 0 110 | 5 Q0 5-9 0 111 | 5 Q0 5-10 0 112 | 5 Q0 5-11 0 113 | 5 Q0 5-12 0 114 | 5 Q0 5-13 0 115 | 5 Q0 5-14 2 116 | 5 Q0 5-15 0 117 | 5 Q0 5-16 0 118 | 5 Q0 5-17 1 119 | 5 Q0 5-18 0 120 | 5 Q0 5-19 0 121 | 6 Q0 6-0 0 122 | 6 Q0 6-1 0 123 | 6 Q0 6-2 2 124 | 6 Q0 6-3 0 125 | 6 Q0 6-4 0 126 | 6 Q0 6-5 0 127 | 6 Q0 6-6 0 128 | 6 Q0 6-7 2 129 | 6 Q0 6-8 2 130 | 6 Q0 6-9 2 131 | 6 Q0 6-10 0 132 | 6 Q0 6-11 1 133 | 6 Q0 6-12 0 134 | 6 Q0 6-13 0 135 | 6 Q0 6-14 0 136 | 6 Q0 6-15 0 137 | 6 Q0 6-16 1 138 | 6 Q0 6-17 0 139 | 6 Q0 6-18 0 140 | 6 Q0 6-19 0 141 | 7 Q0 7-0 2 142 | 7 Q0 7-1 1 143 | 7 Q0 7-2 2 144 | 7 Q0 7-3 2 145 | 7 Q0 7-4 0 146 | 7 Q0 7-5 0 147 | 7 Q0 7-6 0 148 | 7 Q0 7-7 0 149 | 7 Q0 7-8 0 150 | 7 Q0 7-9 0 151 | 7 Q0 7-10 0 152 | 7 Q0 7-11 0 153 | 7 Q0 7-12 0 154 | 7 Q0 7-13 0 155 | 7 Q0 7-14 0 156 | 7 Q0 7-15 0 157 | 7 Q0 7-16 2 158 | 7 Q0 7-17 1 159 | 7 Q0 7-18 0 160 | 7 Q0 7-19 0 161 | 8 Q0 8-0 2 162 | 8 Q0 8-1 2 163 | 8 Q0 8-2 0 164 | 8 Q0 8-3 0 165 | 8 Q0 8-4 0 166 | 8 Q0 8-5 0 167 | 8 Q0 8-6 0 168 | 8 Q0 8-7 0 169 | 8 Q0 8-8 0 170 | 8 Q0 8-9 2 171 | 8 Q0 8-10 0 172 | 8 Q0 8-11 0 173 | 8 Q0 8-12 0 174 | 8 Q0 8-13 0 175 | 8 Q0 8-14 0 176 | 8 Q0 8-15 0 177 | 8 Q0 8-16 0 178 | 8 Q0 8-17 0 179 | 8 Q0 8-18 0 180 | 8 Q0 8-19 0 181 | 9 Q0 9-0 2 182 | 9 Q0 9-1 2 183 | 9 Q0 9-2 2 184 | 9 Q0 9-3 0 185 | 9 Q0 9-4 0 186 | 9 Q0 9-5 0 187 | 9 Q0 9-6 0 188 | 9 Q0 9-7 0 189 | 9 Q0 9-8 0 190 | 9 Q0 9-9 2 191 | 9 Q0 9-10 0 192 | 9 Q0 9-11 2 193 | 9 Q0 9-12 2 194 | 9 Q0 9-13 2 195 | 9 Q0 9-14 2 196 | 9 Q0 9-15 0 197 | 9 Q0 9-16 0 198 | 9 Q0 9-17 2 199 | 9 Q0 9-18 0 200 | 9 Q0 9-19 2 201 | 10 Q0 10-0 2 202 | 10 Q0 10-1 0 203 | 10 Q0 10-2 0 204 | 10 Q0 10-3 0 205 | 10 Q0 10-4 0 206 | 10 Q0 10-5 0 207 | 10 Q0 10-6 0 208 | 10 Q0 10-7 0 209 | 10 Q0 10-8 1 210 | 10 Q0 10-9 0 211 | 10 Q0 10-10 0 212 | 10 Q0 10-11 0 213 | 10 Q0 10-12 0 214 | 10 Q0 10-13 0 215 | 10 Q0 10-14 0 216 | 10 Q0 10-15 0 217 | 10 Q0 10-16 0 218 | 10 Q0 10-17 2 219 | 10 Q0 10-18 0 220 | 10 Q0 10-19 0 221 | 11 Q0 11-0 2 222 | 11 Q0 11-1 2 223 | 11 Q0 11-2 0 224 | 11 Q0 11-3 0 225 | 11 Q0 11-4 0 226 | 11 Q0 11-5 0 227 | 11 Q0 11-6 1 228 | 11 Q0 11-7 0 229 | 11 Q0 11-8 1 230 | 11 Q0 11-9 0 231 | 11 Q0 11-10 1 232 | 11 Q0 11-11 2 233 | 11 Q0 11-12 1 234 | 11 Q0 11-13 0 235 | 11 Q0 11-14 0 236 | 11 Q0 11-15 0 237 | 11 Q0 11-16 0 238 | 11 Q0 11-17 0 239 | 11 Q0 11-18 0 240 | 11 Q0 11-19 2 241 | 12 Q0 12-0 2 242 | 12 Q0 12-1 2 243 | 12 Q0 12-2 2 244 | 12 Q0 12-3 0 245 | 12 Q0 12-4 0 246 | 12 Q0 12-5 0 247 | 12 Q0 12-6 0 248 | 12 Q0 12-7 0 249 | 12 Q0 12-8 1 250 | 12 Q0 12-9 0 251 | 12 Q0 12-10 0 252 | 12 Q0 12-11 2 253 | 12 Q0 12-12 2 254 | 12 Q0 12-13 1 255 | 12 Q0 12-14 2 256 | 12 Q0 12-15 0 257 | 12 Q0 12-16 2 258 | 12 Q0 12-17 2 259 | 12 Q0 12-18 0 260 | 12 Q0 12-19 0 261 | 13 Q0 13-0 0 262 | 13 Q0 13-1 0 263 | 13 Q0 13-2 2 264 | 13 Q0 13-3 1 265 | 13 Q0 13-4 0 266 | 13 Q0 13-5 0 267 | 13 Q0 13-6 2 268 | 13 Q0 13-7 0 269 | 13 Q0 13-8 2 270 | 13 Q0 13-9 0 271 | 13 Q0 13-10 0 272 | 13 Q0 13-11 0 273 | 13 Q0 13-12 0 274 | 13 Q0 13-13 0 275 | 13 Q0 13-14 0 276 | 13 Q0 13-15 0 277 | 13 Q0 13-16 0 278 | 13 Q0 13-17 2 279 | 13 Q0 13-18 2 280 | 13 Q0 13-19 0 281 | 14 Q0 14-0 2 282 | 14 Q0 14-1 0 283 | 14 Q0 14-2 2 284 | 14 Q0 14-3 0 285 | 14 Q0 14-4 0 286 | 14 Q0 14-5 0 287 | 14 Q0 14-6 0 288 | 14 Q0 14-7 0 289 | 14 Q0 14-8 0 290 | 14 Q0 14-9 0 291 | 14 Q0 14-10 0 292 | 14 Q0 14-11 0 293 | 14 Q0 14-12 0 294 | 14 Q0 14-13 0 295 | 14 Q0 14-14 0 296 | 14 Q0 14-15 2 297 | 14 Q0 14-16 1 298 | 14 Q0 14-17 0 299 | 14 Q0 14-18 0 300 | 14 Q0 14-19 0 301 | 15 Q0 15-0 0 302 | 15 Q0 15-1 0 303 | 15 Q0 15-2 0 304 | 15 Q0 15-3 0 305 | 15 Q0 15-4 2 306 | 15 Q0 15-5 1 307 | 15 Q0 15-6 0 308 | 15 Q0 15-7 0 309 | 15 Q0 15-8 0 310 | 15 Q0 15-9 0 311 | 15 Q0 15-10 0 312 | 15 Q0 15-11 0 313 | 15 Q0 15-12 2 314 | 15 Q0 15-13 2 315 | 15 Q0 15-14 0 316 | 15 Q0 15-15 0 317 | 15 Q0 15-16 0 318 | 15 Q0 15-17 0 319 | 15 Q0 15-18 0 320 | 15 Q0 15-19 0 321 | 16 Q0 16-0 0 322 | 16 Q0 16-1 2 323 | 16 Q0 16-2 1 324 | 16 Q0 16-3 1 325 | 16 Q0 16-4 0 326 | 16 Q0 16-5 2 327 | 16 Q0 16-6 0 328 | 16 Q0 16-7 0 329 | 16 Q0 16-8 0 330 | 16 Q0 16-9 0 331 | 16 Q0 16-10 0 332 | 16 Q0 16-11 0 333 | 16 Q0 16-12 0 334 | 16 Q0 16-13 2 335 | 16 Q0 16-14 0 336 | 16 Q0 16-15 0 337 | 16 Q0 16-16 0 338 | 16 Q0 16-17 0 339 | 16 Q0 16-18 1 340 | 16 Q0 16-19 0 341 | 17 Q0 17-0 2 342 | 17 Q0 17-1 2 343 | 17 Q0 17-2 1 344 | 17 Q0 17-3 1 345 | 17 Q0 17-4 0 346 | 17 Q0 17-5 0 347 | 17 Q0 17-6 0 348 | 17 Q0 17-7 0 349 | 17 Q0 17-8 2 350 | 17 Q0 17-9 0 351 | 17 Q0 17-10 0 352 | 17 Q0 17-11 0 353 | 17 Q0 17-12 0 354 | 17 Q0 17-13 0 355 | 17 Q0 17-14 0 356 | 17 Q0 17-15 0 357 | 17 Q0 17-16 0 358 | 17 Q0 17-17 0 359 | 17 Q0 17-18 0 360 | 17 Q0 17-19 0 361 | 18 Q0 18-0 1 362 | 18 Q0 18-1 2 363 | 18 Q0 18-2 2 364 | 18 Q0 18-3 2 365 | 18 Q0 18-4 1 366 | 18 Q0 18-5 0 367 | 18 Q0 18-6 2 368 | 18 Q0 18-7 0 369 | 18 Q0 18-8 0 370 | 18 Q0 18-9 0 371 | 18 Q0 18-10 0 372 | 18 Q0 18-11 0 373 | 18 Q0 18-12 0 374 | 18 Q0 18-13 0 375 | 18 Q0 18-14 0 376 | 18 Q0 18-15 0 377 | 18 Q0 18-16 0 378 | 18 Q0 18-17 0 379 | 18 Q0 18-18 1 380 | 18 Q0 18-19 2 381 | 19 Q0 19-0 2 382 | 19 Q0 19-1 2 383 | 19 Q0 19-2 2 384 | 19 Q0 19-3 2 385 | 19 Q0 19-4 2 386 | 19 Q0 19-5 0 387 | 19 Q0 19-6 1 388 | 19 Q0 19-7 1 389 | 19 Q0 19-8 0 390 | 19 Q0 19-9 1 391 | 19 Q0 19-10 0 392 | 19 Q0 19-11 1 393 | 19 Q0 19-12 0 394 | 19 Q0 19-13 0 395 | 19 Q0 19-14 1 396 | 19 Q0 19-15 2 397 | 19 Q0 19-16 0 398 | 19 Q0 19-17 0 399 | 19 Q0 19-18 0 400 | 19 Q0 19-19 1 401 | 20 Q0 20-0 2 402 | 20 Q0 20-1 0 403 | 20 Q0 20-2 1 404 | 20 Q0 20-3 1 405 | 20 Q0 20-4 0 406 | 20 Q0 20-5 0 407 | 20 Q0 20-6 0 408 | 20 Q0 20-7 0 409 | 20 Q0 20-8 2 410 | 20 Q0 20-9 0 411 | 20 Q0 20-10 0 412 | 20 Q0 20-11 0 413 | 20 Q0 20-12 0 414 | 20 Q0 20-13 0 415 | 20 Q0 20-14 0 416 | 20 Q0 20-15 0 417 | 20 Q0 20-16 0 418 | 20 Q0 20-17 0 419 | 20 Q0 20-18 0 420 | 20 Q0 20-19 0 421 | -------------------------------------------------------------------------------- /NovelEval/queries.tsv: -------------------------------------------------------------------------------- 1 | 0 How many different Spider-Men are there in Across the Spider-Verse? 2 | 1 What is the screen resolution of vision pro? 3 | 2 Which film was the 2023 Palme d'Or winner? 4 | 3 Who will be the CEO of Twitter after Elon Musk is no longer the CEO? 5 | 4 How many goals did Haaland scored in the 2023 Champions League Final 6 | 5 Where did Benzema go after leaving Real Madrid? 7 | 6 Where was the 2023 Premier League FA Cup Final held? 8 | 7 What is the name of the combined Deepmind and Google Brain? 9 | 8 Where will Blackpink's 2023 world tour concert in France be held? 10 | 9 Where did the G7 Summit 2023 take place? 11 | 10 What are the best papers of CVPR 2023? 12 | 11 What is the release date of song Middle Ground? 13 | 12 Who wins NBA Finals 2023? 14 | 13 Who does Momoa play in Fast X? 15 | 14 What is Messi's annual income after transferring to Miami? 16 | 15 Who sang the theme song of Transformers Rise of the Beasts? 17 | 16 How much video memory does the DGX GH200 have? 18 | 17 What are the new features of PyTorch 2? 19 | 18 Who is the villain in The Flash? 20 | 19 Who win 2023 Laureus World Sportsman Of The Year Award? 21 | 20 The Little Mermaid first week box office? 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RankGPT: LLMs as Re-Ranking Agent 2 | 3 | [![Generic badge](https://img.shields.io/badge/arXiv-2304.09542-red.svg)](https://arxiv.org/abs/2304.09542) 4 | [![LICENSE](https://img.shields.io/badge/license-Apache-blue.svg?style=flat)](https://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | Code for paper "[Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agent](https://arxiv.org/abs/2304.09542)" 7 | 8 | This project aims to explore generative LLMs such as ChatGPT and GPT-4 for relevance ranking in Information Retrieval (IR). 9 | 10 | 11 | ## News 12 | - **[2023.12.10]** Our [RankGPT](https://arxiv.org/abs/2304.09542) paper won the Outstanding Paper Award of EMNLP2023! 🎉🎉🎉 13 | - **[2023.11.06]** Introduce [Instruction Distillation](https://github.com/sunnweiwei/RankGPT/tree/main/InstructDistill): Simplifing complex ranking instructions to enhance the efficiency of LLMs. Achieve SOTA ranking performances with only open-source LLMs! 14 | - **[2023.10.08]** Our paper has been accepted for presentation at the EMNLP 2023 main conference. See the updated version at https://arxiv.org/pdf/2304.09542.pdf! 15 | - **[2023.08.05]** Now support Azure, Claude, Cohere, Llama2 via [LiteLLM](https://github.com/BerriAI/litellm)! 16 | - **[2023.07.11]** Release a new test set NovelEval with the novel search questions and passages that have not been contaminated by the latest LLMs (e.g., GPT-4). See [NovelEval](https://github.com/sunnweiwei/RankGPT/tree/main/NovelEval) for details. 17 | - **[2023.04.23]** Sharing 100K ChatGPT predicted permutations on MS MARCO training set [here](#download-data-and-model). 18 | - **[2023.04.19]** Our paper is now available at https://arxiv.org/abs/2304.09542 19 | 20 | ## Quick example 21 | Below defines a query and three candidate passages: 22 | 23 | ```python 24 | item = { 25 | 'query': 'How much impact do masks have on preventing the spread of the COVID-19?', 26 | 'hits': [ 27 | {'content': 'Title: Universal Masking is Urgent in the COVID-19 Pandemic: SEIR and Agent Based Models, Empirical Validation, Policy Recommendations Content: We present two models for the COVID-19 pandemic predicting the impact of universal face mask wearing upon the spread of the SARS-CoV-2 virus--one employing a stochastic dynamic network based compartmental SEIR (susceptible-exposed-infectious-recovered) approach, and the other employing individual ABM (agent-based modelling) Monte Carlo simulation--indicating (1) significant impact under (near) universal masking when at least 80% of a population is wearing masks, versus minimal impact when only 50% or less of the population is wearing masks, and (2) significant impact when universal masking is adopted early, by Day 50 of a regional outbreak, versus minimal impact when universal masking is adopted late. These effects hold even at the lower filtering rates of homemade masks. To validate these theoretical models, we compare their predictions against a new empirical data set we have collected'}, 28 | {'content': 'Title: Masking the general population might attenuate COVID-19 outbreaks Content: The effect of masking the general population on a COVID-19 epidemic is estimated by computer simulation using two separate state-of-the-art web-based softwares, one of them calibrated for the SARS-CoV-2 virus. The questions addressed are these: 1. Can mask use by the general population limit the spread of SARS-CoV-2 in a country? 2. What types of masks exist, and how elaborate must a mask be to be effective against COVID-19? 3. Does the mask have to be applied early in an epidemic? 4. A brief general discussion of masks and some possible future research questions regarding masks and SARS-CoV-2. Results are as follows: (1) The results indicate that any type of mask, even simple home-made ones, may be effective. Masks use seems to have an effect in lowering new patients even the protective effect of each mask (here dubbed"one-mask protection") is'}, 29 | {'content': 'Title: To mask or not to mask: Modeling the potential for face mask use by the general public to curtail the COVID-19 pandemic Content: Face mask use by the general public for limiting the spread of the COVID-19 pandemic is controversial, though increasingly recommended, and the potential of this intervention is not well understood. We develop a compartmental model for assessing the community-wide impact of mask use by the general, asymptomatic public, a portion of which may be asymptomatically infectious. Model simulations, using data relevant to COVID-19 dynamics in the US states of New York and Washington, suggest that broad adoption of even relatively ineffective face masks may meaningfully reduce community transmission of COVID-19 and decrease peak hospitalizations and deaths. Moreover, mask use decreases the effective transmission rate in nearly linear proportion to the product of mask effectiveness (as a fraction of potentially infectious contacts blocked) and coverage rate (as'} 30 | ] 31 | } 32 | 33 | ``` 34 | 35 | We can re-rank the passages using ChatGPT with instructional permutation generation: 36 | 37 | ```python 38 | from rank_gpt import permutation_pipeline 39 | new_item = permutation_pipeline(item, rank_start=0, rank_end=3, model_name='gpt-3.5-turbo', api_key='Your OPENAI Key!') 40 | print(new_item) 41 | ``` 42 | 43 | We get the following result: 44 | 45 | ```python 46 | { 47 | 'query': 'How much impact do masks have on preventing the spread of the COVID-19?', 48 | 'hits': [ 49 | {'content': 'Title: Universal Masking is Urgent in the COVID-19 Pandemic: SEIR and Agent Based Models, Empirical Validation, Policy Recommendations Content: We present two models for the COVID-19 pandemic predicting the impact of universal face mask wearing upon the spread of the SARS-CoV-2 virus--one employing a stochastic dynamic network based compartmental SEIR (susceptible-exposed-infectious-recovered) approach, and the other employing individual ABM (agent-based modelling) Monte Carlo simulation--indicating (1) significant impact under (near) universal masking when at least 80% of a population is wearing masks, versus minimal impact when only 50% or less of the population is wearing masks, and (2) significant impact when universal masking is adopted early, by Day 50 of a regional outbreak, versus minimal impact when universal masking is adopted late. These effects hold even at the lower filtering rates of homemade masks. To validate these theoretical models, we compare their predictions against a new empirical data set we have collected'}, 50 | {'content': 'Title: To mask or not to mask: Modeling the potential for face mask use by the general public to curtail the COVID-19 pandemic Content: Face mask use by the general public for limiting the spread of the COVID-19 pandemic is controversial, though increasingly recommended, and the potential of this intervention is not well understood. We develop a compartmental model for assessing the community-wide impact of mask use by the general, asymptomatic public, a portion of which may be asymptomatically infectious. Model simulations, using data relevant to COVID-19 dynamics in the US states of New York and Washington, suggest that broad adoption of even relatively ineffective face masks may meaningfully reduce community transmission of COVID-19 and decrease peak hospitalizations and deaths. Moreover, mask use decreases the effective transmission rate in nearly linear proportion to the product of mask effectiveness (as a fraction of potentially infectious contacts blocked) and coverage rate (as'}, 51 | {'content': 'Title: Masking the general population might attenuate COVID-19 outbreaks Content: The effect of masking the general population on a COVID-19 epidemic is estimated by computer simulation using two separate state-of-the-art web-based softwares, one of them calibrated for the SARS-CoV-2 virus. The questions addressed are these: 1. Can mask use by the general population limit the spread of SARS-CoV-2 in a country? 2. What types of masks exist, and how elaborate must a mask be to be effective against COVID-19? 3. Does the mask have to be applied early in an epidemic? 4. A brief general discussion of masks and some possible future research questions regarding masks and SARS-CoV-2. Results are as follows: (1) The results indicate that any type of mask, even simple home-made ones, may be effective. Masks use seems to have an effect in lowering new patients even the protective effect of each mask (here dubbed"one-mask protection") is'} 52 | ] 53 | } 54 | ``` 55 | 56 |
57 | Step by step example 58 | 59 | ```python 60 | from rank_gpt import create_permutation_instruction, run_llm, receive_permutation 61 | 62 | # (1) Create permutation generation instruction 63 | messages = create_permutation_instruction(item=item, rank_start=0, rank_end=3, model_name='gpt-3.5-turbo') 64 | # (2) Get ChatGPT predicted permutation 65 | permutation = run_llm(messages, api_key="Your OPENAI Key!", model_name='gpt-3.5-turbo') 66 | # (3) Use permutation to re-rank the passage 67 | item = receive_permutation(item, permutation, rank_start=0, rank_end=3) 68 | 69 | ``` 70 | 71 |
72 | 73 | ## Sliding window strategy 74 | 75 | We introduce a sliding window strategy for the instructional permutation generation, that enables LLMs to rank more passages than their maximum token limit. 76 | 77 | The idea is to rank from back to front using a sliding window, re-ranking only the passages within the window at a time. 78 | 79 | Below is an example by re-ranking 3 passages with window size of 2 and step size of 1: 80 | 81 | ```python 82 | from rank_gpt import sliding_windows 83 | api_key = "Your OPENAI Key" 84 | new_item = sliding_windows(item, rank_start=0, rank_end=3, window_size=2, step=1, model_name='gpt-3.5-turbo', api_key=api_key) 85 | print(new_item) 86 | ``` 87 | 88 | ## Evaluation on Benchmarks 89 | We use [pyserini](https://github.com/castorini/pyserini) to retrieve 100 passages for each query and re-rank them using instructional permutation generation. 90 | 91 | Example of evaluation on TREC-DL19: 92 | 93 | ```python 94 | from pyserini.search import LuceneSearcher, get_topics, get_qrels 95 | from rank_gpt import run_retriever, sliding_windows 96 | import tempfile 97 | openai_key = None # Your openai key 98 | 99 | # Retrieve passages using pyserini BM25. 100 | searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage') 101 | topics = get_topics('dl19-passage') 102 | qrels = get_qrels('dl19-passage') 103 | rank_results = run_retriever(topics, searcher, qrels, k=100) 104 | 105 | # Run sliding window permutation generation 106 | new_results = [] 107 | for item in tqdm(rank_results): 108 | new_item = sliding_windows(item, rank_start=0, rank_end=100, window_size=20, step=10, model_name='gpt-3.5-turbo', api_key=openai_key) 109 | new_results.append(new_item) 110 | 111 | # Evaluate nDCG@10 112 | from trec_eval import EvalFunction 113 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 114 | EvalFunction.write_file(new_results, temp_file) 115 | EvalFunction.main('dl19-passage', temp_file) 116 | ``` 117 | 118 | Run evaluation on all benchmarks 119 | 120 | ```sh 121 | python run_evaluation.py 122 | ``` 123 | 124 | Below are the results (average nDCG@10) of our preliminary experiments on [TREC](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2020.html), [BEIR](https://github.com/beir-cellar/beir) and [Mr. TyDi](https://github.com/castorini/mr.tydi). 125 | 126 | ![Results on benchmarks](assets/benchmark-results.png) 127 | 128 | 129 | ## Training Specialized Models 130 | 131 | ### Download data and model 132 | 133 | | File | Note | Link | 134 | |:-------------------------------|:--------|:--------:| 135 | | marco-train-10k.jsonl | 10K queries sampled from MS MARCO | [Google drive](https://drive.google.com/file/d/1G3MpQ5a4KgUS13JJZFE9aQvCbQfgSQzj/view?usp=share_link) | 136 | | marco-train-10k-gpt3.5.json | Permutations predicted by ChatGPT | [Google drive](https://drive.google.com/file/d/1i7ckK7kN7BAqq5g7xAd0dLv3cTYYiclA/view?usp=share_link) | 137 | | deberta-10k-rank_net | Specialized Deberta model trained with RankNet loss | [Google drive](https://drive.google.com/file/d/1-KEpJ2KnJCqiJof4zNEA4m78tnwgxKhb/view?usp=share_link) | 138 | |marco-train-100k.jsonl | 100K queries from MS MARCO | [Google drive](https://drive.google.com/file/d/1OgF4rj89FWSr7pl1c7Hu4x0oQYIMwhik/view?usp=share_link) | 139 | | marco-train-100k-gpt3.5.json | Permutations by ChatGPT of the 100K queries | [Google drive](https://drive.google.com/file/d/1z327WOKr70rC4UfOlQVBQnuLxChi_uPs/view?usp=share_link) | 140 | 141 | ### Distill LLM to a small specialized model 142 | 143 | ```bash 144 | python specialization.py \ 145 | --model microsoft/deberta-v3-base \ 146 | --loss rank_net \ 147 | --data data/marco-train-10k.jsonl \ 148 | --permutation marco-train-10k-gpt3.5.json \ 149 | --save_path out/deberta-10k-rank_net \ 150 | --do_train true \ 151 | --do_eval true 152 | ``` 153 | 154 | or run on multi-gpus, using [accelerate](https://github.com/huggingface/accelerate): 155 | 156 | ```bash 157 | accelerate launch --num_processes 4 specialization.py \ 158 | --model microsoft/deberta-v3-base \ 159 | --loss rank_net \ 160 | --data data/marco-train-10k.jsonl \ 161 | --permutation marco-train-10k-gpt3.5.json \ 162 | --save_path out/deberta-10k-rank_net \ 163 | --do_train true \ 164 | --do_eval true 165 | ``` 166 | 167 | ### Evaluate the distilled model on benchmarks 168 | 169 | ```bash 170 | python specialization.py \ 171 | --model out/deberta-10k-rank_net \ 172 | --do_train false \ 173 | --do_eval true 174 | ``` 175 | 176 | The following figure show the results of distilled specialized model with different model size and number of training queires. 177 | 178 | ![Specialization results.](assets/specialization-results.png) 179 | 180 | ## Cite 181 | 182 | ```latex 183 | @article{Sun2023IsCG, 184 | title={Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agent}, 185 | author={Weiwei Sun and Lingyong Yan and Xinyu Ma and Pengjie Ren and Dawei Yin and Zhaochun Ren}, 186 | journal={ArXiv}, 187 | year={2023}, 188 | volume={abs/2304.09542} 189 | } 190 | ``` 191 | ``` 192 | @article{Sun2023InstructionDM, 193 | title={Instruction Distillation Makes Large Language Models Efficient Zero-shot Rankers}, 194 | author={Weiwei Sun and Zheng Chen and Xinyu Ma and Lingyong Yan and Shuaiqiang Wang and Pengjie Ren and Zhumin Chen and Dawei Yin and Zhaochun Ren}, 195 | journal={ArXiv}, 196 | year={2023}, 197 | volume={abs/2311.01555}, 198 | } 199 | ``` 200 | -------------------------------------------------------------------------------- /assets/benchmark-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunnweiwei/RankGPT/0d62bc3855c7c118048a7c47c18e719b938e291a/assets/benchmark-results.png -------------------------------------------------------------------------------- /assets/specialization-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunnweiwei/RankGPT/0d62bc3855c7c118048a7c47c18e719b938e291a/assets/specialization-results.png -------------------------------------------------------------------------------- /pointwise.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | # This file includes the implementation of Relevance Generation and Query Generation as described in the RankGPT paper. For more details, refer to the paper available at: https://arxiv.org/abs/2304.09542 4 | 5 | client = OpenAI(api_key='sk-xxx') 6 | 7 | FEW_SHOT_EXAMPLE = '''Given a passage and a question, predict whether the passage includes an answer to the question by producing either `Yes` or `No`. 8 | 9 | Passage: Its 25 drops per ml, you guys are all wrong. If it is water, the standard was changed 15 - 20 years ago to make 20 drops = 1mL. The viscosity of most things is temperature dependent, so this would be at room temperature. Hope this helps. 10 | Query: how many eye drops per ml 11 | Does the passage answer the query? 12 | Answer: Yes 13 | 14 | Passage: RE: How many eyedrops are there in a 10 ml bottle of Cosopt? My Kaiser pharmacy insists that 2 bottles should last me 100 days but I run out way before that time when I am using 4 drops per day.In the past other pharmacies have given me 3 10-ml bottles for 100 days.E: How many eyedrops are there in a 10 ml bottle of Cosopt? My Kaiser pharmacy insists that 2 bottles should last me 100 days but I run out way before that time when I am using 4 drops per day. 15 | Query: how many eye drops per ml 16 | Does the passage answer the query? 17 | Answer: No 18 | 19 | Passage: : You can transfer money to your checking account from other Wells Fargo. accounts through Wells Fargo Mobile Banking with the mobile app, online, at any. Wells Fargo ATM, or at a Wells Fargo branch. 1 Money in — deposits. 20 | Query: can you open a wells fargo account online 21 | Does the passage answer the query? 22 | Answer: No 23 | 24 | Passage: You can open a Wells Fargo banking account from your home or even online. It is really easy to do, provided you have all of the appropriate documentation. Wells Fargo has so many bank account options that you will be sure to find one that works for you. They offer free checking accounts with free online banking. 25 | Query: can you open a wells fargo account online 26 | Does the passage answer the query? 27 | Answer: Yes 28 | ''' 29 | 30 | ZERO_SHOT_EXAMPLE = '''Given a passage and a question, predict whether the passage includes an answer to the question by producing either `Yes` or `No`.''' 31 | 32 | 33 | def relevance_generation(query, passage, instruction: str = ZERO_SHOT_EXAMPLE, model='gpt-3.5-turbo'): 34 | prompt = f"{instruction}\nPassage: {passage}\nQuery: {query}\nDoes the passage answer the query?\nAnswer:" 35 | if 'instruct' in model or 'text' in model or 'davinci' in model: 36 | response = client.completions.create( 37 | model=model, 38 | prompt=prompt, 39 | temperature=0, logprobs=5, max_tokens=2 40 | ) 41 | text = response.choices[0].text 42 | token_logprobs = response.choices[0].logprobs.token_logprobs[0] 43 | top_logprobs = response.choices[0].logprobs.top_logprobs[0] 44 | else: 45 | response = client.chat.completions.create( 46 | model=model, 47 | messages=[{"role": "user", "content": prompt}], 48 | temperature=0, max_tokens=2, logprobs=True, top_logprobs=5 49 | ) 50 | text = response.choices[0].message.content 51 | token_logprobs = response.choices[0].logprobs.content[0].logprob 52 | top_logprobs = response.choices[0].logprobs.content[0].top_logprobs 53 | top_logprobs = {word.token: word.logprob for word in top_logprobs} 54 | 55 | if 'Yes' in text: 56 | logprobs = token_logprobs 57 | logprobs = - 1 / logprobs 58 | rel = logprobs 59 | elif 'No' in text: 60 | logprobs = token_logprobs 61 | logprobs = 1 / logprobs 62 | rel = logprobs 63 | else: 64 | if ' Yes' in top_logprobs and ' No' in top_logprobs and top_logprobs[' Yes'] > top_logprobs[' No']: 65 | logprobs = top_logprobs[' Yes'] 66 | logprobs = - 1 / logprobs 67 | rel = logprobs 68 | elif ' Yes' in top_logprobs and ' No' in top_logprobs and top_logprobs[' Yes'] < top_logprobs[' No']: 69 | logprobs = top_logprobs[' No'] 70 | logprobs = 1 / logprobs 71 | rel = logprobs 72 | elif ' Yes' in top_logprobs: 73 | logprobs = top_logprobs[' Yes'] 74 | logprobs = - 1 / logprobs 75 | rel = logprobs 76 | elif ' No' in top_logprobs: 77 | logprobs = top_logprobs[' No'] 78 | logprobs = 1 / logprobs 79 | rel = logprobs 80 | elif 'yes' in text.lower(): 81 | rel = 0 82 | else: 83 | rel = -1000000 84 | return rel 85 | 86 | 87 | def query_generation(query, passage, model='davinci-002'): 88 | prompt = [f"Please write a question based on this passage.\nPassage: {passage}\Question:", 89 | f" {query}"] 90 | 91 | response = client.completions.create( 92 | model=model, 93 | prompt=prompt[0] + prompt[1], 94 | temperature=0, logprobs=0, max_tokens=0, echo=True 95 | ) 96 | # print(response) 97 | out = response.choices[0] 98 | assert prompt[0] + prompt[1] == out.text 99 | i = out.logprobs.text_offset.index(len(prompt[0]) - 1) 100 | if i == 0: 101 | i = i + 1 102 | loss = -sum(out.logprobs.token_logprobs[i:-1]) # ignore the last '.' 103 | avg_loss = loss / (len(out.logprobs.text_offset) - i - 1) # 1 is the last '.' 104 | rel = avg_loss 105 | return rel 106 | 107 | 108 | def main(): 109 | query = 'hello world' 110 | passage1 = '''A "Hello, World!" program is generally a simple computer program which outputs (or displays) to the screen (often the console) a message similar to "Hello, World!" while ignoring any user input. A small piece of code in most general-purpose programming languages, this program is used to illustrate a language's basic syntax. A "Hello, World!" program is often the first written by a student of a new programming language,[1] but such a program can also be used as a sanity check to ensure that the computer software intended to compile or run source code is correctly installed, and that its operator understands how to use it.''' 111 | passage2 = '''Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation.[31] Python is dynamically typed and garbage-collected. It supports multiple programming paradigms, including structured (particularly procedural), object-oriented and functional programming. It is often described as a "batteries included" language due to its comprehensive standard library.[32][33] Guido van Rossum began working on Python in the late 1980s as a successor to the ABC programming language and first released it in 1991 as Python 0.9.0.[34] Python 2.0 was released in 2000. Python 3.0, released in 2008, was a major revision not completely backward-compatible with earlier versions.''' 112 | 113 | print(relevance_generation(query, passage1, model='gpt-3.5-turbo')) 114 | print(relevance_generation(query, passage2, model='gpt-3.5-turbo')) 115 | 116 | print(query_generation(query, passage1, model='babbage-002')) 117 | print(query_generation(query, passage2, model='babbage-002')) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /rank_gpt.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tqdm import tqdm 3 | import time 4 | import json 5 | 6 | 7 | class OpenaiClient: 8 | def __init__(self, keys=None, start_id=None, proxy=None): 9 | from openai import OpenAI 10 | import openai 11 | if isinstance(keys, str): 12 | keys = [keys] 13 | if keys is None: 14 | raise "Please provide OpenAI Key." 15 | 16 | self.key = keys 17 | self.key_id = start_id or 0 18 | self.key_id = self.key_id % len(self.key) 19 | self.api_key = self.key[self.key_id % len(self.key)] 20 | self.client = OpenAI(api_key=self.api_key) 21 | 22 | def chat(self, *args, return_text=False, reduce_length=False, **kwargs): 23 | while True: 24 | try: 25 | completion = self.client.chat.completions.create(*args, **kwargs, timeout=30) 26 | break 27 | except Exception as e: 28 | print(str(e)) 29 | if "This model's maximum context length is" in str(e): 30 | print('reduce_length') 31 | return 'ERROR::reduce_length' 32 | time.sleep(0.1) 33 | if return_text: 34 | completion = completion.choices[0].message.content 35 | return completion 36 | 37 | def text(self, *args, return_text=False, reduce_length=False, **kwargs): 38 | while True: 39 | try: 40 | completion = self.client.completions.create( 41 | *args, **kwargs 42 | ) 43 | break 44 | except Exception as e: 45 | print(e) 46 | if "This model's maximum context length is" in str(e): 47 | print('reduce_length') 48 | return 'ERROR::reduce_length' 49 | time.sleep(0.1) 50 | if return_text: 51 | completion = completion.choices[0].text 52 | return completion 53 | 54 | 55 | class ClaudeClient: 56 | def __init__(self, keys): 57 | from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT 58 | self.anthropic = Anthropic(api_key=keys) 59 | 60 | def chat(self, messages, return_text=True, max_tokens=300, *args, **kwargs): 61 | system = ' '.join([turn['content'] for turn in messages if turn['role'] == 'system']) 62 | messages = [turn for turn in messages if turn['role'] != 'system'] 63 | if len(system) == 0: 64 | system = None 65 | completion = self.anthropic.beta.messages.create(messages=messages, system=system, max_tokens=max_tokens, *args, **kwargs) 66 | if return_text: 67 | completion = completion.content[0].text 68 | return completion 69 | 70 | def text(self, max_tokens=None, return_text=True, *args, **kwargs): 71 | completion = self.anthropic.beta.messages.create(max_tokens_to_sample=max_tokens, *args, **kwargs) 72 | if return_text: 73 | completion = completion.completion 74 | return completion 75 | 76 | 77 | class LitellmClient: 78 | # https://github.com/BerriAI/litellm 79 | def __init__(self, keys=None): 80 | self.api_key = keys 81 | 82 | def chat(self, return_text=True, *args, **kwargs): 83 | from litellm import completion 84 | response = completion(api_key=self.api_key, *args, **kwargs) 85 | if return_text: 86 | response = response.choices[0].message.content 87 | return response 88 | 89 | 90 | def convert_messages_to_prompt(messages): 91 | # convert chat message into a single prompt; used for completion model (eg davinci) 92 | prompt = '' 93 | for turn in messages: 94 | if turn['role'] == 'system': 95 | prompt += f"{turn['content']}\n\n" 96 | elif turn['role'] == 'user': 97 | prompt += f"{turn['content']}\n\n" 98 | else: # 'assistant' 99 | pass 100 | prompt += "The ranking results of the 20 passages (only identifiers) is:" 101 | return prompt 102 | 103 | 104 | def run_retriever(topics, searcher, qrels=None, k=100, qid=None): 105 | ranks = [] 106 | if isinstance(topics, str): 107 | hits = searcher.search(topics, k=k) 108 | ranks.append({'query': topics, 'hits': []}) 109 | rank = 0 110 | for hit in hits: 111 | rank += 1 112 | content = json.loads(searcher.doc(hit.docid).raw()) 113 | if 'title' in content: 114 | content = 'Title: ' + content['title'] + ' ' + 'Content: ' + content['text'] 115 | else: 116 | content = content['contents'] 117 | content = ' '.join(content.split()) 118 | ranks[-1]['hits'].append({ 119 | 'content': content, 120 | 'qid': qid, 'docid': hit.docid, 'rank': rank, 'score': hit.score}) 121 | return ranks[-1] 122 | 123 | for qid in tqdm(topics): 124 | if qid in qrels: 125 | query = topics[qid]['title'] 126 | ranks.append({'query': query, 'hits': []}) 127 | hits = searcher.search(query, k=k) 128 | rank = 0 129 | for hit in hits: 130 | rank += 1 131 | content = json.loads(searcher.doc(hit.docid).raw()) 132 | if 'title' in content: 133 | content = 'Title: ' + content['title'] + ' ' + 'Content: ' + content['text'] 134 | else: 135 | content = content['contents'] 136 | content = ' '.join(content.split()) 137 | ranks[-1]['hits'].append({ 138 | 'content': content, 139 | 'qid': qid, 'docid': hit.docid, 'rank': rank, 'score': hit.score}) 140 | return ranks 141 | 142 | 143 | def get_prefix_prompt(query, num): 144 | return [{'role': 'system', 145 | 'content': "You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query."}, 146 | {'role': 'user', 147 | 'content': f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}."}, 148 | {'role': 'assistant', 'content': 'Okay, please provide the passages.'}] 149 | 150 | 151 | def get_post_prompt(query, num): 152 | return f"Search Query: {query}. \nRank the {num} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only response the ranking results, do not say any word or explain." 153 | 154 | 155 | def create_permutation_instruction(item=None, rank_start=0, rank_end=100, model_name='gpt-3.5-turbo'): 156 | query = item['query'] 157 | num = len(item['hits'][rank_start: rank_end]) 158 | 159 | max_length = 300 160 | 161 | messages = get_prefix_prompt(query, num) 162 | rank = 0 163 | for hit in item['hits'][rank_start: rank_end]: 164 | rank += 1 165 | content = hit['content'] 166 | content = content.replace('Title: Content: ', '') 167 | content = content.strip() 168 | # For Japanese should cut by character: content = content[:int(max_length)] 169 | content = ' '.join(content.split()[:int(max_length)]) 170 | messages.append({'role': 'user', 'content': f"[{rank}] {content}"}) 171 | messages.append({'role': 'assistant', 'content': f'Received passage [{rank}].'}) 172 | messages.append({'role': 'user', 'content': get_post_prompt(query, num)}) 173 | 174 | return messages 175 | 176 | 177 | def run_llm(messages, api_key=None, model_name="gpt-3.5-turbo"): 178 | if 'gpt' in model_name: 179 | Client = OpenaiClient 180 | elif 'claude' in model_name: 181 | Client = ClaudeClient 182 | else: 183 | Client = LitellmClient 184 | 185 | agent = Client(api_key) 186 | response = agent.chat(model=model_name, messages=messages, temperature=0, return_text=True) 187 | return response 188 | 189 | 190 | def clean_response(response: str): 191 | new_response = '' 192 | for c in response: 193 | if not c.isdigit(): 194 | new_response += ' ' 195 | else: 196 | new_response += c 197 | new_response = new_response.strip() 198 | return new_response 199 | 200 | 201 | def remove_duplicate(response): 202 | new_response = [] 203 | for c in response: 204 | if c not in new_response: 205 | new_response.append(c) 206 | return new_response 207 | 208 | 209 | def receive_permutation(item, permutation, rank_start=0, rank_end=100): 210 | response = clean_response(permutation) 211 | response = [int(x) - 1 for x in response.split()] 212 | response = remove_duplicate(response) 213 | cut_range = copy.deepcopy(item['hits'][rank_start: rank_end]) 214 | original_rank = [tt for tt in range(len(cut_range))] 215 | response = [ss for ss in response if ss in original_rank] 216 | response = response + [tt for tt in original_rank if tt not in response] 217 | for j, x in enumerate(response): 218 | item['hits'][j + rank_start] = copy.deepcopy(cut_range[x]) 219 | if 'rank' in item['hits'][j + rank_start]: 220 | item['hits'][j + rank_start]['rank'] = cut_range[j]['rank'] 221 | if 'score' in item['hits'][j + rank_start]: 222 | item['hits'][j + rank_start]['score'] = cut_range[j]['score'] 223 | return item 224 | 225 | 226 | def permutation_pipeline(item=None, rank_start=0, rank_end=100, model_name='gpt-3.5-turbo', api_key=None): 227 | messages = create_permutation_instruction(item=item, rank_start=rank_start, rank_end=rank_end, 228 | model_name=model_name) # chan 229 | permutation = run_llm(messages, api_key=api_key, model_name=model_name) 230 | item = receive_permutation(item, permutation, rank_start=rank_start, rank_end=rank_end) 231 | return item 232 | 233 | 234 | def sliding_windows(item=None, rank_start=0, rank_end=100, window_size=20, step=10, model_name='gpt-3.5-turbo', 235 | api_key=None): 236 | item = copy.deepcopy(item) 237 | end_pos = rank_end 238 | start_pos = rank_end - window_size 239 | while start_pos >= rank_start: 240 | start_pos = max(start_pos, rank_start) 241 | item = permutation_pipeline(item, start_pos, end_pos, model_name=model_name, api_key=api_key) 242 | end_pos = end_pos - step 243 | start_pos = start_pos - step 244 | return item 245 | 246 | 247 | def write_eval_file(rank_results, file): 248 | with open(file, 'w') as f: 249 | for i in range(len(rank_results)): 250 | rank = 1 251 | hits = rank_results[i]['hits'] 252 | for hit in hits: 253 | f.write(f"{hit['qid']} Q0 {hit['docid']} {rank} {hit['score']} rank\n") 254 | rank += 1 255 | return True 256 | 257 | 258 | def main(): 259 | from pyserini.search import LuceneSearcher 260 | from pyserini.search import get_topics, get_qrels 261 | import tempfile 262 | 263 | api_key = None # Your openai key 264 | 265 | searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage') 266 | topics = get_topics('dl19-passage') 267 | qrels = get_qrels('dl19-passage') 268 | 269 | rank_results = run_retriever(topics, searcher, qrels, k=100) 270 | 271 | new_results = [] 272 | for item in tqdm(rank_results): 273 | new_item = permutation_pipeline(item, rank_start=0, rank_end=20, model_name='gpt-3.5-turbo', 274 | api_key=api_key) 275 | new_results.append(new_item) 276 | 277 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 278 | from trec_eval import EvalFunction 279 | 280 | EvalFunction.write_file(new_results, temp_file) 281 | EvalFunction.main(THE_TOPICS[data], temp_file) 282 | 283 | 284 | if __name__ == '__main__': 285 | main() 286 | -------------------------------------------------------------------------------- /rank_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from torch.nn import BCELoss, BCEWithLogitsLoss 4 | from itertools import product 5 | 6 | 7 | class RankLoss: 8 | 9 | @staticmethod 10 | def softmax_ce_loss(y_pred, *args, **kwargs): 11 | return F.cross_entropy(y_pred, torch.zeros((y_pred.size(0),)).long().cuda()) 12 | 13 | @staticmethod 14 | def pointwise_rmse(y_pred, y_true=None): 15 | if y_true is None: 16 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 17 | y_true[:, 0] = 1 18 | errors = (y_true - y_pred) 19 | squared_errors = errors ** 2 20 | valid_mask = (y_true != -100).float() 21 | mean_squared_errors = torch.sum(squared_errors, dim=1) / torch.sum(valid_mask, dim=1) 22 | rmses = torch.sqrt(mean_squared_errors) 23 | return torch.mean(rmses) 24 | 25 | @staticmethod 26 | def pointwise_bce(y_pred, y_true=None): 27 | if y_true is None: 28 | y_true = torch.zeros_like(y_pred).float().to(y_pred.device) 29 | y_true[:, 0] = 1 30 | loss = F.binary_cross_entropy(torch.sigmoid(y_pred), y_true) 31 | return loss 32 | 33 | @staticmethod 34 | def list_net(y_pred, y_true=None, padded_value_indicator=-100, eps=1e-10): 35 | """ 36 | ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach". 37 | :param y_pred: predictions from the model, shape [batch_size, slate_length] 38 | :param y_true: ground truth labels, shape [batch_size, slate_length] 39 | :param eps: epsilon value, used for numerical stability 40 | :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 41 | :return: loss value, a torch.Tensor 42 | """ 43 | if y_true is None: 44 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 45 | y_true[:, 0] = 1 46 | 47 | preds_smax = F.softmax(y_pred, dim=1) 48 | true_smax = F.softmax(y_true, dim=1) 49 | 50 | preds_smax = preds_smax + eps 51 | preds_log = torch.log(preds_smax) 52 | 53 | return torch.mean(-torch.sum(true_smax * preds_log, dim=1)) 54 | 55 | @staticmethod 56 | def rank_net(y_pred, y_true=None, padded_value_indicator=-100, weight_by_diff=False, 57 | weight_by_diff_powed=False): 58 | """ 59 | RankNet loss introduced in "Learning to Rank using Gradient Descent". 60 | :param y_pred: predictions from the model, shape [batch_size, slate_length] 61 | :param y_true: ground truth labels, shape [batch_size, slate_length] 62 | :param weight_by_diff: flag indicating whether to weight the score differences by ground truth differences. 63 | :param weight_by_diff_powed: flag indicating whether to weight the score differences by the squared ground truth differences. 64 | :return: loss value, a torch.Tensor 65 | """ 66 | if y_true is None: 67 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 68 | y_true[:, 0] = 1 69 | 70 | # here we generate every pair of indices from the range of document length in the batch 71 | document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2)) 72 | 73 | pairs_true = y_true[:, document_pairs_candidates] 74 | selected_pred = y_pred[:, document_pairs_candidates] 75 | 76 | # here we calculate the relative true relevance of every candidate pair 77 | true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1] 78 | pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1] 79 | 80 | # here we filter just the pairs that are 'positive' and did not involve a padded instance 81 | # we can do that since in the candidate pairs we had symetric pairs so we can stick with 82 | # positive ones for a simpler loss function formulation 83 | the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs)) 84 | 85 | pred_diffs = pred_diffs[the_mask] 86 | 87 | weight = None 88 | if weight_by_diff: 89 | abs_diff = torch.abs(true_diffs) 90 | weight = abs_diff[the_mask] 91 | elif weight_by_diff_powed: 92 | true_pow_diffs = torch.pow(pairs_true[:, :, 0], 2) - torch.pow(pairs_true[:, :, 1], 2) 93 | abs_diff = torch.abs(true_pow_diffs) 94 | weight = abs_diff[the_mask] 95 | 96 | # here we 'binarize' true relevancy diffs since for a pairwise loss we just need to know 97 | # whether one document is better than the other and not about the actual difference in 98 | # their relevancy levels 99 | true_diffs = (true_diffs > 0).type(torch.float32) 100 | true_diffs = true_diffs[the_mask] 101 | 102 | return BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs) 103 | 104 | @staticmethod 105 | def lambda_loss(y_pred, y_true=None, eps=1e-10, padded_value_indicator=-100, weighing_scheme=None, k=None, 106 | sigma=1., mu=10., reduction="mean", reduction_log="binary"): 107 | """ 108 | LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization". 109 | Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet. 110 | :param y_pred: predictions from the model, shape [batch_size, slate_length] 111 | :param y_true: ground truth labels, shape [batch_size, slate_length] 112 | :param eps: epsilon value, used for numerical stability 113 | :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 114 | :param weighing_scheme: a string corresponding to a name of one of the weighing schemes 115 | :param k: rank at which the loss is truncated 116 | :param sigma: score difference weight used in the sigmoid function 117 | :param mu: optional weight used in NDCGLoss2++ weighing scheme 118 | :param reduction: losses reduction method, could be either a sum or a mean 119 | :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural 120 | :return: loss value, a torch.Tensor 121 | """ 122 | if y_true is None: 123 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 124 | y_true[:, 0] = 1 125 | 126 | device = y_pred.device 127 | 128 | # Here we sort the true and predicted relevancy scores. 129 | y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1) 130 | y_true_sorted, _ = y_true.sort(descending=True, dim=-1) 131 | 132 | # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element. 133 | true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred) 134 | true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :] 135 | padded_pairs_mask = torch.isfinite(true_diffs) 136 | 137 | if weighing_scheme != "ndcgLoss1_scheme": 138 | padded_pairs_mask = padded_pairs_mask & (true_diffs > 0) 139 | 140 | ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device) 141 | ndcg_at_k_mask[:k, :k] = 1 142 | 143 | # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs) 144 | true_sorted_by_preds.clamp_(min=0.) 145 | y_true_sorted.clamp_(min=0.) 146 | 147 | # Here we find the gains, discounts and ideal DCGs per slate. 148 | pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device) 149 | D = torch.log2(1. + pos_idxs.float())[None, :] 150 | maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps) 151 | G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None] 152 | 153 | # Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0) 154 | if weighing_scheme is None: 155 | weights = 1. 156 | else: 157 | weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds) # type: ignore 158 | 159 | # We are clamping the array entries to maintain correct backprop (log(0) and division by 0) 160 | scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8) 161 | scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.) 162 | weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps) 163 | if reduction_log == "natural": 164 | losses = torch.log(weighted_probas) 165 | elif reduction_log == "binary": 166 | losses = torch.log2(weighted_probas) 167 | else: 168 | raise ValueError("Reduction logarithm base can be either natural or binary") 169 | 170 | if reduction == "sum": 171 | loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask]) 172 | elif reduction == "mean": 173 | loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask]) 174 | else: 175 | raise ValueError("Reduction method can be either sum or mean") 176 | 177 | return loss 178 | 179 | 180 | def ndcgLoss1_scheme(G, D, *args): 181 | return (G / D)[:, :, None] 182 | 183 | 184 | def ndcgLoss2_scheme(G, D, *args): 185 | pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device) 186 | delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :]) 187 | deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.)) 188 | deltas.diagonal().zero_() 189 | 190 | return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :]) 191 | 192 | 193 | def lambdaRank_scheme(G, D, *args): 194 | return torch.abs(torch.pow(D[:, :, None], -1.) - torch.pow(D[:, None, :], -1.)) * torch.abs( 195 | G[:, :, None] - G[:, None, :]) 196 | 197 | 198 | def ndcgLoss2PP_scheme(G, D, *args): 199 | return args[0] * ndcgLoss2_scheme(G, D) + lambdaRank_scheme(G, D) 200 | 201 | 202 | def rankNet_scheme(G, D, *args): 203 | return 1. 204 | 205 | 206 | def rankNetWeightedByGTDiff_scheme(G, D, *args): 207 | return torch.abs(args[1][:, :, None] - args[1][:, None, :]) 208 | 209 | 210 | def rankNetWeightedByGTDiffPowed_scheme(G, D, *args): 211 | return torch.abs(torch.pow(args[1][:, :, None], 2) - torch.pow(args[1][:, None, :], 2)) 212 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | openai 3 | tiktoken 4 | pyserini 5 | 6 | -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | THE_INDEX = { 2 | 'dl19': 'msmarco-v1-passage', 3 | 'dl20': 'msmarco-v1-passage', 4 | 'covid': 'beir-v1.0.0-trec-covid.flat', 5 | 'arguana': 'beir-v1.0.0-arguana.flat', 6 | 'touche': 'beir-v1.0.0-webis-touche2020.flat', 7 | 'news': 'beir-v1.0.0-trec-news.flat', 8 | 'scifact': 'beir-v1.0.0-scifact.flat', 9 | 'fiqa': 'beir-v1.0.0-fiqa.flat', 10 | 'scidocs': 'beir-v1.0.0-scidocs.flat', 11 | 'nfc': 'beir-v1.0.0-nfcorpus.flat', 12 | 'quora': 'beir-v1.0.0-quora.flat', 13 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity.flat', 14 | 'fever': 'beir-v1.0.0-fever-flat', 15 | 'robust04': 'beir-v1.0.0-robust04.flat', 16 | 'signal': 'beir-v1.0.0-signal1m.flat', 17 | 18 | 'mrtydi-ar': 'mrtydi-v1.1-arabic', 19 | 'mrtydi-bn': 'mrtydi-v1.1-bengali', 20 | 'mrtydi-fi': 'mrtydi-v1.1-finnish', 21 | 'mrtydi-id': 'mrtydi-v1.1-indonesian', 22 | 'mrtydi-ja': 'mrtydi-v1.1-japanese', 23 | 'mrtydi-ko': 'mrtydi-v1.1-korean', 24 | 'mrtydi-ru': 'mrtydi-v1.1-russian', 25 | 'mrtydi-sw': 'mrtydi-v1.1-swahili', 26 | 'mrtydi-te': 'mrtydi-v1.1-telugu', 27 | 'mrtydi-th': 'mrtydi-v1.1-thai', 28 | } 29 | 30 | THE_TOPICS = { 31 | 'dl19': 'dl19-passage', 32 | 'dl20': 'dl20-passage', 33 | 'covid': 'beir-v1.0.0-trec-covid-test', 34 | 'arguana': 'beir-v1.0.0-arguana-test', 35 | 'touche': 'beir-v1.0.0-webis-touche2020-test', 36 | 'news': 'beir-v1.0.0-trec-news-test', 37 | 'scifact': 'beir-v1.0.0-scifact-test', 38 | 'fiqa': 'beir-v1.0.0-fiqa-test', 39 | 'scidocs': 'beir-v1.0.0-scidocs-test', 40 | 'nfc': 'beir-v1.0.0-nfcorpus-test', 41 | 'quora': 'beir-v1.0.0-quora-test', 42 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity-test', 43 | 'fever': 'beir-v1.0.0-fever-test', 44 | 'robust04': 'beir-v1.0.0-robust04-test', 45 | 'signal': 'beir-v1.0.0-signal1m-test', 46 | 47 | 'mrtydi-ar': 'mrtydi-v1.1-arabic-test', 48 | 'mrtydi-bn': 'mrtydi-v1.1-bengali-test', 49 | 'mrtydi-fi': 'mrtydi-v1.1-finnish-test', 50 | 'mrtydi-id': 'mrtydi-v1.1-indonesian-test', 51 | 'mrtydi-ja': 'mrtydi-v1.1-japanese-test', 52 | 'mrtydi-ko': 'mrtydi-v1.1-korean-test', 53 | 'mrtydi-ru': 'mrtydi-v1.1-russian-test', 54 | 'mrtydi-sw': 'mrtydi-v1.1-swahili-test', 55 | 'mrtydi-te': 'mrtydi-v1.1-telugu-test', 56 | 'mrtydi-th': 'mrtydi-v1.1-thai-test', 57 | 58 | } 59 | 60 | from rank_gpt import run_retriever, sliding_windows, write_eval_file 61 | from pyserini.search import LuceneSearcher, get_topics, get_qrels 62 | from tqdm import tqdm 63 | import tempfile 64 | import os 65 | import json 66 | import shutil 67 | 68 | openai_key = os.environ.get("OPENAI_API_KEY", None) 69 | 70 | for data in ['dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04']: 71 | print('#' * 20) 72 | print(f'Evaluation on {data}') 73 | print('#' * 20) 74 | 75 | 76 | # Retrieve passages using pyserini BM25. 77 | try: 78 | searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data]) 79 | topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20') 80 | qrels = get_qrels(THE_TOPICS[data]) 81 | rank_results = run_retriever(topics, searcher, qrels, k=100) 82 | except: 83 | print(f'Failed to retrieve passages for {data}') 84 | continue 85 | 86 | # Run sliding window permutation generation 87 | new_results = [] 88 | for item in tqdm(rank_results): 89 | new_item = sliding_windows(item, rank_start=0, rank_end=100, window_size=20, step=10, 90 | model_name='gpt-3.5-turbo', api_key=openai_key) 91 | new_results.append(new_item) 92 | 93 | # Evaluate nDCG@10 94 | from trec_eval import EvalFunction 95 | 96 | # Create an empty text file to write results, and pass the name to eval 97 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 98 | EvalFunction.write_file(new_results, temp_file) 99 | EvalFunction.main(THE_TOPICS[data], temp_file) 100 | 101 | 102 | for data in ['mrtydi-ar', 'mrtydi-bn', 'mrtydi-fi', 'mrtydi-id', 'mrtydi-ja', 'mrtydi-ko', 'mrtydi-ru', 'mrtydi-sw', 'mrtydi-te', 'mrtydi-th']: 103 | print('#' * 20) 104 | print(f'Evaluation on {data}') 105 | print('#' * 20) 106 | 107 | # Retrieve passages using pyserini BM25. 108 | try: 109 | searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data]) 110 | topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20') 111 | qrels = get_qrels(THE_TOPICS[data]) 112 | rank_results = run_retriever(topics, searcher, qrels, k=100) 113 | rank_results = rank_results[:100] 114 | 115 | except: 116 | print(f'Failed to retrieve passages for {data}') 117 | continue 118 | 119 | # Run sliding window permutation generation 120 | new_results = [] 121 | for item in tqdm(rank_results): 122 | new_item = sliding_windows(item, rank_start=0, rank_end=100, window_size=20, step=10, 123 | model_name='gpt-3.5-turbo', api_key=openai_key) 124 | new_results.append(new_item) 125 | 126 | # Evaluate nDCG@10 127 | from trec_eval import EvalFunction 128 | 129 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 130 | EvalFunction.write_file(new_results, temp_file) 131 | EvalFunction.main(THE_TOPICS[data], temp_file) 132 | -------------------------------------------------------------------------------- /specialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torch.utils.data import Dataset 3 | from accelerate import Accelerator 4 | from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer, AdamW 5 | import torch 6 | from tqdm import tqdm 7 | from rank_loss import RankLoss 8 | import numpy as np 9 | import os 10 | import argparse 11 | import tempfile 12 | import copy 13 | 14 | 15 | class RerankData(Dataset): 16 | def __init__(self, data, tokenizer, neg_num=20, label=True): 17 | self.data = data 18 | self.tokenizer = tokenizer 19 | self.neg_num = neg_num 20 | self.label = label 21 | if not label: 22 | self.neg_num += 1 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, item): 28 | item = self.data[item] 29 | query = item['query'] 30 | 31 | if self.label: 32 | pos = [str(item['positive_passages'][0]['text'])] 33 | pos_id = [psg['docid'] for psg in item['positive_passages']] 34 | neg = [str(psg['text']) for psg in item['retrieved_passages'] if psg['docid'] not in pos_id][:self.neg_num] 35 | else: 36 | pos = [] 37 | neg = [str(psg['text']) for psg in item['retrieved_passages']][:self.neg_num] 38 | neg = neg + [''] * (self.neg_num - len(neg)) 39 | passages = pos + neg 40 | return [query] * len(passages), passages 41 | 42 | def collate_fn(self, data): 43 | query, passages = zip(*data) 44 | query = sum(query, []) 45 | passages = sum(passages, []) 46 | features = self.tokenizer(query, passages, padding=True, truncation=True, return_tensors="pt", 47 | max_length=500) 48 | return features 49 | 50 | 51 | def receive_response(data, responses): 52 | def clean_response(response: str): 53 | new_response = '' 54 | for c in response: 55 | if not c.isdigit(): 56 | new_response += ' ' 57 | else: 58 | new_response += c 59 | new_response = new_response.strip() 60 | return new_response 61 | 62 | def remove_duplicate(response): 63 | new_response = [] 64 | for c in response: 65 | if c not in new_response: 66 | new_response.append(c) 67 | return new_response 68 | 69 | new_data = [] 70 | for item, response in zip(data, responses): 71 | response = clean_response(response) 72 | response = [int(x) - 1 for x in response.split()] 73 | response = remove_duplicate(response) 74 | passages = item['retrieved_passages'] 75 | original_rank = [tt for tt in range(len(passages))] 76 | response = [ss for ss in response if ss in original_rank] 77 | response = response + [tt for tt in original_rank if tt not in response] 78 | new_passages = [passages[ii] for ii in response] 79 | new_data.append({'query': item['query'], 80 | 'positive_passages': item['positive_passages'], 81 | 'retrieved_passages': new_passages}) 82 | return new_data 83 | 84 | 85 | def parse_args(): 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument('--model', type=str, default='microsoft/deberta-v3-base') 88 | parser.add_argument('--loss', type=str, default='rank_net') 89 | parser.add_argument('--data', type=str, default='data/marco-train-10k.jsonl') 90 | parser.add_argument('--save_path', type=str, default='out/deberta-rank_net') 91 | parser.add_argument('--permutation', type=str, default='marco-train-10k-gpt3.5.json') 92 | parser.add_argument('--do_train', type=bool, default=True) 93 | parser.add_argument('--do_eval', type=bool, default=True) 94 | args = parser.parse_args() 95 | 96 | print('====Input Arguments====') 97 | print(json.dumps(vars(args), indent=2, sort_keys=False)) 98 | return args 99 | 100 | 101 | def train(args): 102 | model_name = args.model 103 | loss_type = args.loss 104 | data_path = args.data 105 | save_path = args.save_path 106 | permutation = args.permutation 107 | 108 | accelerator = Accelerator(gradient_accumulation_steps=8) 109 | neg_num = 19 110 | 111 | # Create cross encoder model 112 | config = AutoConfig.from_pretrained(model_name) 113 | config.num_labels = 1 114 | model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config) 115 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 116 | 117 | # Load data and permutation 118 | data = [json.loads(line) for line in open(data_path)] 119 | response = json.load(open(permutation)) 120 | data = receive_response(data, response) 121 | dataset = RerankData(data, tokenizer, neg_num=neg_num, label=False) 122 | 123 | # Prepare data loader 124 | data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, 125 | batch_size=1, shuffle=True, num_workers=0) 126 | optimizer = AdamW(model.parameters(), 5e-5) 127 | model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) 128 | 129 | # Prepare loss function 130 | loss_function = getattr(RankLoss, loss_type) 131 | 132 | # Train for 3 epoch 133 | for epoch in range(3): 134 | accelerator.print(f'Training {save_path} {epoch}') 135 | accelerator.wait_for_everyone() 136 | model.train() 137 | tk0 = tqdm(data_loader, total=len(data_loader)) 138 | loss_report = [] 139 | for batch in tk0: 140 | with accelerator.accumulate(model): 141 | out = model(**batch) 142 | logits = out.logits 143 | logits = logits.view(-1, neg_num + 1) 144 | 145 | y_true = torch.tensor([[1 / (i + 1) for i in range(logits.size(1))]] * logits.size(0)).cuda() 146 | loss = loss_function(logits, y_true) 147 | 148 | accelerator.backward(loss) 149 | accelerator.clip_grad_norm_(model.parameters(), 1.) 150 | optimizer.step() 151 | optimizer.zero_grad() 152 | loss_report.append(accelerator.gather(loss).mean().item()) 153 | tk0.set_postfix(loss=sum(loss_report) / len(loss_report)) 154 | accelerator.wait_for_everyone() 155 | 156 | # Save model 157 | unwrap_model = accelerator.unwrap_model(model) 158 | os.makedirs(save_path, exist_ok=True) 159 | unwrap_model.save_pretrained(save_path) 160 | 161 | return model, tokenizer 162 | 163 | 164 | def eval_on_benchmark(args, model=None, tokenizer=None): 165 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 166 | from rank_gpt import run_retriever, receive_permutation, write_eval_file 167 | from trec_eval import EvalFunction 168 | from pyserini.search import LuceneSearcher, get_topics, get_qrels 169 | 170 | THE_INDEX = { 171 | 'dl19': 'msmarco-v1-passage', 172 | 'dl20': 'msmarco-v1-passage', 173 | 'covid': 'beir-v1.0.0-trec-covid.flat', 174 | 'arguana': 'beir-v1.0.0-arguana.flat', 175 | 'touche': 'beir-v1.0.0-webis-touche2020.flat', 176 | 'news': 'beir-v1.0.0-trec-news.flat', 177 | 'scifact': 'beir-v1.0.0-scifact.flat', 178 | 'fiqa': 'beir-v1.0.0-fiqa.flat', 179 | 'scidocs': 'beir-v1.0.0-scidocs.flat', 180 | 'nfc': 'beir-v1.0.0-nfcorpus.flat', 181 | 'quora': 'beir-v1.0.0-quora.flat', 182 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity.flat', 183 | 'fever': 'beir-v1.0.0-fever-flat', 184 | 'robust04': 'beir-v1.0.0-robust04-flat', 185 | 'signal': 'beir-v1.0.0-signal1m-flat', 186 | } 187 | THE_TOPICS = { 188 | 'dl19': 'dl19-passage', 189 | 'dl20': 'dl20-passage', 190 | 'covid': 'beir-v1.0.0-trec-covid-test', 191 | 'arguana': 'beir-v1.0.0-arguana-test', 192 | 'touche': 'beir-v1.0.0-webis-touche2020-test', 193 | 'news': 'beir-v1.0.0-trec-news-test', 194 | 'scifact': 'beir-v1.0.0-scifact-test', 195 | 'fiqa': 'beir-v1.0.0-fiqa-test', 196 | 'scidocs': 'beir-v1.0.0-scidocs-test', 197 | 'nfc': 'beir-v1.0.0-nfcorpus-test', 198 | 'quora': 'beir-v1.0.0-quora-test', 199 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity-test', 200 | 'fever': 'beir-v1.0.0-fever-test', 201 | 'robust04': 'beir-v1.0.0-robust04-test', 202 | 'signal': 'beir-v1.0.0-signal1m-test', 203 | } 204 | 205 | if model is None or tokenizer is None: 206 | tokenizer = AutoTokenizer.from_pretrained(args.model) 207 | model = AutoModelForSequenceClassification.from_pretrained(args.model) 208 | model = model.cuda() 209 | 210 | model.eval() 211 | 212 | for data in ['dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04']: 213 | print() 214 | print('#' * 20) 215 | print(f'Now eval [{data}]') 216 | print('#' * 20) 217 | 218 | searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data]) 219 | topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20') 220 | qrels = get_qrels(THE_TOPICS[data]) 221 | rank_results = run_retriever(topics, searcher, qrels, k=100) 222 | 223 | reranked_data = [] 224 | for item in tqdm(rank_results): 225 | q = item['query'] 226 | passages = [psg['content'] for i, psg in enumerate(item['hits'])][:100] 227 | if len(passages) == 0: 228 | reranked_data.append(item) 229 | continue 230 | features = tokenizer([q] * len(passages), passages, padding=True, truncation=True, return_tensors="pt", 231 | max_length=500) 232 | features = {k: v.cuda() for k, v in features.items()} 233 | with torch.no_grad(): 234 | scores = model(**features).logits 235 | normalized_scores = [float(score[0]) for score in scores] 236 | ranked = np.argsort(normalized_scores)[::-1] 237 | response = ' > '.join([str(ss + 1) for ss in ranked]) 238 | reranked_data.append(receive_permutation(item, response, rank_start=0, rank_end=100)) 239 | 240 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 241 | EvalFunction.write_file(reranked_data, temp_file) 242 | EvalFunction.main(THE_TOPICS[data], temp_file) 243 | 244 | 245 | if __name__ == '__main__': 246 | args = parse_args() 247 | model, tokenizer = None, None 248 | if args.do_train: 249 | model, tokenizer = train(args) 250 | if args.do_eval: 251 | eval_on_benchmark(args, model, tokenizer) 252 | -------------------------------------------------------------------------------- /trec_eval.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import tempfile 3 | import os 4 | import copy 5 | from typing import Dict, Tuple 6 | import pytrec_eval 7 | 8 | 9 | def trec_eval(qrels: Dict[str, Dict[str, int]], 10 | results: Dict[str, Dict[str, float]], 11 | k_values: Tuple[int] = (10, 50, 100, 200, 1000)) -> Dict[str, float]: 12 | ndcg, _map, recall = {}, {}, {} 13 | 14 | for k in k_values: 15 | ndcg[f"NDCG@{k}"] = 0.0 16 | _map[f"MAP@{k}"] = 0.0 17 | recall[f"Recall@{k}"] = 0.0 18 | 19 | map_string = "map_cut." + ",".join([str(k) for k in k_values]) 20 | ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) 21 | recall_string = "recall." + ",".join([str(k) for k in k_values]) 22 | 23 | evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string}) 24 | scores = evaluator.evaluate(results) 25 | 26 | for query_id in scores: 27 | for k in k_values: 28 | ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)] 29 | _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)] 30 | recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)] 31 | 32 | def _normalize(m: dict) -> dict: 33 | return {k: round(v / len(scores), 5) for k, v in m.items()} 34 | 35 | ndcg = _normalize(ndcg) 36 | _map = _normalize(_map) 37 | recall = _normalize(recall) 38 | 39 | all_metrics = {} 40 | for mt in [ndcg, _map, recall]: 41 | all_metrics.update(mt) 42 | 43 | return all_metrics 44 | 45 | 46 | def get_qrels_file(name): 47 | THE_TOPICS = { 48 | 'dl19': 'dl19-passage', 49 | 'dl20': 'dl20-passage', 50 | 'covid': 'beir-v1.0.0-trec-covid-test', 51 | 'arguana': 'beir-v1.0.0-arguana-test', 52 | 'touche': 'beir-v1.0.0-webis-touche2020-test', 53 | 'news': 'beir-v1.0.0-trec-news-test', 54 | 'scifact': 'beir-v1.0.0-scifact-test', 55 | 'fiqa': 'beir-v1.0.0-fiqa-test', 56 | 'scidocs': 'beir-v1.0.0-scidocs-test', 57 | 'nfc': 'beir-v1.0.0-nfcorpus-test', 58 | 'quora': 'beir-v1.0.0-quora-test', 59 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity-test', 60 | 'fever': 'beir-v1.0.0-fever-test', 61 | 'robust04': 'beir-v1.0.0-robust04-test', 62 | 'signal': 'beir-v1.0.0-signal1m-test', 63 | } 64 | name = THE_TOPICS.get(name, '') 65 | name = name.replace('-test', '.test') 66 | name = 'data/label_file/qrels.' + name + '.txt' # try to use cache 67 | if not os.path.exists(): 68 | from pyserini.search import get_qrels_file 69 | return get_qrels_file(name) # download from pyserini 70 | return name 71 | 72 | 73 | def remove_duplicate(response): 74 | new_response = [] 75 | for c in response: 76 | if c not in new_response: 77 | new_response.append(c) 78 | else: 79 | print('duplicate') 80 | return new_response 81 | 82 | 83 | def clean_response(response: str): 84 | new_response = '' 85 | for c in response: 86 | if not c.isdigit(): 87 | new_response += ' ' 88 | else: 89 | try: 90 | new_response += str(int(c)) 91 | except: 92 | new_response += ' ' 93 | new_response = new_response.strip() 94 | return new_response 95 | 96 | 97 | class EvalFunction: 98 | @staticmethod 99 | def receive_responses(rank_results, responses, cut_start=0, cut_end=100): 100 | print('receive_responses', len(responses), len(rank_results)) 101 | for i in range(len(responses)): 102 | response = responses[i] 103 | response = clean_response(response) 104 | response = [int(x) - 1 for x in response.split()] 105 | response = remove_duplicate(response) 106 | cut_range = copy.deepcopy(rank_results[i]['hits'][cut_start: cut_end]) 107 | original_rank = [tt for tt in range(len(cut_range))] 108 | response = [ss for ss in response if ss in original_rank] 109 | response = response + [tt for tt in original_rank if tt not in response] 110 | for j, x in enumerate(response): 111 | rank_results[i]['hits'][j + cut_start] = { 112 | 'content': cut_range[x]['content'], 'qid': cut_range[x]['qid'], 'docid': cut_range[x]['docid'], 113 | 'rank': cut_range[j]['rank'], 'score': cut_range[j]['score']} 114 | return rank_results 115 | 116 | @staticmethod 117 | def write_file(rank_results, file): 118 | print('write_file') 119 | with open(file, 'w') as f: 120 | for i in range(len(rank_results)): 121 | rank = 1 122 | hits = rank_results[i]['hits'] 123 | for hit in hits: 124 | f.write(f"{hit['qid']} Q0 {hit['docid']} {rank} {hit['score']} rank\n") 125 | rank += 1 126 | return True 127 | 128 | @staticmethod 129 | def trunc(qrels, run): 130 | qrels = get_qrels_file(qrels) 131 | # print(qrels) 132 | run = pd.read_csv(run, delim_whitespace=True, header=None) 133 | qrels = pd.read_csv(qrels, delim_whitespace=True, header=None) 134 | run[0] = run[0].astype(str) 135 | qrels[0] = qrels[0].astype(str) 136 | 137 | qrels = qrels[qrels[0].isin(run[0])] 138 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 139 | qrels.to_csv(temp_file, sep='\t', header=None, index=None) 140 | return temp_file 141 | 142 | @staticmethod 143 | def main(args_qrel, args_run): 144 | 145 | args_qrel = EvalFunction.trunc(args_qrel, args_run) 146 | 147 | assert os.path.exists(args_qrel) 148 | assert os.path.exists(args_run) 149 | 150 | with open(args_qrel, 'r') as f_qrel: 151 | qrel = pytrec_eval.parse_qrel(f_qrel) 152 | 153 | with open(args_run, 'r') as f_run: 154 | run = pytrec_eval.parse_run(f_run) 155 | 156 | all_metrics = trec_eval(qrel, run, k_values=(1, 5, 10)) 157 | print(all_metrics) 158 | return all_metrics 159 | 160 | 161 | if __name__ == '__main__': 162 | EvalFunction.main('dl19', 'ranking_results_file') 163 | --------------------------------------------------------------------------------