├── README.md ├── README_zh.md ├── assets └── cmteb-0505.png ├── data_example ├── cls.dataset │ ├── data-00000-of-00001.arrow │ ├── dataset_info.json │ └── state.json ├── retri.dataset │ ├── data-00000-of-00001.arrow │ ├── dataset_info.json │ └── state.json └── sts.dataset │ ├── data-00000-of-00001.arrow │ ├── dataset_info.json │ └── state.json ├── ds_config_zero1.json ├── finetune └── train.py ├── meta_lists └── piccolo-ft.txt ├── piccolo ├── __pycache__ │ ├── arguments.cpython-310.pyc │ ├── criteria.cpython-310.pyc │ ├── data.cpython-310.pyc │ ├── data_structures.cpython-310.pyc │ └── model.cpython-310.pyc ├── arguments.py ├── criteria.py ├── data.py ├── data_structures.py └── model.py ├── requirements.txt └── scripts └── ft.sh /README.md: -------------------------------------------------------------------------------- 1 | [EN](README.md) | [简体中文](README_zh.md) 2 | 3 | # [Piccolo2: General Text Embeddings with Multi-Task Hybrid loss Training](https://arxiv.org/abs/2405.06932) 4 | 5 | 🚀 **New SOTA on CMTEB** 6 | 7 | 🔥 Our general sentence embedding [sensenova/piccolo-large-zh-v2](https://huggingface.co/sensenova/piccolo-large-zh-v2) achieves SOTA on the CMTEB Leaderboard with an average score of 70.95. [2024/4/23] 8 | 9 |
10 | 📄 Results on CMTEB Leaderboard [click to expand] 11 |

12 | 13 |

14 |
15 | 16 | ## 💡Model Highlights 17 | Here we introduce Piccolo2, an embedding model that surpasses other models in the comprehensive evaluation over 6 tasks on CMTEB benchmark, setting a new state-of-the-art. Piccolo2 primarily leverages an efficient multi-task hybrid loss training approach, effectively harnessing textual data and labels from diverse downstream tasks. In addition, Piccolo2 scales up the embedding dimension and uses MRL training to support more flexible vector dimensions. 18 | 19 | For huggingface model, please refer to our space: https://huggingface.co/sensenova 20 | For training details, please refer to our tech report: https://arxiv.org/abs/2405.06932 21 | 22 | ## 📖 Repo Details 23 | In this repo, we release our training code and implement some tricks that are helpful for embedding training, including: 24 | - Multi-task Hybrid Loss Training 25 | - Matryoshka Representation Learning 26 | - Embdding Dimension Scaling 27 | - Task-Homogenous Dataset 28 | - Position Embedding Hierarchical Decomposition 29 | 30 | In order to save memory, we default use deepspeed-zero1, gradient checkpointing and mix-precision for training. We also provide slurm scripts to help conduct distributed training. 31 | 32 | ### Tips 33 | 1. The repository will default to training using a multi-task hybrid loss, which is described in our technical report. 34 | 35 | 2. For embdding dimension scaling, we hard code the pretrain path, input dim and output dim in the code: 36 | ```python 37 | self.scaling_layer = ScalingLayer(origin_dim=1024, scaling_dim=1792) 38 | if os.path.exists(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')): 39 | scaling_layer_state_dict = torch.load(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')) 40 | self.scaling_layer.load_state_dict(scaling_layer_state_dict, strict=True) 41 | ``` 42 | 3. For MRL training, we hard code the nesting list. Feel free to modify it by yourself. 43 | ```python 44 | self.mrl_nesting_list = [256, 512, 768, 1024, 1280, 1536, 1792] 45 | ``` 46 | 47 | 4. If you want to increase the length of the position embedding, set `extend_pe` to True and then set `max_length` to your expected length in the scripts. 48 | 49 | ## 🔨 Best Practice 50 | ### 1. Environment 51 | ```shell 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | ### 2. Prepare Dataset 56 | We divide the dataset into three major categories: retrieval/reranking, clustering/classification, abd sts/pair classification, and we adopt different loss functions for different categories. In `data_example` dir, we provide example for these three types of data. 57 | 58 | 1) `Retri`: For Retrieval and Reranking dataset, we implement standard InfoNCE with in-batch-negative, this dataset contain 4 columns: `text`, `text_pos`, `text_neg`, `type`. Here 'type' is 'retri_contrast' 59 | 60 | 2) `STS`: For STS and Pair-Classification dataset, we implement the cosent loss, this dataset contain 4 columns: `text`, `text_pair`, `label`, `type`. Here 'type' is 'cosent' 61 | 62 | 3) `Cls`: For Classification and Clustering dataset, we implement the InfoNCE w/o in-batch-negative, this dataset contain 4 columns: `text`, `text_pos`, `text_neg`, `type`. Here 'type' is 'cls_contrast' 63 | 64 | The 'type' column indicates the type of the current dataset. We obtain the type of the current batch to apply different loss functions. 65 | 66 | ### 3. Training 67 | We offer the training script located at `scripts/ft.sh`. Below we've provided explanations for the variables used within the script. 68 | 69 | **Environment Variables** 70 | - ROOT: The absolute path of the repo in your machine. 71 | - GPUS_PER_NODE: Number of GPUs on a single machine 72 | The default proviede scripts is for single-machine training. If you are in a multi-node distributed training scenario, you should additionally specify the env parameters below: 73 | - WORLD_SIZE: Number of nodes 74 | - RANK: rank of current node, it is often obtained thourgh the SLURM environment variable 75 | - MASTER_ADDR:MASTER_PORT: Communication port 76 | 77 | **Training Variables** 78 | - MODEL_NAME_OR_PATH: the absolute path of the pretrain model 79 | - DS_PATH: the deepspeed config, we provide a default config in `./de_config_zero1.json` 80 | - META_PATHS: the list of the dataset. We provide a sample in `meta_lists/piccolo.txt`. Each row consists of two columns: the relative path of the dataset and the number of repeats. 81 | - ROOT_DIRS: The absolute path of the dataset dir 82 | 83 | **Run** 84 | ```shell 85 | bash scripts/ft.sh 86 | ``` 87 | 88 | ## 🤗 **Model List** 89 | | Model|Language||Description|prompt| 90 | |:-|:-:|:-:|:--------------------------------------------:|:---------:| 91 | | [sensenova/piccolo-large-zh-v2](https://huggingface.co/sensenova/piccolo-large-zh-v2) | Chinese | | version2: finetuning with multi-task hybrid loss training | None | 92 | | [sensenova/piccolo-large-zh](https://huggingface.co/sensenova/piccolo-large-zh) | Chinese | | version1: pretrain under 400 million chinese text pair | '查询'/'结果' | 93 | | [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh) | Chinese | | version1: pretrain under 400 million chinese text pair | '查询'/'结果' | 94 | 95 | 96 | ## Citation 97 | If you find our tech report, models or code helpful, please cite our report or give a star on github or huggingface! 98 | ```bibtex 99 | @misc{2405.06932, 100 | Author = {Junqin Huang and Zhongjie Hu and Zihao Jing and Mengya Gao and Yichao Wu}, 101 | Title = {Piccolo2: General Text Embedding with Multi-task Hybrid Loss Training}, 102 | Year = {2024}, 103 | Eprint = {arXiv:2405.06932}, 104 | } 105 | ``` -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | [EN](README.md) | [简体中文](README_zh.md) 2 | 3 | # [Piccolo2: General Text Embeddings with Multi-Task Hybrid loss Training](https://arxiv.org/abs/2405.06932) 4 | 5 | 🚀 **New SOTA on CMTEB** 6 | 7 | 🔥 我们最新的通用embedding模型 [sensenova/piccolo-large-zh-v2](https://huggingface.co/sensenova/piccolo-large-zh-v2) 在CMTEB评测榜单上取得了70.95的均分! [2024/4/23] 8 | 9 |
10 | 📄 CMTEB结果 [点击展开] 11 |

