├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── launch_job.sh
├── prompts.py
├── rank1.py
├── requirements.txt
├── run_mteb.py
└── train_configs
├── README.md
├── export_model.yaml
├── train_lora_llama.yaml
├── train_lora_mistral.yaml
├── train_lora_qwen_14b.yaml
├── train_lora_qwen_32b.yaml
└── train_lora_qwen_7b.yaml
/.gitignore:
--------------------------------------------------------------------------------
1 | results
2 | rank1-run-files
3 | .venv
4 | __pycache__
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "mteb"]
2 | path = mteb_branch
3 | url = https://github.com/embeddings-benchmark/mteb.git
4 | branch = ce-update
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Orion Weller
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Rank1: Test-Time Compute for Reranking in Information Retrieval
2 |
3 |
4 |
5 | Model/Data Links |
6 | Installation |
7 | Usage |
8 | Citation
9 |
10 |
11 |
12 | Official repository for [rank1, a reasoning reranker model that "thinks"](http://arxiv.org/abs/2502.18418). Rank1 leverages test-time compute to generate reasoning chains before making relevance judgments.
13 |
14 | ## Links
15 | #### Models
16 | | Resource | Description |
17 | |:---------|:------------|
18 | | [rank1-0.5b](https://huggingface.co/jhu-clsp/rank1-0.5b) | Trained from Qwen2.5-0.5B base |
19 | | [rank1-1.5b](https://huggingface.co/jhu-clsp/rank1-1.5b) | Trained from Qwen2.5-1.5B base |
20 | | [rank1-3b](https://huggingface.co/jhu-clsp/rank1-3b) | Trained from Qwen2.5-3B base |
21 | | [rank1-7b](https://huggingface.co/jhu-clsp/rank1-7b) | Trained from Qwen2.5-7B base |
22 | | [rank1-14b](https://huggingface.co/jhu-clsp/rank1-14b) | Trained from Qwen2.5-14B base |
23 | | [rank1-32b](https://huggingface.co/jhu-clsp/rank1-32b) | Trained from Qwen2.5-32B base |
24 | | [rank1-mistral-2501-24b](https://huggingface.co/jhu-clsp/rank1-mistral-2501-24b) | Trained from Mistral-Small 2501 24B base |
25 | | [rank1-llama3-8b](https://huggingface.co/jhu-clsp/rank1-llama3-8b) | Trained from Llama 3.1 8B base |
26 |
27 | #### Quantized Models (fits in 24GB GPUs)
28 | | Resource | Description |
29 | |:---------|:------------|
30 | | [rank1-7b-awq](https://huggingface.co/jhu-clsp/rank1-7b-awq) | Quantized version of rank1-7b |
31 | | [rank1-14b-awq](https://huggingface.co/jhu-clsp/rank1-14b-awq) | Quantized version of rank1-14b |
32 | | [rank1-32b-awq](https://huggingface.co/jhu-clsp/rank1-32b-awq) | Quantized version of rank1-32b |
33 | | [rank1-mistral-2501-24b-awq](https://huggingface.co/jhu-clsp/rank1-mistral-2501-24b-awq) | Quantized version of rank1-mistral-24b |
34 | | [rank1-llama3-8b-awq](https://huggingface.co/jhu-clsp/rank1-llama3-8b-awq) | Quantized version of rank1-llama3-8b |
35 |
36 | #### Datasets
37 | | Resource | Description |
38 | |:---------|:------------|
39 | | [rank1-r1-msmarco](https://huggingface.co/datasets/jhu-clsp/rank1-R1-MSMARCO) | All R1 output examples from MS MARCO |
40 | | [rank1-training-data](https://huggingface.co/datasets/jhu-clsp/rank1-training-data) | Training data used for rank1 models |
41 | | [rank1-run-files](https://huggingface.co/datasets/jhu-clsp/rank1-Run-Files) | Pre-computed run files for use in top 100 doc reranking |
42 |
43 | ## Installation
44 | To reproduce the experiments, you can use the following code with uv for fast, reliable dependency management:
45 |
46 | ```bash
47 | git clone https://github.com/orionw/rank1.git
48 | cd rank1/
49 | git submodule update --init --recursive
50 |
51 | # Install uv if you don't have it already
52 | curl -fsSL https://pkg.uv.dev/install.sh | sh
53 |
54 | # Create and activate virtual environment with uv
55 | uv venv env --python=3.10
56 | source env/bin/activate
57 |
58 | # Install dependencies with uv
59 | uv pip install -r requirements.txt
60 | uv pip install -e mteb_branch/
61 | uv pip install --no-build-isolation xformers==0.0.28.post3
62 | uv pip install vllm==0.7.2
63 |
64 | # Recommended: download a flash attention wheel from https://github.com/Dao-AILab/flash-attention/releases and `uv pip install` it
65 | # wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
66 | # uv pip install flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
67 |
68 | # Download the Rank1-Run-Files repository (required for evaluation)
69 | git lfs install # if you don't have it already
70 | git clone https://huggingface.co/datasets/jhu-clsp/rank1-run-files
71 | ```
72 |
73 | ## Usage
74 | ### Tips
75 | **Reproducibility** Depending on your batch size for evaluation you will get minorly different results due to non-determinisms in vLLM. For our experiments we processed all instances in one batch (e.g. batch_size=99999999999). We also found that using the flag `enforce_eager` sped up inference for the smaller models but not for the larger models.
76 |
77 | **Adapting to New Tasks** You may want to use these models on tasks where the relevance definition is different from MS MARCO. For these you will want a custom prompt to let the model know. You can see these in `prompts.py` various datasets.
78 |
79 |
80 | ### Running Evaluations
81 | To run an evaluation with the rank1 model on a specific dataset:
82 |
83 | ```bash
84 | bash launch_job.sh jhu-clsp/rank1-7b NevIR default 1
85 | ```
86 |
87 | Parameters:
88 | - `jhu-clsp/rank1-7b`: Model name or path
89 | - `NevIR`: Dataset name
90 | - `default`: Subtask name (use "default" if no subtask)
91 | - `1`: Number of GPUs to use
92 |
93 |
94 | ### Using Rank1 in Your Own Code
95 | You can integrate rank1 into your code:
96 |
97 | ```python
98 | from rank1 import rank1
99 |
100 | # Initialize the model
101 | model = rank1(
102 | model_name_or_path="jhu-clsp/rank1-7B",
103 | num_gpus=1,
104 | device="cuda",
105 | context_size=16000,
106 | max_output_tokens=8192,
107 | fp_options="float16"
108 | )
109 |
110 | # Rerank documents
111 | results = model.predict({
112 | "query": ["Your query/prompt here", "Same number as docs"],
113 | "corpus": ["Document 1 content", "Document 2 content", ...],
114 | })
115 | ```
116 |
117 | ### MTEB Integration
118 | Rank1 is compatible with the MTEB benchmarking framework. To evaluate your model:
119 |
120 | ```python
121 | from mteb import MTEB
122 | from rank1 import rank1
123 |
124 | # Initialize your model
125 | model = rank1(
126 | model_name_or_path="jhu-clsp/rank1-7b",
127 | num_gpus=1,
128 | device="cuda"
129 | )
130 |
131 | # Select tasks (or use specific task names)
132 | evaluation = MTEB(tasks=["NevIR"])
133 |
134 | # Run evaluation
135 | results = evaluation.run(model)
136 | ```
137 |
138 | ## Citing
139 | If you use rank1 you can cite:
140 |
141 | ```bibtex
142 | @misc{weller2025rank1testtimecomputereranking,
143 | title={Rank1: Test-Time Compute for Reranking in Information Retrieval},
144 | author={Orion Weller and Kathryn Ricci and Eugene Yang and Andrew Yates and Dawn Lawrie and Benjamin Van Durme},
145 | year={2025},
146 | eprint={2502.18418},
147 | archivePrefix={arXiv},
148 | primaryClass={cs.IR},
149 | url={https://arxiv.org/abs/2502.18418},
150 | }
151 | ```
152 |
153 | ## License
154 | [MIT](LICENSE)
155 |
--------------------------------------------------------------------------------
/launch_job.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | sweep_id=$1
5 | # Load Python environment
6 | source env/bin/activate
7 | export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
8 | export VLLM_WORKER_MULTIPROC_METHOD=spawn
9 |
10 | # Print debug information
11 | echo "=== Debug Information ==="
12 | echo "Hostname: $(hostname)"
13 | echo "Current directory: $(pwd)"
14 | echo "Date: $(date)"
15 |
16 | # Print GPU information if nvidia-smi is available
17 | if command -v nvidia-smi &> /dev/null; then
18 | echo -e "\n=== GPU Information ==="
19 | nvidia-smi
20 | else
21 | echo "nvidia-smi not found - no GPU information available"
22 | fi
23 |
24 | # Print memory information
25 | echo -e "\n=== Memory Information ==="
26 | free -h
27 |
28 | # Print CPU information
29 | echo -e "\n=== CPU Information ==="
30 | lscpu | grep "Model name"
31 | lscpu | grep "CPU(s):"
32 |
33 | # Print arguments
34 | echo -e "\n=== Script Arguments ==="
35 | echo "Model: $1"
36 | echo "Dataset: $2"
37 | echo "Subtask: $3"
38 | echo "Number of GPUs: $4"
39 |
40 | echo -e "\n=== Environment ==="
41 | echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
42 | echo "PATH: $PATH"
43 | echo "PYTHONPATH: $PYTHONPATH"
44 |
45 | echo "=== End Debug Information ===\n"
46 |
47 |
48 | model=$1
49 | dataset=$2
50 | subtask=$3
51 | num_gpus=$4
52 |
53 | # if num_gpus is not provided, then set it to 1
54 | if [ -z "$num_gpus" ]; then
55 | num_gpus=1
56 | fi
57 |
58 | # if subtask is "default" ignore it
59 | echo "$(which python)"
60 | if [ "$subtask" != "default" ]; then
61 | echo "Running with subtask: $subtask"
62 | python run_mteb.py -m $model -d $dataset -s $subtask -n $num_gpus
63 | else
64 | echo "Running with no subtask"
65 | python run_mteb.py -m $model -d $dataset -n $num_gpus
66 | fi
67 |
68 | # example: bash launch_job.sh jhu-clsp/Rank1-7B BrightRetrieval biology 1
69 |
--------------------------------------------------------------------------------
/prompts.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 |
4 | PROMPT_DICT = {
5 | "SciFact": """Claim: FILL_QUERY_HERE
6 |
7 | A relevant passage would provide evidence that either **supports** or **refutes** this claim. A passage with any information on any related subpart should be relevant.""",
8 |
9 | "ClimateFEVER": """I am looking to write an essay for and need to find evidence the supports or contradict this statement:
10 |
11 | FILL_QUERY_HERE
12 |
13 | If the passage provides information that supports or contradicts it in any way it is relevant so I can cite it.
14 |
15 | """,
16 |
17 | "TRECCOVID": """FILL_QUERY_HERE If the article answers any part of the question it is relevant.""",
18 |
19 | "ArguAna": """I am looking to write an essay and need to find counterarguments against this statement:
20 |
21 | FILL_QUERY_HERE
22 |
23 | Does this passage have any counterargument or evidence that could be used to help me?
24 |
25 | """,
26 |
27 | "DBPedia": """I am looking to write an essay on this topic and need as much related background information to help me. The topic is:
28 |
29 | FILL_QUERY_HERE
30 |
31 | If the passage provides any background information that could be connected it is relevant.
32 |
33 | """,
34 |
35 | "FiQA2018": """FILL_QUERY_HERE Find a passage that would be a good answer from StackExchange.""",
36 |
37 | "NFCorpus": """Topic: FILL_QUERY_HERE
38 |
39 | Given the above topic, I need to learn about all aspects of it. It does not need to be directly relevant, only tangentially informational. Please mark as relevant any passages with even weak connections. I need to learn fast for my job, which means I need to understand each part individually.
40 |
41 | Again remember, any connection means relevant even if indirect. So if it is not addressed, that is okay -- it does not need to be explicitly.
42 |
43 | Find me passages with any type of connection, including weak connections!!!!""",
44 |
45 | "Touche2020": """I am looking to write an essay and need to find arguments for or against this statement:
46 |
47 | FILL_QUERY_HERE
48 |
49 | Does this passage have any argument or evidence that could be used to help me?
50 |
51 | """,
52 |
53 | "SCIDOCS": """papers that could be cited in FILL_QUERY_HERE. Anything with even indirect relevance should be relevant. This includes papers in the same broader field of science""",
54 |
55 | "BrightRetrieval_aops": """Find different but similar math problems to FILL_QUERY_HERE\n\nA document is relevant if it uses the same class of functions and shares **any** overlapping techniques.""",
56 |
57 | "BrightRetrieval_theoremqa_questions": """Find a passage which uses the same mathematical process as this one: FILL_QUERY_HERE""",
58 |
59 | "BrightRetrieval_leetcode": """I am looking to find different problems that share similar data structures (of any kind) or algorithms (e.g. DFS, DP, sorting, traversals, etc.). I am looking for problems that share one or both of these similarities to this:
60 |
61 | FILL_QUERY_HERE
62 |
63 | Does this passage share any similarities? e.g. if there was a textbook on leetcode problems, this would be in the same book even though it could be in a different chapter.
64 |
65 |
66 | """,
67 |
68 | "BrightRetrieval_pony": """I will use the programming language pony. Problem: FILL_QUERY_HERE
69 |
70 | But to solve the problem above, I need to know things about pony. A passage is relevant if it contains docs that match **any** part (even basic parts) of the code I will have to write for the above program.""",
71 |
72 | "BrightRetrieval": """Can you find background information about the concepts used to answer the question:
73 |
74 | FILL_QUERY_HERE
75 |
76 | A passage is relevant if it contains background information about a **sub-concept** that someone might cite/link to when answering the above question."""
77 |
78 | }
79 |
80 | PROMPT_DICT["BrightRetrieval_theoremqa_theorems"] = PROMPT_DICT["BrightRetrieval_theoremqa_questions"]
81 |
82 |
83 | def get_prompt(task_name, subtask_name: str = None):
84 | if subtask_name is not None and task_name in PROMPT_DICT:
85 | # if subtask is present, use that, otherwise use just the task name
86 | if f"{task_name}_{subtask_name}" in PROMPT_DICT:
87 | return PROMPT_DICT[f"{task_name}_{subtask_name}"]
88 | else: # default for subtask (e.g. BrightRetrieval)
89 | return PROMPT_DICT[task_name]
90 | elif task_name in PROMPT_DICT:
91 | # no subtask
92 | return PROMPT_DICT[task_name]
93 | else:
94 | return None
95 |
96 |
97 | BEIR_DATASETS = [
98 | "ArguAna",
99 | "ClimateFEVER",
100 | "DBPedia",
101 | "FEVER",
102 | "FiQA2018",
103 | "HotpotQA",
104 | "NFCorpus",
105 | "NQ",
106 | "QuoraRetrieval",
107 | "SCIDOCS",
108 | "SciFact",
109 | "TRECCOVID",
110 | "Touche2020",
111 | ]
112 |
113 |
114 | def validate_json(file_path: str) -> bool:
115 | try:
116 | with open(file_path, "r") as f:
117 | data = json.load(f)
118 | # assert there are string keys and that within that a dict of key -> float values
119 | for key in data:
120 | assert isinstance(key, str), f"Key is not a string: {key}"
121 | assert isinstance(data[key], dict), f"Data is not a dict: {data[key]}"
122 | for inner_key, inner_value in data[key].items():
123 | assert isinstance(inner_key, str), f"Inner key is not a string: {inner_key}"
124 | assert isinstance(inner_value, float), f"Inner value is not a float: {inner_value}"
125 | return True
126 | except Exception as e:
127 | print(e)
128 | return False
--------------------------------------------------------------------------------
/rank1.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import mteb
3 | from mteb import MTEB
4 | import logging
5 | import os
6 | import json
7 |
8 | from functools import partial
9 | import logging
10 | import math
11 | from typing import Any, Callable, List, Tuple
12 |
13 | import torch
14 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
15 | from vllm import LLM, SamplingParams
16 |
17 | from mteb.encoder_interface import Encoder
18 | from mteb.evaluation.evaluators.RetrievalEvaluator import DenseRetrievalExactSearch
19 | from mteb.model_meta import ModelMeta
20 | from mteb.models.rerankers_custom import RerankerWrapper
21 |
22 |
23 | logging.basicConfig(level=logging.INFO)
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | class rank1(RerankerWrapper):
28 | name: str = "rank1"
29 |
30 | def __init__(
31 | self,
32 | model_name_or_path: str = "jhu-clsp/rank1-7b",
33 | batch_size: int = 999999999999,
34 | context_size: int = 16000,
35 | max_output_tokens: int = 8192,
36 | fp_options: str = "float16",
37 | num_gpus: int = 1,
38 | device: str = "cuda",
39 | force_rethink: int = 0,
40 | dataset_prompt: str = None,
41 | **kwargs,
42 | ):
43 | """
44 | rank1 is a reasoning reranker model (using test-time compute) which generates a reasoning chain before deciding true or false
45 |
46 | Args:
47 | model_name_or_path: Path to the model or name of the model on HuggingFace Hub
48 | batch_size: Maximum batch size for processing (default: very large number to let vLLM handle batching)
49 | context_size: Maximum context length for the model (default: 4096)
50 | max_output_tokens: Maximum number of tokens to generate (default: 1024)
51 | fp_options: Floating point precision to use, e.g. 'float16' (default: 'float16')
52 | num_gpus: Number of GPUs to use for tensor parallelism (default: 1)
53 | device: Device to load the model on (default: 'cuda')
54 | force_rethink: Number of times to force model to rethink its answer (default: 0)
55 | **kwargs: Additional keyword arguments passed to parent RerankerWrapper
56 | """
57 | super().__init__(model_name_or_path, batch_size=batch_size, fp_options=fp_options, **kwargs)
58 |
59 | self.context_size = context_size
60 | self.max_output_tokens = max_output_tokens
61 | self.num_gpus = num_gpus
62 | self.device = device
63 | self.force_rethink = force_rethink
64 | self.model_name_or_path = model_name_or_path
65 | self.dataset_prompt = dataset_prompt
66 |
67 | # Initialize tokenizer with max length of
68 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
69 | self.tokenizer.padding_side = "left"
70 | self.tokenizer.pad_token = self.tokenizer.eos_token
71 |
72 | # Cache commonly used token IDs
73 | self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0]
74 | self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0]
75 | self.think_token = self.tokenizer("", add_special_tokens=False).input_ids[0]
76 | self.think_end_token = self.tokenizer("", add_special_tokens=False).input_ids[-1]
77 |
78 | self.model = LLM(
79 | model=model_name_or_path,
80 | tensor_parallel_size=int(num_gpus),
81 | trust_remote_code=True,
82 | max_model_len=context_size,
83 | gpu_memory_utilization=0.9,
84 | dtype=fp_options,
85 | )
86 | self.sampling_params = SamplingParams(
87 | temperature=0,
88 | max_tokens=max_output_tokens,
89 | logprobs=20,
90 | stop=[" true", " false"],
91 | skip_special_tokens=False
92 | )
93 |
94 | def _fix_incomplete_responses(
95 | self,
96 | original_prompts: List[str],
97 | generated_texts: List[str]
98 | ) -> Tuple[List[str], List[int], List[float]]:
99 | """
100 | This function is used to fix incomplete responses from the vLLM model. In some cases the model does not generate the end token.
101 | In these cases, we should force it to generate it so that we have some prediction.
102 |
103 | Args:
104 | original_prompts: The original prompts that were used to generate the texts
105 | generated_texts: The texts that were generated by the vLLM model
106 |
107 | Returns:
108 | final_texts: The texts that were generated by the vLLM model + the outputs from the forcing step
109 | token_counts: The number of tokens in the texts total
110 | scores: The scores of the texts
111 | """
112 | cleaned_texts = []
113 | for text in generated_texts:
114 | text = text.rstrip()
115 | if not text.endswith(('.', '!', '?')):
116 | last_punct = max(text.rfind('.'), text.rfind('!'), text.rfind('?'))
117 | if last_punct != -1:
118 | text = text[:last_punct + 1]
119 | cleaned_texts.append(text.strip())
120 |
121 | forced_prompts = [
122 | f"{original_prompt}\n{cleaned_text}\n"
123 | for original_prompt, cleaned_text in zip(original_prompts, cleaned_texts)
124 | ]
125 |
126 | new_sampling_args = SamplingParams(
127 | temperature=0,
128 | max_tokens=1,
129 | logprobs=20,
130 | allowed_token_ids=[self.true_token, self.false_token],
131 | skip_special_tokens=False
132 | )
133 | outputs = self.model.generate(forced_prompts, new_sampling_args)
134 |
135 | # get the next token logits of just the next token
136 | all_final_texts = []
137 | all_token_counts = []
138 | all_scores = []
139 | for i in range(len(outputs)):
140 | try:
141 | text = outputs[i].outputs[0].text
142 | final_logits = outputs[i].outputs[0].logprobs[-1]
143 | assert self.false_token in final_logits and self.true_token in final_logits, f"final logits are missing true or false: {final_logits}"
144 | except Exception as e:
145 | print(f"Error: {e} on fixing error, setting at 0.5 score: {outputs[i].outputs}")
146 | all_scores.append(0.5)
147 | all_token_counts.append(len(outputs[i].outputs[0].token_ids))
148 | all_final_texts.append(text)
149 | continue
150 |
151 | token_count = len(outputs[i].outputs[0].token_ids)
152 | true_logit = final_logits[self.true_token].logprob
153 | false_logit = final_logits[self.false_token].logprob
154 | true_score = math.exp(true_logit)
155 | false_score = math.exp(false_logit)
156 | score = true_score / (true_score + false_score)
157 |
158 | all_final_texts.append(text)
159 | all_token_counts.append(token_count)
160 | all_scores.append(score)
161 |
162 | return all_final_texts, all_token_counts, all_scores
163 |
164 | def truncate_input(self, text: str) -> str:
165 | """
166 | Truncate the input text to the context size. This is not used, except if you are using the Llama 8B quantized model
167 | """
168 | if len(self.tokenizer(text)["input_ids"]) >= self.context_size:
169 | return self.tokenizer.decode(self.tokenizer(text)["input_ids"][:self.context_size])
170 | else:
171 | return text
172 |
173 | def _process_with_vllm(self, prompts):
174 | """
175 | vLLM is significantly faster than HF, so we use it by default. This function handles the cases where the model does not generate the end token.
176 |
177 | Args:
178 | prompts: The prompts to generate from
179 |
180 | Returns:
181 | outputs: The outputs from the vLLM model
182 | """
183 | # prompts = [self.truncate_input(prompt) for prompt in prompts]
184 | outputs = self.model.generate(prompts, self.sampling_params)
185 |
186 | # Pre-allocate lists with None values
187 | total_length = len(prompts)
188 | all_outputs = [None] * total_length
189 | all_output_token_counts = [None] * total_length
190 | all_scores = [None] * total_length
191 |
192 | incomplete_prompts = []
193 | incomplete_texts = []
194 | incomplete_indices = []
195 |
196 | # Process complete responses first
197 | for i, output in enumerate(outputs):
198 | text = output.outputs[0].text
199 | try:
200 | final_logits = output.outputs[0].logprobs[-1]
201 | except Exception as e:
202 | print(f"Error: {e} on getting final logits: {output.outputs[0]}")
203 | incomplete_prompts.append(prompts[i])
204 | incomplete_texts.append(text)
205 | incomplete_indices.append(i)
206 | continue
207 |
208 | if self.true_token not in final_logits or self.false_token not in final_logits:
209 | incomplete_prompts.append(prompts[i])
210 | incomplete_texts.append(text)
211 | incomplete_indices.append(i)
212 | continue
213 |
214 | token_count = len(output.outputs[0].token_ids)
215 | true_logit = final_logits[self.true_token].logprob
216 | false_logit = final_logits[self.false_token].logprob
217 | true_score = math.exp(true_logit)
218 | false_score = math.exp(false_logit)
219 | score = true_score / (true_score + false_score)
220 |
221 | all_outputs[i] = text
222 | all_output_token_counts[i] = token_count
223 | all_scores[i] = score
224 |
225 | # Handle incomplete responses
226 | if incomplete_indices:
227 | fixed_texts, fixed_counts, fixed_scores = self._fix_incomplete_responses(
228 | incomplete_prompts, incomplete_texts
229 | )
230 |
231 | # Fill in the fixed responses at their original positions
232 | for orig_idx, (text, count, score) in zip(
233 | incomplete_indices, zip(fixed_texts, fixed_counts, fixed_scores)
234 | ):
235 | all_outputs[orig_idx] = text
236 | all_output_token_counts[orig_idx] = count
237 | all_scores[orig_idx] = score
238 |
239 | return all_outputs, all_output_token_counts, all_scores
240 |
241 | def return_prompt(self, query, doc_content, prompt) -> str:
242 | query = prompt.replace("FILL_QUERY_HERE", query) if prompt else query
243 | return "Determine if the following passage is relevant to the query. " \
244 | "Answer only with 'true' or 'false'.\n" \
245 | f"Query: {query}\n" \
246 | f"Passage: {doc_content}\n" \
247 | "" # force the model to start with this
248 |
249 | def _prepare_prompts_for_rethink(self, prompts: List[str], texts: List[str], rethink_text: str = "Wait") -> List[str]:
250 | """Prepare prompts for the rethinking step."""
251 | full_texts = [p + t for p, t in zip(prompts, texts)]
252 | stripped_texts = [t.split("")[0] for t in full_texts]
253 | just_generated_texts = [t.split("")[0] for t in full_texts]
254 | return [s + f"\n{rethink_text}" for s in stripped_texts], just_generated_texts
255 |
256 | @torch.inference_mode()
257 | def predict(self, input_to_rerank, **kwargs):
258 | """This is setup to run with mteb but can be adapted to your purpose"""
259 | inputs = list(zip(*input_to_rerank))
260 | if len(input_to_rerank[0]) == 2:
261 | queries, passages = inputs
262 | instructions = None
263 | else:
264 | queries, passages, instructions = inputs
265 |
266 | if instructions is not None and instructions[0] is not None:
267 | queries = [f"{q} {i}".strip() if q.strip() != i.strip() else q.strip() for i, q in zip(instructions, queries)]
268 |
269 | if isinstance(passages[0], dict):
270 | passages = [f"{v['title']} {v['text']}" if 'title' in v else v['text'] for v in passages]
271 |
272 | prompts = [
273 | self.return_prompt(query, passage, self.dataset_prompt)
274 | for query, passage in zip(queries, passages)
275 | ]
276 | print(f"Example prompt: ```\n{prompts[0]}\n```")
277 |
278 | texts, token_counts, scores = self._process_with_vllm(prompts)
279 | while self.force_rethink:
280 | revised_prompts, previously_generated_texts = self._prepare_prompts_for_rethink(prompts, texts)
281 | new_texts, new_token_counts, new_scores = self._process_with_vllm(revised_prompts)
282 | # add to the previous output
283 | texts = [prev + f"\n{rethink_text}" + f"{new_text}" for prev, new_text in zip(texts, new_texts)]
284 | scores = new_scores
285 | token_counts = [prev_token_count + new_token_count for prev_token_count, new_token_count in zip(token_counts, new_token_counts)]
286 | self.force_rethink -= 1
287 |
288 | return scores
289 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.3.0
2 | bm25s==0.2.7.post1
3 | datasets==3.2.0
4 | einops==0.8.1
5 | gritlm==1.0.2
6 | huggingface-hub==0.28.1
7 | numpy==1.26.4
8 | pystemmer==2.2.0.3
9 | scikit-learn==1.6.1
10 | scipy==1.13.1
11 | sentence-transformers==3.4.1
12 | sentencepiece==0.2.0
13 | tiktoken==0.7.0
14 | tokenizers==0.21.0
15 | tomlkit==0.12.0
16 | torch==2.5.1
17 | torchaudio==2.5.1
18 | torchvision==0.20.1
19 | tqdm==4.67.1
20 | transformers==4.48.2
21 | peft
22 |
--------------------------------------------------------------------------------
/run_mteb.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import argparse
4 | import mteb
5 | from mteb import MTEB
6 | import logging
7 | import os
8 | import json
9 |
10 | from functools import partial
11 | import logging
12 | import math
13 | from typing import Any, Callable, List, Tuple
14 |
15 | from mteb.encoder_interface import Encoder
16 | from mteb.evaluation.evaluators.RetrievalEvaluator import DenseRetrievalExactSearch
17 | from mteb.model_meta import ModelMeta
18 | from mteb.models.rerankers_custom import RerankerWrapper
19 |
20 | from prompts import get_prompt, PROMPT_DICT, validate_json, BEIR_DATASETS
21 | from rank1 import rank1
22 |
23 | logging.basicConfig(level=logging.INFO)
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def get_safe_folder_name(model_name):
28 | return model_name.replace("/", "_").replace("\\", "_")
29 |
30 | def run_evaluation(dataset_name: str, subtask: str, model_name: str, num_gpus: int, skip_prompt: bool) -> None:
31 | """Run MTEB evaluation for specified model and dataset
32 |
33 | Args:
34 | dataset_name: Name of MTEB dataset to evaluate
35 | """
36 | # Initialize MTEB task and evaluation
37 | tasks = mteb.get_tasks(tasks=[dataset_name])
38 | evaluation = MTEB(tasks=tasks)
39 |
40 | if dataset_name == "BrightRetrieval":
41 | previous_results = f"rank1-run-files/{subtask}_bm25_long_False/score.json"
42 | eval_splits = ["standard"]
43 | elif dataset_name == "mFollowIRCrossLingual":
44 | previous_results = f"rank1-run-files/mFollowIRCrossLingual_{subtask}_predictions.json"
45 | eval_splits = None
46 | elif dataset_name == "mFollowIR":
47 | previous_results = f"rank1-run-files/mFollowIR_{subtask}_predictions.json"
48 | eval_splits = None
49 | elif dataset_name in BEIR_DATASETS:
50 | eval_splits = ["test"]
51 | previous_results = f"rank1-run-files/{dataset_name}_default_predictions.json"
52 | else:
53 | print(f"Running with no subtask or eval splits for dataset: {dataset_name}")
54 | previous_results = None
55 | eval_splits = None
56 |
57 | encode_kwargs = {
58 | # use vLLM to batch
59 | "batch_size": 999999 if "rank1" in model_name.lower() else 32
60 | }
61 |
62 | prompt = get_prompt(dataset_name, subtask)
63 | if prompt is not None and not skip_prompt:
64 | is_prompted = True
65 | else:
66 | is_prompted = False
67 | prompt = None
68 | print(f"Prompt: {prompt}")
69 |
70 | if subtask == "default":
71 | subtask = None
72 |
73 | if previous_results is not None:
74 | assert validate_json(previous_results), f"Previous results are not valid json: {previous_results}"
75 | print(f"Previous results: {previous_results}")
76 |
77 | model = rank1(model_name_or_path=model_name.strip(), num_gpus=num_gpus, dataset_prompt=prompt)
78 | output_dir = f"results/{model_name}/{dataset_name}_{subtask}"
79 | print(f"Output directory: {output_dir}")
80 |
81 | # Run evaluation
82 | evaluation.run(
83 | model,
84 | save_predictions=True,
85 | encode_kwargs=encode_kwargs,
86 | output_folder=output_dir,
87 | previous_results=previous_results,
88 | eval_subsets=[subtask] if previous_results else None,
89 | eval_splits=eval_splits,
90 | top_k=100
91 | )
92 |
93 |
94 | def main():
95 | parser = argparse.ArgumentParser(description="Run MTEB evaluation")
96 | parser.add_argument("-d", "--dataset", required=True, help="MTEB dataset name")
97 | parser.add_argument("-s", "--subtask", required=False, default=None, help="MTEB subtask name")
98 | parser.add_argument("-m", "--model_name", required=True, help="Model name")
99 | parser.add_argument("-n", "--num_gpus", required=False, help="Number of GPUs", default=1)
100 | parser.add_argument("-p", "--skip_prompt", action="store_true", help="Skip prompt")
101 | args = parser.parse_args()
102 | run_evaluation(args.dataset.strip(), args.subtask.strip() if args.subtask is not None else None, args.model_name.strip(), args.num_gpus, args.skip_prompt)
103 |
104 |
105 | if __name__ == "__main__":
106 | os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
107 | main()
108 |
--------------------------------------------------------------------------------
/train_configs/README.md:
--------------------------------------------------------------------------------
1 | # Rank1 Model Training Configurations
2 |
3 | This directory contains configuration files for training and exporting Rank1 models using [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).
4 |
5 | ## Getting Started
6 |
7 | ### 1. Setup Training Data
8 |
9 | We use the data in `jhu-clsp/rank1-training-data` for training. Follow these steps to set up:
10 |
11 | 1. Download the training data and store it in LLaMA-Factory's data folder (called `train.json`)
12 | 2. Update the `dataset_info.json` file to reference your data:
13 |
14 | ```json
15 | {
16 | "r1_reranking_self_filtered_negs_only": {
17 | "file_name": "train.json"
18 | }
19 | }
20 | ```
21 |
22 | ### 2. Available Training Configurations
23 |
24 | The following configuration files are available:
25 |
26 | - `train_lora_llama.yaml` - LoRA training configuration for Llama-3 8B
27 | - `train_lora_mistral.yaml` - LoRA training configuration for Mistral 24B
28 | - `train_lora_qwen_7b.yaml` - LoRA training configuration for Qwen 7B
29 | - `train_lora_qwen_14b.yaml` - LoRA training configuration for Qwen 14B
30 | - `train_lora_qwen_32b.yaml` - LoRA training configuration for Qwen 32B
31 |
32 | ### 3. Training Process
33 |
34 | After installing LLaMA-Factory requirements, start training with:
35 |
36 | ```bash
37 | llamafactory-cli train train_configs/train_lora_XXX.yaml
38 | ```
39 |
40 | Replace `XXX` with the specific model you want to train (e.g., `llama`, `mistral`, `qwen_7b`).
41 |
42 | ### 4. Exporting Model Weights
43 |
44 | After training, export the merged LoRA weights using:
45 |
46 | ```bash
47 | llamafactory-cli export train_configs/export_model.yaml
48 | ```
49 |
50 | Before running the export command, modify the `export_model.yaml` file with your specific values:
51 |
52 | ```yaml
53 | model_name_or_path: BASE_MODEL_HERE # Path to the base model
54 | adapter_name_or_path: INPUT_DIR # Path to the trained adapter
55 | export_dir: EXPORT_DIR # Where to save the exported model
56 | ```
57 |
58 | ## Configuration Parameters
59 |
60 | Each YAML file contains configuration sections for:
61 | - Model settings
62 | - Training method (LoRA parameters)
63 | - Dataset configuration
64 | - Output settings
65 | - Training hyperparameters
66 | - Evaluation settings
67 |
68 | Refer to each specific YAML file for detailed configuration options.
69 |
70 |
--------------------------------------------------------------------------------
/train_configs/export_model.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: BASE_MODEL_HERE
3 | adapter_name_or_path: INPUT_DIR
4 | template: empty
5 | finetuning_type: lora
6 | trust_remote_code: true
7 |
8 | ### export
9 | export_dir: EXPORT_DIR
10 | export_size: 10
11 | export_device: cpu
12 | export_legacy_format: false
13 | new_special_tokens: ","
14 | resize_vocab: true
--------------------------------------------------------------------------------
/train_configs/train_lora_llama.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Meta-Llama-3-8B
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: lora
9 | lora_rank: 32
10 | lora_alpha: 64
11 | lora_target: all
12 |
13 | ### dataset
14 | dataset: r1_reranking_self_filtered_negs_only
15 | template: empty
16 | cutoff_len: 2500
17 | overwrite_cache: false
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: saves/llama-8b/lora/sft
22 | logging_steps: 10
23 | plot_loss: true
24 | overwrite_output_dir: true
25 |
26 | ### train
27 | per_device_train_batch_size: 4
28 | gradient_accumulation_steps: 8
29 | learning_rate: 1.0e-4
30 | num_train_epochs: 2.0
31 | lr_scheduler_type: cosine
32 | warmup_ratio: 0.05
33 | bf16: true
34 | ddp_timeout: 180000000
35 | new_special_tokens: ","
36 | resize_vocab: true
37 |
38 | ### eval
39 | val_size: 0.01
40 | per_device_eval_batch_size: 1
41 | eval_strategy: steps
42 | eval_steps: 999999999999
43 | save_steps: 250
44 | seed: 12345
45 |
46 | report_to: wandb
47 | run_name: llama_8b_lora
--------------------------------------------------------------------------------
/train_configs/train_lora_mistral.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: mistralai/Mistral-Small-24B-Base-2501
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: lora
9 | lora_rank: 32
10 | lora_alpha: 64
11 | lora_target: all
12 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
13 |
14 | ### dataset
15 | dataset: r1_reranking_self_filtered_negs_only
16 | template: empty
17 | cutoff_len: 2500
18 | overwrite_cache: false
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/mistral-24b/lora/sft
23 | logging_steps: 10
24 | plot_loss: true
25 | overwrite_output_dir: true
26 |
27 | ### train
28 | per_device_train_batch_size: 2
29 | gradient_accumulation_steps: 16
30 | learning_rate: 1.0e-4
31 | num_train_epochs: 2.0
32 | lr_scheduler_type: cosine
33 | warmup_ratio: 0.05
34 | bf16: true
35 | ddp_timeout: 180000000
36 | new_special_tokens: ","
37 | resize_vocab: true
38 |
39 | ### eval
40 | val_size: 0.01
41 | per_device_eval_batch_size: 1
42 | eval_strategy: steps
43 | eval_steps: 999999999999
44 | save_steps: 250
45 | seed: 12345
46 |
47 | report_to: wandb
48 | run_name: mistral_24b_lora
--------------------------------------------------------------------------------
/train_configs/train_lora_qwen_14b.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-14B
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: lora
9 | lora_rank: 32
10 | lora_alpha: 64
11 | lora_target: all
12 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
13 |
14 | ### dataset
15 | dataset: r1_reranking_self_filtered_negs_only
16 | template: empty
17 | cutoff_len: 2500
18 | overwrite_cache: false
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/qwen-14b/lora/sft
23 | logging_steps: 10
24 | plot_loss: true
25 | overwrite_output_dir: true
26 |
27 | ### train
28 | per_device_train_batch_size: 2
29 | gradient_accumulation_steps: 16
30 | learning_rate: 1.0e-4
31 | num_train_epochs: 2.0
32 | lr_scheduler_type: cosine
33 | warmup_ratio: 0.05
34 | bf16: true
35 | ddp_timeout: 180000000
36 | new_special_tokens: ","
37 | resize_vocab: true
38 |
39 | ### eval
40 | val_size: 0.01
41 | per_device_eval_batch_size: 1
42 | eval_strategy: steps
43 | eval_steps: 999999999999
44 | save_steps: 250
45 | seed: 12345
46 |
47 | report_to: wandb
48 | run_name: qwen_14b_lora
--------------------------------------------------------------------------------
/train_configs/train_lora_qwen_32b.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-32B
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: lora
9 | lora_rank: 32
10 | lora_alpha: 64
11 | lora_target: all
12 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
13 |
14 | ### dataset
15 | dataset: r1_reranking_self_filtered_negs_only
16 | template: empty
17 | cutoff_len: 2500
18 | overwrite_cache: false
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/qwen-32b/lora/sft
23 | logging_steps: 10
24 | plot_loss: true
25 | overwrite_output_dir: true
26 |
27 | ### train
28 | per_device_train_batch_size: 1
29 | gradient_accumulation_steps: 32
30 | learning_rate: 1.0e-4
31 | num_train_epochs: 2.0
32 | lr_scheduler_type: cosine
33 | warmup_ratio: 0.05
34 | bf16: true
35 | ddp_timeout: 180000000
36 | new_special_tokens: ","
37 | resize_vocab: true
38 |
39 | ### eval
40 | val_size: 0.01
41 | per_device_eval_batch_size: 1
42 | eval_strategy: steps
43 | eval_steps: 999999999999
44 | save_steps: 250
45 | seed: 12345
46 |
47 | report_to: wandb
48 | run_name: qwen_32b_lora
--------------------------------------------------------------------------------
/train_configs/train_lora_qwen_7b.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-7B
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: lora
9 | lora_rank: 32
10 | lora_alpha: 64
11 | lora_target: all
12 |
13 | ### dataset
14 | dataset: r1_reranking_self_filtered_negs_only
15 | template: empty
16 | cutoff_len: 2500
17 | overwrite_cache: false
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: saves/qwen-7b/lora/sft
22 | logging_steps: 10
23 | plot_loss: true
24 | overwrite_output_dir: true
25 |
26 | ### train
27 | per_device_train_batch_size: 4
28 | gradient_accumulation_steps: 8
29 | learning_rate: 1.0e-4
30 | num_train_epochs: 2.0
31 | lr_scheduler_type: cosine
32 | warmup_ratio: 0.05
33 | bf16: true
34 | ddp_timeout: 180000000
35 | new_special_tokens: ","
36 | resize_vocab: true
37 |
38 | ### eval
39 | val_size: 0.01
40 | per_device_eval_batch_size: 1
41 | eval_strategy: steps
42 | eval_steps: 999999999999
43 | save_steps: 250
44 | seed: 12345
45 |
46 | report_to: wandb
47 | run_name: qwen_7b_lora
--------------------------------------------------------------------------------