├── LICENSE.txt ├── README.md ├── config.py ├── data └── train_data.json ├── dpo.py ├── inference.py ├── main.py ├── model ├── model.py └── reference_model.py ├── requirements.txt └── utils └── data_load.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPO 算法手动实现 2 | ## 博客地址:https://zhuanlan.zhihu.com/p/702774357 3 | ## 项目描述: 4 | 本仓库实现DPO算法,由于个人硬件有限强化的模型是qwen_0.5B, 使用lora调节模型参数。 5 | 6 | ## 代码组织解释 7 | ### 训练数据 8 | data/train_data.json 此数据是自己造的,仅用于学习使用。虽然数据就2条,我微调了3个epochs,效果还行。后面会附带训练后的结果 9 | ### model 文件 10 | model/model.py 这里需要优化的策略模型 11 | model/reference_model.py 参考模型,其实就是原始的qwen模型 12 | ### utils 文件 13 | 这里主要有两个文件,utils/data_load.py加载数据 14 | ### config.py 配置文件 15 | ### dpo.py DPO核心实现 16 | ### main.py 训练代码 17 | ### inference.py 训练完成之后,使用该文件预测效果 18 | 19 | ## 微调后效果比对 20 | 输入的问题:操作说明:为一个关于回收的网站设计一个标志。 21 | instruction:以下是描述任务的说明。编写适当地完成请求的响应。 22 | ### 微调之前的效果 23 | 尊敬的用户,感谢您对我们的支持!为了更好地满足您的需求,我们为您设计了一个标志。 24 | 首先,我们需要确定这个标志的主题是什么。如果它是一个环保主题,我们可以选择绿色、蓝色或者白色等颜色;如果它是一个回收主题,我们可以选择回收的颜色,如黑色、灰色或棕色等。 25 | 然后,我们需要考虑标志的大小和位置。一般来说,标志应该在网站的中心位置,以便于用户快速找到并使用。 26 | 接下来,我们需要考虑标志的设计元素。我们可以使用一些简单的线条来表示回收,比如回收的形状、颜色等。同时,我们也可以添加一些细节,比如回收的符号或者回收的图案,以增加标志的视觉吸引力。 27 | 最后,我们需要确保标志的清晰度和易读性。我们应该尽量避免使用过于复杂的字体或者图形,以免让用户感到困惑。 28 | 希望这些建议能帮助您设计出一个成功的标志。如果您有任何问题,欢迎随时向我提问。 29 | 30 | ### 微调之后的效果 31 | 回复:感谢您的创意!我们正在努力创建一个环保且可持续的设计,以鼓励人们采取行动来保护环境。我们可以使用绿色植物、蓝色或白色作为背景色,并在标志上添加一些可回收的元素,如回收纸板或瓶子。 32 | 33 | ## 感觉越优化越差,训练数据和资源有限,仅供学习参考 34 | 35 | ## 最后一点,在训练和推理阶段,一定要使用贪婪模式 36 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from dataclasses import dataclass, field 4 | 5 | 6 | class Config: 7 | # model 参数 ########################### 8 | # 文本生成模型,下载地址 https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat 9 | gpt_model = "E:\\ai_model\\model\\qwen0.5" 10 | data_path = "E:\\ai_model\\RLHF\\my_code\\RLHF_DPO/data/train_data.json" 11 | save_lora_path = "E:\\ai_model\\model\\dpo\\save_lora" 12 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 13 | batch_size = 2 14 | epochs = 3 15 | lr = 0.001 16 | # DPO 参数 ############################ 17 | dpo_epochs = 3 18 | beta = 0.1 19 | 20 | 21 | @dataclass 22 | class LoraArguments: 23 | lora_r: int = 2 24 | lora_alpha: int = 8 25 | lora_dropout: float = 0 26 | lora_target_modules: List[str] = field( 27 | default_factory=lambda: ['k_proj', 'v_proj'] 28 | ) 29 | # lora_target_modules = None 30 | lora_weight_path: str = "" 31 | q_lora: bool = False 32 | load_in_4bit: bool = False 33 | load_in_8bit: bool = False 34 | is_reload_trained_params = False # 是否接着上次训练模型继续训练 35 | -------------------------------------------------------------------------------- /data/train_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "以下是描述任务的说明。编写适当地完成请求的响应。", 4 | "input": "操作说明:为一个关于回收的网站设计一个标志。", 5 | "output": { 6 | "chosen": "要为一个关于回收的网站设计一个标志,我会首先集思广益代表回收理念的概念,比如传统的回收箭头符号、树叶或树木,以及大地色调。\n\n徽标的草图可能有由绿叶制成的回收符号,网站的名称位于符号下方,字体为粗体,年份字体为绿色或棕色等大地色调。另一种可能性是用一个地球仪来代表地球,并用绿色或蓝色的对比阴影包裹它的回收箭头。\n\n文本应该是平衡的和有吸引力的,清晰易读。一些首选字体可能是 Sans-serif,例如 Montserrat 或 Open Sans,但此选项将取决于网站的名称。最终结果应该是一个干净、现代和可识别的标志,反映回收网站的价值。", 7 | "rejected": "我很抱歉,但我无法为您的回收网站设计一个标志。设计标志需要专业的设计技能和工具,而我主要提供文本方面的帮助。建议您联系专业的平面设计师来完成这个任务。" 8 | } 9 | }, 10 | { 11 | "instruction": "你好,以下是描述任务的说明。编写适当地完成请求的响应。", 12 | "input": "操作说明:为一个关于回收的网站设计一个标志。", 13 | "output": { 14 | "chosen": "当然!为回收网站设计一个标志是一个很棒的创意。您可以考虑使用绿色和蓝色的色调,这些颜色通常代表环保和可持续性。可以加入回收标志、地球或树木等图案,以传达环保的理念。如果您有一些具体的想法或元素,可以与设计师讨论,他们可以帮助您把这些想法变成现实。祝您设计出一个令人满意的标志!", 15 | "rejected": "设计一个关于回收的网站的标志并不是我擅长的领域。标志设计需要专业的平面设计技能和经验,建议您联系一位专业的设计师,他们能够为您提供更好的帮助。" 16 | } 17 | } 18 | ] -------------------------------------------------------------------------------- /dpo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class DPO: 6 | def __init__(self, model, model_opt, config): 7 | self.model = model 8 | self.model_opt = model_opt 9 | self.beta = config.beta 10 | self.dpo_epochs = config.dpo_epochs 11 | 12 | def train(self, inputs_ids, attention_mask, ref_logits, labels_mask): 13 | # 计算参考模型的logps 14 | ref_token_logps = self.probs_from_logits(torch.tensor(ref_logits)[:, :-1, :], inputs_ids[:, 1:]) 15 | ref_logps = self.filter_mask(ref_token_logps, labels_mask) 16 | # 一次数据多训练几次,这样reference只用计算一次 17 | for dpo_epoch in range(self.dpo_epochs): 18 | # 计算策略模型的logps 19 | logits = self.model(inputs_ids, attention_mask) 20 | policy_token_logps = self.probs_from_logits(logits[:, :-1, :], inputs_ids[:, 1:]) 21 | policy_logps = self.filter_mask(policy_token_logps, labels_mask) 22 | loss = self.dpo_loss(policy_logps, ref_logps) 23 | self.model_opt.zero_grad() 24 | loss.backward() 25 | self.model_opt.step() 26 | print(loss) 27 | 28 | def dpo_loss(self, policy_logps, ref_logps): 29 | """ 30 | 计算公式L_{DPO}(\pi_{\theta};\pi_{ref}) = -E[log sigmoid(\beta[log(\pi_\theta(y_w|x)/\pi_{ref}(y_w|x) - log(\pi_\theta(y_l|x)/\pi_{ref}(y_l|x))])] 31 | 详细过程可以看博客:https://zhuanlan.zhihu.com/p/702774357 32 | :param policy_logps: 策略模型的logps 33 | :param ref_logps: 参考模型的logps 34 | :return: loss 35 | """ 36 | 37 | def concat_probs(logps): 38 | """ 39 | 拆开合理与不合理数据的logps 40 | :param logps: 参考模型或者策略模型的logps 41 | :return: 合理和不合理数据的logps 42 | """ 43 | len_chosen = int(len(logps) / 2) 44 | rejected_data = torch.cat(logps[:len_chosen]) 45 | chosen_data = torch.cat(logps[len_chosen:]) 46 | return rejected_data, chosen_data 47 | 48 | policy_rejected_logps, policy_chosen_logps = concat_probs(policy_logps) # 计算合理数据的logps和不合理数据的logps 49 | ref_rejected_logps, ref_chosen_logps = concat_probs(ref_logps) 50 | pi_logratios = policy_chosen_logps - policy_rejected_logps 51 | ref_logratios = ref_chosen_logps - ref_rejected_logps 52 | # 这里计算策略模型合理与不合理差距与参考模型合理与不合理差距的差 53 | logits = pi_logratios - ref_logratios 54 | loss = -F.logsigmoid(self.beta * logits) 55 | return loss.mean() 56 | 57 | @staticmethod 58 | def probs_from_logits(logits, labels): 59 | log_probs = F.log_softmax(logits, dim=2) 60 | probs = torch.gather(log_probs, 2, labels.unsqueeze(2)).squeeze(-1) 61 | return probs 62 | 63 | @staticmethod 64 | def filter_mask(values, labels_masks): 65 | """ 66 | :param values: 一般是prob_old、prob_ref、value(价值)的值 67 | :param labels_masks:label 对应的mask 68 | :return: 去除padding之后的数据 69 | """ 70 | return [value[one_response_ids_mask[:-1] == 1].sum().unsqueeze(0) for value, one_response_ids_mask in 71 | zip(values, labels_masks)] 72 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from config import Config 3 | from peft import LoraConfig, PeftModel 4 | from config import LoraArguments 5 | 6 | 7 | class LoraDPOModel(PeftModel): 8 | def __init__(self, config: Config): 9 | self.config = config 10 | model = AutoModelForCausalLM.from_pretrained(config.gpt_model).to(config.device) 11 | self.tokenizer = AutoTokenizer.from_pretrained(config.gpt_model) 12 | lora_args = LoraArguments() 13 | lora_config = LoraConfig( 14 | r=lora_args.lora_r, 15 | lora_alpha=lora_args.lora_alpha, 16 | target_modules=lora_args.lora_target_modules, 17 | lora_dropout=lora_args.lora_dropout, 18 | task_type="CAUSAL_LM", 19 | ) 20 | super().__init__(model, lora_config) 21 | model = super().from_pretrained(model, config.save_lora_path) 22 | self.lora_dpo_model = model.merge_and_unload() 23 | self.raw_model = AutoModelForCausalLM.from_pretrained(config.gpt_model).to(config.device) 24 | print() 25 | 26 | def forward(self, query, instruction): 27 | messages = [ 28 | {"role": "system", "content": "你是一个非常有帮助和智能的助手。"}, 29 | {"role": "instrution", "content": instruction}, 30 | {"role": "user", "content": query} 31 | ] 32 | text = self.tokenizer.apply_chat_template( 33 | messages, 34 | tokenize=False, 35 | add_generation_prompt=True 36 | ) 37 | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.config.device) 38 | lora_dpo_response = self.predict(model_inputs, self.lora_dpo_model, self.tokenizer) 39 | raw_response = self.predict(model_inputs, self.raw_model, self.tokenizer) 40 | return lora_dpo_response, raw_response 41 | 42 | @staticmethod 43 | def predict(model_inputs, model, tokenizer): 44 | generated_ids = model.generate( 45 | model_inputs.input_ids, 46 | max_new_tokens=512, 47 | num_beams=1, 48 | do_sample=False 49 | ) 50 | generated_ids = [ 51 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 52 | ] 53 | 54 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 55 | return response 56 | 57 | 58 | if __name__ == '__main__': 59 | lora_dpo_model = LoraDPOModel(Config()) 60 | lora_dpo_response, raw_response = lora_dpo_model("操作说明:为一个关于回收的网站设计一个标志。", "以下是描述任务的说明。编写适当地完成请求的响应。") 61 | print(f"lord_dpo_response:{lora_dpo_response}") 62 | print(f"raw_response:{raw_response}") 63 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from config import Config 2 | from torch.utils.data import DataLoader 3 | from torch.optim import Adam 4 | from model.model import Model 5 | from model.reference_model import ReferenceModel 6 | from utils.data_load import CustomDataset 7 | from dpo import DPO 8 | 9 | 10 | class TrainDpo: 11 | def __init__(self): 12 | self.config = Config() 13 | # 演员和评论家模型 14 | self.model = Model(self.config) 15 | self.tokenizer = self.model.tokenizer 16 | # 获得策略模型优化器, 这里使用的是lora, 不优化全量数据 17 | self.model_opt = Adam(self.model.parameters(), lr=self.config.lr) 18 | # 参考模型 19 | self.reference_model = ReferenceModel(self.config) 20 | # 训练数据 21 | dataset = CustomDataset(self.config.data_path, self.tokenizer) 22 | self.data_loader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, 23 | collate_fn=dataset.collate_fn) 24 | self.dpo = DPO(self.model, self.model_opt, self.config) 25 | 26 | def train_dpo(self): 27 | for epoch in range(self.config.epochs): 28 | for batch_data in self.data_loader: 29 | ref_logits = self.reference_model(batch_data["inputs_ids"], batch_data["inputs_masks"]) # 获得参考模型的logit 30 | self.dpo.train(batch_data["inputs_ids"], batch_data["inputs_masks"], ref_logits, 31 | batch_data["labels_mask"]) 32 | 33 | self.save_model() 34 | 35 | def save_model(self): 36 | # 保存lora参数 37 | self.model.model.save_pretrained(self.config.save_lora_path, safe_serialization=False) 38 | 39 | 40 | if __name__ == '__main__': 41 | train_dpo = TrainDpo() 42 | train_dpo.train_dpo() 43 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from config import Config 4 | from peft import LoraConfig, PeftModel 5 | from config import LoraArguments 6 | 7 | 8 | class LoraModel(PeftModel): 9 | def __init__(self, config: Config, model): 10 | lora_args = LoraArguments() 11 | lora_config = LoraConfig( 12 | r=lora_args.lora_r, 13 | lora_alpha=lora_args.lora_alpha, 14 | target_modules=lora_args.lora_target_modules, 15 | lora_dropout=lora_args.lora_dropout, 16 | task_type="CAUSAL_LM", 17 | ) 18 | super().__init__(model, lora_config) 19 | if lora_args.is_reload_trained_params: 20 | super().from_pretrained(model, config.save_lora_path) 21 | for name, module in self.named_modules(): 22 | if 'lora_' in name: 23 | for param in module.parameters(): 24 | param.requires_grad = True 25 | 26 | def forward(self, input_ids, attention_mask): 27 | res = super().forward(input_ids, attention_mask, output_hidden_states=True) 28 | return res.logits 29 | 30 | 31 | class Model(torch.nn.Module): 32 | def __init__(self, config: Config): 33 | super().__init__() 34 | model = AutoModelForCausalLM.from_pretrained(config.gpt_model).to(config.device).eval() 35 | self.model = LoraModel(config, model) 36 | self.tokenizer = AutoTokenizer.from_pretrained(config.gpt_model) 37 | 38 | def forward(self, input_ids, attention_mask): 39 | logits = self.model(input_ids, attention_mask) 40 | return logits 41 | 42 | -------------------------------------------------------------------------------- /model/reference_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class ReferenceModel(nn.Module): 7 | def __init__(self, config): 8 | super(ReferenceModel, self).__init__() 9 | self.config = config 10 | self.reference_model = AutoModelForCausalLM.from_pretrained(self.config.gpt_model, torch_dtype="auto").to( 11 | self.config.device) 12 | 13 | @torch.no_grad() 14 | def forward(self, input_ids, attention_mask): 15 | logits = self.reference_model(input_ids=input_ids, 16 | attention_mask=attention_mask).logits 17 | return logits 18 | 19 | 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.40.0 2 | torch==2.3.0 3 | peft==0.10.0 -------------------------------------------------------------------------------- /utils/data_load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import json 4 | 5 | 6 | class CustomDataset(Dataset): 7 | def __init__(self, data_file, tokenizer): 8 | self.tokenizer = tokenizer 9 | with open(data_file, 'r', encoding="utf-8") as f: 10 | self.data = json.load(f) 11 | self.total_samples = len(self.data) 12 | 13 | 14 | def __len__(self): 15 | return self.total_samples 16 | 17 | def __getitem__(self, idx): 18 | # 根据索引加载数据 19 | # 这里可以根据需要从文件中读取数据并进行预处理 20 | line = self.data[idx] 21 | query = line["input"] 22 | instruction = line["instruction"] 23 | rejected = line["output"]["rejected"] 24 | chosen = line["output"]["chosen"] 25 | messages = [ 26 | {"role": "system", "content": "你是一个非常有帮助和智能的助手。"}, 27 | {"role": "instrution", "content": instruction}, 28 | {"role": "user", "content": query}, 29 | ] 30 | text = self.tokenizer.apply_chat_template( 31 | messages, 32 | tokenize=False, 33 | add_generation_prompt=True 34 | ) 35 | prompt_inputs = self.tokenizer.encode(text=text) 36 | rejected_inputs = self.tokenizer.encode(text=rejected) 37 | chosen_inputs = self.tokenizer.encode(text=chosen) 38 | return [prompt_inputs, rejected_inputs, chosen_inputs] 39 | 40 | def collate_fn(self, batch): 41 | def statisc_ids(data, labels_mask): 42 | max_length = max([len(i) for i in data]) 43 | attention_masks = [] 44 | return_ids = [] 45 | labels_masks = [] 46 | for one, mask in zip(data, labels_mask): 47 | padding_num = max_length - len(one) 48 | return_ids.append(one + [self.tokenizer.pad_token_id] * padding_num) 49 | labels_masks.append(mask + [0] * padding_num) 50 | attention_masks.append([1] * len(one) + [0] * padding_num) 51 | return return_ids, attention_masks, labels_masks 52 | 53 | inputs_ids = [] 54 | labels_mask = [] 55 | for i in range(1, 3): 56 | for one, prompt_input in zip(batch, [one[0] for one in batch]): 57 | res = one[i] 58 | inputs_ids.append(prompt_input + res) 59 | labels_mask.append([0] * len(prompt_input) + [1] * len(res)) 60 | inputs_ids, inputs_masks, labels_masks = statisc_ids(inputs_ids, labels_mask) 61 | return {"inputs_ids": torch.tensor(inputs_ids), "inputs_masks": torch.tensor(inputs_masks), 62 | "labels_mask": torch.tensor(labels_masks)} 63 | 64 | 65 | if __name__ == '__main__': 66 | from config import Config 67 | from transformers import AutoTokenizer 68 | 69 | config = Config() 70 | tokenizer = AutoTokenizer.from_pretrained(config.gpt_model) 71 | # 创建自定义数据集实例 72 | dataset = CustomDataset(config.data_path, tokenizer) 73 | 74 | # 创建数据加载器并指定批次大小 75 | batch_size = 2 76 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate_fn) 77 | 78 | # 使用生成器函数按需读取数据 79 | for batch in data_loader: 80 | print() 81 | # 在每个批次中进行模型训练 82 | # batch 包含了一个批次的样本数据 83 | # 在这里执行模型训练操作 84 | pass 85 | --------------------------------------------------------------------------------