12 | 13 |

14 |
15 | 16 | ## 💡Model Highlights 17 | Piccolo2 在CMTEB榜单上的6项任务的综合评估中超越了其他模型,目前位于第一位。Piccolo2 主要利用高效的多任务混合损失训练方法,有效地利用来自不同下游任务的文本数据和标签。 此外,Piccolo2 扩大了Embedding维度,并使用MRL训练来支持更灵活的向量维度。 18 | huggingface上放了我们最新的模型: https://huggingface.co/sensenova 19 | 对于训练细节,可以参考我们的技术报告: https://arxiv.org/abs/2405.06932 20 | 21 | ## 📖 Repo Details 22 | 在这个repo里面,我们放出了训练的代码,里面提供了一些有助于提升Embedding模型性能的训练技巧: 23 | - Multi-task Hybrid Loss Training 24 | - Matryoshka Representation Learning 25 | - Embdding Dimension Scaling 26 | - Task-Homogenous Dataset 27 | - Position Embedding Hierarchical Decomposition 28 | 29 | 为了节省内存,我们默认使用 deepspeed-zero1、gradient checkpointing和mix-precision进行训练。我们还提供了一个脚本用来帮助大家进行分布式训练。 30 | 31 | ### Tips 32 | 1. 该项目会默认使用multi-task hybrid loss来进行训练,前提是数据按规定格式准备好。 33 | 34 | 2. 对于scaling dimension length, 我们将它的一些参数写死在了代码里面, 请根据需要自行更改: 35 | ```python 36 | self.scaling_layer = ScalingLayer(origin_dim=1024, scaling_dim=1792) 37 | if os.path.exists(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')): 38 | scaling_layer_state_dict = torch.load(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')) 39 | self.scaling_layer.load_state_dict(scaling_layer_state_dict, strict=True) 40 | ``` 41 | 3. 对于MRL训练, 我们也把它的参数写死在了代码里. 42 | ```python 43 | self.mrl_nesting_list = [256, 512, 768, 1024, 1280, 1536, 1792] 44 | ``` 45 | 46 | 4. 如果你想增长模型的position embedding,我们实现了一个简单的层次分解方法, 只需要将 `extend_pe` 设置为True 然后把 `max_length` 设置为你的目标长度就可以了. 47 | 48 | 49 | ## 🔨 使用指南 50 | ### 1. 环境 51 | ```shell 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | ### 2. 数据准备 56 | 我们将数据集分为三大类:检索/排序、聚类/分类、句对相似度/句对分类,并对不同类别采用不同的损失函数。 在`data_example`目录中,我们提供了这三种类型数据的示例。 57 | 58 | 1) `Retri`: 对于检索、重排数据集, 我们采用标准的InfoNCE进行优化,同时采用in-batch-negative来扩充负样本数量。这个数据集需要四列: `text`, `text_pos`, `text_neg`, `type`. 这里的 'type' 需要被标注为 'retri_contrast'。 59 | 60 | 2) `STS`: 对于句对相似度,句对分类数据集, 我们采用了cosent这种排序损失, 这个数据集同样有四列: `text`, `text_pair`, `label`, `type`. 这里 'type' 需要被标注为 'cosent'。 61 | 62 | 3) `Cls`: 对于分类、聚类任务, 我们将文本和它的语义标签视为正负样本对,同样采用了InfoNCE损失来优化,但不再采样in-batch-negative(因为这很容易造成训练冲突),这类数据集同样包含四列: `text`, `text_pos`, `text_neg`, `type`. 这里 'type' 需要被标注为 'cls_contrast' 63 | 64 | 'type' 列表明了当前数据的类型,在训练的时候,我们通过获取当前数据的类型,以采用不同的损失进行优化。 65 | 66 | ### 3. 训练 67 | 我们提供了训练的脚本 `scripts/ft.sh`. 下面我们对这个脚本里的一些变量做了简单的解释. 68 | 69 | **环境参数** 70 | - ROOT: 为该项目在本地机器上的绝对路径. 71 | - GPUS_PER_NODE: 单卡的GPU数量 72 | 默认提供的脚本是在单机下进行训练的,如果你是在多机多卡下进行分布式训练,你需要额外填上: 73 | - WORLD_SIZE: 机器的数量 74 | - RANK: 当前机器的Rank序号,通常需要从SLURM的环境变量获取 75 | - MASTER_ADDR:MASTER_PORT: 通信端口 76 | 77 | 78 | **训练参数** 79 | - MODEL_NAME_OR_PATH: pretrain model的绝对路径。 80 | - DS_PATH: deepspeed参数, 默认的config放在了 `./de_config_zero1.json`。 81 | - META_PATHS: 使用的数据集列表. 我们提供了一个样本 `meta_lists/piccolo.txt`. 该txt文件的每一行有两列,第一列是数据集的相对路径,第二列是数据集的repeat次数。 82 | - ROOT_DIRS: 数据集的目录的绝对路径。 83 | 84 | **Run** 85 | ```shell 86 | bash scripts/ft.sh 87 | ``` 88 | 89 | ## 🤗 **Model List** 90 | | Model|语言||简介|prompt| 91 | |:-|:-:|:-:|:--------------------------------------------:|:---------:| 92 | | [sensenova/piccolo-large-zh-v2](https://huggingface.co/sensenova/piccolo-large-zh-v2) | Chinese | | version2: 采用了多任务混合损失进行训练 | None | 93 | | [sensenova/piccolo-large-zh](https://huggingface.co/sensenova/piccolo-large-zh) | Chinese | | version1: 使用4亿的中文pair对进行预训练 | '查询'/'结果' | 94 | | [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh) | Chinese | | version1: 使用4亿的中文pair对进行预训练 | '查询'/'结果' | 95 | 96 | 97 | ## Citation 98 | 如果您认为我们的技术报告、模型、代码有帮助,请像下面这样引用我们的文章,或者在github和huggingface上点一个免费的赞! 99 | ```bibtex 100 | @misc{2405.06932, 101 | Author = {Junqin Huang and Zhongjie Hu and Zihao Jing and Mengya Gao and Yichao Wu}, 102 | Title = {Piccolo2: General Text Embedding with Multi-task Hybrid Loss Training}, 103 | Year = {2024}, 104 | Eprint = {arXiv:2405.06932}, 105 | } 106 | ``` -------------------------------------------------------------------------------- /assets/cmteb-0505.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/assets/cmteb-0505.png -------------------------------------------------------------------------------- /data_example/cls.dataset/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/data_example/cls.dataset/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /data_example/cls.dataset/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "text": { 6 | "dtype": "string", 7 | "_type": "Value" 8 | }, 9 | "text_pos": { 10 | "dtype": "string", 11 | "_type": "Value" 12 | }, 13 | "text_neg": { 14 | "feature": { 15 | "dtype": "string", 16 | "_type": "Value" 17 | }, 18 | "_type": "Sequence" 19 | }, 20 | "type": { 21 | "dtype": "string", 22 | "_type": "Value" 23 | } 24 | }, 25 | "homepage": "", 26 | "license": "" 27 | } -------------------------------------------------------------------------------- /data_example/cls.dataset/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "885b3f7efdf053e4", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": null 13 | } -------------------------------------------------------------------------------- /data_example/retri.dataset/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/data_example/retri.dataset/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /data_example/retri.dataset/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "text": { 6 | "dtype": "string", 7 | "_type": "Value" 8 | }, 9 | "text_pos": { 10 | "dtype": "string", 11 | "_type": "Value" 12 | }, 13 | "text_neg": { 14 | "feature": { 15 | "dtype": "string", 16 | "_type": "Value" 17 | }, 18 | "_type": "Sequence" 19 | }, 20 | "type": { 21 | "dtype": "string", 22 | "_type": "Value" 23 | } 24 | }, 25 | "homepage": "", 26 | "license": "" 27 | } -------------------------------------------------------------------------------- /data_example/retri.dataset/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "7e61224bcfd17036", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": null 13 | } -------------------------------------------------------------------------------- /data_example/sts.dataset/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/data_example/sts.dataset/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /data_example/sts.dataset/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "text": { 6 | "dtype": "string", 7 | "_type": "Value" 8 | }, 9 | "text_pair": { 10 | "dtype": "string", 11 | "_type": "Value" 12 | }, 13 | "label": { 14 | "dtype": "int64", 15 | "_type": "Value" 16 | }, 17 | "type": { 18 | "dtype": "string", 19 | "_type": "Value" 20 | } 21 | }, 22 | "homepage": "", 23 | "license": "" 24 | } -------------------------------------------------------------------------------- /data_example/sts.dataset/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "6f5ce22b311c38ae", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": null 13 | } -------------------------------------------------------------------------------- /ds_config_zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "allgather_partitions": true, 5 | "allgather_bucket_size": 2e8, 6 | "overlap_comm": true, 7 | "reduce_scatter": true, 8 | "reduce_bucket_size": 2e8, 9 | "contiguous_gradients": true 10 | }, 11 | "train_micro_batch_size_per_gpu": "auto" 12 | } -------------------------------------------------------------------------------- /finetune/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | from pathlib import Path 5 | from transformers.trainer import logger, Optional 6 | from datasets import Dataset as HfDataset 7 | from datasets import concatenate_datasets, load_from_disk 8 | from dataclasses import asdict 9 | from transformers import AutoTokenizer, HfArgumentParser, TrainerCallback # type: ignore 10 | from transformers.trainer import Trainer 11 | 12 | from piccolo.arguments import ModelArguments, DataArguments, STETrainingArguments 13 | from piccolo.data import DatsetWithInfo, UniCollator, UniDataset 14 | from piccolo.model import STEmbedder 15 | from tqdm import tqdm 16 | 17 | def load_all_datasets(meta_paths, root_dirs, query_prefix, doc_prefix) -> list[DatsetWithInfo]: 18 | all_datasets = [] 19 | for meta_path, root_dir in zip(meta_paths, root_dirs): 20 | CNT = 0 21 | meta_file = open(meta_path, 'r') 22 | for line in tqdm(meta_file.readlines()): 23 | dataset_name, repeat_num = line.strip().split(' ') 24 | dataset_dict = load_from_disk(str(os.path.join(root_dir, dataset_name))) 25 | if isinstance(dataset_dict, dict): 26 | dataset: HfDataset = concatenate_datasets(list(dataset_dict.values())) 27 | else: 28 | dataset = dataset_dict 29 | for idx in range(int(repeat_num)): 30 | all_datasets.append( 31 | DatsetWithInfo(hf_dataset=dataset, name=dataset_name + '_{}'.format(idx), 32 | query_prefix=query_prefix, passage_prefix=doc_prefix) 33 | ) 34 | CNT += 1 35 | print('loading {} datasets from path: {}'.format(CNT, meta_path)) 36 | return all_datasets 37 | 38 | class MyCallback(TrainerCallback): 39 | def on_epoch_end(self, args, state, control, train_dataloader, **kwargs): 40 | train_dataloader.dataset.create_or_refresh_data() 41 | 42 | 43 | class STETrainer(Trainer): 44 | def __init__(self, efficient_save, **kwargs): 45 | super().__init__(**kwargs) 46 | self.efficient_save = efficient_save 47 | 48 | def save_ckpt_for_sentence_transformers(self, tmp_dir, output_dir, pooling_mode: str = 'mean'): 49 | '''convert to sentence transformer format''' 50 | import shutil 51 | from sentence_transformers import models, SentenceTransformer 52 | word_embedding_model = models.Transformer(tmp_dir) 53 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean') 54 | if os.path.exists(os.path.join(tmp_dir, 'scaling_layer.bin')): 55 | state_dict = torch.load(os.path.join(tmp_dir, 'scaling_layer.bin')) 56 | in_features, out_features = state_dict['linear.weight'].shape[1], state_dict['linear.weight'].shape[0] 57 | scaling_layer = models.Dense(in_features, out_features, bias=True, activation_function=torch.nn.modules.linear.Identity()) 58 | scaling_layer.load_state_dict(state_dict, strict=True) 59 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model, scaling_layer], device='cpu') 60 | else: 61 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu') 62 | model.save(output_dir, safe_serialization=False) 63 | shutil.rmtree(tmp_dir) 64 | 65 | def _save(self, output_dir: Optional[str] = None, **kwargs): 66 | '''save the unwrap model''' 67 | 68 | output_dir = output_dir if output_dir is not None else self.args.output_dir 69 | os.makedirs(output_dir, exist_ok=True) 70 | 71 | logger.info("Saving model checkpoint to %s", output_dir) 72 | unwrap_model = self.model.embedder.encoder 73 | if self.is_world_process_zero(): 74 | # first saves to the tmp dir, then converts to sentence-transformer 75 | tmp_dir = output_dir + '-tmp' 76 | unwrap_model.save_pretrained(tmp_dir, safe_serialization=self.args.save_safetensors) 77 | self.tokenizer.save_pretrained(tmp_dir) 78 | if hasattr(self.model, 'scaling_layer'): 79 | scaling_layer = {'linear.weight': self.model.scaling_layer.state_dict()['linear.weight'].data.cpu(), 80 | 'linear.bias': self.model.scaling_layer.state_dict()['linear.bias'].data.cpu()} 81 | torch.save(scaling_layer, os.path.join(tmp_dir, 'scaling_layer.bin')) 82 | self.save_ckpt_for_sentence_transformers(tmp_dir, output_dir, self.model.embedder.pooling_strategy.value) 83 | 84 | def _save_checkpoint(self, model, trial, metrics=None): 85 | if self.efficient_save: 86 | '''only save the model ckpt weights to save disk mem''' 87 | from transformers.trainer import PREFIX_CHECKPOINT_DIR 88 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 89 | run_dir = self._get_output_dir(trial=trial) 90 | output_dir = os.path.join(run_dir, checkpoint_folder) 91 | self.save_model(output_dir, _internal_call=True) 92 | else: 93 | super()._save_checkpoint(model, trial, metrics) 94 | 95 | 96 | def main(): 97 | parser = HfArgumentParser((ModelArguments, DataArguments, STETrainingArguments)) 98 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 99 | model_args: ModelArguments 100 | data_args: DataArguments 101 | training_args: STETrainingArguments 102 | 103 | # DataLoader 104 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) 105 | all_datasets = load_all_datasets(data_args.meta_paths, data_args.root_dirs, data_args.query_prefix, data_args.doc_prefix) 106 | train_dataset = UniDataset(all_datasets, batch_size=data_args.batch_size, neg_num=data_args.neg_num) 107 | data_collator = UniCollator(tokenizer=tokenizer, max_length=model_args.max_length) 108 | 109 | # Model 110 | model = STEmbedder( 111 | model_name_or_path=model_args.model_name_or_path, 112 | embedding_strategy=model_args.embedding_strategy, 113 | freeze_pos_emb=True, 114 | add_scaling_layer=model_args.use_scaling_layer, 115 | use_mrl=model_args.use_mrl, 116 | extend_pe=model_args.extend_pe, 117 | max_length=model_args.max_length 118 | ) 119 | model.embedder.encoder.config.pad_token_id = tokenizer.pad_token_id 120 | 121 | # Trainer 122 | trainer = STETrainer( 123 | model=model, 124 | args=training_args, 125 | train_dataset=train_dataset, 126 | data_collator=data_collator, 127 | tokenizer=tokenizer, 128 | callbacks=[MyCallback], 129 | efficient_save=training_args.efficient_save 130 | ) 131 | 132 | # save training info 133 | if trainer.is_world_process_zero(): 134 | Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) 135 | Path(os.path.join(training_args.output_dir, 'parameters')).mkdir(parents=True, exist_ok=True) 136 | ## save data list info 137 | meta_paths = data_args.meta_paths 138 | with open(os.path.join(training_args.output_dir, 'parameters','data.list'), 'w') as f: 139 | for meta_path in meta_paths: 140 | f.writelines(f'list_name: {meta_path} \n') 141 | f.writelines(open(meta_path, 'r').readlines()) 142 | f.writelines('\n\n') 143 | 144 | trainer.train() 145 | 146 | # save parameter and model at the end 147 | if trainer.is_world_process_zero(): 148 | trainer.save_model(training_args.output_dir, _internal_call=True) 149 | ## save parameter 150 | parameter_dict = {'model_args': asdict(model_args), 'data_args': asdict(data_args), 'train_args': asdict(training_args)} 151 | Path(os.path.join(training_args.output_dir, 'parameters')).mkdir(parents=True, exist_ok=True) 152 | with open(os.path.join(training_args.output_dir, 'parameters', 'param.yaml'), 'w') as yaml_file: 153 | yaml.dump(parameter_dict, yaml_file) 154 | 155 | 156 | 157 | if __name__ == "__main__": 158 | main() 159 | -------------------------------------------------------------------------------- /meta_lists/piccolo-ft.txt: -------------------------------------------------------------------------------- 1 | cls.dataset 1 2 | retri.dataset 1 3 | sts.dataset 1 -------------------------------------------------------------------------------- /piccolo/__pycache__/arguments.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/arguments.cpython-310.pyc -------------------------------------------------------------------------------- /piccolo/__pycache__/criteria.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/criteria.cpython-310.pyc -------------------------------------------------------------------------------- /piccolo/__pycache__/data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/data.cpython-310.pyc -------------------------------------------------------------------------------- /piccolo/__pycache__/data_structures.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/data_structures.cpython-310.pyc -------------------------------------------------------------------------------- /piccolo/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /piccolo/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from transformers import TrainingArguments 3 | from piccolo.model import PoolingStrategy 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | model_name_or_path: str = field() # must require 11 | embedding_strategy: PoolingStrategy = field(default=PoolingStrategy.mean) 12 | extend_pe: bool = field(default=False) 13 | max_length: int = field(default=512) 14 | # scaling layer and mrl Training 15 | use_scaling_layer: bool = field(default=False) 16 | use_mrl: bool = field(default=False) 17 | 18 | @dataclass 19 | class DataArguments: 20 | # train data 21 | meta_paths: list[str] = field() # must require 22 | root_dirs: list[str] = field() # must require 23 | batch_size: int = field(default=16) 24 | query_prefix: str = field(default='') 25 | doc_prefix: str = field(default='') 26 | # hard neg 27 | neg_num: int = field(default=1) # only affects retri_contrast_loss 28 | 29 | @dataclass 30 | class STETrainingArguments(TrainingArguments): 31 | efficient_save: bool = field(default=True) 32 | -------------------------------------------------------------------------------- /piccolo/criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class ContrastLoss(torch.nn.Module): 4 | def __init__(self, temperature: float = 0.05): 5 | super().__init__() 6 | self.temperature = temperature 7 | 8 | 9 | class RetriContrastLoss(ContrastLoss): 10 | ''' 11 | loss for retrieval 12 | if use_all_pair is set to true, it will use the query-query pair as neg, 13 | otherwise it use query-passage as neg 14 | ''' 15 | def __init__(self, temperature: float = 0.05, use_all_pair=False): 16 | super().__init__() 17 | self.temperature = temperature 18 | self._cross_entropy_loss = torch.nn.CrossEntropyLoss() 19 | self.use_all_pair=use_all_pair 20 | 21 | def forward( 22 | self, 23 | text_embeddings: torch.Tensor, 24 | text_pos_embeddings: torch.Tensor, 25 | text_neg_embeddings: torch.Tensor | None, 26 | text_neg_index: torch.Tensor | None, 27 | ) -> torch.Tensor: 28 | if text_neg_embeddings is None: 29 | sim_matrix = torch.cosine_similarity( 30 | text_embeddings.unsqueeze(1), 31 | text_pos_embeddings.unsqueeze(0), 32 | dim=-1, 33 | ) 34 | sim_matrix = sim_matrix / self.temperature 35 | labels = torch.arange(sim_matrix.size(0), device=text_embeddings.device, dtype=torch.long) 36 | loss = self._cross_entropy_loss(sim_matrix, labels) 37 | return loss 38 | 39 | sim_neg_matrix = torch.cosine_similarity( 40 | text_embeddings.unsqueeze(1), 41 | text_neg_embeddings.unsqueeze(0), 42 | dim=-1, 43 | ) 44 | if self.use_all_pair: 45 | sim_pos_matrix = torch.cosine_similarity( 46 | text_embeddings.unsqueeze(1), 47 | text_pos_embeddings.unsqueeze(0), 48 | dim=-1, 49 | ) 50 | sim_matrix = torch.cat([sim_pos_matrix, sim_neg_matrix], dim=1) 51 | labels = torch.arange(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device) 52 | else: 53 | sim_pos_vector = torch.cosine_similarity(text_embeddings, text_pos_embeddings, dim=-1) 54 | sim_matrix = torch.cat([sim_pos_vector.unsqueeze(1), sim_neg_matrix], dim=1) 55 | labels = torch.zeros(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device) 56 | sim_matrix = sim_matrix / self.temperature 57 | loss = self._cross_entropy_loss(sim_matrix, labels) 58 | return loss 59 | 60 | 61 | class CoSentLoss(ContrastLoss): 62 | ''' 63 | loss for sts and pair classification. 64 | here we hard code the cosent loss weight to 0.04 65 | ''' 66 | bias: torch.Tensor 67 | 68 | def __init__(self, temperature: float = 0.05, cosent_w: float = 0.04) -> None: 69 | super().__init__(temperature) 70 | self.register_buffer('bias', torch.tensor([0.0])) 71 | self.cosent_w = cosent_w 72 | 73 | def forward(self, predict_similarity: torch.Tensor, true_similarity: torch.Tensor) -> torch.Tensor: 74 | predict_similarity = predict_similarity / self.temperature 75 | cosine_similarity_diff = -(predict_similarity.unsqueeze(0) - predict_similarity.unsqueeze(1)) 76 | smaller_mask = true_similarity.unsqueeze(0) <= true_similarity.unsqueeze(1) 77 | cosine_similarity_diff = cosine_similarity_diff[~smaller_mask] 78 | cosine_diff_scores_add_bias = torch.cat((cosine_similarity_diff, self.bias)) 79 | loss = torch.logsumexp(cosine_diff_scores_add_bias, dim=0) * self.cosent_w 80 | return loss 81 | 82 | class ClsContrastLoss(torch.nn.Module): 83 | ''' 84 | loss for clustering and classification 85 | here we hard code the cls contrast loss weight to 0.2 86 | ''' 87 | def __init__(self, temperature: float = 0.05, cls_w = 0.2): 88 | super().__init__() 89 | self.temperature = temperature 90 | self._cross_entropy_loss = torch.nn.CrossEntropyLoss() 91 | self.cls_w = cls_w 92 | 93 | def forward( 94 | self, 95 | text_embeddings: torch.Tensor, 96 | text_pos_embeddings: torch.Tensor, 97 | text_neg_embeddings: torch.Tensor, 98 | ) -> torch.Tensor: 99 | bs = text_embeddings.shape[0] 100 | assert text_neg_embeddings.shape[0] % bs == 0, 'neg num is not equal for each sample' 101 | neg_num = int(text_neg_embeddings.shape[0] // bs) 102 | 103 | sim_neg_matrix = torch.cosine_similarity( 104 | text_embeddings.unsqueeze(1), 105 | text_neg_embeddings.unsqueeze(0), 106 | dim=-1, 107 | ) 108 | sim_pos_vector = torch.cosine_similarity(text_embeddings, text_pos_embeddings, dim=-1) 109 | 110 | # find the neg for eatch training sample 111 | neg_matrix = [] 112 | for i in range(bs): 113 | neg_matrix.append(sim_neg_matrix[i, i * neg_num : (i + 1) * neg_num]) 114 | sim_neg_matrix = torch.stack(neg_matrix) 115 | sim_matrix = torch.cat([sim_pos_vector.unsqueeze(1), sim_neg_matrix], dim=1) 116 | sim_matrix = sim_matrix / self.temperature 117 | labels = torch.zeros(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device) 118 | loss = self._cross_entropy_loss(sim_matrix, labels) * self.cls_w 119 | return loss -------------------------------------------------------------------------------- /piccolo/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Any, Sequence, cast, TypeAlias 7 | from datasets import Dataset as HfDataset 8 | from torch.utils.data import Dataset, RandomSampler 9 | 10 | from piccolo.data_structures import PairRetriContrastRecord, PairScoredRecord, PairClsContrastRecord 11 | from transformers.tokenization_utils import PreTrainedTokenizer 12 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 13 | 14 | Tokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast 15 | 16 | 17 | class UniCollator: 18 | ''' 19 | Uni Data Collator 20 | 21 | for retrieval, sts, pair classification, the query max length is 64, doc max length is 512 22 | for clustering and classification, we specially set the query max length to 512, 23 | bcz for clustering and classification task, query('text') is ofent much longer than the pos/neg ('label') 24 | ''' 25 | def __init__(self, tokenizer: Tokenizer, max_length: int, q_max_length: int = 64) -> None: 26 | self.tokenizer = tokenizer 27 | self.max_length = max_length 28 | self.q_max_length = q_max_length 29 | 30 | def __call__(self, records: list) -> dict[str, torch.Tensor]: 31 | records = records[0] 32 | if isinstance(records[0], PairClsContrastRecord): 33 | texts = [record.text for record in records] 34 | texts_pos = [record.text_pos for record in records] 35 | texts_neg = [] 36 | for i, record in enumerate(records): 37 | for neg in record.text_neg: 38 | texts_neg.append(neg) 39 | text_ids = self.tokenizer(texts, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] 40 | text_pos_ids = self.tokenizer(texts_pos, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] 41 | text_neg_ids = self.tokenizer(texts_neg, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] 42 | 43 | return { 44 | 'text_ids': cast(torch.Tensor, text_ids), 45 | 'text_pos_ids': cast(torch.Tensor, text_pos_ids), 46 | 'text_neg_ids': cast(torch.Tensor, text_neg_ids), 47 | 'type': 'cls_contrast', 48 | } 49 | elif isinstance(records[0], PairRetriContrastRecord): 50 | texts = [record.text for record in records] 51 | texts_pos = [record.text_pos for record in records] 52 | texts_neg = [] 53 | texts_neg_index = [] # index indictates for which text the negative sample belongs 54 | for i, record in enumerate(records): 55 | for neg in record.text_neg: 56 | texts_neg.append(neg) 57 | texts_neg_index.append(i) 58 | 59 | text_ids = self.tokenizer(texts, padding=True, max_length=self.q_max_length, truncation=True, return_tensors='pt',)['input_ids'] 60 | text_pos_ids = self.tokenizer(texts_pos, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] 61 | if len(texts_neg) > 0: 62 | text_neg_ids = self.tokenizer(texts_neg, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] 63 | else: 64 | text_neg_ids = None 65 | return { 66 | 'text_ids': cast(torch.Tensor, text_ids), 67 | 'text_pos_ids': cast(torch.Tensor, text_pos_ids), 68 | 'text_neg_ids': cast(torch.Tensor, text_neg_ids), 69 | 'text_neg_index': cast(torch.Tensor, texts_neg_index), 70 | 'type': 'retri_contrast', 71 | } 72 | elif isinstance(records[0], PairScoredRecord): 73 | texts = [record.text for record in records] 74 | texts_pair = [record.text_pair for record in records] 75 | labels = [record.label for record in records] 76 | labels = torch.tensor(labels, dtype=torch.float32) 77 | 78 | text_ids = self.tokenizer(texts, padding=True, max_length=self.q_max_length, truncation=True, return_tensors='pt',)['input_ids'] 79 | text_pair_ids = self.tokenizer(texts_pair, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] 80 | 81 | return { 82 | 'text_ids': cast(torch.Tensor, text_ids), 83 | 'text_pair_ids': cast(torch.Tensor, text_pair_ids), 84 | 'labels': labels, 85 | 'type': 'cosent', 86 | } 87 | else: 88 | raise NotImplementedError("only support pairscored and pairneg records") 89 | 90 | 91 | @dataclass 92 | class TaskBatchIndex: 93 | name: str 94 | batch_index: list[int] 95 | 96 | 97 | @dataclass 98 | class DatsetWithInfo: 99 | hf_dataset: HfDataset 100 | name: str 101 | query_prefix: str = '' 102 | passage_prefix: str = '' 103 | 104 | 105 | class UniDataset(Dataset): 106 | ''' 107 | Task-Homogenous Dataset 108 | 109 | Code is modified from M3E: https://github.com/wangyuxinwhy/uniem 110 | It's also adopted in SFR Embedding: https://blog.salesforceairesearch.com/sfr-embedded-mistral 111 | 112 | This technique can ensure that the in-batch-negative samples come from the same data set, 113 | and it is generally believed that this can improve the quality of the in batch negatives. 114 | ''' 115 | def __init__( 116 | self, 117 | hf_datasets: list[DatsetWithInfo], 118 | neg_num: int = 1, 119 | batch_size: int = 32, 120 | max_samples: int | None = None, 121 | ): 122 | self.batch_size = batch_size 123 | self.hf_datasets = hf_datasets 124 | self.max_samples = max_samples 125 | self.name_dataset_map = {dataset.name: dataset.hf_dataset for dataset in hf_datasets} 126 | self.neg_num = neg_num 127 | self.query_prefix_map = {dataset.name: dataset.query_prefix for dataset in hf_datasets} 128 | self.passage_prefix_map = {dataset.name: dataset.passage_prefix for dataset in hf_datasets} 129 | self.create_or_refresh_data() 130 | 131 | def __len__(self): 132 | return len(self.task_batch_index_list) 133 | 134 | @staticmethod 135 | def is_valid_text(text: Any) -> bool: 136 | return isinstance(text, str) and bool(text.strip()) 137 | 138 | def create_or_refresh_data(self): 139 | self.task_batch_index_list: list[TaskBatchIndex] = [] 140 | for dataset in self.hf_datasets: 141 | max_samples = self.max_samples or len(dataset.hf_dataset) 142 | batch_size = self.batch_size 143 | num_samples = (max_samples // batch_size) * batch_size 144 | buffer = [] 145 | for i in RandomSampler(dataset.hf_dataset, num_samples=num_samples): 146 | buffer.append(i) 147 | if len(buffer) == batch_size: 148 | self.task_batch_index_list.append(TaskBatchIndex(name=dataset.name, batch_index=buffer)) 149 | buffer = [] 150 | self.random_index_list = list(RandomSampler(self.task_batch_index_list)) 151 | 152 | 153 | def get_pair_scored_records(self, records, task_name): 154 | pair_records = [] 155 | for record in records: 156 | text = record['text'] 157 | text_pair = record['text_pair'] 158 | label = record['label'] 159 | if not (self.is_valid_text(text) and self.is_valid_text(text_pair)): 160 | continue 161 | text = self.query_prefix_map[task_name] + text 162 | text_pair = self.passage_prefix_map[task_name] + text_pair 163 | pair_records.append(PairScoredRecord(text=text, text_pair=text_pair, label=label)) 164 | return pair_records 165 | 166 | def get_pair_retri_contrast_records(self, records, task_name, hf_dataset, batch_index): 167 | 168 | def process_records(record): 169 | text = record['text'] 170 | if isinstance(record['text_pos'], list): # random sample a positive 171 | assert len(record['text_pos']) >= 1, 'text pos num should be at least 1' 172 | text_pos = random.sample(record['text_pos'], 1)[0] 173 | else: 174 | text_pos = record['text_pos'] 175 | 176 | if not (self.is_valid_text(text) and self.is_valid_text(text_pos)): 177 | # skip current sample and random sample an index 178 | random_index = random.sample(range(len(hf_dataset)), k=1)[0] 179 | while random_index in batch_index: 180 | random_index = random.sample(range(len(hf_dataset)), k=1)[0] 181 | return process_records(hf_dataset[random_index]) 182 | 183 | text_neg = random.sample(record['text_neg'], min(self.neg_num, len(record['text_neg']))) 184 | text = self.query_prefix_map[task_name] + text 185 | text_pos = self.passage_prefix_map[task_name] + text_pos 186 | text_neg = [self.passage_prefix_map[task_name] + neg for neg in text_neg] 187 | return text, text_pos, text_neg 188 | 189 | pair_records = [] 190 | for record in records: 191 | text, text_pos, text_neg = process_records(record) 192 | pair_records.append(PairRetriContrastRecord(text=text, text_pos=text_pos, text_neg=text_neg)) 193 | assert len(pair_records) == self.batch_size, 'error, current batch size not match !!!' 194 | return pair_records 195 | 196 | def get_pair_cls_contrast_records(self, records, task_name): 197 | pair_records = [] 198 | for record in records: 199 | text, text_pos, text_neg = record['text'], record['text_pos'], record['text_neg'] 200 | if isinstance(record['text_pos'], list): 201 | text_pos = random.sample(record['text_pos'], 1)[0] 202 | elif isinstance(record['text_pos'], str): 203 | text_pos = record['text_pos'] 204 | else: 205 | assert False, 'type error' 206 | text_neg = random.sample(record['text_neg'], min(10, len(record['text_neg']))) # TODO, hard code the neg num to 10 207 | if self.is_valid_text(text) and self.is_valid_text(text_pos): 208 | pair_records.append(PairClsContrastRecord(text=text, text_pos=text_pos, text_neg=text_neg)) 209 | return pair_records 210 | 211 | def __getitem__(self, index: int): 212 | index = self.random_index_list[index] 213 | task_batch_index = self.task_batch_index_list[index] 214 | task_name = task_batch_index.name 215 | batch_index = task_batch_index.batch_index 216 | 217 | hf_dataset = self.name_dataset_map[task_name] 218 | records = [hf_dataset[i] for i in batch_index] 219 | if hf_dataset[0]['type'] == 'cls_contrast': 220 | pair_records = self.get_pair_cls_contrast_records(records, task_name) 221 | elif hf_dataset[0]['type'] == 'retri_contrast': 222 | pair_records = self.get_pair_retri_contrast_records(records, task_name, hf_dataset, batch_index) 223 | elif hf_dataset[0]['type'] == 'cosent': 224 | pair_records = self.get_pair_scored_records(records, task_name) 225 | else: 226 | raise NotImplementedError('only support pair contrast and pair scored') 227 | 228 | if not pair_records: 229 | print(f'records is empty', records) 230 | return self.__getitem__(index + 1) 231 | return pair_records 232 | -------------------------------------------------------------------------------- /piccolo/data_structures.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass(slots=True) 4 | class PairRetriContrastRecord: 5 | text: str 6 | text_pos: str 7 | text_neg: list 8 | 9 | @dataclass(slots=True) 10 | class PairClsContrastRecord: 11 | text: str 12 | text_pos: str 13 | text_neg: list 14 | 15 | @dataclass(slots=True) 16 | class PairScoredRecord: 17 | text: str 18 | text_pair: str 19 | label: float -------------------------------------------------------------------------------- /piccolo/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | 5 | from enum import Enum 6 | from pathlib import Path 7 | from typing import ClassVar, Literal, Type, TypeVar, cast 8 | from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel # type: ignore 9 | from piccolo.criteria import CoSentLoss, ClsContrastLoss, RetriContrastLoss 10 | 11 | class PoolingStrategy(str, Enum): 12 | cls = 'cls' 13 | mean = 'mean' 14 | 15 | class InBatchNegLossType(str, Enum): 16 | cosent = 'cosent' 17 | retri_contrast = 'retri_contrast' 18 | cls_contrast = 'cls_contrast' 19 | 20 | def build_loss(loss_type, temperature, **kwargs): 21 | loss_type = InBatchNegLossType(loss_type) 22 | match loss_type: 23 | case InBatchNegLossType.cosent: 24 | return CoSentLoss(temperature) 25 | case InBatchNegLossType.cls_contrast: 26 | return ClsContrastLoss(temperature) 27 | case InBatchNegLossType.retri_contrast: 28 | return RetriContrastLoss(temperature, **kwargs) 29 | 30 | def creat_attention_mask_from_input_ids(input_ids: torch.Tensor, pad_token_id: int) -> torch.Tensor: 31 | return input_ids != pad_token_id 32 | 33 | def mean_pooling(hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: 34 | if attention_mask is None: 35 | return torch.mean(hidden_state, dim=1) 36 | attention_mask = attention_mask.float() 37 | return torch.sum(hidden_state * attention_mask.unsqueeze(-1), dim=1) / torch.sum(attention_mask, dim=-1, keepdim=True) 38 | 39 | def load_hf_pretrained_model(model_name_or_path: str) -> PreTrainedModel: 40 | config = AutoConfig.from_pretrained(model_name_or_path) 41 | if config.model_type == 't5': 42 | from transformers import T5EncoderModel # type: ignore 43 | pretrained_model = T5EncoderModel.from_pretrained(model_name_or_path) 44 | else: 45 | pretrained_model = AutoModel.from_pretrained(model_name_or_path) 46 | return pretrained_model # type: ignore 47 | 48 | 49 | StrategyEmbedderClsMap: dict[PoolingStrategy, Type['Embedder']] = {} 50 | 51 | class Embedder(torch.nn.Module): 52 | pooling_strategy: ClassVar[PoolingStrategy] 53 | 54 | def __init__(self, encoder: PreTrainedModel, pad_token_id: int | None = None): 55 | super().__init__() 56 | self.encoder = encoder 57 | 58 | if pad_token_id is None: 59 | if encoder.config.pad_token_id is not None: 60 | self.pad_token_id = encoder.config.pad_token_id 61 | else: 62 | self.pad_token_id = 0 63 | else: 64 | self.pad_token_id = pad_token_id 65 | 66 | def __init_subclass__(cls) -> None: 67 | StrategyEmbedderClsMap[cls.pooling_strategy] = cls 68 | 69 | def save_pretrained(self, path: str | Path): 70 | self.encoder.save_pretrained(path) 71 | 72 | @classmethod 73 | def from_pretrained(cls, model_name_or_path: str): 74 | encoder = load_hf_pretrained_model(model_name_or_path) 75 | return cls(encoder) 76 | 77 | 78 | class LastMeanEmbedder(Embedder): 79 | pooling_strategy: ClassVar[PoolingStrategy] = PoolingStrategy.mean 80 | 81 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: 82 | if attention_mask is None: 83 | attention_mask = creat_attention_mask_from_input_ids(input_ids, self.pad_token_id) 84 | embeddings = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state 85 | embeddings = mean_pooling(embeddings, attention_mask) 86 | return embeddings 87 | 88 | 89 | class ClsEmbedder(Embedder): 90 | pooling_strategy: ClassVar[PoolingStrategy] = PoolingStrategy.cls 91 | 92 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: 93 | if attention_mask is None: 94 | attention_mask = creat_attention_mask_from_input_ids(input_ids, self.pad_token_id) 95 | embeddings = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] 96 | return embeddings 97 | 98 | 99 | class EmbedderForTrain(torch.nn.Module): 100 | embedder: Embedder 101 | 102 | def __init__(self, embedder: Embedder): 103 | super().__init__() 104 | self.embedder = embedder 105 | 106 | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs): 107 | self.embedder.encoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs) 108 | 109 | 110 | class ScalingLayer(torch.nn.Module): 111 | def __init__(self, origin_dim: int = 1024, scaling_dim: int = 1792): 112 | super().__init__() 113 | self.linear = torch.nn.Linear(in_features=origin_dim, out_features=scaling_dim, bias=True) 114 | def forward(self, input): 115 | return self.linear(input) 116 | 117 | 118 | class STEmbedder(EmbedderForTrain): 119 | """ 120 | Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. 121 | 122 | Args: 123 | model_name_or_path (`str` ): 124 | the path of the pretrain model 125 | embedding_strategy (`PoolingStrategy`): 126 | only support 'mean' and 'cls' 127 | for 'mean', we use the average embedding across all valid tokens embeddings of BERT's last layer. 128 | for 'cls', we use the first token embedding of the BERT's last layer 129 | freeze_pos_emb (`bool`): 130 | whether fix the position embedding or not, default is True 131 | add_scaling_layer (`bool`): 132 | add a scaling layer after the last layer of BERT, it is used for dimension scaling. 133 | use_mrl (`bool`) 134 | employ the Matryoshka Representation Learning algorithm to support flexible dimension 135 | extend_pe (`bool`) 136 | extend position embedding to longer length, here we adopt a very simple method from: 137 | https://kexue.fm/archives/7947 138 | 139 | Notation: 140 | some parameter are hard-coded in this code, such as in_feature and out_feature of scaling layer, and nesting list of MRL. 141 | """ 142 | def __init__( 143 | self, 144 | model_name_or_path: str, 145 | embedding_strategy: PoolingStrategy | str = PoolingStrategy.mean, 146 | freeze_pos_emb: bool = True, 147 | add_scaling_layer: bool = False, 148 | use_mrl: bool = False, 149 | extend_pe: bool = False, 150 | max_length: int = 512, 151 | ): 152 | pretrained_model = load_hf_pretrained_model(model_name_or_path) 153 | embedder = StrategyEmbedderClsMap[PoolingStrategy(embedding_strategy)](pretrained_model) 154 | super().__init__(embedder) 155 | self.retri_contrst_loss = build_loss('retri_contrast', temperature=0.01, use_all_pair=True) 156 | self.cosent_loss = build_loss('cosent', temperature=0.05) 157 | self.cls_contrast_loss = build_loss('cls_contrast', temperature=0.05) 158 | self.use_mrl = use_mrl 159 | self.add_scaling_layer = add_scaling_layer 160 | 161 | if add_scaling_layer: 162 | ''' 163 | Here we hard code the scaling layer pretrain path, input_dim and output_dim, you can modify it by yourself 164 | ''' 165 | self.scaling_layer = ScalingLayer(origin_dim=1024, scaling_dim=1792) 166 | if os.path.exists(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')): 167 | scaling_layer_state_dict = torch.load(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')) 168 | self.scaling_layer.load_state_dict(scaling_layer_state_dict, strict=True) 169 | print('load scaling layer successfully') 170 | else: 171 | print('not found pretrain, random init scaling layer') 172 | 173 | if use_mrl: 174 | self.mrl_nesting_list = [256, 512, 768, 1024, 1280, 1536, 1792] # hard code here 175 | 176 | if extend_pe: 177 | sp = 0 # TODO, hard code here, for xlm roberta, this should be 2 178 | # re-init the position embeddings 179 | org_pe = self.embedder.encoder.embeddings.position_embeddings 180 | pad_idx = self.embedder.encoder.embeddings.position_embeddings.padding_idx 181 | extended_pe = torch.nn.Embedding(max_length + sp, org_pe.embedding_dim, padding_idx=pad_idx) 182 | for start_idx in range(0, max_length + sp, org_pe.num_embeddings): # 迭代式地去复制,从而扩增embedding 183 | end_idx = min(start_idx + org_pe.num_embeddings, max_length + sp) 184 | extended_pe.weight.data[start_idx : end_idx] = org_pe.weight.data[:end_idx - start_idx].clone() 185 | self.embedder.encoder.embeddings.position_embeddings = extended_pe 186 | self.embedder.encoder.embeddings.position_ids = torch.arange(max_length + sp).expand((1, -1)) 187 | self.embedder.encoder.embeddings.token_type_ids = \ 188 | torch.zeros(self.embedder.encoder.embeddings.position_ids.size(), dtype=torch.long) 189 | self.embedder.encoder.config.max_position_embeddings = max_length + sp 190 | 191 | if not extend_pe and freeze_pos_emb: # extend pe时, 不能 freeze pos emb 192 | for name, param in self.embedder.encoder.embeddings.named_parameters(): 193 | if "position_embeddings" in name: 194 | param.requires_grad = False 195 | 196 | def get_embedding(self, text_ids): 197 | if text_ids is None: 198 | return None 199 | text_embeddings = self.embedder(text_ids) 200 | if self.add_scaling_layer: 201 | text_embeddings = self.scaling_layer(text_embeddings.half()).float() 202 | return text_embeddings 203 | 204 | def compute_cls_loss(self, text_ids: torch.Tensor, text_labels: torch.tensor): 205 | text_embeddings = self.get_embedding(text_ids) 206 | pred_cls = self.cls_head(text_embeddings.half()) 207 | loss = torch.nn.functional.cross_entropy(pred_cls, text_labels) 208 | return {'loss': loss} 209 | 210 | def compute_cls_contrast_loss(self, text_ids: torch.Tensor, text_pos_ids: torch.Tensor, 211 | text_neg_ids: torch.Tensor = None, type: str = 'cls_contrast') -> dict[str, torch.Tensor]: 212 | text_embeddings = self.get_embedding(text_ids) 213 | text_pos_embeddings = self.get_embedding(text_pos_ids) 214 | text_neg_embeddings = self.get_embedding(text_neg_ids) 215 | 216 | if self.use_mrl: 217 | loss = torch.tensor(0.0, device=text_embeddings.device) 218 | for num_feat in self.mrl_nesting_list: 219 | emb, pos_emb, neg_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat], text_neg_embeddings[..., :num_feat] 220 | loss += self.cls_contrast_loss(emb, pos_emb, neg_emb) / len(self.mrl_nesting_list) 221 | else: 222 | loss = self.cls_contrast_loss(text_embeddings, text_pos_embeddings, text_neg_embeddings) 223 | print('cls contrast loss: ', loss) 224 | return {'loss': loss} 225 | 226 | def compute_retri_contrast_loss(self, text_ids: torch.Tensor, text_pos_ids: torch.Tensor, text_neg_ids: torch.Tensor = None, 227 | type: str = 'retri_contrast', **kwargs) -> dict[str, torch.Tensor]: 228 | text_embeddings = self.get_embedding(text_ids) 229 | text_pos_embeddings = self.get_embedding(text_pos_ids) 230 | text_neg_embeddings = self.get_embedding(text_neg_ids) 231 | 232 | if self.use_mrl: 233 | loss = torch.tensor(0.0, device=text_embeddings.device) 234 | for num_feat in self.mrl_nesting_list: 235 | if text_neg_embeddings is not None: 236 | emb, pos_emb, neg_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat], text_neg_embeddings[..., :num_feat] 237 | else: 238 | emb, pos_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat] 239 | neg_emb = None 240 | loss += self.retri_contrst_loss(emb, pos_emb, neg_emb, **kwargs) / len(self.mrl_nesting_list) 241 | else: 242 | loss = self.retri_contrst_loss(text_embeddings, text_pos_embeddings, text_neg_embeddings, **kwargs) 243 | print('triplet loss: ', loss) 244 | return {'loss': loss} 245 | 246 | def compute_cosent_loss(self, text_ids: torch.Tensor, text_pair_ids: torch.Tensor, labels: torch.Tensor, type: str = 'cosent'): 247 | text_embeddings = self.get_embedding(text_ids) 248 | text_pair_embeddings = self.get_embedding(text_pair_ids) 249 | if self.use_mrl: 250 | loss = torch.tensor(0.0, device=text_embeddings.device) 251 | for num_feat in self.mrl_nesting_list: 252 | emb, emb_pair = text_embeddings[..., :num_feat], text_pair_embeddings[..., :num_feat] 253 | predict_labels = torch.cosine_similarity(emb, emb_pair, dim=-1) 254 | loss += self.cosent_loss(predict_labels, labels) / len(self.mrl_nesting_list) 255 | else: 256 | predict_labels = torch.cosine_similarity(text_embeddings, text_pair_embeddings, dim=-1) 257 | loss = self.cosent_loss(predict_labels, labels) 258 | print('cosent loss: ', loss) 259 | return {'loss': loss, 'predict_labels': predict_labels} 260 | 261 | 262 | def forward(self, **kwargs): 263 | if kwargs['type'] == 'cls_contrast': 264 | return self.compute_cls_contrast_loss(**kwargs) 265 | elif kwargs['type'] == 'retri_contrast': 266 | return self.compute_retri_contrast_loss(**kwargs) 267 | elif kwargs['type'] == 'cosent': 268 | return self.compute_cosent_loss(**kwargs) 269 | else: 270 | raise NotImplementedError('not suuport current input kwargs') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.36.1 2 | datasets>=2.16.1 3 | torch>=2.1.1 4 | enum -------------------------------------------------------------------------------- /scripts/ft.sh: -------------------------------------------------------------------------------- 1 | ROOT=/mnt/lustre/huangjunqin/piccolo 2 | export PYTHONPATH=$ROOT:${PYTHONPATH} 3 | 4 | # SLURM Parameter 5 | GPUS_PER_NODE=1 6 | if [ -z "$WORLD_SIZE" ]; then 7 | WORLD_SIZE=1 8 | RANK=0 9 | MASTER_ADDR=127.0.0.1 10 | MASTER_PORT=6000 11 | fi 12 | 13 | # Hyper Parameter Start 14 | MODEL_NAME_OR_PATH=/mnt/lustre/huangjunqin/test/piccolo2-large-zh-0417 15 | EPOCHS=3 16 | BATCH_SIZE=8 17 | LR=1e-5 18 | NEG_NUM=1 19 | DS_PATH=$ROOT/ds_config_zero1.json 20 | MAX_LENGTH=512 21 | META_PATHS=( 22 | meta_lists/piccolo-ft.txt 23 | ) 24 | 25 | ROOT_DIRS=( 26 | data_example/ 27 | ) 28 | # Hyper Parameter End 29 | 30 | 31 | model_args=( 32 | "--model_name_or_path" $MODEL_NAME_OR_PATH 33 | "--max_length=$MAX_LENGTH" 34 | "--query_prefix=''" 35 | "--doc_prefix=''" 36 | "--use_scaling_layer=True" 37 | "--use_mrl=True" 38 | ) 39 | 40 | data_args=( 41 | "--meta_paths" "${META_PATHS[@]}" 42 | "--root_dirs" "${ROOT_DIRS[@]}" 43 | "--neg_num=$NEG_NUM" 44 | ) 45 | 46 | train_args=( 47 | "--fp16" 48 | "--gradient_checkpointing=True" 49 | "--output_dir=output_test" 50 | "--num_train_epochs=$EPOCHS" 51 | "--dataloader_num_workers=0" 52 | "--batch_size=$BATCH_SIZE" 53 | "--learning_rate=$LR" 54 | "--deepspeed=$DS_PATH" 55 | "--logging_steps=500" 56 | "--save_safetensors=False" 57 | "--report_to=tensorboard" 58 | "--save_strategy=epoch" 59 | "--per_device_train_batch_size=1" 60 | ) 61 | 62 | all_args=("${model_args[@]}" "${data_args[@]}" "${train_args[@]}") 63 | 64 | 65 | export LAUNCHER="python -u -m torch.distributed.run \ 66 | --nproc_per_node $GPUS_PER_NODE \ 67 | --nnodes $WORLD_SIZE \ 68 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ 69 | --rdzv_backend c10d \ 70 | --max_restarts 0 \ 71 | --tee 3 \ 72 | " 73 | 74 | export CMD=" \ 75 | $ROOT/finetune/train.py \ 76 | " 77 | 78 | echo $CMD 79 | 80 | bash -c "$LAUNCHER $CMD ${all_args[*]}" --------------------------------------------------------------------------------