├── README.md └── example_with_embeddings ├── arguments.py ├── finetune.py ├── medi-data └── README.txt ├── models.py ├── mteb_benchmark.py ├── output └── checkpoint-1000 │ └── adapter_model.bin ├── preprocess_and_tokenize.py ├── results ├── ArguAna.json ├── CQADupstackAndroidRetrieval.json └── CQADupstackWebmastersRetrieval.json └── results_finetuned └── ArguAna.json /README.md: -------------------------------------------------------------------------------- 1 | # 基于ChatGLM的text embedding实现 2 | 3 | **text embedding based on ChatGLM:** This project aims to prefix-tuning ChatGLM-6B with respect to the text-embedding task, so that a single model can be used to complete the text-embedding task and question-answer task at the same time, which reduce the overhead of additional deployment of the embedding model. You can easily use it in langchain and implement a fully localized-deploymented trivia model 4 | 5 | ## 介绍 6 | 7 | 本项目尝试在text embedding任务上对ChatGLM-6B进行prefix-tuning微调,从而使用单个模型即可完成text embedding和问答任务,减少额外部署embedding模型的开销,可以轻松接入langchain实现一个全本地化部署的知识问答模型 8 | 9 | 微调方法:[mymusise/ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning)和ChatGLM-6B的[P-tuning v2](https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/README.md) 10 | 11 | 数据集:[medi-data](https://github.com/HKUNLP/instructor-embedding#train-instructor) 12 | 13 | text embedding的实现方式:在prefix-tuning的基础上,在需要embedding的文本最后加上\[MASK\]标记,取transformer模块在\[MASK\]标记对应处的输出作为文本的text embedding 14 | 15 | 受限于个人的硬件条件,本人只使用了medi-data中约8000个英文样本对模型进行了微调,微调后在MTEB的Arguana任务上的效果比glove.6B.300d稍好,后续可以考虑使用更多数据或中文语料进行进一步的微调 16 | 17 | ## 使用 18 | 19 | ### 模型训练 20 | 21 | ```bash 22 | cd example_with_embeddings 23 | ``` 24 | 25 | tokenization 26 | 27 | ```bash 28 | python preprocess_and_tokenize.py \ 29 | --train_file medi-data/medi-data.json \ 30 | --overwrite_cache 31 | --model_name_or_path "THUDM/chatglm-6b-int4" \ 32 | --max_source_length 512 \ 33 | --per_device_train_batch_size 8 34 | ``` 35 | 36 | 训练 37 | 38 | ```bash 39 | python finetune.py \ 40 | --dataset_path medi-data/processed \ 41 | --model_name_or_path "THUDM/chatglm-6b-int4" \ 42 | --per_device_train_batch_size 8 \ 43 | --gradient_accumulation_steps 1 \ 44 | --num_train_epochs 1 \ 45 | --max_steps 1000 \ 46 | --save_steps 50 \ 47 | --save_total_limit 2 \ 48 | --learning_rate 1e-4 \ 49 | --remove_unused_columns false \ 50 | --logging_steps 1 \ 51 | --output_dir output 52 | ``` 53 | 54 | ### 模型评估 55 | 56 | ```bash 57 | python mteb_benchmark.py 58 | ``` -------------------------------------------------------------------------------- /example_with_embeddings/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | config_name: Optional[str] = field( 15 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 16 | ) 17 | tokenizer_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 19 | ) 20 | cache_dir: Optional[str] = field( 21 | default=None, 22 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 23 | ) 24 | use_fast_tokenizer: bool = field( 25 | default=True, 26 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 27 | ) 28 | model_revision: str = field( 29 | default="main", 30 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 31 | ) 32 | use_auth_token: bool = field( 33 | default=False, 34 | metadata={ 35 | "help": ( 36 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 37 | "with private models)." 38 | ) 39 | }, 40 | ) 41 | resize_position_embeddings: Optional[bool] = field( 42 | default=None, 43 | metadata={ 44 | "help": ( 45 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 46 | "the model's position embeddings." 47 | ) 48 | }, 49 | ) 50 | quantization_bit: Optional[int] = field( 51 | default=None 52 | ) 53 | pre_seq_len: Optional[int] = field( 54 | default=None 55 | ) 56 | prefix_projection: bool = field( 57 | default=False 58 | ) 59 | 60 | 61 | @dataclass 62 | class DataTrainingArguments: 63 | """ 64 | Arguments pertaining to what data we are going to input our model for training and eval. 65 | """ 66 | 67 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 68 | 69 | dataset_name: Optional[str] = field( 70 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 71 | ) 72 | dataset_config_name: Optional[str] = field( 73 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 74 | ) 75 | prompt_column: Optional[str] = field( 76 | default=None, 77 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 78 | ) 79 | response_column: Optional[str] = field( 80 | default=None, 81 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 82 | ) 83 | train_file: Optional[str] = field( 84 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 85 | ) 86 | validation_file: Optional[str] = field( 87 | default=None, 88 | metadata={ 89 | "help": ( 90 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 91 | ) 92 | }, 93 | ) 94 | test_file: Optional[str] = field( 95 | default=None, 96 | metadata={ 97 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 98 | }, 99 | ) 100 | overwrite_cache: bool = field( 101 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 102 | ) 103 | preprocessing_num_workers: Optional[int] = field( 104 | default=None, 105 | metadata={"help": "The number of processes to use for the preprocessing."}, 106 | ) 107 | max_source_length: Optional[int] = field( 108 | default=1024, 109 | metadata={ 110 | "help": ( 111 | "The maximum total input sequence length after tokenization. Sequences longer " 112 | "than this will be truncated, sequences shorter will be padded." 113 | ) 114 | }, 115 | ) 116 | max_target_length: Optional[int] = field( 117 | default=128, 118 | metadata={ 119 | "help": ( 120 | "The maximum total sequence length for target text after tokenization. Sequences longer " 121 | "than this will be truncated, sequences shorter will be padded." 122 | ) 123 | }, 124 | ) 125 | val_max_target_length: Optional[int] = field( 126 | default=None, 127 | metadata={ 128 | "help": ( 129 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 130 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 131 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 132 | "during ``evaluate`` and ``predict``." 133 | ) 134 | }, 135 | ) 136 | pad_to_max_length: bool = field( 137 | default=False, 138 | metadata={ 139 | "help": ( 140 | "Whether to pad all samples to model maximum sentence length. " 141 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 142 | "efficient on GPU but very bad for TPU." 143 | ) 144 | }, 145 | ) 146 | max_train_samples: Optional[int] = field( 147 | default=None, 148 | metadata={ 149 | "help": ( 150 | "For debugging purposes or quicker training, truncate the number of training examples to this " 151 | "value if set." 152 | ) 153 | }, 154 | ) 155 | max_eval_samples: Optional[int] = field( 156 | default=None, 157 | metadata={ 158 | "help": ( 159 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 160 | "value if set." 161 | ) 162 | }, 163 | ) 164 | max_predict_samples: Optional[int] = field( 165 | default=None, 166 | metadata={ 167 | "help": ( 168 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 169 | "value if set." 170 | ) 171 | }, 172 | ) 173 | num_beams: Optional[int] = field( 174 | default=None, 175 | metadata={ 176 | "help": ( 177 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 178 | "which is used during ``evaluate`` and ``predict``." 179 | ) 180 | }, 181 | ) 182 | ignore_pad_token_for_loss: bool = field( 183 | default=True, 184 | metadata={ 185 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 186 | }, 187 | ) 188 | source_prefix: Optional[str] = field( 189 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 190 | ) 191 | 192 | forced_bos_token: Optional[str] = field( 193 | default=None, 194 | metadata={ 195 | "help": ( 196 | "The token to force as the first generated token after the decoder_start_token_id." 197 | "Useful for multilingual models like mBART where the first generated token" 198 | "needs to be the target language token (Usually it is the target language token)" 199 | ) 200 | }, 201 | ) 202 | 203 | 204 | 205 | def __post_init__(self): 206 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: 207 | raise ValueError("Need either a dataset name or a training/validation/test file.") 208 | else: 209 | if self.train_file is not None: 210 | extension = self.train_file.split(".")[-1] 211 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 212 | if self.validation_file is not None: 213 | extension = self.validation_file.split(".")[-1] 214 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 215 | if self.val_max_target_length is None: 216 | self.val_max_target_length = self.max_target_length 217 | 218 | -------------------------------------------------------------------------------- /example_with_embeddings/finetune.py: -------------------------------------------------------------------------------- 1 | # from transformers.integrations import TensorBoardCallback 2 | # from torch.utils.tensorboard import SummaryWriter 3 | from typing import Optional, Tuple 4 | 5 | from torch.optim._multi_tensor import AdamW 6 | from torch.utils.data import SequentialSampler, DistributedSampler 7 | from transformers import TrainingArguments, DataCollatorForSeq2Seq 8 | from transformers import Trainer, HfArgumentParser 9 | from transformers import AutoTokenizer, AutoModel 10 | import torch 11 | import torch.nn as nn 12 | from peft import get_peft_model, TaskType, PrefixTuningConfig, LoraConfig 13 | from dataclasses import dataclass, field 14 | import datasets 15 | import os 16 | 17 | @dataclass 18 | class FinetuneArguments: 19 | dataset_path: str = field(default="medi-data/processed") 20 | model_path: str = field(default="output") 21 | lora_rank: int = field(default=8) 22 | 23 | from types import MethodType 24 | def embedding_forward( 25 | self, 26 | input_ids: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.Tensor] = None, 28 | attention_mask: Optional[torch.Tensor] = None, 29 | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, 30 | inputs_embeds: Optional[torch.Tensor] = None, 31 | labels: Optional[torch.Tensor] = None, 32 | use_cache: Optional[bool] = None, 33 | output_attentions: Optional[bool] = None, 34 | output_hidden_states: Optional[bool] = None, 35 | return_dict: Optional[bool] = None, 36 | ): 37 | use_cache = use_cache if use_cache is not None else self.config.use_cache 38 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 39 | 40 | MASK, gMASK = 150000, 150001 41 | mask_token = MASK if MASK in input_ids else gMASK 42 | use_gmask = False if MASK in input_ids else gMASK 43 | seqs = input_ids.tolist() 44 | mask_positions = [seq.index(mask_token) for seq in seqs] 45 | 46 | position_ids = self.get_position_ids( 47 | input_ids, 48 | mask_positions = mask_positions, 49 | device=input_ids.device, 50 | ) 51 | 52 | past_key_values = tuple([i.permute((0,3,1,2,4)).contiguous() for i in past_key_values]) 53 | 54 | transformer_outputs = self.transformer( 55 | input_ids=input_ids, 56 | position_ids=position_ids, 57 | attention_mask=attention_mask, 58 | past_key_values=past_key_values, 59 | inputs_embeds=inputs_embeds, 60 | use_cache=use_cache, 61 | output_attentions=output_attentions, 62 | output_hidden_states=output_hidden_states, 63 | return_dict=return_dict, 64 | ) 65 | 66 | hidden_states = transformer_outputs.last_hidden_state[-1] 67 | 68 | return hidden_states 69 | 70 | def has_length(dataset): 71 | """ 72 | Checks if the dataset implements __len__() and it doesn't raise an error 73 | """ 74 | try: 75 | return len(dataset) is not None 76 | except TypeError: 77 | # TypeError: len() of unsized object 78 | return False 79 | 80 | class ModifiedTrainer(Trainer): 81 | cl_temperature = 0.01 82 | 83 | def _get_train_sampler(self) : 84 | if self.train_dataset is None or not has_length(self.train_dataset): 85 | return None 86 | 87 | generator = None 88 | if self.args.world_size <= 1: 89 | generator = torch.Generator() 90 | # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with 91 | # `args.seed`) if data_seed isn't provided. 92 | # Further on in this method, we default to `args.seed` instead. 93 | if self.args.data_seed is None: 94 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 95 | else: 96 | seed = self.args.data_seed 97 | generator.manual_seed(seed) 98 | 99 | seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed 100 | 101 | if self.args.world_size <= 1: 102 | return SequentialSampler(self.train_dataset) 103 | else: 104 | return DistributedSampler( 105 | self.train_dataset, 106 | num_replicas=self.args.world_size, 107 | rank=self.args.process_index, 108 | seed=seed, 109 | ) 110 | 111 | def compute_loss(self, model, inputs, return_outputs=False): 112 | # for task_id in inputs['task_name']: 113 | # assert task_id==inputs['task_name'][0],f"Examples in the same batch should come from the same task, " \ 114 | # f"but task {task_id} and task {inputs['task_name'][0]} are found" 115 | cur_results = {} 116 | for k in ['query', 'pos', 'neg']: 117 | cur_inputs = { 118 | 'input_ids': inputs[f'{k}_input_ids'], 119 | # 'attention_mask': inputs[f'{k}_attention_mask'], 120 | # 'context_masks': inputs[f'{k}_context_masks'], 121 | } 122 | cur_results[k] = model(**cur_inputs) 123 | embeddings_query = cur_results['query'] 124 | embeddings_pos = cur_results['pos'] 125 | embeddings_neg = cur_results['neg'] 126 | 127 | num = len(embeddings_query) 128 | all_scores = None 129 | from torch import nn 130 | similarity_fct = nn.CosineSimilarity(dim=-1) 131 | for i in range(0, num): 132 | anchor_emb = embeddings_query[i].unsqueeze(0) 133 | pos_emb = embeddings_pos[i].unsqueeze(0) 134 | cur_score = similarity_fct(anchor_emb, pos_emb) / self.cl_temperature 135 | 136 | for j in range(0, num): 137 | one_neg_emb = embeddings_neg[j].unsqueeze(0) 138 | one_neg_score = similarity_fct(anchor_emb, one_neg_emb) / self.cl_temperature 139 | cur_score = torch.cat([cur_score, one_neg_score], dim=-1) 140 | if all_scores is None: 141 | all_scores = cur_score.unsqueeze(0) 142 | else: 143 | all_scores = torch.cat([all_scores, cur_score.unsqueeze(0)], dim=0) 144 | 145 | labels = torch.zeros(all_scores.size(0)).long().to(embeddings_query.device) 146 | loss = nn.CrossEntropyLoss()(all_scores, labels) 147 | 148 | all_another_scores = None 149 | for i in range(0, num): 150 | anchor_emb = embeddings_pos[i].unsqueeze(0) 151 | pos_emb = embeddings_query[i].unsqueeze(0) 152 | cur_score = similarity_fct(anchor_emb, pos_emb) / self.cl_temperature 153 | 154 | for j in range(0, num): 155 | if i == j: 156 | continue 157 | one_neg_emb = embeddings_query[j].unsqueeze(0) 158 | one_neg_score = similarity_fct(anchor_emb, one_neg_emb) / self.cl_temperature 159 | cur_score = torch.cat([cur_score, one_neg_score], dim=-1) 160 | if all_another_scores is None: 161 | all_another_scores = cur_score.unsqueeze(0) 162 | else: 163 | all_another_scores = torch.cat([all_another_scores, cur_score.unsqueeze(0)], dim=0) 164 | labels_another = torch.zeros(all_another_scores.size(0)).long().to(embeddings_query.device) 165 | loss += nn.CrossEntropyLoss()(all_another_scores, labels_another) 166 | 167 | return loss 168 | 169 | def save_model(self, output_dir=None, _internal_call=False): 170 | from transformers.trainer import TRAINING_ARGS_NAME 171 | 172 | os.makedirs(output_dir, exist_ok=True) 173 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 174 | saved_params = { 175 | k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad 176 | } 177 | torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin")) 178 | 179 | ignore_pad_token_for_loss = False 180 | 181 | from arguments import ModelArguments 182 | 183 | def main(): 184 | # writer = SummaryWriter() 185 | finetune_args, training_args,model_args = HfArgumentParser( 186 | (FinetuneArguments, TrainingArguments,ModelArguments) 187 | ).parse_args_into_dataclasses() 188 | training_args.remove_unused_columns = False 189 | 190 | # init model 191 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, truncation_side="left") 192 | 193 | model = AutoModel.from_pretrained( 194 | model_args.model_name_or_path, trust_remote_code=True 195 | ).half().cuda() 196 | model.gradient_checkpointing_enable() 197 | model.enable_input_require_grads() 198 | model.is_parallelizable = True 199 | model.model_parallel = True 200 | # model.lm_head = CastOutputToFloat(model.lm_head) 201 | model.config.use_cache = ( 202 | False # silence the warnings. Please re-enable for inference! 203 | ) 204 | original_forward = model.forward 205 | model.forward = MethodType(embedding_forward, model) 206 | 207 | # setup peft 208 | peft_config = PrefixTuningConfig( 209 | task_type=TaskType.CAUSAL_LM, 210 | num_virtual_tokens=8, 211 | # encoder_hidden_size=8, 212 | prefix_projection=False 213 | ) 214 | model = get_peft_model(model, peft_config) 215 | 216 | model = model.half().cuda() 217 | 218 | # load dataset 219 | dataset = datasets.load_from_disk(finetune_args.dataset_path) 220 | print(f"\n{len(dataset)=}\n") 221 | 222 | label_pad_token_id = -100 if ignore_pad_token_for_loss else tokenizer.pad_token_id 223 | data_collator = DataCollatorForSeq2Seq( 224 | tokenizer, 225 | model=model, 226 | label_pad_token_id=label_pad_token_id, 227 | pad_to_multiple_of=None, 228 | # padding=False, 229 | ) 230 | 231 | optimizer = AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-6) 232 | 233 | # start train 234 | trainer = ModifiedTrainer( 235 | model=model, 236 | train_dataset=dataset, 237 | args=training_args, 238 | # callbacks=[TensorBoardCallback(writer)], 239 | data_collator=data_collator, 240 | optimizers=(optimizer, None) # LambdaLR(optimizer,lr_lambda=training_args.learning_rate)), 241 | ) 242 | checkpoint = None 243 | if training_args.resume_from_checkpoint is not None: 244 | checkpoint = training_args.resume_from_checkpoint 245 | trainer.train(resume_from_checkpoint=checkpoint) 246 | # writer.close() 247 | # save model 248 | model.save_pretrained(training_args.output_dir) 249 | 250 | 251 | if __name__ == "__main__": 252 | main() -------------------------------------------------------------------------------- /example_with_embeddings/medi-data/README.txt: -------------------------------------------------------------------------------- 1 | This folder contains the Multitask Embeddings Data with Instructions (MEDI) for the paper: One Embedder, Any Task: Instruction-Finetuned Text Embeddings. 2 | 3 | It contains the follow file: 4 | - medi-data.json 5 | # Training Examples: 1,435,000 6 | - README.txt 7 | 8 | The MEDI data consists of a collection of 330 datasets from Super-NI(Super-NaturalInstructions), sentence-transformer embedding training data, and KILT, spanning a wide range of domains and tasks. 9 | 10 | If you use the dataset, please cite the following papers including Su et al., 2022, Wang et al., 2022, Petroni et al., 2021 and sentence transformer embedding training data at https://huggingface.co/datasets/sentence-transformers/embedding-training-data. 11 | 12 | @inproceedings{INSTRUCTOR, 13 | title={One Embedder, Any Task: Instruction-Finetuned Text Embeddings}, 14 | author={Hongjin Su, Weijia Shi, Jungo Kasai, Yizhong Wang, Yushi Hu, Mari Ostendorf, Wen-tau Yih, Noah A. Smith, Luke Zettlemoyer, Tao Yu}, 15 | url={https://arxiv.org/abs/2212.09741}, 16 | year={2022}, 17 | } 18 | 19 | @inproceedings{wang2022super, 20 | title={Super-naturalinstructions: generalization via declarative instructions on 1600+ tasks}, 21 | author={Wang, Yizhong and Mishra, Swaroop and Alipoormolabashi, Pegah and Kordi, Yeganeh and Mirzaei, Amirreza and Arunkumar, Anjana and Ashok, Arjun and Dhanasekaran, Arut Selvan and Naik, Atharva and Stap, David and others}, 22 | year={2022}, 23 | organization={EMNLP} 24 | } 25 | 26 | @article{petroni2020kilt, 27 | title={KILT: a benchmark for knowledge intensive language tasks}, 28 | author={Petroni, Fabio and Piktus, Aleksandra and Fan, Angela and Lewis, Patrick and Yazdani, Majid and De Cao, Nicola and Thorne, James and Jernite, Yacine and Karpukhin, Vladimir and Maillard, Jean and others}, 29 | journal={arXiv preprint arXiv:2009.02252}, 30 | year={2020} 31 | } 32 | 33 | 34 | -------------------------------------------------------------------------------- /example_with_embeddings/models.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from types import MethodType 3 | from typing import Optional, Tuple 4 | 5 | from transformers import AutoModel,AutoTokenizer 6 | from peft import PeftModel, PrefixTuningConfig, TaskType, get_peft_model, PromptLearningConfig, PeftType 7 | import torch 8 | 9 | MODEL = "THUDM/chatglm-6b-int4" #"C:\Documents\data\大型语言模型\Models\ChatGLM-6B-main\model-int4" 10 | Embedding_prefix = "./output/checkpoint-1000/adapter_model.bin" 11 | 12 | def embedding_forward( 13 | self, 14 | input_ids: Optional[torch.Tensor] = None, 15 | position_ids: Optional[torch.Tensor] = None, 16 | attention_mask: Optional[torch.Tensor] = None, 17 | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, 18 | inputs_embeds: Optional[torch.Tensor] = None, 19 | labels: Optional[torch.Tensor] = None, 20 | use_cache: Optional[bool] = None, 21 | output_attentions: Optional[bool] = None, 22 | output_hidden_states: Optional[bool] = None, 23 | return_dict: Optional[bool] = None, 24 | ): 25 | use_cache = use_cache if use_cache is not None else self.config.use_cache 26 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 27 | 28 | MASK, gMASK = 150000, 150001 29 | mask_token = MASK if MASK in input_ids else gMASK 30 | use_gmask = False if MASK in input_ids else gMASK 31 | seqs = input_ids.tolist() 32 | mask_positions = [seq.index(mask_token) for seq in seqs] 33 | 34 | position_ids = self.get_position_ids( 35 | input_ids, 36 | mask_positions = mask_positions, 37 | device=input_ids.device, 38 | ) 39 | 40 | past_key_values = [torch.permute(i,(0,3,1,2,4)) for i in past_key_values] 41 | 42 | transformer_outputs = self.transformer( 43 | input_ids=input_ids, 44 | position_ids=position_ids, 45 | attention_mask=attention_mask, 46 | past_key_values=past_key_values, 47 | inputs_embeds=inputs_embeds, 48 | use_cache=use_cache, 49 | output_attentions=output_attentions, 50 | output_hidden_states=output_hidden_states, 51 | return_dict=return_dict, 52 | ) 53 | 54 | hidden_states = transformer_outputs.last_hidden_state[-1] 55 | 56 | return hidden_states 57 | 58 | def peft_forward( 59 | self, 60 | input_ids=None, 61 | attention_mask=None, 62 | inputs_embeds=None, 63 | labels=None, 64 | output_attentions=None, 65 | output_hidden_states=None, 66 | return_dict=None, 67 | **kwargs, 68 | ): 69 | batch_size = input_ids.shape[0] 70 | if attention_mask is not None: 71 | # concat prompt attention mask 72 | prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device) 73 | attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) 74 | 75 | if kwargs.get("position_ids", None) is not None: 76 | warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") 77 | kwargs["position_ids"] = None 78 | if kwargs.get("token_type_ids", None) is not None: 79 | warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") 80 | kwargs["token_type_ids"] = None 81 | kwargs.update( 82 | { 83 | "attention_mask": attention_mask, 84 | "labels": labels, 85 | "output_attentions": output_attentions, 86 | "output_hidden_states": output_hidden_states, 87 | "return_dict": return_dict, 88 | } 89 | ) 90 | 91 | past_key_values = self.get_prompt(batch_size) 92 | return embedding_forward(self.base_model,input_ids=input_ids, past_key_values=past_key_values, **kwargs) # self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs) 93 | 94 | tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True,truncation_side="left") 95 | model = AutoModel.from_pretrained(MODEL, trust_remote_code=True).half().cuda() 96 | model = model.eval() 97 | 98 | # setup peft 99 | peft_config = PrefixTuningConfig( 100 | task_type=TaskType.CAUSAL_LM, 101 | num_virtual_tokens=8, 102 | prefix_projection=False 103 | ) 104 | 105 | embedding_model = get_peft_model(model, peft_config).half().cuda() 106 | 107 | embedding_model.load_state_dict(torch.load(Embedding_prefix),strict=False) 108 | 109 | embedding_model.forward = MethodType(peft_forward,embedding_model) -------------------------------------------------------------------------------- /example_with_embeddings/mteb_benchmark.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | from transformers import AutoTokenizer, AutoModel, top_k_top_p_filtering 6 | import torch 7 | 8 | from mteb import MTEB,DRESModel 9 | 10 | from models import embedding_model,tokenizer 11 | 12 | def get_embeddings(tokenizer, query): 13 | tokenized = tokenizer([q+'[MASK]' for q in query], return_tensors="pt", padding=True,truncation=True,max_length=1024)['input_ids'].to(embedding_model.device) 14 | with torch.no_grad(): 15 | outputs = embedding_model(tokenized) 16 | return outputs.tolist() 17 | # outputs = model.transformer(tokenized) 18 | # return outputs.last_hidden_state[-1, :, :].tolist() 19 | 20 | class MyModel(DRESModel): 21 | 22 | def encode(self,sentences,**kwargs): 23 | batch_size = 4 24 | dataloader = DataLoader(sentences, batch_size=batch_size, shuffle=False) 25 | embeddings = [] 26 | for batch in tqdm(dataloader): 27 | embeddings.extend(get_embeddings(tokenizer, batch)) 28 | return embeddings 29 | 30 | def encode_queries(self, queries: List[str], batch_size: int, **kwargs): 31 | return self.encode(queries,**kwargs) 32 | 33 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs): 34 | if type(corpus) is dict: 35 | sentences = [ 36 | (corpus["title"][i] + self.sep + corpus["text"][i]).strip() 37 | if "title" in corpus 38 | else corpus["text"][i].strip() 39 | for i in range(len(corpus["text"])) 40 | ] 41 | else: 42 | sentences = [ 43 | (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() 44 | for doc in corpus 45 | ] 46 | return self.encode(sentences, **kwargs) 47 | 48 | 49 | get_embeddings(tokenizer, ["I am a sentence"]) 50 | 51 | evalModel = MyModel(None) 52 | # evaluation = MTEB(task_types=['Retrieval']) 53 | evaluation = MTEB(tasks=['ArguAna']) 54 | evaluation.run(evalModel,output_folder=f'results_finetuned') -------------------------------------------------------------------------------- /example_with_embeddings/output/checkpoint-1000/adapter_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgechen1827/ChatGLM-text-embedding/0c30ca2e76fbeb18910eaa15d719d4f58fe07f94/example_with_embeddings/output/checkpoint-1000/adapter_model.bin -------------------------------------------------------------------------------- /example_with_embeddings/preprocess_and_tokenize.py: -------------------------------------------------------------------------------- 1 | # This script is based on the modification from https://github.com/huggingface/transformers 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import json 7 | 8 | import torch 9 | from datasets import DatasetDict, Dataset 10 | from transformers import ( 11 | AutoTokenizer, 12 | HfArgumentParser, 13 | Seq2SeqTrainingArguments, 14 | set_seed, 15 | ) 16 | 17 | from arguments import ModelArguments, DataTrainingArguments 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 23 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 24 | # If we pass only one argument to the script and it's the path to a json file, 25 | # let's parse it to get our arguments. 26 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 27 | else: 28 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 29 | 30 | max_source_length = 512 31 | max_train_samples = None 32 | 33 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True,truncation_side="left") 34 | 35 | set_seed(2023) 36 | with open('medi-data/medi-data.json') as f: 37 | train_examples_raw = json.load(f) 38 | 39 | old_train_examples_raw = train_examples_raw 40 | train_examples_raw = [] 41 | total_n = len(old_train_examples_raw) 42 | real_batch_size = max(training_args.per_device_train_batch_size, 43 | training_args.per_device_train_batch_size * torch.cuda.device_count()) 44 | # print('real_batch_size: ', real_batch_size,training_args.per_device_train_batch_size,torch.cuda.device_count()) 45 | for idx in range(0, total_n, real_batch_size): 46 | local_task_name = old_train_examples_raw[idx]['task_name'] 47 | cur_batch = [] 48 | include_batch = True 49 | for idx1 in range(idx, min(idx + real_batch_size, total_n)): 50 | if not old_train_examples_raw[idx1]['task_name'] == local_task_name: 51 | print(f'one batch in task {old_train_examples_raw[idx1]["task_name"]} is skipped') 52 | include_batch = False 53 | break 54 | else: 55 | cur_batch.append(old_train_examples_raw[idx1]) 56 | if include_batch and len(cur_batch) == real_batch_size: 57 | train_examples_raw.append(cur_batch) 58 | random.shuffle(train_examples_raw) 59 | train_examples_raw_batch = train_examples_raw 60 | train_examples_raw = [] 61 | for b in train_examples_raw_batch: 62 | train_examples_raw += b 63 | print(f'There are {len(train_examples_raw)} pairs to train in total') 64 | 65 | train_examples = {'query': [], 'pos': [], 'neg': [], 'task_name': []} 66 | task_name_map = {} 67 | total_train_num = len(train_examples_raw) 68 | task_count = 0 69 | for i in range(total_train_num): 70 | cur_e = train_examples_raw[i] 71 | for k in ['query', 'pos', 'neg']: 72 | for s in cur_e[k][:-1]: 73 | assert not '!@#$%^&**!@#$%^&**' in s 74 | cur_e[k][-1] = str(cur_e[k][-1]) 75 | if True: # not data_args.add_prompt_to_document: 76 | cur_e[k][0] = '' 77 | assert cur_e[k][0].startswith('Represent ') or cur_e[k][0] == '' 78 | train_examples[k].append('!@#$%^&**!@#$%^&**'.join(cur_e[k])) # '!@#$%^&**!@#$%^&**' 79 | if not cur_e['task_name'] in task_name_map: 80 | task_name_map[cur_e['task_name']] = task_count 81 | task_count += 1 82 | train_examples['task_name'].append(task_name_map[cur_e['task_name']]) 83 | 84 | raw_datasets = DatasetDict({'train': Dataset.from_dict(train_examples)}) 85 | 86 | column_names = raw_datasets["train"].column_names 87 | 88 | 89 | def preprocess_function(examples): 90 | all_tokenized = None 91 | for key in ['query', 'pos', 'neg']: 92 | num = len(examples[key]) 93 | contexts = [] 94 | for local_idx in range(num): 95 | splits = examples[key][local_idx].split('!@#$%^&**!@#$%^&**') 96 | # assert len(splits) == 2 97 | contexts.append(splits[-1]) 98 | assert isinstance(contexts[-1], str) 99 | tokenized = tokenizer([q + '[MASK]' for q in contexts], padding=True, truncation=True, 100 | return_tensors="pt", max_length=max_source_length) 101 | # tokenized['context_masks'] = torch.sum(context_tok['attention_mask'], dim=1) 102 | # tokenized['context_masks'] = tokenized['context_masks'] - 1 103 | # for my_idx in range(len(tokenized['context_masks'])): 104 | # if tokenized['context_masks'][my_idx] <= 1: 105 | # tokenized['context_masks'][my_idx] = 0 106 | keys = tokenized.keys() 107 | if all_tokenized is None: 108 | all_tokenized = tokenized.copy() 109 | for k in keys: 110 | all_tokenized[k] = all_tokenized[k].tolist() 111 | for k in keys: 112 | all_tokenized[f'{key}_{k}'] = tokenized[k].tolist() 113 | all_tokenized['task_name'] = examples['task_name'] 114 | # all_tokenized['label'] = all_tokenized['input_ids'] 115 | return all_tokenized 116 | 117 | 118 | train_dataset = raw_datasets["train"] 119 | if max_train_samples is not None: 120 | max_train_samples = min(len(train_dataset), max_train_samples) 121 | train_dataset = train_dataset.select(range(max_train_samples)) 122 | # with TrainingArguments(output_dir='/medi-data').main_process_first(desc="train dataset map pre-processing"): 123 | with training_args.main_process_first(desc="train dataset map pre-processing"): 124 | train_dataset = train_dataset.map( 125 | preprocess_function, 126 | # batch_size=1, 127 | batched=True, 128 | num_proc=data_args.preprocessing_num_workers, 129 | remove_columns=column_names, 130 | load_from_cache_file=not data_args.overwrite_cache, 131 | desc="Running tokenizer on train dataset", 132 | ) 133 | train_dataset.save_to_disk('medi-data/processed') -------------------------------------------------------------------------------- /example_with_embeddings/results/ArguAna.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_revision": null, 3 | "mteb_dataset_name": "ArguAna", 4 | "mteb_version": "1.0.1", 5 | "test": { 6 | "evaluation_time": 1099.35, 7 | "map_at_1": 0.10669, 8 | "map_at_10": 0.22237, 9 | "map_at_100": 0.23785, 10 | "map_at_1000": 0.23823, 11 | "map_at_3": 0.17342, 12 | "map_at_5": 0.19842, 13 | "mrr_at_1": 0.1074, 14 | "mrr_at_10": 0.22277, 15 | "mrr_at_100": 0.23831, 16 | "mrr_at_1000": 0.23869, 17 | "mrr_at_3": 0.17402, 18 | "mrr_at_5": 0.19866, 19 | "ndcg_at_1": 0.10669, 20 | "ndcg_at_10": 0.29922, 21 | "ndcg_at_100": 0.37089, 22 | "ndcg_at_1000": 0.38178, 23 | "ndcg_at_3": 0.1957, 24 | "ndcg_at_5": 0.24123, 25 | "precision_at_1": 0.10669, 26 | "precision_at_10": 0.05512, 27 | "precision_at_100": 0.00878, 28 | "precision_at_1000": 0.00097, 29 | "precision_at_3": 0.08677, 30 | "precision_at_5": 0.0744, 31 | "recall_at_1": 0.10669, 32 | "recall_at_10": 0.55121, 33 | "recall_at_100": 0.87767, 34 | "recall_at_1000": 0.96515, 35 | "recall_at_3": 0.26031, 36 | "recall_at_5": 0.37198 37 | } 38 | } -------------------------------------------------------------------------------- /example_with_embeddings/results/CQADupstackAndroidRetrieval.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_revision": null, 3 | "mteb_dataset_name": "CQADupstackAndroidRetrieval", 4 | "mteb_version": "1.0.1", 5 | "test": { 6 | "evaluation_time": 2281.82, 7 | "map_at_1": 0.04766, 8 | "map_at_10": 0.06811, 9 | "map_at_100": 0.07251, 10 | "map_at_1000": 0.0733, 11 | "map_at_3": 0.0603, 12 | "map_at_5": 0.06362, 13 | "mrr_at_1": 0.05866, 14 | "mrr_at_10": 0.08773, 15 | "mrr_at_100": 0.09265, 16 | "mrr_at_1000": 0.09336, 17 | "mrr_at_3": 0.07797, 18 | "mrr_at_5": 0.08262, 19 | "ndcg_at_1": 0.05866, 20 | "ndcg_at_10": 0.08602, 21 | "ndcg_at_100": 0.10998, 22 | "ndcg_at_1000": 0.13444, 23 | "ndcg_at_3": 0.07146, 24 | "ndcg_at_5": 0.07593, 25 | "precision_at_1": 0.05866, 26 | "precision_at_10": 0.0186, 27 | "precision_at_100": 0.00386, 28 | "precision_at_1000": 0.00073, 29 | "precision_at_3": 0.0372, 30 | "precision_at_5": 0.0269, 31 | "recall_at_1": 0.04766, 32 | "recall_at_10": 0.12038, 33 | "recall_at_100": 0.22867, 34 | "recall_at_1000": 0.40955, 35 | "recall_at_3": 0.07669, 36 | "recall_at_5": 0.09071 37 | } 38 | } -------------------------------------------------------------------------------- /example_with_embeddings/results/CQADupstackWebmastersRetrieval.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_revision": null, 3 | "mteb_dataset_name": "CQADupstackWebmastersRetrieval", 4 | "mteb_version": "1.0.1", 5 | "test": { 6 | "evaluation_time": 1565.16, 7 | "map_at_1": 0.02926, 8 | "map_at_10": 0.04326, 9 | "map_at_100": 0.0474, 10 | "map_at_1000": 0.04865, 11 | "map_at_3": 0.03544, 12 | "map_at_5": 0.04066, 13 | "mrr_at_1": 0.03953, 14 | "mrr_at_10": 0.05811, 15 | "mrr_at_100": 0.06305, 16 | "mrr_at_1000": 0.06412, 17 | "mrr_at_3": 0.04908, 18 | "mrr_at_5": 0.05451, 19 | "ndcg_at_1": 0.03953, 20 | "ndcg_at_10": 0.05781, 21 | "ndcg_at_100": 0.08062, 22 | "ndcg_at_1000": 0.11664, 23 | "ndcg_at_3": 0.04302, 24 | "ndcg_at_5": 0.05175, 25 | "precision_at_1": 0.03953, 26 | "precision_at_10": 0.01245, 27 | "precision_at_100": 0.00354, 28 | "precision_at_1000": 0.00104, 29 | "precision_at_3": 0.02108, 30 | "precision_at_5": 0.01858, 31 | "recall_at_1": 0.02926, 32 | "recall_at_10": 0.08778, 33 | "recall_at_100": 0.19427, 34 | "recall_at_1000": 0.46043, 35 | "recall_at_3": 0.04558, 36 | "recall_at_5": 0.06786 37 | } 38 | } -------------------------------------------------------------------------------- /example_with_embeddings/results_finetuned/ArguAna.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_revision": null, 3 | "mteb_dataset_name": "ArguAna", 4 | "mteb_version": "1.0.2", 5 | "test": { 6 | "evaluation_time": 10709.28, 7 | "map_at_1": 0.16145, 8 | "map_at_10": 0.28938, 9 | "map_at_100": 0.30486, 10 | "map_at_1000": 0.30512, 11 | "map_at_3": 0.23815, 12 | "map_at_5": 0.26382, 13 | "mrr_at_1": 0.16287, 14 | "mrr_at_10": 0.28989, 15 | "mrr_at_100": 0.30537, 16 | "mrr_at_1000": 0.30562, 17 | "mrr_at_3": 0.23826, 18 | "mrr_at_5": 0.26433, 19 | "ndcg_at_1": 0.16145, 20 | "ndcg_at_10": 0.37405, 21 | "ndcg_at_100": 0.44096, 22 | "ndcg_at_1000": 0.44733, 23 | "ndcg_at_3": 0.26489, 24 | "ndcg_at_5": 0.31147, 25 | "precision_at_1": 0.16145, 26 | "precision_at_10": 0.06515, 27 | "precision_at_100": 0.00946, 28 | "precision_at_1000": 0.001, 29 | "precision_at_3": 0.11427, 30 | "precision_at_5": 0.09132, 31 | "recall_at_1": 0.16145, 32 | "recall_at_10": 0.65149, 33 | "recall_at_100": 0.94595, 34 | "recall_at_1000": 0.99502, 35 | "recall_at_3": 0.34282, 36 | "recall_at_5": 0.45661 37 | } 38 | } --------------------------------------------------------------------------------