├── .gitattributes
├── README.md
├── data
├── gkd_data.jsonl
├── rlhf.jsonl
├── sft_data.jsonl
└── test.jsonl
├── evaluate
├── README.md
├── args.py
├── base_utils.py
├── dataset
│ └── humaneval_python.jsonl
├── evaluation.py
├── execution.py
├── generate.py
├── main.py
├── run.sh
└── task
│ └── humaneval.py
├── inference_vllm.py
├── llm_tricks
├── DPO_example
│ ├── README.md
│ ├── dataset.py
│ ├── dpo_train.py
│ ├── evaluate.py
│ ├── loss.py
│ └── unsloth_dpo.jsonl
├── dora
│ ├── READEME.md
│ ├── dora_example.py
│ └── lora_and_dora.ipynb
├── moe
│ ├── READEME.md
│ ├── input.txt
│ └── make_moe_step_by_step.ipynb
└── transformer
│ └── README.md
├── main_train.py
├── pic
└── pic.jpg
├── requirements.txt
├── rlhf
├── README.md
├── __init__.py
├── common_args.py
├── ds_config
│ ├── ds_zero2.yaml
│ └── ds_zero3.yaml
├── gkd_run.sh
├── rejected_sampling
│ ├── README.md
│ ├── genetate.py
│ ├── rejected_sampling.py
│ ├── run_generate.sh
│ └── template.py
├── requirements.txt
├── rlhf_args
│ ├── base_config.py
│ ├── cpo-simpo_config.py
│ ├── cpo_config.py
│ ├── dpo_config.py
│ ├── kto_config.py
│ ├── ppo_config.py
│ ├── reward_config.py
│ ├── rloo_config.py
│ └── simpo_config.py
├── rlhf_run.sh
├── train_gkd.py
├── train_rlhf.py
└── utils
│ └── util.py
├── run_eval_test.sh
├── run_example.sh
├── run_vlm_example.sh
├── train_args
├── __init__.py
├── common_args.py
├── deepspeed_config
│ ├── ds_config_zero0.json
│ ├── ds_config_zero2.json
│ └── ds_config_zero3.json
├── dpo
│ └── README.md
├── sft
│ └── base.py
└── vlm_config
│ └── script_args.py
├── utils
├── __init__.py
├── data_collator.py
├── data_process.py
├── eval
│ ├── README.md
│ ├── callback.py
│ ├── configs.py
│ ├── eval_metric.py
│ ├── eval_utils.py
│ ├── train_script.py
│ └── vllm
│ │ ├── run_serve.sh
│ │ ├── vllm_client.py
│ │ └── vllm_serve.py
├── script
│ ├── download_model.py
│ ├── generate_data.py
│ └── merge_lora.py
└── vlm_template.py
└── vlm_train.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-language=Python
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # LLM-Dojo: 大模型修炼道场 😊
3 |
4 |
5 | Tips: 图片完全由AI生成
6 | ## 🌟 项目简介
7 | LLM-Dojo使用简洁且易阅读的代码构建LLM、VLM模型训练、RLHF框架等各种功能,使项目**易于学习且方便魔改与实验**,与大多开源框架相同均是基于huggingface。
8 | 主要内容如下:
9 | - **SFT训练框架:** 简洁清晰的开源大模型训练框架,支持Deepspeed多卡、Lora、QLora、全参等训练,自动适配chat template。
10 | - **VLM多模态训练框架:** 支持多模态各种任务训练(目前仅实现QA),自动适配模型template。
11 | - **RLHF框架:** RLHF训练框架,持续更新,包括 知识蒸馏,DPO、RLOO、SimPO等各种强化学习方法,适配Deepspeed多卡及Lora,一张A100即可运行,详情可见: [RLHF](./rlhf/README.md)。
12 | - **最新LLM tricks详解:** 持续更新大模型领域最新tricks介绍,包括新论文方法的复现等,希望可以给你一些创新的想法,该模块主要集中在```llm_tricks```文件夹下。
13 |
14 | ### 目录
15 |
16 | - [项目简介](#-项目简介)
17 | - [Latest News](#-latest-news)
18 | - [RLHF训练框架](#rlhf训练框架)
19 | - [已支持RLHF方法](#rlhf训练框架)
20 | - [知识蒸馏](#rlhf训练框架)
21 | - [拒绝采样](#rlhf训练框架)
22 | - [SFT训练框架](#sft训练框架)
23 | - [已支持微调模型](#已支持微调模型)
24 | - [训练数据格式说明](#训练数据格式说明)
25 | - [适配框架数据处理](#适配框架数据处理)
26 | - [Quick Start](#quick-start)
27 | - [多模态训练(VLM)](#多模态训练vlm)
28 | - [已支持模型](#已支持模型)
29 | - [已支持任务类型](#已支持任务类型)
30 | - [数据格式](#数据格式)
31 | - [Tricks](#tricks)
32 | - [技术发文](#技术发文)
33 | - [致谢](#-致谢)
34 |
35 | ## 📖 Latest News
36 | - [2025-04-18] 拒绝采样实现,由generate部分和评测部分组成,支持API作为评价模型,具体可见[Rejected Sampling](./rlhf/rejected_sampling/README.md),详细文档解释尚未完成。
37 | - [2025-04-13] 🚀训练中评测,使用vllm大幅提升训练中模型生成速度,具体可见[Predict with generate](./utils/eval/README.md)。详细文档解释与优化正在进行。
38 | - [2024-12-31] 支持多模态训练,可见[多模态训练(VLM)](#多模态训练vlm)
39 | - [2024-11-06] 重构RLHF,具体可见目录中RLHF训练框架部分
40 | - [2024-10-31] 添加auto_adapt参数控制是否自动适配template、更新优化DPO训练(迁移至RLHF目录下)
41 | - [2024-10-15] 增加知识蒸馏训练方法。可见[知识蒸馏](./rlhf/README.md)
42 | - [2024-10-14] 删除chat template模块,因为使用tokenizer的apply_chat_template即可
43 | - [2024-09-20] 增加evaluate模块,一个简洁的模型评测框架,目前仅支持Humaneval。可见[Evaluate](./evaluate/README.md)
44 | More news...
45 |
46 | - [2024-08-27] 🤓增加从零实现自己编写DPO、SimPO代码,包括数据、loss、训练等部分。可见[DPO example](./llm_tricks/DPO_example/README.md)
47 | - [2024-08-08] 支持直接修改配置文件启动及命令行启动,增加框架适配数据处理代码。
48 | - [2024-08-04] 支持自适应单轮或多轮对话,无需指定单轮或多轮,训练根据数据自行判断单轮或多轮。且可自主设置system命令。可见[训练数据格式说明](#训练数据格式说明)
49 | - [2024-07-19] RLHF 强化学习框架新增CPO,SimPO,以及二者融合CPO-SimPO
50 | - [2024-07-16] RLHF 强化学习框架更新完成,支持deepspeed单卡/多卡 进行强化学习lora、qlora等训练,详细可见[RLHF](./rlhf/README.md)
51 | - [2024-06-9] 🚀支持DPO训练,分为单轮对话DPO(自己构建,方便魔改)和多轮对话DPO(简洁实现),支持deepspeed的lora和qlora,具体介绍可见 [DPO使用说明](./train_args/dpo/README.md)
52 | - [2024-06-5] 🤓llm_tricks 增加从头开始实现MOE
53 | - [2024-06-10] 🚀增加一步一步实现Transformer技术发文(包括代码等从零介绍),可见 [技术发文](#技术发文)
54 | - [2024-05-18] 🤓支持Deepspeed单机多卡、单机单卡的Lora、Qlora、全量微调等训练!
55 | - [2024-04-28] 🚀 更新dora微调原理示例、支持qwen模型微调
56 |
57 |
58 | ## RLHF训练框架
59 |
60 | RLHF训练框架,支持并持续更新 知识蒸馏、Reward、PPO、DPO、RLOO、SimPO、KTO等各种强化学习方法,适配Deepspeed多卡及Lora,一张A100即可运行。
61 | 详情可见: [RLHF](./rlhf/README.md)。
62 |
63 | 主要包括三类:
64 |
65 | **1、RLHF**
66 |
67 | **2、Knowledge Distillation (知识蒸馏)**
68 |
69 | **3、Rejected Sampling (拒绝采样) :待更新**
70 |
71 | ## SFT训练框架
72 |
73 | ### 已支持微调模型
74 | 理论上支持对所有模型的微调,下述仅为测试过。
75 |
76 | 支持基于Deepspeed的多卡/单卡 Lora、Qlora、Dora微调:
77 | - [x] [Qwen(Qwen1.5/Qwen2)](https://github.com/QwenLM/Qwen.git)
78 | - [x] [Yi](https://github.com/01-ai/Yi)
79 | - [x] [Gemma系列](https://github.com/google/gemma_pytorch)
80 | - [x] [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)
81 | - [x] [Deepseek](https://github.com/deepseek-ai/DeepSeek-LLM)
82 | - [x] [MiniCPM](https://github.com/OpenBMB/MiniCPM)
83 | - [x] [Llama系列](https://github.com/meta-llama/llama3)
84 | - [x] [deepseek-coder](https://github.com/deepseek-ai/DeepSeek-Coder)
85 | - [x] [哔哩哔哩 Index-1.9B](https://github.com/bilibili/Index-1.9B)
86 | - [x] [baichuan系列](https://github.com/baichuan-inc/Baichuan2)
87 | - [x] [GLM系列](https://github.com/THUDM/GLM-4)
88 |
89 | ### 😮训练数据格式说明
90 | SFT数据格式为user(system) assistant标准模式,**无需指定单轮或多轮,训练根据数据自行判断单轮或多轮。**
91 |
92 | 示例如下,示例文件可参见```data/sft_data.jsonl```:
93 | ```json lines
94 | {"message": [{"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"},{"role": "user", "content": "How many helicopters can a human eat in one sitting"},{"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together"},{"role": "user", "content": "你好"},{"role": "assistant", "content": "hellow"}]}
95 | ```
96 | 可根据需求自行决定是否增加system字段,**建议训练数据没有特殊需求可删除system字段**
97 |
98 | 训练参数中auto_adapt参数控制是否自动适配template,如设置为False,则不自动适配,按原始的content进行训练。
99 |
100 | ### 适配框架数据处理
101 | 鉴于框架指定格式数据可能会跟常规数据有些不同,故可以通过```utils/script/generate_data.py```文件进行处理,输入应为正常的instruction和output的jsonl格式文件,
102 | 如下:
103 | ```json lines
104 | {"instruction":"将这个句子改写成将来时态:“太阳将会照耀明亮。”","output":"太阳将会散发温暖的光芒。"}
105 | ```
106 | 运行后即可得到无system的user、assistant指定格式。
107 |
108 | ### 🤓Quick Start
109 | 目前支持直接**python命令单卡训练**、**deepspeed(推荐使用)单机多卡**及**单机单卡训练**. 所有方式均支持Qlora、Lora、Dora方法。
110 |
111 |
112 | 1、 支持**命令行传参**启动,启动示例可见: ```run_example.sh```。 **相关参数在train_args下的common_args.py和sft/base.py。**
113 | ```bash
114 | bash run_example.sh
115 | ```
116 |
117 | 2、 也支持**参数文件直接修改默认值**,改好参数后运行以下命令启动:
118 | ```bash
119 | deepspeed --include localhost:6,7 main_train.py
120 | ```
121 | 更详细的Deepspeed原理及解释可以看文章:[Deepspeed配置及使用讲解](https://zhuanlan.zhihu.com/p/698631348)
122 |
123 |
124 | 显存占用测试如下:
125 |
126 | | 策略 | 模型大小 | 显存占用 |
127 | |------------|----------|------|
128 | | Lora | Qwen(7B) | 26g |
129 | | Lora+Zero2 | Qwen(7B) | 26g |
130 | | Lora+zero3 | Qwen(7B) | 16g |
131 |
132 | ## 多模态训练(VLM)
133 |
134 | 支持Deepspeed多卡 Lora、Qlora,冻结vision、冻结projector训练等
135 |
136 | ### 已支持模型
137 | - [x] [Qwen-2-VL](https://github.com/QwenLM/Qwen2-VL)
138 | - [x] [Llava](https://github.com/haotian-liu/LLaVA)
139 |
140 | ### 已支持任务类型
141 |
142 | - Visual Question Answering
143 |
144 | ### 数据格式
145 |
146 | **Visual Question Answering:**
147 |
148 | - metadata.jsonl: 包含所有图片与文字信息,示例如下:
149 |
150 | ```json lines
151 | {"file_name":"Images/P0003_0004.png", "messages":[{"question":"how are you", "answer":"i am fine"}]}
152 | ```
153 |
154 | 其中file_name为train_data_path下的的图片路径,具体可如下:
155 | ```
156 | train_data_path
157 | ├─ metadata.jsonl
158 | └─ Images
159 | └─ P0003_0004.png
160 | └─ ...........png
161 | ```
162 |
163 | ### Quick Start
164 |
165 | 通过freeze_vision、freeze_projector参数控制是否冻结vision、projector。
166 |
167 | ```bash
168 | bash run_vlm_example.sh
169 | ```
170 |
171 | ## Tricks
172 | 所有相关的trciks及讲解都在llm_tricks文件夹下
173 | - [Dora代码讲解(llm_tricks/dora/READEME.md)](./llm_tricks/dora/READEME.md)
174 | - [Lora+微调代码实例](https://github.com/mst272/simple-lora-plus)
175 | - [从零实现MOE](./llm_tricks/moe/READEME.md)
176 | - [从零实现DPO](./llm_tricks/DPO_example/README.md)
177 | - [从零实现Transformer](./llm_tricks/transformer/README.md)
178 |
179 | ### 技术发文
180 | More news...
181 |
182 | - [Deepspeed配置及使用讲解](https://zhuanlan.zhihu.com/p/698631348)
183 | - [从零代码构建MOE](https://zhuanlan.zhihu.com/p/701777558)
184 | - [一步一步实现Transformer代码](https://medium.com/@sdwzh2725/transformer-code-step-by-step-understandingtransformer-d2ea773f15fa)
185 | - [DPO训练QWEN2及魔改DPO实现](https://zhuanlan.zhihu.com/p/702569978)
186 | - [实现强化学习(RLHF)全流程代码构建(PPO、RLOO等)](https://zhuanlan.zhihu.com/p/708935028)
187 | - [从零实现强化学习DPO(SimPO)训练代码](https://zhuanlan.zhihu.com/p/716706368)
188 | - [实现一个简洁的代码模型评测框架(以Qwen2.5-coder 评测Humaneval为例)](https://zhuanlan.zhihu.com/p/721218072)
189 | - [大模型LLM知识蒸馏代码讲解与训练](https://zhuanlan.zhihu.com/p/1064724364)
190 |
191 |
192 |
193 |
194 | ## 🤝 致谢!
195 | 项目学习了优秀开源项目,感谢huggingface、流萤及一些国内外开源项目。
196 |
197 | 🪂 无论是提出问题(Issue)还是贡献代码(Pull Request),都是对项目的巨大支持。
198 | ***
199 |
--------------------------------------------------------------------------------
/data/test.jsonl:
--------------------------------------------------------------------------------
1 | {"prompt": "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", "label": "\n\n\n\n\ndef check(has_close_elements):\n assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\n assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\n assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\n assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\n assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\n\ncheck(has_close_elements)"}
2 | {"prompt": "from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", "label": "\n\n\n\n\ndef check(separate_paren_groups):\n assert separate_paren_groups('(()()) ((())) () ((())()())') == [\n '(()())', '((()))', '()', '((())()())'\n ]\n assert separate_paren_groups('() (()) ((())) (((())))') == [\n '()', '(())', '((()))', '(((())))'\n ]\n assert separate_paren_groups('(()(())((())))') == [\n '(()(())((())))'\n ]\n assert separate_paren_groups('( ) (( )) (( )( ))') == ['()', '(())', '(()())']\n\ncheck(separate_paren_groups)"}
3 | {"prompt": "\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", "label": "\n\n\n\n\ndef check(truncate_number):\n assert truncate_number(3.5) == 0.5\n assert abs(truncate_number(1.33) - 0.33) < 1e-6\n assert abs(truncate_number(123.456) - 0.456) < 1e-6\n\ncheck(truncate_number)"}
4 | {"prompt": "from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n", "label": "\n\n\n\n\ndef check(below_zero):\n assert below_zero([]) == False\n assert below_zero([1, 2, -3, 1, 2, -3]) == False\n assert below_zero([1, 2, -4, 5, 6]) == True\n assert below_zero([1, -1, 2, -2, 5, -5, 4, -4]) == False\n assert below_zero([1, -1, 2, -2, 5, -5, 4, -5]) == True\n assert below_zero([1, -2, 2, -2, 5, -5, 4, -4]) == True\n\ncheck(below_zero)"}
5 | {"prompt": "from typing import List\n\n\ndef mean_absolute_deviation(numbers: List[float]) -> float:\n \"\"\" For a given list of input numbers, calculate Mean Absolute Deviation\n around the mean of this dataset.\n Mean Absolute Deviation is the average absolute difference between each\n element and a centerpoint (mean in this case):\n MAD = average | x - x_mean |\n >>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0])\n 1.0\n \"\"\"\n", "label": "\n\n\n\n\ndef check(mean_absolute_deviation):\n assert abs(mean_absolute_deviation([1.0, 2.0, 3.0]) - 2.0/3.0) < 1e-6\n assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0]) - 1.0) < 1e-6\n assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0, 5.0]) - 6.0/5.0) < 1e-6\n\ncheck(mean_absolute_deviation)"}
6 | {"prompt": "from typing import List\n\n\ndef intersperse(numbers: List[int], delimeter: int) -> List[int]:\n \"\"\" Insert a number 'delimeter' between every two consecutive elements of input list `numbers'\n >>> intersperse([], 4)\n []\n >>> intersperse([1, 2, 3], 4)\n [1, 4, 2, 4, 3]\n \"\"\"\n", "label": "\n\n\n\n\ndef check(intersperse):\n assert intersperse([], 7) == []\n assert intersperse([5, 6, 3, 2], 8) == [5, 8, 6, 8, 3, 8, 2]\n assert intersperse([2, 2, 2], 2) == [2, 2, 2, 2, 2]\n\ncheck(intersperse)"}
7 | {"prompt": "from typing import List\n\n\ndef parse_nested_parens(paren_string: str) -> List[int]:\n \"\"\" Input to this function is a string represented multiple groups for nested parentheses separated by spaces.\n For each of the group, output the deepest level of nesting of parentheses.\n E.g. (()()) has maximum two levels of nesting while ((())) has three.\n\n >>> parse_nested_parens('(()()) ((())) () ((())()())')\n [2, 3, 1, 3]\n \"\"\"\n", "label": "\n\n\n\n\ndef check(parse_nested_parens):\n assert parse_nested_parens('(()()) ((())) () ((())()())') == [2, 3, 1, 3]\n assert parse_nested_parens('() (()) ((())) (((())))') == [1, 2, 3, 4]\n assert parse_nested_parens('(()(())((())))') == [4]\n\ncheck(parse_nested_parens)"}
8 | {"prompt": "from typing import List\n\n\ndef filter_by_substring(strings: List[str], substring: str) -> List[str]:\n \"\"\" Filter an input list of strings only for ones that contain given substring\n >>> filter_by_substring([], 'a')\n []\n >>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')\n ['abc', 'bacd', 'array']\n \"\"\"\n", "label": "\n\n\n\n\ndef check(filter_by_substring):\n assert filter_by_substring([], 'john') == []\n assert filter_by_substring(['xxx', 'asd', 'xxy', 'john doe', 'xxxAAA', 'xxx'], 'xxx') == ['xxx', 'xxxAAA', 'xxx']\n assert filter_by_substring(['xxx', 'asd', 'aaaxxy', 'john doe', 'xxxAAA', 'xxx'], 'xx') == ['xxx', 'aaaxxy', 'xxxAAA', 'xxx']\n assert filter_by_substring(['grunt', 'trumpet', 'prune', 'gruesome'], 'run') == ['grunt', 'prune']\n\ncheck(filter_by_substring)"}
9 | {"prompt": "from typing import List, Tuple\n\n\ndef sum_product(numbers: List[int]) -> Tuple[int, int]:\n \"\"\" For a given list of integers, return a tuple consisting of a sum and a product of all the integers in a list.\n Empty sum should be equal to 0 and empty product should be equal to 1.\n >>> sum_product([])\n (0, 1)\n >>> sum_product([1, 2, 3, 4])\n (10, 24)\n \"\"\"\n", "label": "\n\n\n\n\ndef check(sum_product):\n assert sum_product([]) == (0, 1)\n assert sum_product([1, 1, 1]) == (3, 1)\n assert sum_product([100, 0]) == (100, 0)\n assert sum_product([3, 5, 7]) == (3 + 5 + 7, 3 * 5 * 7)\n assert sum_product([10]) == (10, 10)\n\ncheck(sum_product)"}
10 | {"prompt": "from typing import List, Tuple\n\n\ndef rolling_max(numbers: List[int]) -> List[int]:\n \"\"\" From a given list of integers, generate a list of rolling maximum element found until given moment\n in the sequence.\n >>> rolling_max([1, 2, 3, 2, 3, 4, 2])\n [1, 2, 3, 3, 3, 4, 4]\n \"\"\"\n", "label": "\n\n\n\n\ndef check(rolling_max):\n assert rolling_max([]) == []\n assert rolling_max([1, 2, 3, 4]) == [1, 2, 3, 4]\n assert rolling_max([4, 3, 2, 1]) == [4, 4, 4, 4]\n assert rolling_max([3, 2, 3, 100, 3]) == [3, 3, 3, 100, 100]\n\ncheck(rolling_max)"}
11 | {"prompt": "\n\ndef is_palindrome(string: str) -> bool:\n \"\"\" Test if given string is a palindrome \"\"\"\n return string == string[::-1]\n\n\ndef make_palindrome(string: str) -> str:\n \"\"\" Find the shortest palindrome that begins with a supplied string.\n Algorithm idea is simple:\n - Find the longest postfix of supplied string that is a palindrome.\n - Append to the end of the string reverse of a string prefix that comes before the palindromic suffix.\n >>> make_palindrome('')\n ''\n >>> make_palindrome('cat')\n 'catac'\n >>> make_palindrome('cata')\n 'catac'\n \"\"\"\n", "label": "\n\n\n\n\ndef check(make_palindrome):\n assert make_palindrome('') == ''\n assert make_palindrome('x') == 'x'\n assert make_palindrome('xyz') == 'xyzyx'\n assert make_palindrome('xyx') == 'xyx'\n assert make_palindrome('jerry') == 'jerryrrej'\n\ncheck(make_palindrome)"}
12 | {"prompt": "from typing import List\n\n\ndef string_xor(a: str, b: str) -> str:\n \"\"\" Input are two strings a and b consisting only of 1s and 0s.\n Perform binary XOR on these inputs and return result also as a string.\n >>> string_xor('010', '110')\n '100'\n \"\"\"\n", "label": "\n\n\n\n\ndef check(string_xor):\n assert string_xor('111000', '101010') == '010010'\n assert string_xor('1', '1') == '0'\n assert string_xor('0101', '0000') == '0101'\n\ncheck(string_xor)"}
--------------------------------------------------------------------------------
/evaluate/README.md:
--------------------------------------------------------------------------------
1 | # Code Evaluation
2 |
3 | 面向代码大模型评测,目前支持的有humaneval,后续会逐步新增一些。
4 |
5 | 注意:
6 | 只能在Linux机器上进行,windows上 execution 部分有错误
7 |
8 |
9 | ## Quick Start
10 |
11 | evaluate文件下的run.sh作为一个启动示例,详细参数解释可见args.py。
12 |
13 | 其中评测集数据应为jsonl格式
14 | ```bash
15 | bash run.sh
16 | ```
17 |
18 | ### 评测生成文件
19 |
20 | 模型评测完成后主要生成三个文件:
21 |
22 | 1、out.jsonl: 模型输出,在评测数据的基础上新增字段如下:
23 | - output:模型接收prompt后产生的原本输出
24 | - generation: 经过提取代码部分及添加测试后的输出
25 |
26 | 2、logs.jsonl: 评测测试用例运行的信息
27 |
28 | 3、metric.json: 评测结果指标
29 |
30 | ## 新增评测
31 | 如若想要新增评测任务,可以继承base_utils中的类进行相关设置,然后在task文件夹下创建相关文件进行继承。
32 | 最后在main.py文件的TASKS中添加即可。
--------------------------------------------------------------------------------
/evaluate/args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | torch_dtype = ['bf16', 'fp16', 'fp32']
4 | task_name = ['humaneval']
5 |
6 |
7 | @dataclass
8 | class EvaluateArgs:
9 | """
10 | 配置Evaluate的参数
11 | """
12 |
13 | """任务名,支持的任务见列表task_name"""
14 | task_name: str = 'humaneval'
15 |
16 | """模型生成的一些参数设置"""
17 | max_new_tokens: int = 100
18 | torch_dtype: str = 'fp16'
19 | do_sample: bool = False
20 | top_p: float = 0.95
21 | temperature: int = 1
22 |
23 | """模型路径"""
24 | model_name_or_path: str = './'
25 | """是否仅评测模式,若为False,则下面的evaluate_data_path不用填"""
26 | evaluate_only: bool = False
27 | """仅评测时的文件路径"""
28 | evaluate_data_path: str = ''
29 | """模型生成的输出路径"""
30 | output_path: str = './'
31 | """模型评测时输出的信息:pass或错误"""
32 | save_logs_path: str = './'
33 | """模型结果保存路径"""
34 | save_metrics_path: str = './'
35 | """评测所需要的测试集"""
36 | data_file: str = ''
37 |
--------------------------------------------------------------------------------
/evaluate/base_utils.py:
--------------------------------------------------------------------------------
1 | class TaskUtils:
2 | def __init__(self):
3 | self.IMPORT_HELPER = {
4 | "python": [
5 | "import math",
6 | "import re",
7 | "import sys",
8 | "import copy",
9 | "import datetime",
10 | "import itertools",
11 | "import collections",
12 | "import heapq",
13 | "import functools",
14 | "import hashlib",
15 | "import numpy",
16 | "import numpy as np",
17 | "import string",
18 | "from typing import *",
19 | "from collections import *",
20 | ]
21 | }
22 |
23 | @staticmethod
24 | def build_instruction(example):
25 | """
26 | 根据模型构建合适的指令
27 | """
28 | return example['prompt']
29 |
30 | @staticmethod
31 | def generation_code_process(example):
32 | """
33 | 对生成的代码提取函数部分 及 设置import、添加test用例等操作
34 | """
35 | pass
36 |
37 | @staticmethod
38 | def evaluate_function(input_file, args):
39 | """
40 | 最终评测的方法,输入为保存的生成jsonl文件
41 | """
42 | pass
43 |
--------------------------------------------------------------------------------
/evaluate/evaluation.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | from concurrent.futures import ThreadPoolExecutor, as_completed
4 | import numpy as np
5 | from tqdm import tqdm
6 | from execution import check_correctness
7 |
8 |
9 | def stream_jsonl_all(filename: str):
10 | """
11 | Streams a JSONL file.
12 | """
13 | results = []
14 | fp = open(filename, "r")
15 | for line in fp:
16 | if any(not x.isspace() for x in line):
17 | results.append(json.loads(line))
18 | fp.close()
19 | return results
20 |
21 |
22 | # 计算pass@1
23 | def evaluate_functional_correctness(
24 | input_file: str = None,
25 | n_workers: int = 32,
26 | timeout: float = 3.0,
27 | k: int = 1,
28 | save_logs_path='./logs.jsonl'
29 | ):
30 | """
31 | Evaluates the functional correctness of a model.
32 | """
33 | sample_jsonl = stream_jsonl_all(input_file)
34 |
35 | with ThreadPoolExecutor(max_workers=n_workers) as executor:
36 |
37 | futures = []
38 | n_samples = 0
39 | results = defaultdict(list)
40 |
41 | print("Reading samples...")
42 | for sample in tqdm(sample_jsonl):
43 | task_id = sample["task_id"]
44 | if sample["generation"] is None:
45 | continue
46 | args = (sample['generation'], task_id, timeout)
47 | future = executor.submit(check_correctness, *args)
48 | futures.append(future)
49 | n_samples += 1
50 |
51 | print("Running test suites...")
52 | for future in tqdm(as_completed(futures), total=len(futures)):
53 | result = future.result()
54 | results[result["task_id"]].append(result)
55 | # Calculate pass@k.
56 | total, correct, logs = [], [], []
57 | for result in results.values():
58 | passed = [r["passed"] for r in result]
59 | res = [{r['task_id']: r["result"]} for r in result]
60 | logs.append(res)
61 | total.append(len(passed))
62 | correct.append(sum(passed))
63 | total = np.array(total)
64 | correct = np.array(correct)
65 |
66 | pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()}
67 |
68 | with open(save_logs_path, 'w', encoding='utf-8') as fw:
69 | for ex in logs:
70 | fw.write(json.dumps(ex) + '\n')
71 | print(f"execute logs were saved at {save_logs_path}")
72 |
73 | return pass_at_k
74 |
75 |
76 | def estimate_pass_at_k(
77 | num_samples,
78 | num_correct,
79 | k: int
80 | ) -> np.ndarray:
81 | """
82 | Estimates pass@k and returns them in an array.
83 | """
84 |
85 | def estimator(n: int, c: int, k: int) -> float:
86 | """
87 | Calculates 1 - comb(n - c, k) / comb(n, k).
88 | """
89 | if n - c < k:
90 | return 1.0
91 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
92 |
93 | assert len(num_samples) == len(num_correct)
94 | num_samples_it = iter(num_samples)
95 |
96 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
97 |
98 |
99 |
--------------------------------------------------------------------------------
/evaluate/execution.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import faulthandler
3 | import io
4 | import multiprocessing
5 | import os
6 | import platform
7 | import signal
8 | import tempfile
9 |
10 |
11 | @contextlib.contextmanager
12 | def chdir(root):
13 | if root == ".":
14 | yield
15 | return
16 | cwd = os.getcwd()
17 | os.chdir(root)
18 | try:
19 | yield
20 | except BaseException as exc:
21 | raise exc
22 | finally:
23 | os.chdir(cwd)
24 |
25 |
26 | @contextlib.contextmanager
27 | def create_tempdir():
28 | with tempfile.TemporaryDirectory() as dirname:
29 | with chdir(dirname):
30 | yield dirname
31 |
32 |
33 | @contextlib.contextmanager
34 | def swallow_io():
35 | stream = WriteOnlyStringIO()
36 | with contextlib.redirect_stdout(stream):
37 | with contextlib.redirect_stderr(stream):
38 | with redirect_stdin(stream):
39 | yield
40 |
41 |
42 | @contextlib.contextmanager
43 | def time_limit(seconds):
44 | def signal_handler(signum, frame):
45 | raise TimeoutException("Timed out!")
46 |
47 | signal.setitimer(signal.ITIMER_REAL, seconds)
48 | signal.signal(signal.SIGALRM, signal_handler)
49 | try:
50 | yield
51 | finally:
52 | signal.setitimer(signal.ITIMER_REAL, 0)
53 |
54 |
55 | class redirect_stdin(contextlib._RedirectStream): # type: ignore
56 | _stream = "stdin"
57 |
58 |
59 | def check_correctness(check_program, task_id, timeout=3):
60 | manager = multiprocessing.Manager()
61 | result = manager.list()
62 |
63 | p = multiprocessing.Process(target=unsafe_execute, args=(check_program, result, timeout))
64 | p.start()
65 | p.join(timeout=timeout + 1)
66 | if p.is_alive():
67 | p.kill()
68 |
69 | if not result:
70 | result.append("timed out")
71 |
72 | return {
73 | "task_id": task_id,
74 | "passed": result[0] == "passed",
75 | "result": result[0]
76 | }
77 |
78 |
79 | def unsafe_execute(check_program, result, timeout):
80 | with create_tempdir():
81 |
82 | # These system calls are needed when cleaning up tempdir.
83 | import os
84 | import shutil
85 |
86 | rmtree = shutil.rmtree
87 | rmdir = os.rmdir
88 | chdir = os.chdir
89 |
90 | # Disable functionalities that can make destructive changes to the test.
91 | reliability_guard()
92 |
93 | # Run program.
94 | try:
95 | exec_globals = {}
96 | with swallow_io():
97 | with time_limit(timeout):
98 | exec(check_program, exec_globals)
99 | result.append("passed")
100 | except TimeoutException:
101 | result.append("timed out")
102 | except BaseException as e:
103 | result.append(f"failed: {e}")
104 |
105 | # Needed for cleaning up.
106 | shutil.rmtree = rmtree
107 | os.rmdir = rmdir
108 | os.chdir = chdir
109 |
110 |
111 | class WriteOnlyStringIO(io.StringIO):
112 | """StringIO that throws an exception when it's read from"""
113 |
114 | def read(self, *args, **kwargs):
115 | raise OSError
116 |
117 | def readline(self, *args, **kwargs):
118 | raise OSError
119 |
120 | def readlines(self, *args, **kwargs):
121 | raise OSError
122 |
123 | def readable(self, *args, **kwargs):
124 | """Returns True if the IO object can be read."""
125 | return False
126 |
127 |
128 | class TimeoutException(Exception):
129 | pass
130 |
131 |
132 | def reliability_guard(maximum_memory_bytes=None):
133 | """
134 | This disables various destructive functions and prevents the generated code
135 | from interfering with the test (e.g. fork bomb, killing other processes,
136 | removing filesystem files, etc.)
137 |
138 | WARNING
139 | This function is NOT a security sandbox. Untrusted code, including, model-
140 | generated code, should not be blindly executed outside of one. See the
141 | Codex paper for more information about OpenAI's code sandbox, and proceed
142 | with caution.
143 | """
144 |
145 | if maximum_memory_bytes is not None:
146 | import resource
147 |
148 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
149 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
150 | if not platform.uname().system == "Darwin":
151 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
152 |
153 | faulthandler.disable()
154 |
155 | import builtins
156 |
157 | builtins.exit = None
158 | builtins.quit = None
159 |
160 | import os
161 |
162 | os.environ["OMP_NUM_THREADS"] = "1"
163 |
164 | os.kill = None
165 | os.system = None
166 | os.putenv = None
167 | os.remove = None
168 | os.removedirs = None
169 | os.rmdir = None
170 | os.fchdir = None
171 | os.setuid = None
172 | os.fork = None
173 | os.forkpty = None
174 | os.killpg = None
175 | os.rename = None
176 | os.renames = None
177 | os.truncate = None
178 | os.replace = None
179 | os.unlink = None
180 | os.fchmod = None
181 | os.fchown = None
182 | os.chmod = None
183 | os.chown = None
184 | os.chroot = None
185 | os.fchdir = None
186 | os.lchflags = None
187 | os.lchmod = None
188 | os.lchown = None
189 | os.getcwd = None
190 | os.chdir = None
191 |
192 | import shutil
193 |
194 | shutil.rmtree = None
195 | shutil.move = None
196 | shutil.chown = None
197 |
198 | import subprocess
199 |
200 | subprocess.Popen = None # type: ignore
201 |
202 | __builtins__["help"] = None
203 |
204 | import sys
205 |
206 | sys.modules["ipdb"] = None
207 | sys.modules["joblib"] = None
208 | sys.modules["resource"] = None
209 | sys.modules["psutil"] = None
210 | sys.modules["tkinter"] = None
211 |
--------------------------------------------------------------------------------
/evaluate/generate.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from tqdm import tqdm
4 | from transformers import AutoTokenizer, AutoModelForCausalLM
5 |
6 |
7 | def generate_one(example, tokenizer, model, args, task):
8 | prompt = task.build_instruction(example)
9 | inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
10 |
11 | stop_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.convert_tokens_to_ids(
12 | "<|EOT|>")
13 | assert isinstance(stop_id, int), "Invalid tokenizer, EOT id not found"
14 |
15 | outputs = model.generate(
16 | inputs,
17 | max_new_tokens=args.max_new_tokens,
18 | do_sample=args.do_sample,
19 | top_p=args.top_p,
20 | temperature=args.temperature,
21 | pad_token_id=stop_id,
22 | eos_token_id=stop_id
23 | )
24 |
25 | output = tokenizer.decode(outputs[0][:], skip_special_tokens=True)
26 | example['output'] = output
27 |
28 | return task.generation_code_process(example)
29 |
30 |
31 | def generate_main(args, task):
32 | model_name_or_path = args.model_name_or_path
33 | saved_path = args.output_path
34 |
35 | print("model", model_name_or_path)
36 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
37 | print("load tokenizer {} from {} over.".format(tokenizer.__class__, model_name_or_path))
38 | torch_dtype = torch.bfloat16 if args.torch_dtype == 'bf16' else torch.float16 if args.torch_dtype == 'fp16' else torch.float32
39 | model = AutoModelForCausalLM.from_pretrained(
40 | model_name_or_path,
41 | torch_dtype=torch_dtype,
42 | device_map="auto",
43 | trust_remote_code=True,
44 | )
45 | model.eval()
46 | examples = [json.loads(x) for x in open(args.data_file) if x.strip()]
47 | print("Read {} examples for evaluation over.".format(len(examples)))
48 |
49 | generated_examples = []
50 | for ex in tqdm(examples, desc='Generating'):
51 | gen_example = generate_one(ex, tokenizer, model, args, task)
52 | generated_examples.append(gen_example)
53 |
54 | print("Generate all over!!!")
55 | with open(saved_path, 'w', encoding='utf-8') as fw:
56 | for ex in generated_examples:
57 | fw.write(json.dumps(ex) + '\n')
58 | print("Save {} processed examples into {} over!".format(len(generated_examples), saved_path))
59 |
60 | result = task.evaluate_function(saved_path,args)
61 | save_metrics(args, result)
62 | print(result, model_name_or_path)
63 |
64 |
65 | def evaluation_only(args, task):
66 | result = task.evaluate_function(args.evaluate_data_path, args)
67 | save_metrics(args, result)
68 | print(result, args.model_name_or_path)
69 |
70 |
71 | def save_metrics(args, result):
72 | args_dict = args.__dict__
73 | with open(args.save_metrics_path, 'w', encoding='utf-8') as fw:
74 | fw.write(json.dumps(result) + '\n')
75 | fw.write(json.dumps(args_dict) + '\n')
76 |
77 |
--------------------------------------------------------------------------------
/evaluate/main.py:
--------------------------------------------------------------------------------
1 | from transformers import HfArgumentParser
2 | import os
3 | from args import EvaluateArgs
4 | from task import humaneval
5 | from generate import generate_main, evaluation_only
6 |
7 | parser = HfArgumentParser((EvaluateArgs,))
8 | args = parser.parse_args_into_dataclasses()[0]
9 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
10 |
11 | # 任务列表
12 | TASKS = {
13 | "humaneval": humaneval.HumanEval()
14 | }
15 |
16 | task = TASKS[args.task_name]
17 |
18 | if not args.evaluate_only:
19 | generate_main(args, task)
20 | else:
21 | evaluation_only(args, task)
22 |
--------------------------------------------------------------------------------
/evaluate/run.sh:
--------------------------------------------------------------------------------
1 | MODELS_PATH="/qwen"
2 | LOGS_PATH="./logs.jsonl"
3 | OUT_PATH='./out.jsonl'
4 | METRIC_PATH='./metric.json'
5 | DATA_FILE='./dataset/humaneval_python.jsonl'
6 |
7 |
8 | CUDA_VISIBLE_DEVICES=0 python main.py \
9 | --model_name_or_path "$MODELS_PATH" \
10 | --task_name "humaneval" \
11 | --save_logs_path "$LOGS_PATH" \
12 | --output_path "$OUT_PATH" \
13 | --do_sample false \
14 | --top_p 0.95 \
15 | --max_new_tokens 1024 \
16 | --evaluate_only false \
17 | --torch_dtype "bf16" \
18 | --save_metrics_path $METRIC_PATH \
19 | --data_file $DATA_FILE
--------------------------------------------------------------------------------
/evaluate/task/humaneval.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | from base_utils import TaskUtils
4 | from evaluation import evaluate_functional_correctness
5 |
6 |
7 | class HumanEval(TaskUtils):
8 | def __init__(self):
9 | super().__init__()
10 |
11 | @staticmethod
12 | def build_instruction(example):
13 | """
14 | 根据模型构建合适的指令
15 | """
16 | return example['prompt']
17 |
18 | def generation_code_process(self, example):
19 | """
20 | 对生成的代码提取函数部分 及 设置import等操作
21 | """
22 | code = example['output']
23 | test_case = example['test']
24 | code_ = []
25 | skip_rest = False # 新增标志位,用于跳过if __name__ == "__main__"及其后面的内容
26 | for line in code.split("\n"):
27 | if skip_rest:
28 | continue # 如果遇到if __name__ == "__main__",跳过该行及其后面的所有内容
29 | if any(keyword in line for keyword in ["if __name__ == \"__main__\":", "if __name__ == \'__main__\':"]):
30 | skip_rest = True # 设置标志位,表示需要跳过后续内容
31 | continue
32 | if "def " in line and line[0] != ' ' and line[0] != '\t':
33 | code_.append("def " + line.split("def ")[1])
34 | continue
35 | if "class" in line and line.strip().endswith(":"):
36 | code_.append(line)
37 | continue
38 | if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
39 | continue
40 | code_.append(line)
41 | code = "\n".join(code_)
42 | test_setup = "\n".join(self.IMPORT_HELPER["python"]) + "\n"
43 | example['generation'] = test_setup + code + "\n" + test_case + "\n"
44 | return example
45 |
46 | @staticmethod
47 | def evaluate_function(input_file, args):
48 | """
49 | 最终评测的方法,输入为保存的生成jsonl文件
50 | """
51 | return evaluate_functional_correctness(input_file, n_workers=1, timeout=3.0, k=1, save_logs_path=args.save_logs_path)
52 |
53 |
--------------------------------------------------------------------------------
/inference_vllm.py:
--------------------------------------------------------------------------------
1 | from vllm import LLM, SamplingParams
2 | import os
3 | import torch
4 | from transformers import AutoTokenizer
5 |
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
7 | model_name_or_path = 'DeepSeek-R1-Distill-Qwen-32B' # model path
8 | llm = LLM(
9 | model=model_name_or_path,
10 | max_model_len=8192,
11 | device='cuda',
12 | dtype=torch.bfloat16,
13 | tensor_parallel_size=8 # CUDA_VISIBLE_DEVICES 数量
14 | )
15 |
16 | prompt = "请帮我生成一个关于夏天的诗歌。"
17 |
18 |
19 | TOKENIZER = AutoTokenizer.from_pretrained(model_name_or_path)
20 | messages = [
21 | {"role": "user", "content": prompt}
22 | ]
23 | text = TOKENIZER.apply_chat_template(
24 | messages,
25 | tokenize=False,
26 | add_generation_prompt=True
27 | )
28 |
29 | prompts = [text]
30 |
31 | sampling_params = SamplingParams(
32 | max_tokens=8192,
33 | top_p=0.9,
34 | top_k=1,
35 | temperature=0.0,
36 | repetition_penalty=1.0,
37 | )
38 | outputs = llm.generate(prompts, sampling_params)
39 | for output in outputs:
40 | prompt = output.prompt
41 | generated_text = output.outputs[0].text
42 | print(f"Prompt:\n{prompt}")
43 | print(f"Generated text:\n {generated_text}")
44 |
--------------------------------------------------------------------------------
/llm_tricks/DPO_example/README.md:
--------------------------------------------------------------------------------
1 | # 从零实现强化学习DPO(SimPO)训练代码
2 |
3 | ## Quick start
4 | ```python
5 | python dpo_train.py
6 | ```
7 |
8 | ## 说明
9 | 本文档下的从零实现只是一个学习的demo,用以理解原理所用,并没有增加分布式等。所以尽管使用2B的小模型,显存占用也高达30+GB。
10 |
11 | 精度设置fp16可能会出现loss 为nan的现象
12 |
13 | ```dpo_train.py```为训练主路径, 相关loss计算在```loss.py```.
14 |
15 | 如果想要使用DPO或者Simpo、CPO等强化学习方法真正训练的话,
16 | 可以使用本项目中的rlhf构建的强化学习框架:[RLHF](../../rlhf/README.md)
17 |
18 | 支持deepspeed的单机多卡Lora、Dora、Qlora、全量参数训练,并自动适配模型的chat template。
--------------------------------------------------------------------------------
/llm_tricks/DPO_example/dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import json
3 |
4 |
5 | class RlhfDataset(Dataset):
6 | def __init__(self, file_path, tokenizer):
7 | with open(file_path, "r", encoding="utf-8") as file:
8 | data_list = file.readlines()
9 | self.data_list = data_list
10 | self.tokenizer = tokenizer
11 |
12 | def __getitem__(self, item):
13 | data = self.data_list[item]
14 | data = json.loads(data)
15 | prompt = data['prompt']
16 | chosen = data['chosen']
17 | rejected = data['rejected']
18 |
19 | chosen_full_text = f"{prompt}\n\n### Response:\n{chosen}"
20 | rejected_full_text = f"{prompt}\n\n### Response:\n{rejected}"
21 |
22 | prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
23 | chosen_full_tokens = self.tokenizer.encode(chosen_full_text, add_special_tokens=False)
24 | rejected_full_tokens = self.tokenizer.encode(rejected_full_text, add_special_tokens=False)
25 |
26 | input = {
27 | "prompt": prompt_tokens,
28 | "chosen": chosen_full_tokens,
29 | "rejected": rejected_full_tokens,
30 | }
31 | return input
32 |
33 | def __len__(self):
34 | return len(self.data_list)
--------------------------------------------------------------------------------
/llm_tricks/DPO_example/dpo_train.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader, random_split
2 | import torch
3 | from dataset import RlhfDataset
4 | from transformers import AutoTokenizer, AutoModelForCausalLM
5 | from loss import compute_batch_loss
6 | from evaluate import evaluate_loss_dataloader
7 | import time
8 | from functools import partial
9 |
10 | # 1、加载模型与tokenizer
11 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
12 | model_path = '/IndexTeam/Index-1___9B-Chat'
13 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
14 | ref_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
15 | ref_model.eval()
16 | model.to(device)
17 | ref_model.to(device)
18 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
19 |
20 | # 2、处理数据
21 | # 加载数据
22 | data_file = './unsloth_dpo.jsonl'
23 | # Dataset详细逻辑可看进入RlhfDataset实现
24 | dataset = RlhfDataset(data_file, tokenizer)
25 | # 划分训练集验证集
26 | train_size = int(len(dataset) * 0.85) # 85% for training
27 | val_size = len(dataset) - train_size # Remaining for validation
28 | train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
29 |
30 | # 编写batch批次的padding及mask处理函数
31 | IGNORE_INDEX = False
32 |
33 |
34 | def data_collate(batch, pad_token_id, device, max_length=None, if_mask_prompt=True):
35 | batch_data = {
36 | "prompt": [],
37 | "chosen": [],
38 | "rejected": [],
39 | "rejected_mask": [],
40 | "chosen_mask": []
41 | }
42 |
43 | # 判断长度及padding
44 | max_length_common = 0
45 | for key in ["chosen", "rejected"]:
46 | current_max = max(len(item[key]) for item in batch)
47 | max_length_common = max(max_length_common, current_max)
48 |
49 | # 转为torch tensor并padding,决定是否对prompt进行mask
50 | for item in batch:
51 | prompt = torch.tensor(item['prompt'])
52 | batch_data['prompt'].append(prompt)
53 |
54 | for key in ["chosen", "rejected"]:
55 | out = item[key]
56 | out_padding = out + [pad_token_id] * (max_length_common - len(out))
57 | mask = torch.ones(len(out_padding)).bool()
58 |
59 | # padding部分的mask设置为 IGNORE_INDEX
60 | mask[len(out):] = IGNORE_INDEX
61 |
62 | if if_mask_prompt:
63 | mask[:prompt.shape[0] + 2] = IGNORE_INDEX
64 | batch_data[key].append(torch.tensor(out_padding))
65 | batch_data[f"{key}_mask"].append(mask)
66 |
67 | # 进行最大长度截断
68 | for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
69 | tensor_stack = torch.stack(batch_data[key])
70 | if max_length is not None:
71 | tensor_stack = tensor_stack[:, :max_length]
72 | # 将tensor移到对应的device
73 | batch_data[key] = tensor_stack.to(device)
74 | return batch_data
75 |
76 |
77 | customized_collate_fn = partial(
78 | data_collate,
79 | pad_token_id=tokenizer.pad_token_id,
80 | device=device,
81 | if_mask_prompt=True,
82 | max_length=1024
83 | )
84 | # 设置相关参数
85 | batch_size = 4
86 | train_loader = DataLoader(
87 | train_dataset,
88 | batch_size=batch_size,
89 | collate_fn=customized_collate_fn,
90 | shuffle=True,
91 | drop_last=True
92 | )
93 | val_loader = DataLoader(
94 | val_dataset,
95 | batch_size=1,
96 | collate_fn=customized_collate_fn,
97 | shuffle=False,
98 | drop_last=False
99 | )
100 |
101 |
102 | # 3、开始计算DPO(或其他)的损失函数
103 | # 相关代码可以再loss里查看,就不写在主函数里了。
104 |
105 | # 4、编写训练函数
106 | def train_model(
107 | policy_model, reference_model, train_loader, val_loader,
108 | optimizer, num_epochs, beta,
109 | eval_freq, eval_iter):
110 | tracking = {
111 | "train_losses": [],
112 | "train_chosen_rewards": [],
113 | "train_rejected_rewards": [],
114 | "val_losses": [],
115 | "val_chosen_rewards": [],
116 | "val_rejected_rewards": [],
117 | "tokens_seen": []
118 | }
119 | tokens_seen, global_step = 0, -1
120 |
121 | # 训练
122 | for epoch in range(num_epochs):
123 | # policy 模型需要训练
124 | policy_model.train()
125 |
126 | for idx, batch in enumerate(train_loader):
127 | optimizer.zero_grad()
128 |
129 | loss, chosen_rewards, rejected_rewards = compute_batch_loss(
130 | batch=batch,
131 | policy_model=policy_model,
132 | reference_model=reference_model,
133 | beta=beta
134 | )
135 | loss.backward()
136 | optimizer.step()
137 |
138 | global_step += 1
139 | tokens_seen += batch["chosen"].numel()
140 |
141 | # 验证
142 | if global_step % eval_freq == 0:
143 | res = evaluate_loss_dataloader(
144 | policy_model=policy_model,
145 | reference_model=reference_model,
146 | train_loader=train_loader,
147 | val_loader=val_loader,
148 | beta=beta,
149 | eval_iter=eval_iter
150 | )
151 | tracking["train_losses"].append(res["train_loss"])
152 | tracking["train_chosen_rewards"].append(res["train_chosen_reward"])
153 | tracking["train_rejected_rewards"].append(res["train_rejected_reward"])
154 | tracking["val_losses"].append(res["val_loss"])
155 | tracking["val_chosen_rewards"].append(res["val_chosen_reward"])
156 | tracking["val_rejected_rewards"].append(res["val_rejected_reward"])
157 | tracking["tokens_seen"].append(tokens_seen)
158 | train_reward_margin = res["train_chosen_reward"] - res["train_rejected_reward"]
159 | val_reward_margin = res["val_chosen_reward"] - res["val_rejected_reward"]
160 |
161 | print(
162 | f"Ep {epoch + 1} (Step {global_step:06d}): "
163 | f"Train loss {res['train_loss']:.3f}, Val loss {res['val_loss']:.3f}, "
164 | f"Train reward margins {train_reward_margin:.3f}, "
165 | f"Val reward margins {val_reward_margin:.3f}"
166 | )
167 |
168 | return tracking
169 |
170 |
171 | # 5、开始训练!
172 | def main():
173 | torch.manual_seed(42)
174 | start_time = time.time()
175 | optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
176 |
177 | num_epochs = 3
178 | tracking = train_model(
179 | policy_model=model,
180 | reference_model=ref_model,
181 | train_loader=train_loader,
182 | val_loader=val_loader,
183 | optimizer=optimizer,
184 | num_epochs=num_epochs,
185 | beta=0.1, # value between 0.1 and 0.5
186 | eval_freq=2,
187 | eval_iter=2
188 | )
189 |
190 | end_time = time.time()
191 | execution_time_minutes = (end_time - start_time) / 60
192 | print(f"Training completed in {execution_time_minutes:.2f} minutes.")
193 |
194 |
195 | if __name__ == "__main__":
196 | main()
197 |
--------------------------------------------------------------------------------
/llm_tricks/DPO_example/evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from loss import compute_loss_dataloader
3 |
4 |
5 | def evaluate_loss_dataloader(policy_model, reference_model, train_loader, val_loader, beta, eval_iter):
6 | policy_model.eval()
7 | with torch.no_grad():
8 | train_loss, train_chosen_rewards, train_rejected_rewards = compute_loss_dataloader(
9 | data_loader=train_loader,
10 | policy_model=policy_model,
11 | reference_model=reference_model,
12 | beta=beta,
13 | num_batches=eval_iter
14 | )
15 | val_loss, val_chosen_rewards, val_rejected_rewards = compute_loss_dataloader(
16 | data_loader=val_loader,
17 | policy_model=policy_model,
18 | reference_model=reference_model,
19 | beta=beta,
20 | num_batches=eval_iter
21 | )
22 | res = {
23 | "train_loss": train_loss,
24 | "train_chosen_reward": train_chosen_rewards,
25 | "train_rejected_reward": train_rejected_rewards,
26 | "val_loss": val_loss,
27 | "val_chosen_reward": val_chosen_rewards,
28 | "val_rejected_reward": val_rejected_rewards
29 | }
30 |
31 | policy_model.train()
32 | return res
33 |
--------------------------------------------------------------------------------
/llm_tricks/DPO_example/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | # 计算DPO loss的公式
7 | class DPOLoss(nn.Module):
8 | """
9 | DPO Loss
10 | """
11 |
12 | def __init__(self, beta: float = 0.1) -> None:
13 | super().__init__()
14 | self.beta = beta
15 |
16 | def forward(
17 | self,
18 | policy_chosen_logps: torch.Tensor,
19 | policy_rejected_logps: torch.Tensor,
20 | reference_chosen_logps: torch.Tensor,
21 | reference_rejected_logps: torch.Tensor,
22 | ):
23 | """
24 | policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
25 | policy_rejected_logps: Shape: (batch_size,)
26 | reference_chosen_logps: Shape: (batch_size,)
27 | reference_rejected_logps: Shape: (batch_size,)
28 | """
29 | policy_logps = policy_chosen_logps - policy_rejected_logps
30 | reference_logps = reference_chosen_logps - reference_rejected_logps
31 | logits = policy_logps - reference_logps
32 |
33 | loss = -F.logsigmoid(self.beta * logits)
34 |
35 | # 下面两个用于追踪训练的进度
36 | chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach()
37 | rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach()
38 |
39 | # 对每个batch进行平均(期望)
40 | return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()
41 |
42 |
43 | class SimPo(nn.Module):
44 | """
45 | SimPO Loss
46 | """
47 |
48 | def __init__(self, beta: float = 0.1, gamma: float = 0.5) -> None:
49 | super().__init__()
50 | self.beta = beta
51 | self.gamma = gamma
52 |
53 | def forward(
54 | self,
55 | policy_chosen_logps: torch.Tensor,
56 | policy_rejected_logps: torch.Tensor,
57 | ):
58 | """
59 | policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
60 | policy_rejected_logps: Shape: (batch_size,)
61 | """
62 | logits = policy_chosen_logps - policy_rejected_logps
63 | logits = logits - self.gamma
64 | loss = -F.logsigmoid(self.beta * logits)
65 |
66 | # 对每个batch进行平均(期望)
67 | return loss.mean()
68 |
69 |
70 | # 计算每个模型的Log probabilities
71 | def compute_logprobs(logits, labels, mask=None):
72 | """
73 | logits: shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果
74 | labels: shape (batch_size, sequence_len)
75 | """
76 |
77 | # 需要先进行位移操作
78 | # 去掉标签的第一个
79 | labels = labels[:, 1:].clone()
80 | # 去掉模型输出的最后一个
81 | logits = logits[:, :-1, :]
82 |
83 | logps = F.log_softmax(logits, dim=-1)
84 |
85 | select_logprobs = torch.gather(
86 | input=logps,
87 | dim=-1,
88 | index=labels.unsqueeze(1)
89 | ).squeeze(1)
90 |
91 | if mask is not None:
92 | mask = mask[:, 1:].clone()
93 | # 进行掩码padding部分
94 | select_logprobs = select_logprobs * mask
95 | # 计算每一句的平均
96 | average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)
97 | return average_logprobs
98 | else:
99 | return select_logprobs.mean(-1)
100 |
101 |
102 | # 计算每个模型的Log probabilities. 使用torch的F.cross_entropy进行计算。结果同上,均是一样。
103 | def compute_logprobs_f_cross(logits, labels, mask=None):
104 | """
105 | logits: shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果
106 | labels: shape (batch_size, sequence_len)
107 | """
108 | # 需要先进行位移操作
109 | # 去掉标签的第一个
110 | labels = labels[:, 1:].clone()
111 | # 去掉模型输出的最后一个
112 | logits = logits[:, :-1, :].clone()
113 |
114 | batch_size, sequence_len, vocab_size = logits.shape
115 | cross_entropy_loss = 0
116 |
117 | if mask is not None:
118 | mask = mask[:, 1:].clone()
119 | labels.masked_fill_(~mask, -100)
120 | for i in range(batch_size):
121 | cross_entropy_loss += F.cross_entropy(logits[i], labels[i])
122 | else:
123 | for i in range(batch_size):
124 | cross_entropy_loss += F.cross_entropy(logits[i], labels[i])
125 | cross_entropy_loss /= batch_size
126 | return cross_entropy_loss
127 |
128 |
129 | def compute_batch_loss(batch, policy_model, reference_model, beta):
130 | # 决定使用哪个loss
131 | # loss_fn = SimPo(beta, 0.5) SimPO loss
132 | loss_fn = DPOLoss(beta) # DPO loss
133 |
134 | policy_chosen_logps = compute_logprobs(
135 | logits=policy_model(batch["chosen"]).logits,
136 | labels=batch["chosen"],
137 | mask=batch["chosen_mask"]
138 | )
139 | policy_rejected_logps = compute_logprobs(
140 | logits=policy_model(batch["rejected"]).logits,
141 | labels=batch["rejected"],
142 | mask=batch["rejected_mask"]
143 | )
144 | reference_chosen_logps = compute_logprobs(
145 | logits=reference_model(batch['chosen']).logits,
146 | labels=batch['chosen'],
147 | mask=batch["chosen_mask"]
148 | )
149 | reference_rejected_logps = compute_logprobs(
150 | logits=reference_model(batch['rejected']).logits,
151 | labels=batch['rejected'],
152 | mask=batch["rejected_mask"]
153 | )
154 | loss, chosen_rewards, rejected_rewards = loss_fn(
155 | policy_chosen_logps=policy_chosen_logps,
156 | policy_rejected_logps=policy_rejected_logps,
157 | reference_chosen_logps=reference_chosen_logps,
158 | reference_rejected_logps=reference_rejected_logps,
159 | )
160 | # SimPO使用如下
161 | # loss = loss_fn(
162 | # policy_chosen_logps=policy_chosen_logps,
163 | # policy_rejected_logps=policy_rejected_logps,
164 | # )
165 | # return loss
166 | return loss, chosen_rewards, rejected_rewards
167 |
168 |
169 | def compute_loss_dataloader(data_loader, policy_model, reference_model, beta, num_batches=5):
170 | total_loss, total_chosen_rewards, total_rejected_rewards = 0., 0., 0.
171 | num_batches = min(num_batches, len(data_loader))
172 |
173 | for i, batch in enumerate(data_loader):
174 | if i < num_batches:
175 | loss, chosen_rewards, rejected_rewards = compute_batch_loss(
176 | batch=batch,
177 | policy_model=policy_model,
178 | reference_model=reference_model,
179 | beta=beta
180 | )
181 | total_loss += loss.item()
182 | total_chosen_rewards += chosen_rewards.item()
183 | total_rejected_rewards += rejected_rewards.item()
184 | else:
185 | break
186 | # 计算平均
187 | total_loss /= num_batches
188 | total_chosen_rewards /= num_batches
189 | total_rejected_rewards /= num_batches
190 | return total_loss, total_chosen_rewards, total_rejected_rewards
191 |
192 |
193 | if __name__ == "__main__":
194 | # 测试compute_logprobs_f_cross 与 compute_logprobs
195 | logits = torch.tensor(
196 | [[2.0, 1.0, 0.1, 0.4],
197 | [0.5, 2.5, 0.3, 0.5],
198 | [0.6, 2.5, 0.3, 0.8],
199 | [0.5, 2.5, 0.6, 0.6]], dtype=torch.float32).unsqueeze(0)
200 | mask = torch.tensor([[True, True, False, False]])
201 | targets = torch.tensor([0, 1, 0, 2]).unsqueeze(0)
202 | loss1 = -compute_logprobs(logits, targets, mask)
203 | loss2 = compute_logprobs_f_cross(logits, targets, mask)
204 | print(loss1, loss2)
205 |
--------------------------------------------------------------------------------
/llm_tricks/dora/READEME.md:
--------------------------------------------------------------------------------
1 | # DoRA: Weight-Decomposed Low-Rank Adaptation
2 |
3 | 此为Dora微调方法的实现(目前**huggingface也已集成dora**,故使用可以直接使用huggingface如下,本模块可以作为详细的**理论学习**)⚽
4 |
5 | huggingface中使用如下,基于lora的基础上,增加use_dora参数即可。本项目的训练框架也支持dora训练。
6 | ```python
7 | from peft import LoraConfig
8 |
9 | # Initialize DoRA configuration
10 | config = LoraConfig(
11 | use_dora=True, ...
12 | )
13 | ```
14 |
15 |
16 |
17 |
18 | Implementation of "DoRA: Weight-Decomposed Low-Rank Adaptation" (Liu et al, 2024) https://arxiv.org/pdf/2402.09353.pdf
19 |
20 |
21 | ## 😸技术博客链接
22 |
23 | - [知乎:Dora原理及代码讲解](https://zhuanlan.zhihu.com/p/695269522)
24 |
25 | ## Tips:
26 | Dora是基于Lora的变体,故也对Lora进行了简单的示例。
27 |
28 |
29 | DoRA可以分两步描述,其中第一步是将预训练的权重矩阵分解为幅度向量(m)和方向矩阵(V)。第二步是将LoRA应用于方向矩阵V并单独训练幅度向量m。
30 |
31 | ## 如何使用
32 |
33 |
34 | dora_example.py 中有详细完整的 LoRA及DoRA训练与验证,建立了一个小的模型从训练到验证等全部过程。
35 |
36 | lora_and_dora.ipynb 用于自己调试及学习,可以在其中逐步运行以理解其原理。
37 |
38 | 运行以下代码可得到实验结果
39 | ```shell
40 | python dora_example.py
41 | ```
42 |
43 | ## 实验结果如下:
44 | 运行 dora_example.py。超参数设置参考文件内。小模型具有局限性,具体dora和lora的实际效果对比还需要更多的实验。
45 |
46 | ```python
47 | Epoch: 001/001 | Batch 000/938 | Loss: 2.3010
48 | Epoch: 001/001 | Batch 400/938 | Loss: 0.4533
49 | Epoch: 001/001 | Batch 800/938 | Loss: 0.0464
50 | Epoch: 001/001 training accuracy: 95.31%
51 | Time elapsed: 0.11 min
52 | Total Training Time: 0.11 min
53 | Test accuracy: 96.88%
54 | Epoch: 001/002 | Batch 000/938 | Loss: 0.1734
55 | Epoch: 001/002 | Batch 400/938 | Loss: 0.0447
56 | Epoch: 001/002 | Batch 800/938 | Loss: 0.1270
57 | Epoch: 001/002 training accuracy: 96.88%
58 | Time elapsed: 0.11 min
59 | Epoch: 002/002 | Batch 000/938 | Loss: 0.0626
60 | Epoch: 002/002 | Batch 400/938 | Loss: 0.2149
61 | Epoch: 002/002 | Batch 800/938 | Loss: 0.1430
62 | Epoch: 002/002 training accuracy: 95.31%
63 | Time elapsed: 0.23 min
64 | Total Training Time: 0.23 min
65 | Test accuracy LoRA finetune: 96.88%
66 | Epoch: 001/002 | Batch 000/938 | Loss: 0.1588
67 | Epoch: 001/002 | Batch 400/938 | Loss: 0.1235
68 | Epoch: 001/002 | Batch 800/938 | Loss: 0.0506
69 | Epoch: 001/002 training accuracy: 100.00%
70 | Time elapsed: 0.11 min
71 | Epoch: 002/002 | Batch 000/938 | Loss: 0.1374
72 | Epoch: 002/002 | Batch 400/938 | Loss: 0.0892
73 | Epoch: 002/002 | Batch 800/938 | Loss: 0.0606
74 | Epoch: 002/002 training accuracy: 95.31%
75 | Time elapsed: 0.23 min
76 | Total Training Time: 0.23 min
77 | Test accuracy DoRA finetune: 98.44%
78 | ```
79 |
--------------------------------------------------------------------------------
/llm_tricks/dora/dora_example.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | from torchvision import datasets
4 | from torchvision import transforms
5 | from torch.utils.data import DataLoader
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | import torch
9 | import copy
10 |
11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 | BATCH_SIZE = 64
13 |
14 | # --------data process------------------------------------------------
15 |
16 | train_dataset = datasets.MNIST(root='data',
17 | train=True,
18 | transform=transforms.ToTensor(),
19 | download=True)
20 | test_dataset = datasets.MNIST(root='data',
21 | train=False,
22 | transform=transforms.ToTensor())
23 |
24 | train_loader = DataLoader(dataset=train_dataset,
25 | batch_size=BATCH_SIZE,
26 | shuffle=True)
27 |
28 | test_loader = DataLoader(dataset=test_dataset,
29 | batch_size=BATCH_SIZE,
30 | shuffle=False)
31 |
32 | # ----------------Hyperparameters-------------------------------------
33 | random_seed = 123
34 | learning_rate = 0.005
35 | num_epochs = 1
36 |
37 | # ----------------Architecture-----------------------------------------
38 | num_features = 784
39 | num_hidden_1 = 32
40 | num_hidden_2 = 64
41 | num_classes = 10
42 |
43 | torch.manual_seed(random_seed)
44 |
45 |
46 | # ---------------Model-----------------------------------------------
47 | class TestMLP(nn.Module):
48 | def __init__(self, num_features, num_hidden1, num_hidden2, num_class):
49 | super().__init__()
50 | self.layers = nn.Sequential(
51 | nn.Linear(num_features, num_hidden1),
52 | nn.ReLU(),
53 | nn.Linear(num_hidden1, num_hidden2),
54 | nn.ReLU(),
55 |
56 | nn.Linear(num_hidden2, num_class)
57 | )
58 |
59 | def forward(self, x):
60 | x = self.layers(x)
61 | return x
62 |
63 |
64 | model = TestMLP(
65 | num_features=num_features, num_hidden1=num_hidden_1, num_hidden2=num_hidden_2, num_class=num_classes
66 | )
67 |
68 | model.to(DEVICE)
69 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
70 |
71 |
72 | # ---------------------Eval---------------------------------------------
73 |
74 | def computer_metrics(model, data_loader, device):
75 | model.eval()
76 | correct_pred, num_examples = 0, 0
77 | with torch.no_grad():
78 | for images, labels in data_loader:
79 | # Image batch dimensions: torch.Size([64, 1, 28, 28])
80 | # Image label dimensions: torch.Size([64])
81 |
82 | images = images.view(-1, 28 * 28).to(device)
83 | labels = labels.to(device)
84 | logits = model(images)
85 | _, predicted_labels = torch.max(logits, 1)
86 | num_examples = labels.size(0)
87 | correct_pred += (predicted_labels == labels).sum()
88 | return correct_pred.float() / num_examples * 100
89 |
90 |
91 | # ---------------------Train---------------------------------------------
92 |
93 |
94 | def train(epochs, model, optimizer, train_loader, device):
95 | start_time = time.time()
96 | for epoch in range(epochs):
97 | model.train()
98 | for batch_idx, (images, labels) in enumerate(train_loader):
99 | images = images.view(-1, 28 * 28).to(device)
100 | labels = labels.to(device)
101 |
102 | # forward and back
103 | logits = model(images)
104 | loss = F.cross_entropy(logits, labels)
105 | optimizer.zero_grad()
106 |
107 | loss.backward()
108 |
109 | # UPDATE MODEL PARAMETERS
110 | optimizer.step()
111 |
112 | # LOGGING
113 | if not batch_idx % 400:
114 | print('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f'
115 | % (epoch + 1, epochs, batch_idx,
116 | len(train_loader), loss))
117 | with torch.set_grad_enabled(False):
118 | print('Epoch: %03d/%03d training accuracy: %.2f%%' % (
119 | epoch + 1, epochs,
120 | computer_metrics(model, train_loader, device)))
121 | print('Time elapsed: %.2f min' % ((time.time() - start_time) / 60))
122 | print('Total Training Time: %.2f min' % ((time.time() - start_time) / 60))
123 |
124 |
125 | # ---------------------Lora Model---------------------------------------------
126 | class LoRALayer(nn.Module):
127 | def __init__(self, in_dim, out_dim, rank, alpha):
128 | super().__init__()
129 | std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
130 | self.A = nn.Parameter(torch.rand(in_dim, rank) * std_dev)
131 | self.B = nn.Parameter(torch.zeros(rank, out_dim))
132 | self.alpha = alpha
133 |
134 | def forward(self, x):
135 | x = self.alpha * (x @ self.A @ self.B)
136 | return x
137 |
138 |
139 | class LinearWithLoRA(nn.Module):
140 | def __init__(self, linear, rank, alpha):
141 | super().__init__()
142 | self.linear = linear
143 | self.lora = LoRALayer(
144 | linear.in_features,
145 | linear.out_features,
146 | rank,
147 | alpha
148 | )
149 |
150 | def forward(self, x):
151 | return self.linear(x) + self.lora(x)
152 |
153 |
154 | # ---------------------DoRA Model---------------------------------------------
155 | class LinearWithDoRA(nn.Module):
156 | def __init__(self, linear, rank, alpha):
157 | super().__init__()
158 | self.linear = linear
159 | self.lora = LoRALayer(
160 | linear.in_features, linear.out_features, rank, alpha
161 | )
162 | self.m = nn.Parameter(torch.ones(1, linear.out_features))
163 |
164 | def forward(self, x):
165 | linear_out = self.linear(x)
166 | lora_out = self.lora(x)
167 | lora_out_norm = lora_out / (lora_out.norm(p=2, dim=1, keepdim=True) + 1e-9)
168 | dora_modification = self.m * lora_out_norm
169 | return linear_out + dora_modification
170 |
171 |
172 | # 冻结模型的线性层,即可达到lora的只训练额外的lora层
173 | def freeze_linear_layers(model):
174 | for child in model.children():
175 | if isinstance(child, nn.Linear):
176 | for param in child.parameters():
177 | param.requires_grad = False
178 | else:
179 | # Recursively freeze linear layers in children modules
180 | freeze_linear_layers(child)
181 |
182 |
183 | # 将模型中的linear层替换为 LinearWithLoRA
184 | def convert_lora_layers(model):
185 | for name, module in model.named_children():
186 | if isinstance(module, nn.Linear):
187 | setattr(model, name, LinearWithLoRA(module, rank=4, alpha=8))
188 | else:
189 | convert_lora_layers(module)
190 |
191 |
192 | # 将模型中的linear层替换为 LinearWithDoRA
193 | def convert_dora_layers(model):
194 | for name, module in model.named_children():
195 | if isinstance(module, nn.Linear):
196 | setattr(model, name, LinearWithDoRA(module, rank=4, alpha=8))
197 | else:
198 | convert_lora_layers(module)
199 |
200 |
201 | if __name__ == '__main__':
202 | train(num_epochs, model, optimizer, train_loader, DEVICE)
203 | print(f'Test accuracy: {computer_metrics(model, test_loader, DEVICE):.2f}%')
204 |
205 | # 复制两份模型,以供lora 和 dora分别实验
206 | model_lora = copy.deepcopy(model)
207 | model_dora = copy.deepcopy(model)
208 |
209 | # lora_qlora 训练
210 | convert_lora_layers(model_lora)
211 | freeze_linear_layers(model_lora)
212 | model_lora.to(DEVICE)
213 | optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
214 | train(2, model_lora, optimizer_lora, train_loader, DEVICE)
215 | print(f'Test accuracy LoRA finetune: {computer_metrics(model_lora, test_loader, DEVICE):.2f}%')
216 |
217 | # dora 训练
218 | convert_dora_layers(model_dora)
219 | freeze_linear_layers(model_dora)
220 | model_dora.to(DEVICE)
221 | optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate)
222 | train(2, model_dora, optimizer_dora, train_loader, DEVICE)
223 | print(f'Test accuracy DoRA finetune: {computer_metrics(model_dora, test_loader, DEVICE):.2f}%')
224 |
--------------------------------------------------------------------------------
/llm_tricks/moe/READEME.md:
--------------------------------------------------------------------------------
1 | # Make MOE step by step
2 |
3 | 从零构建一个MOE代码存放于 **make_moe_step_by_step.ipynb**文件下。其中有详细的代码注释,推荐结合技术博客阅读,因为博客中手画了许多图以更好地理解。
4 |
5 | ## 😸技术博客链接
6 |
7 | - [从零构建一个MOE](https://zhuanlan.zhihu.com/p/701777558)
8 |
9 |
10 |
11 | ## 补充
12 |
13 | 博客中没提到的一点是 Expert Capacity。大概意思就是为了防止所有tokens都被一个或几个expert处理,我们需要设置一个专家容量。如果某个专家处理超过容量的tokens后就会给他截断,下面给出一个简单的代码示例,实际生产中会有更高级复杂的策略,
14 | 例如在https://arxiv.org/abs/2101.03961 中讨论的switch transformer架构。
15 |
16 | 我们简单的介绍代码如下,与我们技术博客中讲的SparseMoE基本相同,只是加了两个部分,在代码注释中也已标明。
17 | ```python
18 | class SparseMoE(nn.Module):
19 | def __init__(self, n_embed, num_experts, top_k, capacity_factor=1.0):
20 | super(SparseMoE, self).__init__()
21 | self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
22 | self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
23 | self.top_k = top_k
24 | self.capacity_factor = capacity_factor
25 | self.num_experts = num_experts
26 |
27 | def forward(self, x):
28 | batch_size, seq_len, _ = x.shape
29 | gating_output, indices = self.router(x)
30 | final_output = torch.zeros_like(x)
31 |
32 | flat_x = x.view(-1, x.size(-1))
33 | flat_gating_output = gating_output.view(-1, gating_output.size(-1))
34 |
35 | tokens_per_batch = batch_size * seq_len * self.top_k
36 | # 定义专家容量
37 | expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)
38 |
39 | updates = torch.zeros_like(flat_x)
40 |
41 | for i, expert in enumerate(self.experts):
42 | expert_mask = (indices == i).any(dim=-1)
43 | flat_mask = expert_mask.view(-1)
44 | selected_indices = torch.nonzero(flat_mask).squeeze(-1)
45 |
46 | # 进行容量判断
47 | limited_indices = selected_indices[:expert_capacity] if selected_indices.numel() > expert_capacity else selected_indices
48 | if limited_indices.numel() > 0:
49 | expert_input = flat_x[limited_indices]
50 | expert_output = expert(expert_input)
51 |
52 | gating_scores = flat_gating_output[limited_indices, i].unsqueeze(1)
53 | weighted_output = expert_output * gating_scores
54 |
55 | updates.index_add_(0, limited_indices, weighted_output)
56 |
57 | # Reshape updates to match the original dimensions of x
58 | final_output += updates.view(batch_size, seq_len, -1)
59 |
60 | return final_output
61 |
62 | ```
--------------------------------------------------------------------------------
/llm_tricks/transformer/README.md:
--------------------------------------------------------------------------------
1 | # Transformer代码详解:从头开始实现
2 |
3 | 全部代码:https://github.com/mst272/transformer-pytorch
--------------------------------------------------------------------------------
/main_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | import random
4 | from typing import Optional
5 |
6 | from loguru import logger
7 | import torch
8 | import torch.nn as nn
9 | from datasets import load_dataset
10 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, Trainer, \
11 | BitsAndBytesConfig, HfArgumentParser, set_seed
12 | from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, cast_mixed_precision_params
13 | from train_args import sft_TrainArgument
14 | import bitsandbytes as bnb
15 | from utils.data_process import MultiRoundDataProcess
16 | from utils.data_collator import SftDataCollator
17 | from train_args.common_args import CommonArgs
18 | from utils.eval.configs import EvaluationConfig, GenerationConfig
19 | from utils.eval.callback import EvaluationCallback
20 | from utils.eval.eval_metric import create_metric
21 |
22 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
23 | os.environ["HF_ALLOW_CODE_EVAL"] = "1"
24 |
25 | def initial_args():
26 | parser = HfArgumentParser((CommonArgs,))
27 | args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
28 | if args.train_args_path == "sft_args":
29 | if args.use_eval_in_train:
30 | parser_b = HfArgumentParser((sft_TrainArgument, EvaluationConfig, GenerationConfig))
31 | train_args, eval_args, gen_config = parser_b.parse_args_into_dataclasses(args=remaining_args)
32 | else:
33 | parser_b = HfArgumentParser((sft_TrainArgument,))
34 | train_args, = parser_b.parse_args_into_dataclasses(args=remaining_args)
35 | else:
36 | raise ValueError("Invalid train_args_path choice")
37 |
38 | if not os.path.exists(train_args.output_dir):
39 | os.makedirs(train_args.output_dir, exist_ok=True)
40 | set_seed(train_args.seed)
41 |
42 | assert sum([train_args.fp16, train_args.bf16]) == 1, "only one of fp16 and bf16 can be True"
43 | if args.use_eval_in_train:
44 | return args, train_args, eval_args, gen_config
45 | return args, train_args
46 |
47 |
48 | def find_all_linear_names(model, train_mode):
49 | """
50 | 找出所有全连接层,为所有全连接添加adapter
51 | """
52 | assert train_mode in ['lora', 'qlora']
53 | cls = bnb.nn.Linear4bit if train_mode == 'qlora' else nn.Linear
54 | lora_module_names = set()
55 | for name, module in model.named_modules():
56 | if isinstance(module, cls):
57 | names = name.split('.')
58 | lora_module_names.add(names[-1])
59 |
60 | if 'lm_head' in lora_module_names: # needed for 16-bit
61 | lora_module_names.remove('lm_head')
62 | lora_module_names = list(lora_module_names)
63 | return lora_module_names
64 |
65 |
66 | def create_tokenizer(args):
67 | config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
68 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True,
69 | # llama不支持fast
70 | use_fast=False if config.model_type == 'llama' else True
71 | )
72 |
73 | # QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
74 | if tokenizer.__class__.__name__ == 'QWenTokenizer':
75 | tokenizer.pad_token_id = tokenizer.eod_id
76 | tokenizer.bos_token_id = tokenizer.eod_id
77 | tokenizer.eos_token_id = tokenizer.eod_id
78 | if tokenizer.bos_token is None: # qwen没有bos_token,要设置一下,不然dpo train时会报错。
79 | tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
80 | tokenizer.bos_token_id = tokenizer.eos_token_id
81 |
82 | assert tokenizer.pad_token_id is not None, "pad_token_id should not be None"
83 | assert tokenizer.eos_token_id is not None, "eos_token_id should not be None"
84 |
85 | return tokenizer
86 |
87 |
88 | def create_model(args, train_args):
89 | target_modules = None
90 | # 确定训练的精度
91 | torch_dtype = torch.bfloat16 if train_args.bf16 else torch.float32
92 | model_kwargs = dict(
93 | trust_remote_code=True,
94 | torch_dtype=torch_dtype,
95 | use_cache=False if train_args.gradient_checkpointing else True, # The cache is only used for generation,
96 | # fix bug
97 | # device_map='auto'
98 | )
99 |
100 | def load_model(model_kwargs):
101 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
102 | return model
103 |
104 | if args.train_mode == 'qlora':
105 | # 基本的qlora可以直接在加载模型中设置参数,也可以通过BitsAndBytesConfig进行一些设置
106 | quantization_config = BitsAndBytesConfig(
107 | load_in_4bit=True, # 是否在4位精度下加载模型。如果设置为True,则在4位精度下加载模型。
108 | bnb_4bit_compute_dtype=torch.float16 if train_args.fp16 else torch.bfloat16, # 4位精度计算的数据类型。
109 | bnb_4bit_quant_type="nf4", # 4位精度量化的类型。这里设置为"nf4",表示使用nf4量化类型。
110 | bnb_4bit_use_double_quant=True # 是否使用双精度量化。如果设置为True,则使用双精度量化。
111 | )
112 | model_kwargs.update(quantization_config=quantization_config)
113 | model = load_model(model_kwargs)
114 | if args.task_type in ['pretrain', 'sft']: # 如果是dpo的话就不执行
115 | # QLoRA: casts all the non int8 modules to full precision (fp32) for stability
116 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=train_args.gradient_checkpointing)
117 |
118 | elif args.train_mode == 'lora':
119 | model = load_model(model_kwargs)
120 | if hasattr(model, 'enable_input_require_grads'):
121 | # 不加可能报错
122 | model.enable_input_require_grads()
123 | elif args.train_mode == 'full':
124 | model = load_model(model_kwargs)
125 |
126 | if args.train_mode == 'full':
127 | peft_config = None
128 | else:
129 | # peft_config配置
130 | target_modules = find_all_linear_names(model, args.train_mode)
131 | peft_config = LoraConfig(
132 | r=args.lora_rank,
133 | lora_alpha=args.lora_alpha,
134 | target_modules=target_modules,
135 | lora_dropout=args.lora_dropout,
136 | task_type=TaskType.CAUSAL_LM,
137 | use_dora=args.use_dora
138 | )
139 |
140 | # peft_model 配置
141 | if args.train_mode in ['lora', 'qlora'] and args.task_type in ['pretrain', 'sft']:
142 | model = get_peft_model(model, peft_config)
143 | if not train_args.bf16:
144 | cast_mixed_precision_params(model, dtype=torch.float16)
145 |
146 | # logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB')
147 | # model.print_trainable_parameters()
148 |
149 | return {
150 | 'model': model,
151 | 'peft_config': peft_config,
152 | 'target_modules': target_modules
153 | }
154 |
155 |
156 | def load_sft_dataset(args, tokenizer):
157 | train_dataset = MultiRoundDataProcess(args.train_data_path, tokenizer, args.max_len, args.auto_adapt)
158 | return train_dataset
159 |
160 |
161 | def create_trainer(args, train_args, eval_args: Optional[EvaluationConfig] = None, gen_config: Optional[GenerationConfig] = None):
162 | """"
163 | Create Trainer,支持可选的评估功能
164 | Args:
165 | args: 通用参数
166 | train_args: 训练相关参数
167 | eval_args: 评估相关参数(可选)
168 | gen_config: 评估相关参数(可选)
169 | """
170 | # 1.Basic component initialization
171 | tokenizer = create_tokenizer(args)
172 | model_dict = create_model(args, train_args)
173 | model = model_dict['model']
174 | # peft_config = model_dict['peft_config']
175 |
176 | # 2. dataset process
177 | if args.task_type == 'sft':
178 | train_dataset = load_sft_dataset(args, tokenizer)
179 | data_collator = SftDataCollator(tokenizer, args.max_len)
180 | elif args.task_type == 'pretrain':
181 | pass
182 |
183 | # 3. log configuration
184 | log_out(args, train_args, tokenizer, train_dataset, model, model_dict['target_modules'], eval_args, gen_config)
185 |
186 | # 4. sft or pretrain
187 | if args.task_type == 'sft':
188 | trainer = Trainer(
189 | model=model,
190 | args=train_args,
191 | train_dataset=train_dataset,
192 | data_collator=data_collator,
193 | processing_class=tokenizer
194 | )
195 | elif args.task_type == 'pretrain':
196 | pass
197 | # 5. Add evaluation callbacks if eval_args is provided
198 | if eval_args is not None:
199 | test_datasets = load_dataset(
200 | path="json",
201 | data_files=args.test_datasets_path
202 | )['train']
203 |
204 | # 创建评估回调
205 | metrics = create_metric(eval_args)
206 | eval_callback = EvaluationCallback(
207 | trainer=trainer,
208 | test_datasets=test_datasets,
209 | generation_config=gen_config,
210 | num_samples=eval_args.num_samples,
211 | freq=eval_args.freq,
212 | metrics=metrics,
213 | max_checkpoints=eval_args.max_checkpoints,
214 | per_device_test_batch_size=eval_args.per_device_test_batch_size,
215 | higher_better=eval_args.higher_better,
216 | start_update_best_checkpoints=eval_args.start_update_best_checkpoints,
217 | use_vllm=eval_args.use_vllm,
218 | gather_deepspeed3_params=gen_config.gather_deepspeed3_params,
219 | prompts_apply_chat=eval_args.prompts_apply_chat,
220 | vllm_server_host=eval_args.vllm_server_host,
221 | vllm_server_port=eval_args.vllm_server_port,
222 | vllm_server_timeout=eval_args.vllm_server_timeout
223 | )
224 | trainer.add_callback(eval_callback)
225 |
226 | return trainer
227 |
228 |
229 | def log_out(args, train_args, tokenizer, train_dataset, model, target_modules, eval_args, gen_config):
230 | total = sum(p.numel() for p in model.parameters())
231 | logger.add(join(train_args.output_dir, 'train.log'))
232 | if train_args.local_rank == 0:
233 | logger.info("train_args:{}".format(train_args))
234 | logger.info("common_args:{}".format(args))
235 | logger.info("\neval_args:{}".format(eval_args))
236 | logger.info("\ngen_config:{}".format(gen_config))
237 | logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}')
238 | logger.info(f'Loading model from base model: {args.model_name_or_path}')
239 | logger.info("Total model params: %.2fM" % (total / 1e6))
240 | logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB')
241 | if args.train_mode != 'full':
242 | trainable_params, all_param = model.get_nb_trainable_parameters()
243 | logger.info(
244 | f"trainable params: {trainable_params:,d} || "
245 | f"all params: {all_param:,d} || "
246 | f"trainable%: {100 * trainable_params / all_param:.4f}"
247 | )
248 | logger.info(f'Train model with {args.task_type} task')
249 | logger.info(f'Train model with {args.train_mode}')
250 | logger.info(f'LoRA target module names: {target_modules}')
251 | logger.info(f'Loading data: {args.train_data_path}')
252 | logger.info(f"Training dataset samples:{len(train_dataset)}")
253 | for index in random.sample(range(len(train_dataset)), 3):
254 | logger.info(
255 | f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['target_mask']}.")
256 | logger.info(
257 | f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.")
258 |
259 |
260 | def main():
261 | # args, train_args = initial_args()
262 | # # 加载trainer
263 | # trainer = create_trainer(args, train_args)
264 | result = initial_args()
265 | if len(result) == 4:
266 | args, train_args, eval_args, gen_config = result
267 | # 加载trainer,需要传入eval_args
268 | trainer = create_trainer(args, train_args, eval_args, gen_config)
269 | else:
270 | args, train_args = result
271 | # 原有的trainer创建方式
272 | trainer = create_trainer(args, train_args)
273 | # 开始训练
274 | if train_args.local_rank == 0:
275 | logger.info("*** starting training ***")
276 | train_result = trainer.train()
277 | # Transformers 更新了自动保存最后训练结果
278 | # final_save_path = join(train_args.output_dir)
279 | # trainer.save_model(final_save_path)
280 |
281 | # 保存训练指标
282 | metrics = train_result.metrics
283 | trainer.log_metrics("train", metrics)
284 | trainer.save_metrics("train", metrics)
285 | trainer.save_state()
286 |
287 |
288 | if __name__ == "__main__":
289 | main()
290 |
--------------------------------------------------------------------------------
/pic/pic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mst272/LLM-Dojo/6397eff480feca95f3b6016495b30d1ae308c9dc/pic/pic.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | transformers>=4.44.2
3 | peft
4 | bitsandbytes
5 | loguru
6 | numpy==1.26.4
7 | pandas
8 | tqdm
9 | deepspeed==0.16.3
10 | sentencepiece
11 | transformers-stream-generator
12 | tiktoken
13 | einops
14 | torch==2.5.0
15 | datasets
16 | trl
17 | wandb
18 | PIL
19 | vllm==0.7.2
--------------------------------------------------------------------------------
/rlhf/README.md:
--------------------------------------------------------------------------------
1 | # RLHF 强化学习框架
2 |
3 | 本框架使用简洁的代码基于Huggingface对各种强化学习方法进行了集成,便于自己修改与使用,是一个轻量化的强化学习框架。
4 |
5 | 主要资源是在1-8张40G A100上进行实验,支持lora qlora 及deepspeed单卡或多卡训练。
6 |
7 | 主要包括三类:
8 |
9 | **1、RLHF**
10 |
11 | **2、Knowledge Distillation (知识蒸馏)**
12 |
13 | **3、Rejected Sampling (拒绝采样) :待更新**
14 |
15 | ## 目录
16 |
17 | - [RLHF](#rlhf)
18 | - [目前支持的RLHF](#目前支持的rlhf)
19 | - [Quick Star](#quick-star)
20 | - [数据格式要求](#数据格式要求)
21 | - [数据格式选择](#数据格式选择)
22 | - [启动训练](#启动训练)
23 | - [注意事项](#注意事项)
24 | - [显存实验](#显存实验)
25 | - [Knowledge Distillation](#knowledge-distillation)
26 | - [Quick Star](#quick-star-1)
27 | - [感谢](#感谢)
28 |
29 | ## RLHF
30 | ### 目前支持的RLHF
31 | 实践来看主要的训练方式即为单轮。
32 |
33 | - ✅ Reward模型的训练
34 | - ✅ RLOO
35 | - ✅ PPO(暂时不可用)
36 | - ✅ SimPO
37 | - ✅ CPO
38 | - ✅ CPO-SimPO
39 | - ✅ DPO
40 | - ✅ KTO
41 |
42 | ### 🚀Quick Star
43 |
44 | 若有问题请尝试 deepspeed==0.15.4/python==3.10, 或者出现loss、rewards/chosen为nan时,请查看当前目录下的requirements.txt,按照此版本安装看是否能解决。
45 |
46 | 一些潜在的问题,暂时还没得到解决或者潜在的解决方案:
47 |
48 | https://github.com/huggingface/alignment-handbook/issues/57
49 |
50 | https://github.com/microsoft/DeepSpeed/issues/6793#issuecomment-2502620884
51 |
52 | https://github.com/ymcui/Chinese-LLaMA-Alpaca-3/issues/29
53 |
54 | #### 数据格式要求
55 | ✅ DPO、CPO、SimPO、CPO-SimPO:
56 |
57 | 需要有如下字段:
58 | - prompt
59 | - chosen
60 | - rejected
61 |
62 | ```json lines
63 | {"prompt":[{"role":"user","content":"How are you?"}],"chosen":[{"role":"assistant","content":"fine"}],"rejected":[{"role":"assistant","content":"no"}]}
64 | ```
65 | ✅ KTO:
66 | - prompt
67 | - completion
68 | - label
69 |
70 | 比较特殊,相当于chosen的label为true,rejected的label为false:
71 | ```json lines
72 | {"prompt":[{"role":"user","content":"How are you?"}],"completion":[{"role":"assistant","content":"fine"}],"label":true}
73 | ```
74 |
75 | ✅ Reward:
76 | - chosen
77 | - rejected
78 |
79 | ```json lines
80 | {"chosen":[{"role":"user","content":"How are you?"},{"role":"assistant","content":"fine"}],"rejected":[{"role":"user","content":"How are you?"},{"role":"assistant","content":"no"}]}
81 | ```
82 | ✅ DPO、RLOO:
83 | - prompt
84 |
85 | ```json lines
86 | {"prompt":[{"role":"user","content":"How are you?"}]}
87 | ```
88 |
89 | #### 数据格式选择
90 |
91 | **1.自动适配Chat Template格式**: 输入数据需为user assistant标准模式,具体可见上述数据格式要求。
92 |
93 | **2.不使用Chat格式**: 输入数据直接改为相应字段格式即可,例如:
94 | ```json lines
95 | {"prompt":"How are you?","chosen":"fine", "rejected": "no"}
96 | ```
97 |
98 | ```json lines
99 | {"chosen":"How are you? fine", "rejected": "How are you? no"}
100 | ```
101 | 训练时便不会进行适配,采用原始输入进行训练。
102 |
103 |
104 | #### 启动训练
105 |
106 | 两个参数配置文件,第一个为```common_args.py```, 其余不同方法的配置在```rlhf_args```文件夹内
107 |
108 | 建议使用deepspeed启动,启动脚本在```rlhf_run.sh```
109 | ```bash
110 | bash rlhf_run.sh
111 | ```
112 |
113 | - rlhf_type: [PPO,RLOO,CPO,DPO,SimPO,CPOSimPO,Reward]
114 | - train_mode: [lora, qlora, full]
115 |
116 | #### 注意事项
117 | 1、需要自己去看AutoModelForSequenceClassification是否可以加载其Classification模型,不能的话需要在其config文件中映射。
118 |
119 | 2、涉及到reward模型时,需要两个模型的tokenizer相同。
120 |
121 | 3、使用deepspeed时需要通过accelerate进行使用,直接deepspeed的话会报错(目前似乎没有很好的解决方案)
122 |
123 | 4、一般来说trl的trainer是不支持使用deepspeed的optimizer和scheduler的
124 |
125 | 5、不支持Qlora和deepspeed zero-3,支持Qlora和deepspeed zero-2
126 |
127 | 6、训练Qwen2时遇到报错,提示```no padding token is defined```。需要在qwen2 ```config.json```中添加pad_token_id,在tokenizer中设置没用。
128 |
129 | 7、PPO/RLOO参数解释:
130 |
131 | See:https://github.com/huggingface/trl/issues/1740
132 |
133 | The ``num_train_epochs`` and ``num_ppo_epochs`` are actually two different things. The num_train_epochs means how many epochs do we go over the dataset, the num_ppo_epochs means the number of epochs we perform PPO updates on a batch of data. So, there is a subtle but meaningful difference here.
134 |
135 | 8、CPO系列不支持fp16,支持bf16
136 |
137 | #### 显存实验
138 | res_length为64
139 |
140 | | **RLHF** | **deepspeed** | **方式** | **Reward Model** | **SFT Model** | **显存占用** |
141 | |----------|---------------|--------|------------------|----------------|------------------------|
142 | | RLOO | Zero 3 | Lora | QWEN2(7B) | QWEN2(7B) | 2 x A100(40GB): 15~30G |
143 | | RLOO | Zero 3 | Full | QWEN2(7B) | QWEN2(7B) | 2 x A100(40GB): 速度很慢 |
144 | | RLOO | Zero 2 | Qlora | QWEN2(7B) | QWEN2(7B) | 2 x A100(40GB): 30~40G |
145 | | PPO | Zero 2 | Lora | MiniCPM(2B) | Deepseek(6.7B) | 2 x A100(40GB): OOM |
146 | | PPO | Zero 3 | Lora | MiniCPM(2B) | Deepseek(6.7B) | 2 x A100(40GB): 20-25G |
147 | | PPO | Zero 2 | Qlora | MiniCPM(2B) | Deepseek(6.7B) | 2 x A100(40GB): 30G |
148 |
149 | ## Knowledge Distillation
150 | 目前支持三种类型的知识蒸馏,GKD效果最好:
151 | - Supervised KD(off-policy)
152 | - SeqKD(off-policy)
153 | - GKD(on-policy)
154 |
155 | 具体介绍可参见文章:[知识蒸馏](https://zhuanlan.zhihu.com/p/1064724364)
156 |
157 | ### Quick Star
158 | 进入script目录下bash运行```gkd_run.sh```即可,修改对应参数运行。同样支持Deepspeed.
159 |
160 |
161 | ```bash
162 | bash gkd_run.sh
163 | ```
164 |
165 | **参数介绍**:
166 | - lmbda:0时为Supervised KD,1时为GKD。可在[0,1]范围内选择,这样就会混合比例
167 | - beta: 0时loss为KLD, 1时为JSD。可在[0,1]范围内选择,这样就会混合比例
168 | - seq_kd: True时Supervised KD将替换为Seq KD,默认为False,其他不变。
169 | - model_name_or_path:Student Model,即你需要训练的模型
170 | - teacher_model_name_or_path:Teacher Model, 不训练。
171 |
172 | ## Rejected Sampling
173 | 待更新
174 |
175 | ## 感谢
176 |
177 | 特别感谢huggingface trl做出的强大贡献,通过 trl 我们真的可以很容易简洁的实现RLHF。
--------------------------------------------------------------------------------
/rlhf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mst272/LLM-Dojo/6397eff480feca95f3b6016495b30d1ae308c9dc/rlhf/__init__.py
--------------------------------------------------------------------------------
/rlhf/common_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class CommonArgs:
7 | """
8 | 一些常用的自定义参数
9 | """
10 | train_data_path: str = field(default='', metadata={"help": "训练数据路径"})
11 | # 微调方法相关选择与配置
12 | rlhf_type: str = field(default="DPO",
13 | metadata={"help": "选择使用的RLHF方法,目前支持[PPO,RLOO,DPO,CPO,SimPO,CPOSimPO,Reward]"})
14 | train_mode: str = field(default='lora', metadata={"help": "选择采用的训练方式:[qlora, lora, full]"})
15 |
16 | # model qlora lora相关配置
17 | model_name_or_path: str = './'
18 | use_dora: bool = field(default=False, metadata={"help": "仅在train_mode==lora时可以使用。是否使用Dora(一个基于Lora的变体)"})
19 | lora_rank: Optional[int] = field(default=32, metadata={"help": "lora rank"})
20 | lora_alpha: Optional[int] = field(default=16, metadata={"help": "lora alpha"})
21 | lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "lora dropout"})
22 |
--------------------------------------------------------------------------------
/rlhf/ds_config/ds_zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: cpu
6 | zero3_init_flag: false
7 | zero_stage: 2
8 | distributed_type: DEEPSPEED
9 | downcast_bf16: 'no'
10 | machine_rank: 0
11 | main_training_function: main
12 | mixed_precision: 'bf16'
13 | num_machines: 1
14 | num_processes: 2
15 | rdzv_backend: static
16 | same_network: true
17 | tpu_env: []
18 | tpu_use_cluster: false
19 | tpu_use_sudo: false
20 | use_cpu: false
21 | main_process_port: 29501
--------------------------------------------------------------------------------
/rlhf/ds_config/ds_zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: cpu
6 | offload_param_device: cpu
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 2
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 | main_process_port: 29501
--------------------------------------------------------------------------------
/rlhf/gkd_run.sh:
--------------------------------------------------------------------------------
1 | # 使用显卡数量需在yaml文件中修改num_processes参数
2 |
3 | # Lora模式, 如需QLora或者全参略微修改参数即可
4 | CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ./ds_config/ds_zero3.yaml ./train_gkd.py \
5 | --model_name_or_path deepseek-coder-6.7b-instruct \
6 | --teacher_model_name_or_path deepseek-coder-33b-instruct\
7 | --dataset_name ../data/gkd_data.jsonl \
8 | --learning_rate 2e-5 \
9 | --per_device_train_batch_size 4 \
10 | --gradient_accumulation_steps 8 \
11 | --output_dir gkd-model2 \
12 | --logging_steps 2 \
13 | --num_train_epochs 1 \
14 | --gradient_checkpointing \
15 | --lmbda 0.5 \
16 | --beta 0.5 \
17 | --use_peft \
18 | --lora_r 32 \
19 | --lora_alpha 16 \
20 | --trust_remote_code \
21 | --bf16 \
22 | --save_strategy "steps" \
23 | --save_steps 180 \
24 | --save_total_limit 5 \
25 | --warmup_steps 10 \
26 | --lr_scheduler_type "cosine" \
27 | --torch_dtype bfloat16 > logs.log 2>&1 &
--------------------------------------------------------------------------------
/rlhf/rejected_sampling/README.md:
--------------------------------------------------------------------------------
1 | # Rejected Sampling
2 |
3 | ## 1、Generate
4 |
5 | ### Data format
6 |
7 | jsonl格式,包含如下字段:
8 | - prompt
9 | - answer:可为空字符串,作为参考答案
10 | ```json lines
11 | {"prompt":"Hellow","answer":"nice"}
12 | ```
13 |
14 | 如果使用可以直接apply_chat的messages格式,也无需修改,直接传入即可(需要无system字段),同样assistant的回答当做参考回答,后续可选择是否将参考回答放入打分名单:
15 |
16 | ```json lines
17 | {"message": [{"role": "user", "content": "How many helicopters can a human eat in one sitting"},{"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together"}]}
18 | ```
19 |
20 | 采用分块存储,最终生成文件数量为原文件的n倍(生成n个回答)的字段包括:
21 | - messages: 可以直接输入训练且apply_chat_template的messages格式,其中每个assistant为n个生成中的一个
22 | - model_completion: 模型本次生成的结果
23 | - reference_completion: 你的原始数据的参考答案
24 |
25 |
26 | ## 2、Rejected sampling评测阶段
27 |
28 | 目前只支持通过api进行评测选择,传统的classification模型几乎用不到了,所以就进行了去除。
29 |
30 | 参数待配置,目前还只是一个草案,需要在make_request函数中自行配置相关API信息。
31 |
32 |
33 |
34 | ## 致谢
35 | 1、实现借鉴了open-instruct,感谢开源!
--------------------------------------------------------------------------------
/rlhf/rejected_sampling/genetate.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from dataclasses import dataclass, asdict
3 | import os
4 | import json
5 | from typing import Dict, List
6 |
7 | import torch
8 | from datasets import load_dataset, concatenate_datasets, Dataset
9 | from transformers import AutoTokenizer
10 | from vllm import LLM, SamplingParams
11 | from rlhf.utils.util import ArgumentParserPlus
12 |
13 |
14 | @dataclass
15 | class Args:
16 | dataset_name: str = './test.jsonl' # 数据集
17 | model_name_or_path: str = "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr"
18 | save_filename: str = "completions.jsonl"
19 | auto_adapt: bool = True # if apply chat template
20 | system: str = '' # chat template default
21 | chunk_size: int = 50000
22 |
23 |
24 | @dataclass
25 | class GenerationArgs:
26 | num_completions: int = 3
27 | temperature: float = 0.8
28 | response_length: int = 4096
29 | top_p: float = 0.9
30 | tensor_parallel_size: int = 1
31 | dtype: torch.dtype = torch.bfloat16
32 |
33 |
34 | def load_datasets(data_files: str, shuffle: bool):
35 | """
36 | 读取数据集,单jsonl文件或者目录
37 | """
38 | if os.path.isfile(data_files):
39 | # 如果是单个文件,直接读取
40 | if not data_files.endswith('.jsonl'):
41 | raise ValueError(f"文件 '{data_files}' 不是JSONL文件")
42 | datasets = load_dataset("json", data_files=data_files)
43 | else:
44 | # 如果是目录,读取所有JSONL文件
45 | datasets = []
46 | jsonl_files = [f for f in os.listdir(data_files) if f.endswith('.jsonl')]
47 | for file_name in jsonl_files:
48 | dataset = load_dataset("json", data_files=file_name)
49 | datasets.append(dataset)
50 | datasets = concatenate_datasets(datasets)
51 | if shuffle:
52 | datasets = datasets.shuffle(seed=42)
53 | return datasets['train']
54 |
55 |
56 | def save_jsonl(save_filename: str, table: Dict[str, List]):
57 | first_key = list(table.keys())[0]
58 | os.makedirs(os.path.dirname(save_filename), exist_ok=True)
59 | with open(save_filename, "w") as outfile:
60 | for i in range(len(table[first_key])):
61 | json.dump({key: table[key][i] for key in table}, outfile)
62 | outfile.write("\n")
63 |
64 |
65 | def save_jsonl_in_chunks_to_files(base_filename: str, table: Dict[str, List], chunksize: int):
66 | """
67 | 将字典数据按指定的 chunksize 分块保存为多个 JSONL 文件。
68 |
69 | Args:
70 | base_filename: 保存的文件名的基本名称(不包含 chunk 编号)。
71 | table: 包含数据的字典,其中 values 是等长的列表。
72 | chunksize: 每个 chunk 文件保存的行数。
73 | """
74 | first_key = list(table.keys())[0]
75 | num_rows = len(table[first_key])
76 | os.makedirs(os.path.dirname(base_filename), exist_ok=True)
77 | chunk_number = 0
78 | for i in range(0, num_rows, chunksize):
79 | chunk_number += 1
80 | save_filename = f"{base_filename}_chunk_{chunk_number}.jsonl"
81 | with open(save_filename, "w") as outfile:
82 | for j in range(i, min(i + chunksize, num_rows)):
83 | json.dump({key: table[key][j] for key in table}, outfile)
84 | outfile.write("\n")
85 |
86 |
87 | def generate_with_vllm(model_name_or_path: str, prompt_token_ids: List[int], gen_args: GenerationArgs):
88 | llm = LLM(
89 | model=model_name_or_path,
90 | tensor_parallel_size=gen_args.tensor_parallel_size,
91 | max_model_len=gen_args.response_length,
92 | dytype=gen_args.dtype
93 | )
94 |
95 | # filter out prompts which are beyond the model's max token length
96 | max_model_len = llm.llm_engine.scheduler_config.max_model_len
97 | prompt_token_ids_len = len(prompt_token_ids)
98 | prompt_token_ids = [item for item in prompt_token_ids if len(item) < max_model_len]
99 | if len(prompt_token_ids) != prompt_token_ids_len:
100 | print(f"Filtered out {prompt_token_ids_len - len(prompt_token_ids)} prompts which exceeds max token length")
101 |
102 | outputs = llm.generate(
103 | prompt_token_ids=prompt_token_ids,
104 | sampling_params=SamplingParams(
105 | n=gen_args.num_completions,
106 | temperature=gen_args.temperature,
107 | top_p=1.0,
108 | max_tokens=gen_args.response_length,
109 | include_stop_str_in_output=True,
110 | ),
111 | )
112 |
113 | return [
114 | {
115 | "outputs": [asdict(out) for out in output.outputs],
116 | "prompt": output.prompt,
117 | "prompt_logprobs": output.prompt_logprobs,
118 | "metrics": output.metrics,
119 | }
120 | for output in outputs
121 | ]
122 |
123 |
124 | def tokenize(dataset: Dataset, auto_adapt: bool, system: str, tokenizer):
125 | def tokenize_fn(row):
126 | answer = row['answer'] if 'answer' in row else row['messages'][1]['content']
127 | prompt = row['prompt'] if 'prompt' in row else row['messages'][0]['content']
128 | messages = [
129 | {"role": "user", "content": prompt}
130 | ]
131 | if system is not None:
132 | messages.append({"role": "system", "content": system})
133 | outputs = tokenizer.apply_chat_template(
134 | messages,
135 | tokenize=True,
136 | add_generation_prompt=True
137 | )
138 | return {"input_ids": outputs, "prompt": prompt, "answer": answer}
139 |
140 | def tokenize_fn_origin(row):
141 | prompt = row['prompt'] if 'prompt' in row else row['messages'][0]['content']
142 | answer = row['answer'] if 'answer' in row else row['messages'][1]['content']
143 | outputs = tokenizer.encode(prompt)
144 | return {"input_ids": outputs, "prompt": prompt, "answer": answer}
145 |
146 | return dataset.map(
147 | tokenize_fn if auto_adapt else tokenize_fn_origin,
148 | desc="Tokenizing and reformatting rejected sampling data",
149 | )
150 |
151 |
152 | def main(args: Args, gen_args: GenerationArgs):
153 | dataset = load_datasets(data_files=args.dataset_name, shuffle=True)
154 |
155 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
156 | dataset = tokenize(dataset=dataset, auto_adapt=args.auto_adapt, system=args.system, tokenizer=tokenizer)
157 | prompt_token_ids = dataset['input_ids']
158 | outputs = generate_with_vllm(args.model_name_or_path, prompt_token_ids, gen_args)
159 |
160 | # Assuming we generate n=3 completions per prompt; the outputs will look like:
161 | # prompt | completions
162 | # -------|------------
163 | # q1 | a1
164 | # q1 | a2
165 | # q1 | a3
166 | # q2 | a1
167 | # ...
168 | table = defaultdict(list)
169 | num_prompt_with_identical_completions = 0
170 | for output, answer, prompt in zip(outputs, dataset["answer"], dataset['prompt']):
171 | # if the model completions are exactly the same across all completions per prompt, we can skip this
172 | if len(set(tuple(item["text"]) for item in output["outputs"])) == 1:
173 | num_prompt_with_identical_completions += 1
174 | continue
175 |
176 | for item in output["outputs"]:
177 | new_messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": item["text"]}]
178 | table["messages"].append(new_messages)
179 | table["model_completion"].append(item["text"])
180 | table["reference_completion"].append(answer)
181 |
182 | print(f"Number prompts with identical completions: {num_prompt_with_identical_completions}")
183 | # save_jsonl(args.save_filename, table)
184 | save_jsonl_in_chunks_to_files(args.save_filename, table, args.chunk_size)
185 |
186 |
187 | if __name__ == "__main__":
188 | parser = ArgumentParserPlus((Args, GenerationArgs))
189 | main(*parser.parse())
190 |
--------------------------------------------------------------------------------
/rlhf/rejected_sampling/rejected_sampling.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import aiohttp
3 | from datetime import datetime
4 | import re
5 | from tqdm.auto import tqdm
6 | import string
7 | import json
8 | from typing import List, Dict
9 | import glob
10 | import orjson
11 | import os
12 |
13 |
14 | def load_completions(input_path: str) -> List[dict]:
15 | """
16 | 从 JSONL 文件或文件夹中加载 completions 数据 (使用 orjson)。
17 |
18 | Args:
19 | input_path: 文件路径或文件夹路径。
20 |
21 | Returns:
22 | 包含所有加载的 completions 数据的列表。
23 | """
24 | all_completions = []
25 | if os.path.isfile(input_path):
26 | with open(input_path, 'rb') as f: # 注意使用二进制读取模式 'rb'
27 | for line in f:
28 | all_completions.append(orjson.loads(line))
29 | elif os.path.isdir(input_path):
30 | for filename in glob.glob(os.path.join(input_path, '*.jsonl')):
31 | print(f"正在加载文件: {filename}")
32 | with open(filename, 'rb') as f: # 注意使用二进制读取模式 'rb'
33 | for line in f:
34 | all_completions.append(orjson.loads(line))
35 | else:
36 | print(f"错误: 输入路径 '{input_path}' 不是有效的文件或文件夹。")
37 | return all_completions
38 |
39 |
40 | def reason_post_process(code, index):
41 | """
42 | Args:
43 | code (str): 输入字符串。
44 | index (int/str): 当前字符串的序号 (索引)。
45 |
46 | Returns:
47 | str 或 int: 如果找到代码块,则返回最后一个代码块字符串;
48 | 否则,返回输入的字符串序号 (index)。
49 | """
50 |
51 | # Look for code blocks
52 | code_pattern = r'```(?:python|go|ts|php|csharp|bash|javascript|cpp|cs|java)(.*?)```'
53 | code_match = re.findall(code_pattern, code, re.DOTALL)
54 |
55 | if code_match:
56 | # If code block exists, return its content (excluding the ``` markers)
57 | return code_match[-1].strip()
58 | else:
59 | # If no code block, return the solution content directly
60 | print('---', index)
61 | # print(code)
62 | return code
63 |
64 |
65 | def create_dynamic_comparison_prompt(prompt: str, responses: list[str]) -> str:
66 | """
67 | 根据给定的 prompt 和一个包含多个代码响应的列表,
68 | 生成一个用于比较这些代码片段的动态提示模板。
69 |
70 | Args:
71 | prompt: 描述问题的字符串。
72 | responses: 包含多个代码片段(字符串)的列表。
73 |
74 | Returns:
75 | 一个格式化好的、用于大模型评估的完整提示字符串。
76 | 如果 responses 为空,则返回错误信息。
77 | """
78 | if not responses:
79 | return "错误:未提供任何代码响应进行比较。"
80 |
81 | num_responses = len(responses)
82 |
83 | # 1. 构建模板的静态开头部分
84 | # 修改引言以适应多个片段
85 | header = f"""Compare the following {num_responses} code snippets that aim to solve the given problem.
86 | Evaluate each snippet based on efficiency, readability, and adherence to best practices.
87 | Identify the preferred snippet or rank them if applicable.
88 |
89 | ### Problem:
90 | {prompt}
91 | """
92 |
93 | # 2. 动态构建每个代码片段的部分
94 | snippets_section = ""
95 | for i, response in enumerate(responses):
96 | # 生成标签:Code A, Code B, ...
97 | if i < 26: # 最多支持到 Z
98 | label = string.ascii_uppercase[i]
99 | else: # 如果超过 26 个,就用数字编号
100 | label = str(i + 1)
101 |
102 | snippets_section += f"\n### Code {label}:\n{response}\n" # 每个代码块前后加换行
103 |
104 | # 3. 构建模板的静态结尾部分
105 | footer = """
106 | Code Analysis (Provide a brief analysis for each snippet, discussing its pros and cons regarding efficiency, readability, and best practices):
107 |
108 | Preferred Code (Output only the single letter label of the most preferred code snippet in 【】 below this line, e.g., 【answer here】):
109 | """
110 |
111 | # 4. 组合所有部分
112 | full_prompt = header + snippets_section + footer
113 | return full_prompt
114 |
115 |
116 | async def make_request(session: aiohttp.ClientSession, prompt: str, index: int, api_key: str, post_url: str,
117 | cookie: str) -> Dict:
118 | url = post_url
119 | headers = {
120 | "Authorization": api_key,
121 | "Content-Type": "application/json",
122 | "Cookie": cookie
123 | }
124 |
125 | payload = {
126 | "stream": False,
127 | "model": "default",
128 | "messages": [
129 | {
130 | "role": "user",
131 | "content": prompt
132 | }
133 | ],
134 | "max_tokens": 4096,
135 | "temperature": 0.0,
136 | "n": 1
137 | }
138 |
139 | try:
140 | async with session.post(url, headers=headers, json=payload, timeout=1000) as response:
141 | response.raise_for_status()
142 | json_response = await response.json()
143 | return {
144 | 'index': index,
145 | 'status': 'success',
146 | 'prompt': prompt,
147 | 'response': reason_post_process(json_response['choices'][0]['message']['content'], index)
148 | }
149 | except aiohttp.ClientError as e:
150 | return {
151 | 'index': index,
152 | 'status': 'error',
153 | 'prompt': prompt,
154 | 'response': f"请求失败:{str(e)}"
155 | }
156 | except json.JSONDecodeError as e:
157 | return {
158 | 'index': index,
159 | 'status': 'error',
160 | 'prompt': prompt,
161 | 'response': f"JSON解析错误:{str(e)}"
162 | }
163 |
164 |
165 | def extract_answer(text):
166 | """
167 | 提取字符串中第一个【】内的内容,并返回字符串。
168 |
169 | Args:
170 | text: 输入字符串。
171 |
172 | Returns:
173 | 第一个【】内的内容字符串,如果没有任何匹配,则返回 None。
174 | """
175 | pattern = r'【(.*?)】'
176 | match = re.search(pattern, text)
177 | if match:
178 | return match.group(1)
179 | else:
180 | return None
181 |
182 |
183 | ### 优化
184 | async def process_group(session: aiohttp.ClientSession, group: dict) -> dict:
185 | """处理单个分组的异步函数"""
186 | prompt = group['prompt']
187 | responses = group['responses']
188 | full_prompt = create_dynamic_comparison_prompt(prompt, responses)
189 |
190 | # 使用已存在的session发送请求
191 | result = await make_request(session, full_prompt, 0) # 索引在这里不重要
192 | selected_label = extract_answer(result['response'])
193 |
194 | return {
195 | 'prompt': prompt,
196 | 'responses': responses,
197 | 'indices': group['indices'],
198 | 'selected_label': selected_label
199 | }
200 |
201 |
202 | async def process_comparisons_async(completions, num_per_group=3, max_concurrent=5):
203 | start_time = datetime.now()
204 | grouped_data = []
205 | current_group = {
206 | 'prompt': None,
207 | 'responses': [],
208 | 'indices': []
209 | }
210 |
211 | # 1. 分组数据并记录原始索引
212 | for idx, item in enumerate(completions):
213 | prompt = item['messages'][0]['content']
214 | response = item['messages'][1]['content']
215 |
216 | if idx % num_per_group == 0 and idx != 0:
217 | grouped_data.append(current_group)
218 | current_group = {
219 | 'prompt': prompt,
220 | 'responses': [response],
221 | 'indices': [idx]
222 | }
223 | else:
224 | if not current_group['responses']:
225 | current_group['prompt'] = prompt
226 | current_group['responses'].append(response)
227 | current_group['indices'].append(idx)
228 |
229 | # 处理最后一组
230 | if current_group['responses']:
231 | grouped_data.append(current_group)
232 |
233 | print(f"总共需要处理 {len(grouped_data)} 个分组")
234 |
235 | # 2. 并发处理所有分组
236 | async with aiohttp.ClientSession() as session:
237 | tasks = []
238 | # 使用信号量控制并发数
239 | semaphore = asyncio.Semaphore(max_concurrent)
240 |
241 | # 将计数器移动到外层
242 | success_count = 0
243 | error_count = 0
244 |
245 | # 修改嵌套函数的结构,确保正确使用 nonlocal
246 | async def process_with_semaphore(group, pbar):
247 | nonlocal success_count, error_count # 正确的 nonlocal 声明位置
248 |
249 | async with semaphore:
250 | try:
251 | result = await process_group(session, group)
252 | success_count += 1
253 | pbar.update(1)
254 | pbar.set_postfix({'成功': success_count, '失败': error_count})
255 | return result
256 | except Exception as e:
257 | error_count += 1
258 | pbar.update(1)
259 | pbar.set_postfix({'成功': success_count, '失败': error_count})
260 | return {
261 | 'prompt': group['prompt'],
262 | 'responses': group['responses'],
263 | 'indices': group['indices'],
264 | 'selected_label': None,
265 | 'error': str(e)
266 | }
267 |
268 | with tqdm(total=len(grouped_data), desc="处理进度") as pbar:
269 | tasks = [process_with_semaphore(group, pbar) for group in grouped_data]
270 | results = await asyncio.gather(*tasks)
271 |
272 | # 打印统计信息
273 | end_time = datetime.now()
274 | print(f"\n处理完成时间: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
275 | print(f"处理结果统计:")
276 | print(f"- 成功: {success_count}")
277 | print(f"- 失败: {error_count}")
278 | print(f"- 总计: {len(grouped_data)}")
279 | print(f"- 耗时: {end_time - start_time}")
280 |
281 | # 3. 更新原始数据
282 | for result in results:
283 | if result.get('error'):
284 | # 处理错误情况
285 | for original_idx in result['indices']:
286 | completions[original_idx]['comparison'] = {
287 | 'error': result['error']
288 | }
289 | continue
290 |
291 | selected_label = result['selected_label'].strip().upper() if result['selected_label'] else None
292 | for pos, original_idx in enumerate(result['indices']):
293 | label = string.ascii_uppercase[pos]
294 | completions[original_idx]['comparison'] = {
295 | 'group_prompt': result['prompt'],
296 | 'position_label': label,
297 | 'is_best': label == selected_label if selected_label else False,
298 | 'best_label': selected_label,
299 | 'compared_with': len(result['responses'])
300 | }
301 |
302 | return completions
303 |
304 |
305 | def _save_chunk_to_jsonl(chunk: List[dict], output_filename: str) -> bool:
306 | """
307 | 保存一个数据块为 JSONL 格式的文件 (使用 orjson)。
308 |
309 | Args:
310 | chunk: 要保存的数据块。
311 | output_filename: 输出文件名。
312 |
313 | Returns:
314 | True 如果保存成功,False 如果发生错误。
315 | """
316 | try:
317 | with open(output_filename, 'wb') as f: # 注意使用二进制写入模式 'wb'
318 | for item in chunk:
319 | f.write(orjson.dumps(item) + b'\n') # orjson.dumps 返回 bytes
320 | return True
321 | except Exception as e:
322 | print(f"保存 chunk 到文件 '{output_filename}' 时发生错误: {str(e)}")
323 | return False
324 |
325 |
326 | def save_results_in_chunks(completions: List[dict], output_prefix: str = 'output', chunksize: int = 1000):
327 | """
328 | 将处理结果按 chunksize 分块保存为多个 JSONL 文件 (使用 orjson)。
329 |
330 | Args:
331 | completions: 处理后的完整数据列表。
332 | output_prefix: 输出文件名的前缀。
333 | chunksize: 每个文件保存的数据条数。
334 | """
335 | # 提取目录路径
336 | output_dir = os.path.dirname(output_prefix)
337 | # 如果目录不存在,则创建目录
338 | if output_dir and not os.path.exists(output_dir):
339 | os.makedirs(output_dir, exist_ok=True)
340 |
341 | num_chunks = (len(completions) + chunksize - 1) // chunksize
342 | for i in range(num_chunks):
343 | start_index = i * chunksize
344 | end_index = min((i + 1) * chunksize, len(completions))
345 | chunk = completions[start_index:end_index]
346 | output_filename = f"{output_prefix}_part_{i + 1}.jsonl"
347 | if _save_chunk_to_jsonl(chunk, output_filename):
348 | print(f"已保存 chunk {i + 1} 到文件: {output_filename}")
349 |
350 |
351 | if __name__ == "__main__":
352 | completions = load_completions('/rejected')
353 | # completions = completions[:10]
354 | processed_data = asyncio.run(process_comparisons_async(completions, 3, 60))
355 | output_file = './v3_eval/v3eval'
356 |
357 | save_results_in_chunks(processed_data, output_file, 16000)
358 |
--------------------------------------------------------------------------------
/rlhf/rejected_sampling/run_generate.sh:
--------------------------------------------------------------------------------
1 | dataset_name=''
2 | model_name_or_path=''
3 | save_filename='rejected_generate'
4 |
5 | nohup python generate.py \
6 | --dataset_name "$dataset_name" \
7 | --model_name_or_path "$model_name_or_path" \
8 | --save_filename "$save_filename" \
9 | --auto_adapt True \
10 | --num_completions 3 \
11 | --temperature 0.8 \
12 | --response_length 4096 \
13 | --top_p 0.9 \
14 | --tensor_parallel_size 8 \
15 | --chunk_size 50000
--------------------------------------------------------------------------------
/rlhf/rejected_sampling/template.py:
--------------------------------------------------------------------------------
1 | # prompt_templates.py
2 |
3 | DEFAULT_SKILL = "summarization"
4 |
5 | GENERATION_TEMPLATES = {
6 | "chat": """
7 | You are an expert assistant and your goal is to provide the most helpful and accurate response to the following chat. The chat maybe between two users User A and User B or it can be a single turn User A. Please ensure that your response is clear, concise, and addresses all aspects of the prompt.
8 | Don't answer by simply one word. Try to make your answer diverse and interesting.
9 |
10 | ### Prompt:
11 | {prompt}
12 |
13 | ### Response:
14 | """,
15 | "summarization": """
16 | Please provide a concise summary of the following text, highlighting the most important points without including unimportant or irrelevant details.
17 |
18 | ### Text to Summarize:
19 | {prompt}
20 |
21 | Summary:
22 | """,
23 | "code_generation": """
24 | Please write a Python function that solves the following problem. Ensure the code is efficient, readable, and follows best practices.
25 |
26 | ### Problem:
27 | {prompt}
28 |
29 | Python Code:
30 | """,
31 | "safety": """
32 | Please provide a safe and appropriate response to the following scenario or question. Ensure your response adheres to ethical guidelines and promotes user safety.
33 |
34 | ### Scenario:
35 | {prompt}
36 |
37 | Safe Response:
38 | """,
39 | }
40 |
41 | JUDGMENT_TEMPLATES = {
42 | "chat": """
43 | You are an evaluator tasked with assessing the response to the prompt based on three criteria: accuracy, instruction following, and tone/presentation.
44 | For each response provided to the prompt, assign a score from 1 to 3 for each criterion, where 1 is poor, 2 is average, and 3 is excellent.
45 | Don't give full scores unless you're mind blown by the response.
46 |
47 | ### Prompt:
48 | {prompt}
49 |
50 | ### Response:
51 | {response}
52 |
53 | ### Evaluation Criteria:
54 | 1. **Accuracy**: How factually correct and relevant is the response?
55 | 2. **Instruction Following**: How well does the response adhere to the given instructions?
56 | 3. **Tone/Presentation**: How clear, appropriate, and well-structured is the response?
57 |
58 | ### Scoring Template:
59 | Response:
60 | - **Accuracy**: [Score 1-3]
61 | - **Instruction Following**: [Score 1-3]
62 | - **Tone/Presentation**: [Score 1-3]
63 |
64 | Total score:
65 | """,
66 | "summarization": """
67 | Compare the following summaries of the given text. Which one does a better job of summarizing the most important points without including unimportant or irrelevant details? Judge based on accuracy, coverage, and coherence.
68 |
69 | ### Original Text:
70 | {prompt}
71 |
72 | ### Summary A:
73 | {response1}
74 |
75 | ### Summary B:
76 | {response2}
77 |
78 | Comparison:
79 | Preferred Summary:
80 | """,
81 | "code_generation": """
82 | Compare the following two code snippets that aim to solve the given problem. Which one is more efficient, readable, and follows best practices?
83 |
84 | ### Problem:
85 | {prompt}
86 |
87 | ### Code A:
88 | {response1}
89 |
90 | ### Code B:
91 | {response2}
92 |
93 | Code Analysis:
94 | Preferred Code:
95 | """,
96 | "safety": """
97 | Evaluate the safety implications of the following two responses to the given scenario. Which one is safer and more appropriate?
98 |
99 | ### Scenario:
100 | {prompt}
101 |
102 | ### Response A:
103 | {response1}
104 |
105 | ### Response B:
106 | {response2}
107 |
108 | Safety Analysis:
109 | Preferred Response:
110 | """,
111 | }
112 |
113 |
114 | def get_generation_template(skill: str) -> str:
115 | return GENERATION_TEMPLATES.get(skill, GENERATION_TEMPLATES[DEFAULT_SKILL])
116 |
117 |
118 | def get_judgment_template(skill: str) -> str:
119 | return JUDGMENT_TEMPLATES.get(skill, JUDGMENT_TEMPLATES[DEFAULT_SKILL])
120 |
--------------------------------------------------------------------------------
/rlhf/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.1.1
2 | datasets==3.1.0
3 | peft==0.13.2
4 | transformers==4.46.3
5 | trl==0.12.1
6 | deepspeed==0.15.4
7 | torch==2.5.1
--------------------------------------------------------------------------------
/rlhf/rlhf_args/base_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from transformers import TrainingArguments
3 |
4 |
5 | @dataclass
6 | class BaseConfig(TrainingArguments):
7 | """
8 | 训练参数
9 | """
10 | output_dir: str = field(default='./output', metadata={"help": "模型训练完成后的保存路径"})
11 | num_train_epochs: int = 1,
12 |
13 | per_device_train_batch_size: int = 2
14 | gradient_checkpointing: bool = True
15 | gradient_accumulation_steps: int = 16,
16 |
17 | learning_rate: float = 2e-4
18 | logging_steps: int = 10
19 | save_steps: int = 500
20 | save_strategy: str = "steps"
21 | save_total_limit: int = 2
22 | lr_scheduler_type: str = "cosine",
23 | warmup_steps: int = 10
24 | optim: str = 'adamw_torch'
25 | report_to: str = 'tensorboard'
26 | remove_unused_columns: bool = False
27 | bf16: bool = False
28 | fp16: bool = False
--------------------------------------------------------------------------------
/rlhf/rlhf_args/cpo-simpo_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal
3 | from cpo_config import CPOConfig
4 |
5 |
6 | @dataclass
7 | class CPOSimPOConfig(CPOConfig):
8 | """
9 | 基于CPOConfig,只需修改loss_type为simpo且cpo_alpha不为0即可
10 | """
11 | loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "simpo"
12 | """The type of loss to use."""
13 | cpo_alpha: float = 0.5
14 | """combined use of CPO and SimPO, which enables more stable training and improved performance.A non-zero
15 | cpo_alpha"""
16 |
--------------------------------------------------------------------------------
/rlhf/rlhf_args/cpo_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from rlhf_args.base_config import BaseConfig
3 | from trl import CPOConfig as TrlCPOConfig
4 |
5 |
6 | @dataclass
7 | class CPOConfig(BaseConfig, TrlCPOConfig):
8 | pass
9 |
10 |
11 |
--------------------------------------------------------------------------------
/rlhf/rlhf_args/dpo_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from rlhf_args.base_config import BaseConfig
3 | from typing import Literal
4 | from trl import DPOConfig as TrlDPOConfig
5 |
6 |
7 | @dataclass
8 | class DPOConfig(BaseConfig, TrlDPOConfig):
9 | """
10 | 训练参数, 可直接在此修改. 想看更多参数可直接在TrlDPOConfig中去看
11 | """
12 | beta: float = 0.1
13 | label_smoothing: float = 0.0
14 | loss_type: Literal[
15 | "sigmoid",
16 | "hinge",
17 | "ipo",
18 | "exo_pair",
19 | "nca_pair",
20 | "robust",
21 | "bco_pair",
22 | "sppo_hard",
23 | "aot",
24 | "aot_pair",
25 | "apo_zero",
26 | "apo_down",
27 | ] = "sigmoid"
28 | label_pad_token_id: int = -100
29 |
--------------------------------------------------------------------------------
/rlhf/rlhf_args/kto_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from rlhf_args.base_config import BaseConfig
3 | from trl import KTOConfig as TrlKTOConfig
4 |
5 |
6 | @dataclass
7 | class KTOConfig(BaseConfig, TrlKTOConfig):
8 | desirable_weight: float = 1.0
9 | undesirable_weight: float = 1.0
10 |
--------------------------------------------------------------------------------
/rlhf/rlhf_args/ppo_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from trl import PPOConfig as TrlPPOConfig
3 | from rlhf_args.base_config import BaseConfig
4 |
5 |
6 | @dataclass
7 | class PPOConfig(BaseConfig, TrlPPOConfig):
8 | eval_samples = 30 # eval数量
9 |
--------------------------------------------------------------------------------
/rlhf/rlhf_args/reward_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from rlhf_args.base_config import BaseConfig
3 | from trl import RewardConfig as TrlRewardConfig
4 |
5 |
6 | @dataclass
7 | class RewardConfig(BaseConfig, TrlRewardConfig):
8 | pass
9 |
--------------------------------------------------------------------------------
/rlhf/rlhf_args/rloo_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 | from rlhf_args.base_config import BaseConfig
4 | from trl import RLOOConfig as TrlPLOOConfig
5 |
6 |
7 | # 支持直接通过total_episodes确定训练步数,也支持通过在TrainingArguments中配置num_train_epochs确定训练步数。
8 | @dataclass
9 | class RLOOConfig(BaseConfig, TrlPLOOConfig):
10 | reward_model_path: str = "./"
11 | sft_model_path: str = "./"
12 | total_episodes: Optional[int] = None
13 | eval_samples = 30 # eval数量
14 |
--------------------------------------------------------------------------------
/rlhf/rlhf_args/simpo_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal
3 | from cpo_config import CPOConfig
4 |
5 |
6 | @dataclass
7 | class SimPOConfig(CPOConfig):
8 | """
9 | 基于CPOConfig,只需修改CPO的loss type为simpo,cpo_alpha设为0即可
10 | """
11 | loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "simpo"
12 | """The type of loss to use."""
13 | cpo_alpha: float = 0
14 | """A hyperparameter that controls the strength of the BC regularizer in CPO training."""
15 | simpo_gamma: float = 0.5
16 | """A target reward margin for the SimPO loss, used only when the "simpo" option is enabled."""
17 |
18 |
--------------------------------------------------------------------------------
/rlhf/rlhf_run.sh:
--------------------------------------------------------------------------------
1 | # 使用显卡数量需在yaml文件中修改num_processes参数
2 |
3 | # rlhf_type:[PPO,RLOO,CPO,DPO,SimPO,CPOSimPO,Reward]
4 | # train_mode:[lora, qlora, full]
5 |
6 | TRAIN_DATA='./'
7 | MODEL_PATH='./'
8 | OUTPUT_PATH='./'
9 |
10 | CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ./ds_config/ds_zero2.yaml ./train_rlhf.py \
11 | --model_name_or_path "$MODEL_PATH" \
12 | --train_data_path "$TRAIN_DATA" \
13 | --output_dir "$OUTPUT_PATH" \
14 | --rlhf_type "DPO" \
15 | --train_mode "lora" \
16 | --learning_rate 2e-5 \
17 | --per_device_train_batch_size 2 \
18 | --gradient_checkpointing \
19 | --gradient_accumulation_steps 8 \
20 | --logging_steps 2 \
21 | --num_train_epochs 1 \
22 | --bf16 \
23 | --save_strategy "steps" \
24 | --report_to "wandb" \
25 | --save_steps 180 \
26 | --save_total_limit 5 \
27 | --warmup_steps 10 \
28 | --remove_unused_columns False\
29 | --lr_scheduler_type "cosine"
30 |
31 | # [CPO,DPO,SimPO,CPOSimPO,Reward] 可直接使用上述运行
32 |
33 | # [PPO,RLOO] 需要额外添加如下参数:
34 | # --reward_model_path './'\
35 | # --local_rollout_forward_batch_size 1\
36 | # --missing_eos_penalty 1.0\
37 | # --num_ppo_epochs 1 \
38 | # --num_mini_batches 1
--------------------------------------------------------------------------------
/rlhf/train_gkd.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import random
3 | from transformers import AutoTokenizer
4 | from trl import (
5 | GKDConfig,
6 | GKDTrainer,
7 | LogCompletionsCallback,
8 | ModelConfig,
9 | ScriptArguments,
10 | TrlParser,
11 | get_kbit_device_map,
12 | get_peft_config,
13 | get_quantization_config,
14 | )
15 | from accelerate import PartialState
16 |
17 |
18 | if __name__ == "__main__":
19 | parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig))
20 | args, training_args, model_config = parser.parse_args_and_config()
21 |
22 | ################
23 | # Model & Tokenizer
24 | ################
25 | quantization_config = get_quantization_config(model_config)
26 | model_kwargs = dict(
27 | revision=model_config.model_revision,
28 | trust_remote_code=model_config.trust_remote_code,
29 | attn_implementation=model_config.attn_implementation,
30 | torch_dtype=model_config.torch_dtype,
31 | use_cache=False if training_args.gradient_checkpointing else True,
32 | device_map=get_kbit_device_map() if quantization_config is not None else None,
33 | quantization_config=quantization_config,
34 | )
35 | training_args.model_init_kwargs = model_kwargs
36 |
37 | teacher_model_kwargs = dict(
38 | revision=model_config.model_revision,
39 | trust_remote_code=model_config.trust_remote_code,
40 | attn_implementation=model_config.attn_implementation,
41 | torch_dtype=model_config.torch_dtype,
42 | use_cache=True,
43 | device_map=get_kbit_device_map() if quantization_config is not None else None,
44 | quantization_config=quantization_config,
45 | )
46 | training_args.teacher_model_init_kwargs = teacher_model_kwargs
47 |
48 | tokenizer = AutoTokenizer.from_pretrained(
49 | model_config.model_name_or_path,
50 | trust_remote_code=model_config.trust_remote_code,
51 | padding_side="left",
52 | )
53 | if tokenizer.pad_token is None:
54 | tokenizer.pad_token = tokenizer.eos_token
55 |
56 | ################
57 | # Dataset
58 | ################
59 | dataset = load_dataset(data_files=args.dataset_name, path='json') # 适配jsonl格式
60 |
61 | with PartialState().local_main_process_first():
62 | dataset = dataset.map(
63 | lambda x: {
64 | "prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=True)
65 | },
66 | num_proc=training_args.dataset_num_proc,
67 | )
68 | train_data = dataset['train']
69 | test_data = train_data.select(random.sample(range(len(train_data)), 20))
70 |
71 | ################
72 | # Training
73 | ################
74 | trainer = GKDTrainer(
75 | model=model_config.model_name_or_path,
76 | teacher_model=training_args.teacher_model_name_or_path,
77 | args=training_args,
78 | train_dataset=dataset[args.dataset_train_split],
79 | eval_dataset=test_data,
80 | processing_class=tokenizer,
81 | peft_config=get_peft_config(model_config),
82 | )
83 | completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8)
84 | trainer.add_callback(completions_callback)
85 | trainer.train()
86 |
87 | # Save
88 | trainer.save_model(training_args.output_dir)
89 |
--------------------------------------------------------------------------------
/rlhf/train_rlhf.py:
--------------------------------------------------------------------------------
1 | import deepspeed
2 | deepspeed.ops.op_builder.CPUAdamBuilder().load()
3 | import importlib
4 | import os
5 | from peft import LoraConfig, TaskType
6 | from datasets import load_dataset
7 | from transformers import (
8 | AutoModelForCausalLM,
9 | AutoTokenizer,
10 | AutoModelForSequenceClassification,
11 | HfArgumentParser,
12 | BitsAndBytesConfig,
13 | )
14 | import torch
15 | from accelerate import PartialState
16 | from trl import DPOTrainer, CPOTrainer, PPOTrainer, RLOOTrainer, RewardTrainer, KTOTrainer
17 | from common_args import CommonArgs
18 | from utils.util import find_all_linear_names
19 | from loguru import logger
20 | from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
21 |
22 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
23 |
24 | WITH_REWARD_MODEL = ['RLOO', 'PPO']
25 | USE_REF_MODEL = ['DPO', 'RLOO', 'KTO']
26 |
27 | trainer_map = {
28 | 'PPO': PPOTrainer,
29 | "RLOO": RLOOTrainer,
30 | "DPO": DPOTrainer,
31 | "CPO": CPOTrainer,
32 | "SimPO": CPOTrainer,
33 | "CPOSimPO": CPOTrainer,
34 | 'Reward': RewardTrainer,
35 | "KTO": KTOTrainer
36 | }
37 |
38 | train_args_path = {
39 | 'PPO': 'rlhf_args/ppo_config.py',
40 | "RLOO": 'rlhf_args/rloo_config.py',
41 | "DPO": 'rlhf_args/dpo_config.py',
42 | "CPO": 'rlhf_args/cpo_config.py',
43 | "SimPO": 'rlhf_args/simpo_config.py',
44 | "CPOSimPO": 'rlhf_args/cpo-simpo_config.py',
45 | 'Reward': 'rlhf_args/reward_config.py',
46 | 'KTO': 'rlhf_args/kto_config.py'
47 | }
48 |
49 |
50 | def load_config(args, remaining_args):
51 | # 根据config_option加载相应的配置
52 | module_path = train_args_path[args.rlhf_type].replace("/", ".").rstrip(".py")
53 | # 动态导入模块
54 | module = importlib.import_module(module_path)
55 | # 每个模块导入的类名均为TrainArgument
56 | class_name = args.rlhf_type + "Config"
57 | # 使用getattr获取模块中的类
58 | argument = getattr(module, class_name)
59 |
60 | parser_b = HfArgumentParser((argument,))
61 | train_args, = parser_b.parse_args_into_dataclasses(args=remaining_args)
62 | return train_args
63 |
64 |
65 | def load_judge_reward():
66 | pass
67 |
68 |
69 | def load_classification_reward(path, model_kwargs):
70 | try:
71 | reward_model = AutoModelForSequenceClassification.from_pretrained(path, num_labels=1,
72 | **model_kwargs)
73 | return reward_model
74 | except Exception as e:
75 | assert False, "模型不支持AutoModelForSequenceClassification需要在对应config文件中添加映射"
76 |
77 |
78 | def load_tokenizer(path):
79 | tokenizer = AutoTokenizer.from_pretrained(
80 | path,
81 | padding_side="left",
82 | trust_remote_code=True,
83 | )
84 |
85 | if tokenizer.pad_token is None:
86 | tokenizer.add_special_tokens({"pad_token": "[PAD]"})
87 | return tokenizer
88 |
89 |
90 | def prepare_dataset(dataset, tokenizer):
91 | """pre-tokenize the dataset before training; only collate during training"""
92 |
93 | def tokenize(element):
94 | outputs = tokenizer(
95 | element['prompt'],
96 | padding=False,
97 | )
98 | return {"input_ids": outputs["input_ids"]}
99 |
100 | return dataset.map(
101 | tokenize,
102 | batched=True
103 | )
104 |
105 |
106 | def main():
107 | parser = HfArgumentParser((CommonArgs,))
108 | args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
109 | # 根据CommonArgs中的config_option动态加载配置
110 | config = load_config(args, remaining_args)
111 |
112 | ################
113 | # Data
114 | ################
115 | train_dataset = load_dataset(data_files=args.train_data_path, path='json')
116 |
117 | ################
118 | # Model & Tokenizer
119 | ################
120 | tokenizer = load_tokenizer(args.model_name_or_path)
121 |
122 | model_kwargs = dict(
123 | trust_remote_code=True,
124 | torch_dtype=torch.float16 if config.fp16 else torch.bfloat16,
125 | )
126 |
127 | if args.train_mode == 'qlora':
128 | quantization_config = BitsAndBytesConfig(
129 | load_in_4bit=True,
130 | bnb_4bit_compute_dtype=torch.float16 if config.fp16 else torch.bfloat16,
131 | bnb_4bit_use_double_quant=True,
132 | bnb_4bit_quant_type="nf4",
133 | llm_int8_threshold=6.0,
134 | llm_int8_has_fp16_weight=False,
135 | )
136 | model_kwargs.update(quantization_config=quantization_config)
137 |
138 | # 加载policy model
139 | if args.rlhf_type == 'Reward':
140 | policy = load_classification_reward(args.model_name_or_path, model_kwargs)
141 | else:
142 | policy = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
143 |
144 | if args.train_mode in ['lora', 'qlora']:
145 | lora_config = LoraConfig(
146 | task_type=TaskType.SEQ_CLS if args.rlhf_type == 'Reward' else TaskType.CAUSAL_LM,
147 | target_modules=find_all_linear_names(policy),
148 | r=args.lora_rank, # Lora 秩
149 | lora_alpha=args.lora_alpha, # Lora alpha,具体作用参见 Lora 原理
150 | lora_dropout=args.lora_dropout, # Dropout 比例
151 | use_dora=args.use_dora
152 | )
153 | ref_model = None # if peft, the model with a disabled adapter
154 | elif args.rlhf_type in USE_REF_MODEL:
155 | ref_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
156 | lora_config = None
157 |
158 | # 决定是否加载Reward model
159 | if args.rlhf_type in WITH_REWARD_MODEL:
160 | # 如果模型不支持AutoModelForSequenceClassification需要在对应config文件中添加映射
161 | reward_model = load_classification_reward(config.reward_model_path, model_kwargs)
162 |
163 | # data process
164 | # Compute that only on the main process for faster data processing.
165 | # see: https://github.com/huggingface/trl/pull/1255
166 | train_dataset = train_dataset.select(range(len(train_dataset) - config.eval_samples))
167 | eval_dataset = train_dataset.select(range(len(train_dataset) - config.eval_samples, len(train_dataset)))
168 | with PartialState().local_main_process_first():
169 | train_dataset = prepare_dataset(train_dataset, tokenizer)
170 | eval_dataset = prepare_dataset(eval_dataset, tokenizer)
171 | if ref_model is None:
172 | ref_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True,
173 | torch_dtype=torch.float16 if config.fp16 else torch.bfloat16)
174 |
175 | ################
176 | # Training
177 | ################
178 | trainer_kwargs_map = {
179 | "DPO": dict(
180 | model=policy,
181 | ref_model=ref_model,
182 | args=config,
183 | train_dataset=train_dataset['train'],
184 | eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None,
185 | processing_class=tokenizer,
186 | peft_config=lora_config,
187 | )
188 | if args.rlhf_type in ['DPO', 'KTO']
189 | else dict()
190 | ,
191 | 'CPO': dict(
192 | model=policy,
193 | args=config,
194 | train_dataset=train_dataset['train'],
195 | eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None,
196 | processing_class=tokenizer,
197 | peft_config=lora_config,
198 | )
199 | if args.rlhf_type in ['CPO', 'SimPO', 'CPOSimPO']
200 | else dict()
201 | ,
202 | "PPO": dict(
203 | ),
204 | "RLOO": dict(
205 | config=config,
206 | processing_class=tokenizer,
207 | policy=policy,
208 | ref_policy=ref_model,
209 | reward_model=reward_model,
210 | train_dataset=train_dataset,
211 | eval_dataset=eval_dataset,
212 | )
213 | if args.rlhf_type == 'RLOO'
214 | else dict()
215 | ,
216 | 'Reward': dict(
217 | model=policy,
218 | processing_class=tokenizer,
219 | args=config,
220 | train_dataset=train_dataset['train'],
221 | eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None,
222 | peft_config=lora_config,
223 | )
224 | if args.rlhf_type == 'Reward'
225 | else dict()
226 | }
227 |
228 | trainer_kwargs_map['SimPO'] = trainer_kwargs_map['CPO'].copy()
229 | trainer_kwargs_map['CPOSimPO'] = trainer_kwargs_map['CPO'].copy()
230 | trainer_kwargs_map['KTO'] = trainer_kwargs_map['DPO'].copy()
231 |
232 | # 从字典中获取相应的 Trainer 类
233 | trainer_kwargs = trainer_kwargs_map.get(args.rlhf_type)
234 | TrainerClass = trainer_map.get(args.rlhf_type)
235 | if TrainerClass is None:
236 | raise ValueError(f"Unknown trainer type: {args.rlhf_type}")
237 |
238 | trainer = TrainerClass(**trainer_kwargs)
239 | trainer.train()
240 | # trainer.save_model(config.output_dir)
241 |
242 |
243 | if __name__ == "__main__":
244 | main()
245 |
--------------------------------------------------------------------------------
/rlhf/utils/util.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import os
3 | import sys
4 | from dataclasses import dataclass
5 | from typing import List, Dict, Optional, Tuple, Union
6 | import copy
7 |
8 | import torch
9 | import torch.nn as nn
10 | from transformers import HfArgumentParser
11 | from transformers.hf_argparser import DataClassType
12 |
13 |
14 | class ArgumentParserPlus(HfArgumentParser):
15 | def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
16 | """
17 | Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
18 |
19 | Args:
20 | yaml_arg (`str`):
21 | The path to the config file used
22 | other_args (`List[str]`, *optional`):
23 | A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
24 |
25 | Returns:
26 | [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
27 | """
28 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
29 |
30 | outputs = []
31 | # strip other args list into dict of key-value pairs
32 | other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
33 | used_args = {}
34 |
35 | # overwrite the default/loaded value with the value provided to the command line
36 | # noqa adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
37 | for data_yaml, data_class in zip(arg_list, self.dataclass_types):
38 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
39 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
40 | for arg, val in other_args.items():
41 | # add only if in keys
42 |
43 | if arg in keys:
44 | base_type = data_yaml.__dataclass_fields__[arg].type
45 | inputs[arg] = val
46 |
47 | # cast type for ints, floats (default to strings)
48 | if base_type in [int, float]:
49 | inputs[arg] = base_type(val)
50 |
51 | if base_type == List[str]:
52 | inputs[arg] = [str(v) for v in val.split(",")]
53 |
54 | # bool of a non-empty string is True, so we manually check for bools
55 | if base_type == bool:
56 | if val in ["true", "True"]:
57 | inputs[arg] = True
58 | else:
59 | inputs[arg] = False
60 |
61 | # add to used-args so we can check if double add
62 | if arg not in used_args:
63 | used_args[arg] = val
64 | else:
65 | raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
66 |
67 | obj = data_class(**inputs)
68 | outputs.append(obj)
69 |
70 | return outputs
71 |
72 | def parse(self) -> Union[DataClassType, Tuple[DataClassType]]:
73 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
74 | # If we pass only one argument to the script and it's the path to a YAML file,
75 | # let's parse it to get our arguments.
76 | output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
77 | # parse command line args and yaml file
78 | elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
79 | output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
80 | # parse command line args only
81 | else:
82 | output = self.parse_args_into_dataclasses()
83 |
84 | if len(output) == 1:
85 | output = output[0]
86 | return output
87 |
88 |
89 | def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor:
90 | """
91 | Finds the index of the first `True` value in each row of a boolean tensor. If no `True` value exists in a row,
92 | it returns the length of the row.
93 |
94 | Args:
95 | bools (torch.Tensor): A boolean tensor of shape (batch_size, sequence_length), where `True` values indicate
96 | the positions of interest.
97 | dtype (torch.dtype): The data type to use for the output indices (default is torch.long).
98 |
99 | Returns:
100 | torch.Tensor: A tensor of shape (batch_size,) containing the index of the first `True` value in each row.
101 | If a row has no `True` value, the index will be the length of the row.
102 | """
103 |
104 | # Get the length of each row (i.e., the number of columns in the last dimension)
105 | # row_len is a scalar representing the length of each sequence (sequence_length)
106 | row_len = bools.size(-1)
107 |
108 | # Calculate the index positions for the first `True` in each row
109 | # ~bools: Invert the boolean values (True becomes False and vice versa)
110 | # ~bools.type(dtype): Convert the inverted boolean tensor to the specified dtype (0 for True, 1 for False)
111 | # row_len * (~bools).type(dtype): For `False` values, this will give `row_len`, for `True` values it gives 0.
112 | # torch.arange(row_len, dtype=dtype, device=bools.device): Generates a tensor with values [0, 1, 2, ..., row_len-1]
113 | # for each row. Shape: (sequence_length,)
114 | # zero_or_index: Shape (batch_size, sequence_length). This tensor contains the indices for `True` values and `row_len`
115 | # for `False` values.
116 | zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
117 |
118 | # Return the minimum value in each row (i.e., the first `True` index or `row_len` if none exist)
119 | # torch.min(zero_or_index, dim=-1).values: This returns the minimum value in each row, which corresponds to the first
120 | # `True` value's index or `row_len` if there is no `True` in that row.
121 | # The returned tensor has shape (batch_size,)
122 | return torch.min(zero_or_index, dim=-1).values
123 |
124 |
125 | def get_reward(
126 | model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
127 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128 | """
129 | This function computes reward scores for a batch of query responses based on a pre-trained reward model.
130 |
131 | Args:
132 | model (torch.nn.Module): The pre-trained reward model.
133 | query_responses (torch.Tensor): Tensor containing the tokenized responses for which to compute rewards.
134 | Shape: (batch_size, sequence_length)
135 | pad_token_id (int): The ID used for padding tokens in the tokenized sequences.
136 | context_length (int): The length of the prompt or context preceding the completions.
137 |
138 | Returns:
139 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
140 | - reward_logits: The logits output from the model for all tokens in the sequences.
141 | Shape: (batch_size, sequence_length)
142 | - final_scores: The final reward scores, one for each sequence, after adjusting for sequence lengths.
143 | Shape: (batch_size,)
144 | - sequence_lengths: The lengths of each sequence (excluding padding).
145 | Shape: (batch_size,)
146 |
147 | For example:
148 | query_responses = torch.tensor([
149 | [token0, token1, token2, token3, 0, 0],
150 | [token0, token1, token4, 0, 0, 0]
151 | ]) # 形状: (2, 6)
152 |
153 | attention_mask = query_responses != 0
154 | # [[1, 1, 1, 1, 0, 0],
155 | # [1, 1, 1, 0, 0, 0]]
156 |
157 | position_ids = attention_mask.cumsum(1) - attention_mask.long()
158 | # [[0, 1, 2, 3, 4, 4],
159 | # [0, 1, 2, 3, 3, 3]]
160 |
161 | input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
162 | # 在此例中,填充 token 已为 0,无变化
163 |
164 | reward_logits = torch.tensor([
165 | [[r0_0], [r0_1], [r0_2], [r0_3], [r0_4], [r0_5]],
166 | [[r1_0], [r1_1], [r1_2], [r1_3], [r1_4], [r1_5]]
167 | ]) # 形状: (2, 6, 1)
168 |
169 | query_responses[:, 2:] == 0
170 | # [[False, False, True, True],
171 | # [False, True, True, True]]
172 |
173 | sequence_lengths = first_true_indices(...) - 1 + 2
174 | # first_true_indices = [2, 1]
175 | # sequence_lengths = [2-1+2, 1-1+2] = [3, 2]
176 |
177 | final_scores = reward_logits[torch.arange(2), [3, 2]].squeeze(-1)
178 | # = reward_logits[[0,1], [3, 2]] ---> reward_logits[0, 3, :], reward_logits[1, 2, :]
179 | # = [r0_3, r1_2],形状: (2,)
180 | """
181 |
182 | # Create an attention mask where tokens that are not padding have a value of 1, and padding tokens have a value of 0
183 | # Shape: (batch_size, sequence_length)
184 | attention_mask = query_responses != pad_token_id
185 |
186 | # Calculate position IDs for each token, considering the cumulative sum of the attention mask (to exclude padding)
187 | # Shape: (batch_size, sequence_length)
188 | position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
189 |
190 | # Access the LM backbone from the reward model using its base model prefix
191 | lm_backbone = getattr(model, model.base_model_prefix)
192 |
193 | # Replace padding tokens with zeros in the input IDs (so padding tokens won't affect the model's processing)
194 | # Shape: (batch_size, sequence_length)
195 | input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
196 | output = lm_backbone(
197 | input_ids=input_ids,
198 | attention_mask=attention_mask,
199 | position_ids=position_ids,
200 | return_dict=True,
201 | output_hidden_states=True,
202 | use_cache=False, # otherwise mistral-based RM would error out
203 | )
204 | reward_logits = model.score(output.hidden_states[-1]) # (batch_size, sequence_length)
205 |
206 | # Calculate the length of each sequence by finding the first occurrence of a padding token after the context
207 | # sequence_lengths shape: (batch_size,)
208 | sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
209 | assert (
210 | reward_logits.shape[-1] == 1
211 | ), "Reward model should output a single scalar per token. Check if you added `num_labels=1` when doing `AutoModelForSequenceClassification.from_pretrained(...)`."
212 | # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
213 |
214 | # Return the reward logits for all tokens, the final reward scores for each sequence, and the sequence lengths
215 | return (
216 | # reward_logits shape: (batch_size, sequence_length)
217 | reward_logits,
218 | # final_scores shape: (batch_size,)
219 | reward_logits[
220 | torch.arange(reward_logits.size(0), device=reward_logits.device),
221 | sequence_lengths,
222 | ].squeeze(
223 | -1
224 | ), # Shape: (batch_size,)
225 | sequence_lengths,
226 | )
227 |
228 |
229 | def find_all_linear_names(model):
230 | """
231 | 找出所有全连接层,为所有全连接添加adapter
232 | """
233 | cls = nn.Linear
234 | lora_module_names = set()
235 | for name, module in model.named_modules():
236 | if isinstance(module, cls):
237 | names = name.split('.')
238 | lora_module_names.add(names[-1])
239 |
240 | if 'lm_head' in lora_module_names: # needed for 16-bit
241 | lora_module_names.remove('lm_head')
242 | lora_module_names = list(lora_module_names)
243 | return lora_module_names
244 |
245 |
246 | def is_right_apply_chat(tokenizer, prompt: List[Dict[str, str]], assistant_content: List[Dict[str, str]]) -> bool:
247 | """
248 | Checks if the assistant's content is correctly applied to the prompt in a chat template.
249 | Args:
250 | tokenizer: The tokenizer.
251 | prompt: The initial prompt message.
252 | assistant_content: The content provided by the assistant.
253 | Returns:
254 | bool: True if the assistant's content is correctly applied, False otherwise.
255 | """
256 | try:
257 | test_assistant = tokenizer.apply_chat_template(assistant_content, tokenize=False)
258 | test_prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
259 | conversation = copy.deepcopy(prompt)
260 | conversation.append(assistant_content[0])
261 | if tokenizer.apply_chat_template(conversation) == test_prompt + test_assistant:
262 | return True
263 | else:
264 | return False
265 | except Exception as e:
266 | return False
267 |
268 |
269 | def fix_chat_template_if_needed(tokenizer, prompt: List[Dict[str, str]], chosen: List[Dict[str, str]],
270 | rejected: List[Dict[str, str]]):
271 | """
272 | Fixes the chat template if needed.
273 | Args:
274 | tokenizer: The tokenizer.
275 | prompt: The initial prompt message.
276 | chosen: The chosen response, a list containing a single dictionary representing the chosen message.
277 | rejected: The rejected response, a list containing a single dictionary representing the rejected message.
278 | Returns:
279 | - tuple: A tuple containing the fixed prompt, fixed chosen response, and fixed rejected response.
280 | """
281 | conversation_chosen = copy.deepcopy(prompt)
282 | conversation_rejected = copy.deepcopy(prompt)
283 | conversation_chosen.append(chosen[0])
284 | conversation_rejected.append(rejected[0])
285 | conversation_chosen = tokenizer.apply_chat_template(conversation_chosen, tokenize=False)
286 | conversation_rejected = tokenizer.apply_chat_template(conversation_rejected, tokenize=False)
287 | # find position
288 | start_position = conversation_chosen.find(chosen[0]['content'][0])
289 | # The following is right
290 | fixed_prompt = conversation_chosen[:start_position]
291 | fixed_chosen = conversation_chosen[start_position:]
292 | fixed_rejected = conversation_rejected[start_position:]
293 | return fixed_prompt, fixed_chosen, fixed_rejected
294 |
--------------------------------------------------------------------------------
/run_eval_test.sh:
--------------------------------------------------------------------------------
1 | DATA_PATH=''
2 | OUTPUT_PATH=""
3 | MODEL_PATH=""
4 |
5 | BASE_CMD="CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 nohup accelerate launch --config_file rlhf/ds_config/ds_zero3.yaml main_train.py \
6 | --train_data_path "$DATA_PATH" \
7 | --model_name_or_path "$MODEL_PATH" \
8 | --max_len 4096 \
9 | --num_train_epochs 1 \
10 | --per_device_train_batch_size 2 \
11 | --gradient_accumulation_steps 8 \
12 | --task_type "sft" \
13 | --train_mode "full" \
14 | --output_dir "$OUTPUT_PATH" \
15 | --save_strategy "steps" \
16 | --save_steps 100 \
17 | --save_total_limit 3 \
18 | --learning_rate 2e-5 \
19 | --warmup_steps 16 \
20 | --logging_steps 1 \
21 | --lr_scheduler_type "cosine" \
22 | --gradient_checkpointing True \
23 | --report_to "wandb" \
24 | --bf16 True \
25 | --auto_adapt True"
26 |
27 | # 评估相关参数(仅在use_eval_in_train为True时使用)
28 | EVAL_ARGS="--use_eval_in_train True \
29 | --test_datasets_path "data/test.jsonl" \
30 | --max_new_tokens 4096 \
31 | --freq 4 \
32 | --metrics "code" \
33 | --vllm_server_port 8001 \
34 | --vllm_server_timeout 30 \
35 | --save_best_checkpoints True \
36 | --max_checkpoints 2 \
37 | --start_update_best_checkpoints 4 \
38 | --prompts_apply_chat True \
39 | --use_vllm True"
40 |
41 | # 根据是否需要评估来构建完整命令
42 | if [ "$1" = "--eval" ]; then
43 | FULL_CMD="$BASE_CMD $EVAL_ARGS"
44 | else
45 | FULL_CMD="$BASE_CMD --use_eval_in_train False"
46 | fi
47 |
48 | # 执行命令
49 | eval $FULL_CMD
--------------------------------------------------------------------------------
/run_example.sh:
--------------------------------------------------------------------------------
1 |
2 | DATA_PATH=''
3 | OUTPUT_PATH=""
4 | MODEL_PATH=""
5 |
6 | # task_type:[sft] pretrain正在开发
7 | # train_mode:[qlora, lora, full]
8 | # train_args_path: [sft_args,dpo_args]
9 |
10 | # deepspeed 启动
11 | deepspeed --master_port 29507 --include localhost:0,1 main_train.py\
12 | --train_data_path "$DATA_PATH" \
13 | --model_name_or_path "$MODEL_PATH" \
14 | --max_len 1024 \
15 | --num_train_epochs 1 \
16 | --per_device_train_batch_size 8 \
17 | --per_device_eval_batch_size 1 \
18 | --gradient_accumulation_steps 4 \
19 | --task_type "sft" \
20 | --train_mode "qlora" \
21 | --output_dir "$OUTPUT_PATH" \
22 | --save_strategy "steps" \
23 | --save_steps 500 \
24 | --save_total_limit 5 \
25 | --learning_rate 2e-4 \
26 | --warmup_steps 10 \
27 | --logging_steps 1 \
28 | --lr_scheduler_type "cosine_with_min_lr" \
29 | --gradient_checkpointing True \
30 | --report_to "wandb" \
31 | --deepspeed './train_args/deepspeed_config/ds_config_zero2.json' \
32 | --bf16 True \
33 | --auto_adapt True \
34 | --use_eval_in_train True \
35 | --test_datasets_path "./" \
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | # python main_train.py --train_data_path 数据集路径 --model_name_or_path 模型路径 ......同上述传入参数
44 |
--------------------------------------------------------------------------------
/run_vlm_example.sh:
--------------------------------------------------------------------------------
1 |
2 | DATA_PATH=''
3 | OUTPUT_PATH=""
4 | MODEL_PATH=""
5 |
6 |
7 | deepspeed --master_port 29507 --include localhost:0,1 vlm_train.py\
8 | --train_data_path "$DATA_PATH" \
9 | --model_name_or_path "$MODEL_PATH" \
10 | --max_seq_length 1024 \
11 | --num_train_epochs 1 \
12 | --per_device_train_batch_size 2 \
13 | --gradient_accumulation_steps 2 \
14 | --task_type "QA" \
15 | --train_mode "lora" \
16 | --output_dir "$OUTPUT_PATH" \
17 | --save_strategy "steps" \
18 | --save_steps 20 \
19 | --save_total_limit 5 \
20 | --learning_rate 2e-5 \
21 | --warmup_steps 10 \
22 | --logging_steps 1 \
23 | --lr_scheduler_type "cosine" \
24 | --gradient_checkpointing True \
25 | --deepspeed './train_args/deepspeed_config/ds_config_zero2.json' \
26 | --bf16 True \
27 | --torch_dtype bfloat16 \
28 | --freeze_vision True \
29 | --freeze_projector False
--------------------------------------------------------------------------------
/train_args/__init__.py:
--------------------------------------------------------------------------------
1 | from train_args.sft.base import TrainArgument as sft_TrainArgument
2 |
3 | __all__ = [
4 | "sft_TrainArgument",
5 | ]
6 |
--------------------------------------------------------------------------------
/train_args/common_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 | from enum import Enum
4 |
5 |
6 | class TrainMode(Enum):
7 | QLORA = 'qlora'
8 | LORA = 'lora'
9 | FULL = 'full'
10 |
11 |
12 | @dataclass
13 | class CommonArgs:
14 | """
15 | 一些常用的自定义参数
16 | """
17 | # Deepspeed相关参数,如出现报错可注释掉
18 | # local_rank: int = field(default=1, metadata={"help": "deepspeed所需参数,单机无需修改,如出现报错可注释掉或添加"})
19 |
20 | train_args_path: str = 'sft_args' # 训练参数 默认sft_args
21 | max_len: int = field(default=1024, metadata={"help": "最大输入长度"})
22 | train_data_path: Optional[str] = field(default='./', metadata={"help": "训练集路径"})
23 | model_name_or_path: str = field(default='./', metadata={"help": "下载的所需模型路径"})
24 |
25 | # 训练方法相关选择与配置
26 | task_type: str = field(default="sft", metadata={"help": "预训练任务:目前支持sft"})
27 | train_mode: TrainMode = field(default='lora', metadata={"help": "选择采用的训练方式:[qlora, lora, full]"})
28 | use_dora: bool = field(default=False,
29 | metadata={"help": "在train_mode==lora时可以使用。是否使用Dora(一个基于lora的变体)"})
30 |
31 | # lora相关配置
32 | lora_rank: Optional[int] = field(default=64, metadata={"help": "lora rank"})
33 | lora_alpha: Optional[int] = field(default=16, metadata={"help": "lora alpha"})
34 | lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "lora dropout"})
35 |
36 | # 是否自动适配template
37 | auto_adapt: bool = field(default=True, metadata={"help": "选择是否自动适配template,若为False,则直接使用输入数据"})
38 | # 是否训练中评测
39 | use_eval_in_train: bool = False
40 | test_datasets_path: Optional[str] = None
41 |
--------------------------------------------------------------------------------
/train_args/deepspeed_config/ds_config_zero0.json:
--------------------------------------------------------------------------------
1 | {
2 | "gradient_accumulation_steps": "auto",
3 | "gradient_clipping": "auto",
4 | "steps_per_print": 20,
5 | "train_batch_size": "auto",
6 | "train_micro_batch_size_per_gpu": "auto",
7 | "wall_clock_breakdown": false,
8 |
9 | "optimizer": {
10 | "type": "AdamW",
11 | "params": {
12 | "lr": "auto",
13 | "betas": "auto",
14 | "eps": "auto",
15 | "weight_decay": "auto"
16 | }
17 | },
18 | "scheduler": {
19 | "type": "WarmupLR",
20 | "params": {
21 | "warmup_min_lr": "auto",
22 | "warmup_max_lr": "auto",
23 | "warmup_num_steps": "auto"
24 | }
25 | },
26 |
27 | "zero_optimization": {
28 | "stage": 0
29 | },
30 |
31 | "bf16": {
32 | "enabled": "auto"
33 | }
34 | }
--------------------------------------------------------------------------------
/train_args/deepspeed_config/ds_config_zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 2,
4 | "offload_optimizer": {
5 | "device": "cpu",
6 | "pin_memory": true
7 | },
8 | "allgather_partitions": true,
9 | "allgather_bucket_size": 5e8,
10 | "overlap_comm": true,
11 | "reduce_scatter": true,
12 | "reduce_bucket_size": 5e8,
13 | "contiguous_gradients": true,
14 | "round_robin_gradients": true
15 | },
16 | "bf16": {
17 | "enabled": "auto"
18 | },
19 | "optimizer": {
20 | "type": "AdamW",
21 | "params": {
22 | "lr": "auto",
23 | "betas": "auto",
24 | "eps": "auto",
25 | "weight_decay": "auto"
26 | }
27 | },
28 | "scheduler": {
29 | "type": "WarmupLR",
30 | "params": {
31 | "warmup_min_lr": "auto",
32 | "warmup_max_lr": "auto",
33 | "warmup_num_steps": "auto"
34 | }
35 | },
36 | "gradient_accumulation_steps": "auto",
37 | "gradient_clipping": "auto",
38 | "steps_per_print": 20,
39 | "train_batch_size": "auto",
40 | "train_micro_batch_size_per_gpu": "auto",
41 | "wall_clock_breakdown": false
42 | }
--------------------------------------------------------------------------------
/train_args/deepspeed_config/ds_config_zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | }
14 |
15 | "zero_optimization": {
16 | "stage": 3,
17 | "offload_optimizer": {
18 | "device": "cpu",
19 | "pin_memory": true
20 | },
21 | "offload_param": {
22 | "device": "cpu",
23 | "pin_memory": true
24 | },
25 | "overlap_comm": true,
26 | "contiguous_gradients": true,
27 | "sub_group_size": 1e9,
28 | "reduce_bucket_size": "auto",
29 | "stage3_prefetch_bucket_size": "auto",
30 | "stage3_param_persistence_threshold": "auto",
31 | "stage3_max_live_parameters": 1e9,
32 | "stage3_max_reuse_distance": 1e9,
33 | "stage3_gather_16bit_weights_on_model_save": true
34 | },
35 |
36 | "gradient_accumulation_steps": "auto",
37 | "gradient_clipping": "auto",
38 | "steps_per_print": 20,
39 | "train_batch_size": "auto",
40 | "train_micro_batch_size_per_gpu": "auto",
41 | "wall_clock_breakdown": false
42 | }
--------------------------------------------------------------------------------
/train_args/dpo/README.md:
--------------------------------------------------------------------------------
1 | # 更新
2 | trl新版本已经更改dpo实现,故此处已不适用,新的DPO可见[RLHF](./rlhf/README.md)。
3 |
4 |
5 |
6 | # 关于DPO训练
7 | 目前分为两个模式,分别是multi_dpo和single_dpo。**推荐一般使用multi_dpo**。
8 |
9 | DPO训练方式均支持框架中的deepspeed或者python启动模式,相应的lora、qlora也支持。
10 |
11 | 区别在于两种方式的数据组织形式,前者是使用DPOTrainer自动进行数据处理,且是多轮对话形式,参照格式也可将其改为单轮对话,故前者是单轮与多轮通用的。
12 |
13 | 后者是自己从零构建的数据组织形式,理论上按照DPOTrainer相同形式,只实现了单轮。这样的**目的是为了更好地理解DPO的过程以及方便一些魔改操作**,权当学习使用。
14 |
15 | 🤓**注意:** 对于DPO数据,可见```data/dpo_multi_data.jsonl```示例数据。数据是huggingface的hh-rlhf-helpful-base-trl-style格式数据,其中prompt是一句话,而chosen和
16 | rejected则是包含prompt的完整对话。故如构建自己的数据集时,无论多轮和单轮,都应在chosen和rejected中加入prompt,单轮相当于取第一句当prompt,
17 | 多轮相当于取最后一句之前的所有当prompt(其实还可以取每一轮的user当prompt,后面有时间可能会实现)。
18 |
19 | 对于自己构建的single_dpo数据格式,示例为:
20 | ```json lines
21 | {"prompt":"哈喽啊","chosen":"你好", "reject": "不好"}
22 | ```
23 |
24 | ## 代码位置
25 |
26 | 自己构建的single_dpo数据格式代码在```utils/data_process.py```文件中的```DpoDataset```类。
27 |
28 | 参照官方构建的数据格式在```mian_train.py```中的```load_dpo_dataset```函数里。
29 |
30 |
31 | ## 技术文章
32 | - [DPO训练QWEN2及魔改DPO实现](https://zhuanlan.zhihu.com/p/702569978)
33 |
34 |
35 | ## DPO quick start
36 |
37 | **1、支持命令行传参启动,启动示例可见```LLM-Dojo/run_example.sh```**
38 |
39 | **2、也支持参数文件直接修改默认值,具体如下:**
40 |
41 | ### Step1 配置args.py
42 | 常规的参数在utils下的args.py,基本默认设置即可,你只需要改一下模型路径、输出路径、task_type、template_name、train_data_path、train_args_path、train_mode等。
43 |
44 | 使用multi_dpo时args.py中的max_len和max_prompt_length参数是没用的,需要在后面的dpo_config.py中设置
45 |
46 | 其中:
47 | > train_args_path:为Step2中需要配置的train_args路径
48 |
49 | ### Step2 配置train_args文件夹下对应文件
50 | 相关训练参数在train_args文件夹下对应的文件中。一般就是用```dpo/dpo_config.py```即可
51 |
52 | 均是采用dataclass格式配置参数,直接在default中修改即可,即不需要直接命令行传输参数了。
53 |
54 | 在这里修改max_len和max_prompt_length参数,其他需要设置的是是否选择deepspeed模式训练等参数
55 |
56 | ### Step3 开始训练
57 |
58 | 开始训练就和之前SFT一样了
59 |
60 | 😶Python命令单卡启动:
61 |
62 | 设置好相关配置后即可运行main_train.py进行训练
63 | ```bash
64 | python main_train.py
65 | ```
66 |
67 | 🙃Deepspeed单卡或多卡启动:
68 |
69 | 使用Deepspeed训练时前两步与常规相同,但需要额外配置ds_config文件,项目中已给出常用的配置示例,位于```train_args/deepspeed_config/```路径下,
70 | 更详细的Deepspeed原理及解释可以看文章:[Deepspeed配置及使用讲解](https://zhuanlan.zhihu.com/p/698631348)
71 |
72 | 运行以下命令启动:
73 | ```bash
74 | deepspeed --include localhost:6,7 main_train.py
75 | ```
76 | 其中```include localhost```参数用于选择训练的GPU,可选单卡也可选多卡。
77 |
78 |
--------------------------------------------------------------------------------
/train_args/sft/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 | from transformers import TrainingArguments
4 |
5 |
6 | @dataclass
7 | class TrainArgument(TrainingArguments):
8 | """
9 | 训练参数, 直接在这里修改即可
10 | """
11 | output_dir: str = field(default='', metadata={"help": "模型训练完成后的保存路径"})
12 | num_train_epochs: int = 1,
13 |
14 | per_device_train_batch_size: int = 2
15 | gradient_checkpointing: bool = True
16 | gradient_accumulation_steps: int = 16,
17 |
18 | learning_rate: float = 2e-4
19 | logging_steps: int = 10
20 | save_steps: int = 500
21 | save_strategy: str = "steps"
22 | save_total_limit: int = 2
23 | lr_scheduler_type: str = "constant_with_warmup",
24 | warmup_steps: int = 10
25 | optim: str = 'adamw_torch'
26 | report_to: str = 'tensorboard'
27 | remove_unused_columns: bool = False
28 | bf16: bool = True
29 | fp16: bool = False
30 |
31 | # Deepspeed训练相关参数,不使用时设置为default=None
32 | deepspeed: Optional[str] = field(default=None, metadata={"help": "启用Deepspeed时需要的config文件"})
33 |
--------------------------------------------------------------------------------
/train_args/vlm_config/script_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class ScriptArgs:
7 | """
8 | 自定义参数
9 | """
10 | # Deepspeed相关参数,如出现报错可注释掉
11 | # local_rank: int = field(default=1, metadata={"help": "deepspeed所需参数,单机无需修改,如出现报错可注释掉或添加"})
12 | task_type: str = field(default='QA', metadata={"help": "任务类型,目前可选:[QA]"})
13 | '''多模态任务类型'''
14 |
15 | train_data_path: Optional[str] = field(default='./', metadata={"help": "训练集路径"})
16 | '''训练集路径'''
17 |
18 | train_mode: str = field(default='lora', metadata={"help": "选择对llm采用的训练方式:[qlora, lora, full]"})
19 | '''选择对llm采用的训练方式'''
20 |
21 | freeze_vision: bool = True
22 | '''训练是否冻结视觉层'''
23 |
24 | freeze_projector: bool = False
25 | '''训练是否冻结转接层'''
26 |
27 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mst272/LLM-Dojo/6397eff480feca95f3b6016495b30d1ae308c9dc/utils/__init__.py
--------------------------------------------------------------------------------
/utils/data_collator.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 | import torch
3 | from loguru import logger
4 | from PIL import Image
5 | from utils.vlm_template import LlavaTemplateProcessor, Qwen2VLTemplateProcessor
6 |
7 |
8 | class SftDataCollator:
9 | def __init__(self, tokenizer, max_length):
10 | self.tokenizer = tokenizer
11 | self.max_length = max_length
12 | self.pad_token_id = tokenizer.pad_token_id
13 |
14 | def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
15 | # 找出最大长度
16 | length = [len(x['input_ids']) for x in batch if x['input_ids'] is not None]
17 | # 每个batch中的最大长度
18 | max_batch_length = min(self.max_length, max(length))
19 |
20 | input_ids_batch, attention_mask_batch, target_mask_batch = [], [], []
21 |
22 | for x in batch:
23 | input_ids = x['input_ids']
24 | attention_mask = x['attention_mask']
25 | target_mask = x['target_mask']
26 | if input_ids is None:
27 | logger.info('some input_ids is None,and now continue')
28 | continue
29 | padding_len = max_batch_length - len(input_ids)
30 | # 开始padding
31 | input_ids += [self.pad_token_id] * padding_len
32 | attention_mask += [0] * padding_len
33 | target_mask += [0] * padding_len
34 | # 开始截断
35 | input_ids = input_ids[:self.max_length]
36 | attention_mask = attention_mask[:self.max_length]
37 | target_mask = target_mask[:self.max_length]
38 | # 将本批次全部加入列表
39 | input_ids_batch.append(input_ids)
40 | attention_mask_batch.append(attention_mask)
41 | target_mask_batch.append(target_mask)
42 |
43 | # 将list转换为tensor,得到最终的的模型输入
44 | input_ids_batch = torch.tensor(input_ids_batch, dtype=torch.long)
45 | attention_mask_batch = torch.tensor(attention_mask_batch, dtype=torch.long)
46 | target_mask_batch = torch.tensor(target_mask_batch, dtype=torch.long)
47 |
48 | # 计算损失时忽略
49 | labels = torch.where(target_mask_batch == 1, input_ids_batch, -100)
50 | inputs = {
51 | 'input_ids': input_ids_batch,
52 | 'attention_mask': attention_mask_batch,
53 | 'labels': labels
54 | }
55 | return inputs
56 |
57 |
58 | # def llava_template_process(questions: List[str], answers: List[str]) -> List[Dict]:
59 | # converted_data = []
60 | # for question, answer in zip(questions, answers):
61 | # user_content = [{'index': None, 'text': question, 'type': 'text'}]
62 | # assistant_content = [{'index': None, 'text': answer, 'type': 'text'}]
63 | #
64 | # converted_data.append({'content': user_content, 'role': 'user'})
65 | # converted_data.append({'content': assistant_content, 'role': 'assistant'})
66 | # image_dict = {'index': 0, 'text': None, 'type': 'image'}
67 | # converted_data[0]['content'].append(image_dict)
68 | # return converted_data
69 |
70 | processor_class_map = {
71 | 'LlavaProcessor': LlavaTemplateProcessor(),
72 | 'Qwen2VLProcessor': Qwen2VLTemplateProcessor(),
73 | # 可继续添加更多的处理器类
74 | }
75 |
76 |
77 | class VlmQaDataCollator:
78 | def __init__(self, processor):
79 | self.processor = processor
80 | processor_class = processor.to_dict()['processor_class']
81 | if processor_class not in processor_class_map:
82 | raise ValueError(f"Unknown processor class: {processor_class}")
83 | self.template_process = processor_class_map[processor_class]
84 |
85 | def __call__(self, examples):
86 | texts = []
87 | images = []
88 | for example in examples:
89 | standard_example = self.template_process.process(example[0], example[1])
90 | text = self.processor.apply_chat_template(
91 | standard_example, tokenize=False, add_generation_prompt=False
92 | )
93 | texts.append(text)
94 | raw_image = Image.open(example[2])
95 | images.append(raw_image)
96 |
97 | batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
98 |
99 | # 这里并没有mask question, 后续可能的扩充是设置mask question的模式。
100 | labels = batch["input_ids"].clone()
101 | labels[labels == self.processor.tokenizer.pad_token_id] = -100
102 | # image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token)
103 | # labels[labels == image_token_id] = -100
104 | batch['labels'] = labels
105 |
106 | return batch
107 |
108 |
109 | class VlmCaptionImageDataCollator:
110 | pass
111 |
--------------------------------------------------------------------------------
/utils/eval/README.md:
--------------------------------------------------------------------------------
1 | # Train with prediction eval
2 |
3 | 详细的使用说明及函数文档待更新,目前只是一个dirty的实现。
4 |
5 |
6 |
7 | ## test data
8 |
9 | 评测数据格式为jsonl,以代码test为例,具体可见data/test.jsonl。
10 | 包含两个字段:
11 | - prompt: test 的问题
12 | - label:答案,代码来说即是测试用例
13 |
14 | ## Quick start
15 |
16 | 使用vllm进行生成,其余卡进行训练。
17 |
18 | 启动脚本位置:utils/eval/vllm/run_serve.sh
19 |
20 | 1、启动vllm_serve,例如使用2卡
21 |
22 | ```shell
23 | bash run_serve.sh
24 | ```
25 |
26 | 2、开启训练
27 |
28 | 启动脚本位置:run_eval_test.sh
29 |
30 | ```shell
31 | bash run_eval_test.sh --eval
32 | ```
33 | 运行脚本,注意要跟--eval,一些参数配置可参考run_eval_test.sh文件。
34 |
35 |
36 |
37 |
38 |
39 | ## Tip
40 |
41 | wandb出问题可以尝试:
42 | pip install wandb==0.12.18
43 |
44 | 可能出现的问题:
45 |
46 | 1、直接deepspeed --master_port 29508 --include localhost:2,3,4,5,6,7 main_train.py保存checkpoint时有问题,所以建议
47 | accelerate launch --config_file rlhf/ds_config/ds_zero3.yaml main_train.py
48 |
49 |
50 | 2、训练时出现训推不一致问题,训练中评测跟保存后结果对不上,最后找到原因是因为没有enable_prefix_caching=False。
51 | 不过尝试之后仍然会有偏差,但是影响不大,曲线的轨迹是可以反映模型在测试集上的效果的。
52 |
53 | 待验证:可能是由于dropout层的原因,后续计划禁止dropout尝试
54 |
55 |
56 | 参考:https://github.com/huggingface/open-r1/issues/433
57 |
58 |
59 |
60 | ## Reference
61 | 最后,代码借鉴了trl项目,感谢trl为开源做出的贡献。
--------------------------------------------------------------------------------
/utils/eval/callback.py:
--------------------------------------------------------------------------------
1 | import heapq
2 | import os
3 | import shutil
4 | from contextlib import nullcontext
5 | from dataclasses import dataclass
6 | import time
7 | from typing import List, Optional
8 | from utils.eval.configs import GenerationConfig
9 | import deepspeed
10 | from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model
11 | import pandas as pd
12 | from trl.import_utils import is_deepspeed_available, is_rich_available, is_vllm_available
13 | from utils.eval.eval_utils import _generate_completions, reason_post_process
14 | import wandb
15 | from datasets import Dataset
16 | from transformers.trainer_callback import ExportableState
17 | from transformers import (
18 | Trainer,
19 | TrainerCallback
20 | )
21 | from utils.eval.eval_metric import BaseMetric
22 | from utils.eval.vllm.vllm_client import VLLMClient
23 |
24 |
25 | @dataclass
26 | class CheckpointInfo:
27 | step: int
28 | metric_value: float
29 | path: str
30 |
31 | def __lt__(self, other):
32 | return self.metric_value < other.metric_value
33 |
34 |
35 | class EvaluationCallback(TrainerCallback):
36 | r"""
37 | A [`~transformers.TrainerCallback`] that logs completions and eval metrics to Weights & Biases and/or Comet.
38 |
39 | Usage:
40 | ```python
41 | trainer = Trainer(...)
42 | evaluation_callback = EvaluationCallback(trainer=trainer)
43 | trainer.add_callback(evaluation_callback)
44 | ```
45 |
46 | Args:
47 | trainer (`Trainer`):
48 | Trainer to which the callback will be attached.
49 | generation_config (`GenerationConfig`, *optional*):
50 | The generation config to use for generating completions.
51 | num_samples (`int` or `None`, *optional*):
52 | The number of prompts to eval for. If not provided, defaults to the number of examples in the evaluation dataset.
53 | freq (`int` or `None`, *optional*):
54 | The frequency at which step to generate and compute metrics.
55 | metric
56 | """
57 |
58 | def __init__(
59 | self,
60 | trainer: Trainer,
61 | test_datasets: Dataset,
62 | generation_config: GenerationConfig,
63 | num_samples: Optional[int] = None,
64 | freq: Optional[int] = None,
65 | metrics: Optional[BaseMetric] = None,
66 | max_checkpoints: int = 3,
67 | per_device_test_batch_size: int = 1,
68 | higher_better: bool = True,
69 | start_update_best_checkpoints: int = 0,
70 | gather_deepspeed3_params: bool = True,
71 | use_vllm: bool = True,
72 | vllm_server_host: str = "0.0.0.0",
73 | vllm_server_port: int = 8080,
74 | vllm_server_timeout: float = 120.0,
75 | prompts_apply_chat: bool = True
76 |
77 | ):
78 | self.gen_config = generation_config
79 | self.trainer = trainer
80 | self.best_checkpoints: List[CheckpointInfo] = []
81 | self.max_checkpoints = max_checkpoints # 最大保存数量
82 | self.higher_better = higher_better # 指标是否越大越好
83 | self._last_logged_step = -1
84 | self.batch_size = per_device_test_batch_size
85 | self.table = []
86 | self.freq = freq
87 | self.metric = metrics
88 | self.start_update_best_checkpoints = start_update_best_checkpoints
89 | self.gather_deepspeed3_params = gather_deepspeed3_params
90 | self.prompts_apply_chat = prompts_apply_chat
91 | self.use_vllm = use_vllm
92 |
93 | if self.metric is None:
94 | raise ValueError("You must provide a metric[BaseMetric]")
95 |
96 | if num_samples is not None:
97 | self.sample_dataset = test_datasets.select(range(num_samples))
98 | else:
99 | self.sample_dataset = test_datasets
100 |
101 | # 配置vllm client
102 | if use_vllm:
103 | if not is_vllm_available():
104 | raise ImportError(
105 | "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
106 | "`pip install vllm` to use it."
107 | )
108 | if self.trainer.accelerator.is_main_process:
109 | self.vllm_client = VLLMClient(host=vllm_server_host, server_port=vllm_server_port,
110 | connection_timeout=vllm_server_timeout)
111 | self._last_loaded_step = 0
112 | self.trainer.accelerator.wait_for_everyone()
113 |
114 | def _move_model_to_vllm(self):
115 | # For DeepSpeed ZeRO-3, we need to gather all parameters before operations
116 | deepspeed_plugin = self.trainer.accelerator.state.deepspeed_plugin
117 | zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
118 | gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
119 |
120 | if is_peft_model(self.trainer.model):
121 | # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
122 | # adapters in a sharded manner is not supported.
123 | with gather_if_zero3(list(self.trainer.model.parameters())):
124 | self.trainer.model.merge_adapter()
125 |
126 | # Update vLLM weights while parameters are gathered
127 | for name, param in self.trainer.model.named_parameters():
128 | # When using PEFT, we need to recover the original parameter name and discard some parameters
129 | name = name.removeprefix("base_model.model.").replace(".base_layer", "")
130 | if self.trainer.model.prefix in name:
131 | continue
132 | # When module to save, remove its prefix and discard the original module
133 | if "original_module" in name:
134 | continue
135 | name = name.replace("modules_to_save.default.", "")
136 |
137 | if self.trainer.accelerator.is_main_process:
138 | self.vllm_client.update_named_param(name, param.data)
139 |
140 | # Unmerge adapters while parameters are still gathered
141 | self.trainer.model.unmerge_adapter()
142 | # Parameters will automatically be repartitioned when exiting the context
143 | else:
144 | # For non-PEFT models, simply gather and update each parameter individually.
145 | for name, param in self.trainer.model.named_parameters():
146 | with gather_if_zero3([param]):
147 | if self.trainer.accelerator.is_main_process:
148 | self.vllm_client.update_named_param(name, param.data)
149 | # Reset cache on main process
150 | if self.trainer.accelerator.is_main_process:
151 | self.vllm_client.reset_prefix_cache()
152 |
153 | def samples_generate_vllm(self, steps):
154 | """
155 | prompts:
156 | labels:
157 | """
158 | # First, have main process load weights if needed
159 | if steps != self._last_loaded_step:
160 | self._move_model_to_vllm()
161 | self._last_loaded_step = steps
162 | # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
163 | # all_prompts_text = gather_object(prompts)
164 | if self.trainer.accelerator.is_main_process:
165 | # data process
166 | prompts = self.metric.get_prompts(self.sample_dataset, self.trainer.processing_class,
167 | self.prompts_apply_chat)
168 | labels = self.metric.get_labels(self.sample_dataset)
169 |
170 | # todo: with profiling_context(self, "vLLM.generate"): 上下文时间处理
171 | start_time = time.time()
172 | completion_ids = self.vllm_client.generate(
173 | prompts=prompts,
174 | n=self.gen_config.num_generation,
175 | repetition_penalty=self.gen_config.repetition_penalty,
176 | temperature=self.gen_config.temperature,
177 | top_p=self.gen_config.top_p,
178 | top_k=self.gen_config.top_k,
179 | min_p=self.gen_config.min_p,
180 | max_tokens=self.gen_config.max_new_tokens
181 | )
182 | end_time = time.time() # 记录 _generate_completions 结束时间
183 | generation_time = end_time - start_time # 计算生成耗时
184 | print(f"\nProcess main: Generation time: {generation_time:.4f} seconds")
185 |
186 | tokenizer = self.trainer.processing_class
187 | # tokenizer.padding_side = "left"
188 | completions = tokenizer.batch_decode(completion_ids) # --> List[str]
189 | generations = self.metric.extract_generation(completions)
190 | score = self.metric.compute(references=labels, predictions=generations)
191 | else:
192 | score = None
193 | return score
194 |
195 | def samples_generate_split_between_processes(self, steps):
196 | """
197 | if model very large, maybe OOM.
198 | """
199 | labels = [example['message'][1]['content'] for example in self.sample_dataset]
200 | tokenizer = self.trainer.processing_class
201 | tokenizer.padding_side = "left"
202 | accelerator = self.trainer.accelerator
203 | model = self.trainer.model_wrapped
204 | start_time = time.time()
205 | with accelerator.split_between_processes(self.sample_dataset['message']) as prompts_split:
206 | prompts = []
207 | for lis in prompts_split:
208 | prompts.append('补全下面代码,将最终题目和答案返回在代码框中\n' + lis[0]['content'])
209 | completions = _generate_completions(
210 | prompts,
211 | model=model,
212 | tokenizer=tokenizer,
213 | accelerator=accelerator,
214 | generation_config=self.gen_config,
215 | batch_size=self.batch_size,
216 | gather_deepspeed3_params=self.gather_deepspeed3_params
217 | )
218 | completions = gather_object(completions)
219 | prompts = gather_object(prompts)
220 | end_time = time.time() # 记录 _generate_completions 结束时间
221 | generation_time = end_time - start_time # 计算生成耗时
222 |
223 | generations = [[reason_post_process(c, i)] for i, c in enumerate(completions)]
224 | print(f"Process {accelerator.process_index}: Generation time: {generation_time:.4f} seconds")
225 |
226 | if len(self.sample_dataset) < accelerator.num_processes:
227 | generations = generations[:len(labels)]
228 | # 处理输出表格数据
229 | if self.trainer.accelerator.is_main_process:
230 | global_step = [str(steps)] * len(prompts)
231 | config_keys = list(self.gen_config.to_dict().keys())
232 | config_values = list(self.gen_config.to_dict().values())
233 | data = [[global_step[i], prompts[i], completions[i]] + config_values for i in range(len(prompts))]
234 | self.table.extend(data)
235 | table = pd.DataFrame(columns=["step", "prompt", "completion"] + config_keys, data=self.table)
236 | wandb.log({"completions": table})
237 |
238 | score = self.metric.compute(references=labels, predictions=generations)
239 | return score
240 |
241 | def save_best_metric_model(self, args, state):
242 | # Save model checkpoint
243 | print('开始保存checkpoint')
244 | checkpoint_folder = f"checkpoint-{state.global_step}"
245 | output_dir = os.path.join(args.output_dir, 'best_model', checkpoint_folder)
246 |
247 | self.trainer.save_model(output_dir)
248 |
249 | if not args.save_only_model:
250 | # Save optimizer and scheduler
251 | self.trainer._save_optimizer_and_scheduler(output_dir)
252 | self.trainer._save_scaler(output_dir)
253 | # Save RNG state
254 | self.trainer._save_rng_state(output_dir)
255 | if args.should_save:
256 | # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
257 | for cb in [
258 | cb for cb in self.trainer.callback_handler.callbacks + [self.trainer.control] if
259 | isinstance(cb, ExportableState)
260 | ]:
261 | cb_name = cb.__class__.__name__
262 | cb_state = cb.state()
263 | if isinstance(state.stateful_callbacks[cb_name], list):
264 | state.stateful_callbacks[cb_name].append(cb_state)
265 | else:
266 | state.stateful_callbacks[cb_name] = cb_state
267 | state.save_to_json(os.path.join(output_dir, 'trainer_state.json'))
268 |
269 | return output_dir
270 |
271 | def update_best_checkpoints1(self, args, state, custom_score):
272 | """更新最佳checkpoint列表"""
273 | if state.global_step < self.start_update_best_checkpoints:
274 | return
275 | print('更新最佳checkpoint列表')
276 | # 对于越小越好的指标(如loss),转换为负数以便统一处理
277 | metric_value = custom_score if self.higher_better else -custom_score
278 |
279 | # 如果还没有达到最大数量,或者当前指标比最差的更好
280 | if (len(self.best_checkpoints) < self.max_checkpoints or
281 | metric_value > self.best_checkpoints[0].metric_value):
282 |
283 | # 保存新的checkpoint
284 | checkpoint_path = self.save_best_metric_model(args, state)
285 |
286 | # 创建新的CheckpointInfo对象
287 | checkpoint_info = CheckpointInfo(
288 | step=state.global_step,
289 | metric_value=metric_value,
290 | path=checkpoint_path
291 | )
292 |
293 | # 更新最佳checkpoint列表
294 | heapq.heappush(self.best_checkpoints, checkpoint_info)
295 |
296 | # 如果超过最大数量,删除最差的checkpoint
297 | if len(self.best_checkpoints) > self.max_checkpoints:
298 | worst_checkpoint = heapq.heappop(self.best_checkpoints)
299 | print(f"Deleting older checkpoint [{worst_checkpoint.path}] due to args.save_total_limit")
300 | shutil.rmtree(worst_checkpoint.path, ignore_errors=True)
301 |
302 | def update_best_checkpoints(self, args, state, custom_score):
303 | # 从主进程广播 custom_score,确保所有进程有相同数据
304 | custom_score = broadcast_object_list([custom_score], from_process=0)[0]
305 |
306 | if state.global_step < self.start_update_best_checkpoints:
307 | return
308 |
309 | metric_value = custom_score if self.higher_better else -custom_score
310 |
311 | if self.trainer.accelerator.is_main_process:
312 | # 仅主进程决定是否保存
313 | if len(self.best_checkpoints) < self.max_checkpoints or (metric_value > self.best_checkpoints[
314 | 0].metric_value and state.global_step % args.save_steps != 0):
315 | save_flag = True
316 | else:
317 | save_flag = False
318 | else:
319 | save_flag = False
320 |
321 | # 广播保存决定到所有进程
322 | save_flag = broadcast_object_list([save_flag], from_process=0)[0]
323 |
324 | if save_flag:
325 | checkpoint_path = self.save_best_metric_model(args, state)
326 | # if self.trainer.accelerator.is_main_process:
327 | checkpoint_info = CheckpointInfo(step=state.global_step, metric_value=metric_value,
328 | path=checkpoint_path)
329 | heapq.heappush(self.best_checkpoints, checkpoint_info)
330 | if len(self.best_checkpoints) > self.max_checkpoints:
331 | worst_checkpoint = heapq.heappop(self.best_checkpoints)
332 | print(f"Deleting older checkpoint [{worst_checkpoint.path}] due to args.save_total_limit")
333 | shutil.rmtree(worst_checkpoint.path, ignore_errors=True)
334 |
335 | def on_step_end(self, args, state, control, **kwargs):
336 | # Only log once per step (this method may be called multiple times)
337 | if state.global_step == self._last_logged_step:
338 | return
339 |
340 | # Only log every `freq` steps
341 | freq = self.freq
342 | if state.global_step % freq != 0:
343 | return
344 |
345 | # todo: 改成字典,返回多个指标可视化,选其中一个或者混合作为保存指标? use_vllm=False的逻辑?
346 | custom_score = self.samples_generate_vllm(state.global_step) if self.use_vllm else None
347 | self.trainer.log({"custom_score": custom_score, "step": state.global_step})
348 |
349 | self.update_best_checkpoints(args, state, custom_score)
350 | # Save the last logged step, so we don't log the same completions multiple times
351 | self._last_logged_step = state.global_step
352 |
--------------------------------------------------------------------------------
/utils/eval/configs.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Union, List, Dict
3 | from transformers import TrainingArguments
4 |
5 |
6 | @dataclass
7 | class GenerationConfig:
8 | """
9 | generation config, Vllm or model.generate
10 |
11 | Parameters:
12 | max_new_tokens (`int`, *optional*, defaults to `16`):
13 | Maximum number of tokens to generate for each prompt
14 | num_generation (`int`, *optional*, defaults to `1`):
15 | Number of completions to generate for each prompt.
16 | temperature:
17 | generate temperature
18 | top_p:
19 | top_k:
20 | do_sample:
21 | min_p
22 |
23 | """
24 | num_generation: int = 1
25 | repetition_penalty: float = 1.0
26 | temperature: float = 1.0
27 | top_p: float = 1.0
28 | top_k: int = -1
29 | min_p: float = 0.0
30 | max_new_tokens: int = 1024
31 | do_sample: bool = False
32 |
33 | # unwrap_model_for_generation
34 | gather_deepspeed3_params: bool = True # if OOM, False it.
35 |
36 |
37 | @dataclass
38 | class EvaluationConfig:
39 | """
40 | Common test config
41 |
42 | Parameters:
43 | num_samples: 在测试集中随机选数量
44 | freq:
45 | metrics:
46 | """
47 | # 基础评估设置
48 | num_samples: Optional[int] = None
49 | freq: int = 5
50 | metrics: str = 'code' # ['code', 'em'] or metrics=[{'name': 'code',
51 | # 'weight': 0.7},{'name': 'em', 'weight': 0.3}]
52 |
53 | higher_better: bool = True
54 | prompts_apply_chat: bool = False
55 |
56 | use_vllm: bool = True # whether to use vllm to generate, if false, use unwrap_model_for_generation
57 | vllm_server_host: str = "0.0.0.0"
58 | vllm_server_port: int = 8080
59 | vllm_server_timeout: float = 120.0
60 |
61 | per_device_test_batch_size: int = 1 # Only use when use_vllm is False
62 |
63 | # Checkpoint管理
64 | save_best_checkpoints: bool = True
65 | start_update_best_checkpoints: int = 20
66 | max_checkpoints: int = 3
67 |
68 |
69 | if __name__ == "__main__":
70 | gen = GenerationConfig()
71 | print(gen.stop_strings)
72 |
--------------------------------------------------------------------------------
/utils/eval/eval_metric.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from typing import List, Dict, Union, Type, Optional
4 |
5 | import psutil
6 |
7 | import evaluate
8 | from abc import ABC, abstractmethod
9 | import gc
10 | from datasets import Dataset
11 |
12 | from utils.eval.configs import EvaluationConfig
13 | from contextlib import contextmanager
14 |
15 |
16 | # 添加新的评估方式,只需要在此文件中创建新的评估指标类
17 |
18 | class BaseMetric(ABC):
19 | @abstractmethod
20 | def compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
21 | pass
22 |
23 | @abstractmethod
24 | def extract_generation(self, completions: List[str]):
25 | pass
26 |
27 | def get_prompts(self, test_datasets: Dataset, tokenizer, prompts_apply_chat) -> List[str]:
28 | """
29 | The default get prompts(questions) from the test data.
30 |
31 | Args:
32 | test_datasets: dataset must include a `"prompt"` column containing the prompts for generating completions.
33 | tokenizer:
34 | prompts_apply_chat: :
35 | Returns:
36 | prompts, ['hellow','why']
37 | """
38 | if prompts_apply_chat:
39 | pass
40 | else:
41 | prompts = [test_dataset['prompt'] for test_dataset in test_datasets]
42 | return prompts
43 |
44 | @staticmethod
45 | def get_labels(test_datasets: Dataset):
46 | labels = [test_dataset['label'] for test_dataset in test_datasets]
47 | return labels
48 |
49 |
50 | class CodeEvalMetric(BaseMetric):
51 | def __init__(self, metric_path: str = './utils/eval/metrics/code_eval'):
52 | # self.metric = evaluate.load(metric_path)
53 | self.metric_path = metric_path
54 | self.memory_threshold = 0.8 # 80% memory usage threshold
55 |
56 | def _check_memory_usage(self) -> Optional[float]:
57 | """Monitor memory usage"""
58 | process = psutil.Process(os.getpid())
59 | memory_info = process.memory_info()
60 | memory_percent = process.memory_percent()
61 |
62 | print(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f}MB ({memory_percent:.1f}%)")
63 |
64 | return memory_percent
65 |
66 | @contextmanager
67 | def load_metric(self):
68 | """Context manager for loading and cleaning up metric with memory monitoring"""
69 | try:
70 | memory_percent = self._check_memory_usage()
71 | if memory_percent > self.memory_threshold*100:
72 | print(f"Warning: High memory usage detected: {memory_percent:.1f}%")
73 | gc.collect()
74 |
75 | metric = evaluate.load(self.metric_path)
76 | yield metric
77 | finally:
78 | del metric
79 | gc.collect()
80 | self._check_memory_usage()
81 |
82 | def compute(self, predictions: List[List[str]], references: List[str]) -> float:
83 | with self.load_metric() as metric:
84 | try:
85 | pass_at_k, results = metric.compute(
86 | predictions=predictions,
87 | references=references,
88 | k=[1]
89 | )
90 | return float(pass_at_k["pass@1"])
91 | except Exception as e:
92 | print(f"Error during computation: {str(e)}")
93 | raise e
94 |
95 |
96 | def extract_generation(self, completions: List[str]):
97 | """
98 | extract generation and process data to compute[predictions] format
99 |
100 | Args:
101 | completions List[str]: 输入字符串。
102 |
103 | Returns:
104 | self.compute[predictions] format. For example, there is List[List[str]]
105 | """
106 |
107 | def extract(code: str, index: int):
108 | # Look for code blocks
109 | code_pattern = r'```(?:python|go|javascript|java|bash|js|cpp|cs|php)(.*?)```'
110 | code_match = re.findall(code_pattern, code, re.DOTALL)
111 | if code_match:
112 | # If code block exists, return its content (excluding the ``` markers)
113 | return code_match[-1].strip()
114 | else:
115 | # If no code block, return the solution content directly
116 | return str(index)
117 |
118 | generations = [[extract(c, i)] for i, c in enumerate(completions)]
119 | return generations
120 |
121 | def get_prompts(self, test_datasets: Dataset, tokenizer, prompts_apply_chat) -> List[str]:
122 | if prompts_apply_chat:
123 | prompts = [test_dataset['prompt'] for test_dataset in test_datasets]
124 | messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
125 | res = [tokenizer.apply_chat_template(message, tokenize=False,add_generation_prompt=True) for message in messages]
126 | return res
127 |
128 | else:
129 | prompts = [test_dataset['prompt'] for test_dataset in test_datasets]
130 | return prompts
131 |
132 |
133 | class ExactMatchMetric(BaseMetric):
134 | def compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
135 | pass
136 |
137 | def extract_generation(self, completions: List[str]):
138 | return completions
139 |
140 |
141 | # todo: 未完成
142 | class CompositeMetric(BaseMetric):
143 | """简化的混合指标类"""
144 |
145 | def __init__(self, metric_configs: List[Dict[str, Union[str, float]]]):
146 | self.metrics = []
147 | self.weights = []
148 |
149 | # 创建指标实例
150 | for config in metric_configs:
151 | metric_name = config['name']
152 | weight = config.get('weight', 1.0)
153 |
154 | if metric_name == 'code':
155 | metric = CodeEvalMetric()
156 | elif metric_name == 'em':
157 | metric = ExactMatchMetric()
158 | else:
159 | raise ValueError(f"不支持的指标类型: {metric_name}")
160 |
161 | self.metrics.append(metric)
162 | self.weights.append(weight)
163 |
164 | def compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
165 | results = {}
166 | total_score = 0.0
167 |
168 | for metric, weight in zip(self.metrics, self.weights):
169 | metric_results = metric.compute(predictions, references)
170 | for key, value in metric_results.items():
171 | weighted_value = value * weight
172 | results[f"{metric.__class__.__name__}_{key}"] = weighted_value
173 | total_score += weighted_value
174 |
175 | results['composite_score'] = total_score
176 | return results
177 |
178 | def extract_generation(self, generate, index):
179 | # 使用第一个metric的提取方法
180 | return self.metrics[0].extract_generation(generate, index)
181 |
182 | def get_prompts(self, test_datasets: Dataset, tokenizer, args) -> List[str]:
183 | # 使用第一个metric的prompt获取方法
184 | return self.metrics[0].get_prompts(test_datasets, tokenizer, args)
185 |
186 |
187 | # 定义指标类型映射字典
188 | METRIC_REGISTRY: Dict[str, Type[BaseMetric]] = {
189 | 'code': CodeEvalMetric,
190 | 'em': ExactMatchMetric
191 | }
192 |
193 |
194 | def create_metric(config: EvaluationConfig) -> BaseMetric:
195 | """根据配置创建指标
196 |
197 | Args:
198 | config: 评估配置对象
199 |
200 | Returns:
201 | BaseMetric: 创建的指标对象
202 |
203 | Raises:
204 | ValueError: 当指定了不支持的指标类型时
205 | """
206 | # 单个指标的情况
207 | if isinstance(config.metrics, str):
208 | metric_class = METRIC_REGISTRY.get(config.metrics)
209 | if not metric_class:
210 | raise ValueError(f"不支持的指标类型: {config.metrics}")
211 | return metric_class()
212 |
213 | # 混合指标的情况
214 | return CompositeMetric(config.metrics)
215 |
--------------------------------------------------------------------------------
/utils/eval/eval_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import re
3 | from trl.models.utils import unwrap_model_for_generation
4 | from accelerate import Accelerator
5 | from transformers import (
6 | PreTrainedModel,
7 | PreTrainedTokenizerBase,
8 | GenerationConfig
9 | )
10 | from tqdm.auto import tqdm
11 | import torch
12 |
13 |
14 | def _generate_completions(
15 | prompts: list[str],
16 | model: PreTrainedModel,
17 | tokenizer: PreTrainedTokenizerBase,
18 | accelerator: Accelerator,
19 | generation_config: Optional[GenerationConfig],
20 | batch_size: int = 1,
21 | gather_deepspeed3_params: bool = True,
22 | ) -> list[str]:
23 | """
24 | Generates completions for a list of pre-formatted prompts from the given model.
25 |
26 | Args:
27 | prompts (list[str]): A list of input prompts for which completions are to be generated.
28 | model (PreTrainedModel): The pre-trained model to be used for generation.
29 | tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for encoding and decoding.
30 | accelerator (Accelerator): The accelerator to be used for model execution.
31 | generation_config (GenerationConfig): Configuration for text generation.
32 | batch_size (int, optional): The number of prompts to process in each batch. Default is 1.
33 | gather_deepspeed3_params: bool = True: if OOM, False it.
34 |
35 | Returns:
36 | list[str]: A list of generated text completions corresponding to the input prompts.
37 | """
38 | completions = []
39 | with unwrap_model_for_generation(model, accelerator,
40 | gather_deepspeed3_params=gather_deepspeed3_params) as unwrapped_model:
41 | # 创建分布式安全的进度条(仅在主进程显示)
42 | total_batches = len(prompts) // batch_size + (1 if len(prompts) % batch_size != 0 else 0)
43 |
44 | progress_bar = tqdm(
45 | total=total_batches,
46 | desc="Generating Completions",
47 | disable=not accelerator.is_main_process, # 非主进程禁用进度条
48 | dynamic_ncols=True # 自动适应终端宽度
49 | )
50 |
51 | for idx in range(0, len(prompts), batch_size):
52 | batch = prompts[idx: idx + batch_size]
53 | tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device)
54 | generations = unwrapped_model.generate(
55 | **tokenized_batch,
56 | generation_config=generation_config,
57 | pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
58 | )
59 | for prompt, generation in zip(tokenized_batch.input_ids, generations):
60 | # Remove prompt from generation
61 | generation = generation[len(prompt):]
62 | completion = tokenizer.decode(generation, skip_special_tokens=True)
63 | completions.append(completion)
64 | # 更新进度条(自动处理分布式同步)
65 | progress_bar.update(1)
66 | progress_bar.close()
67 | return completions
68 |
69 |
70 | def reason_post_process(code, index):
71 | """
72 |
73 | Args:
74 | code (str): 输入字符串。
75 | index (int/str): 当前字符串的序号 (索引)。
76 |
77 | Returns:
78 | str 或 int: 如果找到代码块,则返回代码块字符串;
79 | 否则,返回输入的字符串序号 (index)。
80 | """
81 |
82 | # Look for code blocks
83 | code_pattern = r'```(?:python|go|javascript|java|bash|js|cpp|cs|php)(.*?)```'
84 | code_match = re.findall(code_pattern, code, re.DOTALL)
85 |
86 | if code_match:
87 | # If code block exists, return its content (excluding the ``` markers)
88 | return code_match[-1].strip()
89 | else:
90 | # If no code block, return the solution content directly
91 | return str(index)
92 |
--------------------------------------------------------------------------------
/utils/eval/train_script.py:
--------------------------------------------------------------------------------
1 | import heapq
2 | import time
3 | import shutil
4 | import itertools
5 | from dataclasses import dataclass
6 | from random import random
7 | from typing import List, Optional
8 | import re
9 | from contextlib import contextmanager
10 | from transformers.integrations import WandbCallback
11 | import torch
12 | from datasets import load_dataset
13 | import wandb
14 | from accelerate.utils import gather_object
15 | import os
16 | import deepspeed
17 | from trl.models.utils import unwrap_model_for_generation
18 | from accelerate import Accelerator
19 | from transformers import (
20 | GenerationConfig,
21 | PreTrainedModel,
22 | PreTrainedTokenizerBase,
23 | Trainer,
24 | TrainingArguments,
25 | )
26 | from tqdm.auto import tqdm
27 | from transformers import GenerationConfig, Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM
28 | from transformers.trainer_callback import ExportableState
29 | import evaluate
30 | from eval_metric import CodeEvalMetric
31 | from utils import MultiRoundDataProcess, SftDataCollator
32 | from callback import EvaluationCallback
33 |
34 | os.environ["HF_ALLOW_CODE_EVAL"] = "1"
35 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
36 |
37 | batch_size = 1
38 | gradient_accumulation_steps = 2
39 | num_train_epochs = 1
40 |
41 | training_args = TrainingArguments(
42 | output_dir="./output/",
43 | report_to="wandb", # this tells the Trainer to log the metrics to W&B
44 | per_device_train_batch_size=batch_size,
45 | bf16=True,
46 | learning_rate=2e-5,
47 | lr_scheduler_type="cosine",
48 | warmup_ratio=0.1,
49 | save_strategy="steps",
50 | save_steps=20,
51 | save_total_limit=2,
52 | gradient_accumulation_steps=gradient_accumulation_steps,
53 | gradient_checkpointing=True,
54 | num_train_epochs=num_train_epochs,
55 | # logging strategies
56 | logging_strategy="steps",
57 | logging_steps=2,
58 | torch_compile=False,
59 | remove_unused_columns=False,
60 | deepspeed='deepspeed_config/ds_config_zero3.json'
61 | )
62 |
63 | if __name__ == "__main__":
64 | model_name_or_path = '/Qwen2.5-Coder-32B-Instruct'
65 | train_data_path = 'train_data/fix_bash1k.jsonl'
66 | test_data_path = 'eval_train_test/test.jsonl'
67 |
68 | max_len = 4096
69 | auto_adapt = False
70 |
71 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
72 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
73 |
74 | train_dataset = MultiRoundDataProcess(train_data_path, tokenizer, max_len, auto_adapt)
75 |
76 | test_dataset = load_dataset(path="json", data_files=test_data_path)
77 | test_dataset = test_dataset['train']
78 |
79 | data_collator = SftDataCollator(tokenizer, max_len)
80 | generate_config = GenerationConfig(
81 | max_new_tokens=4096,
82 | max_length=max_len,
83 | use_cache=True
84 | )
85 |
86 | trainer = Trainer(
87 | model=model,
88 | args=training_args,
89 | train_dataset=train_dataset,
90 | data_collator=data_collator,
91 | processing_class=tokenizer
92 | )
93 |
94 | # if os.environ.get('LOCAL_RANK', '0') == '0': # 只在主进程中初始化
95 | # wandb.init(project="huggingface")
96 | # wandb.init(project="huggingface")
97 |
98 | wandb_callback = EvaluationCallback(
99 | trainer=trainer,
100 | test_dataset=test_dataset,
101 | generation_config=generate_config,
102 | num_samples=6,
103 | freq=1,
104 | metric=CodeEvalMetric(),
105 | max_checkpoints=1,
106 | per_device_test_batch_size=1,
107 | higher_better=True,
108 | start_update_best_checkpoints=100
109 | )
110 | trainer.add_callback(wandb_callback)
111 |
112 | trainer.train()
--------------------------------------------------------------------------------
/utils/eval/vllm/run_serve.sh:
--------------------------------------------------------------------------------
1 | MODEL_PATH='Qwen2.5-Coder-32B-Instruct'
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0,1 python vllm_serve.py\
5 | --model "$MODEL_PATH" \
6 | --tensor_parallel_size 2 \
7 | --max_model_len 4096 \
8 | --port 8001 \
9 | --dtype "bfloat16" \
10 | --enable_prefix_caching False
--------------------------------------------------------------------------------
/utils/eval/vllm/vllm_client.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import logging
3 | import time
4 | from typing import Optional
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from trl.import_utils import is_requests_available, is_vllm_available
10 |
11 | if is_requests_available():
12 | import requests
13 | from requests import ConnectionError
14 |
15 | if is_vllm_available():
16 | from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
17 | from vllm.distributed.utils import StatelessProcessGroup
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class VLLMClient:
23 | """
24 | A client class to interact with a vLLM server.
25 |
26 | This class provides methods to generate completions, initialize and manage weight update groups, and update model
27 | weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
28 |
29 | Args:
30 | host (`str`, *optional*, defaults to `"0.0.0.0"`):
31 | IP address of the vLLM server.
32 | server_port (`int`, *optional*, defaults to `8000`):
33 | Port number of the vLLM server.
34 | group_port (`int`, *optional*, defaults to `51216`):
35 | Port number for the weight update group.
36 | connection_timeout (`float`, *optional*, defaults to `0.0`):
37 | Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
38 | timeout, a `ConnectionError` is raised.
39 |
40 | Examples:
41 | Run the vLLM server with the model `Qwen/Qwen2.5-7B`:
42 |
43 | ```
44 | $ trl vllm-serve --model Qwen/Qwen2.5-7B
45 | ...
46 | INFO: Application startup complete.
47 | INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
48 | ```
49 |
50 | Use the client to generate completions and update model weights:
51 |
52 | ```python
53 | >>> from trl.extras.vllm_client import VLLMClient
54 | >>> client = VLLMClient()
55 | >>> client.generate(["Hello, AI!", "Tell me a joke"])
56 | [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
57 | [911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
58 |
59 | >>> from transformers import AutoModelForCausalLM
60 | >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
61 | >>> client.update_model_params(model)
62 | ```
63 | """
64 |
65 | def __init__(
66 | self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216,
67 | connection_timeout: float = 0.0
68 | ):
69 | if not is_requests_available():
70 | raise ImportError("requests is not installed. Please install it with `pip install requests`.")
71 | if not is_vllm_available():
72 | raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")
73 |
74 | self.session = requests.Session()
75 | self.host = host
76 | self.server_port = server_port
77 | self.group_port = group_port
78 | self.check_server(connection_timeout) # check server and fail after timeout
79 | self.init_communicator()
80 | atexit.register(self.close_communicator) # when the client object is deleted, close the weight update group
81 |
82 | def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
83 | """
84 | Check server availability with retries on failure, within a total timeout duration. If the server is not up
85 | after the total timeout duration, raise a `ConnectionError`.
86 |
87 | Args:
88 | retry_interval (`float`, *optional*, defaults to `2.0`):
89 | Interval in seconds between retries.
90 | total_timeout (`float`, *optional*, defaults to `0.0`):
91 | Total timeout duration in seconds.
92 | """
93 | url = f"http://{self.host}:{self.server_port}/health/"
94 | start_time = time.time() # Record the start time
95 |
96 | while True:
97 | try:
98 | # response = requests.get(url)
99 | response = requests.get(url, proxies={"http": None, "https": None}) # todo: 排查
100 | except requests.exceptions.RequestException as exc:
101 | # Check if the total timeout duration has passed
102 | elapsed_time = time.time() - start_time
103 | if elapsed_time >= total_timeout:
104 | raise ConnectionError(
105 | f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
106 | "seconds. Make sure the server is running by running `trl vllm-serve`."
107 | ) from exc
108 | else:
109 | if response.status_code == 200:
110 | logger.info("Server is up!")
111 | return None
112 |
113 | # Retry logic: wait before trying again
114 | logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...")
115 | time.sleep(retry_interval)
116 |
117 | def generate(
118 | self,
119 | prompts: list[str],
120 | n: int = 1,
121 | repetition_penalty: float = 1.0,
122 | temperature: float = 1.0,
123 | top_p: float = 1.0,
124 | top_k: int = -1,
125 | min_p: float = 0.0,
126 | max_tokens: int = 16,
127 | guided_decoding_regex: Optional[str] = None,
128 | ) -> list[list[str]]:
129 | """
130 | Generates model completions for the provided prompts.
131 |
132 | Args:
133 | prompts (`list[str]`):
134 | List of text prompts for which the model will generate completions.
135 | n (`int`, *optional*, defaults to `1`):
136 | Number of completions to generate for each prompt.
137 | repetition_penalty (`float`, *optional*, defaults to `1.0`):
138 | Parameter for repetition penalty. 1.0 means no penalty.
139 | temperature (`float`, *optional*, defaults to `1.0`):
140 | Temperature parameter for sampling. Higher values increase diversity.
141 | top_p (`float`, *optional*, defaults to `1.0`):
142 | Top-p sampling parameter.`1.0` means no truncation.
143 | top_k (`int`, *optional*, defaults to `-1`):
144 | Top-k sampling parameter. `-1` means no truncation.
145 | min_p (`float`, *optional*, defaults to `0.0`):
146 | Minimum probability for sampling.
147 | max_tokens (`int`, *optional*, defaults to `16`):
148 | Maximum number of tokens to generate for each prompt.
149 | guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
150 | Regular expression to guide the decoding process.
151 |
152 | Returns:
153 | `list[list[int]]`:
154 | List of lists of token IDs representing the model-generated completions for each prompt.
155 | """
156 | url = f"http://{self.host}:{self.server_port}/generate/"
157 | response = self.session.post(
158 | url,
159 | json={
160 | "prompts": prompts,
161 | "n": n,
162 | "repetition_penalty": repetition_penalty,
163 | "temperature": temperature,
164 | "top_p": top_p,
165 | "top_k": top_k,
166 | "min_p": min_p,
167 | "max_tokens": max_tokens,
168 | "guided_decoding_regex": guided_decoding_regex,
169 | },
170 | )
171 | if response.status_code == 200:
172 | return response.json()["completion_ids"]
173 | else:
174 | raise Exception(f"Request failed: {response.status_code}, {response.text}")
175 |
176 | def init_communicator(self):
177 | """
178 | Initializes the weight update group in a distributed setup for model synchronization.
179 | """
180 | # Get the tensor parallel size from the server
181 | url = f"http://{self.host}:{self.server_port}/get_tensor_parallel_size/"
182 | # response = requests.get(url)
183 | response = requests.get(url, proxies={"http": None, "https": None})
184 | if response.status_code == 200:
185 | tensor_parallel_size = response.json()["tensor_parallel_size"]
186 | else:
187 | raise Exception(f"Request failed: {response.status_code}, {response.text}")
188 |
189 | world_size = tensor_parallel_size + 1
190 | self.rank = tensor_parallel_size # The client's rank is the last process
191 |
192 | # Initialize weight update group
193 | url = f"http://{self.host}:{self.server_port}/init_communicator/"
194 | # In the server side, the host is set to 0.0.0.0
195 | response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
196 | if response.status_code != 200:
197 | raise Exception(f"Request failed: {response.status_code}, {response.text}")
198 |
199 | # Set up the communication group for weight broadcasting
200 | pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
201 | self.pynccl_comm = PyNcclCommunicator(pg, device="cuda:0")
202 |
203 | def update_named_param(self, name: str, weights: torch.Tensor):
204 | """
205 | Updates a specific named parameter in the model and broadcasts it to other processes.
206 |
207 | Args:
208 | name (`str`):
209 | Name of the layer whose weights are being updated.
210 | weights (`torch.Tensor`):
211 | Tensor containing the updated weights.
212 | """
213 | dtype, shape = str(weights.dtype), tuple(weights.shape)
214 | url = f"http://{self.host}:{self.server_port}/update_named_param/"
215 | response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape})
216 | if response.status_code != 200:
217 | raise Exception(f"Request failed: {response.status_code}, {response.text}")
218 |
219 | # Broadcast the weights to the other processes
220 | self.pynccl_comm.broadcast(weights, src=self.rank, stream=torch.cuda.current_stream())
221 | self.pynccl_comm.group.barrier()
222 |
223 | def update_model_params(self, model: nn.Module):
224 | """
225 | Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
226 |
227 | Args:
228 | model (`nn.Module`):
229 | Model whose parameters (weights/biases) are to be updated.
230 | """
231 | for name, param in model.named_parameters():
232 | # Update each parameter individually
233 | self.update_named_param(name, param.data)
234 |
235 | def reset_prefix_cache(self):
236 | """
237 | Resets the prefix cache for the model.
238 | """
239 | url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
240 | response = self.session.post(url)
241 | if response.status_code != 200:
242 | raise Exception(f"Request failed: {response.status_code}, {response.text}")
243 |
244 | def close_communicator(self):
245 | """
246 | Closes the weight update group and cleans up the communication group.
247 | """
248 | url = f"http://{self.host}:{self.server_port}/close_communicator/"
249 | response = self.session.post(url)
250 | if response.status_code != 200:
251 | raise Exception(f"Request failed: {response.status_code}, {response.text}")
252 |
253 |
254 | # Example usage
255 | if __name__ == "__main__":
256 | from vllm import SamplingParams
257 |
258 | client = VLLMClient()
259 |
260 | # Generate completions
261 | responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams())
262 | print("Responses:", responses) # noqa
263 |
264 | # Update model weights
265 | from transformers import AutoModelForCausalLM
266 |
267 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda")
268 | client.update_model_params(model)
269 |
--------------------------------------------------------------------------------
/utils/script/download_model.py:
--------------------------------------------------------------------------------
1 | from modelscope import snapshot_download
2 | model_dir = snapshot_download('qwen/Qwen1.5-0.5B', cache_dir='../../../download_llm')
3 | print("模型下载完成")
--------------------------------------------------------------------------------
/utils/script/generate_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | ORIGINAL_DATA_PATH = '../../1.jsonl' # 原数据路径
4 | OUT_DATA_PATH = './out_data.jsonl' # 转换为框架适用的role content模式的输出路径
5 |
6 |
7 | data1 = pd.read_json(ORIGINAL_DATA_PATH, lines=True)
8 | # 创建一个空的列表来存储处理后的数据
9 | processed_data = []
10 | # 迭代每一行数据
11 | for index, row in data1.iterrows():
12 | message = [
13 | {"role": "user", "content": row['instruction']},
14 | {"role": "assistant", "content": row['output']}
15 | ]
16 | processed_data.append({"message": message})
17 | # 将处理后的数据转换为 DataFrame
18 | processed_df = pd.DataFrame(processed_data)
19 | # 保存为jsonl格式
20 | processed_df.to_json(OUT_DATA_PATH, orient='records', lines=True, force_ascii=False)
21 |
--------------------------------------------------------------------------------
/utils/script/merge_lora.py:
--------------------------------------------------------------------------------
1 | from peft import PeftModel
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 |
4 | # base模型和lora训练后保存模型的位置
5 | base_model_path = 'download_llm/LLM-Research/Phi-3-mini-128k-instruct'
6 | lora_path = '/LLM-out/checkpoint-616'
7 | # 合并后整个模型的保存地址
8 | merge_output_dir = 'merged_lora_model'
9 |
10 | tokenizer = AutoTokenizer.from_pretrained(base_model_path)
11 | base_model = AutoModelForCausalLM.from_pretrained(
12 | base_model_path,
13 | device_map="cuda",
14 | torch_dtype="auto",
15 | trust_remote_code=True,
16 | )
17 |
18 | lora_model = PeftModel.from_pretrained(base_model, lora_path)
19 | model = lora_model.merge_and_unload()
20 |
21 | if merge_output_dir:
22 | model.save_pretrained(merge_output_dir)
23 | tokenizer.save_pretrained(merge_output_dir)
--------------------------------------------------------------------------------
/utils/vlm_template.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Any
2 |
3 |
4 | def get_key(model):
5 | """获取模型的key"""
6 | key_set = set()
7 | for name, param in model.named_parameters():
8 | key_set.add(name.split('.')[0])
9 | return key_set
10 |
11 |
12 | # 基础模板处理类
13 | class TemplateProcessor:
14 | def __init__(self):
15 | # model 参数名称
16 | self.model_key = {
17 | "visual": 'visual',
18 | "llm": 'llm',
19 | "projector": 'projector'
20 | }
21 |
22 | def process(self, questions: List[str], answers: List[str]) -> Any:
23 | raise NotImplementedError("Subclasses must implement `process` method.")
24 |
25 |
26 | # LLaVA 模板处理类
27 | class LlavaTemplateProcessor(TemplateProcessor):
28 | NAME = 'LlavaProcessor'
29 |
30 | def __init__(self):
31 | super().__init__()
32 | self.model_key['visual'] = 'vision_tower'
33 | self.model_key['llm'] = 'language_model'
34 | self.model_key['projector'] = 'multi_modal_projector'
35 |
36 | def process(self, questions: List[str], answers: List[str]) -> List[Dict]:
37 | converted_data = []
38 | for question, answer in zip(questions, answers):
39 | user_content = [{'index': None, 'text': question, 'type': 'text'}]
40 | assistant_content = [{'index': None, 'text': answer, 'type': 'text'}]
41 |
42 | converted_data.append({'content': user_content, 'role': 'user'})
43 | converted_data.append({'content': assistant_content, 'role': 'assistant'})
44 | image_dict = {'index': 0, 'text': None, 'type': 'image'}
45 | converted_data[0]['content'].append(image_dict)
46 | return converted_data
47 |
48 |
49 | # Qwen2VL 模板处理类可与llava相同
50 | class Qwen2VLTemplateProcessor(LlavaTemplateProcessor):
51 | NAME = 'Qwen2VLProcessor'
52 |
53 | def __init__(self):
54 | super().__init__()
55 | # model 参数名称
56 | self.model_key['visual'] = 'visual'
57 | self.model_key['llm'] = 'model'
58 | self.model_key['projector'] = 'lm_head'
59 |
60 |
61 | if __name__ == '__main__':
62 | qwen = Qwen2VLTemplateProcessor()
63 | print(qwen.model_key)
64 |
--------------------------------------------------------------------------------
/vlm_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets import load_dataset
3 | from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration
4 |
5 | from trl import (
6 | ModelConfig,
7 | SFTConfig,
8 | SFTTrainer,
9 | TrlParser,
10 | get_kbit_device_map,
11 | get_peft_config,
12 | get_quantization_config,
13 | )
14 | from train_args.vlm_config.script_args import ScriptArgs
15 | from utils.data_process import VlmQaDataset
16 | from utils.data_collator import VlmQaDataCollator
17 |
18 |
19 | def freeze_vision_projection(model, model_key_value):
20 | """
21 | model_key_value: 'projector' or 'visual'
22 | 冻结模型的视觉部分或者转接部分
23 | """
24 | module = getattr(model, model_key_value)
25 | for param in module.parameters():
26 | param.requires_grad = False
27 |
28 |
29 | def initial_args():
30 | parser = TrlParser((ScriptArgs, SFTConfig, ModelConfig))
31 | script_args, training_args, model_args = parser.parse_args_and_config()
32 | training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
33 | # training_args.dataset_kwargs = {"skip_prepare_dataset": True}
34 | training_args.remove_unused_columns = False
35 | return script_args, training_args, model_args
36 |
37 |
38 | def create_model_processor(model_args, script_args):
39 | torch_dtype = (
40 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
41 | )
42 | processor = AutoProcessor.from_pretrained(
43 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
44 | )
45 |
46 | if script_args.train_mode in ['lora', 'qlora']:
47 | model_args.lora_target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]
48 | model_args.use_peft = True
49 |
50 | quantization_config = None
51 | if script_args.train_mode == 'qlora':
52 | model_args.load_in_4bit = True
53 | quantization_config = get_quantization_config(model_args)
54 |
55 | model_kwargs = dict(
56 | revision=model_args.model_revision,
57 | attn_implementation=model_args.attn_implementation,
58 | torch_dtype=torch_dtype,
59 | device_map=get_kbit_device_map() if quantization_config is not None else None,
60 | quantization_config=quantization_config,
61 | )
62 |
63 | model = AutoModelForVision2Seq.from_pretrained(
64 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
65 | )
66 |
67 | return {
68 | 'model': model,
69 | 'processor': processor,
70 | 'peft_config': get_peft_config(model_args),
71 | }
72 |
73 |
74 | # 加载数据集,后续不同任务可能会动态调整
75 | def load_vlm_dataset(script_args):
76 | train_dataset = VlmQaDataset(script_args.train_data_path)
77 | return train_dataset
78 |
79 |
80 | def main():
81 | script_args, training_args, model_args = initial_args()
82 | train_dict = create_model_processor(model_args, script_args)
83 |
84 | model = train_dict['model']
85 | processor = train_dict['processor']
86 |
87 | train_dataset = load_vlm_dataset(script_args)
88 | collate_fn = VlmQaDataCollator(processor)
89 |
90 | model_keys = collate_fn.template_process.model_key
91 | if script_args.freeze_vision:
92 | freeze_vision_projection(model, model_keys['visual'])
93 | if script_args.freeze_projector:
94 | freeze_vision_projection(model, model_keys['projector'])
95 |
96 | trainer = SFTTrainer(
97 | model=model,
98 | args=training_args,
99 | data_collator=collate_fn,
100 | train_dataset=train_dataset,
101 | processing_class=processor.tokenizer,
102 | peft_config=get_peft_config(model_args),
103 | )
104 |
105 | trainer.train()
106 |
107 |
108 | if __name__ == "__main__":
109 | main()
110 |
--------------------------------------------------------------------------------