├── llm2vec ├── __init__.py ├── dataset │ ├── __init__.py │ ├── utils.py │ ├── dataset.py │ └── E5Data.py ├── loss │ ├── __init__.py │ ├── utils.py │ ├── HardNegativeNLLLoss.py │ └── loss_utils.py ├── experiment_utils.py └── llm2vec.py ├── images └── main_new-1.png ├── requirements.txt ├── scripts ├── mteb_eval.py └── run_supervised.py ├── train_config ├── Minicpm2-2B.json └── Minicpm3-4B.json ├── test_config └── mteb │ └── task_to_instructions.json ├── llm2vec_models.py └── README.md /llm2vec/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm2vec/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .E5Data import E5Data 2 | -------------------------------------------------------------------------------- /llm2vec/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .HardNegativeNLLLoss import HardNegativeNLLLoss 2 | -------------------------------------------------------------------------------- /images/main_new-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBMB/DEBATER/HEAD/images/main_new-1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | torch 4 | peft 5 | transformers>=4.43.1,<=4.44.2 6 | datasets 7 | evaluate 8 | scikit-learn 9 | mteb>=1.14.12 -------------------------------------------------------------------------------- /llm2vec/loss/utils.py: -------------------------------------------------------------------------------- 1 | from .HardNegativeNLLLoss import HardNegativeNLLLoss 2 | 3 | 4 | def load_loss(loss_class, *args, **kwargs): 5 | if loss_class == "HardNegativeNLLLoss": 6 | loss_cls = HardNegativeNLLLoss 7 | else: 8 | raise ValueError(f"Unknown loss class {loss_class}") 9 | return loss_cls(*args, **kwargs) 10 | -------------------------------------------------------------------------------- /scripts/mteb_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | import mteb 5 | from mteb import MTEB 6 | 7 | logging.basicConfig(level=logging.WARNING) 8 | logger = logging.getLogger(__name__) 9 | 10 | if __name__ == "__main__": 11 | tasks = mteb.get_tasks(tasks=["ArguAna"], languages=["eng"]) 12 | 13 | evaluation = MTEB(tasks=tasks) 14 | model_kwargs = {} 15 | with open("../test_config/mteb/task_to_instructions.json", "r") as f: 16 | task_to_instructions = json.load(f) 17 | model_kwargs["task_to_instructions"] = task_to_instructions 18 | 19 | model = mteb.get_model("xxxxx", **model_kwargs) # Same name as defined in llm2vec_models.py 20 | 21 | evaluation.run(model, output_folder="xxxxx",eval_splits=["test"]) 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /llm2vec/dataset/utils.py: -------------------------------------------------------------------------------- 1 | from ..dataset import E5Data 2 | 3 | def load_dataset(dataset_name, split="validation", file_path=None, **kwargs): 4 | """ 5 | Loads a dataset by name. 6 | 7 | Args: 8 | dataset_name (str): Name of the dataset to load. 9 | split (str): Split of the dataset to load. 10 | file_path (str): Path to the dataset file. 11 | """ 12 | dataset_mapping = { 13 | "E5": E5Data, 14 | } 15 | 16 | if dataset_name not in dataset_mapping: 17 | raise NotImplementedError(f"Dataset name {dataset_name} not supported.") 18 | 19 | if split not in ["train", "validation", "test"]: 20 | raise NotImplementedError(f"Split {split} not supported.") 21 | 22 | return dataset_mapping[dataset_name](split=split, file_path=file_path, **kwargs) 23 | -------------------------------------------------------------------------------- /train_config/Minicpm2-2B.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "xxxx", 3 | "peft_model_name_or_path": null, 4 | "pooling_mode": "M_token", 5 | "dataset_name": "E5", 6 | "dataset_file_path": "xxxx/data/echo-data", 7 | "remove_unused_columns": false, 8 | "learning_rate": 2e-4, 9 | "num_train_epochs": 1, 10 | "warmup_steps": 300, 11 | "per_device_train_batch_size": 64, 12 | "per_device_eval_batch_size": 64, 13 | "gradient_accumulation_steps": 1, 14 | "do_train": true, 15 | "disable_tqdm": false, 16 | "max_seq_length": 512, 17 | "overwrite_output_dir": true, 18 | "output_dir": "xxxxxx", 19 | "logging_steps": 50, 20 | "save_steps": 200, 21 | "save_only_model": true, 22 | "stop_after_n_steps": 1000, 23 | "lora_r": 16, 24 | "gradient_checkpointing": true, 25 | "torch_dtype": "bfloat16", 26 | "attn_implementation": "flash_attention_2", 27 | "seed": 42 28 | } -------------------------------------------------------------------------------- /train_config/Minicpm3-4B.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "xxxx", 3 | "peft_model_name_or_path": null, 4 | "pooling_mode": "M_token", 5 | "dataset_name": "E5", 6 | "dataset_file_path": "xxxx/data/echo-data", 7 | "remove_unused_columns": false, 8 | "learning_rate": 2e-4, 9 | "num_train_epochs": 1, 10 | "warmup_steps": 300, 11 | "per_device_train_batch_size": 64, 12 | "per_device_eval_batch_size": 64, 13 | "gradient_accumulation_steps": 1, 14 | "do_train": true, 15 | "disable_tqdm": false, 16 | "max_seq_length": 512, 17 | "overwrite_output_dir": true, 18 | "output_dir": "xxxxxx", 19 | "logging_steps": 50, 20 | "save_steps": 200, 21 | "save_only_model": true, 22 | "stop_after_n_steps": 1000, 23 | "lora_r": 16, 24 | "gradient_checkpointing": true, 25 | "torch_dtype": "bfloat16", 26 | "attn_implementation": "flash_attention_2", 27 | "seed": 42 28 | } -------------------------------------------------------------------------------- /llm2vec/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union, List 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class DataSample: 9 | id_: int 10 | query: str 11 | positive: str 12 | negative: str = None 13 | task_name: str = None 14 | 15 | 16 | class TrainSample: 17 | """ 18 | Structure for one input example with texts, the label and a unique id 19 | """ 20 | 21 | def __init__( 22 | self, guid: str = "", texts: List[str] = None, label: Union[int, float] = 0 23 | ): 24 | """ 25 | Creates one TrainSample with the given texts, guid and label 26 | 27 | 28 | :param guid 29 | id for the example 30 | :param texts 31 | the texts for the example. 32 | :param label 33 | the label for the example 34 | """ 35 | self.guid = guid 36 | self.texts = texts 37 | self.label = label 38 | 39 | def __str__(self): 40 | return " label: {}, texts: {}".format( 41 | str(self.label), "; ".join(self.texts) 42 | ) 43 | 44 | 45 | class Dataset(torch.utils.data.Dataset): 46 | def load_data(self, file_path: str = None): 47 | raise NotImplementedError() 48 | 49 | def __getitem__(self, index): 50 | raise NotImplementedError() 51 | 52 | def __len__(self): 53 | raise NotImplementedError() 54 | -------------------------------------------------------------------------------- /llm2vec/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def generate_experiment_id( 4 | name, 5 | split, 6 | model_name, 7 | pooling_mode, 8 | train_batch_size, 9 | max_seq_length, 10 | epochs, 11 | seed, 12 | warmup_steps, 13 | lr, 14 | lora_r, 15 | ): 16 | experiment_id = name + "_" + split 17 | 18 | if isinstance(model_name, str): 19 | experiment_id += f"_m-{model_name}" 20 | if isinstance(pooling_mode, str): 21 | experiment_id += f"_p-{pooling_mode}" 22 | if isinstance(train_batch_size, int): 23 | experiment_id += f"_b-{train_batch_size}" 24 | if isinstance(max_seq_length, int): 25 | experiment_id += f"_l-{max_seq_length}" 26 | if isinstance(epochs, int): 27 | experiment_id += f"_e-{epochs}" 28 | if isinstance(seed, int): 29 | experiment_id += f"_s-{seed}" 30 | if isinstance(warmup_steps, int): 31 | experiment_id += f"_w-{warmup_steps}" 32 | if isinstance(lr, float): 33 | experiment_id += f"_lr-{lr}" 34 | if isinstance(lora_r, int): 35 | experiment_id += f"_lora_r-{lora_r}" 36 | 37 | return experiment_id 38 | 39 | 40 | def parse_experiment_id(experiment_id): 41 | """ 42 | Parses experiment identifier into key-value pairs. 43 | 44 | Args: 45 | experiment_id (str): Unique experiment identifier to parse. 46 | 47 | Returns: 48 | dict: Dictionary containing the parsed key-value pairs. 49 | """ 50 | regex, post_regex = "", "" 51 | if "/" in experiment_id: 52 | regex = "([A-Za-z0-9-_./]*)/" 53 | post_regex = "/([A-Za-z0-9-_./]*)" 54 | regex += "([A-Za-z0-9-_.]+)" 55 | regex += "_m-([A-Z-a-z0-9-_.]+)" 56 | regex += "_p-([A-Z-a-z0-9-_.]+)" 57 | regex += "_b-(\d+)" 58 | regex += "_l-(\d+)" 59 | regex += "_e-(\d+)" 60 | regex += "_s-(\d+)" 61 | regex += "_w-(\d+)" 62 | regex += "_lr-([A-Z-a-z0-9-_.]+)" 63 | regex += "_lora_r-(\d+)" 64 | regex += post_regex 65 | 66 | parts = re.match(regex, experiment_id).groups() 67 | if post_regex != "": 68 | parts = parts[1:-1] 69 | 70 | result = { 71 | "name": parts[0], 72 | "model_name_or_path": parts[1], 73 | "pooling_mode": parts[2], 74 | "train_batch_size": int(parts[3]), 75 | "max_seq_length": int(parts[4]), 76 | "epochs": int(parts[6]), 77 | "seed": int(parts[7]), 78 | "warmup_steps": int(parts[8]), 79 | "lr": float(parts[9]), 80 | "lora_r": int(parts[10]), 81 | } 82 | 83 | return result 84 | -------------------------------------------------------------------------------- /llm2vec/loss/HardNegativeNLLLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from .loss_utils import cos_sim, mismatched_sizes_all_gather 4 | 5 | 6 | class HardNegativeNLLLoss: 7 | def __init__( 8 | self, 9 | scale: float = 30.0, 10 | similarity_fct=cos_sim, 11 | ): 12 | self.scale = scale 13 | self.similarity_fct = similarity_fct 14 | self.cross_entropy_loss = nn.CrossEntropyLoss() 15 | 16 | def __call__( 17 | self, 18 | q_reps: Tensor, 19 | d_reps_pos: Tensor, 20 | d_reps_neg: Tensor = None, 21 | ): 22 | if d_reps_neg is None: 23 | d_reps_neg = d_reps_pos[:0, :] 24 | 25 | if torch.distributed.is_initialized(): 26 | full_d_reps_pos = mismatched_sizes_all_gather(d_reps_pos) 27 | full_d_reps_pos = torch.cat(full_d_reps_pos) 28 | 29 | full_q_reps = mismatched_sizes_all_gather(q_reps) 30 | full_q_reps = torch.cat(full_q_reps) 31 | 32 | full_d_reps_neg = mismatched_sizes_all_gather(d_reps_neg) 33 | full_d_reps_neg = torch.cat(full_d_reps_neg) 34 | else: 35 | full_d_reps_pos = d_reps_pos 36 | full_q_reps = q_reps 37 | full_d_reps_neg = d_reps_neg 38 | 39 | d_reps_pos_last8_loss1 = full_d_reps_pos[:, -1, :] 40 | d_reps_neg_last8_loss1 = full_d_reps_neg[:, -1, :] 41 | 42 | d_reps_last8_loss1 = torch.cat([d_reps_pos_last8_loss1, d_reps_neg_last8_loss1], dim=0) 43 | scores_loss1 = self.similarity_fct(full_q_reps, d_reps_last8_loss1) 44 | scores_loss1 = scores_loss1 * self.scale 45 | d_reps_pos_last8 = full_d_reps_pos 46 | d_reps_neg_last8 = full_d_reps_neg 47 | 48 | d_reps_last8 = torch.cat([d_reps_pos_last8, d_reps_neg_last8], dim=0) 49 | q_b = full_q_reps.size(0) 50 | d_b = d_reps_last8.size(0) 51 | d_views = d_reps_last8.size(1) 52 | h_dim = d_reps_last8.size(2) 53 | 54 | d_reps_last8 = d_reps_last8.reshape(d_b * d_views, h_dim) 55 | sim_scores = self.similarity_fct(full_q_reps, d_reps_last8) 56 | max_scores, _ = torch.max(sim_scores.view(q_b, d_b, d_views), dim=2) 57 | scores = max_scores * self.scale 58 | 59 | labels_scores = torch.tensor( 60 | range(len(scores)), dtype=torch.long, device=scores.device 61 | ) 62 | 63 | loss1 = self.cross_entropy_loss(scores, labels_scores) 64 | 65 | teacher_targets = torch.softmax(scores.detach(), dim=-1) 66 | loss2 = - torch.mean( 67 | torch.sum(torch.log_softmax(scores_loss1, dim=-1) * teacher_targets, dim=-1)) 68 | 69 | loss = (loss1 + loss2) / 2 70 | 71 | return loss 72 | -------------------------------------------------------------------------------- /test_config/mteb/task_to_instructions.json: -------------------------------------------------------------------------------- 1 | { 2 | "ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim:", 3 | "HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question:", 4 | "FEVER": "Given a claim, retrieve documents that support or refute the claim:", 5 | "MSMARCO": "Given a web search query, retrieve relevant passages that answer the query:", 6 | "DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia:", 7 | "NQ": "Given a question, retrieve Wikipedia passages that answer the question:", 8 | "QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question:", 9 | "SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper:", 10 | "TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query:", 11 | "Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question:", 12 | "SciFact": "Given a scientific claim, retrieve documents that support or refute the claim:", 13 | "NFCorpus": "Given a question, retrieve relevant documents that best answer the question:", 14 | "ArguAna": "Given a claim, find documents that refute the claim:", 15 | "CQADupstackTexRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 16 | "CQADupstackWebmastersRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 17 | "CQADupstackEnglishRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 18 | "CQADupstackGamingRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 19 | "CQADupstackGisRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 20 | "CQADupstackUnixRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 21 | "CQADupstackMathematicaRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 22 | "CQADupstackStatsRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 23 | "CQADupstackPhysicsRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 24 | "CQADupstackProgrammersRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 25 | "CQADupstackAndroidRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 26 | "CQADupstackWordpressRetrieval": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question:", 27 | "FiQA2018": "Given a financial question, retrieve user replies that best answer the question:" 28 | } -------------------------------------------------------------------------------- /llm2vec/loss/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | class AllGather(torch.autograd.Function): 6 | """ 7 | all_gather with gradient back-propagation 8 | """ 9 | 10 | @staticmethod 11 | def forward(ctx, tensor_list, tensor, group, async_op): 12 | torch.distributed.all_gather( 13 | tensor_list, tensor, group=group, async_op=async_op 14 | ) 15 | return tuple(tensor_list) 16 | 17 | @staticmethod 18 | def backward(ctx, *grad_list): 19 | grad_list = list(grad_list) 20 | rank = torch.distributed.get_rank() 21 | 22 | dist_ops = [ 23 | torch.distributed.reduce(grad_list[i], i, async_op=True) 24 | for i in range(torch.distributed.get_world_size()) 25 | ] 26 | 27 | for op in dist_ops: 28 | op.wait() 29 | 30 | return None, grad_list[rank], None, None 31 | 32 | 33 | all_gather_with_grad = AllGather.apply 34 | 35 | 36 | def cos_sim(a: Tensor, b: Tensor): 37 | """ 38 | Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. 39 | :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) 40 | """ 41 | if not isinstance(a, torch.Tensor): 42 | a = torch.tensor(a) 43 | 44 | if not isinstance(b, torch.Tensor): 45 | b = torch.tensor(b) 46 | 47 | if len(a.shape) == 1: 48 | a = a.unsqueeze(0) 49 | 50 | if len(b.shape) == 1: 51 | b = b.unsqueeze(0) 52 | 53 | a_norm = torch.nn.functional.normalize(a, p=2, dim=1) 54 | b_norm = torch.nn.functional.normalize(b, p=2, dim=1) 55 | return torch.mm(a_norm, b_norm.transpose(0, 1)) 56 | 57 | 58 | def mismatched_sizes_all_gather( 59 | tensor: Tensor, group=None, async_op=False, mismatched_axis=0 60 | ): 61 | # all_gather doesn't support tensor lists where the first dimension is mismatched. This does. 62 | assert torch.distributed.is_initialized(), "torch.distributed not initialized" 63 | world_size = torch.distributed.get_world_size() 64 | # let's get the sizes for everyone 65 | mismatched_sizes = torch.tensor( 66 | [tensor.shape[mismatched_axis]], dtype=torch.int64, device="cuda" 67 | ) 68 | sizes = [torch.zeros_like(mismatched_sizes) for _ in range(world_size)] 69 | torch.distributed.all_gather( 70 | sizes, mismatched_sizes, group=group, async_op=async_op 71 | ) 72 | sizes = torch.cat(sizes).cpu().tolist() 73 | # now pad to the max dim-0 size 74 | max_size = max(sizes) 75 | padded = torch.zeros( 76 | ( 77 | *tensor.shape[:mismatched_axis], 78 | max_size, 79 | *tensor.shape[mismatched_axis + 1 :], 80 | ), 81 | device=tensor.device, 82 | dtype=tensor.dtype, 83 | ) 84 | # selects the place where we're adding information 85 | padded_to_fill = padded.narrow(mismatched_axis, 0, tensor.shape[mismatched_axis]) 86 | padded_to_fill[...] = tensor 87 | # gather the padded tensors 88 | tensor_list = [ 89 | torch.zeros(padded.shape, device=padded.device, dtype=padded.dtype) 90 | for _ in range(world_size) 91 | ] 92 | all_gather_with_grad(tensor_list, padded, group, async_op) 93 | # trim off the padding 94 | for rank in range(world_size): 95 | # checks that the rest is 0 96 | assert ( 97 | not tensor_list[rank] 98 | .narrow( 99 | mismatched_axis, 100 | sizes[rank], 101 | padded.shape[mismatched_axis] - sizes[rank], 102 | ) 103 | .count_nonzero() 104 | .is_nonzero() 105 | ), "This would remove non-padding information" 106 | tensor_list[rank] = tensor_list[rank].narrow(mismatched_axis, 0, sizes[rank]) 107 | return tensor_list 108 | -------------------------------------------------------------------------------- /llm2vec_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Any, Callable, Literal 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from mteb.encoder_interface import Encoder 10 | from mteb.model_meta import ModelMeta 11 | from mteb.models.text_formatting_utils import corpus_to_texts 12 | 13 | from .instructions import task_to_instruction 14 | 15 | logging.basicConfig(level=logging.WARNING) 16 | logger = logging.getLogger(__name__) 17 | 18 | EncodeTypes = Literal["query", "passage"] 19 | 20 | class LLM2VecWrapper: 21 | def __init__(self, *args, **kwargs): 22 | try: 23 | from llm2vec import LLM2Vec 24 | except ImportError: 25 | raise ImportError( 26 | "To use the LLM2Vec models `llm2vec` is required. Please install it with `pip install llm2vec`." 27 | ) 28 | extra_kwargs = {} 29 | try: 30 | import flash_attn # noqa 31 | 32 | extra_kwargs["attn_implementation"] = "flash_attention_2" 33 | except ImportError: 34 | logger.warning( 35 | "LLM2Vec models were trained with flash attention enabled. For optimal performance, please install the `flash_attn` package with `pip install flash-attn --no-build-isolation`." 36 | ) 37 | self.task_to_instructions = None 38 | if "task_to_instructions" in kwargs: 39 | self.task_to_instructions = kwargs.pop("task_to_instructions") 40 | 41 | if "device" in kwargs: 42 | kwargs["device_map"] = kwargs.pop("device") 43 | elif torch.cuda.device_count() > 1: 44 | kwargs["device_map"] = None 45 | 46 | self.model = LLM2Vec.from_pretrained(*args, **extra_kwargs, **kwargs) 47 | 48 | def encode( 49 | self, 50 | sentences: list[str], 51 | *, 52 | prompt_name: str = None, 53 | **kwargs: Any, # noqa 54 | ) -> np.ndarray: 55 | if prompt_name is not None: 56 | instruction = ( 57 | self.task_to_instructions[prompt_name] 58 | if self.task_to_instructions 59 | and prompt_name in self.task_to_instructions 60 | else llm2vec_instruction(task_to_instruction(prompt_name)) 61 | ) 62 | else: 63 | instruction = "" 64 | sentences = [[instruction, sentence + ""] for sentence in sentences] 65 | return self.model.encode(sentences, **kwargs) 66 | 67 | def encode_corpus( 68 | self, 69 | corpus: list[dict[str, str]] | dict[str, list[str]] | list[str], 70 | **kwargs: Any, 71 | ) -> np.ndarray: 72 | sentences = corpus_to_texts(corpus, sep=" ") 73 | sentences = [["",sentence + " Use eight words to represent the above text in multiple aspects: "]for sentence in sentences] 74 | if "request_qid" in kwargs: 75 | kwargs.pop("request_qid") 76 | return self.model.encode(sentences, **kwargs) 77 | 78 | def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray: 79 | kwargs['is_q'] = True 80 | return self.encode(queries, **kwargs) 81 | 82 | 83 | def _loader(wrapper: type[LLM2VecWrapper], **kwargs) -> Callable[..., Encoder]: 84 | _kwargs = kwargs 85 | 86 | def loader_inner(**kwargs: Any) -> Encoder: 87 | return wrapper(**_kwargs, **kwargs) 88 | 89 | return loader_inner 90 | 91 | 92 | 93 | llm2vec_MiniCPM2B = ModelMeta( 94 | loader=_loader( 95 | LLM2VecWrapper, 96 | base_model_name_or_path="xxxxx", # Base MiniCPM Model 97 | peft_model_name_or_path="xxxxx", # Trained lora parameters 98 | device_map="auto", 99 | torch_dtype=torch.bfloat16, 100 | ), 101 | name="xxxxxx", # Custom Name 102 | languages=["eng_Latn"], 103 | open_source=True, 104 | revision=None, 105 | release_date="2025-01-02", 106 | ) 107 | 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # *Learning More Effective Representations for Dense Retrieval through Deliberate Thinking Before Search* 3 | 4 | [![arxiv](https://img.shields.io/badge/arXiv-2502.12974-orange?link=http%3A%2F%2Farxiv.org%2Fabs%2F2502.12974)](https://arxiv.org/abs/2502.12974) [![HF Link](https://img.shields.io/badge/HuggingFace-DEBATER_2B-brightgreen)](https://huggingface.co/bigtailwolf/DEBATER-2B) [![HF Link two](https://img.shields.io/badge/HuggingFace-DEBATER_4B-green)](https://huggingface.co/bigtailwolf/DEBATER-4B) 5 | 6 | DEBATER is a novel framework that introduces the Chain-of-Deliberation mechanism to iteratively optimize document representations through a continuous chain of thought. To consolidate information from multiple thinking steps, DEBATER incorporates the Self Distillation mechanism, which identifies the most informative steps and integrates them into a unified text embedding. 7 | 8 | ![model](./images/main_new-1.png) 9 | 10 | ## Installation 11 | ```bash 12 | pip install -r requirements.txt 13 | pip install flash-attn --no-build-isolation 14 | ``` 15 | ## Training 16 | We use the dataset from [Repetition Improves Language Model Embeddings](https://arxiv.org/abs/2402.15449). The dataset can be downloaded from their [GitHub](https://github.com/jakespringer/echo-embeddings#training). After downloading, put it in the data folder: 17 | 18 | 19 | ``` 20 | data 21 | └── echo-data 22 | ├── allnli.jsonl 23 | ├── dureader.jsonl 24 | ... 25 | ``` 26 | To train the MiniCPM model, you can run the following script: 27 | 28 | ```bash 29 | torchrun --nproc_per_node=4 scripts/run_supervised.py train_configs/Minicpm2-2B.json 30 | ``` 31 | or: 32 | ```bash 33 | torchrun --nproc_per_node=4 scripts/run_supervised.py train_configs/Minicpm3-4B.json 34 | ``` 35 | 36 | 37 | Please modify the contents of the `train_configs` folder,For example: 38 | 39 | ```json 40 | { 41 | "model_name_or_path": "xxxx", 42 | "peft_model_name_or_path": null, 43 | "pooling_mode": "last_token", 44 | "dataset_name": "E5", 45 | "dataset_file_path": "xxxx/data/echo-data", 46 | "remove_unused_columns": false, 47 | "learning_rate": 2e-4, 48 | "num_train_epochs": 1, 49 | "warmup_steps": 300, 50 | "per_device_train_batch_size": 64, 51 | "per_device_eval_batch_size": 64, 52 | "gradient_accumulation_steps": 1, 53 | "do_train": true, 54 | "disable_tqdm": false, 55 | "max_seq_length": 512, 56 | "overwrite_output_dir": true, 57 | "output_dir": "xxxxxx", 58 | "logging_steps": 50, 59 | "save_steps": 200, 60 | "save_only_model": true, 61 | "stop_after_n_steps": 1000, 62 | "lora_r": 16, 63 | "gradient_checkpointing": true, 64 | "torch_dtype": "bfloat16", 65 | "attn_implementation": "flash_attention_2", 66 | "seed": 42 67 | } 68 | ``` 69 | The main modified parameters are: 70 | ```bash 71 | "model_name_or_path": "xxxx", #The path of MiniCPM model 72 | "dataset_file_path": "xxxx", #The path of Traing Data 73 | "output_dir": "xxxx" #The path of Output Dir 74 | ``` 75 | 76 | 77 | ## Evaluation 78 | 79 | In order to use the trained model for evaluation, you need to modify the following content in the mteb package, the path is: `your_env_site-packages/mteb/models/llm2vec_models.py`, and customize the trained model at the end: 80 | ```bash 81 | llm2vec_MiniCPM2B = ModelMeta( 82 | loader=_loader( 83 | LLM2VecWrapper, 84 | base_model_name_or_path="xxxxx", # Base MiniCPM Model 85 | peft_model_name_or_path="xxxxx", # Trained lora parameters 86 | device_map="auto", 87 | torch_dtype=torch.bfloat16, 88 | ), 89 | name="xxxxxx", # Custom Name 90 | languages=["eng_Latn"], 91 | open_source=True, 92 | revision=None, 93 | release_date="2025-01-02", 94 | ) 95 | ``` 96 | We provide a sample file `llm2vec_models.py` for reference. 97 | Run the following script to evaluate, taking Arguana as an example: 98 | 99 | ```bash 100 | python mteb_eval.py 101 | ``` 102 | The checkpoint can be obtained at the following address: 103 | 104 | (1) The checkpoint of the DEBATER-2B is [here](https://huggingface.co/bigtailwolf/DEBATER-2B). 105 | 106 | (2) The checkpoint of the DEBATER-4B is [here](https://huggingface.co/bigtailwolf/DEBATER-4B). 107 | 108 | 109 | ## Citation 110 | 111 | If you find our work to be of value and helpful to your research, please acknowledge our contributions by citing us in your publications or projects: 112 | ```bibtex 113 | @misc{ji2025learningeffectiverepresentationsdense, 114 | title={Learning More Effective Representations for Dense Retrieval through Deliberate Thinking Before Search}, 115 | author={Yifan Ji and Zhipeng Xu and Zhenghao Liu and Yukun Yan and Shi Yu and Yishan Li and Zhiyuan Liu and Yu Gu and Ge Yu and Maosong Sun}, 116 | year={2025}, 117 | eprint={2502.12974}, 118 | archivePrefix={arXiv}, 119 | primaryClass={cs.IR}, 120 | url={https://arxiv.org/abs/2502.12974}, 121 | } 122 | ``` 123 | ## Contact 124 | If you have questions, suggestions, and bug reports, please email: 125 | ```bash 126 | bigtailwolf001@gmail.com 127 | ``` 128 | -------------------------------------------------------------------------------- /llm2vec/dataset/E5Data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import os 4 | 5 | from .dataset import DataSample, TrainSample, Dataset 6 | from accelerate.logging import get_logger 7 | 8 | logger = get_logger(__name__, log_level="INFO") 9 | 10 | E5_EMBEDDING_PROMPTS = { 11 | "allnli": [ 12 | "Given a premise, retrieve a hypothesis that is entailed by the premise:", 13 | "Retrieve semantically similar text:", 14 | ], 15 | "dureader": "Given a Chinese search query, retrieve web passages that answer the question:", 16 | "eli5_question_answer": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum:", 17 | "fever": "Given a claim, retrieve documents that support or refute the claim:", 18 | "hotpot_qa": "Given a multi-hop question, retrieve documents that can help answer the question:", 19 | "miracl": "Given a question, retrieve Wikipedia passages that answer the question:", 20 | "mrtydi": "Given a question, retrieve Wikipedia passages that answer the question:", 21 | "msmarco_passage": "Given a web search query, retrieve relevant passages that answer the query:", 22 | "msmarco_document": "Given a web search query, retrieve relevant documents that answer the query:", 23 | "nq": "Given a question, retrieve Wikipedia passages that answer the question:", 24 | "quora_duplicates": [ 25 | "Given a question, retrieve questions that are semantically equivalent to the given question:", 26 | "Find questions that have the same meaning as the input question:", 27 | ], 28 | "squad": "Retrieve Wikipedia passages that answer the question:", 29 | "t2ranking": "Given a Chinese search query, retrieve web passages that answer the question:", 30 | "trivia_qa": "Retrieve Wikipedia passages that answer the question:", 31 | } 32 | 33 | 34 | class E5Data(Dataset): 35 | def __init__( 36 | self, 37 | dataset_name: str = "E5", 38 | split: str = "validation", 39 | file_path: str = "data/echo-data", 40 | effective_batch_size: int = 32, 41 | shuffle_individual_datasets: bool = True, 42 | separator: str = "!@#$%^&*()", 43 | ): 44 | self.dataset_name = dataset_name 45 | self.split = split 46 | self.effective_batch_size = effective_batch_size 47 | self.shuffle_individual_datasets = shuffle_individual_datasets 48 | self.separator = separator 49 | 50 | self.data = [] 51 | self.load_data(file_path) 52 | 53 | def __len__(self): 54 | return len(self.data) 55 | 56 | def load_data(self, file_path: str = None): 57 | logger.info(f"Loading E5 data from {file_path}...") 58 | # file path is actually a directory 59 | 60 | data_map = {} 61 | all_samples = [] 62 | id_ = 0 63 | for dataset in E5_EMBEDDING_PROMPTS: 64 | logger.info(f"Loading dataset {dataset}...") 65 | if dataset not in data_map: 66 | data_map[dataset] = [] 67 | with open(os.path.join(file_path, f"{dataset}.jsonl"), "r") as f: 68 | dataset_samples = f.readlines() 69 | 70 | dataset_samples = [json.loads(d) for d in dataset_samples] 71 | 72 | for i, sample in enumerate(dataset_samples): 73 | instruction = ( 74 | E5_EMBEDDING_PROMPTS[dataset] 75 | if isinstance(E5_EMBEDDING_PROMPTS[dataset], str) 76 | else E5_EMBEDDING_PROMPTS[dataset][i % 2] 77 | ) 78 | query = f"{instruction} " + self.separator + sample["query"] + "" 79 | pos = self.separator + sample["positive"] + " Use eight words to represent the above text in multiple aspects: " 80 | neg = self.separator + sample["negative"] + " Use eight words to represent the above text in multiple aspects: " 81 | 82 | data_map[dataset].append(id_) 83 | 84 | all_samples.append( 85 | DataSample( 86 | id_=id_, 87 | query=query, 88 | positive=pos, 89 | negative=neg, 90 | task_name=dataset, 91 | ) 92 | ) 93 | id_ += 1 94 | 95 | # combine split1 and split2 96 | new_data_map = {} 97 | for dataset in data_map: 98 | new_dataset = dataset.replace("_split1", "").replace("_split2", "") 99 | if new_dataset not in new_data_map: 100 | new_data_map[new_dataset] = [] 101 | new_data_map[new_dataset] += data_map[dataset] 102 | data_map = new_data_map 103 | 104 | if self.shuffle_individual_datasets: 105 | for task, samples in data_map.items(): 106 | random.shuffle(samples) 107 | 108 | datasets = list(data_map.keys()) 109 | 110 | logger.info( 111 | f"Batching Echo data properly for effective batch size of {self.effective_batch_size}..." 112 | ) 113 | all_batches = [] 114 | for dataset in datasets: 115 | dataset_samples = data_map[dataset] 116 | for i in range(0, len(dataset_samples), self.effective_batch_size): 117 | batch = dataset_samples[i: i + self.effective_batch_size] 118 | if len(batch) == self.effective_batch_size: 119 | all_batches.append(batch) 120 | else: 121 | logger.info(f"Skip 1 batch for dataset {dataset}.") 122 | random.shuffle(all_batches) 123 | 124 | final_idx_order = [] 125 | for batch in all_batches: 126 | for idx in batch: 127 | final_idx_order.append(idx) 128 | 129 | self.data = [all_samples[idx] for idx in final_idx_order] 130 | logger.info(f"Loaded {len(self.data)} samples.") 131 | 132 | def __getitem__(self, index): 133 | sample = self.data[index] 134 | if self.split == "train": 135 | return TrainSample( 136 | texts=[sample.query, sample.positive, sample.negative], label=1.0 137 | ) 138 | elif self.split == "validation": 139 | assert False, "E5Data does not have a validation split." -------------------------------------------------------------------------------- /scripts/run_supervised.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | import os 4 | import sys 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader, SequentialSampler 10 | 11 | from accelerate import Accelerator, DistributedDataParallelKwargs 12 | from accelerate.logging import get_logger 13 | 14 | import transformers 15 | from transformers import ( 16 | MODEL_FOR_MASKED_LM_MAPPING, 17 | HfArgumentParser, 18 | TrainingArguments, 19 | Trainer, 20 | TrainerCallback, 21 | set_seed, 22 | ) 23 | from transformers.trainer_utils import seed_worker 24 | 25 | from peft import LoraConfig, get_peft_model 26 | 27 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 28 | sys.path.append(project_root) 29 | 30 | from llm2vec.dataset.utils import load_dataset 31 | from llm2vec.llm2vec import LLM2Vec 32 | from llm2vec.loss.utils import load_loss 33 | from llm2vec.experiment_utils import generate_experiment_id 34 | 35 | from tqdm import tqdm 36 | 37 | transformers.logging.set_verbosity_error() 38 | 39 | logging.basicConfig( 40 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 41 | datefmt="%Y-%m-%d %H:%M:%S", 42 | level=logging.INFO, 43 | ) 44 | logger = get_logger(__name__, log_level="INFO") 45 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 46 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 47 | 48 | 49 | 50 | def initialize_peft( 51 | model, 52 | lora_r: int = 8, 53 | lora_alpha: int = 16, 54 | lora_dropout: float = 0.05, 55 | lora_modules: Optional[List[str]] = None, 56 | ): 57 | print(model.config.__class__.__name__) 58 | model.config.output_hidden_states = True 59 | if lora_modules is None and model.config.__class__.__name__ in [ 60 | "MiniCPMConfig", 61 | "MiniCPM3Config" 62 | ]: 63 | lora_modules = [ 64 | "q_proj", 65 | "v_proj", 66 | "k_proj", 67 | "o_proj", 68 | "gate_proj", 69 | "up_proj", 70 | "down_proj", 71 | ] 72 | elif lora_modules is None: 73 | raise ValueError("lora_modules must be specified for this model.") 74 | 75 | config = LoraConfig( 76 | r=lora_r, 77 | lora_alpha=lora_alpha, 78 | target_modules=lora_modules, 79 | lora_dropout=lora_dropout, 80 | bias="none", 81 | task_type=None, 82 | ) 83 | 84 | model = get_peft_model(model, config) 85 | print(f"Model's Lora trainable parameters:") 86 | model.print_trainable_parameters() 87 | return model 88 | 89 | 90 | @dataclass 91 | class ModelArguments: 92 | """ 93 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 94 | """ 95 | 96 | model_name_or_path: Optional[str] = field( 97 | default=None, 98 | metadata={ 99 | "help": ( 100 | "The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 101 | ) 102 | }, 103 | ) 104 | peft_model_name_or_path: Optional[str] = field( 105 | default=None, 106 | metadata={"help": ("The PEFT model checkpoint to add on top of base model.")}, 107 | ) 108 | max_seq_length: Optional[int] = field( 109 | default=None, 110 | metadata={ 111 | "help": ( 112 | "The maximum total input sequence length after tokenization. Sequences longer " 113 | "than this will be truncated." 114 | ) 115 | }, 116 | ) 117 | torch_dtype: Optional[str] = field( 118 | default=None, 119 | metadata={ 120 | "help": ( 121 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 122 | "dtype will be automatically derived from the model's weights." 123 | ), 124 | "choices": ["auto", "bfloat16", "float16", "float32"], 125 | }, 126 | ) 127 | attn_implementation: Optional[str] = field( 128 | default="sdpa", 129 | metadata={ 130 | "help": ("The attention implementation to use in the model."), 131 | "choices": ["eager", "sdpa", "flash_attention_2"], 132 | }, 133 | ) 134 | pooling_mode: Optional[str] = field( 135 | default="last_token", 136 | metadata={ 137 | "help": ("The pooling mode to use in the model."), 138 | "choices": ["last_token","M_token"], 139 | }, 140 | ) 141 | 142 | 143 | @dataclass 144 | class DataTrainingArguments: 145 | """ 146 | Arguments pertaining to what data we are going to input our model for training and eval. 147 | """ 148 | 149 | dataset_name: Optional[str] = field( 150 | default=None, 151 | metadata={"help": "The name of the dataset to use. Options: E5"}, 152 | ) 153 | dataset_file_path: Optional[str] = field( 154 | default=None, metadata={"help": "The input training data file or folder."} 155 | ) 156 | # TODO: implement this 157 | max_train_samples: Optional[int] = field( 158 | default=None, 159 | metadata={ 160 | "help": ( 161 | "For debugging purposes or quicker training, truncate the number of training examples to this " 162 | "value if set." 163 | ) 164 | }, 165 | ) 166 | 167 | 168 | @dataclass 169 | class CustomArguments: 170 | """ 171 | Custom arguments for the script 172 | """ 173 | 174 | lora_dropout: float = field( 175 | default=0.05, metadata={"help": "The dropout rate for lora"} 176 | ) 177 | 178 | lora_r: int = field(default=8, metadata={"help": "The r value for lora"}) 179 | 180 | stop_after_n_steps: int = field( 181 | default=10000, metadata={"help": "Stop training after n steps"} 182 | ) 183 | 184 | experiment_id: Optional[str] = field( 185 | default=None, metadata={"help": "The experiment id"} 186 | ) 187 | 188 | loss_class: Optional[str] = field( 189 | default="HardNegativeNLLLoss", 190 | metadata={ 191 | "help": "The loss class to use for training. Options: HardNegativeNLLLoss" 192 | }, 193 | ) 194 | 195 | loss_scale: float = field( 196 | default=30.0, metadata={"help": "The loss scale for the loss function"} 197 | ) 198 | 199 | 200 | @dataclass 201 | class DefaultCollator: 202 | model: LLM2Vec 203 | 204 | def __init__(self, model: LLM2Vec) -> None: 205 | self.model = model 206 | 207 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 208 | batch = features 209 | num_texts = len(batch[0].texts) 210 | texts = [[] for _ in range(num_texts)] 211 | labels = [] 212 | 213 | for example in batch: 214 | for idx, text in enumerate(example.texts): 215 | texts[idx].append(text) 216 | labels.append(example.label) 217 | labels = torch.tensor(labels) 218 | 219 | sentence_features = [] 220 | for idx in range(num_texts): 221 | tokenized = self.model.tokenize(texts[idx]) 222 | sentence_features.append(tokenized) 223 | 224 | return sentence_features, labels 225 | 226 | 227 | class StopTrainingCallback(TrainerCallback): 228 | def __init__(self, stop_after_n_steps: int): 229 | self.stop_after_n_steps = stop_after_n_steps 230 | 231 | def on_step_end(self, args, state, control, **kwargs): 232 | if state.global_step >= self.stop_after_n_steps: 233 | control.should_training_stop = True 234 | 235 | 236 | class LLM2VecSupervisedTrainer(Trainer): 237 | 238 | def __init__( 239 | self, 240 | *args, 241 | loss_function=None, 242 | **kwargs, 243 | ) -> None: 244 | super().__init__(*args, **kwargs) 245 | self.loss_function = loss_function 246 | 247 | def compute_loss( 248 | self, 249 | model: nn.Module, 250 | inputs: Dict[str, Union[torch.Tensor, Any]], 251 | return_outputs: bool = False, 252 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 253 | features, labels = inputs 254 | 255 | q_reps = self.model.forward(features[0], is_q=True) 256 | d_reps = self.model.forward(features[1], is_q=False) 257 | 258 | d_reps_neg = None 259 | if len(features) > 2: 260 | d_reps_neg = self.model.forward(features[2], is_q=False) 261 | 262 | loss = self.loss_function(q_reps, d_reps, d_reps_neg) 263 | 264 | if return_outputs: 265 | output = torch.cat( 266 | [model(row)["sentence_embedding"][:, None] for row in features], dim=1 267 | ) 268 | return loss, output 269 | 270 | return loss 271 | 272 | def get_train_dataloader(self) -> DataLoader: 273 | # Copying most of the code from the parent class, changing the sampler to SequentialSampler 274 | if self.train_dataset is None: 275 | raise ValueError("Trainer: training requires a train_dataset.") 276 | 277 | train_dataset = self.train_dataset 278 | data_collator = self.data_collator 279 | 280 | data_collator = self._get_collator_with_removed_columns( 281 | data_collator, description="training" 282 | ) 283 | 284 | dataloader_params = { 285 | "batch_size": self._train_batch_size, 286 | "collate_fn": data_collator, 287 | "num_workers": self.args.dataloader_num_workers, 288 | "pin_memory": self.args.dataloader_pin_memory, 289 | "persistent_workers": self.args.dataloader_persistent_workers, 290 | } 291 | 292 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 293 | # Changing from random sampler to sequential sampler 294 | dataloader_params["sampler"] = SequentialSampler(train_dataset) 295 | dataloader_params["drop_last"] = self.args.dataloader_drop_last 296 | dataloader_params["worker_init_fn"] = seed_worker 297 | 298 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 299 | 300 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 301 | # If we are executing this function, we are the process zero, so we don't check for that. 302 | output_dir = output_dir if output_dir is not None else self.args.output_dir 303 | os.makedirs(output_dir, exist_ok=True) 304 | logger.info(f"Saving model checkpoint to {output_dir}") 305 | 306 | self.model.save(output_dir) 307 | 308 | # Good practice: save your training arguments together with the trained model 309 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 310 | 311 | 312 | def main(): 313 | parser = HfArgumentParser( 314 | (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments) 315 | ) 316 | 317 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 318 | # If we pass only one argument to the script and it's the path to a json file, 319 | # let's parse it to get our arguments. 320 | model_args, data_args, training_args, custom_args = parser.parse_json_file( 321 | json_file=os.path.abspath(sys.argv[1]) 322 | ) 323 | else: 324 | ( 325 | model_args, 326 | data_args, 327 | training_args, 328 | custom_args, 329 | ) = parser.parse_args_into_dataclasses() 330 | 331 | training_args.report_to = ["tensorboard"] 332 | if training_args.ddp_find_unused_parameters: 333 | kwargs = [ 334 | DistributedDataParallelKwargs( 335 | dim=0, 336 | broadcast_buffers=True, 337 | bucket_cap_mb=25, 338 | find_unused_parameters=True, 339 | check_reduction=False, 340 | gradient_as_bucket_view=False, 341 | ) 342 | ] 343 | else: 344 | kwargs = [] 345 | accelerator = Accelerator(kwargs_handlers=kwargs) 346 | 347 | set_seed(training_args.seed) 348 | 349 | if training_args.gradient_checkpointing: 350 | training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} 351 | 352 | if custom_args.experiment_id is not None: 353 | experiment_id = custom_args.experiment_id 354 | else: 355 | experiment_id = generate_experiment_id( 356 | name=data_args.dataset_name, 357 | split="train", 358 | model_name=( 359 | model_args.model_name_or_path 360 | if "/" not in model_args.model_name_or_path 361 | else model_args.model_name_or_path.split("/")[-1] 362 | ), 363 | pooling_mode=model_args.pooling_mode, 364 | train_batch_size=training_args.per_device_train_batch_size 365 | * accelerator.num_processes 366 | * training_args.gradient_accumulation_steps, 367 | max_seq_length=model_args.max_seq_length, 368 | epochs=training_args.num_train_epochs, 369 | seed=training_args.seed, 370 | warmup_steps=training_args.warmup_steps, 371 | lr=training_args.learning_rate, 372 | lora_r=custom_args.lora_r, 373 | ) 374 | 375 | training_args.output_dir = f"{training_args.output_dir}/{experiment_id}" 376 | 377 | # TODO: can also pass separator arg here 378 | train_dataset = load_dataset( 379 | data_args.dataset_name, 380 | split="train", 381 | file_path=data_args.dataset_file_path, 382 | effective_batch_size=training_args.per_device_train_batch_size 383 | * accelerator.num_processes, 384 | ) 385 | 386 | train_examples = [ 387 | train_dataset[i] 388 | for i in tqdm( 389 | range(len(train_dataset)), 390 | desc="Loading train examples...", 391 | disable=not accelerator.is_main_process, 392 | ) 393 | ] 394 | 395 | torch_dtype = ( 396 | model_args.torch_dtype 397 | if model_args.torch_dtype in ["auto", None] 398 | else getattr(torch, model_args.torch_dtype) 399 | ) 400 | model = LLM2Vec.from_pretrained( 401 | 402 | trust_remote_code=True, 403 | base_model_name_or_path=model_args.model_name_or_path, 404 | peft_model_name_or_path=model_args.peft_model_name_or_path, 405 | merge_peft=True, 406 | pooling_mode=model_args.pooling_mode, 407 | max_length=model_args.max_seq_length, 408 | torch_dtype=torch_dtype, 409 | attn_implementation=model_args.attn_implementation, 410 | ) 411 | # for name, param in model.named_parameters(): 412 | # print(f"Parameter: {name}, dtype: {param.dtype}") 413 | # print("model.config : ") 414 | # print(model.config.__class__.__name__) 415 | 416 | # model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model 417 | model.model = initialize_peft( 418 | model.model, 419 | lora_r=custom_args.lora_r, 420 | lora_alpha=2 * custom_args.lora_r, 421 | lora_dropout=custom_args.lora_dropout, 422 | ) 423 | 424 | tokenizer = model.tokenizer 425 | 426 | train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale) 427 | 428 | data_collator = DefaultCollator(model) 429 | 430 | trainer = LLM2VecSupervisedTrainer( 431 | model=model, 432 | args=training_args, 433 | train_dataset=train_examples, 434 | data_collator=data_collator, 435 | tokenizer=tokenizer, 436 | loss_function=train_loss, 437 | ) 438 | 439 | if custom_args.stop_after_n_steps is not None: 440 | trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps)) 441 | 442 | trainer.train() 443 | 444 | 445 | if __name__ == "__main__": 446 | main() 447 | -------------------------------------------------------------------------------- /llm2vec/llm2vec.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.multiprocessing as mp 9 | from peft import PeftModel 10 | from torch import Tensor, device, nn 11 | from tqdm.autonotebook import tqdm, trange 12 | from transformers import ( 13 | AutoModel, 14 | PretrainedConfig, 15 | AutoTokenizer, 16 | ) 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | def batch_to_device(batch, target_device: device): 21 | """ 22 | send a pytorch batch to a device (CPU/GPU) 23 | """ 24 | for key in batch: 25 | if isinstance(batch[key], Tensor): 26 | batch[key] = batch[key].to(target_device) 27 | return batch 28 | 29 | 30 | class LLM2Vec(nn.Module): 31 | def __init__( 32 | self, 33 | model: AutoModel, 34 | tokenizer: AutoTokenizer, 35 | pooling_mode: str = "last_token", 36 | max_length: int = 512, 37 | doc_max_length: int = 400, 38 | skip_instruction: bool = True, 39 | ): 40 | super().__init__() 41 | self.model = model 42 | self.tokenizer = tokenizer 43 | self.pooling_mode = pooling_mode 44 | self.skip_instruction = skip_instruction 45 | self.max_length = max_length 46 | self.doc_max_length = doc_max_length 47 | self.config = model.config 48 | 49 | @classmethod 50 | def from_pretrained( 51 | cls, 52 | base_model_name_or_path, 53 | peft_model_name_or_path=None, 54 | merge_peft=False, 55 | **kwargs, 56 | ): 57 | # pop out encoder args 58 | keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] 59 | encoder_args = { 60 | key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None 61 | } 62 | 63 | tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, trust_remote_code=True) 64 | tokenizer.pad_token = tokenizer.eos_token 65 | tokenizer.padding_side = "left" 66 | 67 | model = AutoModel.from_pretrained(base_model_name_or_path, **kwargs) 68 | 69 | if os.path.isdir(base_model_name_or_path) and os.path.exists( 70 | f"{base_model_name_or_path}/config.json" 71 | ): 72 | with open(f"{base_model_name_or_path}/config.json", "r") as fIn: 73 | config_dict = json.load(fIn) 74 | config = PretrainedConfig.from_dict(config_dict) 75 | model.config._name_or_path = config._name_or_path 76 | 77 | # For special case where config.json and adapter weights are in the same directory 78 | if hasattr(model, "peft_config"): 79 | model = PeftModel.from_pretrained( 80 | model, 81 | base_model_name_or_path, 82 | ) 83 | model = model.merge_and_unload() 84 | 85 | if peft_model_name_or_path is not None: 86 | model = PeftModel.from_pretrained( 87 | model, 88 | peft_model_name_or_path, 89 | ) 90 | 91 | if merge_peft: 92 | model = model.merge_and_unload() 93 | 94 | config = {} 95 | config_addr = ( 96 | peft_model_name_or_path 97 | if peft_model_name_or_path is not None 98 | else base_model_name_or_path 99 | ) 100 | if os.path.exists(f"{config_addr}/llm2vec_config.json"): 101 | with open(f"{config_addr}/llm2vec_config.json", "r") as fIn: 102 | llm2vec_config = json.load(fIn) 103 | config.update(llm2vec_config) 104 | 105 | for key, value in encoder_args.items(): 106 | config[key] = value 107 | return cls(model=model, tokenizer=tokenizer, **config) 108 | 109 | def tokenize(self, texts): 110 | texts_2 = [] 111 | original_texts = [] 112 | for text in texts: 113 | t = text.split("!@#$%^&*()") 114 | texts_2.append(t[1] if len(t) > 1 else "") 115 | original_texts.append("".join(t)) 116 | 117 | original = self.tokenizer( 118 | original_texts, 119 | return_tensors="pt", 120 | padding=True, 121 | truncation=True, 122 | max_length=self.max_length, 123 | ) 124 | 125 | embed_mask = None 126 | for t_i, t in enumerate(texts_2): 127 | ids = self.tokenizer( 128 | [t], 129 | return_tensors="pt", 130 | padding=True, 131 | truncation=True, 132 | max_length=self.max_length, 133 | add_special_tokens=False, 134 | ) 135 | if embed_mask is None: 136 | e_m = torch.zeros_like(original["attention_mask"][t_i]) 137 | if len(ids["input_ids"][0]) > 0: 138 | e_m[-len(ids["input_ids"][0]):] = torch.ones( 139 | len(ids["input_ids"][0]) 140 | ) 141 | embed_mask = e_m.unsqueeze(0) 142 | else: 143 | e_m = torch.zeros_like(original["attention_mask"][t_i]) 144 | if len(ids["input_ids"][0]) > 0: 145 | e_m[-len(ids["input_ids"][0]):] = torch.ones( 146 | len(ids["input_ids"][0]) 147 | ) 148 | embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) 149 | 150 | original["embed_mask"] = embed_mask 151 | return original 152 | 153 | def _skip_instruction(self, sentence_feature): 154 | assert ( 155 | sentence_feature["attention_mask"].shape 156 | == sentence_feature["embed_mask"].shape 157 | ) 158 | sentence_feature["attention_mask"] = sentence_feature["embed_mask"] 159 | 160 | def forward(self, sentence_feature: Dict[str, Tensor],is_q=False): 161 | embed_mask = None 162 | if "embed_mask" in sentence_feature: 163 | embed_mask = sentence_feature.pop("embed_mask") 164 | reps = self.model(**sentence_feature) 165 | sentence_feature["embed_mask"] = embed_mask 166 | 167 | return self.get_pooling(sentence_feature, reps.last_hidden_state,is_q=is_q) 168 | 169 | 170 | def get_pooling(self, features, last_hidden_states, is_q=False): 171 | 172 | assert ( 173 | self.tokenizer.padding_side == "left" 174 | ), "Pooling modes are implemented for padding from left." 175 | if self.skip_instruction: 176 | self._skip_instruction(features) 177 | 178 | if self.pooling_mode == "last_token": 179 | return last_hidden_states[:, -1] 180 | elif self.pooling_mode == "M_token": 181 | if not is_q: 182 | return last_hidden_states[:, -8:] 183 | else: 184 | return last_hidden_states[:, -1] 185 | else: 186 | raise ValueError(f"{self.pooling_mode} is not implemented yet.") 187 | 188 | def _convert_to_str(self, instruction, text): 189 | tokenized_q = self.tokenizer( 190 | text, 191 | return_tensors="pt", 192 | padding=True, 193 | truncation=True, 194 | max_length=self.max_length, 195 | add_special_tokens=False, 196 | ) 197 | tokenized_q_length = len(tokenized_q["input_ids"][0]) 198 | 199 | while tokenized_q_length > self.doc_max_length: 200 | reduction_ratio = self.doc_max_length / tokenized_q_length 201 | reduced_length = int(len(text.split()) * reduction_ratio) 202 | text = " ".join(text.split()[:reduced_length]) 203 | tokenized_q = self.tokenizer( 204 | text, 205 | return_tensors="pt", 206 | padding=True, 207 | truncation=True, 208 | max_length=self.max_length, 209 | add_special_tokens=False, 210 | ) 211 | tokenized_q_length = len(tokenized_q["input_ids"][0]) 212 | 213 | return ( 214 | f"{instruction.strip()} !@#$%^&*(){text}" 215 | if instruction 216 | else f"!@#$%^&*(){text}" 217 | ) 218 | 219 | def encode( 220 | self, 221 | sentences: Union[str, List[str]], 222 | batch_size: int = 128, 223 | show_progress_bar: bool = True, 224 | convert_to_numpy: bool = False, 225 | convert_to_tensor: bool = False, 226 | device: Optional[str] = None, 227 | is_q=False, 228 | ): 229 | """ 230 | Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string. 231 | Args: 232 | sentences: sentence or sentences to encode. 233 | batch_size: batch size for turning sentence tokens into embeddings. 234 | show_progress_bar: whether to show progress bars during encoding steps. 235 | convert_to_numpy: If true, return numpy arrays instead of torch tensors. 236 | convert_to_tensor: If true, return torch tensors (default). 237 | device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified, 238 | the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports 239 | multiprocessing as currently implemented. 240 | 241 | Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation). 242 | 243 | """ 244 | if isinstance(sentences[0], str) and isinstance(sentences[-1], int): 245 | sentences = [sentences] 246 | # required for MEDI version of MTEB 247 | if isinstance(sentences[0], str): 248 | sentences = [[""] + [sentence] for sentence in sentences] 249 | 250 | if device is None: 251 | device = "cuda" if torch.cuda.is_available() else "cpu" 252 | 253 | concatenated_input_texts = [] 254 | for sentence in sentences: 255 | assert isinstance(sentence[0], str) 256 | assert isinstance(sentence[1], str) 257 | concatenated_input_texts.append( 258 | self._convert_to_str(sentence[0], sentence[1]) 259 | ) 260 | sentences = concatenated_input_texts 261 | 262 | self.eval() 263 | 264 | if convert_to_tensor: 265 | convert_to_numpy = False 266 | 267 | length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) 268 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 269 | all_embeddings = [] 270 | 271 | if torch.cuda.device_count() <= 1: 272 | # This branch also support mps devices 273 | self.to(device) 274 | for start_index in trange( 275 | 0, 276 | len(sentences), 277 | batch_size, 278 | desc="Batches", 279 | disable=not show_progress_bar, 280 | ): 281 | sentences_batch = sentences_sorted[ 282 | start_index: start_index + batch_size 283 | ] 284 | embeddings = self._encode( 285 | sentences_batch, device=device, convert_to_numpy=convert_to_numpy, is_q=is_q 286 | ) 287 | all_embeddings.append(embeddings) 288 | else: 289 | 290 | num_proc = torch.cuda.device_count() 291 | cuda_compatible_multiprocess = mp.get_context("spawn") 292 | with cuda_compatible_multiprocess.Pool(num_proc) as p: 293 | sentences_batches = [ 294 | sentences_sorted[start_index: start_index + batch_size] 295 | for start_index in range(0, len(sentences), batch_size) 296 | ] 297 | 298 | progress_bar = tqdm( 299 | total=len(sentences_batches), 300 | desc="Batches", 301 | disable=not show_progress_bar, 302 | ) 303 | results = [] 304 | 305 | def update(*args): 306 | progress_bar.update() 307 | 308 | for batch in sentences_batches: 309 | results.append( 310 | p.apply_async( 311 | self._encode, 312 | args=(batch, None, convert_to_numpy, True, is_q), 313 | callback=update, 314 | ) 315 | ) 316 | 317 | all_embeddings = [result.get() for result in results] 318 | progress_bar.close() 319 | 320 | all_embeddings = torch.cat(all_embeddings, dim=0) 321 | all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] 322 | all_embeddings = all_embeddings.to(torch.float32) 323 | if convert_to_numpy: 324 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 325 | return all_embeddings 326 | 327 | def save(self, output_path, merge_before_save=False, save_config=True): 328 | if merge_before_save and isinstance(self.model, PeftModel): 329 | self.model = self.model.merge_and_unload() 330 | if hasattr(self.model, "_hf_peft_config_loaded"): 331 | self.model._hf_peft_config_loaded = False 332 | 333 | self.model.save_pretrained(output_path) 334 | self.tokenizer.save_pretrained(output_path) 335 | 336 | llm2vec_config = { 337 | "pooling_mode": self.pooling_mode, 338 | "max_length": self.max_length, 339 | "doc_max_length": self.doc_max_length, 340 | "skip_instruction": self.skip_instruction, 341 | } 342 | 343 | if save_config: 344 | os.makedirs(output_path, exist_ok=True) 345 | with open(f"{output_path}/llm2vec_config.json", "w") as fOut: 346 | json.dump(llm2vec_config, fOut, indent=4) 347 | 348 | def _encode( 349 | self, 350 | sentences_batch, 351 | device: Optional[str] = None, 352 | multiprocessing=False, 353 | is_q=False 354 | ): 355 | if multiprocessing: 356 | # multiprocessing only supports CUDA devices at this time, so we ignore the value of device 357 | # and use cuda:rank for the device 358 | rank = mp.current_process()._identity[0] 359 | if device is None and torch.cuda.is_available(): 360 | device = f"cuda:{rank % torch.cuda.device_count()}" 361 | self.to(device) 362 | features = self.tokenize( 363 | [self.prepare_for_tokenization(sentence) for sentence in sentences_batch] 364 | ) 365 | features = batch_to_device(features, device) 366 | 367 | with torch.no_grad(): 368 | embeddings = self.forward(features, is_q=is_q) 369 | embeddings = embeddings.detach() 370 | embeddings = embeddings.cpu() 371 | 372 | return embeddings 373 | 374 | def _text_length(self, text: Union[List[int], List[List[int]]]): 375 | """ 376 | Help function to get the length for the input text. Text can be either a string (which means a single text) 377 | a list of ints (which means a single tokenized text), or a tuple of list of ints 378 | (representing several text inputs to the model). 379 | """ 380 | if ( 381 | isinstance(text, str) 382 | or (isinstance(text, list) and isinstance(text[0], int)) 383 | or len(text) == 0 384 | ): # Single text, list of ints, or empty 385 | return len(text) 386 | if isinstance(text, dict): # {key: value} case 387 | return len(next(iter(text.values()))) 388 | elif not hasattr(text, "__len__"): # Object has no len() method 389 | return 1 390 | else: 391 | return sum([len(t) for t in text]) 392 | 393 | def resize_token_embeddings( 394 | self, 395 | new_num_tokens: Optional[int] = None, 396 | pad_to_multiple_of: Optional[int] = None, 397 | ) -> nn.Embedding: 398 | return self.model.resize_token_embeddings( 399 | new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of 400 | ) 401 | 402 | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): 403 | self.model.gradient_checkpointing_enable( 404 | gradient_checkpointing_kwargs=gradient_checkpointing_kwargs 405 | ) 406 | --------------------------------------------------------------------------------