├── LICENSE.txt ├── README.md ├── config.py ├── data └── train_data.json ├── inference.py ├── main.py ├── model ├── __pycache__ │ ├── actor_critic_model.cpython-310.pyc │ ├── reference_model.cpython-310.pyc │ └── reward_model.cpython-310.pyc ├── actor_critic_model.py ├── reference_model.py └── reward_model.py ├── ppo.py ├── requirements.txt └── utils ├── __pycache__ ├── data_load.cpython-310.pyc └── tools.cpython-310.pyc ├── data_load.py └── tools.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RLHF PPO 2 | ## 博客地址:https://zhuanlan.zhihu.com/p/701832687 3 | ## 项目描述: 4 | 本仓库实现PPO算法,由于个人硬件有限强化的模型是qwen_0.5B, 使用lora调节模型参数。 5 | 奖励模型使用的是Erlangshen-Roberta-330M-Sentiment,不需要微调这个模型,下载地址:https://huggingface.co/IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment 6 | ## 代码组织解释 7 | ### 训练数据 8 | data/train_data.json 此数据是自己造的,仅用于学习使用。虽然数据就2条,我微调了20个epochs,效果还行。后面会附带训练后的结果 9 | ### model 文件 10 | model/actor_critic_model.py 这里actor和critic模型使用同一个底座,没有拆开 11 | model/reference_model.py 参考模型,其实就是原始的qwen模型 12 | model/reward_model.py 奖励模型,使用的是Erlangshen-Roberta-330M-Sentiment 13 | ### utils 文件 14 | 这里主要有两个文件,utils/data_load.py加载数据和utils/tools.py模型这几个模型都需要调用的功能独立出来 15 | ### config.py 配置文件 16 | ### ppo.py PPO核心实现 17 | ### main.py 训练代码 18 | ### inference.py 训练完成之后,使用该文件预测效果 19 | 20 | 21 | ## 微调后效果比对 22 | 输入的问题:饭店服务员的态度太差,使用委婉积极的态度投诉。系统角色:你是一个有文化的文明人 23 | ### 微调之前的效果 24 | 尊敬的经理: 25 | 您好!我最近在您的饭店用餐时遇到了一些问题。我想通过这封信向您反映一下。 26 | 首先,我对餐厅的服务态度感到非常不满。当我点菜时,服务员的态度并不友好,总是显得有些冷漠和不耐烦。他们似乎对我的需求没有足够的关注,甚至有时会对我提出一些不合理的建议或要求。 27 | 其次,我在用餐过程中也遇到了一些困扰。我发现有些菜品的味道并不符合我的口味,而且有些服务人员的服务态度也不够热情。这些都让我感到很失望。 28 | 最后,我还注意到有一些环境问题。例如,餐厅的卫生状况不佳,餐具和杯子经常被污染,这让我感到非常不舒服。 29 | 我希望您能理解并采取措施来改善我们的用餐体验。我相信,只要我们共同努力,我们的服务质量将会得到显著提高。 30 | 再次感谢您抽出宝贵的时间阅读这封信,并期待您的回复。 31 | 顺祝商祺, 32 | [你的名字] 33 | ### 微调之后的效果 34 | 尊敬的饭店服务员: 35 | 您好!我最近在您的餐厅用餐时遇到了一些问题。我想通过您提出的问题和反馈,希望能够得到您的理解和帮助。 36 | 首先,我想强调的是,我在用餐过程中并没有感到不舒服或不满。相反,我认为您的服务态度非常专业、热情和友好。我注意到您对我们的食物和服务都非常关注,并且总是尽力提供最好的体验。 37 | 然而,我也意识到,有时候我们可能会因为自己的需求或者环境而做出不恰当的行为。因此,我希望您能理解我的情况,并尝试以一种更礼貌的方式表达您的观点。 38 | 如果您需要更多的信息或者建议,我很愿意听取。我相信,只要我们共同努力,我们可以找到一个更好的解决方案。 39 | 再次感谢您的耐心倾听和理解。我期待着您的回复。 40 | 谢谢! 41 | 祝商祺, 42 | [你的名字] 43 | ## 最后一点,在训练和推理阶段,一定要使用贪婪模式 44 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from dataclasses import dataclass, field 4 | 5 | class Config: 6 | # model 参数 ########################### 7 | # 情感分析模型,下载地址https://huggingface.co/IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment 8 | Sentiment_model = "E:\\ai_model\\model\\Erlangshen-Roberta-330M-Sentiment" 9 | # 文本生成模型,下载地址 https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat 10 | gpt_model = "E:\\ai_model\\model\\qwen0.5" 11 | data_path = "data/train_data.json" 12 | save_lora_path = "E:\\ai_model\\model\\ppo\\save_lora" 13 | save_v_head_path = "E:\\ai_model\\model\\ppo\\v_head\\pytorch_model.bin" 14 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 15 | batch_size = 2 16 | epochs = 10 17 | lr = 0.001 18 | # PPO 参数 ############################ 19 | ppo_epochs = 3 20 | kl_ctl_value = 0.2 21 | gamma = 1.0 # 用于优势计算的折扣因子。控制未来奖励的重要性。 22 | lam = 0.95 # 用于优势计算的Lambda参数。它用于控制对未来奖励的考虑程度,结合时间差异方法。 23 | cliprange_value = 0.2 # 损失计算中值函数的裁剪范围。裁剪可以防止极端值对训练过程的负面影响。 24 | cliprange = 0.2 # PPO策略梯度损失中的裁剪范围。这个裁剪范围用于限制策略更新的步长,从而保持训练的稳定性。 25 | vf_coef = 0.1 26 | 27 | 28 | @dataclass 29 | class LoraArguments: 30 | lora_r: int = 2 31 | lora_alpha: int = 8 32 | lora_dropout: float = 0 33 | lora_target_modules: List[str] = field( 34 | default_factory=lambda: ['k_proj', 'v_proj'] 35 | ) 36 | # lora_target_modules = None 37 | lora_weight_path: str = "" 38 | q_lora: bool = False 39 | load_in_4bit: bool = False 40 | load_in_8bit: bool = False 41 | is_reload_trained_params = True # 是否接着上次训练模型继续训练 42 | -------------------------------------------------------------------------------- /data/train_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "query": "饭店服务员的态度太差,使用委婉积极的态度投诉", 4 | "system_content": "你是一个有文化的文明人" 5 | }, 6 | { 7 | "query": "领导故意刁难你,你想骂他娘的,使用文明语言骂他娘的", 8 | "system_content": "你是一个有文化的文明人" 9 | } 10 | ] 11 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from RLHF_PPO.config import Config 3 | from peft import LoraConfig, PeftModel 4 | from RLHF_PPO.config import LoraArguments 5 | 6 | 7 | class LoraPPOModel(PeftModel): 8 | def __init__(self, config: Config): 9 | self.config = config 10 | model = AutoModelForCausalLM.from_pretrained(config.gpt_model).to(config.device) 11 | self.tokenizer = AutoTokenizer.from_pretrained(config.gpt_model) 12 | lora_args = LoraArguments() 13 | lora_config = LoraConfig( 14 | r=lora_args.lora_r, 15 | lora_alpha=lora_args.lora_alpha, 16 | target_modules=lora_args.lora_target_modules, 17 | lora_dropout=lora_args.lora_dropout, 18 | task_type="CAUSAL_LM", 19 | ) 20 | super().__init__(model, lora_config) 21 | model = super().from_pretrained(model, config.save_lora_path) 22 | self.lora_ppo_model = model.merge_and_unload() 23 | self.raw_model = AutoModelForCausalLM.from_pretrained(config.gpt_model).to(config.device) 24 | print() 25 | 26 | def forward(self, query, system_content): 27 | messages = [ 28 | {"role": "system", "content": system_content}, 29 | {"role": "user", "content": query} 30 | ] 31 | text = self.tokenizer.apply_chat_template( 32 | messages, 33 | tokenize=False, 34 | add_generation_prompt=True 35 | ) 36 | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.config.device) 37 | lora_ppo_response = self.predict(model_inputs, self.lora_ppo_model, self.tokenizer) 38 | raw_response = self.predict(model_inputs, self.raw_model, self.tokenizer) 39 | return lora_ppo_response, raw_response 40 | 41 | @staticmethod 42 | def predict(model_inputs, model, tokenizer): 43 | generated_ids = model.generate( 44 | model_inputs.input_ids, 45 | max_new_tokens=512, 46 | num_beams=1, 47 | do_sample=False 48 | ) 49 | generated_ids = [ 50 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 51 | ] 52 | 53 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 54 | return response 55 | 56 | 57 | if __name__ == '__main__': 58 | lora_ppo_model = LoraPPOModel(Config()) 59 | lora_ppo_response, raw_response = lora_ppo_model("饭店服务员的态度太差,使用委婉积极的态度投诉", "你是一个有文化的文明人") 60 | print(f"lora_ppo_response:{lora_ppo_response}") 61 | print(f"raw_response:{raw_response}") 62 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from config import Config 3 | from torch.utils.data import DataLoader 4 | from torch.optim import Adam 5 | from RLHF_PPO.model.actor_critic_model import ActorCriticLoraModel 6 | from RLHF_PPO.model.reward_model import RewardModel 7 | from RLHF_PPO.utils.data_load import CustomDataset 8 | from RLHF_PPO.model.reference_model import ReferenceModel 9 | from ppo import PPO 10 | from RLHF_PPO.utils.tools import Tools 11 | 12 | 13 | class TrainPpo: 14 | def __init__(self): 15 | self.config = Config() 16 | # 演员和评论家模型 17 | self.actor_critic_model = ActorCriticLoraModel(self.config) 18 | self.tokenizer = self.actor_critic_model.tokenizer 19 | # 获得演员和评论家模型优化器, 这里使用的是lora, 不优化全量数据 20 | self.actor_critic_opt = Adam(self.actor_critic_model.parameters(), lr=self.config.lr) 21 | # 参考模型 22 | self.reference_model = ReferenceModel(self.config) 23 | # 奖励模型 24 | self.reward_model = RewardModel(self.config) 25 | # 训练数据 26 | dataset = CustomDataset(self.config.data_path, self.tokenizer) 27 | self.data_loader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, 28 | collate_fn=dataset.collate_fn) 29 | self.ppo = PPO(self.actor_critic_model, self.config, self.actor_critic_opt) 30 | 31 | def train_ppo(self): 32 | self.save_model() 33 | for epoch in range(self.config.epochs): 34 | for batch_data in self.data_loader: 35 | # 获得演员模型生成的结果(prompt_generate)和ids(prompt_generate_ids, generate_ids) 36 | prompt_generate, prompt_generate_ids, generate_ids = self.actor_critic_model.actor_generate( 37 | batch_data[0]) 38 | attention_mask = (prompt_generate_ids != self.tokenizer.pad_token_id) 39 | generate_ids_mask = (generate_ids[:, :-1] != self.tokenizer.pad_token_id) 40 | # 模型生成的token, 为什么减去1,因为最后一个字符是结束符 41 | response_shape = generate_ids.shape[1] - 1 42 | # 初始化工具 43 | tools = Tools(response_shape, generate_ids_mask) 44 | # 去掉输入,获得真正生成的数据。用于计算reword value 45 | pure_generate = [one.split("assistant\n")[1] for one in prompt_generate] 46 | reward = self.reward_model(pure_generate) 47 | # 获得参考模型probs 48 | prob_refs = self.reference_model(prompt_generate_ids, attention_mask, tools) 49 | # 获得上帝模型(评论家模型)的价值 50 | self.ppo.train(prompt_generate_ids, attention_mask, prob_refs, reward, tools) 51 | self.save_model() 52 | 53 | def save_model(self): 54 | # 保存lora参数 55 | self.actor_critic_model.model.save_pretrained(self.config.save_lora_path, safe_serialization=False) 56 | # 保存价值模型参数 57 | torch.save(self.actor_critic_model.model.v_head.state_dict(), self.config.save_v_head_path) 58 | 59 | 60 | if __name__ == '__main__': 61 | train_ppo = TrainPpo() 62 | train_ppo.train_ppo() 63 | -------------------------------------------------------------------------------- /model/__pycache__/actor_critic_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OctopusMind/RLHF_PPO/34bd4c1d91721810a459cbb707f03f8afbe3b3f0/model/__pycache__/actor_critic_model.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/reference_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OctopusMind/RLHF_PPO/34bd4c1d91721810a459cbb707f03f8afbe3b3f0/model/__pycache__/reference_model.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/reward_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OctopusMind/RLHF_PPO/34bd4c1d91721810a459cbb707f03f8afbe3b3f0/model/__pycache__/reward_model.cpython-310.pyc -------------------------------------------------------------------------------- /model/actor_critic_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from RLHF_PPO.config import Config 4 | from RLHF_PPO.utils.tools import Tools 5 | from peft import LoraConfig, get_peft_model, PeftModel 6 | from RLHF_PPO.config import LoraArguments 7 | 8 | 9 | class LoraModel(PeftModel): 10 | def __init__(self, config: Config, model): 11 | lora_args = LoraArguments() 12 | lora_config = LoraConfig( 13 | r=lora_args.lora_r, 14 | lora_alpha=lora_args.lora_alpha, 15 | target_modules=lora_args.lora_target_modules, 16 | lora_dropout=lora_args.lora_dropout, 17 | task_type="CAUSAL_LM", 18 | ) 19 | super().__init__(model, lora_config) 20 | self.v_head = torch.nn.Linear(1024, 1, bias=False).to(config.device) 21 | if lora_args.is_reload_trained_params: 22 | super().from_pretrained(model, config.save_lora_path) 23 | self.v_head.load_state_dict(torch.load(config.save_v_head_path)) 24 | for name, module in self.named_modules(): 25 | if 'lora_' in name: 26 | for param in module.parameters(): 27 | param.requires_grad = True 28 | 29 | def forward(self, input_ids, attention_mask, tools: Tools): 30 | res = super().forward(input_ids, attention_mask, output_hidden_states=True) 31 | values = self.v_head(res.hidden_states[0]).squeeze(-1)[:, :-1] 32 | values = tools.filter_mask(values) 33 | probs = tools.probs_from_logits(res.logits[:, :-1, :], input_ids[:, 1:]) 34 | probs = tools.filter_mask(probs) 35 | return probs, values 36 | 37 | 38 | class ActorCriticLoraModel(torch.nn.Module): 39 | def __init__(self, config: Config): 40 | super().__init__() 41 | model = AutoModelForCausalLM.from_pretrained(config.gpt_model).to(config.device).eval() 42 | self.model = LoraModel(config, model) 43 | self.tokenizer = AutoTokenizer.from_pretrained(config.gpt_model) 44 | 45 | def forward(self, input_ids, attention_mask, tools: Tools): 46 | probs, values = self.model(input_ids, attention_mask, tools) 47 | return probs, values 48 | 49 | @torch.no_grad() 50 | def actor_generate(self, input_ids): 51 | generated_ids = self.model.generate(input_ids, max_new_tokens=512, top_p=1.0, 52 | num_beams=1, 53 | do_sample=False) 54 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 55 | response_id = generated_ids[:, input_ids.shape[1]:] 56 | return response, generated_ids, response_id 57 | -------------------------------------------------------------------------------- /model/reference_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM 2 | from RLHF_PPO.utils.tools import Tools 3 | import torch.nn as nn 4 | import torch 5 | 6 | 7 | class ReferenceModel(nn.Module): 8 | def __init__(self, config): 9 | super(ReferenceModel, self).__init__() 10 | self.config = config 11 | self.reference_model = AutoModelForCausalLM.from_pretrained(self.config.gpt_model, torch_dtype="auto").to( 12 | self.config.device) 13 | 14 | @torch.no_grad() 15 | def forward(self, input_ids, attention_mask, tools: Tools): 16 | logits = self.reference_model(input_ids=input_ids, 17 | attention_mask=attention_mask).logits 18 | prob_refs = tools.probs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) 19 | prob_refs = tools.filter_mask(prob_refs) 20 | return prob_refs 21 | 22 | 23 | -------------------------------------------------------------------------------- /model/reward_model.py: -------------------------------------------------------------------------------- 1 | from transformers import BertForSequenceClassification, BertTokenizer 2 | from RLHF_PPO.config import Config 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | # 奖励模型 8 | class RewardModel(nn.Module): 9 | def __init__(self, config: Config): 10 | super(RewardModel, self).__init__() 11 | self.config = config 12 | self.reward_tokenizer = BertTokenizer.from_pretrained(self.config.Sentiment_model) 13 | self.reward_model = BertForSequenceClassification.from_pretrained(self.config.Sentiment_model).to( 14 | self.config.device) 15 | 16 | @torch.no_grad() 17 | def forward(self, text): 18 | input_ids, attention_mask = self.data_process(text) 19 | output = self.reward_model(torch.tensor(input_ids).to(self.config.device), 20 | torch.tensor(attention_mask).to(self.config.device)) 21 | probs = torch.softmax(torch.tensor(output.logits), dim=1).tolist() 22 | reward = [prob[0] for prob in probs] 23 | return reward 24 | 25 | def data_process(self, texts): 26 | attention_mask = [] 27 | input_ids = [self.reward_tokenizer.encode(text)[:512] for text in texts] 28 | max_length = max(len(i) for i in input_ids) 29 | res = [] 30 | for one in input_ids: 31 | padding_num = max_length - len(one) 32 | res.append(one + [self.reward_tokenizer.pad_token_id] * padding_num) 33 | attention_mask.append([1] * len(one) + [0] * padding_num) 34 | return res, attention_mask 35 | 36 | -------------------------------------------------------------------------------- /ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from config import Config 3 | from utils.tools import Tools 4 | 5 | 6 | class PPO: 7 | def __init__(self, actor_critic_model, config: Config, actor_critic_opt): 8 | self.actor_critic_model = actor_critic_model 9 | self.config = config 10 | self.actor_critic_opt = actor_critic_opt 11 | 12 | def train(self, prompt_generate_ids, attention_mask, prob_refs, reward, tools: Tools): 13 | with torch.no_grad(): 14 | _, old_values = self.actor_critic_model(prompt_generate_ids, attention_mask, tools) # 计算每个token的价值 15 | for _ in range(self.config.ppo_epochs): 16 | # 获得actor_critic模型新的probs和token对应的价值 17 | new_probs, new_values = self.actor_critic_model(prompt_generate_ids, attention_mask, tools) 18 | # 计算奖励值 19 | rewards, non_score_rewards = self.compute_rewards(reward, new_probs, prob_refs) # 计算reward 20 | loss = self.loss(new_probs=new_probs, old_values=old_values, new_values=new_values, 21 | rewards=rewards, old_probs=prob_refs) 22 | 23 | self.actor_critic_opt.zero_grad() 24 | loss.backward() 25 | self.actor_critic_opt.step() 26 | print(loss) 27 | 28 | def loss(self, new_probs, old_values, new_values, rewards, old_probs): 29 | """ 30 | 计算actor模型和评价模型的loss 31 | :param new_probs: actor模型生成的probs 32 | :param old_values: ppo 优化之前的价值 33 | :param new_values: ppo 优化过程中新的价值 34 | :param rewards: 每次生成token对应的奖励 35 | :param old_probs: reference模型生成的probs 36 | :return: actor loss 和 critic loss 37 | """ 38 | """Calculate policy and value losses.""" 39 | loss = torch.tensor(0.0) 40 | for new_prob, old_value, new_value, reward, old_prob in zip(new_probs, old_values, new_values, rewards, 41 | old_probs): 42 | new_prob = new_prob.unsqueeze(0) 43 | old_value = old_value.unsqueeze(0) 44 | new_value = new_value.unsqueeze(0) 45 | reward = reward.unsqueeze(0) 46 | old_prob = old_prob.unsqueeze(0) 47 | last_gae_lam = 0 48 | advantages_reversed = [] 49 | gen_len = new_prob.shape[1] 50 | # GAE 计算优势函数,当前token获得的奖励(真实的) + 未来获得的价值(这个是上帝视角,不包含当前token) - 包含当前token在上帝视角下的价值 51 | # 当前token获得的奖励(真实的) + 未来获得的价值(这个是上帝视角,不包含当前token) 比 包含当前token在上帝视角下的价值 要准 52 | for t in reversed(range(gen_len)): 53 | next_values = old_value[:, t + 1] if t < gen_len - 1 else 0.0 54 | delta = reward[:, t] + self.config.gamma * next_values - old_value[:, t] 55 | last_gae_lam = delta + self.config.gamma * self.config.lam * last_gae_lam 56 | advantages_reversed.append(last_gae_lam) 57 | advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) 58 | returns = advantages + old_value # Q值,当前token获得的奖励(真实的) + 未来获得的价值(这个是上帝视角,不包含当前token) 59 | advantages = self.whiten(advantages) 60 | advantages = advantages.detach() 61 | value_clipped = torch.clamp(new_value, 62 | old_value - self.config.cliprange_value, 63 | old_value + self.config.cliprange_value) # 截断防止训练废了 64 | vf_loss1 = (new_value - returns) ** 2 # 上帝视角的价值减去Q值的误差,用于优化上帝模型 65 | vf_loss2 = (value_clipped - returns) ** 2 66 | vf_loss = torch.mean(torch.max(vf_loss2, vf_loss1)) 67 | 68 | ratio = torch.exp(new_prob - old_prob) # 控制优化范围,防止训练离原始模型偏差过大 69 | pg_losses = -advantages * ratio # importance sampling 70 | pg_losses2 = -advantages * torch.clamp(ratio, 71 | 1.0 - self.config.cliprange, 72 | 1.0 + self.config.cliprange) # 截断防止训练废了 73 | pg_loss = torch.mean(torch.max(pg_losses, pg_losses2)) 74 | loss += pg_loss + self.config.vf_coef * vf_loss 75 | return loss 76 | 77 | def compute_rewards(self, scores, probs, ref_probs): 78 | """ 79 | 计算reward值,由于对每一个token不能给与即使的奖励,这里使用kl散度补偿 80 | :param scores:reward model给出的奖励值,每条句子只有一个值 81 | :param probs: actor model生成的probs 82 | :param ref_probs: reference model 生成的probs 83 | :return: 返回每个token的奖励值 84 | """ 85 | rewards, non_score_rewards = [], [] 86 | for score, prob, ref_prob in zip(scores, probs, ref_probs): 87 | kl = prob - ref_prob # (seq_len, ) 88 | non_score_reward = -self.config.kl_ctl_value * kl # (seq_len, ) 89 | non_score_rewards.append(non_score_reward) 90 | reward = non_score_reward.clone() # 前面每一个token的reward都来自KL惩罚 91 | reward[-1] += score # 在最后一位加上人工给的reward 92 | rewards.append(reward) 93 | return rewards, non_score_rewards # (batch, seq_len) 94 | 95 | @staticmethod 96 | def whiten(values, shift_mean=True): 97 | """ 98 | 归一化 99 | :param values: 要归一化的值 100 | :param shift_mean: 负一化方式 101 | :return: 返回归一化之后的结果 102 | """ 103 | mean, var = torch.mean(values), torch.var(values) 104 | whitened = (values - mean) * torch.rsqrt(var + 1e-8) 105 | if not shift_mean: 106 | whitened += mean 107 | return whitened 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.40.0 2 | torch==2.3.0 3 | peft==0.10.0 4 | -------------------------------------------------------------------------------- /utils/__pycache__/data_load.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OctopusMind/RLHF_PPO/34bd4c1d91721810a459cbb707f03f8afbe3b3f0/utils/__pycache__/data_load.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OctopusMind/RLHF_PPO/34bd4c1d91721810a459cbb707f03f8afbe3b3f0/utils/__pycache__/tools.cpython-310.pyc -------------------------------------------------------------------------------- /utils/data_load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import json 4 | 5 | 6 | class CustomDataset(Dataset): 7 | def __init__(self, data_file, actor_tokenizer): 8 | self.actor_tokenizer = actor_tokenizer 9 | with open(data_file, 'r', encoding="utf-8") as f: 10 | self.data = json.load(f) 11 | self.total_samples = len(self.data) 12 | self.actor_padding_id = actor_tokenizer.pad_token_id 13 | 14 | def __len__(self): 15 | return self.total_samples 16 | 17 | def __getitem__(self, idx): 18 | # 根据索引加载数据 19 | # 这里可以根据需要从文件中读取数据并进行预处理 20 | line = self.data[idx] 21 | query = line["query"] 22 | system_content = line["system_content"] 23 | messages = [ 24 | {"role": "system", "content": system_content}, 25 | {"role": "user", "content": query}, 26 | ] 27 | text = self.actor_tokenizer.apply_chat_template( 28 | messages, 29 | tokenize=False, 30 | add_generation_prompt=True 31 | ) 32 | model_inputs = self.actor_tokenizer([text], return_tensors="pt") 33 | 34 | return [model_inputs.input_ids.tolist()[0], model_inputs.attention_mask.tolist()[0]] 35 | 36 | def collate_fn(self, batch): 37 | max_length = max([len(i[0]) for i in batch]) 38 | input_ids = [] 39 | mask_attention = [] 40 | for one in batch: 41 | padding_num = max_length - len(one[0]) 42 | input_ids.append([self.actor_padding_id] * padding_num + one[0]) 43 | mask_attention.append([0] * padding_num + one[1]) 44 | return torch.tensor(input_ids), torch.tensor(mask_attention) 45 | 46 | 47 | if __name__ == '__main__': 48 | from ..config import Config 49 | from transformers import AutoTokenizer 50 | 51 | config = Config() 52 | tokenizer = AutoTokenizer.from_pretrained(config.gpt_model) 53 | # 创建自定义数据集实例 54 | dataset = CustomDataset(config.data_path, tokenizer) 55 | 56 | # 创建数据加载器并指定批次大小 57 | batch_size = 2 58 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate_fn) 59 | 60 | # 使用生成器函数按需读取数据 61 | for batch in data_loader: 62 | print() 63 | # 在每个批次中进行模型训练 64 | # batch 包含了一个批次的样本数据 65 | # 在这里执行模型训练操作 66 | pass 67 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class Tools: 6 | def __init__(self, response_shape, response_ids_mask): 7 | """ 8 | :param response_shape: 模型生成的句子长度 9 | :param response_ids_mask: 模型批量生成的句子中有padding,这里去除padding数据 10 | """ 11 | self.response_shape = response_shape 12 | self.response_ids_mask = response_ids_mask 13 | 14 | def filter_mask(self, values): 15 | """ 16 | :param values: 一般是prob_old、prob_ref、value(价值)的值 17 | :return: 去除padding之后的数据 18 | """ 19 | return [value[-self.response_shape:][one_response_ids_mask] for value, one_response_ids_mask in 20 | zip(values, self.response_ids_mask)] 21 | 22 | @staticmethod 23 | def probs_from_logits(logits, labels): 24 | log_probs = F.log_softmax(logits, dim=2) 25 | probs = torch.gather(log_probs, 2, labels.unsqueeze(2)).squeeze(-1) 26 | return probs 27 | --------------------------------------------------------------------------------