├── model
├── __init__.py
├── tokenizer_config.json
└── MokioModel.py
├── dataset
├── __init__.py
└── lm_dataset.py
├── .python-version
├── torch_methods
├── unsqueeze_d.py
├── transpose_d.py
├── outer_d.py
├── arange_d.py
├── Dropout_d.py
├── view_d.py
├── cat_d.py
├── where_d.py
└── Linear_d.py
├── main.py
├── .vscode
└── settings.json
├── .gitignore
├── README.md
├── pyproject.toml
├── eval.py
└── trainer
├── trainer_utils.py
├── train_lora.py
├── train_full_sft.py
├── train_pretrain.py
├── train_dpo.py
├── train_ppo.py
└── train_grpo.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.13
2 |
--------------------------------------------------------------------------------
/torch_methods/unsqueeze_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | t1 = torch.Tensor([1, 2, 3])
3 | t2 = t1.unsqueeze(0)
4 | print(t2)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | def main():
2 | print("Hello from mokiomind!")
3 |
4 |
5 | if __name__ == "__main__":
6 | main()
7 |
--------------------------------------------------------------------------------
/torch_methods/transpose_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | t1=torch.Tensor([[1,2,3],[4,5,6]])
4 | t1=t1.transpose(0,1)
5 | print(t1)
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python-envs.defaultEnvManager": "ms-python.python:venv",
3 | "python-envs.pythonProjects": []
4 | }
--------------------------------------------------------------------------------
/torch_methods/outer_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | v1=torch.tensor([1,2,3])
3 | v2=torch.tensor([4,5,6])
4 | result=torch.outer(v1,v2)
5 | print(result)
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python-generated files
2 | __pycache__/
3 | *.py[oc]
4 | build/
5 | dist/
6 | wheels/
7 | *.egg-info
8 |
9 | # Virtual environments
10 | .venv
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## 希望大家支持原项目minimind哦~https://github.com/jingyaogong/minimind
2 |
3 | ### 现在项目还有不完善和错误的地方,优先看视频,后续会慢慢补齐github这边
4 |
5 | ### 希望能为其他项目点点star🌟
6 |
7 | ### 感谢你的支持
8 |
--------------------------------------------------------------------------------
/torch_methods/arange_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | t=torch.arange(0,10,2)
4 | print(t) # Output: tensor([0, 2, 4, 6, 8])
5 |
6 | t2=torch.arange(5,0,-1)
7 | print(t2) # Output: tensor([5, 4, 3, 2, 1])
--------------------------------------------------------------------------------
/torch_methods/Dropout_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | dropout_layer = nn.Dropout(p=0.5)
5 |
6 | t1=torch.Tensor([1,2,3])
7 | t2=dropout_layer(t1)
8 | # 这里Dropout丢弃了1,为了保持期望不变,将1和3扩大两倍
9 | print(t2)
--------------------------------------------------------------------------------
/torch_methods/view_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | t = torch.tensor([[ 1, 2, 3, 4, 5, 6],
3 | [ 7, 8, 9, 10, 11, 12]])
4 | t_view1 = t.view(3, 4)
5 | print(t_view1)
6 | t_view2 = t.view(4, 3)
7 | print(t_view2)
--------------------------------------------------------------------------------
/torch_methods/cat_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | t1=torch.tensor([[1,2,3],[4,5,6]])
4 | t2=torch.tensor([[7,8,9],[10,11,12]])
5 | result=torch.cat((t1,t2),dim=0)
6 | print(result)
7 |
8 | result2=torch.cat((t1,t2),dim=1)
9 | print(result2)
--------------------------------------------------------------------------------
/torch_methods/where_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | x=torch.tensor([1, 2, 3, 4, 5])
4 | y=torch.tensor([10, 20, 30, 40, 50])
5 |
6 | condition=(x>3)
7 |
8 | result=torch.where(condition,x,y)
9 |
10 | print(result) # Output: tensor([10, 20, 30, 4, 5])
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "mokiomind"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | requires-python = ">=3.13"
7 | dependencies = [
8 | "numpy>=2.3.4",
9 | "pandas>=2.3.3",
10 | "torch>=2.9.0",
11 | "transformers>=4.57.1",
12 | ]
13 |
--------------------------------------------------------------------------------
/torch_methods/Linear_d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | layer = nn.Linear(in_features=3, out_features=5, bias=True)
5 | t1 = torch.Tensor([1, 2, 3]) # shape: (3,)
6 |
7 | t2 = torch.Tensor([[1, 2, 3]]) # shape: (1, 3)
8 | # 这里应用的w和b是随机的,真实训练里会在optimizer上更新
9 | output2 = layer(t2) # shape: (1, 5)
10 | print(output2)
--------------------------------------------------------------------------------
/dataset/lm_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from torch.utils.data import Dataset
4 | import torch
5 | import os
6 |
7 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
8 |
9 |
10 | class PretrainDataset(Dataset):
11 | def __init__(self, data_path, tokenizer, max_length=512):
12 | super().__init__()
13 | self.tokenizer = tokenizer
14 | self.max_length = max_length
15 | self.samples = self.load_data(data_path)
16 |
17 | def load_data(self, path):
18 | samples = []
19 | with open(path, "r", encoding="utf-8") as f:
20 | for line_num, line in enumerate(f, 1):
21 | # 提取每一行内容放到sample
22 | data = json.loads(line.strip())
23 | samples.append(data)
24 | return samples
25 |
26 | def __len__(self):
27 | return len(self.samples)
28 |
29 | def __getitem__(self, index):
30 | sample = self.samples[index]
31 | # 用tokenizer进行编码
32 | # 超过max_length的截断,不到的填充
33 | encoding = self.tokenizer(
34 | str(sample["text"]),
35 | max_length=self.max_length,
36 | padding="max_length",
37 | truncation=True,
38 | return_tensors="pt",
39 | )
40 |
41 | input_ids = encoding.input_ids.squeeze()
42 | # 忽略padding产生的Y
43 | loss_mask = input_ids != self.tokenizer.pad_token_id
44 | # 第一个到倒数第二个token
45 | X = torch.tensor(input_ids[:-1], dtype=torch.long)
46 | # 第二个到最后一个token
47 | Y = torch.tensor(input_ids[1:], dtype=torch.long)
48 | loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
49 | return X, Y, loss_mask
50 |
--------------------------------------------------------------------------------
/model/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_bos_token": false,
3 | "add_eos_token": false,
4 | "add_prefix_space": false,
5 | "added_tokens_decoder": {
6 | "0": {
7 | "content": "<|endoftext|>",
8 | "lstrip": false,
9 | "normalized": false,
10 | "rstrip": false,
11 | "single_word": false,
12 | "special": true
13 | },
14 | "1": {
15 | "content": "<|im_start|>",
16 | "lstrip": false,
17 | "normalized": false,
18 | "rstrip": false,
19 | "single_word": false,
20 | "special": true
21 | },
22 | "2": {
23 | "content": "<|im_end|>",
24 | "lstrip": false,
25 | "normalized": false,
26 | "rstrip": false,
27 | "single_word": false,
28 | "special": true
29 | }
30 | },
31 | "additional_special_tokens": [],
32 | "bos_token": "<|im_start|>",
33 | "clean_up_tokenization_spaces": false,
34 | "eos_token": "<|im_end|>",
35 | "legacy": true,
36 | "model_max_length": 32768,
37 | "pad_token": "<|endoftext|>",
38 | "sp_model_kwargs": {},
39 | "spaces_between_special_tokens": false,
40 | "tokenizer_class": "PreTrainedTokenizerFast",
41 | "unk_token": "<|endoftext|>",
42 | "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}"
43 | }
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import warnings
4 | import numpy as np
5 | import torch
6 | from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
7 | from model.MokioModel import MokioMindConfig, MokioMindForCausalLM
8 | from trainer.trainer_utils import setup_seed
9 |
10 | warnings.filterwarnings("ignore")
11 |
12 |
13 | def init_model(args):
14 | tokenizer = AutoTokenizer.from_pretrained(args.load_from)
15 | if "model" in args.load_from:
16 | model = MokioMindForCausalLM(
17 | MokioMindConfig(
18 | hidden_size=args.hidden_size,
19 | num_hidden_layers=args.num_hidden_layers,
20 | inference_rope_scaling=args.inference_rope_scaling,
21 | )
22 | )
23 | moe_suffix = "_moe" if hasattr(args, "use_moe") and args.use_moe else ""
24 | ckp = f"./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth"
25 | model.load_state_dict(torch.load(ckp, map_location=args.device), strict=False)
26 | else:
27 | model = AutoModelForCausalLM.from_pretrained(
28 | args.load_from, trust_remote_code=True
29 | )
30 | print(
31 | f"MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)"
32 | )
33 | return model.eval().to(args.device), tokenizer
34 |
35 | def main():
36 | parser = argparse.ArgumentParser(description="MiniMind模型推理与对话")
37 | parser.add_argument(
38 | "--load_from",
39 | default="model",
40 | type=str,
41 | help="模型加载路径(model=原生torch权重,其他路径=transformers格式)",
42 | )
43 | parser.add_argument("--save_dir", default="out", type=str, help="模型权重目录")
44 | parser.add_argument(
45 | "--weight",
46 | default="full_sft",
47 | type=str,
48 | help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)",
49 | )
50 | parser.add_argument(
51 | "--lora_weight",
52 | default="None",
53 | type=str,
54 | help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)",
55 | )
56 | parser.add_argument(
57 | "--hidden_size",
58 | default=512,
59 | type=int,
60 | help="隐藏层维度(512=Small-26M, 640=MoE-145M, 768=Base-104M)",
61 | )
62 | parser.add_argument(
63 | "--num_hidden_layers",
64 | default=8,
65 | type=int,
66 | help="隐藏层数量(Small/MoE=8, Base=16)",
67 | )
68 | parser.add_argument(
69 | "--use_moe",
70 | default=0,
71 | type=int,
72 | choices=[0, 1],
73 | help="是否使用MoE架构(0=否,1=是)",
74 | )
75 | parser.add_argument(
76 | "--inference_rope_scaling",
77 | default=False,
78 | action="store_true",
79 | help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)",
80 | )
81 | parser.add_argument(
82 | "--max_new_tokens",
83 | default=8192,
84 | type=int,
85 | help="最大生成长度(注意:并非模型实际长文本能力)",
86 | )
87 | parser.add_argument(
88 | "--temperature",
89 | default=0.85,
90 | type=float,
91 | help="生成温度,控制随机性(0-1,越大越随机)",
92 | )
93 | parser.add_argument(
94 | "--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)"
95 | )
96 | parser.add_argument(
97 | "--historys",
98 | default=0,
99 | type=int,
100 | help="携带历史对话轮数(需为偶数,0表示不携带历史)",
101 | )
102 | parser.add_argument(
103 | "--device",
104 | default="cuda" if torch.cuda.is_available() else "cpu",
105 | type=str,
106 | help="运行设备",
107 | )
108 | args = parser.parse_args()
109 |
110 | prompts = [
111 | "你有什么特长?",
112 | "为什么天空是蓝色的",
113 | "请用Python写一个计算斐波那契数列的函数",
114 | '解释一下"光合作用"的基本过程',
115 | "如果明天下雨,我应该如何出门",
116 | "比较一下猫和狗作为宠物的优缺点",
117 | "解释什么是机器学习",
118 | "推荐一些中国的美食",
119 | ]
120 |
121 | conversation = []
122 | model, tokenizer = init_model(args)
123 | input_mode = int(input("[0] 自动测试\n[1] 手动输入\n"))
124 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
125 |
126 | prompt_iter = prompts if input_mode == 0 else iter(lambda: input("👶: "), "")
127 | for prompt in prompt_iter:
128 | setup_seed(2026) # or setup_seed(random.randint(0, 2048))
129 | if input_mode == 0:
130 | print(f"👶: {prompt}")
131 | conversation = conversation[-args.historys :] if args.historys else []
132 | conversation.append({"role": "user", "content": prompt})
133 |
134 | templates = {
135 | "conversation": conversation,
136 | "tokenize": False,
137 | "add_generation_prompt": True,
138 | }
139 | if args.weight == "reason":
140 | templates["enable_thinking"] = True # 仅Reason模型使用
141 | inputs = (
142 | tokenizer.apply_chat_template(**templates)
143 | if args.weight != "pretrain"
144 | else (tokenizer.bos_token + prompt)
145 | )
146 | inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)
147 |
148 | print("🤖️: ", end="")
149 | generated_ids = model.generate(
150 | inputs=inputs["input_ids"],
151 | attention_mask=inputs["attention_mask"],
152 | max_new_tokens=args.max_new_tokens,
153 | do_sample=True,
154 | streamer=streamer,
155 | pad_token_id=tokenizer.pad_token_id,
156 | eos_token_id=tokenizer.eos_token_id,
157 | top_p=args.top_p,
158 | temperature=args.temperature,
159 | repetition_penalty=1.0,
160 | )
161 | response = tokenizer.decode(
162 | generated_ids[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
163 | )
164 | conversation.append({"role": "assistant", "content": response})
165 | print("\n\n")
166 |
167 |
168 | if __name__ == "__main__":
169 | main()
170 |
--------------------------------------------------------------------------------
/trainer/trainer_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.distributed as dist
7 | from torch.utils.data import Sampler
8 |
9 | # 检查是否是主进程
10 | def is_main_process():
11 | return not dist.is_initialized() or dist.get_rank() == 0
12 |
13 | # 日志
14 | def Logger(content):
15 | if is_main_process():
16 | print(content)
17 |
18 | # 动态学习率计算
19 | def get_lr(current_step, total_steps, lr):
20 | return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
21 |
22 | # 初始化分布式
23 | def init_distributed_mode():
24 | if int(os.environ.get("RANK", -1)) == -1:
25 | return 0 # 非DDP模式
26 |
27 | dist.init_process_group(backend="nccl")
28 | local_rank = int(os.environ["LOCAL_RANK"])
29 | torch.cuda.set_device(local_rank)
30 | return local_rank
31 |
32 | # 设置种子
33 | def setup_seed(seed: int):
34 | random.seed(seed)
35 | np.random.seed(seed)
36 | torch.manual_seed(seed)
37 | torch.cuda.manual_seed(seed)
38 | torch.cuda.manual_seed_all(seed)
39 | torch.backends.cudnn.deterministic = True
40 | torch.backends.cudnn.benchmark = False
41 |
42 | # 设置检查点
43 | def lm_checkpoint(
44 | lm_config,
45 | weight="full_sft",
46 | model=None,
47 | optimizer=None,
48 | epoch=0,
49 | step=0,
50 | wandb=None,
51 | save_dir="checkpoints",
52 | **kwargs,
53 | ):
54 | os.makedirs(save_dir, exist_ok=True)
55 |
56 | moe_path = "_moe" if hasattr(lm_config, "use_moe") and lm_config.use_moe else ""
57 | ckp_path = f"{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth"
58 | resume_path = f"{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth"
59 |
60 | if model is not None:
61 | from torch.nn.parallel import DistributedDataParallel
62 |
63 | if isinstance(model, DistributedDataParallel):
64 | state_dict = model.module.state_dict()
65 | else:
66 | state_dict = model.state_dict()
67 |
68 | ckp_tmp = ckp_path + ".tmp"
69 | torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp)
70 | os.replace(ckp_tmp, ckp_path)
71 |
72 | wandb_id = None
73 | if wandb:
74 | if hasattr(wandb, "get_run"):
75 | run = wandb.get_run()
76 | wandb_id = getattr(run, "id", None) if run else None
77 | else:
78 | wandb_id = getattr(wandb, "id", None)
79 |
80 | resume_data = {
81 | "model": state_dict,
82 | "optimizer": optimizer.state_dict(),
83 | "epoch": epoch,
84 | "step": step,
85 | "world_size": dist.get_world_size() if dist.is_initialized() else 1,
86 | "wandb_id": wandb_id,
87 | }
88 |
89 | for key, value in kwargs.items():
90 | if value is not None:
91 | if hasattr(value, "state_dict"):
92 | if isinstance(value, DistributedDataParallel):
93 | resume_data[key] = value.module.state_dict()
94 | else:
95 | resume_data[key] = value.state_dict()
96 | else:
97 | resume_data[key] = value
98 |
99 | resume_tmp = resume_path + ".tmp"
100 | torch.save(resume_data, resume_tmp)
101 | os.replace(resume_tmp, resume_path)
102 |
103 | else: # 加载模式
104 | if os.path.exists(resume_path):
105 | ckp_data = torch.load(resume_path, map_location="cpu")
106 | saved_ws = ckp_data.get("world_size", 1)
107 | current_ws = dist.get_world_size() if dist.is_initialized() else 1
108 |
109 | if saved_ws != current_ws:
110 | ckp_data["step"] = ckp_data["step"] * saved_ws // current_ws
111 | Logger(
112 | f"GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data['step']}"
113 | )
114 |
115 | return ckp_data
116 | return None
117 |
118 | # 初始化模型
119 | def init_model(
120 | lm_config,
121 | from_weight="pretrain",
122 | tokenizer_path=None,
123 | save_dir="out",
124 | device="cuda",
125 | ):
126 | from transformers import AutoTokenizer
127 | from model.MokioModel import MokioMindForCausalLM
128 |
129 | # 如果没有指定 tokenizer_path,使用项目根目录下的 model 文件夹
130 | if tokenizer_path is None:
131 | # 获取当前文件所在目录的父目录(项目根目录)
132 | current_dir = os.path.dirname(os.path.abspath(__file__))
133 | project_root = os.path.dirname(current_dir)
134 | tokenizer_path = os.path.join(project_root, "model")
135 |
136 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
137 |
138 | model = MokioMindForCausalLM(lm_config)
139 |
140 | if from_weight != "none":
141 | moe_suffix = (
142 | "_moe" if hasattr(lm_config, "use_moe") and lm_config.use_moe else ""
143 | )
144 | weight_path = (
145 | f"{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth"
146 | )
147 |
148 | weights = torch.load(weight_path, map_location=device)
149 |
150 | model.load_state_dict(weights, strict=False)
151 |
152 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
153 | Logger(f"所加载Model可训练参数:{total_params / 1e6:.3f} 百万")
154 |
155 | return model.to(device), tokenizer
156 |
157 |
158 | class SkipBatchSampler(Sampler):
159 | def __init__(self, sampler, batch_size, skip_batches=0):
160 | self.sampler = sampler #
161 | self.batch_size = batch_size
162 | self.skip_batches = skip_batches
163 |
164 | def __iter__(self):
165 | batch = [] # 当前批次
166 | skipped = 0 # 已跳过的批次数
167 |
168 | for idx in self.sampler:
169 | batch.append(idx) # 添加样本到当前批次
170 |
171 | if len(batch) == self.batch_size:
172 | if skipped < self.skip_batches:
173 | skipped += 1 # 增加跳过计数
174 | batch = [] # 清空批次,不返回
175 | continue # 跳过这个批次
176 |
177 | yield batch
178 | batch = [] # 重置批次
179 |
180 | if len(batch) > 0 and skipped >= self.skip_batches:
181 | yield batch
182 |
183 | def __len__(self):
184 | total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
185 |
186 | return max(0, total_batches - self.skip_batches)
187 |
--------------------------------------------------------------------------------
/trainer/train_lora.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | __package__ = "trainer"
5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6 |
7 | import argparse
8 | import time
9 | import warnings
10 | import torch
11 | import torch.distributed as dist
12 | from contextlib import nullcontext
13 | from torch import optim, nn
14 | from torch.nn.parallel import DistributedDataParallel
15 | from torch.utils.data import DataLoader, DistributedSampler
16 | from model.model_minimind import MiniMindConfig
17 | from dataset.lm_dataset import SFTDataset
18 | from model.model_lora import save_lora, apply_lora
19 | from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
20 |
21 | warnings.filterwarnings('ignore')
22 |
23 | def train_epoch(epoch,loader,iters,lora_params,start_step=0,wandb=None):
24 | # reduction='none'表示不对损失进行平均处理
25 | loss_fct=nn.CrossEntropyLoss(reduction='none')
26 | start_time=time.time()
27 | for step,(X,Y,loss_mask) in enumerate(loader,start=start_step+1):
28 | X=X.to(args.device)
29 | Y=Y.to(args.device)
30 | loss_mask=loss_mask.to(args.device)
31 | # 动态调整学习率
32 | # 余弦退火
33 | lr=get_lr(epoch*iters+step,args.epochs*iters,args.learning_rate)
34 |
35 |
36 | for param_group in optimizer.param_groups:
37 | param_group['lr']=lr
38 |
39 | # 混合精度上下文训练
40 | with autocast_ctx:
41 | # 前向传播
42 | res=model(X)
43 |
44 | # 损失计算
45 | loss=loss_fct(res.logits.view(-1,res.logits.size(-1)),Y.view(-1))
46 |
47 | loss = (loss * loss_mask).sum() / loss_mask.sum()
48 | loss += res.aux_loss
49 | loss = loss / args.accumulation_steps
50 | # 混合精度反向传播
51 | scaler.scale(loss).backward()
52 | if (step+1)%args.accumulation_steps==0:
53 | scaler.unscale_(optimizer)
54 | # 梯度裁剪,防止梯度爆炸
55 | torch.nn.utils.clip_grad_norm_(lora_params,args.grad_clip)
56 |
57 | scaler.step(optimizer)
58 | scaler.update()
59 |
60 | optimizer.zero_grad(set_to_none=True)
61 |
62 | # 每log_interval步或者最后一步打印一次日志
63 | if step % args.log_interval == 0 or step == iters - 1:
64 | spend_time = time.time() - start_time
65 | current_loss = loss.item() * args.accumulation_steps
66 | current_lr = optimizer.param_groups[-1]['lr']
67 | eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
68 |
69 | Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
70 |
71 | if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
72 |
73 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
74 | # 评估测评模型
75 | model.eval()
76 |
77 | # LoRA保存,只保存AB矩阵
78 | lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
79 | # LoRA只保存LoRA权重
80 | save_lora(model, lora_save_path)
81 | lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
82 | model.train()
83 |
84 |
85 | if __name__ == "__main__":
86 | parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
87 | parser.add_argument("--save_dir", type=str, default="../out/lora", help="模型保存目录")
88 | parser.add_argument("--lora_name", type=str, default="lora_identity", help="LoRA权重名称(如lora_identity/lora_medical等)")
89 | parser.add_argument("--epochs", type=int, default=50, help="训练轮数")
90 | parser.add_argument("--batch_size", type=int, default=32, help="batch size")
91 | parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
92 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
93 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
94 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
95 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
96 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
97 | parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
98 | parser.add_argument("--save_interval", type=int, default=1, help="模型保存间隔")
99 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
100 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
101 | parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
102 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
103 | parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径")
104 | parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练,默认full_sft")
105 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
106 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
107 | parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
108 | args = parser.parse_args()
109 |
110 | # ========== 1. 初始化环境和随机种子 ==========
111 | local_rank = init_distributed_mode()
112 | if dist.is_initialized(): args.device = f"cuda:{local_rank}"
113 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
114 |
115 | # ========== 2. 配置目录、模型参数、检查ckp ==========
116 | os.makedirs(args.save_dir, exist_ok=True)
117 | lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
118 | ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
119 |
120 | # ========== 3. 设置混合精度 ==========
121 | device_type = "cuda" if "cuda" in args.device else "cpu"
122 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
123 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
124 |
125 | # ========== 4. 配wandb ==========
126 | wandb = None
127 | if args.use_wandb and is_main_process():
128 | import swanlab as wandb
129 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None
130 | resume = 'must' if wandb_id else None
131 | wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
132 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
133 |
134 | # ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
135 | model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
136 | apply_lora(model)
137 |
138 | # 统计参数
139 | total_params = sum(p.numel() for p in model.parameters())
140 | lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
141 | Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
142 | Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
143 | Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
144 |
145 | # 冻结非LoRA参数,收集LoRA参数
146 | lora_params = []
147 | for name, param in model.named_parameters():
148 | if 'lora' in name:
149 | param.requires_grad = True
150 | lora_params.append(param)
151 | else:
152 | param.requires_grad = False
153 |
154 | # ========== 6. 定义数据和优化器 ==========
155 | train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
156 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
157 | scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
158 | optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
159 |
160 | # ========== 7. 从ckp恢复状态 ==========
161 | start_epoch, start_step = 0, 0
162 | if ckp_data:
163 | model.load_state_dict(ckp_data['model'], strict=False)
164 | optimizer.load_state_dict(ckp_data['optimizer'])
165 | scaler.load_state_dict(ckp_data['scaler'])
166 | start_epoch = ckp_data['epoch']
167 | start_step = ckp_data.get('step', 0)
168 |
169 | # ========== 8. DDP包模型 ==========
170 | if dist.is_initialized():
171 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
172 | model = DistributedDataParallel(model, device_ids=[local_rank])
173 |
174 | # ========== 9. 开始训练 ==========
175 | for epoch in range(start_epoch, args.epochs):
176 | train_sampler and train_sampler.set_epoch(epoch)
177 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
178 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
179 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
180 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
181 | train_epoch(epoch, loader, len(loader) + start_step + 1, lora_params, start_step, wandb)
182 | else: # 默认从头开始
183 | loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
184 | train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
185 |
--------------------------------------------------------------------------------
/trainer/train_full_sft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # 📚 Python模块系统和路径管理
5 | __package__ = "trainer"
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7 |
8 | import argparse # 命令行参数解析
9 | import time # 时间测量
10 | import warnings # 警告控制
11 | import torch # PyTorch深度学习框架
12 | import torch.distributed as dist # 分布式训练
13 | from contextlib import nullcontext # 上下文管理器
14 | from torch import optim, nn # 优化器和神经网络
15 | from torch.nn.parallel import DistributedDataParallel # 分布式并行
16 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载
17 |
18 | from model.model_minimind import MiniMindConfig # 模型配置
19 | from dataset.lm_dataset import SFTDataset # SFT数据集
20 | from trainer.trainer_utils import ( # 训练工具
21 | get_lr, Logger, is_main_process, lm_checkpoint,
22 | init_distributed_mode, setup_seed, init_model, SkipBatchSampler
23 | )
24 |
25 | warnings.filterwarnings('ignore')
26 |
27 | def train_epoch(epoch,loader,iters,start_step=0,wandb=None):
28 | loss_fct=nn.CrossEntropyLoss(reduction='none')
29 | start_time=time.time()
30 |
31 | for step,(X,Y,loss_mask) in enumerate(loader,start=start_step+1):
32 | X=X.to(args.device)
33 | Y=Y.to(args.device)
34 | loss_mask=loss_mask.to(args.device)
35 |
36 | # 动态调整学习率
37 | lr=get_lr(epoch*iters+step,args.epochs*iters,args.learning_rate)
38 | for param_group in optimizer.param_groups:
39 | param_group['lr']=lr
40 |
41 | with autocast_ctx:
42 | # 前向传播
43 | res=model(X)
44 |
45 | # 损失计算
46 | loss=loss_fct(res.logits.view(-1,res.logits.size(-1)),Y.view(-1)).view(Y.size())
47 |
48 | loss= (loss * loss_mask).sum() / loss_mask.sum()
49 |
50 | loss+=res.aux_loss
51 |
52 | loss=loss/args.acculation_steps
53 |
54 | scaler.scale(loss).backward()
55 |
56 | if (step+1)%args.accumulation_steps==0:
57 | scaler.unscale_(optimizer)
58 |
59 | torch.nn.utils.clip_grad_norm_(model.parameters(),args.grad_clip)
60 |
61 | scaler.step(optimizer)
62 | scaler.update()
63 |
64 | optimizer.zero_grad(set_to_none=True)
65 |
66 | if step % args.log_interval == 0 or step == iters - 1:
67 | spend_time = time.time() - start_time
68 | current_loss = loss.item() * args.accumulation_steps # 恢复真实损失值
69 | current_lr = optimizer.param_groups[-1]['lr']
70 | eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
71 |
72 | Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
73 |
74 | # 记录到实验跟踪系统
75 | if wandb:
76 | wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
77 |
78 | # 📚 SFT模型检查点保存
79 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
80 | model.eval() # 切换到评估模式
81 |
82 | # 构建SFT模型保存路径
83 | moe_suffix = '_moe' if lm_config.use_moe else ''
84 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
85 |
86 | # 处理分布式模型的状态字典
87 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
88 | state_dict = model.module.state_dict()
89 | else:
90 | state_dict = model.state_dict()
91 |
92 | # 半精度保存节省空间
93 | state_dict = {k: v.half() for k, v in state_dict.items()}
94 | torch.save(state_dict, ckp)
95 | # 保存完整训练状态
96 | lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
97 | epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
98 | model.train() # 恢复训练模式
99 |
100 | if __name__ == "__main__":
101 | """
102 | SFT主函数:监督微调脚本的入口点
103 |
104 | 📚 SFT与预训练的参数差异:
105 | - 学习率更小:5e-7 vs 5e-4(预训练)
106 | - 训练轮数较少:2轮 vs 多轮
107 | - batch_size可以更小:已有基础能力,不需要大batch
108 | - 累积步数通常为1:SFT数据质量高,不需要太多累积
109 | """
110 |
111 | parser = argparse.ArgumentParser(description="MiniMind Full SFT")
112 |
113 | # ========== 基础训练参数 ==========
114 | parser.add_argument("--save_dir", type=str, default="../out",
115 | help="模型保存目录")
116 | parser.add_argument('--save_weight', default='full_sft', type=str,
117 | help="保存权重的前缀名")
118 | parser.add_argument("--epochs", type=int, default=2,
119 | help="训练轮数(SFT通常2-5轮即可)")
120 | parser.add_argument("--batch_size", type=int, default=16,
121 | help="batch size(SFT可以使用较小的batch)")
122 |
123 | # 📚 SFT学习率设置知识点
124 | # SFT学习率通常比预训练小1-2个数量级
125 | # 因为模型已经有了基础能力,只需要微调
126 | parser.add_argument("--learning_rate", type=float, default=5e-7,
127 | help="初始学习率(比预训练小很多)")
128 |
129 | # ========== 硬件配置 ==========
130 | parser.add_argument("--device", type=str,
131 | default="cuda:0" if torch.cuda.is_available() else "cpu",
132 | help="训练设备")
133 | parser.add_argument("--dtype", type=str, default="bfloat16",
134 | help="混合精度类型")
135 | parser.add_argument("--num_workers", type=int, default=1,
136 | help="数据加载线程数")
137 |
138 | # ========== 训练策略 ==========
139 | # 📚 SFT梯度累积知识点
140 | # SFT数据质量高,通常不需要大量梯度累积
141 | # accumulation_steps=1 意味着每个batch都更新参数
142 | parser.add_argument("--accumulation_steps", type=int, default=1,
143 | help="梯度累积步数(SFT通常设为1)")
144 | parser.add_argument("--grad_clip", type=float, default=1.0,
145 | help="梯度裁剪阈值")
146 | parser.add_argument("--log_interval", type=int, default=100,
147 | help="日志打印间隔")
148 | parser.add_argument("--save_interval", type=int, default=100,
149 | help="模型保存间隔")
150 |
151 | # ========== 模型架构参数 ==========
152 | parser.add_argument('--hidden_size', default=512, type=int,
153 | help="隐藏层维度")
154 | parser.add_argument('--num_hidden_layers', default=8, type=int,
155 | help="隐藏层数量")
156 | parser.add_argument('--max_seq_len', default=512, type=int,
157 | help="训练的最大截断长度")
158 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1],
159 | help="是否使用MoE架构(0=否,1=是)")
160 |
161 | # ========== SFT数据和恢复参数 ==========
162 | # 📚 SFT数据路径知识点
163 | # SFT数据通常是结构化的问答对或对话数据
164 | # 包含instruction和response两部分
165 | parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl",
166 | help="SFT训练数据路径")
167 |
168 | # 📚 SFT权重继承知识点
169 | # SFT通常从预训练模型开始,而不是从零开始
170 | # 'pretrain'表示从预训练权重开始微调
171 | parser.add_argument('--from_weight', default='pretrain', type=str,
172 | help="基于哪个权重训练(通常从预训练权重开始)")
173 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1],
174 | help="是否自动检测&续训(0=否,1=是)")
175 |
176 | # ========== 实验跟踪 ==========
177 | parser.add_argument("--use_wandb", action="store_true",
178 | help="是否使用wandb")
179 | parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT",
180 | help="wandb项目名")
181 |
182 | args = parser.parse_args()
183 |
184 | # ========== 1. 初始化环境和随机种子 ==========
185 | local_rank = init_distributed_mode()
186 | if dist.is_initialized(): args.device = f"cuda:{local_rank}"
187 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
188 |
189 | # ========== 2. 配置目录、模型参数、检查ckp ==========
190 | os.makedirs(args.save_dir, exist_ok=True)
191 | lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
192 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
193 |
194 | # ========== 3. 设置混合精度 ==========
195 | device_type = "cuda" if "cuda" in args.device else "cpu"
196 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
197 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
198 |
199 | # ========== 4. 配wandb ==========
200 | wandb = None
201 | if args.use_wandb and is_main_process():
202 | import swanlab as wandb
203 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None
204 | resume = 'must' if wandb_id else None
205 | wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
206 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
207 |
208 | # ========== 5. 定义模型、数据、优化器 ==========
209 | model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
210 | train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
211 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
212 | scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
213 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
214 |
215 | # ========== 6. 从ckp恢复状态 ==========
216 | start_epoch, start_step = 0, 0
217 | if ckp_data:
218 | model.load_state_dict(ckp_data['model'])
219 | optimizer.load_state_dict(ckp_data['optimizer'])
220 | scaler.load_state_dict(ckp_data['scaler'])
221 | start_epoch = ckp_data['epoch']
222 | start_step = ckp_data.get('step', 0)
223 |
224 | # ========== 7. DDP包模型 ==========
225 | if dist.is_initialized():
226 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
227 | model = DistributedDataParallel(model, device_ids=[local_rank])
228 |
229 | # ========== 8. 开始训练 ==========
230 | for epoch in range(start_epoch, args.epochs):
231 | train_sampler and train_sampler.set_epoch(epoch)
232 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
233 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
234 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
235 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
236 | train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
237 | else: # 默认从头开始
238 | loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
239 | train_epoch(epoch, loader, len(loader), 0, wandb)
240 |
--------------------------------------------------------------------------------
/trainer/train_pretrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 |
5 | __package__ = "trainer"
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
7 |
8 | import argparse # 命令行参数解析
9 | import time # 时间统计
10 | import warnings # 警告控制
11 | import torch
12 | import torch.distributed as dist # 分布式训练支持
13 | from contextlib import nullcontext # 上下文管理器
14 | from torch import optim, nn # 优化器和神经网络模块
15 | from torch.nn.parallel import DistributedDataParallel # 分布式数据并行
16 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载器
17 |
18 | from model.MokioModel import MokioMindConfig
19 | from dataset.lm_dataset import PretrainDataset
20 | from trainer.trainer_utils import ( # 训练工具函数
21 | get_lr,
22 | Logger,
23 | is_main_process,
24 | lm_checkpoint,
25 | init_distributed_mode,
26 | setup_seed,
27 | init_model,
28 | SkipBatchSampler,
29 | )
30 |
31 | # 忽略警告信息,保持输出清洁
32 | warnings.filterwarnings("ignore")
33 |
34 |
35 | def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
36 | loss_fct = nn.CrossEntropyLoss(reduction="none")
37 | start_time = time.time() # 记录开始时间
38 |
39 | # 遍历数据批次
40 | for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
41 | X = X.to(args.device)
42 | Y = Y.to(args.device)
43 | loss_mask = loss_mask.to(args.device)
44 |
45 | lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
46 |
47 | for param_group in optimizer.param_groups:
48 | param_group["lr"] = lr
49 |
50 | with autocast_ctx:
51 | # 前向传播
52 | res = model(X)
53 |
54 | loss = loss_fct(
55 | res.logits.view(-1, res.logits.size(-1)), # [batch*seq, vocab_size]
56 | Y.view(-1), # [batch*seq]
57 | ).view(Y.size()) # 恢复为 [batch_size, seq_len]
58 |
59 | loss = (loss * loss_mask).sum() / loss_mask.sum()
60 |
61 | loss+=res.aux_loss
62 |
63 | loss = loss / args.accumulation_steps
64 |
65 | scaler.scale(loss).backward()
66 |
67 | if (step + 1) % args.accumulation_steps == 0:
68 | # scaler.unscale_(): 还原梯度的真实值
69 | scaler.unscale_(optimizer)
70 |
71 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
72 |
73 | # 📚 优化器更新知识点
74 | # scaler.step(): 执行参数更新
75 | # scaler.update(): 更新scaler的缩放因子
76 | scaler.step(optimizer)
77 | scaler.update()
78 |
79 | optimizer.zero_grad(set_to_none=True)
80 |
81 | if step % args.log_interval == 0 or step == iters - 1:
82 | spend_time = time.time() - start_time
83 | current_loss = loss.item() * args.accumulation_steps # 恢复真实损失值
84 | current_lr = optimizer.param_groups[-1]["lr"] # 当前学习率
85 |
86 | eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
87 |
88 | Logger(
89 | f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:"
90 | )
91 |
92 | # 记录到实验跟踪系统
93 | if wandb:
94 | wandb.log(
95 | {"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min}
96 | )
97 |
98 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
99 | model.eval() # 切换到评估模式
100 |
101 | # 构建保存路径
102 | moe_suffix = (
103 | "_moe" if hasattr(lm_config, "use_moe") and lm_config.use_moe else ""
104 | )
105 | ckp = f"{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth"
106 |
107 | # 📚 分布式模型保存知识点
108 | # DDP模型需要通过.module访问真正的模型
109 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
110 | state_dict = model.module.state_dict()
111 | else:
112 | state_dict = model.state_dict()
113 |
114 | # 📚 半精度保存知识点
115 | # 将float32参数转为float16,减少存储空间
116 | state_dict = {k: v.half() for k, v in state_dict.items()}
117 | torch.save(state_dict, ckp)
118 |
119 | # 保存完整训练状态
120 | lm_checkpoint(
121 | lm_config,
122 | weight=args.save_weight,
123 | model=model,
124 | optimizer=optimizer,
125 | scaler=scaler,
126 | epoch=epoch,
127 | step=step,
128 | wandb=wandb,
129 | save_dir="checkpoints",
130 | )
131 |
132 | model.train() # 恢复训练模式
133 |
134 |
135 | if __name__ == "__main__":
136 | parser = argparse.ArgumentParser(description="MiniMind Pretraining")
137 |
138 | # ========== 基础训练参数 ==========
139 | parser.add_argument("--save_dir", type=str, default="out", help="模型保存目录")
140 | parser.add_argument(
141 | "--save_weight", default="pretrain", type=str, help="保存权重的前缀名"
142 | )
143 | parser.add_argument(
144 | "--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)"
145 | )
146 | parser.add_argument("--batch_size", type=int, default=32, help="batch size")
147 | parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
148 |
149 | # ========== 硬件和性能参数 ==========
150 | parser.add_argument(
151 | "--device",
152 | type=str,
153 | default="cuda:0" if torch.cuda.is_available() else "cpu",
154 | help="训练设备",
155 | )
156 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
157 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
158 |
159 | # ========== 训练策略参数 ==========
160 | parser.add_argument(
161 | "--accumulation_steps", type=int, default=8, help="梯度累积步数"
162 | )
163 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
164 | parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
165 | parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
166 |
167 | # ========== 模型架构参数 ==========
168 | parser.add_argument("--hidden_size", default=512, type=int, help="隐藏层维度")
169 | parser.add_argument("--num_hidden_layers", default=8, type=int, help="隐藏层数量")
170 | parser.add_argument(
171 | "--max_seq_len", default=512, type=int, help="训练的最大截断长度"
172 | )
173 | parser.add_argument(
174 | "--use_moe",
175 | default=0,
176 | type=int,
177 | choices=[0, 1],
178 | help="是否使用MoE架构(0=否,1=是)",
179 | )
180 |
181 | # ========== 数据和恢复参数 ==========
182 | parser.add_argument(
183 | "--data_path",
184 | type=str,
185 | default="dataset/pretrain_hq.jsonl",
186 | help="预训练数据路径",
187 | )
188 | parser.add_argument(
189 | "--from_weight",
190 | default="none",
191 | type=str,
192 | help="基于哪个权重训练,为none则从头开始",
193 | )
194 | parser.add_argument(
195 | "--from_resume",
196 | default=0,
197 | type=int,
198 | choices=[0, 1],
199 | help="是否自动检测&续训(0=否,1=是)",
200 | )
201 |
202 | # ========== 实验跟踪参数 ==========
203 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
204 | parser.add_argument(
205 | "--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名"
206 | )
207 |
208 | # 解析命令行参数
209 | args = parser.parse_args()
210 |
211 | # ========== 1. 初始化环境和随机种子 ==========
212 | """
213 | 📚 分布式训练初始化知识点:
214 | - local_rank: 当前进程在本机上的GPU编号
215 | - 随机种子: 确保不同进程有不同但可复现的随机序列
216 | - 这样既保证了随机性,又保证了可复现性
217 | """
218 | local_rank = init_distributed_mode()
219 | if dist.is_initialized():
220 | args.device = f"cuda:{local_rank}" # 分布式训练时使用对应的GPU
221 |
222 | # 📚 随机种子设置知识点
223 | # 不同进程使用不同的种子,避免数据采样完全相同
224 | # 42是基础种子,每个进程加上自己的rank保证不同
225 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
226 |
227 | # ========== 2. 配置目录、模型参数、检查点 ==========
228 | """
229 | 📚 模型配置和检查点管理:
230 | - 创建保存目录
231 | - 构建模型配置对象
232 | - 尝试加载断点续训数据
233 | """
234 | os.makedirs(args.save_dir, exist_ok=True) # 确保保存目录存在
235 |
236 | # 创建MiniMind模型配置
237 | lm_config = MokioMindConfig(
238 | hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,use_moe=bool(args.use_moe)
239 | )
240 |
241 | # 📚 断点续训知识点
242 | # 如果开启了断点续训,尝试加载之前的训练状态
243 | ckp_data = (
244 | lm_checkpoint(lm_config, weight=args.save_weight, save_dir="checkpoints")
245 | if args.from_resume == 1
246 | else None
247 | )
248 |
249 | # ========== 3. 设置混合精度 ==========
250 | """
251 | 📚 混合精度训练知识点:
252 | - bfloat16: Google开发,数值范围大,更稳定
253 | - float16: 标准半精度,节省内存但可能溢出
254 | - autocast: 自动选择精度,关键运算用float32
255 | """
256 | device_type = "cuda" if "cuda" in args.device else "cpu"
257 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
258 |
259 | # 📚 上下文管理器知识点
260 | # CPU不支持autocast,使用nullcontext作为空操作
261 | autocast_ctx = (
262 | nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
263 | )
264 |
265 | # ========== 4. 配置WandB实验跟踪 ==========
266 | """
267 | 📚 实验跟踪系统知识点:
268 | - WandB: 实验管理平台,记录训练过程
269 | - SwanLab: 国产替代方案
270 | - 支持断点续训时恢复到同一个实验
271 | """
272 | wandb = None
273 | if args.use_wandb and is_main_process():
274 | # 使用SwanLab作为WandB的替代
275 | import swanlab as wandb
276 |
277 | # 📚 实验恢复知识点
278 | # 如果有检查点数据,获取之前的wandb_id来恢复实验
279 | wandb_id = ckp_data.get("wandb_id") if ckp_data else None
280 | resume = "must" if wandb_id else None # 必须恢复到指定实验
281 |
282 | # 构建实验名称,包含关键超参数
283 | wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
284 | wandb.init(
285 | project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume
286 | )
287 |
288 | # ========== 5. 定义模型、数据、优化器 ==========
289 | """
290 | 📚 训练组件初始化:
291 | - 模型: 根据配置创建MiniMind模型
292 | - 数据集: 加载预训练数据
293 | - 采样器: 分布式训练的数据分配
294 | - 优化器: AdamW优化器
295 | - 缩放器: 混合精度训练的梯度缩放
296 | """
297 | # 初始化模型和分词器
298 | model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
299 |
300 | train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
301 |
302 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
303 |
304 | scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == "float16"))
305 |
306 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
307 |
308 | start_epoch, start_step = 0, 0
309 | if ckp_data:
310 | # 恢复模型参数
311 | model.load_state_dict(ckp_data["model"])
312 | # 恢复优化器状态(动量、方差估计等)
313 | optimizer.load_state_dict(ckp_data["optimizer"])
314 | # 恢复梯度缩放器状态
315 | scaler.load_state_dict(ckp_data["scaler"])
316 | # 恢复训练进度
317 | start_epoch = ckp_data["epoch"]
318 | start_step = ckp_data.get("step", 0)
319 |
320 | if dist.is_initialized():
321 | # 📚 RoPE位置编码特殊处理
322 | # freqs_cos, freqs_sin是位置编码缓存,不需要梯度同步
323 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
324 | model = DistributedDataParallel(model, device_ids=[local_rank])
325 |
326 | for epoch in range(start_epoch, args.epochs):
327 | # 📚 分布式采样器epoch设置
328 | # 每个epoch设置不同的随机种子,确保数据顺序随机化
329 | if train_sampler:
330 | train_sampler.set_epoch(epoch)
331 |
332 | # 📚 断点续训逻辑
333 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
334 | # 使用跳批采样器,跳过已训练的数据
335 | batch_sampler = SkipBatchSampler(
336 | train_sampler or range(len(train_ds)), args.batch_size, start_step + 1
337 | )
338 | loader = DataLoader(
339 | train_ds,
340 | batch_sampler=batch_sampler,
341 | num_workers=args.num_workers,
342 | pin_memory=True,
343 | )
344 | Logger(
345 | f"Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始"
346 | )
347 | train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
348 | else: # 默认从头开始
349 | loader = DataLoader(
350 | train_ds,
351 | batch_size=args.batch_size,
352 | shuffle=(train_sampler is None),
353 | sampler=train_sampler,
354 | num_workers=args.num_workers,
355 | pin_memory=True,
356 | )
357 | train_epoch(epoch, loader, len(loader), 0, wandb)
358 |
--------------------------------------------------------------------------------
/trainer/train_dpo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # 📚 Python模块系统
5 | __package__ = "trainer"
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7 |
8 | import argparse # 命令行参数解析
9 | import time # 时间统计
10 | import warnings # 警告控制
11 | import torch # PyTorch深度学习框架
12 | import torch.nn.functional as F # 神经网络函数
13 | import torch.distributed as dist # 分布式训练支持
14 | from contextlib import nullcontext # 上下文管理器
15 | from torch import optim # 优化器
16 | from torch.nn.parallel import DistributedDataParallel # 分布式数据并行
17 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载
18 |
19 | # MiniMind相关组件
20 | from model.MokioModel import MokioMindConfig # 模型配置
21 | from dataset.lm_dataset import DPODataset # DPO数据集
22 | from trainer.trainer_utils import ( # 训练工具函数
23 | get_lr, Logger, is_main_process, lm_checkpoint,
24 | init_distributed_mode, setup_seed, init_model, SkipBatchSampler
25 | )
26 |
27 | def logits_to_log_probs(logits,labels):
28 | #词表logits转换为log概率
29 | log_probs=F.log_softmax(logits,dim=2)
30 | #从log词表概率里选出label对应的log概率
31 | #也就是从拿到token在其对应位置的概率
32 | log_probs_per_token=torch.gather(log_probs,dim=2,index=labels.unsqueeze(2)).unsqueeze(-1)
33 | return log_probs_per_token
34 |
35 | # DPO的loss计算
36 | # 公式:L = -log(σ(β * (π(y_w) - π(y_l) - (π_ref(y_w) - π_ref(y_l)))))
37 | def dpo_loss(ref_log_probs,policy_log_probs,mask,beta):
38 |
39 | seq_lengths=mask.sum(dim=1,keepdim=True)
40 | clamp_min(1e-8)
41 | # 计算ref和policy的序列log概率均值
42 | ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
43 | policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
44 |
45 | # 分别获取chosen和rejected的ref和policy的log概率
46 | batch_size = ref_log_probs.shape[0]
47 | chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
48 | reject_ref_log_probs = ref_log_probs[batch_size // 2:]
49 | chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
50 | reject_policy_log_probs = policy_log_probs[batch_size // 2:]
51 | # 计算策略模型的log概率差异
52 | pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
53 | # 参考模型的log概率差异
54 | ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
55 | # DPO损失计算
56 | logits = pi_logratios - ref_logratios
57 | loss = -F.logsigmoid(beta * logits)
58 | return loss.mean()
59 |
60 | def train_epoch(epoch,loader,iters,ref_model,lm_config,start_step=0,wandb=None,beta=0.1):
61 | start_time = time.time()
62 | for step,batch in enumerate(loader,start=start_step+1):
63 | x_chosen = batch['x_chosen'].to(args.device)
64 | x_rejected = batch['x_rejected'].to(args.device)
65 | y_chosen = batch['y_chosen'].to(args.device)
66 | y_rejected = batch['y_rejected'].to(args.device)
67 | mask_chosen = batch['mask_chosen'].to(args.device)
68 | mask_rejected = batch['mask_rejected'].to(args.device)
69 |
70 | x = torch.cat([x_chosen, x_rejected], dim=0)
71 | y = torch.cat([y_chosen, y_rejected], dim=0)
72 | mask = torch.cat([mask_chosen, mask_rejected], dim=0)
73 |
74 | # 📚 学习率调度
75 | lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
76 | for param_group in optimizer.param_groups:
77 | param_group['lr'] = lr
78 |
79 | with autocast_ctx:
80 | # 📚 参考模型前向传播
81 | # 参考模型冻结,只用于计算baseline概率
82 | with torch.no_grad():
83 | ref_outputs = ref_model(x)
84 | ref_logits = ref_outputs.logits
85 | ref_log_probs = logits_to_log_probs(ref_logits, y)
86 |
87 | # 📚 策略模型前向传播
88 | # 策略模型是需要优化的主要模型
89 | outputs = model(x)
90 | logits = outputs.logits
91 | policy_log_probs = logits_to_log_probs(logits, y)
92 |
93 | # 📚 DPO损失计算
94 | loss = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
95 | loss = loss / args.accumulation_steps
96 |
97 | # 📚 反向传播
98 | scaler.scale(loss).backward()
99 |
100 | if (step + 1) % args.accumulation_steps == 0:
101 | scaler.unscale_(optimizer)
102 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
103 | scaler.step(optimizer)
104 | scaler.update()
105 | optimizer.zero_grad(set_to_none=True)
106 |
107 | # 📚 训练日志
108 | if step % args.log_interval == 0 or step == iters - 1:
109 | spend_time = time.time() - start_time
110 | current_loss = loss.item() * args.accumulation_steps
111 | current_lr = optimizer.param_groups[-1]['lr']
112 | eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
113 |
114 | Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
115 |
116 | if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
117 |
118 | # 📚 模型保存
119 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
120 | model.eval()
121 | moe_suffix = '_moe' if lm_config.use_moe else ''
122 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
123 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
124 | state_dict = model.module.state_dict()
125 | else:
126 | state_dict = model.state_dict()
127 | state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
128 | torch.save(state_dict, ckp)
129 | lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
130 | model.train()
131 |
132 | if __name__ == "__main__":
133 | """
134 | DPO主函数:直接偏好优化脚本的入口点
135 |
136 | 📚 DPO训练流程:
137 | 1. 准备策略模型和参考模型
138 | 2. 加载偏好数据(chosen vs rejected)
139 | 3. 同时前向传播计算两种模型的概率
140 | 4. 计算DPO损失并优化策略模型
141 | 5. 迭代直到收敛
142 | """
143 |
144 | # 📚 命令行参数解析
145 | parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
146 |
147 | # ========== 基础训练参数 ==========
148 | parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
149 | parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
150 | parser.add_argument("--epochs", type=int, default=1, help="训练轮数(DPO通常1-2轮)")
151 | parser.add_argument("--batch_size", type=int, default=4, help="batch size(DPO batch较小)")
152 |
153 | # 📚 DPO学习率知识点
154 | # DPO学习率通常很小,避免过度优化导致遗忘
155 | # 建议不超过5e-8
156 | parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)")
157 |
158 | # ========== 硬件配置 ==========
159 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
160 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
161 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
162 |
163 | # ========== 训练策略 ==========
164 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
165 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
166 | parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
167 | parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
168 |
169 | # ========== 模型架构参数 ==========
170 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
171 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
172 | parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
173 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
174 |
175 | # ========== DPO数据和模型参数 ==========
176 | # 📚 DPO数据格式知识点
177 | # 数据包含chosen(偏好)和rejected(不偏好)回答配对
178 | parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
179 |
180 | # 📚 DPO权重继承知识点
181 | # DPO通常基于SFT模型进行对齐优化
182 | parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练(通常是SFT模型)")
183 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
184 |
185 | # 📚 DPO beta参数知识点
186 | # beta控制优化强度,0.1-0.5是常见范围
187 | parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数(控制优化强度)")
188 |
189 | # ========== 实验跟踪 ==========
190 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
191 | parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
192 |
193 | args = parser.parse_args()
194 |
195 | # ========== 1. 初始化环境和随机种子 ==========
196 | local_rank = init_distributed_mode()
197 | if dist.is_initialized(): args.device = f"cuda:{local_rank}"
198 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
199 |
200 | # ========== 2. 配置目录、模型参数、检查ckp ==========
201 | os.makedirs(args.save_dir, exist_ok=True)
202 | lm_config = MokioMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
203 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
204 |
205 | # ========== 3. 设置混合精度 ==========
206 | device_type = "cuda" if "cuda" in args.device else "cpu"
207 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
208 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
209 |
210 | # ========== 4. 配置wandb ==========
211 | wandb = None
212 | if args.use_wandb and is_main_process():
213 | import swanlab as wandb
214 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None
215 | resume = 'must' if wandb_id else None
216 | wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
217 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
218 |
219 | # ========== 5. 定义模型和参考模型 ==========
220 | # 📚 DPO双模型架构
221 | # 策略模型:需要优化的模型
222 | # 参考模型:冻结的baseline模型
223 | model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
224 | Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
225 |
226 | # 📚 参考模型初始化
227 | # 参考模型与策略模型初始权重相同,但完全冻结
228 | ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
229 | ref_model.eval() # 设为评估模式
230 | ref_model.requires_grad_(False) # 冻结所有参数
231 | Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
232 |
233 | # 📚 DPO数据集
234 | train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
235 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
236 | scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
237 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
238 |
239 | # ========== 6. 从ckp恢复状态 ==========
240 | start_epoch, start_step = 0, 0
241 | if ckp_data:
242 | model.load_state_dict(ckp_data['model'])
243 | optimizer.load_state_dict(ckp_data['optimizer'])
244 | scaler.load_state_dict(ckp_data['scaler'])
245 | start_epoch = ckp_data['epoch']
246 | start_step = ckp_data.get('step', 0)
247 |
248 | # ========== 7. DDP包装模型 ==========
249 | if dist.is_initialized():
250 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
251 | model = DistributedDataParallel(model, device_ids=[local_rank])
252 |
253 | # ========== 8. 开始训练 ==========
254 | for epoch in range(start_epoch, args.epochs):
255 | train_sampler and train_sampler.set_epoch(epoch)
256 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
257 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
258 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
259 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
260 | train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, lm_config, start_step, wandb, args.beta)
261 | else: # 默认从头开始
262 | loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
263 | train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
264 |
--------------------------------------------------------------------------------
/trainer/train_ppo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # 📚 Python模块系统
5 | __package__ = "trainer"
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7 |
8 | import argparse # 命令行参数解析
9 | import re # 正则表达式,用于奖励计算
10 | import warnings # 警告控制
11 | import torch # PyTorch深度学习框架
12 | import torch.distributed as dist # 分布式训练支持
13 | import torch.nn.functional as F # 神经网络函数
14 | from transformers import AutoTokenizer # HuggingFace分词器
15 | from contextlib import nullcontext # 上下文管理器
16 | from torch import optim, nn # 优化器和神经网络
17 | from torch.nn.parallel import DistributedDataParallel # 分布式并行
18 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载
19 | from torch.nn.utils import clip_grad_norm_ # 梯度裁剪
20 | from torch.optim.lr_scheduler import CosineAnnealingLR # 余弦退火学习率调度
21 | from transformers import AutoModel # HuggingFace模型加载
22 | from model.MokioModel import MokioMindConfig, MokioMindForCausalLM # MiniMind模型
23 | from dataset.lm_dataset import RLAIFDataset # RL数据集
24 | from trainer.trainer_utils import ( # 训练工具函数
25 | Logger, is_main_process, lm_checkpoint, init_distributed_mode,
26 | setup_seed, SkipBatchSampler, init_model
27 | )
28 |
29 | warnings.filterwarnings('ignore')
30 | #==========Critic Model部分==========
31 |
32 | class CriticModel(MokioMindForCausalLM):
33 | def __init__(self,params):
34 | super().__init__(params)
35 | # 价值头,用于输出每个token位置的状态价值
36 | self.value_head=nn.Linear(params.hidden_size,1)
37 |
38 | def forward(self,input_ids=None,attention_mask=None,**kwargs):
39 | outputs=self.model(input_ids=input_ids,attention_mask=attention_mask,**kwargs)
40 | hidden_states=self.model.norm(outputs[0])
41 |
42 | values=self.value_head(hidden_states).squeeze(-1)
43 | return values
44 |
45 | #==========奖励计算部分==========
46 | def calculate_rewards(prompts,responses,reward_model,reward_tokenizer):
47 | def reasoning_model_reward(rewards):
48 | # 使用正则表达式匹配思考-回答格式
49 | pattern = r"^\n.*?\n\n\n.*?\n$"
50 | # 多了一个\n,考虑到think和answer之间有空行的情况
51 | pattern2 = r"^\n.*?\n\n\n\n.*?\n$"
52 | # 通过正则表达式计算奖励,如果回答符合格式则奖励0.5,否则0.0
53 | matches_pattern = [re.match(pattern, response, re.S) for response in responses]
54 | matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
55 |
56 | format_rewards = []
57 | for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
58 | if match_pattern:
59 | format_rewards.append(0.5)
60 | elif match_pattern2:
61 | format_rewards.append(0.5)
62 | else:
63 | format_rewards.append(0.0)
64 | rewards += torch.tensor(format_rewards, device=args.device)
65 |
66 | def mark_num(text):
67 | reward=0
68 | if text.count("")==1:
69 | reward+=0.25
70 | if text.count("")==1:
71 | reward+=0.25
72 | if text.count("")==1:
73 | reward+=0.25
74 | if text.count("")==1:
75 | reward+=0.25
76 | return reward
77 |
78 | mark_rewards=[mark_num(response) for response in responses]
79 | rewards+=torch.tensor(mark_rewards,device=args.device)
80 | return rewards
81 | rewards=torch.zeros(len(responses),device=args.device)
82 |
83 | if args.reasoning==1:
84 | rewards=reasoning_model_reward(rewards)
85 | #==========Reward模型评分部分==========
86 | with torch.no_grad():
87 | reward_model_scores = []
88 | for prompt,response in zip(prompts,responses):
89 |
90 | pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
91 | matches = re.findall(pattern, prompt, re.DOTALL)
92 | messages = [{"role": role, "content": content.strip()} for role, content in matches]
93 |
94 | tmp_chat=messages+[{"role":"assistant","content":response}]
95 | score=reward_model.get_reward(tmp_chat,reward_tokenizer)
96 |
97 | scale=3.0
98 | score=max(min(score,scale),-scale)
99 |
100 | if args.reasoning==1:
101 | answer_match = re.search(r'(.*?)', response, re.DOTALL)
102 | if answer_match:
103 | answer_content = answer_match.group(1).strip()
104 | # 对answer内容单独计算reward
105 | tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
106 | answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
107 | answer_score = max(min(answer_score, scale), -scale)
108 | # 📚 加权组合
109 | score = score * 0.4 + answer_score * 0.6
110 | reward_model_scores.append(score)
111 |
112 | reward_model_scores=torch.tensor(reward_model_scores,device=args.device)
113 | rewards+=reward_model_scores
114 |
115 | return rewards
116 |
117 | #==========PPO训练一个Epoch部分==========
118 | def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step=0, wandb=None):
119 | # 切换actor和critic模型到训练模式
120 | actor_model.train()
121 | critic_model.train()
122 |
123 | for step,batch in enumerate(loader,start=start_step+1):
124 | prompts=batch['prompt']
125 | # 编码输入
126 | enc=tokenizer(prompts,return_tensors='pt',padding=True,truncation=True,max_length=args.max_seq_len).to(args.device)
127 | # 计算每个prompt的长度(用于后续处理)
128 | prompt_lengths=enc.attention_mask.sum(dim=1)
129 |
130 | with torch.no_grad():
131 | model_for_gen=actor_model.module if isinstance(actor_model,DistributedDataParallel) else actor_model
132 |
133 | gen_out=model_for_gen.generate(
134 | input_ids=enc.input_ids,
135 | attention_mask=enc.attention_mask,
136 | max_new_tokens=args.max_gen_len,
137 | do_sample=True,
138 | temperature=0.8,
139 | pad_token_id=tokenizer.eos_token_id,
140 | eos_token_id=tokenizer.eos_token_id
141 | )
142 |
143 | # 解码生成的响应
144 | responses_text=[tokenizer.decode(gen_out[i,prompt_lengths[i]:],skip_special_tokens=True) for i in range(len(prompts))]
145 |
146 | # 计算奖励
147 | rewards=calculate_rewards(prompts,responses_text,reward_model,reward_tokenizer)
148 |
149 | # 创建一个mask,用于标记哪些位置上是有效token
150 | full_mask=(gen_out!=tokenizer.pad_token_id).long()
151 | # critic模型进行价值估计
152 | value_seq=critic_model(input_ids=gen_out,attention_mask=full_mask)
153 | # 拿到最后一个非pad位置的索引
154 | last_indices=full_mask.sum(dim=1)-1
155 | # 获取每条序列最后token的value
156 | values=value_seq[torch.arange(len(last_indices)),last_indices]
157 | # advantage=reward-估计的value
158 | advantages = rewards - values.detach() # [B]
159 |
160 | # 计算actor log,表示actor对这个答案的“信心”
161 | # 先生成logits
162 | logits=actor_model(input_ids=gen_out,attention_mask=full_mask).logits # [B, L, V]
163 | # label是生成的token序列,去掉第一个token(因为logits是预测下一个token的概率)
164 | labels=gen_out[:,1:].clone()
165 | # 使用log_softmax计算log概率
166 | logp_tokens=F.log_softmax(logits[:,:-1,:],dim=-1).gather(2,labels.unsqueeze(-1)).squeeze(-1) # [B, L-1]
167 | seq_len=gen_out.size(1)-1
168 | # 只关心response部分的概率,所以要把prompts部分的mask掉
169 | resp_mask=torch.arange(seq_len,device=gen_out.device).unsqueeze(0)>=prompt_lengths.unsqueeze(1)
170 |
171 | final_mask=resp_mask&(~labels.eq(tokenizer.pad_token_id))
172 | # 把所有response部分的log概率加起来,得到每条序列的总log概率
173 | actor_logp=(logp_tokens*final_mask).sum(dim=1)
174 |
175 | # 计算old和ref log的概率
176 | # old用于防止策略更新过大,ref用于计算KL惩罚,防止模型忘本
177 | with torch.no_grad():
178 | old_logits = old_actor_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V]
179 | old_logp_tokens = F.log_softmax(old_logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
180 | old_logp = (old_logp_tokens * final_mask).sum(dim=1) # [B]
181 |
182 | ref_logits = ref_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V]
183 | ref_logp_tokens = F.log_softmax(ref_logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
184 | ref_logp = (ref_logp_tokens * final_mask).sum(dim=1) # [B]
185 |
186 | # 计算KL散度和ratio
187 | kl=(actor_logp - old_logp).mean()
188 | kl_ref=(actor_logp - ref_logp).mean()
189 | ratio=torch.exp(actor_logp - old_logp) # [B]
190 |
191 | # PPO裁剪损失
192 | surr1=ratio*advantages # [B]
193 | surr2=torch.clamp(ratio,1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon)*advantages # [B]
194 | policy_loss=-torch.min(surr1,surr2).mean()
195 |
196 | # 价值函数损失
197 | value_loss=F.mse_loss(values,rewards)
198 | # 总损失
199 | loss = policy_loss + args.vf_coef * value_loss + args.kl_coef * kl_ref # scalar
200 | loss.backward()
201 |
202 | # 更新参数
203 | if (step + 1) % args.accumulation_steps == 0:
204 | clip_grad_norm_(actor_model.parameters(), args.grad_clip)
205 | clip_grad_norm_(critic_model.parameters(), args.grad_clip)
206 | actor_optimizer.step()
207 | critic_optimizer.step()
208 | actor_scheduler.step()
209 | critic_scheduler.step()
210 | actor_optimizer.zero_grad()
211 | critic_optimizer.zero_grad()
212 |
213 | # 📚 日志记录
214 | if is_main_process():
215 | response_ids = gen_out[:, enc.input_ids.shape[1]:]
216 | is_eos = (response_ids == tokenizer.eos_token_id)
217 | eos_indices = torch.argmax(is_eos.int(), dim=1)
218 | has_eos = is_eos.any(dim=1)
219 | lengths = torch.where(has_eos, eos_indices + 1, torch.tensor(response_ids.shape[1], device=is_eos.device))
220 | avg_len = lengths.float().mean()
221 |
222 | actor_loss_val = policy_loss.item()
223 | critic_loss_val = value_loss.item()
224 | reward_val = rewards.mean().item()
225 | kl_val = kl.item()
226 | kl_ref_val = kl_ref.item()
227 | avg_len_val = avg_len.item()
228 | actor_lr = actor_optimizer.param_groups[0]['lr']
229 | critic_lr = critic_optimizer.param_groups[0]['lr']
230 |
231 | if wandb is not None:
232 | wandb.log({
233 | "actor_loss": actor_loss_val,
234 | "critic_loss": critic_loss_val,
235 | "reward": reward_val,
236 | "kl": kl_val,
237 | "kl_ref": kl_ref_val,
238 | "avg_response_len": avg_len_val,
239 | "actor_lr": actor_lr,
240 | })
241 |
242 | Logger(f"Epoch: {epoch+1}, Step: {step}/{iters}, "
243 | f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, "
244 | f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, "
245 | f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}")
246 |
247 | # 📚 更新old actor
248 | if (step + 1) % args.update_old_actor_freq == 0:
249 | state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
250 | old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
251 | old_actor_model.to(args.device)
252 |
253 | # 📚 模型保存
254 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
255 | actor_model.eval()
256 | moe_suffix = '_moe' if lm_config.use_moe else ''
257 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
258 | actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
259 | torch.save({k: v.half() for k, v in actor_state.items()}, ckp)
260 |
261 | # 使用 lm_checkpoint 保存完整状态(包括 critic)
262 | lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
263 | epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
264 | scheduler=actor_scheduler, critic_model=critic_model,
265 | critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
266 | actor_model.train()
267 |
268 |
269 |
270 | if __name__ == "__main__":
271 | """
272 | PPO主函数:近端策略优化脚本的入口点
273 |
274 | 📚 PPO训练架构:
275 | 1. Actor模型:生成策略,输出动作概率
276 | 2. Critic模型:价值函数,估计状态价值
277 | 3. Reward模型:奖励函数,评估生成质量
278 | 4. Old Actor:用于重要性采样的旧策略
279 | 5. Reference:用于KL惩罚的参考策略
280 | """
281 |
282 | # 📚 命令行参数解析
283 | parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)")
284 |
285 | # ========== 基础训练参数 ==========
286 | parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
287 | parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名")
288 | parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
289 | parser.add_argument("--batch_size", type=int, default=2, help="batch size(PPO batch较小)")
290 |
291 | # 📚 PPO学习率设置
292 | # PPO学习率通常很小,避免策略剧烈变化
293 | parser.add_argument("--learning_rate", type=float, default=8e-8, help="Actor学习率")
294 | parser.add_argument("--critic_learning_rate", type=float, default=8e-8, help="Critic学习率")
295 |
296 | # ========== 硬件配置 ==========
297 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
298 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
299 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
300 |
301 | # ========== 训练策略 ==========
302 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
303 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
304 | parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
305 | parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
306 |
307 | # ========== 模型架构参数 ==========
308 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
309 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
310 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
311 |
312 | # ========== PPO生成参数 ==========
313 | parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
314 | parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
315 |
316 | # ========== 数据和模型参数 ==========
317 | parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
318 |
319 | # 📚 PPO超参数
320 | parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数(控制策略更新幅度)")
321 | parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
322 | parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
323 |
324 | # 📚 推理模型配置
325 | parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)')
326 | parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率")
327 |
328 | # 📚 Reward模型路径
329 | parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
330 |
331 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
332 |
333 | # ========== 实验跟踪 ==========
334 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
335 | parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
336 |
337 | args = parser.parse_args()
338 | # ========== 1. 初始化环境和随机种子 ==========
339 | local_rank = init_distributed_mode()
340 | if dist.is_initialized(): args.device = f"cuda:{local_rank}"
341 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
342 |
343 | # ========== 2. 配置目录、模型参数、检查ckp ==========
344 | os.makedirs(args.save_dir, exist_ok=True)
345 | lm_config = MokioMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
346 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
347 |
348 | # ========== 3. 设置混合精度 ==========
349 | device_type = "cuda" if "cuda" in args.device else "cpu"
350 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
351 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
352 |
353 | # ========== 4. 配置wandb ==========
354 | wandb = None
355 | if args.use_wandb and is_main_process():
356 | import swanlab as wandb
357 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None
358 | resume = 'must' if wandb_id else None
359 | wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
360 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
361 | # ========== 5. 初始化模型和数据 ==========
362 | # 📚 PPO模型架构
363 | base_weight = "reason" if args.reasoning == 1 else "full_sft"
364 |
365 | # 📚 Actor模型(策略模型)
366 | actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
367 | tokenizer.padding_side = 'left' # PPO需要左侧padding
368 |
369 | # 📚 Old Actor模型(用于重要性采样)
370 | old_actor_model, _ = init_model(lm_config, base_weight, device=args.device)
371 | old_actor_model = old_actor_model.eval().requires_grad_(False)
372 |
373 | # 📚 Reference模型(用于KL惩罚)
374 | ref_model, _ = init_model(lm_config, base_weight, device=args.device)
375 | ref_model = ref_model.eval().requires_grad_(False)
376 |
377 | # 📚 Critic模型(价值函数)
378 | moe_suffix = '_moe' if lm_config.use_moe else ''
379 | ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
380 | state_dict = torch.load(ckp, map_location=args.device)
381 | critic_model = CriticModel(lm_config)
382 | critic_model.load_state_dict(state_dict, strict=False)
383 | critic_model = critic_model.to(args.device)
384 |
385 | # 📚 Reward模型(奖励函数)
386 | reward_model = AutoModel.from_pretrained(
387 | args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
388 | )
389 | reward_model = reward_model.to(args.device).eval().requires_grad_(False)
390 | reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
391 |
392 | # 📚 数据和优化器
393 | train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len))
394 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
395 | actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
396 | critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
397 | loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
398 | iters = len(loader_for_count)
399 | total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
400 | actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
401 | critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
402 |
403 | # ========== 6. 从ckp恢复状态 ==========
404 | start_epoch, start_step = 0, 0
405 | if ckp_data:
406 | actor_model.load_state_dict(ckp_data['model'])
407 | critic_model.load_state_dict(ckp_data['critic_model'])
408 | actor_optimizer.load_state_dict(ckp_data['optimizer'])
409 | critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
410 | actor_scheduler.load_state_dict(ckp_data['scheduler'])
411 | critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
412 | start_epoch = ckp_data['epoch']
413 | start_step = ckp_data.get('step', 0)
414 |
415 | # ========== 7. DDP包装模型 ==========
416 | if dist.is_initialized():
417 | actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
418 | critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
419 | actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
420 | critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
421 | old_actor_model.to(args.device)
422 |
423 | # ========== 8. 开始训练 ==========
424 | for epoch in range(start_epoch, args.epochs):
425 | train_sampler and train_sampler.set_epoch(epoch)
426 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
427 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
428 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
429 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
430 | ppo_train_epoch(epoch, loader, len(loader) + start_step + 1, old_actor_model, ref_model,
431 | actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb)
432 | else: # 默认从头开始
433 | loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None),
434 | sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
435 | ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
436 | actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)
--------------------------------------------------------------------------------
/model/MokioModel.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 |
3 |
4 | class MokioMindConfig(PretrainedConfig):
5 | model_type = "mokiomind"
6 |
7 | def __init__(
8 | self,
9 | dropout: float = 0.0,
10 | bos_token_id: int = 1,
11 | eos_token_id: int = 2,
12 | hidden_act: str = "silu",
13 | hidden_size: int = 512,
14 | intermediate_size: int = None,
15 | max_position_embeddings: int = 32768,
16 | num_attention_heads: int = 8,
17 | num_hidden_layers: int = 8,
18 | num_key_value_heads: int = 2,
19 | vocab_size: int = 6400,
20 | rms_norm_eps: float = 1e-05,
21 | rope_theta: int = 1000000,
22 | inference_rope_scaling: bool = False,
23 | flash_attention: bool = True,
24 |
25 | ############ MoE ############
26 | use_moe:bool=False,
27 | num_experts_per_tok:int=2,
28 | n_routed_experts:int=4,
29 | n_shared_experts:int=1,
30 | scoring_func:str='softmax',
31 | aux_loss_alpha:float=0.1,
32 | seq_aux:bool=True,
33 | norm_topk_prob:bool=True,
34 | **kwargs,
35 | ):
36 | super().__init__(**kwargs)
37 |
38 | self.dropout = dropout
39 | self.bos_token_id = bos_token_id
40 | self.eos_token_id = eos_token_id
41 | self.hidden_act = hidden_act
42 | self.hidden_size = hidden_size
43 | self.intermediate_size = intermediate_size
44 | self.max_position_embeddings = max_position_embeddings
45 | self.num_attention_heads = num_attention_heads
46 | self.num_hidden_layers = num_hidden_layers
47 | self.num_key_value_heads = num_key_value_heads
48 | self.vocab_size = vocab_size
49 | self.rms_norm_eps = rms_norm_eps
50 | self.rope_theta = rope_theta
51 | self.inference_rope_scaling = inference_rope_scaling
52 | self.flash_attention = flash_attention
53 | self.use_moe=use_moe
54 | self.num_experts_per_tok=num_experts_per_tok
55 | self.n_routed_experts=n_routed_experts
56 | self.n_shared_experts=n_shared_experts
57 | self.seq_aux=seq_aux
58 | self.norm_topk_prob=norm_topk_prob
59 | self.aux_loss_alpha=aux_loss_alpha
60 | self.scoring_func=scoring_func
61 |
62 | self.rope_scaling = (
63 | {
64 | "beta_fast": 4,
65 | "beta_slow": 1,
66 | "factor": 4,
67 | "original_max_position_embeddings": 2048,
68 | "type": "yarn",
69 | }
70 | if self.inference_rope_scaling
71 | else None
72 | )
73 |
74 |
75 | import torch
76 | import math
77 | import torch.nn as nn
78 | from typing import Optional, Tuple,List,Union
79 | import torch.nn.functional as F
80 | from transformers.activations import ACT2FN
81 | from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
82 | from transformers.modeling_outputs import CausalLMOutputWithPast
83 |
84 |
85 | class RMSNorm(nn.Module):
86 | def __init__(self, dim: int, eps: float = 1e-5):
87 | super().__init__()
88 | self.dim = dim
89 | self.eps = eps
90 | self.weight = nn.Parameter(torch.ones(dim))
91 |
92 | def _norm(self, x):
93 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
94 |
95 | def forward(self, x):
96 | return self.weight * self._norm(x.float()).typed_as(x)
97 |
98 |
99 | def precompute_freqs(
100 | dim: int,
101 | end: int = int(32 * 1024),
102 | rope_base: float = 1e6,
103 | rope_scaling: Optional[dict] = None,
104 | ):
105 | freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
106 |
107 | if rope_scaling is not None:
108 | original_max, factor, beta_fast, beta_slow = (
109 | rope_scaling.get("original_max_position_embeddings", 2048),
110 | rope_scaling.get("factor", 4),
111 | rope_scaling.get("beta_fast", 4.0),
112 | rope_scaling.get("beta_slow", 1.0),
113 | )
114 |
115 | if end / original_max > 1.0:
116 | corr_dim = next(
117 | (i for i in range(dim // 2) if 2 * math.pi / freqs[i] > original_max),
118 | dim // 2,
119 | )
120 |
121 | power = torch.arange(0, dim // 2, device=freqs.device).float() / max(
122 | dim // 2 - 1, 1
123 | )
124 |
125 | beta = beta_slow + (beta_fast - beta_slow) * power
126 |
127 | scale = torch.where(
128 | torch.arange(dim // 2, device=freqs.device) < corr_dim,
129 | (beta * factor - beta + 1) / (beta * factor),
130 | 1.0 / factor,
131 | )
132 |
133 | freqs = freqs * scale
134 |
135 | t = torch.arange(end, device=freqs.device)
136 | freqs = torch.outer(t, freqs).float()
137 |
138 | freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
139 | freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
140 |
141 | return freqs_cos, freqs_sin
142 |
143 |
144 | def apply_rotary_pos_emb(
145 | q, k, cos, sin, position_ids=None, unsqueeze_dim=1
146 | ) -> Tuple[torch.Tensor, torch.Tensor]:
147 | def rotate_half(x):
148 | return torch.cat(
149 | (-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1
150 | )
151 |
152 | q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (
153 | rotate_half(q) * sin.unsqueeze(unsqueeze_dim)
154 | )
155 | k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (
156 | rotate_half(k) * sin.unsqueeze(unsqueeze_dim)
157 | )
158 | return q_embed, k_embed
159 |
160 |
161 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
162 | bs, slen, num_key_value_heads, head_dim = x.shape
163 | if n_rep == 1:
164 | return x
165 |
166 | return (
167 | x[:, :, :, None, :]
168 | .expand(bs, slen, num_key_value_heads, n_rep, head_dim)
169 | .reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
170 | )
171 |
172 |
173 | class Attention(nn.Module):
174 | def __init__(self, args: MokioMindConfig):
175 | super().__init__()
176 |
177 | self.num_key_value_heads = (
178 | args.num_attention_heads
179 | if args.num_key_value_heads is None
180 | else args.num_key_value_heads
181 | )
182 |
183 | assert args.num_attention_heads % self.num_key_value_heads == 0
184 |
185 | self.n_local_heads = args.num_attention_heads
186 | self.n_local_kv_heads = self.num_key_value_heads
187 | self.n_rep = self.n_local_heads // self.n_local_kv_heads
188 | self.head_dim = args.hidden_size // args.num_attention_heads
189 |
190 | self.q_proj = nn.Linear(
191 | args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
192 | )
193 | self.k_proj = nn.Linear(
194 | args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
195 | )
196 | self.v_proj = nn.Linear(
197 | args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
198 | )
199 | self.o_proj = nn.Linear(
200 | args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
201 | )
202 |
203 | self.attn_dropout = nn.Dropout(args.dropout)
204 | self.resid_dropout = nn.Dropout(args.dropout)
205 | self.dropout = args.dropout
206 | self.flash = (
207 | hasattr(torch.nn.functional, "scaled_dot_product_attention")
208 | and args.flash_attention
209 | )
210 |
211 | def forward(
212 | self,
213 | x: torch.Tensor,
214 | position_embeddings: Tuple[torch.Tensor, torch.Tensor],
215 | past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
216 | use_cache=False,
217 | attention_mask: Optional[torch.Tensor] = None,
218 | ):
219 | bsz, seq_len, _ = x.shape
220 |
221 | xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
222 |
223 | xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
224 | xk = xk.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)
225 | xv = xv.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)
226 |
227 | cos, sin = position_embeddings
228 | xq, xk = apply_rotary_pos_emb(xq, xk, cos[:seq_len], sin[:seq_len])
229 |
230 | if past_key_value is not None:
231 | xk = torch.cat([past_key_value[0], xk], dim=1)
232 | xv = torch.cat([past_key_value[1], xv], dim=1)
233 |
234 | past_kv = (xk, xv) if use_cache else None
235 |
236 | xq = xq.transpose(1, 2)
237 | xk = repeat_kv(xk, self.n_rep).transpose(1, 2)
238 | xv = repeat_kv(xv, self.n_rep).transpose(1, 2)
239 |
240 | if (
241 | self.flash
242 | and seq_len > 1
243 | and (attention_mask is None or torch.all(attention_mask == 1))
244 | ):
245 | attn_mask = (
246 | None
247 | if attention_mask is None
248 | else attention_mask.view(bsz, 1, 1, -1)
249 | .expand(bsz, self.n_local_heads, seq_len, -1)
250 | .bool()
251 | )
252 | output = F.scaled_dot_product_attention(
253 | xq,
254 | xk,
255 | xv,
256 | attn_mask=attn_mask,
257 | dropout_p=self.dropout if self.training else 0.0,
258 | is_causal=True, # 自回归(因果)注意力
259 | )
260 | else:
261 | scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
262 |
263 | causal_mask = torch.triu(
264 | torch.full((seq_len, seq_len), float("-inf"), device=scores.device),
265 | diagonal=-1,
266 | )
267 |
268 | scores=scores+causal_mask.unsqueeze(0).unsqueeze(0)
269 |
270 | if attention_mask is not None:
271 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
272 | extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
273 | scores = scores + extended_attention_mask
274 |
275 | scores = F.softmax(scores.float(), dim=-1).type_as(xq)
276 | scores=self.attn_dropout(scores)
277 | output=scores@xv
278 |
279 | output=output.transpose(1,2).reshape(
280 | bsz,seq_len,-1
281 | )
282 | output=self.resid_dropout(self.o_proj(output))
283 | return output,past_kv
284 |
285 | class FeedForward(nn.Module):
286 | def __init__(self, config: MokioMindConfig):
287 | super().__init__()
288 | if config.intermediate_size is None:
289 | intermediate_size = int(config.hidden_size * 8 / 3)
290 | config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
291 |
292 | self.gate_proj = nn.Linear(
293 | config.hidden_size, config.intermediate_size, bias=False
294 | )
295 | self.down_proj = nn.Linear(
296 | config.intermediate_size, config.hidden_size, bias=False
297 | )
298 | self.up_proj = nn.Linear(
299 | config.hidden_size, config.intermediate_size, bias=False
300 | )
301 | self.dropout = nn.Dropout(config.dropout)
302 | self.act_fn = ACT2FN[config.hidden_act]
303 |
304 | def forward(self, x):
305 | gated = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
306 | return self.dropout(self.down_proj(gated))
307 |
308 | class MoEGate(nn.Module):
309 | def __init__(self,config:MokioMindConfig):
310 | super().__init__()
311 | self.config=config
312 | self.top_k=config.num_experts_per_tok
313 |
314 | self.scoring_func=config.scoring_func
315 | self.alpha=config.aux_loss_alpha
316 | self.seq_aux=config.seq_aux
317 |
318 | self.norm_topk_prob=config.norm_topk_prob
319 | self.gating_dim=config.hidden_size
320 | self.weight=nn.Parameter(torch.empty((self.n_routed_experts,self.gating_dim)))
321 | self.reset_parameters()
322 |
323 | def reset_parameters(self)->None:
324 | init.kaiming_uniform_(self.weight,a=math.sqrt(5))
325 |
326 | def forward(self,hidden_states):
327 | bsz,seq_len,h=hidden_states.shape
328 | hidden_states=hidden_states.view(-1,h)
329 | logits=F.linear(hidden_states,self.weight,None)
330 |
331 | if self.scoring_func == 'softmax':
332 | scores = logits.softmax(dim=-1)
333 | else:
334 | raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
335 |
336 | topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
337 |
338 | if self.top_k > 1 and self.norm_topk_prob:
339 | denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
340 | topk_weight = topk_weight / denominator
341 |
342 | if self.training and self.alpha > 0.0:
343 | scores_for_aux = scores
344 | aux_topk = self.top_k
345 | topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
346 | if self.seq_aux:
347 | scores_for_seq_aux=scores_for_aux.view(bsz,seq_len,-1)
348 | ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
349 | ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device))
350 | ce = ce.div(seq_len * aux_topk / self.n_routed_experts)
351 | aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
352 | else:
353 | mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
354 | ce = mask_ce.float().mean(0)
355 | Pi = scores_for_aux.mean(0)
356 | fi = ce * self.n_routed_experts
357 | aux_loss = (Pi * fi).sum() * self.alpha
358 | else:
359 | aux_loss = 0
360 | return topk_weight, topk_idx, aux_loss
361 |
362 | class MoEFeedForaward(nn.Module):
363 | def __init__(self,config:MokioMindConfig):
364 | super().__init__()
365 | self.config=config
366 | # 专家层
367 | self.experts=nn.ModuleList(
368 | [FeedForward(config)
369 | for _ in range(config.n_routed_experts)]
370 | )
371 | # 门控层
372 | self.gate=MoEGate(config)
373 | if config.n_shared_experts>0:
374 | self.shared_experts=nn.ModuleList(
375 | [FeedForward(config)
376 | for _ in range(config.n_shared_experts)]
377 | )
378 | def forward(self,x):
379 | identity=x
380 | orig_shape=x.shape
381 | bsz,seq_len,h=orig_shape
382 |
383 | # 使用门控机制旋转专家
384 | topk_weight, topk_idx, aux_loss = self.gate(x)
385 | # 展开x以便处理
386 | x=x.view(-1,x.shape[-1])
387 |
388 | flat_topk_idx=topk_idx.view(-1)
389 | if self.training:
390 | # 按照定义的num_experts_per_tok重复输入token
391 | # 每个token安排num_experts_per_tok个专家处理
392 | x=x.repeat_interleave(self.config.num_experts_per_tok,dim=0)
393 | # y是空张量,和x形状相同
394 | y=torch.empty_like(x,dtype=torch.float32)
395 | # 遍历所有专家
396 | for i,expert in enumerate(self.experts):
397 | # 找到所有指向专家i的token
398 | # 然后将这些token输入专家i进行处理
399 | # 最后将结果放回y对应位置
400 | y[flat_topk_idx==i]=expert(x[flat_topk_idx==i]).to(y.dtype)
401 | # 加权求和
402 | # 最后的y意义是每个token经过专家处理后的加权结果
403 | y=(y.view(*topk_weight.shape,-1)*topk_weight.unsqueeze(-1).sum(dim=1))
404 | y=y.view(*orig_shape)
405 | # 如果是推理阶段
406 | else:
407 | y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
408 | if self.config.n_shared_experts > 0:
409 | for expert in self.shared_experts:
410 | y = y + expert(identity)
411 | self.aux_loss = aux_loss
412 | return y
413 |
414 | @torch.no_grad()
415 | # MoE推理方法
416 | def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
417 | # 使用cache,创建一个和x形状相同的零张量
418 | expert_cache = torch.zeros_like(x)
419 | # 对专家索引进行排序,最后是[0,0,0,1,1,2,2,2,...]这样的顺序
420 | # 分拣
421 | idxs = flat_expert_indices.argsort()
422 | # 统计每个专家被分配到的token数量
423 | # 打包
424 | tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
425 | # 计算每个token对应的专家索引
426 | token_idxs = idxs // self.config.num_experts_per_tok
427 | # 对每个打包好的包进行处理
428 | for i, end_idx in enumerate(tokens_per_expert):
429 | # 计算当前包的起始位置
430 | start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
431 | if start_idx == end_idx:
432 | continue
433 | # 取出当前包对应的专家
434 | expert = self.experts[i]
435 | # 取出token对应的原始id
436 | exp_token_idx = token_idxs[start_idx:end_idx]
437 | # 取出token对应的数据
438 | expert_tokens = x[exp_token_idx]
439 | # 计算专家输出,一次性处理当前包的所有token
440 | expert_out = expert(expert_tokens).to(expert_cache.dtype)
441 | # 加权
442 | expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
443 | # 将结果散点加到缓存中对应位置
444 | expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
445 |
446 | return expert_cache
447 |
448 |
449 |
450 | class MokioMindBlock(nn.Module):
451 | def __init__(self, layer_id: int, config: MokioMindConfig):
452 | super().__init__()
453 | self.num_attention_heads = config.num_attention_heads
454 | self.hidden_size = config.hidden_size
455 | self.head_dim = config.hidden_size // config.num_attention_heads
456 | self.self_attention = Attention(config)
457 |
458 | self.layer_id = layer_id
459 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
460 | self.post_attention_layernorm = RMSNorm(
461 | config.hidden_size, eps=config.rms_norm_eps
462 | )
463 | self.mlp = FeedForward(config)if not config.use_moe else MoEFeedForaward(config)
464 |
465 | def forward(
466 | self,
467 | hidden_states,
468 | position_embeddings: Tuple[torch.Tensor, torch.Tensor],
469 | past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
470 | use_cache=False,
471 | attention_mask: Optional[torch.Tensor] = None,
472 | ):
473 | res = hidden_states
474 |
475 | hidden_states, present_key_value = self.self_attention(
476 | self.input_layernorm(hidden_states), # pre-norm
477 | position_embeddings,
478 | past_key_value,
479 | use_cache,
480 | attention_mask,
481 | )
482 |
483 | hidden_states = res + hidden_states
484 |
485 | hidden_states = hidden_states + self.mlp(
486 | self.post_attention_layernorm(hidden_states)
487 | )
488 | return hidden_states, present_key_value
489 |
490 | class MokioMindModel(nn.Module):
491 | def __init__(self, config: MokioMindConfig):
492 | super().__init__()
493 | self.config = config
494 | self.vocab_size, self.num_hidden_layers = (
495 | config.vocab_size,
496 | config.num_hidden_layers,
497 | )
498 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
499 | self.dropout = nn.Dropout(config.dropout)
500 | self.layers = nn.ModuleList(
501 | [MokioMindBlock(l, config) for l in range(self.num_hidden_layers)]
502 | )
503 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
504 |
505 | freqs_cos, freqs_sin = precompute_freqs(
506 | dim=config.hidden_size // config.num_attention_heads,
507 | end=config.max_position_embeddings,
508 | rope_base=config.rope_theta,
509 | rope_scaling=config.rope_scaling,
510 | )
511 | self.register_buffer("freqs_cos", freqs_cos, persistent=False)
512 | self.register_buffer("freqs_sin", freqs_sin, persistent=False)
513 |
514 | def forward(
515 | self,
516 | input_ids: Optional[torch.Tensor] = None,
517 | attention_mask: Optional[torch.Tensor] = None,
518 | past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
519 | use_cache: bool = False,
520 | **kwargs,
521 | ):
522 | # input_ids: [bsz, seq_len]
523 | batch_size, seq_length = input_ids.shape
524 |
525 | if hasattr(past_key_values, "layers"):
526 | past_key_values = None
527 |
528 | past_key_values = past_key_values or [None] * len(self.layers)
529 |
530 | # 计算start_pos:如果存在past,则start_pos为已有past序列长度
531 | start_pos = (
532 | past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
533 | )
534 |
535 | # Embedding + dropout
536 | hidden_states = self.dropout(
537 | self.embed_tokens(input_ids)
538 | ) # [bsz, seq_len, hidden]
539 |
540 | position_embeddings = (
541 | self.freqs_cos[start_pos : start_pos + seq_length],
542 | self.freqs_sin[start_pos : start_pos + seq_length],
543 | )
544 | presents = []
545 | for layer_idx, (layer, past_key_value) in enumerate(
546 | zip(self.layers, past_key_values)
547 | ):
548 | hidden_states, present = layer(
549 | hidden_states,
550 | position_embeddings,
551 | past_key_value=past_key_value,
552 | use_cache=use_cache,
553 | attention_mask=attention_mask,
554 | )
555 | presents.append(present)
556 |
557 | hidden_states = self.norm(hidden_states)
558 |
559 | aux_loss=sum(layer.mlp.aux_loss for layer in self.layers if isinstance(layer.mlp,MoEFeedForaward))
560 |
561 | return hidden_states, presents,aux_loss
562 |
563 | class MokioMindForCausalLM(PreTrainedModel, GenerationMixin):
564 | config_class = MokioMindConfig
565 |
566 | def __init__(self, config: MokioMindConfig):
567 | super().__init__(config)
568 | self.model = MokioMindModel(config)
569 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
570 | self.model.embed_tokens.weight = self.lm_head.weight
571 |
572 | def forward(
573 | self,
574 | input_ids: Optional[torch.Tensor] = None,
575 | attention_mask: Optional[torch.Tensor] = None,
576 | past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
577 | use_cache: bool = False,
578 | logits_to_keep: Union[int, torch.Tensor] = 0,
579 | **args,
580 | ):
581 | h, past_kvs,aux_loss = self.model(
582 | input_ids=input_ids,
583 | attention_mask=attention_mask,
584 | past_key_values=past_key_values,
585 | use_cache=use_cache,
586 | **args,
587 | )
588 |
589 | slice_indices = (
590 | slice(-logits_to_keep, None)
591 | if isinstance(logits_to_keep, int)
592 | else logits_to_keep
593 | )
594 | logits = self.lm_head(h[:, slice_indices, :])
595 |
596 | return CausalLMOutputWithPast(
597 | logits=logits,
598 | past_key_values=past_kvs,
599 | hidden_states=h,
600 | )
601 |
--------------------------------------------------------------------------------
/trainer/train_grpo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # 📚 Python模块系统
5 | __package__ = "trainer"
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7 |
8 | import argparse # 命令行参数解析
9 | import re # 正则表达式,用于奖励计算
10 | import gc # 垃圾回收,手动释放内存
11 | import warnings # 警告控制
12 | import torch # PyTorch深度学习框架
13 | import torch.distributed as dist # 分布式训练支持
14 | from transformers import AutoTokenizer # HuggingFace分词器
15 | from contextlib import nullcontext # 上下文管理器
16 | from torch import optim # 优化器
17 | from torch.nn.parallel import DistributedDataParallel # 分布式并行
18 | from torch.utils.data import DataLoader, DistributedSampler # 数据加载
19 | from torch.optim.lr_scheduler import CosineAnnealingLR # 余弦退火学习率调度
20 | from transformers import AutoModel # HuggingFace模型加载
21 | from model.MokioModel import MokioMindConfig, MokioMindForCausalLM # MokioMind模型
22 | from dataset.lm_dataset import RLAIFDataset # RL数据集
23 | from trainer.trainer_utils import ( # 训练工具函数
24 | Logger, is_main_process, lm_checkpoint, init_distributed_mode,
25 | setup_seed, SkipBatchSampler, init_model
26 | )
27 |
28 | warnings.filterwarnings('ignore')
29 |
30 | def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
31 | # 整合所有奖励函数计算总奖励
32 | def reasoning_model_reward(rewards):
33 | # 先计算推理格式奖励
34 | pattern = r"^\n.*?\n\n\n.*?\n$"
35 | pattern2 = r"^\n.*?\n\n\n\n.*?\n$"
36 | matches_pattern = [re.match(pattern, response, re.S) for response in responses]
37 | matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
38 |
39 | format_rewards = []
40 | for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
41 | if match_pattern or match_pattern2:
42 | format_rewards.append(0.5)
43 | else:
44 | format_rewards.append(0.0)
45 | rewards += torch.tensor(format_rewards, device=args.device)
46 | def mark_num(text):
47 | reward = 0
48 | if text.count("") == 1: reward += 0.25
49 | if text.count("") == 1: reward += 0.25
50 | if text.count("") == 1: reward += 0.25
51 | if text.count("") == 1: reward += 0.25
52 | return reward
53 |
54 | mark_rewards = [mark_num(response) for response in responses]
55 | rewards += torch.tensor(mark_rewards, device=args.device)
56 | return rewards
57 | rewards = torch.zeros(len(responses), device=args.device)
58 |
59 | if args.reasoning == 1:
60 | rewards = reasoning_model_reward(rewards)
61 |
62 | with torch.no_grad():
63 | reward_model_scores = []
64 | batch_size = len(prompts)
65 | scale = 3.0
66 |
67 | # 📚 批处理评分
68 | for i in range(batch_size):
69 | for j in range(args.num_generations):
70 | response_idx = i * args.num_generations + j
71 | response = responses[response_idx]
72 | prompt = prompts[i]
73 |
74 | # 对话格式解析
75 | pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
76 | matches = re.findall(pattern, prompt, re.DOTALL)
77 | messages = [{"role": role, "content": content.strip()} for role, content in matches]
78 |
79 | # 构建完整对话
80 | tmp_chat = messages + [{"role": "assistant", "content": response}]
81 | score = reward_model.get_score(reward_tokenizer, tmp_chat)
82 | score = max(min(score, scale), -scale)
83 |
84 | # 推理模型额外奖励
85 | if args.reasoning == 1:
86 | answer_match = re.search(r'(.*?)', response, re.DOTALL)
87 | if answer_match:
88 | answer_content = answer_match.group(1).strip()
89 | tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
90 | answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
91 | answer_score = max(min(answer_score, scale), -scale)
92 | score = score * 0.4 + answer_score * 0.6
93 |
94 | reward_model_scores.append(score)
95 |
96 | reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
97 | rewards += reward_model_scores
98 |
99 | return rewards
100 |
101 | def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
102 | for step,batch in enumerate(loader,start=start_step+1):
103 | prompts = batch['prompt'] # list[str], length B
104 |
105 | # 📚 分词和编码
106 | prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
107 | padding_side="left", add_special_tokens=False).to(args.device)
108 | if args.max_seq_len:
109 | prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
110 | prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
111 |
112 | with torch.no_grad():
113 | model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
114 | outputs = model_for_gen.generate(
115 | **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
116 | num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id)
117 | completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):]
118 | def get_per_token_logps(mdl, input_ids, n_keep):
119 | input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
120 | logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
121 | per_token_logps = []
122 | for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
123 | ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
124 | per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
125 | return torch.stack(per_token_logps)
126 | per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
127 | with torch.no_grad():
128 | ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
129 |
130 | # 📚 解码响应文本
131 | completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
132 |
133 | # 📚 计算奖励
134 | rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device)
135 |
136 | grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
137 | mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
138 | std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
139 | advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
140 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
141 |
142 | is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
143 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
144 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
145 | completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()
146 |
147 | kl_div = ref_per_token_logps - per_token_logps
148 | per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
149 | per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
150 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar
151 | loss.backward()
152 |
153 | if (step + 1) % args.accumulation_steps == 0:
154 | if args.grad_clip > 0:
155 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
156 | optimizer.step()
157 | scheduler.step()
158 | optimizer.zero_grad()
159 |
160 | # 📚 日志记录
161 | if step % args.log_interval == 0 or step == iters:
162 | policy_loss_val = loss.item()
163 | avg_reward_val = rewards.mean().item()
164 | avg_len_val = completion_mask.sum(dim=1).float().mean().item()
165 | current_lr = optimizer.param_groups[0]['lr']
166 |
167 | Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
168 | f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
169 | f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
170 |
171 | if wandb and is_main_process():
172 | wandb.log({
173 | "policy_loss": policy_loss_val,
174 | "reward": avg_reward_val,
175 | "avg_response_len": avg_len_val,
176 | "advantages_mean": advantages.mean().item(),
177 | "learning_rate": current_lr
178 | })
179 |
180 | # 📚 模型保存
181 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
182 | model.eval()
183 | moe_suffix = '_moe' if lm_config.use_moe else ''
184 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
185 | state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
186 | torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
187 | lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
188 | epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
189 | model.train()
190 |
191 | # 📚 内存清理
192 | del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
193 | del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
194 | torch.cuda.empty_cache()
195 | gc.collect()
196 |
197 |
198 |
199 | if __name__ == "__main__":
200 | """
201 | GRPO主函数:组相对策略优化脚本的入口点
202 |
203 | 📚 GRPO训练架构:
204 | 1. Policy模型:需要优化的策略网络
205 | 2. Reference模型:冻结的参考策略
206 | 3. Reward模型:评估生成质量
207 | 4. 组内比较:每个prompt生成多个样本
208 | 5. 相对优势:标准化组内奖励
209 | """
210 |
211 | # 📚 命令行参数解析
212 | parser = argparse.ArgumentParser(description="MokioMind GRPO (Group Relative Policy Optimization)")
213 |
214 | # ========== 基础训练参数 ==========
215 | parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
216 | parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名")
217 | parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
218 | parser.add_argument("--batch_size", type=int, default=2, help="batch size(GRPO batch较小)")
219 |
220 | # 📚 GRPO学习率设置
221 | # GRPO学习率通常很小,避免策略剧烈变化
222 | parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
223 |
224 | # ========== 硬件配置 ==========
225 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
226 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
227 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
228 |
229 | # ========== 训练策略 ==========
230 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
231 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
232 | parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
233 | parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
234 |
235 | # ========== 模型架构参数 ==========
236 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
237 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
238 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
239 |
240 | # ========== GRPO生成参数 ==========
241 | parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
242 | parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
243 |
244 | # ========== 数据和模型参数 ==========
245 | parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
246 |
247 | # 📚 GRPO关键参数
248 | parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数(组大小)")
249 | parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
250 |
251 | # 📚 推理模型配置
252 | parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)')
253 |
254 | # 📚 Reward模型路径
255 | parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
256 |
257 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
258 |
259 | # ========== 实验跟踪 ==========
260 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
261 | parser.add_argument("--wandb_project", type=str, default="MokioMind-GRPO", help="wandb项目名")
262 |
263 | args = parser.parse_args()
264 |
265 | # ========== 1. 初始化环境和随机种子 ==========
266 | local_rank = init_distributed_mode()
267 | if dist.is_initialized(): args.device = f"cuda:{local_rank}"
268 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
269 |
270 | # ========== 2. 配置目录、模型参数、检查ckp ==========
271 | os.makedirs(args.save_dir, exist_ok=True)
272 | lm_config = MokioMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
273 | max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
274 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
275 |
276 | # ========== 3. 设置混合精度 ==========
277 | device_type = "cuda" if "cuda" in args.device else "cpu"
278 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
279 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
280 |
281 | # ========== 4. 配置wandb ==========
282 | wandb = None
283 | if args.use_wandb and is_main_process():
284 | import swanlab as wandb
285 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None
286 | resume = 'must' if wandb_id else None
287 | wandb_run_name = f"MokioMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
288 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
289 |
290 | # ========== 5. 初始化模型和数据 ==========
291 | # 📚 GRPO模型架构
292 | base_weight = "reason" if args.reasoning == 1 else "full_sft"
293 |
294 | # 📚 Policy模型(策略模型)
295 | model, tokenizer = init_model(lm_config, base_weight, device=args.device)
296 |
297 | # 📚 Reference模型(用于KL惩罚)
298 | ref_model, _ = init_model(lm_config, base_weight, device=args.device)
299 | ref_model = ref_model.eval().requires_grad_(False)
300 |
301 | # 📚 Reward模型(奖励函数)
302 | reward_model = AutoModel.from_pretrained(
303 | args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
304 | )
305 | reward_model = reward_model.to(args.device).eval().requires_grad_(False)
306 | reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
307 |
308 | # 📚 数据和优化器
309 | train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
310 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
311 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
312 | loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
313 | iters = len(loader_for_count)
314 | total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
315 | scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
316 |
317 | # ========== 6. 从ckp恢复状态 ==========
318 | start_epoch, start_step = 0, 0
319 | if ckp_data:
320 | model.load_state_dict(ckp_data['model'])
321 | optimizer.load_state_dict(ckp_data['optimizer'])
322 | scheduler.load_state_dict(ckp_data['scheduler'])
323 | start_epoch = ckp_data['epoch']
324 | start_step = ckp_data.get('step', 0)
325 |
326 | # ========== 7. DDP包装模型 ==========
327 | if dist.is_initialized():
328 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
329 | model = DistributedDataParallel(model, device_ids=[local_rank])
330 |
331 | # ========== 8. 开始训练 ==========
332 | for epoch in range(start_epoch, args.epochs):
333 | train_sampler and train_sampler.set_epoch(epoch)
334 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
335 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
336 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
337 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
338 | grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
339 | else: # 默认从头开始
340 | loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
341 | drop_last=False, shuffle=(train_sampler is None),
342 | num_workers=args.num_workers, sampler=train_sampler)
343 | grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
344 |
345 |
346 | def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
347 | """整合所有奖励函数计算总奖励"""
348 | def reasoning_model_reward(rewards):
349 | pattern = r"^\n.*?\n\n\n.*?\n$"
350 | pattern2 = r"^\n.*?\n\n\n\n.*?\n$"
351 | matches_pattern = [re.match(pattern, response, re.S) for response in responses]
352 | matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
353 |
354 | format_rewards = []
355 | for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
356 | if match_pattern or match_pattern2:
357 | format_rewards.append(0.5)
358 | else:
359 | format_rewards.append(0.0)
360 | rewards += torch.tensor(format_rewards, device=args.device)
361 |
362 | def mark_num(text):
363 | reward = 0
364 | if text.count("") == 1: reward += 0.25
365 | if text.count("") == 1: reward += 0.25
366 | if text.count("") == 1: reward += 0.25
367 | if text.count("") == 1: reward += 0.25
368 | return reward
369 |
370 | mark_rewards = [mark_num(response) for response in responses]
371 | rewards += torch.tensor(mark_rewards, device=args.device)
372 | return rewards
373 |
374 | rewards = torch.zeros(len(responses), device=args.device)
375 | if args.reasoning == 1:
376 | rewards = reasoning_model_reward(rewards)
377 |
378 | with torch.no_grad():
379 | reward_model_scores = []
380 | batch_size = len(prompts)
381 | scale = 3.0
382 |
383 | for i in range(batch_size):
384 | for j in range(args.num_generations):
385 | response_idx = i * args.num_generations + j
386 | response = responses[response_idx]
387 | prompt = prompts[i]
388 |
389 | pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
390 | matches = re.findall(pattern, prompt, re.DOTALL)
391 | messages = [{"role": role, "content": content.strip()} for role, content in matches]
392 |
393 | tmp_chat = messages + [{"role": "assistant", "content": response}]
394 | score = reward_model.get_score(reward_tokenizer, tmp_chat)
395 | score = max(min(score, scale), -scale)
396 |
397 | if args.reasoning == 1:
398 | answer_match = re.search(r'(.*?)', response, re.DOTALL)
399 | if answer_match:
400 | answer_content = answer_match.group(1).strip()
401 | tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
402 | answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
403 | answer_score = max(min(answer_score, scale), -scale)
404 | score = score * 0.4 + answer_score * 0.6
405 |
406 | reward_model_scores.append(score)
407 |
408 | reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
409 | rewards += reward_model_scores
410 |
411 | return rewards
412 |
413 |
414 | def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
415 | for step, batch in enumerate(loader, start=start_step + 1):
416 | prompts = batch['prompt'] # list[str], length B
417 | prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
418 | padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
419 | if args.max_seq_len:
420 | prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
421 | prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
422 |
423 | with torch.no_grad():
424 | # DDP 模型需要使用 .module 访问 generate 方法
425 | model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
426 | outputs = model_for_gen.generate(
427 | **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
428 | num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
429 |
430 | completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
431 |
432 | def get_per_token_logps(mdl, input_ids, n_keep):
433 | input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
434 | logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
435 | per_token_logps = []
436 | for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
437 | ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
438 | per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
439 | return torch.stack(per_token_logps)
440 |
441 | per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
442 | with torch.no_grad():
443 | ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
444 |
445 | completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
446 | rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]
447 |
448 | grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
449 | mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
450 | std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
451 | advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
452 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
453 |
454 | is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
455 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
456 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
457 | completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
458 |
459 | kl_div = ref_per_token_logps - per_token_logps
460 | per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
461 | per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
462 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar
463 | loss.backward()
464 |
465 | if (step + 1) % args.accumulation_steps == 0:
466 | if args.grad_clip > 0:
467 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
468 | optimizer.step()
469 | scheduler.step()
470 | optimizer.zero_grad()
471 |
472 | if step % args.log_interval == 0 or step == iters:
473 | policy_loss_val = loss.item()
474 | avg_reward_val = rewards.mean().item()
475 | avg_len_val = completion_mask.sum(dim=1).float().mean().item()
476 | current_lr = optimizer.param_groups[0]['lr']
477 |
478 | Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
479 | f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
480 | f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
481 |
482 | if wandb and is_main_process():
483 | wandb.log({
484 | "policy_loss": policy_loss_val,
485 | "reward": avg_reward_val,
486 | "avg_response_len": avg_len_val,
487 | "advantages_mean": advantages.mean().item(),
488 | "learning_rate": current_lr
489 | })
490 |
491 | if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
492 | model.eval()
493 | moe_suffix = '_moe' if lm_config.use_moe else ''
494 | ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
495 | state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
496 | torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
497 | lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
498 | epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
499 | model.train()
500 |
501 | del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
502 | del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
503 | torch.cuda.empty_cache()
504 | gc.collect()
505 |
506 |
507 | if __name__ == "__main__":
508 | parser = argparse.ArgumentParser(description="MokioMind GRPO (Group Relative Policy Optimization)")
509 | parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
510 | parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名")
511 | parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
512 | parser.add_argument("--batch_size", type=int, default=2, help="batch size")
513 | parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
514 | parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
515 | parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
516 | parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
517 | parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
518 | parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
519 | parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
520 | parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
521 | parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
522 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
523 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
524 | parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
525 | parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
526 | parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
527 | parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
528 | parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
529 | parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)')
530 | parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
531 | parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
532 | parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
533 | parser.add_argument("--wandb_project", type=str, default="MokioMind-GRPO", help="wandb项目名")
534 | args = parser.parse_args()
535 |
536 | # ========== 1. 初始化环境和随机种子 ==========
537 | local_rank = init_distributed_mode()
538 | if dist.is_initialized(): args.device = f"cuda:{local_rank}"
539 | setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
540 |
541 | # ========== 2. 配置目录、模型参数、检查ckp ==========
542 | os.makedirs(args.save_dir, exist_ok=True)
543 | lm_config = MokioMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
544 | max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
545 | ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
546 |
547 | # ========== 3. 设置混合精度 ==========
548 | device_type = "cuda" if "cuda" in args.device else "cpu"
549 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
550 | autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
551 |
552 | # ========== 4. 配wandb ==========
553 | wandb = None
554 | if args.use_wandb and is_main_process():
555 | import swanlab as wandb
556 | wandb_id = ckp_data.get('wandb_id') if ckp_data else None
557 | resume = 'must' if wandb_id else None
558 | wandb_run_name = f"MokioMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
559 | wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
560 |
561 | # ========== 5. 初始化模型和数据 ==========
562 | base_weight = "reason" if args.reasoning == 1 else "full_sft"
563 | # Policy模型
564 | model, tokenizer = init_model(lm_config, base_weight, device=args.device)
565 | # Reference模型
566 | ref_model, _ = init_model(lm_config, base_weight, device=args.device)
567 | ref_model = ref_model.eval().requires_grad_(False)
568 | # Reward模型
569 | reward_model = AutoModel.from_pretrained(
570 | args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
571 | )
572 | reward_model = reward_model.to(args.device).eval().requires_grad_(False)
573 | reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
574 | # 数据和优化器
575 | train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
576 | train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
577 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
578 | loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
579 | iters = len(loader_for_count)
580 | total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
581 | scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
582 |
583 | # ========== 6. 从ckp恢复状态 ==========
584 | start_epoch, start_step = 0, 0
585 | if ckp_data:
586 | model.load_state_dict(ckp_data['model'])
587 | optimizer.load_state_dict(ckp_data['optimizer'])
588 | scheduler.load_state_dict(ckp_data['scheduler'])
589 | start_epoch = ckp_data['epoch']
590 | start_step = ckp_data.get('step', 0)
591 |
592 | # ========== 7. DDP包模型 ==========
593 | if dist.is_initialized():
594 | model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
595 | model = DistributedDataParallel(model, device_ids=[local_rank])
596 |
597 | # ========== 8. 开始训练 ==========
598 | for epoch in range(start_epoch, args.epochs):
599 | train_sampler and train_sampler.set_epoch(epoch)
600 | if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
601 | batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
602 | loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
603 | Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
604 | grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
605 | else: # 默认从头开始
606 | loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
607 | drop_last=False, shuffle=(train_sampler is None),
608 | num_workers=args.num_workers, sampler=train_sampler)
609 | grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
610 |
--------------------------------------------------------------------------------