├── model ├── __init__.py ├── tokenizer_config.json └── MokioModel.py ├── dataset ├── __init__.py └── lm_dataset.py ├── .python-version ├── torch_methods ├── unsqueeze_d.py ├── transpose_d.py ├── outer_d.py ├── arange_d.py ├── Dropout_d.py ├── view_d.py ├── cat_d.py ├── where_d.py └── Linear_d.py ├── main.py ├── .vscode └── settings.json ├── .gitignore ├── README.md ├── pyproject.toml ├── eval.py └── trainer ├── trainer_utils.py ├── train_lora.py ├── train_full_sft.py ├── train_pretrain.py ├── train_dpo.py ├── train_ppo.py └── train_grpo.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.13 2 | -------------------------------------------------------------------------------- /torch_methods/unsqueeze_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | t1 = torch.Tensor([1, 2, 3]) 3 | t2 = t1.unsqueeze(0) 4 | print(t2) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | def main(): 2 | print("Hello from mokiomind!") 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /torch_methods/transpose_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | t1=torch.Tensor([[1,2,3],[4,5,6]]) 4 | t1=t1.transpose(0,1) 5 | print(t1) -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python-envs.defaultEnvManager": "ms-python.python:venv", 3 | "python-envs.pythonProjects": [] 4 | } -------------------------------------------------------------------------------- /torch_methods/outer_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | v1=torch.tensor([1,2,3]) 3 | v2=torch.tensor([4,5,6]) 4 | result=torch.outer(v1,v2) 5 | print(result) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | 9 | # Virtual environments 10 | .venv 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 希望大家支持原项目minimind哦~https://github.com/jingyaogong/minimind 2 | 3 | ### 现在项目还有不完善和错误的地方,优先看视频,后续会慢慢补齐github这边 4 | 5 | ### 希望能为其他项目点点star🌟 6 | 7 | ### 感谢你的支持 8 | -------------------------------------------------------------------------------- /torch_methods/arange_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | t=torch.arange(0,10,2) 4 | print(t) # Output: tensor([0, 2, 4, 6, 8]) 5 | 6 | t2=torch.arange(5,0,-1) 7 | print(t2) # Output: tensor([5, 4, 3, 2, 1]) -------------------------------------------------------------------------------- /torch_methods/Dropout_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | dropout_layer = nn.Dropout(p=0.5) 5 | 6 | t1=torch.Tensor([1,2,3]) 7 | t2=dropout_layer(t1) 8 | # 这里Dropout丢弃了1,为了保持期望不变,将1和3扩大两倍 9 | print(t2) -------------------------------------------------------------------------------- /torch_methods/view_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | t = torch.tensor([[ 1, 2, 3, 4, 5, 6], 3 | [ 7, 8, 9, 10, 11, 12]]) 4 | t_view1 = t.view(3, 4) 5 | print(t_view1) 6 | t_view2 = t.view(4, 3) 7 | print(t_view2) -------------------------------------------------------------------------------- /torch_methods/cat_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | t1=torch.tensor([[1,2,3],[4,5,6]]) 4 | t2=torch.tensor([[7,8,9],[10,11,12]]) 5 | result=torch.cat((t1,t2),dim=0) 6 | print(result) 7 | 8 | result2=torch.cat((t1,t2),dim=1) 9 | print(result2) -------------------------------------------------------------------------------- /torch_methods/where_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | x=torch.tensor([1, 2, 3, 4, 5]) 4 | y=torch.tensor([10, 20, 30, 40, 50]) 5 | 6 | condition=(x>3) 7 | 8 | result=torch.where(condition,x,y) 9 | 10 | print(result) # Output: tensor([10, 20, 30, 4, 5]) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mokiomind" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.13" 7 | dependencies = [ 8 | "numpy>=2.3.4", 9 | "pandas>=2.3.3", 10 | "torch>=2.9.0", 11 | "transformers>=4.57.1", 12 | ] 13 | -------------------------------------------------------------------------------- /torch_methods/Linear_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | layer = nn.Linear(in_features=3, out_features=5, bias=True) 5 | t1 = torch.Tensor([1, 2, 3]) # shape: (3,) 6 | 7 | t2 = torch.Tensor([[1, 2, 3]]) # shape: (1, 3) 8 | # 这里应用的w和b是随机的,真实训练里会在optimizer上更新 9 | output2 = layer(t2) # shape: (1, 5) 10 | print(output2) -------------------------------------------------------------------------------- /dataset/lm_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from torch.utils.data import Dataset 4 | import torch 5 | import os 6 | 7 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 8 | 9 | 10 | class PretrainDataset(Dataset): 11 | def __init__(self, data_path, tokenizer, max_length=512): 12 | super().__init__() 13 | self.tokenizer = tokenizer 14 | self.max_length = max_length 15 | self.samples = self.load_data(data_path) 16 | 17 | def load_data(self, path): 18 | samples = [] 19 | with open(path, "r", encoding="utf-8") as f: 20 | for line_num, line in enumerate(f, 1): 21 | # 提取每一行内容放到sample 22 | data = json.loads(line.strip()) 23 | samples.append(data) 24 | return samples 25 | 26 | def __len__(self): 27 | return len(self.samples) 28 | 29 | def __getitem__(self, index): 30 | sample = self.samples[index] 31 | # 用tokenizer进行编码 32 | # 超过max_length的截断,不到的填充 33 | encoding = self.tokenizer( 34 | str(sample["text"]), 35 | max_length=self.max_length, 36 | padding="max_length", 37 | truncation=True, 38 | return_tensors="pt", 39 | ) 40 | 41 | input_ids = encoding.input_ids.squeeze() 42 | # 忽略padding产生的Y 43 | loss_mask = input_ids != self.tokenizer.pad_token_id 44 | # 第一个到倒数第二个token 45 | X = torch.tensor(input_ids[:-1], dtype=torch.long) 46 | # 第二个到最后一个token 47 | Y = torch.tensor(input_ids[1:], dtype=torch.long) 48 | loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) 49 | return X, Y, loss_mask 50 | -------------------------------------------------------------------------------- /model/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": false, 3 | "add_eos_token": false, 4 | "add_prefix_space": false, 5 | "added_tokens_decoder": { 6 | "0": { 7 | "content": "<|endoftext|>", 8 | "lstrip": false, 9 | "normalized": false, 10 | "rstrip": false, 11 | "single_word": false, 12 | "special": true 13 | }, 14 | "1": { 15 | "content": "<|im_start|>", 16 | "lstrip": false, 17 | "normalized": false, 18 | "rstrip": false, 19 | "single_word": false, 20 | "special": true 21 | }, 22 | "2": { 23 | "content": "<|im_end|>", 24 | "lstrip": false, 25 | "normalized": false, 26 | "rstrip": false, 27 | "single_word": false, 28 | "special": true 29 | } 30 | }, 31 | "additional_special_tokens": [], 32 | "bos_token": "<|im_start|>", 33 | "clean_up_tokenization_spaces": false, 34 | "eos_token": "<|im_end|>", 35 | "legacy": true, 36 | "model_max_length": 32768, 37 | "pad_token": "<|endoftext|>", 38 | "sp_model_kwargs": {}, 39 | "spaces_between_special_tokens": false, 40 | "tokenizer_class": "PreTrainedTokenizerFast", 41 | "unk_token": "<|endoftext|>", 42 | "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}" 43 | } -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import warnings 4 | import numpy as np 5 | import torch 6 | from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer 7 | from model.MokioModel import MokioMindConfig, MokioMindForCausalLM 8 | from trainer.trainer_utils import setup_seed 9 | 10 | warnings.filterwarnings("ignore") 11 | 12 | 13 | def init_model(args): 14 | tokenizer = AutoTokenizer.from_pretrained(args.load_from) 15 | if "model" in args.load_from: 16 | model = MokioMindForCausalLM( 17 | MokioMindConfig( 18 | hidden_size=args.hidden_size, 19 | num_hidden_layers=args.num_hidden_layers, 20 | inference_rope_scaling=args.inference_rope_scaling, 21 | ) 22 | ) 23 | moe_suffix = "_moe" if hasattr(args, "use_moe") and args.use_moe else "" 24 | ckp = f"./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth" 25 | model.load_state_dict(torch.load(ckp, map_location=args.device), strict=False) 26 | else: 27 | model = AutoModelForCausalLM.from_pretrained( 28 | args.load_from, trust_remote_code=True 29 | ) 30 | print( 31 | f"MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)" 32 | ) 33 | return model.eval().to(args.device), tokenizer 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser(description="MiniMind模型推理与对话") 37 | parser.add_argument( 38 | "--load_from", 39 | default="model", 40 | type=str, 41 | help="模型加载路径(model=原生torch权重,其他路径=transformers格式)", 42 | ) 43 | parser.add_argument("--save_dir", default="out", type=str, help="模型权重目录") 44 | parser.add_argument( 45 | "--weight", 46 | default="full_sft", 47 | type=str, 48 | help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)", 49 | ) 50 | parser.add_argument( 51 | "--lora_weight", 52 | default="None", 53 | type=str, 54 | help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)", 55 | ) 56 | parser.add_argument( 57 | "--hidden_size", 58 | default=512, 59 | type=int, 60 | help="隐藏层维度(512=Small-26M, 640=MoE-145M, 768=Base-104M)", 61 | ) 62 | parser.add_argument( 63 | "--num_hidden_layers", 64 | default=8, 65 | type=int, 66 | help="隐藏层数量(Small/MoE=8, Base=16)", 67 | ) 68 | parser.add_argument( 69 | "--use_moe", 70 | default=0, 71 | type=int, 72 | choices=[0, 1], 73 | help="是否使用MoE架构(0=否,1=是)", 74 | ) 75 | parser.add_argument( 76 | "--inference_rope_scaling", 77 | default=False, 78 | action="store_true", 79 | help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)", 80 | ) 81 | parser.add_argument( 82 | "--max_new_tokens", 83 | default=8192, 84 | type=int, 85 | help="最大生成长度(注意:并非模型实际长文本能力)", 86 | ) 87 | parser.add_argument( 88 | "--temperature", 89 | default=0.85, 90 | type=float, 91 | help="生成温度,控制随机性(0-1,越大越随机)", 92 | ) 93 | parser.add_argument( 94 | "--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)" 95 | ) 96 | parser.add_argument( 97 | "--historys", 98 | default=0, 99 | type=int, 100 | help="携带历史对话轮数(需为偶数,0表示不携带历史)", 101 | ) 102 | parser.add_argument( 103 | "--device", 104 | default="cuda" if torch.cuda.is_available() else "cpu", 105 | type=str, 106 | help="运行设备", 107 | ) 108 | args = parser.parse_args() 109 | 110 | prompts = [ 111 | "你有什么特长?", 112 | "为什么天空是蓝色的", 113 | "请用Python写一个计算斐波那契数列的函数", 114 | '解释一下"光合作用"的基本过程', 115 | "如果明天下雨,我应该如何出门", 116 | "比较一下猫和狗作为宠物的优缺点", 117 | "解释什么是机器学习", 118 | "推荐一些中国的美食", 119 | ] 120 | 121 | conversation = [] 122 | model, tokenizer = init_model(args) 123 | input_mode = int(input("[0] 自动测试\n[1] 手动输入\n")) 124 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 125 | 126 | prompt_iter = prompts if input_mode == 0 else iter(lambda: input("👶: "), "") 127 | for prompt in prompt_iter: 128 | setup_seed(2026) # or setup_seed(random.randint(0, 2048)) 129 | if input_mode == 0: 130 | print(f"👶: {prompt}") 131 | conversation = conversation[-args.historys :] if args.historys else [] 132 | conversation.append({"role": "user", "content": prompt}) 133 | 134 | templates = { 135 | "conversation": conversation, 136 | "tokenize": False, 137 | "add_generation_prompt": True, 138 | } 139 | if args.weight == "reason": 140 | templates["enable_thinking"] = True # 仅Reason模型使用 141 | inputs = ( 142 | tokenizer.apply_chat_template(**templates) 143 | if args.weight != "pretrain" 144 | else (tokenizer.bos_token + prompt) 145 | ) 146 | inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device) 147 | 148 | print("🤖️: ", end="") 149 | generated_ids = model.generate( 150 | inputs=inputs["input_ids"], 151 | attention_mask=inputs["attention_mask"], 152 | max_new_tokens=args.max_new_tokens, 153 | do_sample=True, 154 | streamer=streamer, 155 | pad_token_id=tokenizer.pad_token_id, 156 | eos_token_id=tokenizer.eos_token_id, 157 | top_p=args.top_p, 158 | temperature=args.temperature, 159 | repetition_penalty=1.0, 160 | ) 161 | response = tokenizer.decode( 162 | generated_ids[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True 163 | ) 164 | conversation.append({"role": "assistant", "content": response}) 165 | print("\n\n") 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /trainer/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.data import Sampler 8 | 9 | # 检查是否是主进程 10 | def is_main_process(): 11 | return not dist.is_initialized() or dist.get_rank() == 0 12 | 13 | # 日志 14 | def Logger(content): 15 | if is_main_process(): 16 | print(content) 17 | 18 | # 动态学习率计算 19 | def get_lr(current_step, total_steps, lr): 20 | return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) 21 | 22 | # 初始化分布式 23 | def init_distributed_mode(): 24 | if int(os.environ.get("RANK", -1)) == -1: 25 | return 0 # 非DDP模式 26 | 27 | dist.init_process_group(backend="nccl") 28 | local_rank = int(os.environ["LOCAL_RANK"]) 29 | torch.cuda.set_device(local_rank) 30 | return local_rank 31 | 32 | # 设置种子 33 | def setup_seed(seed: int): 34 | random.seed(seed) 35 | np.random.seed(seed) 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | 42 | # 设置检查点 43 | def lm_checkpoint( 44 | lm_config, 45 | weight="full_sft", 46 | model=None, 47 | optimizer=None, 48 | epoch=0, 49 | step=0, 50 | wandb=None, 51 | save_dir="checkpoints", 52 | **kwargs, 53 | ): 54 | os.makedirs(save_dir, exist_ok=True) 55 | 56 | moe_path = "_moe" if hasattr(lm_config, "use_moe") and lm_config.use_moe else "" 57 | ckp_path = f"{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth" 58 | resume_path = f"{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth" 59 | 60 | if model is not None: 61 | from torch.nn.parallel import DistributedDataParallel 62 | 63 | if isinstance(model, DistributedDataParallel): 64 | state_dict = model.module.state_dict() 65 | else: 66 | state_dict = model.state_dict() 67 | 68 | ckp_tmp = ckp_path + ".tmp" 69 | torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp) 70 | os.replace(ckp_tmp, ckp_path) 71 | 72 | wandb_id = None 73 | if wandb: 74 | if hasattr(wandb, "get_run"): 75 | run = wandb.get_run() 76 | wandb_id = getattr(run, "id", None) if run else None 77 | else: 78 | wandb_id = getattr(wandb, "id", None) 79 | 80 | resume_data = { 81 | "model": state_dict, 82 | "optimizer": optimizer.state_dict(), 83 | "epoch": epoch, 84 | "step": step, 85 | "world_size": dist.get_world_size() if dist.is_initialized() else 1, 86 | "wandb_id": wandb_id, 87 | } 88 | 89 | for key, value in kwargs.items(): 90 | if value is not None: 91 | if hasattr(value, "state_dict"): 92 | if isinstance(value, DistributedDataParallel): 93 | resume_data[key] = value.module.state_dict() 94 | else: 95 | resume_data[key] = value.state_dict() 96 | else: 97 | resume_data[key] = value 98 | 99 | resume_tmp = resume_path + ".tmp" 100 | torch.save(resume_data, resume_tmp) 101 | os.replace(resume_tmp, resume_path) 102 | 103 | else: # 加载模式 104 | if os.path.exists(resume_path): 105 | ckp_data = torch.load(resume_path, map_location="cpu") 106 | saved_ws = ckp_data.get("world_size", 1) 107 | current_ws = dist.get_world_size() if dist.is_initialized() else 1 108 | 109 | if saved_ws != current_ws: 110 | ckp_data["step"] = ckp_data["step"] * saved_ws // current_ws 111 | Logger( 112 | f"GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data['step']}" 113 | ) 114 | 115 | return ckp_data 116 | return None 117 | 118 | # 初始化模型 119 | def init_model( 120 | lm_config, 121 | from_weight="pretrain", 122 | tokenizer_path=None, 123 | save_dir="out", 124 | device="cuda", 125 | ): 126 | from transformers import AutoTokenizer 127 | from model.MokioModel import MokioMindForCausalLM 128 | 129 | # 如果没有指定 tokenizer_path,使用项目根目录下的 model 文件夹 130 | if tokenizer_path is None: 131 | # 获取当前文件所在目录的父目录(项目根目录) 132 | current_dir = os.path.dirname(os.path.abspath(__file__)) 133 | project_root = os.path.dirname(current_dir) 134 | tokenizer_path = os.path.join(project_root, "model") 135 | 136 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 137 | 138 | model = MokioMindForCausalLM(lm_config) 139 | 140 | if from_weight != "none": 141 | moe_suffix = ( 142 | "_moe" if hasattr(lm_config, "use_moe") and lm_config.use_moe else "" 143 | ) 144 | weight_path = ( 145 | f"{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth" 146 | ) 147 | 148 | weights = torch.load(weight_path, map_location=device) 149 | 150 | model.load_state_dict(weights, strict=False) 151 | 152 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 153 | Logger(f"所加载Model可训练参数:{total_params / 1e6:.3f} 百万") 154 | 155 | return model.to(device), tokenizer 156 | 157 | 158 | class SkipBatchSampler(Sampler): 159 | def __init__(self, sampler, batch_size, skip_batches=0): 160 | self.sampler = sampler # 161 | self.batch_size = batch_size 162 | self.skip_batches = skip_batches 163 | 164 | def __iter__(self): 165 | batch = [] # 当前批次 166 | skipped = 0 # 已跳过的批次数 167 | 168 | for idx in self.sampler: 169 | batch.append(idx) # 添加样本到当前批次 170 | 171 | if len(batch) == self.batch_size: 172 | if skipped < self.skip_batches: 173 | skipped += 1 # 增加跳过计数 174 | batch = [] # 清空批次,不返回 175 | continue # 跳过这个批次 176 | 177 | yield batch 178 | batch = [] # 重置批次 179 | 180 | if len(batch) > 0 and skipped >= self.skip_batches: 181 | yield batch 182 | 183 | def __len__(self): 184 | total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size 185 | 186 | return max(0, total_batches - self.skip_batches) 187 | -------------------------------------------------------------------------------- /trainer/train_lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | __package__ = "trainer" 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 6 | 7 | import argparse 8 | import time 9 | import warnings 10 | import torch 11 | import torch.distributed as dist 12 | from contextlib import nullcontext 13 | from torch import optim, nn 14 | from torch.nn.parallel import DistributedDataParallel 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | from model.model_minimind import MiniMindConfig 17 | from dataset.lm_dataset import SFTDataset 18 | from model.model_lora import save_lora, apply_lora 19 | from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler 20 | 21 | warnings.filterwarnings('ignore') 22 | 23 | def train_epoch(epoch,loader,iters,lora_params,start_step=0,wandb=None): 24 | # reduction='none'表示不对损失进行平均处理 25 | loss_fct=nn.CrossEntropyLoss(reduction='none') 26 | start_time=time.time() 27 | for step,(X,Y,loss_mask) in enumerate(loader,start=start_step+1): 28 | X=X.to(args.device) 29 | Y=Y.to(args.device) 30 | loss_mask=loss_mask.to(args.device) 31 | # 动态调整学习率 32 | # 余弦退火 33 | lr=get_lr(epoch*iters+step,args.epochs*iters,args.learning_rate) 34 | 35 | 36 | for param_group in optimizer.param_groups: 37 | param_group['lr']=lr 38 | 39 | # 混合精度上下文训练 40 | with autocast_ctx: 41 | # 前向传播 42 | res=model(X) 43 | 44 | # 损失计算 45 | loss=loss_fct(res.logits.view(-1,res.logits.size(-1)),Y.view(-1)) 46 | 47 | loss = (loss * loss_mask).sum() / loss_mask.sum() 48 | loss += res.aux_loss 49 | loss = loss / args.accumulation_steps 50 | # 混合精度反向传播 51 | scaler.scale(loss).backward() 52 | if (step+1)%args.accumulation_steps==0: 53 | scaler.unscale_(optimizer) 54 | # 梯度裁剪,防止梯度爆炸 55 | torch.nn.utils.clip_grad_norm_(lora_params,args.grad_clip) 56 | 57 | scaler.step(optimizer) 58 | scaler.update() 59 | 60 | optimizer.zero_grad(set_to_none=True) 61 | 62 | # 每log_interval步或者最后一步打印一次日志 63 | if step % args.log_interval == 0 or step == iters - 1: 64 | spend_time = time.time() - start_time 65 | current_loss = loss.item() * args.accumulation_steps 66 | current_lr = optimizer.param_groups[-1]['lr'] 67 | eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 68 | 69 | Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') 70 | 71 | if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}) 72 | 73 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): 74 | # 评估测评模型 75 | model.eval() 76 | 77 | # LoRA保存,只保存AB矩阵 78 | lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth' 79 | # LoRA只保存LoRA权重 80 | save_lora(model, lora_save_path) 81 | lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') 82 | model.train() 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning") 87 | parser.add_argument("--save_dir", type=str, default="../out/lora", help="模型保存目录") 88 | parser.add_argument("--lora_name", type=str, default="lora_identity", help="LoRA权重名称(如lora_identity/lora_medical等)") 89 | parser.add_argument("--epochs", type=int, default=50, help="训练轮数") 90 | parser.add_argument("--batch_size", type=int, default=32, help="batch size") 91 | parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率") 92 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") 93 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") 94 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") 95 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") 96 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") 97 | parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔") 98 | parser.add_argument("--save_interval", type=int, default=1, help="模型保存间隔") 99 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") 100 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") 101 | parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度") 102 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") 103 | parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径") 104 | parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练,默认full_sft") 105 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") 106 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") 107 | parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名") 108 | args = parser.parse_args() 109 | 110 | # ========== 1. 初始化环境和随机种子 ========== 111 | local_rank = init_distributed_mode() 112 | if dist.is_initialized(): args.device = f"cuda:{local_rank}" 113 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) 114 | 115 | # ========== 2. 配置目录、模型参数、检查ckp ========== 116 | os.makedirs(args.save_dir, exist_ok=True) 117 | lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) 118 | ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None 119 | 120 | # ========== 3. 设置混合精度 ========== 121 | device_type = "cuda" if "cuda" in args.device else "cpu" 122 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 123 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) 124 | 125 | # ========== 4. 配wandb ========== 126 | wandb = None 127 | if args.use_wandb and is_main_process(): 128 | import swanlab as wandb 129 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None 130 | resume = 'must' if wandb_id else None 131 | wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" 132 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) 133 | 134 | # ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ========== 135 | model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) 136 | apply_lora(model) 137 | 138 | # 统计参数 139 | total_params = sum(p.numel() for p in model.parameters()) 140 | lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) 141 | Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M") 142 | Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M") 143 | Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%") 144 | 145 | # 冻结非LoRA参数,收集LoRA参数 146 | lora_params = [] 147 | for name, param in model.named_parameters(): 148 | if 'lora' in name: 149 | param.requires_grad = True 150 | lora_params.append(param) 151 | else: 152 | param.requires_grad = False 153 | 154 | # ========== 6. 定义数据和优化器 ========== 155 | train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) 156 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None 157 | scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) 158 | optimizer = optim.AdamW(lora_params, lr=args.learning_rate) 159 | 160 | # ========== 7. 从ckp恢复状态 ========== 161 | start_epoch, start_step = 0, 0 162 | if ckp_data: 163 | model.load_state_dict(ckp_data['model'], strict=False) 164 | optimizer.load_state_dict(ckp_data['optimizer']) 165 | scaler.load_state_dict(ckp_data['scaler']) 166 | start_epoch = ckp_data['epoch'] 167 | start_step = ckp_data.get('step', 0) 168 | 169 | # ========== 8. DDP包模型 ========== 170 | if dist.is_initialized(): 171 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} 172 | model = DistributedDataParallel(model, device_ids=[local_rank]) 173 | 174 | # ========== 9. 开始训练 ========== 175 | for epoch in range(start_epoch, args.epochs): 176 | train_sampler and train_sampler.set_epoch(epoch) 177 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 178 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) 179 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) 180 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') 181 | train_epoch(epoch, loader, len(loader) + start_step + 1, lora_params, start_step, wandb) 182 | else: # 默认从头开始 183 | loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) 184 | train_epoch(epoch, loader, len(loader), lora_params, 0, wandb) 185 | -------------------------------------------------------------------------------- /trainer/train_full_sft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # 📚 Python模块系统和路径管理 5 | __package__ = "trainer" 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | 8 | import argparse # 命令行参数解析 9 | import time # 时间测量 10 | import warnings # 警告控制 11 | import torch # PyTorch深度学习框架 12 | import torch.distributed as dist # 分布式训练 13 | from contextlib import nullcontext # 上下文管理器 14 | from torch import optim, nn # 优化器和神经网络 15 | from torch.nn.parallel import DistributedDataParallel # 分布式并行 16 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载 17 | 18 | from model.model_minimind import MiniMindConfig # 模型配置 19 | from dataset.lm_dataset import SFTDataset # SFT数据集 20 | from trainer.trainer_utils import ( # 训练工具 21 | get_lr, Logger, is_main_process, lm_checkpoint, 22 | init_distributed_mode, setup_seed, init_model, SkipBatchSampler 23 | ) 24 | 25 | warnings.filterwarnings('ignore') 26 | 27 | def train_epoch(epoch,loader,iters,start_step=0,wandb=None): 28 | loss_fct=nn.CrossEntropyLoss(reduction='none') 29 | start_time=time.time() 30 | 31 | for step,(X,Y,loss_mask) in enumerate(loader,start=start_step+1): 32 | X=X.to(args.device) 33 | Y=Y.to(args.device) 34 | loss_mask=loss_mask.to(args.device) 35 | 36 | # 动态调整学习率 37 | lr=get_lr(epoch*iters+step,args.epochs*iters,args.learning_rate) 38 | for param_group in optimizer.param_groups: 39 | param_group['lr']=lr 40 | 41 | with autocast_ctx: 42 | # 前向传播 43 | res=model(X) 44 | 45 | # 损失计算 46 | loss=loss_fct(res.logits.view(-1,res.logits.size(-1)),Y.view(-1)).view(Y.size()) 47 | 48 | loss= (loss * loss_mask).sum() / loss_mask.sum() 49 | 50 | loss+=res.aux_loss 51 | 52 | loss=loss/args.acculation_steps 53 | 54 | scaler.scale(loss).backward() 55 | 56 | if (step+1)%args.accumulation_steps==0: 57 | scaler.unscale_(optimizer) 58 | 59 | torch.nn.utils.clip_grad_norm_(model.parameters(),args.grad_clip) 60 | 61 | scaler.step(optimizer) 62 | scaler.update() 63 | 64 | optimizer.zero_grad(set_to_none=True) 65 | 66 | if step % args.log_interval == 0 or step == iters - 1: 67 | spend_time = time.time() - start_time 68 | current_loss = loss.item() * args.accumulation_steps # 恢复真实损失值 69 | current_lr = optimizer.param_groups[-1]['lr'] 70 | eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 71 | 72 | Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') 73 | 74 | # 记录到实验跟踪系统 75 | if wandb: 76 | wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}) 77 | 78 | # 📚 SFT模型检查点保存 79 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): 80 | model.eval() # 切换到评估模式 81 | 82 | # 构建SFT模型保存路径 83 | moe_suffix = '_moe' if lm_config.use_moe else '' 84 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' 85 | 86 | # 处理分布式模型的状态字典 87 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 88 | state_dict = model.module.state_dict() 89 | else: 90 | state_dict = model.state_dict() 91 | 92 | # 半精度保存节省空间 93 | state_dict = {k: v.half() for k, v in state_dict.items()} 94 | torch.save(state_dict, ckp) 95 | # 保存完整训练状态 96 | lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, 97 | epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler) 98 | model.train() # 恢复训练模式 99 | 100 | if __name__ == "__main__": 101 | """ 102 | SFT主函数:监督微调脚本的入口点 103 | 104 | 📚 SFT与预训练的参数差异: 105 | - 学习率更小:5e-7 vs 5e-4(预训练) 106 | - 训练轮数较少:2轮 vs 多轮 107 | - batch_size可以更小:已有基础能力,不需要大batch 108 | - 累积步数通常为1:SFT数据质量高,不需要太多累积 109 | """ 110 | 111 | parser = argparse.ArgumentParser(description="MiniMind Full SFT") 112 | 113 | # ========== 基础训练参数 ========== 114 | parser.add_argument("--save_dir", type=str, default="../out", 115 | help="模型保存目录") 116 | parser.add_argument('--save_weight', default='full_sft', type=str, 117 | help="保存权重的前缀名") 118 | parser.add_argument("--epochs", type=int, default=2, 119 | help="训练轮数(SFT通常2-5轮即可)") 120 | parser.add_argument("--batch_size", type=int, default=16, 121 | help="batch size(SFT可以使用较小的batch)") 122 | 123 | # 📚 SFT学习率设置知识点 124 | # SFT学习率通常比预训练小1-2个数量级 125 | # 因为模型已经有了基础能力,只需要微调 126 | parser.add_argument("--learning_rate", type=float, default=5e-7, 127 | help="初始学习率(比预训练小很多)") 128 | 129 | # ========== 硬件配置 ========== 130 | parser.add_argument("--device", type=str, 131 | default="cuda:0" if torch.cuda.is_available() else "cpu", 132 | help="训练设备") 133 | parser.add_argument("--dtype", type=str, default="bfloat16", 134 | help="混合精度类型") 135 | parser.add_argument("--num_workers", type=int, default=1, 136 | help="数据加载线程数") 137 | 138 | # ========== 训练策略 ========== 139 | # 📚 SFT梯度累积知识点 140 | # SFT数据质量高,通常不需要大量梯度累积 141 | # accumulation_steps=1 意味着每个batch都更新参数 142 | parser.add_argument("--accumulation_steps", type=int, default=1, 143 | help="梯度累积步数(SFT通常设为1)") 144 | parser.add_argument("--grad_clip", type=float, default=1.0, 145 | help="梯度裁剪阈值") 146 | parser.add_argument("--log_interval", type=int, default=100, 147 | help="日志打印间隔") 148 | parser.add_argument("--save_interval", type=int, default=100, 149 | help="模型保存间隔") 150 | 151 | # ========== 模型架构参数 ========== 152 | parser.add_argument('--hidden_size', default=512, type=int, 153 | help="隐藏层维度") 154 | parser.add_argument('--num_hidden_layers', default=8, type=int, 155 | help="隐藏层数量") 156 | parser.add_argument('--max_seq_len', default=512, type=int, 157 | help="训练的最大截断长度") 158 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], 159 | help="是否使用MoE架构(0=否,1=是)") 160 | 161 | # ========== SFT数据和恢复参数 ========== 162 | # 📚 SFT数据路径知识点 163 | # SFT数据通常是结构化的问答对或对话数据 164 | # 包含instruction和response两部分 165 | parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", 166 | help="SFT训练数据路径") 167 | 168 | # 📚 SFT权重继承知识点 169 | # SFT通常从预训练模型开始,而不是从零开始 170 | # 'pretrain'表示从预训练权重开始微调 171 | parser.add_argument('--from_weight', default='pretrain', type=str, 172 | help="基于哪个权重训练(通常从预训练权重开始)") 173 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], 174 | help="是否自动检测&续训(0=否,1=是)") 175 | 176 | # ========== 实验跟踪 ========== 177 | parser.add_argument("--use_wandb", action="store_true", 178 | help="是否使用wandb") 179 | parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", 180 | help="wandb项目名") 181 | 182 | args = parser.parse_args() 183 | 184 | # ========== 1. 初始化环境和随机种子 ========== 185 | local_rank = init_distributed_mode() 186 | if dist.is_initialized(): args.device = f"cuda:{local_rank}" 187 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) 188 | 189 | # ========== 2. 配置目录、模型参数、检查ckp ========== 190 | os.makedirs(args.save_dir, exist_ok=True) 191 | lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) 192 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None 193 | 194 | # ========== 3. 设置混合精度 ========== 195 | device_type = "cuda" if "cuda" in args.device else "cpu" 196 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 197 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) 198 | 199 | # ========== 4. 配wandb ========== 200 | wandb = None 201 | if args.use_wandb and is_main_process(): 202 | import swanlab as wandb 203 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None 204 | resume = 'must' if wandb_id else None 205 | wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" 206 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) 207 | 208 | # ========== 5. 定义模型、数据、优化器 ========== 209 | model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) 210 | train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) 211 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None 212 | scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) 213 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 214 | 215 | # ========== 6. 从ckp恢复状态 ========== 216 | start_epoch, start_step = 0, 0 217 | if ckp_data: 218 | model.load_state_dict(ckp_data['model']) 219 | optimizer.load_state_dict(ckp_data['optimizer']) 220 | scaler.load_state_dict(ckp_data['scaler']) 221 | start_epoch = ckp_data['epoch'] 222 | start_step = ckp_data.get('step', 0) 223 | 224 | # ========== 7. DDP包模型 ========== 225 | if dist.is_initialized(): 226 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} 227 | model = DistributedDataParallel(model, device_ids=[local_rank]) 228 | 229 | # ========== 8. 开始训练 ========== 230 | for epoch in range(start_epoch, args.epochs): 231 | train_sampler and train_sampler.set_epoch(epoch) 232 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 233 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) 234 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) 235 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') 236 | train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb) 237 | else: # 默认从头开始 238 | loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) 239 | train_epoch(epoch, loader, len(loader), 0, wandb) 240 | -------------------------------------------------------------------------------- /trainer/train_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | __package__ = "trainer" 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 7 | 8 | import argparse # 命令行参数解析 9 | import time # 时间统计 10 | import warnings # 警告控制 11 | import torch 12 | import torch.distributed as dist # 分布式训练支持 13 | from contextlib import nullcontext # 上下文管理器 14 | from torch import optim, nn # 优化器和神经网络模块 15 | from torch.nn.parallel import DistributedDataParallel # 分布式数据并行 16 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载器 17 | 18 | from model.MokioModel import MokioMindConfig 19 | from dataset.lm_dataset import PretrainDataset 20 | from trainer.trainer_utils import ( # 训练工具函数 21 | get_lr, 22 | Logger, 23 | is_main_process, 24 | lm_checkpoint, 25 | init_distributed_mode, 26 | setup_seed, 27 | init_model, 28 | SkipBatchSampler, 29 | ) 30 | 31 | # 忽略警告信息,保持输出清洁 32 | warnings.filterwarnings("ignore") 33 | 34 | 35 | def train_epoch(epoch, loader, iters, start_step=0, wandb=None): 36 | loss_fct = nn.CrossEntropyLoss(reduction="none") 37 | start_time = time.time() # 记录开始时间 38 | 39 | # 遍历数据批次 40 | for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): 41 | X = X.to(args.device) 42 | Y = Y.to(args.device) 43 | loss_mask = loss_mask.to(args.device) 44 | 45 | lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) 46 | 47 | for param_group in optimizer.param_groups: 48 | param_group["lr"] = lr 49 | 50 | with autocast_ctx: 51 | # 前向传播 52 | res = model(X) 53 | 54 | loss = loss_fct( 55 | res.logits.view(-1, res.logits.size(-1)), # [batch*seq, vocab_size] 56 | Y.view(-1), # [batch*seq] 57 | ).view(Y.size()) # 恢复为 [batch_size, seq_len] 58 | 59 | loss = (loss * loss_mask).sum() / loss_mask.sum() 60 | 61 | loss+=res.aux_loss 62 | 63 | loss = loss / args.accumulation_steps 64 | 65 | scaler.scale(loss).backward() 66 | 67 | if (step + 1) % args.accumulation_steps == 0: 68 | # scaler.unscale_(): 还原梯度的真实值 69 | scaler.unscale_(optimizer) 70 | 71 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 72 | 73 | # 📚 优化器更新知识点 74 | # scaler.step(): 执行参数更新 75 | # scaler.update(): 更新scaler的缩放因子 76 | scaler.step(optimizer) 77 | scaler.update() 78 | 79 | optimizer.zero_grad(set_to_none=True) 80 | 81 | if step % args.log_interval == 0 or step == iters - 1: 82 | spend_time = time.time() - start_time 83 | current_loss = loss.item() * args.accumulation_steps # 恢复真实损失值 84 | current_lr = optimizer.param_groups[-1]["lr"] # 当前学习率 85 | 86 | eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 87 | 88 | Logger( 89 | f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:" 90 | ) 91 | 92 | # 记录到实验跟踪系统 93 | if wandb: 94 | wandb.log( 95 | {"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min} 96 | ) 97 | 98 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): 99 | model.eval() # 切换到评估模式 100 | 101 | # 构建保存路径 102 | moe_suffix = ( 103 | "_moe" if hasattr(lm_config, "use_moe") and lm_config.use_moe else "" 104 | ) 105 | ckp = f"{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth" 106 | 107 | # 📚 分布式模型保存知识点 108 | # DDP模型需要通过.module访问真正的模型 109 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 110 | state_dict = model.module.state_dict() 111 | else: 112 | state_dict = model.state_dict() 113 | 114 | # 📚 半精度保存知识点 115 | # 将float32参数转为float16,减少存储空间 116 | state_dict = {k: v.half() for k, v in state_dict.items()} 117 | torch.save(state_dict, ckp) 118 | 119 | # 保存完整训练状态 120 | lm_checkpoint( 121 | lm_config, 122 | weight=args.save_weight, 123 | model=model, 124 | optimizer=optimizer, 125 | scaler=scaler, 126 | epoch=epoch, 127 | step=step, 128 | wandb=wandb, 129 | save_dir="checkpoints", 130 | ) 131 | 132 | model.train() # 恢复训练模式 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser(description="MiniMind Pretraining") 137 | 138 | # ========== 基础训练参数 ========== 139 | parser.add_argument("--save_dir", type=str, default="out", help="模型保存目录") 140 | parser.add_argument( 141 | "--save_weight", default="pretrain", type=str, help="保存权重的前缀名" 142 | ) 143 | parser.add_argument( 144 | "--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)" 145 | ) 146 | parser.add_argument("--batch_size", type=int, default=32, help="batch size") 147 | parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率") 148 | 149 | # ========== 硬件和性能参数 ========== 150 | parser.add_argument( 151 | "--device", 152 | type=str, 153 | default="cuda:0" if torch.cuda.is_available() else "cpu", 154 | help="训练设备", 155 | ) 156 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") 157 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") 158 | 159 | # ========== 训练策略参数 ========== 160 | parser.add_argument( 161 | "--accumulation_steps", type=int, default=8, help="梯度累积步数" 162 | ) 163 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") 164 | parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") 165 | parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔") 166 | 167 | # ========== 模型架构参数 ========== 168 | parser.add_argument("--hidden_size", default=512, type=int, help="隐藏层维度") 169 | parser.add_argument("--num_hidden_layers", default=8, type=int, help="隐藏层数量") 170 | parser.add_argument( 171 | "--max_seq_len", default=512, type=int, help="训练的最大截断长度" 172 | ) 173 | parser.add_argument( 174 | "--use_moe", 175 | default=0, 176 | type=int, 177 | choices=[0, 1], 178 | help="是否使用MoE架构(0=否,1=是)", 179 | ) 180 | 181 | # ========== 数据和恢复参数 ========== 182 | parser.add_argument( 183 | "--data_path", 184 | type=str, 185 | default="dataset/pretrain_hq.jsonl", 186 | help="预训练数据路径", 187 | ) 188 | parser.add_argument( 189 | "--from_weight", 190 | default="none", 191 | type=str, 192 | help="基于哪个权重训练,为none则从头开始", 193 | ) 194 | parser.add_argument( 195 | "--from_resume", 196 | default=0, 197 | type=int, 198 | choices=[0, 1], 199 | help="是否自动检测&续训(0=否,1=是)", 200 | ) 201 | 202 | # ========== 实验跟踪参数 ========== 203 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") 204 | parser.add_argument( 205 | "--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名" 206 | ) 207 | 208 | # 解析命令行参数 209 | args = parser.parse_args() 210 | 211 | # ========== 1. 初始化环境和随机种子 ========== 212 | """ 213 | 📚 分布式训练初始化知识点: 214 | - local_rank: 当前进程在本机上的GPU编号 215 | - 随机种子: 确保不同进程有不同但可复现的随机序列 216 | - 这样既保证了随机性,又保证了可复现性 217 | """ 218 | local_rank = init_distributed_mode() 219 | if dist.is_initialized(): 220 | args.device = f"cuda:{local_rank}" # 分布式训练时使用对应的GPU 221 | 222 | # 📚 随机种子设置知识点 223 | # 不同进程使用不同的种子,避免数据采样完全相同 224 | # 42是基础种子,每个进程加上自己的rank保证不同 225 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) 226 | 227 | # ========== 2. 配置目录、模型参数、检查点 ========== 228 | """ 229 | 📚 模型配置和检查点管理: 230 | - 创建保存目录 231 | - 构建模型配置对象 232 | - 尝试加载断点续训数据 233 | """ 234 | os.makedirs(args.save_dir, exist_ok=True) # 确保保存目录存在 235 | 236 | # 创建MiniMind模型配置 237 | lm_config = MokioMindConfig( 238 | hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,use_moe=bool(args.use_moe) 239 | ) 240 | 241 | # 📚 断点续训知识点 242 | # 如果开启了断点续训,尝试加载之前的训练状态 243 | ckp_data = ( 244 | lm_checkpoint(lm_config, weight=args.save_weight, save_dir="checkpoints") 245 | if args.from_resume == 1 246 | else None 247 | ) 248 | 249 | # ========== 3. 设置混合精度 ========== 250 | """ 251 | 📚 混合精度训练知识点: 252 | - bfloat16: Google开发,数值范围大,更稳定 253 | - float16: 标准半精度,节省内存但可能溢出 254 | - autocast: 自动选择精度,关键运算用float32 255 | """ 256 | device_type = "cuda" if "cuda" in args.device else "cpu" 257 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 258 | 259 | # 📚 上下文管理器知识点 260 | # CPU不支持autocast,使用nullcontext作为空操作 261 | autocast_ctx = ( 262 | nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) 263 | ) 264 | 265 | # ========== 4. 配置WandB实验跟踪 ========== 266 | """ 267 | 📚 实验跟踪系统知识点: 268 | - WandB: 实验管理平台,记录训练过程 269 | - SwanLab: 国产替代方案 270 | - 支持断点续训时恢复到同一个实验 271 | """ 272 | wandb = None 273 | if args.use_wandb and is_main_process(): 274 | # 使用SwanLab作为WandB的替代 275 | import swanlab as wandb 276 | 277 | # 📚 实验恢复知识点 278 | # 如果有检查点数据,获取之前的wandb_id来恢复实验 279 | wandb_id = ckp_data.get("wandb_id") if ckp_data else None 280 | resume = "must" if wandb_id else None # 必须恢复到指定实验 281 | 282 | # 构建实验名称,包含关键超参数 283 | wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" 284 | wandb.init( 285 | project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume 286 | ) 287 | 288 | # ========== 5. 定义模型、数据、优化器 ========== 289 | """ 290 | 📚 训练组件初始化: 291 | - 模型: 根据配置创建MiniMind模型 292 | - 数据集: 加载预训练数据 293 | - 采样器: 分布式训练的数据分配 294 | - 优化器: AdamW优化器 295 | - 缩放器: 混合精度训练的梯度缩放 296 | """ 297 | # 初始化模型和分词器 298 | model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) 299 | 300 | train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len) 301 | 302 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None 303 | 304 | scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == "float16")) 305 | 306 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 307 | 308 | start_epoch, start_step = 0, 0 309 | if ckp_data: 310 | # 恢复模型参数 311 | model.load_state_dict(ckp_data["model"]) 312 | # 恢复优化器状态(动量、方差估计等) 313 | optimizer.load_state_dict(ckp_data["optimizer"]) 314 | # 恢复梯度缩放器状态 315 | scaler.load_state_dict(ckp_data["scaler"]) 316 | # 恢复训练进度 317 | start_epoch = ckp_data["epoch"] 318 | start_step = ckp_data.get("step", 0) 319 | 320 | if dist.is_initialized(): 321 | # 📚 RoPE位置编码特殊处理 322 | # freqs_cos, freqs_sin是位置编码缓存,不需要梯度同步 323 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} 324 | model = DistributedDataParallel(model, device_ids=[local_rank]) 325 | 326 | for epoch in range(start_epoch, args.epochs): 327 | # 📚 分布式采样器epoch设置 328 | # 每个epoch设置不同的随机种子,确保数据顺序随机化 329 | if train_sampler: 330 | train_sampler.set_epoch(epoch) 331 | 332 | # 📚 断点续训逻辑 333 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 334 | # 使用跳批采样器,跳过已训练的数据 335 | batch_sampler = SkipBatchSampler( 336 | train_sampler or range(len(train_ds)), args.batch_size, start_step + 1 337 | ) 338 | loader = DataLoader( 339 | train_ds, 340 | batch_sampler=batch_sampler, 341 | num_workers=args.num_workers, 342 | pin_memory=True, 343 | ) 344 | Logger( 345 | f"Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始" 346 | ) 347 | train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb) 348 | else: # 默认从头开始 349 | loader = DataLoader( 350 | train_ds, 351 | batch_size=args.batch_size, 352 | shuffle=(train_sampler is None), 353 | sampler=train_sampler, 354 | num_workers=args.num_workers, 355 | pin_memory=True, 356 | ) 357 | train_epoch(epoch, loader, len(loader), 0, wandb) 358 | -------------------------------------------------------------------------------- /trainer/train_dpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # 📚 Python模块系统 5 | __package__ = "trainer" 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | 8 | import argparse # 命令行参数解析 9 | import time # 时间统计 10 | import warnings # 警告控制 11 | import torch # PyTorch深度学习框架 12 | import torch.nn.functional as F # 神经网络函数 13 | import torch.distributed as dist # 分布式训练支持 14 | from contextlib import nullcontext # 上下文管理器 15 | from torch import optim # 优化器 16 | from torch.nn.parallel import DistributedDataParallel # 分布式数据并行 17 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载 18 | 19 | # MiniMind相关组件 20 | from model.MokioModel import MokioMindConfig # 模型配置 21 | from dataset.lm_dataset import DPODataset # DPO数据集 22 | from trainer.trainer_utils import ( # 训练工具函数 23 | get_lr, Logger, is_main_process, lm_checkpoint, 24 | init_distributed_mode, setup_seed, init_model, SkipBatchSampler 25 | ) 26 | 27 | def logits_to_log_probs(logits,labels): 28 | #词表logits转换为log概率 29 | log_probs=F.log_softmax(logits,dim=2) 30 | #从log词表概率里选出label对应的log概率 31 | #也就是从拿到token在其对应位置的概率 32 | log_probs_per_token=torch.gather(log_probs,dim=2,index=labels.unsqueeze(2)).unsqueeze(-1) 33 | return log_probs_per_token 34 | 35 | # DPO的loss计算 36 | # 公式:L = -log(σ(β * (π(y_w) - π(y_l) - (π_ref(y_w) - π_ref(y_l))))) 37 | def dpo_loss(ref_log_probs,policy_log_probs,mask,beta): 38 | 39 | seq_lengths=mask.sum(dim=1,keepdim=True) 40 | clamp_min(1e-8) 41 | # 计算ref和policy的序列log概率均值 42 | ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze() 43 | policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze() 44 | 45 | # 分别获取chosen和rejected的ref和policy的log概率 46 | batch_size = ref_log_probs.shape[0] 47 | chosen_ref_log_probs = ref_log_probs[:batch_size // 2] 48 | reject_ref_log_probs = ref_log_probs[batch_size // 2:] 49 | chosen_policy_log_probs = policy_log_probs[:batch_size // 2] 50 | reject_policy_log_probs = policy_log_probs[batch_size // 2:] 51 | # 计算策略模型的log概率差异 52 | pi_logratios = chosen_policy_log_probs - reject_policy_log_probs 53 | # 参考模型的log概率差异 54 | ref_logratios = chosen_ref_log_probs - reject_ref_log_probs 55 | # DPO损失计算 56 | logits = pi_logratios - ref_logratios 57 | loss = -F.logsigmoid(beta * logits) 58 | return loss.mean() 59 | 60 | def train_epoch(epoch,loader,iters,ref_model,lm_config,start_step=0,wandb=None,beta=0.1): 61 | start_time = time.time() 62 | for step,batch in enumerate(loader,start=start_step+1): 63 | x_chosen = batch['x_chosen'].to(args.device) 64 | x_rejected = batch['x_rejected'].to(args.device) 65 | y_chosen = batch['y_chosen'].to(args.device) 66 | y_rejected = batch['y_rejected'].to(args.device) 67 | mask_chosen = batch['mask_chosen'].to(args.device) 68 | mask_rejected = batch['mask_rejected'].to(args.device) 69 | 70 | x = torch.cat([x_chosen, x_rejected], dim=0) 71 | y = torch.cat([y_chosen, y_rejected], dim=0) 72 | mask = torch.cat([mask_chosen, mask_rejected], dim=0) 73 | 74 | # 📚 学习率调度 75 | lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) 76 | for param_group in optimizer.param_groups: 77 | param_group['lr'] = lr 78 | 79 | with autocast_ctx: 80 | # 📚 参考模型前向传播 81 | # 参考模型冻结,只用于计算baseline概率 82 | with torch.no_grad(): 83 | ref_outputs = ref_model(x) 84 | ref_logits = ref_outputs.logits 85 | ref_log_probs = logits_to_log_probs(ref_logits, y) 86 | 87 | # 📚 策略模型前向传播 88 | # 策略模型是需要优化的主要模型 89 | outputs = model(x) 90 | logits = outputs.logits 91 | policy_log_probs = logits_to_log_probs(logits, y) 92 | 93 | # 📚 DPO损失计算 94 | loss = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta) 95 | loss = loss / args.accumulation_steps 96 | 97 | # 📚 反向传播 98 | scaler.scale(loss).backward() 99 | 100 | if (step + 1) % args.accumulation_steps == 0: 101 | scaler.unscale_(optimizer) 102 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 103 | scaler.step(optimizer) 104 | scaler.update() 105 | optimizer.zero_grad(set_to_none=True) 106 | 107 | # 📚 训练日志 108 | if step % args.log_interval == 0 or step == iters - 1: 109 | spend_time = time.time() - start_time 110 | current_loss = loss.item() * args.accumulation_steps 111 | current_lr = optimizer.param_groups[-1]['lr'] 112 | eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 113 | 114 | Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:') 115 | 116 | if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}) 117 | 118 | # 📚 模型保存 119 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): 120 | model.eval() 121 | moe_suffix = '_moe' if lm_config.use_moe else '' 122 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' 123 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 124 | state_dict = model.module.state_dict() 125 | else: 126 | state_dict = model.state_dict() 127 | state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存 128 | torch.save(state_dict, ckp) 129 | lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') 130 | model.train() 131 | 132 | if __name__ == "__main__": 133 | """ 134 | DPO主函数:直接偏好优化脚本的入口点 135 | 136 | 📚 DPO训练流程: 137 | 1. 准备策略模型和参考模型 138 | 2. 加载偏好数据(chosen vs rejected) 139 | 3. 同时前向传播计算两种模型的概率 140 | 4. 计算DPO损失并优化策略模型 141 | 5. 迭代直到收敛 142 | """ 143 | 144 | # 📚 命令行参数解析 145 | parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)") 146 | 147 | # ========== 基础训练参数 ========== 148 | parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") 149 | parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名") 150 | parser.add_argument("--epochs", type=int, default=1, help="训练轮数(DPO通常1-2轮)") 151 | parser.add_argument("--batch_size", type=int, default=4, help="batch size(DPO batch较小)") 152 | 153 | # 📚 DPO学习率知识点 154 | # DPO学习率通常很小,避免过度优化导致遗忘 155 | # 建议不超过5e-8 156 | parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)") 157 | 158 | # ========== 硬件配置 ========== 159 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") 160 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") 161 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") 162 | 163 | # ========== 训练策略 ========== 164 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") 165 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") 166 | parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") 167 | parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔") 168 | 169 | # ========== 模型架构参数 ========== 170 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") 171 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") 172 | parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度") 173 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") 174 | 175 | # ========== DPO数据和模型参数 ========== 176 | # 📚 DPO数据格式知识点 177 | # 数据包含chosen(偏好)和rejected(不偏好)回答配对 178 | parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径") 179 | 180 | # 📚 DPO权重继承知识点 181 | # DPO通常基于SFT模型进行对齐优化 182 | parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练(通常是SFT模型)") 183 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") 184 | 185 | # 📚 DPO beta参数知识点 186 | # beta控制优化强度,0.1-0.5是常见范围 187 | parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数(控制优化强度)") 188 | 189 | # ========== 实验跟踪 ========== 190 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") 191 | parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名") 192 | 193 | args = parser.parse_args() 194 | 195 | # ========== 1. 初始化环境和随机种子 ========== 196 | local_rank = init_distributed_mode() 197 | if dist.is_initialized(): args.device = f"cuda:{local_rank}" 198 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) 199 | 200 | # ========== 2. 配置目录、模型参数、检查ckp ========== 201 | os.makedirs(args.save_dir, exist_ok=True) 202 | lm_config = MokioMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) 203 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None 204 | 205 | # ========== 3. 设置混合精度 ========== 206 | device_type = "cuda" if "cuda" in args.device else "cpu" 207 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 208 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) 209 | 210 | # ========== 4. 配置wandb ========== 211 | wandb = None 212 | if args.use_wandb and is_main_process(): 213 | import swanlab as wandb 214 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None 215 | resume = 'must' if wandb_id else None 216 | wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" 217 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) 218 | 219 | # ========== 5. 定义模型和参考模型 ========== 220 | # 📚 DPO双模型架构 221 | # 策略模型:需要优化的模型 222 | # 参考模型:冻结的baseline模型 223 | model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) 224 | Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M') 225 | 226 | # 📚 参考模型初始化 227 | # 参考模型与策略模型初始权重相同,但完全冻结 228 | ref_model, _ = init_model(lm_config, args.from_weight, device=args.device) 229 | ref_model.eval() # 设为评估模式 230 | ref_model.requires_grad_(False) # 冻结所有参数 231 | Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M') 232 | 233 | # 📚 DPO数据集 234 | train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len) 235 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None 236 | scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) 237 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 238 | 239 | # ========== 6. 从ckp恢复状态 ========== 240 | start_epoch, start_step = 0, 0 241 | if ckp_data: 242 | model.load_state_dict(ckp_data['model']) 243 | optimizer.load_state_dict(ckp_data['optimizer']) 244 | scaler.load_state_dict(ckp_data['scaler']) 245 | start_epoch = ckp_data['epoch'] 246 | start_step = ckp_data.get('step', 0) 247 | 248 | # ========== 7. DDP包装模型 ========== 249 | if dist.is_initialized(): 250 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} 251 | model = DistributedDataParallel(model, device_ids=[local_rank]) 252 | 253 | # ========== 8. 开始训练 ========== 254 | for epoch in range(start_epoch, args.epochs): 255 | train_sampler and train_sampler.set_epoch(epoch) 256 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 257 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) 258 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) 259 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') 260 | train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, lm_config, start_step, wandb, args.beta) 261 | else: # 默认从头开始 262 | loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) 263 | train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta) 264 | -------------------------------------------------------------------------------- /trainer/train_ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # 📚 Python模块系统 5 | __package__ = "trainer" 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | 8 | import argparse # 命令行参数解析 9 | import re # 正则表达式,用于奖励计算 10 | import warnings # 警告控制 11 | import torch # PyTorch深度学习框架 12 | import torch.distributed as dist # 分布式训练支持 13 | import torch.nn.functional as F # 神经网络函数 14 | from transformers import AutoTokenizer # HuggingFace分词器 15 | from contextlib import nullcontext # 上下文管理器 16 | from torch import optim, nn # 优化器和神经网络 17 | from torch.nn.parallel import DistributedDataParallel # 分布式并行 18 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载 19 | from torch.nn.utils import clip_grad_norm_ # 梯度裁剪 20 | from torch.optim.lr_scheduler import CosineAnnealingLR # 余弦退火学习率调度 21 | from transformers import AutoModel # HuggingFace模型加载 22 | from model.MokioModel import MokioMindConfig, MokioMindForCausalLM # MiniMind模型 23 | from dataset.lm_dataset import RLAIFDataset # RL数据集 24 | from trainer.trainer_utils import ( # 训练工具函数 25 | Logger, is_main_process, lm_checkpoint, init_distributed_mode, 26 | setup_seed, SkipBatchSampler, init_model 27 | ) 28 | 29 | warnings.filterwarnings('ignore') 30 | #==========Critic Model部分========== 31 | 32 | class CriticModel(MokioMindForCausalLM): 33 | def __init__(self,params): 34 | super().__init__(params) 35 | # 价值头,用于输出每个token位置的状态价值 36 | self.value_head=nn.Linear(params.hidden_size,1) 37 | 38 | def forward(self,input_ids=None,attention_mask=None,**kwargs): 39 | outputs=self.model(input_ids=input_ids,attention_mask=attention_mask,**kwargs) 40 | hidden_states=self.model.norm(outputs[0]) 41 | 42 | values=self.value_head(hidden_states).squeeze(-1) 43 | return values 44 | 45 | #==========奖励计算部分========== 46 | def calculate_rewards(prompts,responses,reward_model,reward_tokenizer): 47 | def reasoning_model_reward(rewards): 48 | # 使用正则表达式匹配思考-回答格式 49 | pattern = r"^\n.*?\n\n\n.*?\n$" 50 | # 多了一个\n,考虑到think和answer之间有空行的情况 51 | pattern2 = r"^\n.*?\n\n\n\n.*?\n$" 52 | # 通过正则表达式计算奖励,如果回答符合格式则奖励0.5,否则0.0 53 | matches_pattern = [re.match(pattern, response, re.S) for response in responses] 54 | matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses] 55 | 56 | format_rewards = [] 57 | for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2): 58 | if match_pattern: 59 | format_rewards.append(0.5) 60 | elif match_pattern2: 61 | format_rewards.append(0.5) 62 | else: 63 | format_rewards.append(0.0) 64 | rewards += torch.tensor(format_rewards, device=args.device) 65 | 66 | def mark_num(text): 67 | reward=0 68 | if text.count("")==1: 69 | reward+=0.25 70 | if text.count("")==1: 71 | reward+=0.25 72 | if text.count("")==1: 73 | reward+=0.25 74 | if text.count("")==1: 75 | reward+=0.25 76 | return reward 77 | 78 | mark_rewards=[mark_num(response) for response in responses] 79 | rewards+=torch.tensor(mark_rewards,device=args.device) 80 | return rewards 81 | rewards=torch.zeros(len(responses),device=args.device) 82 | 83 | if args.reasoning==1: 84 | rewards=reasoning_model_reward(rewards) 85 | #==========Reward模型评分部分========== 86 | with torch.no_grad(): 87 | reward_model_scores = [] 88 | for prompt,response in zip(prompts,responses): 89 | 90 | pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>" 91 | matches = re.findall(pattern, prompt, re.DOTALL) 92 | messages = [{"role": role, "content": content.strip()} for role, content in matches] 93 | 94 | tmp_chat=messages+[{"role":"assistant","content":response}] 95 | score=reward_model.get_reward(tmp_chat,reward_tokenizer) 96 | 97 | scale=3.0 98 | score=max(min(score,scale),-scale) 99 | 100 | if args.reasoning==1: 101 | answer_match = re.search(r'(.*?)', response, re.DOTALL) 102 | if answer_match: 103 | answer_content = answer_match.group(1).strip() 104 | # 对answer内容单独计算reward 105 | tmp_chat = messages + [{"role": "assistant", "content": answer_content}] 106 | answer_score = reward_model.get_score(reward_tokenizer, tmp_chat) 107 | answer_score = max(min(answer_score, scale), -scale) 108 | # 📚 加权组合 109 | score = score * 0.4 + answer_score * 0.6 110 | reward_model_scores.append(score) 111 | 112 | reward_model_scores=torch.tensor(reward_model_scores,device=args.device) 113 | rewards+=reward_model_scores 114 | 115 | return rewards 116 | 117 | #==========PPO训练一个Epoch部分========== 118 | def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step=0, wandb=None): 119 | # 切换actor和critic模型到训练模式 120 | actor_model.train() 121 | critic_model.train() 122 | 123 | for step,batch in enumerate(loader,start=start_step+1): 124 | prompts=batch['prompt'] 125 | # 编码输入 126 | enc=tokenizer(prompts,return_tensors='pt',padding=True,truncation=True,max_length=args.max_seq_len).to(args.device) 127 | # 计算每个prompt的长度(用于后续处理) 128 | prompt_lengths=enc.attention_mask.sum(dim=1) 129 | 130 | with torch.no_grad(): 131 | model_for_gen=actor_model.module if isinstance(actor_model,DistributedDataParallel) else actor_model 132 | 133 | gen_out=model_for_gen.generate( 134 | input_ids=enc.input_ids, 135 | attention_mask=enc.attention_mask, 136 | max_new_tokens=args.max_gen_len, 137 | do_sample=True, 138 | temperature=0.8, 139 | pad_token_id=tokenizer.eos_token_id, 140 | eos_token_id=tokenizer.eos_token_id 141 | ) 142 | 143 | # 解码生成的响应 144 | responses_text=[tokenizer.decode(gen_out[i,prompt_lengths[i]:],skip_special_tokens=True) for i in range(len(prompts))] 145 | 146 | # 计算奖励 147 | rewards=calculate_rewards(prompts,responses_text,reward_model,reward_tokenizer) 148 | 149 | # 创建一个mask,用于标记哪些位置上是有效token 150 | full_mask=(gen_out!=tokenizer.pad_token_id).long() 151 | # critic模型进行价值估计 152 | value_seq=critic_model(input_ids=gen_out,attention_mask=full_mask) 153 | # 拿到最后一个非pad位置的索引 154 | last_indices=full_mask.sum(dim=1)-1 155 | # 获取每条序列最后token的value 156 | values=value_seq[torch.arange(len(last_indices)),last_indices] 157 | # advantage=reward-估计的value 158 | advantages = rewards - values.detach() # [B] 159 | 160 | # 计算actor log,表示actor对这个答案的“信心” 161 | # 先生成logits 162 | logits=actor_model(input_ids=gen_out,attention_mask=full_mask).logits # [B, L, V] 163 | # label是生成的token序列,去掉第一个token(因为logits是预测下一个token的概率) 164 | labels=gen_out[:,1:].clone() 165 | # 使用log_softmax计算log概率 166 | logp_tokens=F.log_softmax(logits[:,:-1,:],dim=-1).gather(2,labels.unsqueeze(-1)).squeeze(-1) # [B, L-1] 167 | seq_len=gen_out.size(1)-1 168 | # 只关心response部分的概率,所以要把prompts部分的mask掉 169 | resp_mask=torch.arange(seq_len,device=gen_out.device).unsqueeze(0)>=prompt_lengths.unsqueeze(1) 170 | 171 | final_mask=resp_mask&(~labels.eq(tokenizer.pad_token_id)) 172 | # 把所有response部分的log概率加起来,得到每条序列的总log概率 173 | actor_logp=(logp_tokens*final_mask).sum(dim=1) 174 | 175 | # 计算old和ref log的概率 176 | # old用于防止策略更新过大,ref用于计算KL惩罚,防止模型忘本 177 | with torch.no_grad(): 178 | old_logits = old_actor_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V] 179 | old_logp_tokens = F.log_softmax(old_logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1] 180 | old_logp = (old_logp_tokens * final_mask).sum(dim=1) # [B] 181 | 182 | ref_logits = ref_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V] 183 | ref_logp_tokens = F.log_softmax(ref_logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1] 184 | ref_logp = (ref_logp_tokens * final_mask).sum(dim=1) # [B] 185 | 186 | # 计算KL散度和ratio 187 | kl=(actor_logp - old_logp).mean() 188 | kl_ref=(actor_logp - ref_logp).mean() 189 | ratio=torch.exp(actor_logp - old_logp) # [B] 190 | 191 | # PPO裁剪损失 192 | surr1=ratio*advantages # [B] 193 | surr2=torch.clamp(ratio,1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon)*advantages # [B] 194 | policy_loss=-torch.min(surr1,surr2).mean() 195 | 196 | # 价值函数损失 197 | value_loss=F.mse_loss(values,rewards) 198 | # 总损失 199 | loss = policy_loss + args.vf_coef * value_loss + args.kl_coef * kl_ref # scalar 200 | loss.backward() 201 | 202 | # 更新参数 203 | if (step + 1) % args.accumulation_steps == 0: 204 | clip_grad_norm_(actor_model.parameters(), args.grad_clip) 205 | clip_grad_norm_(critic_model.parameters(), args.grad_clip) 206 | actor_optimizer.step() 207 | critic_optimizer.step() 208 | actor_scheduler.step() 209 | critic_scheduler.step() 210 | actor_optimizer.zero_grad() 211 | critic_optimizer.zero_grad() 212 | 213 | # 📚 日志记录 214 | if is_main_process(): 215 | response_ids = gen_out[:, enc.input_ids.shape[1]:] 216 | is_eos = (response_ids == tokenizer.eos_token_id) 217 | eos_indices = torch.argmax(is_eos.int(), dim=1) 218 | has_eos = is_eos.any(dim=1) 219 | lengths = torch.where(has_eos, eos_indices + 1, torch.tensor(response_ids.shape[1], device=is_eos.device)) 220 | avg_len = lengths.float().mean() 221 | 222 | actor_loss_val = policy_loss.item() 223 | critic_loss_val = value_loss.item() 224 | reward_val = rewards.mean().item() 225 | kl_val = kl.item() 226 | kl_ref_val = kl_ref.item() 227 | avg_len_val = avg_len.item() 228 | actor_lr = actor_optimizer.param_groups[0]['lr'] 229 | critic_lr = critic_optimizer.param_groups[0]['lr'] 230 | 231 | if wandb is not None: 232 | wandb.log({ 233 | "actor_loss": actor_loss_val, 234 | "critic_loss": critic_loss_val, 235 | "reward": reward_val, 236 | "kl": kl_val, 237 | "kl_ref": kl_ref_val, 238 | "avg_response_len": avg_len_val, 239 | "actor_lr": actor_lr, 240 | }) 241 | 242 | Logger(f"Epoch: {epoch+1}, Step: {step}/{iters}, " 243 | f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, " 244 | f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, " 245 | f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}") 246 | 247 | # 📚 更新old actor 248 | if (step + 1) % args.update_old_actor_freq == 0: 249 | state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict() 250 | old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()}) 251 | old_actor_model.to(args.device) 252 | 253 | # 📚 模型保存 254 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): 255 | actor_model.eval() 256 | moe_suffix = '_moe' if lm_config.use_moe else '' 257 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' 258 | actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict() 259 | torch.save({k: v.half() for k, v in actor_state.items()}, ckp) 260 | 261 | # 使用 lm_checkpoint 保存完整状态(包括 critic) 262 | lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer, 263 | epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', 264 | scheduler=actor_scheduler, critic_model=critic_model, 265 | critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler) 266 | actor_model.train() 267 | 268 | 269 | 270 | if __name__ == "__main__": 271 | """ 272 | PPO主函数:近端策略优化脚本的入口点 273 | 274 | 📚 PPO训练架构: 275 | 1. Actor模型:生成策略,输出动作概率 276 | 2. Critic模型:价值函数,估计状态价值 277 | 3. Reward模型:奖励函数,评估生成质量 278 | 4. Old Actor:用于重要性采样的旧策略 279 | 5. Reference:用于KL惩罚的参考策略 280 | """ 281 | 282 | # 📚 命令行参数解析 283 | parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)") 284 | 285 | # ========== 基础训练参数 ========== 286 | parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") 287 | parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名") 288 | parser.add_argument("--epochs", type=int, default=1, help="训练轮数") 289 | parser.add_argument("--batch_size", type=int, default=2, help="batch size(PPO batch较小)") 290 | 291 | # 📚 PPO学习率设置 292 | # PPO学习率通常很小,避免策略剧烈变化 293 | parser.add_argument("--learning_rate", type=float, default=8e-8, help="Actor学习率") 294 | parser.add_argument("--critic_learning_rate", type=float, default=8e-8, help="Critic学习率") 295 | 296 | # ========== 硬件配置 ========== 297 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") 298 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") 299 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") 300 | 301 | # ========== 训练策略 ========== 302 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") 303 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") 304 | parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔") 305 | parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔") 306 | 307 | # ========== 模型架构参数 ========== 308 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") 309 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") 310 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") 311 | 312 | # ========== PPO生成参数 ========== 313 | parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度") 314 | parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度") 315 | 316 | # ========== 数据和模型参数 ========== 317 | parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径") 318 | 319 | # 📚 PPO超参数 320 | parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数(控制策略更新幅度)") 321 | parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数") 322 | parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数") 323 | 324 | # 📚 推理模型配置 325 | parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)') 326 | parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率") 327 | 328 | # 📚 Reward模型路径 329 | parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径") 330 | 331 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") 332 | 333 | # ========== 实验跟踪 ========== 334 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") 335 | parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名") 336 | 337 | args = parser.parse_args() 338 | # ========== 1. 初始化环境和随机种子 ========== 339 | local_rank = init_distributed_mode() 340 | if dist.is_initialized(): args.device = f"cuda:{local_rank}" 341 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) 342 | 343 | # ========== 2. 配置目录、模型参数、检查ckp ========== 344 | os.makedirs(args.save_dir, exist_ok=True) 345 | lm_config = MokioMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) 346 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None 347 | 348 | # ========== 3. 设置混合精度 ========== 349 | device_type = "cuda" if "cuda" in args.device else "cpu" 350 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 351 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) 352 | 353 | # ========== 4. 配置wandb ========== 354 | wandb = None 355 | if args.use_wandb and is_main_process(): 356 | import swanlab as wandb 357 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None 358 | resume = 'must' if wandb_id else None 359 | wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" 360 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) 361 | # ========== 5. 初始化模型和数据 ========== 362 | # 📚 PPO模型架构 363 | base_weight = "reason" if args.reasoning == 1 else "full_sft" 364 | 365 | # 📚 Actor模型(策略模型) 366 | actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device) 367 | tokenizer.padding_side = 'left' # PPO需要左侧padding 368 | 369 | # 📚 Old Actor模型(用于重要性采样) 370 | old_actor_model, _ = init_model(lm_config, base_weight, device=args.device) 371 | old_actor_model = old_actor_model.eval().requires_grad_(False) 372 | 373 | # 📚 Reference模型(用于KL惩罚) 374 | ref_model, _ = init_model(lm_config, base_weight, device=args.device) 375 | ref_model = ref_model.eval().requires_grad_(False) 376 | 377 | # 📚 Critic模型(价值函数) 378 | moe_suffix = '_moe' if lm_config.use_moe else '' 379 | ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth' 380 | state_dict = torch.load(ckp, map_location=args.device) 381 | critic_model = CriticModel(lm_config) 382 | critic_model.load_state_dict(state_dict, strict=False) 383 | critic_model = critic_model.to(args.device) 384 | 385 | # 📚 Reward模型(奖励函数) 386 | reward_model = AutoModel.from_pretrained( 387 | args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True 388 | ) 389 | reward_model = reward_model.to(args.device).eval().requires_grad_(False) 390 | reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True) 391 | 392 | # 📚 数据和优化器 393 | train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len)) 394 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None 395 | actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate) 396 | critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate) 397 | loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler) 398 | iters = len(loader_for_count) 399 | total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs 400 | actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) 401 | critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10) 402 | 403 | # ========== 6. 从ckp恢复状态 ========== 404 | start_epoch, start_step = 0, 0 405 | if ckp_data: 406 | actor_model.load_state_dict(ckp_data['model']) 407 | critic_model.load_state_dict(ckp_data['critic_model']) 408 | actor_optimizer.load_state_dict(ckp_data['optimizer']) 409 | critic_optimizer.load_state_dict(ckp_data['critic_optimizer']) 410 | actor_scheduler.load_state_dict(ckp_data['scheduler']) 411 | critic_scheduler.load_state_dict(ckp_data['critic_scheduler']) 412 | start_epoch = ckp_data['epoch'] 413 | start_step = ckp_data.get('step', 0) 414 | 415 | # ========== 7. DDP包装模型 ========== 416 | if dist.is_initialized(): 417 | actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} 418 | critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} 419 | actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank]) 420 | critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank]) 421 | old_actor_model.to(args.device) 422 | 423 | # ========== 8. 开始训练 ========== 424 | for epoch in range(start_epoch, args.epochs): 425 | train_sampler and train_sampler.set_epoch(epoch) 426 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 427 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) 428 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) 429 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') 430 | ppo_train_epoch(epoch, loader, len(loader) + start_step + 1, old_actor_model, ref_model, 431 | actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb) 432 | else: # 默认从头开始 433 | loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), 434 | sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) 435 | ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model, 436 | actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb) -------------------------------------------------------------------------------- /model/MokioModel.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class MokioMindConfig(PretrainedConfig): 5 | model_type = "mokiomind" 6 | 7 | def __init__( 8 | self, 9 | dropout: float = 0.0, 10 | bos_token_id: int = 1, 11 | eos_token_id: int = 2, 12 | hidden_act: str = "silu", 13 | hidden_size: int = 512, 14 | intermediate_size: int = None, 15 | max_position_embeddings: int = 32768, 16 | num_attention_heads: int = 8, 17 | num_hidden_layers: int = 8, 18 | num_key_value_heads: int = 2, 19 | vocab_size: int = 6400, 20 | rms_norm_eps: float = 1e-05, 21 | rope_theta: int = 1000000, 22 | inference_rope_scaling: bool = False, 23 | flash_attention: bool = True, 24 | 25 | ############ MoE ############ 26 | use_moe:bool=False, 27 | num_experts_per_tok:int=2, 28 | n_routed_experts:int=4, 29 | n_shared_experts:int=1, 30 | scoring_func:str='softmax', 31 | aux_loss_alpha:float=0.1, 32 | seq_aux:bool=True, 33 | norm_topk_prob:bool=True, 34 | **kwargs, 35 | ): 36 | super().__init__(**kwargs) 37 | 38 | self.dropout = dropout 39 | self.bos_token_id = bos_token_id 40 | self.eos_token_id = eos_token_id 41 | self.hidden_act = hidden_act 42 | self.hidden_size = hidden_size 43 | self.intermediate_size = intermediate_size 44 | self.max_position_embeddings = max_position_embeddings 45 | self.num_attention_heads = num_attention_heads 46 | self.num_hidden_layers = num_hidden_layers 47 | self.num_key_value_heads = num_key_value_heads 48 | self.vocab_size = vocab_size 49 | self.rms_norm_eps = rms_norm_eps 50 | self.rope_theta = rope_theta 51 | self.inference_rope_scaling = inference_rope_scaling 52 | self.flash_attention = flash_attention 53 | self.use_moe=use_moe 54 | self.num_experts_per_tok=num_experts_per_tok 55 | self.n_routed_experts=n_routed_experts 56 | self.n_shared_experts=n_shared_experts 57 | self.seq_aux=seq_aux 58 | self.norm_topk_prob=norm_topk_prob 59 | self.aux_loss_alpha=aux_loss_alpha 60 | self.scoring_func=scoring_func 61 | 62 | self.rope_scaling = ( 63 | { 64 | "beta_fast": 4, 65 | "beta_slow": 1, 66 | "factor": 4, 67 | "original_max_position_embeddings": 2048, 68 | "type": "yarn", 69 | } 70 | if self.inference_rope_scaling 71 | else None 72 | ) 73 | 74 | 75 | import torch 76 | import math 77 | import torch.nn as nn 78 | from typing import Optional, Tuple,List,Union 79 | import torch.nn.functional as F 80 | from transformers.activations import ACT2FN 81 | from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig 82 | from transformers.modeling_outputs import CausalLMOutputWithPast 83 | 84 | 85 | class RMSNorm(nn.Module): 86 | def __init__(self, dim: int, eps: float = 1e-5): 87 | super().__init__() 88 | self.dim = dim 89 | self.eps = eps 90 | self.weight = nn.Parameter(torch.ones(dim)) 91 | 92 | def _norm(self, x): 93 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 94 | 95 | def forward(self, x): 96 | return self.weight * self._norm(x.float()).typed_as(x) 97 | 98 | 99 | def precompute_freqs( 100 | dim: int, 101 | end: int = int(32 * 1024), 102 | rope_base: float = 1e6, 103 | rope_scaling: Optional[dict] = None, 104 | ): 105 | freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 106 | 107 | if rope_scaling is not None: 108 | original_max, factor, beta_fast, beta_slow = ( 109 | rope_scaling.get("original_max_position_embeddings", 2048), 110 | rope_scaling.get("factor", 4), 111 | rope_scaling.get("beta_fast", 4.0), 112 | rope_scaling.get("beta_slow", 1.0), 113 | ) 114 | 115 | if end / original_max > 1.0: 116 | corr_dim = next( 117 | (i for i in range(dim // 2) if 2 * math.pi / freqs[i] > original_max), 118 | dim // 2, 119 | ) 120 | 121 | power = torch.arange(0, dim // 2, device=freqs.device).float() / max( 122 | dim // 2 - 1, 1 123 | ) 124 | 125 | beta = beta_slow + (beta_fast - beta_slow) * power 126 | 127 | scale = torch.where( 128 | torch.arange(dim // 2, device=freqs.device) < corr_dim, 129 | (beta * factor - beta + 1) / (beta * factor), 130 | 1.0 / factor, 131 | ) 132 | 133 | freqs = freqs * scale 134 | 135 | t = torch.arange(end, device=freqs.device) 136 | freqs = torch.outer(t, freqs).float() 137 | 138 | freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) 139 | freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) 140 | 141 | return freqs_cos, freqs_sin 142 | 143 | 144 | def apply_rotary_pos_emb( 145 | q, k, cos, sin, position_ids=None, unsqueeze_dim=1 146 | ) -> Tuple[torch.Tensor, torch.Tensor]: 147 | def rotate_half(x): 148 | return torch.cat( 149 | (-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1 150 | ) 151 | 152 | q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + ( 153 | rotate_half(q) * sin.unsqueeze(unsqueeze_dim) 154 | ) 155 | k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + ( 156 | rotate_half(k) * sin.unsqueeze(unsqueeze_dim) 157 | ) 158 | return q_embed, k_embed 159 | 160 | 161 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 162 | bs, slen, num_key_value_heads, head_dim = x.shape 163 | if n_rep == 1: 164 | return x 165 | 166 | return ( 167 | x[:, :, :, None, :] 168 | .expand(bs, slen, num_key_value_heads, n_rep, head_dim) 169 | .reshape(bs, slen, num_key_value_heads * n_rep, head_dim) 170 | ) 171 | 172 | 173 | class Attention(nn.Module): 174 | def __init__(self, args: MokioMindConfig): 175 | super().__init__() 176 | 177 | self.num_key_value_heads = ( 178 | args.num_attention_heads 179 | if args.num_key_value_heads is None 180 | else args.num_key_value_heads 181 | ) 182 | 183 | assert args.num_attention_heads % self.num_key_value_heads == 0 184 | 185 | self.n_local_heads = args.num_attention_heads 186 | self.n_local_kv_heads = self.num_key_value_heads 187 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 188 | self.head_dim = args.hidden_size // args.num_attention_heads 189 | 190 | self.q_proj = nn.Linear( 191 | args.hidden_size, args.num_attention_heads * self.head_dim, bias=False 192 | ) 193 | self.k_proj = nn.Linear( 194 | args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False 195 | ) 196 | self.v_proj = nn.Linear( 197 | args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False 198 | ) 199 | self.o_proj = nn.Linear( 200 | args.num_attention_heads * self.head_dim, args.hidden_size, bias=False 201 | ) 202 | 203 | self.attn_dropout = nn.Dropout(args.dropout) 204 | self.resid_dropout = nn.Dropout(args.dropout) 205 | self.dropout = args.dropout 206 | self.flash = ( 207 | hasattr(torch.nn.functional, "scaled_dot_product_attention") 208 | and args.flash_attention 209 | ) 210 | 211 | def forward( 212 | self, 213 | x: torch.Tensor, 214 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 215 | past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 216 | use_cache=False, 217 | attention_mask: Optional[torch.Tensor] = None, 218 | ): 219 | bsz, seq_len, _ = x.shape 220 | 221 | xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) 222 | 223 | xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) 224 | xk = xk.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) 225 | xv = xv.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) 226 | 227 | cos, sin = position_embeddings 228 | xq, xk = apply_rotary_pos_emb(xq, xk, cos[:seq_len], sin[:seq_len]) 229 | 230 | if past_key_value is not None: 231 | xk = torch.cat([past_key_value[0], xk], dim=1) 232 | xv = torch.cat([past_key_value[1], xv], dim=1) 233 | 234 | past_kv = (xk, xv) if use_cache else None 235 | 236 | xq = xq.transpose(1, 2) 237 | xk = repeat_kv(xk, self.n_rep).transpose(1, 2) 238 | xv = repeat_kv(xv, self.n_rep).transpose(1, 2) 239 | 240 | if ( 241 | self.flash 242 | and seq_len > 1 243 | and (attention_mask is None or torch.all(attention_mask == 1)) 244 | ): 245 | attn_mask = ( 246 | None 247 | if attention_mask is None 248 | else attention_mask.view(bsz, 1, 1, -1) 249 | .expand(bsz, self.n_local_heads, seq_len, -1) 250 | .bool() 251 | ) 252 | output = F.scaled_dot_product_attention( 253 | xq, 254 | xk, 255 | xv, 256 | attn_mask=attn_mask, 257 | dropout_p=self.dropout if self.training else 0.0, 258 | is_causal=True, # 自回归(因果)注意力 259 | ) 260 | else: 261 | scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) 262 | 263 | causal_mask = torch.triu( 264 | torch.full((seq_len, seq_len), float("-inf"), device=scores.device), 265 | diagonal=-1, 266 | ) 267 | 268 | scores=scores+causal_mask.unsqueeze(0).unsqueeze(0) 269 | 270 | if attention_mask is not None: 271 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 272 | extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 273 | scores = scores + extended_attention_mask 274 | 275 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 276 | scores=self.attn_dropout(scores) 277 | output=scores@xv 278 | 279 | output=output.transpose(1,2).reshape( 280 | bsz,seq_len,-1 281 | ) 282 | output=self.resid_dropout(self.o_proj(output)) 283 | return output,past_kv 284 | 285 | class FeedForward(nn.Module): 286 | def __init__(self, config: MokioMindConfig): 287 | super().__init__() 288 | if config.intermediate_size is None: 289 | intermediate_size = int(config.hidden_size * 8 / 3) 290 | config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64) 291 | 292 | self.gate_proj = nn.Linear( 293 | config.hidden_size, config.intermediate_size, bias=False 294 | ) 295 | self.down_proj = nn.Linear( 296 | config.intermediate_size, config.hidden_size, bias=False 297 | ) 298 | self.up_proj = nn.Linear( 299 | config.hidden_size, config.intermediate_size, bias=False 300 | ) 301 | self.dropout = nn.Dropout(config.dropout) 302 | self.act_fn = ACT2FN[config.hidden_act] 303 | 304 | def forward(self, x): 305 | gated = self.act_fn(self.gate_proj(x)) * self.up_proj(x) 306 | return self.dropout(self.down_proj(gated)) 307 | 308 | class MoEGate(nn.Module): 309 | def __init__(self,config:MokioMindConfig): 310 | super().__init__() 311 | self.config=config 312 | self.top_k=config.num_experts_per_tok 313 | 314 | self.scoring_func=config.scoring_func 315 | self.alpha=config.aux_loss_alpha 316 | self.seq_aux=config.seq_aux 317 | 318 | self.norm_topk_prob=config.norm_topk_prob 319 | self.gating_dim=config.hidden_size 320 | self.weight=nn.Parameter(torch.empty((self.n_routed_experts,self.gating_dim))) 321 | self.reset_parameters() 322 | 323 | def reset_parameters(self)->None: 324 | init.kaiming_uniform_(self.weight,a=math.sqrt(5)) 325 | 326 | def forward(self,hidden_states): 327 | bsz,seq_len,h=hidden_states.shape 328 | hidden_states=hidden_states.view(-1,h) 329 | logits=F.linear(hidden_states,self.weight,None) 330 | 331 | if self.scoring_func == 'softmax': 332 | scores = logits.softmax(dim=-1) 333 | else: 334 | raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') 335 | 336 | topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) 337 | 338 | if self.top_k > 1 and self.norm_topk_prob: 339 | denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 340 | topk_weight = topk_weight / denominator 341 | 342 | if self.training and self.alpha > 0.0: 343 | scores_for_aux = scores 344 | aux_topk = self.top_k 345 | topk_idx_for_aux_loss = topk_idx.view(bsz, -1) 346 | if self.seq_aux: 347 | scores_for_seq_aux=scores_for_aux.view(bsz,seq_len,-1) 348 | ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) 349 | ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)) 350 | ce = ce.div(seq_len * aux_topk / self.n_routed_experts) 351 | aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha 352 | else: 353 | mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) 354 | ce = mask_ce.float().mean(0) 355 | Pi = scores_for_aux.mean(0) 356 | fi = ce * self.n_routed_experts 357 | aux_loss = (Pi * fi).sum() * self.alpha 358 | else: 359 | aux_loss = 0 360 | return topk_weight, topk_idx, aux_loss 361 | 362 | class MoEFeedForaward(nn.Module): 363 | def __init__(self,config:MokioMindConfig): 364 | super().__init__() 365 | self.config=config 366 | # 专家层 367 | self.experts=nn.ModuleList( 368 | [FeedForward(config) 369 | for _ in range(config.n_routed_experts)] 370 | ) 371 | # 门控层 372 | self.gate=MoEGate(config) 373 | if config.n_shared_experts>0: 374 | self.shared_experts=nn.ModuleList( 375 | [FeedForward(config) 376 | for _ in range(config.n_shared_experts)] 377 | ) 378 | def forward(self,x): 379 | identity=x 380 | orig_shape=x.shape 381 | bsz,seq_len,h=orig_shape 382 | 383 | # 使用门控机制旋转专家 384 | topk_weight, topk_idx, aux_loss = self.gate(x) 385 | # 展开x以便处理 386 | x=x.view(-1,x.shape[-1]) 387 | 388 | flat_topk_idx=topk_idx.view(-1) 389 | if self.training: 390 | # 按照定义的num_experts_per_tok重复输入token 391 | # 每个token安排num_experts_per_tok个专家处理 392 | x=x.repeat_interleave(self.config.num_experts_per_tok,dim=0) 393 | # y是空张量,和x形状相同 394 | y=torch.empty_like(x,dtype=torch.float32) 395 | # 遍历所有专家 396 | for i,expert in enumerate(self.experts): 397 | # 找到所有指向专家i的token 398 | # 然后将这些token输入专家i进行处理 399 | # 最后将结果放回y对应位置 400 | y[flat_topk_idx==i]=expert(x[flat_topk_idx==i]).to(y.dtype) 401 | # 加权求和 402 | # 最后的y意义是每个token经过专家处理后的加权结果 403 | y=(y.view(*topk_weight.shape,-1)*topk_weight.unsqueeze(-1).sum(dim=1)) 404 | y=y.view(*orig_shape) 405 | # 如果是推理阶段 406 | else: 407 | y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) 408 | if self.config.n_shared_experts > 0: 409 | for expert in self.shared_experts: 410 | y = y + expert(identity) 411 | self.aux_loss = aux_loss 412 | return y 413 | 414 | @torch.no_grad() 415 | # MoE推理方法 416 | def moe_infer(self, x, flat_expert_indices, flat_expert_weights): 417 | # 使用cache,创建一个和x形状相同的零张量 418 | expert_cache = torch.zeros_like(x) 419 | # 对专家索引进行排序,最后是[0,0,0,1,1,2,2,2,...]这样的顺序 420 | # 分拣 421 | idxs = flat_expert_indices.argsort() 422 | # 统计每个专家被分配到的token数量 423 | # 打包 424 | tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) 425 | # 计算每个token对应的专家索引 426 | token_idxs = idxs // self.config.num_experts_per_tok 427 | # 对每个打包好的包进行处理 428 | for i, end_idx in enumerate(tokens_per_expert): 429 | # 计算当前包的起始位置 430 | start_idx = 0 if i == 0 else tokens_per_expert[i - 1] 431 | if start_idx == end_idx: 432 | continue 433 | # 取出当前包对应的专家 434 | expert = self.experts[i] 435 | # 取出token对应的原始id 436 | exp_token_idx = token_idxs[start_idx:end_idx] 437 | # 取出token对应的数据 438 | expert_tokens = x[exp_token_idx] 439 | # 计算专家输出,一次性处理当前包的所有token 440 | expert_out = expert(expert_tokens).to(expert_cache.dtype) 441 | # 加权 442 | expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) 443 | # 将结果散点加到缓存中对应位置 444 | expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) 445 | 446 | return expert_cache 447 | 448 | 449 | 450 | class MokioMindBlock(nn.Module): 451 | def __init__(self, layer_id: int, config: MokioMindConfig): 452 | super().__init__() 453 | self.num_attention_heads = config.num_attention_heads 454 | self.hidden_size = config.hidden_size 455 | self.head_dim = config.hidden_size // config.num_attention_heads 456 | self.self_attention = Attention(config) 457 | 458 | self.layer_id = layer_id 459 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 460 | self.post_attention_layernorm = RMSNorm( 461 | config.hidden_size, eps=config.rms_norm_eps 462 | ) 463 | self.mlp = FeedForward(config)if not config.use_moe else MoEFeedForaward(config) 464 | 465 | def forward( 466 | self, 467 | hidden_states, 468 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 469 | past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 470 | use_cache=False, 471 | attention_mask: Optional[torch.Tensor] = None, 472 | ): 473 | res = hidden_states 474 | 475 | hidden_states, present_key_value = self.self_attention( 476 | self.input_layernorm(hidden_states), # pre-norm 477 | position_embeddings, 478 | past_key_value, 479 | use_cache, 480 | attention_mask, 481 | ) 482 | 483 | hidden_states = res + hidden_states 484 | 485 | hidden_states = hidden_states + self.mlp( 486 | self.post_attention_layernorm(hidden_states) 487 | ) 488 | return hidden_states, present_key_value 489 | 490 | class MokioMindModel(nn.Module): 491 | def __init__(self, config: MokioMindConfig): 492 | super().__init__() 493 | self.config = config 494 | self.vocab_size, self.num_hidden_layers = ( 495 | config.vocab_size, 496 | config.num_hidden_layers, 497 | ) 498 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 499 | self.dropout = nn.Dropout(config.dropout) 500 | self.layers = nn.ModuleList( 501 | [MokioMindBlock(l, config) for l in range(self.num_hidden_layers)] 502 | ) 503 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 504 | 505 | freqs_cos, freqs_sin = precompute_freqs( 506 | dim=config.hidden_size // config.num_attention_heads, 507 | end=config.max_position_embeddings, 508 | rope_base=config.rope_theta, 509 | rope_scaling=config.rope_scaling, 510 | ) 511 | self.register_buffer("freqs_cos", freqs_cos, persistent=False) 512 | self.register_buffer("freqs_sin", freqs_sin, persistent=False) 513 | 514 | def forward( 515 | self, 516 | input_ids: Optional[torch.Tensor] = None, 517 | attention_mask: Optional[torch.Tensor] = None, 518 | past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, 519 | use_cache: bool = False, 520 | **kwargs, 521 | ): 522 | # input_ids: [bsz, seq_len] 523 | batch_size, seq_length = input_ids.shape 524 | 525 | if hasattr(past_key_values, "layers"): 526 | past_key_values = None 527 | 528 | past_key_values = past_key_values or [None] * len(self.layers) 529 | 530 | # 计算start_pos:如果存在past,则start_pos为已有past序列长度 531 | start_pos = ( 532 | past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0 533 | ) 534 | 535 | # Embedding + dropout 536 | hidden_states = self.dropout( 537 | self.embed_tokens(input_ids) 538 | ) # [bsz, seq_len, hidden] 539 | 540 | position_embeddings = ( 541 | self.freqs_cos[start_pos : start_pos + seq_length], 542 | self.freqs_sin[start_pos : start_pos + seq_length], 543 | ) 544 | presents = [] 545 | for layer_idx, (layer, past_key_value) in enumerate( 546 | zip(self.layers, past_key_values) 547 | ): 548 | hidden_states, present = layer( 549 | hidden_states, 550 | position_embeddings, 551 | past_key_value=past_key_value, 552 | use_cache=use_cache, 553 | attention_mask=attention_mask, 554 | ) 555 | presents.append(present) 556 | 557 | hidden_states = self.norm(hidden_states) 558 | 559 | aux_loss=sum(layer.mlp.aux_loss for layer in self.layers if isinstance(layer.mlp,MoEFeedForaward)) 560 | 561 | return hidden_states, presents,aux_loss 562 | 563 | class MokioMindForCausalLM(PreTrainedModel, GenerationMixin): 564 | config_class = MokioMindConfig 565 | 566 | def __init__(self, config: MokioMindConfig): 567 | super().__init__(config) 568 | self.model = MokioMindModel(config) 569 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 570 | self.model.embed_tokens.weight = self.lm_head.weight 571 | 572 | def forward( 573 | self, 574 | input_ids: Optional[torch.Tensor] = None, 575 | attention_mask: Optional[torch.Tensor] = None, 576 | past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, 577 | use_cache: bool = False, 578 | logits_to_keep: Union[int, torch.Tensor] = 0, 579 | **args, 580 | ): 581 | h, past_kvs,aux_loss = self.model( 582 | input_ids=input_ids, 583 | attention_mask=attention_mask, 584 | past_key_values=past_key_values, 585 | use_cache=use_cache, 586 | **args, 587 | ) 588 | 589 | slice_indices = ( 590 | slice(-logits_to_keep, None) 591 | if isinstance(logits_to_keep, int) 592 | else logits_to_keep 593 | ) 594 | logits = self.lm_head(h[:, slice_indices, :]) 595 | 596 | return CausalLMOutputWithPast( 597 | logits=logits, 598 | past_key_values=past_kvs, 599 | hidden_states=h, 600 | ) 601 | -------------------------------------------------------------------------------- /trainer/train_grpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # 📚 Python模块系统 5 | __package__ = "trainer" 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | 8 | import argparse # 命令行参数解析 9 | import re # 正则表达式,用于奖励计算 10 | import gc # 垃圾回收,手动释放内存 11 | import warnings # 警告控制 12 | import torch # PyTorch深度学习框架 13 | import torch.distributed as dist # 分布式训练支持 14 | from transformers import AutoTokenizer # HuggingFace分词器 15 | from contextlib import nullcontext # 上下文管理器 16 | from torch import optim # 优化器 17 | from torch.nn.parallel import DistributedDataParallel # 分布式并行 18 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载 19 | from torch.optim.lr_scheduler import CosineAnnealingLR # 余弦退火学习率调度 20 | from transformers import AutoModel # HuggingFace模型加载 21 | from model.MokioModel import MokioMindConfig, MokioMindForCausalLM # MokioMind模型 22 | from dataset.lm_dataset import RLAIFDataset # RL数据集 23 | from trainer.trainer_utils import ( # 训练工具函数 24 | Logger, is_main_process, lm_checkpoint, init_distributed_mode, 25 | setup_seed, SkipBatchSampler, init_model 26 | ) 27 | 28 | warnings.filterwarnings('ignore') 29 | 30 | def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): 31 | # 整合所有奖励函数计算总奖励 32 | def reasoning_model_reward(rewards): 33 | # 先计算推理格式奖励 34 | pattern = r"^\n.*?\n\n\n.*?\n$" 35 | pattern2 = r"^\n.*?\n\n\n\n.*?\n$" 36 | matches_pattern = [re.match(pattern, response, re.S) for response in responses] 37 | matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses] 38 | 39 | format_rewards = [] 40 | for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2): 41 | if match_pattern or match_pattern2: 42 | format_rewards.append(0.5) 43 | else: 44 | format_rewards.append(0.0) 45 | rewards += torch.tensor(format_rewards, device=args.device) 46 | def mark_num(text): 47 | reward = 0 48 | if text.count("") == 1: reward += 0.25 49 | if text.count("") == 1: reward += 0.25 50 | if text.count("") == 1: reward += 0.25 51 | if text.count("") == 1: reward += 0.25 52 | return reward 53 | 54 | mark_rewards = [mark_num(response) for response in responses] 55 | rewards += torch.tensor(mark_rewards, device=args.device) 56 | return rewards 57 | rewards = torch.zeros(len(responses), device=args.device) 58 | 59 | if args.reasoning == 1: 60 | rewards = reasoning_model_reward(rewards) 61 | 62 | with torch.no_grad(): 63 | reward_model_scores = [] 64 | batch_size = len(prompts) 65 | scale = 3.0 66 | 67 | # 📚 批处理评分 68 | for i in range(batch_size): 69 | for j in range(args.num_generations): 70 | response_idx = i * args.num_generations + j 71 | response = responses[response_idx] 72 | prompt = prompts[i] 73 | 74 | # 对话格式解析 75 | pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>" 76 | matches = re.findall(pattern, prompt, re.DOTALL) 77 | messages = [{"role": role, "content": content.strip()} for role, content in matches] 78 | 79 | # 构建完整对话 80 | tmp_chat = messages + [{"role": "assistant", "content": response}] 81 | score = reward_model.get_score(reward_tokenizer, tmp_chat) 82 | score = max(min(score, scale), -scale) 83 | 84 | # 推理模型额外奖励 85 | if args.reasoning == 1: 86 | answer_match = re.search(r'(.*?)', response, re.DOTALL) 87 | if answer_match: 88 | answer_content = answer_match.group(1).strip() 89 | tmp_chat = messages + [{"role": "assistant", "content": answer_content}] 90 | answer_score = reward_model.get_score(reward_tokenizer, tmp_chat) 91 | answer_score = max(min(answer_score, scale), -scale) 92 | score = score * 0.4 + answer_score * 0.6 93 | 94 | reward_model_scores.append(score) 95 | 96 | reward_model_scores = torch.tensor(reward_model_scores, device=args.device) 97 | rewards += reward_model_scores 98 | 99 | return rewards 100 | 101 | def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None): 102 | for step,batch in enumerate(loader,start=start_step+1): 103 | prompts = batch['prompt'] # list[str], length B 104 | 105 | # 📚 分词和编码 106 | prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False, 107 | padding_side="left", add_special_tokens=False).to(args.device) 108 | if args.max_seq_len: 109 | prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:] 110 | prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:] 111 | 112 | with torch.no_grad(): 113 | model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model 114 | outputs = model_for_gen.generate( 115 | **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8, 116 | num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) 117 | completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] 118 | def get_per_token_logps(mdl, input_ids, n_keep): 119 | input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids 120 | logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :] 121 | per_token_logps = [] 122 | for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]): 123 | ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row 124 | per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)) 125 | return torch.stack(per_token_logps) 126 | per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R] 127 | with torch.no_grad(): 128 | ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R] 129 | 130 | # 📚 解码响应文本 131 | completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) 132 | 133 | # 📚 计算奖励 134 | rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) 135 | 136 | grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen] 137 | mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen] 138 | std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen] 139 | advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10) 140 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 141 | 142 | is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R] 143 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) 144 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] 145 | completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() 146 | 147 | kl_div = ref_per_token_logps - per_token_logps 148 | per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R] 149 | per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R] 150 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar 151 | loss.backward() 152 | 153 | if (step + 1) % args.accumulation_steps == 0: 154 | if args.grad_clip > 0: 155 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 156 | optimizer.step() 157 | scheduler.step() 158 | optimizer.zero_grad() 159 | 160 | # 📚 日志记录 161 | if step % args.log_interval == 0 or step == iters: 162 | policy_loss_val = loss.item() 163 | avg_reward_val = rewards.mean().item() 164 | avg_len_val = completion_mask.sum(dim=1).float().mean().item() 165 | current_lr = optimizer.param_groups[0]['lr'] 166 | 167 | Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, ' 168 | f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, ' 169 | f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}') 170 | 171 | if wandb and is_main_process(): 172 | wandb.log({ 173 | "policy_loss": policy_loss_val, 174 | "reward": avg_reward_val, 175 | "avg_response_len": avg_len_val, 176 | "advantages_mean": advantages.mean().item(), 177 | "learning_rate": current_lr 178 | }) 179 | 180 | # 📚 模型保存 181 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): 182 | model.eval() 183 | moe_suffix = '_moe' if lm_config.use_moe else '' 184 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' 185 | state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() 186 | torch.save({k: v.half() for k, v in state_dict.items()}, ckp) 187 | lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, 188 | epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler) 189 | model.train() 190 | 191 | # 📚 内存清理 192 | del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps 193 | del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask 194 | torch.cuda.empty_cache() 195 | gc.collect() 196 | 197 | 198 | 199 | if __name__ == "__main__": 200 | """ 201 | GRPO主函数:组相对策略优化脚本的入口点 202 | 203 | 📚 GRPO训练架构: 204 | 1. Policy模型:需要优化的策略网络 205 | 2. Reference模型:冻结的参考策略 206 | 3. Reward模型:评估生成质量 207 | 4. 组内比较:每个prompt生成多个样本 208 | 5. 相对优势:标准化组内奖励 209 | """ 210 | 211 | # 📚 命令行参数解析 212 | parser = argparse.ArgumentParser(description="MokioMind GRPO (Group Relative Policy Optimization)") 213 | 214 | # ========== 基础训练参数 ========== 215 | parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") 216 | parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名") 217 | parser.add_argument("--epochs", type=int, default=1, help="训练轮数") 218 | parser.add_argument("--batch_size", type=int, default=2, help="batch size(GRPO batch较小)") 219 | 220 | # 📚 GRPO学习率设置 221 | # GRPO学习率通常很小,避免策略剧烈变化 222 | parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率") 223 | 224 | # ========== 硬件配置 ========== 225 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") 226 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") 227 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") 228 | 229 | # ========== 训练策略 ========== 230 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") 231 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") 232 | parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔") 233 | parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔") 234 | 235 | # ========== 模型架构参数 ========== 236 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") 237 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") 238 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") 239 | 240 | # ========== GRPO生成参数 ========== 241 | parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度") 242 | parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度") 243 | 244 | # ========== 数据和模型参数 ========== 245 | parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径") 246 | 247 | # 📚 GRPO关键参数 248 | parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数(组大小)") 249 | parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数") 250 | 251 | # 📚 推理模型配置 252 | parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)') 253 | 254 | # 📚 Reward模型路径 255 | parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径") 256 | 257 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") 258 | 259 | # ========== 实验跟踪 ========== 260 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") 261 | parser.add_argument("--wandb_project", type=str, default="MokioMind-GRPO", help="wandb项目名") 262 | 263 | args = parser.parse_args() 264 | 265 | # ========== 1. 初始化环境和随机种子 ========== 266 | local_rank = init_distributed_mode() 267 | if dist.is_initialized(): args.device = f"cuda:{local_rank}" 268 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) 269 | 270 | # ========== 2. 配置目录、模型参数、检查ckp ========== 271 | os.makedirs(args.save_dir, exist_ok=True) 272 | lm_config = MokioMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, 273 | max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe)) 274 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None 275 | 276 | # ========== 3. 设置混合精度 ========== 277 | device_type = "cuda" if "cuda" in args.device else "cpu" 278 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 279 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) 280 | 281 | # ========== 4. 配置wandb ========== 282 | wandb = None 283 | if args.use_wandb and is_main_process(): 284 | import swanlab as wandb 285 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None 286 | resume = 'must' if wandb_id else None 287 | wandb_run_name = f"MokioMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" 288 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) 289 | 290 | # ========== 5. 初始化模型和数据 ========== 291 | # 📚 GRPO模型架构 292 | base_weight = "reason" if args.reasoning == 1 else "full_sft" 293 | 294 | # 📚 Policy模型(策略模型) 295 | model, tokenizer = init_model(lm_config, base_weight, device=args.device) 296 | 297 | # 📚 Reference模型(用于KL惩罚) 298 | ref_model, _ = init_model(lm_config, base_weight, device=args.device) 299 | ref_model = ref_model.eval().requires_grad_(False) 300 | 301 | # 📚 Reward模型(奖励函数) 302 | reward_model = AutoModel.from_pretrained( 303 | args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True 304 | ) 305 | reward_model = reward_model.to(args.device).eval().requires_grad_(False) 306 | reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True) 307 | 308 | # 📚 数据和优化器 309 | train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) 310 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None 311 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 312 | loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler) 313 | iters = len(loader_for_count) 314 | total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs 315 | scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) 316 | 317 | # ========== 6. 从ckp恢复状态 ========== 318 | start_epoch, start_step = 0, 0 319 | if ckp_data: 320 | model.load_state_dict(ckp_data['model']) 321 | optimizer.load_state_dict(ckp_data['optimizer']) 322 | scheduler.load_state_dict(ckp_data['scheduler']) 323 | start_epoch = ckp_data['epoch'] 324 | start_step = ckp_data.get('step', 0) 325 | 326 | # ========== 7. DDP包装模型 ========== 327 | if dist.is_initialized(): 328 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} 329 | model = DistributedDataParallel(model, device_ids=[local_rank]) 330 | 331 | # ========== 8. 开始训练 ========== 332 | for epoch in range(start_epoch, args.epochs): 333 | train_sampler and train_sampler.set_epoch(epoch) 334 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 335 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) 336 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) 337 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') 338 | grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb) 339 | else: # 默认从头开始 340 | loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, 341 | drop_last=False, shuffle=(train_sampler is None), 342 | num_workers=args.num_workers, sampler=train_sampler) 343 | grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb) 344 | 345 | 346 | def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): 347 | """整合所有奖励函数计算总奖励""" 348 | def reasoning_model_reward(rewards): 349 | pattern = r"^\n.*?\n\n\n.*?\n$" 350 | pattern2 = r"^\n.*?\n\n\n\n.*?\n$" 351 | matches_pattern = [re.match(pattern, response, re.S) for response in responses] 352 | matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses] 353 | 354 | format_rewards = [] 355 | for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2): 356 | if match_pattern or match_pattern2: 357 | format_rewards.append(0.5) 358 | else: 359 | format_rewards.append(0.0) 360 | rewards += torch.tensor(format_rewards, device=args.device) 361 | 362 | def mark_num(text): 363 | reward = 0 364 | if text.count("") == 1: reward += 0.25 365 | if text.count("") == 1: reward += 0.25 366 | if text.count("") == 1: reward += 0.25 367 | if text.count("") == 1: reward += 0.25 368 | return reward 369 | 370 | mark_rewards = [mark_num(response) for response in responses] 371 | rewards += torch.tensor(mark_rewards, device=args.device) 372 | return rewards 373 | 374 | rewards = torch.zeros(len(responses), device=args.device) 375 | if args.reasoning == 1: 376 | rewards = reasoning_model_reward(rewards) 377 | 378 | with torch.no_grad(): 379 | reward_model_scores = [] 380 | batch_size = len(prompts) 381 | scale = 3.0 382 | 383 | for i in range(batch_size): 384 | for j in range(args.num_generations): 385 | response_idx = i * args.num_generations + j 386 | response = responses[response_idx] 387 | prompt = prompts[i] 388 | 389 | pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>" 390 | matches = re.findall(pattern, prompt, re.DOTALL) 391 | messages = [{"role": role, "content": content.strip()} for role, content in matches] 392 | 393 | tmp_chat = messages + [{"role": "assistant", "content": response}] 394 | score = reward_model.get_score(reward_tokenizer, tmp_chat) 395 | score = max(min(score, scale), -scale) 396 | 397 | if args.reasoning == 1: 398 | answer_match = re.search(r'(.*?)', response, re.DOTALL) 399 | if answer_match: 400 | answer_content = answer_match.group(1).strip() 401 | tmp_chat = messages + [{"role": "assistant", "content": answer_content}] 402 | answer_score = reward_model.get_score(reward_tokenizer, tmp_chat) 403 | answer_score = max(min(answer_score, scale), -scale) 404 | score = score * 0.4 + answer_score * 0.6 405 | 406 | reward_model_scores.append(score) 407 | 408 | reward_model_scores = torch.tensor(reward_model_scores, device=args.device) 409 | rewards += reward_model_scores 410 | 411 | return rewards 412 | 413 | 414 | def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None): 415 | for step, batch in enumerate(loader, start=start_step + 1): 416 | prompts = batch['prompt'] # list[str], length B 417 | prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False, 418 | padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P] 419 | if args.max_seq_len: 420 | prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:] 421 | prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:] 422 | 423 | with torch.no_grad(): 424 | # DDP 模型需要使用 .module 访问 generate 方法 425 | model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model 426 | outputs = model_for_gen.generate( 427 | **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8, 428 | num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R] 429 | 430 | completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R] 431 | 432 | def get_per_token_logps(mdl, input_ids, n_keep): 433 | input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids 434 | logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :] 435 | per_token_logps = [] 436 | for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]): 437 | ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row 438 | per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)) 439 | return torch.stack(per_token_logps) 440 | 441 | per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R] 442 | with torch.no_grad(): 443 | ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R] 444 | 445 | completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) 446 | rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen] 447 | 448 | grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen] 449 | mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen] 450 | std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen] 451 | advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10) 452 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen] 453 | 454 | is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R] 455 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) 456 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] 457 | completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R] 458 | 459 | kl_div = ref_per_token_logps - per_token_logps 460 | per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R] 461 | per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R] 462 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar 463 | loss.backward() 464 | 465 | if (step + 1) % args.accumulation_steps == 0: 466 | if args.grad_clip > 0: 467 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 468 | optimizer.step() 469 | scheduler.step() 470 | optimizer.zero_grad() 471 | 472 | if step % args.log_interval == 0 or step == iters: 473 | policy_loss_val = loss.item() 474 | avg_reward_val = rewards.mean().item() 475 | avg_len_val = completion_mask.sum(dim=1).float().mean().item() 476 | current_lr = optimizer.param_groups[0]['lr'] 477 | 478 | Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, ' 479 | f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, ' 480 | f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}') 481 | 482 | if wandb and is_main_process(): 483 | wandb.log({ 484 | "policy_loss": policy_loss_val, 485 | "reward": avg_reward_val, 486 | "avg_response_len": avg_len_val, 487 | "advantages_mean": advantages.mean().item(), 488 | "learning_rate": current_lr 489 | }) 490 | 491 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): 492 | model.eval() 493 | moe_suffix = '_moe' if lm_config.use_moe else '' 494 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' 495 | state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() 496 | torch.save({k: v.half() for k, v in state_dict.items()}, ckp) 497 | lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, 498 | epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler) 499 | model.train() 500 | 501 | del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps 502 | del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask 503 | torch.cuda.empty_cache() 504 | gc.collect() 505 | 506 | 507 | if __name__ == "__main__": 508 | parser = argparse.ArgumentParser(description="MokioMind GRPO (Group Relative Policy Optimization)") 509 | parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") 510 | parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名") 511 | parser.add_argument("--epochs", type=int, default=1, help="训练轮数") 512 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 513 | parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率") 514 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") 515 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") 516 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数") 517 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") 518 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") 519 | parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔") 520 | parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔") 521 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") 522 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") 523 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") 524 | parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度") 525 | parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度") 526 | parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径") 527 | parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数") 528 | parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数") 529 | parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)') 530 | parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径") 531 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") 532 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") 533 | parser.add_argument("--wandb_project", type=str, default="MokioMind-GRPO", help="wandb项目名") 534 | args = parser.parse_args() 535 | 536 | # ========== 1. 初始化环境和随机种子 ========== 537 | local_rank = init_distributed_mode() 538 | if dist.is_initialized(): args.device = f"cuda:{local_rank}" 539 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) 540 | 541 | # ========== 2. 配置目录、模型参数、检查ckp ========== 542 | os.makedirs(args.save_dir, exist_ok=True) 543 | lm_config = MokioMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, 544 | max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe)) 545 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None 546 | 547 | # ========== 3. 设置混合精度 ========== 548 | device_type = "cuda" if "cuda" in args.device else "cpu" 549 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 550 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) 551 | 552 | # ========== 4. 配wandb ========== 553 | wandb = None 554 | if args.use_wandb and is_main_process(): 555 | import swanlab as wandb 556 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None 557 | resume = 'must' if wandb_id else None 558 | wandb_run_name = f"MokioMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" 559 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) 560 | 561 | # ========== 5. 初始化模型和数据 ========== 562 | base_weight = "reason" if args.reasoning == 1 else "full_sft" 563 | # Policy模型 564 | model, tokenizer = init_model(lm_config, base_weight, device=args.device) 565 | # Reference模型 566 | ref_model, _ = init_model(lm_config, base_weight, device=args.device) 567 | ref_model = ref_model.eval().requires_grad_(False) 568 | # Reward模型 569 | reward_model = AutoModel.from_pretrained( 570 | args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True 571 | ) 572 | reward_model = reward_model.to(args.device).eval().requires_grad_(False) 573 | reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True) 574 | # 数据和优化器 575 | train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) 576 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None 577 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 578 | loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler) 579 | iters = len(loader_for_count) 580 | total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs 581 | scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) 582 | 583 | # ========== 6. 从ckp恢复状态 ========== 584 | start_epoch, start_step = 0, 0 585 | if ckp_data: 586 | model.load_state_dict(ckp_data['model']) 587 | optimizer.load_state_dict(ckp_data['optimizer']) 588 | scheduler.load_state_dict(ckp_data['scheduler']) 589 | start_epoch = ckp_data['epoch'] 590 | start_step = ckp_data.get('step', 0) 591 | 592 | # ========== 7. DDP包模型 ========== 593 | if dist.is_initialized(): 594 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} 595 | model = DistributedDataParallel(model, device_ids=[local_rank]) 596 | 597 | # ========== 8. 开始训练 ========== 598 | for epoch in range(start_epoch, args.epochs): 599 | train_sampler and train_sampler.set_epoch(epoch) 600 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 601 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) 602 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) 603 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') 604 | grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb) 605 | else: # 默认从头开始 606 | loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, 607 | drop_last=False, shuffle=(train_sampler is None), 608 | num_workers=args.num_workers, sampler=train_sampler) 609 | grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb) 610 | --------------------------------------------------------------------------------