├── README.md
├── README_zh.md
├── assets
└── cmteb-0505.png
├── data_example
├── cls.dataset
│ ├── data-00000-of-00001.arrow
│ ├── dataset_info.json
│ └── state.json
├── retri.dataset
│ ├── data-00000-of-00001.arrow
│ ├── dataset_info.json
│ └── state.json
└── sts.dataset
│ ├── data-00000-of-00001.arrow
│ ├── dataset_info.json
│ └── state.json
├── ds_config_zero1.json
├── finetune
└── train.py
├── meta_lists
└── piccolo-ft.txt
├── piccolo
├── __pycache__
│ ├── arguments.cpython-310.pyc
│ ├── criteria.cpython-310.pyc
│ ├── data.cpython-310.pyc
│ ├── data_structures.cpython-310.pyc
│ └── model.cpython-310.pyc
├── arguments.py
├── criteria.py
├── data.py
├── data_structures.py
└── model.py
├── requirements.txt
└── scripts
└── ft.sh
/README.md:
--------------------------------------------------------------------------------
1 | [EN](README.md) | [简体中文](README_zh.md)
2 |
3 | # [Piccolo2: General Text Embeddings with Multi-Task Hybrid loss Training](https://arxiv.org/abs/2405.06932)
4 |
5 | 🚀 **New SOTA on CMTEB**
6 |
7 | 🔥 Our general sentence embedding [sensenova/piccolo-large-zh-v2](https://huggingface.co/sensenova/piccolo-large-zh-v2) achieves SOTA on the CMTEB Leaderboard with an average score of 70.95. [2024/4/23]
8 |
9 |
10 | 📄 Results on CMTEB Leaderboard [click to expand]
11 |
12 |
13 |
14 |
15 |
16 | ## 💡Model Highlights
17 | Here we introduce Piccolo2, an embedding model that surpasses other models in the comprehensive evaluation over 6 tasks on CMTEB benchmark, setting a new state-of-the-art. Piccolo2 primarily leverages an efficient multi-task hybrid loss training approach, effectively harnessing textual data and labels from diverse downstream tasks. In addition, Piccolo2 scales up the embedding dimension and uses MRL training to support more flexible vector dimensions.
18 |
19 | For huggingface model, please refer to our space: https://huggingface.co/sensenova
20 | For training details, please refer to our tech report: https://arxiv.org/abs/2405.06932
21 |
22 | ## 📖 Repo Details
23 | In this repo, we release our training code and implement some tricks that are helpful for embedding training, including:
24 | - Multi-task Hybrid Loss Training
25 | - Matryoshka Representation Learning
26 | - Embdding Dimension Scaling
27 | - Task-Homogenous Dataset
28 | - Position Embedding Hierarchical Decomposition
29 |
30 | In order to save memory, we default use deepspeed-zero1, gradient checkpointing and mix-precision for training. We also provide slurm scripts to help conduct distributed training.
31 |
32 | ### Tips
33 | 1. The repository will default to training using a multi-task hybrid loss, which is described in our technical report.
34 |
35 | 2. For embdding dimension scaling, we hard code the pretrain path, input dim and output dim in the code:
36 | ```python
37 | self.scaling_layer = ScalingLayer(origin_dim=1024, scaling_dim=1792)
38 | if os.path.exists(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')):
39 | scaling_layer_state_dict = torch.load(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin'))
40 | self.scaling_layer.load_state_dict(scaling_layer_state_dict, strict=True)
41 | ```
42 | 3. For MRL training, we hard code the nesting list. Feel free to modify it by yourself.
43 | ```python
44 | self.mrl_nesting_list = [256, 512, 768, 1024, 1280, 1536, 1792]
45 | ```
46 |
47 | 4. If you want to increase the length of the position embedding, set `extend_pe` to True and then set `max_length` to your expected length in the scripts.
48 |
49 | ## 🔨 Best Practice
50 | ### 1. Environment
51 | ```shell
52 | pip install -r requirements.txt
53 | ```
54 |
55 | ### 2. Prepare Dataset
56 | We divide the dataset into three major categories: retrieval/reranking, clustering/classification, abd sts/pair classification, and we adopt different loss functions for different categories. In `data_example` dir, we provide example for these three types of data.
57 |
58 | 1) `Retri`: For Retrieval and Reranking dataset, we implement standard InfoNCE with in-batch-negative, this dataset contain 4 columns: `text`, `text_pos`, `text_neg`, `type`. Here 'type' is 'retri_contrast'
59 |
60 | 2) `STS`: For STS and Pair-Classification dataset, we implement the cosent loss, this dataset contain 4 columns: `text`, `text_pair`, `label`, `type`. Here 'type' is 'cosent'
61 |
62 | 3) `Cls`: For Classification and Clustering dataset, we implement the InfoNCE w/o in-batch-negative, this dataset contain 4 columns: `text`, `text_pos`, `text_neg`, `type`. Here 'type' is 'cls_contrast'
63 |
64 | The 'type' column indicates the type of the current dataset. We obtain the type of the current batch to apply different loss functions.
65 |
66 | ### 3. Training
67 | We offer the training script located at `scripts/ft.sh`. Below we've provided explanations for the variables used within the script.
68 |
69 | **Environment Variables**
70 | - ROOT: The absolute path of the repo in your machine.
71 | - GPUS_PER_NODE: Number of GPUs on a single machine
72 | The default proviede scripts is for single-machine training. If you are in a multi-node distributed training scenario, you should additionally specify the env parameters below:
73 | - WORLD_SIZE: Number of nodes
74 | - RANK: rank of current node, it is often obtained thourgh the SLURM environment variable
75 | - MASTER_ADDR:MASTER_PORT: Communication port
76 |
77 | **Training Variables**
78 | - MODEL_NAME_OR_PATH: the absolute path of the pretrain model
79 | - DS_PATH: the deepspeed config, we provide a default config in `./de_config_zero1.json`
80 | - META_PATHS: the list of the dataset. We provide a sample in `meta_lists/piccolo.txt`. Each row consists of two columns: the relative path of the dataset and the number of repeats.
81 | - ROOT_DIRS: The absolute path of the dataset dir
82 |
83 | **Run**
84 | ```shell
85 | bash scripts/ft.sh
86 | ```
87 |
88 | ## 🤗 **Model List**
89 | | Model|Language||Description|prompt|
90 | |:-|:-:|:-:|:--------------------------------------------:|:---------:|
91 | | [sensenova/piccolo-large-zh-v2](https://huggingface.co/sensenova/piccolo-large-zh-v2) | Chinese | | version2: finetuning with multi-task hybrid loss training | None |
92 | | [sensenova/piccolo-large-zh](https://huggingface.co/sensenova/piccolo-large-zh) | Chinese | | version1: pretrain under 400 million chinese text pair | '查询'/'结果' |
93 | | [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh) | Chinese | | version1: pretrain under 400 million chinese text pair | '查询'/'结果' |
94 |
95 |
96 | ## Citation
97 | If you find our tech report, models or code helpful, please cite our report or give a star on github or huggingface!
98 | ```bibtex
99 | @misc{2405.06932,
100 | Author = {Junqin Huang and Zhongjie Hu and Zihao Jing and Mengya Gao and Yichao Wu},
101 | Title = {Piccolo2: General Text Embedding with Multi-task Hybrid Loss Training},
102 | Year = {2024},
103 | Eprint = {arXiv:2405.06932},
104 | }
105 | ```
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 | [EN](README.md) | [简体中文](README_zh.md)
2 |
3 | # [Piccolo2: General Text Embeddings with Multi-Task Hybrid loss Training](https://arxiv.org/abs/2405.06932)
4 |
5 | 🚀 **New SOTA on CMTEB**
6 |
7 | 🔥 我们最新的通用embedding模型 [sensenova/piccolo-large-zh-v2](https://huggingface.co/sensenova/piccolo-large-zh-v2) 在CMTEB评测榜单上取得了70.95的均分! [2024/4/23]
8 |
9 |
10 | 📄 CMTEB结果 [点击展开]
11 |
12 |
13 |
14 |
15 |
16 | ## 💡Model Highlights
17 | Piccolo2 在CMTEB榜单上的6项任务的综合评估中超越了其他模型,目前位于第一位。Piccolo2 主要利用高效的多任务混合损失训练方法,有效地利用来自不同下游任务的文本数据和标签。 此外,Piccolo2 扩大了Embedding维度,并使用MRL训练来支持更灵活的向量维度。
18 | huggingface上放了我们最新的模型: https://huggingface.co/sensenova
19 | 对于训练细节,可以参考我们的技术报告: https://arxiv.org/abs/2405.06932
20 |
21 | ## 📖 Repo Details
22 | 在这个repo里面,我们放出了训练的代码,里面提供了一些有助于提升Embedding模型性能的训练技巧:
23 | - Multi-task Hybrid Loss Training
24 | - Matryoshka Representation Learning
25 | - Embdding Dimension Scaling
26 | - Task-Homogenous Dataset
27 | - Position Embedding Hierarchical Decomposition
28 |
29 | 为了节省内存,我们默认使用 deepspeed-zero1、gradient checkpointing和mix-precision进行训练。我们还提供了一个脚本用来帮助大家进行分布式训练。
30 |
31 | ### Tips
32 | 1. 该项目会默认使用multi-task hybrid loss来进行训练,前提是数据按规定格式准备好。
33 |
34 | 2. 对于scaling dimension length, 我们将它的一些参数写死在了代码里面, 请根据需要自行更改:
35 | ```python
36 | self.scaling_layer = ScalingLayer(origin_dim=1024, scaling_dim=1792)
37 | if os.path.exists(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')):
38 | scaling_layer_state_dict = torch.load(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin'))
39 | self.scaling_layer.load_state_dict(scaling_layer_state_dict, strict=True)
40 | ```
41 | 3. 对于MRL训练, 我们也把它的参数写死在了代码里.
42 | ```python
43 | self.mrl_nesting_list = [256, 512, 768, 1024, 1280, 1536, 1792]
44 | ```
45 |
46 | 4. 如果你想增长模型的position embedding,我们实现了一个简单的层次分解方法, 只需要将 `extend_pe` 设置为True 然后把 `max_length` 设置为你的目标长度就可以了.
47 |
48 |
49 | ## 🔨 使用指南
50 | ### 1. 环境
51 | ```shell
52 | pip install -r requirements.txt
53 | ```
54 |
55 | ### 2. 数据准备
56 | 我们将数据集分为三大类:检索/排序、聚类/分类、句对相似度/句对分类,并对不同类别采用不同的损失函数。 在`data_example`目录中,我们提供了这三种类型数据的示例。
57 |
58 | 1) `Retri`: 对于检索、重排数据集, 我们采用标准的InfoNCE进行优化,同时采用in-batch-negative来扩充负样本数量。这个数据集需要四列: `text`, `text_pos`, `text_neg`, `type`. 这里的 'type' 需要被标注为 'retri_contrast'。
59 |
60 | 2) `STS`: 对于句对相似度,句对分类数据集, 我们采用了cosent这种排序损失, 这个数据集同样有四列: `text`, `text_pair`, `label`, `type`. 这里 'type' 需要被标注为 'cosent'。
61 |
62 | 3) `Cls`: 对于分类、聚类任务, 我们将文本和它的语义标签视为正负样本对,同样采用了InfoNCE损失来优化,但不再采样in-batch-negative(因为这很容易造成训练冲突),这类数据集同样包含四列: `text`, `text_pos`, `text_neg`, `type`. 这里 'type' 需要被标注为 'cls_contrast'
63 |
64 | 'type' 列表明了当前数据的类型,在训练的时候,我们通过获取当前数据的类型,以采用不同的损失进行优化。
65 |
66 | ### 3. 训练
67 | 我们提供了训练的脚本 `scripts/ft.sh`. 下面我们对这个脚本里的一些变量做了简单的解释.
68 |
69 | **环境参数**
70 | - ROOT: 为该项目在本地机器上的绝对路径.
71 | - GPUS_PER_NODE: 单卡的GPU数量
72 | 默认提供的脚本是在单机下进行训练的,如果你是在多机多卡下进行分布式训练,你需要额外填上:
73 | - WORLD_SIZE: 机器的数量
74 | - RANK: 当前机器的Rank序号,通常需要从SLURM的环境变量获取
75 | - MASTER_ADDR:MASTER_PORT: 通信端口
76 |
77 |
78 | **训练参数**
79 | - MODEL_NAME_OR_PATH: pretrain model的绝对路径。
80 | - DS_PATH: deepspeed参数, 默认的config放在了 `./de_config_zero1.json`。
81 | - META_PATHS: 使用的数据集列表. 我们提供了一个样本 `meta_lists/piccolo.txt`. 该txt文件的每一行有两列,第一列是数据集的相对路径,第二列是数据集的repeat次数。
82 | - ROOT_DIRS: 数据集的目录的绝对路径。
83 |
84 | **Run**
85 | ```shell
86 | bash scripts/ft.sh
87 | ```
88 |
89 | ## 🤗 **Model List**
90 | | Model|语言||简介|prompt|
91 | |:-|:-:|:-:|:--------------------------------------------:|:---------:|
92 | | [sensenova/piccolo-large-zh-v2](https://huggingface.co/sensenova/piccolo-large-zh-v2) | Chinese | | version2: 采用了多任务混合损失进行训练 | None |
93 | | [sensenova/piccolo-large-zh](https://huggingface.co/sensenova/piccolo-large-zh) | Chinese | | version1: 使用4亿的中文pair对进行预训练 | '查询'/'结果' |
94 | | [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh) | Chinese | | version1: 使用4亿的中文pair对进行预训练 | '查询'/'结果' |
95 |
96 |
97 | ## Citation
98 | 如果您认为我们的技术报告、模型、代码有帮助,请像下面这样引用我们的文章,或者在github和huggingface上点一个免费的赞!
99 | ```bibtex
100 | @misc{2405.06932,
101 | Author = {Junqin Huang and Zhongjie Hu and Zihao Jing and Mengya Gao and Yichao Wu},
102 | Title = {Piccolo2: General Text Embedding with Multi-task Hybrid Loss Training},
103 | Year = {2024},
104 | Eprint = {arXiv:2405.06932},
105 | }
106 | ```
--------------------------------------------------------------------------------
/assets/cmteb-0505.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/assets/cmteb-0505.png
--------------------------------------------------------------------------------
/data_example/cls.dataset/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/data_example/cls.dataset/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/data_example/cls.dataset/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "citation": "",
3 | "description": "",
4 | "features": {
5 | "text": {
6 | "dtype": "string",
7 | "_type": "Value"
8 | },
9 | "text_pos": {
10 | "dtype": "string",
11 | "_type": "Value"
12 | },
13 | "text_neg": {
14 | "feature": {
15 | "dtype": "string",
16 | "_type": "Value"
17 | },
18 | "_type": "Sequence"
19 | },
20 | "type": {
21 | "dtype": "string",
22 | "_type": "Value"
23 | }
24 | },
25 | "homepage": "",
26 | "license": ""
27 | }
--------------------------------------------------------------------------------
/data_example/cls.dataset/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "885b3f7efdf053e4",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": null
13 | }
--------------------------------------------------------------------------------
/data_example/retri.dataset/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/data_example/retri.dataset/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/data_example/retri.dataset/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "citation": "",
3 | "description": "",
4 | "features": {
5 | "text": {
6 | "dtype": "string",
7 | "_type": "Value"
8 | },
9 | "text_pos": {
10 | "dtype": "string",
11 | "_type": "Value"
12 | },
13 | "text_neg": {
14 | "feature": {
15 | "dtype": "string",
16 | "_type": "Value"
17 | },
18 | "_type": "Sequence"
19 | },
20 | "type": {
21 | "dtype": "string",
22 | "_type": "Value"
23 | }
24 | },
25 | "homepage": "",
26 | "license": ""
27 | }
--------------------------------------------------------------------------------
/data_example/retri.dataset/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "7e61224bcfd17036",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": null
13 | }
--------------------------------------------------------------------------------
/data_example/sts.dataset/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/data_example/sts.dataset/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/data_example/sts.dataset/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "citation": "",
3 | "description": "",
4 | "features": {
5 | "text": {
6 | "dtype": "string",
7 | "_type": "Value"
8 | },
9 | "text_pair": {
10 | "dtype": "string",
11 | "_type": "Value"
12 | },
13 | "label": {
14 | "dtype": "int64",
15 | "_type": "Value"
16 | },
17 | "type": {
18 | "dtype": "string",
19 | "_type": "Value"
20 | }
21 | },
22 | "homepage": "",
23 | "license": ""
24 | }
--------------------------------------------------------------------------------
/data_example/sts.dataset/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "6f5ce22b311c38ae",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": null
13 | }
--------------------------------------------------------------------------------
/ds_config_zero1.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "allgather_partitions": true,
5 | "allgather_bucket_size": 2e8,
6 | "overlap_comm": true,
7 | "reduce_scatter": true,
8 | "reduce_bucket_size": 2e8,
9 | "contiguous_gradients": true
10 | },
11 | "train_micro_batch_size_per_gpu": "auto"
12 | }
--------------------------------------------------------------------------------
/finetune/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import torch
4 | from pathlib import Path
5 | from transformers.trainer import logger, Optional
6 | from datasets import Dataset as HfDataset
7 | from datasets import concatenate_datasets, load_from_disk
8 | from dataclasses import asdict
9 | from transformers import AutoTokenizer, HfArgumentParser, TrainerCallback # type: ignore
10 | from transformers.trainer import Trainer
11 |
12 | from piccolo.arguments import ModelArguments, DataArguments, STETrainingArguments
13 | from piccolo.data import DatsetWithInfo, UniCollator, UniDataset
14 | from piccolo.model import STEmbedder
15 | from tqdm import tqdm
16 |
17 | def load_all_datasets(meta_paths, root_dirs, query_prefix, doc_prefix) -> list[DatsetWithInfo]:
18 | all_datasets = []
19 | for meta_path, root_dir in zip(meta_paths, root_dirs):
20 | CNT = 0
21 | meta_file = open(meta_path, 'r')
22 | for line in tqdm(meta_file.readlines()):
23 | dataset_name, repeat_num = line.strip().split(' ')
24 | dataset_dict = load_from_disk(str(os.path.join(root_dir, dataset_name)))
25 | if isinstance(dataset_dict, dict):
26 | dataset: HfDataset = concatenate_datasets(list(dataset_dict.values()))
27 | else:
28 | dataset = dataset_dict
29 | for idx in range(int(repeat_num)):
30 | all_datasets.append(
31 | DatsetWithInfo(hf_dataset=dataset, name=dataset_name + '_{}'.format(idx),
32 | query_prefix=query_prefix, passage_prefix=doc_prefix)
33 | )
34 | CNT += 1
35 | print('loading {} datasets from path: {}'.format(CNT, meta_path))
36 | return all_datasets
37 |
38 | class MyCallback(TrainerCallback):
39 | def on_epoch_end(self, args, state, control, train_dataloader, **kwargs):
40 | train_dataloader.dataset.create_or_refresh_data()
41 |
42 |
43 | class STETrainer(Trainer):
44 | def __init__(self, efficient_save, **kwargs):
45 | super().__init__(**kwargs)
46 | self.efficient_save = efficient_save
47 |
48 | def save_ckpt_for_sentence_transformers(self, tmp_dir, output_dir, pooling_mode: str = 'mean'):
49 | '''convert to sentence transformer format'''
50 | import shutil
51 | from sentence_transformers import models, SentenceTransformer
52 | word_embedding_model = models.Transformer(tmp_dir)
53 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean')
54 | if os.path.exists(os.path.join(tmp_dir, 'scaling_layer.bin')):
55 | state_dict = torch.load(os.path.join(tmp_dir, 'scaling_layer.bin'))
56 | in_features, out_features = state_dict['linear.weight'].shape[1], state_dict['linear.weight'].shape[0]
57 | scaling_layer = models.Dense(in_features, out_features, bias=True, activation_function=torch.nn.modules.linear.Identity())
58 | scaling_layer.load_state_dict(state_dict, strict=True)
59 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model, scaling_layer], device='cpu')
60 | else:
61 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
62 | model.save(output_dir, safe_serialization=False)
63 | shutil.rmtree(tmp_dir)
64 |
65 | def _save(self, output_dir: Optional[str] = None, **kwargs):
66 | '''save the unwrap model'''
67 |
68 | output_dir = output_dir if output_dir is not None else self.args.output_dir
69 | os.makedirs(output_dir, exist_ok=True)
70 |
71 | logger.info("Saving model checkpoint to %s", output_dir)
72 | unwrap_model = self.model.embedder.encoder
73 | if self.is_world_process_zero():
74 | # first saves to the tmp dir, then converts to sentence-transformer
75 | tmp_dir = output_dir + '-tmp'
76 | unwrap_model.save_pretrained(tmp_dir, safe_serialization=self.args.save_safetensors)
77 | self.tokenizer.save_pretrained(tmp_dir)
78 | if hasattr(self.model, 'scaling_layer'):
79 | scaling_layer = {'linear.weight': self.model.scaling_layer.state_dict()['linear.weight'].data.cpu(),
80 | 'linear.bias': self.model.scaling_layer.state_dict()['linear.bias'].data.cpu()}
81 | torch.save(scaling_layer, os.path.join(tmp_dir, 'scaling_layer.bin'))
82 | self.save_ckpt_for_sentence_transformers(tmp_dir, output_dir, self.model.embedder.pooling_strategy.value)
83 |
84 | def _save_checkpoint(self, model, trial, metrics=None):
85 | if self.efficient_save:
86 | '''only save the model ckpt weights to save disk mem'''
87 | from transformers.trainer import PREFIX_CHECKPOINT_DIR
88 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
89 | run_dir = self._get_output_dir(trial=trial)
90 | output_dir = os.path.join(run_dir, checkpoint_folder)
91 | self.save_model(output_dir, _internal_call=True)
92 | else:
93 | super()._save_checkpoint(model, trial, metrics)
94 |
95 |
96 | def main():
97 | parser = HfArgumentParser((ModelArguments, DataArguments, STETrainingArguments))
98 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
99 | model_args: ModelArguments
100 | data_args: DataArguments
101 | training_args: STETrainingArguments
102 |
103 | # DataLoader
104 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
105 | all_datasets = load_all_datasets(data_args.meta_paths, data_args.root_dirs, data_args.query_prefix, data_args.doc_prefix)
106 | train_dataset = UniDataset(all_datasets, batch_size=data_args.batch_size, neg_num=data_args.neg_num)
107 | data_collator = UniCollator(tokenizer=tokenizer, max_length=model_args.max_length)
108 |
109 | # Model
110 | model = STEmbedder(
111 | model_name_or_path=model_args.model_name_or_path,
112 | embedding_strategy=model_args.embedding_strategy,
113 | freeze_pos_emb=True,
114 | add_scaling_layer=model_args.use_scaling_layer,
115 | use_mrl=model_args.use_mrl,
116 | extend_pe=model_args.extend_pe,
117 | max_length=model_args.max_length
118 | )
119 | model.embedder.encoder.config.pad_token_id = tokenizer.pad_token_id
120 |
121 | # Trainer
122 | trainer = STETrainer(
123 | model=model,
124 | args=training_args,
125 | train_dataset=train_dataset,
126 | data_collator=data_collator,
127 | tokenizer=tokenizer,
128 | callbacks=[MyCallback],
129 | efficient_save=training_args.efficient_save
130 | )
131 |
132 | # save training info
133 | if trainer.is_world_process_zero():
134 | Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
135 | Path(os.path.join(training_args.output_dir, 'parameters')).mkdir(parents=True, exist_ok=True)
136 | ## save data list info
137 | meta_paths = data_args.meta_paths
138 | with open(os.path.join(training_args.output_dir, 'parameters','data.list'), 'w') as f:
139 | for meta_path in meta_paths:
140 | f.writelines(f'list_name: {meta_path} \n')
141 | f.writelines(open(meta_path, 'r').readlines())
142 | f.writelines('\n\n')
143 |
144 | trainer.train()
145 |
146 | # save parameter and model at the end
147 | if trainer.is_world_process_zero():
148 | trainer.save_model(training_args.output_dir, _internal_call=True)
149 | ## save parameter
150 | parameter_dict = {'model_args': asdict(model_args), 'data_args': asdict(data_args), 'train_args': asdict(training_args)}
151 | Path(os.path.join(training_args.output_dir, 'parameters')).mkdir(parents=True, exist_ok=True)
152 | with open(os.path.join(training_args.output_dir, 'parameters', 'param.yaml'), 'w') as yaml_file:
153 | yaml.dump(parameter_dict, yaml_file)
154 |
155 |
156 |
157 | if __name__ == "__main__":
158 | main()
159 |
--------------------------------------------------------------------------------
/meta_lists/piccolo-ft.txt:
--------------------------------------------------------------------------------
1 | cls.dataset 1
2 | retri.dataset 1
3 | sts.dataset 1
--------------------------------------------------------------------------------
/piccolo/__pycache__/arguments.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/arguments.cpython-310.pyc
--------------------------------------------------------------------------------
/piccolo/__pycache__/criteria.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/criteria.cpython-310.pyc
--------------------------------------------------------------------------------
/piccolo/__pycache__/data.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/data.cpython-310.pyc
--------------------------------------------------------------------------------
/piccolo/__pycache__/data_structures.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/data_structures.cpython-310.pyc
--------------------------------------------------------------------------------
/piccolo/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSenseNova/piccolo-embedding/9dd1d66cd47981b462d76a6040507663e87ee143/piccolo/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/piccolo/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from transformers import TrainingArguments
3 | from piccolo.model import PoolingStrategy
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 | model_name_or_path: str = field() # must require
11 | embedding_strategy: PoolingStrategy = field(default=PoolingStrategy.mean)
12 | extend_pe: bool = field(default=False)
13 | max_length: int = field(default=512)
14 | # scaling layer and mrl Training
15 | use_scaling_layer: bool = field(default=False)
16 | use_mrl: bool = field(default=False)
17 |
18 | @dataclass
19 | class DataArguments:
20 | # train data
21 | meta_paths: list[str] = field() # must require
22 | root_dirs: list[str] = field() # must require
23 | batch_size: int = field(default=16)
24 | query_prefix: str = field(default='')
25 | doc_prefix: str = field(default='')
26 | # hard neg
27 | neg_num: int = field(default=1) # only affects retri_contrast_loss
28 |
29 | @dataclass
30 | class STETrainingArguments(TrainingArguments):
31 | efficient_save: bool = field(default=True)
32 |
--------------------------------------------------------------------------------
/piccolo/criteria.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class ContrastLoss(torch.nn.Module):
4 | def __init__(self, temperature: float = 0.05):
5 | super().__init__()
6 | self.temperature = temperature
7 |
8 |
9 | class RetriContrastLoss(ContrastLoss):
10 | '''
11 | loss for retrieval
12 | if use_all_pair is set to true, it will use the query-query pair as neg,
13 | otherwise it use query-passage as neg
14 | '''
15 | def __init__(self, temperature: float = 0.05, use_all_pair=False):
16 | super().__init__()
17 | self.temperature = temperature
18 | self._cross_entropy_loss = torch.nn.CrossEntropyLoss()
19 | self.use_all_pair=use_all_pair
20 |
21 | def forward(
22 | self,
23 | text_embeddings: torch.Tensor,
24 | text_pos_embeddings: torch.Tensor,
25 | text_neg_embeddings: torch.Tensor | None,
26 | text_neg_index: torch.Tensor | None,
27 | ) -> torch.Tensor:
28 | if text_neg_embeddings is None:
29 | sim_matrix = torch.cosine_similarity(
30 | text_embeddings.unsqueeze(1),
31 | text_pos_embeddings.unsqueeze(0),
32 | dim=-1,
33 | )
34 | sim_matrix = sim_matrix / self.temperature
35 | labels = torch.arange(sim_matrix.size(0), device=text_embeddings.device, dtype=torch.long)
36 | loss = self._cross_entropy_loss(sim_matrix, labels)
37 | return loss
38 |
39 | sim_neg_matrix = torch.cosine_similarity(
40 | text_embeddings.unsqueeze(1),
41 | text_neg_embeddings.unsqueeze(0),
42 | dim=-1,
43 | )
44 | if self.use_all_pair:
45 | sim_pos_matrix = torch.cosine_similarity(
46 | text_embeddings.unsqueeze(1),
47 | text_pos_embeddings.unsqueeze(0),
48 | dim=-1,
49 | )
50 | sim_matrix = torch.cat([sim_pos_matrix, sim_neg_matrix], dim=1)
51 | labels = torch.arange(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device)
52 | else:
53 | sim_pos_vector = torch.cosine_similarity(text_embeddings, text_pos_embeddings, dim=-1)
54 | sim_matrix = torch.cat([sim_pos_vector.unsqueeze(1), sim_neg_matrix], dim=1)
55 | labels = torch.zeros(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device)
56 | sim_matrix = sim_matrix / self.temperature
57 | loss = self._cross_entropy_loss(sim_matrix, labels)
58 | return loss
59 |
60 |
61 | class CoSentLoss(ContrastLoss):
62 | '''
63 | loss for sts and pair classification.
64 | here we hard code the cosent loss weight to 0.04
65 | '''
66 | bias: torch.Tensor
67 |
68 | def __init__(self, temperature: float = 0.05, cosent_w: float = 0.04) -> None:
69 | super().__init__(temperature)
70 | self.register_buffer('bias', torch.tensor([0.0]))
71 | self.cosent_w = cosent_w
72 |
73 | def forward(self, predict_similarity: torch.Tensor, true_similarity: torch.Tensor) -> torch.Tensor:
74 | predict_similarity = predict_similarity / self.temperature
75 | cosine_similarity_diff = -(predict_similarity.unsqueeze(0) - predict_similarity.unsqueeze(1))
76 | smaller_mask = true_similarity.unsqueeze(0) <= true_similarity.unsqueeze(1)
77 | cosine_similarity_diff = cosine_similarity_diff[~smaller_mask]
78 | cosine_diff_scores_add_bias = torch.cat((cosine_similarity_diff, self.bias))
79 | loss = torch.logsumexp(cosine_diff_scores_add_bias, dim=0) * self.cosent_w
80 | return loss
81 |
82 | class ClsContrastLoss(torch.nn.Module):
83 | '''
84 | loss for clustering and classification
85 | here we hard code the cls contrast loss weight to 0.2
86 | '''
87 | def __init__(self, temperature: float = 0.05, cls_w = 0.2):
88 | super().__init__()
89 | self.temperature = temperature
90 | self._cross_entropy_loss = torch.nn.CrossEntropyLoss()
91 | self.cls_w = cls_w
92 |
93 | def forward(
94 | self,
95 | text_embeddings: torch.Tensor,
96 | text_pos_embeddings: torch.Tensor,
97 | text_neg_embeddings: torch.Tensor,
98 | ) -> torch.Tensor:
99 | bs = text_embeddings.shape[0]
100 | assert text_neg_embeddings.shape[0] % bs == 0, 'neg num is not equal for each sample'
101 | neg_num = int(text_neg_embeddings.shape[0] // bs)
102 |
103 | sim_neg_matrix = torch.cosine_similarity(
104 | text_embeddings.unsqueeze(1),
105 | text_neg_embeddings.unsqueeze(0),
106 | dim=-1,
107 | )
108 | sim_pos_vector = torch.cosine_similarity(text_embeddings, text_pos_embeddings, dim=-1)
109 |
110 | # find the neg for eatch training sample
111 | neg_matrix = []
112 | for i in range(bs):
113 | neg_matrix.append(sim_neg_matrix[i, i * neg_num : (i + 1) * neg_num])
114 | sim_neg_matrix = torch.stack(neg_matrix)
115 | sim_matrix = torch.cat([sim_pos_vector.unsqueeze(1), sim_neg_matrix], dim=1)
116 | sim_matrix = sim_matrix / self.temperature
117 | labels = torch.zeros(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device)
118 | loss = self._cross_entropy_loss(sim_matrix, labels) * self.cls_w
119 | return loss
--------------------------------------------------------------------------------
/piccolo/data.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 | from dataclasses import dataclass
5 | from pathlib import Path
6 | from typing import Any, Sequence, cast, TypeAlias
7 | from datasets import Dataset as HfDataset
8 | from torch.utils.data import Dataset, RandomSampler
9 |
10 | from piccolo.data_structures import PairRetriContrastRecord, PairScoredRecord, PairClsContrastRecord
11 | from transformers.tokenization_utils import PreTrainedTokenizer
12 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
13 |
14 | Tokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
15 |
16 |
17 | class UniCollator:
18 | '''
19 | Uni Data Collator
20 |
21 | for retrieval, sts, pair classification, the query max length is 64, doc max length is 512
22 | for clustering and classification, we specially set the query max length to 512,
23 | bcz for clustering and classification task, query('text') is ofent much longer than the pos/neg ('label')
24 | '''
25 | def __init__(self, tokenizer: Tokenizer, max_length: int, q_max_length: int = 64) -> None:
26 | self.tokenizer = tokenizer
27 | self.max_length = max_length
28 | self.q_max_length = q_max_length
29 |
30 | def __call__(self, records: list) -> dict[str, torch.Tensor]:
31 | records = records[0]
32 | if isinstance(records[0], PairClsContrastRecord):
33 | texts = [record.text for record in records]
34 | texts_pos = [record.text_pos for record in records]
35 | texts_neg = []
36 | for i, record in enumerate(records):
37 | for neg in record.text_neg:
38 | texts_neg.append(neg)
39 | text_ids = self.tokenizer(texts, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids']
40 | text_pos_ids = self.tokenizer(texts_pos, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids']
41 | text_neg_ids = self.tokenizer(texts_neg, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids']
42 |
43 | return {
44 | 'text_ids': cast(torch.Tensor, text_ids),
45 | 'text_pos_ids': cast(torch.Tensor, text_pos_ids),
46 | 'text_neg_ids': cast(torch.Tensor, text_neg_ids),
47 | 'type': 'cls_contrast',
48 | }
49 | elif isinstance(records[0], PairRetriContrastRecord):
50 | texts = [record.text for record in records]
51 | texts_pos = [record.text_pos for record in records]
52 | texts_neg = []
53 | texts_neg_index = [] # index indictates for which text the negative sample belongs
54 | for i, record in enumerate(records):
55 | for neg in record.text_neg:
56 | texts_neg.append(neg)
57 | texts_neg_index.append(i)
58 |
59 | text_ids = self.tokenizer(texts, padding=True, max_length=self.q_max_length, truncation=True, return_tensors='pt',)['input_ids']
60 | text_pos_ids = self.tokenizer(texts_pos, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids']
61 | if len(texts_neg) > 0:
62 | text_neg_ids = self.tokenizer(texts_neg, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids']
63 | else:
64 | text_neg_ids = None
65 | return {
66 | 'text_ids': cast(torch.Tensor, text_ids),
67 | 'text_pos_ids': cast(torch.Tensor, text_pos_ids),
68 | 'text_neg_ids': cast(torch.Tensor, text_neg_ids),
69 | 'text_neg_index': cast(torch.Tensor, texts_neg_index),
70 | 'type': 'retri_contrast',
71 | }
72 | elif isinstance(records[0], PairScoredRecord):
73 | texts = [record.text for record in records]
74 | texts_pair = [record.text_pair for record in records]
75 | labels = [record.label for record in records]
76 | labels = torch.tensor(labels, dtype=torch.float32)
77 |
78 | text_ids = self.tokenizer(texts, padding=True, max_length=self.q_max_length, truncation=True, return_tensors='pt',)['input_ids']
79 | text_pair_ids = self.tokenizer(texts_pair, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids']
80 |
81 | return {
82 | 'text_ids': cast(torch.Tensor, text_ids),
83 | 'text_pair_ids': cast(torch.Tensor, text_pair_ids),
84 | 'labels': labels,
85 | 'type': 'cosent',
86 | }
87 | else:
88 | raise NotImplementedError("only support pairscored and pairneg records")
89 |
90 |
91 | @dataclass
92 | class TaskBatchIndex:
93 | name: str
94 | batch_index: list[int]
95 |
96 |
97 | @dataclass
98 | class DatsetWithInfo:
99 | hf_dataset: HfDataset
100 | name: str
101 | query_prefix: str = ''
102 | passage_prefix: str = ''
103 |
104 |
105 | class UniDataset(Dataset):
106 | '''
107 | Task-Homogenous Dataset
108 |
109 | Code is modified from M3E: https://github.com/wangyuxinwhy/uniem
110 | It's also adopted in SFR Embedding: https://blog.salesforceairesearch.com/sfr-embedded-mistral
111 |
112 | This technique can ensure that the in-batch-negative samples come from the same data set,
113 | and it is generally believed that this can improve the quality of the in batch negatives.
114 | '''
115 | def __init__(
116 | self,
117 | hf_datasets: list[DatsetWithInfo],
118 | neg_num: int = 1,
119 | batch_size: int = 32,
120 | max_samples: int | None = None,
121 | ):
122 | self.batch_size = batch_size
123 | self.hf_datasets = hf_datasets
124 | self.max_samples = max_samples
125 | self.name_dataset_map = {dataset.name: dataset.hf_dataset for dataset in hf_datasets}
126 | self.neg_num = neg_num
127 | self.query_prefix_map = {dataset.name: dataset.query_prefix for dataset in hf_datasets}
128 | self.passage_prefix_map = {dataset.name: dataset.passage_prefix for dataset in hf_datasets}
129 | self.create_or_refresh_data()
130 |
131 | def __len__(self):
132 | return len(self.task_batch_index_list)
133 |
134 | @staticmethod
135 | def is_valid_text(text: Any) -> bool:
136 | return isinstance(text, str) and bool(text.strip())
137 |
138 | def create_or_refresh_data(self):
139 | self.task_batch_index_list: list[TaskBatchIndex] = []
140 | for dataset in self.hf_datasets:
141 | max_samples = self.max_samples or len(dataset.hf_dataset)
142 | batch_size = self.batch_size
143 | num_samples = (max_samples // batch_size) * batch_size
144 | buffer = []
145 | for i in RandomSampler(dataset.hf_dataset, num_samples=num_samples):
146 | buffer.append(i)
147 | if len(buffer) == batch_size:
148 | self.task_batch_index_list.append(TaskBatchIndex(name=dataset.name, batch_index=buffer))
149 | buffer = []
150 | self.random_index_list = list(RandomSampler(self.task_batch_index_list))
151 |
152 |
153 | def get_pair_scored_records(self, records, task_name):
154 | pair_records = []
155 | for record in records:
156 | text = record['text']
157 | text_pair = record['text_pair']
158 | label = record['label']
159 | if not (self.is_valid_text(text) and self.is_valid_text(text_pair)):
160 | continue
161 | text = self.query_prefix_map[task_name] + text
162 | text_pair = self.passage_prefix_map[task_name] + text_pair
163 | pair_records.append(PairScoredRecord(text=text, text_pair=text_pair, label=label))
164 | return pair_records
165 |
166 | def get_pair_retri_contrast_records(self, records, task_name, hf_dataset, batch_index):
167 |
168 | def process_records(record):
169 | text = record['text']
170 | if isinstance(record['text_pos'], list): # random sample a positive
171 | assert len(record['text_pos']) >= 1, 'text pos num should be at least 1'
172 | text_pos = random.sample(record['text_pos'], 1)[0]
173 | else:
174 | text_pos = record['text_pos']
175 |
176 | if not (self.is_valid_text(text) and self.is_valid_text(text_pos)):
177 | # skip current sample and random sample an index
178 | random_index = random.sample(range(len(hf_dataset)), k=1)[0]
179 | while random_index in batch_index:
180 | random_index = random.sample(range(len(hf_dataset)), k=1)[0]
181 | return process_records(hf_dataset[random_index])
182 |
183 | text_neg = random.sample(record['text_neg'], min(self.neg_num, len(record['text_neg'])))
184 | text = self.query_prefix_map[task_name] + text
185 | text_pos = self.passage_prefix_map[task_name] + text_pos
186 | text_neg = [self.passage_prefix_map[task_name] + neg for neg in text_neg]
187 | return text, text_pos, text_neg
188 |
189 | pair_records = []
190 | for record in records:
191 | text, text_pos, text_neg = process_records(record)
192 | pair_records.append(PairRetriContrastRecord(text=text, text_pos=text_pos, text_neg=text_neg))
193 | assert len(pair_records) == self.batch_size, 'error, current batch size not match !!!'
194 | return pair_records
195 |
196 | def get_pair_cls_contrast_records(self, records, task_name):
197 | pair_records = []
198 | for record in records:
199 | text, text_pos, text_neg = record['text'], record['text_pos'], record['text_neg']
200 | if isinstance(record['text_pos'], list):
201 | text_pos = random.sample(record['text_pos'], 1)[0]
202 | elif isinstance(record['text_pos'], str):
203 | text_pos = record['text_pos']
204 | else:
205 | assert False, 'type error'
206 | text_neg = random.sample(record['text_neg'], min(10, len(record['text_neg']))) # TODO, hard code the neg num to 10
207 | if self.is_valid_text(text) and self.is_valid_text(text_pos):
208 | pair_records.append(PairClsContrastRecord(text=text, text_pos=text_pos, text_neg=text_neg))
209 | return pair_records
210 |
211 | def __getitem__(self, index: int):
212 | index = self.random_index_list[index]
213 | task_batch_index = self.task_batch_index_list[index]
214 | task_name = task_batch_index.name
215 | batch_index = task_batch_index.batch_index
216 |
217 | hf_dataset = self.name_dataset_map[task_name]
218 | records = [hf_dataset[i] for i in batch_index]
219 | if hf_dataset[0]['type'] == 'cls_contrast':
220 | pair_records = self.get_pair_cls_contrast_records(records, task_name)
221 | elif hf_dataset[0]['type'] == 'retri_contrast':
222 | pair_records = self.get_pair_retri_contrast_records(records, task_name, hf_dataset, batch_index)
223 | elif hf_dataset[0]['type'] == 'cosent':
224 | pair_records = self.get_pair_scored_records(records, task_name)
225 | else:
226 | raise NotImplementedError('only support pair contrast and pair scored')
227 |
228 | if not pair_records:
229 | print(f'records is empty', records)
230 | return self.__getitem__(index + 1)
231 | return pair_records
232 |
--------------------------------------------------------------------------------
/piccolo/data_structures.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | @dataclass(slots=True)
4 | class PairRetriContrastRecord:
5 | text: str
6 | text_pos: str
7 | text_neg: list
8 |
9 | @dataclass(slots=True)
10 | class PairClsContrastRecord:
11 | text: str
12 | text_pos: str
13 | text_neg: list
14 |
15 | @dataclass(slots=True)
16 | class PairScoredRecord:
17 | text: str
18 | text_pair: str
19 | label: float
--------------------------------------------------------------------------------
/piccolo/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 |
5 | from enum import Enum
6 | from pathlib import Path
7 | from typing import ClassVar, Literal, Type, TypeVar, cast
8 | from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel # type: ignore
9 | from piccolo.criteria import CoSentLoss, ClsContrastLoss, RetriContrastLoss
10 |
11 | class PoolingStrategy(str, Enum):
12 | cls = 'cls'
13 | mean = 'mean'
14 |
15 | class InBatchNegLossType(str, Enum):
16 | cosent = 'cosent'
17 | retri_contrast = 'retri_contrast'
18 | cls_contrast = 'cls_contrast'
19 |
20 | def build_loss(loss_type, temperature, **kwargs):
21 | loss_type = InBatchNegLossType(loss_type)
22 | match loss_type:
23 | case InBatchNegLossType.cosent:
24 | return CoSentLoss(temperature)
25 | case InBatchNegLossType.cls_contrast:
26 | return ClsContrastLoss(temperature)
27 | case InBatchNegLossType.retri_contrast:
28 | return RetriContrastLoss(temperature, **kwargs)
29 |
30 | def creat_attention_mask_from_input_ids(input_ids: torch.Tensor, pad_token_id: int) -> torch.Tensor:
31 | return input_ids != pad_token_id
32 |
33 | def mean_pooling(hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
34 | if attention_mask is None:
35 | return torch.mean(hidden_state, dim=1)
36 | attention_mask = attention_mask.float()
37 | return torch.sum(hidden_state * attention_mask.unsqueeze(-1), dim=1) / torch.sum(attention_mask, dim=-1, keepdim=True)
38 |
39 | def load_hf_pretrained_model(model_name_or_path: str) -> PreTrainedModel:
40 | config = AutoConfig.from_pretrained(model_name_or_path)
41 | if config.model_type == 't5':
42 | from transformers import T5EncoderModel # type: ignore
43 | pretrained_model = T5EncoderModel.from_pretrained(model_name_or_path)
44 | else:
45 | pretrained_model = AutoModel.from_pretrained(model_name_or_path)
46 | return pretrained_model # type: ignore
47 |
48 |
49 | StrategyEmbedderClsMap: dict[PoolingStrategy, Type['Embedder']] = {}
50 |
51 | class Embedder(torch.nn.Module):
52 | pooling_strategy: ClassVar[PoolingStrategy]
53 |
54 | def __init__(self, encoder: PreTrainedModel, pad_token_id: int | None = None):
55 | super().__init__()
56 | self.encoder = encoder
57 |
58 | if pad_token_id is None:
59 | if encoder.config.pad_token_id is not None:
60 | self.pad_token_id = encoder.config.pad_token_id
61 | else:
62 | self.pad_token_id = 0
63 | else:
64 | self.pad_token_id = pad_token_id
65 |
66 | def __init_subclass__(cls) -> None:
67 | StrategyEmbedderClsMap[cls.pooling_strategy] = cls
68 |
69 | def save_pretrained(self, path: str | Path):
70 | self.encoder.save_pretrained(path)
71 |
72 | @classmethod
73 | def from_pretrained(cls, model_name_or_path: str):
74 | encoder = load_hf_pretrained_model(model_name_or_path)
75 | return cls(encoder)
76 |
77 |
78 | class LastMeanEmbedder(Embedder):
79 | pooling_strategy: ClassVar[PoolingStrategy] = PoolingStrategy.mean
80 |
81 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
82 | if attention_mask is None:
83 | attention_mask = creat_attention_mask_from_input_ids(input_ids, self.pad_token_id)
84 | embeddings = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
85 | embeddings = mean_pooling(embeddings, attention_mask)
86 | return embeddings
87 |
88 |
89 | class ClsEmbedder(Embedder):
90 | pooling_strategy: ClassVar[PoolingStrategy] = PoolingStrategy.cls
91 |
92 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
93 | if attention_mask is None:
94 | attention_mask = creat_attention_mask_from_input_ids(input_ids, self.pad_token_id)
95 | embeddings = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
96 | return embeddings
97 |
98 |
99 | class EmbedderForTrain(torch.nn.Module):
100 | embedder: Embedder
101 |
102 | def __init__(self, embedder: Embedder):
103 | super().__init__()
104 | self.embedder = embedder
105 |
106 | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs):
107 | self.embedder.encoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
108 |
109 |
110 | class ScalingLayer(torch.nn.Module):
111 | def __init__(self, origin_dim: int = 1024, scaling_dim: int = 1792):
112 | super().__init__()
113 | self.linear = torch.nn.Linear(in_features=origin_dim, out_features=scaling_dim, bias=True)
114 | def forward(self, input):
115 | return self.linear(input)
116 |
117 |
118 | class STEmbedder(EmbedderForTrain):
119 | """
120 | Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
121 |
122 | Args:
123 | model_name_or_path (`str` ):
124 | the path of the pretrain model
125 | embedding_strategy (`PoolingStrategy`):
126 | only support 'mean' and 'cls'
127 | for 'mean', we use the average embedding across all valid tokens embeddings of BERT's last layer.
128 | for 'cls', we use the first token embedding of the BERT's last layer
129 | freeze_pos_emb (`bool`):
130 | whether fix the position embedding or not, default is True
131 | add_scaling_layer (`bool`):
132 | add a scaling layer after the last layer of BERT, it is used for dimension scaling.
133 | use_mrl (`bool`)
134 | employ the Matryoshka Representation Learning algorithm to support flexible dimension
135 | extend_pe (`bool`)
136 | extend position embedding to longer length, here we adopt a very simple method from:
137 | https://kexue.fm/archives/7947
138 |
139 | Notation:
140 | some parameter are hard-coded in this code, such as in_feature and out_feature of scaling layer, and nesting list of MRL.
141 | """
142 | def __init__(
143 | self,
144 | model_name_or_path: str,
145 | embedding_strategy: PoolingStrategy | str = PoolingStrategy.mean,
146 | freeze_pos_emb: bool = True,
147 | add_scaling_layer: bool = False,
148 | use_mrl: bool = False,
149 | extend_pe: bool = False,
150 | max_length: int = 512,
151 | ):
152 | pretrained_model = load_hf_pretrained_model(model_name_or_path)
153 | embedder = StrategyEmbedderClsMap[PoolingStrategy(embedding_strategy)](pretrained_model)
154 | super().__init__(embedder)
155 | self.retri_contrst_loss = build_loss('retri_contrast', temperature=0.01, use_all_pair=True)
156 | self.cosent_loss = build_loss('cosent', temperature=0.05)
157 | self.cls_contrast_loss = build_loss('cls_contrast', temperature=0.05)
158 | self.use_mrl = use_mrl
159 | self.add_scaling_layer = add_scaling_layer
160 |
161 | if add_scaling_layer:
162 | '''
163 | Here we hard code the scaling layer pretrain path, input_dim and output_dim, you can modify it by yourself
164 | '''
165 | self.scaling_layer = ScalingLayer(origin_dim=1024, scaling_dim=1792)
166 | if os.path.exists(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')):
167 | scaling_layer_state_dict = torch.load(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin'))
168 | self.scaling_layer.load_state_dict(scaling_layer_state_dict, strict=True)
169 | print('load scaling layer successfully')
170 | else:
171 | print('not found pretrain, random init scaling layer')
172 |
173 | if use_mrl:
174 | self.mrl_nesting_list = [256, 512, 768, 1024, 1280, 1536, 1792] # hard code here
175 |
176 | if extend_pe:
177 | sp = 0 # TODO, hard code here, for xlm roberta, this should be 2
178 | # re-init the position embeddings
179 | org_pe = self.embedder.encoder.embeddings.position_embeddings
180 | pad_idx = self.embedder.encoder.embeddings.position_embeddings.padding_idx
181 | extended_pe = torch.nn.Embedding(max_length + sp, org_pe.embedding_dim, padding_idx=pad_idx)
182 | for start_idx in range(0, max_length + sp, org_pe.num_embeddings): # 迭代式地去复制,从而扩增embedding
183 | end_idx = min(start_idx + org_pe.num_embeddings, max_length + sp)
184 | extended_pe.weight.data[start_idx : end_idx] = org_pe.weight.data[:end_idx - start_idx].clone()
185 | self.embedder.encoder.embeddings.position_embeddings = extended_pe
186 | self.embedder.encoder.embeddings.position_ids = torch.arange(max_length + sp).expand((1, -1))
187 | self.embedder.encoder.embeddings.token_type_ids = \
188 | torch.zeros(self.embedder.encoder.embeddings.position_ids.size(), dtype=torch.long)
189 | self.embedder.encoder.config.max_position_embeddings = max_length + sp
190 |
191 | if not extend_pe and freeze_pos_emb: # extend pe时, 不能 freeze pos emb
192 | for name, param in self.embedder.encoder.embeddings.named_parameters():
193 | if "position_embeddings" in name:
194 | param.requires_grad = False
195 |
196 | def get_embedding(self, text_ids):
197 | if text_ids is None:
198 | return None
199 | text_embeddings = self.embedder(text_ids)
200 | if self.add_scaling_layer:
201 | text_embeddings = self.scaling_layer(text_embeddings.half()).float()
202 | return text_embeddings
203 |
204 | def compute_cls_loss(self, text_ids: torch.Tensor, text_labels: torch.tensor):
205 | text_embeddings = self.get_embedding(text_ids)
206 | pred_cls = self.cls_head(text_embeddings.half())
207 | loss = torch.nn.functional.cross_entropy(pred_cls, text_labels)
208 | return {'loss': loss}
209 |
210 | def compute_cls_contrast_loss(self, text_ids: torch.Tensor, text_pos_ids: torch.Tensor,
211 | text_neg_ids: torch.Tensor = None, type: str = 'cls_contrast') -> dict[str, torch.Tensor]:
212 | text_embeddings = self.get_embedding(text_ids)
213 | text_pos_embeddings = self.get_embedding(text_pos_ids)
214 | text_neg_embeddings = self.get_embedding(text_neg_ids)
215 |
216 | if self.use_mrl:
217 | loss = torch.tensor(0.0, device=text_embeddings.device)
218 | for num_feat in self.mrl_nesting_list:
219 | emb, pos_emb, neg_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat], text_neg_embeddings[..., :num_feat]
220 | loss += self.cls_contrast_loss(emb, pos_emb, neg_emb) / len(self.mrl_nesting_list)
221 | else:
222 | loss = self.cls_contrast_loss(text_embeddings, text_pos_embeddings, text_neg_embeddings)
223 | print('cls contrast loss: ', loss)
224 | return {'loss': loss}
225 |
226 | def compute_retri_contrast_loss(self, text_ids: torch.Tensor, text_pos_ids: torch.Tensor, text_neg_ids: torch.Tensor = None,
227 | type: str = 'retri_contrast', **kwargs) -> dict[str, torch.Tensor]:
228 | text_embeddings = self.get_embedding(text_ids)
229 | text_pos_embeddings = self.get_embedding(text_pos_ids)
230 | text_neg_embeddings = self.get_embedding(text_neg_ids)
231 |
232 | if self.use_mrl:
233 | loss = torch.tensor(0.0, device=text_embeddings.device)
234 | for num_feat in self.mrl_nesting_list:
235 | if text_neg_embeddings is not None:
236 | emb, pos_emb, neg_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat], text_neg_embeddings[..., :num_feat]
237 | else:
238 | emb, pos_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat]
239 | neg_emb = None
240 | loss += self.retri_contrst_loss(emb, pos_emb, neg_emb, **kwargs) / len(self.mrl_nesting_list)
241 | else:
242 | loss = self.retri_contrst_loss(text_embeddings, text_pos_embeddings, text_neg_embeddings, **kwargs)
243 | print('triplet loss: ', loss)
244 | return {'loss': loss}
245 |
246 | def compute_cosent_loss(self, text_ids: torch.Tensor, text_pair_ids: torch.Tensor, labels: torch.Tensor, type: str = 'cosent'):
247 | text_embeddings = self.get_embedding(text_ids)
248 | text_pair_embeddings = self.get_embedding(text_pair_ids)
249 | if self.use_mrl:
250 | loss = torch.tensor(0.0, device=text_embeddings.device)
251 | for num_feat in self.mrl_nesting_list:
252 | emb, emb_pair = text_embeddings[..., :num_feat], text_pair_embeddings[..., :num_feat]
253 | predict_labels = torch.cosine_similarity(emb, emb_pair, dim=-1)
254 | loss += self.cosent_loss(predict_labels, labels) / len(self.mrl_nesting_list)
255 | else:
256 | predict_labels = torch.cosine_similarity(text_embeddings, text_pair_embeddings, dim=-1)
257 | loss = self.cosent_loss(predict_labels, labels)
258 | print('cosent loss: ', loss)
259 | return {'loss': loss, 'predict_labels': predict_labels}
260 |
261 |
262 | def forward(self, **kwargs):
263 | if kwargs['type'] == 'cls_contrast':
264 | return self.compute_cls_contrast_loss(**kwargs)
265 | elif kwargs['type'] == 'retri_contrast':
266 | return self.compute_retri_contrast_loss(**kwargs)
267 | elif kwargs['type'] == 'cosent':
268 | return self.compute_cosent_loss(**kwargs)
269 | else:
270 | raise NotImplementedError('not suuport current input kwargs')
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers>=4.36.1
2 | datasets>=2.16.1
3 | torch>=2.1.1
4 | enum
--------------------------------------------------------------------------------
/scripts/ft.sh:
--------------------------------------------------------------------------------
1 | ROOT=/mnt/lustre/huangjunqin/piccolo
2 | export PYTHONPATH=$ROOT:${PYTHONPATH}
3 |
4 | # SLURM Parameter
5 | GPUS_PER_NODE=1
6 | if [ -z "$WORLD_SIZE" ]; then
7 | WORLD_SIZE=1
8 | RANK=0
9 | MASTER_ADDR=127.0.0.1
10 | MASTER_PORT=6000
11 | fi
12 |
13 | # Hyper Parameter Start
14 | MODEL_NAME_OR_PATH=/mnt/lustre/huangjunqin/test/piccolo2-large-zh-0417
15 | EPOCHS=3
16 | BATCH_SIZE=8
17 | LR=1e-5
18 | NEG_NUM=1
19 | DS_PATH=$ROOT/ds_config_zero1.json
20 | MAX_LENGTH=512
21 | META_PATHS=(
22 | meta_lists/piccolo-ft.txt
23 | )
24 |
25 | ROOT_DIRS=(
26 | data_example/
27 | )
28 | # Hyper Parameter End
29 |
30 |
31 | model_args=(
32 | "--model_name_or_path" $MODEL_NAME_OR_PATH
33 | "--max_length=$MAX_LENGTH"
34 | "--query_prefix=''"
35 | "--doc_prefix=''"
36 | "--use_scaling_layer=True"
37 | "--use_mrl=True"
38 | )
39 |
40 | data_args=(
41 | "--meta_paths" "${META_PATHS[@]}"
42 | "--root_dirs" "${ROOT_DIRS[@]}"
43 | "--neg_num=$NEG_NUM"
44 | )
45 |
46 | train_args=(
47 | "--fp16"
48 | "--gradient_checkpointing=True"
49 | "--output_dir=output_test"
50 | "--num_train_epochs=$EPOCHS"
51 | "--dataloader_num_workers=0"
52 | "--batch_size=$BATCH_SIZE"
53 | "--learning_rate=$LR"
54 | "--deepspeed=$DS_PATH"
55 | "--logging_steps=500"
56 | "--save_safetensors=False"
57 | "--report_to=tensorboard"
58 | "--save_strategy=epoch"
59 | "--per_device_train_batch_size=1"
60 | )
61 |
62 | all_args=("${model_args[@]}" "${data_args[@]}" "${train_args[@]}")
63 |
64 |
65 | export LAUNCHER="python -u -m torch.distributed.run \
66 | --nproc_per_node $GPUS_PER_NODE \
67 | --nnodes $WORLD_SIZE \
68 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
69 | --rdzv_backend c10d \
70 | --max_restarts 0 \
71 | --tee 3 \
72 | "
73 |
74 | export CMD=" \
75 | $ROOT/finetune/train.py \
76 | "
77 |
78 | echo $CMD
79 |
80 | bash -c "$LAUNCHER $CMD ${all_args[*]}"
--------------------------------------------------------------------------------