├── readme.md ├── image ├── GRPO.png ├── GRPO-kl.png ├── GRPO-loss.png ├── GRPO-reward.png └── Snipaste_2025-03-31_20-41-35.png ├── for_ppo.py └── for_ppo.ipynb /readme.md: -------------------------------------------------------------------------------- 1 | 先手撕ppo 2 | 3 | ppo难度最大 4 | 5 | 再去手撕其他方法 -------------------------------------------------------------------------------- /image/GRPO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlll123456/study_rlhf/HEAD/image/GRPO.png -------------------------------------------------------------------------------- /image/GRPO-kl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlll123456/study_rlhf/HEAD/image/GRPO-kl.png -------------------------------------------------------------------------------- /image/GRPO-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlll123456/study_rlhf/HEAD/image/GRPO-loss.png -------------------------------------------------------------------------------- /image/GRPO-reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlll123456/study_rlhf/HEAD/image/GRPO-reward.png -------------------------------------------------------------------------------- /image/Snipaste_2025-03-31_20-41-35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlll123456/study_rlhf/HEAD/image/Snipaste_2025-03-31_20-41-35.png -------------------------------------------------------------------------------- /for_ppo.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # 代码实现ppo 3 | 4 | # %% [markdown] 5 | # trl代码中的对于ppo的实现 6 | # https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py 7 | 8 | # %% [markdown] 9 | # 下面为你解释这些参数的含义: 10 | # 11 | # ### 模型架构相关参数 12 | # 1. **`vocab_size = 10`** 13 | # 词汇表的大小代表了模型能够识别的不同词汇的数量。举例来说,若你正在处理的是一个简单的数字文本任务,其中仅有 0 - 9 这 10 个数字,那么 `vocab_size` 就会被设定为 10。 14 | # 15 | # 2. **`hidden_size = 128`** 16 | # 隐藏层的维度大小表明了模型中每个隐藏层神经元的数量。在神经网络里,隐藏层会对输入数据进行特征提取与转换。`hidden_size` 越大,模型所能学习到的特征就越复杂,不过这也会使计算量和内存需求增加。 17 | # 18 | # 3. **`intermediate_size = 256`** 19 | # 在 Transformer 架构里,`intermediate_size` 指的是前馈神经网络(FFN)中间层的维度。FFN 一般由两个线性层构成,中间层的维度通常会比输入输出层的维度大,这样有助于模型学习到更丰富的特征。 20 | # 21 | # 4. **`num_hidden_layers = 2`** 22 | # 隐藏层的数量意味着模型中堆叠的隐藏层的层数。层数越多,模型的表达能力就越强,能够学习到更复杂的模式,但同时也会增加过拟合的风险以及训练的难度。 23 | # 24 | # 5. **`num_attention_heads = 4`** 25 | # 注意力头的数量是指在多头注意力机制中并行的注意力头的个数。多头注意力机制能够让模型从不同的表示子空间中捕捉特征,提升模型的表达能力。 26 | # 27 | # 6. **`num_key_value_heads = 4`** 28 | # 键值对注意力头的数量在某些改进的注意力机制中会用到,它决定了用于计算键(key)和值(value)的注意力头的数量。在标准的多头注意力机制里,`num_key_value_heads` 通常和 `num_attention_heads` 相等。 29 | # 30 | # ### 数据处理和生成相关参数 31 | # 7. **`batch_size = 5`** 32 | # 批量大小代表了在一次训练或者推理过程中同时处理的样本数量。使用较大的批量大小能够提升训练效率,但会增加内存的需求;而较小的批量大小则可以减少内存使用,但会使训练速度变慢。 33 | # 34 | # 8. **`length_x = 5`** 35 | # 输入序列的长度指的是每个输入样本的长度。在处理文本时,它代表的是输入文本中词元(token)的数量。 36 | # 37 | # 9. **`max_new_tokens = 5`** 38 | # 最大新生成的词元数量表示在文本生成任务中,模型最多可以生成的词元数量。例如在文本续写任务里,这个参数会限制模型生成的文本长度。 39 | 40 | # %% 41 | vocab_size = 10 #当前教程实际使用的时候是词汇表实际大小 42 | hidden_size = 128 43 | intermediate_size = 256 44 | num_hidden_layers = 2 45 | num_attention_heads = 4 46 | batch_size = 3 47 | length_x = 5 48 | max_new_tokens = 5 49 | 50 | # %% [markdown] 51 | # ## 初始化actor模型 52 | # 53 | # 以GPT2为例,初始化模型 54 | 55 | # %% 56 | import torch 57 | from transformers import GPT2Config, GPT2LMHeadModel 58 | 59 | torch.manual_seed(1) 60 | 61 | # 定义参数 62 | vocab_size = 10 63 | hidden_size = 128 64 | intermediate_size = 256 65 | num_hidden_layers = 2 66 | num_attention_heads = 4 67 | 68 | # 加载模型配置 69 | config = GPT2Config( 70 | vocab_size=50257, 71 | n_embd=hidden_size, 72 | n_inner=intermediate_size, 73 | n_layer=num_hidden_layers, 74 | n_head=num_attention_heads 75 | ) 76 | 77 | # 初始化 GPT - 2 模型 78 | model = GPT2LMHeadModel(config) 79 | 80 | # %% [markdown] 81 | # ## model generate 82 | # 83 | # 主要看下inputs_ids和attention_mask的含义 84 | 85 | # %% [markdown] 86 | # ### inputs_ids 87 | # 88 | # input_ids:它是一个张量(tensor),表示文本被分词后每个词(token)对应的 ID。比如在第一行 [20015, 232, 25465, ...] 中,每个数字都是原文本中一个词被 GPT - 2 分词器转换后的唯一标识。不同模型的词表不同,这些 ID 对应的具体词汇也不一样。这里第一行可能对应一句中文文本分词结果,第二行 [14150, 257, 922, ...] 前半部分对应英文文本,后半部分 50256 一般是填充值 ,表示补齐固定长度。 89 | # 90 | # 91 | # attention_mask:同样是张量,用于指示哪些位置是有效的词(值为 1),哪些位置是填充的(值为 0) 。比如第二行 [1, 1, 1, 1, 0, 0, 0, 0, 0, 0] 表示前 4 个词是有效输入,后面是填充的,模型在处理时会忽略填充位置。 92 | 93 | # %% [markdown] 94 | # inputs_ids可以认为是要输入的文本经过tokenizer处理后的结果,而attention_mask则是用于指示哪些位置是有效的词(值为 1),哪些位置是填充的(值为 0) 。 95 | 96 | # %% 97 | from transformers import GPT2Tokenizer 98 | import torch 99 | 100 | # 初始化 GPT - 2 分词器 101 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 102 | # 设置padding token 103 | tokenizer.pad_token = tokenizer.eos_token # 使用EOS token作为padding token 104 | 105 | # 输入文本 106 | inputs = ['今天天气不错', 'have a good day'] 107 | 108 | # 对输入进行分词处理 109 | inputs = tokenizer(inputs, return_tensors='pt',padding=True, truncation=True) 110 | 111 | print(inputs) 112 | 113 | # %% 114 | output_ids = model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens) 115 | 116 | print(output_ids) 117 | 118 | 119 | # %% 120 | output_ids = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 121 | print(output_ids) 122 | 123 | # %% [markdown] 124 | # 填充左边和右边会导致input_ids中padding_id的位置不一样,导致attention_mask中padding_id的位置不一样,导致模型在处理时会忽略填充位置。 125 | 126 | # %% 127 | tokenizer.padding_side = 'left' 128 | inputs = ['今天天气不错', 'have a good day'] 129 | inputs = tokenizer(inputs, return_tensors='pt',padding=True, truncation=True) 130 | 131 | print(inputs) 132 | 133 | output_ids = model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens) 134 | 135 | print(output_ids) 136 | 137 | output_ids = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 138 | print(output_ids) 139 | 140 | # %% [markdown] 141 | # ## 初始化reward model 142 | 143 | # %% [markdown] 144 | # 根据之前的定义,奖励模型可以从模型的输出中提取出最后一个token的隐藏状态,然后通过一个线性层计算奖励。 145 | 146 | # %% [markdown] 147 | # 假设batch_size = 2, sequence_length = 4 148 | # input_ids = torch.tensor([ 149 | # [1, 2, 3, 4], # 第一个序列 150 | # [5, 6, 7, 8] # 第二个序列 151 | # ]) 152 | # 153 | # attention_mask = torch.tensor([ 154 | # [1, 1, 1, 0], # 第一个序列有效长度为3 155 | # [1, 1, 1, 1] # 第二个序列有效长度为4 156 | # ]) 157 | # 158 | # sequence_length = attention_mask.sum(dim=1).long() - 1 159 | # 160 | # 结果: tensor([2, 3]) 161 | # 162 | # 第一个序列:3-1=2(索引从0开始) 163 | # 164 | # 第二个序列:4-1=3 165 | # 166 | # batch_indices = torch.arange(batch_size) 167 | # 168 | # 结果: tensor([0, 1]) 169 | # 170 | # 假设hidden_size = 2 171 | # 172 | # last_hidden_state = torch.tensor([ 173 | # [[1.0, 1.1], [2.0, 2.1], [3.0, 3.1], [4.0, 4.1]], # 第一个序列 174 | # [[5.0, 5.1], [6.0, 6.1], [7.0, 7.1], [8.0, 8.1]] # 第二个序列 175 | # ]) 176 | # 177 | # 使用batch_indices和sequence_length提取 178 | # 179 | # result = last_hidden_state[batch_indices, sequence_length] 180 | # 181 | # 结果: tensor([[3.0, 3.1], # 第一个序列的第2个位置(索引从0开始) 182 | # 183 | # [8.0, 8.1]]) # 第二个序列的第3个位置 184 | 185 | # %% 186 | class GPTRewardModel(torch.nn.Module): 187 | def __init__(self, gpt_model, reward_head): 188 | super(GPTRewardModel, self).__init__() 189 | self.gpt_model = gpt_model 190 | self.reward_head = reward_head 191 | 192 | def forward(self, input_ids, attention_mask): 193 | # 获取模型的输出 194 | outputs = self.gpt_model(input_ids=input_ids, attention_mask=attention_mask) 195 | # 通常取最后一个隐藏状态作为输出 196 | last_hidden_state = outputs.hidden_states[-1] 197 | batch_size = input_ids.shape[0] 198 | # 确保sequence_length是long类型 199 | sequence_length = attention_mask.sum(dim=1).long() - 1 200 | 201 | # 使用torch.arange并确保在正确的设备上 202 | batch_indices = torch.arange(batch_size, device=input_ids.device).long() 203 | last_hidden_state = last_hidden_state[batch_indices, sequence_length] 204 | 205 | # 计算奖励 206 | rewards = self.reward_head(last_hidden_state) 207 | return rewards 208 | 209 | # 重新初始化模型 210 | model.config.output_hidden_states = True 211 | rm_model = GPTRewardModel(model, torch.nn.Linear(hidden_size, 1)) 212 | 213 | # %% 214 | reward = rm_model(inputs['input_ids'], inputs['attention_mask']) 215 | print(reward) 216 | 217 | # %% [markdown] 218 | # ## 简化版ppo 219 | # 从以上过程可以看出,我们输入给模型的其实是input_ids和attention_mask,所以我们现在为了展示方便,构造一个没有实际意义的输入,输入给模型,然后输出奖励。 220 | 221 | # %% 222 | prompt = torch.randint(0, vocab_size, (batch_size, length_x)) 223 | response = torch.randint(0, vocab_size, (batch_size, length_x + max_new_tokens)) 224 | 225 | # %% 226 | print(prompt) 227 | print(response) 228 | 229 | # %% [markdown] 230 | # 我们希望让模型只关注response,所以对prompt对应的mask置为0 231 | 232 | # %% 233 | attention_mask = torch.ones(batch_size, length_x+max_new_tokens) 234 | attention_mask[:, :length_x] = 0 235 | print(attention_mask) 236 | 237 | 238 | # %% 239 | prompt_attention_mask = torch.ones(batch_size, length_x) 240 | prompt_attention_mask 241 | 242 | # %% [markdown] 243 | # 创建几个模型 244 | # 245 | # 246 | # model_ref 和model的配置一样 247 | # 248 | # reward model和value model的配置大体一样 249 | # 250 | # value model的输出是所有token的隐藏状态所得到的value 251 | 252 | # %% 253 | # 初始化 GPT - 2 模型 254 | model_ref = GPT2LMHeadModel(config) 255 | 256 | # %% [markdown] 257 | # 查看区别 258 | 259 | # %% 260 | print(model_ref) 261 | print(model) 262 | 263 | # %% [markdown] 264 | # ## 初始化value model 265 | 266 | # %% [markdown] 267 | # 假设我们有以下维度的数据: 268 | # 269 | # last_hidden_state 的形状是 [batch_size, sequence_length, hidden_size] 270 | # 271 | # 比如 [5, 10, 128],表示批次大小为5,序列长度为10,隐藏层维度为128 272 | # 273 | # self.value_head 是一个线性层 Linear(hidden_size, 1) 274 | # 275 | # 输入维度是128,输出维度是1 276 | # 277 | # 处理过程: 278 | # 279 | # self.value_head(last_hidden_state) 的操作: 280 | # 281 | # 输入: [5, 10, 128] 282 | # 283 | # 输出: [5, 10, 1] # 线性层将最后一个维度从128转换为1 284 | # 285 | # [:, :, 0] 的操作: 286 | # 287 | # 取最后一个维度的第0个元素 288 | # 289 | # 结果形状变为: [5, 10] 290 | 291 | # %% 292 | class GPTValueModel(torch.nn.Module): 293 | def __init__(self, gpt_model, value_head): 294 | super().__init__() 295 | self.gpt_model = gpt_model 296 | self.value_head = value_head 297 | 298 | def forward(self, input_ids, attention_mask): 299 | outputs = self.gpt_model(input_ids=input_ids, attention_mask=attention_mask) 300 | last_hidden_state = outputs.hidden_states[-1] 301 | 302 | values = self.value_head(last_hidden_state)[:, :, 0] 303 | 304 | return values 305 | 306 | model.config.output_hidden_states = True 307 | vm_model = GPTValueModel(model,torch.nn.Linear(hidden_size, 1)) 308 | 309 | # %% 310 | print(rm_model) 311 | print(vm_model) 312 | 313 | # %% [markdown] 314 | # ## ppo前向过程 315 | 316 | # %% [markdown] 317 | # 创建几个model的函数 318 | 319 | # %% 320 | def get_response(model, prompt, max_new_tokens): 321 | inputs = {'input_ids': prompt} # ignore mask,好像不需要mask 322 | y = model.generate(**inputs, 323 | max_new_tokens=max_new_tokens, 324 | # forced_eos_token_id=True 325 | ) 326 | return y 327 | 328 | def get_reward(model, response, attention_mask): 329 | inputs = {'input_ids': response, 'attention_mask': attention_mask} # ignore mask 330 | y = model(inputs['input_ids'], inputs['attention_mask']) 331 | return y 332 | 333 | def get_value(model, prompt, attention_mask): 334 | inputs = {'input_ids': prompt, 'attention_mask': attention_mask} # ignore mask 335 | y = model(inputs['input_ids'], inputs['attention_mask']) 336 | return y 337 | 338 | # %% 339 | prompt 340 | 341 | # %% 342 | response 343 | 344 | # %% 345 | prompt_attention_mask 346 | 347 | # %% 348 | attention_mask 349 | 350 | # %% 351 | print(get_response(model, prompt, max_new_tokens, prompt_attention_mask)) 352 | print(get_reward(rm_model, response, attention_mask)) 353 | print(get_value(vm_model, response, attention_mask)) 354 | 355 | 356 | # %% [markdown] 357 | # PPO 相关设置 358 | 359 | # %% [markdown] 360 | # 封装几个ppo的model 361 | 362 | # %% 363 | class PPOModels(): 364 | def __init__(self, model_actor, model_ref, model_rm, model_critic): 365 | self.actor = model_actor 366 | self.ref = model_ref 367 | self.rm = model_rm 368 | self.critic = model_critic 369 | 370 | 371 | model_ref.eval() 372 | rm_model.eval() 373 | models = PPOModels(model, model_ref, rm_model, vm_model) 374 | 375 | 376 | # %% [markdown] 377 | # 设置ppo的超参数 378 | 379 | # %% [markdown] 380 | # 1. ppo_epochs在每次策略更新时,PPO 算法对收集到的数据进行迭代训练的次数。 381 | # 382 | # 2. mini_batch_size每个训练步骤中,从收集到的数据里选取的小批量数据的样本数量。 383 | # 384 | # 3. epochs整个训练过程中,算法对所有收集到的数据进行完整遍历的次数。 385 | # 386 | # 4. kl_ctlKL 散度惩罚项的系数,用于控制新旧策略之间的差异程度。 387 | # 388 | # 5. vf_coef价值函数损失的系数,用于平衡策略损失和价值函数损失在总损失中的权重。 389 | # 390 | # 6. lam广义优势估计(GAE)中的 \(\lambda\) 参数,用于平衡优势估计的偏差和方差。 391 | # 392 | # 7. gamma折扣因子,用于计算未来奖励的折现值,决定未来奖励在当前价值估计中的重要程度。 393 | # 394 | # 8. cliprange_value价值函数裁剪范围的参数,用于限制价值函数更新的幅度 395 | 396 | # %% 397 | class PPOConfig(): 398 | def __init__(self): 399 | self.ppo_epochs = 5 400 | self.mini_batch_size = 2 401 | self.epochs = 4 402 | self.kl_ctl = 0.1 403 | self.vf_coef = 0.1 404 | self.lam = 0.9 405 | self.gamma = 0.9 406 | self.cliprange_value = 0.2 407 | 408 | def __str__(self): 409 | return f'ppo_epochs:{self.ppo_epochs}\nmini_batch_size:{self.mini_batch_size}\nepochs:{self.epochs}\nkl_ctl:{self.kl_ctl}' 410 | 411 | 412 | ppo_config = PPOConfig() 413 | 414 | # %% [markdown] 415 | # 在每一步中ppo都在干什么 416 | 417 | # %% [markdown] 418 | # 首先要有个列表来记录每一步的采样 419 | 420 | # %% 421 | ppo_old_batchs = { 422 | 'prompt': None, 423 | 'response': None, 424 | 'mask': None, 425 | 'logprobs_ref': None, 426 | 'logprobs_old': None, 427 | 'logprobs': None, 428 | 'values_old': None, 429 | 'values': None, 430 | 'rewards': None, 431 | 'rewards_kl': None, 432 | 'loss': None, 433 | 'logits': None, 434 | } 435 | 436 | ppo_old_batchs['prompt'] = prompt 437 | ppo_old_batchs['response'] = response 438 | ppo_old_batchs['mask'] = attention_mask 439 | 440 | # %% 441 | ppo_old_batchs 442 | 443 | # %% [markdown] 444 | # 前向推理,得到token的logprobs 445 | 446 | # %% [markdown] 447 | # logprobs = F.log_softmax(logits, dim=-1)第一步:对logits进行softmax并取log 448 | # 449 | # torch.gather是一个用于从张量中按索引收集值的操作 450 | # 451 | # 假设我们有: 452 | # 453 | # logp.shape = [1, 5, 32] # [batch_size, seq_len, vocab_size] 454 | # 455 | # labels.shape = [1, 5] # [batch_size, seq_len] 456 | # 457 | # 1. labels.unsqueeze(2) 458 | # 459 | # 在最后增加一个维度 460 | # 461 | # labels_expanded = labels.unsqueeze(2) # shape变为[1, 5, 1] 462 | # 463 | # 2. torch.gather(logp, 2, labels_expanded) 464 | # 465 | # dim=2表示在词表维度(第3维)上收集值 466 | # 467 | # gathered = torch.gather(logp, 2, labels_expanded) # shape为[1, 5, 1] 468 | # 469 | # 3. squeeze(-1) 470 | # 471 | # 去掉最后一个维度 472 | # 473 | # logpy = gathered.squeeze(-1) # 最终shape为[1, 5] 474 | 475 | # %% 476 | import torch.nn.functional as F 477 | 478 | def get_logits(model, input_ids): 479 | # 得到logits 480 | outputs = model(input_ids=input_ids) 481 | logits = outputs.logits 482 | return logits 483 | 484 | def get_logprobs(model, response, attention_mask): 485 | # 得到logprobs 486 | logits = get_logits(model, response) 487 | # F.log_softmax() 是先进行softmax运算然后再取对数(log) 488 | all_token_logprobs = F.log_softmax(logits, dim=-1) 489 | # 使用torch.gather() 从logprobs中收集response的值 490 | gathered = torch.gather(all_token_logprobs, 2, response.unsqueeze(2)) 491 | # 去掉最后一个维度 492 | response_logprobs = gathered.squeeze(-1) 493 | return response_logprobs 494 | 495 | logprobs_ref = get_logprobs(models.ref, ppo_old_batchs['response'], ppo_old_batchs['mask']) 496 | logprobs_old = get_logprobs(models.actor, ppo_old_batchs['response'], ppo_old_batchs['mask']) 497 | logprobs = get_logprobs(models.actor, ppo_old_batchs['response'], ppo_old_batchs['mask']) 498 | 499 | print(logprobs_ref.shape) 500 | print(logprobs_old.shape) 501 | print(logprobs.shape) 502 | 503 | 504 | # %% 505 | response.shape 506 | 507 | # %% 508 | logprobs 509 | 510 | # %% [markdown] 511 | # 计算kl 512 | 513 | # %% 514 | def get_kl(logprobs_ref, logprobs_old, kl_ctl): 515 | kl = logprobs_ref - logprobs_old 516 | kl = kl * kl_ctl 517 | return kl 518 | 519 | kl = get_kl(logprobs_ref, logprobs_old, ppo_config.kl_ctl) 520 | print(kl) 521 | 522 | 523 | # %% [markdown] 524 | # 计算reward_kl 525 | # 526 | 527 | # %% 528 | def get_reward_with_kl(logprobs_ref, logprobs_old, kl_ctl, reward): 529 | kl = logprobs_ref - logprobs_old 530 | kl = kl * kl_ctl 531 | kl[:, -1] += reward[:, 0] 532 | return kl 533 | 534 | print(kl) 535 | rewards = get_reward(models.rm, ppo_old_batchs['response'], ppo_old_batchs['mask']) 536 | print(rewards) 537 | 538 | kl_reward = get_reward_with_kl(logprobs_ref, logprobs_old, ppo_config.kl_ctl, rewards) 539 | print(kl_reward) 540 | 541 | 542 | # %% 543 | values = get_value(models.critic, ppo_old_batchs['response'], ppo_old_batchs['mask']) 544 | 545 | # %% 546 | values 547 | 548 | # %% 549 | ppo_old_batchs['logprobs_ref'] = logprobs_ref 550 | ppo_old_batchs['logprobs_old'] = logprobs_old 551 | ppo_old_batchs['logprobs'] = logprobs 552 | ppo_old_batchs['values_old'] = values 553 | ppo_old_batchs['rewards'] = rewards 554 | ppo_old_batchs['rewards_kl'] = kl_reward 555 | 556 | ppo_old_batchs 557 | 558 | # %% [markdown] 559 | # ## 计算loss 560 | 561 | # %% [markdown] 562 | # rewards:一个张量,代表在每个时间步获得的奖励。 563 | # 564 | # mask:一个掩码张量,用于标识哪些时间步是有效的(例如,用于处理终止状态)。 565 | # 566 | # values:一个张量,代表每个时间步的状态价值估计。 567 | # 568 | # gamma:折扣因子,用于计算未来奖励的折现值,取值范围通常在 [0, 1] 之间。 569 | # 570 | # lam:GAE 中的 \(\lambda\) 参数,用于平衡偏差和方差,取值范围同样在 [0, 1] 之间。 571 | 572 | # %% 573 | def get_GAE(rewards, attention_mask, values, gemma, lam): 574 | lastgae = 0 #初始化为 0,用于存储上一个时间步的广义优势估计值。 575 | advantages_recersed = [] 576 | response_len = rewards.shape[-1] 577 | 578 | values = values * attention_mask 579 | rewards = rewards * attention_mask 580 | 581 | for t in reversed(range(response_len)): 582 | nextvalues = values[:, t + 1] if t < response_len - 1 else 0.0 583 | # 计算时间步 t 的 TD 误差(Temporal Difference error),即当前奖励加上折扣后的下一个时间步的价值估计,再减去当前时间步的价值估计。 584 | delta = rewards[:, t] + gemma * nextvalues - values[:, t] 585 | # 根据 GAE 的递推公式,计算当前时间步的广义优势估计值。 586 | lastgae = delta + gemma * lam * lastgae 587 | advantages_recersed.append(lastgae) 588 | # 将 advantages_reversed 列表反转,使其按时间步的正序排列。 589 | advantages = torch.stack(advantages_recersed[::-1]).transpose(0, 1) 590 | return advantages 591 | 592 | 593 | # %% 594 | ppo_old_batchs 595 | 596 | # %% 597 | gae = get_GAE(ppo_old_batchs['rewards'], ppo_old_batchs['mask'], ppo_old_batchs['values_old'], ppo_config.gamma, ppo_config.lam) 598 | gae 599 | 600 | # %% 601 | gae = get_GAE(ppo_old_batchs['rewards_kl'], ppo_old_batchs['mask'], ppo_old_batchs['values_old'], ppo_config.gamma, ppo_config.lam) 602 | gae 603 | 604 | 605 | # %% [markdown] 606 | # 计算value loss 607 | # 608 | 609 | # %% [markdown] 610 | # advantages:优势函数的估计值,用于计算回报。 611 | # 612 | # 613 | # values:当前价值函数的估计值。 614 | # 615 | # values_old:旧的价值函数估计值。 616 | # 617 | # mask:掩码张量,用于指定哪些元素参与损失计算。 618 | # 619 | # cliprange_value:裁剪范围,用于限制价值函数的更新幅度。 620 | 621 | # %% [markdown] 622 | # https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/ppo_trainer.py#L561C29-L567C30 623 | 624 | # %% 625 | def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis = None) -> torch.Tensor: 626 | """Compute mean of tensor with a masked values.""" 627 | if axis is not None: 628 | return (values * mask).sum(axis=axis) / mask.sum(axis=axis) 629 | else: 630 | return (values * mask).sum() / mask.sum() 631 | 632 | def get_value_loss(advantages, values, values_old, attention_mask, cliprange_value): 633 | returns = values_old + advantages 634 | advantages = advantages.detach() 635 | 636 | vpredclipped = torch.clamp(values, values_old - cliprange_value, values_old + cliprange_value) 637 | 638 | vf_losses1 = torch.square(vpredclipped - returns) 639 | vf_losses2 = torch.square(values - returns) 640 | vf_loss_max = torch.max(vf_losses1, vf_losses2) 641 | vf_loss = 0.5 * masked_mean(vf_loss_max, attention_mask) 642 | return vf_loss 643 | 644 | 645 | 646 | # %% 647 | ppo_old_batchs['values'] = ppo_old_batchs['values_old'] + 0.5 648 | 649 | # %% 650 | value_loss = get_value_loss(gae, ppo_old_batchs['values'], ppo_old_batchs['values_old'], ppo_old_batchs['mask'], ppo_config.cliprange_value) 651 | value_loss 652 | 653 | # %% [markdown] 654 | # 计算policy loss 655 | # https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/ppo_trainer.py#L569-L574 656 | 657 | # %% 658 | def get_policy_loss(advantages, logprobs, logprobs_old, mask, cliprange): 659 | # 重要性采样 660 | ratio = torch.exp(logprobs - logprobs_old) 661 | # 计算策略损失 662 | pg_losses = -advantages * ratio 663 | pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) 664 | pg_loss_max = torch.max(pg_losses, pg_losses2) 665 | pg_loss = masked_mean(pg_loss_max, mask) 666 | return pg_loss 667 | 668 | 669 | 670 | # %% 671 | pg_loss = get_policy_loss(gae, ppo_old_batchs['logprobs'], ppo_old_batchs['logprobs_old'], ppo_old_batchs['mask'], ppo_config.cliprange_value) 672 | 673 | # %% 674 | pg_loss 675 | 676 | # %% [markdown] 677 | # 计算熵损失 678 | # https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/ppo_trainer.py#L582-L583 679 | 680 | # %% [markdown] 681 | # entropy(熵)没有直接参与到模型的损失(loss) 682 | # 683 | # 在计算完损失并进行反向传播和参数更新后,代码计算了 entropy 684 | # 685 | # 这里计算的 entropy 被记录到 entropy_stats 张量中,用于后续的统计和记录,但没有用于损失计算。 686 | 687 | # %% 688 | logits = get_logits(models.actor, ppo_old_batchs['response']) 689 | ppo_old_batchs['logits'] = logits 690 | 691 | # %% 692 | def get_entropy_loss(logits, mask): 693 | prob_dist = torch.nn.functional.softmax(logits, dim=-1) 694 | entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) 695 | return entropy 696 | 697 | entropy = get_entropy_loss(ppo_old_batchs['logits'], ppo_old_batchs['mask']) 698 | entropy 699 | 700 | 701 | # %% 702 | loss = pg_loss + ppo_config.vf_coef * value_loss 703 | 704 | # %% 705 | def get_loss(batchs, ppo_config): 706 | gae = get_GAE(batchs['rewards_kl'], 707 | batchs['mask'], 708 | batchs['values'], 709 | ppo_config.gamma, 710 | ppo_config.lam) 711 | value_loss = get_value_loss(gae, 712 | batchs['values'], 713 | batchs['values_old'], 714 | batchs['mask'], 715 | ppo_config.cliprange_value) 716 | pg_loss = get_policy_loss( 717 | gae, 718 | batchs['logprobs'], 719 | batchs['logprobs_old'], 720 | batchs['mask'], 721 | ppo_config.cliprange_value) 722 | entropy = get_entropy_loss(batchs['logits'], batchs['mask']) 723 | loss = pg_loss + ppo_config.vf_coef * value_loss 724 | return loss 725 | 726 | # %% 727 | loss = get_loss(ppo_old_batchs, ppo_config) 728 | loss 729 | 730 | # %% 731 | ppo_old_batchs 732 | 733 | # %% [markdown] 734 | # ## PPO训练 735 | # 736 | # https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/ppo_trainer.py#L529-L538 737 | 738 | # %% [markdown] 739 | # 将一个完整的批次数据 ppo_batchs 按照指定的 batch_size 和 mini_batch_size 划分成多个小批次数据 740 | 741 | # %% 742 | import numpy as np 743 | def get_minibatch(ppo_batchs, batch_size, mini_batch_size): 744 | # 计算需要多少个小批次 745 | step = batch_size // mini_batch_size 746 | ppo_batchs_iter = [] 747 | 748 | # 随机打乱索引以提高训练效果 749 | b_inds = np.random.permutation(batch_size) 750 | 751 | # 根据索引创建小批次 752 | for i in range(step): 753 | start_idx = i * mini_batch_size 754 | end_idx = start_idx + mini_batch_size 755 | batch_inds = b_inds[start_idx:end_idx] 756 | 757 | # 创建当前小批次的数据 758 | mini_batch = {} 759 | for key, value in ppo_batchs.items(): 760 | if value is not None and isinstance(value, torch.Tensor) and value.size(0) == batch_size: 761 | mini_batch[key] = value[batch_inds] 762 | else: 763 | mini_batch[key] = value 764 | 765 | ppo_batchs_iter.append(mini_batch) 766 | 767 | return ppo_batchs_iter 768 | 769 | # %% 770 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 771 | 772 | # %% 773 | ppo_old_batchs 774 | 775 | # %% 776 | def ppo_train_step(models, ppo_batchs, ppo_config, get_loss, optimizer): 777 | losses = [] 778 | 779 | 780 | # 多轮PPO训练 781 | for i in range(ppo_config.ppo_epochs): 782 | # 获取小批次数据 783 | ppo_batchs_iter = get_minibatch( 784 | ppo_batchs, batch_size, ppo_config.mini_batch_size) 785 | 786 | # 对每个小批次进行训练 787 | for mini_batchs in ppo_batchs_iter: 788 | # 获取当前策略的输出 789 | optimizer.zero_grad() 790 | # 重新计算所有中间结果,而不是重用之前的计算图 791 | with torch.set_grad_enabled(True): 792 | logits = get_logits(models.actor, mini_batchs['prompt']) 793 | """ 794 | 省略了 795 | """ 796 | 797 | 798 | # 计算损失 799 | loss= get_loss( 800 | mini_batchs, ppo_config) 801 | 802 | # 在实际训练中应该进行反向传播 803 | loss.backward() 804 | optimizer.step() 805 | 806 | # 记录损失 807 | losses.append(loss) 808 | 809 | # 更新批次数据中的损失 810 | ppo_batchs['loss'] = losses 811 | 812 | print(losses) 813 | 814 | 815 | 816 | 817 | -------------------------------------------------------------------------------- /for_ppo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 代码实现ppo" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "先把本教程中的mask忽略,加入了一些mask写的有点乱" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "trl代码中的对于ppo的实现\n", 22 | "https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py\n", 23 | "\n", 24 | "https://mp.weixin.qq.com/s/S72LO26IsZ8AED8sQKIWnQ\n", 25 | "\n", 26 | "讲了PPO loss max https://zhuanlan.zhihu.com/p/28223597805\n", 27 | "\n", 28 | "https://zhuanlan.zhihu.com/p/677607581" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "下面为你解释这些参数的含义:\n", 36 | "\n", 37 | "### 模型架构相关参数\n", 38 | "1. **`vocab_size = 10`**\n", 39 | "词汇表的大小代表了模型能够识别的不同词汇的数量。举例来说,若你正在处理的是一个简单的数字文本任务,其中仅有 0 - 9 这 10 个数字,那么 `vocab_size` 就会被设定为 10。\n", 40 | "\n", 41 | "2. **`hidden_size = 128`**\n", 42 | "隐藏层的维度大小表明了模型中每个隐藏层神经元的数量。在神经网络里,隐藏层会对输入数据进行特征提取与转换。`hidden_size` 越大,模型所能学习到的特征就越复杂,不过这也会使计算量和内存需求增加。\n", 43 | "\n", 44 | "3. **`intermediate_size = 256`**\n", 45 | "在 Transformer 架构里,`intermediate_size` 指的是前馈神经网络(FFN)中间层的维度。FFN 一般由两个线性层构成,中间层的维度通常会比输入输出层的维度大,这样有助于模型学习到更丰富的特征。\n", 46 | "\n", 47 | "4. **`num_hidden_layers = 2`**\n", 48 | "隐藏层的数量意味着模型中堆叠的隐藏层的层数。层数越多,模型的表达能力就越强,能够学习到更复杂的模式,但同时也会增加过拟合的风险以及训练的难度。\n", 49 | "\n", 50 | "5. **`num_attention_heads = 4`**\n", 51 | "注意力头的数量是指在多头注意力机制中并行的注意力头的个数。多头注意力机制能够让模型从不同的表示子空间中捕捉特征,提升模型的表达能力。\n", 52 | "\n", 53 | "6. **`num_key_value_heads = 4`**\n", 54 | "键值对注意力头的数量在某些改进的注意力机制中会用到,它决定了用于计算键(key)和值(value)的注意力头的数量。在标准的多头注意力机制里,`num_key_value_heads` 通常和 `num_attention_heads` 相等。\n", 55 | "\n", 56 | "### 数据处理和生成相关参数\n", 57 | "7. **`batch_size = 5`**\n", 58 | "批量大小代表了在一次训练或者推理过程中同时处理的样本数量。使用较大的批量大小能够提升训练效率,但会增加内存的需求;而较小的批量大小则可以减少内存使用,但会使训练速度变慢。\n", 59 | "\n", 60 | "8. **`length_x = 5`**\n", 61 | "输入序列的长度指的是每个输入样本的长度。在处理文本时,它代表的是输入文本中词元(token)的数量。\n", 62 | "\n", 63 | "9. **`max_new_tokens = 5`**\n", 64 | "最大新生成的词元数量表示在文本生成任务中,模型最多可以生成的词元数量。例如在文本续写任务里,这个参数会限制模型生成的文本长度。 " 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 2, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "vocab_size = 10 #当前教程实际使用的时候是词汇表实际大小\n", 74 | "hidden_size = 128\n", 75 | "intermediate_size = 256\n", 76 | "num_hidden_layers = 2\n", 77 | "num_attention_heads = 4\n", 78 | "batch_size = 3\n", 79 | "length_x = 5\n", 80 | "max_new_tokens = 5" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## 初始化actor模型\n", 88 | "\n", 89 | "以GPT2为例,初始化模型" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stderr", 99 | "output_type": "stream", 100 | "text": [ 101 | "/opt/anaconda3/envs/llm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 102 | " from .autonotebook import tqdm as notebook_tqdm\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "import torch\n", 108 | "from transformers import GPT2Config, GPT2LMHeadModel\n", 109 | "\n", 110 | "torch.manual_seed(1)\n", 111 | "\n", 112 | "# 定义参数\n", 113 | "vocab_size = 10\n", 114 | "hidden_size = 128\n", 115 | "intermediate_size = 256\n", 116 | "num_hidden_layers = 2\n", 117 | "num_attention_heads = 4\n", 118 | "\n", 119 | "# 加载模型配置\n", 120 | "config = GPT2Config(\n", 121 | " vocab_size=50257,\n", 122 | " n_embd=hidden_size,\n", 123 | " n_inner=intermediate_size,\n", 124 | " n_layer=num_hidden_layers,\n", 125 | " n_head=num_attention_heads\n", 126 | ")\n", 127 | "\n", 128 | "# 初始化 GPT - 2 模型\n", 129 | "model = GPT2LMHeadModel(config)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## model generate\n", 137 | "\n", 138 | "主要看下inputs_ids和attention_mask的含义" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "### inputs_ids\n", 146 | "\n", 147 | "input_ids:它是一个张量(tensor),表示文本被分词后每个词(token)对应的 ID。比如在第一行 [20015, 232, 25465, ...] 中,每个数字都是原文本中一个词被 GPT - 2 分词器转换后的唯一标识。不同模型的词表不同,这些 ID 对应的具体词汇也不一样。这里第一行可能对应一句中文文本分词结果,第二行 [14150, 257, 922, ...] 前半部分对应英文文本,后半部分 50256 一般是填充值 ,表示补齐固定长度。\n", 148 | "\n", 149 | "\n", 150 | "attention_mask:同样是张量,用于指示哪些位置是有效的词(值为 1),哪些位置是填充的(值为 0) 。比如第二行 [1, 1, 1, 1, 0, 0, 0, 0, 0, 0] 表示前 4 个词是有效输入,后面是填充的,模型在处理时会忽略填充位置。" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "inputs_ids可以认为是要输入的文本经过tokenizer处理后的结果,而attention_mask则是用于指示哪些位置是有效的词(值为 1),哪些位置是填充的(值为 0) 。" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 4, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "{'input_ids': tensor([[20015, 232, 25465, 25465, 36365, 242, 38834, 165, 242, 247],\n", 170 | " [14150, 257, 922, 1110, 50256, 50256, 50256, 50256, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", 171 | " [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])}\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "from transformers import GPT2Tokenizer\n", 177 | "import torch\n", 178 | "\n", 179 | "# 初始化 GPT - 2 分词器\n", 180 | "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n", 181 | "# 设置padding token\n", 182 | "tokenizer.pad_token = tokenizer.eos_token # 使用EOS token作为padding token\n", 183 | "\n", 184 | "# 输入文本\n", 185 | "inputs = ['今天天气不错', 'have a good day']\n", 186 | "\n", 187 | "# 对输入进行分词处理\n", 188 | "inputs = tokenizer(inputs, return_tensors='pt',padding=True, truncation=True)\n", 189 | "\n", 190 | "print(inputs)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 5, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stderr", 200 | "output_type": "stream", 201 | "text": [ 202 | "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", 203 | "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n", 204 | "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", 205 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 206 | ] 207 | }, 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | "tensor([[20015, 232, 25465, 25465, 36365, 242, 38834, 165, 242, 247,\n", 213 | " 247, 247, 247, 247, 247],\n", 214 | " [14150, 257, 922, 1110, 50256, 50256, 50256, 50256, 50256, 50256,\n", 215 | " 50256, 50256, 50256, 50256, 50256]])\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "output_ids = model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)\n", 221 | "print(output_ids)\n" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 6, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "['今天天气不错�����', 'have a good day']\n" 234 | ] 235 | } 236 | ], 237 | "source": [ 238 | "output_ids = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n", 239 | "print(output_ids)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "填充左边和右边会导致input_ids中padding_id的位置不一样,导致attention_mask中padding_id的位置不一样,导致模型在处理时会忽略填充位置。" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 7, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stderr", 256 | "output_type": "stream", 257 | "text": [ 258 | "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", 259 | "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n" 260 | ] 261 | }, 262 | { 263 | "name": "stdout", 264 | "output_type": "stream", 265 | "text": [ 266 | "{'input_ids': tensor([[20015, 232, 25465, 25465, 36365, 242, 38834, 165, 242, 247],\n", 267 | " [50256, 50256, 50256, 50256, 50256, 50256, 14150, 257, 922, 1110]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", 268 | " [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])}\n", 269 | "tensor([[20015, 232, 25465, 25465, 36365, 242, 38834, 165, 242, 247,\n", 270 | " 247, 247, 247, 247, 247],\n", 271 | " [50256, 50256, 50256, 50256, 50256, 50256, 14150, 257, 922, 1110,\n", 272 | " 1110, 1110, 1110, 1110, 1110]])\n", 273 | "['今天天气不错�����', 'have a good day day day day day day']\n" 274 | ] 275 | } 276 | ], 277 | "source": [ 278 | "tokenizer.padding_side = 'left'\n", 279 | "inputs = ['今天天气不错', 'have a good day']\n", 280 | "inputs = tokenizer(inputs, return_tensors='pt',padding=True, truncation=True)\n", 281 | "\n", 282 | "print(inputs)\n", 283 | "\n", 284 | "output_ids = model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)\n", 285 | "\n", 286 | "print(output_ids)\n", 287 | "\n", 288 | "output_ids = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n", 289 | "print(output_ids)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "# 现在开始正式讲rlhf流程" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": {}, 302 | "source": [ 303 | "## 初始化reward model" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "根据之前的定义,奖励模型可以从模型的输出中提取出最后一个token的隐藏状态,然后通过一个线性层计算奖励。" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "假设batch_size = 2, sequence_length = 4\n", 318 | "input_ids = torch.tensor([\n", 319 | " [1, 2, 3, 4], # 第一个序列\n", 320 | " [5, 6, 7, 8] # 第二个序列\n", 321 | "])\n", 322 | "\n", 323 | "attention_mask = torch.tensor([\n", 324 | " [1, 1, 1, 0], # 第一个序列有效长度为3\n", 325 | " [1, 1, 1, 1] # 第二个序列有效长度为4\n", 326 | "])\n", 327 | "\n", 328 | "sequence_length = attention_mask.sum(dim=1).long() - 1\n", 329 | "\n", 330 | "结果: tensor([2, 3])\n", 331 | "\n", 332 | "第一个序列:3-1=2(索引从0开始)\n", 333 | "\n", 334 | "第二个序列:4-1=3\n", 335 | "\n", 336 | "batch_indices = torch.arange(batch_size)\n", 337 | "\n", 338 | "结果: tensor([0, 1])\n", 339 | "\n", 340 | "假设hidden_size = 2\n", 341 | "\n", 342 | "last_hidden_state = torch.tensor([\n", 343 | " [[1.0, 1.1], [2.0, 2.1], [3.0, 3.1], [4.0, 4.1]], # 第一个序列\n", 344 | " [[5.0, 5.1], [6.0, 6.1], [7.0, 7.1], [8.0, 8.1]] # 第二个序列\n", 345 | "])\n", 346 | "\n", 347 | "使用batch_indices和sequence_length提取\n", 348 | "\n", 349 | "result = last_hidden_state[batch_indices, sequence_length]\n", 350 | "\n", 351 | "结果: tensor([[3.0, 3.1], # 第一个序列的第2个位置(索引从0开始)\n", 352 | "\n", 353 | "[8.0, 8.1]]) # 第二个序列的第3个位置" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 8, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "class GPTRewardModel(torch.nn.Module):\n", 363 | " def __init__(self, gpt_model, reward_head):\n", 364 | " super(GPTRewardModel, self).__init__()\n", 365 | " self.gpt_model = gpt_model\n", 366 | " self.reward_head = reward_head\n", 367 | " \n", 368 | " def forward(self, input_ids, attention_mask):\n", 369 | " # 获取模型的输出\n", 370 | " outputs = self.gpt_model(input_ids=input_ids, attention_mask=attention_mask)\n", 371 | " # 通常取最后一个隐藏状态作为输出\n", 372 | " last_hidden_state = outputs.hidden_states[-1]\n", 373 | " batch_size = input_ids.shape[0]\n", 374 | " # 确保sequence_length是long类型\n", 375 | " sequence_length = attention_mask.sum(dim=1).long() - 1\n", 376 | " # 使用torch.arange并确保在正确的设备上\n", 377 | " batch_indices = torch.arange(batch_size, device=input_ids.device).long()\n", 378 | " last_hidden_state = last_hidden_state[batch_indices, sequence_length]\n", 379 | " print(f\"last_hidden_state shape: {last_hidden_state.shape}, sequence_length: {sequence_length.shape}\")\n", 380 | " # 计算奖励\n", 381 | " rewards = self.reward_head(last_hidden_state)\n", 382 | " return rewards\n", 383 | "\n", 384 | "# 重新初始化模型\n", 385 | "model.config.output_hidden_states = True\n", 386 | "rm_model = GPTRewardModel(model, torch.nn.Linear(hidden_size, 1)) ## 这里的reward_head是一个线性层,将最后一个隐藏状态映射到奖励值" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 9, 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "data": { 396 | "text/plain": [ 397 | "tensor([[20015, 232, 25465, 25465, 36365, 242, 38834, 165, 242, 247],\n", 398 | " [50256, 50256, 50256, 50256, 50256, 50256, 14150, 257, 922, 1110]])" 399 | ] 400 | }, 401 | "execution_count": 9, 402 | "metadata": {}, 403 | "output_type": "execute_result" 404 | } 405 | ], 406 | "source": [ 407 | "inputs['input_ids']" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 10, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "name": "stdout", 417 | "output_type": "stream", 418 | "text": [ 419 | "last_hidden_state shape: torch.Size([2, 128]), sequence_length: torch.Size([2])\n", 420 | "tensor([[-0.1647],\n", 421 | " [-0.2839]], grad_fn=)\n" 422 | ] 423 | } 424 | ], 425 | "source": [ 426 | "reward = rm_model(inputs['input_ids'], inputs['attention_mask'])\n", 427 | "print(reward)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": {}, 433 | "source": [ 434 | "## 简化版ppo\n", 435 | "从以上过程可以看出,我们输入给模型的其实是input_ids和attention_mask,所以我们现在为了展示方便,构造一个没有实际意义的输入,输入给模型,然后输出奖励。" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 11, 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "prompt = torch.randint(0, vocab_size, (batch_size, length_x))\n", 445 | "response = torch.randint(0, vocab_size, (batch_size, length_x + max_new_tokens))" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 12, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "name": "stdout", 455 | "output_type": "stream", 456 | "text": [ 457 | "tensor([[5, 0, 0, 1, 0],\n", 458 | " [4, 8, 1, 4, 1],\n", 459 | " [9, 6, 7, 0, 5]])\n", 460 | "tensor([[4, 8, 5, 2, 9, 5, 5, 0, 6, 3],\n", 461 | " [0, 3, 0, 4, 8, 2, 6, 4, 9, 3],\n", 462 | " [2, 6, 7, 5, 0, 0, 3, 3, 4, 8]])\n" 463 | ] 464 | } 465 | ], 466 | "source": [ 467 | "print(prompt)\n", 468 | "print(response)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "我们希望让模型只关注response,所以对prompt对应的mask置为0" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 13, 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | "tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 488 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 489 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]])\n" 490 | ] 491 | } 492 | ], 493 | "source": [ 494 | "attention_mask = torch.ones(batch_size, length_x+max_new_tokens)\n", 495 | "attention_mask[:, :length_x] = 0\n", 496 | "print(attention_mask)\n" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 14, 502 | "metadata": {}, 503 | "outputs": [ 504 | { 505 | "data": { 506 | "text/plain": [ 507 | "tensor([[1., 1., 1., 1., 1.],\n", 508 | " [1., 1., 1., 1., 1.],\n", 509 | " [1., 1., 1., 1., 1.]])" 510 | ] 511 | }, 512 | "execution_count": 14, 513 | "metadata": {}, 514 | "output_type": "execute_result" 515 | } 516 | ], 517 | "source": [ 518 | "prompt_attention_mask = torch.ones(batch_size, length_x)\n", 519 | "prompt_attention_mask" 520 | ] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "metadata": {}, 525 | "source": [ 526 | "创建几个模型\n", 527 | "\n", 528 | "\n", 529 | "model_ref 和model的配置一样\n", 530 | "\n", 531 | "reward model和value model的配置大体一样\n", 532 | "\n", 533 | "value model的输出是所有token的隐藏状态所得到的value" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": 15, 539 | "metadata": {}, 540 | "outputs": [ 541 | { 542 | "name": "stderr", 543 | "output_type": "stream", 544 | "text": [ 545 | "/opt/anaconda3/envs/llm/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:774: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_hidden_states` is. When `return_dict_in_generate` is not `True`, `output_hidden_states` is ignored.\n", 546 | " warnings.warn(\n" 547 | ] 548 | } 549 | ], 550 | "source": [ 551 | "# 初始化 GPT - 2 模型\n", 552 | "model_ref = GPT2LMHeadModel(config)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "markdown", 557 | "metadata": {}, 558 | "source": [ 559 | "查看区别" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 16, 565 | "metadata": {}, 566 | "outputs": [ 567 | { 568 | "name": "stdout", 569 | "output_type": "stream", 570 | "text": [ 571 | "GPT2LMHeadModel(\n", 572 | " (transformer): GPT2Model(\n", 573 | " (wte): Embedding(50257, 128)\n", 574 | " (wpe): Embedding(1024, 128)\n", 575 | " (drop): Dropout(p=0.1, inplace=False)\n", 576 | " (h): ModuleList(\n", 577 | " (0-1): 2 x GPT2Block(\n", 578 | " (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 579 | " (attn): GPT2SdpaAttention(\n", 580 | " (c_attn): Conv1D(nf=384, nx=128)\n", 581 | " (c_proj): Conv1D(nf=128, nx=128)\n", 582 | " (attn_dropout): Dropout(p=0.1, inplace=False)\n", 583 | " (resid_dropout): Dropout(p=0.1, inplace=False)\n", 584 | " )\n", 585 | " (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 586 | " (mlp): GPT2MLP(\n", 587 | " (c_fc): Conv1D(nf=256, nx=128)\n", 588 | " (c_proj): Conv1D(nf=128, nx=256)\n", 589 | " (act): NewGELUActivation()\n", 590 | " (dropout): Dropout(p=0.1, inplace=False)\n", 591 | " )\n", 592 | " )\n", 593 | " )\n", 594 | " (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 595 | " )\n", 596 | " (lm_head): Linear(in_features=128, out_features=50257, bias=False)\n", 597 | ")\n", 598 | "GPT2LMHeadModel(\n", 599 | " (transformer): GPT2Model(\n", 600 | " (wte): Embedding(50257, 128)\n", 601 | " (wpe): Embedding(1024, 128)\n", 602 | " (drop): Dropout(p=0.1, inplace=False)\n", 603 | " (h): ModuleList(\n", 604 | " (0-1): 2 x GPT2Block(\n", 605 | " (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 606 | " (attn): GPT2SdpaAttention(\n", 607 | " (c_attn): Conv1D(nf=384, nx=128)\n", 608 | " (c_proj): Conv1D(nf=128, nx=128)\n", 609 | " (attn_dropout): Dropout(p=0.1, inplace=False)\n", 610 | " (resid_dropout): Dropout(p=0.1, inplace=False)\n", 611 | " )\n", 612 | " (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 613 | " (mlp): GPT2MLP(\n", 614 | " (c_fc): Conv1D(nf=256, nx=128)\n", 615 | " (c_proj): Conv1D(nf=128, nx=256)\n", 616 | " (act): NewGELUActivation()\n", 617 | " (dropout): Dropout(p=0.1, inplace=False)\n", 618 | " )\n", 619 | " )\n", 620 | " )\n", 621 | " (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 622 | " )\n", 623 | " (lm_head): Linear(in_features=128, out_features=50257, bias=False)\n", 624 | ")\n" 625 | ] 626 | } 627 | ], 628 | "source": [ 629 | "print(model_ref)\n", 630 | "print(model)" 631 | ] 632 | }, 633 | { 634 | "cell_type": "markdown", 635 | "metadata": {}, 636 | "source": [ 637 | "## 初始化value model" 638 | ] 639 | }, 640 | { 641 | "cell_type": "markdown", 642 | "metadata": {}, 643 | "source": [ 644 | "假设我们有以下维度的数据:\n", 645 | "\n", 646 | "last_hidden_state 的形状是 [batch_size, sequence_length, hidden_size]\n", 647 | "\n", 648 | "比如 [5, 10, 128],表示批次大小为5,序列长度为10,隐藏层维度为128\n", 649 | "\n", 650 | "self.value_head 是一个线性层 Linear(hidden_size, 1)\n", 651 | "\n", 652 | "输入维度是128,输出维度是1\n", 653 | "\n", 654 | "处理过程:\n", 655 | "\n", 656 | "self.value_head(last_hidden_state) 的操作:\n", 657 | "\n", 658 | "输入: [5, 10, 128]\n", 659 | "\n", 660 | "输出: [5, 10, 1] # 线性层将最后一个维度从128转换为1\n", 661 | "\n", 662 | "[:, :, 0] 的操作:\n", 663 | "\n", 664 | "取最后一个维度的第0个元素\n", 665 | "\n", 666 | "结果形状变为: [5, 10]" 667 | ] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "execution_count": 17, 672 | "metadata": {}, 673 | "outputs": [], 674 | "source": [ 675 | "class GPTValueModel(torch.nn.Module):\n", 676 | " def __init__(self, gpt_model, value_head):\n", 677 | " super().__init__()\n", 678 | " self.gpt_model = gpt_model\n", 679 | " self.value_head = value_head\n", 680 | " \n", 681 | " def forward(self, input_ids, attention_mask):\n", 682 | " outputs = self.gpt_model(input_ids=input_ids, attention_mask=attention_mask)\n", 683 | " last_hidden_state = outputs.hidden_states[-1]\n", 684 | " values = self.value_head(last_hidden_state)[:, :, 0]\n", 685 | " return values\n", 686 | " \n", 687 | "model.config.output_hidden_states = True\n", 688 | "vm_model = GPTValueModel(model,torch.nn.Linear(hidden_size, 1))" 689 | ] 690 | }, 691 | { 692 | "cell_type": "code", 693 | "execution_count": 18, 694 | "metadata": {}, 695 | "outputs": [ 696 | { 697 | "name": "stdout", 698 | "output_type": "stream", 699 | "text": [ 700 | "GPTRewardModel(\n", 701 | " (gpt_model): GPT2LMHeadModel(\n", 702 | " (transformer): GPT2Model(\n", 703 | " (wte): Embedding(50257, 128)\n", 704 | " (wpe): Embedding(1024, 128)\n", 705 | " (drop): Dropout(p=0.1, inplace=False)\n", 706 | " (h): ModuleList(\n", 707 | " (0-1): 2 x GPT2Block(\n", 708 | " (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 709 | " (attn): GPT2SdpaAttention(\n", 710 | " (c_attn): Conv1D(nf=384, nx=128)\n", 711 | " (c_proj): Conv1D(nf=128, nx=128)\n", 712 | " (attn_dropout): Dropout(p=0.1, inplace=False)\n", 713 | " (resid_dropout): Dropout(p=0.1, inplace=False)\n", 714 | " )\n", 715 | " (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 716 | " (mlp): GPT2MLP(\n", 717 | " (c_fc): Conv1D(nf=256, nx=128)\n", 718 | " (c_proj): Conv1D(nf=128, nx=256)\n", 719 | " (act): NewGELUActivation()\n", 720 | " (dropout): Dropout(p=0.1, inplace=False)\n", 721 | " )\n", 722 | " )\n", 723 | " )\n", 724 | " (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 725 | " )\n", 726 | " (lm_head): Linear(in_features=128, out_features=50257, bias=False)\n", 727 | " )\n", 728 | " (reward_head): Linear(in_features=128, out_features=1, bias=True)\n", 729 | ")\n", 730 | "GPTValueModel(\n", 731 | " (gpt_model): GPT2LMHeadModel(\n", 732 | " (transformer): GPT2Model(\n", 733 | " (wte): Embedding(50257, 128)\n", 734 | " (wpe): Embedding(1024, 128)\n", 735 | " (drop): Dropout(p=0.1, inplace=False)\n", 736 | " (h): ModuleList(\n", 737 | " (0-1): 2 x GPT2Block(\n", 738 | " (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 739 | " (attn): GPT2SdpaAttention(\n", 740 | " (c_attn): Conv1D(nf=384, nx=128)\n", 741 | " (c_proj): Conv1D(nf=128, nx=128)\n", 742 | " (attn_dropout): Dropout(p=0.1, inplace=False)\n", 743 | " (resid_dropout): Dropout(p=0.1, inplace=False)\n", 744 | " )\n", 745 | " (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 746 | " (mlp): GPT2MLP(\n", 747 | " (c_fc): Conv1D(nf=256, nx=128)\n", 748 | " (c_proj): Conv1D(nf=128, nx=256)\n", 749 | " (act): NewGELUActivation()\n", 750 | " (dropout): Dropout(p=0.1, inplace=False)\n", 751 | " )\n", 752 | " )\n", 753 | " )\n", 754 | " (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 755 | " )\n", 756 | " (lm_head): Linear(in_features=128, out_features=50257, bias=False)\n", 757 | " )\n", 758 | " (value_head): Linear(in_features=128, out_features=1, bias=True)\n", 759 | ")\n" 760 | ] 761 | } 762 | ], 763 | "source": [ 764 | "print(rm_model)\n", 765 | "print(vm_model)" 766 | ] 767 | }, 768 | { 769 | "cell_type": "markdown", 770 | "metadata": {}, 771 | "source": [ 772 | "## ppo前向过程" 773 | ] 774 | }, 775 | { 776 | "cell_type": "markdown", 777 | "metadata": {}, 778 | "source": [ 779 | "创建几个model的函数" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": 19, 785 | "metadata": {}, 786 | "outputs": [], 787 | "source": [ 788 | "def get_response(model, prompt, max_new_tokens, attention_mask):\n", 789 | " inputs = {'input_ids': prompt, 'attention_mask': attention_mask} # ignore mask,好像不需要mask\n", 790 | " y = model.generate(**inputs,\n", 791 | " max_new_tokens=max_new_tokens,\n", 792 | " # forced_eos_token_id=True\n", 793 | " )\n", 794 | " return y\n", 795 | "\n", 796 | "def get_reward(model, response, attention_mask):\n", 797 | " inputs = {'input_ids': response, 'attention_mask': attention_mask} # ignore mask\n", 798 | " y = model(inputs['input_ids'], inputs['attention_mask'])\n", 799 | " return y\n", 800 | "\n", 801 | "def get_value(model, prompt, attention_mask):\n", 802 | " inputs = {'input_ids': prompt, 'attention_mask': attention_mask} # ignore mask\n", 803 | " y = model(inputs['input_ids'], inputs['attention_mask'])\n", 804 | " return y" 805 | ] 806 | }, 807 | { 808 | "cell_type": "code", 809 | "execution_count": 20, 810 | "metadata": {}, 811 | "outputs": [ 812 | { 813 | "data": { 814 | "text/plain": [ 815 | "tensor([[5, 0, 0, 1, 0],\n", 816 | " [4, 8, 1, 4, 1],\n", 817 | " [9, 6, 7, 0, 5]])" 818 | ] 819 | }, 820 | "execution_count": 20, 821 | "metadata": {}, 822 | "output_type": "execute_result" 823 | } 824 | ], 825 | "source": [ 826 | "prompt" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": 21, 832 | "metadata": {}, 833 | "outputs": [ 834 | { 835 | "data": { 836 | "text/plain": [ 837 | "tensor([[4, 8, 5, 2, 9, 5, 5, 0, 6, 3],\n", 838 | " [0, 3, 0, 4, 8, 2, 6, 4, 9, 3],\n", 839 | " [2, 6, 7, 5, 0, 0, 3, 3, 4, 8]])" 840 | ] 841 | }, 842 | "execution_count": 21, 843 | "metadata": {}, 844 | "output_type": "execute_result" 845 | } 846 | ], 847 | "source": [ 848 | "response" 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": 22, 854 | "metadata": {}, 855 | "outputs": [ 856 | { 857 | "data": { 858 | "text/plain": [ 859 | "tensor([[1., 1., 1., 1., 1.],\n", 860 | " [1., 1., 1., 1., 1.],\n", 861 | " [1., 1., 1., 1., 1.]])" 862 | ] 863 | }, 864 | "execution_count": 22, 865 | "metadata": {}, 866 | "output_type": "execute_result" 867 | } 868 | ], 869 | "source": [ 870 | "prompt_attention_mask" 871 | ] 872 | }, 873 | { 874 | "cell_type": "code", 875 | "execution_count": 23, 876 | "metadata": {}, 877 | "outputs": [ 878 | { 879 | "data": { 880 | "text/plain": [ 881 | "tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 882 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 883 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]])" 884 | ] 885 | }, 886 | "execution_count": 23, 887 | "metadata": {}, 888 | "output_type": "execute_result" 889 | } 890 | ], 891 | "source": [ 892 | "attention_mask" 893 | ] 894 | }, 895 | { 896 | "cell_type": "markdown", 897 | "metadata": {}, 898 | "source": [ 899 | "在这里就可以看到,ppo流程中的reward只是在最后一个token上得到的,但是我的value model要在每一个token上得到一个价值" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": 25, 905 | "metadata": {}, 906 | "outputs": [ 907 | { 908 | "name": "stderr", 909 | "output_type": "stream", 910 | "text": [ 911 | "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n" 912 | ] 913 | }, 914 | { 915 | "name": "stdout", 916 | "output_type": "stream", 917 | "text": [ 918 | "tensor([[ 5, 0, 0, 1, 0, 0, 0, 0, 0, 0],\n", 919 | " [ 4, 8, 1, 4, 1, 10998, 10998, 10998, 10998, 10998],\n", 920 | " [ 9, 6, 7, 0, 5, 5, 5, 5, 5, 5]])\n", 921 | "last_hidden_state shape: torch.Size([3, 128]), sequence_length: torch.Size([3])\n", 922 | "tensor([[-0.4702],\n", 923 | " [-1.0223],\n", 924 | " [-0.6396]], grad_fn=)\n", 925 | "tensor([[ 0.1054, -0.1810, -0.2179, -0.4633, -0.1662, 0.0374, -0.7071, -0.7640,\n", 926 | " -1.3427, 0.2779],\n", 927 | " [ 0.0424, -0.0425, -1.1631, -0.1351, 0.2049, 0.0207, -0.9090, 0.4028,\n", 928 | " -0.1427, 0.6911],\n", 929 | " [ 0.1912, -0.2840, 0.1110, 0.6809, -0.4596, -0.1590, -0.2637, -0.3191,\n", 930 | " -0.1446, 0.9440]], grad_fn=)\n" 931 | ] 932 | } 933 | ], 934 | "source": [ 935 | "print(get_response(model, prompt, max_new_tokens, prompt_attention_mask))\n", 936 | "print(get_reward(rm_model, response, attention_mask))\n", 937 | "print(get_value(vm_model, response, attention_mask))\n" 938 | ] 939 | }, 940 | { 941 | "cell_type": "markdown", 942 | "metadata": {}, 943 | "source": [ 944 | "PPO 相关设置" 945 | ] 946 | }, 947 | { 948 | "cell_type": "markdown", 949 | "metadata": {}, 950 | "source": [ 951 | "封装几个ppo的model" 952 | ] 953 | }, 954 | { 955 | "cell_type": "code", 956 | "execution_count": 26, 957 | "metadata": {}, 958 | "outputs": [], 959 | "source": [ 960 | "class PPOModels():\n", 961 | " def __init__(self, model_actor, model_ref, model_rm, model_critic):\n", 962 | " self.actor = model_actor\n", 963 | " self.ref = model_ref\n", 964 | " self.rm = model_rm\n", 965 | " self.critic = model_critic\n", 966 | "\n", 967 | "\n", 968 | "model_ref.eval()\n", 969 | "rm_model.eval()\n", 970 | "models = PPOModels(model, model_ref, rm_model, vm_model)\n" 971 | ] 972 | }, 973 | { 974 | "cell_type": "markdown", 975 | "metadata": {}, 976 | "source": [ 977 | "设置ppo的超参数" 978 | ] 979 | }, 980 | { 981 | "cell_type": "markdown", 982 | "metadata": {}, 983 | "source": [ 984 | "1. ppo_epochs在每次策略更新时,PPO 算法对收集到的数据进行迭代训练的次数。\n", 985 | "\n", 986 | "2. mini_batch_size每个训练步骤中,从收集到的数据里选取的小批量数据的样本数量。\n", 987 | "\n", 988 | "3. epochs整个训练过程中,算法对所有收集到的数据进行完整遍历的次数。\n", 989 | "\n", 990 | "4. kl_ctlKL 散度惩罚项的系数,用于控制新旧策略之间的差异程度。\n", 991 | "\n", 992 | "5. vf_coef价值函数损失的系数,用于平衡策略损失和价值函数损失在总损失中的权重。\n", 993 | "\n", 994 | "6. lam广义优势估计(GAE)中的 \\(\\lambda\\) 参数,用于平衡优势估计的偏差和方差。\n", 995 | "\n", 996 | "7. gamma折扣因子,用于计算未来奖励的折现值,决定未来奖励在当前价值估计中的重要程度。\n", 997 | "\n", 998 | "8. cliprange_value价值函数裁剪范围的参数,用于限制价值函数更新的幅度" 999 | ] 1000 | }, 1001 | { 1002 | "cell_type": "code", 1003 | "execution_count": 27, 1004 | "metadata": {}, 1005 | "outputs": [], 1006 | "source": [ 1007 | "class PPOConfig():\n", 1008 | " def __init__(self):\n", 1009 | " self.ppo_epochs = 5\n", 1010 | " self.mini_batch_size = 2\n", 1011 | " self.epochs = 4\n", 1012 | " self.kl_ctl = 0.1\n", 1013 | " self.vf_coef = 0.1\n", 1014 | " self.lam = 0.9\n", 1015 | " self.gamma = 0.9\n", 1016 | " self.cliprange_value = 0.2\n", 1017 | "\n", 1018 | " def __str__(self):\n", 1019 | " return f'ppo_epochs:{self.ppo_epochs}\\nmini_batch_size:{self.mini_batch_size}\\nepochs:{self.epochs}\\nkl_ctl:{self.kl_ctl}'\n", 1020 | "\n", 1021 | "\n", 1022 | "ppo_config = PPOConfig()" 1023 | ] 1024 | }, 1025 | { 1026 | "cell_type": "markdown", 1027 | "metadata": {}, 1028 | "source": [ 1029 | "在每一步中ppo都在干什么" 1030 | ] 1031 | }, 1032 | { 1033 | "cell_type": "markdown", 1034 | "metadata": {}, 1035 | "source": [ 1036 | "首先要有个列表来记录每一步的采样" 1037 | ] 1038 | }, 1039 | { 1040 | "cell_type": "code", 1041 | "execution_count": 28, 1042 | "metadata": {}, 1043 | "outputs": [], 1044 | "source": [ 1045 | "ppo_old_batchs = {\n", 1046 | " 'prompt': None,\n", 1047 | " 'response': None,\n", 1048 | " 'mask': None,\n", 1049 | " 'logprobs_ref': None,\n", 1050 | " 'logprobs_old': None,\n", 1051 | " 'logprobs': None,\n", 1052 | " 'values_old': None,\n", 1053 | " 'values': None,\n", 1054 | " 'rewards': None,\n", 1055 | " 'rewards_kl': None,\n", 1056 | " 'loss': None,\n", 1057 | " 'logits': None,\n", 1058 | "}\n", 1059 | "\n", 1060 | "ppo_old_batchs['prompt'] = prompt\n", 1061 | "ppo_old_batchs['response'] = response\n", 1062 | "ppo_old_batchs['mask'] = attention_mask" 1063 | ] 1064 | }, 1065 | { 1066 | "cell_type": "code", 1067 | "execution_count": 29, 1068 | "metadata": {}, 1069 | "outputs": [ 1070 | { 1071 | "data": { 1072 | "text/plain": [ 1073 | "{'prompt': tensor([[5, 0, 0, 1, 0],\n", 1074 | " [4, 8, 1, 4, 1],\n", 1075 | " [9, 6, 7, 0, 5]]),\n", 1076 | " 'response': tensor([[4, 8, 5, 2, 9, 5, 5, 0, 6, 3],\n", 1077 | " [0, 3, 0, 4, 8, 2, 6, 4, 9, 3],\n", 1078 | " [2, 6, 7, 5, 0, 0, 3, 3, 4, 8]]),\n", 1079 | " 'mask': tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 1080 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 1081 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]),\n", 1082 | " 'logprobs_ref': None,\n", 1083 | " 'logprobs_old': None,\n", 1084 | " 'logprobs': None,\n", 1085 | " 'values_old': None,\n", 1086 | " 'values': None,\n", 1087 | " 'rewards': None,\n", 1088 | " 'rewards_kl': None,\n", 1089 | " 'loss': None,\n", 1090 | " 'logits': None}" 1091 | ] 1092 | }, 1093 | "execution_count": 29, 1094 | "metadata": {}, 1095 | "output_type": "execute_result" 1096 | } 1097 | ], 1098 | "source": [ 1099 | "ppo_old_batchs" 1100 | ] 1101 | }, 1102 | { 1103 | "cell_type": "markdown", 1104 | "metadata": {}, 1105 | "source": [ 1106 | "前向推理,得到token的logprobs" 1107 | ] 1108 | }, 1109 | { 1110 | "cell_type": "markdown", 1111 | "metadata": {}, 1112 | "source": [ 1113 | "logprobs = F.log_softmax(logits, dim=-1)第一步:对logits进行softmax并取log\n", 1114 | "\n", 1115 | "torch.gather是一个用于从张量中按索引收集值的操作 \n", 1116 | "\n", 1117 | "假设我们有:\n", 1118 | "\n", 1119 | "logp.shape = [1, 5, 32] # [batch_size, seq_len, vocab_size]\n", 1120 | "\n", 1121 | "labels.shape = [1, 5] # [batch_size, seq_len]\n", 1122 | "\n", 1123 | "1. labels.unsqueeze(2)\n", 1124 | "\n", 1125 | "在最后增加一个维度\n", 1126 | "\n", 1127 | "labels_expanded = labels.unsqueeze(2) # shape变为[1, 5, 1]\n", 1128 | "\n", 1129 | "2. torch.gather(logp, 2, labels_expanded)\n", 1130 | "\n", 1131 | "dim=2表示在词表维度(第3维)上收集值\n", 1132 | "\n", 1133 | "gathered = torch.gather(logp, 2, labels_expanded) # shape为[1, 5, 1]\n", 1134 | "\n", 1135 | "3. squeeze(-1)\n", 1136 | "\n", 1137 | "去掉最后一个维度\n", 1138 | "\n", 1139 | "logpy = gathered.squeeze(-1) # 最终shape为[1, 5]" 1140 | ] 1141 | }, 1142 | { 1143 | "cell_type": "code", 1144 | "execution_count": 31, 1145 | "metadata": {}, 1146 | "outputs": [ 1147 | { 1148 | "name": "stdout", 1149 | "output_type": "stream", 1150 | "text": [ 1151 | "inputs_ids shape: torch.Size([3, 10])\n", 1152 | "logits shape: torch.Size([3, 10, 50257])\n", 1153 | "logits shape: torch.Size([3, 10, 50257]), response shape: torch.Size([3, 10]), attention_mask shape: torch.Size([3, 10])\n", 1154 | "all_token_logprobs shape: torch.Size([3, 10, 50257])\n", 1155 | "gathered shape: torch.Size([3, 10, 1]), response shape: torch.Size([3, 10])\n", 1156 | "response_logprobs shape: torch.Size([3, 10])\n", 1157 | "\n", 1158 | "\n", 1159 | "inputs_ids shape: torch.Size([3, 10])\n", 1160 | "logits shape: torch.Size([3, 10, 50257])\n", 1161 | "logits shape: torch.Size([3, 10, 50257]), response shape: torch.Size([3, 10]), attention_mask shape: torch.Size([3, 10])\n", 1162 | "all_token_logprobs shape: torch.Size([3, 10, 50257])\n", 1163 | "gathered shape: torch.Size([3, 10, 1]), response shape: torch.Size([3, 10])\n", 1164 | "response_logprobs shape: torch.Size([3, 10])\n", 1165 | "\n", 1166 | "\n", 1167 | "inputs_ids shape: torch.Size([3, 10])\n", 1168 | "logits shape: torch.Size([3, 10, 50257])\n", 1169 | "logits shape: torch.Size([3, 10, 50257]), response shape: torch.Size([3, 10]), attention_mask shape: torch.Size([3, 10])\n", 1170 | "all_token_logprobs shape: torch.Size([3, 10, 50257])\n", 1171 | "gathered shape: torch.Size([3, 10, 1]), response shape: torch.Size([3, 10])\n", 1172 | "response_logprobs shape: torch.Size([3, 10])\n", 1173 | "torch.Size([3, 10])\n", 1174 | "torch.Size([3, 10])\n", 1175 | "torch.Size([3, 10])\n" 1176 | ] 1177 | } 1178 | ], 1179 | "source": [ 1180 | "import torch.nn.functional as F\n", 1181 | "\n", 1182 | "def get_logits(model, input_ids):\n", 1183 | " # 得到logits\n", 1184 | " outputs = model(input_ids=input_ids)\n", 1185 | " print(f\"inputs_ids shape: {input_ids.shape}\")\n", 1186 | " logits = outputs.logits\n", 1187 | " print(f\"logits shape: {logits.shape}\")\n", 1188 | " return logits\n", 1189 | "\n", 1190 | "def get_logprobs(model, response, attention_mask):\n", 1191 | " # 得到logprobs\n", 1192 | " logits = get_logits(model, response)\n", 1193 | " print(f\"logits shape: {logits.shape}, response shape: {response.shape}, attention_mask shape: {attention_mask.shape}\")\n", 1194 | " # F.log_softmax() 是先进行softmax运算然后再取对数(log)\n", 1195 | " all_token_logprobs = F.log_softmax(logits, dim=-1)\n", 1196 | " print(f\"all_token_logprobs shape: {all_token_logprobs.shape}\")\n", 1197 | " # 使用torch.gather() 从logprobs中收集response的值\n", 1198 | " gathered = torch.gather(all_token_logprobs, 2, response.unsqueeze(2))\n", 1199 | " print(f\"gathered shape: {gathered.shape}, response shape: {response.shape}\")\n", 1200 | " # 去掉最后一个维度\n", 1201 | " response_logprobs = gathered.squeeze(-1)\n", 1202 | " print(f\"response_logprobs shape: {response_logprobs.shape}\")\n", 1203 | " return response_logprobs\n", 1204 | "\n", 1205 | "logprobs_ref = get_logprobs(models.ref, ppo_old_batchs['response'], ppo_old_batchs['mask'])\n", 1206 | "print('\\n')\n", 1207 | "logprobs_old = get_logprobs(models.actor, ppo_old_batchs['response'], ppo_old_batchs['mask'])\n", 1208 | "print('\\n')\n", 1209 | "logprobs = get_logprobs(models.actor, ppo_old_batchs['response'], ppo_old_batchs['mask'])\n", 1210 | "\n", 1211 | "print(logprobs_ref.shape)\n", 1212 | "print(logprobs_old.shape)\n", 1213 | "print(logprobs.shape) \n" 1214 | ] 1215 | }, 1216 | { 1217 | "cell_type": "code", 1218 | "execution_count": 32, 1219 | "metadata": {}, 1220 | "outputs": [ 1221 | { 1222 | "data": { 1223 | "text/plain": [ 1224 | "torch.Size([3, 10])" 1225 | ] 1226 | }, 1227 | "execution_count": 32, 1228 | "metadata": {}, 1229 | "output_type": "execute_result" 1230 | } 1231 | ], 1232 | "source": [ 1233 | "response.shape" 1234 | ] 1235 | }, 1236 | { 1237 | "cell_type": "code", 1238 | "execution_count": 33, 1239 | "metadata": {}, 1240 | "outputs": [ 1241 | { 1242 | "data": { 1243 | "text/plain": [ 1244 | "tensor([[ -9.6364, -10.0382, -9.4454, -9.7810, -9.3484, -9.5437, -9.6146,\n", 1245 | " -9.3174, -9.8408, -9.5032],\n", 1246 | " [ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 1247 | " -9.6053, -9.3741, -9.4720],\n", 1248 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 1249 | " -9.9353, -9.3172, -9.8445]], grad_fn=)" 1250 | ] 1251 | }, 1252 | "execution_count": 33, 1253 | "metadata": {}, 1254 | "output_type": "execute_result" 1255 | } 1256 | ], 1257 | "source": [ 1258 | "logprobs" 1259 | ] 1260 | }, 1261 | { 1262 | "cell_type": "markdown", 1263 | "metadata": {}, 1264 | "source": [ 1265 | "计算kl" 1266 | ] 1267 | }, 1268 | { 1269 | "cell_type": "code", 1270 | "execution_count": 34, 1271 | "metadata": {}, 1272 | "outputs": [ 1273 | { 1274 | "name": "stdout", 1275 | "output_type": "stream", 1276 | "text": [ 1277 | "tensor([[-0.0130, 0.0095, -0.0262, -0.0021, -0.0283, -0.0148, -0.0134, -0.0258,\n", 1278 | " 0.0089, -0.0307],\n", 1279 | " [-0.0315, -0.0049, -0.0047, -0.0323, 0.0020, -0.0178, 0.0170, -0.0316,\n", 1280 | " -0.0339, -0.0369],\n", 1281 | " [-0.0574, 0.0419, -0.0651, -0.0085, -0.0412, -0.0019, -0.0238, 0.0211,\n", 1282 | " -0.0333, 0.0152]], grad_fn=)\n" 1283 | ] 1284 | } 1285 | ], 1286 | "source": [ 1287 | "def get_kl(logprobs_ref, logprobs_old, kl_ctl):\n", 1288 | " kl = logprobs_ref - logprobs_old\n", 1289 | " kl = kl * kl_ctl\n", 1290 | " return kl\n", 1291 | "\n", 1292 | "kl = get_kl(logprobs_ref, logprobs_old, ppo_config.kl_ctl)\n", 1293 | "print(kl)\n" 1294 | ] 1295 | }, 1296 | { 1297 | "cell_type": "markdown", 1298 | "metadata": {}, 1299 | "source": [ 1300 | "计算reward_kl\n" 1301 | ] 1302 | }, 1303 | { 1304 | "cell_type": "code", 1305 | "execution_count": 35, 1306 | "metadata": {}, 1307 | "outputs": [ 1308 | { 1309 | "name": "stdout", 1310 | "output_type": "stream", 1311 | "text": [ 1312 | "tensor([[-0.0130, 0.0095, -0.0262, -0.0021, -0.0283, -0.0148, -0.0134, -0.0258,\n", 1313 | " 0.0089, -0.0307],\n", 1314 | " [-0.0315, -0.0049, -0.0047, -0.0323, 0.0020, -0.0178, 0.0170, -0.0316,\n", 1315 | " -0.0339, -0.0369],\n", 1316 | " [-0.0574, 0.0419, -0.0651, -0.0085, -0.0412, -0.0019, -0.0238, 0.0211,\n", 1317 | " -0.0333, 0.0152]], grad_fn=)\n", 1318 | "last_hidden_state shape: torch.Size([3, 128]), sequence_length: torch.Size([3])\n", 1319 | "tensor([[-0.7784],\n", 1320 | " [-0.9515],\n", 1321 | " [-0.9003]], grad_fn=)\n", 1322 | "tensor([[-0.0130, 0.0095, -0.0262, -0.0021, -0.0283, -0.0148, -0.0134, -0.0258,\n", 1323 | " 0.0089, -0.8090],\n", 1324 | " [-0.0315, -0.0049, -0.0047, -0.0323, 0.0020, -0.0178, 0.0170, -0.0316,\n", 1325 | " -0.0339, -0.9884],\n", 1326 | " [-0.0574, 0.0419, -0.0651, -0.0085, -0.0412, -0.0019, -0.0238, 0.0211,\n", 1327 | " -0.0333, -0.8852]], grad_fn=)\n" 1328 | ] 1329 | } 1330 | ], 1331 | "source": [ 1332 | "def get_reward_with_kl(logprobs_ref, logprobs_old, kl_ctl, reward):\n", 1333 | " kl = logprobs_ref - logprobs_old\n", 1334 | " kl = kl * kl_ctl\n", 1335 | " kl[:, -1] += reward[:, 0]\n", 1336 | " return kl\n", 1337 | "\n", 1338 | "print(kl)\n", 1339 | "rewards = get_reward(models.rm, ppo_old_batchs['response'], ppo_old_batchs['mask'])\n", 1340 | "print(rewards)\n", 1341 | "\n", 1342 | "kl_reward = get_reward_with_kl(logprobs_ref, logprobs_old, ppo_config.kl_ctl, rewards)\n", 1343 | "print(kl_reward)\n" 1344 | ] 1345 | }, 1346 | { 1347 | "cell_type": "code", 1348 | "execution_count": 36, 1349 | "metadata": {}, 1350 | "outputs": [], 1351 | "source": [ 1352 | "values = get_value(models.critic, ppo_old_batchs['response'], ppo_old_batchs['mask'])" 1353 | ] 1354 | }, 1355 | { 1356 | "cell_type": "code", 1357 | "execution_count": 37, 1358 | "metadata": {}, 1359 | "outputs": [ 1360 | { 1361 | "data": { 1362 | "text/plain": [ 1363 | "tensor([[ 0.1939, -0.0731, -0.0170, -0.4315, 0.0534, -0.2046, -0.6074, -0.7700,\n", 1364 | " -1.2505, 0.1553],\n", 1365 | " [ 0.0511, -0.2098, -0.8512, -0.1117, 0.2560, -0.0967, -0.9718, 0.2660,\n", 1366 | " -0.1777, 0.4735],\n", 1367 | " [ 0.2042, -0.6096, -0.0284, 0.2577, -0.3757, -0.3134, -0.5433, -0.2487,\n", 1368 | " -0.2369, 1.0747]], grad_fn=)" 1369 | ] 1370 | }, 1371 | "execution_count": 37, 1372 | "metadata": {}, 1373 | "output_type": "execute_result" 1374 | } 1375 | ], 1376 | "source": [ 1377 | "values" 1378 | ] 1379 | }, 1380 | { 1381 | "cell_type": "code", 1382 | "execution_count": 38, 1383 | "metadata": {}, 1384 | "outputs": [ 1385 | { 1386 | "data": { 1387 | "text/plain": [ 1388 | "{'prompt': tensor([[5, 0, 0, 1, 0],\n", 1389 | " [4, 8, 1, 4, 1],\n", 1390 | " [9, 6, 7, 0, 5]]),\n", 1391 | " 'response': tensor([[4, 8, 5, 2, 9, 5, 5, 0, 6, 3],\n", 1392 | " [0, 3, 0, 4, 8, 2, 6, 4, 9, 3],\n", 1393 | " [2, 6, 7, 5, 0, 0, 3, 3, 4, 8]]),\n", 1394 | " 'mask': tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 1395 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 1396 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]),\n", 1397 | " 'logprobs_ref': tensor([[ -9.7659, -9.9431, -9.7075, -9.8018, -9.6310, -9.6916, -9.7483,\n", 1398 | " -9.5755, -9.7520, -9.8097],\n", 1399 | " [ -9.9691, -9.7657, -9.7810, -9.7806, -9.8304, -9.9382, -9.6816,\n", 1400 | " -9.9212, -9.7132, -9.8413],\n", 1401 | " [-10.4189, -9.7863, -10.1431, -9.8084, -9.5995, -9.5113, -9.8666,\n", 1402 | " -9.7238, -9.6501, -9.6926]], grad_fn=),\n", 1403 | " 'logprobs_old': tensor([[ -9.6364, -10.0382, -9.4454, -9.7810, -9.3484, -9.5437, -9.6146,\n", 1404 | " -9.3174, -9.8408, -9.5032],\n", 1405 | " [ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 1406 | " -9.6053, -9.3741, -9.4720],\n", 1407 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 1408 | " -9.9353, -9.3172, -9.8445]], grad_fn=),\n", 1409 | " 'logprobs': tensor([[ -9.6364, -10.0382, -9.4454, -9.7810, -9.3484, -9.5437, -9.6146,\n", 1410 | " -9.3174, -9.8408, -9.5032],\n", 1411 | " [ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 1412 | " -9.6053, -9.3741, -9.4720],\n", 1413 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 1414 | " -9.9353, -9.3172, -9.8445]], grad_fn=),\n", 1415 | " 'values_old': tensor([[ 0.1939, -0.0731, -0.0170, -0.4315, 0.0534, -0.2046, -0.6074, -0.7700,\n", 1416 | " -1.2505, 0.1553],\n", 1417 | " [ 0.0511, -0.2098, -0.8512, -0.1117, 0.2560, -0.0967, -0.9718, 0.2660,\n", 1418 | " -0.1777, 0.4735],\n", 1419 | " [ 0.2042, -0.6096, -0.0284, 0.2577, -0.3757, -0.3134, -0.5433, -0.2487,\n", 1420 | " -0.2369, 1.0747]], grad_fn=),\n", 1421 | " 'values': None,\n", 1422 | " 'rewards': tensor([[-0.7784],\n", 1423 | " [-0.9515],\n", 1424 | " [-0.9003]], grad_fn=),\n", 1425 | " 'rewards_kl': tensor([[-0.0130, 0.0095, -0.0262, -0.0021, -0.0283, -0.0148, -0.0134, -0.0258,\n", 1426 | " 0.0089, -0.8090],\n", 1427 | " [-0.0315, -0.0049, -0.0047, -0.0323, 0.0020, -0.0178, 0.0170, -0.0316,\n", 1428 | " -0.0339, -0.9884],\n", 1429 | " [-0.0574, 0.0419, -0.0651, -0.0085, -0.0412, -0.0019, -0.0238, 0.0211,\n", 1430 | " -0.0333, -0.8852]], grad_fn=),\n", 1431 | " 'loss': None,\n", 1432 | " 'logits': None}" 1433 | ] 1434 | }, 1435 | "execution_count": 38, 1436 | "metadata": {}, 1437 | "output_type": "execute_result" 1438 | } 1439 | ], 1440 | "source": [ 1441 | "ppo_old_batchs['logprobs_ref'] = logprobs_ref\n", 1442 | "ppo_old_batchs['logprobs_old'] = logprobs_old\n", 1443 | "ppo_old_batchs['logprobs'] = logprobs\n", 1444 | "ppo_old_batchs['values_old'] = values\n", 1445 | "ppo_old_batchs['rewards'] = rewards\n", 1446 | "ppo_old_batchs['rewards_kl'] = kl_reward\n", 1447 | "\n", 1448 | "ppo_old_batchs" 1449 | ] 1450 | }, 1451 | { 1452 | "cell_type": "markdown", 1453 | "metadata": {}, 1454 | "source": [ 1455 | "## 计算loss" 1456 | ] 1457 | }, 1458 | { 1459 | "cell_type": "markdown", 1460 | "metadata": {}, 1461 | "source": [ 1462 | "rewards:一个张量,代表在每个时间步获得的奖励。\n", 1463 | "\n", 1464 | "mask:一个掩码张量,用于标识哪些时间步是有效的(例如,用于处理终止状态)。\n", 1465 | "\n", 1466 | "values:一个张量,代表每个时间步的状态价值估计。\n", 1467 | "\n", 1468 | "gamma:折扣因子,用于计算未来奖励的折现值,取值范围通常在 [0, 1] 之间。\n", 1469 | "\n", 1470 | "lam:GAE 中的 \\(\\lambda\\) 参数,用于平衡偏差和方差,取值范围同样在 [0, 1] 之间。" 1471 | ] 1472 | }, 1473 | { 1474 | "cell_type": "markdown", 1475 | "metadata": {}, 1476 | "source": [] 1477 | }, 1478 | { 1479 | "cell_type": "markdown", 1480 | "metadata": {}, 1481 | "source": [ 1482 | "# PPO 中的 GAE 公式\n", 1483 | "\n", 1484 | "在PPO(Proximal Policy Optimization)算法中,优势函数和价值损失是连接价值估计与策略优化的核心组件。\n", 1485 | "\n", 1486 | "## 优势函数(Advantage Function)\n", 1487 | "\n", 1488 | "优势函数衡量在某一状态下采取特定动作的**相对价值**,定义为:\n", 1489 | "\n", 1490 | "$$A(s_t, a_t) = Q(s_t, a_t) - V(s_t)$$\n", 1491 | "\n", 1492 | "状态 - 动作价值函数(Q 函数),表示在状态 \\(s_t\\) 采取动作 \\(a_t\\) 后,从后续轨迹中获得的总折扣回报的期望。\n", 1493 | "\n", 1494 | "状态价值函数(V 函数),表示在状态 \\(s_t\\) 下,遵循当前策略时获得的总折扣回报的期望(即 “平均收益”)。\n", 1495 | "\n", 1496 | "优势函数的本质是回答:\n", 1497 | "\n", 1498 | "在状态 \\(s_t\\) 下选择动作 \\(a_t\\),比‘按当前策略随机选一个动作’好多少?”\n", 1499 | "\n", 1500 | "若 \\(A(s_t, a_t) > 0\\):动作 \\(a_t\\) 优于平均水平,值得鼓励(策略应提高该动作的概率)\n", 1501 | "\n", 1502 | "若 \\(A(s_t, a_t) < 0\\):动作 \\(a_t\\) 劣于平均水平,应抑制(策略应降低该动作的概率)。\n", 1503 | "\n", 1504 | "优势函数将 “绝对价值” 转化为 “相对价值”,减少了估计偏差(例如,即使 \\(Q(s_t, a_t)\\) 和 \\(V(s_t)\\) 都有误差,两者的差值可能更稳定)\n", 1505 | "\n", 1506 | "在实际训练中,Q 和 V 无法直接获得,PPO 通常使用GAE(Generalized Advantage Estimation) 来估计优势函数\n", 1507 | "\n", 1508 | "GAE(Generalized Advantage Estimation)的时序差分残差公式:\n", 1509 | "\n", 1510 | "$$\\delta_t = r_t + \\gamma V(s_{t+1}) - V(s_t)$$\n", 1511 | "\n", 1512 | "其中,$r_t$ 是时间步 $t$ 的奖励,$\\gamma$ 是折扣因子,$V(s_t)$ 是状态 $s_t$ 的价值估计。\n", 1513 | "\n", 1514 | "GAE 优势估计的递归形式:\n", 1515 | "\n", 1516 | "$$\\hat{A}_t = \\delta_t + \\gamma \\lambda \\hat{A}_{t+1}$$\n", 1517 | "\n", 1518 | "其中 $\\lambda$ 是 GAE 的衰减参数($0 \\leq \\lambda \\leq 1$)。" 1519 | ] 1520 | }, 1521 | { 1522 | "cell_type": "code", 1523 | "execution_count": 39, 1524 | "metadata": {}, 1525 | "outputs": [], 1526 | "source": [ 1527 | "def get_GAE(rewards, attention_mask, values, gemma, lam):\n", 1528 | " lastgae = 0 #初始化为 0,用于存储上一个时间步的广义优势估计值。\n", 1529 | " advantages_recersed = []\n", 1530 | " response_len = rewards.shape[-1]\n", 1531 | "\n", 1532 | " values = values * attention_mask\n", 1533 | " rewards = rewards * attention_mask\n", 1534 | "\n", 1535 | " for t in reversed(range(response_len)):\n", 1536 | " nextvalues = values[:, t + 1] if t < response_len - 1 else 0.0\n", 1537 | " # 计算时间步 t 的 TD 误差(Temporal Difference error),即当前奖励加上折扣后的下一个时间步的价值估计,再减去当前时间步的价值估计。\n", 1538 | " delta = rewards[:, t] + gemma * nextvalues - values[:, t]\n", 1539 | " # 根据 GAE 的递推公式,计算当前时间步的广义优势估计值。\n", 1540 | " lastgae = delta + gemma * lam * lastgae\n", 1541 | " advantages_recersed.append(lastgae)\n", 1542 | " # 将 advantages_reversed 列表反转,使其按时间步的正序排列。\n", 1543 | " advantages = torch.stack(advantages_recersed[::-1]).transpose(0, 1)\n", 1544 | " return advantages\n" 1545 | ] 1546 | }, 1547 | { 1548 | "cell_type": "code", 1549 | "execution_count": 40, 1550 | "metadata": {}, 1551 | "outputs": [ 1552 | { 1553 | "data": { 1554 | "text/plain": [ 1555 | "{'prompt': tensor([[5, 0, 0, 1, 0],\n", 1556 | " [4, 8, 1, 4, 1],\n", 1557 | " [9, 6, 7, 0, 5]]),\n", 1558 | " 'response': tensor([[4, 8, 5, 2, 9, 5, 5, 0, 6, 3],\n", 1559 | " [0, 3, 0, 4, 8, 2, 6, 4, 9, 3],\n", 1560 | " [2, 6, 7, 5, 0, 0, 3, 3, 4, 8]]),\n", 1561 | " 'mask': tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 1562 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 1563 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]),\n", 1564 | " 'logprobs_ref': tensor([[ -9.7659, -9.9431, -9.7075, -9.8018, -9.6310, -9.6916, -9.7483,\n", 1565 | " -9.5755, -9.7520, -9.8097],\n", 1566 | " [ -9.9691, -9.7657, -9.7810, -9.7806, -9.8304, -9.9382, -9.6816,\n", 1567 | " -9.9212, -9.7132, -9.8413],\n", 1568 | " [-10.4189, -9.7863, -10.1431, -9.8084, -9.5995, -9.5113, -9.8666,\n", 1569 | " -9.7238, -9.6501, -9.6926]], grad_fn=),\n", 1570 | " 'logprobs_old': tensor([[ -9.6364, -10.0382, -9.4454, -9.7810, -9.3484, -9.5437, -9.6146,\n", 1571 | " -9.3174, -9.8408, -9.5032],\n", 1572 | " [ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 1573 | " -9.6053, -9.3741, -9.4720],\n", 1574 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 1575 | " -9.9353, -9.3172, -9.8445]], grad_fn=),\n", 1576 | " 'logprobs': tensor([[ -9.6364, -10.0382, -9.4454, -9.7810, -9.3484, -9.5437, -9.6146,\n", 1577 | " -9.3174, -9.8408, -9.5032],\n", 1578 | " [ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 1579 | " -9.6053, -9.3741, -9.4720],\n", 1580 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 1581 | " -9.9353, -9.3172, -9.8445]], grad_fn=),\n", 1582 | " 'values_old': tensor([[ 0.1939, -0.0731, -0.0170, -0.4315, 0.0534, -0.2046, -0.6074, -0.7700,\n", 1583 | " -1.2505, 0.1553],\n", 1584 | " [ 0.0511, -0.2098, -0.8512, -0.1117, 0.2560, -0.0967, -0.9718, 0.2660,\n", 1585 | " -0.1777, 0.4735],\n", 1586 | " [ 0.2042, -0.6096, -0.0284, 0.2577, -0.3757, -0.3134, -0.5433, -0.2487,\n", 1587 | " -0.2369, 1.0747]], grad_fn=),\n", 1588 | " 'values': None,\n", 1589 | " 'rewards': tensor([[-0.7784],\n", 1590 | " [-0.9515],\n", 1591 | " [-0.9003]], grad_fn=),\n", 1592 | " 'rewards_kl': tensor([[-0.0130, 0.0095, -0.0262, -0.0021, -0.0283, -0.0148, -0.0134, -0.0258,\n", 1593 | " 0.0089, -0.8090],\n", 1594 | " [-0.0315, -0.0049, -0.0047, -0.0323, 0.0020, -0.0178, 0.0170, -0.0316,\n", 1595 | " -0.0339, -0.9884],\n", 1596 | " [-0.0574, 0.0419, -0.0651, -0.0085, -0.0412, -0.0019, -0.0238, 0.0211,\n", 1597 | " -0.0333, -0.8852]], grad_fn=),\n", 1598 | " 'loss': None,\n", 1599 | " 'logits': None}" 1600 | ] 1601 | }, 1602 | "execution_count": 40, 1603 | "metadata": {}, 1604 | "output_type": "execute_result" 1605 | } 1606 | ], 1607 | "source": [ 1608 | "ppo_old_batchs" 1609 | ] 1610 | }, 1611 | { 1612 | "cell_type": "code", 1613 | "execution_count": 42, 1614 | "metadata": {}, 1615 | "outputs": [ 1616 | { 1617 | "data": { 1618 | "text/plain": [ 1619 | "tensor([[-0.2043, -0.2523, -0.3115, -0.3845, -0.4747, -0.3587, -0.0023, 0.1193,\n", 1620 | " 0.6180, -0.9643],\n", 1621 | " [-0.1865, -0.2303, -0.2843, -0.3509, -0.4333, -0.4275, 0.4546, -0.9550,\n", 1622 | " -0.6142, -1.4619],\n", 1623 | " [-0.1640, -0.2025, -0.2500, -0.3087, -0.3811, -0.1223, 0.0682, -0.2809,\n", 1624 | " -0.4166, -1.9599]], grad_fn=)" 1625 | ] 1626 | }, 1627 | "execution_count": 42, 1628 | "metadata": {}, 1629 | "output_type": "execute_result" 1630 | } 1631 | ], 1632 | "source": [ 1633 | "gae = get_GAE(ppo_old_batchs['rewards_kl'], ppo_old_batchs['mask'], ppo_old_batchs['values_old'], ppo_config.gamma, ppo_config.lam)\n", 1634 | "gae\n" 1635 | ] 1636 | }, 1637 | { 1638 | "cell_type": "markdown", 1639 | "metadata": {}, 1640 | "source": [ 1641 | "计算value loss\n" 1642 | ] 1643 | }, 1644 | { 1645 | "cell_type": "markdown", 1646 | "metadata": {}, 1647 | "source": [ 1648 | "advantages:优势函数的估计值,用于计算回报。\n", 1649 | "\n", 1650 | "\n", 1651 | "values:当前价值函数的估计值。\n", 1652 | "\n", 1653 | "values_old:旧的价值函数估计值。\n", 1654 | "\n", 1655 | "mask:掩码张量,用于指定哪些元素参与损失计算。\n", 1656 | "\n", 1657 | "cliprange_value:裁剪范围,用于限制价值函数的更新幅度。" 1658 | ] 1659 | }, 1660 | { 1661 | "cell_type": "markdown", 1662 | "metadata": {}, 1663 | "source": [ 1664 | "https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/ppo_trainer.py#L561C29-L567C30" 1665 | ] 1666 | }, 1667 | { 1668 | "cell_type": "code", 1669 | "execution_count": 44, 1670 | "metadata": {}, 1671 | "outputs": [], 1672 | "source": [ 1673 | "def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis = None) -> torch.Tensor:\n", 1674 | " \"\"\"Compute mean of tensor with a masked values.\"\"\"\n", 1675 | " if axis is not None:\n", 1676 | " return (values * mask).sum(axis=axis) / mask.sum(axis=axis)\n", 1677 | " else:\n", 1678 | " return (values * mask).sum() / mask.sum()\n", 1679 | "\n", 1680 | "def get_value_loss(advantages, values, values_old, attention_mask, cliprange_value):\n", 1681 | " # 目标回报 = 旧价值估计 + 优势估计\n", 1682 | " # 这是因为优势函数的定义为:A = Q - V,因此 Q = V + A,这里用returns表示目标 Q 值\n", 1683 | " returns = values_old + advantages\n", 1684 | " advantages = advantages.detach()\n", 1685 | " # 对新的价值估计values进行裁剪,限制其与旧价值估计values_old的差异不超过cliprange_value\n", 1686 | " vpredclipped = torch.clamp(values, values_old - cliprange_value, values_old + cliprange_value)\n", 1687 | "\n", 1688 | " vf_losses1 = torch.square(vpredclipped - returns) # 裁剪后的价值估计与目标回报的平方误差\n", 1689 | " vf_losses2 = torch.square(values - returns) # 未裁剪的价值估计与目标回报的平方误差\n", 1690 | " vf_loss_max = torch.max(vf_losses1, vf_losses2)\n", 1691 | " vf_loss = 0.5 * masked_mean(vf_loss_max, attention_mask)\n", 1692 | " return vf_loss\n", 1693 | "\n" 1694 | ] 1695 | }, 1696 | { 1697 | "cell_type": "code", 1698 | "execution_count": 45, 1699 | "metadata": {}, 1700 | "outputs": [], 1701 | "source": [ 1702 | "ppo_old_batchs['values'] = ppo_old_batchs['values_old'] + 0.5" 1703 | ] 1704 | }, 1705 | { 1706 | "cell_type": "code", 1707 | "execution_count": 46, 1708 | "metadata": {}, 1709 | "outputs": [ 1710 | { 1711 | "data": { 1712 | "text/plain": [ 1713 | "tensor(0.6554, grad_fn=)" 1714 | ] 1715 | }, 1716 | "execution_count": 46, 1717 | "metadata": {}, 1718 | "output_type": "execute_result" 1719 | } 1720 | ], 1721 | "source": [ 1722 | "value_loss = get_value_loss(gae, ppo_old_batchs['values'], ppo_old_batchs['values_old'], ppo_old_batchs['mask'], ppo_config.cliprange_value)\n", 1723 | "value_loss" 1724 | ] 1725 | }, 1726 | { 1727 | "cell_type": "markdown", 1728 | "metadata": {}, 1729 | "source": [ 1730 | "计算policy loss\n", 1731 | "https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/ppo_trainer.py#L569-L574" 1732 | ] 1733 | }, 1734 | { 1735 | "cell_type": "markdown", 1736 | "metadata": {}, 1737 | "source": [ 1738 | "markdown\n", 1739 | "# PPO(Proximal Policy Optimization)核心公式与实现\n", 1740 | "\n", 1741 | "PPO算法的核心是通过策略损失和价值损失的联合优化来更新智能体策略,以下是完整的公式说明与代码实现。\n", 1742 | "\n", 1743 | "## 1. 策略损失(Policy Loss)\n", 1744 | "\n", 1745 | "### 核心公式\n", 1746 | "\n", 1747 | "策略损失的计算基于重要性采样和裁剪机制:\n", 1748 | "\n", 1749 | "1. **重要性采样比率** \n", 1750 | " $$\\text{ratio}_t = \\frac{\\pi_\\theta(a_t | s_t)}{\\pi_{\\theta_{\\text{old}}}(a_t | s_t)} = \\exp\\left(\\log \\pi_\\theta(a_t | s_t) - \\log \\pi_{\\theta_{\\text{old}}}(a_t | s_t)\\right)$$\n", 1751 | "\n", 1752 | "2. **未裁剪损失** \n", 1753 | " $$L_1(\\theta) = -A_t \\cdot \\text{ratio}_t$$\n", 1754 | "\n", 1755 | "3. **裁剪后损失** \n", 1756 | " $$L_2(\\theta) = -A_t \\cdot \\text{clip}(\\text{ratio}_t, 1-\\epsilon, 1+\\epsilon)$$\n", 1757 | "\n", 1758 | "4. **最终策略损失** \n", 1759 | " $$L_{\\text{policy}}(\\theta) = \\mathbb{E}\\left[ \\max(L_1(\\theta), L_2(\\theta)) \\right]$$\n", 1760 | "\n", 1761 | "其中:\n", 1762 | "- $A_t$ 是优势估计(GAE计算结果)\n", 1763 | "- $\\epsilon$ 是裁剪范围超参数(通常为0.2)\n", 1764 | "- $\\pi_\\theta$ 是当前策略,$\\pi_{\\theta_{\\text{old}}}$ 是更新前的旧策略\n" 1765 | ] 1766 | }, 1767 | { 1768 | "cell_type": "code", 1769 | "execution_count": 47, 1770 | "metadata": {}, 1771 | "outputs": [], 1772 | "source": [ 1773 | "def get_policy_loss(advantages, logprobs, logprobs_old, mask, cliprange):\n", 1774 | " # 重要性采样\n", 1775 | " ratio = torch.exp(logprobs - logprobs_old)\n", 1776 | " # 计算策略损失\n", 1777 | " pg_losses = -advantages * ratio\n", 1778 | " pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)\n", 1779 | " pg_loss_max = torch.max(pg_losses, pg_losses2)\n", 1780 | " pg_loss = masked_mean(pg_loss_max, mask)\n", 1781 | " return pg_loss\n", 1782 | "\n" 1783 | ] 1784 | }, 1785 | { 1786 | "cell_type": "code", 1787 | "execution_count": 48, 1788 | "metadata": {}, 1789 | "outputs": [], 1790 | "source": [ 1791 | "pg_loss = get_policy_loss(gae, ppo_old_batchs['logprobs'], ppo_old_batchs['logprobs_old'], ppo_old_batchs['mask'], ppo_config.cliprange_value)" 1792 | ] 1793 | }, 1794 | { 1795 | "cell_type": "code", 1796 | "execution_count": 49, 1797 | "metadata": {}, 1798 | "outputs": [ 1799 | { 1800 | "data": { 1801 | "text/plain": [ 1802 | "tensor(0.4202, grad_fn=)" 1803 | ] 1804 | }, 1805 | "execution_count": 49, 1806 | "metadata": {}, 1807 | "output_type": "execute_result" 1808 | } 1809 | ], 1810 | "source": [ 1811 | "pg_loss" 1812 | ] 1813 | }, 1814 | { 1815 | "cell_type": "markdown", 1816 | "metadata": {}, 1817 | "source": [ 1818 | "计算熵损失\n", 1819 | "https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/ppo_trainer.py#L582-L583" 1820 | ] 1821 | }, 1822 | { 1823 | "cell_type": "markdown", 1824 | "metadata": {}, 1825 | "source": [ 1826 | "entropy(熵)没有直接参与到模型的损失(loss)\n", 1827 | "\n", 1828 | "在计算完损失并进行反向传播和参数更新后,代码计算了 entropy\n", 1829 | "\n", 1830 | "这里计算的 entropy 被记录到 entropy_stats 张量中,用于后续的统计和记录,但没有用于损失计算。" 1831 | ] 1832 | }, 1833 | { 1834 | "cell_type": "code", 1835 | "execution_count": 50, 1836 | "metadata": {}, 1837 | "outputs": [ 1838 | { 1839 | "name": "stdout", 1840 | "output_type": "stream", 1841 | "text": [ 1842 | "inputs_ids shape: torch.Size([3, 10])\n", 1843 | "logits shape: torch.Size([3, 10, 50257])\n" 1844 | ] 1845 | } 1846 | ], 1847 | "source": [ 1848 | "logits = get_logits(models.actor, ppo_old_batchs['response'])\n", 1849 | "ppo_old_batchs['logits'] = logits" 1850 | ] 1851 | }, 1852 | { 1853 | "cell_type": "markdown", 1854 | "metadata": {}, 1855 | "source": [ 1856 | "# PPO中的熵损失(Entropy Loss)计算\n", 1857 | "\n", 1858 | "熵损失用于衡量策略的随机性,在PPO中通常作为总损失的一部分,鼓励智能体保持探索行为。\n", 1859 | "\n", 1860 | "## 熵计算函数\n", 1861 | "\n", 1862 | "```python\n", 1863 | "def get_entropy_loss(logits, mask):\n", 1864 | " # 将logits转换为概率分布(softmax归一化)\n", 1865 | " prob_dist = torch.nn.functional.softmax(logits, dim=-1)\n", 1866 | " \n", 1867 | " # 计算熵: H(p) = -Σ(p_i * log(p_i))\n", 1868 | " # 等价于: log(Σ(exp(logits_i))) - Σ(p_i * logits_i)\n", 1869 | " entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)\n", 1870 | " \n", 1871 | " return entropy\n", 1872 | "\n", 1873 | "# 计算旧批次数据的熵\n", 1874 | "entropy = get_entropy_loss(ppo_old_batchs['logits'], ppo_old_batchs['mask'])\n", 1875 | "entropy # 返回每个样本的熵值" 1876 | ] 1877 | }, 1878 | { 1879 | "cell_type": "code", 1880 | "execution_count": 52, 1881 | "metadata": {}, 1882 | "outputs": [ 1883 | { 1884 | "name": "stdout", 1885 | "output_type": "stream", 1886 | "text": [ 1887 | "logits shape: torch.Size([3, 10, 50257]), mask shape: torch.Size([3, 10])\n", 1888 | "prob_dist shape: torch.Size([3, 10, 50257]), logits shape: torch.Size([3, 10, 50257])\n" 1889 | ] 1890 | }, 1891 | { 1892 | "data": { 1893 | "text/plain": [ 1894 | "tensor([[10.7993, 10.7995, 10.7994, 10.7994, 10.7990, 10.7992, 10.7994, 10.7994,\n", 1895 | " 10.7997, 10.7995],\n", 1896 | " [10.7995, 10.7994, 10.7994, 10.7996, 10.7995, 10.7992, 10.7994, 10.7995,\n", 1897 | " 10.7993, 10.7996],\n", 1898 | " [10.7992, 10.7996, 10.7994, 10.7993, 10.7995, 10.7993, 10.7994, 10.7994,\n", 1899 | " 10.7996, 10.7994]], grad_fn=)" 1900 | ] 1901 | }, 1902 | "execution_count": 52, 1903 | "metadata": {}, 1904 | "output_type": "execute_result" 1905 | } 1906 | ], 1907 | "source": [ 1908 | "def get_entropy_loss(logits, mask):\n", 1909 | " prob_dist = torch.nn.functional.softmax(logits, dim=-1)\n", 1910 | " print(f\"prob_dist shape: {prob_dist.shape}, logits shape: {logits.shape}\")\n", 1911 | " # 计算熵\n", 1912 | " # 使用torch.logsumexp计算logits的对数和,然后减去每个概率分布乘以logits的和\n", 1913 | " # 这里的熵计算公式是 H(X) = log(sum(exp(logits))) - sum(prob_dist * logits)\n", 1914 | " \n", 1915 | " entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)\n", 1916 | " return entropy\n", 1917 | "print(f\"logits shape: {logits.shape}, mask shape: {ppo_old_batchs['mask'].shape}\")\n", 1918 | "entropy = get_entropy_loss(ppo_old_batchs['logits'], ppo_old_batchs['mask'])\n", 1919 | "entropy\n", 1920 | " " 1921 | ] 1922 | }, 1923 | { 1924 | "cell_type": "code", 1925 | "execution_count": 53, 1926 | "metadata": {}, 1927 | "outputs": [], 1928 | "source": [ 1929 | "loss = pg_loss + ppo_config.vf_coef * value_loss" 1930 | ] 1931 | }, 1932 | { 1933 | "cell_type": "code", 1934 | "execution_count": 54, 1935 | "metadata": {}, 1936 | "outputs": [], 1937 | "source": [ 1938 | "def get_loss(batchs, ppo_config):\n", 1939 | " gae = get_GAE(batchs['rewards_kl'],\n", 1940 | " batchs['mask'],\n", 1941 | " batchs['values'],\n", 1942 | " ppo_config.gamma,\n", 1943 | " ppo_config.lam)\n", 1944 | " value_loss = get_value_loss(gae,\n", 1945 | " batchs['values'],\n", 1946 | " batchs['values_old'],\n", 1947 | " batchs['mask'],\n", 1948 | " ppo_config.cliprange_value)\n", 1949 | " pg_loss = get_policy_loss(\n", 1950 | " gae,\n", 1951 | " batchs['logprobs'],\n", 1952 | " batchs['logprobs_old'],\n", 1953 | " batchs['mask'],\n", 1954 | " ppo_config.cliprange_value)\n", 1955 | " entropy = get_entropy_loss(batchs['logits'], batchs['mask'])\n", 1956 | " loss = pg_loss + ppo_config.vf_coef * value_loss\n", 1957 | " return loss" 1958 | ] 1959 | }, 1960 | { 1961 | "cell_type": "code", 1962 | "execution_count": 55, 1963 | "metadata": {}, 1964 | "outputs": [ 1965 | { 1966 | "name": "stdout", 1967 | "output_type": "stream", 1968 | "text": [ 1969 | "prob_dist shape: torch.Size([3, 10, 50257]), logits shape: torch.Size([3, 10, 50257])\n" 1970 | ] 1971 | }, 1972 | { 1973 | "data": { 1974 | "text/plain": [ 1975 | "tensor(0.9609, grad_fn=)" 1976 | ] 1977 | }, 1978 | "execution_count": 55, 1979 | "metadata": {}, 1980 | "output_type": "execute_result" 1981 | } 1982 | ], 1983 | "source": [ 1984 | "loss = get_loss(ppo_old_batchs, ppo_config)\n", 1985 | "loss" 1986 | ] 1987 | }, 1988 | { 1989 | "cell_type": "code", 1990 | "execution_count": 56, 1991 | "metadata": {}, 1992 | "outputs": [ 1993 | { 1994 | "data": { 1995 | "text/plain": [ 1996 | "{'prompt': tensor([[5, 0, 0, 1, 0],\n", 1997 | " [4, 8, 1, 4, 1],\n", 1998 | " [9, 6, 7, 0, 5]]),\n", 1999 | " 'response': tensor([[4, 8, 5, 2, 9, 5, 5, 0, 6, 3],\n", 2000 | " [0, 3, 0, 4, 8, 2, 6, 4, 9, 3],\n", 2001 | " [2, 6, 7, 5, 0, 0, 3, 3, 4, 8]]),\n", 2002 | " 'mask': tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 2003 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 2004 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]),\n", 2005 | " 'logprobs_ref': tensor([[ -9.7659, -9.9431, -9.7075, -9.8018, -9.6310, -9.6916, -9.7483,\n", 2006 | " -9.5755, -9.7520, -9.8097],\n", 2007 | " [ -9.9691, -9.7657, -9.7810, -9.7806, -9.8304, -9.9382, -9.6816,\n", 2008 | " -9.9212, -9.7132, -9.8413],\n", 2009 | " [-10.4189, -9.7863, -10.1431, -9.8084, -9.5995, -9.5113, -9.8666,\n", 2010 | " -9.7238, -9.6501, -9.6926]], grad_fn=),\n", 2011 | " 'logprobs_old': tensor([[ -9.6364, -10.0382, -9.4454, -9.7810, -9.3484, -9.5437, -9.6146,\n", 2012 | " -9.3174, -9.8408, -9.5032],\n", 2013 | " [ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 2014 | " -9.6053, -9.3741, -9.4720],\n", 2015 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 2016 | " -9.9353, -9.3172, -9.8445]], grad_fn=),\n", 2017 | " 'logprobs': tensor([[ -9.6364, -10.0382, -9.4454, -9.7810, -9.3484, -9.5437, -9.6146,\n", 2018 | " -9.3174, -9.8408, -9.5032],\n", 2019 | " [ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 2020 | " -9.6053, -9.3741, -9.4720],\n", 2021 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 2022 | " -9.9353, -9.3172, -9.8445]], grad_fn=),\n", 2023 | " 'values_old': tensor([[ 0.1939, -0.0731, -0.0170, -0.4315, 0.0534, -0.2046, -0.6074, -0.7700,\n", 2024 | " -1.2505, 0.1553],\n", 2025 | " [ 0.0511, -0.2098, -0.8512, -0.1117, 0.2560, -0.0967, -0.9718, 0.2660,\n", 2026 | " -0.1777, 0.4735],\n", 2027 | " [ 0.2042, -0.6096, -0.0284, 0.2577, -0.3757, -0.3134, -0.5433, -0.2487,\n", 2028 | " -0.2369, 1.0747]], grad_fn=),\n", 2029 | " 'values': tensor([[ 0.6939, 0.4269, 0.4830, 0.0685, 0.5534, 0.2954, -0.1074, -0.2700,\n", 2030 | " -0.7505, 0.6553],\n", 2031 | " [ 0.5511, 0.2902, -0.3512, 0.3883, 0.7560, 0.4033, -0.4718, 0.7660,\n", 2032 | " 0.3223, 0.9735],\n", 2033 | " [ 0.7042, -0.1096, 0.4716, 0.7577, 0.1243, 0.1866, -0.0433, 0.2513,\n", 2034 | " 0.2631, 1.5747]], grad_fn=),\n", 2035 | " 'rewards': tensor([[-0.7784],\n", 2036 | " [-0.9515],\n", 2037 | " [-0.9003]], grad_fn=),\n", 2038 | " 'rewards_kl': tensor([[-0.0130, 0.0095, -0.0262, -0.0021, -0.0283, -0.0148, -0.0134, -0.0258,\n", 2039 | " 0.0089, -0.8090],\n", 2040 | " [-0.0315, -0.0049, -0.0047, -0.0323, 0.0020, -0.0178, 0.0170, -0.0316,\n", 2041 | " -0.0339, -0.9884],\n", 2042 | " [-0.0574, 0.0419, -0.0651, -0.0085, -0.0412, -0.0019, -0.0238, 0.0211,\n", 2043 | " -0.0333, -0.8852]], grad_fn=),\n", 2044 | " 'loss': None,\n", 2045 | " 'logits': tensor([[[-1.4843e-01, -3.8199e-01, 1.5566e-01, ..., 6.0343e-01,\n", 2046 | " -3.5546e-01, -2.5944e-01],\n", 2047 | " [-1.6893e-01, -2.5384e-03, -8.4530e-03, ..., 8.7142e-02,\n", 2048 | " -2.0942e-01, -8.3370e-02],\n", 2049 | " [-4.3086e-01, 5.5402e-02, -4.6384e-01, ..., 9.1063e-02,\n", 2050 | " -8.1510e-02, 1.6532e-01],\n", 2051 | " ...,\n", 2052 | " [ 1.5315e+00, -8.8365e-02, -1.9262e-01, ..., -2.3480e-01,\n", 2053 | " 7.7313e-02, -1.3036e-02],\n", 2054 | " [-9.2542e-02, -2.2912e-01, 8.3747e-02, ..., 2.8154e-03,\n", 2055 | " -1.3022e-01, 6.1364e-02],\n", 2056 | " [-4.0653e-01, -2.8789e-02, -1.5729e-01, ..., 2.5900e-01,\n", 2057 | " -3.2773e-01, -1.3417e-01]],\n", 2058 | " \n", 2059 | " [[ 1.1939e+00, -3.4385e-01, 1.8697e-01, ..., 8.9561e-02,\n", 2060 | " -1.3423e-01, -5.1387e-05],\n", 2061 | " [ 1.3593e-01, -2.1616e-01, 1.7281e-01, ..., 5.4955e-02,\n", 2062 | " -2.8100e-01, -9.6232e-02],\n", 2063 | " [ 1.1163e+00, -4.0199e-01, -5.8994e-02, ..., -4.4124e-02,\n", 2064 | " 8.6503e-02, -4.1281e-02],\n", 2065 | " ...,\n", 2066 | " [ 7.9734e-02, -4.3286e-01, 1.4872e-01, ..., -5.1665e-03,\n", 2067 | " -7.4853e-02, -2.7805e-02],\n", 2068 | " [ 3.4729e-01, -2.8876e-01, 3.5831e-02, ..., 1.3297e-01,\n", 2069 | " -8.0469e-03, 5.7139e-02],\n", 2070 | " [-3.4550e-01, -1.6689e-01, -1.2459e-01, ..., 2.8532e-01,\n", 2071 | " -3.9113e-01, -1.1683e-01]],\n", 2072 | " \n", 2073 | " [[ 6.7189e-03, -4.6148e-02, 1.0041e+00, ..., 5.6802e-01,\n", 2074 | " -1.4841e-01, -1.4218e-01],\n", 2075 | " [-8.0866e-02, -2.3968e-01, 1.6320e-01, ..., 6.1787e-02,\n", 2076 | " 1.6179e-02, 2.5040e-01],\n", 2077 | " [-3.4248e-01, -1.3313e-01, -4.3621e-01, ..., 3.2381e-01,\n", 2078 | " 1.3221e-02, 5.6685e-02],\n", 2079 | " ...,\n", 2080 | " [ 3.4316e-01, -8.4548e-04, -3.4696e-01, ..., -6.7568e-02,\n", 2081 | " -1.2948e-01, -1.6340e-01],\n", 2082 | " [-3.2091e-02, -6.8572e-01, 2.5836e-01, ..., 2.4276e-01,\n", 2083 | " -1.0186e-01, -1.8865e-01],\n", 2084 | " [-4.6698e-01, -2.5016e-01, -1.1452e-01, ..., 6.8086e-02,\n", 2085 | " -3.2970e-01, -7.7348e-02]]], grad_fn=)}" 2086 | ] 2087 | }, 2088 | "execution_count": 56, 2089 | "metadata": {}, 2090 | "output_type": "execute_result" 2091 | } 2092 | ], 2093 | "source": [ 2094 | "ppo_old_batchs" 2095 | ] 2096 | }, 2097 | { 2098 | "cell_type": "markdown", 2099 | "metadata": {}, 2100 | "source": [ 2101 | "## PPO训练\n", 2102 | "\n", 2103 | "https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/ppo_trainer.py#L529-L538" 2104 | ] 2105 | }, 2106 | { 2107 | "cell_type": "markdown", 2108 | "metadata": {}, 2109 | "source": [ 2110 | "将一个完整的批次数据 ppo_batchs 按照指定的 batch_size 和 mini_batch_size 划分成多个小批次数据" 2111 | ] 2112 | }, 2113 | { 2114 | "cell_type": "code", 2115 | "execution_count": 88, 2116 | "metadata": {}, 2117 | "outputs": [], 2118 | "source": [ 2119 | "import numpy as np\n", 2120 | "def get_minibatch(ppo_batchs, batch_size, mini_batch_size):\n", 2121 | " # 计算需要多少个小批次\n", 2122 | " step = batch_size // mini_batch_size\n", 2123 | " ppo_batchs_iter = []\n", 2124 | " \n", 2125 | " # 随机打乱索引以提高训练效果\n", 2126 | " b_inds = np.random.permutation(batch_size)\n", 2127 | " \n", 2128 | " # 根据索引创建小批次\n", 2129 | " for i in range(step):\n", 2130 | " start_idx = i * mini_batch_size\n", 2131 | " end_idx = start_idx + mini_batch_size\n", 2132 | " batch_inds = b_inds[start_idx:end_idx]\n", 2133 | " \n", 2134 | " # 创建当前小批次的数据\n", 2135 | " mini_batch = {}\n", 2136 | " for key, value in ppo_batchs.items():\n", 2137 | " if value is not None and isinstance(value, torch.Tensor) and value.size(0) == batch_size:\n", 2138 | " mini_batch[key] = value[batch_inds]\n", 2139 | " else:\n", 2140 | " mini_batch[key] = value\n", 2141 | " \n", 2142 | " ppo_batchs_iter.append(mini_batch)\n", 2143 | " \n", 2144 | " return ppo_batchs_iter" 2145 | ] 2146 | }, 2147 | { 2148 | "cell_type": "code", 2149 | "execution_count": 74, 2150 | "metadata": {}, 2151 | "outputs": [], 2152 | "source": [ 2153 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)" 2154 | ] 2155 | }, 2156 | { 2157 | "cell_type": "code", 2158 | "execution_count": 75, 2159 | "metadata": {}, 2160 | "outputs": [ 2161 | { 2162 | "data": { 2163 | "text/plain": [ 2164 | "{'prompt': tensor([[5, 0, 0, 1, 0],\n", 2165 | " [4, 8, 1, 4, 1],\n", 2166 | " [9, 6, 7, 0, 5],\n", 2167 | " [4, 8, 5, 2, 9],\n", 2168 | " [5, 5, 0, 6, 3]]),\n", 2169 | " 'response': tensor([[0, 3, 0, 4, 8, 2, 6, 4, 9, 3],\n", 2170 | " [2, 6, 7, 5, 0, 0, 3, 3, 4, 8],\n", 2171 | " [0, 8, 8, 2, 6, 0, 6, 0, 5, 8],\n", 2172 | " [8, 1, 4, 6, 2, 7, 5, 5, 9, 5],\n", 2173 | " [7, 4, 9, 5, 6, 6, 6, 1, 9, 8]]),\n", 2174 | " 'mask': tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 2175 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 2176 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 2177 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n", 2178 | " [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]),\n", 2179 | " 'logprobs_ref': tensor([[ -9.7657, -9.5145, -9.7403, -9.4521, -9.8023, -9.8455, -9.8040,\n", 2180 | " -9.5040, -9.9263, -9.4373],\n", 2181 | " [-10.0543, -9.8124, -9.6533, -9.7472, -9.6888, -9.7347, -9.5207,\n", 2182 | " -9.2883, -9.4406, -9.7164],\n", 2183 | " [ -9.7657, -10.3167, -9.8208, -9.8356, -9.5770, -9.7337, -9.7759,\n", 2184 | " -9.6341, -9.4780, -9.8436],\n", 2185 | " [-10.3158, -9.5739, -9.6799, -9.8827, -10.0626, -9.6075, -9.7284,\n", 2186 | " -9.6707, -9.9424, -9.6236],\n", 2187 | " [ -9.8277, -9.9490, -9.4426, -9.7313, -9.5943, -9.7917, -9.6991,\n", 2188 | " -9.7685, -9.9496, -9.7640]], grad_fn=),\n", 2189 | " 'logprobs_old': tensor([[ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 2190 | " -9.6053, -9.3741, -9.4720],\n", 2191 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 2192 | " -9.9353, -9.3172, -9.8445],\n", 2193 | " [ -9.6546, -10.1063, -9.7419, -9.6142, -9.8585, -9.5115, -9.7855,\n", 2194 | " -9.2093, -9.4475, -9.7984],\n", 2195 | " [-10.0683, -9.7843, -9.6151, -9.7731, -9.4803, -9.1821, -9.4697,\n", 2196 | " -9.6959, -9.3579, -9.5344],\n", 2197 | " [ -9.6822, -9.6050, -9.5979, -9.7321, -10.0195, -10.2095, -9.9384,\n", 2198 | " -9.7428, -9.4144, -9.9008]], grad_fn=),\n", 2199 | " 'logprobs': tensor([[ -9.6546, -9.7166, -9.7343, -9.4578, -9.8507, -9.7604, -9.8515,\n", 2200 | " -9.6053, -9.3741, -9.4720],\n", 2201 | " [ -9.8447, -10.2057, -9.4921, -9.7237, -9.1873, -9.4923, -9.6284,\n", 2202 | " -9.9353, -9.3172, -9.8445],\n", 2203 | " [ -9.6546, -10.1063, -9.7419, -9.6142, -9.8585, -9.5115, -9.7855,\n", 2204 | " -9.2093, -9.4475, -9.7984],\n", 2205 | " [-10.0683, -9.7843, -9.6151, -9.7731, -9.4803, -9.1821, -9.4697,\n", 2206 | " -9.6959, -9.3579, -9.5344],\n", 2207 | " [ -9.6822, -9.6050, -9.5979, -9.7321, -10.0195, -10.2095, -9.9384,\n", 2208 | " -9.7428, -9.4144, -9.9008]], grad_fn=),\n", 2209 | " 'values_old': tensor([[ 1.2677, 0.5070, 0.9766, -0.4549, 0.5805, -0.4866, 0.5283, -0.2907,\n", 2210 | " 0.0779, -0.1667],\n", 2211 | " [ 0.3226, -0.0667, -0.7088, -0.4413, 0.6490, 0.8188, 1.3689, 0.6129,\n", 2212 | " 0.8584, -0.0860],\n", 2213 | " [ 1.2112, 0.0672, 0.4946, -0.7344, 0.5928, 0.8188, 1.0112, 0.7424,\n", 2214 | " 1.3459, -0.0567],\n", 2215 | " [ 0.5810, -0.2458, 0.0620, -0.9607, -0.0040, -1.0716, 0.5418, -0.1127,\n", 2216 | " -0.0043, -0.3484],\n", 2217 | " [-0.4887, -0.2443, -0.6051, -0.6362, 0.2427, -0.0520, 0.6208, 0.1293,\n", 2218 | " 0.1234, -0.2866]], grad_fn=),\n", 2219 | " 'values': tensor([[ 1.7677, 1.0070, 1.4766, 0.0451, 1.0805, 0.0134, 1.0283, 0.2093,\n", 2220 | " 0.5779, 0.3333],\n", 2221 | " [ 0.8226, 0.4333, -0.2088, 0.0587, 1.1490, 1.3188, 1.8689, 1.1129,\n", 2222 | " 1.3584, 0.4140],\n", 2223 | " [ 1.7112, 0.5672, 0.9946, -0.2344, 1.0928, 1.3188, 1.5112, 1.2424,\n", 2224 | " 1.8459, 0.4433],\n", 2225 | " [ 1.0810, 0.2542, 0.5620, -0.4607, 0.4960, -0.5716, 1.0418, 0.3873,\n", 2226 | " 0.4957, 0.1516],\n", 2227 | " [ 0.0113, 0.2557, -0.1051, -0.1362, 0.7427, 0.4480, 1.1208, 0.6293,\n", 2228 | " 0.6234, 0.2134]], grad_fn=),\n", 2229 | " 'rewards': tensor([[-0.9515],\n", 2230 | " [-0.9003],\n", 2231 | " [-1.3975],\n", 2232 | " [-1.6012],\n", 2233 | " [-1.6159]], grad_fn=),\n", 2234 | " 'rewards_kl': tensor([[-1.1109e-02, 2.0212e-02, -5.9538e-04, 5.7230e-04, 4.8371e-03,\n", 2235 | " -8.5035e-03, 4.7553e-03, 1.0133e-02, -5.5222e-02, -9.4801e-01],\n", 2236 | " [-2.0961e-02, 3.9329e-02, -1.6113e-02, -2.3515e-03, -5.0153e-02,\n", 2237 | " -2.4239e-02, 1.0773e-02, 6.4699e-02, -1.2334e-02, -8.8754e-01],\n", 2238 | " [-1.1109e-02, -2.1040e-02, -7.8938e-03, -2.2135e-02, 2.8151e-02,\n", 2239 | " -2.2220e-02, 9.5730e-04, -4.2487e-02, -3.0542e-03, -1.4020e+00],\n", 2240 | " [-2.4752e-02, 2.1033e-02, -6.4787e-03, -1.0964e-02, -5.8229e-02,\n", 2241 | " -4.2533e-02, -2.5875e-02, 2.5274e-03, -5.8451e-02, -1.6101e+00],\n", 2242 | " [-1.4547e-02, -3.4400e-02, 1.5532e-02, 8.0109e-05, 4.2521e-02,\n", 2243 | " 4.1776e-02, 2.3927e-02, -2.5682e-03, -5.3520e-02, -1.6023e+00]],\n", 2244 | " grad_fn=),\n", 2245 | " 'loss': None,\n", 2246 | " 'logits': tensor([[[ 1.1939e+00, -3.4385e-01, 1.8697e-01, ..., 8.9561e-02,\n", 2247 | " -1.3423e-01, -5.1387e-05],\n", 2248 | " [ 1.3593e-01, -2.1616e-01, 1.7281e-01, ..., 5.4955e-02,\n", 2249 | " -2.8100e-01, -9.6232e-02],\n", 2250 | " [ 1.1163e+00, -4.0199e-01, -5.8994e-02, ..., -4.4124e-02,\n", 2251 | " 8.6503e-02, -4.1281e-02],\n", 2252 | " ...,\n", 2253 | " [ 7.9734e-02, -4.3286e-01, 1.4872e-01, ..., -5.1665e-03,\n", 2254 | " -7.4853e-02, -2.7805e-02],\n", 2255 | " [ 3.4729e-01, -2.8876e-01, 3.5831e-02, ..., 1.3297e-01,\n", 2256 | " -8.0469e-03, 5.7139e-02],\n", 2257 | " [-3.4550e-01, -1.6689e-01, -1.2459e-01, ..., 2.8532e-01,\n", 2258 | " -3.9113e-01, -1.1683e-01]],\n", 2259 | " \n", 2260 | " [[ 6.7189e-03, -4.6148e-02, 1.0041e+00, ..., 5.6802e-01,\n", 2261 | " -1.4841e-01, -1.4218e-01],\n", 2262 | " [-8.0866e-02, -2.3968e-01, 1.6320e-01, ..., 6.1787e-02,\n", 2263 | " 1.6179e-02, 2.5040e-01],\n", 2264 | " [-3.4248e-01, -1.3313e-01, -4.3621e-01, ..., 3.2381e-01,\n", 2265 | " 1.3221e-02, 5.6685e-02],\n", 2266 | " ...,\n", 2267 | " [ 3.4316e-01, -8.4548e-04, -3.4696e-01, ..., -6.7568e-02,\n", 2268 | " -1.2948e-01, -1.6340e-01],\n", 2269 | " [-3.2091e-02, -6.8572e-01, 2.5836e-01, ..., 2.4276e-01,\n", 2270 | " -1.0186e-01, -1.8865e-01],\n", 2271 | " [-4.6698e-01, -2.5016e-01, -1.1452e-01, ..., 6.8086e-02,\n", 2272 | " -3.2970e-01, -7.7348e-02]],\n", 2273 | " \n", 2274 | " [[ 1.1939e+00, -3.4385e-01, 1.8697e-01, ..., 8.9561e-02,\n", 2275 | " -1.3423e-01, -5.1387e-05],\n", 2276 | " [-1.0593e-01, -1.3282e-01, 2.0533e-01, ..., -1.9474e-01,\n", 2277 | " -1.6972e-01, 4.7611e-02],\n", 2278 | " [-4.1880e-01, 1.8398e-02, -5.3639e-02, ..., -1.0487e-02,\n", 2279 | " -1.2665e-01, -7.0815e-02],\n", 2280 | " ...,\n", 2281 | " [ 1.6399e+00, -2.6469e-01, -8.5538e-02, ..., -2.8674e-01,\n", 2282 | " 5.6738e-02, 8.3134e-02],\n", 2283 | " [ 1.7255e-01, -3.7670e-01, -3.0233e-01, ..., -7.1360e-02,\n", 2284 | " -9.5127e-02, 4.1914e-01],\n", 2285 | " [-4.9126e-01, -2.2191e-01, -7.8555e-03, ..., -5.6117e-03,\n", 2286 | " -3.6520e-01, 9.7580e-03]],\n", 2287 | " \n", 2288 | " [[-1.4264e-01, -2.8157e-02, 2.0611e-01, ..., 3.9266e-01,\n", 2289 | " -3.9834e-01, -2.0778e-01],\n", 2290 | " [-2.1525e-01, 1.0653e+00, 2.1692e-01, ..., 1.1699e-01,\n", 2291 | " 5.6338e-02, -1.0115e-01],\n", 2292 | " [-5.6471e-01, -2.6728e-01, 5.1792e-02, ..., 2.3630e-01,\n", 2293 | " -8.6777e-02, -2.1680e-01],\n", 2294 | " ...,\n", 2295 | " [ 5.0904e-02, 6.5761e-02, -6.5508e-01, ..., -3.1484e-01,\n", 2296 | " 5.0776e-02, 3.6046e-01],\n", 2297 | " [ 2.1136e-01, -1.6706e-01, -4.8888e-02, ..., 1.3312e-01,\n", 2298 | " 2.5565e-03, -4.6409e-02],\n", 2299 | " [-5.2850e-01, -9.6140e-02, -4.2049e-01, ..., 4.4030e-03,\n", 2300 | " -1.7598e-01, 2.3337e-01]],\n", 2301 | " \n", 2302 | " [[-1.9936e-01, -1.6945e-01, -1.9695e-01, ..., 5.2535e-01,\n", 2303 | " -1.7846e-01, -2.9423e-01],\n", 2304 | " [-3.5077e-01, -4.7752e-01, 2.0070e-01, ..., 2.2220e-01,\n", 2305 | " -8.3356e-02, -2.5743e-01],\n", 2306 | " [-2.8892e-01, 3.0952e-02, -2.3381e-01, ..., 1.5720e-01,\n", 2307 | " 9.4805e-02, -8.3954e-02],\n", 2308 | " ...,\n", 2309 | " [ 6.9169e-02, 1.1067e+00, -2.3178e-01, ..., 7.0888e-02,\n", 2310 | " 2.1960e-01, -7.3331e-02],\n", 2311 | " [ 3.3634e-01, -3.3223e-01, -1.2819e-01, ..., 1.1444e-01,\n", 2312 | " 1.8477e-01, -8.0723e-02],\n", 2313 | " [-4.0076e-01, -2.4644e-01, -2.1143e-01, ..., 1.5312e-02,\n", 2314 | " -1.8078e-01, -1.4051e-01]]], grad_fn=)}" 2315 | ] 2316 | }, 2317 | "execution_count": 75, 2318 | "metadata": {}, 2319 | "output_type": "execute_result" 2320 | } 2321 | ], 2322 | "source": [ 2323 | "ppo_old_batchs" 2324 | ] 2325 | }, 2326 | { 2327 | "cell_type": "code", 2328 | "execution_count": 155, 2329 | "metadata": {}, 2330 | "outputs": [], 2331 | "source": [ 2332 | "def ppo_train_step(models, ppo_batchs, ppo_config, get_loss, optimizer):\n", 2333 | " losses = []\n", 2334 | " \n", 2335 | " \n", 2336 | " # 多轮PPO训练\n", 2337 | " for i in range(ppo_config.ppo_epochs):\n", 2338 | " # 获取小批次数据\n", 2339 | " ppo_batchs_iter = get_minibatch(\n", 2340 | " ppo_batchs, batch_size, ppo_config.mini_batch_size)\n", 2341 | " \n", 2342 | " # 对每个小批次进行训练\n", 2343 | " for mini_batchs in ppo_batchs_iter:\n", 2344 | " # 获取当前策略的输出\n", 2345 | " optimizer.zero_grad()\n", 2346 | " # 重新计算所有中间结果,而不是重用之前的计算图\n", 2347 | " with torch.set_grad_enabled(True):\n", 2348 | " logits = get_logits(models.actor, mini_batchs['prompt'])\n", 2349 | " \"\"\"\n", 2350 | " 省略了\n", 2351 | " \"\"\"\n", 2352 | "\n", 2353 | " \n", 2354 | " # 计算损失\n", 2355 | " loss= get_loss(\n", 2356 | " mini_batchs, ppo_config)\n", 2357 | " \n", 2358 | " # 在实际训练中应该进行反向传播\n", 2359 | " loss.backward()\n", 2360 | " optimizer.step()\n", 2361 | " \n", 2362 | " # 记录损失\n", 2363 | " losses.append(loss)\n", 2364 | " \n", 2365 | " # 更新批次数据中的损失\n", 2366 | " ppo_batchs['loss'] = losses\n", 2367 | " \n", 2368 | " print(losses)\n", 2369 | "\n" 2370 | ] 2371 | } 2372 | ], 2373 | "metadata": { 2374 | "kernelspec": { 2375 | "display_name": "llm", 2376 | "language": "python", 2377 | "name": "python3" 2378 | }, 2379 | "language_info": { 2380 | "codemirror_mode": { 2381 | "name": "ipython", 2382 | "version": 3 2383 | }, 2384 | "file_extension": ".py", 2385 | "mimetype": "text/x-python", 2386 | "name": "python", 2387 | "nbconvert_exporter": "python", 2388 | "pygments_lexer": "ipython3", 2389 | "version": "3.11.8" 2390 | } 2391 | }, 2392 | "nbformat": 4, 2393 | "nbformat_minor": 2 2394 | } 2395 | --------------------------------------------------------------------------------