├── assets
├── fig1.png
├── fig2.png
├── fig3.png
└── fig4.png
├── README.md
└── r1-video-grpo_trainer_fixbug.py
/assets/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hui-design/R1-Video-fixbug/HEAD/assets/fig1.png
--------------------------------------------------------------------------------
/assets/fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hui-design/R1-Video-fixbug/HEAD/assets/fig2.png
--------------------------------------------------------------------------------
/assets/fig3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hui-design/R1-Video-fixbug/HEAD/assets/fig3.png
--------------------------------------------------------------------------------
/assets/fig4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hui-design/R1-Video-fixbug/HEAD/assets/fig4.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # R1-Video-fixbug
2 |
3 |
4 | Recently, many awesome open-source projects have dedicated to applying Deepseek-R1/GPRO to multimodal tasks. Among them, [Open-R1-Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video/) is one such awesome project applied to video understanding. During our reproduction, we found a bug in the code (**until 2025-02-22**). This bug caused the reference model to have problems **when executing the `get_per_token_probs` function, resulting in incorrect calculation of the KL divergence term**.
5 | In addition, we found that not only Open-R1-Video has this problem, [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal) also has this problem, which made us wonder whether this bug really exists, so we did some exploration.
6 |
7 | ## 🔥 News
8 | 2025-02-23: [Open-R1-Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video/) has merged our pull request, and they get performance gains based on our fixed version
9 |
10 |
11 | ## What is the bug?
12 | In lines 444-453 of Open-R1-Video/src/open_r1_video/trainer/grpo_trainer.py, both the current model ($$\pi_{\theta}$$) and ref_model ($$\pi_{ref}$$) require passing through the `get_per_token_logps` function to execute a model forward and obtain logps.
13 |
14 | - In GRPO, $$\pi_{\theta}$$ and $$\pi_{ref}$$ differ only in their parameters, while their inputs should be identical
15 |
16 |
17 |

18 |
19 |
20 |
21 |
22 | - However, in the grpo_trainer implementations of both [Open-R1-Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video/blob/main/src/open_r1_video/trainer/grpo_trainer.py) and [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal/blob/main/src/open_r1/trainer/grpo_trainer.py), the inputs differ. Specifically, the model has the **"\*\*prompt_inputs"** argument, while ref_model lacks the **"\*\*prompt_inputs"** argument.
23 |
24 |
25 |

