├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── launch_job.sh ├── prompts.py ├── rank1.py ├── requirements.txt ├── run_mteb.py └── train_configs ├── README.md ├── export_model.yaml ├── train_lora_llama.yaml ├── train_lora_mistral.yaml ├── train_lora_qwen_14b.yaml ├── train_lora_qwen_32b.yaml └── train_lora_qwen_7b.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | rank1-run-files 3 | .venv 4 | __pycache__ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mteb"] 2 | path = mteb_branch 3 | url = https://github.com/embeddings-benchmark/mteb.git 4 | branch = ce-update 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Orion Weller 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Rank1: Test-Time Compute for Reranking in Information Retrieval

2 | 3 |

4 |

5 | Model/Data Links | 6 | Installation | 7 | Usage | 8 | Citation 9 |

10 |

11 | 12 | Official repository for [rank1, a reasoning reranker model that "thinks"](http://arxiv.org/abs/2502.18418). Rank1 leverages test-time compute to generate reasoning chains before making relevance judgments. 13 | 14 | ## Links 15 | #### Models 16 | | Resource | Description | 17 | |:---------|:------------| 18 | | [rank1-0.5b](https://huggingface.co/jhu-clsp/rank1-0.5b) | Trained from Qwen2.5-0.5B base | 19 | | [rank1-1.5b](https://huggingface.co/jhu-clsp/rank1-1.5b) | Trained from Qwen2.5-1.5B base | 20 | | [rank1-3b](https://huggingface.co/jhu-clsp/rank1-3b) | Trained from Qwen2.5-3B base | 21 | | [rank1-7b](https://huggingface.co/jhu-clsp/rank1-7b) | Trained from Qwen2.5-7B base | 22 | | [rank1-14b](https://huggingface.co/jhu-clsp/rank1-14b) | Trained from Qwen2.5-14B base | 23 | | [rank1-32b](https://huggingface.co/jhu-clsp/rank1-32b) | Trained from Qwen2.5-32B base | 24 | | [rank1-mistral-2501-24b](https://huggingface.co/jhu-clsp/rank1-mistral-2501-24b) | Trained from Mistral-Small 2501 24B base | 25 | | [rank1-llama3-8b](https://huggingface.co/jhu-clsp/rank1-llama3-8b) | Trained from Llama 3.1 8B base | 26 | 27 | #### Quantized Models (fits in 24GB GPUs) 28 | | Resource | Description | 29 | |:---------|:------------| 30 | | [rank1-7b-awq](https://huggingface.co/jhu-clsp/rank1-7b-awq) | Quantized version of rank1-7b | 31 | | [rank1-14b-awq](https://huggingface.co/jhu-clsp/rank1-14b-awq) | Quantized version of rank1-14b | 32 | | [rank1-32b-awq](https://huggingface.co/jhu-clsp/rank1-32b-awq) | Quantized version of rank1-32b | 33 | | [rank1-mistral-2501-24b-awq](https://huggingface.co/jhu-clsp/rank1-mistral-2501-24b-awq) | Quantized version of rank1-mistral-24b | 34 | | [rank1-llama3-8b-awq](https://huggingface.co/jhu-clsp/rank1-llama3-8b-awq) | Quantized version of rank1-llama3-8b | 35 | 36 | #### Datasets 37 | | Resource | Description | 38 | |:---------|:------------| 39 | | [rank1-r1-msmarco](https://huggingface.co/datasets/jhu-clsp/rank1-R1-MSMARCO) | All R1 output examples from MS MARCO | 40 | | [rank1-training-data](https://huggingface.co/datasets/jhu-clsp/rank1-training-data) | Training data used for rank1 models | 41 | | [rank1-run-files](https://huggingface.co/datasets/jhu-clsp/rank1-Run-Files) | Pre-computed run files for use in top 100 doc reranking | 42 | 43 | ## Installation 44 | To reproduce the experiments, you can use the following code with uv for fast, reliable dependency management: 45 | 46 | ```bash 47 | git clone https://github.com/orionw/rank1.git 48 | cd rank1/ 49 | git submodule update --init --recursive 50 | 51 | # Install uv if you don't have it already 52 | curl -fsSL https://pkg.uv.dev/install.sh | sh 53 | 54 | # Create and activate virtual environment with uv 55 | uv venv env --python=3.10 56 | source env/bin/activate 57 | 58 | # Install dependencies with uv 59 | uv pip install -r requirements.txt 60 | uv pip install -e mteb_branch/ 61 | uv pip install --no-build-isolation xformers==0.0.28.post3 62 | uv pip install vllm==0.7.2 63 | 64 | # Recommended: download a flash attention wheel from https://github.com/Dao-AILab/flash-attention/releases and `uv pip install` it 65 | # wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 66 | # uv pip install flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 67 | 68 | # Download the Rank1-Run-Files repository (required for evaluation) 69 | git lfs install # if you don't have it already 70 | git clone https://huggingface.co/datasets/jhu-clsp/rank1-run-files 71 | ``` 72 | 73 | ## Usage 74 | ### Tips 75 | **Reproducibility** Depending on your batch size for evaluation you will get minorly different results due to non-determinisms in vLLM. For our experiments we processed all instances in one batch (e.g. batch_size=99999999999). We also found that using the flag `enforce_eager` sped up inference for the smaller models but not for the larger models. 76 | 77 | **Adapting to New Tasks** You may want to use these models on tasks where the relevance definition is different from MS MARCO. For these you will want a custom prompt to let the model know. You can see these in `prompts.py` various datasets. 78 | 79 | 80 | ### Running Evaluations 81 | To run an evaluation with the rank1 model on a specific dataset: 82 | 83 | ```bash 84 | bash launch_job.sh jhu-clsp/rank1-7b NevIR default 1 85 | ``` 86 | 87 | Parameters: 88 | - `jhu-clsp/rank1-7b`: Model name or path 89 | - `NevIR`: Dataset name 90 | - `default`: Subtask name (use "default" if no subtask) 91 | - `1`: Number of GPUs to use 92 | 93 | 94 | ### Using Rank1 in Your Own Code 95 | You can integrate rank1 into your code: 96 | 97 | ```python 98 | from rank1 import rank1 99 | 100 | # Initialize the model 101 | model = rank1( 102 | model_name_or_path="jhu-clsp/rank1-7B", 103 | num_gpus=1, 104 | device="cuda", 105 | context_size=16000, 106 | max_output_tokens=8192, 107 | fp_options="float16" 108 | ) 109 | 110 | # Rerank documents 111 | results = model.predict({ 112 | "query": ["Your query/prompt here", "Same number as docs"], 113 | "corpus": ["Document 1 content", "Document 2 content", ...], 114 | }) 115 | ``` 116 | 117 | ### MTEB Integration 118 | Rank1 is compatible with the MTEB benchmarking framework. To evaluate your model: 119 | 120 | ```python 121 | from mteb import MTEB 122 | from rank1 import rank1 123 | 124 | # Initialize your model 125 | model = rank1( 126 | model_name_or_path="jhu-clsp/rank1-7b", 127 | num_gpus=1, 128 | device="cuda" 129 | ) 130 | 131 | # Select tasks (or use specific task names) 132 | evaluation = MTEB(tasks=["NevIR"]) 133 | 134 | # Run evaluation 135 | results = evaluation.run(model) 136 | ``` 137 | 138 | ## Citing 139 | If you use rank1 you can cite: 140 | 141 | ```bibtex 142 | @misc{weller2025rank1testtimecomputereranking, 143 | title={Rank1: Test-Time Compute for Reranking in Information Retrieval}, 144 | author={Orion Weller and Kathryn Ricci and Eugene Yang and Andrew Yates and Dawn Lawrie and Benjamin Van Durme}, 145 | year={2025}, 146 | eprint={2502.18418}, 147 | archivePrefix={arXiv}, 148 | primaryClass={cs.IR}, 149 | url={https://arxiv.org/abs/2502.18418}, 150 | } 151 | ``` 152 | 153 | ## License 154 | [MIT](LICENSE) 155 | -------------------------------------------------------------------------------- /launch_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | sweep_id=$1 5 | # Load Python environment 6 | source env/bin/activate 7 | export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 8 | export VLLM_WORKER_MULTIPROC_METHOD=spawn 9 | 10 | # Print debug information 11 | echo "=== Debug Information ===" 12 | echo "Hostname: $(hostname)" 13 | echo "Current directory: $(pwd)" 14 | echo "Date: $(date)" 15 | 16 | # Print GPU information if nvidia-smi is available 17 | if command -v nvidia-smi &> /dev/null; then 18 | echo -e "\n=== GPU Information ===" 19 | nvidia-smi 20 | else 21 | echo "nvidia-smi not found - no GPU information available" 22 | fi 23 | 24 | # Print memory information 25 | echo -e "\n=== Memory Information ===" 26 | free -h 27 | 28 | # Print CPU information 29 | echo -e "\n=== CPU Information ===" 30 | lscpu | grep "Model name" 31 | lscpu | grep "CPU(s):" 32 | 33 | # Print arguments 34 | echo -e "\n=== Script Arguments ===" 35 | echo "Model: $1" 36 | echo "Dataset: $2" 37 | echo "Subtask: $3" 38 | echo "Number of GPUs: $4" 39 | 40 | echo -e "\n=== Environment ===" 41 | echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" 42 | echo "PATH: $PATH" 43 | echo "PYTHONPATH: $PYTHONPATH" 44 | 45 | echo "=== End Debug Information ===\n" 46 | 47 | 48 | model=$1 49 | dataset=$2 50 | subtask=$3 51 | num_gpus=$4 52 | 53 | # if num_gpus is not provided, then set it to 1 54 | if [ -z "$num_gpus" ]; then 55 | num_gpus=1 56 | fi 57 | 58 | # if subtask is "default" ignore it 59 | echo "$(which python)" 60 | if [ "$subtask" != "default" ]; then 61 | echo "Running with subtask: $subtask" 62 | python run_mteb.py -m $model -d $dataset -s $subtask -n $num_gpus 63 | else 64 | echo "Running with no subtask" 65 | python run_mteb.py -m $model -d $dataset -n $num_gpus 66 | fi 67 | 68 | # example: bash launch_job.sh jhu-clsp/Rank1-7B BrightRetrieval biology 1 69 | -------------------------------------------------------------------------------- /prompts.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | 4 | PROMPT_DICT = { 5 | "SciFact": """Claim: FILL_QUERY_HERE 6 | 7 | A relevant passage would provide evidence that either **supports** or **refutes** this claim. A passage with any information on any related subpart should be relevant.""", 8 | 9 | "ClimateFEVER": """I am looking to write an essay for and need to find evidence the supports or contradict this statement: 10 | 11 | FILL_QUERY_HERE 12 | 13 | If the passage provides information that supports or contradicts it in any way it is relevant so I can cite it. 14 | 15 | """, 16 | 17 | "TRECCOVID": """FILL_QUERY_HERE If the article answers any part of the question it is relevant.""", 18 | 19 | "ArguAna": """I am looking to write an essay and need to find counterarguments against this statement: 20 | 21 | FILL_QUERY_HERE 22 | 23 | Does this passage have any counterargument or evidence that could be used to help me? 24 | 25 | """, 26 | 27 | "DBPedia": """I am looking to write an essay on this topic and need as much related background information to help me. The topic is: 28 | 29 | FILL_QUERY_HERE 30 | 31 | If the passage provides any background information that could be connected it is relevant. 32 | 33 | """, 34 | 35 | "FiQA2018": """FILL_QUERY_HERE Find a passage that would be a good answer from StackExchange.""", 36 | 37 | "NFCorpus": """Topic: FILL_QUERY_HERE 38 | 39 | Given the above topic, I need to learn about all aspects of it. It does not need to be directly relevant, only tangentially informational. Please mark as relevant any passages with even weak connections. I need to learn fast for my job, which means I need to understand each part individually. 40 | 41 | Again remember, any connection means relevant even if indirect. So if it is not addressed, that is okay -- it does not need to be explicitly. 42 | 43 | Find me passages with any type of connection, including weak connections!!!!""", 44 | 45 | "Touche2020": """I am looking to write an essay and need to find arguments for or against this statement: 46 | 47 | FILL_QUERY_HERE 48 | 49 | Does this passage have any argument or evidence that could be used to help me? 50 | 51 | """, 52 | 53 | "SCIDOCS": """papers that could be cited in FILL_QUERY_HERE. Anything with even indirect relevance should be relevant. This includes papers in the same broader field of science""", 54 | 55 | "BrightRetrieval_aops": """Find different but similar math problems to FILL_QUERY_HERE\n\nA document is relevant if it uses the same class of functions and shares **any** overlapping techniques.""", 56 | 57 | "BrightRetrieval_theoremqa_questions": """Find a passage which uses the same mathematical process as this one: FILL_QUERY_HERE""", 58 | 59 | "BrightRetrieval_leetcode": """I am looking to find different problems that share similar data structures (of any kind) or algorithms (e.g. DFS, DP, sorting, traversals, etc.). I am looking for problems that share one or both of these similarities to this: 60 | 61 | FILL_QUERY_HERE 62 | 63 | Does this passage share any similarities? e.g. if there was a textbook on leetcode problems, this would be in the same book even though it could be in a different chapter. 64 | 65 | 66 | """, 67 | 68 | "BrightRetrieval_pony": """I will use the programming language pony. Problem: FILL_QUERY_HERE 69 | 70 | But to solve the problem above, I need to know things about pony. A passage is relevant if it contains docs that match **any** part (even basic parts) of the code I will have to write for the above program.""", 71 | 72 | "BrightRetrieval": """Can you find background information about the concepts used to answer the question: 73 | 74 | FILL_QUERY_HERE 75 | 76 | A passage is relevant if it contains background information about a **sub-concept** that someone might cite/link to when answering the above question.""" 77 | 78 | } 79 | 80 | PROMPT_DICT["BrightRetrieval_theoremqa_theorems"] = PROMPT_DICT["BrightRetrieval_theoremqa_questions"] 81 | 82 | 83 | def get_prompt(task_name, subtask_name: str = None): 84 | if subtask_name is not None and task_name in PROMPT_DICT: 85 | # if subtask is present, use that, otherwise use just the task name 86 | if f"{task_name}_{subtask_name}" in PROMPT_DICT: 87 | return PROMPT_DICT[f"{task_name}_{subtask_name}"] 88 | else: # default for subtask (e.g. BrightRetrieval) 89 | return PROMPT_DICT[task_name] 90 | elif task_name in PROMPT_DICT: 91 | # no subtask 92 | return PROMPT_DICT[task_name] 93 | else: 94 | return None 95 | 96 | 97 | BEIR_DATASETS = [ 98 | "ArguAna", 99 | "ClimateFEVER", 100 | "DBPedia", 101 | "FEVER", 102 | "FiQA2018", 103 | "HotpotQA", 104 | "NFCorpus", 105 | "NQ", 106 | "QuoraRetrieval", 107 | "SCIDOCS", 108 | "SciFact", 109 | "TRECCOVID", 110 | "Touche2020", 111 | ] 112 | 113 | 114 | def validate_json(file_path: str) -> bool: 115 | try: 116 | with open(file_path, "r") as f: 117 | data = json.load(f) 118 | # assert there are string keys and that within that a dict of key -> float values 119 | for key in data: 120 | assert isinstance(key, str), f"Key is not a string: {key}" 121 | assert isinstance(data[key], dict), f"Data is not a dict: {data[key]}" 122 | for inner_key, inner_value in data[key].items(): 123 | assert isinstance(inner_key, str), f"Inner key is not a string: {inner_key}" 124 | assert isinstance(inner_value, float), f"Inner value is not a float: {inner_value}" 125 | return True 126 | except Exception as e: 127 | print(e) 128 | return False -------------------------------------------------------------------------------- /rank1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import mteb 3 | from mteb import MTEB 4 | import logging 5 | import os 6 | import json 7 | 8 | from functools import partial 9 | import logging 10 | import math 11 | from typing import Any, Callable, List, Tuple 12 | 13 | import torch 14 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 15 | from vllm import LLM, SamplingParams 16 | 17 | from mteb.encoder_interface import Encoder 18 | from mteb.evaluation.evaluators.RetrievalEvaluator import DenseRetrievalExactSearch 19 | from mteb.model_meta import ModelMeta 20 | from mteb.models.rerankers_custom import RerankerWrapper 21 | 22 | 23 | logging.basicConfig(level=logging.INFO) 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class rank1(RerankerWrapper): 28 | name: str = "rank1" 29 | 30 | def __init__( 31 | self, 32 | model_name_or_path: str = "jhu-clsp/rank1-7b", 33 | batch_size: int = 999999999999, 34 | context_size: int = 16000, 35 | max_output_tokens: int = 8192, 36 | fp_options: str = "float16", 37 | num_gpus: int = 1, 38 | device: str = "cuda", 39 | force_rethink: int = 0, 40 | dataset_prompt: str = None, 41 | **kwargs, 42 | ): 43 | """ 44 | rank1 is a reasoning reranker model (using test-time compute) which generates a reasoning chain before deciding true or false 45 | 46 | Args: 47 | model_name_or_path: Path to the model or name of the model on HuggingFace Hub 48 | batch_size: Maximum batch size for processing (default: very large number to let vLLM handle batching) 49 | context_size: Maximum context length for the model (default: 4096) 50 | max_output_tokens: Maximum number of tokens to generate (default: 1024) 51 | fp_options: Floating point precision to use, e.g. 'float16' (default: 'float16') 52 | num_gpus: Number of GPUs to use for tensor parallelism (default: 1) 53 | device: Device to load the model on (default: 'cuda') 54 | force_rethink: Number of times to force model to rethink its answer (default: 0) 55 | **kwargs: Additional keyword arguments passed to parent RerankerWrapper 56 | """ 57 | super().__init__(model_name_or_path, batch_size=batch_size, fp_options=fp_options, **kwargs) 58 | 59 | self.context_size = context_size 60 | self.max_output_tokens = max_output_tokens 61 | self.num_gpus = num_gpus 62 | self.device = device 63 | self.force_rethink = force_rethink 64 | self.model_name_or_path = model_name_or_path 65 | self.dataset_prompt = dataset_prompt 66 | 67 | # Initialize tokenizer with max length of 68 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 69 | self.tokenizer.padding_side = "left" 70 | self.tokenizer.pad_token = self.tokenizer.eos_token 71 | 72 | # Cache commonly used token IDs 73 | self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0] 74 | self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0] 75 | self.think_token = self.tokenizer("", add_special_tokens=False).input_ids[0] 76 | self.think_end_token = self.tokenizer("", add_special_tokens=False).input_ids[-1] 77 | 78 | self.model = LLM( 79 | model=model_name_or_path, 80 | tensor_parallel_size=int(num_gpus), 81 | trust_remote_code=True, 82 | max_model_len=context_size, 83 | gpu_memory_utilization=0.9, 84 | dtype=fp_options, 85 | ) 86 | self.sampling_params = SamplingParams( 87 | temperature=0, 88 | max_tokens=max_output_tokens, 89 | logprobs=20, 90 | stop=[" true", " false"], 91 | skip_special_tokens=False 92 | ) 93 | 94 | def _fix_incomplete_responses( 95 | self, 96 | original_prompts: List[str], 97 | generated_texts: List[str] 98 | ) -> Tuple[List[str], List[int], List[float]]: 99 | """ 100 | This function is used to fix incomplete responses from the vLLM model. In some cases the model does not generate the end token. 101 | In these cases, we should force it to generate it so that we have some prediction. 102 | 103 | Args: 104 | original_prompts: The original prompts that were used to generate the texts 105 | generated_texts: The texts that were generated by the vLLM model 106 | 107 | Returns: 108 | final_texts: The texts that were generated by the vLLM model + the outputs from the forcing step 109 | token_counts: The number of tokens in the texts total 110 | scores: The scores of the texts 111 | """ 112 | cleaned_texts = [] 113 | for text in generated_texts: 114 | text = text.rstrip() 115 | if not text.endswith(('.', '!', '?')): 116 | last_punct = max(text.rfind('.'), text.rfind('!'), text.rfind('?')) 117 | if last_punct != -1: 118 | text = text[:last_punct + 1] 119 | cleaned_texts.append(text.strip()) 120 | 121 | forced_prompts = [ 122 | f"{original_prompt}\n{cleaned_text}\n" 123 | for original_prompt, cleaned_text in zip(original_prompts, cleaned_texts) 124 | ] 125 | 126 | new_sampling_args = SamplingParams( 127 | temperature=0, 128 | max_tokens=1, 129 | logprobs=20, 130 | allowed_token_ids=[self.true_token, self.false_token], 131 | skip_special_tokens=False 132 | ) 133 | outputs = self.model.generate(forced_prompts, new_sampling_args) 134 | 135 | # get the next token logits of just the next token 136 | all_final_texts = [] 137 | all_token_counts = [] 138 | all_scores = [] 139 | for i in range(len(outputs)): 140 | try: 141 | text = outputs[i].outputs[0].text 142 | final_logits = outputs[i].outputs[0].logprobs[-1] 143 | assert self.false_token in final_logits and self.true_token in final_logits, f"final logits are missing true or false: {final_logits}" 144 | except Exception as e: 145 | print(f"Error: {e} on fixing error, setting at 0.5 score: {outputs[i].outputs}") 146 | all_scores.append(0.5) 147 | all_token_counts.append(len(outputs[i].outputs[0].token_ids)) 148 | all_final_texts.append(text) 149 | continue 150 | 151 | token_count = len(outputs[i].outputs[0].token_ids) 152 | true_logit = final_logits[self.true_token].logprob 153 | false_logit = final_logits[self.false_token].logprob 154 | true_score = math.exp(true_logit) 155 | false_score = math.exp(false_logit) 156 | score = true_score / (true_score + false_score) 157 | 158 | all_final_texts.append(text) 159 | all_token_counts.append(token_count) 160 | all_scores.append(score) 161 | 162 | return all_final_texts, all_token_counts, all_scores 163 | 164 | def truncate_input(self, text: str) -> str: 165 | """ 166 | Truncate the input text to the context size. This is not used, except if you are using the Llama 8B quantized model 167 | """ 168 | if len(self.tokenizer(text)["input_ids"]) >= self.context_size: 169 | return self.tokenizer.decode(self.tokenizer(text)["input_ids"][:self.context_size]) 170 | else: 171 | return text 172 | 173 | def _process_with_vllm(self, prompts): 174 | """ 175 | vLLM is significantly faster than HF, so we use it by default. This function handles the cases where the model does not generate the end token. 176 | 177 | Args: 178 | prompts: The prompts to generate from 179 | 180 | Returns: 181 | outputs: The outputs from the vLLM model 182 | """ 183 | # prompts = [self.truncate_input(prompt) for prompt in prompts] 184 | outputs = self.model.generate(prompts, self.sampling_params) 185 | 186 | # Pre-allocate lists with None values 187 | total_length = len(prompts) 188 | all_outputs = [None] * total_length 189 | all_output_token_counts = [None] * total_length 190 | all_scores = [None] * total_length 191 | 192 | incomplete_prompts = [] 193 | incomplete_texts = [] 194 | incomplete_indices = [] 195 | 196 | # Process complete responses first 197 | for i, output in enumerate(outputs): 198 | text = output.outputs[0].text 199 | try: 200 | final_logits = output.outputs[0].logprobs[-1] 201 | except Exception as e: 202 | print(f"Error: {e} on getting final logits: {output.outputs[0]}") 203 | incomplete_prompts.append(prompts[i]) 204 | incomplete_texts.append(text) 205 | incomplete_indices.append(i) 206 | continue 207 | 208 | if self.true_token not in final_logits or self.false_token not in final_logits: 209 | incomplete_prompts.append(prompts[i]) 210 | incomplete_texts.append(text) 211 | incomplete_indices.append(i) 212 | continue 213 | 214 | token_count = len(output.outputs[0].token_ids) 215 | true_logit = final_logits[self.true_token].logprob 216 | false_logit = final_logits[self.false_token].logprob 217 | true_score = math.exp(true_logit) 218 | false_score = math.exp(false_logit) 219 | score = true_score / (true_score + false_score) 220 | 221 | all_outputs[i] = text 222 | all_output_token_counts[i] = token_count 223 | all_scores[i] = score 224 | 225 | # Handle incomplete responses 226 | if incomplete_indices: 227 | fixed_texts, fixed_counts, fixed_scores = self._fix_incomplete_responses( 228 | incomplete_prompts, incomplete_texts 229 | ) 230 | 231 | # Fill in the fixed responses at their original positions 232 | for orig_idx, (text, count, score) in zip( 233 | incomplete_indices, zip(fixed_texts, fixed_counts, fixed_scores) 234 | ): 235 | all_outputs[orig_idx] = text 236 | all_output_token_counts[orig_idx] = count 237 | all_scores[orig_idx] = score 238 | 239 | return all_outputs, all_output_token_counts, all_scores 240 | 241 | def return_prompt(self, query, doc_content, prompt) -> str: 242 | query = prompt.replace("FILL_QUERY_HERE", query) if prompt else query 243 | return "Determine if the following passage is relevant to the query. " \ 244 | "Answer only with 'true' or 'false'.\n" \ 245 | f"Query: {query}\n" \ 246 | f"Passage: {doc_content}\n" \ 247 | "" # force the model to start with this 248 | 249 | def _prepare_prompts_for_rethink(self, prompts: List[str], texts: List[str], rethink_text: str = "Wait") -> List[str]: 250 | """Prepare prompts for the rethinking step.""" 251 | full_texts = [p + t for p, t in zip(prompts, texts)] 252 | stripped_texts = [t.split("")[0] for t in full_texts] 253 | just_generated_texts = [t.split("")[0] for t in full_texts] 254 | return [s + f"\n{rethink_text}" for s in stripped_texts], just_generated_texts 255 | 256 | @torch.inference_mode() 257 | def predict(self, input_to_rerank, **kwargs): 258 | """This is setup to run with mteb but can be adapted to your purpose""" 259 | inputs = list(zip(*input_to_rerank)) 260 | if len(input_to_rerank[0]) == 2: 261 | queries, passages = inputs 262 | instructions = None 263 | else: 264 | queries, passages, instructions = inputs 265 | 266 | if instructions is not None and instructions[0] is not None: 267 | queries = [f"{q} {i}".strip() if q.strip() != i.strip() else q.strip() for i, q in zip(instructions, queries)] 268 | 269 | if isinstance(passages[0], dict): 270 | passages = [f"{v['title']} {v['text']}" if 'title' in v else v['text'] for v in passages] 271 | 272 | prompts = [ 273 | self.return_prompt(query, passage, self.dataset_prompt) 274 | for query, passage in zip(queries, passages) 275 | ] 276 | print(f"Example prompt: ```\n{prompts[0]}\n```") 277 | 278 | texts, token_counts, scores = self._process_with_vllm(prompts) 279 | while self.force_rethink: 280 | revised_prompts, previously_generated_texts = self._prepare_prompts_for_rethink(prompts, texts) 281 | new_texts, new_token_counts, new_scores = self._process_with_vllm(revised_prompts) 282 | # add to the previous output 283 | texts = [prev + f"\n{rethink_text}" + f"{new_text}" for prev, new_text in zip(texts, new_texts)] 284 | scores = new_scores 285 | token_counts = [prev_token_count + new_token_count for prev_token_count, new_token_count in zip(token_counts, new_token_counts)] 286 | self.force_rethink -= 1 287 | 288 | return scores 289 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.3.0 2 | bm25s==0.2.7.post1 3 | datasets==3.2.0 4 | einops==0.8.1 5 | gritlm==1.0.2 6 | huggingface-hub==0.28.1 7 | numpy==1.26.4 8 | pystemmer==2.2.0.3 9 | scikit-learn==1.6.1 10 | scipy==1.13.1 11 | sentence-transformers==3.4.1 12 | sentencepiece==0.2.0 13 | tiktoken==0.7.0 14 | tokenizers==0.21.0 15 | tomlkit==0.12.0 16 | torch==2.5.1 17 | torchaudio==2.5.1 18 | torchvision==0.20.1 19 | tqdm==4.67.1 20 | transformers==4.48.2 21 | peft 22 | -------------------------------------------------------------------------------- /run_mteb.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import mteb 5 | from mteb import MTEB 6 | import logging 7 | import os 8 | import json 9 | 10 | from functools import partial 11 | import logging 12 | import math 13 | from typing import Any, Callable, List, Tuple 14 | 15 | from mteb.encoder_interface import Encoder 16 | from mteb.evaluation.evaluators.RetrievalEvaluator import DenseRetrievalExactSearch 17 | from mteb.model_meta import ModelMeta 18 | from mteb.models.rerankers_custom import RerankerWrapper 19 | 20 | from prompts import get_prompt, PROMPT_DICT, validate_json, BEIR_DATASETS 21 | from rank1 import rank1 22 | 23 | logging.basicConfig(level=logging.INFO) 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def get_safe_folder_name(model_name): 28 | return model_name.replace("/", "_").replace("\\", "_") 29 | 30 | def run_evaluation(dataset_name: str, subtask: str, model_name: str, num_gpus: int, skip_prompt: bool) -> None: 31 | """Run MTEB evaluation for specified model and dataset 32 | 33 | Args: 34 | dataset_name: Name of MTEB dataset to evaluate 35 | """ 36 | # Initialize MTEB task and evaluation 37 | tasks = mteb.get_tasks(tasks=[dataset_name]) 38 | evaluation = MTEB(tasks=tasks) 39 | 40 | if dataset_name == "BrightRetrieval": 41 | previous_results = f"rank1-run-files/{subtask}_bm25_long_False/score.json" 42 | eval_splits = ["standard"] 43 | elif dataset_name == "mFollowIRCrossLingual": 44 | previous_results = f"rank1-run-files/mFollowIRCrossLingual_{subtask}_predictions.json" 45 | eval_splits = None 46 | elif dataset_name == "mFollowIR": 47 | previous_results = f"rank1-run-files/mFollowIR_{subtask}_predictions.json" 48 | eval_splits = None 49 | elif dataset_name in BEIR_DATASETS: 50 | eval_splits = ["test"] 51 | previous_results = f"rank1-run-files/{dataset_name}_default_predictions.json" 52 | else: 53 | print(f"Running with no subtask or eval splits for dataset: {dataset_name}") 54 | previous_results = None 55 | eval_splits = None 56 | 57 | encode_kwargs = { 58 | # use vLLM to batch 59 | "batch_size": 999999 if "rank1" in model_name.lower() else 32 60 | } 61 | 62 | prompt = get_prompt(dataset_name, subtask) 63 | if prompt is not None and not skip_prompt: 64 | is_prompted = True 65 | else: 66 | is_prompted = False 67 | prompt = None 68 | print(f"Prompt: {prompt}") 69 | 70 | if subtask == "default": 71 | subtask = None 72 | 73 | if previous_results is not None: 74 | assert validate_json(previous_results), f"Previous results are not valid json: {previous_results}" 75 | print(f"Previous results: {previous_results}") 76 | 77 | model = rank1(model_name_or_path=model_name.strip(), num_gpus=num_gpus, dataset_prompt=prompt) 78 | output_dir = f"results/{model_name}/{dataset_name}_{subtask}" 79 | print(f"Output directory: {output_dir}") 80 | 81 | # Run evaluation 82 | evaluation.run( 83 | model, 84 | save_predictions=True, 85 | encode_kwargs=encode_kwargs, 86 | output_folder=output_dir, 87 | previous_results=previous_results, 88 | eval_subsets=[subtask] if previous_results else None, 89 | eval_splits=eval_splits, 90 | top_k=100 91 | ) 92 | 93 | 94 | def main(): 95 | parser = argparse.ArgumentParser(description="Run MTEB evaluation") 96 | parser.add_argument("-d", "--dataset", required=True, help="MTEB dataset name") 97 | parser.add_argument("-s", "--subtask", required=False, default=None, help="MTEB subtask name") 98 | parser.add_argument("-m", "--model_name", required=True, help="Model name") 99 | parser.add_argument("-n", "--num_gpus", required=False, help="Number of GPUs", default=1) 100 | parser.add_argument("-p", "--skip_prompt", action="store_true", help="Skip prompt") 101 | args = parser.parse_args() 102 | run_evaluation(args.dataset.strip(), args.subtask.strip() if args.subtask is not None else None, args.model_name.strip(), args.num_gpus, args.skip_prompt) 103 | 104 | 105 | if __name__ == "__main__": 106 | os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" 107 | main() 108 | -------------------------------------------------------------------------------- /train_configs/README.md: -------------------------------------------------------------------------------- 1 | # Rank1 Model Training Configurations 2 | 3 | This directory contains configuration files for training and exporting Rank1 models using [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory). 4 | 5 | ## Getting Started 6 | 7 | ### 1. Setup Training Data 8 | 9 | We use the data in `jhu-clsp/rank1-training-data` for training. Follow these steps to set up: 10 | 11 | 1. Download the training data and store it in LLaMA-Factory's data folder (called `train.json`) 12 | 2. Update the `dataset_info.json` file to reference your data: 13 | 14 | ```json 15 | { 16 | "r1_reranking_self_filtered_negs_only": { 17 | "file_name": "train.json" 18 | } 19 | } 20 | ``` 21 | 22 | ### 2. Available Training Configurations 23 | 24 | The following configuration files are available: 25 | 26 | - `train_lora_llama.yaml` - LoRA training configuration for Llama-3 8B 27 | - `train_lora_mistral.yaml` - LoRA training configuration for Mistral 24B 28 | - `train_lora_qwen_7b.yaml` - LoRA training configuration for Qwen 7B 29 | - `train_lora_qwen_14b.yaml` - LoRA training configuration for Qwen 14B 30 | - `train_lora_qwen_32b.yaml` - LoRA training configuration for Qwen 32B 31 | 32 | ### 3. Training Process 33 | 34 | After installing LLaMA-Factory requirements, start training with: 35 | 36 | ```bash 37 | llamafactory-cli train train_configs/train_lora_XXX.yaml 38 | ``` 39 | 40 | Replace `XXX` with the specific model you want to train (e.g., `llama`, `mistral`, `qwen_7b`). 41 | 42 | ### 4. Exporting Model Weights 43 | 44 | After training, export the merged LoRA weights using: 45 | 46 | ```bash 47 | llamafactory-cli export train_configs/export_model.yaml 48 | ``` 49 | 50 | Before running the export command, modify the `export_model.yaml` file with your specific values: 51 | 52 | ```yaml 53 | model_name_or_path: BASE_MODEL_HERE # Path to the base model 54 | adapter_name_or_path: INPUT_DIR # Path to the trained adapter 55 | export_dir: EXPORT_DIR # Where to save the exported model 56 | ``` 57 | 58 | ## Configuration Parameters 59 | 60 | Each YAML file contains configuration sections for: 61 | - Model settings 62 | - Training method (LoRA parameters) 63 | - Dataset configuration 64 | - Output settings 65 | - Training hyperparameters 66 | - Evaluation settings 67 | 68 | Refer to each specific YAML file for detailed configuration options. 69 | 70 | -------------------------------------------------------------------------------- /train_configs/export_model.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: BASE_MODEL_HERE 3 | adapter_name_or_path: INPUT_DIR 4 | template: empty 5 | finetuning_type: lora 6 | trust_remote_code: true 7 | 8 | ### export 9 | export_dir: EXPORT_DIR 10 | export_size: 10 11 | export_device: cpu 12 | export_legacy_format: false 13 | new_special_tokens: "," 14 | resize_vocab: true -------------------------------------------------------------------------------- /train_configs/train_lora_llama.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B 3 | trust_remote_code: true 4 | 5 | ### method 6 | stage: sft 7 | do_train: true 8 | finetuning_type: lora 9 | lora_rank: 32 10 | lora_alpha: 64 11 | lora_target: all 12 | 13 | ### dataset 14 | dataset: r1_reranking_self_filtered_negs_only 15 | template: empty 16 | cutoff_len: 2500 17 | overwrite_cache: false 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/llama-8b/lora/sft 22 | logging_steps: 10 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 4 28 | gradient_accumulation_steps: 8 29 | learning_rate: 1.0e-4 30 | num_train_epochs: 2.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.05 33 | bf16: true 34 | ddp_timeout: 180000000 35 | new_special_tokens: "," 36 | resize_vocab: true 37 | 38 | ### eval 39 | val_size: 0.01 40 | per_device_eval_batch_size: 1 41 | eval_strategy: steps 42 | eval_steps: 999999999999 43 | save_steps: 250 44 | seed: 12345 45 | 46 | report_to: wandb 47 | run_name: llama_8b_lora -------------------------------------------------------------------------------- /train_configs/train_lora_mistral.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: mistralai/Mistral-Small-24B-Base-2501 3 | trust_remote_code: true 4 | 5 | ### method 6 | stage: sft 7 | do_train: true 8 | finetuning_type: lora 9 | lora_rank: 32 10 | lora_alpha: 64 11 | lora_target: all 12 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 13 | 14 | ### dataset 15 | dataset: r1_reranking_self_filtered_negs_only 16 | template: empty 17 | cutoff_len: 2500 18 | overwrite_cache: false 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/mistral-24b/lora/sft 23 | logging_steps: 10 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | 27 | ### train 28 | per_device_train_batch_size: 2 29 | gradient_accumulation_steps: 16 30 | learning_rate: 1.0e-4 31 | num_train_epochs: 2.0 32 | lr_scheduler_type: cosine 33 | warmup_ratio: 0.05 34 | bf16: true 35 | ddp_timeout: 180000000 36 | new_special_tokens: "," 37 | resize_vocab: true 38 | 39 | ### eval 40 | val_size: 0.01 41 | per_device_eval_batch_size: 1 42 | eval_strategy: steps 43 | eval_steps: 999999999999 44 | save_steps: 250 45 | seed: 12345 46 | 47 | report_to: wandb 48 | run_name: mistral_24b_lora -------------------------------------------------------------------------------- /train_configs/train_lora_qwen_14b.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-14B 3 | trust_remote_code: true 4 | 5 | ### method 6 | stage: sft 7 | do_train: true 8 | finetuning_type: lora 9 | lora_rank: 32 10 | lora_alpha: 64 11 | lora_target: all 12 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 13 | 14 | ### dataset 15 | dataset: r1_reranking_self_filtered_negs_only 16 | template: empty 17 | cutoff_len: 2500 18 | overwrite_cache: false 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/qwen-14b/lora/sft 23 | logging_steps: 10 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | 27 | ### train 28 | per_device_train_batch_size: 2 29 | gradient_accumulation_steps: 16 30 | learning_rate: 1.0e-4 31 | num_train_epochs: 2.0 32 | lr_scheduler_type: cosine 33 | warmup_ratio: 0.05 34 | bf16: true 35 | ddp_timeout: 180000000 36 | new_special_tokens: "," 37 | resize_vocab: true 38 | 39 | ### eval 40 | val_size: 0.01 41 | per_device_eval_batch_size: 1 42 | eval_strategy: steps 43 | eval_steps: 999999999999 44 | save_steps: 250 45 | seed: 12345 46 | 47 | report_to: wandb 48 | run_name: qwen_14b_lora -------------------------------------------------------------------------------- /train_configs/train_lora_qwen_32b.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-32B 3 | trust_remote_code: true 4 | 5 | ### method 6 | stage: sft 7 | do_train: true 8 | finetuning_type: lora 9 | lora_rank: 32 10 | lora_alpha: 64 11 | lora_target: all 12 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 13 | 14 | ### dataset 15 | dataset: r1_reranking_self_filtered_negs_only 16 | template: empty 17 | cutoff_len: 2500 18 | overwrite_cache: false 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/qwen-32b/lora/sft 23 | logging_steps: 10 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | 27 | ### train 28 | per_device_train_batch_size: 1 29 | gradient_accumulation_steps: 32 30 | learning_rate: 1.0e-4 31 | num_train_epochs: 2.0 32 | lr_scheduler_type: cosine 33 | warmup_ratio: 0.05 34 | bf16: true 35 | ddp_timeout: 180000000 36 | new_special_tokens: "," 37 | resize_vocab: true 38 | 39 | ### eval 40 | val_size: 0.01 41 | per_device_eval_batch_size: 1 42 | eval_strategy: steps 43 | eval_steps: 999999999999 44 | save_steps: 250 45 | seed: 12345 46 | 47 | report_to: wandb 48 | run_name: qwen_32b_lora -------------------------------------------------------------------------------- /train_configs/train_lora_qwen_7b.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-7B 3 | trust_remote_code: true 4 | 5 | ### method 6 | stage: sft 7 | do_train: true 8 | finetuning_type: lora 9 | lora_rank: 32 10 | lora_alpha: 64 11 | lora_target: all 12 | 13 | ### dataset 14 | dataset: r1_reranking_self_filtered_negs_only 15 | template: empty 16 | cutoff_len: 2500 17 | overwrite_cache: false 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/qwen-7b/lora/sft 22 | logging_steps: 10 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 4 28 | gradient_accumulation_steps: 8 29 | learning_rate: 1.0e-4 30 | num_train_epochs: 2.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.05 33 | bf16: true 34 | ddp_timeout: 180000000 35 | new_special_tokens: "," 36 | resize_vocab: true 37 | 38 | ### eval 39 | val_size: 0.01 40 | per_device_eval_batch_size: 1 41 | eval_strategy: steps 42 | eval_steps: 999999999999 43 | save_steps: 250 44 | seed: 12345 45 | 46 | report_to: wandb 47 | run_name: qwen_7b_lora --------------------------------------------------------------------------------