├── .gitattributes ├── README.md ├── data ├── gkd_data.jsonl ├── rlhf.jsonl ├── sft_data.jsonl └── test.jsonl ├── evaluate ├── README.md ├── args.py ├── base_utils.py ├── dataset │ └── humaneval_python.jsonl ├── evaluation.py ├── execution.py ├── generate.py ├── main.py ├── run.sh └── task │ └── humaneval.py ├── inference_vllm.py ├── llm_tricks ├── DPO_example │ ├── README.md │ ├── dataset.py │ ├── dpo_train.py │ ├── evaluate.py │ ├── loss.py │ └── unsloth_dpo.jsonl ├── dora │ ├── READEME.md │ ├── dora_example.py │ └── lora_and_dora.ipynb ├── moe │ ├── READEME.md │ ├── input.txt │ └── make_moe_step_by_step.ipynb └── transformer │ └── README.md ├── main_train.py ├── pic └── pic.jpg ├── requirements.txt ├── rlhf ├── README.md ├── __init__.py ├── common_args.py ├── ds_config │ ├── ds_zero2.yaml │ └── ds_zero3.yaml ├── gkd_run.sh ├── rejected_sampling │ ├── README.md │ ├── genetate.py │ ├── rejected_sampling.py │ ├── run_generate.sh │ └── template.py ├── requirements.txt ├── rlhf_args │ ├── base_config.py │ ├── cpo-simpo_config.py │ ├── cpo_config.py │ ├── dpo_config.py │ ├── kto_config.py │ ├── ppo_config.py │ ├── reward_config.py │ ├── rloo_config.py │ └── simpo_config.py ├── rlhf_run.sh ├── train_gkd.py ├── train_rlhf.py └── utils │ └── util.py ├── run_eval_test.sh ├── run_example.sh ├── run_vlm_example.sh ├── train_args ├── __init__.py ├── common_args.py ├── deepspeed_config │ ├── ds_config_zero0.json │ ├── ds_config_zero2.json │ └── ds_config_zero3.json ├── dpo │ └── README.md ├── sft │ └── base.py └── vlm_config │ └── script_args.py ├── utils ├── __init__.py ├── data_collator.py ├── data_process.py ├── eval │ ├── README.md │ ├── callback.py │ ├── configs.py │ ├── eval_metric.py │ ├── eval_utils.py │ ├── train_script.py │ └── vllm │ │ ├── run_serve.sh │ │ ├── vllm_client.py │ │ └── vllm_serve.py ├── script │ ├── download_model.py │ ├── generate_data.py │ └── merge_lora.py └── vlm_template.py └── vlm_train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # LLM-Dojo: 大模型修炼道场 😊 3 | 4 | 5 | Tips: 图片完全由AI生成 6 | ## 🌟 项目简介 7 | LLM-Dojo使用简洁且易阅读的代码构建LLM、VLM模型训练、RLHF框架等各种功能,使项目**易于学习且方便魔改与实验**,与大多开源框架相同均是基于huggingface。 8 | 主要内容如下: 9 | - **SFT训练框架:** 简洁清晰的开源大模型训练框架,支持Deepspeed多卡、Lora、QLora、全参等训练,自动适配chat template。 10 | - **VLM多模态训练框架:** 支持多模态各种任务训练(目前仅实现QA),自动适配模型template。 11 | - **RLHF框架:** RLHF训练框架,持续更新,包括 知识蒸馏,DPO、RLOO、SimPO等各种强化学习方法,适配Deepspeed多卡及Lora,一张A100即可运行,详情可见: [RLHF](./rlhf/README.md)。 12 | - **最新LLM tricks详解:** 持续更新大模型领域最新tricks介绍,包括新论文方法的复现等,希望可以给你一些创新的想法,该模块主要集中在```llm_tricks```文件夹下。 13 | 14 | ### 目录 15 | 16 | - [项目简介](#-项目简介) 17 | - [Latest News](#-latest-news) 18 | - [RLHF训练框架](#rlhf训练框架) 19 | - [已支持RLHF方法](#rlhf训练框架) 20 | - [知识蒸馏](#rlhf训练框架) 21 | - [拒绝采样](#rlhf训练框架) 22 | - [SFT训练框架](#sft训练框架) 23 | - [已支持微调模型](#已支持微调模型) 24 | - [训练数据格式说明](#训练数据格式说明) 25 | - [适配框架数据处理](#适配框架数据处理) 26 | - [Quick Start](#quick-start) 27 | - [多模态训练(VLM)](#多模态训练vlm) 28 | - [已支持模型](#已支持模型) 29 | - [已支持任务类型](#已支持任务类型) 30 | - [数据格式](#数据格式) 31 | - [Tricks](#tricks) 32 | - [技术发文](#技术发文) 33 | - [致谢](#-致谢) 34 | 35 | ## 📖 Latest News 36 | - [2025-04-18] 拒绝采样实现,由generate部分和评测部分组成,支持API作为评价模型,具体可见[Rejected Sampling](./rlhf/rejected_sampling/README.md),详细文档解释尚未完成。 37 | - [2025-04-13] 🚀训练中评测,使用vllm大幅提升训练中模型生成速度,具体可见[Predict with generate](./utils/eval/README.md)。详细文档解释与优化正在进行。 38 | - [2024-12-31] 支持多模态训练,可见[多模态训练(VLM)](#多模态训练vlm) 39 | - [2024-11-06] 重构RLHF,具体可见目录中RLHF训练框架部分 40 | - [2024-10-31] 添加auto_adapt参数控制是否自动适配template、更新优化DPO训练(迁移至RLHF目录下) 41 | - [2024-10-15] 增加知识蒸馏训练方法。可见[知识蒸馏](./rlhf/README.md) 42 | - [2024-10-14] 删除chat template模块,因为使用tokenizer的apply_chat_template即可 43 | - [2024-09-20] 增加evaluate模块,一个简洁的模型评测框架,目前仅支持Humaneval。可见[Evaluate](./evaluate/README.md) 44 |
More news... 45 | 46 | - [2024-08-27] 🤓增加从零实现自己编写DPO、SimPO代码,包括数据、loss、训练等部分。可见[DPO example](./llm_tricks/DPO_example/README.md) 47 | - [2024-08-08] 支持直接修改配置文件启动及命令行启动,增加框架适配数据处理代码。 48 | - [2024-08-04] 支持自适应单轮或多轮对话,无需指定单轮或多轮,训练根据数据自行判断单轮或多轮。且可自主设置system命令。可见[训练数据格式说明](#训练数据格式说明) 49 | - [2024-07-19] RLHF 强化学习框架新增CPO,SimPO,以及二者融合CPO-SimPO 50 | - [2024-07-16] RLHF 强化学习框架更新完成,支持deepspeed单卡/多卡 进行强化学习lora、qlora等训练,详细可见[RLHF](./rlhf/README.md) 51 | - [2024-06-9] 🚀支持DPO训练,分为单轮对话DPO(自己构建,方便魔改)和多轮对话DPO(简洁实现),支持deepspeed的lora和qlora,具体介绍可见 [DPO使用说明](./train_args/dpo/README.md) 52 | - [2024-06-5] 🤓llm_tricks 增加从头开始实现MOE 53 | - [2024-06-10] 🚀增加一步一步实现Transformer技术发文(包括代码等从零介绍),可见 [技术发文](#技术发文) 54 | - [2024-05-18] 🤓支持Deepspeed单机多卡、单机单卡的Lora、Qlora、全量微调等训练! 55 | - [2024-04-28] 🚀 更新dora微调原理示例、支持qwen模型微调 56 |
57 | 58 | ## RLHF训练框架 59 | 60 | RLHF训练框架,支持并持续更新 知识蒸馏、Reward、PPO、DPO、RLOO、SimPO、KTO等各种强化学习方法,适配Deepspeed多卡及Lora,一张A100即可运行。 61 | 详情可见: [RLHF](./rlhf/README.md)。 62 | 63 | 主要包括三类: 64 | 65 | **1、RLHF** 66 | 67 | **2、Knowledge Distillation (知识蒸馏)** 68 | 69 | **3、Rejected Sampling (拒绝采样) :待更新** 70 | 71 | ## SFT训练框架 72 | 73 | ### 已支持微调模型 74 | 理论上支持对所有模型的微调,下述仅为测试过。 75 | 76 | 支持基于Deepspeed的多卡/单卡 Lora、Qlora、Dora微调: 77 | - [x] [Qwen(Qwen1.5/Qwen2)](https://github.com/QwenLM/Qwen.git) 78 | - [x] [Yi](https://github.com/01-ai/Yi) 79 | - [x] [Gemma系列](https://github.com/google/gemma_pytorch) 80 | - [x] [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) 81 | - [x] [Deepseek](https://github.com/deepseek-ai/DeepSeek-LLM) 82 | - [x] [MiniCPM](https://github.com/OpenBMB/MiniCPM) 83 | - [x] [Llama系列](https://github.com/meta-llama/llama3) 84 | - [x] [deepseek-coder](https://github.com/deepseek-ai/DeepSeek-Coder) 85 | - [x] [哔哩哔哩 Index-1.9B](https://github.com/bilibili/Index-1.9B) 86 | - [x] [baichuan系列](https://github.com/baichuan-inc/Baichuan2) 87 | - [x] [GLM系列](https://github.com/THUDM/GLM-4) 88 | 89 | ### 😮训练数据格式说明 90 | SFT数据格式为user(system) assistant标准模式,**无需指定单轮或多轮,训练根据数据自行判断单轮或多轮。** 91 | 92 | 示例如下,示例文件可参见```data/sft_data.jsonl```: 93 | ```json lines 94 | {"message": [{"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"},{"role": "user", "content": "How many helicopters can a human eat in one sitting"},{"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together"},{"role": "user", "content": "你好"},{"role": "assistant", "content": "hellow"}]} 95 | ``` 96 | 可根据需求自行决定是否增加system字段,**建议训练数据没有特殊需求可删除system字段** 97 | 98 | 训练参数中auto_adapt参数控制是否自动适配template,如设置为False,则不自动适配,按原始的content进行训练。 99 | 100 | ### 适配框架数据处理 101 | 鉴于框架指定格式数据可能会跟常规数据有些不同,故可以通过```utils/script/generate_data.py```文件进行处理,输入应为正常的instruction和output的jsonl格式文件, 102 | 如下: 103 | ```json lines 104 | {"instruction":"将这个句子改写成将来时态:“太阳将会照耀明亮。”","output":"太阳将会散发温暖的光芒。"} 105 | ``` 106 | 运行后即可得到无system的user、assistant指定格式。 107 | 108 | ### 🤓Quick Start 109 | 目前支持直接**python命令单卡训练**、**deepspeed(推荐使用)单机多卡**及**单机单卡训练**. 所有方式均支持Qlora、Lora、Dora方法。 110 | 111 | 112 | 1、 支持**命令行传参**启动,启动示例可见: ```run_example.sh```。 **相关参数在train_args下的common_args.py和sft/base.py。** 113 | ```bash 114 | bash run_example.sh 115 | ``` 116 | 117 | 2、 也支持**参数文件直接修改默认值**,改好参数后运行以下命令启动: 118 | ```bash 119 | deepspeed --include localhost:6,7 main_train.py 120 | ``` 121 | 更详细的Deepspeed原理及解释可以看文章:[Deepspeed配置及使用讲解](https://zhuanlan.zhihu.com/p/698631348) 122 | 123 | 124 | 显存占用测试如下: 125 | 126 | | 策略 | 模型大小 | 显存占用 | 127 | |------------|----------|------| 128 | | Lora | Qwen(7B) | 26g | 129 | | Lora+Zero2 | Qwen(7B) | 26g | 130 | | Lora+zero3 | Qwen(7B) | 16g | 131 | 132 | ## 多模态训练(VLM) 133 | 134 | 支持Deepspeed多卡 Lora、Qlora,冻结vision、冻结projector训练等 135 | 136 | ### 已支持模型 137 | - [x] [Qwen-2-VL](https://github.com/QwenLM/Qwen2-VL) 138 | - [x] [Llava](https://github.com/haotian-liu/LLaVA) 139 | 140 | ### 已支持任务类型 141 | 142 | - Visual Question Answering 143 | 144 | ### 数据格式 145 | 146 | **Visual Question Answering:** 147 | 148 | - metadata.jsonl: 包含所有图片与文字信息,示例如下: 149 | 150 | ```json lines 151 | {"file_name":"Images/P0003_0004.png", "messages":[{"question":"how are you", "answer":"i am fine"}]} 152 | ``` 153 | 154 | 其中file_name为train_data_path下的的图片路径,具体可如下: 155 | ``` 156 | train_data_path 157 | ├─ metadata.jsonl 158 | └─ Images 159 | └─ P0003_0004.png 160 | └─ ...........png 161 | ``` 162 | 163 | ### Quick Start 164 | 165 | 通过freeze_vision、freeze_projector参数控制是否冻结vision、projector。 166 | 167 | ```bash 168 | bash run_vlm_example.sh 169 | ``` 170 | 171 | ## Tricks 172 | 所有相关的trciks及讲解都在llm_tricks文件夹下 173 | - [Dora代码讲解(llm_tricks/dora/READEME.md)](./llm_tricks/dora/READEME.md) 174 | - [Lora+微调代码实例](https://github.com/mst272/simple-lora-plus) 175 | - [从零实现MOE](./llm_tricks/moe/READEME.md) 176 | - [从零实现DPO](./llm_tricks/DPO_example/README.md) 177 | - [从零实现Transformer](./llm_tricks/transformer/README.md) 178 | 179 | ### 技术发文 180 |
More news... 181 | 182 | - [Deepspeed配置及使用讲解](https://zhuanlan.zhihu.com/p/698631348) 183 | - [从零代码构建MOE](https://zhuanlan.zhihu.com/p/701777558) 184 | - [一步一步实现Transformer代码](https://medium.com/@sdwzh2725/transformer-code-step-by-step-understandingtransformer-d2ea773f15fa) 185 | - [DPO训练QWEN2及魔改DPO实现](https://zhuanlan.zhihu.com/p/702569978) 186 | - [实现强化学习(RLHF)全流程代码构建(PPO、RLOO等)](https://zhuanlan.zhihu.com/p/708935028) 187 | - [从零实现强化学习DPO(SimPO)训练代码](https://zhuanlan.zhihu.com/p/716706368) 188 | - [实现一个简洁的代码模型评测框架(以Qwen2.5-coder 评测Humaneval为例)](https://zhuanlan.zhihu.com/p/721218072) 189 | - [大模型LLM知识蒸馏代码讲解与训练](https://zhuanlan.zhihu.com/p/1064724364) 190 | 191 |
192 | 193 | 194 | ## 🤝 致谢! 195 | 项目学习了优秀开源项目,感谢huggingface、流萤及一些国内外开源项目。 196 | 197 | 🪂 无论是提出问题(Issue)还是贡献代码(Pull Request),都是对项目的巨大支持。 198 | *** 199 | -------------------------------------------------------------------------------- /data/test.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt": "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", "label": "\n\n\n\n\ndef check(has_close_elements):\n assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\n assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\n assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\n assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\n assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\n\ncheck(has_close_elements)"} 2 | {"prompt": "from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", "label": "\n\n\n\n\ndef check(separate_paren_groups):\n assert separate_paren_groups('(()()) ((())) () ((())()())') == [\n '(()())', '((()))', '()', '((())()())'\n ]\n assert separate_paren_groups('() (()) ((())) (((())))') == [\n '()', '(())', '((()))', '(((())))'\n ]\n assert separate_paren_groups('(()(())((())))') == [\n '(()(())((())))'\n ]\n assert separate_paren_groups('( ) (( )) (( )( ))') == ['()', '(())', '(()())']\n\ncheck(separate_paren_groups)"} 3 | {"prompt": "\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", "label": "\n\n\n\n\ndef check(truncate_number):\n assert truncate_number(3.5) == 0.5\n assert abs(truncate_number(1.33) - 0.33) < 1e-6\n assert abs(truncate_number(123.456) - 0.456) < 1e-6\n\ncheck(truncate_number)"} 4 | {"prompt": "from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n", "label": "\n\n\n\n\ndef check(below_zero):\n assert below_zero([]) == False\n assert below_zero([1, 2, -3, 1, 2, -3]) == False\n assert below_zero([1, 2, -4, 5, 6]) == True\n assert below_zero([1, -1, 2, -2, 5, -5, 4, -4]) == False\n assert below_zero([1, -1, 2, -2, 5, -5, 4, -5]) == True\n assert below_zero([1, -2, 2, -2, 5, -5, 4, -4]) == True\n\ncheck(below_zero)"} 5 | {"prompt": "from typing import List\n\n\ndef mean_absolute_deviation(numbers: List[float]) -> float:\n \"\"\" For a given list of input numbers, calculate Mean Absolute Deviation\n around the mean of this dataset.\n Mean Absolute Deviation is the average absolute difference between each\n element and a centerpoint (mean in this case):\n MAD = average | x - x_mean |\n >>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0])\n 1.0\n \"\"\"\n", "label": "\n\n\n\n\ndef check(mean_absolute_deviation):\n assert abs(mean_absolute_deviation([1.0, 2.0, 3.0]) - 2.0/3.0) < 1e-6\n assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0]) - 1.0) < 1e-6\n assert abs(mean_absolute_deviation([1.0, 2.0, 3.0, 4.0, 5.0]) - 6.0/5.0) < 1e-6\n\ncheck(mean_absolute_deviation)"} 6 | {"prompt": "from typing import List\n\n\ndef intersperse(numbers: List[int], delimeter: int) -> List[int]:\n \"\"\" Insert a number 'delimeter' between every two consecutive elements of input list `numbers'\n >>> intersperse([], 4)\n []\n >>> intersperse([1, 2, 3], 4)\n [1, 4, 2, 4, 3]\n \"\"\"\n", "label": "\n\n\n\n\ndef check(intersperse):\n assert intersperse([], 7) == []\n assert intersperse([5, 6, 3, 2], 8) == [5, 8, 6, 8, 3, 8, 2]\n assert intersperse([2, 2, 2], 2) == [2, 2, 2, 2, 2]\n\ncheck(intersperse)"} 7 | {"prompt": "from typing import List\n\n\ndef parse_nested_parens(paren_string: str) -> List[int]:\n \"\"\" Input to this function is a string represented multiple groups for nested parentheses separated by spaces.\n For each of the group, output the deepest level of nesting of parentheses.\n E.g. (()()) has maximum two levels of nesting while ((())) has three.\n\n >>> parse_nested_parens('(()()) ((())) () ((())()())')\n [2, 3, 1, 3]\n \"\"\"\n", "label": "\n\n\n\n\ndef check(parse_nested_parens):\n assert parse_nested_parens('(()()) ((())) () ((())()())') == [2, 3, 1, 3]\n assert parse_nested_parens('() (()) ((())) (((())))') == [1, 2, 3, 4]\n assert parse_nested_parens('(()(())((())))') == [4]\n\ncheck(parse_nested_parens)"} 8 | {"prompt": "from typing import List\n\n\ndef filter_by_substring(strings: List[str], substring: str) -> List[str]:\n \"\"\" Filter an input list of strings only for ones that contain given substring\n >>> filter_by_substring([], 'a')\n []\n >>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')\n ['abc', 'bacd', 'array']\n \"\"\"\n", "label": "\n\n\n\n\ndef check(filter_by_substring):\n assert filter_by_substring([], 'john') == []\n assert filter_by_substring(['xxx', 'asd', 'xxy', 'john doe', 'xxxAAA', 'xxx'], 'xxx') == ['xxx', 'xxxAAA', 'xxx']\n assert filter_by_substring(['xxx', 'asd', 'aaaxxy', 'john doe', 'xxxAAA', 'xxx'], 'xx') == ['xxx', 'aaaxxy', 'xxxAAA', 'xxx']\n assert filter_by_substring(['grunt', 'trumpet', 'prune', 'gruesome'], 'run') == ['grunt', 'prune']\n\ncheck(filter_by_substring)"} 9 | {"prompt": "from typing import List, Tuple\n\n\ndef sum_product(numbers: List[int]) -> Tuple[int, int]:\n \"\"\" For a given list of integers, return a tuple consisting of a sum and a product of all the integers in a list.\n Empty sum should be equal to 0 and empty product should be equal to 1.\n >>> sum_product([])\n (0, 1)\n >>> sum_product([1, 2, 3, 4])\n (10, 24)\n \"\"\"\n", "label": "\n\n\n\n\ndef check(sum_product):\n assert sum_product([]) == (0, 1)\n assert sum_product([1, 1, 1]) == (3, 1)\n assert sum_product([100, 0]) == (100, 0)\n assert sum_product([3, 5, 7]) == (3 + 5 + 7, 3 * 5 * 7)\n assert sum_product([10]) == (10, 10)\n\ncheck(sum_product)"} 10 | {"prompt": "from typing import List, Tuple\n\n\ndef rolling_max(numbers: List[int]) -> List[int]:\n \"\"\" From a given list of integers, generate a list of rolling maximum element found until given moment\n in the sequence.\n >>> rolling_max([1, 2, 3, 2, 3, 4, 2])\n [1, 2, 3, 3, 3, 4, 4]\n \"\"\"\n", "label": "\n\n\n\n\ndef check(rolling_max):\n assert rolling_max([]) == []\n assert rolling_max([1, 2, 3, 4]) == [1, 2, 3, 4]\n assert rolling_max([4, 3, 2, 1]) == [4, 4, 4, 4]\n assert rolling_max([3, 2, 3, 100, 3]) == [3, 3, 3, 100, 100]\n\ncheck(rolling_max)"} 11 | {"prompt": "\n\ndef is_palindrome(string: str) -> bool:\n \"\"\" Test if given string is a palindrome \"\"\"\n return string == string[::-1]\n\n\ndef make_palindrome(string: str) -> str:\n \"\"\" Find the shortest palindrome that begins with a supplied string.\n Algorithm idea is simple:\n - Find the longest postfix of supplied string that is a palindrome.\n - Append to the end of the string reverse of a string prefix that comes before the palindromic suffix.\n >>> make_palindrome('')\n ''\n >>> make_palindrome('cat')\n 'catac'\n >>> make_palindrome('cata')\n 'catac'\n \"\"\"\n", "label": "\n\n\n\n\ndef check(make_palindrome):\n assert make_palindrome('') == ''\n assert make_palindrome('x') == 'x'\n assert make_palindrome('xyz') == 'xyzyx'\n assert make_palindrome('xyx') == 'xyx'\n assert make_palindrome('jerry') == 'jerryrrej'\n\ncheck(make_palindrome)"} 12 | {"prompt": "from typing import List\n\n\ndef string_xor(a: str, b: str) -> str:\n \"\"\" Input are two strings a and b consisting only of 1s and 0s.\n Perform binary XOR on these inputs and return result also as a string.\n >>> string_xor('010', '110')\n '100'\n \"\"\"\n", "label": "\n\n\n\n\ndef check(string_xor):\n assert string_xor('111000', '101010') == '010010'\n assert string_xor('1', '1') == '0'\n assert string_xor('0101', '0000') == '0101'\n\ncheck(string_xor)"} -------------------------------------------------------------------------------- /evaluate/README.md: -------------------------------------------------------------------------------- 1 | # Code Evaluation 2 | 3 | 面向代码大模型评测,目前支持的有humaneval,后续会逐步新增一些。 4 | 5 | 注意: 6 | 只能在Linux机器上进行,windows上 execution 部分有错误 7 | 8 | 9 | ## Quick Start 10 | 11 | evaluate文件下的run.sh作为一个启动示例,详细参数解释可见args.py。 12 | 13 | 其中评测集数据应为jsonl格式 14 | ```bash 15 | bash run.sh 16 | ``` 17 | 18 | ### 评测生成文件 19 | 20 | 模型评测完成后主要生成三个文件: 21 | 22 | 1、out.jsonl: 模型输出,在评测数据的基础上新增字段如下: 23 | - output:模型接收prompt后产生的原本输出 24 | - generation: 经过提取代码部分及添加测试后的输出 25 | 26 | 2、logs.jsonl: 评测测试用例运行的信息 27 | 28 | 3、metric.json: 评测结果指标 29 | 30 | ## 新增评测 31 | 如若想要新增评测任务,可以继承base_utils中的类进行相关设置,然后在task文件夹下创建相关文件进行继承。 32 | 最后在main.py文件的TASKS中添加即可。 -------------------------------------------------------------------------------- /evaluate/args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | torch_dtype = ['bf16', 'fp16', 'fp32'] 4 | task_name = ['humaneval'] 5 | 6 | 7 | @dataclass 8 | class EvaluateArgs: 9 | """ 10 | 配置Evaluate的参数 11 | """ 12 | 13 | """任务名,支持的任务见列表task_name""" 14 | task_name: str = 'humaneval' 15 | 16 | """模型生成的一些参数设置""" 17 | max_new_tokens: int = 100 18 | torch_dtype: str = 'fp16' 19 | do_sample: bool = False 20 | top_p: float = 0.95 21 | temperature: int = 1 22 | 23 | """模型路径""" 24 | model_name_or_path: str = './' 25 | """是否仅评测模式,若为False,则下面的evaluate_data_path不用填""" 26 | evaluate_only: bool = False 27 | """仅评测时的文件路径""" 28 | evaluate_data_path: str = '' 29 | """模型生成的输出路径""" 30 | output_path: str = './' 31 | """模型评测时输出的信息:pass或错误""" 32 | save_logs_path: str = './' 33 | """模型结果保存路径""" 34 | save_metrics_path: str = './' 35 | """评测所需要的测试集""" 36 | data_file: str = '' 37 | -------------------------------------------------------------------------------- /evaluate/base_utils.py: -------------------------------------------------------------------------------- 1 | class TaskUtils: 2 | def __init__(self): 3 | self.IMPORT_HELPER = { 4 | "python": [ 5 | "import math", 6 | "import re", 7 | "import sys", 8 | "import copy", 9 | "import datetime", 10 | "import itertools", 11 | "import collections", 12 | "import heapq", 13 | "import functools", 14 | "import hashlib", 15 | "import numpy", 16 | "import numpy as np", 17 | "import string", 18 | "from typing import *", 19 | "from collections import *", 20 | ] 21 | } 22 | 23 | @staticmethod 24 | def build_instruction(example): 25 | """ 26 | 根据模型构建合适的指令 27 | """ 28 | return example['prompt'] 29 | 30 | @staticmethod 31 | def generation_code_process(example): 32 | """ 33 | 对生成的代码提取函数部分 及 设置import、添加test用例等操作 34 | """ 35 | pass 36 | 37 | @staticmethod 38 | def evaluate_function(input_file, args): 39 | """ 40 | 最终评测的方法,输入为保存的生成jsonl文件 41 | """ 42 | pass 43 | -------------------------------------------------------------------------------- /evaluate/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | import numpy as np 5 | from tqdm import tqdm 6 | from execution import check_correctness 7 | 8 | 9 | def stream_jsonl_all(filename: str): 10 | """ 11 | Streams a JSONL file. 12 | """ 13 | results = [] 14 | fp = open(filename, "r") 15 | for line in fp: 16 | if any(not x.isspace() for x in line): 17 | results.append(json.loads(line)) 18 | fp.close() 19 | return results 20 | 21 | 22 | # 计算pass@1 23 | def evaluate_functional_correctness( 24 | input_file: str = None, 25 | n_workers: int = 32, 26 | timeout: float = 3.0, 27 | k: int = 1, 28 | save_logs_path='./logs.jsonl' 29 | ): 30 | """ 31 | Evaluates the functional correctness of a model. 32 | """ 33 | sample_jsonl = stream_jsonl_all(input_file) 34 | 35 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 36 | 37 | futures = [] 38 | n_samples = 0 39 | results = defaultdict(list) 40 | 41 | print("Reading samples...") 42 | for sample in tqdm(sample_jsonl): 43 | task_id = sample["task_id"] 44 | if sample["generation"] is None: 45 | continue 46 | args = (sample['generation'], task_id, timeout) 47 | future = executor.submit(check_correctness, *args) 48 | futures.append(future) 49 | n_samples += 1 50 | 51 | print("Running test suites...") 52 | for future in tqdm(as_completed(futures), total=len(futures)): 53 | result = future.result() 54 | results[result["task_id"]].append(result) 55 | # Calculate pass@k. 56 | total, correct, logs = [], [], [] 57 | for result in results.values(): 58 | passed = [r["passed"] for r in result] 59 | res = [{r['task_id']: r["result"]} for r in result] 60 | logs.append(res) 61 | total.append(len(passed)) 62 | correct.append(sum(passed)) 63 | total = np.array(total) 64 | correct = np.array(correct) 65 | 66 | pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()} 67 | 68 | with open(save_logs_path, 'w', encoding='utf-8') as fw: 69 | for ex in logs: 70 | fw.write(json.dumps(ex) + '\n') 71 | print(f"execute logs were saved at {save_logs_path}") 72 | 73 | return pass_at_k 74 | 75 | 76 | def estimate_pass_at_k( 77 | num_samples, 78 | num_correct, 79 | k: int 80 | ) -> np.ndarray: 81 | """ 82 | Estimates pass@k and returns them in an array. 83 | """ 84 | 85 | def estimator(n: int, c: int, k: int) -> float: 86 | """ 87 | Calculates 1 - comb(n - c, k) / comb(n, k). 88 | """ 89 | if n - c < k: 90 | return 1.0 91 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 92 | 93 | assert len(num_samples) == len(num_correct) 94 | num_samples_it = iter(num_samples) 95 | 96 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) 97 | 98 | 99 | -------------------------------------------------------------------------------- /evaluate/execution.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import faulthandler 3 | import io 4 | import multiprocessing 5 | import os 6 | import platform 7 | import signal 8 | import tempfile 9 | 10 | 11 | @contextlib.contextmanager 12 | def chdir(root): 13 | if root == ".": 14 | yield 15 | return 16 | cwd = os.getcwd() 17 | os.chdir(root) 18 | try: 19 | yield 20 | except BaseException as exc: 21 | raise exc 22 | finally: 23 | os.chdir(cwd) 24 | 25 | 26 | @contextlib.contextmanager 27 | def create_tempdir(): 28 | with tempfile.TemporaryDirectory() as dirname: 29 | with chdir(dirname): 30 | yield dirname 31 | 32 | 33 | @contextlib.contextmanager 34 | def swallow_io(): 35 | stream = WriteOnlyStringIO() 36 | with contextlib.redirect_stdout(stream): 37 | with contextlib.redirect_stderr(stream): 38 | with redirect_stdin(stream): 39 | yield 40 | 41 | 42 | @contextlib.contextmanager 43 | def time_limit(seconds): 44 | def signal_handler(signum, frame): 45 | raise TimeoutException("Timed out!") 46 | 47 | signal.setitimer(signal.ITIMER_REAL, seconds) 48 | signal.signal(signal.SIGALRM, signal_handler) 49 | try: 50 | yield 51 | finally: 52 | signal.setitimer(signal.ITIMER_REAL, 0) 53 | 54 | 55 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 56 | _stream = "stdin" 57 | 58 | 59 | def check_correctness(check_program, task_id, timeout=3): 60 | manager = multiprocessing.Manager() 61 | result = manager.list() 62 | 63 | p = multiprocessing.Process(target=unsafe_execute, args=(check_program, result, timeout)) 64 | p.start() 65 | p.join(timeout=timeout + 1) 66 | if p.is_alive(): 67 | p.kill() 68 | 69 | if not result: 70 | result.append("timed out") 71 | 72 | return { 73 | "task_id": task_id, 74 | "passed": result[0] == "passed", 75 | "result": result[0] 76 | } 77 | 78 | 79 | def unsafe_execute(check_program, result, timeout): 80 | with create_tempdir(): 81 | 82 | # These system calls are needed when cleaning up tempdir. 83 | import os 84 | import shutil 85 | 86 | rmtree = shutil.rmtree 87 | rmdir = os.rmdir 88 | chdir = os.chdir 89 | 90 | # Disable functionalities that can make destructive changes to the test. 91 | reliability_guard() 92 | 93 | # Run program. 94 | try: 95 | exec_globals = {} 96 | with swallow_io(): 97 | with time_limit(timeout): 98 | exec(check_program, exec_globals) 99 | result.append("passed") 100 | except TimeoutException: 101 | result.append("timed out") 102 | except BaseException as e: 103 | result.append(f"failed: {e}") 104 | 105 | # Needed for cleaning up. 106 | shutil.rmtree = rmtree 107 | os.rmdir = rmdir 108 | os.chdir = chdir 109 | 110 | 111 | class WriteOnlyStringIO(io.StringIO): 112 | """StringIO that throws an exception when it's read from""" 113 | 114 | def read(self, *args, **kwargs): 115 | raise OSError 116 | 117 | def readline(self, *args, **kwargs): 118 | raise OSError 119 | 120 | def readlines(self, *args, **kwargs): 121 | raise OSError 122 | 123 | def readable(self, *args, **kwargs): 124 | """Returns True if the IO object can be read.""" 125 | return False 126 | 127 | 128 | class TimeoutException(Exception): 129 | pass 130 | 131 | 132 | def reliability_guard(maximum_memory_bytes=None): 133 | """ 134 | This disables various destructive functions and prevents the generated code 135 | from interfering with the test (e.g. fork bomb, killing other processes, 136 | removing filesystem files, etc.) 137 | 138 | WARNING 139 | This function is NOT a security sandbox. Untrusted code, including, model- 140 | generated code, should not be blindly executed outside of one. See the 141 | Codex paper for more information about OpenAI's code sandbox, and proceed 142 | with caution. 143 | """ 144 | 145 | if maximum_memory_bytes is not None: 146 | import resource 147 | 148 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 149 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 150 | if not platform.uname().system == "Darwin": 151 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 152 | 153 | faulthandler.disable() 154 | 155 | import builtins 156 | 157 | builtins.exit = None 158 | builtins.quit = None 159 | 160 | import os 161 | 162 | os.environ["OMP_NUM_THREADS"] = "1" 163 | 164 | os.kill = None 165 | os.system = None 166 | os.putenv = None 167 | os.remove = None 168 | os.removedirs = None 169 | os.rmdir = None 170 | os.fchdir = None 171 | os.setuid = None 172 | os.fork = None 173 | os.forkpty = None 174 | os.killpg = None 175 | os.rename = None 176 | os.renames = None 177 | os.truncate = None 178 | os.replace = None 179 | os.unlink = None 180 | os.fchmod = None 181 | os.fchown = None 182 | os.chmod = None 183 | os.chown = None 184 | os.chroot = None 185 | os.fchdir = None 186 | os.lchflags = None 187 | os.lchmod = None 188 | os.lchown = None 189 | os.getcwd = None 190 | os.chdir = None 191 | 192 | import shutil 193 | 194 | shutil.rmtree = None 195 | shutil.move = None 196 | shutil.chown = None 197 | 198 | import subprocess 199 | 200 | subprocess.Popen = None # type: ignore 201 | 202 | __builtins__["help"] = None 203 | 204 | import sys 205 | 206 | sys.modules["ipdb"] = None 207 | sys.modules["joblib"] = None 208 | sys.modules["resource"] = None 209 | sys.modules["psutil"] = None 210 | sys.modules["tkinter"] = None 211 | -------------------------------------------------------------------------------- /evaluate/generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from tqdm import tqdm 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | 6 | 7 | def generate_one(example, tokenizer, model, args, task): 8 | prompt = task.build_instruction(example) 9 | inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device) 10 | 11 | stop_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.convert_tokens_to_ids( 12 | "<|EOT|>") 13 | assert isinstance(stop_id, int), "Invalid tokenizer, EOT id not found" 14 | 15 | outputs = model.generate( 16 | inputs, 17 | max_new_tokens=args.max_new_tokens, 18 | do_sample=args.do_sample, 19 | top_p=args.top_p, 20 | temperature=args.temperature, 21 | pad_token_id=stop_id, 22 | eos_token_id=stop_id 23 | ) 24 | 25 | output = tokenizer.decode(outputs[0][:], skip_special_tokens=True) 26 | example['output'] = output 27 | 28 | return task.generation_code_process(example) 29 | 30 | 31 | def generate_main(args, task): 32 | model_name_or_path = args.model_name_or_path 33 | saved_path = args.output_path 34 | 35 | print("model", model_name_or_path) 36 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 37 | print("load tokenizer {} from {} over.".format(tokenizer.__class__, model_name_or_path)) 38 | torch_dtype = torch.bfloat16 if args.torch_dtype == 'bf16' else torch.float16 if args.torch_dtype == 'fp16' else torch.float32 39 | model = AutoModelForCausalLM.from_pretrained( 40 | model_name_or_path, 41 | torch_dtype=torch_dtype, 42 | device_map="auto", 43 | trust_remote_code=True, 44 | ) 45 | model.eval() 46 | examples = [json.loads(x) for x in open(args.data_file) if x.strip()] 47 | print("Read {} examples for evaluation over.".format(len(examples))) 48 | 49 | generated_examples = [] 50 | for ex in tqdm(examples, desc='Generating'): 51 | gen_example = generate_one(ex, tokenizer, model, args, task) 52 | generated_examples.append(gen_example) 53 | 54 | print("Generate all over!!!") 55 | with open(saved_path, 'w', encoding='utf-8') as fw: 56 | for ex in generated_examples: 57 | fw.write(json.dumps(ex) + '\n') 58 | print("Save {} processed examples into {} over!".format(len(generated_examples), saved_path)) 59 | 60 | result = task.evaluate_function(saved_path,args) 61 | save_metrics(args, result) 62 | print(result, model_name_or_path) 63 | 64 | 65 | def evaluation_only(args, task): 66 | result = task.evaluate_function(args.evaluate_data_path, args) 67 | save_metrics(args, result) 68 | print(result, args.model_name_or_path) 69 | 70 | 71 | def save_metrics(args, result): 72 | args_dict = args.__dict__ 73 | with open(args.save_metrics_path, 'w', encoding='utf-8') as fw: 74 | fw.write(json.dumps(result) + '\n') 75 | fw.write(json.dumps(args_dict) + '\n') 76 | 77 | -------------------------------------------------------------------------------- /evaluate/main.py: -------------------------------------------------------------------------------- 1 | from transformers import HfArgumentParser 2 | import os 3 | from args import EvaluateArgs 4 | from task import humaneval 5 | from generate import generate_main, evaluation_only 6 | 7 | parser = HfArgumentParser((EvaluateArgs,)) 8 | args = parser.parse_args_into_dataclasses()[0] 9 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 10 | 11 | # 任务列表 12 | TASKS = { 13 | "humaneval": humaneval.HumanEval() 14 | } 15 | 16 | task = TASKS[args.task_name] 17 | 18 | if not args.evaluate_only: 19 | generate_main(args, task) 20 | else: 21 | evaluation_only(args, task) 22 | -------------------------------------------------------------------------------- /evaluate/run.sh: -------------------------------------------------------------------------------- 1 | MODELS_PATH="/qwen" 2 | LOGS_PATH="./logs.jsonl" 3 | OUT_PATH='./out.jsonl' 4 | METRIC_PATH='./metric.json' 5 | DATA_FILE='./dataset/humaneval_python.jsonl' 6 | 7 | 8 | CUDA_VISIBLE_DEVICES=0 python main.py \ 9 | --model_name_or_path "$MODELS_PATH" \ 10 | --task_name "humaneval" \ 11 | --save_logs_path "$LOGS_PATH" \ 12 | --output_path "$OUT_PATH" \ 13 | --do_sample false \ 14 | --top_p 0.95 \ 15 | --max_new_tokens 1024 \ 16 | --evaluate_only false \ 17 | --torch_dtype "bf16" \ 18 | --save_metrics_path $METRIC_PATH \ 19 | --data_file $DATA_FILE -------------------------------------------------------------------------------- /evaluate/task/humaneval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from base_utils import TaskUtils 4 | from evaluation import evaluate_functional_correctness 5 | 6 | 7 | class HumanEval(TaskUtils): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | @staticmethod 12 | def build_instruction(example): 13 | """ 14 | 根据模型构建合适的指令 15 | """ 16 | return example['prompt'] 17 | 18 | def generation_code_process(self, example): 19 | """ 20 | 对生成的代码提取函数部分 及 设置import等操作 21 | """ 22 | code = example['output'] 23 | test_case = example['test'] 24 | code_ = [] 25 | skip_rest = False # 新增标志位,用于跳过if __name__ == "__main__"及其后面的内容 26 | for line in code.split("\n"): 27 | if skip_rest: 28 | continue # 如果遇到if __name__ == "__main__",跳过该行及其后面的所有内容 29 | if any(keyword in line for keyword in ["if __name__ == \"__main__\":", "if __name__ == \'__main__\':"]): 30 | skip_rest = True # 设置标志位,表示需要跳过后续内容 31 | continue 32 | if "def " in line and line[0] != ' ' and line[0] != '\t': 33 | code_.append("def " + line.split("def ")[1]) 34 | continue 35 | if "class" in line and line.strip().endswith(":"): 36 | code_.append(line) 37 | continue 38 | if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t': 39 | continue 40 | code_.append(line) 41 | code = "\n".join(code_) 42 | test_setup = "\n".join(self.IMPORT_HELPER["python"]) + "\n" 43 | example['generation'] = test_setup + code + "\n" + test_case + "\n" 44 | return example 45 | 46 | @staticmethod 47 | def evaluate_function(input_file, args): 48 | """ 49 | 最终评测的方法,输入为保存的生成jsonl文件 50 | """ 51 | return evaluate_functional_correctness(input_file, n_workers=1, timeout=3.0, k=1, save_logs_path=args.save_logs_path) 52 | 53 | -------------------------------------------------------------------------------- /inference_vllm.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | import os 3 | import torch 4 | from transformers import AutoTokenizer 5 | 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7' 7 | model_name_or_path = 'DeepSeek-R1-Distill-Qwen-32B' # model path 8 | llm = LLM( 9 | model=model_name_or_path, 10 | max_model_len=8192, 11 | device='cuda', 12 | dtype=torch.bfloat16, 13 | tensor_parallel_size=8 # CUDA_VISIBLE_DEVICES 数量 14 | ) 15 | 16 | prompt = "请帮我生成一个关于夏天的诗歌。" 17 | 18 | 19 | TOKENIZER = AutoTokenizer.from_pretrained(model_name_or_path) 20 | messages = [ 21 | {"role": "user", "content": prompt} 22 | ] 23 | text = TOKENIZER.apply_chat_template( 24 | messages, 25 | tokenize=False, 26 | add_generation_prompt=True 27 | ) 28 | 29 | prompts = [text] 30 | 31 | sampling_params = SamplingParams( 32 | max_tokens=8192, 33 | top_p=0.9, 34 | top_k=1, 35 | temperature=0.0, 36 | repetition_penalty=1.0, 37 | ) 38 | outputs = llm.generate(prompts, sampling_params) 39 | for output in outputs: 40 | prompt = output.prompt 41 | generated_text = output.outputs[0].text 42 | print(f"Prompt:\n{prompt}") 43 | print(f"Generated text:\n {generated_text}") 44 | -------------------------------------------------------------------------------- /llm_tricks/DPO_example/README.md: -------------------------------------------------------------------------------- 1 | # 从零实现强化学习DPO(SimPO)训练代码 2 | 3 | ## Quick start 4 | ```python 5 | python dpo_train.py 6 | ``` 7 | 8 | ## 说明 9 | 本文档下的从零实现只是一个学习的demo,用以理解原理所用,并没有增加分布式等。所以尽管使用2B的小模型,显存占用也高达30+GB。 10 | 11 | 精度设置fp16可能会出现loss 为nan的现象 12 | 13 | ```dpo_train.py```为训练主路径, 相关loss计算在```loss.py```. 14 | 15 | 如果想要使用DPO或者Simpo、CPO等强化学习方法真正训练的话, 16 | 可以使用本项目中的rlhf构建的强化学习框架:[RLHF](../../rlhf/README.md) 17 | 18 | 支持deepspeed的单机多卡Lora、Dora、Qlora、全量参数训练,并自动适配模型的chat template。 -------------------------------------------------------------------------------- /llm_tricks/DPO_example/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | 4 | 5 | class RlhfDataset(Dataset): 6 | def __init__(self, file_path, tokenizer): 7 | with open(file_path, "r", encoding="utf-8") as file: 8 | data_list = file.readlines() 9 | self.data_list = data_list 10 | self.tokenizer = tokenizer 11 | 12 | def __getitem__(self, item): 13 | data = self.data_list[item] 14 | data = json.loads(data) 15 | prompt = data['prompt'] 16 | chosen = data['chosen'] 17 | rejected = data['rejected'] 18 | 19 | chosen_full_text = f"{prompt}\n\n### Response:\n{chosen}" 20 | rejected_full_text = f"{prompt}\n\n### Response:\n{rejected}" 21 | 22 | prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False) 23 | chosen_full_tokens = self.tokenizer.encode(chosen_full_text, add_special_tokens=False) 24 | rejected_full_tokens = self.tokenizer.encode(rejected_full_text, add_special_tokens=False) 25 | 26 | input = { 27 | "prompt": prompt_tokens, 28 | "chosen": chosen_full_tokens, 29 | "rejected": rejected_full_tokens, 30 | } 31 | return input 32 | 33 | def __len__(self): 34 | return len(self.data_list) -------------------------------------------------------------------------------- /llm_tricks/DPO_example/dpo_train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, random_split 2 | import torch 3 | from dataset import RlhfDataset 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | from loss import compute_batch_loss 6 | from evaluate import evaluate_loss_dataloader 7 | import time 8 | from functools import partial 9 | 10 | # 1、加载模型与tokenizer 11 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") 12 | model_path = '/IndexTeam/Index-1___9B-Chat' 13 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) 14 | ref_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) 15 | ref_model.eval() 16 | model.to(device) 17 | ref_model.to(device) 18 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) 19 | 20 | # 2、处理数据 21 | # 加载数据 22 | data_file = './unsloth_dpo.jsonl' 23 | # Dataset详细逻辑可看进入RlhfDataset实现 24 | dataset = RlhfDataset(data_file, tokenizer) 25 | # 划分训练集验证集 26 | train_size = int(len(dataset) * 0.85) # 85% for training 27 | val_size = len(dataset) - train_size # Remaining for validation 28 | train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) 29 | 30 | # 编写batch批次的padding及mask处理函数 31 | IGNORE_INDEX = False 32 | 33 | 34 | def data_collate(batch, pad_token_id, device, max_length=None, if_mask_prompt=True): 35 | batch_data = { 36 | "prompt": [], 37 | "chosen": [], 38 | "rejected": [], 39 | "rejected_mask": [], 40 | "chosen_mask": [] 41 | } 42 | 43 | # 判断长度及padding 44 | max_length_common = 0 45 | for key in ["chosen", "rejected"]: 46 | current_max = max(len(item[key]) for item in batch) 47 | max_length_common = max(max_length_common, current_max) 48 | 49 | # 转为torch tensor并padding,决定是否对prompt进行mask 50 | for item in batch: 51 | prompt = torch.tensor(item['prompt']) 52 | batch_data['prompt'].append(prompt) 53 | 54 | for key in ["chosen", "rejected"]: 55 | out = item[key] 56 | out_padding = out + [pad_token_id] * (max_length_common - len(out)) 57 | mask = torch.ones(len(out_padding)).bool() 58 | 59 | # padding部分的mask设置为 IGNORE_INDEX 60 | mask[len(out):] = IGNORE_INDEX 61 | 62 | if if_mask_prompt: 63 | mask[:prompt.shape[0] + 2] = IGNORE_INDEX 64 | batch_data[key].append(torch.tensor(out_padding)) 65 | batch_data[f"{key}_mask"].append(mask) 66 | 67 | # 进行最大长度截断 68 | for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]: 69 | tensor_stack = torch.stack(batch_data[key]) 70 | if max_length is not None: 71 | tensor_stack = tensor_stack[:, :max_length] 72 | # 将tensor移到对应的device 73 | batch_data[key] = tensor_stack.to(device) 74 | return batch_data 75 | 76 | 77 | customized_collate_fn = partial( 78 | data_collate, 79 | pad_token_id=tokenizer.pad_token_id, 80 | device=device, 81 | if_mask_prompt=True, 82 | max_length=1024 83 | ) 84 | # 设置相关参数 85 | batch_size = 4 86 | train_loader = DataLoader( 87 | train_dataset, 88 | batch_size=batch_size, 89 | collate_fn=customized_collate_fn, 90 | shuffle=True, 91 | drop_last=True 92 | ) 93 | val_loader = DataLoader( 94 | val_dataset, 95 | batch_size=1, 96 | collate_fn=customized_collate_fn, 97 | shuffle=False, 98 | drop_last=False 99 | ) 100 | 101 | 102 | # 3、开始计算DPO(或其他)的损失函数 103 | # 相关代码可以再loss里查看,就不写在主函数里了。 104 | 105 | # 4、编写训练函数 106 | def train_model( 107 | policy_model, reference_model, train_loader, val_loader, 108 | optimizer, num_epochs, beta, 109 | eval_freq, eval_iter): 110 | tracking = { 111 | "train_losses": [], 112 | "train_chosen_rewards": [], 113 | "train_rejected_rewards": [], 114 | "val_losses": [], 115 | "val_chosen_rewards": [], 116 | "val_rejected_rewards": [], 117 | "tokens_seen": [] 118 | } 119 | tokens_seen, global_step = 0, -1 120 | 121 | # 训练 122 | for epoch in range(num_epochs): 123 | # policy 模型需要训练 124 | policy_model.train() 125 | 126 | for idx, batch in enumerate(train_loader): 127 | optimizer.zero_grad() 128 | 129 | loss, chosen_rewards, rejected_rewards = compute_batch_loss( 130 | batch=batch, 131 | policy_model=policy_model, 132 | reference_model=reference_model, 133 | beta=beta 134 | ) 135 | loss.backward() 136 | optimizer.step() 137 | 138 | global_step += 1 139 | tokens_seen += batch["chosen"].numel() 140 | 141 | # 验证 142 | if global_step % eval_freq == 0: 143 | res = evaluate_loss_dataloader( 144 | policy_model=policy_model, 145 | reference_model=reference_model, 146 | train_loader=train_loader, 147 | val_loader=val_loader, 148 | beta=beta, 149 | eval_iter=eval_iter 150 | ) 151 | tracking["train_losses"].append(res["train_loss"]) 152 | tracking["train_chosen_rewards"].append(res["train_chosen_reward"]) 153 | tracking["train_rejected_rewards"].append(res["train_rejected_reward"]) 154 | tracking["val_losses"].append(res["val_loss"]) 155 | tracking["val_chosen_rewards"].append(res["val_chosen_reward"]) 156 | tracking["val_rejected_rewards"].append(res["val_rejected_reward"]) 157 | tracking["tokens_seen"].append(tokens_seen) 158 | train_reward_margin = res["train_chosen_reward"] - res["train_rejected_reward"] 159 | val_reward_margin = res["val_chosen_reward"] - res["val_rejected_reward"] 160 | 161 | print( 162 | f"Ep {epoch + 1} (Step {global_step:06d}): " 163 | f"Train loss {res['train_loss']:.3f}, Val loss {res['val_loss']:.3f}, " 164 | f"Train reward margins {train_reward_margin:.3f}, " 165 | f"Val reward margins {val_reward_margin:.3f}" 166 | ) 167 | 168 | return tracking 169 | 170 | 171 | # 5、开始训练! 172 | def main(): 173 | torch.manual_seed(42) 174 | start_time = time.time() 175 | optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) 176 | 177 | num_epochs = 3 178 | tracking = train_model( 179 | policy_model=model, 180 | reference_model=ref_model, 181 | train_loader=train_loader, 182 | val_loader=val_loader, 183 | optimizer=optimizer, 184 | num_epochs=num_epochs, 185 | beta=0.1, # value between 0.1 and 0.5 186 | eval_freq=2, 187 | eval_iter=2 188 | ) 189 | 190 | end_time = time.time() 191 | execution_time_minutes = (end_time - start_time) / 60 192 | print(f"Training completed in {execution_time_minutes:.2f} minutes.") 193 | 194 | 195 | if __name__ == "__main__": 196 | main() 197 | -------------------------------------------------------------------------------- /llm_tricks/DPO_example/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from loss import compute_loss_dataloader 3 | 4 | 5 | def evaluate_loss_dataloader(policy_model, reference_model, train_loader, val_loader, beta, eval_iter): 6 | policy_model.eval() 7 | with torch.no_grad(): 8 | train_loss, train_chosen_rewards, train_rejected_rewards = compute_loss_dataloader( 9 | data_loader=train_loader, 10 | policy_model=policy_model, 11 | reference_model=reference_model, 12 | beta=beta, 13 | num_batches=eval_iter 14 | ) 15 | val_loss, val_chosen_rewards, val_rejected_rewards = compute_loss_dataloader( 16 | data_loader=val_loader, 17 | policy_model=policy_model, 18 | reference_model=reference_model, 19 | beta=beta, 20 | num_batches=eval_iter 21 | ) 22 | res = { 23 | "train_loss": train_loss, 24 | "train_chosen_reward": train_chosen_rewards, 25 | "train_rejected_reward": train_rejected_rewards, 26 | "val_loss": val_loss, 27 | "val_chosen_reward": val_chosen_rewards, 28 | "val_rejected_reward": val_rejected_rewards 29 | } 30 | 31 | policy_model.train() 32 | return res 33 | -------------------------------------------------------------------------------- /llm_tricks/DPO_example/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | # 计算DPO loss的公式 7 | class DPOLoss(nn.Module): 8 | """ 9 | DPO Loss 10 | """ 11 | 12 | def __init__(self, beta: float = 0.1) -> None: 13 | super().__init__() 14 | self.beta = beta 15 | 16 | def forward( 17 | self, 18 | policy_chosen_logps: torch.Tensor, 19 | policy_rejected_logps: torch.Tensor, 20 | reference_chosen_logps: torch.Tensor, 21 | reference_rejected_logps: torch.Tensor, 22 | ): 23 | """ 24 | policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,) 25 | policy_rejected_logps: Shape: (batch_size,) 26 | reference_chosen_logps: Shape: (batch_size,) 27 | reference_rejected_logps: Shape: (batch_size,) 28 | """ 29 | policy_logps = policy_chosen_logps - policy_rejected_logps 30 | reference_logps = reference_chosen_logps - reference_rejected_logps 31 | logits = policy_logps - reference_logps 32 | 33 | loss = -F.logsigmoid(self.beta * logits) 34 | 35 | # 下面两个用于追踪训练的进度 36 | chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach() 37 | rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach() 38 | 39 | # 对每个batch进行平均(期望) 40 | return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean() 41 | 42 | 43 | class SimPo(nn.Module): 44 | """ 45 | SimPO Loss 46 | """ 47 | 48 | def __init__(self, beta: float = 0.1, gamma: float = 0.5) -> None: 49 | super().__init__() 50 | self.beta = beta 51 | self.gamma = gamma 52 | 53 | def forward( 54 | self, 55 | policy_chosen_logps: torch.Tensor, 56 | policy_rejected_logps: torch.Tensor, 57 | ): 58 | """ 59 | policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,) 60 | policy_rejected_logps: Shape: (batch_size,) 61 | """ 62 | logits = policy_chosen_logps - policy_rejected_logps 63 | logits = logits - self.gamma 64 | loss = -F.logsigmoid(self.beta * logits) 65 | 66 | # 对每个batch进行平均(期望) 67 | return loss.mean() 68 | 69 | 70 | # 计算每个模型的Log probabilities 71 | def compute_logprobs(logits, labels, mask=None): 72 | """ 73 | logits: shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果 74 | labels: shape (batch_size, sequence_len) 75 | """ 76 | 77 | # 需要先进行位移操作 78 | # 去掉标签的第一个 79 | labels = labels[:, 1:].clone() 80 | # 去掉模型输出的最后一个 81 | logits = logits[:, :-1, :] 82 | 83 | logps = F.log_softmax(logits, dim=-1) 84 | 85 | select_logprobs = torch.gather( 86 | input=logps, 87 | dim=-1, 88 | index=labels.unsqueeze(1) 89 | ).squeeze(1) 90 | 91 | if mask is not None: 92 | mask = mask[:, 1:].clone() 93 | # 进行掩码padding部分 94 | select_logprobs = select_logprobs * mask 95 | # 计算每一句的平均 96 | average_logprobs = select_logprobs.sum(-1) / mask.sum(-1) 97 | return average_logprobs 98 | else: 99 | return select_logprobs.mean(-1) 100 | 101 | 102 | # 计算每个模型的Log probabilities. 使用torch的F.cross_entropy进行计算。结果同上,均是一样。 103 | def compute_logprobs_f_cross(logits, labels, mask=None): 104 | """ 105 | logits: shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果 106 | labels: shape (batch_size, sequence_len) 107 | """ 108 | # 需要先进行位移操作 109 | # 去掉标签的第一个 110 | labels = labels[:, 1:].clone() 111 | # 去掉模型输出的最后一个 112 | logits = logits[:, :-1, :].clone() 113 | 114 | batch_size, sequence_len, vocab_size = logits.shape 115 | cross_entropy_loss = 0 116 | 117 | if mask is not None: 118 | mask = mask[:, 1:].clone() 119 | labels.masked_fill_(~mask, -100) 120 | for i in range(batch_size): 121 | cross_entropy_loss += F.cross_entropy(logits[i], labels[i]) 122 | else: 123 | for i in range(batch_size): 124 | cross_entropy_loss += F.cross_entropy(logits[i], labels[i]) 125 | cross_entropy_loss /= batch_size 126 | return cross_entropy_loss 127 | 128 | 129 | def compute_batch_loss(batch, policy_model, reference_model, beta): 130 | # 决定使用哪个loss 131 | # loss_fn = SimPo(beta, 0.5) SimPO loss 132 | loss_fn = DPOLoss(beta) # DPO loss 133 | 134 | policy_chosen_logps = compute_logprobs( 135 | logits=policy_model(batch["chosen"]).logits, 136 | labels=batch["chosen"], 137 | mask=batch["chosen_mask"] 138 | ) 139 | policy_rejected_logps = compute_logprobs( 140 | logits=policy_model(batch["rejected"]).logits, 141 | labels=batch["rejected"], 142 | mask=batch["rejected_mask"] 143 | ) 144 | reference_chosen_logps = compute_logprobs( 145 | logits=reference_model(batch['chosen']).logits, 146 | labels=batch['chosen'], 147 | mask=batch["chosen_mask"] 148 | ) 149 | reference_rejected_logps = compute_logprobs( 150 | logits=reference_model(batch['rejected']).logits, 151 | labels=batch['rejected'], 152 | mask=batch["rejected_mask"] 153 | ) 154 | loss, chosen_rewards, rejected_rewards = loss_fn( 155 | policy_chosen_logps=policy_chosen_logps, 156 | policy_rejected_logps=policy_rejected_logps, 157 | reference_chosen_logps=reference_chosen_logps, 158 | reference_rejected_logps=reference_rejected_logps, 159 | ) 160 | # SimPO使用如下 161 | # loss = loss_fn( 162 | # policy_chosen_logps=policy_chosen_logps, 163 | # policy_rejected_logps=policy_rejected_logps, 164 | # ) 165 | # return loss 166 | return loss, chosen_rewards, rejected_rewards 167 | 168 | 169 | def compute_loss_dataloader(data_loader, policy_model, reference_model, beta, num_batches=5): 170 | total_loss, total_chosen_rewards, total_rejected_rewards = 0., 0., 0. 171 | num_batches = min(num_batches, len(data_loader)) 172 | 173 | for i, batch in enumerate(data_loader): 174 | if i < num_batches: 175 | loss, chosen_rewards, rejected_rewards = compute_batch_loss( 176 | batch=batch, 177 | policy_model=policy_model, 178 | reference_model=reference_model, 179 | beta=beta 180 | ) 181 | total_loss += loss.item() 182 | total_chosen_rewards += chosen_rewards.item() 183 | total_rejected_rewards += rejected_rewards.item() 184 | else: 185 | break 186 | # 计算平均 187 | total_loss /= num_batches 188 | total_chosen_rewards /= num_batches 189 | total_rejected_rewards /= num_batches 190 | return total_loss, total_chosen_rewards, total_rejected_rewards 191 | 192 | 193 | if __name__ == "__main__": 194 | # 测试compute_logprobs_f_cross 与 compute_logprobs 195 | logits = torch.tensor( 196 | [[2.0, 1.0, 0.1, 0.4], 197 | [0.5, 2.5, 0.3, 0.5], 198 | [0.6, 2.5, 0.3, 0.8], 199 | [0.5, 2.5, 0.6, 0.6]], dtype=torch.float32).unsqueeze(0) 200 | mask = torch.tensor([[True, True, False, False]]) 201 | targets = torch.tensor([0, 1, 0, 2]).unsqueeze(0) 202 | loss1 = -compute_logprobs(logits, targets, mask) 203 | loss2 = compute_logprobs_f_cross(logits, targets, mask) 204 | print(loss1, loss2) 205 | -------------------------------------------------------------------------------- /llm_tricks/dora/READEME.md: -------------------------------------------------------------------------------- 1 | # DoRA: Weight-Decomposed Low-Rank Adaptation 2 | 3 | 此为Dora微调方法的实现(目前**huggingface也已集成dora**,故使用可以直接使用huggingface如下,本模块可以作为详细的**理论学习**)⚽ 4 | 5 | huggingface中使用如下,基于lora的基础上,增加use_dora参数即可。本项目的训练框架也支持dora训练。 6 | ```python 7 | from peft import LoraConfig 8 | 9 | # Initialize DoRA configuration 10 | config = LoraConfig( 11 | use_dora=True, ... 12 | ) 13 | ``` 14 | 15 | 16 | 17 | 18 | Implementation of "DoRA: Weight-Decomposed Low-Rank Adaptation" (Liu et al, 2024) https://arxiv.org/pdf/2402.09353.pdf 19 | 20 | 21 | ## 😸技术博客链接 22 | 23 | - [知乎:Dora原理及代码讲解](https://zhuanlan.zhihu.com/p/695269522) 24 | 25 | ## Tips: 26 | Dora是基于Lora的变体,故也对Lora进行了简单的示例。 27 | 28 | 29 | DoRA可以分两步描述,其中第一步是将预训练的权重矩阵分解为幅度向量(m)和方向矩阵(V)。第二步是将LoRA应用于方向矩阵V并单独训练幅度向量m。 30 | 31 | ## 如何使用 32 | 33 | 34 | dora_example.py 中有详细完整的 LoRA及DoRA训练与验证,建立了一个小的模型从训练到验证等全部过程。 35 | 36 | lora_and_dora.ipynb 用于自己调试及学习,可以在其中逐步运行以理解其原理。 37 | 38 | 运行以下代码可得到实验结果 39 | ```shell 40 | python dora_example.py 41 | ``` 42 | 43 | ## 实验结果如下: 44 | 运行 dora_example.py。超参数设置参考文件内。小模型具有局限性,具体dora和lora的实际效果对比还需要更多的实验。 45 | 46 | ```python 47 | Epoch: 001/001 | Batch 000/938 | Loss: 2.3010 48 | Epoch: 001/001 | Batch 400/938 | Loss: 0.4533 49 | Epoch: 001/001 | Batch 800/938 | Loss: 0.0464 50 | Epoch: 001/001 training accuracy: 95.31% 51 | Time elapsed: 0.11 min 52 | Total Training Time: 0.11 min 53 | Test accuracy: 96.88% 54 | Epoch: 001/002 | Batch 000/938 | Loss: 0.1734 55 | Epoch: 001/002 | Batch 400/938 | Loss: 0.0447 56 | Epoch: 001/002 | Batch 800/938 | Loss: 0.1270 57 | Epoch: 001/002 training accuracy: 96.88% 58 | Time elapsed: 0.11 min 59 | Epoch: 002/002 | Batch 000/938 | Loss: 0.0626 60 | Epoch: 002/002 | Batch 400/938 | Loss: 0.2149 61 | Epoch: 002/002 | Batch 800/938 | Loss: 0.1430 62 | Epoch: 002/002 training accuracy: 95.31% 63 | Time elapsed: 0.23 min 64 | Total Training Time: 0.23 min 65 | Test accuracy LoRA finetune: 96.88% 66 | Epoch: 001/002 | Batch 000/938 | Loss: 0.1588 67 | Epoch: 001/002 | Batch 400/938 | Loss: 0.1235 68 | Epoch: 001/002 | Batch 800/938 | Loss: 0.0506 69 | Epoch: 001/002 training accuracy: 100.00% 70 | Time elapsed: 0.11 min 71 | Epoch: 002/002 | Batch 000/938 | Loss: 0.1374 72 | Epoch: 002/002 | Batch 400/938 | Loss: 0.0892 73 | Epoch: 002/002 | Batch 800/938 | Loss: 0.0606 74 | Epoch: 002/002 training accuracy: 95.31% 75 | Time elapsed: 0.23 min 76 | Total Training Time: 0.23 min 77 | Test accuracy DoRA finetune: 98.44% 78 | ``` 79 | -------------------------------------------------------------------------------- /llm_tricks/dora/dora_example.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from torchvision import datasets 4 | from torchvision import transforms 5 | from torch.utils.data import DataLoader 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | import torch 9 | import copy 10 | 11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | BATCH_SIZE = 64 13 | 14 | # --------data process------------------------------------------------ 15 | 16 | train_dataset = datasets.MNIST(root='data', 17 | train=True, 18 | transform=transforms.ToTensor(), 19 | download=True) 20 | test_dataset = datasets.MNIST(root='data', 21 | train=False, 22 | transform=transforms.ToTensor()) 23 | 24 | train_loader = DataLoader(dataset=train_dataset, 25 | batch_size=BATCH_SIZE, 26 | shuffle=True) 27 | 28 | test_loader = DataLoader(dataset=test_dataset, 29 | batch_size=BATCH_SIZE, 30 | shuffle=False) 31 | 32 | # ----------------Hyperparameters------------------------------------- 33 | random_seed = 123 34 | learning_rate = 0.005 35 | num_epochs = 1 36 | 37 | # ----------------Architecture----------------------------------------- 38 | num_features = 784 39 | num_hidden_1 = 32 40 | num_hidden_2 = 64 41 | num_classes = 10 42 | 43 | torch.manual_seed(random_seed) 44 | 45 | 46 | # ---------------Model----------------------------------------------- 47 | class TestMLP(nn.Module): 48 | def __init__(self, num_features, num_hidden1, num_hidden2, num_class): 49 | super().__init__() 50 | self.layers = nn.Sequential( 51 | nn.Linear(num_features, num_hidden1), 52 | nn.ReLU(), 53 | nn.Linear(num_hidden1, num_hidden2), 54 | nn.ReLU(), 55 | 56 | nn.Linear(num_hidden2, num_class) 57 | ) 58 | 59 | def forward(self, x): 60 | x = self.layers(x) 61 | return x 62 | 63 | 64 | model = TestMLP( 65 | num_features=num_features, num_hidden1=num_hidden_1, num_hidden2=num_hidden_2, num_class=num_classes 66 | ) 67 | 68 | model.to(DEVICE) 69 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 70 | 71 | 72 | # ---------------------Eval--------------------------------------------- 73 | 74 | def computer_metrics(model, data_loader, device): 75 | model.eval() 76 | correct_pred, num_examples = 0, 0 77 | with torch.no_grad(): 78 | for images, labels in data_loader: 79 | # Image batch dimensions: torch.Size([64, 1, 28, 28]) 80 | # Image label dimensions: torch.Size([64]) 81 | 82 | images = images.view(-1, 28 * 28).to(device) 83 | labels = labels.to(device) 84 | logits = model(images) 85 | _, predicted_labels = torch.max(logits, 1) 86 | num_examples = labels.size(0) 87 | correct_pred += (predicted_labels == labels).sum() 88 | return correct_pred.float() / num_examples * 100 89 | 90 | 91 | # ---------------------Train--------------------------------------------- 92 | 93 | 94 | def train(epochs, model, optimizer, train_loader, device): 95 | start_time = time.time() 96 | for epoch in range(epochs): 97 | model.train() 98 | for batch_idx, (images, labels) in enumerate(train_loader): 99 | images = images.view(-1, 28 * 28).to(device) 100 | labels = labels.to(device) 101 | 102 | # forward and back 103 | logits = model(images) 104 | loss = F.cross_entropy(logits, labels) 105 | optimizer.zero_grad() 106 | 107 | loss.backward() 108 | 109 | # UPDATE MODEL PARAMETERS 110 | optimizer.step() 111 | 112 | # LOGGING 113 | if not batch_idx % 400: 114 | print('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f' 115 | % (epoch + 1, epochs, batch_idx, 116 | len(train_loader), loss)) 117 | with torch.set_grad_enabled(False): 118 | print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( 119 | epoch + 1, epochs, 120 | computer_metrics(model, train_loader, device))) 121 | print('Time elapsed: %.2f min' % ((time.time() - start_time) / 60)) 122 | print('Total Training Time: %.2f min' % ((time.time() - start_time) / 60)) 123 | 124 | 125 | # ---------------------Lora Model--------------------------------------------- 126 | class LoRALayer(nn.Module): 127 | def __init__(self, in_dim, out_dim, rank, alpha): 128 | super().__init__() 129 | std_dev = 1 / torch.sqrt(torch.tensor(rank).float()) 130 | self.A = nn.Parameter(torch.rand(in_dim, rank) * std_dev) 131 | self.B = nn.Parameter(torch.zeros(rank, out_dim)) 132 | self.alpha = alpha 133 | 134 | def forward(self, x): 135 | x = self.alpha * (x @ self.A @ self.B) 136 | return x 137 | 138 | 139 | class LinearWithLoRA(nn.Module): 140 | def __init__(self, linear, rank, alpha): 141 | super().__init__() 142 | self.linear = linear 143 | self.lora = LoRALayer( 144 | linear.in_features, 145 | linear.out_features, 146 | rank, 147 | alpha 148 | ) 149 | 150 | def forward(self, x): 151 | return self.linear(x) + self.lora(x) 152 | 153 | 154 | # ---------------------DoRA Model--------------------------------------------- 155 | class LinearWithDoRA(nn.Module): 156 | def __init__(self, linear, rank, alpha): 157 | super().__init__() 158 | self.linear = linear 159 | self.lora = LoRALayer( 160 | linear.in_features, linear.out_features, rank, alpha 161 | ) 162 | self.m = nn.Parameter(torch.ones(1, linear.out_features)) 163 | 164 | def forward(self, x): 165 | linear_out = self.linear(x) 166 | lora_out = self.lora(x) 167 | lora_out_norm = lora_out / (lora_out.norm(p=2, dim=1, keepdim=True) + 1e-9) 168 | dora_modification = self.m * lora_out_norm 169 | return linear_out + dora_modification 170 | 171 | 172 | # 冻结模型的线性层,即可达到lora的只训练额外的lora层 173 | def freeze_linear_layers(model): 174 | for child in model.children(): 175 | if isinstance(child, nn.Linear): 176 | for param in child.parameters(): 177 | param.requires_grad = False 178 | else: 179 | # Recursively freeze linear layers in children modules 180 | freeze_linear_layers(child) 181 | 182 | 183 | # 将模型中的linear层替换为 LinearWithLoRA 184 | def convert_lora_layers(model): 185 | for name, module in model.named_children(): 186 | if isinstance(module, nn.Linear): 187 | setattr(model, name, LinearWithLoRA(module, rank=4, alpha=8)) 188 | else: 189 | convert_lora_layers(module) 190 | 191 | 192 | # 将模型中的linear层替换为 LinearWithDoRA 193 | def convert_dora_layers(model): 194 | for name, module in model.named_children(): 195 | if isinstance(module, nn.Linear): 196 | setattr(model, name, LinearWithDoRA(module, rank=4, alpha=8)) 197 | else: 198 | convert_lora_layers(module) 199 | 200 | 201 | if __name__ == '__main__': 202 | train(num_epochs, model, optimizer, train_loader, DEVICE) 203 | print(f'Test accuracy: {computer_metrics(model, test_loader, DEVICE):.2f}%') 204 | 205 | # 复制两份模型,以供lora 和 dora分别实验 206 | model_lora = copy.deepcopy(model) 207 | model_dora = copy.deepcopy(model) 208 | 209 | # lora_qlora 训练 210 | convert_lora_layers(model_lora) 211 | freeze_linear_layers(model_lora) 212 | model_lora.to(DEVICE) 213 | optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate) 214 | train(2, model_lora, optimizer_lora, train_loader, DEVICE) 215 | print(f'Test accuracy LoRA finetune: {computer_metrics(model_lora, test_loader, DEVICE):.2f}%') 216 | 217 | # dora 训练 218 | convert_dora_layers(model_dora) 219 | freeze_linear_layers(model_dora) 220 | model_dora.to(DEVICE) 221 | optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate) 222 | train(2, model_dora, optimizer_dora, train_loader, DEVICE) 223 | print(f'Test accuracy DoRA finetune: {computer_metrics(model_dora, test_loader, DEVICE):.2f}%') 224 | -------------------------------------------------------------------------------- /llm_tricks/moe/READEME.md: -------------------------------------------------------------------------------- 1 | # Make MOE step by step 2 | 3 | 从零构建一个MOE代码存放于 **make_moe_step_by_step.ipynb**文件下。其中有详细的代码注释,推荐结合技术博客阅读,因为博客中手画了许多图以更好地理解。 4 | 5 | ## 😸技术博客链接 6 | 7 | - [从零构建一个MOE](https://zhuanlan.zhihu.com/p/701777558) 8 | 9 | 10 | 11 | ## 补充 12 | 13 | 博客中没提到的一点是 Expert Capacity。大概意思就是为了防止所有tokens都被一个或几个expert处理,我们需要设置一个专家容量。如果某个专家处理超过容量的tokens后就会给他截断,下面给出一个简单的代码示例,实际生产中会有更高级复杂的策略, 14 | 例如在https://arxiv.org/abs/2101.03961 中讨论的switch transformer架构。 15 | 16 | 我们简单的介绍代码如下,与我们技术博客中讲的SparseMoE基本相同,只是加了两个部分,在代码注释中也已标明。 17 | ```python 18 | class SparseMoE(nn.Module): 19 | def __init__(self, n_embed, num_experts, top_k, capacity_factor=1.0): 20 | super(SparseMoE, self).__init__() 21 | self.router = NoisyTopkRouter(n_embed, num_experts, top_k) 22 | self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)]) 23 | self.top_k = top_k 24 | self.capacity_factor = capacity_factor 25 | self.num_experts = num_experts 26 | 27 | def forward(self, x): 28 | batch_size, seq_len, _ = x.shape 29 | gating_output, indices = self.router(x) 30 | final_output = torch.zeros_like(x) 31 | 32 | flat_x = x.view(-1, x.size(-1)) 33 | flat_gating_output = gating_output.view(-1, gating_output.size(-1)) 34 | 35 | tokens_per_batch = batch_size * seq_len * self.top_k 36 | # 定义专家容量 37 | expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor) 38 | 39 | updates = torch.zeros_like(flat_x) 40 | 41 | for i, expert in enumerate(self.experts): 42 | expert_mask = (indices == i).any(dim=-1) 43 | flat_mask = expert_mask.view(-1) 44 | selected_indices = torch.nonzero(flat_mask).squeeze(-1) 45 | 46 | # 进行容量判断 47 | limited_indices = selected_indices[:expert_capacity] if selected_indices.numel() > expert_capacity else selected_indices 48 | if limited_indices.numel() > 0: 49 | expert_input = flat_x[limited_indices] 50 | expert_output = expert(expert_input) 51 | 52 | gating_scores = flat_gating_output[limited_indices, i].unsqueeze(1) 53 | weighted_output = expert_output * gating_scores 54 | 55 | updates.index_add_(0, limited_indices, weighted_output) 56 | 57 | # Reshape updates to match the original dimensions of x 58 | final_output += updates.view(batch_size, seq_len, -1) 59 | 60 | return final_output 61 | 62 | ``` -------------------------------------------------------------------------------- /llm_tricks/transformer/README.md: -------------------------------------------------------------------------------- 1 | # Transformer代码详解:从头开始实现 2 | 3 | 全部代码:https://github.com/mst272/transformer-pytorch -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | import random 4 | from typing import Optional 5 | 6 | from loguru import logger 7 | import torch 8 | import torch.nn as nn 9 | from datasets import load_dataset 10 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, Trainer, \ 11 | BitsAndBytesConfig, HfArgumentParser, set_seed 12 | from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, cast_mixed_precision_params 13 | from train_args import sft_TrainArgument 14 | import bitsandbytes as bnb 15 | from utils.data_process import MultiRoundDataProcess 16 | from utils.data_collator import SftDataCollator 17 | from train_args.common_args import CommonArgs 18 | from utils.eval.configs import EvaluationConfig, GenerationConfig 19 | from utils.eval.callback import EvaluationCallback 20 | from utils.eval.eval_metric import create_metric 21 | 22 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 23 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 24 | 25 | def initial_args(): 26 | parser = HfArgumentParser((CommonArgs,)) 27 | args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True) 28 | if args.train_args_path == "sft_args": 29 | if args.use_eval_in_train: 30 | parser_b = HfArgumentParser((sft_TrainArgument, EvaluationConfig, GenerationConfig)) 31 | train_args, eval_args, gen_config = parser_b.parse_args_into_dataclasses(args=remaining_args) 32 | else: 33 | parser_b = HfArgumentParser((sft_TrainArgument,)) 34 | train_args, = parser_b.parse_args_into_dataclasses(args=remaining_args) 35 | else: 36 | raise ValueError("Invalid train_args_path choice") 37 | 38 | if not os.path.exists(train_args.output_dir): 39 | os.makedirs(train_args.output_dir, exist_ok=True) 40 | set_seed(train_args.seed) 41 | 42 | assert sum([train_args.fp16, train_args.bf16]) == 1, "only one of fp16 and bf16 can be True" 43 | if args.use_eval_in_train: 44 | return args, train_args, eval_args, gen_config 45 | return args, train_args 46 | 47 | 48 | def find_all_linear_names(model, train_mode): 49 | """ 50 | 找出所有全连接层,为所有全连接添加adapter 51 | """ 52 | assert train_mode in ['lora', 'qlora'] 53 | cls = bnb.nn.Linear4bit if train_mode == 'qlora' else nn.Linear 54 | lora_module_names = set() 55 | for name, module in model.named_modules(): 56 | if isinstance(module, cls): 57 | names = name.split('.') 58 | lora_module_names.add(names[-1]) 59 | 60 | if 'lm_head' in lora_module_names: # needed for 16-bit 61 | lora_module_names.remove('lm_head') 62 | lora_module_names = list(lora_module_names) 63 | return lora_module_names 64 | 65 | 66 | def create_tokenizer(args): 67 | config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) 68 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, 69 | # llama不支持fast 70 | use_fast=False if config.model_type == 'llama' else True 71 | ) 72 | 73 | # QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|> 74 | if tokenizer.__class__.__name__ == 'QWenTokenizer': 75 | tokenizer.pad_token_id = tokenizer.eod_id 76 | tokenizer.bos_token_id = tokenizer.eod_id 77 | tokenizer.eos_token_id = tokenizer.eod_id 78 | if tokenizer.bos_token is None: # qwen没有bos_token,要设置一下,不然dpo train时会报错。 79 | tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token}) 80 | tokenizer.bos_token_id = tokenizer.eos_token_id 81 | 82 | assert tokenizer.pad_token_id is not None, "pad_token_id should not be None" 83 | assert tokenizer.eos_token_id is not None, "eos_token_id should not be None" 84 | 85 | return tokenizer 86 | 87 | 88 | def create_model(args, train_args): 89 | target_modules = None 90 | # 确定训练的精度 91 | torch_dtype = torch.bfloat16 if train_args.bf16 else torch.float32 92 | model_kwargs = dict( 93 | trust_remote_code=True, 94 | torch_dtype=torch_dtype, 95 | use_cache=False if train_args.gradient_checkpointing else True, # The cache is only used for generation, 96 | # fix bug 97 | # device_map='auto' 98 | ) 99 | 100 | def load_model(model_kwargs): 101 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs) 102 | return model 103 | 104 | if args.train_mode == 'qlora': 105 | # 基本的qlora可以直接在加载模型中设置参数,也可以通过BitsAndBytesConfig进行一些设置 106 | quantization_config = BitsAndBytesConfig( 107 | load_in_4bit=True, # 是否在4位精度下加载模型。如果设置为True,则在4位精度下加载模型。 108 | bnb_4bit_compute_dtype=torch.float16 if train_args.fp16 else torch.bfloat16, # 4位精度计算的数据类型。 109 | bnb_4bit_quant_type="nf4", # 4位精度量化的类型。这里设置为"nf4",表示使用nf4量化类型。 110 | bnb_4bit_use_double_quant=True # 是否使用双精度量化。如果设置为True,则使用双精度量化。 111 | ) 112 | model_kwargs.update(quantization_config=quantization_config) 113 | model = load_model(model_kwargs) 114 | if args.task_type in ['pretrain', 'sft']: # 如果是dpo的话就不执行 115 | # QLoRA: casts all the non int8 modules to full precision (fp32) for stability 116 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=train_args.gradient_checkpointing) 117 | 118 | elif args.train_mode == 'lora': 119 | model = load_model(model_kwargs) 120 | if hasattr(model, 'enable_input_require_grads'): 121 | # 不加可能报错 122 | model.enable_input_require_grads() 123 | elif args.train_mode == 'full': 124 | model = load_model(model_kwargs) 125 | 126 | if args.train_mode == 'full': 127 | peft_config = None 128 | else: 129 | # peft_config配置 130 | target_modules = find_all_linear_names(model, args.train_mode) 131 | peft_config = LoraConfig( 132 | r=args.lora_rank, 133 | lora_alpha=args.lora_alpha, 134 | target_modules=target_modules, 135 | lora_dropout=args.lora_dropout, 136 | task_type=TaskType.CAUSAL_LM, 137 | use_dora=args.use_dora 138 | ) 139 | 140 | # peft_model 配置 141 | if args.train_mode in ['lora', 'qlora'] and args.task_type in ['pretrain', 'sft']: 142 | model = get_peft_model(model, peft_config) 143 | if not train_args.bf16: 144 | cast_mixed_precision_params(model, dtype=torch.float16) 145 | 146 | # logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB') 147 | # model.print_trainable_parameters() 148 | 149 | return { 150 | 'model': model, 151 | 'peft_config': peft_config, 152 | 'target_modules': target_modules 153 | } 154 | 155 | 156 | def load_sft_dataset(args, tokenizer): 157 | train_dataset = MultiRoundDataProcess(args.train_data_path, tokenizer, args.max_len, args.auto_adapt) 158 | return train_dataset 159 | 160 | 161 | def create_trainer(args, train_args, eval_args: Optional[EvaluationConfig] = None, gen_config: Optional[GenerationConfig] = None): 162 | """" 163 | Create Trainer,支持可选的评估功能 164 | Args: 165 | args: 通用参数 166 | train_args: 训练相关参数 167 | eval_args: 评估相关参数(可选) 168 | gen_config: 评估相关参数(可选) 169 | """ 170 | # 1.Basic component initialization 171 | tokenizer = create_tokenizer(args) 172 | model_dict = create_model(args, train_args) 173 | model = model_dict['model'] 174 | # peft_config = model_dict['peft_config'] 175 | 176 | # 2. dataset process 177 | if args.task_type == 'sft': 178 | train_dataset = load_sft_dataset(args, tokenizer) 179 | data_collator = SftDataCollator(tokenizer, args.max_len) 180 | elif args.task_type == 'pretrain': 181 | pass 182 | 183 | # 3. log configuration 184 | log_out(args, train_args, tokenizer, train_dataset, model, model_dict['target_modules'], eval_args, gen_config) 185 | 186 | # 4. sft or pretrain 187 | if args.task_type == 'sft': 188 | trainer = Trainer( 189 | model=model, 190 | args=train_args, 191 | train_dataset=train_dataset, 192 | data_collator=data_collator, 193 | processing_class=tokenizer 194 | ) 195 | elif args.task_type == 'pretrain': 196 | pass 197 | # 5. Add evaluation callbacks if eval_args is provided 198 | if eval_args is not None: 199 | test_datasets = load_dataset( 200 | path="json", 201 | data_files=args.test_datasets_path 202 | )['train'] 203 | 204 | # 创建评估回调 205 | metrics = create_metric(eval_args) 206 | eval_callback = EvaluationCallback( 207 | trainer=trainer, 208 | test_datasets=test_datasets, 209 | generation_config=gen_config, 210 | num_samples=eval_args.num_samples, 211 | freq=eval_args.freq, 212 | metrics=metrics, 213 | max_checkpoints=eval_args.max_checkpoints, 214 | per_device_test_batch_size=eval_args.per_device_test_batch_size, 215 | higher_better=eval_args.higher_better, 216 | start_update_best_checkpoints=eval_args.start_update_best_checkpoints, 217 | use_vllm=eval_args.use_vllm, 218 | gather_deepspeed3_params=gen_config.gather_deepspeed3_params, 219 | prompts_apply_chat=eval_args.prompts_apply_chat, 220 | vllm_server_host=eval_args.vllm_server_host, 221 | vllm_server_port=eval_args.vllm_server_port, 222 | vllm_server_timeout=eval_args.vllm_server_timeout 223 | ) 224 | trainer.add_callback(eval_callback) 225 | 226 | return trainer 227 | 228 | 229 | def log_out(args, train_args, tokenizer, train_dataset, model, target_modules, eval_args, gen_config): 230 | total = sum(p.numel() for p in model.parameters()) 231 | logger.add(join(train_args.output_dir, 'train.log')) 232 | if train_args.local_rank == 0: 233 | logger.info("train_args:{}".format(train_args)) 234 | logger.info("common_args:{}".format(args)) 235 | logger.info("\neval_args:{}".format(eval_args)) 236 | logger.info("\ngen_config:{}".format(gen_config)) 237 | logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}') 238 | logger.info(f'Loading model from base model: {args.model_name_or_path}') 239 | logger.info("Total model params: %.2fM" % (total / 1e6)) 240 | logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB') 241 | if args.train_mode != 'full': 242 | trainable_params, all_param = model.get_nb_trainable_parameters() 243 | logger.info( 244 | f"trainable params: {trainable_params:,d} || " 245 | f"all params: {all_param:,d} || " 246 | f"trainable%: {100 * trainable_params / all_param:.4f}" 247 | ) 248 | logger.info(f'Train model with {args.task_type} task') 249 | logger.info(f'Train model with {args.train_mode}') 250 | logger.info(f'LoRA target module names: {target_modules}') 251 | logger.info(f'Loading data: {args.train_data_path}') 252 | logger.info(f"Training dataset samples:{len(train_dataset)}") 253 | for index in random.sample(range(len(train_dataset)), 3): 254 | logger.info( 255 | f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['target_mask']}.") 256 | logger.info( 257 | f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.") 258 | 259 | 260 | def main(): 261 | # args, train_args = initial_args() 262 | # # 加载trainer 263 | # trainer = create_trainer(args, train_args) 264 | result = initial_args() 265 | if len(result) == 4: 266 | args, train_args, eval_args, gen_config = result 267 | # 加载trainer,需要传入eval_args 268 | trainer = create_trainer(args, train_args, eval_args, gen_config) 269 | else: 270 | args, train_args = result 271 | # 原有的trainer创建方式 272 | trainer = create_trainer(args, train_args) 273 | # 开始训练 274 | if train_args.local_rank == 0: 275 | logger.info("*** starting training ***") 276 | train_result = trainer.train() 277 | # Transformers 更新了自动保存最后训练结果 278 | # final_save_path = join(train_args.output_dir) 279 | # trainer.save_model(final_save_path) 280 | 281 | # 保存训练指标 282 | metrics = train_result.metrics 283 | trainer.log_metrics("train", metrics) 284 | trainer.save_metrics("train", metrics) 285 | trainer.save_state() 286 | 287 | 288 | if __name__ == "__main__": 289 | main() 290 | -------------------------------------------------------------------------------- /pic/pic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mst272/LLM-Dojo/6397eff480feca95f3b6016495b30d1ae308c9dc/pic/pic.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | transformers>=4.44.2 3 | peft 4 | bitsandbytes 5 | loguru 6 | numpy==1.26.4 7 | pandas 8 | tqdm 9 | deepspeed==0.16.3 10 | sentencepiece 11 | transformers-stream-generator 12 | tiktoken 13 | einops 14 | torch==2.5.0 15 | datasets 16 | trl 17 | wandb 18 | PIL 19 | vllm==0.7.2 -------------------------------------------------------------------------------- /rlhf/README.md: -------------------------------------------------------------------------------- 1 | # RLHF 强化学习框架 2 | 3 | 本框架使用简洁的代码基于Huggingface对各种强化学习方法进行了集成,便于自己修改与使用,是一个轻量化的强化学习框架。 4 | 5 | 主要资源是在1-8张40G A100上进行实验,支持lora qlora 及deepspeed单卡或多卡训练。 6 | 7 | 主要包括三类: 8 | 9 | **1、RLHF** 10 | 11 | **2、Knowledge Distillation (知识蒸馏)** 12 | 13 | **3、Rejected Sampling (拒绝采样) :待更新** 14 | 15 | ## 目录 16 | 17 | - [RLHF](#rlhf) 18 | - [目前支持的RLHF](#目前支持的rlhf) 19 | - [Quick Star](#quick-star) 20 | - [数据格式要求](#数据格式要求) 21 | - [数据格式选择](#数据格式选择) 22 | - [启动训练](#启动训练) 23 | - [注意事项](#注意事项) 24 | - [显存实验](#显存实验) 25 | - [Knowledge Distillation](#knowledge-distillation) 26 | - [Quick Star](#quick-star-1) 27 | - [感谢](#感谢) 28 | 29 | ## RLHF 30 | ### 目前支持的RLHF 31 | 实践来看主要的训练方式即为单轮。 32 | 33 | - ✅ Reward模型的训练 34 | - ✅ RLOO 35 | - ✅ PPO(暂时不可用) 36 | - ✅ SimPO 37 | - ✅ CPO 38 | - ✅ CPO-SimPO 39 | - ✅ DPO 40 | - ✅ KTO 41 | 42 | ### 🚀Quick Star 43 | 44 | 若有问题请尝试 deepspeed==0.15.4/python==3.10, 或者出现loss、rewards/chosen为nan时,请查看当前目录下的requirements.txt,按照此版本安装看是否能解决。 45 | 46 | 一些潜在的问题,暂时还没得到解决或者潜在的解决方案: 47 | 48 | https://github.com/huggingface/alignment-handbook/issues/57 49 | 50 | https://github.com/microsoft/DeepSpeed/issues/6793#issuecomment-2502620884 51 | 52 | https://github.com/ymcui/Chinese-LLaMA-Alpaca-3/issues/29 53 | 54 | #### 数据格式要求 55 | ✅ DPO、CPO、SimPO、CPO-SimPO: 56 | 57 | 需要有如下字段: 58 | - prompt 59 | - chosen 60 | - rejected 61 | 62 | ```json lines 63 | {"prompt":[{"role":"user","content":"How are you?"}],"chosen":[{"role":"assistant","content":"fine"}],"rejected":[{"role":"assistant","content":"no"}]} 64 | ``` 65 | ✅ KTO: 66 | - prompt 67 | - completion 68 | - label 69 | 70 | 比较特殊,相当于chosen的label为true,rejected的label为false: 71 | ```json lines 72 | {"prompt":[{"role":"user","content":"How are you?"}],"completion":[{"role":"assistant","content":"fine"}],"label":true} 73 | ``` 74 | 75 | ✅ Reward: 76 | - chosen 77 | - rejected 78 | 79 | ```json lines 80 | {"chosen":[{"role":"user","content":"How are you?"},{"role":"assistant","content":"fine"}],"rejected":[{"role":"user","content":"How are you?"},{"role":"assistant","content":"no"}]} 81 | ``` 82 | ✅ DPO、RLOO: 83 | - prompt 84 | 85 | ```json lines 86 | {"prompt":[{"role":"user","content":"How are you?"}]} 87 | ``` 88 | 89 | #### 数据格式选择 90 | 91 | **1.自动适配Chat Template格式**: 输入数据需为user assistant标准模式,具体可见上述数据格式要求。 92 | 93 | **2.不使用Chat格式**: 输入数据直接改为相应字段格式即可,例如: 94 | ```json lines 95 | {"prompt":"How are you?","chosen":"fine", "rejected": "no"} 96 | ``` 97 | 98 | ```json lines 99 | {"chosen":"How are you? fine", "rejected": "How are you? no"} 100 | ``` 101 | 训练时便不会进行适配,采用原始输入进行训练。 102 | 103 | 104 | #### 启动训练 105 | 106 | 两个参数配置文件,第一个为```common_args.py```, 其余不同方法的配置在```rlhf_args```文件夹内 107 | 108 | 建议使用deepspeed启动,启动脚本在```rlhf_run.sh``` 109 | ```bash 110 | bash rlhf_run.sh 111 | ``` 112 | 113 | - rlhf_type: [PPO,RLOO,CPO,DPO,SimPO,CPOSimPO,Reward] 114 | - train_mode: [lora, qlora, full] 115 | 116 | #### 注意事项 117 | 1、需要自己去看AutoModelForSequenceClassification是否可以加载其Classification模型,不能的话需要在其config文件中映射。 118 | 119 | 2、涉及到reward模型时,需要两个模型的tokenizer相同。 120 | 121 | 3、使用deepspeed时需要通过accelerate进行使用,直接deepspeed的话会报错(目前似乎没有很好的解决方案) 122 | 123 | 4、一般来说trl的trainer是不支持使用deepspeed的optimizer和scheduler的 124 | 125 | 5、不支持Qlora和deepspeed zero-3,支持Qlora和deepspeed zero-2 126 | 127 | 6、训练Qwen2时遇到报错,提示```no padding token is defined```。需要在qwen2 ```config.json```中添加pad_token_id,在tokenizer中设置没用。 128 | 129 | 7、PPO/RLOO参数解释: 130 | 131 | See:https://github.com/huggingface/trl/issues/1740 132 | 133 | The ``num_train_epochs`` and ``num_ppo_epochs`` are actually two different things. The num_train_epochs means how many epochs do we go over the dataset, the num_ppo_epochs means the number of epochs we perform PPO updates on a batch of data. So, there is a subtle but meaningful difference here. 134 | 135 | 8、CPO系列不支持fp16,支持bf16 136 | 137 | #### 显存实验 138 | res_length为64 139 | 140 | | **RLHF** | **deepspeed** | **方式** | **Reward Model** | **SFT Model** | **显存占用** | 141 | |----------|---------------|--------|------------------|----------------|------------------------| 142 | | RLOO | Zero 3 | Lora | QWEN2(7B) | QWEN2(7B) | 2 x A100(40GB): 15~30G | 143 | | RLOO | Zero 3 | Full | QWEN2(7B) | QWEN2(7B) | 2 x A100(40GB): 速度很慢 | 144 | | RLOO | Zero 2 | Qlora | QWEN2(7B) | QWEN2(7B) | 2 x A100(40GB): 30~40G | 145 | | PPO | Zero 2 | Lora | MiniCPM(2B) | Deepseek(6.7B) | 2 x A100(40GB): OOM | 146 | | PPO | Zero 3 | Lora | MiniCPM(2B) | Deepseek(6.7B) | 2 x A100(40GB): 20-25G | 147 | | PPO | Zero 2 | Qlora | MiniCPM(2B) | Deepseek(6.7B) | 2 x A100(40GB): 30G | 148 | 149 | ## Knowledge Distillation 150 | 目前支持三种类型的知识蒸馏,GKD效果最好: 151 | - Supervised KD(off-policy) 152 | - SeqKD(off-policy) 153 | - GKD(on-policy) 154 | 155 | 具体介绍可参见文章:[知识蒸馏](https://zhuanlan.zhihu.com/p/1064724364) 156 | 157 | ### Quick Star 158 | 进入script目录下bash运行```gkd_run.sh```即可,修改对应参数运行。同样支持Deepspeed. 159 | 160 | 161 | ```bash 162 | bash gkd_run.sh 163 | ``` 164 | 165 | **参数介绍**: 166 | - lmbda:0时为Supervised KD,1时为GKD。可在[0,1]范围内选择,这样就会混合比例 167 | - beta: 0时loss为KLD, 1时为JSD。可在[0,1]范围内选择,这样就会混合比例 168 | - seq_kd: True时Supervised KD将替换为Seq KD,默认为False,其他不变。 169 | - model_name_or_path:Student Model,即你需要训练的模型 170 | - teacher_model_name_or_path:Teacher Model, 不训练。 171 | 172 | ## Rejected Sampling 173 | 待更新 174 | 175 | ## 感谢 176 | 177 | 特别感谢huggingface trl做出的强大贡献,通过 trl 我们真的可以很容易简洁的实现RLHF。 -------------------------------------------------------------------------------- /rlhf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mst272/LLM-Dojo/6397eff480feca95f3b6016495b30d1ae308c9dc/rlhf/__init__.py -------------------------------------------------------------------------------- /rlhf/common_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class CommonArgs: 7 | """ 8 | 一些常用的自定义参数 9 | """ 10 | train_data_path: str = field(default='', metadata={"help": "训练数据路径"}) 11 | # 微调方法相关选择与配置 12 | rlhf_type: str = field(default="DPO", 13 | metadata={"help": "选择使用的RLHF方法,目前支持[PPO,RLOO,DPO,CPO,SimPO,CPOSimPO,Reward]"}) 14 | train_mode: str = field(default='lora', metadata={"help": "选择采用的训练方式:[qlora, lora, full]"}) 15 | 16 | # model qlora lora相关配置 17 | model_name_or_path: str = './' 18 | use_dora: bool = field(default=False, metadata={"help": "仅在train_mode==lora时可以使用。是否使用Dora(一个基于Lora的变体)"}) 19 | lora_rank: Optional[int] = field(default=32, metadata={"help": "lora rank"}) 20 | lora_alpha: Optional[int] = field(default=16, metadata={"help": "lora alpha"}) 21 | lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "lora dropout"}) 22 | -------------------------------------------------------------------------------- /rlhf/ds_config/ds_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: cpu 6 | zero3_init_flag: false 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | machine_rank: 0 11 | main_training_function: main 12 | mixed_precision: 'bf16' 13 | num_machines: 1 14 | num_processes: 2 15 | rdzv_backend: static 16 | same_network: true 17 | tpu_env: [] 18 | tpu_use_cluster: false 19 | tpu_use_sudo: false 20 | use_cpu: false 21 | main_process_port: 29501 -------------------------------------------------------------------------------- /rlhf/ds_config/ds_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: cpu 6 | offload_param_device: cpu 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 2 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | main_process_port: 29501 -------------------------------------------------------------------------------- /rlhf/gkd_run.sh: -------------------------------------------------------------------------------- 1 | # 使用显卡数量需在yaml文件中修改num_processes参数 2 | 3 | # Lora模式, 如需QLora或者全参略微修改参数即可 4 | CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ./ds_config/ds_zero3.yaml ./train_gkd.py \ 5 | --model_name_or_path deepseek-coder-6.7b-instruct \ 6 | --teacher_model_name_or_path deepseek-coder-33b-instruct\ 7 | --dataset_name ../data/gkd_data.jsonl \ 8 | --learning_rate 2e-5 \ 9 | --per_device_train_batch_size 4 \ 10 | --gradient_accumulation_steps 8 \ 11 | --output_dir gkd-model2 \ 12 | --logging_steps 2 \ 13 | --num_train_epochs 1 \ 14 | --gradient_checkpointing \ 15 | --lmbda 0.5 \ 16 | --beta 0.5 \ 17 | --use_peft \ 18 | --lora_r 32 \ 19 | --lora_alpha 16 \ 20 | --trust_remote_code \ 21 | --bf16 \ 22 | --save_strategy "steps" \ 23 | --save_steps 180 \ 24 | --save_total_limit 5 \ 25 | --warmup_steps 10 \ 26 | --lr_scheduler_type "cosine" \ 27 | --torch_dtype bfloat16 > logs.log 2>&1 & -------------------------------------------------------------------------------- /rlhf/rejected_sampling/README.md: -------------------------------------------------------------------------------- 1 | # Rejected Sampling 2 | 3 | ## 1、Generate 4 | 5 | ### Data format 6 | 7 | jsonl格式,包含如下字段: 8 | - prompt 9 | - answer:可为空字符串,作为参考答案 10 | ```json lines 11 | {"prompt":"Hellow","answer":"nice"} 12 | ``` 13 | 14 | 如果使用可以直接apply_chat的messages格式,也无需修改,直接传入即可(需要无system字段),同样assistant的回答当做参考回答,后续可选择是否将参考回答放入打分名单: 15 | 16 | ```json lines 17 | {"message": [{"role": "user", "content": "How many helicopters can a human eat in one sitting"},{"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together"}]} 18 | ``` 19 | 20 | 采用分块存储,最终生成文件数量为原文件的n倍(生成n个回答)的字段包括: 21 | - messages: 可以直接输入训练且apply_chat_template的messages格式,其中每个assistant为n个生成中的一个 22 | - model_completion: 模型本次生成的结果 23 | - reference_completion: 你的原始数据的参考答案 24 | 25 | 26 | ## 2、Rejected sampling评测阶段 27 | 28 | 目前只支持通过api进行评测选择,传统的classification模型几乎用不到了,所以就进行了去除。 29 | 30 | 参数待配置,目前还只是一个草案,需要在make_request函数中自行配置相关API信息。 31 | 32 | 33 | 34 | ## 致谢 35 | 1、实现借鉴了open-instruct,感谢开源! -------------------------------------------------------------------------------- /rlhf/rejected_sampling/genetate.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass, asdict 3 | import os 4 | import json 5 | from typing import Dict, List 6 | 7 | import torch 8 | from datasets import load_dataset, concatenate_datasets, Dataset 9 | from transformers import AutoTokenizer 10 | from vllm import LLM, SamplingParams 11 | from rlhf.utils.util import ArgumentParserPlus 12 | 13 | 14 | @dataclass 15 | class Args: 16 | dataset_name: str = './test.jsonl' # 数据集 17 | model_name_or_path: str = "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr" 18 | save_filename: str = "completions.jsonl" 19 | auto_adapt: bool = True # if apply chat template 20 | system: str = '' # chat template default 21 | chunk_size: int = 50000 22 | 23 | 24 | @dataclass 25 | class GenerationArgs: 26 | num_completions: int = 3 27 | temperature: float = 0.8 28 | response_length: int = 4096 29 | top_p: float = 0.9 30 | tensor_parallel_size: int = 1 31 | dtype: torch.dtype = torch.bfloat16 32 | 33 | 34 | def load_datasets(data_files: str, shuffle: bool): 35 | """ 36 | 读取数据集,单jsonl文件或者目录 37 | """ 38 | if os.path.isfile(data_files): 39 | # 如果是单个文件,直接读取 40 | if not data_files.endswith('.jsonl'): 41 | raise ValueError(f"文件 '{data_files}' 不是JSONL文件") 42 | datasets = load_dataset("json", data_files=data_files) 43 | else: 44 | # 如果是目录,读取所有JSONL文件 45 | datasets = [] 46 | jsonl_files = [f for f in os.listdir(data_files) if f.endswith('.jsonl')] 47 | for file_name in jsonl_files: 48 | dataset = load_dataset("json", data_files=file_name) 49 | datasets.append(dataset) 50 | datasets = concatenate_datasets(datasets) 51 | if shuffle: 52 | datasets = datasets.shuffle(seed=42) 53 | return datasets['train'] 54 | 55 | 56 | def save_jsonl(save_filename: str, table: Dict[str, List]): 57 | first_key = list(table.keys())[0] 58 | os.makedirs(os.path.dirname(save_filename), exist_ok=True) 59 | with open(save_filename, "w") as outfile: 60 | for i in range(len(table[first_key])): 61 | json.dump({key: table[key][i] for key in table}, outfile) 62 | outfile.write("\n") 63 | 64 | 65 | def save_jsonl_in_chunks_to_files(base_filename: str, table: Dict[str, List], chunksize: int): 66 | """ 67 | 将字典数据按指定的 chunksize 分块保存为多个 JSONL 文件。 68 | 69 | Args: 70 | base_filename: 保存的文件名的基本名称(不包含 chunk 编号)。 71 | table: 包含数据的字典,其中 values 是等长的列表。 72 | chunksize: 每个 chunk 文件保存的行数。 73 | """ 74 | first_key = list(table.keys())[0] 75 | num_rows = len(table[first_key]) 76 | os.makedirs(os.path.dirname(base_filename), exist_ok=True) 77 | chunk_number = 0 78 | for i in range(0, num_rows, chunksize): 79 | chunk_number += 1 80 | save_filename = f"{base_filename}_chunk_{chunk_number}.jsonl" 81 | with open(save_filename, "w") as outfile: 82 | for j in range(i, min(i + chunksize, num_rows)): 83 | json.dump({key: table[key][j] for key in table}, outfile) 84 | outfile.write("\n") 85 | 86 | 87 | def generate_with_vllm(model_name_or_path: str, prompt_token_ids: List[int], gen_args: GenerationArgs): 88 | llm = LLM( 89 | model=model_name_or_path, 90 | tensor_parallel_size=gen_args.tensor_parallel_size, 91 | max_model_len=gen_args.response_length, 92 | dytype=gen_args.dtype 93 | ) 94 | 95 | # filter out prompts which are beyond the model's max token length 96 | max_model_len = llm.llm_engine.scheduler_config.max_model_len 97 | prompt_token_ids_len = len(prompt_token_ids) 98 | prompt_token_ids = [item for item in prompt_token_ids if len(item) < max_model_len] 99 | if len(prompt_token_ids) != prompt_token_ids_len: 100 | print(f"Filtered out {prompt_token_ids_len - len(prompt_token_ids)} prompts which exceeds max token length") 101 | 102 | outputs = llm.generate( 103 | prompt_token_ids=prompt_token_ids, 104 | sampling_params=SamplingParams( 105 | n=gen_args.num_completions, 106 | temperature=gen_args.temperature, 107 | top_p=1.0, 108 | max_tokens=gen_args.response_length, 109 | include_stop_str_in_output=True, 110 | ), 111 | ) 112 | 113 | return [ 114 | { 115 | "outputs": [asdict(out) for out in output.outputs], 116 | "prompt": output.prompt, 117 | "prompt_logprobs": output.prompt_logprobs, 118 | "metrics": output.metrics, 119 | } 120 | for output in outputs 121 | ] 122 | 123 | 124 | def tokenize(dataset: Dataset, auto_adapt: bool, system: str, tokenizer): 125 | def tokenize_fn(row): 126 | answer = row['answer'] if 'answer' in row else row['messages'][1]['content'] 127 | prompt = row['prompt'] if 'prompt' in row else row['messages'][0]['content'] 128 | messages = [ 129 | {"role": "user", "content": prompt} 130 | ] 131 | if system is not None: 132 | messages.append({"role": "system", "content": system}) 133 | outputs = tokenizer.apply_chat_template( 134 | messages, 135 | tokenize=True, 136 | add_generation_prompt=True 137 | ) 138 | return {"input_ids": outputs, "prompt": prompt, "answer": answer} 139 | 140 | def tokenize_fn_origin(row): 141 | prompt = row['prompt'] if 'prompt' in row else row['messages'][0]['content'] 142 | answer = row['answer'] if 'answer' in row else row['messages'][1]['content'] 143 | outputs = tokenizer.encode(prompt) 144 | return {"input_ids": outputs, "prompt": prompt, "answer": answer} 145 | 146 | return dataset.map( 147 | tokenize_fn if auto_adapt else tokenize_fn_origin, 148 | desc="Tokenizing and reformatting rejected sampling data", 149 | ) 150 | 151 | 152 | def main(args: Args, gen_args: GenerationArgs): 153 | dataset = load_datasets(data_files=args.dataset_name, shuffle=True) 154 | 155 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 156 | dataset = tokenize(dataset=dataset, auto_adapt=args.auto_adapt, system=args.system, tokenizer=tokenizer) 157 | prompt_token_ids = dataset['input_ids'] 158 | outputs = generate_with_vllm(args.model_name_or_path, prompt_token_ids, gen_args) 159 | 160 | # Assuming we generate n=3 completions per prompt; the outputs will look like: 161 | # prompt | completions 162 | # -------|------------ 163 | # q1 | a1 164 | # q1 | a2 165 | # q1 | a3 166 | # q2 | a1 167 | # ... 168 | table = defaultdict(list) 169 | num_prompt_with_identical_completions = 0 170 | for output, answer, prompt in zip(outputs, dataset["answer"], dataset['prompt']): 171 | # if the model completions are exactly the same across all completions per prompt, we can skip this 172 | if len(set(tuple(item["text"]) for item in output["outputs"])) == 1: 173 | num_prompt_with_identical_completions += 1 174 | continue 175 | 176 | for item in output["outputs"]: 177 | new_messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": item["text"]}] 178 | table["messages"].append(new_messages) 179 | table["model_completion"].append(item["text"]) 180 | table["reference_completion"].append(answer) 181 | 182 | print(f"Number prompts with identical completions: {num_prompt_with_identical_completions}") 183 | # save_jsonl(args.save_filename, table) 184 | save_jsonl_in_chunks_to_files(args.save_filename, table, args.chunk_size) 185 | 186 | 187 | if __name__ == "__main__": 188 | parser = ArgumentParserPlus((Args, GenerationArgs)) 189 | main(*parser.parse()) 190 | -------------------------------------------------------------------------------- /rlhf/rejected_sampling/rejected_sampling.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import aiohttp 3 | from datetime import datetime 4 | import re 5 | from tqdm.auto import tqdm 6 | import string 7 | import json 8 | from typing import List, Dict 9 | import glob 10 | import orjson 11 | import os 12 | 13 | 14 | def load_completions(input_path: str) -> List[dict]: 15 | """ 16 | 从 JSONL 文件或文件夹中加载 completions 数据 (使用 orjson)。 17 | 18 | Args: 19 | input_path: 文件路径或文件夹路径。 20 | 21 | Returns: 22 | 包含所有加载的 completions 数据的列表。 23 | """ 24 | all_completions = [] 25 | if os.path.isfile(input_path): 26 | with open(input_path, 'rb') as f: # 注意使用二进制读取模式 'rb' 27 | for line in f: 28 | all_completions.append(orjson.loads(line)) 29 | elif os.path.isdir(input_path): 30 | for filename in glob.glob(os.path.join(input_path, '*.jsonl')): 31 | print(f"正在加载文件: {filename}") 32 | with open(filename, 'rb') as f: # 注意使用二进制读取模式 'rb' 33 | for line in f: 34 | all_completions.append(orjson.loads(line)) 35 | else: 36 | print(f"错误: 输入路径 '{input_path}' 不是有效的文件或文件夹。") 37 | return all_completions 38 | 39 | 40 | def reason_post_process(code, index): 41 | """ 42 | Args: 43 | code (str): 输入字符串。 44 | index (int/str): 当前字符串的序号 (索引)。 45 | 46 | Returns: 47 | str 或 int: 如果找到代码块,则返回最后一个代码块字符串; 48 | 否则,返回输入的字符串序号 (index)。 49 | """ 50 | 51 | # Look for code blocks 52 | code_pattern = r'```(?:python|go|ts|php|csharp|bash|javascript|cpp|cs|java)(.*?)```' 53 | code_match = re.findall(code_pattern, code, re.DOTALL) 54 | 55 | if code_match: 56 | # If code block exists, return its content (excluding the ``` markers) 57 | return code_match[-1].strip() 58 | else: 59 | # If no code block, return the solution content directly 60 | print('---', index) 61 | # print(code) 62 | return code 63 | 64 | 65 | def create_dynamic_comparison_prompt(prompt: str, responses: list[str]) -> str: 66 | """ 67 | 根据给定的 prompt 和一个包含多个代码响应的列表, 68 | 生成一个用于比较这些代码片段的动态提示模板。 69 | 70 | Args: 71 | prompt: 描述问题的字符串。 72 | responses: 包含多个代码片段(字符串)的列表。 73 | 74 | Returns: 75 | 一个格式化好的、用于大模型评估的完整提示字符串。 76 | 如果 responses 为空,则返回错误信息。 77 | """ 78 | if not responses: 79 | return "错误:未提供任何代码响应进行比较。" 80 | 81 | num_responses = len(responses) 82 | 83 | # 1. 构建模板的静态开头部分 84 | # 修改引言以适应多个片段 85 | header = f"""Compare the following {num_responses} code snippets that aim to solve the given problem. 86 | Evaluate each snippet based on efficiency, readability, and adherence to best practices. 87 | Identify the preferred snippet or rank them if applicable. 88 | 89 | ### Problem: 90 | {prompt} 91 | """ 92 | 93 | # 2. 动态构建每个代码片段的部分 94 | snippets_section = "" 95 | for i, response in enumerate(responses): 96 | # 生成标签:Code A, Code B, ... 97 | if i < 26: # 最多支持到 Z 98 | label = string.ascii_uppercase[i] 99 | else: # 如果超过 26 个,就用数字编号 100 | label = str(i + 1) 101 | 102 | snippets_section += f"\n### Code {label}:\n{response}\n" # 每个代码块前后加换行 103 | 104 | # 3. 构建模板的静态结尾部分 105 | footer = """ 106 | Code Analysis (Provide a brief analysis for each snippet, discussing its pros and cons regarding efficiency, readability, and best practices): 107 | 108 | Preferred Code (Output only the single letter label of the most preferred code snippet in 【】 below this line, e.g., 【answer here】): 109 | """ 110 | 111 | # 4. 组合所有部分 112 | full_prompt = header + snippets_section + footer 113 | return full_prompt 114 | 115 | 116 | async def make_request(session: aiohttp.ClientSession, prompt: str, index: int, api_key: str, post_url: str, 117 | cookie: str) -> Dict: 118 | url = post_url 119 | headers = { 120 | "Authorization": api_key, 121 | "Content-Type": "application/json", 122 | "Cookie": cookie 123 | } 124 | 125 | payload = { 126 | "stream": False, 127 | "model": "default", 128 | "messages": [ 129 | { 130 | "role": "user", 131 | "content": prompt 132 | } 133 | ], 134 | "max_tokens": 4096, 135 | "temperature": 0.0, 136 | "n": 1 137 | } 138 | 139 | try: 140 | async with session.post(url, headers=headers, json=payload, timeout=1000) as response: 141 | response.raise_for_status() 142 | json_response = await response.json() 143 | return { 144 | 'index': index, 145 | 'status': 'success', 146 | 'prompt': prompt, 147 | 'response': reason_post_process(json_response['choices'][0]['message']['content'], index) 148 | } 149 | except aiohttp.ClientError as e: 150 | return { 151 | 'index': index, 152 | 'status': 'error', 153 | 'prompt': prompt, 154 | 'response': f"请求失败:{str(e)}" 155 | } 156 | except json.JSONDecodeError as e: 157 | return { 158 | 'index': index, 159 | 'status': 'error', 160 | 'prompt': prompt, 161 | 'response': f"JSON解析错误:{str(e)}" 162 | } 163 | 164 | 165 | def extract_answer(text): 166 | """ 167 | 提取字符串中第一个【】内的内容,并返回字符串。 168 | 169 | Args: 170 | text: 输入字符串。 171 | 172 | Returns: 173 | 第一个【】内的内容字符串,如果没有任何匹配,则返回 None。 174 | """ 175 | pattern = r'【(.*?)】' 176 | match = re.search(pattern, text) 177 | if match: 178 | return match.group(1) 179 | else: 180 | return None 181 | 182 | 183 | ### 优化 184 | async def process_group(session: aiohttp.ClientSession, group: dict) -> dict: 185 | """处理单个分组的异步函数""" 186 | prompt = group['prompt'] 187 | responses = group['responses'] 188 | full_prompt = create_dynamic_comparison_prompt(prompt, responses) 189 | 190 | # 使用已存在的session发送请求 191 | result = await make_request(session, full_prompt, 0) # 索引在这里不重要 192 | selected_label = extract_answer(result['response']) 193 | 194 | return { 195 | 'prompt': prompt, 196 | 'responses': responses, 197 | 'indices': group['indices'], 198 | 'selected_label': selected_label 199 | } 200 | 201 | 202 | async def process_comparisons_async(completions, num_per_group=3, max_concurrent=5): 203 | start_time = datetime.now() 204 | grouped_data = [] 205 | current_group = { 206 | 'prompt': None, 207 | 'responses': [], 208 | 'indices': [] 209 | } 210 | 211 | # 1. 分组数据并记录原始索引 212 | for idx, item in enumerate(completions): 213 | prompt = item['messages'][0]['content'] 214 | response = item['messages'][1]['content'] 215 | 216 | if idx % num_per_group == 0 and idx != 0: 217 | grouped_data.append(current_group) 218 | current_group = { 219 | 'prompt': prompt, 220 | 'responses': [response], 221 | 'indices': [idx] 222 | } 223 | else: 224 | if not current_group['responses']: 225 | current_group['prompt'] = prompt 226 | current_group['responses'].append(response) 227 | current_group['indices'].append(idx) 228 | 229 | # 处理最后一组 230 | if current_group['responses']: 231 | grouped_data.append(current_group) 232 | 233 | print(f"总共需要处理 {len(grouped_data)} 个分组") 234 | 235 | # 2. 并发处理所有分组 236 | async with aiohttp.ClientSession() as session: 237 | tasks = [] 238 | # 使用信号量控制并发数 239 | semaphore = asyncio.Semaphore(max_concurrent) 240 | 241 | # 将计数器移动到外层 242 | success_count = 0 243 | error_count = 0 244 | 245 | # 修改嵌套函数的结构,确保正确使用 nonlocal 246 | async def process_with_semaphore(group, pbar): 247 | nonlocal success_count, error_count # 正确的 nonlocal 声明位置 248 | 249 | async with semaphore: 250 | try: 251 | result = await process_group(session, group) 252 | success_count += 1 253 | pbar.update(1) 254 | pbar.set_postfix({'成功': success_count, '失败': error_count}) 255 | return result 256 | except Exception as e: 257 | error_count += 1 258 | pbar.update(1) 259 | pbar.set_postfix({'成功': success_count, '失败': error_count}) 260 | return { 261 | 'prompt': group['prompt'], 262 | 'responses': group['responses'], 263 | 'indices': group['indices'], 264 | 'selected_label': None, 265 | 'error': str(e) 266 | } 267 | 268 | with tqdm(total=len(grouped_data), desc="处理进度") as pbar: 269 | tasks = [process_with_semaphore(group, pbar) for group in grouped_data] 270 | results = await asyncio.gather(*tasks) 271 | 272 | # 打印统计信息 273 | end_time = datetime.now() 274 | print(f"\n处理完成时间: {end_time.strftime('%Y-%m-%d %H:%M:%S')}") 275 | print(f"处理结果统计:") 276 | print(f"- 成功: {success_count}") 277 | print(f"- 失败: {error_count}") 278 | print(f"- 总计: {len(grouped_data)}") 279 | print(f"- 耗时: {end_time - start_time}") 280 | 281 | # 3. 更新原始数据 282 | for result in results: 283 | if result.get('error'): 284 | # 处理错误情况 285 | for original_idx in result['indices']: 286 | completions[original_idx]['comparison'] = { 287 | 'error': result['error'] 288 | } 289 | continue 290 | 291 | selected_label = result['selected_label'].strip().upper() if result['selected_label'] else None 292 | for pos, original_idx in enumerate(result['indices']): 293 | label = string.ascii_uppercase[pos] 294 | completions[original_idx]['comparison'] = { 295 | 'group_prompt': result['prompt'], 296 | 'position_label': label, 297 | 'is_best': label == selected_label if selected_label else False, 298 | 'best_label': selected_label, 299 | 'compared_with': len(result['responses']) 300 | } 301 | 302 | return completions 303 | 304 | 305 | def _save_chunk_to_jsonl(chunk: List[dict], output_filename: str) -> bool: 306 | """ 307 | 保存一个数据块为 JSONL 格式的文件 (使用 orjson)。 308 | 309 | Args: 310 | chunk: 要保存的数据块。 311 | output_filename: 输出文件名。 312 | 313 | Returns: 314 | True 如果保存成功,False 如果发生错误。 315 | """ 316 | try: 317 | with open(output_filename, 'wb') as f: # 注意使用二进制写入模式 'wb' 318 | for item in chunk: 319 | f.write(orjson.dumps(item) + b'\n') # orjson.dumps 返回 bytes 320 | return True 321 | except Exception as e: 322 | print(f"保存 chunk 到文件 '{output_filename}' 时发生错误: {str(e)}") 323 | return False 324 | 325 | 326 | def save_results_in_chunks(completions: List[dict], output_prefix: str = 'output', chunksize: int = 1000): 327 | """ 328 | 将处理结果按 chunksize 分块保存为多个 JSONL 文件 (使用 orjson)。 329 | 330 | Args: 331 | completions: 处理后的完整数据列表。 332 | output_prefix: 输出文件名的前缀。 333 | chunksize: 每个文件保存的数据条数。 334 | """ 335 | # 提取目录路径 336 | output_dir = os.path.dirname(output_prefix) 337 | # 如果目录不存在,则创建目录 338 | if output_dir and not os.path.exists(output_dir): 339 | os.makedirs(output_dir, exist_ok=True) 340 | 341 | num_chunks = (len(completions) + chunksize - 1) // chunksize 342 | for i in range(num_chunks): 343 | start_index = i * chunksize 344 | end_index = min((i + 1) * chunksize, len(completions)) 345 | chunk = completions[start_index:end_index] 346 | output_filename = f"{output_prefix}_part_{i + 1}.jsonl" 347 | if _save_chunk_to_jsonl(chunk, output_filename): 348 | print(f"已保存 chunk {i + 1} 到文件: {output_filename}") 349 | 350 | 351 | if __name__ == "__main__": 352 | completions = load_completions('/rejected') 353 | # completions = completions[:10] 354 | processed_data = asyncio.run(process_comparisons_async(completions, 3, 60)) 355 | output_file = './v3_eval/v3eval' 356 | 357 | save_results_in_chunks(processed_data, output_file, 16000) 358 | -------------------------------------------------------------------------------- /rlhf/rejected_sampling/run_generate.sh: -------------------------------------------------------------------------------- 1 | dataset_name='' 2 | model_name_or_path='' 3 | save_filename='rejected_generate' 4 | 5 | nohup python generate.py \ 6 | --dataset_name "$dataset_name" \ 7 | --model_name_or_path "$model_name_or_path" \ 8 | --save_filename "$save_filename" \ 9 | --auto_adapt True \ 10 | --num_completions 3 \ 11 | --temperature 0.8 \ 12 | --response_length 4096 \ 13 | --top_p 0.9 \ 14 | --tensor_parallel_size 8 \ 15 | --chunk_size 50000 -------------------------------------------------------------------------------- /rlhf/rejected_sampling/template.py: -------------------------------------------------------------------------------- 1 | # prompt_templates.py 2 | 3 | DEFAULT_SKILL = "summarization" 4 | 5 | GENERATION_TEMPLATES = { 6 | "chat": """ 7 | You are an expert assistant and your goal is to provide the most helpful and accurate response to the following chat. The chat maybe between two users User A and User B or it can be a single turn User A. Please ensure that your response is clear, concise, and addresses all aspects of the prompt. 8 | Don't answer by simply one word. Try to make your answer diverse and interesting. 9 | 10 | ### Prompt: 11 | {prompt} 12 | 13 | ### Response: 14 | """, 15 | "summarization": """ 16 | Please provide a concise summary of the following text, highlighting the most important points without including unimportant or irrelevant details. 17 | 18 | ### Text to Summarize: 19 | {prompt} 20 | 21 | Summary: 22 | """, 23 | "code_generation": """ 24 | Please write a Python function that solves the following problem. Ensure the code is efficient, readable, and follows best practices. 25 | 26 | ### Problem: 27 | {prompt} 28 | 29 | Python Code: 30 | """, 31 | "safety": """ 32 | Please provide a safe and appropriate response to the following scenario or question. Ensure your response adheres to ethical guidelines and promotes user safety. 33 | 34 | ### Scenario: 35 | {prompt} 36 | 37 | Safe Response: 38 | """, 39 | } 40 | 41 | JUDGMENT_TEMPLATES = { 42 | "chat": """ 43 | You are an evaluator tasked with assessing the response to the prompt based on three criteria: accuracy, instruction following, and tone/presentation. 44 | For each response provided to the prompt, assign a score from 1 to 3 for each criterion, where 1 is poor, 2 is average, and 3 is excellent. 45 | Don't give full scores unless you're mind blown by the response. 46 | 47 | ### Prompt: 48 | {prompt} 49 | 50 | ### Response: 51 | {response} 52 | 53 | ### Evaluation Criteria: 54 | 1. **Accuracy**: How factually correct and relevant is the response? 55 | 2. **Instruction Following**: How well does the response adhere to the given instructions? 56 | 3. **Tone/Presentation**: How clear, appropriate, and well-structured is the response? 57 | 58 | ### Scoring Template: 59 | Response: 60 | - **Accuracy**: [Score 1-3] 61 | - **Instruction Following**: [Score 1-3] 62 | - **Tone/Presentation**: [Score 1-3] 63 | 64 | Total score: 65 | """, 66 | "summarization": """ 67 | Compare the following summaries of the given text. Which one does a better job of summarizing the most important points without including unimportant or irrelevant details? Judge based on accuracy, coverage, and coherence. 68 | 69 | ### Original Text: 70 | {prompt} 71 | 72 | ### Summary A: 73 | {response1} 74 | 75 | ### Summary B: 76 | {response2} 77 | 78 | Comparison: 79 | Preferred Summary: 80 | """, 81 | "code_generation": """ 82 | Compare the following two code snippets that aim to solve the given problem. Which one is more efficient, readable, and follows best practices? 83 | 84 | ### Problem: 85 | {prompt} 86 | 87 | ### Code A: 88 | {response1} 89 | 90 | ### Code B: 91 | {response2} 92 | 93 | Code Analysis: 94 | Preferred Code: 95 | """, 96 | "safety": """ 97 | Evaluate the safety implications of the following two responses to the given scenario. Which one is safer and more appropriate? 98 | 99 | ### Scenario: 100 | {prompt} 101 | 102 | ### Response A: 103 | {response1} 104 | 105 | ### Response B: 106 | {response2} 107 | 108 | Safety Analysis: 109 | Preferred Response: 110 | """, 111 | } 112 | 113 | 114 | def get_generation_template(skill: str) -> str: 115 | return GENERATION_TEMPLATES.get(skill, GENERATION_TEMPLATES[DEFAULT_SKILL]) 116 | 117 | 118 | def get_judgment_template(skill: str) -> str: 119 | return JUDGMENT_TEMPLATES.get(skill, JUDGMENT_TEMPLATES[DEFAULT_SKILL]) 120 | -------------------------------------------------------------------------------- /rlhf/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.1.1 2 | datasets==3.1.0 3 | peft==0.13.2 4 | transformers==4.46.3 5 | trl==0.12.1 6 | deepspeed==0.15.4 7 | torch==2.5.1 -------------------------------------------------------------------------------- /rlhf/rlhf_args/base_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from transformers import TrainingArguments 3 | 4 | 5 | @dataclass 6 | class BaseConfig(TrainingArguments): 7 | """ 8 | 训练参数 9 | """ 10 | output_dir: str = field(default='./output', metadata={"help": "模型训练完成后的保存路径"}) 11 | num_train_epochs: int = 1, 12 | 13 | per_device_train_batch_size: int = 2 14 | gradient_checkpointing: bool = True 15 | gradient_accumulation_steps: int = 16, 16 | 17 | learning_rate: float = 2e-4 18 | logging_steps: int = 10 19 | save_steps: int = 500 20 | save_strategy: str = "steps" 21 | save_total_limit: int = 2 22 | lr_scheduler_type: str = "cosine", 23 | warmup_steps: int = 10 24 | optim: str = 'adamw_torch' 25 | report_to: str = 'tensorboard' 26 | remove_unused_columns: bool = False 27 | bf16: bool = False 28 | fp16: bool = False -------------------------------------------------------------------------------- /rlhf/rlhf_args/cpo-simpo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | from cpo_config import CPOConfig 4 | 5 | 6 | @dataclass 7 | class CPOSimPOConfig(CPOConfig): 8 | """ 9 | 基于CPOConfig,只需修改loss_type为simpo且cpo_alpha不为0即可 10 | """ 11 | loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "simpo" 12 | """The type of loss to use.""" 13 | cpo_alpha: float = 0.5 14 | """combined use of CPO and SimPO, which enables more stable training and improved performance.A non-zero 15 | cpo_alpha""" 16 | -------------------------------------------------------------------------------- /rlhf/rlhf_args/cpo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from rlhf_args.base_config import BaseConfig 3 | from trl import CPOConfig as TrlCPOConfig 4 | 5 | 6 | @dataclass 7 | class CPOConfig(BaseConfig, TrlCPOConfig): 8 | pass 9 | 10 | 11 | -------------------------------------------------------------------------------- /rlhf/rlhf_args/dpo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from rlhf_args.base_config import BaseConfig 3 | from typing import Literal 4 | from trl import DPOConfig as TrlDPOConfig 5 | 6 | 7 | @dataclass 8 | class DPOConfig(BaseConfig, TrlDPOConfig): 9 | """ 10 | 训练参数, 可直接在此修改. 想看更多参数可直接在TrlDPOConfig中去看 11 | """ 12 | beta: float = 0.1 13 | label_smoothing: float = 0.0 14 | loss_type: Literal[ 15 | "sigmoid", 16 | "hinge", 17 | "ipo", 18 | "exo_pair", 19 | "nca_pair", 20 | "robust", 21 | "bco_pair", 22 | "sppo_hard", 23 | "aot", 24 | "aot_pair", 25 | "apo_zero", 26 | "apo_down", 27 | ] = "sigmoid" 28 | label_pad_token_id: int = -100 29 | -------------------------------------------------------------------------------- /rlhf/rlhf_args/kto_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from rlhf_args.base_config import BaseConfig 3 | from trl import KTOConfig as TrlKTOConfig 4 | 5 | 6 | @dataclass 7 | class KTOConfig(BaseConfig, TrlKTOConfig): 8 | desirable_weight: float = 1.0 9 | undesirable_weight: float = 1.0 10 | -------------------------------------------------------------------------------- /rlhf/rlhf_args/ppo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from trl import PPOConfig as TrlPPOConfig 3 | from rlhf_args.base_config import BaseConfig 4 | 5 | 6 | @dataclass 7 | class PPOConfig(BaseConfig, TrlPPOConfig): 8 | eval_samples = 30 # eval数量 9 | -------------------------------------------------------------------------------- /rlhf/rlhf_args/reward_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from rlhf_args.base_config import BaseConfig 3 | from trl import RewardConfig as TrlRewardConfig 4 | 5 | 6 | @dataclass 7 | class RewardConfig(BaseConfig, TrlRewardConfig): 8 | pass 9 | -------------------------------------------------------------------------------- /rlhf/rlhf_args/rloo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from rlhf_args.base_config import BaseConfig 4 | from trl import RLOOConfig as TrlPLOOConfig 5 | 6 | 7 | # 支持直接通过total_episodes确定训练步数,也支持通过在TrainingArguments中配置num_train_epochs确定训练步数。 8 | @dataclass 9 | class RLOOConfig(BaseConfig, TrlPLOOConfig): 10 | reward_model_path: str = "./" 11 | sft_model_path: str = "./" 12 | total_episodes: Optional[int] = None 13 | eval_samples = 30 # eval数量 14 | -------------------------------------------------------------------------------- /rlhf/rlhf_args/simpo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | from cpo_config import CPOConfig 4 | 5 | 6 | @dataclass 7 | class SimPOConfig(CPOConfig): 8 | """ 9 | 基于CPOConfig,只需修改CPO的loss type为simpo,cpo_alpha设为0即可 10 | """ 11 | loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "simpo" 12 | """The type of loss to use.""" 13 | cpo_alpha: float = 0 14 | """A hyperparameter that controls the strength of the BC regularizer in CPO training.""" 15 | simpo_gamma: float = 0.5 16 | """A target reward margin for the SimPO loss, used only when the "simpo" option is enabled.""" 17 | 18 | -------------------------------------------------------------------------------- /rlhf/rlhf_run.sh: -------------------------------------------------------------------------------- 1 | # 使用显卡数量需在yaml文件中修改num_processes参数 2 | 3 | # rlhf_type:[PPO,RLOO,CPO,DPO,SimPO,CPOSimPO,Reward] 4 | # train_mode:[lora, qlora, full] 5 | 6 | TRAIN_DATA='./' 7 | MODEL_PATH='./' 8 | OUTPUT_PATH='./' 9 | 10 | CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ./ds_config/ds_zero2.yaml ./train_rlhf.py \ 11 | --model_name_or_path "$MODEL_PATH" \ 12 | --train_data_path "$TRAIN_DATA" \ 13 | --output_dir "$OUTPUT_PATH" \ 14 | --rlhf_type "DPO" \ 15 | --train_mode "lora" \ 16 | --learning_rate 2e-5 \ 17 | --per_device_train_batch_size 2 \ 18 | --gradient_checkpointing \ 19 | --gradient_accumulation_steps 8 \ 20 | --logging_steps 2 \ 21 | --num_train_epochs 1 \ 22 | --bf16 \ 23 | --save_strategy "steps" \ 24 | --report_to "wandb" \ 25 | --save_steps 180 \ 26 | --save_total_limit 5 \ 27 | --warmup_steps 10 \ 28 | --remove_unused_columns False\ 29 | --lr_scheduler_type "cosine" 30 | 31 | # [CPO,DPO,SimPO,CPOSimPO,Reward] 可直接使用上述运行 32 | 33 | # [PPO,RLOO] 需要额外添加如下参数: 34 | # --reward_model_path './'\ 35 | # --local_rollout_forward_batch_size 1\ 36 | # --missing_eos_penalty 1.0\ 37 | # --num_ppo_epochs 1 \ 38 | # --num_mini_batches 1 -------------------------------------------------------------------------------- /rlhf/train_gkd.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import random 3 | from transformers import AutoTokenizer 4 | from trl import ( 5 | GKDConfig, 6 | GKDTrainer, 7 | LogCompletionsCallback, 8 | ModelConfig, 9 | ScriptArguments, 10 | TrlParser, 11 | get_kbit_device_map, 12 | get_peft_config, 13 | get_quantization_config, 14 | ) 15 | from accelerate import PartialState 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig)) 20 | args, training_args, model_config = parser.parse_args_and_config() 21 | 22 | ################ 23 | # Model & Tokenizer 24 | ################ 25 | quantization_config = get_quantization_config(model_config) 26 | model_kwargs = dict( 27 | revision=model_config.model_revision, 28 | trust_remote_code=model_config.trust_remote_code, 29 | attn_implementation=model_config.attn_implementation, 30 | torch_dtype=model_config.torch_dtype, 31 | use_cache=False if training_args.gradient_checkpointing else True, 32 | device_map=get_kbit_device_map() if quantization_config is not None else None, 33 | quantization_config=quantization_config, 34 | ) 35 | training_args.model_init_kwargs = model_kwargs 36 | 37 | teacher_model_kwargs = dict( 38 | revision=model_config.model_revision, 39 | trust_remote_code=model_config.trust_remote_code, 40 | attn_implementation=model_config.attn_implementation, 41 | torch_dtype=model_config.torch_dtype, 42 | use_cache=True, 43 | device_map=get_kbit_device_map() if quantization_config is not None else None, 44 | quantization_config=quantization_config, 45 | ) 46 | training_args.teacher_model_init_kwargs = teacher_model_kwargs 47 | 48 | tokenizer = AutoTokenizer.from_pretrained( 49 | model_config.model_name_or_path, 50 | trust_remote_code=model_config.trust_remote_code, 51 | padding_side="left", 52 | ) 53 | if tokenizer.pad_token is None: 54 | tokenizer.pad_token = tokenizer.eos_token 55 | 56 | ################ 57 | # Dataset 58 | ################ 59 | dataset = load_dataset(data_files=args.dataset_name, path='json') # 适配jsonl格式 60 | 61 | with PartialState().local_main_process_first(): 62 | dataset = dataset.map( 63 | lambda x: { 64 | "prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=True) 65 | }, 66 | num_proc=training_args.dataset_num_proc, 67 | ) 68 | train_data = dataset['train'] 69 | test_data = train_data.select(random.sample(range(len(train_data)), 20)) 70 | 71 | ################ 72 | # Training 73 | ################ 74 | trainer = GKDTrainer( 75 | model=model_config.model_name_or_path, 76 | teacher_model=training_args.teacher_model_name_or_path, 77 | args=training_args, 78 | train_dataset=dataset[args.dataset_train_split], 79 | eval_dataset=test_data, 80 | processing_class=tokenizer, 81 | peft_config=get_peft_config(model_config), 82 | ) 83 | completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8) 84 | trainer.add_callback(completions_callback) 85 | trainer.train() 86 | 87 | # Save 88 | trainer.save_model(training_args.output_dir) 89 | -------------------------------------------------------------------------------- /rlhf/train_rlhf.py: -------------------------------------------------------------------------------- 1 | import deepspeed 2 | deepspeed.ops.op_builder.CPUAdamBuilder().load() 3 | import importlib 4 | import os 5 | from peft import LoraConfig, TaskType 6 | from datasets import load_dataset 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | AutoTokenizer, 10 | AutoModelForSequenceClassification, 11 | HfArgumentParser, 12 | BitsAndBytesConfig, 13 | ) 14 | import torch 15 | from accelerate import PartialState 16 | from trl import DPOTrainer, CPOTrainer, PPOTrainer, RLOOTrainer, RewardTrainer, KTOTrainer 17 | from common_args import CommonArgs 18 | from utils.util import find_all_linear_names 19 | from loguru import logger 20 | from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template 21 | 22 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 23 | 24 | WITH_REWARD_MODEL = ['RLOO', 'PPO'] 25 | USE_REF_MODEL = ['DPO', 'RLOO', 'KTO'] 26 | 27 | trainer_map = { 28 | 'PPO': PPOTrainer, 29 | "RLOO": RLOOTrainer, 30 | "DPO": DPOTrainer, 31 | "CPO": CPOTrainer, 32 | "SimPO": CPOTrainer, 33 | "CPOSimPO": CPOTrainer, 34 | 'Reward': RewardTrainer, 35 | "KTO": KTOTrainer 36 | } 37 | 38 | train_args_path = { 39 | 'PPO': 'rlhf_args/ppo_config.py', 40 | "RLOO": 'rlhf_args/rloo_config.py', 41 | "DPO": 'rlhf_args/dpo_config.py', 42 | "CPO": 'rlhf_args/cpo_config.py', 43 | "SimPO": 'rlhf_args/simpo_config.py', 44 | "CPOSimPO": 'rlhf_args/cpo-simpo_config.py', 45 | 'Reward': 'rlhf_args/reward_config.py', 46 | 'KTO': 'rlhf_args/kto_config.py' 47 | } 48 | 49 | 50 | def load_config(args, remaining_args): 51 | # 根据config_option加载相应的配置 52 | module_path = train_args_path[args.rlhf_type].replace("/", ".").rstrip(".py") 53 | # 动态导入模块 54 | module = importlib.import_module(module_path) 55 | # 每个模块导入的类名均为TrainArgument 56 | class_name = args.rlhf_type + "Config" 57 | # 使用getattr获取模块中的类 58 | argument = getattr(module, class_name) 59 | 60 | parser_b = HfArgumentParser((argument,)) 61 | train_args, = parser_b.parse_args_into_dataclasses(args=remaining_args) 62 | return train_args 63 | 64 | 65 | def load_judge_reward(): 66 | pass 67 | 68 | 69 | def load_classification_reward(path, model_kwargs): 70 | try: 71 | reward_model = AutoModelForSequenceClassification.from_pretrained(path, num_labels=1, 72 | **model_kwargs) 73 | return reward_model 74 | except Exception as e: 75 | assert False, "模型不支持AutoModelForSequenceClassification需要在对应config文件中添加映射" 76 | 77 | 78 | def load_tokenizer(path): 79 | tokenizer = AutoTokenizer.from_pretrained( 80 | path, 81 | padding_side="left", 82 | trust_remote_code=True, 83 | ) 84 | 85 | if tokenizer.pad_token is None: 86 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 87 | return tokenizer 88 | 89 | 90 | def prepare_dataset(dataset, tokenizer): 91 | """pre-tokenize the dataset before training; only collate during training""" 92 | 93 | def tokenize(element): 94 | outputs = tokenizer( 95 | element['prompt'], 96 | padding=False, 97 | ) 98 | return {"input_ids": outputs["input_ids"]} 99 | 100 | return dataset.map( 101 | tokenize, 102 | batched=True 103 | ) 104 | 105 | 106 | def main(): 107 | parser = HfArgumentParser((CommonArgs,)) 108 | args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True) 109 | # 根据CommonArgs中的config_option动态加载配置 110 | config = load_config(args, remaining_args) 111 | 112 | ################ 113 | # Data 114 | ################ 115 | train_dataset = load_dataset(data_files=args.train_data_path, path='json') 116 | 117 | ################ 118 | # Model & Tokenizer 119 | ################ 120 | tokenizer = load_tokenizer(args.model_name_or_path) 121 | 122 | model_kwargs = dict( 123 | trust_remote_code=True, 124 | torch_dtype=torch.float16 if config.fp16 else torch.bfloat16, 125 | ) 126 | 127 | if args.train_mode == 'qlora': 128 | quantization_config = BitsAndBytesConfig( 129 | load_in_4bit=True, 130 | bnb_4bit_compute_dtype=torch.float16 if config.fp16 else torch.bfloat16, 131 | bnb_4bit_use_double_quant=True, 132 | bnb_4bit_quant_type="nf4", 133 | llm_int8_threshold=6.0, 134 | llm_int8_has_fp16_weight=False, 135 | ) 136 | model_kwargs.update(quantization_config=quantization_config) 137 | 138 | # 加载policy model 139 | if args.rlhf_type == 'Reward': 140 | policy = load_classification_reward(args.model_name_or_path, model_kwargs) 141 | else: 142 | policy = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs) 143 | 144 | if args.train_mode in ['lora', 'qlora']: 145 | lora_config = LoraConfig( 146 | task_type=TaskType.SEQ_CLS if args.rlhf_type == 'Reward' else TaskType.CAUSAL_LM, 147 | target_modules=find_all_linear_names(policy), 148 | r=args.lora_rank, # Lora 秩 149 | lora_alpha=args.lora_alpha, # Lora alpha,具体作用参见 Lora 原理 150 | lora_dropout=args.lora_dropout, # Dropout 比例 151 | use_dora=args.use_dora 152 | ) 153 | ref_model = None # if peft, the model with a disabled adapter 154 | elif args.rlhf_type in USE_REF_MODEL: 155 | ref_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs) 156 | lora_config = None 157 | 158 | # 决定是否加载Reward model 159 | if args.rlhf_type in WITH_REWARD_MODEL: 160 | # 如果模型不支持AutoModelForSequenceClassification需要在对应config文件中添加映射 161 | reward_model = load_classification_reward(config.reward_model_path, model_kwargs) 162 | 163 | # data process 164 | # Compute that only on the main process for faster data processing. 165 | # see: https://github.com/huggingface/trl/pull/1255 166 | train_dataset = train_dataset.select(range(len(train_dataset) - config.eval_samples)) 167 | eval_dataset = train_dataset.select(range(len(train_dataset) - config.eval_samples, len(train_dataset))) 168 | with PartialState().local_main_process_first(): 169 | train_dataset = prepare_dataset(train_dataset, tokenizer) 170 | eval_dataset = prepare_dataset(eval_dataset, tokenizer) 171 | if ref_model is None: 172 | ref_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, 173 | torch_dtype=torch.float16 if config.fp16 else torch.bfloat16) 174 | 175 | ################ 176 | # Training 177 | ################ 178 | trainer_kwargs_map = { 179 | "DPO": dict( 180 | model=policy, 181 | ref_model=ref_model, 182 | args=config, 183 | train_dataset=train_dataset['train'], 184 | eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None, 185 | processing_class=tokenizer, 186 | peft_config=lora_config, 187 | ) 188 | if args.rlhf_type in ['DPO', 'KTO'] 189 | else dict() 190 | , 191 | 'CPO': dict( 192 | model=policy, 193 | args=config, 194 | train_dataset=train_dataset['train'], 195 | eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None, 196 | processing_class=tokenizer, 197 | peft_config=lora_config, 198 | ) 199 | if args.rlhf_type in ['CPO', 'SimPO', 'CPOSimPO'] 200 | else dict() 201 | , 202 | "PPO": dict( 203 | ), 204 | "RLOO": dict( 205 | config=config, 206 | processing_class=tokenizer, 207 | policy=policy, 208 | ref_policy=ref_model, 209 | reward_model=reward_model, 210 | train_dataset=train_dataset, 211 | eval_dataset=eval_dataset, 212 | ) 213 | if args.rlhf_type == 'RLOO' 214 | else dict() 215 | , 216 | 'Reward': dict( 217 | model=policy, 218 | processing_class=tokenizer, 219 | args=config, 220 | train_dataset=train_dataset['train'], 221 | eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None, 222 | peft_config=lora_config, 223 | ) 224 | if args.rlhf_type == 'Reward' 225 | else dict() 226 | } 227 | 228 | trainer_kwargs_map['SimPO'] = trainer_kwargs_map['CPO'].copy() 229 | trainer_kwargs_map['CPOSimPO'] = trainer_kwargs_map['CPO'].copy() 230 | trainer_kwargs_map['KTO'] = trainer_kwargs_map['DPO'].copy() 231 | 232 | # 从字典中获取相应的 Trainer 类 233 | trainer_kwargs = trainer_kwargs_map.get(args.rlhf_type) 234 | TrainerClass = trainer_map.get(args.rlhf_type) 235 | if TrainerClass is None: 236 | raise ValueError(f"Unknown trainer type: {args.rlhf_type}") 237 | 238 | trainer = TrainerClass(**trainer_kwargs) 239 | trainer.train() 240 | # trainer.save_model(config.output_dir) 241 | 242 | 243 | if __name__ == "__main__": 244 | main() 245 | -------------------------------------------------------------------------------- /rlhf/utils/util.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | import sys 4 | from dataclasses import dataclass 5 | from typing import List, Dict, Optional, Tuple, Union 6 | import copy 7 | 8 | import torch 9 | import torch.nn as nn 10 | from transformers import HfArgumentParser 11 | from transformers.hf_argparser import DataClassType 12 | 13 | 14 | class ArgumentParserPlus(HfArgumentParser): 15 | def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: 16 | """ 17 | Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. 18 | 19 | Args: 20 | yaml_arg (`str`): 21 | The path to the config file used 22 | other_args (`List[str]`, *optional`): 23 | A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. 24 | 25 | Returns: 26 | [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line 27 | """ 28 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) 29 | 30 | outputs = [] 31 | # strip other args list into dict of key-value pairs 32 | other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} 33 | used_args = {} 34 | 35 | # overwrite the default/loaded value with the value provided to the command line 36 | # noqa adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 37 | for data_yaml, data_class in zip(arg_list, self.dataclass_types): 38 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} 39 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} 40 | for arg, val in other_args.items(): 41 | # add only if in keys 42 | 43 | if arg in keys: 44 | base_type = data_yaml.__dataclass_fields__[arg].type 45 | inputs[arg] = val 46 | 47 | # cast type for ints, floats (default to strings) 48 | if base_type in [int, float]: 49 | inputs[arg] = base_type(val) 50 | 51 | if base_type == List[str]: 52 | inputs[arg] = [str(v) for v in val.split(",")] 53 | 54 | # bool of a non-empty string is True, so we manually check for bools 55 | if base_type == bool: 56 | if val in ["true", "True"]: 57 | inputs[arg] = True 58 | else: 59 | inputs[arg] = False 60 | 61 | # add to used-args so we can check if double add 62 | if arg not in used_args: 63 | used_args[arg] = val 64 | else: 65 | raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") 66 | 67 | obj = data_class(**inputs) 68 | outputs.append(obj) 69 | 70 | return outputs 71 | 72 | def parse(self) -> Union[DataClassType, Tuple[DataClassType]]: 73 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 74 | # If we pass only one argument to the script and it's the path to a YAML file, 75 | # let's parse it to get our arguments. 76 | output = self.parse_yaml_file(os.path.abspath(sys.argv[1])) 77 | # parse command line args and yaml file 78 | elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): 79 | output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:]) 80 | # parse command line args only 81 | else: 82 | output = self.parse_args_into_dataclasses() 83 | 84 | if len(output) == 1: 85 | output = output[0] 86 | return output 87 | 88 | 89 | def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor: 90 | """ 91 | Finds the index of the first `True` value in each row of a boolean tensor. If no `True` value exists in a row, 92 | it returns the length of the row. 93 | 94 | Args: 95 | bools (torch.Tensor): A boolean tensor of shape (batch_size, sequence_length), where `True` values indicate 96 | the positions of interest. 97 | dtype (torch.dtype): The data type to use for the output indices (default is torch.long). 98 | 99 | Returns: 100 | torch.Tensor: A tensor of shape (batch_size,) containing the index of the first `True` value in each row. 101 | If a row has no `True` value, the index will be the length of the row. 102 | """ 103 | 104 | # Get the length of each row (i.e., the number of columns in the last dimension) 105 | # row_len is a scalar representing the length of each sequence (sequence_length) 106 | row_len = bools.size(-1) 107 | 108 | # Calculate the index positions for the first `True` in each row 109 | # ~bools: Invert the boolean values (True becomes False and vice versa) 110 | # ~bools.type(dtype): Convert the inverted boolean tensor to the specified dtype (0 for True, 1 for False) 111 | # row_len * (~bools).type(dtype): For `False` values, this will give `row_len`, for `True` values it gives 0. 112 | # torch.arange(row_len, dtype=dtype, device=bools.device): Generates a tensor with values [0, 1, 2, ..., row_len-1] 113 | # for each row. Shape: (sequence_length,) 114 | # zero_or_index: Shape (batch_size, sequence_length). This tensor contains the indices for `True` values and `row_len` 115 | # for `False` values. 116 | zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) 117 | 118 | # Return the minimum value in each row (i.e., the first `True` index or `row_len` if none exist) 119 | # torch.min(zero_or_index, dim=-1).values: This returns the minimum value in each row, which corresponds to the first 120 | # `True` value's index or `row_len` if there is no `True` in that row. 121 | # The returned tensor has shape (batch_size,) 122 | return torch.min(zero_or_index, dim=-1).values 123 | 124 | 125 | def get_reward( 126 | model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int 127 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 128 | """ 129 | This function computes reward scores for a batch of query responses based on a pre-trained reward model. 130 | 131 | Args: 132 | model (torch.nn.Module): The pre-trained reward model. 133 | query_responses (torch.Tensor): Tensor containing the tokenized responses for which to compute rewards. 134 | Shape: (batch_size, sequence_length) 135 | pad_token_id (int): The ID used for padding tokens in the tokenized sequences. 136 | context_length (int): The length of the prompt or context preceding the completions. 137 | 138 | Returns: 139 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: 140 | - reward_logits: The logits output from the model for all tokens in the sequences. 141 | Shape: (batch_size, sequence_length) 142 | - final_scores: The final reward scores, one for each sequence, after adjusting for sequence lengths. 143 | Shape: (batch_size,) 144 | - sequence_lengths: The lengths of each sequence (excluding padding). 145 | Shape: (batch_size,) 146 | 147 | For example: 148 | query_responses = torch.tensor([ 149 | [token0, token1, token2, token3, 0, 0], 150 | [token0, token1, token4, 0, 0, 0] 151 | ]) # 形状: (2, 6) 152 | 153 | attention_mask = query_responses != 0 154 | # [[1, 1, 1, 1, 0, 0], 155 | # [1, 1, 1, 0, 0, 0]] 156 | 157 | position_ids = attention_mask.cumsum(1) - attention_mask.long() 158 | # [[0, 1, 2, 3, 4, 4], 159 | # [0, 1, 2, 3, 3, 3]] 160 | 161 | input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) 162 | # 在此例中,填充 token 已为 0,无变化 163 | 164 | reward_logits = torch.tensor([ 165 | [[r0_0], [r0_1], [r0_2], [r0_3], [r0_4], [r0_5]], 166 | [[r1_0], [r1_1], [r1_2], [r1_3], [r1_4], [r1_5]] 167 | ]) # 形状: (2, 6, 1) 168 | 169 | query_responses[:, 2:] == 0 170 | # [[False, False, True, True], 171 | # [False, True, True, True]] 172 | 173 | sequence_lengths = first_true_indices(...) - 1 + 2 174 | # first_true_indices = [2, 1] 175 | # sequence_lengths = [2-1+2, 1-1+2] = [3, 2] 176 | 177 | final_scores = reward_logits[torch.arange(2), [3, 2]].squeeze(-1) 178 | # = reward_logits[[0,1], [3, 2]] ---> reward_logits[0, 3, :], reward_logits[1, 2, :] 179 | # = [r0_3, r1_2],形状: (2,) 180 | """ 181 | 182 | # Create an attention mask where tokens that are not padding have a value of 1, and padding tokens have a value of 0 183 | # Shape: (batch_size, sequence_length) 184 | attention_mask = query_responses != pad_token_id 185 | 186 | # Calculate position IDs for each token, considering the cumulative sum of the attention mask (to exclude padding) 187 | # Shape: (batch_size, sequence_length) 188 | position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum 189 | 190 | # Access the LM backbone from the reward model using its base model prefix 191 | lm_backbone = getattr(model, model.base_model_prefix) 192 | 193 | # Replace padding tokens with zeros in the input IDs (so padding tokens won't affect the model's processing) 194 | # Shape: (batch_size, sequence_length) 195 | input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) 196 | output = lm_backbone( 197 | input_ids=input_ids, 198 | attention_mask=attention_mask, 199 | position_ids=position_ids, 200 | return_dict=True, 201 | output_hidden_states=True, 202 | use_cache=False, # otherwise mistral-based RM would error out 203 | ) 204 | reward_logits = model.score(output.hidden_states[-1]) # (batch_size, sequence_length) 205 | 206 | # Calculate the length of each sequence by finding the first occurrence of a padding token after the context 207 | # sequence_lengths shape: (batch_size,) 208 | sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length 209 | assert ( 210 | reward_logits.shape[-1] == 1 211 | ), "Reward model should output a single scalar per token. Check if you added `num_labels=1` when doing `AutoModelForSequenceClassification.from_pretrained(...)`." 212 | # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 213 | 214 | # Return the reward logits for all tokens, the final reward scores for each sequence, and the sequence lengths 215 | return ( 216 | # reward_logits shape: (batch_size, sequence_length) 217 | reward_logits, 218 | # final_scores shape: (batch_size,) 219 | reward_logits[ 220 | torch.arange(reward_logits.size(0), device=reward_logits.device), 221 | sequence_lengths, 222 | ].squeeze( 223 | -1 224 | ), # Shape: (batch_size,) 225 | sequence_lengths, 226 | ) 227 | 228 | 229 | def find_all_linear_names(model): 230 | """ 231 | 找出所有全连接层,为所有全连接添加adapter 232 | """ 233 | cls = nn.Linear 234 | lora_module_names = set() 235 | for name, module in model.named_modules(): 236 | if isinstance(module, cls): 237 | names = name.split('.') 238 | lora_module_names.add(names[-1]) 239 | 240 | if 'lm_head' in lora_module_names: # needed for 16-bit 241 | lora_module_names.remove('lm_head') 242 | lora_module_names = list(lora_module_names) 243 | return lora_module_names 244 | 245 | 246 | def is_right_apply_chat(tokenizer, prompt: List[Dict[str, str]], assistant_content: List[Dict[str, str]]) -> bool: 247 | """ 248 | Checks if the assistant's content is correctly applied to the prompt in a chat template. 249 | Args: 250 | tokenizer: The tokenizer. 251 | prompt: The initial prompt message. 252 | assistant_content: The content provided by the assistant. 253 | Returns: 254 | bool: True if the assistant's content is correctly applied, False otherwise. 255 | """ 256 | try: 257 | test_assistant = tokenizer.apply_chat_template(assistant_content, tokenize=False) 258 | test_prompt = tokenizer.apply_chat_template(prompt, tokenize=False) 259 | conversation = copy.deepcopy(prompt) 260 | conversation.append(assistant_content[0]) 261 | if tokenizer.apply_chat_template(conversation) == test_prompt + test_assistant: 262 | return True 263 | else: 264 | return False 265 | except Exception as e: 266 | return False 267 | 268 | 269 | def fix_chat_template_if_needed(tokenizer, prompt: List[Dict[str, str]], chosen: List[Dict[str, str]], 270 | rejected: List[Dict[str, str]]): 271 | """ 272 | Fixes the chat template if needed. 273 | Args: 274 | tokenizer: The tokenizer. 275 | prompt: The initial prompt message. 276 | chosen: The chosen response, a list containing a single dictionary representing the chosen message. 277 | rejected: The rejected response, a list containing a single dictionary representing the rejected message. 278 | Returns: 279 | - tuple: A tuple containing the fixed prompt, fixed chosen response, and fixed rejected response. 280 | """ 281 | conversation_chosen = copy.deepcopy(prompt) 282 | conversation_rejected = copy.deepcopy(prompt) 283 | conversation_chosen.append(chosen[0]) 284 | conversation_rejected.append(rejected[0]) 285 | conversation_chosen = tokenizer.apply_chat_template(conversation_chosen, tokenize=False) 286 | conversation_rejected = tokenizer.apply_chat_template(conversation_rejected, tokenize=False) 287 | # find position 288 | start_position = conversation_chosen.find(chosen[0]['content'][0]) 289 | # The following is right 290 | fixed_prompt = conversation_chosen[:start_position] 291 | fixed_chosen = conversation_chosen[start_position:] 292 | fixed_rejected = conversation_rejected[start_position:] 293 | return fixed_prompt, fixed_chosen, fixed_rejected 294 | -------------------------------------------------------------------------------- /run_eval_test.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH='' 2 | OUTPUT_PATH="" 3 | MODEL_PATH="" 4 | 5 | BASE_CMD="CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 nohup accelerate launch --config_file rlhf/ds_config/ds_zero3.yaml main_train.py \ 6 | --train_data_path "$DATA_PATH" \ 7 | --model_name_or_path "$MODEL_PATH" \ 8 | --max_len 4096 \ 9 | --num_train_epochs 1 \ 10 | --per_device_train_batch_size 2 \ 11 | --gradient_accumulation_steps 8 \ 12 | --task_type "sft" \ 13 | --train_mode "full" \ 14 | --output_dir "$OUTPUT_PATH" \ 15 | --save_strategy "steps" \ 16 | --save_steps 100 \ 17 | --save_total_limit 3 \ 18 | --learning_rate 2e-5 \ 19 | --warmup_steps 16 \ 20 | --logging_steps 1 \ 21 | --lr_scheduler_type "cosine" \ 22 | --gradient_checkpointing True \ 23 | --report_to "wandb" \ 24 | --bf16 True \ 25 | --auto_adapt True" 26 | 27 | # 评估相关参数(仅在use_eval_in_train为True时使用) 28 | EVAL_ARGS="--use_eval_in_train True \ 29 | --test_datasets_path "data/test.jsonl" \ 30 | --max_new_tokens 4096 \ 31 | --freq 4 \ 32 | --metrics "code" \ 33 | --vllm_server_port 8001 \ 34 | --vllm_server_timeout 30 \ 35 | --save_best_checkpoints True \ 36 | --max_checkpoints 2 \ 37 | --start_update_best_checkpoints 4 \ 38 | --prompts_apply_chat True \ 39 | --use_vllm True" 40 | 41 | # 根据是否需要评估来构建完整命令 42 | if [ "$1" = "--eval" ]; then 43 | FULL_CMD="$BASE_CMD $EVAL_ARGS" 44 | else 45 | FULL_CMD="$BASE_CMD --use_eval_in_train False" 46 | fi 47 | 48 | # 执行命令 49 | eval $FULL_CMD -------------------------------------------------------------------------------- /run_example.sh: -------------------------------------------------------------------------------- 1 | 2 | DATA_PATH='' 3 | OUTPUT_PATH="" 4 | MODEL_PATH="" 5 | 6 | # task_type:[sft] pretrain正在开发 7 | # train_mode:[qlora, lora, full] 8 | # train_args_path: [sft_args,dpo_args] 9 | 10 | # deepspeed 启动 11 | deepspeed --master_port 29507 --include localhost:0,1 main_train.py\ 12 | --train_data_path "$DATA_PATH" \ 13 | --model_name_or_path "$MODEL_PATH" \ 14 | --max_len 1024 \ 15 | --num_train_epochs 1 \ 16 | --per_device_train_batch_size 8 \ 17 | --per_device_eval_batch_size 1 \ 18 | --gradient_accumulation_steps 4 \ 19 | --task_type "sft" \ 20 | --train_mode "qlora" \ 21 | --output_dir "$OUTPUT_PATH" \ 22 | --save_strategy "steps" \ 23 | --save_steps 500 \ 24 | --save_total_limit 5 \ 25 | --learning_rate 2e-4 \ 26 | --warmup_steps 10 \ 27 | --logging_steps 1 \ 28 | --lr_scheduler_type "cosine_with_min_lr" \ 29 | --gradient_checkpointing True \ 30 | --report_to "wandb" \ 31 | --deepspeed './train_args/deepspeed_config/ds_config_zero2.json' \ 32 | --bf16 True \ 33 | --auto_adapt True \ 34 | --use_eval_in_train True \ 35 | --test_datasets_path "./" \ 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | # python main_train.py --train_data_path 数据集路径 --model_name_or_path 模型路径 ......同上述传入参数 44 | -------------------------------------------------------------------------------- /run_vlm_example.sh: -------------------------------------------------------------------------------- 1 | 2 | DATA_PATH='' 3 | OUTPUT_PATH="" 4 | MODEL_PATH="" 5 | 6 | 7 | deepspeed --master_port 29507 --include localhost:0,1 vlm_train.py\ 8 | --train_data_path "$DATA_PATH" \ 9 | --model_name_or_path "$MODEL_PATH" \ 10 | --max_seq_length 1024 \ 11 | --num_train_epochs 1 \ 12 | --per_device_train_batch_size 2 \ 13 | --gradient_accumulation_steps 2 \ 14 | --task_type "QA" \ 15 | --train_mode "lora" \ 16 | --output_dir "$OUTPUT_PATH" \ 17 | --save_strategy "steps" \ 18 | --save_steps 20 \ 19 | --save_total_limit 5 \ 20 | --learning_rate 2e-5 \ 21 | --warmup_steps 10 \ 22 | --logging_steps 1 \ 23 | --lr_scheduler_type "cosine" \ 24 | --gradient_checkpointing True \ 25 | --deepspeed './train_args/deepspeed_config/ds_config_zero2.json' \ 26 | --bf16 True \ 27 | --torch_dtype bfloat16 \ 28 | --freeze_vision True \ 29 | --freeze_projector False -------------------------------------------------------------------------------- /train_args/__init__.py: -------------------------------------------------------------------------------- 1 | from train_args.sft.base import TrainArgument as sft_TrainArgument 2 | 3 | __all__ = [ 4 | "sft_TrainArgument", 5 | ] 6 | -------------------------------------------------------------------------------- /train_args/common_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | from enum import Enum 4 | 5 | 6 | class TrainMode(Enum): 7 | QLORA = 'qlora' 8 | LORA = 'lora' 9 | FULL = 'full' 10 | 11 | 12 | @dataclass 13 | class CommonArgs: 14 | """ 15 | 一些常用的自定义参数 16 | """ 17 | # Deepspeed相关参数,如出现报错可注释掉 18 | # local_rank: int = field(default=1, metadata={"help": "deepspeed所需参数,单机无需修改,如出现报错可注释掉或添加"}) 19 | 20 | train_args_path: str = 'sft_args' # 训练参数 默认sft_args 21 | max_len: int = field(default=1024, metadata={"help": "最大输入长度"}) 22 | train_data_path: Optional[str] = field(default='./', metadata={"help": "训练集路径"}) 23 | model_name_or_path: str = field(default='./', metadata={"help": "下载的所需模型路径"}) 24 | 25 | # 训练方法相关选择与配置 26 | task_type: str = field(default="sft", metadata={"help": "预训练任务:目前支持sft"}) 27 | train_mode: TrainMode = field(default='lora', metadata={"help": "选择采用的训练方式:[qlora, lora, full]"}) 28 | use_dora: bool = field(default=False, 29 | metadata={"help": "在train_mode==lora时可以使用。是否使用Dora(一个基于lora的变体)"}) 30 | 31 | # lora相关配置 32 | lora_rank: Optional[int] = field(default=64, metadata={"help": "lora rank"}) 33 | lora_alpha: Optional[int] = field(default=16, metadata={"help": "lora alpha"}) 34 | lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "lora dropout"}) 35 | 36 | # 是否自动适配template 37 | auto_adapt: bool = field(default=True, metadata={"help": "选择是否自动适配template,若为False,则直接使用输入数据"}) 38 | # 是否训练中评测 39 | use_eval_in_train: bool = False 40 | test_datasets_path: Optional[str] = None 41 | -------------------------------------------------------------------------------- /train_args/deepspeed_config/ds_config_zero0.json: -------------------------------------------------------------------------------- 1 | { 2 | "gradient_accumulation_steps": "auto", 3 | "gradient_clipping": "auto", 4 | "steps_per_print": 20, 5 | "train_batch_size": "auto", 6 | "train_micro_batch_size_per_gpu": "auto", 7 | "wall_clock_breakdown": false, 8 | 9 | "optimizer": { 10 | "type": "AdamW", 11 | "params": { 12 | "lr": "auto", 13 | "betas": "auto", 14 | "eps": "auto", 15 | "weight_decay": "auto" 16 | } 17 | }, 18 | "scheduler": { 19 | "type": "WarmupLR", 20 | "params": { 21 | "warmup_min_lr": "auto", 22 | "warmup_max_lr": "auto", 23 | "warmup_num_steps": "auto" 24 | } 25 | }, 26 | 27 | "zero_optimization": { 28 | "stage": 0 29 | }, 30 | 31 | "bf16": { 32 | "enabled": "auto" 33 | } 34 | } -------------------------------------------------------------------------------- /train_args/deepspeed_config/ds_config_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "offload_optimizer": { 5 | "device": "cpu", 6 | "pin_memory": true 7 | }, 8 | "allgather_partitions": true, 9 | "allgather_bucket_size": 5e8, 10 | "overlap_comm": true, 11 | "reduce_scatter": true, 12 | "reduce_bucket_size": 5e8, 13 | "contiguous_gradients": true, 14 | "round_robin_gradients": true 15 | }, 16 | "bf16": { 17 | "enabled": "auto" 18 | }, 19 | "optimizer": { 20 | "type": "AdamW", 21 | "params": { 22 | "lr": "auto", 23 | "betas": "auto", 24 | "eps": "auto", 25 | "weight_decay": "auto" 26 | } 27 | }, 28 | "scheduler": { 29 | "type": "WarmupLR", 30 | "params": { 31 | "warmup_min_lr": "auto", 32 | "warmup_max_lr": "auto", 33 | "warmup_num_steps": "auto" 34 | } 35 | }, 36 | "gradient_accumulation_steps": "auto", 37 | "gradient_clipping": "auto", 38 | "steps_per_print": 20, 39 | "train_batch_size": "auto", 40 | "train_micro_batch_size_per_gpu": "auto", 41 | "wall_clock_breakdown": false 42 | } -------------------------------------------------------------------------------- /train_args/deepspeed_config/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | } 14 | 15 | "zero_optimization": { 16 | "stage": 3, 17 | "offload_optimizer": { 18 | "device": "cpu", 19 | "pin_memory": true 20 | }, 21 | "offload_param": { 22 | "device": "cpu", 23 | "pin_memory": true 24 | }, 25 | "overlap_comm": true, 26 | "contiguous_gradients": true, 27 | "sub_group_size": 1e9, 28 | "reduce_bucket_size": "auto", 29 | "stage3_prefetch_bucket_size": "auto", 30 | "stage3_param_persistence_threshold": "auto", 31 | "stage3_max_live_parameters": 1e9, 32 | "stage3_max_reuse_distance": 1e9, 33 | "stage3_gather_16bit_weights_on_model_save": true 34 | }, 35 | 36 | "gradient_accumulation_steps": "auto", 37 | "gradient_clipping": "auto", 38 | "steps_per_print": 20, 39 | "train_batch_size": "auto", 40 | "train_micro_batch_size_per_gpu": "auto", 41 | "wall_clock_breakdown": false 42 | } -------------------------------------------------------------------------------- /train_args/dpo/README.md: -------------------------------------------------------------------------------- 1 | # 更新 2 | trl新版本已经更改dpo实现,故此处已不适用,新的DPO可见[RLHF](./rlhf/README.md)。 3 | 4 | 5 | 6 | # 关于DPO训练 7 | 目前分为两个模式,分别是multi_dpo和single_dpo。**推荐一般使用multi_dpo**。 8 | 9 | DPO训练方式均支持框架中的deepspeed或者python启动模式,相应的lora、qlora也支持。 10 | 11 | 区别在于两种方式的数据组织形式,前者是使用DPOTrainer自动进行数据处理,且是多轮对话形式,参照格式也可将其改为单轮对话,故前者是单轮与多轮通用的。 12 | 13 | 后者是自己从零构建的数据组织形式,理论上按照DPOTrainer相同形式,只实现了单轮。这样的**目的是为了更好地理解DPO的过程以及方便一些魔改操作**,权当学习使用。 14 | 15 | 🤓**注意:** 对于DPO数据,可见```data/dpo_multi_data.jsonl```示例数据。数据是huggingface的hh-rlhf-helpful-base-trl-style格式数据,其中prompt是一句话,而chosen和 16 | rejected则是包含prompt的完整对话。故如构建自己的数据集时,无论多轮和单轮,都应在chosen和rejected中加入prompt,单轮相当于取第一句当prompt, 17 | 多轮相当于取最后一句之前的所有当prompt(其实还可以取每一轮的user当prompt,后面有时间可能会实现)。 18 | 19 | 对于自己构建的single_dpo数据格式,示例为: 20 | ```json lines 21 | {"prompt":"哈喽啊","chosen":"你好", "reject": "不好"} 22 | ``` 23 | 24 | ## 代码位置 25 | 26 | 自己构建的single_dpo数据格式代码在```utils/data_process.py```文件中的```DpoDataset```类。 27 | 28 | 参照官方构建的数据格式在```mian_train.py```中的```load_dpo_dataset```函数里。 29 | 30 | 31 | ## 技术文章 32 | - [DPO训练QWEN2及魔改DPO实现](https://zhuanlan.zhihu.com/p/702569978) 33 | 34 | 35 | ## DPO quick start 36 | 37 | **1、支持命令行传参启动,启动示例可见```LLM-Dojo/run_example.sh```** 38 | 39 | **2、也支持参数文件直接修改默认值,具体如下:** 40 | 41 | ### Step1 配置args.py 42 | 常规的参数在utils下的args.py,基本默认设置即可,你只需要改一下模型路径、输出路径、task_type、template_name、train_data_path、train_args_path、train_mode等。 43 | 44 | 使用multi_dpo时args.py中的max_len和max_prompt_length参数是没用的,需要在后面的dpo_config.py中设置 45 | 46 | 其中: 47 | > train_args_path:为Step2中需要配置的train_args路径 48 | 49 | ### Step2 配置train_args文件夹下对应文件 50 | 相关训练参数在train_args文件夹下对应的文件中。一般就是用```dpo/dpo_config.py```即可 51 | 52 | 均是采用dataclass格式配置参数,直接在default中修改即可,即不需要直接命令行传输参数了。 53 | 54 | 在这里修改max_len和max_prompt_length参数,其他需要设置的是是否选择deepspeed模式训练等参数 55 | 56 | ### Step3 开始训练 57 | 58 | 开始训练就和之前SFT一样了 59 | 60 | 😶Python命令单卡启动: 61 | 62 | 设置好相关配置后即可运行main_train.py进行训练 63 | ```bash 64 | python main_train.py 65 | ``` 66 | 67 | 🙃Deepspeed单卡或多卡启动: 68 | 69 | 使用Deepspeed训练时前两步与常规相同,但需要额外配置ds_config文件,项目中已给出常用的配置示例,位于```train_args/deepspeed_config/```路径下, 70 | 更详细的Deepspeed原理及解释可以看文章:[Deepspeed配置及使用讲解](https://zhuanlan.zhihu.com/p/698631348) 71 | 72 | 运行以下命令启动: 73 | ```bash 74 | deepspeed --include localhost:6,7 main_train.py 75 | ``` 76 | 其中```include localhost```参数用于选择训练的GPU,可选单卡也可选多卡。 77 | 78 | -------------------------------------------------------------------------------- /train_args/sft/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | from transformers import TrainingArguments 4 | 5 | 6 | @dataclass 7 | class TrainArgument(TrainingArguments): 8 | """ 9 | 训练参数, 直接在这里修改即可 10 | """ 11 | output_dir: str = field(default='', metadata={"help": "模型训练完成后的保存路径"}) 12 | num_train_epochs: int = 1, 13 | 14 | per_device_train_batch_size: int = 2 15 | gradient_checkpointing: bool = True 16 | gradient_accumulation_steps: int = 16, 17 | 18 | learning_rate: float = 2e-4 19 | logging_steps: int = 10 20 | save_steps: int = 500 21 | save_strategy: str = "steps" 22 | save_total_limit: int = 2 23 | lr_scheduler_type: str = "constant_with_warmup", 24 | warmup_steps: int = 10 25 | optim: str = 'adamw_torch' 26 | report_to: str = 'tensorboard' 27 | remove_unused_columns: bool = False 28 | bf16: bool = True 29 | fp16: bool = False 30 | 31 | # Deepspeed训练相关参数,不使用时设置为default=None 32 | deepspeed: Optional[str] = field(default=None, metadata={"help": "启用Deepspeed时需要的config文件"}) 33 | -------------------------------------------------------------------------------- /train_args/vlm_config/script_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ScriptArgs: 7 | """ 8 | 自定义参数 9 | """ 10 | # Deepspeed相关参数,如出现报错可注释掉 11 | # local_rank: int = field(default=1, metadata={"help": "deepspeed所需参数,单机无需修改,如出现报错可注释掉或添加"}) 12 | task_type: str = field(default='QA', metadata={"help": "任务类型,目前可选:[QA]"}) 13 | '''多模态任务类型''' 14 | 15 | train_data_path: Optional[str] = field(default='./', metadata={"help": "训练集路径"}) 16 | '''训练集路径''' 17 | 18 | train_mode: str = field(default='lora', metadata={"help": "选择对llm采用的训练方式:[qlora, lora, full]"}) 19 | '''选择对llm采用的训练方式''' 20 | 21 | freeze_vision: bool = True 22 | '''训练是否冻结视觉层''' 23 | 24 | freeze_projector: bool = False 25 | '''训练是否冻结转接层''' 26 | 27 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mst272/LLM-Dojo/6397eff480feca95f3b6016495b30d1ae308c9dc/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_collator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | import torch 3 | from loguru import logger 4 | from PIL import Image 5 | from utils.vlm_template import LlavaTemplateProcessor, Qwen2VLTemplateProcessor 6 | 7 | 8 | class SftDataCollator: 9 | def __init__(self, tokenizer, max_length): 10 | self.tokenizer = tokenizer 11 | self.max_length = max_length 12 | self.pad_token_id = tokenizer.pad_token_id 13 | 14 | def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: 15 | # 找出最大长度 16 | length = [len(x['input_ids']) for x in batch if x['input_ids'] is not None] 17 | # 每个batch中的最大长度 18 | max_batch_length = min(self.max_length, max(length)) 19 | 20 | input_ids_batch, attention_mask_batch, target_mask_batch = [], [], [] 21 | 22 | for x in batch: 23 | input_ids = x['input_ids'] 24 | attention_mask = x['attention_mask'] 25 | target_mask = x['target_mask'] 26 | if input_ids is None: 27 | logger.info('some input_ids is None,and now continue') 28 | continue 29 | padding_len = max_batch_length - len(input_ids) 30 | # 开始padding 31 | input_ids += [self.pad_token_id] * padding_len 32 | attention_mask += [0] * padding_len 33 | target_mask += [0] * padding_len 34 | # 开始截断 35 | input_ids = input_ids[:self.max_length] 36 | attention_mask = attention_mask[:self.max_length] 37 | target_mask = target_mask[:self.max_length] 38 | # 将本批次全部加入列表 39 | input_ids_batch.append(input_ids) 40 | attention_mask_batch.append(attention_mask) 41 | target_mask_batch.append(target_mask) 42 | 43 | # 将list转换为tensor,得到最终的的模型输入 44 | input_ids_batch = torch.tensor(input_ids_batch, dtype=torch.long) 45 | attention_mask_batch = torch.tensor(attention_mask_batch, dtype=torch.long) 46 | target_mask_batch = torch.tensor(target_mask_batch, dtype=torch.long) 47 | 48 | # 计算损失时忽略 49 | labels = torch.where(target_mask_batch == 1, input_ids_batch, -100) 50 | inputs = { 51 | 'input_ids': input_ids_batch, 52 | 'attention_mask': attention_mask_batch, 53 | 'labels': labels 54 | } 55 | return inputs 56 | 57 | 58 | # def llava_template_process(questions: List[str], answers: List[str]) -> List[Dict]: 59 | # converted_data = [] 60 | # for question, answer in zip(questions, answers): 61 | # user_content = [{'index': None, 'text': question, 'type': 'text'}] 62 | # assistant_content = [{'index': None, 'text': answer, 'type': 'text'}] 63 | # 64 | # converted_data.append({'content': user_content, 'role': 'user'}) 65 | # converted_data.append({'content': assistant_content, 'role': 'assistant'}) 66 | # image_dict = {'index': 0, 'text': None, 'type': 'image'} 67 | # converted_data[0]['content'].append(image_dict) 68 | # return converted_data 69 | 70 | processor_class_map = { 71 | 'LlavaProcessor': LlavaTemplateProcessor(), 72 | 'Qwen2VLProcessor': Qwen2VLTemplateProcessor(), 73 | # 可继续添加更多的处理器类 74 | } 75 | 76 | 77 | class VlmQaDataCollator: 78 | def __init__(self, processor): 79 | self.processor = processor 80 | processor_class = processor.to_dict()['processor_class'] 81 | if processor_class not in processor_class_map: 82 | raise ValueError(f"Unknown processor class: {processor_class}") 83 | self.template_process = processor_class_map[processor_class] 84 | 85 | def __call__(self, examples): 86 | texts = [] 87 | images = [] 88 | for example in examples: 89 | standard_example = self.template_process.process(example[0], example[1]) 90 | text = self.processor.apply_chat_template( 91 | standard_example, tokenize=False, add_generation_prompt=False 92 | ) 93 | texts.append(text) 94 | raw_image = Image.open(example[2]) 95 | images.append(raw_image) 96 | 97 | batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True) 98 | 99 | # 这里并没有mask question, 后续可能的扩充是设置mask question的模式。 100 | labels = batch["input_ids"].clone() 101 | labels[labels == self.processor.tokenizer.pad_token_id] = -100 102 | # image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token) 103 | # labels[labels == image_token_id] = -100 104 | batch['labels'] = labels 105 | 106 | return batch 107 | 108 | 109 | class VlmCaptionImageDataCollator: 110 | pass 111 | -------------------------------------------------------------------------------- /utils/eval/README.md: -------------------------------------------------------------------------------- 1 | # Train with prediction eval 2 | 3 | 详细的使用说明及函数文档待更新,目前只是一个dirty的实现。 4 | 5 | 6 | 7 | ## test data 8 | 9 | 评测数据格式为jsonl,以代码test为例,具体可见data/test.jsonl。 10 | 包含两个字段: 11 | - prompt: test 的问题 12 | - label:答案,代码来说即是测试用例 13 | 14 | ## Quick start 15 | 16 | 使用vllm进行生成,其余卡进行训练。 17 | 18 | 启动脚本位置:utils/eval/vllm/run_serve.sh 19 | 20 | 1、启动vllm_serve,例如使用2卡 21 | 22 | ```shell 23 | bash run_serve.sh 24 | ``` 25 | 26 | 2、开启训练 27 | 28 | 启动脚本位置:run_eval_test.sh 29 | 30 | ```shell 31 | bash run_eval_test.sh --eval 32 | ``` 33 | 运行脚本,注意要跟--eval,一些参数配置可参考run_eval_test.sh文件。 34 | 35 | 36 | 37 | 38 | 39 | ## Tip 40 | 41 | wandb出问题可以尝试: 42 | pip install wandb==0.12.18 43 | 44 | 可能出现的问题: 45 | 46 | 1、直接deepspeed --master_port 29508 --include localhost:2,3,4,5,6,7 main_train.py保存checkpoint时有问题,所以建议 47 | accelerate launch --config_file rlhf/ds_config/ds_zero3.yaml main_train.py 48 | 49 | 50 | 2、训练时出现训推不一致问题,训练中评测跟保存后结果对不上,最后找到原因是因为没有enable_prefix_caching=False。 51 | 不过尝试之后仍然会有偏差,但是影响不大,曲线的轨迹是可以反映模型在测试集上的效果的。 52 | 53 | 待验证:可能是由于dropout层的原因,后续计划禁止dropout尝试 54 | 55 | 56 | 参考:https://github.com/huggingface/open-r1/issues/433 57 | 58 | 59 | 60 | ## Reference 61 | 最后,代码借鉴了trl项目,感谢trl为开源做出的贡献。 -------------------------------------------------------------------------------- /utils/eval/callback.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import os 3 | import shutil 4 | from contextlib import nullcontext 5 | from dataclasses import dataclass 6 | import time 7 | from typing import List, Optional 8 | from utils.eval.configs import GenerationConfig 9 | import deepspeed 10 | from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model 11 | import pandas as pd 12 | from trl.import_utils import is_deepspeed_available, is_rich_available, is_vllm_available 13 | from utils.eval.eval_utils import _generate_completions, reason_post_process 14 | import wandb 15 | from datasets import Dataset 16 | from transformers.trainer_callback import ExportableState 17 | from transformers import ( 18 | Trainer, 19 | TrainerCallback 20 | ) 21 | from utils.eval.eval_metric import BaseMetric 22 | from utils.eval.vllm.vllm_client import VLLMClient 23 | 24 | 25 | @dataclass 26 | class CheckpointInfo: 27 | step: int 28 | metric_value: float 29 | path: str 30 | 31 | def __lt__(self, other): 32 | return self.metric_value < other.metric_value 33 | 34 | 35 | class EvaluationCallback(TrainerCallback): 36 | r""" 37 | A [`~transformers.TrainerCallback`] that logs completions and eval metrics to Weights & Biases and/or Comet. 38 | 39 | Usage: 40 | ```python 41 | trainer = Trainer(...) 42 | evaluation_callback = EvaluationCallback(trainer=trainer) 43 | trainer.add_callback(evaluation_callback) 44 | ``` 45 | 46 | Args: 47 | trainer (`Trainer`): 48 | Trainer to which the callback will be attached. 49 | generation_config (`GenerationConfig`, *optional*): 50 | The generation config to use for generating completions. 51 | num_samples (`int` or `None`, *optional*): 52 | The number of prompts to eval for. If not provided, defaults to the number of examples in the evaluation dataset. 53 | freq (`int` or `None`, *optional*): 54 | The frequency at which step to generate and compute metrics. 55 | metric 56 | """ 57 | 58 | def __init__( 59 | self, 60 | trainer: Trainer, 61 | test_datasets: Dataset, 62 | generation_config: GenerationConfig, 63 | num_samples: Optional[int] = None, 64 | freq: Optional[int] = None, 65 | metrics: Optional[BaseMetric] = None, 66 | max_checkpoints: int = 3, 67 | per_device_test_batch_size: int = 1, 68 | higher_better: bool = True, 69 | start_update_best_checkpoints: int = 0, 70 | gather_deepspeed3_params: bool = True, 71 | use_vllm: bool = True, 72 | vllm_server_host: str = "0.0.0.0", 73 | vllm_server_port: int = 8080, 74 | vllm_server_timeout: float = 120.0, 75 | prompts_apply_chat: bool = True 76 | 77 | ): 78 | self.gen_config = generation_config 79 | self.trainer = trainer 80 | self.best_checkpoints: List[CheckpointInfo] = [] 81 | self.max_checkpoints = max_checkpoints # 最大保存数量 82 | self.higher_better = higher_better # 指标是否越大越好 83 | self._last_logged_step = -1 84 | self.batch_size = per_device_test_batch_size 85 | self.table = [] 86 | self.freq = freq 87 | self.metric = metrics 88 | self.start_update_best_checkpoints = start_update_best_checkpoints 89 | self.gather_deepspeed3_params = gather_deepspeed3_params 90 | self.prompts_apply_chat = prompts_apply_chat 91 | self.use_vllm = use_vllm 92 | 93 | if self.metric is None: 94 | raise ValueError("You must provide a metric[BaseMetric]") 95 | 96 | if num_samples is not None: 97 | self.sample_dataset = test_datasets.select(range(num_samples)) 98 | else: 99 | self.sample_dataset = test_datasets 100 | 101 | # 配置vllm client 102 | if use_vllm: 103 | if not is_vllm_available(): 104 | raise ImportError( 105 | "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " 106 | "`pip install vllm` to use it." 107 | ) 108 | if self.trainer.accelerator.is_main_process: 109 | self.vllm_client = VLLMClient(host=vllm_server_host, server_port=vllm_server_port, 110 | connection_timeout=vllm_server_timeout) 111 | self._last_loaded_step = 0 112 | self.trainer.accelerator.wait_for_everyone() 113 | 114 | def _move_model_to_vllm(self): 115 | # For DeepSpeed ZeRO-3, we need to gather all parameters before operations 116 | deepspeed_plugin = self.trainer.accelerator.state.deepspeed_plugin 117 | zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 118 | gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext 119 | 120 | if is_peft_model(self.trainer.model): 121 | # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging 122 | # adapters in a sharded manner is not supported. 123 | with gather_if_zero3(list(self.trainer.model.parameters())): 124 | self.trainer.model.merge_adapter() 125 | 126 | # Update vLLM weights while parameters are gathered 127 | for name, param in self.trainer.model.named_parameters(): 128 | # When using PEFT, we need to recover the original parameter name and discard some parameters 129 | name = name.removeprefix("base_model.model.").replace(".base_layer", "") 130 | if self.trainer.model.prefix in name: 131 | continue 132 | # When module to save, remove its prefix and discard the original module 133 | if "original_module" in name: 134 | continue 135 | name = name.replace("modules_to_save.default.", "") 136 | 137 | if self.trainer.accelerator.is_main_process: 138 | self.vllm_client.update_named_param(name, param.data) 139 | 140 | # Unmerge adapters while parameters are still gathered 141 | self.trainer.model.unmerge_adapter() 142 | # Parameters will automatically be repartitioned when exiting the context 143 | else: 144 | # For non-PEFT models, simply gather and update each parameter individually. 145 | for name, param in self.trainer.model.named_parameters(): 146 | with gather_if_zero3([param]): 147 | if self.trainer.accelerator.is_main_process: 148 | self.vllm_client.update_named_param(name, param.data) 149 | # Reset cache on main process 150 | if self.trainer.accelerator.is_main_process: 151 | self.vllm_client.reset_prefix_cache() 152 | 153 | def samples_generate_vllm(self, steps): 154 | """ 155 | prompts: 156 | labels: 157 | """ 158 | # First, have main process load weights if needed 159 | if steps != self._last_loaded_step: 160 | self._move_model_to_vllm() 161 | self._last_loaded_step = steps 162 | # Generate completions using vLLM: gather all prompts and use them in a single call in the main process 163 | # all_prompts_text = gather_object(prompts) 164 | if self.trainer.accelerator.is_main_process: 165 | # data process 166 | prompts = self.metric.get_prompts(self.sample_dataset, self.trainer.processing_class, 167 | self.prompts_apply_chat) 168 | labels = self.metric.get_labels(self.sample_dataset) 169 | 170 | # todo: with profiling_context(self, "vLLM.generate"): 上下文时间处理 171 | start_time = time.time() 172 | completion_ids = self.vllm_client.generate( 173 | prompts=prompts, 174 | n=self.gen_config.num_generation, 175 | repetition_penalty=self.gen_config.repetition_penalty, 176 | temperature=self.gen_config.temperature, 177 | top_p=self.gen_config.top_p, 178 | top_k=self.gen_config.top_k, 179 | min_p=self.gen_config.min_p, 180 | max_tokens=self.gen_config.max_new_tokens 181 | ) 182 | end_time = time.time() # 记录 _generate_completions 结束时间 183 | generation_time = end_time - start_time # 计算生成耗时 184 | print(f"\nProcess main: Generation time: {generation_time:.4f} seconds") 185 | 186 | tokenizer = self.trainer.processing_class 187 | # tokenizer.padding_side = "left" 188 | completions = tokenizer.batch_decode(completion_ids) # --> List[str] 189 | generations = self.metric.extract_generation(completions) 190 | score = self.metric.compute(references=labels, predictions=generations) 191 | else: 192 | score = None 193 | return score 194 | 195 | def samples_generate_split_between_processes(self, steps): 196 | """ 197 | if model very large, maybe OOM. 198 | """ 199 | labels = [example['message'][1]['content'] for example in self.sample_dataset] 200 | tokenizer = self.trainer.processing_class 201 | tokenizer.padding_side = "left" 202 | accelerator = self.trainer.accelerator 203 | model = self.trainer.model_wrapped 204 | start_time = time.time() 205 | with accelerator.split_between_processes(self.sample_dataset['message']) as prompts_split: 206 | prompts = [] 207 | for lis in prompts_split: 208 | prompts.append('补全下面代码,将最终题目和答案返回在代码框中\n' + lis[0]['content']) 209 | completions = _generate_completions( 210 | prompts, 211 | model=model, 212 | tokenizer=tokenizer, 213 | accelerator=accelerator, 214 | generation_config=self.gen_config, 215 | batch_size=self.batch_size, 216 | gather_deepspeed3_params=self.gather_deepspeed3_params 217 | ) 218 | completions = gather_object(completions) 219 | prompts = gather_object(prompts) 220 | end_time = time.time() # 记录 _generate_completions 结束时间 221 | generation_time = end_time - start_time # 计算生成耗时 222 | 223 | generations = [[reason_post_process(c, i)] for i, c in enumerate(completions)] 224 | print(f"Process {accelerator.process_index}: Generation time: {generation_time:.4f} seconds") 225 | 226 | if len(self.sample_dataset) < accelerator.num_processes: 227 | generations = generations[:len(labels)] 228 | # 处理输出表格数据 229 | if self.trainer.accelerator.is_main_process: 230 | global_step = [str(steps)] * len(prompts) 231 | config_keys = list(self.gen_config.to_dict().keys()) 232 | config_values = list(self.gen_config.to_dict().values()) 233 | data = [[global_step[i], prompts[i], completions[i]] + config_values for i in range(len(prompts))] 234 | self.table.extend(data) 235 | table = pd.DataFrame(columns=["step", "prompt", "completion"] + config_keys, data=self.table) 236 | wandb.log({"completions": table}) 237 | 238 | score = self.metric.compute(references=labels, predictions=generations) 239 | return score 240 | 241 | def save_best_metric_model(self, args, state): 242 | # Save model checkpoint 243 | print('开始保存checkpoint') 244 | checkpoint_folder = f"checkpoint-{state.global_step}" 245 | output_dir = os.path.join(args.output_dir, 'best_model', checkpoint_folder) 246 | 247 | self.trainer.save_model(output_dir) 248 | 249 | if not args.save_only_model: 250 | # Save optimizer and scheduler 251 | self.trainer._save_optimizer_and_scheduler(output_dir) 252 | self.trainer._save_scaler(output_dir) 253 | # Save RNG state 254 | self.trainer._save_rng_state(output_dir) 255 | if args.should_save: 256 | # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently 257 | for cb in [ 258 | cb for cb in self.trainer.callback_handler.callbacks + [self.trainer.control] if 259 | isinstance(cb, ExportableState) 260 | ]: 261 | cb_name = cb.__class__.__name__ 262 | cb_state = cb.state() 263 | if isinstance(state.stateful_callbacks[cb_name], list): 264 | state.stateful_callbacks[cb_name].append(cb_state) 265 | else: 266 | state.stateful_callbacks[cb_name] = cb_state 267 | state.save_to_json(os.path.join(output_dir, 'trainer_state.json')) 268 | 269 | return output_dir 270 | 271 | def update_best_checkpoints1(self, args, state, custom_score): 272 | """更新最佳checkpoint列表""" 273 | if state.global_step < self.start_update_best_checkpoints: 274 | return 275 | print('更新最佳checkpoint列表') 276 | # 对于越小越好的指标(如loss),转换为负数以便统一处理 277 | metric_value = custom_score if self.higher_better else -custom_score 278 | 279 | # 如果还没有达到最大数量,或者当前指标比最差的更好 280 | if (len(self.best_checkpoints) < self.max_checkpoints or 281 | metric_value > self.best_checkpoints[0].metric_value): 282 | 283 | # 保存新的checkpoint 284 | checkpoint_path = self.save_best_metric_model(args, state) 285 | 286 | # 创建新的CheckpointInfo对象 287 | checkpoint_info = CheckpointInfo( 288 | step=state.global_step, 289 | metric_value=metric_value, 290 | path=checkpoint_path 291 | ) 292 | 293 | # 更新最佳checkpoint列表 294 | heapq.heappush(self.best_checkpoints, checkpoint_info) 295 | 296 | # 如果超过最大数量,删除最差的checkpoint 297 | if len(self.best_checkpoints) > self.max_checkpoints: 298 | worst_checkpoint = heapq.heappop(self.best_checkpoints) 299 | print(f"Deleting older checkpoint [{worst_checkpoint.path}] due to args.save_total_limit") 300 | shutil.rmtree(worst_checkpoint.path, ignore_errors=True) 301 | 302 | def update_best_checkpoints(self, args, state, custom_score): 303 | # 从主进程广播 custom_score,确保所有进程有相同数据 304 | custom_score = broadcast_object_list([custom_score], from_process=0)[0] 305 | 306 | if state.global_step < self.start_update_best_checkpoints: 307 | return 308 | 309 | metric_value = custom_score if self.higher_better else -custom_score 310 | 311 | if self.trainer.accelerator.is_main_process: 312 | # 仅主进程决定是否保存 313 | if len(self.best_checkpoints) < self.max_checkpoints or (metric_value > self.best_checkpoints[ 314 | 0].metric_value and state.global_step % args.save_steps != 0): 315 | save_flag = True 316 | else: 317 | save_flag = False 318 | else: 319 | save_flag = False 320 | 321 | # 广播保存决定到所有进程 322 | save_flag = broadcast_object_list([save_flag], from_process=0)[0] 323 | 324 | if save_flag: 325 | checkpoint_path = self.save_best_metric_model(args, state) 326 | # if self.trainer.accelerator.is_main_process: 327 | checkpoint_info = CheckpointInfo(step=state.global_step, metric_value=metric_value, 328 | path=checkpoint_path) 329 | heapq.heappush(self.best_checkpoints, checkpoint_info) 330 | if len(self.best_checkpoints) > self.max_checkpoints: 331 | worst_checkpoint = heapq.heappop(self.best_checkpoints) 332 | print(f"Deleting older checkpoint [{worst_checkpoint.path}] due to args.save_total_limit") 333 | shutil.rmtree(worst_checkpoint.path, ignore_errors=True) 334 | 335 | def on_step_end(self, args, state, control, **kwargs): 336 | # Only log once per step (this method may be called multiple times) 337 | if state.global_step == self._last_logged_step: 338 | return 339 | 340 | # Only log every `freq` steps 341 | freq = self.freq 342 | if state.global_step % freq != 0: 343 | return 344 | 345 | # todo: 改成字典,返回多个指标可视化,选其中一个或者混合作为保存指标? use_vllm=False的逻辑? 346 | custom_score = self.samples_generate_vllm(state.global_step) if self.use_vllm else None 347 | self.trainer.log({"custom_score": custom_score, "step": state.global_step}) 348 | 349 | self.update_best_checkpoints(args, state, custom_score) 350 | # Save the last logged step, so we don't log the same completions multiple times 351 | self._last_logged_step = state.global_step 352 | -------------------------------------------------------------------------------- /utils/eval/configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union, List, Dict 3 | from transformers import TrainingArguments 4 | 5 | 6 | @dataclass 7 | class GenerationConfig: 8 | """ 9 | generation config, Vllm or model.generate 10 | 11 | Parameters: 12 | max_new_tokens (`int`, *optional*, defaults to `16`): 13 | Maximum number of tokens to generate for each prompt 14 | num_generation (`int`, *optional*, defaults to `1`): 15 | Number of completions to generate for each prompt. 16 | temperature: 17 | generate temperature 18 | top_p: 19 | top_k: 20 | do_sample: 21 | min_p 22 | 23 | """ 24 | num_generation: int = 1 25 | repetition_penalty: float = 1.0 26 | temperature: float = 1.0 27 | top_p: float = 1.0 28 | top_k: int = -1 29 | min_p: float = 0.0 30 | max_new_tokens: int = 1024 31 | do_sample: bool = False 32 | 33 | # unwrap_model_for_generation 34 | gather_deepspeed3_params: bool = True # if OOM, False it. 35 | 36 | 37 | @dataclass 38 | class EvaluationConfig: 39 | """ 40 | Common test config 41 | 42 | Parameters: 43 | num_samples: 在测试集中随机选数量 44 | freq: 45 | metrics: 46 | """ 47 | # 基础评估设置 48 | num_samples: Optional[int] = None 49 | freq: int = 5 50 | metrics: str = 'code' # ['code', 'em'] or metrics=[{'name': 'code', 51 | # 'weight': 0.7},{'name': 'em', 'weight': 0.3}] 52 | 53 | higher_better: bool = True 54 | prompts_apply_chat: bool = False 55 | 56 | use_vllm: bool = True # whether to use vllm to generate, if false, use unwrap_model_for_generation 57 | vllm_server_host: str = "0.0.0.0" 58 | vllm_server_port: int = 8080 59 | vllm_server_timeout: float = 120.0 60 | 61 | per_device_test_batch_size: int = 1 # Only use when use_vllm is False 62 | 63 | # Checkpoint管理 64 | save_best_checkpoints: bool = True 65 | start_update_best_checkpoints: int = 20 66 | max_checkpoints: int = 3 67 | 68 | 69 | if __name__ == "__main__": 70 | gen = GenerationConfig() 71 | print(gen.stop_strings) 72 | -------------------------------------------------------------------------------- /utils/eval/eval_metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import List, Dict, Union, Type, Optional 4 | 5 | import psutil 6 | 7 | import evaluate 8 | from abc import ABC, abstractmethod 9 | import gc 10 | from datasets import Dataset 11 | 12 | from utils.eval.configs import EvaluationConfig 13 | from contextlib import contextmanager 14 | 15 | 16 | # 添加新的评估方式,只需要在此文件中创建新的评估指标类 17 | 18 | class BaseMetric(ABC): 19 | @abstractmethod 20 | def compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]: 21 | pass 22 | 23 | @abstractmethod 24 | def extract_generation(self, completions: List[str]): 25 | pass 26 | 27 | def get_prompts(self, test_datasets: Dataset, tokenizer, prompts_apply_chat) -> List[str]: 28 | """ 29 | The default get prompts(questions) from the test data. 30 | 31 | Args: 32 | test_datasets: dataset must include a `"prompt"` column containing the prompts for generating completions. 33 | tokenizer: 34 | prompts_apply_chat: : 35 | Returns: 36 | prompts, ['hellow','why'] 37 | """ 38 | if prompts_apply_chat: 39 | pass 40 | else: 41 | prompts = [test_dataset['prompt'] for test_dataset in test_datasets] 42 | return prompts 43 | 44 | @staticmethod 45 | def get_labels(test_datasets: Dataset): 46 | labels = [test_dataset['label'] for test_dataset in test_datasets] 47 | return labels 48 | 49 | 50 | class CodeEvalMetric(BaseMetric): 51 | def __init__(self, metric_path: str = './utils/eval/metrics/code_eval'): 52 | # self.metric = evaluate.load(metric_path) 53 | self.metric_path = metric_path 54 | self.memory_threshold = 0.8 # 80% memory usage threshold 55 | 56 | def _check_memory_usage(self) -> Optional[float]: 57 | """Monitor memory usage""" 58 | process = psutil.Process(os.getpid()) 59 | memory_info = process.memory_info() 60 | memory_percent = process.memory_percent() 61 | 62 | print(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f}MB ({memory_percent:.1f}%)") 63 | 64 | return memory_percent 65 | 66 | @contextmanager 67 | def load_metric(self): 68 | """Context manager for loading and cleaning up metric with memory monitoring""" 69 | try: 70 | memory_percent = self._check_memory_usage() 71 | if memory_percent > self.memory_threshold*100: 72 | print(f"Warning: High memory usage detected: {memory_percent:.1f}%") 73 | gc.collect() 74 | 75 | metric = evaluate.load(self.metric_path) 76 | yield metric 77 | finally: 78 | del metric 79 | gc.collect() 80 | self._check_memory_usage() 81 | 82 | def compute(self, predictions: List[List[str]], references: List[str]) -> float: 83 | with self.load_metric() as metric: 84 | try: 85 | pass_at_k, results = metric.compute( 86 | predictions=predictions, 87 | references=references, 88 | k=[1] 89 | ) 90 | return float(pass_at_k["pass@1"]) 91 | except Exception as e: 92 | print(f"Error during computation: {str(e)}") 93 | raise e 94 | 95 | 96 | def extract_generation(self, completions: List[str]): 97 | """ 98 | extract generation and process data to compute[predictions] format 99 | 100 | Args: 101 | completions List[str]: 输入字符串。 102 | 103 | Returns: 104 | self.compute[predictions] format. For example, there is List[List[str]] 105 | """ 106 | 107 | def extract(code: str, index: int): 108 | # Look for code blocks 109 | code_pattern = r'```(?:python|go|javascript|java|bash|js|cpp|cs|php)(.*?)```' 110 | code_match = re.findall(code_pattern, code, re.DOTALL) 111 | if code_match: 112 | # If code block exists, return its content (excluding the ``` markers) 113 | return code_match[-1].strip() 114 | else: 115 | # If no code block, return the solution content directly 116 | return str(index) 117 | 118 | generations = [[extract(c, i)] for i, c in enumerate(completions)] 119 | return generations 120 | 121 | def get_prompts(self, test_datasets: Dataset, tokenizer, prompts_apply_chat) -> List[str]: 122 | if prompts_apply_chat: 123 | prompts = [test_dataset['prompt'] for test_dataset in test_datasets] 124 | messages = [[{"role": "user", "content": prompt}] for prompt in prompts] 125 | res = [tokenizer.apply_chat_template(message, tokenize=False,add_generation_prompt=True) for message in messages] 126 | return res 127 | 128 | else: 129 | prompts = [test_dataset['prompt'] for test_dataset in test_datasets] 130 | return prompts 131 | 132 | 133 | class ExactMatchMetric(BaseMetric): 134 | def compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]: 135 | pass 136 | 137 | def extract_generation(self, completions: List[str]): 138 | return completions 139 | 140 | 141 | # todo: 未完成 142 | class CompositeMetric(BaseMetric): 143 | """简化的混合指标类""" 144 | 145 | def __init__(self, metric_configs: List[Dict[str, Union[str, float]]]): 146 | self.metrics = [] 147 | self.weights = [] 148 | 149 | # 创建指标实例 150 | for config in metric_configs: 151 | metric_name = config['name'] 152 | weight = config.get('weight', 1.0) 153 | 154 | if metric_name == 'code': 155 | metric = CodeEvalMetric() 156 | elif metric_name == 'em': 157 | metric = ExactMatchMetric() 158 | else: 159 | raise ValueError(f"不支持的指标类型: {metric_name}") 160 | 161 | self.metrics.append(metric) 162 | self.weights.append(weight) 163 | 164 | def compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]: 165 | results = {} 166 | total_score = 0.0 167 | 168 | for metric, weight in zip(self.metrics, self.weights): 169 | metric_results = metric.compute(predictions, references) 170 | for key, value in metric_results.items(): 171 | weighted_value = value * weight 172 | results[f"{metric.__class__.__name__}_{key}"] = weighted_value 173 | total_score += weighted_value 174 | 175 | results['composite_score'] = total_score 176 | return results 177 | 178 | def extract_generation(self, generate, index): 179 | # 使用第一个metric的提取方法 180 | return self.metrics[0].extract_generation(generate, index) 181 | 182 | def get_prompts(self, test_datasets: Dataset, tokenizer, args) -> List[str]: 183 | # 使用第一个metric的prompt获取方法 184 | return self.metrics[0].get_prompts(test_datasets, tokenizer, args) 185 | 186 | 187 | # 定义指标类型映射字典 188 | METRIC_REGISTRY: Dict[str, Type[BaseMetric]] = { 189 | 'code': CodeEvalMetric, 190 | 'em': ExactMatchMetric 191 | } 192 | 193 | 194 | def create_metric(config: EvaluationConfig) -> BaseMetric: 195 | """根据配置创建指标 196 | 197 | Args: 198 | config: 评估配置对象 199 | 200 | Returns: 201 | BaseMetric: 创建的指标对象 202 | 203 | Raises: 204 | ValueError: 当指定了不支持的指标类型时 205 | """ 206 | # 单个指标的情况 207 | if isinstance(config.metrics, str): 208 | metric_class = METRIC_REGISTRY.get(config.metrics) 209 | if not metric_class: 210 | raise ValueError(f"不支持的指标类型: {config.metrics}") 211 | return metric_class() 212 | 213 | # 混合指标的情况 214 | return CompositeMetric(config.metrics) 215 | -------------------------------------------------------------------------------- /utils/eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import re 3 | from trl.models.utils import unwrap_model_for_generation 4 | from accelerate import Accelerator 5 | from transformers import ( 6 | PreTrainedModel, 7 | PreTrainedTokenizerBase, 8 | GenerationConfig 9 | ) 10 | from tqdm.auto import tqdm 11 | import torch 12 | 13 | 14 | def _generate_completions( 15 | prompts: list[str], 16 | model: PreTrainedModel, 17 | tokenizer: PreTrainedTokenizerBase, 18 | accelerator: Accelerator, 19 | generation_config: Optional[GenerationConfig], 20 | batch_size: int = 1, 21 | gather_deepspeed3_params: bool = True, 22 | ) -> list[str]: 23 | """ 24 | Generates completions for a list of pre-formatted prompts from the given model. 25 | 26 | Args: 27 | prompts (list[str]): A list of input prompts for which completions are to be generated. 28 | model (PreTrainedModel): The pre-trained model to be used for generation. 29 | tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for encoding and decoding. 30 | accelerator (Accelerator): The accelerator to be used for model execution. 31 | generation_config (GenerationConfig): Configuration for text generation. 32 | batch_size (int, optional): The number of prompts to process in each batch. Default is 1. 33 | gather_deepspeed3_params: bool = True: if OOM, False it. 34 | 35 | Returns: 36 | list[str]: A list of generated text completions corresponding to the input prompts. 37 | """ 38 | completions = [] 39 | with unwrap_model_for_generation(model, accelerator, 40 | gather_deepspeed3_params=gather_deepspeed3_params) as unwrapped_model: 41 | # 创建分布式安全的进度条(仅在主进程显示) 42 | total_batches = len(prompts) // batch_size + (1 if len(prompts) % batch_size != 0 else 0) 43 | 44 | progress_bar = tqdm( 45 | total=total_batches, 46 | desc="Generating Completions", 47 | disable=not accelerator.is_main_process, # 非主进程禁用进度条 48 | dynamic_ncols=True # 自动适应终端宽度 49 | ) 50 | 51 | for idx in range(0, len(prompts), batch_size): 52 | batch = prompts[idx: idx + batch_size] 53 | tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device) 54 | generations = unwrapped_model.generate( 55 | **tokenized_batch, 56 | generation_config=generation_config, 57 | pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, 58 | ) 59 | for prompt, generation in zip(tokenized_batch.input_ids, generations): 60 | # Remove prompt from generation 61 | generation = generation[len(prompt):] 62 | completion = tokenizer.decode(generation, skip_special_tokens=True) 63 | completions.append(completion) 64 | # 更新进度条(自动处理分布式同步) 65 | progress_bar.update(1) 66 | progress_bar.close() 67 | return completions 68 | 69 | 70 | def reason_post_process(code, index): 71 | """ 72 | 73 | Args: 74 | code (str): 输入字符串。 75 | index (int/str): 当前字符串的序号 (索引)。 76 | 77 | Returns: 78 | str 或 int: 如果找到代码块,则返回代码块字符串; 79 | 否则,返回输入的字符串序号 (index)。 80 | """ 81 | 82 | # Look for code blocks 83 | code_pattern = r'```(?:python|go|javascript|java|bash|js|cpp|cs|php)(.*?)```' 84 | code_match = re.findall(code_pattern, code, re.DOTALL) 85 | 86 | if code_match: 87 | # If code block exists, return its content (excluding the ``` markers) 88 | return code_match[-1].strip() 89 | else: 90 | # If no code block, return the solution content directly 91 | return str(index) 92 | -------------------------------------------------------------------------------- /utils/eval/train_script.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import time 3 | import shutil 4 | import itertools 5 | from dataclasses import dataclass 6 | from random import random 7 | from typing import List, Optional 8 | import re 9 | from contextlib import contextmanager 10 | from transformers.integrations import WandbCallback 11 | import torch 12 | from datasets import load_dataset 13 | import wandb 14 | from accelerate.utils import gather_object 15 | import os 16 | import deepspeed 17 | from trl.models.utils import unwrap_model_for_generation 18 | from accelerate import Accelerator 19 | from transformers import ( 20 | GenerationConfig, 21 | PreTrainedModel, 22 | PreTrainedTokenizerBase, 23 | Trainer, 24 | TrainingArguments, 25 | ) 26 | from tqdm.auto import tqdm 27 | from transformers import GenerationConfig, Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM 28 | from transformers.trainer_callback import ExportableState 29 | import evaluate 30 | from eval_metric import CodeEvalMetric 31 | from utils import MultiRoundDataProcess, SftDataCollator 32 | from callback import EvaluationCallback 33 | 34 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 35 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 36 | 37 | batch_size = 1 38 | gradient_accumulation_steps = 2 39 | num_train_epochs = 1 40 | 41 | training_args = TrainingArguments( 42 | output_dir="./output/", 43 | report_to="wandb", # this tells the Trainer to log the metrics to W&B 44 | per_device_train_batch_size=batch_size, 45 | bf16=True, 46 | learning_rate=2e-5, 47 | lr_scheduler_type="cosine", 48 | warmup_ratio=0.1, 49 | save_strategy="steps", 50 | save_steps=20, 51 | save_total_limit=2, 52 | gradient_accumulation_steps=gradient_accumulation_steps, 53 | gradient_checkpointing=True, 54 | num_train_epochs=num_train_epochs, 55 | # logging strategies 56 | logging_strategy="steps", 57 | logging_steps=2, 58 | torch_compile=False, 59 | remove_unused_columns=False, 60 | deepspeed='deepspeed_config/ds_config_zero3.json' 61 | ) 62 | 63 | if __name__ == "__main__": 64 | model_name_or_path = '/Qwen2.5-Coder-32B-Instruct' 65 | train_data_path = 'train_data/fix_bash1k.jsonl' 66 | test_data_path = 'eval_train_test/test.jsonl' 67 | 68 | max_len = 4096 69 | auto_adapt = False 70 | 71 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) 72 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) 73 | 74 | train_dataset = MultiRoundDataProcess(train_data_path, tokenizer, max_len, auto_adapt) 75 | 76 | test_dataset = load_dataset(path="json", data_files=test_data_path) 77 | test_dataset = test_dataset['train'] 78 | 79 | data_collator = SftDataCollator(tokenizer, max_len) 80 | generate_config = GenerationConfig( 81 | max_new_tokens=4096, 82 | max_length=max_len, 83 | use_cache=True 84 | ) 85 | 86 | trainer = Trainer( 87 | model=model, 88 | args=training_args, 89 | train_dataset=train_dataset, 90 | data_collator=data_collator, 91 | processing_class=tokenizer 92 | ) 93 | 94 | # if os.environ.get('LOCAL_RANK', '0') == '0': # 只在主进程中初始化 95 | # wandb.init(project="huggingface") 96 | # wandb.init(project="huggingface") 97 | 98 | wandb_callback = EvaluationCallback( 99 | trainer=trainer, 100 | test_dataset=test_dataset, 101 | generation_config=generate_config, 102 | num_samples=6, 103 | freq=1, 104 | metric=CodeEvalMetric(), 105 | max_checkpoints=1, 106 | per_device_test_batch_size=1, 107 | higher_better=True, 108 | start_update_best_checkpoints=100 109 | ) 110 | trainer.add_callback(wandb_callback) 111 | 112 | trainer.train() -------------------------------------------------------------------------------- /utils/eval/vllm/run_serve.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH='Qwen2.5-Coder-32B-Instruct' 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0,1 python vllm_serve.py\ 5 | --model "$MODEL_PATH" \ 6 | --tensor_parallel_size 2 \ 7 | --max_model_len 4096 \ 8 | --port 8001 \ 9 | --dtype "bfloat16" \ 10 | --enable_prefix_caching False -------------------------------------------------------------------------------- /utils/eval/vllm/vllm_client.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import logging 3 | import time 4 | from typing import Optional 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from trl.import_utils import is_requests_available, is_vllm_available 10 | 11 | if is_requests_available(): 12 | import requests 13 | from requests import ConnectionError 14 | 15 | if is_vllm_available(): 16 | from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator 17 | from vllm.distributed.utils import StatelessProcessGroup 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class VLLMClient: 23 | """ 24 | A client class to interact with a vLLM server. 25 | 26 | This class provides methods to generate completions, initialize and manage weight update groups, and update model 27 | weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. 28 | 29 | Args: 30 | host (`str`, *optional*, defaults to `"0.0.0.0"`): 31 | IP address of the vLLM server. 32 | server_port (`int`, *optional*, defaults to `8000`): 33 | Port number of the vLLM server. 34 | group_port (`int`, *optional*, defaults to `51216`): 35 | Port number for the weight update group. 36 | connection_timeout (`float`, *optional*, defaults to `0.0`): 37 | Total timeout duration in seconds to wait for the server to be up. If the server is not up after the 38 | timeout, a `ConnectionError` is raised. 39 | 40 | Examples: 41 | Run the vLLM server with the model `Qwen/Qwen2.5-7B`: 42 | 43 | ``` 44 | $ trl vllm-serve --model Qwen/Qwen2.5-7B 45 | ... 46 | INFO: Application startup complete. 47 | INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) 48 | ``` 49 | 50 | Use the client to generate completions and update model weights: 51 | 52 | ```python 53 | >>> from trl.extras.vllm_client import VLLMClient 54 | >>> client = VLLMClient() 55 | >>> client.generate(["Hello, AI!", "Tell me a joke"]) 56 | [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025], 57 | [911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]] 58 | 59 | >>> from transformers import AutoModelForCausalLM 60 | >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda") 61 | >>> client.update_model_params(model) 62 | ``` 63 | """ 64 | 65 | def __init__( 66 | self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, 67 | connection_timeout: float = 0.0 68 | ): 69 | if not is_requests_available(): 70 | raise ImportError("requests is not installed. Please install it with `pip install requests`.") 71 | if not is_vllm_available(): 72 | raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") 73 | 74 | self.session = requests.Session() 75 | self.host = host 76 | self.server_port = server_port 77 | self.group_port = group_port 78 | self.check_server(connection_timeout) # check server and fail after timeout 79 | self.init_communicator() 80 | atexit.register(self.close_communicator) # when the client object is deleted, close the weight update group 81 | 82 | def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): 83 | """ 84 | Check server availability with retries on failure, within a total timeout duration. If the server is not up 85 | after the total timeout duration, raise a `ConnectionError`. 86 | 87 | Args: 88 | retry_interval (`float`, *optional*, defaults to `2.0`): 89 | Interval in seconds between retries. 90 | total_timeout (`float`, *optional*, defaults to `0.0`): 91 | Total timeout duration in seconds. 92 | """ 93 | url = f"http://{self.host}:{self.server_port}/health/" 94 | start_time = time.time() # Record the start time 95 | 96 | while True: 97 | try: 98 | # response = requests.get(url) 99 | response = requests.get(url, proxies={"http": None, "https": None}) # todo: 排查 100 | except requests.exceptions.RequestException as exc: 101 | # Check if the total timeout duration has passed 102 | elapsed_time = time.time() - start_time 103 | if elapsed_time >= total_timeout: 104 | raise ConnectionError( 105 | f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} " 106 | "seconds. Make sure the server is running by running `trl vllm-serve`." 107 | ) from exc 108 | else: 109 | if response.status_code == 200: 110 | logger.info("Server is up!") 111 | return None 112 | 113 | # Retry logic: wait before trying again 114 | logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...") 115 | time.sleep(retry_interval) 116 | 117 | def generate( 118 | self, 119 | prompts: list[str], 120 | n: int = 1, 121 | repetition_penalty: float = 1.0, 122 | temperature: float = 1.0, 123 | top_p: float = 1.0, 124 | top_k: int = -1, 125 | min_p: float = 0.0, 126 | max_tokens: int = 16, 127 | guided_decoding_regex: Optional[str] = None, 128 | ) -> list[list[str]]: 129 | """ 130 | Generates model completions for the provided prompts. 131 | 132 | Args: 133 | prompts (`list[str]`): 134 | List of text prompts for which the model will generate completions. 135 | n (`int`, *optional*, defaults to `1`): 136 | Number of completions to generate for each prompt. 137 | repetition_penalty (`float`, *optional*, defaults to `1.0`): 138 | Parameter for repetition penalty. 1.0 means no penalty. 139 | temperature (`float`, *optional*, defaults to `1.0`): 140 | Temperature parameter for sampling. Higher values increase diversity. 141 | top_p (`float`, *optional*, defaults to `1.0`): 142 | Top-p sampling parameter.`1.0` means no truncation. 143 | top_k (`int`, *optional*, defaults to `-1`): 144 | Top-k sampling parameter. `-1` means no truncation. 145 | min_p (`float`, *optional*, defaults to `0.0`): 146 | Minimum probability for sampling. 147 | max_tokens (`int`, *optional*, defaults to `16`): 148 | Maximum number of tokens to generate for each prompt. 149 | guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): 150 | Regular expression to guide the decoding process. 151 | 152 | Returns: 153 | `list[list[int]]`: 154 | List of lists of token IDs representing the model-generated completions for each prompt. 155 | """ 156 | url = f"http://{self.host}:{self.server_port}/generate/" 157 | response = self.session.post( 158 | url, 159 | json={ 160 | "prompts": prompts, 161 | "n": n, 162 | "repetition_penalty": repetition_penalty, 163 | "temperature": temperature, 164 | "top_p": top_p, 165 | "top_k": top_k, 166 | "min_p": min_p, 167 | "max_tokens": max_tokens, 168 | "guided_decoding_regex": guided_decoding_regex, 169 | }, 170 | ) 171 | if response.status_code == 200: 172 | return response.json()["completion_ids"] 173 | else: 174 | raise Exception(f"Request failed: {response.status_code}, {response.text}") 175 | 176 | def init_communicator(self): 177 | """ 178 | Initializes the weight update group in a distributed setup for model synchronization. 179 | """ 180 | # Get the tensor parallel size from the server 181 | url = f"http://{self.host}:{self.server_port}/get_tensor_parallel_size/" 182 | # response = requests.get(url) 183 | response = requests.get(url, proxies={"http": None, "https": None}) 184 | if response.status_code == 200: 185 | tensor_parallel_size = response.json()["tensor_parallel_size"] 186 | else: 187 | raise Exception(f"Request failed: {response.status_code}, {response.text}") 188 | 189 | world_size = tensor_parallel_size + 1 190 | self.rank = tensor_parallel_size # The client's rank is the last process 191 | 192 | # Initialize weight update group 193 | url = f"http://{self.host}:{self.server_port}/init_communicator/" 194 | # In the server side, the host is set to 0.0.0.0 195 | response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size}) 196 | if response.status_code != 200: 197 | raise Exception(f"Request failed: {response.status_code}, {response.text}") 198 | 199 | # Set up the communication group for weight broadcasting 200 | pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) 201 | self.pynccl_comm = PyNcclCommunicator(pg, device="cuda:0") 202 | 203 | def update_named_param(self, name: str, weights: torch.Tensor): 204 | """ 205 | Updates a specific named parameter in the model and broadcasts it to other processes. 206 | 207 | Args: 208 | name (`str`): 209 | Name of the layer whose weights are being updated. 210 | weights (`torch.Tensor`): 211 | Tensor containing the updated weights. 212 | """ 213 | dtype, shape = str(weights.dtype), tuple(weights.shape) 214 | url = f"http://{self.host}:{self.server_port}/update_named_param/" 215 | response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape}) 216 | if response.status_code != 200: 217 | raise Exception(f"Request failed: {response.status_code}, {response.text}") 218 | 219 | # Broadcast the weights to the other processes 220 | self.pynccl_comm.broadcast(weights, src=self.rank, stream=torch.cuda.current_stream()) 221 | self.pynccl_comm.group.barrier() 222 | 223 | def update_model_params(self, model: nn.Module): 224 | """ 225 | Updates all parameters of the given model by calling `update_named_param` for each parameter in the model. 226 | 227 | Args: 228 | model (`nn.Module`): 229 | Model whose parameters (weights/biases) are to be updated. 230 | """ 231 | for name, param in model.named_parameters(): 232 | # Update each parameter individually 233 | self.update_named_param(name, param.data) 234 | 235 | def reset_prefix_cache(self): 236 | """ 237 | Resets the prefix cache for the model. 238 | """ 239 | url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/" 240 | response = self.session.post(url) 241 | if response.status_code != 200: 242 | raise Exception(f"Request failed: {response.status_code}, {response.text}") 243 | 244 | def close_communicator(self): 245 | """ 246 | Closes the weight update group and cleans up the communication group. 247 | """ 248 | url = f"http://{self.host}:{self.server_port}/close_communicator/" 249 | response = self.session.post(url) 250 | if response.status_code != 200: 251 | raise Exception(f"Request failed: {response.status_code}, {response.text}") 252 | 253 | 254 | # Example usage 255 | if __name__ == "__main__": 256 | from vllm import SamplingParams 257 | 258 | client = VLLMClient() 259 | 260 | # Generate completions 261 | responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams()) 262 | print("Responses:", responses) # noqa 263 | 264 | # Update model weights 265 | from transformers import AutoModelForCausalLM 266 | 267 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda") 268 | client.update_model_params(model) 269 | -------------------------------------------------------------------------------- /utils/script/download_model.py: -------------------------------------------------------------------------------- 1 | from modelscope import snapshot_download 2 | model_dir = snapshot_download('qwen/Qwen1.5-0.5B', cache_dir='../../../download_llm') 3 | print("模型下载完成") -------------------------------------------------------------------------------- /utils/script/generate_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | ORIGINAL_DATA_PATH = '../../1.jsonl' # 原数据路径 4 | OUT_DATA_PATH = './out_data.jsonl' # 转换为框架适用的role content模式的输出路径 5 | 6 | 7 | data1 = pd.read_json(ORIGINAL_DATA_PATH, lines=True) 8 | # 创建一个空的列表来存储处理后的数据 9 | processed_data = [] 10 | # 迭代每一行数据 11 | for index, row in data1.iterrows(): 12 | message = [ 13 | {"role": "user", "content": row['instruction']}, 14 | {"role": "assistant", "content": row['output']} 15 | ] 16 | processed_data.append({"message": message}) 17 | # 将处理后的数据转换为 DataFrame 18 | processed_df = pd.DataFrame(processed_data) 19 | # 保存为jsonl格式 20 | processed_df.to_json(OUT_DATA_PATH, orient='records', lines=True, force_ascii=False) 21 | -------------------------------------------------------------------------------- /utils/script/merge_lora.py: -------------------------------------------------------------------------------- 1 | from peft import PeftModel 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | # base模型和lora训练后保存模型的位置 5 | base_model_path = 'download_llm/LLM-Research/Phi-3-mini-128k-instruct' 6 | lora_path = '/LLM-out/checkpoint-616' 7 | # 合并后整个模型的保存地址 8 | merge_output_dir = 'merged_lora_model' 9 | 10 | tokenizer = AutoTokenizer.from_pretrained(base_model_path) 11 | base_model = AutoModelForCausalLM.from_pretrained( 12 | base_model_path, 13 | device_map="cuda", 14 | torch_dtype="auto", 15 | trust_remote_code=True, 16 | ) 17 | 18 | lora_model = PeftModel.from_pretrained(base_model, lora_path) 19 | model = lora_model.merge_and_unload() 20 | 21 | if merge_output_dir: 22 | model.save_pretrained(merge_output_dir) 23 | tokenizer.save_pretrained(merge_output_dir) -------------------------------------------------------------------------------- /utils/vlm_template.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any 2 | 3 | 4 | def get_key(model): 5 | """获取模型的key""" 6 | key_set = set() 7 | for name, param in model.named_parameters(): 8 | key_set.add(name.split('.')[0]) 9 | return key_set 10 | 11 | 12 | # 基础模板处理类 13 | class TemplateProcessor: 14 | def __init__(self): 15 | # model 参数名称 16 | self.model_key = { 17 | "visual": 'visual', 18 | "llm": 'llm', 19 | "projector": 'projector' 20 | } 21 | 22 | def process(self, questions: List[str], answers: List[str]) -> Any: 23 | raise NotImplementedError("Subclasses must implement `process` method.") 24 | 25 | 26 | # LLaVA 模板处理类 27 | class LlavaTemplateProcessor(TemplateProcessor): 28 | NAME = 'LlavaProcessor' 29 | 30 | def __init__(self): 31 | super().__init__() 32 | self.model_key['visual'] = 'vision_tower' 33 | self.model_key['llm'] = 'language_model' 34 | self.model_key['projector'] = 'multi_modal_projector' 35 | 36 | def process(self, questions: List[str], answers: List[str]) -> List[Dict]: 37 | converted_data = [] 38 | for question, answer in zip(questions, answers): 39 | user_content = [{'index': None, 'text': question, 'type': 'text'}] 40 | assistant_content = [{'index': None, 'text': answer, 'type': 'text'}] 41 | 42 | converted_data.append({'content': user_content, 'role': 'user'}) 43 | converted_data.append({'content': assistant_content, 'role': 'assistant'}) 44 | image_dict = {'index': 0, 'text': None, 'type': 'image'} 45 | converted_data[0]['content'].append(image_dict) 46 | return converted_data 47 | 48 | 49 | # Qwen2VL 模板处理类可与llava相同 50 | class Qwen2VLTemplateProcessor(LlavaTemplateProcessor): 51 | NAME = 'Qwen2VLProcessor' 52 | 53 | def __init__(self): 54 | super().__init__() 55 | # model 参数名称 56 | self.model_key['visual'] = 'visual' 57 | self.model_key['llm'] = 'model' 58 | self.model_key['projector'] = 'lm_head' 59 | 60 | 61 | if __name__ == '__main__': 62 | qwen = Qwen2VLTemplateProcessor() 63 | print(qwen.model_key) 64 | -------------------------------------------------------------------------------- /vlm_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import load_dataset 3 | from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration 4 | 5 | from trl import ( 6 | ModelConfig, 7 | SFTConfig, 8 | SFTTrainer, 9 | TrlParser, 10 | get_kbit_device_map, 11 | get_peft_config, 12 | get_quantization_config, 13 | ) 14 | from train_args.vlm_config.script_args import ScriptArgs 15 | from utils.data_process import VlmQaDataset 16 | from utils.data_collator import VlmQaDataCollator 17 | 18 | 19 | def freeze_vision_projection(model, model_key_value): 20 | """ 21 | model_key_value: 'projector' or 'visual' 22 | 冻结模型的视觉部分或者转接部分 23 | """ 24 | module = getattr(model, model_key_value) 25 | for param in module.parameters(): 26 | param.requires_grad = False 27 | 28 | 29 | def initial_args(): 30 | parser = TrlParser((ScriptArgs, SFTConfig, ModelConfig)) 31 | script_args, training_args, model_args = parser.parse_args_and_config() 32 | training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) 33 | # training_args.dataset_kwargs = {"skip_prepare_dataset": True} 34 | training_args.remove_unused_columns = False 35 | return script_args, training_args, model_args 36 | 37 | 38 | def create_model_processor(model_args, script_args): 39 | torch_dtype = ( 40 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 41 | ) 42 | processor = AutoProcessor.from_pretrained( 43 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 44 | ) 45 | 46 | if script_args.train_mode in ['lora', 'qlora']: 47 | model_args.lora_target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"] 48 | model_args.use_peft = True 49 | 50 | quantization_config = None 51 | if script_args.train_mode == 'qlora': 52 | model_args.load_in_4bit = True 53 | quantization_config = get_quantization_config(model_args) 54 | 55 | model_kwargs = dict( 56 | revision=model_args.model_revision, 57 | attn_implementation=model_args.attn_implementation, 58 | torch_dtype=torch_dtype, 59 | device_map=get_kbit_device_map() if quantization_config is not None else None, 60 | quantization_config=quantization_config, 61 | ) 62 | 63 | model = AutoModelForVision2Seq.from_pretrained( 64 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs 65 | ) 66 | 67 | return { 68 | 'model': model, 69 | 'processor': processor, 70 | 'peft_config': get_peft_config(model_args), 71 | } 72 | 73 | 74 | # 加载数据集,后续不同任务可能会动态调整 75 | def load_vlm_dataset(script_args): 76 | train_dataset = VlmQaDataset(script_args.train_data_path) 77 | return train_dataset 78 | 79 | 80 | def main(): 81 | script_args, training_args, model_args = initial_args() 82 | train_dict = create_model_processor(model_args, script_args) 83 | 84 | model = train_dict['model'] 85 | processor = train_dict['processor'] 86 | 87 | train_dataset = load_vlm_dataset(script_args) 88 | collate_fn = VlmQaDataCollator(processor) 89 | 90 | model_keys = collate_fn.template_process.model_key 91 | if script_args.freeze_vision: 92 | freeze_vision_projection(model, model_keys['visual']) 93 | if script_args.freeze_projector: 94 | freeze_vision_projection(model, model_keys['projector']) 95 | 96 | trainer = SFTTrainer( 97 | model=model, 98 | args=training_args, 99 | data_collator=collate_fn, 100 | train_dataset=train_dataset, 101 | processing_class=processor.tokenizer, 102 | peft_config=get_peft_config(model_args), 103 | ) 104 | 105 | trainer.train() 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | --------------------------------------------------------------------------------