26 |
27 |
28 |
29 |
30 | - We further verified the implementation in [R1-V](https://github.com/Deep-Agent/R1-V/) and confirmed that the inputs to model and ref_model are indeed identical
31 |
32 |
33 |

34 |
35 |
36 |
37 |
38 |
39 | ## What does the bug affect?
40 | Intuitively, only for this part, Open-R1-Video and R1-multimodal are incorrect, while R1-V is the correct one. So, what are the impacts of this bug?
41 | ### 1. Issue with the input_embeds for Reference Model :
42 | In the code, `**prompt_inputs` mainly contain two keys: `"pixel_values_videos"` and `"video_grid_thw."` These two variables represent the video input (in R1-multimodal, this is an image). When these variables are passed into get_per_token_logps, they enter the model.forward method of Qwen2VL (specifically, lines 1667-1703 of `transformers/src/transformers/models/qwen2_vl
43 | /modeling_qwen2_vl.py/Qwen2VLForConditionalGeneration.forward`). If both pixel_values_videos and pixel_values are None, the input to Inputs_embedd will be the embedding of ``, rather than the embedding obtained from pixel_values through the vision_tower. **In this case, the reference model does not see any visual information, leading to an erroneous reference**.
44 | ### 2. Impact on KL Loss:
45 | The KL loss is affected because the KL divergence calculation in grpo relies on the formula KL($$\pi_{\theta},\pi_{ref}$$). Since the logps output from the reference model ($$\pi_{ref}$$) is incorrect, the KL divergence becomes problematic. Specifically, during initialization, the parameters of the model and the reference model are identical, meaning that $$\pi_{\theta}$$ and $$\pi_{ref}$$ should have the same values. **Therefore, the correct initial value of KL divergence should be 0. However, in R1-Video, due to the incorrect logp from the reference model, the initial value of KL divergence is not 0**.
46 |
47 |
48 |

49 |
50 |
51 | Since the weight ($$\beta$$) of kl_loss is only 0.04, the impact of this bug won't be particularly significant, but there will still be some effect.
52 |
53 |
54 | ## The fixed version
55 | To resolve the issue, you should add **prompt_inputs in the get_per_token_logps method for the reference model. This will fix the bug as of February 22, 2025.
56 | ```python
57 | per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
58 | per_token_logps = per_token_logps[:, prompt_length - 1 :]
59 | with torch.inference_mode():
60 | if self.ref_model is not None:
61 | """ Fix Bug
62 | From:
63 | ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids)
64 | To:
65 | ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
66 | """
67 | ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
68 | else:
69 | with self.accelerator.unwrap_model(model).disable_adapter():
70 | """ Fix Bug
71 | From:
72 | ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids)
73 | To:
74 | ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
75 | """
76 | ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
77 | ```
78 |
79 |
80 |
--------------------------------------------------------------------------------
/r1-video-grpo_trainer_fixbug.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import os
17 | import textwrap
18 | from collections import defaultdict
19 | from typing import Any, Callable, Optional, Union
20 |
21 | import torch
22 | import torch.utils.data
23 | import transformers
24 | from datasets import Dataset, IterableDataset
25 | from packaging import version
26 | from transformers import (
27 | AriaForConditionalGeneration,
28 | AriaProcessor,
29 | AutoModelForCausalLM,
30 | AutoModelForSequenceClassification,
31 | AutoProcessor,
32 | AutoTokenizer,
33 | GenerationConfig,
34 | PreTrainedModel,
35 | PreTrainedTokenizerBase,
36 | Qwen2VLForConditionalGeneration,
37 | Trainer,
38 | TrainerCallback,
39 | is_wandb_available,
40 | )
41 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
42 | from transformers.utils import is_peft_available
43 |
44 | from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
45 | from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
46 | from trl.trainer.grpo_config import GRPOConfig
47 | from trl.trainer.utils import generate_model_card, get_comet_experiment_url
48 |
49 |
50 | from qwen_vl_utils import process_vision_info
51 |
52 | if is_peft_available():
53 | from peft import PeftConfig, get_peft_model
54 |
55 | if is_wandb_available():
56 | import wandb
57 |
58 | # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
59 | # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
60 | RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
61 |
62 |
63 | class Qwen2VLGRPOTrainer(Trainer):
64 | """
65 | Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
66 | paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
67 |
68 | Example:
69 |
70 | ```python
71 | from datasets import load_dataset
72 | from trl import GRPOTrainer
73 |
74 | dataset = load_dataset("trl-lib/tldr", split="train")
75 |
76 | trainer = GRPOTrainer(
77 | model="Qwen/Qwen2-0.5B-Instruct",
78 | reward_funcs="weqweasdas/RM-Gemma-2B",
79 | train_dataset=dataset,
80 | )
81 |
82 | trainer.train()
83 | ```
84 |
85 | Args:
86 | model (`Union[str, PreTrainedModel]`):
87 | Model to be trained. Can be either:
88 |
89 | - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
90 | a path to a *directory* containing model weights saved using
91 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
92 | loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
93 | in `args.model_init_kwargs`.
94 | - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
95 | reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
96 | Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
97 | functions with the prompts and completions and sum the rewards. Can be either:
98 |
99 | - A single reward function, such as:
100 | - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
101 | path to a *directory* containing model weights saved using
102 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
103 | using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
104 | keyword arguments in `args.model_init_kwargs`.
105 | - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
106 | - A custom reward function: The function is provided with the prompts and the generated completions,
107 | plus any additional columns in the dataset. It should return a list of rewards. For more details, see
108 | [Using a custom reward function](#using-a-custom-reward-function).
109 | - A list of reward functions, where each item can independently be any of the above types. Mixing different
110 | types within the list (e.g., a string model ID and a custom reward function) is allowed.
111 | args ([`GRPOConfig`], *optional*, defaults to `None`):
112 | Configuration for this trainer. If `None`, a default configuration is used.
113 | train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
114 | Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
115 | ignored. The format of the samples can be either:
116 |
117 | - [Standard](dataset_formats#standard): Each sample contains plain text.
118 | - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
119 | and content).
120 | eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
121 | Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
122 | processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
123 | Processing class used to process the data. The padding side must be set to "left". If `None`, the
124 | processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
125 | reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
126 | Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
127 |
128 | - A single processing class: Used when `reward_funcs` contains only one reward function.
129 | - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
130 | If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
131 | `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
132 | For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
133 | the corresponding entries in `reward_processing_classes` are ignored.
134 | callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
135 | List of callbacks to customize the training loop. Will add those to the list of default callbacks
136 | detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
137 |
138 | If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
139 | method.
140 | optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
141 | A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
142 | model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
143 | peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
144 | PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
145 | """
146 |
147 | def __init__(
148 | self,
149 | model: Union[str, PreTrainedModel],
150 | reward_funcs: Union[RewardFunc, list[RewardFunc]],
151 | args: GRPOConfig = None,
152 | train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
153 | eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
154 | processing_class: Optional[PreTrainedTokenizerBase] = None,
155 | reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
156 | callbacks: Optional[list[TrainerCallback]] = None,
157 | optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
158 | peft_config: Optional["PeftConfig"] = None,
159 | max_pixels: Optional[int] = 12845056,
160 | min_pixels: Optional[int] = 3136,
161 | attn_implementation: str = "flash_attention_2",
162 | ):
163 | # Args
164 | if args is None:
165 | model_name = model if isinstance(model, str) else model.config._name_or_path
166 | model_name = model_name.split("/")[-1]
167 | args = GRPOConfig(f"{model_name}-GRPO")
168 |
169 | # Models
170 | # Trained model
171 | model_init_kwargs = args.model_init_kwargs or {}
172 | model_init_kwargs["attn_implementation"] = attn_implementation
173 | if isinstance(model, str):
174 | model_id = model
175 | torch_dtype = model_init_kwargs.get("torch_dtype")
176 | if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
177 | pass # torch_dtype is already a torch.dtype or "auto" or None
178 | elif isinstance(torch_dtype, str): # it's a str, but not "auto"
179 | torch_dtype = getattr(torch, torch_dtype)
180 | model_init_kwargs["torch_dtype"] = torch_dtype
181 | else:
182 | raise ValueError(
183 | "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
184 | f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
185 | )
186 | # Disable caching if gradient checkpointing is enabled (not supported)
187 | model_init_kwargs["use_cache"] = (
188 | False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
189 | )
190 | if "Qwen2-VL" in model_id:
191 | model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
192 | elif "Aria" in model_id:
193 | model_init_kwargs.pop("use_cache")
194 | model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
195 | else:
196 | model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
197 | else:
198 | model_id = model.config._name_or_path
199 | if args.model_init_kwargs is not None:
200 | raise ValueError(
201 | "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
202 | "This argument can only be used when the `model` argument is a string."
203 | )
204 |
205 | if peft_config is not None:
206 | model = get_peft_model(model, peft_config)
207 |
208 | # Reference model
209 | if is_deepspeed_zero3_enabled():
210 | if "Qwen2-VL" in model_id:
211 | self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
212 | elif "Aria" in model_id:
213 | self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
214 | else:
215 | self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
216 | elif peft_config is None:
217 | # If PEFT configuration is not provided, create a reference model based on the initial model.
218 | self.ref_model = create_reference_model(model)
219 | else:
220 | # If PEFT is used, the reference model is not needed since the adapter can be disabled
221 | # to revert to the initial model.
222 | self.ref_model = None
223 |
224 | # Processing class
225 | if processing_class is None:
226 | if "Qwen2-VL" in model_id or "Aria" in model_id:
227 | processing_class = AutoProcessor.from_pretrained(model_id)
228 | pad_token_id = processing_class.tokenizer.pad_token_id
229 | processing_class.pad_token_id = pad_token_id
230 | processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
231 | if "Qwen2-VL" in model_id:
232 | processing_class.image_processor.max_pixels = max_pixels
233 | processing_class.image_processor.min_pixels = min_pixels
234 | else:
235 | processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
236 | pad_token_id = processing_class.pad_token_id
237 |
238 | # Reward functions
239 | if not isinstance(reward_funcs, list):
240 | reward_funcs = [reward_funcs]
241 | for i, reward_func in enumerate(reward_funcs):
242 | if isinstance(reward_func, str):
243 | reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
244 | reward_func, num_labels=1, **model_init_kwargs
245 | )
246 | self.reward_funcs = reward_funcs
247 |
248 | # Reward processing class
249 | if reward_processing_classes is None:
250 | reward_processing_classes = [None] * len(reward_funcs)
251 | elif not isinstance(reward_processing_classes, list):
252 | reward_processing_classes = [reward_processing_classes]
253 | else:
254 | if len(reward_processing_classes) != len(reward_funcs):
255 | raise ValueError("The number of reward processing classes must match the number of reward functions.")
256 |
257 | for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
258 | if isinstance(reward_func, PreTrainedModel):
259 | if reward_processing_class is None:
260 | reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
261 | if reward_processing_class.pad_token_id is None:
262 | reward_processing_class.pad_token = reward_processing_class.eos_token
263 | # The reward model computes the reward for the latest non-padded token in the input sequence.
264 | # So it's important to set the pad token ID to the padding token ID of the processing class.
265 | reward_func.config.pad_token_id = reward_processing_class.pad_token_id
266 | reward_processing_classes[i] = reward_processing_class
267 | self.reward_processing_classes = reward_processing_classes
268 |
269 | # Data collator
270 | def data_collator(features): # No data collation is needed in GRPO
271 | return features
272 |
273 | # Training arguments
274 | self.max_prompt_length = args.max_prompt_length
275 | self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
276 | self.num_generations = args.num_generations # = G in the GRPO paper
277 | self.generation_config = GenerationConfig(
278 | max_new_tokens=self.max_completion_length,
279 | do_sample=True,
280 | temperature=1, # HACK
281 | num_return_sequences=self.num_generations,
282 | pad_token_id=pad_token_id,
283 | )
284 | self.beta = args.beta
285 |
286 | # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
287 | # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
288 | # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
289 | # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
290 | # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
291 | # This acts as a flag to indicate that the warning has already been issued.
292 | model.warnings_issued["estimate_tokens"] = True
293 |
294 | # Initialize the metrics
295 | self._metrics = defaultdict(list)
296 |
297 | super().__init__(
298 | model=model,
299 | args=args,
300 | data_collator=data_collator,
301 | train_dataset=train_dataset,
302 | eval_dataset=eval_dataset,
303 | processing_class=processing_class,
304 | callbacks=callbacks,
305 | optimizers=optimizers,
306 | )
307 |
308 | # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
309 | # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
310 | # self.model_accepts_loss_kwargs to False to enable scaling.
311 | self.model_accepts_loss_kwargs = False
312 |
313 | if self.ref_model is not None:
314 | if self.is_deepspeed_enabled:
315 | self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
316 | else:
317 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
318 |
319 | for i, reward_func in enumerate(self.reward_funcs):
320 | if isinstance(reward_func, PreTrainedModel):
321 | self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
322 |
323 | def _set_signature_columns_if_needed(self):
324 | # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
325 | # By default, this method sets `self._signature_columns` to the model's expected inputs.
326 | # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
327 | # Instead, we set them to the columns expected by the `training_step` method, hence the override.
328 | if self._signature_columns is None:
329 | self._signature_columns = ["prompt"]
330 |
331 | # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
332 | # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
333 | def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
334 | return inputs
335 |
336 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
337 | if return_outputs:
338 | raise ValueError("The GRPOTrainer does not support returning outputs")
339 |
340 | # TODO if "video" in inputs sample
341 | # import pdb; pdb.set_trace()
342 |
343 | prompts = [x["prompt"] for x in inputs]
344 | prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
345 |
346 | if "image" in inputs[0]:
347 | images = [x["image"] for x in inputs]
348 | elif "video" in inputs[0]:
349 | videos = [x["video"] for x in inputs]
350 | video_inputs = []
351 | for (inp_idx, inp) in enumerate(inputs):
352 | new_inp = inp.copy()
353 | new_inp['prompt'][0]['content'][0]['text'] = inputs[inp_idx]["video"]
354 | video_inputs.append(process_vision_info(new_inp["prompt"])[0])
355 |
356 | # import pdb; pdb.set_trace()
357 |
358 | prompt_inputs = self.processing_class(
359 | text=prompts_text,
360 | images=images if "image" in inputs[0] else None,
361 | videos=video_inputs if "video" in inputs[0] else None,
362 | return_tensors="pt",
363 | padding=True,
364 | padding_side="left",
365 | add_special_tokens=False,
366 | )
367 | prompt_inputs = super()._prepare_inputs(prompt_inputs)
368 |
369 | if self.max_prompt_length is not None:
370 | prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
371 | prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
372 |
373 | # Generate completions
374 | with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
375 | # prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
376 |
377 | # Generate N times, each generate one with the temp_generation_config , stack the output_ids to prompt_completion_ids, pad the empty places with number 151613
378 | num_generations = self.generation_config.num_return_sequences
379 | temp_generation_config = copy.deepcopy(self.generation_config)
380 | temp_generation_config.num_return_sequences = 1
381 |
382 | all_completions = []
383 |
384 | for i in range(num_generations): # -1 because we already have one generation
385 | completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config)
386 | all_completions.append(completion)
387 |
388 | # Stack all completions and pad if needed
389 | max_length = max(completion.size(1) for completion in all_completions)
390 | padded_completions = []
391 |
392 | for completion in all_completions:
393 | if completion.size(1) < max_length:
394 | padding = torch.full(
395 | (completion.size(0), max_length - completion.size(1)),
396 | self.processing_class.tokenizer.pad_token_id,
397 | dtype=completion.dtype,
398 | device=completion.device,
399 | )
400 | padded_completion = torch.cat([completion, padding], dim=1)
401 | else:
402 | padded_completion = completion
403 | padded_completions.append(padded_completion)
404 |
405 | # Stack all padded completions
406 | prompt_completion_ids = torch.cat(padded_completions, dim=0)
407 |
408 | prompt_length = prompt_inputs["input_ids"].size(1)
409 | completion_ids = prompt_completion_ids[:, prompt_length:]
410 |
411 | # import pdb; pdb.set_trace()
412 |
413 | # Get the per-token log probabilities for the completions for the model and the reference model
414 | def get_per_token_logps(model, input_ids, **kwargs):
415 | logits = model(input_ids, **kwargs).logits # (B, L, V)
416 | logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
417 | input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
418 | # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
419 | per_token_logps = []
420 | for logits_row, input_ids_row in zip(logits, input_ids):
421 | log_probs = logits_row.log_softmax(dim=-1)
422 | token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
423 | per_token_logps.append(token_log_prob)
424 | return torch.stack(per_token_logps)
425 |
426 | prompt_inputs.pop("input_ids")
427 | prompt_inputs.pop("attention_mask")
428 | # Okay I am assuming that the inputs are Qwen2VL processor
429 | # and no video for now, repeat the image for each completion
430 | if "image" in inputs[0]:
431 | prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
432 | prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
433 | # import pdb; pdb.set_trace()
434 |
435 | # XXX if input video
436 | # image_grid_thw is from image_process_qwen2_vl
437 | # https://github.com/huggingface/transformers/blob/dd16acb8a3e93b643aa374c9fb80749f5235c1a6/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L414
438 | # automatic process
439 | if "video" in inputs[0]:
440 | prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
441 | prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
442 |
443 |
444 | per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
445 | # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
446 | per_token_logps = per_token_logps[:, prompt_length - 1 :]
447 |
448 | with torch.inference_mode():
449 | if self.ref_model is not None:
450 | """ Fix Bug
451 | From:
452 | ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids)
453 | To:
454 | ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
455 | """
456 | ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
457 | else:
458 | with self.accelerator.unwrap_model(model).disable_adapter():
459 | """ Fix Bug
460 | From:
461 | ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids)
462 | To:
463 | ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
464 | """
465 | ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
466 | ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
467 |
468 | # Compute the KL divergence between the model and the reference model
469 | diff = ref_per_token_logps - per_token_logps
470 | diff = torch.clamp(diff, min=-11.0, max=11.0)
471 |
472 | per_token_kl = torch.exp(diff) - (diff) - 1
473 |
474 | # Mask everything after the first EOS token
475 | is_eos = completion_ids == self.processing_class.eos_token_id
476 | device = self.accelerator.device
477 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
478 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
479 | sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
480 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
481 |
482 | # Decode the generated completions
483 | completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
484 | if is_conversational(inputs[0]):
485 | completions = [[{"role": "assistant", "content": completion}] for completion in completions]
486 |
487 | # Compute the rewards
488 | prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
489 |
490 | rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
491 | for i, (reward_func, reward_processing_class) in enumerate(
492 | zip(self.reward_funcs, self.reward_processing_classes)
493 | ):
494 | # import pdb; pdb.set_trace()
495 | if isinstance(reward_func, PreTrainedModel):
496 | if is_conversational(inputs[0]): # true
497 | messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
498 | texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
499 | else:
500 | texts = [p + c for p, c in zip(prompts, completions)]
501 | reward_inputs = reward_processing_class(
502 | texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
503 | )
504 | reward_inputs = super()._prepare_inputs(reward_inputs)
505 | with torch.inference_mode():
506 | rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
507 | else:
508 | # Repeat all input columns (but "prompt" and "completion") to match the number of generations
509 | reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
510 | for key in reward_kwargs:
511 | for example in inputs:
512 | # Repeat each value in the column for `num_generations` times
513 | reward_kwargs[key].extend([example[key]] * self.num_generations)
514 | output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
515 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
516 |
517 | # Sum the rewards from all reward functions
518 | rewards = rewards_per_func.sum(dim=1)
519 |
520 | # Compute grouped-wise rewards
521 | mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
522 | std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
523 |
524 | # Normalize the rewards to compute the advantages
525 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
526 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
527 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
528 |
529 | # x - x.detach() allows for preserving gradients from x
530 | per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
531 | per_token_loss = -(per_token_loss - self.beta * per_token_kl) # default 0.04
532 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
533 |
534 | # import pdb; pdb.set_trace()
535 |
536 | # Log the metrics
537 | completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
538 | self._metrics["completion_length"].append(completion_length)
539 |
540 | reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
541 | for i, reward_func in enumerate(self.reward_funcs):
542 | if isinstance(reward_func, PreTrainedModel):
543 | reward_func_name = reward_func.config._name_or_path.split("/")[-1]
544 | else:
545 | reward_func_name = reward_func.__name__
546 | self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
547 |
548 | self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
549 |
550 | self._metrics["advantages"].append(self.accelerator.gather_for_metrics(advantages).mean().item())
551 |
552 | self._metrics["reward_mean"].append(self.accelerator.gather_for_metrics(mean_grouped_rewards).mean().item())
553 |
554 | self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
555 |
556 | mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
557 | self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
558 |
559 | # import pdb; pdb.set_trace()
560 |
561 | return loss
562 |
563 | def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
564 | metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
565 | logs = {**logs, **metrics}
566 | if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
567 | super().log(logs, start_time)
568 | else: # transformers<=4.46
569 | super().log(logs)
570 | self._metrics.clear()
571 |
572 | def create_model_card(
573 | self,
574 | model_name: Optional[str] = None,
575 | dataset_name: Optional[str] = None,
576 | tags: Union[str, list[str], None] = None,
577 | ):
578 | """
579 | Creates a draft of a model card using the information available to the `Trainer`.
580 |
581 | Args:
582 | model_name (`str` or `None`, *optional*, defaults to `None`):
583 | Name of the model.
584 | dataset_name (`str` or `None`, *optional*, defaults to `None`):
585 | Name of the dataset used for training.
586 | tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
587 | Tags to be associated with the model card.
588 | """
589 | if not self.is_world_process_zero():
590 | return
591 |
592 | if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
593 | base_model = self.model.config._name_or_path
594 | else:
595 | base_model = None
596 |
597 | tags = tags or []
598 | if isinstance(tags, str):
599 | tags = [tags]
600 |
601 | if hasattr(self.model.config, "unsloth_version"):
602 | tags.append("unsloth")
603 |
604 | citation = textwrap.dedent(
605 | """\
606 | @article{zhihong2024deepseekmath,
607 | title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
608 | author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
609 | year = 2024,
610 | eprint = {arXiv:2402.03300},
611 | """
612 | )
613 |
614 | model_card = generate_model_card(
615 | base_model=base_model,
616 | model_name=model_name,
617 | hub_model_id=self.hub_model_id,
618 | dataset_name=dataset_name,
619 | tags=tags,
620 | wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
621 | comet_url=get_comet_experiment_url(),
622 | trainer_name="GRPO",
623 | trainer_citation=citation,
624 | paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
625 | paper_id="2402.03300",
626 | )
627 |
628 | model_card.save(os.path.join(self.args.output_dir, "README.md"))
--------------------------------------------------------------------------------