├── src ├── __init__.py ├── scripts │ ├── __init__.py │ ├── trec_eval.py │ ├── indexes_and_topics.py │ └── run_evaluation.py ├── modeling │ ├── __init__.py │ ├── causal_lm │ │ ├── __init__.py │ │ ├── modeling_llama.py │ │ └── modeling_mistral.py │ ├── rank_lm │ │ ├── __init__.py │ │ ├── loss.py │ │ ├── modeling_llama.py │ │ └── modeling_mistral.py │ ├── model.py │ ├── encoder.py │ ├── meta.py │ └── builder.py ├── constants.py ├── arguments.py ├── trainer.py ├── utils.py ├── evaluate.py ├── ranker.py ├── train.py └── data.py ├── docs └── images │ └── cover.jpg ├── requirements.txt ├── scripts ├── zero2.json ├── train_s1.sh └── train_s2.sh ├── .gitignore ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/images/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuqi6777/pe_rank/HEAD/docs/images/cover.jpg -------------------------------------------------------------------------------- /src/modeling/causal_lm/__init__.py: -------------------------------------------------------------------------------- 1 | from modeling.causal_lm.modeling_llama import EmbedLlamaForCausalLM 2 | from modeling.causal_lm.modeling_mistral import EmbedMistralForCausalLM -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | from transformers.trainer_pt_utils import LabelSmoother 2 | 3 | 4 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 5 | 6 | PLACEHOLDER = '' 7 | RANK_TOKEN = '' 8 | -------------------------------------------------------------------------------- /src/modeling/rank_lm/__init__.py: -------------------------------------------------------------------------------- 1 | from modeling.rank_lm.modeling_llama import EmbedLlamaConfig, EmbedLlamaModel, EmbedLlamaForRankLM 2 | from modeling.rank_lm.modeling_mistral import EmbedMistralConfig, EmbedMistralModel, EmbedMistralForRankLM -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepspeed==0.12.3 2 | huggingface_hub==0.22.1 3 | numpy==1.26.4 4 | peft==0.6.2 5 | pyserini==0.24.0 6 | pytrec_eval==0.5 7 | sentence_transformers==2.5.1 8 | torch==2.1.2 9 | tqdm==4.66.1 10 | transformers==4.37.2 11 | ujson==5.8.0 12 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": false, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/train_s1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed --include="localhost:0,1,2,3" src/train.py \ 4 | --deepspeed scripts/zero2.json \ 5 | --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \ 6 | --data_path ./data/wiki2m.jsonl \ 7 | --encoder_name jinaai/jina-embeddings-v2-base-en \ 8 | --encoder_pooling mean \ 9 | --projector_type mlp2x_gelu \ 10 | --freeze_backbone \ 11 | --tune_mlp_adapter \ 12 | --bf16 \ 13 | --output_dir ./checkpoints/mistral.jina.projector \ 14 | --num_train_epochs 1 \ 15 | --per_device_train_batch_size 1 \ 16 | --gradient_accumulation_steps 1 \ 17 | --evaluation_strategy "no" \ 18 | --save_strategy "steps" \ 19 | --save_steps 1000 \ 20 | --save_total_limit 1 \ 21 | --learning_rate 1e-3 \ 22 | --warmup_ratio 0.03 \ 23 | --lr_scheduler_type "cosine" \ 24 | --logging_steps 1 \ 25 | --tf32 True \ 26 | --model_max_length 512 \ 27 | --gradient_checkpointing \ 28 | --attn_implementation flash_attention_2 \ 29 | --dataloader_num_workers 4 -------------------------------------------------------------------------------- /scripts/train_s2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed --include="localhost:4,5,6,7" --master_port="29700" src/train.py \ 4 | --deepspeed ./scripts/zero2.json \ 5 | --model_type rank_lm \ 6 | --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \ 7 | --data_path ./data/train.jsonl \ 8 | --use_embedding_with_content True \ 9 | --use_embedding_without_content True \ 10 | --kl_loss_weight 0.2 \ 11 | --loss1_weight 1 \ 12 | --loss2_weight 1 \ 13 | --encoder_name jinaai/jina-embeddings-v2-base-en \ 14 | --encoder_pooling mean \ 15 | --pretrain_mlp_adapter ./checkpoints/mistral.jina.projector/projector.bin \ 16 | --projector_type mlp2x_gelu \ 17 | --tune_mlp_adapter \ 18 | --bf16 True \ 19 | --tf32 True \ 20 | --output_dir "./checkpoints/pe-rank-mistral-jina" \ 21 | --overwrite_output_dir \ 22 | --num_train_epochs 1 \ 23 | --per_device_train_batch_size 4 \ 24 | --gradient_accumulation_steps 2 \ 25 | --save_strategy "steps" \ 26 | --save_steps 3000 \ 27 | --save_total_limit 2 \ 28 | --learning_rate 2e-5 \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --model_max_length 4096 \ 33 | --gradient_checkpointing True \ 34 | --attn_implementation flash_attention_2 \ 35 | --dataloader_num_workers 2 36 | -------------------------------------------------------------------------------- /src/modeling/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | from modeling.encoder import Encoder, build_projector 5 | 6 | 7 | class ELMMetaModel: 8 | def __init__(self, config): 9 | super().__init__(config) 10 | self.config = config 11 | 12 | if hasattr(config, 'encoder_name'): 13 | self.encoder = Encoder(config.encoder_name, config) 14 | self.projector = build_projector(config) 15 | 16 | def get_encoder(self): 17 | return getattr(self, 'encoder', None) 18 | 19 | def get_projector(self): 20 | return getattr(self, 'projector', None) 21 | 22 | def initialize_modules(self, model_args): 23 | encoder_name = model_args.encoder_name 24 | pretrain_mlp_adapter = model_args.pretrain_mlp_adapter 25 | 26 | self.config.encoder_name = encoder_name 27 | self.config.encoder_pooling = model_args.encoder_pooling 28 | if self.get_encoder() is None: 29 | self.encoder = Encoder(self.config.encoder_name, self.config) 30 | 31 | self.config.use_proj = True 32 | self.config.projector_type = getattr( 33 | model_args, 'projector_type', 'linear') 34 | self.config.embedding_size = self.encoder.config.hidden_size 35 | 36 | if self.get_projector() is None: 37 | self.projector = build_projector(self.config) 38 | else: 39 | # In case it is frozen by LoRA 40 | for p in self.projector.parameters(): 41 | p.requires_grad = True 42 | 43 | if pretrain_mlp_adapter is not None: 44 | projector_weights = torch.load(pretrain_mlp_adapter, map_location='cpu') 45 | 46 | def get_w(weights, keyword): 47 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 48 | 49 | self.projector.load_state_dict(get_w(projector_weights, 'projector')) 50 | # self.set_encoder_head() 51 | 52 | def encode_texts(self, **inputs: dict): 53 | embeddings = self.get_encoder()(**inputs) 54 | project_as_token_embeddings = self.get_projector()(embeddings) 55 | # no need to normalize to align with the original token embedding space 56 | return project_as_token_embeddings 57 | -------------------------------------------------------------------------------- /src/scripts/trec_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pytrec_eval 3 | from pyserini.search import get_qrels_file 4 | 5 | 6 | def compute_metrics( 7 | qrels: dict[str, dict[str, int]], 8 | results: dict[str, dict[str, float]], 9 | k_values: tuple[int] = (10, 50, 100, 200, 1000) 10 | ) -> dict[str, float]: 11 | ndcg, _map, recall = {}, {}, {} 12 | 13 | for k in k_values: 14 | _map[f"MAP@{k}"] = 0.0 15 | ndcg[f"NDCG@{k}"] = 0.0 16 | recall[f"Recall@{k}"] = 0.0 17 | 18 | map_string = "map_cut." + ",".join([str(k) for k in k_values]) 19 | ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) 20 | recall_string = "recall." + ",".join([str(k) for k in k_values]) 21 | 22 | evaluator = pytrec_eval.RelevanceEvaluator( 23 | qrels, {map_string, ndcg_string, recall_string}) 24 | scores = evaluator.evaluate(results) 25 | 26 | for query_id in scores: 27 | for k in k_values: 28 | _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)] 29 | ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)] 30 | recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)] 31 | 32 | def _normalize(m: dict) -> dict: 33 | return {k: round(v / len(scores), 4) for k, v in m.items()} 34 | 35 | _map = _normalize(_map) 36 | ndcg = _normalize(ndcg) 37 | recall = _normalize(recall) 38 | 39 | all_metrics = {} 40 | for mt in [_map, ndcg, recall]: 41 | all_metrics.update(mt) 42 | 43 | return all_metrics 44 | 45 | 46 | def pretty_print_metrics(metrics: dict[str, float]): 47 | for metric, value in metrics.items(): 48 | print(f"{metric:<12}\t{value}") 49 | 50 | 51 | def trec_eval(dataset, ranking): 52 | with open(ranking, 'r') as f_run: 53 | run = pytrec_eval.parse_run(f_run) 54 | with open(get_qrels_file(dataset), 'r') as f_qrel: 55 | qrels = pytrec_eval.parse_qrel(f_qrel) 56 | all_metrics = compute_metrics(qrels, run, k_values=(1, 5, 10, 20, 100)) 57 | pretty_print_metrics(all_metrics) 58 | return all_metrics 59 | 60 | 61 | if __name__ == '__main__': 62 | from indexes_and_topics import TOPICS 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--dataset', type=str, default='dl19') 66 | parser.add_argument('--ranking', type=str, required=True) 67 | args = parser.parse_args() 68 | trec_eval(TOPICS[args.dataset], args.ranking) 69 | -------------------------------------------------------------------------------- /src/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | from transformers import TrainingArguments as HFTrainingArguments 4 | 5 | 6 | @dataclass 7 | class ModelArguments: 8 | model_type: str = field( 9 | default="causal_lm", 10 | metadata={ 11 | "help": "The type of model to use. Can be 'causal_lm' or 'rank_lm'." 12 | }, 13 | ) 14 | model_name_or_path: Optional[str] = field(default="JackFram/llama-68m") 15 | trust_remote_code: bool = field( 16 | default=False, 17 | metadata={ 18 | "help": "Whether or not to allow for custom models defined on the Hub in their own modeling files" 19 | }, 20 | ) 21 | padding_side: str = field( 22 | default="right", metadata={"help": "The padding side in tokenizer"} 23 | ) 24 | version: Optional[str] = field(default="v0") 25 | freeze_backbone: bool = field(default=False) 26 | freeze_embedding_layer: bool = field(default=False) 27 | tune_mlp_adapter: bool = field(default=False) 28 | encoder_name: Optional[str] = field(default=None) 29 | encoder_pooling: Optional[str] = field( 30 | default="mean", metadata={"help": "mean or cls"} 31 | ) 32 | pretrain_mlp_adapter: Optional[str] = field(default=None) 33 | projector_type: Optional[str] = field(default='linear') 34 | 35 | 36 | @dataclass 37 | class DataArguments: 38 | data_path: str = field( 39 | default=None, metadata={"help": "Path to the training data."} 40 | ) 41 | eval_data_path: str = field( 42 | default=None, metadata={"help": "Path to the evaluation data."} 43 | ) 44 | use_embedding_with_content: bool = field(default=True) 45 | use_embedding_without_content: bool = field(default=False) 46 | 47 | 48 | @dataclass 49 | class TrainingArguments(HFTrainingArguments): 50 | cache_dir: Optional[str] = field(default=None) 51 | optim: str = field(default="adamw_torch") 52 | remove_unused_columns: bool = field(default=False) 53 | model_max_length: int = field( 54 | default=512, 55 | metadata={ 56 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 57 | }, 58 | ) 59 | attn_implementation: Optional[str] = field( 60 | default=None, 61 | metadata={ 62 | "help": "The implementation of attention. Can be 'flash_attention_2'." 63 | }, 64 | ) 65 | 66 | loss1_weight: float = field(default=1.0) 67 | loss2_weight: float = field(default=1.0) 68 | kl_loss_weight: float = field(default=0.0) 69 | 70 | 71 | @dataclass 72 | class LoraArguments: 73 | lora_enable: bool = False 74 | lora_r: int = 8 75 | lora_alpha: int = 16 76 | lora_dropout: float = 0.05 77 | lora_target_modules: list[str] = field( 78 | default_factory=lambda: ["q_proj", "v_proj"] 79 | ) 80 | lora_weight_path: str = "" 81 | lora_bias: str = "none" 82 | -------------------------------------------------------------------------------- /src/modeling/rank_lm/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor, LongTensor 3 | 4 | from constants import IGNORE_TOKEN_ID 5 | 6 | 7 | class ListMLELoss(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def _rank_minus_one(self, ranking: LongTensor) -> LongTensor: 12 | if ranking.min() == 1: 13 | ranking = ranking - 1 14 | return ranking 15 | 16 | def _make_ranking_mask(self, ranking: LongTensor) -> LongTensor: 17 | ranking = self._rank_minus_one(ranking) 18 | n = ranking.shape[-1] 19 | mask = torch.triu(torch.ones(n, n, dtype=torch.long, device=ranking.device)) 20 | mask = mask[:, torch.sort(ranking, dim=-1).indices].contiguous() 21 | return mask 22 | 23 | def _make_label_mask(self, ranking: LongTensor) -> LongTensor: 24 | ranking = self._rank_minus_one(ranking) 25 | n = ranking.shape[-1] 26 | mask = torch.zeros(n, n, dtype=torch.long, device=ranking.device) 27 | mask[torch.arange(n), ranking] = 1 28 | return mask 29 | 30 | def _make_mask_with_labels( 31 | self, 32 | labels: LongTensor, 33 | ranking: LongTensor, 34 | ) -> tuple[LongTensor, LongTensor, Tensor]: 35 | assert labels.shape[0] == ranking.shape[0] 36 | assert labels.shape[1] >= ranking.shape[1] 37 | label_mask = torch.zeros( 38 | labels.shape[0], labels.shape[1], ranking.shape[1], dtype=torch.long, device=labels.device) 39 | ranking_mask = torch.zeros( 40 | labels.shape[0], labels.shape[1], ranking.shape[1], dtype=torch.long, device=labels.device) 41 | 42 | for i in range(ranking.shape[0]): 43 | assert (labels[i] != IGNORE_TOKEN_ID).sum() == ranking.shape[1] 44 | label_mask[i, labels[i] != IGNORE_TOKEN_ID] = self._make_label_mask(ranking[i]) 45 | ranking_mask[i, labels[i] != IGNORE_TOKEN_ID] = self._make_ranking_mask(ranking[i]) 46 | 47 | return label_mask.contiguous(), ranking_mask.contiguous() 48 | 49 | def forward( 50 | self, 51 | hidden_states: Tensor, 52 | text_embeddings: Tensor, 53 | labels: LongTensor, 54 | ranking: LongTensor 55 | ) -> tuple[Tensor, Tensor]: 56 | logits = (hidden_states @ text_embeddings.permute(0, 2, 1)).contiguous() 57 | 58 | # Shift so that tokens < n predict n 59 | shift_logits = logits[..., :-1, :].contiguous() 60 | shift_labels = labels[..., 1:].contiguous() 61 | 62 | label_mask, ranking_mask = self._make_mask_with_labels(shift_labels, ranking) 63 | labels = self._rank_minus_one(ranking).view(-1) 64 | shift_logits[ranking_mask == 0] = -1e9 # don't set to -inf, otherwise it will cause NaN 65 | shift_logits = shift_logits[label_mask.sum(-1).bool()] 66 | logprob = torch.nn.functional.cross_entropy(shift_logits, labels, reduce=False).view(ranking.shape[0], -1).float() 67 | loss = logprob.mean() 68 | return loss, shift_logits.reshape(ranking.shape[0], ranking.shape[1], -1) 69 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import Trainer as HFTrainer 4 | 5 | from utils import get_adapter_state_maybe_zero_3 6 | 7 | 8 | class Trainer(HFTrainer): 9 | def _save_checkpoint(self, model, trial, metrics=None): 10 | if getattr(self.args, 'tune_mlp_adapter', False) and getattr(self.args, 'freeze_backbone', False): 11 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 12 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 13 | 14 | run_dir = self._get_output_dir(trial=trial) 15 | output_dir = os.path.join(run_dir, checkpoint_folder) 16 | 17 | # Only save Adapter 18 | keys_to_match = ['projector'] 19 | 20 | weight_to_save = get_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 21 | 22 | if self.args.local_rank == 0 or self.args.local_rank == -1: 23 | self.model.config.save_pretrained(output_dir) 24 | torch.save(weight_to_save, os.path.join(output_dir, f'projector.bin')) 25 | else: 26 | super()._save_checkpoint(model, trial, metrics) 27 | 28 | def _save(self, output_dir=None, state_dict=None): 29 | if getattr(self.args, 'tune_mlp_adapter', False) and getattr(self.args, 'freeze_backbone', False): 30 | pass 31 | else: 32 | super()._save(output_dir, state_dict) 33 | 34 | 35 | class RankTrainer(Trainer): 36 | def compute_loss(self, model, inputs, return_outputs=False): 37 | 38 | inputs_w_content = inputs.pop("inputs_w_content", None) 39 | inputs_wo_content = inputs.pop("inputs_wo_content", None) 40 | extra_text_inputs = inputs.pop("extra_text_inputs", dict()) 41 | 42 | if inputs_wo_content is not None: 43 | outputs1 = model( 44 | **inputs_wo_content, 45 | **extra_text_inputs, 46 | **inputs, 47 | ) 48 | if inputs_w_content is not None: 49 | outputs2 = model( 50 | **inputs_w_content, 51 | **extra_text_inputs, 52 | **inputs, 53 | ) 54 | 55 | if inputs_wo_content is not None: 56 | loss1, logits1 = outputs1.loss, outputs1.logits 57 | else: 58 | loss1, logits1 = None, None 59 | if inputs_w_content is not None: 60 | loss2, logits2 = outputs2.loss, outputs2.logits 61 | else: 62 | loss2, logits2 = None, None 63 | 64 | if self.args.kl_loss_weight > 0 and loss1 is not None and loss2 is not None: 65 | loss = self.args.loss1_weight * loss1 + self.args.loss2_weight * loss2 66 | kl_loss = torch.nn.functional.kl_div( 67 | input=torch.log_softmax(logits1, dim=-1), 68 | target=torch.log_softmax(logits2, dim=-1), 69 | log_target=True, 70 | reduction="batchmean", 71 | ) 72 | loss += self.args.kl_loss_weight * kl_loss 73 | else: 74 | loss = self.args.loss1_weight * loss1 + self.args.loss2_weight * loss2 \ 75 | if (loss1 and loss2) else (loss1 or loss2) 76 | 77 | outputs = { 78 | "loss": loss, 79 | "logits1": logits1, 80 | "logits2": logits2, 81 | } 82 | 83 | return (loss, outputs) if return_outputs else loss 84 | -------------------------------------------------------------------------------- /src/scripts/indexes_and_topics.py: -------------------------------------------------------------------------------- 1 | INDEX = { 2 | 'bm25': { 3 | 'dl19': 'msmarco-v1-passage', 4 | 'dl20': 'msmarco-v1-passage', 5 | 'covid': 'beir-v1.0.0-trec-covid.flat', 6 | 'arguana': 'beir-v1.0.0-arguana.flat', 7 | 'touche': 'beir-v1.0.0-webis-touche2020.flat', 8 | 'news': 'beir-v1.0.0-trec-news.flat', 9 | 'scifact': 'beir-v1.0.0-scifact.flat', 10 | 'fiqa': 'beir-v1.0.0-fiqa.flat', 11 | 'scidocs': 'beir-v1.0.0-scidocs.flat', 12 | 'nfc': 'beir-v1.0.0-nfcorpus.flat', 13 | 'quora': 'beir-v1.0.0-quora.flat', 14 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity.flat', 15 | 'fever': 'beir-v1.0.0-fever.flat', 16 | 'robust04': 'beir-v1.0.0-robust04.flat', 17 | 'signal': 'beir-v1.0.0-signal1m.flat', 18 | 'nq': 'beir-v1.0.0-nq.flat', 19 | 'cfever': 'beir-v1.0.0-climate-fever.flat', 20 | 'hotpotqa': 'beir-v1.0.0-hotpotqa.flat', 21 | }, 22 | 'splade++ed': { 23 | 'dl19': 'msmarco-v1-passage-splade-pp-ed-text', 24 | 'dl20': 'msmarco-v1-passage-splade-pp-ed-text', 25 | 'covid': 'beir-v1.0.0-trec-covid.splade-pp-ed', 26 | 'arguana': 'beir-v1.0.0-arguana.splade-pp-ed', 27 | 'touche': 'beir-v1.0.0-webis-touche2020.splade-pp-ed', 28 | 'news': 'beir-v1.0.0-trec-news.splade-pp-ed', 29 | 'scifact': 'beir-v1.0.0-scifact.splade-pp-ed', 30 | 'fiqa': 'beir-v1.0.0-fiqa.splade-pp-ed', 31 | 'scidocs': 'beir-v1.0.0-scidocs.splade-pp-ed', 32 | 'nfc': 'beir-v1.0.0-nfcorpus.splade-pp-ed', 33 | 'quora': 'beir-v1.0.0-quora.splade-pp-ed', 34 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity.splade-pp-ed', 35 | 'fever': 'beir-v1.0.0-fever.splade-pp-ed', 36 | 'robust04': 'beir-v1.0.0-robust04.splade-pp-ed', 37 | 'signal': 'beir-v1.0.0-signal1m.splade-pp-ed', 38 | 'nq': 'beir-v1.0.0-nq.splade-pp-ed', 39 | 'cfever': 'beir-v1.0.0-climate-fever.splade-pp-ed', 40 | 'hotpotqa': 'beir-v1.0.0-hotpotqa.splade-pp-ed' 41 | }, 42 | 'dense': { 43 | 'dl19': 'msmarco-v1-passage', 44 | 'dl20': 'msmarco-v1-passage', 45 | 'covid': 'trec-covid', 46 | 'arguana': 'arguana', 47 | 'touche': 'webis-touche2020', 48 | 'news': 'trec-news', 49 | 'scifact': 'scifact', 50 | 'fiqa': 'fiqa', 51 | 'scidocs': 'scidocs', 52 | 'nfc': 'nfcorpus', 53 | 'quora': 'quora', 54 | 'dbpedia': 'dbpedia-entity', 55 | 'fever': 'fever', 56 | 'robust04': 'robust04', 57 | 'signal': 'signal1m', 58 | 'nq': 'nq', 59 | 'cfever': 'climate-fever', 60 | 'hotpotqa': 'hotpotqa' 61 | } 62 | } 63 | 64 | TOPICS = { 65 | 'dl19': 'dl19-passage', 66 | 'dl20': 'dl20-passage', 67 | 'covid': 'beir-v1.0.0-trec-covid-test', 68 | 'arguana': 'beir-v1.0.0-arguana-test', 69 | 'touche': 'beir-v1.0.0-webis-touche2020-test', 70 | 'news': 'beir-v1.0.0-trec-news-test', 71 | 'scifact': 'beir-v1.0.0-scifact-test', 72 | 'fiqa': 'beir-v1.0.0-fiqa-test', 73 | 'scidocs': 'beir-v1.0.0-scidocs-test', 74 | 'nfc': 'beir-v1.0.0-nfcorpus-test', 75 | 'quora': 'beir-v1.0.0-quora-test', 76 | 'dbpedia': 'beir-v1.0.0-dbpedia-entity-test', 77 | 'fever': 'beir-v1.0.0-fever-test', 78 | 'robust04': 'beir-v1.0.0-robust04-test', 79 | 'signal': 'beir-v1.0.0-signal1m-test', 80 | 'nq': 'beir-v1.0.0-nq-test', 81 | 'cfever': 'beir-v1.0.0-climate-fever-test', 82 | 'hotpotqa': 'beir-v1.0.0-hotpotqa-test', 83 | } 84 | -------------------------------------------------------------------------------- /src/modeling/encoder.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn, Tensor 4 | from transformers import AutoModel 5 | from peft import PeftModel, PeftConfig 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"projector_type": 'identity'} 18 | 19 | 20 | def build_projector(config): 21 | projector_type = getattr(config, 'projector_type', 'linear') 22 | 23 | if projector_type == 'linear': 24 | return nn.Linear(config.embedding_size, config.hidden_size) 25 | 26 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 27 | if mlp_gelu_match: 28 | mlp_depth = int(mlp_gelu_match.group(1)) 29 | modules = [nn.Linear(config.embedding_size, config.hidden_size)] 30 | for _ in range(1, mlp_depth): 31 | modules.append(nn.GELU()) 32 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 33 | return nn.Sequential(*modules) 34 | 35 | if projector_type == 'identity': 36 | return IdentityMap() 37 | 38 | raise ValueError(f'Unknown projector type: {projector_type}') 39 | 40 | 41 | def get_peft_model(peft_model_name): 42 | config = PeftConfig.from_pretrained(peft_model_name) 43 | base_model = AutoModel.from_pretrained(config.base_model_name_or_path) 44 | model = PeftModel.from_pretrained(base_model, peft_model_name) 45 | model = model.merge_and_unload() 46 | return model 47 | 48 | 49 | class Encoder(nn.Module): 50 | def __init__(self, model_name, config): 51 | super().__init__() 52 | if "lora" in model_name: 53 | self.encoder = get_peft_model(model_name) 54 | else: 55 | self.encoder = AutoModel.from_pretrained(model_name, trust_remote_code=True) 56 | 57 | self.config = self.encoder.config 58 | self.pooling = config.encoder_pooling 59 | self.requires_grad_(False) 60 | 61 | @staticmethod 62 | @torch.no_grad() 63 | def mean_pooling(embeddings: Tensor, attention_mask: Tensor) -> Tensor: 64 | return (torch.sum(embeddings * attention_mask.unsqueeze(-1), dim=1) \ 65 | / torch.clamp(torch.sum(attention_mask, dim=1, keepdims=True), min=1e-9)).to(embeddings.dtype) 66 | 67 | @torch.no_grad() 68 | def forward(self, **inputs) -> dict[str, Tensor]: 69 | for key in inputs: 70 | inputs[key] = inputs[key].to(self.encoder.device) 71 | batch_size = 16 72 | all_embeddings = [] 73 | for i in range(0, len(inputs['input_ids']), batch_size): 74 | batch_inputs = {key: value[i:i+batch_size] for key, value in inputs.items()} 75 | outputs = self.encoder(**batch_inputs) 76 | if self.pooling == 'mean': 77 | embeddings = self.mean_pooling(outputs.last_hidden_state, batch_inputs['attention_mask']) 78 | elif self.pooling == 'cls': 79 | embeddings = outputs.last_hidden_state[:, 0] 80 | elif self.pooling == 'last_token': 81 | embeddings = outputs.last_hidden_state[torch.arange(batch_inputs['attention_mask'].shape[0], device=outputs.last_hidden_state.device), 82 | batch_inputs['attention_mask'].sum(-1) - 1] 83 | all_embeddings.append(embeddings) 84 | all_embeddings = torch.cat(all_embeddings, dim=0) 85 | all_embeddings = torch.nn.functional.normalize(all_embeddings, p=2, dim=-1) 86 | return all_embeddings 87 | -------------------------------------------------------------------------------- /src/modeling/meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from torch import Tensor 4 | from typing import Optional, Union 5 | 6 | from modeling.model import ELMMetaModel 7 | from constants import PLACEHOLDER, RANK_TOKEN 8 | 9 | 10 | class MetaLM: 11 | 12 | model: ELMMetaModel 13 | 14 | def get_model(self) -> ELMMetaModel: 15 | return self.model 16 | 17 | def get_encoder(self): 18 | return self.get_model().get_encoder() 19 | 20 | def get_projector(self): 21 | return self.get_model().get_projector() 22 | 23 | def get_encoder_head(self): 24 | return self.get_model().get_encoder_head() 25 | 26 | def prepare_inputs_labels_embeddings( 27 | self, 28 | input_ids: Optional[Tensor], 29 | position_ids: Optional[Tensor], 30 | attention_mask: Optional[Tensor], 31 | past_key_values: Optional[tuple], 32 | labels: Optional[Tensor], 33 | **extra_texts_inputs: dict[str, Tensor] 34 | ): 35 | if self.get_model().get_encoder() is None or "extra_text_input_ids" not in extra_texts_inputs: 36 | return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None 37 | 38 | assert "extra_text_input_ids" in extra_texts_inputs and "extra_text_attention_mask" in extra_texts_inputs, extra_texts_inputs.keys() 39 | 40 | input_embeddings = [] 41 | all_text_embeddings = [] 42 | extra_text_positions = [] 43 | 44 | # TODO: allow one text corresponding to multiple placeholders, now it's 1 to 1 45 | for extra_text_input_ids, extra_text_attention_masks, cur_input_ids in \ 46 | zip(extra_texts_inputs["extra_text_input_ids"], extra_texts_inputs["extra_text_attention_mask"], input_ids): 47 | 48 | num_extra_texts = (cur_input_ids == PLACEHOLDER_ID).sum() + (cur_input_ids == RANK_TOKEN_ID).sum() 49 | assert num_extra_texts <= extra_text_input_ids.shape[0] 50 | extra_text_input_ids = extra_text_input_ids[:num_extra_texts] 51 | extra_text_attention_masks = extra_text_attention_masks[:num_extra_texts] 52 | project_text_embeddings = self.get_model().encode_texts( 53 | input_ids=extra_text_input_ids, attention_mask=extra_text_attention_masks) 54 | 55 | cur_input_embeds = self.get_model().embed_tokens(cur_input_ids.to(self.get_model().device)) 56 | new_input_embeds = cur_input_embeds.clone() 57 | project_text_embeddings = project_text_embeddings.to(new_input_embeds.device) 58 | 59 | text_as_token_indices = (cur_input_ids == PLACEHOLDER_ID) | (cur_input_ids == RANK_TOKEN_ID) 60 | new_input_embeds[text_as_token_indices] = project_text_embeddings.to(cur_input_embeds.dtype) 61 | input_embeddings.append(new_input_embeds) 62 | 63 | all_text_embeddings.append( 64 | project_text_embeddings[:(cur_input_ids == PLACEHOLDER_ID).sum().item(), :]) 65 | 66 | extra_text_position = (cur_input_ids == PLACEHOLDER_ID) 67 | extra_text_positions.append(extra_text_position) 68 | 69 | input_embeddings = torch.stack(input_embeddings) 70 | all_text_embeddings = torch.stack(all_text_embeddings) 71 | extra_text_positions = torch.stack(extra_text_positions) 72 | 73 | return None, position_ids, attention_mask, past_key_values, input_embeddings, labels, \ 74 | all_text_embeddings, extra_text_positions 75 | 76 | def initialize_tokenizer(self, tokenizer: transformers.PreTrainedTokenizer): 77 | global PLACEHOLDER_ID 78 | global RANK_TOKEN_ID 79 | tokenizer.add_tokens([PLACEHOLDER], special_tokens=True) 80 | tokenizer.add_tokens([RANK_TOKEN], special_tokens=True) 81 | self.resize_token_embeddings(len(tokenizer)) 82 | PLACEHOLDER_ID = tokenizer.convert_tokens_to_ids(PLACEHOLDER) 83 | RANK_TOKEN_ID = tokenizer.convert_tokens_to_ids(RANK_TOKEN) 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific ignores 2 | checkpoints/ 3 | runs/ 4 | wandb/ 5 | results/ 6 | data/ 7 | indexes/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | -------------------------------------------------------------------------------- /src/modeling/rank_lm/modeling_llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | from transformers.models.llama.modeling_llama import LlamaConfig, LlamaModel, LlamaForCausalLM 5 | from transformers.file_utils import ModelOutput 6 | from modeling.model import ELMMetaModel 7 | from modeling.meta import MetaLM 8 | from modeling.rank_lm.loss import ListMLELoss 9 | 10 | 11 | class EmbedLlamaConfig(LlamaConfig): 12 | model_type = "embed_llama" 13 | 14 | 15 | class EmbedLlamaModel(ELMMetaModel, LlamaModel): 16 | config_class = EmbedLlamaConfig 17 | 18 | def __init__(self, config: LlamaConfig): 19 | super().__init__(config) 20 | 21 | 22 | @dataclass 23 | class RankingOutput(ModelOutput): 24 | loss: Optional[torch.FloatTensor] = None 25 | logits: Optional[torch.FloatTensor] = None 26 | ranking: Optional[torch.LongTensor] = None 27 | 28 | 29 | class EmbedLlamaForRankLM(MetaLM, LlamaForCausalLM): 30 | config_class = EmbedLlamaConfig 31 | 32 | def __init__(self, config: LlamaConfig): 33 | super().__init__(config) 34 | self.model = EmbedLlamaModel(config) 35 | self.config = config 36 | self.oringinal_vocab_size = config.vocab_size 37 | self.post_init() 38 | 39 | self.loss_function = ListMLELoss(weighted="weighted_1") 40 | self.normalize_embeddings = False 41 | 42 | def forward( 43 | self, 44 | input_ids: torch.LongTensor = None, 45 | attention_mask: Optional[torch.Tensor] = None, 46 | position_ids: Optional[torch.LongTensor] = None, 47 | past_key_values: Optional[list[torch.FloatTensor]] = None, 48 | inputs_embeds: Optional[torch.FloatTensor] = None, 49 | labels: Optional[torch.LongTensor] = None, 50 | use_cache: Optional[bool] = None, 51 | output_attentions: Optional[bool] = None, 52 | output_hidden_states: Optional[bool] = None, 53 | return_dict: Optional[bool] = None, 54 | **extra_texts_inputs 55 | ) -> RankingOutput: 56 | 57 | if inputs_embeds is None: 58 | ( 59 | input_ids, 60 | position_ids, 61 | attention_mask, 62 | past_key_values, 63 | inputs_embeds, 64 | labels, 65 | extra_embeddings, 66 | _, 67 | ) = self.prepare_inputs_labels_embeddings( 68 | input_ids, 69 | position_ids, 70 | attention_mask, 71 | past_key_values, 72 | labels, 73 | **extra_texts_inputs 74 | ) 75 | 76 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 77 | output_hidden_states = ( 78 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 79 | ) 80 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 81 | 82 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 83 | outputs = self.model( 84 | input_ids=input_ids, 85 | attention_mask=attention_mask, 86 | position_ids=position_ids, 87 | past_key_values=past_key_values, 88 | inputs_embeds=inputs_embeds, 89 | use_cache=use_cache, 90 | output_attentions=output_attentions, 91 | output_hidden_states=output_hidden_states, 92 | return_dict=return_dict, 93 | ) 94 | ranking = extra_texts_inputs["ranking"] 95 | hidden_states = outputs[0] 96 | if self.normalize_embeddings: 97 | hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=-1) 98 | extra_embeddings = torch.nn.functional.normalize(extra_embeddings, p=2, dim=-1) 99 | loss, logits = self.loss_function(hidden_states, extra_embeddings, labels, ranking) 100 | return RankingOutput( 101 | loss=loss, 102 | logits=logits, 103 | ranking=ranking, 104 | ) 105 | -------------------------------------------------------------------------------- /src/modeling/rank_lm/modeling_mistral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | from transformers.models.mistral.modeling_mistral import MistralConfig, MistralModel, MistralForCausalLM 5 | from transformers.file_utils import ModelOutput 6 | from modeling.model import ELMMetaModel 7 | from modeling.meta import MetaLM 8 | from modeling.rank_lm.loss import ListMLELoss 9 | 10 | 11 | class EmbedMistralConfig(MistralConfig): 12 | model_type = "embed_mistral" 13 | 14 | 15 | class EmbedMistralModel(ELMMetaModel, MistralModel): 16 | config_class = EmbedMistralConfig 17 | 18 | def __init__(self, config: MistralConfig): 19 | super().__init__(config) 20 | 21 | 22 | @dataclass 23 | class RankingOutput(ModelOutput): 24 | loss: Optional[torch.FloatTensor] = None 25 | logits: Optional[torch.FloatTensor] = None 26 | ranking: Optional[torch.LongTensor] = None 27 | 28 | 29 | class EmbedMistralForRankLM(MetaLM, MistralForCausalLM): 30 | config_class = EmbedMistralConfig 31 | 32 | def __init__(self, config: MistralConfig): 33 | super().__init__(config) 34 | self.model = EmbedMistralModel(config) 35 | self.config = config 36 | self.oringinal_vocab_size = config.vocab_size 37 | self.post_init() 38 | 39 | self.loss_function = ListMLELoss() 40 | self.normalize_embeddings = False 41 | 42 | def forward( 43 | self, 44 | input_ids: torch.LongTensor = None, 45 | attention_mask: Optional[torch.Tensor] = None, 46 | position_ids: Optional[torch.LongTensor] = None, 47 | past_key_values: Optional[list[torch.FloatTensor]] = None, 48 | inputs_embeds: Optional[torch.FloatTensor] = None, 49 | labels: Optional[torch.LongTensor] = None, 50 | use_cache: Optional[bool] = None, 51 | output_attentions: Optional[bool] = None, 52 | output_hidden_states: Optional[bool] = None, 53 | return_dict: Optional[bool] = None, 54 | **extra_texts_inputs 55 | ) -> RankingOutput: 56 | 57 | if inputs_embeds is None: 58 | ( 59 | input_ids, 60 | position_ids, 61 | attention_mask, 62 | past_key_values, 63 | inputs_embeds, 64 | labels, 65 | extra_embeddings, 66 | _, 67 | ) = self.prepare_inputs_labels_embeddings( 68 | input_ids, 69 | position_ids, 70 | attention_mask, 71 | past_key_values, 72 | labels, 73 | **extra_texts_inputs 74 | ) 75 | 76 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 77 | output_hidden_states = ( 78 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 79 | ) 80 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 81 | 82 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 83 | outputs = self.model( 84 | input_ids=input_ids, 85 | attention_mask=attention_mask, 86 | position_ids=position_ids, 87 | past_key_values=past_key_values, 88 | inputs_embeds=inputs_embeds, 89 | use_cache=use_cache, 90 | output_attentions=output_attentions, 91 | output_hidden_states=output_hidden_states, 92 | return_dict=return_dict, 93 | ) 94 | ranking = extra_texts_inputs["ranking"] 95 | hidden_states = outputs[0] 96 | if self.normalize_embeddings: 97 | hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=-1) 98 | extra_embeddings = torch.nn.functional.normalize(extra_embeddings, p=2, dim=-1) 99 | loss, logits = self.loss_function(hidden_states, extra_embeddings, labels, ranking) 100 | return RankingOutput( 101 | loss=loss, 102 | logits=logits, 103 | ranking=ranking, 104 | ) 105 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from deepspeed import zero 5 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 6 | import torch 7 | import transformers 8 | 9 | 10 | def maybe_zero_3(param, ignore_status=False, name=None): 11 | if hasattr(param, "ds_id"): 12 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 13 | if not ignore_status: 14 | logging.warning( 15 | f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 16 | with zero.GatheredParameters([param]): 17 | param = param.data.detach().cpu().clone() 18 | else: 19 | param = param.detach().cpu().clone() 20 | return param 21 | 22 | # Borrowed from peft.utils.get_peft_model_state_dict 23 | 24 | 25 | def get_peft_state_maybe_zero_3(named_params, bias): 26 | if bias == "none": 27 | to_return = {k: t for k, t in named_params if "lora_" in k} 28 | elif bias == "all": 29 | to_return = {k: t for k, 30 | t in named_params if "lora_" in k or "bias" in k} 31 | elif bias == "lora_only": 32 | to_return = {} 33 | maybe_lora_bias = {} 34 | lora_bias_names = set() 35 | for k, t in named_params: 36 | if "lora_" in k: 37 | to_return[k] = t 38 | bias_name = k.split("lora_")[0] + "bias" 39 | lora_bias_names.add(bias_name) 40 | elif "bias" in k: 41 | maybe_lora_bias[k] = t 42 | for k, t in maybe_lora_bias: 43 | if bias_name in lora_bias_names: 44 | to_return[bias_name] = t 45 | else: 46 | raise NotImplementedError 47 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 48 | return to_return 49 | 50 | 51 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): 52 | to_return = {k: t for k, t in named_params if "lora_" not in k} 53 | if require_grad_only: 54 | to_return = {k: t for k, t in to_return.items() if t.requires_grad} 55 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() 56 | for k, v in to_return.items()} 57 | return to_return 58 | 59 | 60 | def get_adapter_state_maybe_zero_3(named_params, keys_to_match): 61 | to_return = {k: t for k, t in named_params if any( 62 | key_match in k for key_match in keys_to_match)} 63 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() 64 | for k, v in to_return.items()} 65 | return to_return 66 | 67 | 68 | def find_all_linear_names(model): 69 | cls = torch.nn.Linear 70 | lora_module_names = set() 71 | keywords = ['projector'] 72 | for name, module in model.named_modules(): 73 | if any(keyword in name for keyword in keywords): 74 | continue 75 | if isinstance(module, cls): 76 | names = name.split('.') 77 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 78 | 79 | if 'lm_head' in lora_module_names: # needed for 16-bit 80 | lora_module_names.remove('lm_head') 81 | return list(lora_module_names) 82 | 83 | 84 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, 85 | output_dir: str): 86 | """Collects the state dict and dump to disk.""" 87 | 88 | # FIXME 89 | 90 | if getattr(trainer.args, "tune_mlp_adapter", False) and getattr(trainer.args, 'freeze_backbone', False): 91 | # Only save Adapter 92 | keys_to_match = ['projector'] 93 | 94 | weight_to_save = get_adapter_state_maybe_zero_3( 95 | trainer.model.named_parameters(), keys_to_match) 96 | trainer.model.config.save_pretrained(output_dir) 97 | 98 | current_folder = output_dir.split('/')[-1] 99 | parent_folder = os.path.dirname(output_dir) 100 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: 101 | if current_folder.startswith('checkpoint-'): 102 | projector_folder = os.path.join(parent_folder, "projector") 103 | os.makedirs(projector_folder, exist_ok=True) 104 | torch.save(weight_to_save, os.path.join( 105 | projector_folder, f'{current_folder}.bin')) 106 | else: 107 | torch.save(weight_to_save, os.path.join( 108 | output_dir, f'projector.bin')) 109 | return 110 | 111 | if trainer.deepspeed: 112 | torch.cuda.synchronize() 113 | trainer.save_model(output_dir) 114 | return 115 | 116 | state_dict = trainer.model.state_dict() 117 | if trainer.args.should_save: 118 | cpu_state_dict = { 119 | key: value.cpu() 120 | for key, value in state_dict.items() 121 | } 122 | del state_dict 123 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 124 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import os 5 | from typing import Optional, Callable, Any 6 | from tqdm import tqdm 7 | 8 | from ranker import * 9 | 10 | 11 | def write_results(rerank_results, output_file): 12 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 13 | with open(output_file, "w") as f: 14 | for i, hits in enumerate(rerank_results): 15 | for j, hit in enumerate(hits): 16 | f.write( 17 | f"{hit['qid']} Q{i} {hit['docid']} {j + 1} {round(1 / (j + 1), 3)} rank") 18 | f.write("\n") 19 | 20 | 21 | class ListwiseSilidingWindowReranker: 22 | def rerank( 23 | self, 24 | query: str, 25 | candidates: list[str], 26 | ranking_func: Callable[[str, list[str]], list[int]], 27 | rank_start: int = 0, 28 | rank_end: int = None, 29 | window_size: Optional[int] = None, 30 | step: int = 10, 31 | **kwargs: dict[str, Any], 32 | ) -> list[str]: 33 | 34 | rerank_result = copy.deepcopy(candidates) 35 | 36 | window_size = window_size or len(candidates) 37 | rank_end = rank_end or len(candidates) 38 | start_pos, end_pos = rank_end - window_size, rank_end 39 | while end_pos > rank_start and start_pos + step != rank_start: 40 | start_pos = max(start_pos, rank_start) 41 | # range from 0 to window_size 42 | permutation = ranking_func(query, rerank_result[start_pos:end_pos]) 43 | 44 | # receive permutation 45 | cut_range = copy.deepcopy(rerank_result[start_pos:end_pos]) 46 | for local_rank, index in enumerate(permutation): 47 | rerank_result[start_pos + 48 | local_rank] = copy.deepcopy(cut_range[index]) 49 | 50 | start_pos, end_pos = start_pos - step, end_pos - step 51 | 52 | return rerank_result 53 | 54 | 55 | def eval_model(args): 56 | from scripts.trec_eval import trec_eval 57 | from scripts.indexes_and_topics import TOPICS 58 | 59 | reranker = ListwiseSilidingWindowReranker() 60 | 61 | if args.ranker == "listwise-text-embedding": 62 | ranking_model = ListwiseTextEmbeddingRanker( 63 | model_path=args.model_path, 64 | model_base=args.model_base, 65 | ) 66 | elif args.ranker == "listwise-embedding": 67 | ranking_model = ListwiseEmbeddingRanker( 68 | model_path=args.model_path, 69 | model_base=args.model_base, 70 | ) 71 | elif args.ranker == "listwise-text": 72 | ranking_model = ListwiseTextRanker( 73 | model_path=args.model_path, 74 | model_base=args.model_base, 75 | model_name="mistral" 76 | ) 77 | else: 78 | raise ValueError(f"Ranker {args.ranker} not supported") 79 | 80 | for dataset in args.datasets: 81 | 82 | output_file = os.path.join( 83 | "results", "rerank_results", args.retriever, 84 | f"eval_{dataset}_{ranking_model.model_name.split('/')[-1]}_{args.ranker}_top{args.topk}.txt" 85 | ) 86 | if os.path.exists(output_file) and not args.overwrite: 87 | print(f"{output_file} exists, skipping") 88 | trec_eval(TOPICS[dataset], output_file) 89 | continue 90 | if os.path.exists(output_file) and args.overwrite: 91 | output_file = output_file.replace(".txt", "_1.txt") 92 | 93 | input_file = os.path.join( 94 | "results", "retrieval_results", args.retriever, 95 | f"{dataset}_top{args.topk}.jsonl" 96 | ) 97 | with open(input_file, "r") as f: 98 | data = [json.loads(line) for line in f] 99 | 100 | rerank_results = [] 101 | for i in tqdm(range(len(data))): 102 | rerank_result = reranker.rerank( 103 | query=data[i]["query"], 104 | candidates=data[i]["hits"], 105 | ranking_func=ranking_model, 106 | window_size=20, 107 | step=10, 108 | ) 109 | rerank_results.append(rerank_result) 110 | write_results(rerank_results, output_file) 111 | trec_eval(TOPICS[dataset], output_file) 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("--model-path", type=str, default=None) 117 | parser.add_argument("--model-base", type=str, default=None) 118 | parser.add_argument("--datasets", nargs="+", default=["dl19", "dl20"]) 119 | parser.add_argument( 120 | "--retriever", type=str, default="bm25", 121 | choices=["bm25", "jina-embeddings-v2-base-en", "e5-mistral", "splade++ed", "bge-base-en-v1.5"] 122 | ) 123 | parser.add_argument("--ranker", type=str, default="listwise-embedding",) 124 | parser.add_argument("--topk", type=int, default=100) 125 | parser.add_argument("--overwrite", action="store_true") 126 | args = parser.parse_args() 127 | eval_model(args) 128 | -------------------------------------------------------------------------------- /src/scripts/run_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import torch 5 | import numpy as np 6 | from pyserini.index import IndexReader 7 | from pyserini.search import LuceneSearcher, LuceneImpactSearcher, FaissSearcher, get_topics, get_qrels 8 | from pyserini.search.faiss import AutoQueryEncoder 9 | from trec_eval import trec_eval 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification 12 | from sentence_transformers import SentenceTransformer 13 | 14 | from indexes_and_topics import INDEX, TOPICS 15 | 16 | 17 | def run_retriever(topics, searcher, index_reader, qrels=None, topk=100, qid=None): 18 | ranks = [] 19 | for qid in tqdm(topics): 20 | if qid in qrels: 21 | query = topics[qid]['title'] 22 | ranks.append({'query': query, 'hits': []}) 23 | hits = searcher.search(query, k=topk) 24 | rank = 0 25 | for hit in hits: 26 | rank += 1 27 | if index_reader.doc(hit.docid): 28 | content = json.loads(index_reader.doc(hit.docid).raw()) 29 | else: 30 | continue 31 | if "title" in content: 32 | content = ( 33 | "Title: " + content["title"] + 34 | " " + "Content: " + content["text"] 35 | ) 36 | elif "contents" in content: 37 | content = content["contents"] 38 | else: 39 | content = content["passage"] 40 | content = ' '.join(content.split()) 41 | ranks[-1]['hits'].append({ 42 | 'content': content, 43 | 'qid': qid, 44 | 'docid': hit.docid, 45 | 'rank': rank, 46 | 'score': hit.score if isinstance(hit.score, float) else hit.score.item() 47 | }) 48 | return ranks 49 | 50 | 51 | def write_retrival_results(rank_results, file): 52 | with open(file, 'w') as f: 53 | for item in rank_results: 54 | f.write((json.dumps(item) + '\n')) 55 | return True 56 | 57 | 58 | def write_eval_file(rank_results, file): 59 | with open(file, 'w') as f: 60 | for i in range(len(rank_results)): 61 | rank = 1 62 | hits = rank_results[i]['hits'] 63 | for hit in hits: 64 | f.write(f"{hit['qid']} Q0 {hit['docid']} {rank} {hit['score']} rank\n") 65 | rank += 1 66 | return True 67 | 68 | 69 | def eval_dataset(args): 70 | 71 | dataset, retriever, topk = args.dataset, args.retriever, args.topk 72 | 73 | print('#' * 20) 74 | print(f'Evaluation on {dataset}') 75 | print('#' * 20) 76 | 77 | retrieval_results_path = os.path.join('results', 'retrieval_results', retriever.split('/')[-1]) 78 | retrieval_results_file = os.path.join(retrieval_results_path, f'{dataset}_top{topk}.jsonl') 79 | if os.path.exists(retrieval_results_file): 80 | with open(retrieval_results_file) as f: 81 | retrieval_results = [json.loads(line) for line in f] 82 | else: 83 | if retriever == 'bm25': 84 | searcher = LuceneSearcher.from_prebuilt_index(INDEX[retriever][dataset]) 85 | elif retriever == 'splade++ed': 86 | searcher = LuceneImpactSearcher.from_prebuilt_index( 87 | INDEX[retriever][dataset], 88 | query_encoder='SpladePlusPlusEnsembleDistil', 89 | min_idf=0, 90 | encoder_type='onnx' 91 | ) 92 | else: 93 | encoder = AutoQueryEncoder(retriever, pooling=args.dense_encoder_pooling, l2_norm=True) 94 | retriever = retriever.split('/')[-1] # maybe hf model 95 | index_dir = os.path.join( 96 | 'indexes', f'{INDEX["dense"][dataset]}.{retriever}') 97 | searcher = FaissSearcher( 98 | index_dir=index_dir, 99 | query_encoder=encoder 100 | ) 101 | 102 | index_reader = IndexReader.from_prebuilt_index(INDEX["bm25"][dataset]) 103 | topics = get_topics(TOPICS[dataset] if dataset != 'dl20' else 'dl20') 104 | qrels = get_qrels(TOPICS[dataset]) 105 | retrieval_results = run_retriever(topics, searcher, index_reader, qrels, topk=topk) 106 | os.makedirs(retrieval_results_path, exist_ok=True) 107 | write_retrival_results( 108 | retrieval_results, 109 | retrieval_results_file 110 | ) 111 | 112 | output_file = os.path.join(retrieval_results_path, f'eval_{dataset}_top{topk}.txt') 113 | write_eval_file(retrieval_results, output_file) 114 | trec_eval(TOPICS[dataset], output_file) 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('--dataset', type=str, required=True) 120 | retriever = parser.add_argument_group('retriever') 121 | retriever.add_argument('--retriever', type=str, default='bm25') 122 | retriever.add_argument('--dense-encoder-pooling', type=str, default='mean') 123 | retriever.add_argument('--topk', type=int, default=100) 124 | args = parser.parse_args() 125 | eval_dataset(args) 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PE-Rank 2 | 3 | Code for paper [Leveraging Passage Embeddings for Efficient Listwise Reranking with Large Language Models](https://arxiv.org/abs/2406.14848) 4 | 5 |

6 | 7 |

8 |

9 | Figure 1: Comparison between RankGPT (upper) and PE-Rank (lower). RankGPT takes the whole passages as input and outputs ordered numbers, while PE-Rank takes a list of special tokens as both input and output. On the right side, we show the reranking results on DL19 using different forms of inputs. 10 |

11 | 12 | ## Installation 13 | 14 | ```bash 15 | git clone git@github.com:liuqi6777/pe_rank.git 16 | ``` 17 | 18 | ## Evaluation 19 | 20 | The checkpoint of the PE-Rank model is available in this link: [PE-Rank](https://huggingface.co/liuqi6777/pe-rank-mistral-jina). 21 | 22 | ### Retrieval 23 | 24 | We provide the scripts for first-stage retrieval, for example, you can use the following command to use BM25 as the retrieval model: 25 | 26 | ```bash 27 | python src/scripts/run_evaluation.py --dataset dl19 --retriever bm25 --topk 100 28 | ``` 29 | 30 | This code will run the BM25 retrieval model on the DL19 dataset and save the retrieval results to `results/retrieval_results/bm25/dl19_top100.jsonl`. 31 | 32 | As alternative, we also provide all the retrieval results in this link: [https://huggingface.co/liuqi6777/pyserini_retrieval_results](https://huggingface.co/liuqi6777/pyserini_retrieval_results). You can download the retrieval results to `results/retrieval_results` folder. 33 | 34 | ### Reranking 35 | 36 | To run the reranking stage, you can use the following command: 37 | 38 | ```bash 39 | python src/evaluate.py --datasets dl19 --model-path liuqi6777/pe-rank-mistral-jina --retriever bm25 --topk 100 40 | ``` 41 | 42 | The reranking results will be saved to `results/rerank_results/bm25/eval_dl19_pe-rank-mistral-jina_listwise-embedding_top100.txt` and you can use the following compute the evaluation metrics: 43 | 44 | ```bash 45 | python src/scripts/trec_eval.py --dataset dl19 --ranking results/rerank_results/bm25/eval_dl19_pe-rank-mistral-jina_listwise-embedding_top100.txt 46 | ``` 47 | 48 | For other datasets or other retrieval models, just replace the `--datasets` and `--retriever` arguments. 49 | 50 | ### More usage 51 | 52 | Comming soon. 53 | 54 | ## Training 55 | 56 | If you want to train the PE-Rank model from scratch or using customized settings, you can follow the instructions below. 57 | 58 | ### Data Preparation 59 | 60 | All datasets used in the paper are available in this link: [pe_rank_data](https://huggingface.co/datasets/liuqi6777/pe_rank_data). Please download the data to `data` folder. 61 | 62 | For example, you can run the following command: 63 | 64 | ```bash 65 | git clone git@hf.co:datasets/liuqi6777/pe_rank_data ./data 66 | ``` 67 | 68 | You can refer the paper for more details about the datasets. 69 | 70 | ### Alignment Stage 71 | 72 | To run the alignment stage, you can use the following command: 73 | 74 | ```bash 75 | deepspeed --include="localhost:0,1,2,3" src/train.py \ 76 | --deepspeed scripts/zero2.json \ 77 | --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \ 78 | --data_path ./data/wiki2m.jsonl \ 79 | --encoder_name jinaai/jina-embeddings-v2-base-en \ 80 | --encoder_pooling mean \ 81 | --projector_type mlp2x_gelu \ 82 | --freeze_backbone \ 83 | --tune_mlp_adapter \ 84 | --bf16 \ 85 | --output_dir ./checkpoints/mistral.jina.projector \ 86 | --num_train_epochs 1 \ 87 | --per_device_train_batch_size 1 \ 88 | --gradient_accumulation_steps 1 \ 89 | --evaluation_strategy "no" \ 90 | --save_strategy "steps" \ 91 | --save_steps 1000 \ 92 | --save_total_limit 1 \ 93 | --learning_rate 1e-3 \ 94 | --warmup_ratio 0.03 \ 95 | --lr_scheduler_type "cosine" \ 96 | --logging_steps 1 \ 97 | --tf32 True \ 98 | --model_max_length 512 \ 99 | --gradient_checkpointing \ 100 | --attn_implementation flash_attention_2 \ 101 | --dataloader_num_workers 4 102 | ``` 103 | 104 | This command will run the alignment stage using the Mistral-7B model as the backbone and Jina-Embeddings as the encoder. 105 | 106 | ### Learning-to-Rank Stage 107 | 108 | To run the learning-to-rank stage, you can use the following command: 109 | 110 | ```bash 111 | deepspeed --include="localhost:4,5,6,7" --master_port="29700" src/train.py \ 112 | --deepspeed ./scripts/zero2.json \ 113 | --model_type rank_lm \ 114 | --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \ 115 | --data_path ./data/train.jsonl \ 116 | --use_embedding_with_content True \ 117 | --use_embedding_without_content True \ 118 | --kl_loss_weight 0.2 \ 119 | --loss1_weight 1 \ 120 | --loss2_weight 1 \ 121 | --encoder_name jinaai/jina-embeddings-v2-base-en \ 122 | --encoder_pooling mean \ 123 | --pretrain_mlp_adapter ./checkpoints/mistral.jina.projector/projector.bin \ 124 | --projector_type mlp2x_gelu \ 125 | --tune_mlp_adapter \ 126 | --bf16 True \ 127 | --tf32 True \ 128 | --output_dir "./checkpoints/pe-rank-mistral-jina" \ 129 | --overwrite_output_dir \ 130 | --num_train_epochs 1 \ 131 | --per_device_train_batch_size 4 \ 132 | --gradient_accumulation_steps 2 \ 133 | --save_strategy "steps" \ 134 | --save_steps 3000 \ 135 | --save_total_limit 2 \ 136 | --learning_rate 2e-5 \ 137 | --warmup_ratio 0.03 \ 138 | --lr_scheduler_type "cosine" \ 139 | --logging_steps 1 \ 140 | --model_max_length 4096 \ 141 | --gradient_checkpointing True \ 142 | --attn_implementation flash_attention_2 \ 143 | --dataloader_num_workers 2 144 | ``` 145 | 146 | This command will run the full learning-to-rank stage. 147 | 148 | ## Citation 149 | 150 | ```bibtex 151 | @article{liu2024leveraging, 152 | title={Leveraging Passage Embeddings for Efficient Listwise Reranking with Large Language Models}, 153 | author={Liu, Qi and Wang, Bo and Wang, Nan and Mao, Jiaxin}, 154 | journal={arXiv preprint arXiv:2406.14848}, 155 | year={2024} 156 | } 157 | ``` 158 | -------------------------------------------------------------------------------- /src/modeling/causal_lm/modeling_llama.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, List, Tuple 2 | import torch 3 | from torch import nn, Tensor 4 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 5 | from transformers import LlamaConfig, LlamaModel, LlamaForCausalLM 6 | 7 | from transformers.modeling_outputs import CausalLMOutputWithPast 8 | from transformers.generation.utils import GenerateOutput 9 | 10 | from modeling.model import ELMMetaModel 11 | from modeling.meta import MetaLM 12 | 13 | 14 | class EmbedLlamaConfig(LlamaConfig): 15 | model_type = "embed_llama" 16 | 17 | 18 | class EmbedLlamaModel(ELMMetaModel, LlamaModel): 19 | config_class = EmbedLlamaConfig 20 | 21 | def __init__(self, config: LlamaConfig): 22 | super().__init__(config) 23 | 24 | 25 | class EmbedLlamaForCausalLM(MetaLM, LlamaForCausalLM): 26 | config_class = EmbedLlamaConfig 27 | 28 | def __init__(self, config: LlamaConfig): 29 | super(LlamaForCausalLM, self).__init__(config) 30 | self.model = EmbedLlamaModel(config) 31 | self.pretraining_tp = config.pretraining_tp 32 | 33 | self.vocab_size = self.original_vocab_size = config.vocab_size 34 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 35 | 36 | # Initialize weights and apply final processing 37 | self.post_init() 38 | 39 | def forward( 40 | self, 41 | input_ids: torch.LongTensor = None, 42 | attention_mask: Optional[torch.Tensor] = None, 43 | position_ids: Optional[torch.LongTensor] = None, 44 | past_key_values: Optional[List[torch.FloatTensor]] = None, 45 | inputs_embeds: Optional[torch.FloatTensor] = None, 46 | labels: Optional[torch.LongTensor] = None, 47 | use_cache: Optional[bool] = None, 48 | output_attentions: Optional[bool] = None, 49 | output_hidden_states: Optional[bool] = None, 50 | return_dict: Optional[bool] = None, 51 | **extra_texts_inputs 52 | ) -> Union[Tuple, CausalLMOutputWithPast]: 53 | 54 | if inputs_embeds is None: 55 | ( 56 | input_ids, 57 | position_ids, 58 | attention_mask, 59 | past_key_values, 60 | inputs_embeds, 61 | labels, 62 | _, 63 | _ 64 | ) = self.prepare_inputs_labels_embeddings( 65 | input_ids, 66 | position_ids, 67 | attention_mask, 68 | past_key_values, 69 | labels, 70 | **extra_texts_inputs 71 | ) 72 | 73 | return super().forward( 74 | input_ids=input_ids, 75 | attention_mask=attention_mask, 76 | position_ids=position_ids, 77 | past_key_values=past_key_values, 78 | inputs_embeds=inputs_embeds, 79 | labels=labels, 80 | use_cache=use_cache, 81 | output_attentions=output_attentions, 82 | output_hidden_states=output_hidden_states, 83 | return_dict=return_dict 84 | ) 85 | 86 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, 87 | inputs_embeds=None, **kwargs) -> dict[str, Tensor]: 88 | extra_text_input_ids = kwargs.pop("extra_text_input_ids", None) 89 | extra_text_attention_mask = kwargs.pop("extra_text_attention_mask", None) 90 | inputs = super().prepare_inputs_for_generation( 91 | input_ids, past_key_values=past_key_values, attention_mask=attention_mask, 92 | inputs_embeds=inputs_embeds, **kwargs 93 | ) 94 | if extra_text_input_ids is not None and extra_text_attention_mask is not None: 95 | inputs["extra_text_input_ids"] = extra_text_input_ids 96 | inputs["extra_text_attention_mask"] = extra_text_attention_mask 97 | return inputs 98 | 99 | @torch.no_grad() 100 | def generate( 101 | self, 102 | inputs: Optional[torch.Tensor] = None, 103 | **kwargs, 104 | ) -> Union[GenerateOutput, torch.LongTensor]: 105 | position_ids = kwargs.pop("position_ids", None) 106 | attention_mask = kwargs.pop("attention_mask", None) 107 | if "inputs_embeds" in kwargs: 108 | raise NotImplementedError("`inputs_embeds` is not supported") 109 | 110 | extra_text_input_ids = kwargs.pop("extra_text_input_ids", None) 111 | extra_text_attention_mask = kwargs.pop("extra_text_attention_mask", None) 112 | 113 | if extra_text_input_ids is not None and extra_text_attention_mask is not None: 114 | ( 115 | inputs, 116 | position_ids, 117 | attention_mask, 118 | _, 119 | inputs_embeds, 120 | _, 121 | extra_text_embeddings, 122 | _, 123 | ) = self.prepare_inputs_labels_embeddings( 124 | inputs, 125 | position_ids, 126 | attention_mask, 127 | None, 128 | None, 129 | extra_text_input_ids=extra_text_input_ids, 130 | extra_text_attention_mask=extra_text_attention_mask 131 | ) 132 | 133 | n = extra_text_embeddings.shape[1] 134 | if self.vocab_size < self.original_vocab_size + n: 135 | self.resize_token_embeddings(self.original_vocab_size + n) 136 | self.get_input_embeddings().weight.data[self.original_vocab_size:] = extra_text_embeddings 137 | self.get_output_embeddings().weight.data[self.original_vocab_size:] = extra_text_embeddings 138 | else: 139 | inputs_embeds = self.get_model().embed_tokens(inputs) 140 | 141 | return super().generate( 142 | position_ids=position_ids, 143 | attention_mask=attention_mask, 144 | inputs_embeds=inputs_embeds, 145 | **kwargs 146 | ) 147 | 148 | 149 | AutoConfig.register("embed_llama", EmbedLlamaConfig) 150 | AutoModelForCausalLM.register(EmbedLlamaConfig, EmbedLlamaForCausalLM) 151 | -------------------------------------------------------------------------------- /src/modeling/causal_lm/modeling_mistral.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, List, Tuple 2 | import torch 3 | from torch import nn, Tensor 4 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 5 | from transformers import MistralConfig, MistralModel, MistralForCausalLM 6 | 7 | from transformers.modeling_outputs import CausalLMOutputWithPast 8 | from transformers.generation.utils import GenerateOutput 9 | 10 | from modeling.model import ELMMetaModel 11 | from modeling.meta import MetaLM 12 | 13 | 14 | class EmbedMistralConfig(MistralConfig): 15 | model_type = "embed_mistral" 16 | 17 | 18 | class EmbedMistralModel(ELMMetaModel, MistralModel): 19 | config_class = EmbedMistralConfig 20 | 21 | def __init__(self, config: MistralConfig): 22 | super().__init__(config) 23 | 24 | 25 | class EmbedMistralForCausalLM(MetaLM, MistralForCausalLM): 26 | config_class = EmbedMistralConfig 27 | 28 | def __init__(self, config: MistralConfig): 29 | super(MistralForCausalLM, self).__init__(config) 30 | self.model = EmbedMistralModel(config) 31 | 32 | self.vocab_size = self.original_vocab_size = config.vocab_size 33 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 34 | 35 | # Initialize weights and apply final processing 36 | self.post_init() 37 | 38 | def forward( 39 | self, 40 | input_ids: torch.LongTensor = None, 41 | attention_mask: Optional[torch.Tensor] = None, 42 | position_ids: Optional[torch.LongTensor] = None, 43 | past_key_values: Optional[List[torch.FloatTensor]] = None, 44 | inputs_embeds: Optional[torch.FloatTensor] = None, 45 | labels: Optional[torch.LongTensor] = None, 46 | use_cache: Optional[bool] = None, 47 | output_attentions: Optional[bool] = None, 48 | output_hidden_states: Optional[bool] = None, 49 | return_dict: Optional[bool] = None, 50 | **extra_texts_inputs 51 | ) -> Union[Tuple, CausalLMOutputWithPast]: 52 | 53 | if inputs_embeds is None: 54 | ( 55 | input_ids, 56 | position_ids, 57 | attention_mask, 58 | past_key_values, 59 | inputs_embeds, 60 | labels, 61 | _, 62 | _ 63 | ) = self.prepare_inputs_labels_embeddings( 64 | input_ids, 65 | position_ids, 66 | attention_mask, 67 | past_key_values, 68 | labels, 69 | **extra_texts_inputs 70 | ) 71 | 72 | return super().forward( 73 | input_ids=input_ids, 74 | attention_mask=attention_mask, 75 | position_ids=position_ids, 76 | past_key_values=past_key_values, 77 | inputs_embeds=inputs_embeds, 78 | labels=labels, 79 | use_cache=use_cache, 80 | output_attentions=output_attentions, 81 | output_hidden_states=output_hidden_states, 82 | return_dict=return_dict 83 | ) 84 | 85 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, 86 | inputs_embeds=None, **kwargs) -> dict[str, Tensor]: 87 | extra_text_input_ids = kwargs.pop("extra_text_input_ids", None) 88 | extra_text_attention_mask = kwargs.pop("extra_text_attention_mask", None) 89 | inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, attention_mask=attention_mask, 91 | inputs_embeds=inputs_embeds, **kwargs 92 | ) 93 | if extra_text_input_ids is not None and extra_text_attention_mask is not None: 94 | inputs["extra_text_input_ids"] = extra_text_input_ids 95 | inputs["extra_text_attention_mask"] = extra_text_attention_mask 96 | return inputs 97 | 98 | @torch.no_grad() 99 | def generate( 100 | self, 101 | inputs: Optional[torch.Tensor] = None, 102 | **kwargs, 103 | ) -> Union[GenerateOutput, torch.LongTensor]: 104 | position_ids = kwargs.pop("position_ids", None) 105 | attention_mask = kwargs.pop("attention_mask", None) 106 | if "inputs_embeds" in kwargs: 107 | raise NotImplementedError("`inputs_embeds` is not supported") 108 | 109 | extra_text_input_ids = kwargs.pop("extra_text_input_ids", None) 110 | extra_text_attention_mask = kwargs.pop("extra_text_attention_mask", None) 111 | 112 | if extra_text_input_ids is not None and extra_text_attention_mask is not None: 113 | ( 114 | inputs, 115 | position_ids, 116 | attention_mask, 117 | _, 118 | inputs_embeds, 119 | _, 120 | extra_text_embeddings, 121 | _, 122 | ) = self.prepare_inputs_labels_embeddings( 123 | inputs, 124 | position_ids, 125 | attention_mask, 126 | None, 127 | None, 128 | extra_text_input_ids=extra_text_input_ids, 129 | extra_text_attention_mask=extra_text_attention_mask 130 | ) 131 | 132 | n = extra_text_embeddings.shape[1] 133 | if self.vocab_size != self.original_vocab_size + n: 134 | self.resize_token_embeddings(self.original_vocab_size + n) 135 | assert self.vocab_size == self.original_vocab_size + n 136 | self.get_input_embeddings().weight.data[self.original_vocab_size:] = extra_text_embeddings 137 | self.get_output_embeddings().weight.data[self.original_vocab_size:] = extra_text_embeddings 138 | else: 139 | inputs_embeds = self.get_model().embed_tokens(inputs) 140 | 141 | return super().generate( 142 | position_ids=position_ids, 143 | attention_mask=attention_mask, 144 | inputs_embeds=inputs_embeds, 145 | **kwargs 146 | ) 147 | 148 | 149 | AutoConfig.register("embed_mistral", EmbedMistralConfig) 150 | AutoModelForCausalLM.register(EmbedMistralConfig, EmbedMistralForCausalLM) 151 | -------------------------------------------------------------------------------- /src/modeling/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 4 | from huggingface_hub import hf_hub_download 5 | from peft import PeftModel, PeftConfig 6 | 7 | from constants import PLACEHOLDER 8 | from modeling.encoder import Encoder, build_projector 9 | from modeling.causal_lm import EmbedLlamaForCausalLM, EmbedMistralForCausalLM 10 | from modeling.rank_lm import EmbedLlamaForRankLM, EmbedMistralForRankLM 11 | 12 | 13 | def load_from_hf(repo_id, filename, subfolder=None): 14 | cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder) 15 | return torch.load(cache_file, map_location='cpu') 16 | 17 | 18 | def load_pretrained_model( 19 | model_path, 20 | model_base=None, 21 | model_name=None, 22 | model_type="causal_lm", 23 | load_8bit=False, 24 | load_4bit=False, 25 | device_map="auto", 26 | device="cuda", 27 | use_flash_attn=False, 28 | **kwargs 29 | ): 30 | kwargs = {"device_map": device_map, **kwargs} 31 | 32 | if device != "cuda": 33 | kwargs['device_map'] = {"": device} 34 | 35 | if load_8bit: 36 | kwargs['load_in_8bit'] = True 37 | elif load_4bit: 38 | kwargs['load_in_4bit'] = True 39 | kwargs['quantization_config'] = BitsAndBytesConfig( 40 | load_in_4bit=True, 41 | bnb_4bit_compute_dtype=torch.float16, 42 | bnb_4bit_use_double_quant=True, 43 | bnb_4bit_quant_type='nf4' 44 | ) 45 | else: 46 | kwargs['torch_dtype'] = torch.float16 47 | 48 | if use_flash_attn: 49 | kwargs['attn_implementation'] = 'flash_attention_2' 50 | 51 | if 'embed' in model_name.lower(): 52 | model_cls = EmbedMistralForRankLM if model_type == "rank_lm" else EmbedMistralForCausalLM 53 | if os.path.exists(os.path.join(model_path, 'adapter_config.json')): 54 | # load lora model 55 | lora_cfg_pretrained = PeftConfig.from_pretrained(model_path) 56 | if not model_base: 57 | model_base = lora_cfg_pretrained.base_model_name_or_path 58 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 59 | tokenizer.pad_token = tokenizer.eos_token 60 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 61 | cfg_pretrained.vocab_size = len(tokenizer) 62 | model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 63 | # now didn't initialize addintional token embeddings, encoder, and projector 64 | # initialize them manually as follows 65 | model.initialize_tokenizer(tokenizer) 66 | model.original_vocab_size = len(tokenizer) 67 | model.get_model().encoder = Encoder(cfg_pretrained.encoder_name, cfg_pretrained).cuda() 68 | model.get_model().projector = build_projector(cfg_pretrained).cuda() 69 | 70 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 71 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 72 | else: 73 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 74 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v 75 | for k, v in non_lora_trainables.items()} 76 | if any(k.startswith('model.model.') for k in non_lora_trainables): 77 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v 78 | for k, v in non_lora_trainables.items()} 79 | model.load_state_dict(non_lora_trainables, strict=False) 80 | 81 | projector_weights = torch.load(os.path.join(model_path, 'projector.bin'), map_location='cpu') 82 | projector_weights = {k[17:] if k.startswith('base_model.model.') else k: v 83 | for k, v in projector_weights.items()} 84 | projector_weights = {k: v.to(torch.float16) for k, v in projector_weights.items()} 85 | model.load_state_dict(projector_weights, strict=False) 86 | # model.get_model().set_encoder_head() 87 | 88 | print('Loading LoRA weights...') 89 | model = PeftModel.from_pretrained(model, model_path) 90 | print('Merging LoRA weights...') 91 | model = model.merge_and_unload() 92 | print('Model is loaded...') 93 | 94 | elif model_base is not None: 95 | # this may be projector only 96 | print('Loading model from base model...') 97 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 98 | tokenizer.pad_token = tokenizer.eos_token 99 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 100 | cfg_pretrained.vocab_size = len(tokenizer) 101 | model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 102 | # now didn't initialize addintional token embeddings, encoder, and projector 103 | # initialize them manually as follows 104 | model.initialize_tokenizer(tokenizer) 105 | model.original_vocab_size = len(tokenizer) 106 | print('Loading encoder...') 107 | model.get_model().encoder = Encoder(cfg_pretrained.encoder_name, cfg_pretrained).cuda() 108 | print('Loading projector...') 109 | model.get_model().projector = build_projector(cfg_pretrained).cuda() 110 | projector_weights = torch.load(os.path.join(model_path, 'projector.bin'), map_location='cpu') 111 | projector_weights = {k: v.to(torch.float16) for k, v in projector_weights.items()} 112 | model.load_state_dict(projector_weights, strict=False) 113 | else: 114 | print(f'Loading model from {model_path}...') 115 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 116 | config = AutoConfig.from_pretrained(model_path) 117 | model = model_cls.from_pretrained( 118 | model_path, 119 | config=config, 120 | low_cpu_mem_usage=True, 121 | **kwargs 122 | ) 123 | model.initialize_tokenizer(tokenizer) 124 | 125 | model.resize_token_embeddings(len(tokenizer)) 126 | 127 | global PLACEHOLDER_ID 128 | PLACEHOLDER_ID = tokenizer.convert_tokens_to_ids(PLACEHOLDER) 129 | 130 | encoder = model.get_encoder() 131 | if device_map != 'auto': 132 | encoder.to(device=device_map, dtype=torch.float16) 133 | else: 134 | # Load language model 135 | if model_base is not None: 136 | # PEFT model 137 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 138 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 139 | print(f"Loading LoRA weights from {model_path}") 140 | model = PeftModel.from_pretrained(model, model_path) 141 | print(f"Merging weights") 142 | model = model.merge_and_unload() 143 | print('Convert to FP16...') 144 | model.to(torch.float16) 145 | else: 146 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 147 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 148 | 149 | if hasattr(model.config, "max_sequence_length"): 150 | context_len = model.config.max_sequence_length 151 | else: 152 | context_len = 2048 153 | 154 | return tokenizer, model, context_len 155 | -------------------------------------------------------------------------------- /src/ranker.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from transformers import AutoTokenizer 4 | 5 | from modeling.builder import load_pretrained_model 6 | 7 | 8 | PLACEHOLDER = "" 9 | 10 | 11 | class Ranker: 12 | def __init__(self, model_path, model_base, model_name="embed_mistral"): 13 | self._tokenizer, self._model, _ = load_pretrained_model( 14 | model_path=model_path, 15 | model_base=model_base, 16 | model_name=model_name, 17 | device_map="cuda", 18 | ) 19 | self.model_name = model_path 20 | self._model.to(torch.float16) 21 | self._model.config.use_cache = True 22 | self._model.eval() 23 | if getattr(self._model.config, "encoder_name", None): 24 | self._encoder_tokenizer = AutoTokenizer.from_pretrained(self._model.config.encoder_name) 25 | else: 26 | self._encoder_tokenizer = None 27 | self.oringinal_vocab_size = self._model.config.vocab_size 28 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | def _get_encoder_inputs(self, candidates: list[str]) -> torch.Tensor: 31 | input_ids = self._encoder_tokenizer( 32 | [p["content"] for p in candidates], 33 | return_tensors="pt", 34 | padding="longest", 35 | max_length=512, 36 | truncation=True, 37 | ).input_ids.unsqueeze(0).to(self.device) 38 | return input_ids 39 | 40 | 41 | class ListwiseTextEmbeddingRanker(Ranker): 42 | 43 | def _add_prefix_prompt(self, query: str, num: int) -> str: 44 | return f"""I will provide you with {num} passages, each with a special token representing the passage enclosed in [], followed by original text. 45 | Rank the passages based on their relevance to the search query: {query}. 46 | """ 47 | 48 | def _add_post_prompt(self, query: str, num: int) -> str: 49 | return f"""Search Query: {query}. 50 | Rank the {num} relatively ordered passages above based on their relevance to the search query, output the ranking in descending order. Only output the {num} unique special token in the ranking. 51 | 52 | """ 53 | 54 | def _replace_number(self, s: str) -> str: 55 | return re.sub(r"\[(\d+)\]", r"(\1)", s) 56 | 57 | def _get_message(self, query: str, candidates: list[str]) -> str: 58 | num = len(candidates) 59 | candidates = [p["content"] for p in candidates] 60 | messages = [] 61 | input_context = self._add_prefix_prompt(query, num) 62 | for i, content in enumerate(candidates): 63 | content = self._replace_number(content.strip()) 64 | input_context += self._get_input_for_one_passage(content, i + 1) 65 | input_context += self._add_post_prompt(query, num) 66 | messages.append({"role": "user", "content": input_context}) 67 | return messages 68 | 69 | def _get_input_for_one_passage(self, content: str, i: int) -> str: 70 | return f"Passage {i}: [] {content}\n\n" 71 | 72 | def _get_llm_inputs(self, query: str, candidates: list[str]) -> torch.Tensor: 73 | messages = self._get_message(query, candidates) 74 | input_ids = self._tokenizer.apply_chat_template( 75 | messages, 76 | add_generation_prompt=True, 77 | return_tensors="pt", 78 | padding="longest", 79 | max_length=32768, 80 | truncation=True, 81 | ).to(self.device) 82 | return input_ids 83 | 84 | @torch.no_grad() 85 | def __call__(self, query: str, candidates: list[str]) -> dict[str]: 86 | input_ids = self._get_llm_inputs(query, candidates) 87 | extra_text_input_ids = self._get_encoder_inputs(candidates) 88 | 89 | def prefix_allowed_tokens_fn(batch_id, prev_ids): 90 | allowed_tokens = list( 91 | set([x + self.oringinal_vocab_size for x in range(len(candidates))]) \ 92 | - set(prev_ids.tolist()) 93 | ) 94 | if len(allowed_tokens) == 0: 95 | return [self._tokenizer.eos_token_id] 96 | elif len(allowed_tokens) == len(candidates): 97 | if prev_ids[-1] != self._tokenizer.bos_token_id: 98 | return [self._tokenizer.bos_token_id] 99 | return allowed_tokens 100 | else: 101 | return allowed_tokens 102 | 103 | outputs = self._model.generate( 104 | input_ids, 105 | max_new_tokens=128, 106 | do_sample=False, 107 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 108 | pad_token_id=self._model.config.eos_token_id, 109 | extra_text_input_ids=extra_text_input_ids, 110 | extra_text_attention_mask=extra_text_input_ids.ne(self._encoder_tokenizer.pad_token_id), 111 | ) 112 | 113 | rankings = outputs[0].cpu().tolist() 114 | rankings = [x - self.oringinal_vocab_size for x in rankings if x >= self.oringinal_vocab_size] 115 | return rankings 116 | 117 | 118 | class ListwiseEmbeddingRanker(ListwiseTextEmbeddingRanker): 119 | def _get_input_for_one_passage(self, content: str, i: int) -> str: 120 | return f"Passage {i}: []\n\n" 121 | 122 | 123 | class ListwiseTextRanker(ListwiseTextEmbeddingRanker): 124 | def _get_input_for_one_passage(self, content: str, i: int) -> str: 125 | return f"Passage {i}: {content}\n\n" 126 | 127 | def _add_prefix_prompt(self, query: str, num: int) -> str: 128 | return f"""I will provide you with {num} passages. 129 | Rank the passages based on their relevance to the search query: {query}. 130 | """ 131 | 132 | def _add_post_prompt(self, query: str, num: int) -> str: 133 | return f"""Search Query: {query}. 134 | Rank the {num} relatively ordered passages above based on their relevance to the search query, output the ranking in descending order. The output format should be [] > [] > ..., e.g., [4] > [2] > ..., Only respond with the ranking results with {num} unique numbers, do not say anything else or explain. 135 | 136 | """ 137 | 138 | def parse_output(self, output: str) -> list[int]: 139 | response = self._clean_response(output) 140 | response = [int(x) - 1 for x in response.split()] 141 | response = self._remove_duplicate(response) 142 | return response 143 | 144 | def _clean_response(self, response: str) -> str: 145 | new_response = "" 146 | for c in response: 147 | if not c.isdigit(): 148 | new_response += " " 149 | else: 150 | new_response += c 151 | new_response = new_response.strip() 152 | return new_response 153 | 154 | def _remove_duplicate(self, response: list[int]) -> list[int]: 155 | new_response = [] 156 | for c in response: 157 | if c not in new_response: 158 | new_response.append(c) 159 | return new_response 160 | 161 | def __call__(self, query: str, candidates: list[str]) -> dict[str]: 162 | input_ids = self._get_llm_inputs(query, candidates) 163 | 164 | outputs = self._model.generate( 165 | input_ids, 166 | max_new_tokens=256, 167 | do_sample=False, 168 | pad_token_id=self._model.config.eos_token_id, 169 | ) 170 | outputs = self._tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True) 171 | 172 | permutation = self.parse_output(outputs) 173 | original_rank = [tt for tt in range(len(candidates))] 174 | permutation = [ss for ss in permutation if ss in original_rank] 175 | permutation = permutation + [tt for tt in original_rank if tt not in permutation] 176 | return permutation 177 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pathlib 4 | 5 | from peft import LoraConfig, get_peft_model 6 | import torch 7 | import transformers 8 | 9 | from arguments import ModelArguments, DataArguments, TrainingArguments, LoraArguments 10 | from data import make_data_module 11 | from modeling.causal_lm import EmbedLlamaForCausalLM, EmbedMistralForCausalLM 12 | from modeling.rank_lm import EmbedLlamaForRankLM, EmbedMistralForRankLM 13 | from trainer import Trainer, RankTrainer 14 | from utils import * 15 | 16 | 17 | local_rank = None 18 | 19 | 20 | def rank0_print(*args): 21 | if local_rank == 0: 22 | print(*args) 23 | 24 | 25 | def train(): 26 | global local_rank 27 | 28 | parser = transformers.HfArgumentParser( 29 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 30 | ) 31 | ( 32 | model_args, 33 | data_args, 34 | training_args, 35 | lora_args, 36 | ) = parser.parse_args_into_dataclasses() 37 | 38 | local_rank = training_args.local_rank 39 | 40 | device_map = None 41 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 42 | ddp = world_size != 1 43 | 44 | compute_dtype = ( 45 | torch.float16 46 | if training_args.fp16 47 | else (torch.bfloat16 if training_args.bf16 else torch.float32) 48 | ) 49 | 50 | # Set up model 51 | if model_args.encoder_name: 52 | if model_args.model_type == "causal_lm": 53 | model = EmbedMistralForCausalLM.from_pretrained( 54 | model_args.model_name_or_path, 55 | cache_dir=training_args.cache_dir, 56 | attn_implementation=training_args.attn_implementation, 57 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), 58 | ) 59 | model.generation_config.do_sample = True 60 | elif model_args.model_type == "rank_lm": 61 | model = EmbedMistralForRankLM.from_pretrained( 62 | model_args.model_name_or_path, 63 | cache_dir=training_args.cache_dir, 64 | attn_implementation=training_args.attn_implementation, 65 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), 66 | ) 67 | else: 68 | raise ValueError(f"Invalid model type: {model_args.model_type}") 69 | else: 70 | model = transformers.MistralForCausalLM.from_pretrained( 71 | model_args.model_name_or_path, 72 | cache_dir=training_args.cache_dir, 73 | attn_implementation=training_args.attn_implementation, 74 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), 75 | ) 76 | model.generation_config.do_sample = True 77 | 78 | # Set RoPE scaling factor 79 | orig_ctx_len = getattr(model.config, "max_position_embeddings", None) 80 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: 81 | scaling_factor = float( 82 | math.ceil(training_args.model_max_length / orig_ctx_len)) 83 | model.config.rope_scaling = { 84 | "type": "linear", "factor": scaling_factor} 85 | model.config.use_cache = False 86 | 87 | # Load Lora 88 | if lora_args.lora_enable: 89 | lora_config = LoraConfig( 90 | r=lora_args.lora_r, 91 | lora_alpha=lora_args.lora_alpha, 92 | target_modules=lora_args.lora_target_modules, 93 | lora_dropout=lora_args.lora_dropout, 94 | bias=lora_args.lora_bias, 95 | task_type="CAUSAL_LM", 96 | ) 97 | 98 | model = get_peft_model(model, lora_config) 99 | 100 | # Load tokenizer 101 | tokenizer = transformers.AutoTokenizer.from_pretrained( 102 | model_args.model_name_or_path, 103 | cache_dir=training_args.cache_dir, 104 | model_max_length=training_args.model_max_length, 105 | padding_side=model_args.padding_side, 106 | use_fast=False, 107 | trust_remote_code=model_args.trust_remote_code, 108 | ) 109 | 110 | if tokenizer.pad_token != tokenizer.unk_token: 111 | tokenizer.pad_token = tokenizer.unk_token 112 | 113 | if model_args.encoder_name: 114 | model.get_model().initialize_modules(model_args) 115 | model.get_encoder().to(compute_dtype) 116 | model.get_projector().to(compute_dtype) 117 | 118 | if "lora" in model_args.encoder_name: 119 | peft_config = LoraConfig.from_pretrained(model_args.encoder_name) 120 | encoder_tokenizer = transformers.AutoTokenizer.from_pretrained( 121 | peft_config.base_model_name_or_path, 122 | cache_dir=training_args.cache_dir, 123 | trust_remote_code=model_args.trust_remote_code, 124 | ) 125 | else: 126 | encoder_tokenizer = transformers.AutoTokenizer.from_pretrained( 127 | model_args.encoder_name, 128 | cache_dir=training_args.cache_dir, 129 | trust_remote_code=model_args.trust_remote_code, 130 | ) 131 | 132 | model.config.tokenizer_padding_side = tokenizer.padding_side 133 | model.config.tokenizer_model_max_length = tokenizer.model_max_length 134 | 135 | model.config.tune_mlp_adapter = training_args.tune_mlp_adapter = model_args.tune_mlp_adapter 136 | model.config.freeze_backbone = training_args.freeze_backbone = model_args.freeze_backbone 137 | if model_args.freeze_backbone: 138 | model.requires_grad_(False) 139 | for p in model.get_model().projector.parameters(): 140 | p.requires_grad = training_args.tune_mlp_adapter 141 | 142 | model.initialize_tokenizer(tokenizer) 143 | 144 | if training_args.gradient_checkpointing: 145 | if hasattr(model, "enable_input_require_grads"): 146 | model.enable_input_require_grads() 147 | else: 148 | def make_inputs_require_grad(module, input, output): 149 | output.requires_grad_(True) 150 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 151 | 152 | training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} 153 | 154 | if model_args.freeze_embedding_layer: 155 | for p in model.get_input_embeddings().parameters(): 156 | p.requires_grad = False 157 | 158 | # Load data 159 | data_module = make_data_module( 160 | tokenizer=tokenizer, 161 | encoder_tokenizer=None if not model_args.encoder_name else encoder_tokenizer, 162 | data_args=data_args, 163 | model_type=model_args.model_type, 164 | ) 165 | 166 | # Start trainner 167 | if model_args.model_type == "causal_lm": 168 | trainer = Trainer( 169 | model=model, tokenizer=tokenizer, args=training_args, **data_module 170 | ) 171 | elif model_args.model_type == "rank_lm": 172 | trainer = RankTrainer( 173 | model=model, tokenizer=tokenizer, args=training_args, **data_module 174 | ) 175 | else: 176 | raise ValueError(f"Invalid model type: {model_args.model_type}") 177 | 178 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 179 | trainer.train(resume_from_checkpoint=True) 180 | else: 181 | trainer.train() 182 | 183 | # Save model 184 | model.config.use_cache = True 185 | 186 | if lora_args.lora_enable: 187 | state_dict = get_peft_state_maybe_zero_3( 188 | model.named_parameters(), lora_args.lora_bias 189 | ) 190 | non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( 191 | model.named_parameters() 192 | ) 193 | projector_state_dict = get_adapter_state_maybe_zero_3( 194 | model.named_parameters(), ["projector"] 195 | ) 196 | if training_args.local_rank == 0 or training_args.local_rank == -1: 197 | model.config.save_pretrained(training_args.output_dir) 198 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 199 | torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) 200 | torch.save(projector_state_dict, os.path.join(training_args.output_dir, 'projector.bin')) 201 | else: 202 | safe_save_model_for_hf_trainer(trainer=trainer, 203 | output_dir=training_args.output_dir) 204 | 205 | 206 | if __name__ == "__main__": 207 | train() 208 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import ujson as json 3 | from typing import Sequence, Callable 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.utils.data import Dataset 8 | import transformers 9 | 10 | from utils import * 11 | from constants import RANK_TOKEN, IGNORE_TOKEN_ID 12 | 13 | 14 | def preprocess_messages( 15 | tokenizer: transformers.PreTrainedTokenizer, 16 | messages: list[dict[str, str]], 17 | mask_targets_func: Callable[[list[dict[str, str]], Tensor], Tensor], 18 | ) -> dict[str, Tensor]: 19 | if messages[-1]["role"] == "assistant": 20 | if messages[-1]["content"].startswith("["): 21 | messages[-1]["content"] = ' ' + messages[-1]["content"] 22 | input_ids = tokenizer.apply_chat_template( 23 | messages, 24 | return_tensors="pt", 25 | padding="longest", 26 | max_length=tokenizer.model_max_length, 27 | truncation=True, 28 | ) 29 | targets = input_ids.clone() 30 | targets = mask_targets_func(tokenizer, messages, targets) 31 | return dict( 32 | input_ids=input_ids, 33 | labels=targets, 34 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 35 | ) 36 | 37 | 38 | def _get_messages_length( 39 | messages: list[dict[str, str]], 40 | tokenizer: transformers.PreTrainedTokenizer 41 | ) -> str: 42 | return tokenizer.apply_chat_template( 43 | messages, 44 | return_tensors='pt', 45 | max_length=tokenizer.model_max_length, 46 | truncation=True 47 | ).shape[1] 48 | 49 | 50 | def _mask_targets_for_causal_lm( 51 | tokenizer: transformers.PreTrainedTokenizer, 52 | messages: list[dict[str, str]], 53 | targets: Tensor, 54 | ) -> Tensor: 55 | for message_idx, message in enumerate(messages): 56 | if message["role"] != "assistant": 57 | message_start_idx = _get_messages_length(messages[:message_idx], tokenizer) if message_idx > 0 else 0 58 | message_end_idx = _get_messages_length(messages[:message_idx+1], tokenizer) 59 | targets[:, message_start_idx:message_end_idx] = IGNORE_TOKEN_ID 60 | if message_end_idx >= tokenizer.model_max_length: 61 | break 62 | return targets 63 | 64 | 65 | def _mask_targets_for_ranking( 66 | tokenizer: transformers.PreTrainedTokenizer, 67 | messages: list[list[dict[str, str]]], 68 | targets: Tensor, 69 | ) -> Tensor: 70 | if RANK_TOKEN not in tokenizer.all_special_tokens: 71 | tokenizer.add_tokens([RANK_TOKEN], special_tokens=True) 72 | for target in targets: 73 | target[target != tokenizer.convert_tokens_to_ids(RANK_TOKEN)] = IGNORE_TOKEN_ID 74 | return targets 75 | 76 | 77 | class SFTDataset(Dataset): 78 | def __init__( 79 | self, 80 | data_path: str, 81 | tokenizer: transformers.PreTrainedTokenizer, 82 | encoder_tokenizer: transformers.PreTrainedTokenizer, 83 | ): 84 | super().__init__() 85 | self.tokenizer = tokenizer 86 | if not self.tokenizer.pad_token: 87 | tokenizer.pad_token = tokenizer.eos_token 88 | self.encoder_tokenizer = encoder_tokenizer 89 | if self.encoder_tokenizer and self.encoder_tokenizer.eos_token: 90 | print("WARNING: will add eos token to the end of extra texts") 91 | self.encoder_tokenizer.pad_token = self.encoder_tokenizer.eos_token 92 | self.raw_data = self.load_data(data_path) 93 | 94 | def __len__(self): 95 | return len(self.raw_data) 96 | 97 | def __getitem__(self, i): 98 | raise NotImplementedError 99 | 100 | def load_data(self, data_path): 101 | if data_path.endswith(".json"): 102 | with open(data_path, "r") as f: 103 | raw_data = json.load(f) 104 | assert isinstance(raw_data, list) 105 | elif data_path.endswith(".jsonl"): 106 | with open(data_path, "r") as f: 107 | raw_data = [json.loads(line) for line in f] 108 | else: 109 | raise ValueError(f"Unsupported data format: {data_path}") 110 | print(f"Loaded {len(raw_data)} examples from {data_path}") 111 | return raw_data 112 | 113 | 114 | class DatasetForCausalLM(SFTDataset): 115 | 116 | def __getitem__(self, i) -> dict[str, Tensor]: 117 | ret = preprocess_messages( 118 | self.tokenizer, 119 | self.raw_data[i]["messages"], 120 | _mask_targets_for_causal_lm, 121 | ) 122 | ret = dict( 123 | input_ids=ret["input_ids"][0], 124 | labels=ret["labels"][0], 125 | attention_mask=ret["attention_mask"][0], 126 | ) 127 | if "extra_texts" in self.raw_data[i]: 128 | if self.encoder_tokenizer.eos_token: 129 | self.raw_data[i]["extra_texts"] = [ 130 | f"{text} {self.encoder_tokenizer.eos_token}" 131 | for text in self.raw_data[i]["extra_texts"] 132 | ] 133 | extra_text_inputs = self.encoder_tokenizer( 134 | self.raw_data[i]["extra_texts"], 135 | return_tensors="pt", 136 | padding="max_length", 137 | max_length=128, 138 | truncation=True, 139 | ) 140 | ret["extra_text_inputs"] = dict( 141 | input_ids=extra_text_inputs["input_ids"], 142 | ) 143 | 144 | return ret 145 | 146 | 147 | class DatasetForRanking(SFTDataset): 148 | 149 | def __init__( 150 | self, 151 | data_path: str, 152 | tokenizer: transformers.PreTrainedTokenizer, 153 | encoder_tokenizer: transformers.PreTrainedTokenizer, 154 | use_embedding_with_content: bool = True, 155 | use_embedding_without_content: bool = False, 156 | ): 157 | super().__init__(data_path, tokenizer, encoder_tokenizer) 158 | self.use_embedding_with_content = use_embedding_with_content 159 | self.use_embedding_without_content = use_embedding_without_content 160 | 161 | def __getitem__(self, i) -> dict[str, Tensor]: 162 | ranking = torch.tensor(self.raw_data[i]["ranking"], dtype=torch.long) 163 | 164 | if self.use_embedding_with_content: 165 | messages_w_content = self.raw_data[i]["messages_w_content"] 166 | if messages_w_content[-1]["role"] == "assistant": 167 | messages_w_content[-1]["content"] = f"{RANK_TOKEN}" * len(ranking) 168 | else: 169 | messages_w_content.append({"role": "assistant", "content": f"{RANK_TOKEN}" * len(ranking)}) 170 | inputs_w_content = preprocess_messages( 171 | self.tokenizer, 172 | messages_w_content, 173 | _mask_targets_for_ranking 174 | ) 175 | inputs_w_content = dict( 176 | input_ids=inputs_w_content["input_ids"][0], 177 | labels=inputs_w_content["labels"][0], 178 | attention_mask=inputs_w_content["attention_mask"][0], 179 | ) 180 | else: 181 | inputs_w_content = None 182 | if self.use_embedding_without_content: 183 | messages_wo_content = self.raw_data[i]["messages_wo_content"] 184 | if messages_wo_content[-1]["role"] == "assistant": 185 | messages_wo_content[-1]["content"] = f"{RANK_TOKEN}" * len(ranking) 186 | else: 187 | messages_wo_content.append({"role": "assistant", "content": f"{RANK_TOKEN}" * len(ranking)}) 188 | inputs_wo_content = preprocess_messages( 189 | self.tokenizer, 190 | messages_wo_content, 191 | _mask_targets_for_ranking 192 | ) 193 | inputs_wo_content = dict( 194 | input_ids=inputs_wo_content["input_ids"][0], 195 | labels=inputs_wo_content["labels"][0], 196 | attention_mask=inputs_wo_content["attention_mask"][0], 197 | ) 198 | else: 199 | inputs_wo_content = None 200 | if "extra_texts" in self.raw_data[i]: 201 | if self.encoder_tokenizer.eos_token: 202 | self.raw_data[i]["extra_texts"] = [ 203 | f"{text}{self.encoder_tokenizer.eos_token}" for text in self.raw_data[i]["extra_texts"] 204 | ] 205 | extra_text_inputs = self.encoder_tokenizer( 206 | self.raw_data[i]["extra_texts"], 207 | return_tensors="pt", 208 | padding="max_length", 209 | max_length=128, 210 | truncation=True, 211 | ) 212 | extra_text_inputs = dict( 213 | input_ids=extra_text_inputs["input_ids"], 214 | ) 215 | else: 216 | extra_text_inputs = None 217 | ret = dict( 218 | inputs_w_content=inputs_w_content, 219 | inputs_wo_content=inputs_wo_content, 220 | extra_text_inputs=extra_text_inputs, 221 | ranking=ranking, 222 | ) 223 | 224 | return ret 225 | 226 | 227 | @dataclass 228 | class DataCollatorForCausalLM: 229 | tokenizer: transformers.PreTrainedTokenizer 230 | encoder_tokenizer: transformers.PreTrainedTokenizer 231 | 232 | def __call__(self, instances: Sequence[dict]) -> dict[str, Tensor]: 233 | input_ids, labels = tuple([instance[key] for instance in instances] 234 | for key in ("input_ids", "labels")) 235 | input_ids = torch.nn.utils.rnn.pad_sequence( 236 | input_ids, 237 | batch_first=True, 238 | padding_value=self.tokenizer.pad_token_id) 239 | labels = torch.nn.utils.rnn.pad_sequence( 240 | labels, 241 | batch_first=True, 242 | padding_value=IGNORE_TOKEN_ID 243 | ) 244 | input_ids = input_ids[:, :self.tokenizer.model_max_length] 245 | labels = labels[:, :self.tokenizer.model_max_length] 246 | batch = dict( 247 | input_ids=input_ids, 248 | labels=labels, 249 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 250 | ) 251 | 252 | if instances[0].get("extra_text_inputs", None) is not None: 253 | extra_text_input_ids = [instance["extra_text_inputs"]["input_ids"] for instance in instances] 254 | extra_text_input_ids = torch.nn.utils.rnn.pad_sequence( 255 | extra_text_input_ids, 256 | batch_first=True, 257 | padding_value=self.encoder_tokenizer.pad_token_id 258 | ) 259 | batch["extra_text_input_ids"] = extra_text_input_ids 260 | batch["extra_text_attention_mask"] = extra_text_input_ids.ne(self.encoder_tokenizer.pad_token_id) 261 | return batch 262 | 263 | 264 | @dataclass 265 | class DataCollatorForRanking: 266 | tokenizer: transformers.PreTrainedTokenizer 267 | encoder_tokenizer: transformers.PreTrainedTokenizer 268 | 269 | def __call__(self, instances: Sequence[dict]) -> dict[str, Tensor]: 270 | batch = dict() 271 | 272 | if instances[0].get("inputs_w_content", None) is not None: 273 | input_ids, labels = tuple( 274 | [instance["inputs_w_content"][key] for instance in instances] for key in ("input_ids", "labels") 275 | ) 276 | input_ids = torch.nn.utils.rnn.pad_sequence( 277 | input_ids, 278 | batch_first=True, 279 | padding_value=self.tokenizer.pad_token_id) 280 | labels = torch.nn.utils.rnn.pad_sequence( 281 | labels, 282 | batch_first=True, 283 | padding_value=IGNORE_TOKEN_ID 284 | ) 285 | input_ids = input_ids[:, :self.tokenizer.model_max_length] 286 | labels = labels[:, :self.tokenizer.model_max_length] 287 | batch["inputs_w_content"] = dict( 288 | input_ids=input_ids, 289 | labels=labels, 290 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 291 | ) 292 | 293 | if instances[0].get("inputs_wo_content", None) is not None: 294 | input_ids, labels = tuple( 295 | [instance["inputs_wo_content"][key] for instance in instances] for key in ("input_ids", "labels") 296 | ) 297 | input_ids = torch.nn.utils.rnn.pad_sequence( 298 | input_ids, 299 | batch_first=True, 300 | padding_value=self.tokenizer.pad_token_id) 301 | labels = torch.nn.utils.rnn.pad_sequence( 302 | labels, 303 | batch_first=True, 304 | padding_value=IGNORE_TOKEN_ID 305 | ) 306 | input_ids = input_ids[:, :self.tokenizer.model_max_length] 307 | labels = labels[:, :self.tokenizer.model_max_length] 308 | batch["inputs_wo_content"] = dict( 309 | input_ids=input_ids, 310 | labels=labels, 311 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 312 | ) 313 | 314 | if instances[0].get("extra_text_inputs", None) is not None: 315 | extra_text_input_ids = [instance["extra_text_inputs"]["input_ids"] for instance in instances] 316 | extra_text_input_ids = torch.nn.utils.rnn.pad_sequence( 317 | extra_text_input_ids, 318 | batch_first=True, 319 | padding_value=self.encoder_tokenizer.pad_token_id 320 | ) 321 | batch["extra_text_inputs"] = dict( 322 | extra_text_input_ids=extra_text_input_ids, 323 | extra_text_attention_mask=extra_text_input_ids.ne(self.encoder_tokenizer.pad_token_id) 324 | ) 325 | 326 | batch["ranking"] = torch.stack([instance["ranking"] for instance in instances]) 327 | 328 | return batch 329 | 330 | 331 | def make_data_module( 332 | tokenizer: transformers.PreTrainedTokenizer, 333 | encoder_tokenizer: transformers.PreTrainedTokenizer, 334 | data_args, 335 | model_type: str, 336 | ) -> dict: 337 | if model_type == "causal_lm": 338 | train_dataset = DatasetForCausalLM( 339 | data_args.data_path, 340 | tokenizer=tokenizer, 341 | encoder_tokenizer=encoder_tokenizer, 342 | ) 343 | else: 344 | train_dataset = DatasetForRanking( 345 | data_args.data_path, 346 | tokenizer=tokenizer, 347 | encoder_tokenizer=encoder_tokenizer, 348 | use_embedding_with_content=data_args.use_embedding_with_content, 349 | use_embedding_without_content=data_args.use_embedding_without_content, 350 | ) 351 | data_collator_cls = DataCollatorForCausalLM if model_type == "causal_lm" else DataCollatorForRanking 352 | data_collator = data_collator_cls( 353 | tokenizer=tokenizer, 354 | encoder_tokenizer=encoder_tokenizer 355 | ) 356 | 357 | # TODO: make eval_dataset available 358 | # if data_args.eval_data_path: 359 | # eval_dataset = dataset_cls( 360 | # data_args.eval_data_path, 361 | # tokenizer=tokenizer, 362 | # encoder_tokenizer=encoder_tokenizer 363 | # ) 364 | # else: 365 | # eval_dataset = None 366 | 367 | return dict( 368 | train_dataset=train_dataset, 369 | eval_dataset=None, 370 | data_collator=data_collator 371 | ) 372 | --------------------------------------------------------------------------------