├── eval ├── utils │ ├── __init__.py │ ├── data_loaders.py │ └── processing.py ├── example_eval.sh └── main.py ├── assets ├── teaser_comparison.png ├── noisyrollout_workflow.png └── noisyrollout_workflow_caption.png ├── requirements.txt ├── Makefile ├── verl ├── utils │ ├── __init__.py │ ├── logger │ │ ├── __init__.py │ │ └── aggregate_logger.py │ ├── checkpoint │ │ ├── __init__.py │ │ ├── checkpoint_manager.py │ │ └── fsdp_checkpoint_manager.py │ ├── reward_score │ │ ├── __init__.py │ │ ├── math.py │ │ └── r1v.py │ ├── py_functional.py │ ├── tokenizer.py │ ├── torch_dtypes.py │ ├── model_utils.py │ ├── image_aug.py │ ├── fsdp_utils.py │ ├── flops_counter.py │ ├── tracking.py │ └── dataset.py ├── models │ ├── __init__.py │ ├── transformers │ │ ├── __init__.py │ │ ├── qwen2_vl.py │ │ └── flash_attention_utils.py │ └── monkey_patch.py ├── trainer │ ├── __init__.py │ ├── config.py │ ├── main.py │ └── metrics.py ├── workers │ ├── __init__.py │ ├── rollout │ │ ├── __init__.py │ │ ├── vllm_rollout │ │ │ ├── __init__.py │ │ │ └── vllm_rollout_spmd.py │ │ ├── base.py │ │ └── config.py │ ├── reward │ │ ├── __init__.py │ │ ├── config.py │ │ └── custom.py │ ├── critic │ │ ├── __init__.py │ │ ├── base.py │ │ └── config.py │ ├── sharding_manager │ │ ├── __init__.py │ │ ├── base.py │ │ ├── fsdp_ulysses.py │ │ └── fsdp_vllm.py │ ├── actor │ │ ├── __init__.py │ │ ├── base.py │ │ └── config.py │ └── config.py ├── single_controller │ ├── __init__.py │ ├── base │ │ ├── register_center │ │ │ ├── __init__.py │ │ │ └── ray.py │ │ ├── __init__.py │ │ ├── worker.py │ │ ├── decorator.py │ │ └── worker_group.py │ └── ray │ │ └── __init__.py └── __init__.py ├── pyproject.toml ├── training_scripts ├── README.md ├── qwen2_5_vl_7b_geo3k_grpo.sh ├── qwen2_5_vl_7b_k12_grpo.sh ├── qwen2_5_vl_7b_geo3k_noisyrollout.sh ├── qwen2_5_vl_7b_k12_noisyrollout.sh └── config.yaml ├── setup.py ├── .gitignore ├── README.md └── scripts └── model_merger.py /eval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/teaser_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-TRAIL/NoisyRollout/HEAD/assets/teaser_comparison.png -------------------------------------------------------------------------------- /assets/noisyrollout_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-TRAIL/NoisyRollout/HEAD/assets/noisyrollout_workflow.png -------------------------------------------------------------------------------- /assets/noisyrollout_workflow_caption.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-TRAIL/NoisyRollout/HEAD/assets/noisyrollout_workflow_caption.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | codetiming 3 | datasets 4 | flash-attn>=2.4.3 5 | liger-kernel 6 | mathruler 7 | numpy 8 | omegaconf 9 | pandas 10 | peft 11 | pillow 12 | pyarrow>=15.0.0 13 | pylatexenc 14 | qwen-vl-utils 15 | ray 16 | tensordict 17 | torchdata 18 | transformers==4.49.0 19 | wandb 20 | vllm>=0.7.3 21 | torch==2.5.1 22 | torchvision==0.20.1 23 | torchaudio==2.5.1 24 | numpy==1.26.4 -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build commit quality style 2 | 3 | check_dirs := scripts verl setup.py 4 | 5 | build: 6 | python3 setup.py sdist bdist_wheel 7 | 8 | commit: 9 | pre-commit install 10 | pre-commit run --all-files 11 | 12 | quality: 13 | ruff check $(check_dirs) 14 | ruff format --check $(check_dirs) 15 | 16 | style: 17 | ruff check $(check_dirs) --fix 18 | ruff format $(check_dirs) 19 | -------------------------------------------------------------------------------- /verl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/single_controller/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/utils/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/single_controller/base/register_center/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .protocol import DataProto 16 | 17 | 18 | __all__ = ["DataProto"] 19 | __version__ = "0.2.0.dev" 20 | -------------------------------------------------------------------------------- /verl/workers/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .config import RolloutConfig 17 | 18 | 19 | __all__ = ["RolloutConfig"] 20 | -------------------------------------------------------------------------------- /verl/workers/rollout/vllm_rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .vllm_rollout_spmd import vLLMRollout 16 | 17 | 18 | __all__ = ["vLLMRollout"] 19 | -------------------------------------------------------------------------------- /verl/workers/reward/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .config import RewardConfig 16 | from .custom import CustomRewardManager 17 | 18 | 19 | __all__ = ["CustomRewardManager", "RewardConfig"] 20 | -------------------------------------------------------------------------------- /verl/workers/reward/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Reward config 16 | """ 17 | 18 | from dataclasses import dataclass 19 | 20 | 21 | @dataclass 22 | class RewardConfig: 23 | reward_type: str = "function" 24 | compute_score: str = "math" 25 | -------------------------------------------------------------------------------- /verl/single_controller/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .worker import Worker 16 | from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup 17 | 18 | 19 | __all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"] 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "verl" 7 | dynamic = ["version", "dependencies", "optional-dependencies", "readme", "license"] 8 | requires-python = ">=3.8" 9 | 10 | [tool.ruff] 11 | target-version = "py38" 12 | line-length = 119 13 | indent-width = 4 14 | 15 | [tool.ruff.lint] 16 | ignore = ["C901", "E501", "E741", "W605", "C408"] 17 | select = ["C", "E", "F", "I", "W", "RUF022"] 18 | 19 | [tool.ruff.lint.per-file-ignores] 20 | "__init__.py" = ["E402", "F401", "F403", "F811"] 21 | 22 | [tool.ruff.lint.isort] 23 | lines-after-imports = 2 24 | known-first-party = ["verl"] 25 | known-third-party = ["torch", "transformers", "wandb"] 26 | 27 | [tool.ruff.format] 28 | quote-style = "double" 29 | indent-style = "space" 30 | skip-magic-trailing-comma = false 31 | line-ending = "auto" 32 | -------------------------------------------------------------------------------- /verl/utils/reward_score/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .math import math_compute_score, pure_math_compute_score 17 | from .r1v import r1v_compute_score 18 | 19 | 20 | __all__ = ["math_compute_score", "r1v_compute_score", "pure_math_compute_score"] 21 | -------------------------------------------------------------------------------- /verl/single_controller/ray/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls 16 | 17 | 18 | __all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"] 19 | -------------------------------------------------------------------------------- /verl/workers/critic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOCritic 16 | from .config import CriticConfig, ModelConfig 17 | from .dp_critic import DataParallelPPOCritic 18 | 19 | 20 | __all__ = ["BasePPOCritic", "CriticConfig", "DataParallelPPOCritic", "ModelConfig"] 21 | -------------------------------------------------------------------------------- /training_scripts/README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | * In our current implementation of NoisyRollout, $n_1$ and $n_2$ are set to the same value. If you set `worker.rollout.n` to 6 in our NoisyRollout training scripts, it means $n_1=n_2=6$, resulting in a total of 12 rollouts. 6 | * We empirically freeze the vision encoder and remove the KL loss for better performance. 7 | * For `worker.actor.aug_type`, please only use `gaussian` for now, as the performance of other choices is not guaranteed. 8 | * For `worker.actor.gaussian_noise_step` (i.e., initial noise strength $\alpha_0$ in our paper), it's a critical hyper-parameter when applying NoisyRollout to other datasets. Suggested values include 400, 450, and 500. 9 | * We also suggest setting `worker.actor.decay_mode=sigmoid` with an appropriate `worker.actor.decay_decay_sig_mid_step` for stable training. -------------------------------------------------------------------------------- /verl/workers/sharding_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .base import BaseShardingManager 17 | from .fsdp_ulysses import FSDPUlyssesShardingManager 18 | from .fsdp_vllm import FSDPVLLMShardingManager 19 | 20 | 21 | __all__ = ["BaseShardingManager", "FSDPUlyssesShardingManager", "FSDPVLLMShardingManager"] 22 | -------------------------------------------------------------------------------- /verl/workers/rollout/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | 17 | from ...protocol import DataProto 18 | 19 | 20 | __all__ = ["BaseRollout"] 21 | 22 | 23 | class BaseRollout(ABC): 24 | @abstractmethod 25 | def generate_sequences(self, prompts: DataProto) -> DataProto: 26 | """Generate sequences""" 27 | pass 28 | -------------------------------------------------------------------------------- /eval/example_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/.bashrc 3 | source ~/miniconda3/bin/activate noisyrollout 4 | 5 | export VLLM_ATTENTION_BACKEND=XFORMERS 6 | export VLLM_USE_V1=0 7 | export GOOGLE_API_KEY="xxx" 8 | 9 | # Define list of model paths to evaluate 10 | HF_MODEL_PATH="" 11 | RESULTS_DIR="results/" 12 | EVAL_DIR="~/NoisyRollout/eval" 13 | DATA_DIR="~/NoisyRollout/eval/data" 14 | 15 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer. 16 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}.""" 17 | 18 | cd $EVAL_DIR 19 | python main.py \ 20 | --model $HF_MODEL_PATH \ 21 | --output-dir $RESULTS_DIR \ 22 | --data-path $DATA_DIR \ 23 | --datasets geo3k,hallubench,mathvista,wemath,mathverse,mathvision \ 24 | --tensor-parallel-size 2 \ 25 | --system-prompt="$SYSTEM_PROMPT" \ 26 | --min-pixels 262144 \ 27 | --max-pixels 1000000 \ 28 | --max-model-len 8192 \ 29 | --temperature 0.0 \ 30 | --eval-threads 24 -------------------------------------------------------------------------------- /verl/workers/actor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOActor 16 | from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig 17 | from .dp_actor import DataParallelPPOActor 18 | 19 | 20 | __all__ = [ 21 | "ActorConfig", 22 | "BasePPOActor", 23 | "DataParallelPPOActor", 24 | "FSDPConfig", 25 | "ModelConfig", 26 | "OptimConfig", 27 | "RefConfig", 28 | ] 29 | -------------------------------------------------------------------------------- /verl/single_controller/base/register_center/ray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | 18 | @ray.remote 19 | class WorkerGroupRegisterCenter: 20 | def __init__(self, rank_zero_info): 21 | self.rank_zero_info = rank_zero_info 22 | 23 | def get_rank_zero_info(self): 24 | return self.rank_zero_info 25 | 26 | 27 | def create_worker_group_register_center(name, info): 28 | return WorkerGroupRegisterCenter.options(name=name).remote(info) 29 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Sharding manager to implement HybridEngine 16 | """ 17 | 18 | from ...protocol import DataProto 19 | 20 | 21 | class BaseShardingManager: 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, exc_type, exc_value, traceback): 26 | pass 27 | 28 | def preprocess_data(self, data: DataProto) -> DataProto: 29 | return data 30 | 31 | def postprocess_data(self, data: DataProto) -> DataProto: 32 | return data 33 | -------------------------------------------------------------------------------- /verl/workers/critic/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Base class for Critic 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, Dict 20 | 21 | import torch 22 | 23 | from ...protocol import DataProto 24 | from .config import CriticConfig 25 | 26 | 27 | __all__ = ["BasePPOCritic"] 28 | 29 | 30 | class BasePPOCritic(ABC): 31 | def __init__(self, config: CriticConfig): 32 | self.config = config 33 | 34 | @abstractmethod 35 | def compute_values(self, data: DataProto) -> torch.Tensor: 36 | """Compute values""" 37 | pass 38 | 39 | @abstractmethod 40 | def update_critic(self, data: DataProto) -> Dict[str, Any]: 41 | """Update the critic""" 42 | pass 43 | -------------------------------------------------------------------------------- /verl/utils/logger/aggregate_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A Ray logger will receive logging info from different processes. 16 | """ 17 | 18 | import numbers 19 | from typing import Any, Dict 20 | 21 | 22 | def concat_dict_to_str(dict: Dict[str, Any], step: int) -> str: 23 | output = [f"step {step}:"] 24 | for k, v in dict.items(): 25 | if isinstance(v, numbers.Number): 26 | output.append(f"{k}:{v:.3f}") 27 | 28 | output_str = " - ".join(output) 29 | return output_str 30 | 31 | 32 | class LocalLogger: 33 | def __init__(self): 34 | pass 35 | 36 | def flush(self): 37 | pass 38 | 39 | def log(self, data: Dict[str, Any], step: int) -> None: 40 | print(concat_dict_to_str(data, step=step), flush=True) 41 | -------------------------------------------------------------------------------- /verl/utils/py_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contain small python utility functions 16 | """ 17 | 18 | from typing import Any, Dict, List 19 | 20 | 21 | def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: 22 | """Union two dict. Will throw an error if there is an item not the same object with the same key.""" 23 | for key in dict2.keys(): 24 | if key in dict1: 25 | assert dict1[key] == dict2[key], f"{key} in meta_dict1 and meta_dict2 are not the same object" 26 | 27 | dict1[key] = dict2[key] 28 | 29 | return dict1 30 | 31 | 32 | def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None: 33 | for key, val in new_data.items(): 34 | if key not in data: 35 | data[key] = [] 36 | 37 | data[key].append(val) 38 | -------------------------------------------------------------------------------- /verl/utils/reward_score/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | from mathruler.grader import extract_boxed_content, grade_answer 18 | 19 | 20 | def math_format_reward(predict_str: str) -> float: 21 | pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) 22 | format_match = re.fullmatch(pattern, predict_str) 23 | return 1.0 if format_match else 0.0 24 | 25 | 26 | def math_acc_reward(predict_str: str, ground_truth: str) -> float: 27 | answer = extract_boxed_content(predict_str) 28 | return 1.0 if grade_answer(answer, ground_truth) else 0.0 29 | 30 | 31 | def pure_math_compute_score(predict_str: str, ground_truth: str) -> float: 32 | return math_acc_reward(predict_str, ground_truth) 33 | 34 | 35 | def math_compute_score(predict_str: str, ground_truth: str) -> float: 36 | return 0.9 * math_acc_reward(predict_str, ground_truth) + 0.1 * math_format_reward(predict_str) -------------------------------------------------------------------------------- /verl/workers/rollout/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Rollout config 16 | """ 17 | 18 | from dataclasses import asdict, dataclass, field 19 | 20 | 21 | @dataclass 22 | class RolloutConfig: 23 | name: str = "vllm" 24 | temperature: float = 1.0 25 | top_k: int = -1 26 | top_p: float = 1.0 27 | dtype: str = "bf16" 28 | gpu_memory_utilization: float = 0.5 29 | ignore_eos: bool = False 30 | enforce_eager: bool = False 31 | free_cache_engine: bool = False 32 | enable_chunked_prefill: bool = False 33 | tensor_parallel_size: int = 2 34 | max_num_batched_tokens: int = 8192 35 | max_num_seqs: int = 1024 36 | disable_log_stats: bool = True 37 | do_sample: bool = True 38 | n: int = 1 39 | limit_images: int = 0 40 | """auto keys""" 41 | prompt_length: int = field(default=-1, init=False) 42 | response_length: int = field(default=-1, init=False) 43 | 44 | def to_dict(self): 45 | return asdict(self) 46 | -------------------------------------------------------------------------------- /verl/models/monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS 17 | 18 | from .transformers.flash_attention_utils import flash_attention_forward 19 | from .transformers.qwen2_vl import qwen2_vl_attn_forward 20 | 21 | 22 | def apply_ulysses_patch(model_type: str) -> None: 23 | if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"): 24 | ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward 25 | elif model_type in ("qwen2_vl", "qwen2_5_vl"): 26 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 27 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 28 | 29 | Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward 30 | Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward 31 | else: 32 | raise NotImplementedError(f"Model architecture {model_type} is not supported yet.") 33 | -------------------------------------------------------------------------------- /verl/workers/critic/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Critic config 16 | """ 17 | 18 | from dataclasses import dataclass, field 19 | 20 | from ..actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig 21 | 22 | 23 | @dataclass 24 | class CriticConfig: 25 | strategy: str = "fsdp" 26 | global_batch_size: int = 256 27 | micro_batch_size_per_device_for_update: int = 4 28 | micro_batch_size_per_device_for_experience: int = 16 29 | max_grad_norm: float = 1.0 30 | cliprange_value: float = 0.5 31 | ppo_epochs: int = 1 32 | padding_free: bool = False 33 | ulysses_sequence_parallel_size: int = 1 34 | model: ModelConfig = field(default_factory=ModelConfig) 35 | optim: OptimConfig = field(default_factory=OptimConfig) 36 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 37 | offload: OffloadConfig = field(default_factory=OffloadConfig) 38 | """auto keys""" 39 | global_batch_size_per_device: int = field(default=-1, init=False) 40 | -------------------------------------------------------------------------------- /verl/utils/reward_score/r1v.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | from mathruler.grader import grade_answer 18 | 19 | 20 | def r1v_format_reward(predict_str: str) -> float: 21 | pattern = re.compile(r".*?\s*.*?", re.DOTALL) 22 | format_match = re.fullmatch(pattern, predict_str) 23 | return 1.0 if format_match else 0.0 24 | 25 | 26 | def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float: 27 | try: 28 | ground_truth = ground_truth.strip() 29 | content_match = re.search(r"(.*?)", predict_str) 30 | given_answer = content_match.group(1).strip() if content_match else predict_str.strip() 31 | if grade_answer(given_answer, ground_truth): 32 | return 1.0 33 | except Exception: 34 | pass 35 | 36 | return 0.0 37 | 38 | 39 | def r1v_compute_score(predict_str: str, ground_truth: str) -> float: 40 | return 0.5 * r1v_accuracy_reward(predict_str, ground_truth) + 0.5 * r1v_format_reward(predict_str) 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from setuptools import find_packages, setup 17 | 18 | 19 | def get_requires(): 20 | with open("requirements.txt", encoding="utf-8") as f: 21 | file_content = f.read() 22 | lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] 23 | return lines 24 | 25 | 26 | extra_require = { 27 | "dev": ["pre-commit", "ruff"], 28 | } 29 | 30 | 31 | def main(): 32 | setup( 33 | name="verl", 34 | version="0.2.0.dev0", 35 | package_dir={"": "."}, 36 | packages=find_packages(where="."), 37 | url="https://github.com/volcengine/verl", 38 | license="Apache 2.0", 39 | author="verl", 40 | author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn", 41 | description="", 42 | install_requires=get_requires(), 43 | extras_require=extra_require, 44 | long_description=open("README.md", encoding="utf-8").read(), 45 | long_description_content_type="text/markdown", 46 | ) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /training_scripts/qwen2_5_vl_7b_geo3k_grpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | source ~/.bashrc 4 | source ~/miniconda3/bin/activate noisyrollout 5 | cd ~/NoisyRollout 6 | 7 | export VLLM_ATTENTION_BACKEND=XFORMERS 8 | export VLLM_USE_V1=0 9 | 10 | export WANDB_BASE_URL=https://api.wandb.ai 11 | export WANDB_API_KEY="xxx" 12 | 13 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct 14 | EXPERIMENT_NAME=qwen2_5_vl_7b_geo3k_grpo 15 | PROJECT_NAME=noisy_rollout 16 | CHECKPOINT_DIR="checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}" 17 | 18 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer. 19 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}.""" 20 | 21 | python3 -m verl.trainer.main \ 22 | config=training_scripts/config.yaml \ 23 | data.train_files=xyliu6/geometry3k@train \ 24 | data.val_files=xyliu6/geometry3k@test \ 25 | data.system_prompt="${SYSTEM_PROMPT}" \ 26 | data.max_response_length=2048 \ 27 | data.max_pixels=1000000 \ 28 | worker.actor.micro_batch_size_per_device_for_update=2 \ 29 | worker.actor.micro_batch_size_per_device_for_experience=4 \ 30 | worker.actor.model.freeze_vision_tower=true \ 31 | worker.actor.model.model_path=${MODEL_PATH} \ 32 | worker.actor.use_kl_loss=false \ 33 | worker.actor.offload.offload_params=true \ 34 | worker.actor.offload.offload_optimizer=true \ 35 | worker.reward.compute_score=math \ 36 | worker.rollout.gpu_memory_utilization=0.35 \ 37 | worker.rollout.tensor_parallel_size=4 \ 38 | worker.rollout.n=12 \ 39 | worker.rollout.enable_chunked_prefill=false \ 40 | trainer.experiment_name=${EXPERIMENT_NAME} \ 41 | trainer.project_name=${PROJECT_NAME} \ 42 | trainer.n_gpus_per_node=8 \ 43 | trainer.save_freq=20 \ 44 | worker.actor.is_noisy=false 45 | -------------------------------------------------------------------------------- /training_scripts/qwen2_5_vl_7b_k12_grpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | source ~/.bashrc 4 | source ~/miniconda3/bin/activate noisyrollout 5 | cd ~/NoisyRollout 6 | 7 | export VLLM_ATTENTION_BACKEND=XFORMERS 8 | export VLLM_USE_V1=0 9 | 10 | export WANDB_BASE_URL=https://api.wandb.ai 11 | export WANDB_API_KEY="xxx" 12 | 13 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct 14 | EXPERIMENT_NAME=qwen2_5_vl_7b_k12_grpo 15 | PROJECT_NAME=noisy_rollout 16 | CHECKPOINT_DIR="checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}" 17 | 18 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer. 19 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}.""" 20 | 21 | python3 -m verl.trainer.main \ 22 | config=training_scripts/config.yaml \ 23 | data.train_files=xyliu6/k12-freeform@mini_train \ 24 | data.val_files=xyliu6/k12-freeform@test \ 25 | data.system_prompt="${SYSTEM_PROMPT}" \ 26 | data.max_response_length=2048 \ 27 | data.max_pixels=1000000 \ 28 | worker.actor.micro_batch_size_per_device_for_update=2 \ 29 | worker.actor.micro_batch_size_per_device_for_experience=4 \ 30 | worker.actor.model.freeze_vision_tower=true \ 31 | worker.actor.model.model_path=${MODEL_PATH} \ 32 | worker.actor.use_kl_loss=false \ 33 | worker.actor.offload.offload_params=true \ 34 | worker.actor.offload.offload_optimizer=true \ 35 | worker.reward.compute_score=math \ 36 | worker.rollout.gpu_memory_utilization=0.35 \ 37 | worker.rollout.tensor_parallel_size=4 \ 38 | worker.rollout.n=12 \ 39 | worker.rollout.enable_chunked_prefill=false \ 40 | trainer.experiment_name=${EXPERIMENT_NAME} \ 41 | trainer.project_name=${PROJECT_NAME} \ 42 | trainer.n_gpus_per_node=8 \ 43 | trainer.save_freq=20 \ 44 | worker.actor.is_noisy=false -------------------------------------------------------------------------------- /verl/workers/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | ActorRolloutRef config 16 | """ 17 | 18 | from dataclasses import dataclass, field 19 | 20 | from .actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig 21 | from .critic import CriticConfig 22 | from .reward import RewardConfig 23 | from .rollout import RolloutConfig 24 | 25 | 26 | __all__ = [ 27 | "ActorConfig", 28 | "CriticConfig", 29 | "FSDPConfig", 30 | "ModelConfig", 31 | "OptimConfig", 32 | "RefConfig", 33 | "RewardConfig", 34 | "RolloutConfig", 35 | "WorkerConfig", 36 | ] 37 | 38 | 39 | @dataclass 40 | class WorkerConfig: 41 | hybrid_engine: bool = True 42 | actor: ActorConfig = field(default_factory=ActorConfig) 43 | critic: CriticConfig = field(default_factory=CriticConfig) 44 | ref: RefConfig = field(default_factory=RefConfig) 45 | reward: RewardConfig = field(default_factory=RewardConfig) 46 | rollout: RolloutConfig = field(default_factory=RolloutConfig) 47 | 48 | def post_init(self): 49 | self.ref.padding_free = self.actor.padding_free 50 | self.ref.ulysses_sequence_parallel_size = self.actor.ulysses_sequence_parallel_size 51 | self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience 52 | -------------------------------------------------------------------------------- /training_scripts/qwen2_5_vl_7b_geo3k_noisyrollout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | source ~/.bashrc 4 | source ~/miniconda3/bin/activate noisyrollout 5 | cd ~/NoisyRollout 6 | 7 | export VLLM_ATTENTION_BACKEND=XFORMERS 8 | export VLLM_USE_V1=0 9 | 10 | export WANDB_BASE_URL=https://api.wandb.ai 11 | export WANDB_API_KEY="xxx" 12 | 13 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct 14 | EXPERIMENT_NAME=qwen2_5_vl_7b_geo3k_noisyrollout 15 | PROJECT_NAME=noisy_rollout 16 | CHECKPOINT_DIR="checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}" 17 | 18 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer. 19 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}.""" 20 | 21 | python3 -m verl.trainer.main \ 22 | config=training_scripts/config.yaml \ 23 | data.train_files=xyliu6/geometry3k@train \ 24 | data.val_files=xyliu6/geometry3k@test \ 25 | data.system_prompt="${SYSTEM_PROMPT}" \ 26 | data.max_response_length=2048 \ 27 | data.max_pixels=1000000 \ 28 | worker.actor.micro_batch_size_per_device_for_update=2 \ 29 | worker.actor.micro_batch_size_per_device_for_experience=4 \ 30 | worker.actor.model.freeze_vision_tower=true \ 31 | worker.actor.model.model_path=${MODEL_PATH} \ 32 | worker.actor.use_kl_loss=false \ 33 | worker.actor.offload.offload_params=true \ 34 | worker.actor.offload.offload_optimizer=true \ 35 | worker.reward.compute_score=math \ 36 | worker.rollout.gpu_memory_utilization=0.35 \ 37 | worker.rollout.tensor_parallel_size=4 \ 38 | worker.rollout.n=6 \ 39 | worker.rollout.enable_chunked_prefill=false \ 40 | trainer.experiment_name=${EXPERIMENT_NAME} \ 41 | trainer.project_name=${PROJECT_NAME} \ 42 | trainer.n_gpus_per_node=8 \ 43 | trainer.save_freq=20 \ 44 | worker.actor.is_noisy=true \ 45 | worker.actor.aug_type=gaussian \ 46 | worker.actor.gaussian_noise_step=500 \ 47 | worker.actor.decay_mode=sigmoid \ 48 | worker.actor.decay_coef=30 \ 49 | worker.actor.decay_sig_mid_step=40 50 | 51 | -------------------------------------------------------------------------------- /training_scripts/qwen2_5_vl_7b_k12_noisyrollout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | source ~/.bashrc 4 | source ~/miniconda3/bin/activate noisyrollout 5 | cd ~/NoisyRollout 6 | 7 | export VLLM_ATTENTION_BACKEND=XFORMERS 8 | export VLLM_USE_V1=0 9 | 10 | export WANDB_BASE_URL=https://api.wandb.ai 11 | export WANDB_API_KEY="xxx" 12 | 13 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct 14 | EXPERIMENT_NAME=qwen2_5_vl_7b_k12_noisyrollout 15 | PROJECT_NAME=noisy_rollout 16 | CHECKPOINT_DIR="checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}" 17 | 18 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer. 19 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}.""" 20 | 21 | python3 -m verl.trainer.main \ 22 | config=training_scripts/config.yaml \ 23 | data.train_files=xyliu6/k12-freeform@train \ 24 | data.val_files=xyliu6/k12-freeform@test \ 25 | data.system_prompt="${SYSTEM_PROMPT}" \ 26 | data.max_response_length=2048 \ 27 | data.max_pixels=1000000 \ 28 | worker.actor.micro_batch_size_per_device_for_update=2 \ 29 | worker.actor.micro_batch_size_per_device_for_experience=4 \ 30 | worker.actor.model.freeze_vision_tower=true \ 31 | worker.actor.model.model_path=${MODEL_PATH} \ 32 | worker.actor.use_kl_loss=false \ 33 | worker.actor.offload.offload_params=true \ 34 | worker.actor.offload.offload_optimizer=true \ 35 | worker.reward.compute_score=math \ 36 | worker.rollout.gpu_memory_utilization=0.35 \ 37 | worker.rollout.tensor_parallel_size=4 \ 38 | worker.rollout.n=6 \ 39 | worker.rollout.enable_chunked_prefill=false \ 40 | trainer.experiment_name=${EXPERIMENT_NAME} \ 41 | trainer.project_name=${PROJECT_NAME} \ 42 | trainer.n_gpus_per_node=8 \ 43 | trainer.save_freq=20 \ 44 | trainer.total_episodes=10 \ 45 | worker.actor.is_noisy=true \ 46 | worker.actor.aug_type=gaussian \ 47 | worker.actor.gaussian_noise_step=450 \ 48 | worker.actor.decay_mode=sigmoid \ 49 | worker.actor.decay_coef=60 \ 50 | worker.actor.decay_sig_mid_step=40 51 | 52 | -------------------------------------------------------------------------------- /verl/workers/actor/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for Actor 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, Dict 20 | 21 | import torch 22 | 23 | from ...protocol import DataProto 24 | from .config import ActorConfig 25 | 26 | 27 | __all__ = ["BasePPOActor"] 28 | 29 | 30 | class BasePPOActor(ABC): 31 | def __init__(self, config: ActorConfig): 32 | """The base class for PPO actor 33 | 34 | Args: 35 | config (ActorConfig): a config passed to the PPOActor. 36 | """ 37 | self.config = config 38 | 39 | @abstractmethod 40 | def compute_log_prob(self, data: DataProto) -> torch.Tensor: 41 | """Compute logits given a batch of data. 42 | 43 | Args: 44 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, 45 | ```attention_mask``` and ```position_ids```. 46 | 47 | Returns: 48 | DataProto: a DataProto containing the key ```log_probs``` 49 | """ 50 | pass 51 | 52 | @abstractmethod 53 | def update_policy(self, data: DataProto) -> Dict[str, Any]: 54 | """Update the policy with an iterator of DataProto 55 | 56 | Args: 57 | data (DataProto): an iterator over the DataProto that returns by 58 | ```make_minibatch_iterator``` 59 | 60 | Returns: 61 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model 62 | such as ```loss```, ```grad_norm```, etc,. 63 | """ 64 | pass 65 | -------------------------------------------------------------------------------- /verl/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for tokenization.""" 15 | 16 | from typing import Optional 17 | 18 | from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin 19 | 20 | 21 | def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer: 22 | """Create a huggingface pretrained tokenizer.""" 23 | tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs) 24 | 25 | if tokenizer.bos_token == "" and tokenizer.eos_token == "": 26 | # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance. 27 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a 28 | print("Found gemma model. Set eos_token and eos_token_id to and 107.") 29 | tokenizer.eos_token = "" 30 | 31 | if tokenizer.pad_token_id is None: 32 | print("Pad token is None. Set it to eos_token.") 33 | tokenizer.pad_token = tokenizer.eos_token 34 | 35 | return tokenizer 36 | 37 | 38 | def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]: 39 | """Create a huggingface pretrained processor.""" 40 | try: 41 | processor = AutoProcessor.from_pretrained(model_path, **kwargs) 42 | except Exception: 43 | processor = None 44 | 45 | # Avoid load tokenizer, see: 46 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 47 | if processor is not None and "Processor" not in processor.__class__.__name__: 48 | processor = None 49 | 50 | return processor 51 | -------------------------------------------------------------------------------- /verl/utils/torch_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | HALF_LIST = [16, "16", "fp16", "float16"] 19 | FLOAT_LIST = [32, "32", "fp32", "float32"] 20 | BFLOAT_LIST = ["bf16", "bfloat16"] 21 | 22 | 23 | class PrecisionType: 24 | """Type of precision used. 25 | 26 | >>> PrecisionType.HALF == 16 27 | True 28 | >>> PrecisionType.HALF in (16, "16") 29 | True 30 | """ 31 | 32 | HALF = "16" 33 | FLOAT = "32" 34 | FULL = "64" 35 | BFLOAT = "bf16" 36 | MIXED = "mixed" 37 | 38 | @staticmethod 39 | def is_fp16(precision): 40 | return precision in HALF_LIST 41 | 42 | @staticmethod 43 | def is_fp32(precision): 44 | return precision in FLOAT_LIST 45 | 46 | @staticmethod 47 | def is_bf16(precision): 48 | return precision in BFLOAT_LIST 49 | 50 | @staticmethod 51 | def to_dtype(precision) -> torch.dtype: 52 | if precision in HALF_LIST: 53 | return torch.float16 54 | elif precision in FLOAT_LIST: 55 | return torch.float32 56 | elif precision in BFLOAT_LIST: 57 | return torch.bfloat16 58 | else: 59 | raise RuntimeError(f"unexpected precision: {precision}") 60 | 61 | @staticmethod 62 | def to_str(precision: torch.dtype) -> str: 63 | if precision == torch.float16: 64 | return "float16" 65 | elif precision == torch.float32: 66 | return "float32" 67 | elif precision == torch.bfloat16: 68 | return "bfloat16" 69 | else: 70 | raise RuntimeError(f"unexpected precision: {precision}") 71 | -------------------------------------------------------------------------------- /verl/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities to create common models 16 | """ 17 | 18 | from functools import lru_cache 19 | from typing import Tuple 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch import nn 24 | 25 | 26 | @lru_cache 27 | def is_rank0() -> int: 28 | return (not dist.is_initialized()) or (dist.get_rank() == 0) 29 | 30 | 31 | def print_gpu_memory_usage(prefix: str) -> None: 32 | if is_rank0(): 33 | memory_allocated = torch.cuda.memory_allocated() / (1024**3) 34 | memory_reserved = torch.cuda.memory_reserved() / (1024**3) 35 | print(f"{prefix} memory allocated: {memory_allocated:.2f} GB, memory reserved: {memory_reserved:.2f} GB.") 36 | 37 | 38 | def get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]: 39 | n_params = sum(p.numel() for p in model.parameters()) 40 | 41 | if scale == "auto": 42 | if n_params > 1e9: 43 | scale = "B" 44 | elif n_params > 1e6: 45 | scale = "M" 46 | elif n_params > 1e3: 47 | scale = "K" 48 | else: 49 | scale = "" 50 | 51 | if scale == "B": 52 | n_params = n_params / 1e9 53 | elif scale == "M": 54 | n_params = n_params / 1e6 55 | elif scale == "K": 56 | n_params = n_params / 1e3 57 | elif scale == "": 58 | pass 59 | else: 60 | raise NotImplementedError(f"Unknown scale {scale}.") 61 | 62 | return n_params, scale 63 | 64 | 65 | def print_model_size(model: nn.Module, name: str = None) -> None: 66 | n_params, scale = get_model_size(model, scale="auto") 67 | if name is None: 68 | name = model.__class__.__name__ 69 | 70 | print(f"{name} contains {n_params:.2f}{scale} parameters") 71 | -------------------------------------------------------------------------------- /training_scripts/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_files: hiyouga/math12k@train 3 | val_files: hiyouga/math12k@test 4 | prompt_key: problem 5 | answer_key: answer 6 | image_key: images 7 | max_prompt_length: 2048 8 | max_response_length: 2048 9 | rollout_batch_size: 512 10 | shuffle: true 11 | seed: 1 12 | max_pixels: 4194304 13 | min_pixels: 262144 14 | 15 | algorithm: 16 | adv_estimator: grpo 17 | kl_coef: 0.0 18 | 19 | worker: 20 | actor: 21 | global_batch_size: 128 22 | micro_batch_size_per_device_for_update: 4 23 | micro_batch_size_per_device_for_experience: 16 24 | max_grad_norm: 1.0 25 | entropy_coeff: 1.0e-3 26 | use_kl_loss: true 27 | kl_loss_coef: 1.0e-2 28 | kl_loss_type: low_var_kl 29 | padding_free: true 30 | ulysses_sequence_parallel_size: 1 31 | is_noisy: false 32 | aug_type: gaussian 33 | gaussian_noise_step: 500 34 | crop_size: 0.0 35 | rotate_angle: 15 36 | decay_coef: 1.0 37 | decay_mode: none 38 | decay_sig_mid_step: 40 39 | model: 40 | model_path: Qwen/Qwen2.5-7B-Instruct 41 | enable_gradient_checkpointing: true 42 | trust_remote_code: false 43 | freeze_vision_tower: false 44 | freeze_language_model: false 45 | optim: 46 | lr: 1.0e-6 47 | weight_decay: 1.0e-2 48 | lr_warmup_ratio: 0.0 49 | fsdp: 50 | enable_full_shard: true 51 | enable_cpu_offload: false 52 | enable_rank0_init: true 53 | offload: 54 | offload_params: true 55 | offload_optimizer: true 56 | 57 | rollout: 58 | temperature: 1.0 59 | n: 5 60 | gpu_memory_utilization: 0.5 61 | enforce_eager: false 62 | enable_chunked_prefill: false 63 | tensor_parallel_size: 2 64 | limit_images: 0 65 | 66 | ref: 67 | fsdp: 68 | enable_full_shard: true 69 | enable_cpu_offload: true 70 | enable_rank0_init: true 71 | offload: 72 | offload_params: false 73 | 74 | reward: 75 | reward_type: function 76 | compute_score: math 77 | 78 | trainer: 79 | total_episodes: 15 80 | logger: ["console", "wandb"] 81 | project_name: noisy_rollout 82 | experiment_name: qwen2_5_7b_math 83 | n_gpus_per_node: 8 84 | nnodes: 1 85 | val_freq: 5 86 | val_before_train: true 87 | val_only: false 88 | val_generations_to_log: 1 89 | save_freq: 5 90 | remove_previous_ckpt: false 91 | remove_ckpt_after_load: false 92 | save_checkpoint_path: null 93 | load_checkpoint_path: null 94 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/fsdp_ulysses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT 16 | """ 17 | 18 | from torch.distributed.device_mesh import DeviceMesh 19 | 20 | from ...protocol import DataProto, all_gather_data_proto 21 | from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group 22 | from .base import BaseShardingManager 23 | 24 | 25 | class FSDPUlyssesShardingManager(BaseShardingManager): 26 | """ 27 | Sharding manager to support data resharding when using FSDP + Ulysses 28 | """ 29 | 30 | def __init__(self, device_mesh: DeviceMesh): 31 | super().__init__() 32 | self.device_mesh = device_mesh 33 | 34 | def __enter__(self): 35 | if self.device_mesh is not None: 36 | self.prev_sp_group = get_ulysses_sequence_parallel_group() 37 | set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group()) 38 | 39 | def __exit__(self, exc_type, exc_value, traceback): 40 | if self.device_mesh is not None: 41 | set_ulysses_sequence_parallel_group(self.prev_sp_group) 42 | 43 | def preprocess_data(self, data: DataProto) -> DataProto: 44 | """ 45 | AllGather data from sp region 46 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE 47 | In Ulysses, we need to make sure the same data is used across a SP group 48 | """ 49 | if self.device_mesh is not None: 50 | sp_size = self.device_mesh["sp"].size() 51 | sp_group = self.device_mesh["sp"].get_group() 52 | all_gather_data_proto(data, size=sp_size, group=sp_group) 53 | 54 | return data 55 | 56 | def postprocess_data(self, data: DataProto) -> DataProto: 57 | """ 58 | Split the data to follow FSDP partition 59 | """ 60 | if self.device_mesh is not None: 61 | sp_size = self.device_mesh["sp"].size() 62 | sp_rank = self.device_mesh["sp"].get_local_rank() 63 | data = data.chunk(chunks=sp_size)[sp_rank] 64 | 65 | return data 66 | -------------------------------------------------------------------------------- /verl/workers/reward/custom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from transformers import PreTrainedTokenizer 18 | 19 | from ...protocol import DataProto 20 | from ...utils.reward_score import math_compute_score, r1v_compute_score, pure_math_compute_score 21 | 22 | 23 | class CustomRewardManager: 24 | def __init__(self, tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: str): 25 | self.tokenizer = tokenizer 26 | self.num_examine = num_examine 27 | if compute_score == "math": 28 | self.compute_score = math_compute_score 29 | elif compute_score == "r1v": 30 | self.compute_score = r1v_compute_score 31 | elif compute_score == "pure_math": 32 | self.compute_score = pure_math_compute_score 33 | else: 34 | raise NotImplementedError() 35 | 36 | def __call__(self, data: DataProto) -> torch.Tensor: 37 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) 38 | already_print = 0 39 | 40 | for i in range(len(data)): 41 | data_item = data[i] # DataProtoItem 42 | 43 | prompt_ids = data_item.batch["prompts"] 44 | prompt_length = prompt_ids.shape[-1] 45 | 46 | valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() 47 | valid_prompt_ids = prompt_ids[-valid_prompt_length:] 48 | 49 | response_ids = data_item.batch["responses"] 50 | valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() 51 | valid_response_ids = response_ids[:valid_response_length] 52 | 53 | # decode 54 | prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) 55 | response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) 56 | 57 | ground_truth = data_item.non_tensor_batch["ground_truth"] 58 | 59 | score = self.compute_score(response_str, ground_truth) 60 | reward_tensor[i, valid_response_length - 1] = score 61 | 62 | if already_print < self.num_examine: 63 | already_print += 1 64 | print("[prompt]", prompt_str) 65 | print("[response]", response_str) 66 | print("[ground_truth]", ground_truth) 67 | print("[score]", score) 68 | 69 | return reward_tensor 70 | -------------------------------------------------------------------------------- /verl/trainer/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | PPO config 16 | """ 17 | 18 | import os 19 | from dataclasses import asdict, dataclass, field, fields, is_dataclass 20 | from typing import Optional, Tuple 21 | 22 | from ..workers.config import WorkerConfig 23 | 24 | 25 | def recursive_post_init(dataclass_obj): 26 | if hasattr(dataclass_obj, "post_init"): 27 | dataclass_obj.post_init() 28 | 29 | for attr in fields(dataclass_obj): 30 | if is_dataclass(getattr(dataclass_obj, attr.name)): 31 | recursive_post_init(getattr(dataclass_obj, attr.name)) 32 | 33 | 34 | @dataclass 35 | class DataConfig: 36 | train_files: str = "" 37 | val_files: str = "" 38 | prompt_key: str = "prompt" 39 | answer_key: str = "answer" 40 | image_key: str = "images" 41 | max_prompt_length: int = 512 42 | max_response_length: int = 512 43 | rollout_batch_size: int = 512 44 | system_prompt: Optional[str] = None 45 | shuffle: bool = True 46 | seed: int = 1 47 | max_pixels: int = 4194304 48 | min_pixels: int = 262144 49 | 50 | 51 | @dataclass 52 | class AlgorithmConfig: 53 | gamma: float = 1.0 54 | lam: float = 1.0 55 | adv_estimator: str = "grpo" 56 | kl_penalty: str = "kl" 57 | kl_type: str = "fixed" 58 | kl_coef: float = 1e-3 59 | kl_horizon: float = 0.0 60 | kl_target: float = 0.0 61 | 62 | 63 | @dataclass 64 | class TrainerConfig: 65 | total_episodes: int = 10 66 | max_steps: Optional[int] = None 67 | project_name: str = "easy_r1" 68 | experiment_name: str = "demo" 69 | logger: Tuple[str] = ("console", "wandb") 70 | nnodes: int = 1 71 | n_gpus_per_node: int = 8 72 | critic_warmup: int = 0 73 | val_freq: int = -1 74 | val_before_train: bool = True 75 | val_only: bool = False 76 | val_generations_to_log: int = 1 77 | save_freq: int = -1 78 | remove_previous_ckpt: bool = False 79 | remove_ckpt_after_load: bool = False 80 | save_checkpoint_path: Optional[str] = None 81 | load_checkpoint_path: Optional[str] = None 82 | def post_init(self): 83 | if self.save_checkpoint_path is None: 84 | self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name) 85 | 86 | 87 | @dataclass 88 | class PPOConfig: 89 | data: DataConfig = field(default_factory=DataConfig) 90 | worker: WorkerConfig = field(default_factory=WorkerConfig) 91 | algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig) 92 | trainer: TrainerConfig = field(default_factory=TrainerConfig) 93 | 94 | def post_init(self): 95 | self.worker.rollout.prompt_length = self.data.max_prompt_length 96 | self.worker.rollout.response_length = self.data.max_response_length 97 | 98 | def deep_post_init(self): 99 | recursive_post_init(self) 100 | 101 | def to_dict(self): 102 | return asdict(self) 103 | -------------------------------------------------------------------------------- /verl/workers/actor/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Actor config 16 | """ 17 | 18 | from dataclasses import dataclass, field 19 | from typing import Any, Dict, Optional, Tuple 20 | 21 | 22 | @dataclass 23 | class ModelConfig: 24 | model_path: Optional[str] = None 25 | tokenizer_path: Optional[str] = None 26 | override_config: Dict[str, Any] = field(default_factory=dict) 27 | enable_gradient_checkpointing: bool = True 28 | trust_remote_code: bool = True 29 | freeze_vision_tower: bool = False 30 | freeze_language_model: bool = False 31 | 32 | def post_init(self): 33 | if self.tokenizer_path is None: 34 | self.tokenizer_path = self.model_path 35 | 36 | 37 | @dataclass 38 | class OptimConfig: 39 | lr: float = 1e-6 40 | betas: Tuple[float, float] = (0.9, 0.999) 41 | weight_decay: float = 1e-2 42 | lr_warmup_ratio: float = 0.0 43 | min_lr_ratio: Optional[float] = None 44 | warmup_style: str = "constant" 45 | """auto keys""" 46 | training_steps: int = field(default=-1, init=False) 47 | 48 | 49 | @dataclass 50 | class FSDPConfig: 51 | enable_full_shard: bool = True 52 | enable_cpu_offload: bool = False 53 | enable_rank0_init: bool = False 54 | use_orig_params: bool = False 55 | torch_dtype: Optional[str] = None 56 | fsdp_size: int = -1 57 | mp_param_dtype: str = "bf16" 58 | mp_reduce_dtype: str = "fp32" 59 | mp_buffer_dtype: str = "fp32" 60 | 61 | 62 | @dataclass 63 | class OffloadConfig: 64 | offload_params: bool = False 65 | offload_optimizer: bool = False 66 | 67 | 68 | @dataclass 69 | class ActorConfig: 70 | strategy: str = "fsdp" 71 | global_batch_size: int = 256 72 | micro_batch_size_per_device_for_update: int = 4 73 | micro_batch_size_per_device_for_experience: int = 16 74 | max_grad_norm: float = 1.0 75 | clip_ratio: float = 0.2 76 | entropy_coeff: float = 1e-3 77 | use_kl_loss: bool = True 78 | kl_loss_coef: float = 1e-3 79 | kl_loss_type: str = "low_var_kl" 80 | ppo_epochs: int = 1 81 | padding_free: bool = False 82 | ulysses_sequence_parallel_size: int = 1 83 | is_noisy: bool = False 84 | aug_type: str = "gaussian" # aug 85 | gaussian_noise_step: int = 500 # aug 86 | crop_size: float = 0.0 # aug 87 | rotate_angle: int = 15 # aug 88 | decay_coef: float = 1.0 # decay 89 | decay_mode: str = "exp" # decay 90 | decay_sig_mid_step: int = 40 # decay 91 | model: ModelConfig = field(default_factory=ModelConfig) 92 | optim: OptimConfig = field(default_factory=OptimConfig) 93 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 94 | offload: OffloadConfig = field(default_factory=OffloadConfig) 95 | """auto keys""" 96 | global_batch_size_per_device: int = field(default=-1, init=False) 97 | 98 | 99 | @dataclass 100 | class RefConfig: 101 | strategy: str = "fsdp" 102 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 103 | offload: OffloadConfig = field(default_factory=OffloadConfig) 104 | """auto keys""" 105 | micro_batch_size_per_device_for_experience: int = field(default=-1, init=False) 106 | padding_free: bool = field(default=False, init=False) 107 | ulysses_sequence_parallel_size: int = field(default=1, init=False) 108 | -------------------------------------------------------------------------------- /verl/trainer/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. 16 | """ 17 | 18 | import json 19 | 20 | import ray 21 | from omegaconf import OmegaConf 22 | 23 | from ..single_controller.ray import RayWorkerGroup 24 | from ..utils.tokenizer import get_processor, get_tokenizer 25 | from ..workers.fsdp_workers import FSDPWorker 26 | from ..workers.reward import CustomRewardManager 27 | from .config import PPOConfig 28 | from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role 29 | 30 | 31 | @ray.remote(num_cpus=1) 32 | def main_task(config: PPOConfig): 33 | # please make sure main_task is not scheduled on head 34 | # print config 35 | config.deep_post_init() 36 | print(json.dumps(config.to_dict(), indent=2)) 37 | 38 | # instantiate tokenizer 39 | tokenizer = get_tokenizer( 40 | config.worker.actor.model.model_path, 41 | trust_remote_code=config.worker.actor.model.trust_remote_code, 42 | use_fast=True, 43 | ) 44 | processor = get_processor( 45 | config.worker.actor.model.model_path, 46 | trust_remote_code=config.worker.actor.model.trust_remote_code, 47 | use_fast=True, 48 | ) 49 | 50 | # define worker classes 51 | ray_worker_group_cls = RayWorkerGroup 52 | role_worker_mapping = { 53 | Role.ActorRollout: ray.remote(FSDPWorker), 54 | Role.Critic: ray.remote(FSDPWorker), 55 | Role.RefPolicy: ray.remote(FSDPWorker), 56 | } 57 | global_pool_id = "global_pool" 58 | resource_pool_spec = { 59 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, 60 | } 61 | mapping = { 62 | Role.ActorRollout: global_pool_id, 63 | Role.Critic: global_pool_id, 64 | Role.RefPolicy: global_pool_id, 65 | } 66 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) 67 | 68 | reward_fn = CustomRewardManager( 69 | tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score 70 | ) 71 | val_reward_fn = CustomRewardManager( 72 | tokenizer=tokenizer, num_examine=1, compute_score="pure_math" 73 | ) 74 | 75 | trainer = RayPPOTrainer( 76 | config=config, 77 | tokenizer=tokenizer, 78 | processor=processor, 79 | role_worker_mapping=role_worker_mapping, 80 | resource_pool_manager=resource_pool_manager, 81 | ray_worker_group_cls=ray_worker_group_cls, 82 | reward_fn=reward_fn, 83 | val_reward_fn=val_reward_fn, 84 | ) 85 | trainer.init_workers() 86 | trainer.fit() 87 | 88 | 89 | def main(): 90 | cli_args = OmegaConf.from_cli() 91 | file_config = OmegaConf.load(getattr(cli_args, "config")) 92 | cli_args.pop("config", None) 93 | 94 | default_config = OmegaConf.structured(PPOConfig()) 95 | ppo_config = OmegaConf.merge(default_config, file_config, cli_args) 96 | ppo_config = OmegaConf.to_object(ppo_config) 97 | 98 | if not ray.is_initialized(): 99 | # this is for local ray cluster 100 | ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) 101 | 102 | ray.get(main_task.remote(ppo_config)) 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # outputs 174 | outputs/ 175 | checkpoints/ 176 | wandb/ 177 | 178 | # data 179 | eval/data 180 | eval/results -------------------------------------------------------------------------------- /verl/utils/checkpoint/checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import random 17 | import shutil 18 | import tempfile 19 | from abc import ABC, abstractmethod 20 | from typing import Union 21 | 22 | import numpy as np 23 | import torch 24 | import torch.distributed as dist 25 | from filelock import FileLock 26 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 27 | from transformers import PreTrainedTokenizer, ProcessorMixin 28 | 29 | 30 | class BaseCheckpointManager(ABC): 31 | """ 32 | A checkpoint manager that saves and loads 33 | - model 34 | - optimizer 35 | - lr_scheduler 36 | - extra_states 37 | in a SPMD way. 38 | 39 | We save 40 | - sharded model states and optimizer states 41 | - full lr_scheduler states 42 | - huggingface tokenizer and config for ckpt merge 43 | """ 44 | 45 | def __init__( 46 | self, 47 | model: FSDP, 48 | optimizer: torch.optim.Optimizer, 49 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler, 50 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin], 51 | ): 52 | self.previous_global_step = None 53 | self.previous_save_local_path = None 54 | 55 | self.model = model 56 | self.optimizer = optimizer 57 | self.lr_scheduler = lr_scheduler 58 | self.processing_class = processing_class 59 | 60 | assert isinstance(self.model, FSDP) 61 | self.rank = dist.get_rank() 62 | self.world_size = dist.get_world_size() 63 | 64 | @abstractmethod 65 | def load_checkpoint(self, *args, **kwargs): 66 | raise NotImplementedError 67 | 68 | @abstractmethod 69 | def save_checkpoint(self, *args, **kwargs): 70 | raise NotImplementedError 71 | 72 | def remove_previous_save_local_path(self): 73 | if not self.previous_save_local_path: 74 | return 75 | 76 | abs_path = os.path.abspath(self.previous_save_local_path) 77 | print(f"Checkpoint manager remove previous save local path: {abs_path}") 78 | if not os.path.exists(abs_path): 79 | return 80 | 81 | # remove previous local_path 82 | shutil.rmtree(abs_path, ignore_errors=True) 83 | 84 | @staticmethod 85 | def local_mkdir(path): 86 | if not os.path.isabs(path): 87 | working_dir = os.getcwd() 88 | path = os.path.join(working_dir, path) 89 | 90 | # Using hash value of path as lock file name to avoid long file name 91 | lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock" 92 | lock_path = os.path.join(tempfile.gettempdir(), lock_filename) 93 | 94 | try: 95 | with FileLock(lock_path, timeout=60): # Add timeout 96 | # make a new dir 97 | os.makedirs(path, exist_ok=True) 98 | except Exception as e: 99 | print(f"Warning: Failed to acquire lock for {path}: {e}") 100 | # Even if the lock is not acquired, try to create the directory 101 | os.makedirs(path, exist_ok=True) 102 | 103 | return path 104 | 105 | @staticmethod 106 | def get_rng_state(): 107 | rng_state = { 108 | "cpu": torch.get_rng_state(), 109 | "cuda": torch.cuda.get_rng_state(), 110 | "numpy": np.random.get_state(), 111 | "random": random.getstate(), 112 | } 113 | return rng_state 114 | 115 | @staticmethod 116 | def load_rng_state(rng_state): 117 | torch.set_rng_state(rng_state["cpu"]) 118 | torch.cuda.set_rng_state(rng_state["cuda"]) 119 | np.random.set_state(rng_state["numpy"]) 120 | random.setstate(rng_state["random"]) 121 | 122 | 123 | def find_latest_ckpt_path(path, directory_format="global_step_{}"): 124 | if path is None: 125 | return None 126 | 127 | tracker_file = get_checkpoint_tracker_filename(path) 128 | if not os.path.exists(tracker_file): 129 | print("Checkpoint tracker file does not exist: %s", tracker_file) 130 | return None 131 | 132 | with open(tracker_file, "rb") as f: 133 | iteration = int(f.read().decode()) 134 | ckpt_path = os.path.join(path, directory_format.format(iteration)) 135 | if not os.path.exists(ckpt_path): 136 | print("Checkpoint does not exist: %s", ckpt_path) 137 | return None 138 | 139 | print("Found checkpoint: %s", ckpt_path) 140 | return ckpt_path 141 | 142 | 143 | def get_checkpoint_tracker_filename(root_path: str): 144 | """ 145 | Tracker file rescords the latest chckpoint during training to restart from. 146 | """ 147 | return os.path.join(root_path, "latest_checkpointed_iteration.txt") 148 | -------------------------------------------------------------------------------- /verl/utils/image_aug.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import random 5 | from PIL import Image 6 | import torchvision.transforms as T 7 | from ..workers.actor.config import ActorConfig 8 | 9 | class ImageAugmenter: 10 | def __init__(self, config: ActorConfig=None): 11 | self.config = config or {} 12 | self.to_tensor = T.ToTensor() 13 | self.to_pil = T.ToPILImage() 14 | 15 | # Set default configurations 16 | self.aug_type = self.config.aug_type 17 | self.noise_step = self.config.gaussian_noise_step # Gaussian noise steps 18 | self.crop_size = self.config.crop_size # Random occlusion region ratio 19 | self.rotate_angle = self.config.rotate_angle # Maximum rotation angle 20 | self.decay_sig_mid_step = self.config.decay_sig_mid_step 21 | 22 | # Augmentation method mapping 23 | self.aug_methods = { 24 | 'gaussian': self.apply_gaussian_noise, 25 | 'crop_fill': self.apply_crop_fill, 26 | 'rotate': self.apply_rotation, 27 | } 28 | 29 | def augment(self, image, step=0, total_steps=1): 30 | # Handle decay 31 | decay = 1.0 32 | if hasattr(self.config, 'decay_mode') and hasattr(self.config, 'decay_coef'): 33 | decay_mode = self.config.decay_mode 34 | decay_coef = self.config.decay_coef 35 | norm_step = step / total_steps 36 | 37 | if decay_mode == 'exp': 38 | decay = 1.0 - decay_coef ** (total_steps - step) 39 | elif decay_mode == 'pow': 40 | decay = 1.0 - norm_step ** decay_coef 41 | elif decay_mode == 'linear': 42 | decay = 1.0 - norm_step 43 | elif decay_mode == 'sigmoid': 44 | x = decay_coef * (norm_step - self.decay_sig_mid_step / total_steps) 45 | decay = 1.0 - (1 / (1 + math.exp(-x))) 46 | 47 | aug_method = self.aug_methods.get(self.aug_type) 48 | if aug_method is None: 49 | return image 50 | 51 | return aug_method(image, decay) 52 | 53 | def apply_gaussian_noise(self, image, decay=1.0): 54 | """Apply gaussian noise to image""" 55 | image_tensor = self.to_tensor(image) 56 | noise_step = int(self.noise_step * decay) 57 | noisy_tensor = self._add_gaussian_noise(image_tensor, noise_step) 58 | noisy_tensor = torch.clamp(noisy_tensor, 0.0, 1.0) 59 | return self.to_pil(noisy_tensor) 60 | 61 | def _add_gaussian_noise(self, image_tensor, noise_step): 62 | """Implementation of gaussian noise""" 63 | num_steps = 1000 64 | betas = torch.linspace(-6, 6, num_steps) 65 | betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5 66 | alphas = 1 - betas 67 | alphas_prod = torch.cumprod(alphas, dim=0) 68 | alphas_bar_sqrt = torch.sqrt(alphas_prod) 69 | one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod) 70 | 71 | def q_x(x_0, t): 72 | noise = torch.randn_like(x_0) 73 | alphas_t = alphas_bar_sqrt[t] 74 | alphas_1_m_t = one_minus_alphas_bar_sqrt[t] 75 | return (alphas_t * x_0 + alphas_1_m_t * noise) 76 | 77 | return q_x(image_tensor, noise_step) 78 | 79 | def apply_rotation(self, image, decay=1.0): 80 | """Apply rotation to image""" 81 | max_angle = self.rotate_angle * decay 82 | angle = random.uniform(-max_angle, max_angle) 83 | return image.rotate(angle, resample=Image.BILINEAR, expand=True) 84 | 85 | def apply_crop_fill(self, image, decay=1.0): 86 | """Apply random region filling (occlusion)""" 87 | width, height = image.size 88 | crop_size_scaled = self.crop_size * decay 89 | crop_w = int(width * crop_size_scaled) 90 | crop_h = int(height * crop_size_scaled) 91 | 92 | if width > crop_w and height > crop_h: 93 | x = random.randint(0, width - crop_w) 94 | y = random.randint(0, height - crop_h) 95 | else: 96 | x, y = 0, 0 97 | crop_w = min(crop_w, width) 98 | crop_h = min(crop_h, height) 99 | 100 | img_np = np.array(image) 101 | img_np[y:y+crop_h, x:x+crop_w] = 0 102 | 103 | return Image.fromarray(img_np) 104 | 105 | 106 | def augment_images(images, config, step=0, total_steps=1): 107 | augmenter = ImageAugmenter(config) 108 | return [augmenter.augment(img, step, total_steps) for img in images] 109 | 110 | 111 | def augment_batch(batch, config, step=0, total_steps=1): 112 | if "multi_modal_data" not in batch.non_tensor_batch: 113 | return batch 114 | 115 | from copy import deepcopy 116 | new_batch = deepcopy(batch) 117 | 118 | for i, item in enumerate(new_batch.non_tensor_batch["multi_modal_data"]): 119 | if "image" in item: 120 | image_list = item["image"] 121 | augmented_images = augment_images( 122 | image_list, 123 | config, 124 | step, 125 | total_steps 126 | ) 127 | new_batch.non_tensor_batch["multi_modal_data"][i]["image"] = augmented_images 128 | 129 | return new_batch 130 | -------------------------------------------------------------------------------- /verl/utils/fsdp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gc 16 | from collections import defaultdict 17 | from functools import partial 18 | from typing import Callable, Union 19 | 20 | import torch 21 | from torch import nn 22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 23 | from torch.distributed.fsdp._runtime_utils import _lazy_init 24 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 25 | from torch.optim import Optimizer 26 | from transformers import PreTrainedModel 27 | from transformers.trainer_pt_utils import get_module_class_from_name 28 | 29 | 30 | def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]: 31 | param_occurrence = defaultdict(int) 32 | for _, param in model.named_parameters(remove_duplicate=False): 33 | param_occurrence[param] += 1 34 | 35 | duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1} 36 | materialized_params = {} 37 | 38 | def init_fn(module: nn.Module): 39 | for name, param in module.named_parameters(recurse=False): 40 | if param in duplicated_params: 41 | module._parameters[name] = materialized_params.setdefault( 42 | param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad) 43 | ) 44 | else: 45 | module._parameters[name] = nn.Parameter( 46 | torch.empty_like(param.data, device=device), requires_grad=param.requires_grad 47 | ) 48 | 49 | return init_fn 50 | 51 | 52 | def get_fsdp_wrap_policy(model: PreTrainedModel): 53 | """Get FSDP wrap policy for the model. 54 | 55 | Args: 56 | module: The module to get wrap policy for 57 | """ 58 | transformer_cls_to_wrap = set() 59 | for module in model._no_split_modules: 60 | transformer_cls = get_module_class_from_name(model, module) 61 | if transformer_cls is None: 62 | raise Exception(f"Cannot find {module} in pretrained model.") 63 | else: 64 | transformer_cls_to_wrap.add(transformer_cls) 65 | 66 | return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap) 67 | 68 | 69 | @torch.no_grad() 70 | def offload_fsdp_model(model: FSDP, empty_cache: bool = True): 71 | # lazy init FSDP model 72 | _lazy_init(model, model) 73 | assert model._is_root, "Only support root model offloading to CPU" 74 | for handle in model._all_handles: 75 | if handle._offload_params: 76 | continue 77 | 78 | flat_param = handle.flat_param 79 | assert ( 80 | flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() 81 | and id(flat_param.data) != id(flat_param._local_shard) 82 | and flat_param.data.size() == flat_param._local_shard.size() 83 | ) 84 | handle.flat_param_to("cpu", non_blocking=True) 85 | # the following still keeps id(._local_shard) != id(.data) 86 | flat_param._local_shard = flat_param.data 87 | assert id(flat_param._local_shard) != id(flat_param.data) 88 | 89 | if empty_cache: 90 | torch.cuda.empty_cache() 91 | 92 | 93 | @torch.no_grad() 94 | def load_fsdp_model(model: FSDP, empty_cache: bool = True): 95 | # lazy init FSDP model 96 | _lazy_init(model, model) 97 | assert model._is_root, "Only support root model loading to GPU" 98 | for handle in model._all_handles: 99 | if handle._offload_params: 100 | continue 101 | 102 | flat_param = handle.flat_param 103 | handle.flat_param_to("cuda", non_blocking=True) 104 | # the following still keeps id(._local_shard) != id(.data) 105 | flat_param._local_shard = flat_param.data 106 | 107 | if empty_cache: 108 | gc.collect() 109 | 110 | 111 | @torch.no_grad() 112 | def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): 113 | if not optimizer.state: 114 | return 115 | 116 | for param_group in optimizer.param_groups: 117 | for param in param_group["params"]: 118 | state = optimizer.state[param] 119 | for key, value in state.items(): 120 | if isinstance(value, torch.Tensor): 121 | state[key] = value.to("cpu", non_blocking=True) 122 | 123 | if empty_cache: 124 | torch.cuda.empty_cache() 125 | 126 | 127 | @torch.no_grad() 128 | def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): 129 | if not optimizer.state: 130 | return 131 | 132 | for param_group in optimizer.param_groups: 133 | for param in param_group["params"]: 134 | state = optimizer.state[param] 135 | for key, value in state.items(): 136 | if isinstance(value, torch.Tensor): 137 | state[key] = value.to("cuda", non_blocking=True) 138 | 139 | if empty_cache: 140 | gc.collect() 141 | -------------------------------------------------------------------------------- /verl/utils/flops_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, List, Tuple 16 | 17 | import torch 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers.models.llama.configuration_llama import LlamaConfig 22 | 23 | 24 | VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"} 25 | 26 | 27 | def get_device_flops(unit: str = "T") -> float: 28 | def unit_convert(number: float, level: str): 29 | units = ["B", "K", "M", "G", "T", "P"] 30 | if number <= 0: 31 | return number 32 | 33 | ptr = 0 34 | while ptr < len(units) and units[ptr] != level: 35 | number /= 1000 36 | ptr += 1 37 | 38 | return number 39 | 40 | device_name = torch.cuda.get_device_name() 41 | flops = float("inf") # INF flops for unkown gpu type 42 | if "H100" in device_name or "H800" in device_name: 43 | flops = 989e12 44 | elif "A100" in device_name or "A800" in device_name: 45 | flops = 312e12 46 | elif "L40" in device_name: 47 | flops = 181.05e12 48 | elif "L20" in device_name: 49 | flops = 119.5e12 50 | elif "H20" in device_name: 51 | flops = 148e12 52 | elif "910B" in device_name: 53 | flops = 354e12 54 | flops_unit = unit_convert(flops, unit) 55 | return flops_unit 56 | 57 | 58 | class FlopsCounter: 59 | """ 60 | Used to count mfu during training loop 61 | 62 | Example: 63 | flops_counter = FlopsCounter(config) 64 | flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time) 65 | """ 66 | 67 | def __init__(self, config: "LlamaConfig"): 68 | if config.model_type not in VALID_MODLE_TYPE: 69 | print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.") 70 | 71 | self.estimate_func = { 72 | "llama": self._estimate_llama_flops, 73 | "qwen2": self._estimate_llama_flops, 74 | "qwen2_vl": self._estimate_llama_flops, 75 | "qwen2_5_vl": self._estimate_llama_flops, 76 | } 77 | self.config = config 78 | 79 | def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float: 80 | return 0 81 | 82 | def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float: 83 | hidden_size = self.config.hidden_size 84 | vocab_size = self.config.vocab_size 85 | num_hidden_layers = self.config.num_hidden_layers 86 | num_key_value_heads = self.config.num_key_value_heads 87 | num_attention_heads = self.config.num_attention_heads 88 | intermediate_size = self.config.intermediate_size 89 | 90 | head_dim = hidden_size // num_attention_heads 91 | q_size = num_attention_heads * head_dim 92 | k_size = num_key_value_heads * head_dim 93 | v_size = num_key_value_heads * head_dim 94 | 95 | # non-attn per layer parm 96 | # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp 97 | mlp_N = hidden_size * intermediate_size * 3 98 | attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) 99 | emd_and_lm_head_N = vocab_size * hidden_size * 2 100 | # non-attn all_layer parm 101 | dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N 102 | # non-attn all_layer & all_token fwd & bwd flops 103 | dense_N_flops = 6 * dense_N * tokens_sum 104 | 105 | # attn all_layer & all_token fwd & bwd flops 106 | seqlen_square_sum = 0 107 | for seqlen in batch_seqlens: 108 | seqlen_square_sum += seqlen * seqlen 109 | 110 | attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers 111 | 112 | # all_layer & all_token fwd & bwd flops 113 | flops_all_token = dense_N_flops + attn_qkv_flops 114 | flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 115 | return flops_achieved 116 | 117 | def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]: 118 | """ 119 | Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. 120 | 121 | Args: 122 | batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch. 123 | delta_time (float): The time taken to process the batch, in seconds. 124 | 125 | Returns: 126 | estimated_flops (float): The estimated FLOPS based on the input tokens and time. 127 | promised_flops (float): The expected FLOPS of the current device. 128 | """ 129 | tokens_sum = sum(batch_seqlens) 130 | func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops) 131 | estimated_flops = func(tokens_sum, batch_seqlens, delta_time) 132 | promised_flops = get_device_flops() 133 | return estimated_flops, promised_flops 134 | -------------------------------------------------------------------------------- /verl/utils/tracking.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A unified tracking interface that supports logging data to different backend 16 | """ 17 | 18 | import os 19 | from dataclasses import dataclass 20 | from typing import List, Tuple, Union 21 | 22 | from .logger.aggregate_logger import LocalLogger 23 | 24 | 25 | class Tracking: 26 | supported_backend = ["wandb", "mlflow", "swanlab", "console"] 27 | 28 | def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None): 29 | if isinstance(default_backend, str): 30 | default_backend = [default_backend] 31 | 32 | for backend in default_backend: 33 | assert backend in self.supported_backend, f"{backend} is not supported" 34 | 35 | self.logger = {} 36 | 37 | if "wandb" in default_backend: 38 | import wandb # type: ignore 39 | 40 | wandb.init(project=project_name, name=experiment_name, config=config) 41 | self.logger["wandb"] = wandb 42 | 43 | if "mlflow" in default_backend: 44 | import mlflow # type: ignore 45 | 46 | mlflow.start_run(run_name=experiment_name) 47 | mlflow.log_params(config) 48 | self.logger["mlflow"] = _MlflowLoggingAdapter() 49 | 50 | if "swanlab" in default_backend: 51 | import swanlab # type: ignore 52 | 53 | SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None) 54 | SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog") 55 | SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud") 56 | if SWANLAB_API_KEY: 57 | swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten 58 | 59 | swanlab.init( 60 | project=project_name, 61 | experiment_name=experiment_name, 62 | config=config, 63 | logdir=SWANLAB_LOG_DIR, 64 | mode=SWANLAB_MODE, 65 | ) 66 | self.logger["swanlab"] = swanlab 67 | 68 | if "console" in default_backend: 69 | self.console_logger = LocalLogger() 70 | self.logger["console"] = self.console_logger 71 | 72 | def log(self, data, step, backend=None): 73 | for default_backend, logger_instance in self.logger.items(): 74 | if backend is None or default_backend in backend: 75 | logger_instance.log(data=data, step=step) 76 | 77 | def __del__(self): 78 | if "wandb" in self.logger: 79 | self.logger["wandb"].finish(exit_code=0) 80 | 81 | if "swanlab" in self.logger: 82 | self.logger["swanlab"].finish() 83 | 84 | 85 | class _MlflowLoggingAdapter: 86 | def log(self, data, step): 87 | import mlflow # type: ignore 88 | 89 | mlflow.log_metrics(metrics=data, step=step) 90 | 91 | 92 | @dataclass 93 | class ValGenerationsLogger: 94 | def log(self, loggers: List[str], samples: List[Tuple[str, str, float]], step: int): 95 | if "wandb" in loggers: 96 | self.log_generations_to_wandb(samples, step) 97 | if "swanlab" in loggers: 98 | self.log_generations_to_swanlab(samples, step) 99 | 100 | def log_generations_to_wandb(self, samples: List[Tuple[str, str, float]], step: int) -> None: 101 | """Log samples to wandb as a table""" 102 | import wandb # type: ignore 103 | 104 | # Create column names for all samples 105 | columns = ["step"] + sum( 106 | [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] 107 | ) 108 | 109 | if not hasattr(self, "validation_table"): 110 | # Initialize the table on first call 111 | self.validation_table = wandb.Table(columns=columns) 112 | 113 | # Create a new table with same columns and existing data 114 | # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 115 | new_table = wandb.Table(columns=columns, data=self.validation_table.data) 116 | 117 | # Add new row with all data 118 | row_data = [] 119 | row_data.append(step) 120 | for sample in samples: 121 | row_data.extend(sample) 122 | 123 | new_table.add_data(*row_data) 124 | 125 | # Update reference and log 126 | wandb.log({"val/generations": new_table}, step=step) 127 | self.validation_table = new_table 128 | 129 | def log_generations_to_swanlab(self, samples: List[Tuple[str, str, float]], step: int) -> None: 130 | """Log samples to swanlab as text""" 131 | import swanlab # type: ignore 132 | 133 | swanlab_text_list = [] 134 | for i, sample in enumerate(samples): 135 | row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}" 136 | swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) 137 | 138 | # Log to swanlab 139 | swanlab.log({"val/generations": swanlab_text_list}, step=step) 140 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/fsdp_vllm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from typing import Dict, Iterable, Tuple, Union 17 | 18 | import torch 19 | import torch.distributed as dist 20 | from torch.distributed._tensor import DTensor 21 | from torch.distributed.device_mesh import DeviceMesh 22 | from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType 23 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP 24 | from vllm import LLM 25 | from vllm.distributed import parallel_state as vllm_ps 26 | 27 | from ...protocol import DataProto, all_gather_data_proto 28 | from ...utils.model_utils import print_gpu_memory_usage 29 | from .base import BaseShardingManager 30 | 31 | 32 | class FSDPVLLMShardingManager(BaseShardingManager): 33 | def __init__( 34 | self, 35 | module: FSDP, 36 | inference_engine: LLM, 37 | device_mesh: DeviceMesh = None, 38 | ): 39 | self.module = module 40 | self.inference_engine = inference_engine 41 | self.device_mesh = device_mesh 42 | with warnings.catch_warnings(): 43 | warnings.simplefilter("ignore") 44 | FSDP.set_state_dict_type( 45 | self.module, 46 | state_dict_type=StateDictType.SHARDED_STATE_DICT, 47 | state_dict_config=ShardedStateDictConfig(), 48 | ) 49 | 50 | self.world_size = dist.get_world_size() 51 | self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() 52 | self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() 53 | self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group 54 | 55 | # Note that torch_random_states may be different on each dp rank 56 | self.torch_random_states = torch.cuda.get_rng_state() 57 | # get a random rng states 58 | if self.device_mesh is not None: 59 | gen_dp_rank = self.device_mesh["dp"].get_local_rank() 60 | torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states 61 | self.gen_random_states = torch.cuda.get_rng_state() 62 | torch.cuda.set_rng_state(self.torch_random_states) 63 | else: 64 | self.gen_random_states = None 65 | 66 | def _make_weight_iterator( 67 | self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]] 68 | ) -> Iterable[Tuple[str, torch.Tensor]]: 69 | for name, tensor in actor_weights.items(): 70 | yield name, tensor.full_tensor() if self.world_size != 1 else tensor 71 | 72 | def __enter__(self): 73 | # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and 74 | # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. 75 | # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory 76 | # to speed up memory allocations. 77 | # 78 | # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management 79 | # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 80 | torch.cuda.empty_cache() 81 | print_gpu_memory_usage("Before state_dict() in sharding manager") 82 | actor_weights = self.module.state_dict() 83 | print_gpu_memory_usage("After state_dict() in sharding manager") 84 | 85 | self.inference_engine.wake_up() 86 | model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model 87 | model.load_weights(self._make_weight_iterator(actor_weights)) 88 | print_gpu_memory_usage("After sync model weights in sharding manager") 89 | 90 | del actor_weights 91 | print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager") 92 | # important: need to manually set the random states of each tp to be identical. 93 | if self.device_mesh is not None: 94 | self.torch_random_states = torch.cuda.get_rng_state() 95 | torch.cuda.set_rng_state(self.gen_random_states) 96 | 97 | def __exit__(self, exc_type, exc_value, traceback): 98 | print_gpu_memory_usage("Before vllm offload in sharding manager") 99 | self.inference_engine.sleep(level=1) 100 | print_gpu_memory_usage("After vllm offload in sharding manager") 101 | 102 | self.module.train() 103 | torch.cuda.empty_cache() # add empty cache after each compute 104 | 105 | # restore random states 106 | if self.device_mesh is not None: 107 | self.gen_random_states = torch.cuda.get_rng_state() 108 | torch.cuda.set_rng_state(self.torch_random_states) 109 | 110 | def preprocess_data(self, data: DataProto) -> DataProto: 111 | """All gather across tp group to make each rank has identical input.""" 112 | all_gather_data_proto(data, size=self.tp_size, group=self.tp_group) 113 | return data 114 | 115 | def postprocess_data(self, data: DataProto) -> DataProto: 116 | """Get chunk data of this tp rank since we do all gather in preprocess.""" 117 | if self.tp_size > 1: 118 | data = data.chunk(chunks=self.tp_size)[self.tp_rank] 119 | 120 | return data 121 | -------------------------------------------------------------------------------- /eval/utils/data_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from typing import List, Dict 7 | from datasets import load_dataset 8 | 9 | def load_geo3k_dataset(data_path: str) -> List[Dict]: 10 | """Load Geo3K dataset""" 11 | data_path = os.path.join(data_path, "geometry3k/test") 12 | dataset = [] 13 | folders = [f for f in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, f))] 14 | 15 | for folder in tqdm(folders, desc="Loading Geo3K data"): 16 | folder_path = os.path.join(data_path, folder) 17 | image_path = os.path.join(folder_path, "img_diagram.png") 18 | json_path = os.path.join(folder_path, "data.json") 19 | 20 | if not os.path.exists(image_path) or not os.path.exists(json_path): 21 | continue 22 | 23 | with open(json_path, "r", encoding="utf-8") as f: 24 | data = json.load(f) 25 | 26 | mapping = {'A': 0, 'B': 1, 'C': 2, 'D': 3} 27 | 28 | dataset.append({ 29 | "id": data["id"], 30 | "image_path": image_path, 31 | "question": data["annotat_text"], 32 | "answer": data["choices"][mapping[data["answer"]]], 33 | "dataset": "geo3k" 34 | }) 35 | 36 | return dataset 37 | 38 | def load_wemath_dataset(data_path: str) -> List[Dict]: 39 | """Load WeMath dataset""" 40 | image_root = os.path.join(data_path, "wemath/images") 41 | data_path = os.path.join(data_path, "wemath/testmini.json") 42 | with open(data_path, "r", encoding="utf-8") as f: 43 | data = json.load(f) 44 | 45 | dataset = [] 46 | for item in data: 47 | # Determine the image path 48 | image_path = os.path.join(image_root, item["image_path"]) 49 | 50 | dataset.append({ 51 | "id": item["ID"] + "@" + item["key"], 52 | "image_path": image_path, 53 | "question": f"{item['question']}\n\nOptions: {item['option']}", 54 | "answer": item["answer"], 55 | "dataset": "wemath" 56 | }) 57 | 58 | return dataset 59 | 60 | def load_mathvista_dataset(data_path: str) -> List[Dict]: 61 | """Load MathVista dataset""" 62 | image_base_dir = os.path.join(data_path, "mathvista") 63 | dataset_raw = load_dataset("AI4Math/MathVista", split="testmini") 64 | 65 | dataset = [] 66 | mapping = { 67 | "0": "A", "1": "B", "2": "C", "3": "D", 68 | "4": "E", "5": "F", "6": "G", "7": "H" 69 | } 70 | 71 | for item in dataset_raw: 72 | if item["question_type"] == "multi_choice": 73 | idx = item["choices"].index(item["answer"]) 74 | answer = mapping[str(idx)] 75 | else: 76 | answer = item["answer"] 77 | 78 | dataset.append({ 79 | "id": item.get("pid", ""), 80 | "image_path": os.path.join(image_base_dir, item["image"]), 81 | "question": item["query"], 82 | "answer": answer, 83 | "task": item["metadata"]["task"], 84 | "dataset": "mathvista" 85 | }) 86 | 87 | return dataset 88 | 89 | def load_mathverse_dataset(data_path: str) -> List[Dict]: 90 | """Load MathVerse dataset""" 91 | image_base_dir = os.path.join(data_path, "mathverse/images") 92 | data_path = os.path.join(data_path, "mathverse/testmini.json") 93 | 94 | with open(data_path, "r", encoding="utf-8") as f: 95 | data = json.load(f) 96 | 97 | dataset = [] 98 | for item in data: 99 | dataset.append({ 100 | "id": item.get("sample_index", ""), 101 | "image_path": os.path.join(image_base_dir, item["image"]), 102 | "question": item["query_cot"], 103 | "question_for_eval": item["question_for_eval"], 104 | "answer": item["answer"], 105 | "problem_version": item["problem_version"], 106 | "dataset": "mathverse" 107 | }) 108 | 109 | return dataset 110 | 111 | def load_mathvision_dataset(data_path: str) -> List[Dict]: 112 | """Load MathVision dataset""" 113 | image_base_dir = os.path.join(data_path, "mathvision/images") 114 | data_path = os.path.join(data_path, "mathvision/MathVision.tsv") 115 | df = pd.read_csv(data_path, sep='\t') 116 | 117 | dataset = [] 118 | for _, row in df.iterrows(): 119 | dataset.append({ 120 | "id": row.get("index", ""), 121 | "image_path": os.path.join(image_base_dir, f"{row['index']}.jpg"), 122 | "question": row["question"], 123 | "answer": row["answer"], 124 | "subject": row.get("category", "unknown"), 125 | "dataset": "mathvision" 126 | }) 127 | 128 | return dataset 129 | 130 | def load_hallubench_dataset(data_path: str) -> List[Dict]: 131 | """Load Hallubench dataset""" 132 | image_base_dir = os.path.join(data_path, "hallubench/images") 133 | data_path = os.path.join(data_path, "hallubench/HallusionBench.json") 134 | 135 | with open(data_path, "r", encoding="utf-8") as f: 136 | data = json.load(f) 137 | 138 | dataset = [] 139 | for item in data: 140 | if not item["filename"]: 141 | continue 142 | 143 | if "?" in item["question"]: 144 | question = item["question"].split("?")[:-1][0] 145 | else: 146 | question = item["question"] 147 | question += "? You final answer can only be \\boxed{yes} or \\boxed{no}." 148 | gt_answer = "yes" if int(item["gt_answer"]) == 1 else "no" 149 | sid, fid, qid = item["set_id"], item["figure_id"], item["question_id"] 150 | dataset.append({ 151 | "id": f"{sid}_{fid}_{qid}", 152 | "image_path": os.path.join(image_base_dir, item["filename"].replace("./", "")), 153 | "question": question, 154 | "question_for_eval": question, 155 | "answer": gt_answer, 156 | "problem_version": item["subcategory"], 157 | "dataset": "hallubench" 158 | }) 159 | 160 | return dataset -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # NoisyRollout: Reinforcing Visual Reasoning with Data Augmentation 4 | [![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2504.13055) [![Hugging Face Collection](https://img.shields.io/badge/Model_&_Dataset-HuggingFace-yellow?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/collections/xyliu6/noisyrollout-67ff992d1cf251087fe021a2) 5 | 6 |
7 | 8 | ## ⚡ Updates 9 | * 18/09/2025: 🎊 NoisyRollout has been accepted to NeurIPS 2025. See you in San Diego! 10 | * 20/05/2025: 🔥 We updated the checkpoints of our trained models (larger model sizes, more training data)! 11 | * 18/04/2025: 🎉 We released our paper, models and codebase. 12 | 13 | ## 🚀 TL;DR 14 |

15 | 16 |

17 | 18 | **NoisyRollout** is a simple and effective data augmentation strategy for the RL training of VLMs that improves visual reasoning through better policy exploration. It introduces *targeted rollout diversity* by mixing rollouts from both clean and moderately distorted images, encouraging the model to learn more robust behaviors. Moreover, a *noise annealing schedule* is implemented to ensure early-stage exploration and late-stage training stability. 19 | 20 | 🎯 **Key Benefits**: 21 | - **No additional cost** — only the rollout strategy is modified 22 | - **Easy to adopt** — no changes to the model architecture or RL objective required 23 | - **Superior generalization** — achieves state-of-the-art results on **5** out-of-domain benchmarks (e.g., **MathVerse: 53.2%**, **HallusionBench: 72.1%**) with just **2.1K** RL samples 24 | 25 | 🫱 No complicated changes — just smarter rollouts and better training! 26 | 27 |

28 | 29 |

30 | 31 | ## 🛠️ Usage 32 | ### (Step1) Install 33 | First, download the wheel of vllm from this [link](https://drive.google.com/file/d/1tO6BQ4omkeXTQhDBTAFi7U7qR8vF55wP/view?usp=sharing). 34 | ```bash 35 | conda create -n noisyrollout python=3.11 -y && conda activate noisyrollout 36 | 37 | pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 transformers==4.49.0 numpy==1.26.4 38 | pip3 install google-generativeai 39 | 40 | # Use this version of vLLM to avoid memory leaks. 41 | pip3 install vllm-0.7.4.dev65+g22757848-cp38-abi3-manylinux1_x86_64.whl 42 | git clone -b verl_v1 https://github.com/hiyouga/vllm.git 43 | cp -r vllm/vllm/ ~/miniconda3/envs/noisyrollout/lib/python3.11/site-packages/ 44 | 45 | pip3 install -e . 46 | ``` 47 | 48 | ### (Step 2) Training 49 | ```bash 50 | # Geo3K (NoisyRollout) 51 | bash training_scripts/qwen2_5_vl_7b_geo3k_noisyrollout.sh 52 | # Geo3K (Vanilla GRPO) 53 | bash training_scripts/qwen2_5_vl_7b_geo3k_grpo.sh 54 | 55 | # K12 (NoisyRollout) 56 | bash training_scripts/qwen2_5_vl_7b_k12_noisyrollout.sh 57 | # K12 (Vanilla GRPO) 58 | bash training_scripts/qwen2_5_vl_7b_k12_grpo.sh 59 | ``` 60 | ### (Step 3) Evaluation 61 | Before running the evaluation, please download the evaluation datasets from [🤗 NoisyRollout Evaluation](https://huggingface.co/datasets/xyliu6/noisyrollout_evaluation_data). Then, create a directory by running `mkdir -p ~/NoisyRollout/eval/data`, upload the `eval_data.zip` file to the `data` folder, and unzip it there. 62 | ```diff 63 | #!/bin/bash 64 | source ~/.bashrc 65 | source ~/miniconda3/bin/activate noisyrollout 66 | 67 | export VLLM_ATTENTION_BACKEND=XFORMERS # remove it when using 32b models 68 | export VLLM_USE_V1=0 # remove it when using 32b models 69 | export GOOGLE_API_KEY="xxx" # put your api key here 70 | 71 | HF_MODEL_PATH="xyliu6/NoisyRollout-Geo3K-7B" 72 | RESULTS_DIR="results/" 73 | EVAL_DIR="~/NoisyRollout/eval" 74 | DATA_DIR="~/NoisyRollout/eval/data" 75 | 76 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}.""" 77 | 78 | cd $EVAL_DIR 79 | python main.py \ 80 | --model $HF_MODEL_PATH \ 81 | --output-dir $RESULTS_DIR \ 82 | --data-path $DATA_DIR \ 83 | --datasets geo3k,hallubench,mathvista,wemath,mathverse,mathvision \ 84 | --tensor-parallel-size 2 \ 85 | --system-prompt="$SYSTEM_PROMPT" \ 86 | --min-pixels 262144 \ 87 | --max-pixels 1000000 \ 88 | --max-model-len 8192 \ 89 | --temperature 0.0 \ 90 | --eval-threads 24 \ 91 | --version="7b" # change it to `32b` when using 32b models 92 | ``` 93 | > 🚧 Currently, only `Gemini-2.0-Flash-001` is supported for parsing generated responses. Support for additional models will be introduced in future updates. 94 | 95 | ## Citation 96 | If you find our works useful for your research, please consider citing: 97 | ```bibtex 98 | @article{liu2025noisyrollout, 99 | title={Noisyrollout: Reinforcing visual reasoning with data augmentation}, 100 | author={Liu, Xiangyan and Ni, Jinjie and Wu, Zijian and Du, Chao and Dou, Longxu and Wang, Haonan and Pang, Tianyu and Shieh, Michael Qizhe}, 101 | journal={arXiv preprint arXiv:2504.13055}, 102 | year={2025} 103 | } 104 | ``` 105 | 106 | ## Acknowledgement 107 | * This work is supported by [Sea AI Lab](https://sail.sea.com/) for computing resources. 108 | * The training codes are built on [EasyR1](https://github.com/hiyouga/EasyR1), and the evaluation suite employs [vLLM](https://github.com/vllm-project/vllm) for acceleration. 109 | * The base models are from [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) and [Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct). 110 | * The original training datasets are from [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) and [K12](https://huggingface.co/datasets/FanqingM/MM-Eureka-Dataset). 111 | * The evaluation datasets are from [MathVerse](https://huggingface.co/datasets/AI4Math/MathVerse), [MathVision](https://huggingface.co/datasets/MathLLMs/MathVision), [MathVista](https://huggingface.co/datasets/AI4Math/MathVista), [WeMath](https://huggingface.co/datasets/We-Math/We-Math), and [HallusionBench](https://github.com/tianyi-lab/HallusionBench). 112 | -------------------------------------------------------------------------------- /verl/trainer/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List 16 | 17 | import numpy as np 18 | import torch 19 | 20 | from ..protocol import DataProto 21 | 22 | 23 | def _compute_response_info(batch: DataProto) -> Dict[str, Any]: 24 | response_length = batch.batch["responses"].shape[-1] 25 | prompt_mask = batch.batch["attention_mask"][:, :-response_length] 26 | response_mask = batch.batch["attention_mask"][:, -response_length:] 27 | prompt_length = prompt_mask.sum(-1).float() 28 | response_length = response_mask.sum(-1).float() # (batch_size,) 29 | return dict( 30 | response_mask=response_mask, 31 | prompt_length=prompt_length, 32 | response_length=response_length, 33 | ) 34 | 35 | 36 | def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: 37 | return {key: np.mean(value) for key, value in metrics.items()} 38 | 39 | 40 | def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]: 41 | sequence_score = batch.batch["token_level_scores"].sum(-1) 42 | sequence_reward = batch.batch["token_level_rewards"].sum(-1) 43 | 44 | advantages = batch.batch["advantages"] 45 | returns = batch.batch["returns"] 46 | 47 | max_response_length = batch.batch["responses"].size(-1) 48 | 49 | prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() 50 | response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() 51 | 52 | max_prompt_length = prompt_mask.size(-1) 53 | 54 | response_info = _compute_response_info(batch) 55 | prompt_length = response_info["prompt_length"] 56 | response_length = response_info["response_length"] 57 | 58 | valid_adv = torch.masked_select(advantages, response_mask) 59 | valid_returns = torch.masked_select(returns, response_mask) 60 | 61 | if use_critic: 62 | values = batch.batch["values"] 63 | valid_values = torch.masked_select(values, response_mask) 64 | return_diff_var = torch.var(valid_returns - valid_values) 65 | return_var = torch.var(valid_returns) 66 | 67 | metrics = { 68 | # score 69 | "critic/score/mean": torch.mean(sequence_score).detach().item(), 70 | "critic/score/max": torch.max(sequence_score).detach().item(), 71 | "critic/score/min": torch.min(sequence_score).detach().item(), 72 | # reward 73 | "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), 74 | "critic/rewards/max": torch.max(sequence_reward).detach().item(), 75 | "critic/rewards/min": torch.min(sequence_reward).detach().item(), 76 | # adv 77 | "critic/advantages/mean": torch.mean(valid_adv).detach().item(), 78 | "critic/advantages/max": torch.max(valid_adv).detach().item(), 79 | "critic/advantages/min": torch.min(valid_adv).detach().item(), 80 | # returns 81 | "critic/returns/mean": torch.mean(valid_returns).detach().item(), 82 | "critic/returns/max": torch.max(valid_returns).detach().item(), 83 | "critic/returns/min": torch.min(valid_returns).detach().item(), 84 | **( 85 | { 86 | # values 87 | "critic/values/mean": torch.mean(valid_values).detach().item(), 88 | "critic/values/max": torch.max(valid_values).detach().item(), 89 | "critic/values/min": torch.min(valid_values).detach().item(), 90 | # vf explained var 91 | "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), 92 | } 93 | if use_critic 94 | else {} 95 | ), 96 | # response length 97 | "response_length/mean": torch.mean(response_length).detach().item(), 98 | "response_length/max": torch.max(response_length).detach().item(), 99 | "response_length/min": torch.min(response_length).detach().item(), 100 | "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) 101 | .detach() 102 | .item(), 103 | # prompt length 104 | "prompt_length/mean": torch.mean(prompt_length).detach().item(), 105 | "prompt_length/max": torch.max(prompt_length).detach().item(), 106 | "prompt_length/min": torch.min(prompt_length).detach().item(), 107 | "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), 108 | } 109 | return metrics 110 | 111 | 112 | def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: 113 | response_info = _compute_response_info(batch) 114 | num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() 115 | num_response_tokens = torch.sum(response_info["response_length"]).item() 116 | num_overall_tokens = num_prompt_tokens + num_response_tokens 117 | num_tokens_of_section = { 118 | "gen": num_response_tokens, 119 | **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, 120 | } 121 | return { 122 | **{f"timing_s/{name}": value for name, value in timing_raw.items()}, 123 | **{ 124 | f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] 125 | for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) 126 | }, 127 | } 128 | 129 | 130 | def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]: 131 | total_num_tokens = sum(batch.meta_info["global_token_num"]) 132 | time = timing_raw["step"] 133 | return { 134 | "perf/total_num_tokens": total_num_tokens, 135 | "perf/time_per_step": time, 136 | "perf/throughput": total_num_tokens / (time * n_gpus), 137 | } 138 | -------------------------------------------------------------------------------- /eval/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import torch 5 | from vllm import LLM, SamplingParams 6 | from utils.data_loaders import ( 7 | load_geo3k_dataset, 8 | load_wemath_dataset, 9 | load_mathvista_dataset, 10 | load_mathverse_dataset, 11 | load_mathvision_dataset, 12 | load_hallubench_dataset 13 | ) 14 | from utils.processing import ( 15 | prepare_prompts, 16 | process_outputs, 17 | calculate_metrics 18 | ) 19 | 20 | def parse_arguments(): 21 | parser = argparse.ArgumentParser(description="Unified evaluation for multimodal math datasets") 22 | 23 | # Model and runtime parameters 24 | parser.add_argument("--model", type=str, required=True, help="Path to the model") 25 | parser.add_argument("--output-dir", type=str, required=True, help="Directory to save results") 26 | parser.add_argument("--max-tokens", type=int, default=2048, help="Maximum number of tokens to generate") 27 | parser.add_argument("--min-pixels", type=int, default=262144) 28 | parser.add_argument("--max-pixels", type=int, default=1000000) 29 | parser.add_argument("--max-model-len", type=int, default=8192) 30 | parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") 31 | parser.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling") 32 | parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty") 33 | parser.add_argument("--tensor-parallel-size", type=int, default=2, help="Number of GPUs for tensor parallelism") 34 | parser.add_argument("--eval-threads", type=int, default=32, help="Number of threads for evaluation") 35 | parser.add_argument("--system-prompt", type=str, default="You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \\boxed{}.", help="System prompt for the model") 36 | parser.add_argument("--version", type=str, default="7b") 37 | 38 | # Dataset selection 39 | parser.add_argument("--datasets", type=str, default="all", help="Comma-separated list of datasets to evaluate: geo3k,wemath,mathvista,mathverse,mathvision or 'all'") 40 | 41 | # Dataset-specific paths 42 | parser.add_argument("--data-path", type=str, default="NoisyRollout/eval/data", help="") 43 | 44 | return parser.parse_args() 45 | 46 | def main(): 47 | args = parse_arguments() 48 | 49 | # Create output directory if it doesn't exist 50 | os.makedirs(args.output_dir, exist_ok=True) 51 | 52 | # Determine which datasets to evaluate 53 | datasets_to_eval = args.datasets.split(",") if args.datasets != "all" else [ 54 | "geo3k", "wemath", "mathvista", "mathverse", "mathvision", "hallubench" 55 | ] 56 | 57 | # Dictionary to store all samples 58 | all_samples = {} 59 | 60 | # Load datasets based on selection 61 | for dataset_name in datasets_to_eval: 62 | if dataset_name == "geo3k": 63 | all_samples["geo3k"] = load_geo3k_dataset(args.data_path) 64 | print(f"Loaded {len(all_samples['geo3k'])} samples from Geo3K") 65 | 66 | elif dataset_name == "wemath": 67 | all_samples["wemath"] = load_wemath_dataset(args.data_path) 68 | print(f"Loaded {len(all_samples['wemath'])} samples from WeMath") 69 | 70 | elif dataset_name == "mathvista": 71 | all_samples["mathvista"] = load_mathvista_dataset(args.data_path) 72 | print(f"Loaded {len(all_samples['mathvista'])} samples from MathVista") 73 | 74 | elif dataset_name == "mathverse": 75 | all_samples["mathverse"] = load_mathverse_dataset(args.data_path) 76 | print(f"Loaded {len(all_samples['mathverse'])} samples from MathVerse") 77 | 78 | elif dataset_name == "mathvision": 79 | all_samples["mathvision"] = load_mathvision_dataset(args.data_path) 80 | print(f"Loaded {len(all_samples['mathvision'])} samples from MathVision") 81 | 82 | elif dataset_name == "hallubench": 83 | all_samples["hallubench"] = load_hallubench_dataset(args.data_path) 84 | print(f"Loaded {len(all_samples['hallubench'])} samples from HalluBench") 85 | 86 | if not all_samples: 87 | print("No datasets loaded. Please check the paths and dataset names.") 88 | return 89 | 90 | # Initialize model 91 | print(f"Initializing model from {args.model}") 92 | llm = LLM( 93 | model=args.model, 94 | tensor_parallel_size=args.tensor_parallel_size, 95 | dtype=torch.bfloat16, 96 | gpu_memory_utilization=0.7, 97 | max_model_len=args.max_model_len 98 | ) 99 | 100 | # Configure sampling parameters 101 | sampling_params = SamplingParams( 102 | temperature=args.temperature, 103 | top_p=args.top_p, 104 | max_tokens=args.max_tokens, 105 | repetition_penalty=args.repetition_penalty, 106 | ) 107 | 108 | # Process in batches 109 | all_results = {} 110 | for dataset_name in all_samples.keys(): 111 | all_results[dataset_name] = [] 112 | 113 | for dataset_name, samples in all_samples.items(): 114 | prompts, metadata = prepare_prompts(dataset_name, samples, args) 115 | 116 | outputs = llm.generate(prompts, sampling_params) 117 | 118 | # Process outputs 119 | results = process_outputs(outputs, metadata, args.eval_threads) 120 | all_results[dataset_name] = results 121 | 122 | metrics = calculate_metrics(results) 123 | 124 | output_dict = { 125 | "results": results, 126 | "metrics": metrics, 127 | "config": vars(args) 128 | } 129 | 130 | output_path = os.path.join(args.output_dir, f"{dataset_name}.json") 131 | with open(output_path, 'w', encoding='utf-8') as f: 132 | json.dump(output_dict, f, ensure_ascii=False, indent=2) 133 | 134 | print(f"{dataset_name.upper()} Results:") 135 | print(f" Total samples: {len(results)}") 136 | print(f" Accuracy: {metrics['accuracy']:.4f}") 137 | if 'sub_accuracies' in metrics: 138 | print(" Task/Category Accuracies:") 139 | for task, acc in metrics['sub_accuracies'].items(): 140 | print(f" {task}: {acc:.4f}") 141 | print() 142 | 143 | print(f"All results saved to {args.output_dir}") 144 | 145 | if __name__ == "__main__": 146 | main() -------------------------------------------------------------------------------- /verl/single_controller/base/worker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | the class for Worker 16 | """ 17 | 18 | import os 19 | import socket 20 | from dataclasses import dataclass 21 | 22 | import ray 23 | 24 | from .decorator import Dispatch, Execute, register 25 | from .register_center.ray import create_worker_group_register_center 26 | 27 | 28 | @dataclass 29 | class DistRankInfo: 30 | tp_rank: int 31 | dp_rank: int 32 | pp_rank: int 33 | 34 | 35 | @dataclass 36 | class DistGlobalInfo: 37 | tp_size: int 38 | dp_size: int 39 | pp_size: int 40 | 41 | 42 | class WorkerHelper: 43 | def _get_node_ip(self): 44 | host_ipv4 = os.getenv("MY_HOST_IP", None) 45 | host_ipv6 = os.getenv("MY_HOST_IPV6", None) 46 | host_ip_by_env = host_ipv4 or host_ipv6 47 | host_ip_by_sdk = ray._private.services.get_node_ip_address() 48 | 49 | host_ip = host_ip_by_env or host_ip_by_sdk 50 | return host_ip 51 | 52 | def _get_free_port(self): 53 | with socket.socket() as sock: 54 | sock.bind(("", 0)) 55 | return sock.getsockname()[1] 56 | 57 | def get_availale_master_addr_port(self): 58 | return self._get_node_ip(), str(self._get_free_port()) 59 | 60 | def _get_pid(self): 61 | return 62 | 63 | 64 | class WorkerMeta: 65 | keys = [ 66 | "WORLD_SIZE", 67 | "RANK", 68 | "LOCAL_WORLD_SIZE", 69 | "LOCAL_RANK", 70 | "MASTER_ADDR", 71 | "MASTER_PORT", 72 | "CUDA_VISIBLE_DEVICES", 73 | ] 74 | 75 | def __init__(self, store) -> None: 76 | self._store = store 77 | 78 | def to_dict(self): 79 | return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys} 80 | 81 | 82 | # we assume that in each WorkerGroup, there is a Master Worker 83 | class Worker(WorkerHelper): 84 | def __new__(cls, *args, **kwargs): 85 | instance = super().__new__(cls) 86 | 87 | # note that here we use int to distinguish 88 | disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0)) 89 | if disable_worker_init: 90 | return instance 91 | 92 | rank = os.environ.get("RANK", None) 93 | worker_group_prefix = os.environ.get("WG_PREFIX", None) 94 | 95 | # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init 96 | if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: 97 | instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) 98 | 99 | return instance 100 | 101 | def _configure_before_init(self, register_center_name: str, rank: int): 102 | assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}" 103 | 104 | if rank == 0: 105 | master_addr, master_port = self.get_availale_master_addr_port() 106 | rank_zero_info = { 107 | "MASTER_ADDR": master_addr, 108 | "MASTER_PORT": master_port, 109 | } 110 | self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info) 111 | os.environ.update(rank_zero_info) 112 | 113 | def __init__(self, cuda_visible_devices=None) -> None: 114 | # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely 115 | world_size = int(os.environ["WORLD_SIZE"]) 116 | rank = int(os.environ["RANK"]) 117 | self._rank = rank 118 | self._world_size = world_size 119 | 120 | master_addr = os.environ["MASTER_ADDR"] 121 | master_port = os.environ["MASTER_PORT"] 122 | 123 | local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) 124 | local_rank = int(os.getenv("LOCAL_RANK", "0")) 125 | 126 | store = { 127 | "_world_size": world_size, 128 | "_rank": rank, 129 | "_local_world_size": local_world_size, 130 | "_local_rank": local_rank, 131 | "_master_addr": master_addr, 132 | "_master_port": master_port, 133 | } 134 | if cuda_visible_devices is not None: 135 | store["_cuda_visible_devices"] = cuda_visible_devices 136 | 137 | meta = WorkerMeta(store=store) 138 | self._configure_with_meta(meta=meta) 139 | 140 | def _configure_with_meta(self, meta: WorkerMeta): 141 | """ 142 | This function should only be called inside by WorkerGroup 143 | """ 144 | assert isinstance(meta, WorkerMeta) 145 | self.__dict__.update(meta.to_dict()) # this is hacky 146 | # print(f"__dict__: {self.__dict__}") 147 | for key in WorkerMeta.keys: 148 | val = self.__dict__.get(f"_{key.lower()}", None) 149 | if val is not None: 150 | # print(f"set {key} to {val}") 151 | os.environ[key] = str(val) 152 | 153 | os.environ["REDIS_STORE_SERVER_HOST"] = ( 154 | str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" 155 | ) 156 | 157 | def get_master_addr_port(self): 158 | return self._master_addr, self._master_port 159 | 160 | def get_cuda_visible_devices(self): 161 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set") 162 | return cuda_visible_devices 163 | 164 | def print_rank0(self, *args, **kwargs): 165 | if self.rank == 0: 166 | print(*args, **kwargs) 167 | 168 | @property 169 | def world_size(self): 170 | return self._world_size 171 | 172 | @property 173 | def rank(self): 174 | return self._rank 175 | 176 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC) 177 | def execute_with_func_generator(self, func, *args, **kwargs): 178 | ret_proto = func(self, *args, **kwargs) 179 | return ret_proto 180 | 181 | @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) 182 | def execute_func_rank_zero(self, func, *args, **kwargs): 183 | result = func(*args, **kwargs) 184 | return result 185 | -------------------------------------------------------------------------------- /verl/utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import os 17 | from collections import defaultdict 18 | from io import BytesIO 19 | from typing import Any, Dict, List, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | from datasets import load_dataset 24 | from PIL import Image 25 | from PIL.Image import Image as ImageObject 26 | from torch.utils.data import Dataset 27 | from transformers import PreTrainedTokenizer, ProcessorMixin 28 | 29 | from ..models.transformers.qwen2_vl import get_rope_index 30 | from . import torch_functional as VF 31 | 32 | 33 | def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]: 34 | tensors = defaultdict(list) 35 | non_tensors = defaultdict(list) 36 | for feature in features: 37 | for key, value in feature.items(): 38 | if isinstance(value, torch.Tensor): 39 | tensors[key].append(value) 40 | else: 41 | non_tensors[key].append(value) 42 | 43 | for key, value in tensors.items(): 44 | tensors[key] = torch.stack(value, dim=0) 45 | 46 | for key, value in non_tensors.items(): 47 | non_tensors[key] = np.array(value, dtype=object) 48 | 49 | return {**tensors, **non_tensors} 50 | 51 | 52 | def process_image(image: Union[Dict[str, Any], ImageObject], max_pixels: int, min_pixels: int) -> ImageObject: 53 | if isinstance(image, dict): 54 | image = Image.open(BytesIO(image["bytes"])) 55 | 56 | if (image.width * image.height) > max_pixels: 57 | resize_factor = math.sqrt(max_pixels / (image.width * image.height)) 58 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 59 | image = image.resize((width, height)) 60 | 61 | if (image.width * image.height) < min_pixels: 62 | resize_factor = math.sqrt(min_pixels / (image.width * image.height)) 63 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 64 | image = image.resize((width, height)) 65 | 66 | if image.mode != "RGB": 67 | image = image.convert("RGB") 68 | 69 | return image 70 | 71 | 72 | class RLHFDataset(Dataset): 73 | """ 74 | We assume the dataset contains a column that contains prompts and other information 75 | """ 76 | 77 | def __init__( 78 | self, 79 | data_path: str, 80 | tokenizer: PreTrainedTokenizer, 81 | processor: Optional[ProcessorMixin], 82 | prompt_key: str = "prompt", 83 | answer_key: str = "answer", 84 | image_key: str = "images", 85 | max_prompt_length: int = 1024, 86 | truncation: str = "error", 87 | system_prompt: str = None, 88 | max_pixels: int = None, 89 | min_pixels: int = None, 90 | ): 91 | self.tokenizer = tokenizer 92 | self.processor = processor 93 | self.prompt_key = prompt_key 94 | self.answer_key = answer_key 95 | self.image_key = image_key 96 | self.max_prompt_length = max_prompt_length 97 | self.truncation = truncation 98 | self.system_prompt = system_prompt 99 | self.max_pixels = max_pixels 100 | self.min_pixels = min_pixels 101 | 102 | if "@" in data_path: 103 | data_path, data_split = data_path.split("@") 104 | else: 105 | data_split = "train" 106 | 107 | if os.path.isdir(data_path): 108 | self.dataset = load_dataset("parquet", data_dir=data_path, split="train") 109 | elif os.path.isfile(data_path): 110 | self.dataset = load_dataset("parquet", data_files=data_path, split="train") 111 | else: # remote dataset 112 | self.dataset = load_dataset(data_path, split=data_split) 113 | 114 | def __len__(self): 115 | return len(self.dataset) 116 | 117 | def __getitem__(self, index): 118 | row_dict: dict = self.dataset[index] 119 | messages = [{"role": "user", "content": row_dict[self.prompt_key]}] 120 | if self.system_prompt: 121 | messages.insert(0, {"role": "system", "content": self.system_prompt}) 122 | 123 | prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 124 | 125 | if self.image_key in row_dict: 126 | prompt = prompt.replace("", "<|vision_start|><|image_pad|><|vision_end|>") 127 | row_dict["multi_modal_data"] = { 128 | "image": [ 129 | process_image(image, self.max_pixels, self.min_pixels) for image in row_dict.pop(self.image_key) 130 | ] 131 | } 132 | model_inputs = self.processor(row_dict["multi_modal_data"]["image"], prompt, return_tensors="pt") 133 | input_ids = model_inputs.pop("input_ids")[0] 134 | attention_mask = model_inputs.pop("attention_mask")[0] 135 | row_dict["multi_modal_inputs"] = dict(model_inputs) 136 | position_ids = get_rope_index( 137 | self.processor, 138 | input_ids=input_ids, 139 | image_grid_thw=model_inputs["image_grid_thw"], 140 | attention_mask=attention_mask, 141 | ) # (3, seq_length) 142 | else: 143 | model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt") 144 | input_ids = model_inputs.pop("input_ids")[0] 145 | attention_mask = model_inputs.pop("attention_mask")[0] 146 | position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,) 147 | 148 | input_ids, attention_mask, position_ids = VF.postprocess_data( 149 | input_ids=input_ids, 150 | attention_mask=attention_mask, 151 | position_ids=position_ids, 152 | max_length=self.max_prompt_length, 153 | pad_token_id=self.tokenizer.pad_token_id, 154 | left_pad=True, 155 | truncation=self.truncation, 156 | ) 157 | row_dict["input_ids"] = input_ids 158 | row_dict["attention_mask"] = attention_mask 159 | row_dict["position_ids"] = position_ids 160 | row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False) 161 | row_dict["ground_truth"] = row_dict.pop(self.answer_key) 162 | row_dict["processed_prompt"] = prompt 163 | return row_dict 164 | -------------------------------------------------------------------------------- /scripts/model_merger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import re 18 | from concurrent.futures import ThreadPoolExecutor 19 | from typing import Dict, List, Tuple 20 | 21 | import torch 22 | from torch.distributed._tensor import DTensor, Placement, Shard 23 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq 24 | 25 | 26 | def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): 27 | if placement.is_replicate(): 28 | return tensors[0] 29 | elif placement.is_partial(): 30 | raise NotImplementedError("Partial placement is not supported yet") 31 | elif placement.is_shard(): 32 | return torch.cat(tensors, dim=placement.dim).contiguous() 33 | else: 34 | raise ValueError(f"Unsupported placement: {placement}") 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model") 40 | parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") 41 | args = parser.parse_args() 42 | 43 | assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface" 44 | local_dir = args.local_dir 45 | 46 | # copy rank zero to find the shape of (dp, fsdp) 47 | rank = 0 48 | world_size = 0 49 | for filename in os.listdir(local_dir): 50 | match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) 51 | if match: 52 | world_size = match.group(1) 53 | break 54 | assert world_size, "No model file with the proper format" 55 | 56 | state_dict = torch.load( 57 | os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu" 58 | ) 59 | pivot_key = sorted(state_dict.keys())[0] 60 | weight = state_dict[pivot_key] 61 | assert isinstance(weight, torch.distributed._tensor.DTensor) 62 | # get sharding info 63 | device_mesh = weight.device_mesh 64 | mesh = device_mesh.mesh 65 | mesh_dim_names = device_mesh.mesh_dim_names 66 | 67 | print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") 68 | 69 | assert mesh_dim_names in (("fsdp",),), f"Unsupported mesh_dim_names {mesh_dim_names}" 70 | 71 | if "tp" in mesh_dim_names: 72 | # fsdp * tp 73 | total_shards = mesh.shape[-1] * mesh.shape[-2] 74 | mesh_shape = (mesh.shape[-2], mesh.shape[-1]) 75 | else: 76 | # fsdp 77 | total_shards = mesh.shape[-1] 78 | mesh_shape = (mesh.shape[-1],) 79 | 80 | print(f"Processing model shards with {total_shards} {mesh_shape} in total") 81 | 82 | model_state_dict_lst = [] 83 | model_state_dict_lst.append(state_dict) 84 | model_state_dict_lst.extend([""] * (total_shards - 1)) 85 | 86 | def process_one_shard(rank): 87 | model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") 88 | state_dict = torch.load(model_path, map_location="cpu", weights_only=False) 89 | model_state_dict_lst[rank] = state_dict 90 | return state_dict 91 | 92 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: 93 | for rank in range(1, total_shards): 94 | executor.submit(process_one_shard, rank) 95 | state_dict = {} 96 | param_placements: Dict[str, List[Placement]] = {} 97 | keys = set(model_state_dict_lst[0].keys()) 98 | for key in keys: 99 | state_dict[key] = [] 100 | for model_state_dict in model_state_dict_lst: 101 | try: 102 | tensor = model_state_dict.pop(key) 103 | except Exception: 104 | print("-" * 30) 105 | print(model_state_dict) 106 | if isinstance(tensor, DTensor): 107 | state_dict[key].append(tensor._local_tensor.bfloat16()) 108 | placements = tuple(tensor.placements) 109 | # replicated placement at dp dimension can be discarded 110 | if mesh_dim_names[0] == "dp": 111 | placements = placements[1:] 112 | if key not in param_placements: 113 | param_placements[key] = placements 114 | else: 115 | assert param_placements[key] == placements 116 | else: 117 | state_dict[key] = tensor.bfloat16() 118 | 119 | del model_state_dict_lst 120 | 121 | for key in sorted(state_dict): 122 | if not isinstance(state_dict[key], list): 123 | print(f"No need to merge key {key}") 124 | continue 125 | # merge shards 126 | placements: Tuple[Shard] = param_placements[key] 127 | if len(mesh_shape) == 1: 128 | # 1-D list, FSDP without TP 129 | assert len(placements) == 1 130 | shards = state_dict[key] 131 | state_dict[key] = merge_by_placement(shards, placements[0]) 132 | else: 133 | # 2-D list, FSDP + TP 134 | raise NotImplementedError("FSDP + TP is not supported yet") 135 | 136 | print("Writing to local disk") 137 | hf_path = os.path.join(local_dir, "huggingface") 138 | config = AutoConfig.from_pretrained(hf_path) 139 | 140 | if "ForTokenClassification" in config.architectures[0]: 141 | auto_model = AutoModelForTokenClassification 142 | elif "ForCausalLM" in config.architectures[0]: 143 | auto_model = AutoModelForCausalLM 144 | elif "ForConditionalGeneration" in config.architectures[0]: 145 | auto_model = AutoModelForVision2Seq 146 | else: 147 | raise NotImplementedError(f"Unknown architecture {config.architectures}") 148 | 149 | with torch.device("meta"): 150 | model = auto_model.from_config(config, torch_dtype=torch.bfloat16) 151 | 152 | model.to_empty(device="cpu") 153 | 154 | print(f"Saving model to {hf_path}") 155 | model.save_pretrained(hf_path, state_dict=state_dict) 156 | del state_dict 157 | del model 158 | if args.hf_upload_path: 159 | # Push to hugging face 160 | from huggingface_hub import HfApi 161 | 162 | api = HfApi() 163 | api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True) 164 | api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model") 165 | -------------------------------------------------------------------------------- /verl/utils/checkpoint/fsdp_checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import warnings 17 | from typing import Union 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 22 | from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType 23 | from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin 24 | 25 | from .checkpoint_manager import BaseCheckpointManager 26 | 27 | 28 | class FSDPCheckpointManager(BaseCheckpointManager): 29 | """ 30 | A checkpoint manager that saves and loads 31 | - model 32 | - optimizer 33 | - lr_scheduler 34 | - extra_states 35 | in a SPMD way. 36 | 37 | We save 38 | - sharded model states and optimizer states 39 | - full lr_scheduler states 40 | - huggingface tokenizer and config for ckpt merge 41 | """ 42 | 43 | def __init__( 44 | self, 45 | model: FSDP, 46 | optimizer: torch.optim.Optimizer, 47 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler, 48 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin], 49 | ): 50 | super().__init__(model, optimizer, lr_scheduler, processing_class) 51 | 52 | def load_checkpoint(self, path: str = None, remove_ckpt_after_load: bool = False): 53 | if path is None: 54 | return 55 | 56 | # every rank download its own checkpoint 57 | local_model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") 58 | local_optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") 59 | local_extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") 60 | print( 61 | f"[rank-{self.rank}]: Loading from {local_model_path} and {local_optim_path} and {local_extra_state_path}" 62 | ) 63 | model_state_dict = torch.load(local_model_path, weights_only=False) 64 | optimizer_state_dict = torch.load(local_optim_path, weights_only=False) 65 | extra_state_dict = torch.load(local_extra_state_path, weights_only=False) 66 | 67 | if remove_ckpt_after_load: 68 | try: 69 | os.remove(local_model_path) 70 | os.remove(local_optim_path) 71 | os.remove(local_extra_state_path) 72 | except Exception as e: 73 | print(f"[rank-{self.rank}]: remove ckpt file after loading failed, exception {e} will be ignored.") 74 | 75 | lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] 76 | state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) 77 | optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) 78 | with warnings.catch_warnings(): 79 | warnings.simplefilter("ignore") 80 | with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): 81 | self.model.load_state_dict(model_state_dict) 82 | if self.optimizer is not None: 83 | self.optimizer.load_state_dict(optimizer_state_dict) 84 | 85 | # recover random state 86 | if "rng" in extra_state_dict: 87 | self.load_rng_state(extra_state_dict["rng"]) 88 | 89 | if self.lr_scheduler is not None: 90 | self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) 91 | 92 | def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt: bool = False): 93 | # record the previous global step 94 | self.previous_global_step = global_step 95 | 96 | # remove previous local_path 97 | if remove_previous_ckpt: 98 | self.remove_previous_save_local_path() 99 | 100 | local_path = self.local_mkdir(local_path) 101 | dist.barrier() 102 | 103 | # every rank will save its own model and optim shard 104 | state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) 105 | optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) 106 | with warnings.catch_warnings(): 107 | warnings.simplefilter("ignore") 108 | with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): 109 | model_state_dict = self.model.state_dict() 110 | if self.optimizer is not None: 111 | optimizer_state_dict = self.optimizer.state_dict() 112 | else: 113 | optimizer_state_dict = None 114 | 115 | if self.lr_scheduler is not None: 116 | lr_scheduler_state_dict = self.lr_scheduler.state_dict() 117 | else: 118 | lr_scheduler_state_dict = None 119 | 120 | extra_state_dict = { 121 | "lr_scheduler": lr_scheduler_state_dict, 122 | "rng": self.get_rng_state(), 123 | } 124 | model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") 125 | optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") 126 | extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") 127 | 128 | print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.") 129 | print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.") 130 | print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.") 131 | torch.save(model_state_dict, model_path) 132 | if self.optimizer is not None: 133 | torch.save(optimizer_state_dict, optim_path) 134 | 135 | torch.save(extra_state_dict, extra_path) 136 | 137 | # wait for everyone to dump to local 138 | dist.barrier() 139 | 140 | if self.rank == 0: 141 | hf_local_path = os.path.join(local_path, "huggingface") 142 | os.makedirs(hf_local_path, exist_ok=True) 143 | assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel) 144 | self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path) 145 | self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_local_path) 146 | self.processing_class.save_pretrained(hf_local_path) 147 | 148 | dist.barrier() 149 | self.previous_save_local_path = local_path 150 | -------------------------------------------------------------------------------- /verl/single_controller/base/decorator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum, auto 16 | from functools import wraps 17 | from types import FunctionType 18 | from typing import TYPE_CHECKING, List 19 | 20 | import ray 21 | 22 | from ...protocol import DataProto, DataProtoFuture 23 | 24 | 25 | if TYPE_CHECKING: 26 | from .worker_group import WorkerGroup 27 | 28 | 29 | # here we add a magic number of avoid user-defined function already have this attribute 30 | MAGIC_ATTR = "attrs_3141562937" 31 | 32 | 33 | class Dispatch(Enum): 34 | RANK_ZERO = auto() 35 | ONE_TO_ALL = auto() 36 | ALL_TO_ALL = auto() 37 | DP_COMPUTE = auto() 38 | DP_COMPUTE_PROTO = auto() 39 | DP_COMPUTE_PROTO_WITH_FUNC = auto() 40 | DP_COMPUTE_METRIC = auto() 41 | 42 | 43 | class Execute(Enum): 44 | ALL = 0 45 | RANK_ZERO = 1 46 | 47 | 48 | def _split_args_kwargs_data_proto(chunks, *args, **kwargs): 49 | splitted_args = [] 50 | for arg in args: 51 | assert isinstance(arg, (DataProto, DataProtoFuture)) 52 | splitted_args.append(arg.chunk(chunks=chunks)) 53 | 54 | splitted_kwargs = {} 55 | for key, val in kwargs.items(): 56 | assert isinstance(val, (DataProto, DataProtoFuture)) 57 | splitted_kwargs[key] = val.chunk(chunks=chunks) 58 | 59 | return splitted_args, splitted_kwargs 60 | 61 | 62 | def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs): 63 | args = tuple([arg] * worker_group.world_size for arg in args) 64 | kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} 65 | return args, kwargs 66 | 67 | 68 | def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs): 69 | return args, kwargs 70 | 71 | 72 | def collect_all_to_all(worker_group: "WorkerGroup", output): 73 | return output 74 | 75 | 76 | def _concat_data_proto_or_future(output: List): 77 | # make sure all the elements in output has the same type 78 | for o in output: 79 | assert type(o) is type(output[0]) 80 | 81 | o = output[0] 82 | 83 | if isinstance(o, DataProto): 84 | return DataProto.concat(output) 85 | elif isinstance(o, ray.ObjectRef): 86 | return DataProtoFuture.concat(output) 87 | else: 88 | raise NotImplementedError 89 | 90 | 91 | def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs): 92 | for value in args: 93 | assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size 94 | 95 | for value in kwargs.values(): 96 | assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size 97 | 98 | return args, kwargs 99 | 100 | 101 | def collect_dp_compute(worker_group: "WorkerGroup", output): 102 | assert len(output) == worker_group.world_size 103 | return output 104 | 105 | 106 | def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs): 107 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) 108 | return splitted_args, splitted_kwargs 109 | 110 | 111 | def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs): 112 | assert type(args[0]) is FunctionType # NOTE: The first one args is a function! 113 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) 114 | splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args 115 | return splitted_args_with_func, splitted_kwargs 116 | 117 | 118 | def collect_dp_compute_data_proto(worker_group: "WorkerGroup", output): 119 | for o in output: 120 | assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}." 121 | 122 | output = collect_dp_compute(worker_group, output) 123 | return _concat_data_proto_or_future(output) 124 | 125 | 126 | def get_predefined_dispatch_fn(dispatch_mode): 127 | predefined_dispatch_mode_fn = { 128 | Dispatch.ONE_TO_ALL: { 129 | "dispatch_fn": dispatch_one_to_all, 130 | "collect_fn": collect_all_to_all, 131 | }, 132 | Dispatch.ALL_TO_ALL: { 133 | "dispatch_fn": dispatch_all_to_all, 134 | "collect_fn": collect_all_to_all, 135 | }, 136 | Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute}, 137 | Dispatch.DP_COMPUTE_PROTO: { 138 | "dispatch_fn": dispatch_dp_compute_data_proto, 139 | "collect_fn": collect_dp_compute_data_proto, 140 | }, 141 | Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { 142 | "dispatch_fn": dispatch_dp_compute_data_proto_with_func, 143 | "collect_fn": collect_dp_compute_data_proto, 144 | }, 145 | Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, 146 | } 147 | return predefined_dispatch_mode_fn[dispatch_mode] 148 | 149 | 150 | def get_predefined_execute_fn(execute_mode): 151 | """ 152 | Note that here we only asks execute_all and execute_rank_zero to be implemented 153 | Leave the choice of how these two functions handle argument 'blocking' to users 154 | """ 155 | predefined_execute_mode_fn = { 156 | Execute.ALL: {"execute_fn_name": "execute_all"}, 157 | Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, 158 | } 159 | return predefined_execute_mode_fn[execute_mode] 160 | 161 | 162 | def _check_dispatch_mode(dispatch_mode): 163 | assert isinstance(dispatch_mode, (Dispatch, dict)), ( 164 | f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" 165 | ) 166 | if isinstance(dispatch_mode, dict): 167 | necessary_keys = ["dispatch_fn", "collect_fn"] 168 | for key in necessary_keys: 169 | assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" 170 | 171 | 172 | def _check_execute_mode(execute_mode): 173 | assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" 174 | 175 | 176 | def _materialize_futures(*args, **kwargs): 177 | new_args = [] 178 | for arg in args: 179 | if isinstance(arg, DataProtoFuture): 180 | arg = arg.get() 181 | # add more type to materialize 182 | new_args.append(arg) 183 | for k, v in kwargs.items(): 184 | if isinstance(v, DataProtoFuture): 185 | kwargs[k] = v.get() 186 | 187 | new_args = tuple(new_args) 188 | return new_args, kwargs 189 | 190 | 191 | def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): 192 | _check_dispatch_mode(dispatch_mode=dispatch_mode) 193 | _check_execute_mode(execute_mode=execute_mode) 194 | 195 | def decorator(func): 196 | @wraps(func) 197 | def inner(*args, **kwargs): 198 | if materialize_futures: 199 | args, kwargs = _materialize_futures(*args, **kwargs) 200 | return func(*args, **kwargs) 201 | 202 | attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} 203 | setattr(inner, MAGIC_ATTR, attrs) 204 | return inner 205 | 206 | return decorator 207 | -------------------------------------------------------------------------------- /verl/single_controller/base/worker_group.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | the class of WorkerGroup 16 | """ 17 | 18 | import logging 19 | import signal 20 | import threading 21 | import time 22 | from typing import Any, Callable, Dict, List 23 | 24 | from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn 25 | 26 | 27 | class ResourcePool: 28 | """The resource pool with meta info such as world_size.""" 29 | 30 | def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None: 31 | if process_on_nodes is None: 32 | process_on_nodes = [] 33 | 34 | self._store = process_on_nodes 35 | self.max_collocate_count = max_collocate_count 36 | self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node 37 | 38 | def add_node(self, process_count): 39 | self._store.append(process_count) 40 | 41 | @property 42 | def world_size(self): 43 | return sum(self._store) 44 | 45 | def __call__(self) -> Any: 46 | return self._store 47 | 48 | @property 49 | def store(self): 50 | return self._store 51 | 52 | def local_world_size_list(self) -> List[int]: 53 | nested_local_world_size_list = [ 54 | [local_world_size for _ in range(local_world_size)] for local_world_size in self._store 55 | ] 56 | return [item for row in nested_local_world_size_list for item in row] 57 | 58 | def local_rank_list(self) -> List[int]: 59 | nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416 60 | return [item for row in nested_local_rank_list for item in row] 61 | 62 | 63 | class ClassWithInitArgs: 64 | """ 65 | This class stores a class constructor and the args/kwargs to construct the class. 66 | It is used to instantiate the remote class. 67 | """ 68 | 69 | def __init__(self, cls, *args, **kwargs) -> None: 70 | self.cls = cls 71 | self.args = args 72 | self.kwargs = kwargs 73 | 74 | def __call__(self) -> Any: 75 | return self.cls(*self.args, **self.kwargs) 76 | 77 | 78 | def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: 79 | import time 80 | 81 | while True: 82 | for worker in workers: 83 | if not is_alive(worker): 84 | logging.warning(f"Worker {worker} is not alive, sending signal to main thread") 85 | signal.raise_signal(signal.SIGABRT) 86 | 87 | time.sleep(gap_time) 88 | 89 | 90 | class WorkerGroup: 91 | """A group of workers""" 92 | 93 | def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: 94 | self._is_init_with_detached_workers = True if resource_pool is None else False 95 | 96 | if resource_pool is not None: 97 | # handle the case when WorkGroup is attached to an existing one 98 | self._procecss_dispatch_config = resource_pool() 99 | else: 100 | self._procecss_dispatch_config = None 101 | 102 | self._workers = [] 103 | self._worker_names = [] 104 | 105 | self._master_addr = None 106 | self._master_port = None 107 | 108 | self._checker_thread: threading.Thread = None 109 | 110 | def _is_worker_alive(self, worker): 111 | raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") 112 | 113 | def _block_until_all_workers_alive(self) -> None: 114 | while True: 115 | all_state = [self._is_worker_alive(worker) for worker in self._workers] 116 | if False in all_state: 117 | time.sleep(1) 118 | else: 119 | break 120 | 121 | def start_worker_aliveness_check(self, every_n_seconds=1) -> None: 122 | # before starting checking worker aliveness, make sure all workers are already alive 123 | self._block_until_all_workers_alive() 124 | 125 | self._checker_thread = threading.Thread( 126 | target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) 127 | ) 128 | self._checker_thread.start() 129 | 130 | @property 131 | def world_size(self): 132 | return len(self._workers) 133 | 134 | def _bind_worker_method(self, user_defined_cls, func_generator): 135 | """ 136 | Bind the worker method to the WorkerGroup 137 | """ 138 | for method_name in dir(user_defined_cls): 139 | try: 140 | method = getattr(user_defined_cls, method_name) 141 | assert callable(method), f"{method_name} in {user_defined_cls} is not callable" 142 | except Exception: 143 | # if it is a property, it will fail because Class doesn't have instance property 144 | continue 145 | 146 | if hasattr(method, MAGIC_ATTR): 147 | # this method is decorated by register 148 | attribute = getattr(method, MAGIC_ATTR) 149 | assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}" 150 | assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" 151 | 152 | dispatch_mode = attribute["dispatch_mode"] 153 | execute_mode = attribute["execute_mode"] 154 | blocking = attribute["blocking"] 155 | 156 | # get dispatch fn 157 | if isinstance(dispatch_mode, Dispatch): 158 | # get default dispatch fn 159 | fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) 160 | dispatch_fn = fn["dispatch_fn"] 161 | collect_fn = fn["collect_fn"] 162 | else: 163 | assert isinstance(dispatch_mode, dict) 164 | assert "dispatch_fn" in dispatch_mode 165 | assert "collect_fn" in dispatch_mode 166 | dispatch_fn = dispatch_mode["dispatch_fn"] 167 | collect_fn = dispatch_mode["collect_fn"] 168 | 169 | # get execute_fn_name 170 | execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) 171 | wg_execute_fn_name = execute_mode["execute_fn_name"] 172 | 173 | # get execute_fn from string 174 | try: 175 | execute_fn = getattr(self, wg_execute_fn_name) 176 | assert callable(execute_fn), "execute_fn must be callable" 177 | except Exception: 178 | print(f"execute_fn {wg_execute_fn_name} is invalid") 179 | raise 180 | 181 | # bind a new method to the RayWorkerGroup 182 | func = func_generator( 183 | self, 184 | method_name, 185 | dispatch_fn=dispatch_fn, 186 | collect_fn=collect_fn, 187 | execute_fn=execute_fn, 188 | blocking=blocking, 189 | ) 190 | 191 | try: 192 | setattr(self, method_name, func) 193 | except Exception: 194 | raise ValueError(f"Fail to set method_name {method_name}") 195 | -------------------------------------------------------------------------------- /eval/utils/processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from PIL import Image 4 | from typing import List, Dict, Tuple 5 | from concurrent.futures import ThreadPoolExecutor 6 | from tqdm import tqdm 7 | 8 | from utils.model_parser import llm_eval_score 9 | from mathruler.grader import extract_boxed_content, grade_answer 10 | 11 | def load_image(image_path: str, min_pixels: int, max_pixels: int) -> Image.Image: 12 | """Load and preprocess an image""" 13 | try: 14 | image = Image.open(image_path).convert("RGB") 15 | 16 | # Resize if too large or too small 17 | if (image.width * image.height) > max_pixels: 18 | resize_factor = math.sqrt(max_pixels / (image.width * image.height)) 19 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 20 | image = image.resize((width, height)) 21 | 22 | if (image.width * image.height) < min_pixels: 23 | resize_factor = math.sqrt(min_pixels / (image.width * image.height)) 24 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 25 | image = image.resize((width, height)) 26 | 27 | return image 28 | except Exception as e: 29 | print(f"Error processing image {image_path}: {str(e)}") 30 | return None 31 | 32 | def prepare_prompts(dataset_name: str, samples: List[Dict], args) -> Tuple[List[Dict], List[Dict]]: 33 | """Prepare prompts for all samples""" 34 | prompts = [] 35 | metadata = [] 36 | 37 | for item in tqdm(samples, desc=f"Preparing {dataset_name} prompts"): 38 | # Skip if image doesn't exist 39 | if not os.path.exists(item["image_path"]): 40 | continue 41 | 42 | # Load image 43 | image = load_image(item["image_path"], args.min_pixels, args.max_pixels) 44 | if image is None: 45 | continue 46 | 47 | # Create prompt 48 | if args.version == "7b": 49 | prompt_text = f"<|im_start|>system\n{args.system_prompt}<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{item['question']}<|im_end|>\n<|im_start|>assistant\n" 50 | elif args.version == "32b": 51 | prompt_text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{args.system_prompt} {item['question']}<|im_end|>\n<|im_start|>assistant\n" 52 | else: 53 | raise 54 | 55 | prompts.append({ 56 | "prompt": prompt_text, 57 | "multi_modal_data": {"image": image}, 58 | }) 59 | 60 | metadata.append({ 61 | "dataset": dataset_name, 62 | "id": item["id"], 63 | "question": item["question"], 64 | "answer": item["answer"], 65 | "prompt": prompt_text, 66 | **{k: v for k, v in item.items() if k not in ["image_path", "dataset", "id", "question", "answer"]} 67 | }) 68 | 69 | return prompts, metadata 70 | 71 | def evaluate_prediction(prediction: str, answer: str, dataset: str, question: str = "") -> float: 72 | """Evaluate a prediction against the ground truth""" 73 | if dataset == "geo3k": 74 | extracted_answer = extract_boxed_content(prediction) 75 | return 1.0 if grade_answer(extracted_answer, answer) else 0.0 76 | 77 | elif dataset == "mathvista" or dataset == "mathverse" or dataset == "mathvision" or dataset == "wemath": 78 | try: 79 | score = llm_eval_score(question, prediction, answer, dataset) 80 | except: 81 | import time 82 | time.sleep(10) 83 | score = llm_eval_score(question, prediction, answer, dataset) 84 | return score 85 | 86 | if dataset == "hallubench": 87 | prediction = prediction.replace("\\boxed{}", "") 88 | extracted_answer = extract_boxed_content(prediction) 89 | return 1.0 if extracted_answer.lower() == answer else 0.0 90 | # return 1.0 if answer.lower() in prediction.lower() else 0.0 91 | 92 | else: 93 | # Default evaluation 94 | return 1.0 if extracted_answer == answer else 0.0 95 | 96 | def process_outputs(outputs, metadata, max_workers: int) -> Dict[str, List[Dict]]: 97 | """Process model outputs and calculate metrics""" 98 | results = [] 99 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 100 | futures = [] 101 | 102 | for i, output in enumerate(outputs): 103 | prediction = output.outputs[0].text.strip() 104 | meta = metadata[i] 105 | dataset = meta["dataset"] 106 | if "question_for_eval" in meta: 107 | question = meta["question_for_eval"] 108 | else: 109 | question = meta["question"] 110 | 111 | future = executor.submit( 112 | evaluate_prediction, 113 | prediction, 114 | meta["answer"], 115 | dataset, 116 | question 117 | ) 118 | futures.append((future, i, prediction, meta)) 119 | 120 | for future, i, prediction, meta in tqdm(futures, desc="Evaluating predictions"): 121 | try: 122 | accuracy = future.result() 123 | 124 | result = { 125 | "id": meta["id"], 126 | "question": meta["question"], 127 | "answer": meta["answer"], 128 | "prediction": prediction, 129 | "accuracy": accuracy, 130 | "correct": accuracy > 0, 131 | **{k: v for k, v in meta.items() if k not in ["dataset", "id", "question", "answer"]} 132 | } 133 | 134 | results.append(result) 135 | except Exception as e: 136 | print(f"Error evaluating prediction {i}: {str(e)}") 137 | 138 | return results 139 | 140 | def calculate_metrics(results: List[Dict]) -> Dict: 141 | """Calculate evaluation metrics""" 142 | if not results: 143 | return {"accuracy": 0.0} 144 | 145 | accuracy = sum(1 for r in results if r["correct"]) / len(results) 146 | metrics = {"accuracy": accuracy} 147 | 148 | # Calculate task-specific accuracies if available 149 | if any("task" in r for r in results): 150 | task_results = {} 151 | for r in results: 152 | if "task" in r: 153 | task = r["task"] 154 | if task not in task_results: 155 | task_results[task] = [] 156 | task_results[task].append(r["correct"]) 157 | 158 | task_accuracies = {task: sum(results) / len(results) for task, results in task_results.items()} 159 | metrics["sub_accuracies"] = task_accuracies 160 | 161 | # Calculate problem version accuracies if available 162 | if any("problem_version" in r for r in results): 163 | version_results = {} 164 | for r in results: 165 | if "problem_version" in r: 166 | version = r["problem_version"] 167 | if version not in version_results: 168 | version_results[version] = [] 169 | version_results[version].append(r["correct"]) 170 | 171 | version_accuracies = {version: sum(results) / len(results) for version, results in version_results.items()} 172 | metrics["sub_accuracies"] = version_accuracies 173 | 174 | # Calculate subject accuracies if available 175 | if any("subject" in r for r in results): 176 | subject_results = {} 177 | for r in results: 178 | if "subject" in r: 179 | subject = r["subject"] 180 | if subject not in subject_results: 181 | subject_results[subject] = [] 182 | subject_results[subject].append(r["correct"]) 183 | 184 | subject_accuracies = {subject: sum(results) / len(results) for subject, results in subject_results.items()} 185 | metrics["sub_accuracies"] = subject_accuracies 186 | 187 | return metrics -------------------------------------------------------------------------------- /verl/models/transformers/qwen2_vl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team 2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 3 | # Based on: 4 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import Optional, Tuple 19 | 20 | import torch 21 | 22 | from .flash_attention_utils import flash_attention_forward 23 | 24 | 25 | try: 26 | from transformers.models.qwen2_vl.modeling_qwen2_vl import ( 27 | Qwen2VLAttention, 28 | apply_multimodal_rotary_pos_emb, 29 | repeat_kv, 30 | ) 31 | from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor 32 | except ImportError: 33 | pass 34 | 35 | 36 | def get_rope_index( 37 | processor: "Qwen2VLProcessor", 38 | input_ids: torch.Tensor, 39 | image_grid_thw: Optional[torch.Tensor] = None, 40 | video_grid_thw: Optional[torch.Tensor] = None, 41 | second_per_grid_ts: Optional[torch.Tensor] = None, 42 | attention_mask: Optional[torch.Tensor] = None, 43 | ) -> torch.Tensor: 44 | """ 45 | Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. 46 | The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. 47 | https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 48 | """ 49 | spatial_merge_size = processor.image_processor.merge_size 50 | tokens_per_second = 2 51 | image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") 52 | video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") 53 | vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") 54 | if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): 55 | if attention_mask is None: 56 | attention_mask = torch.ones_like(input_ids) 57 | 58 | position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) 59 | image_index, video_index = 0, 0 60 | input_ids = input_ids[attention_mask == 1] 61 | image_nums, video_nums = 0, 0 62 | vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) 63 | vision_tokens = input_ids[vision_start_indices + 1] 64 | image_nums = (vision_tokens == image_token_id).sum() 65 | video_nums = (vision_tokens == video_token_id).sum() 66 | input_tokens = input_ids.tolist() 67 | llm_pos_ids_list: list = [] 68 | st = 0 69 | remain_images, remain_videos = image_nums, video_nums 70 | for _ in range(image_nums + video_nums): 71 | if image_token_id in input_tokens and remain_images > 0: 72 | ed_image = input_tokens.index(image_token_id, st) 73 | else: 74 | ed_image = len(input_tokens) + 1 75 | if video_token_id in input_tokens and remain_videos > 0: 76 | ed_video = input_tokens.index(video_token_id, st) 77 | else: 78 | ed_video = len(input_tokens) + 1 79 | if ed_image < ed_video: 80 | t, h, w = ( 81 | image_grid_thw[image_index][0], 82 | image_grid_thw[image_index][1], 83 | image_grid_thw[image_index][2], 84 | ) 85 | second_per_grid_t = 0 86 | image_index += 1 87 | remain_images -= 1 88 | ed = ed_image 89 | else: 90 | t, h, w = ( 91 | video_grid_thw[video_index][0], 92 | video_grid_thw[video_index][1], 93 | video_grid_thw[video_index][2], 94 | ) 95 | if second_per_grid_ts is not None: 96 | second_per_grid_t = second_per_grid_ts[video_index] 97 | else: 98 | second_per_grid_t = 1.0 99 | 100 | video_index += 1 101 | remain_videos -= 1 102 | ed = ed_video 103 | 104 | llm_grid_t, llm_grid_h, llm_grid_w = ( 105 | t.item(), 106 | h.item() // spatial_merge_size, 107 | w.item() // spatial_merge_size, 108 | ) 109 | text_len = ed - st 110 | 111 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 112 | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) 113 | 114 | t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) 115 | t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() 116 | h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() 117 | w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() 118 | llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) 119 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w 120 | 121 | if st < len(input_tokens): 122 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 123 | text_len = len(input_tokens) - st 124 | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) 125 | 126 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) 127 | position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) 128 | else: 129 | if attention_mask is not None: 130 | position_ids = attention_mask.long().cumsum(-1) - 1 131 | position_ids.masked_fill_(attention_mask == 0, 1) 132 | position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) 133 | else: 134 | position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) 135 | 136 | return position_ids 137 | 138 | 139 | def qwen2_vl_attn_forward( 140 | self: "Qwen2VLAttention", 141 | hidden_states: torch.Tensor, 142 | attention_mask: Optional[torch.Tensor] = None, 143 | position_ids: Optional[torch.LongTensor] = None, 144 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 145 | **kwargs, 146 | ) -> Tuple[torch.Tensor, None, None]: 147 | bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size 148 | query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) 149 | key_states = self.k_proj(hidden_states) 150 | value_states = self.v_proj(hidden_states) 151 | 152 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 153 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 154 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 155 | 156 | # Because the input can be padded, the absolute sequence length depends on the max position id. 157 | if position_embeddings is None: 158 | cos, sin = self.rotary_emb(value_states, position_ids) 159 | else: 160 | cos, sin = position_embeddings 161 | 162 | query_states, key_states = apply_multimodal_rotary_pos_emb( 163 | query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] 164 | ) 165 | key_states = repeat_kv(key_states, self.num_key_value_groups) 166 | value_states = repeat_kv(value_states, self.num_key_value_groups) 167 | dropout_rate = 0.0 if not self.training else self.attention_dropout 168 | 169 | sliding_window = None 170 | if ( 171 | self.config.use_sliding_window 172 | and getattr(self.config, "sliding_window", None) is not None 173 | and self.layer_idx >= self.config.max_window_layers 174 | ): 175 | sliding_window = self.config.sliding_window 176 | 177 | attn_output, _ = flash_attention_forward( 178 | self, 179 | query_states, 180 | key_states, 181 | value_states, 182 | attention_mask, 183 | dropout=dropout_rate, 184 | sliding_window=sliding_window, 185 | position_ids=position_ids, # important: pass position ids 186 | ) # (batch_size, seq_length, num_head / sp_size, head_size) 187 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 188 | attn_output = self.o_proj(attn_output) 189 | return attn_output, None, None 190 | -------------------------------------------------------------------------------- /verl/models/transformers/flash_attention_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team 2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 3 | # Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import inspect 18 | import os 19 | from typing import Optional, Tuple 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check 24 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 25 | 26 | from ...utils.ulysses import ( 27 | gather_heads_scatter_seq, 28 | gather_seq_scatter_heads, 29 | get_ulysses_sequence_parallel_group, 30 | get_ulysses_sequence_parallel_world_size, 31 | ) 32 | 33 | 34 | if is_flash_attn_2_available(): 35 | from flash_attn import flash_attn_func, flash_attn_varlen_func 36 | 37 | _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters 38 | _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters 39 | _flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" 40 | _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 41 | 42 | 43 | def prepare_fa2_from_position_ids( 44 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor 45 | ): 46 | query = query.view(-1, query.size(-2), query.size(-1)) 47 | key = key.contiguous().view(-1, key.size(-2), key.size(-1)) 48 | value = value.contiguous().view(-1, value.size(-2), value.size(-1)) 49 | position_ids = position_ids.flatten() 50 | indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) 51 | cu_seqlens = torch.cat( 52 | ( 53 | indices_q[position_ids == 0], 54 | torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), 55 | ) 56 | ) 57 | max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope 58 | return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) 59 | 60 | 61 | def _custom_flash_attention_forward( 62 | query_states: torch.Tensor, 63 | key_states: torch.Tensor, 64 | value_states: torch.Tensor, 65 | attention_mask: Optional[torch.Tensor], 66 | query_length: int, 67 | is_causal: bool = True, 68 | position_ids: Optional[torch.Tensor] = None, 69 | sliding_window: Optional[int] = None, 70 | use_top_left_mask: bool = False, 71 | deterministic: Optional[bool] = None, 72 | **kwargs, 73 | ): 74 | """ 75 | Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) 76 | """ 77 | if not use_top_left_mask: 78 | causal = is_causal 79 | else: 80 | causal = is_causal and query_length != 1 81 | 82 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 83 | use_sliding_windows = ( 84 | _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window 85 | ) 86 | flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} 87 | 88 | if _flash_supports_deterministic: 89 | if deterministic is None: 90 | deterministic = _flash_deterministic_enabled 91 | flash_kwargs["deterministic"] = deterministic 92 | 93 | if kwargs.get("softcap") is not None: 94 | flash_kwargs["softcap"] = kwargs.pop("softcap") 95 | 96 | query_states, key_states, value_states = fa_peft_integration_check( 97 | query_states, key_states, value_states, target_dtype=torch.bfloat16 98 | ) 99 | 100 | sp_size = get_ulysses_sequence_parallel_world_size() 101 | if sp_size > 1: 102 | # (batch_size, seq_length, num_head, head_size) 103 | query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) 104 | key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) 105 | value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) 106 | position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)] 107 | position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group()) 108 | position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length) 109 | 110 | if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope 111 | position_ids = position_ids[0] 112 | 113 | if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): 114 | batch_size = query_states.size(0) 115 | query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( 116 | query_states, key_states, value_states, position_ids 117 | ) # remove channel dimension 118 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 119 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 120 | attn_output = flash_attn_varlen_func( 121 | query_states, 122 | key_states, 123 | value_states, 124 | cu_seqlens_q=cu_seqlens_q, 125 | cu_seqlens_k=cu_seqlens_k, 126 | max_seqlen_q=max_seqlen_in_batch_q, 127 | max_seqlen_k=max_seqlen_in_batch_k, 128 | dropout_p=kwargs.pop("dropout", 0.0), 129 | softmax_scale=kwargs.pop("softmax_scale", None), 130 | causal=causal, 131 | **flash_kwargs, 132 | ) 133 | attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) 134 | else: 135 | attn_output = _flash_attention_forward( 136 | query_states, 137 | key_states, 138 | value_states, 139 | attention_mask, 140 | query_length, 141 | is_causal=is_causal, 142 | sliding_window=sliding_window, 143 | use_top_left_mask=use_top_left_mask, 144 | deterministic=deterministic, 145 | **kwargs, 146 | ) # do not pass position_ids to old flash_attention_forward 147 | 148 | if sp_size > 1: 149 | # (batch_size, seq_length, num_head, head_size) 150 | attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) 151 | 152 | return attn_output 153 | 154 | 155 | def flash_attention_forward( 156 | module: torch.nn.Module, 157 | query: torch.Tensor, 158 | key: torch.Tensor, 159 | value: torch.Tensor, 160 | attention_mask: Optional[torch.Tensor], 161 | dropout: float = 0.0, 162 | scaling: Optional[float] = None, 163 | sliding_window: Optional[int] = None, 164 | softcap: Optional[float] = None, 165 | **kwargs, 166 | ) -> Tuple[torch.Tensor, None]: 167 | # This is before the transpose 168 | q_len = query.shape[2] 169 | 170 | # FA2 uses non-transposed inputs 171 | query = query.transpose(1, 2) 172 | key = key.transpose(1, 2) 173 | value = value.transpose(1, 2) 174 | 175 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 176 | # therefore the input hidden states gets silently casted in float32. Hence, we need 177 | # cast them back in the correct dtype just to be sure everything works as expected. 178 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 179 | # in fp32. (usually our RMSNorm modules handle it correctly) 180 | target_dtype = None 181 | if query.dtype == torch.float32: 182 | if torch.is_autocast_enabled(): 183 | target_dtype = torch.get_autocast_gpu_dtype() 184 | # Handle the case where the model is quantized 185 | elif hasattr(module.config, "_pre_quantization_dtype"): 186 | target_dtype = module.config._pre_quantization_dtype 187 | else: 188 | target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype 189 | 190 | # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice 191 | kwargs.pop("is_causal", None) 192 | 193 | attn_output = _custom_flash_attention_forward( 194 | query, 195 | key, 196 | value, 197 | attention_mask, 198 | query_length=q_len, 199 | is_causal=True, 200 | dropout=dropout, 201 | softmax_scale=scaling, 202 | sliding_window=sliding_window, 203 | softcap=softcap, 204 | use_top_left_mask=_flash_use_top_left_mask, 205 | target_dtype=target_dtype, 206 | **kwargs, 207 | ) 208 | 209 | return attn_output, None 210 | -------------------------------------------------------------------------------- /verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The vllm_rollout that can be applied in different backend 16 | When working with FSDP: 17 | - Use DTensor weight loader (recommended) or HF weight loader 18 | - Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM 19 | """ 20 | 21 | from contextlib import contextmanager 22 | from typing import Any, List, Union 23 | 24 | import numpy as np 25 | import torch 26 | import torch.distributed 27 | from tensordict import TensorDict 28 | from transformers import PreTrainedTokenizer 29 | from vllm import LLM, RequestOutput, SamplingParams 30 | 31 | from ....protocol import DataProto 32 | from ....utils import torch_functional as VF 33 | from ....utils.torch_dtypes import PrecisionType 34 | from ..base import BaseRollout 35 | from ..config import RolloutConfig 36 | 37 | 38 | def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]: 39 | if isinstance(value, torch.Tensor): 40 | return value.repeat_interleave(repeats, dim=0) 41 | else: 42 | return np.repeat(value, repeats, axis=0) 43 | 44 | 45 | class vLLMRollout(BaseRollout): 46 | def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer): 47 | """A vLLM rollout. It requires the module is supported by the vllm. 48 | 49 | Args: 50 | module: module here follows huggingface APIs 51 | config: DictConfig 52 | tokenizer: the task/model tokenizer 53 | """ 54 | super().__init__() 55 | self.config = config 56 | self.pad_token_id = tokenizer.pad_token_id 57 | if config.tensor_parallel_size > torch.distributed.get_world_size(): 58 | raise ValueError("Tensor parallelism size should be less than world size.") 59 | 60 | if not config.enforce_eager and config.free_cache_engine: 61 | raise ValueError("CUDA graph should be disabled when `free_cache_engine` is True.") 62 | 63 | if config.max_num_batched_tokens < config.prompt_length + config.response_length: 64 | raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.") 65 | 66 | vllm_init_kwargs = {} 67 | if config.limit_images > 0: 68 | vllm_init_kwargs = {"limit_mm_per_prompt": {"image": config.limit_images}} 69 | 70 | self.inference_engine = LLM( 71 | model=model_path, 72 | skip_tokenizer_init=False, 73 | tensor_parallel_size=config.tensor_parallel_size, 74 | dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)), 75 | gpu_memory_utilization=config.gpu_memory_utilization, 76 | enforce_eager=config.enforce_eager, 77 | max_model_len=config.prompt_length + config.response_length, 78 | max_num_batched_tokens=config.max_num_batched_tokens, 79 | enable_sleep_mode=True, 80 | distributed_executor_backend="external_launcher", 81 | disable_custom_all_reduce=True, 82 | disable_log_stats=config.disable_log_stats, 83 | enable_chunked_prefill=config.enable_chunked_prefill, 84 | **vllm_init_kwargs, 85 | ) 86 | 87 | # Offload vllm model to reduce peak memory usage 88 | self.inference_engine.sleep(level=1) 89 | 90 | sampling_kwargs = {"max_tokens": config.response_length, "detokenize": False} 91 | default_sampling_params = SamplingParams() 92 | for key in config.to_dict().keys(): 93 | if hasattr(default_sampling_params, key): 94 | sampling_kwargs[key] = getattr(config, key) 95 | 96 | print(f"Sampling params: {sampling_kwargs}.") 97 | self.sampling_params = SamplingParams(**sampling_kwargs) 98 | 99 | @contextmanager 100 | def update_sampling_params(self, **kwargs): 101 | # update sampling params 102 | old_sampling_params_args = {} 103 | if kwargs: 104 | for key, value in kwargs.items(): 105 | if hasattr(self.sampling_params, key): 106 | old_value = getattr(self.sampling_params, key) 107 | old_sampling_params_args[key] = old_value 108 | setattr(self.sampling_params, key, value) 109 | 110 | yield 111 | # roll back to previous sampling params 112 | for key, value in old_sampling_params_args.items(): 113 | setattr(self.sampling_params, key, value) 114 | 115 | @torch.no_grad() 116 | def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: 117 | # left-padded attention_mask 118 | input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length) 119 | attention_mask: torch.Tensor = prompts.batch["attention_mask"] 120 | position_ids: torch.Tensor = prompts.batch["position_ids"] 121 | eos_token_id: int = prompts.meta_info["eos_token_id"] 122 | batch_size = input_ids.size(0) 123 | 124 | do_sample = prompts.meta_info.get("do_sample", True) 125 | if not do_sample: 126 | kwargs = { 127 | "n": 1, 128 | "temperature": 0.0, 129 | "top_p": 1.0, 130 | "top_k": -1, 131 | "min_p": 0.0, 132 | } 133 | 134 | non_tensor_batch = prompts.non_tensor_batch 135 | if batch_size != len(non_tensor_batch["raw_prompt_ids"]): 136 | raise RuntimeError("vllm sharding manager is not work properly.") 137 | 138 | if "multi_modal_data" in non_tensor_batch: 139 | vllm_inputs = [] 140 | for raw_prompt_ids, multi_modal_data in zip( 141 | non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data") 142 | ): 143 | vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data}) 144 | else: 145 | vllm_inputs = [ 146 | {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") 147 | ] 148 | 149 | # users can customize different sampling_params at different run 150 | with self.update_sampling_params(**kwargs): 151 | completions: List[RequestOutput] = self.inference_engine.generate( 152 | prompts=vllm_inputs, sampling_params=self.sampling_params 153 | ) 154 | 155 | response_ids = [] 156 | for completion in completions: 157 | for output in completion.outputs: 158 | response_ids.append(output.token_ids) 159 | 160 | response_ids = VF.pad_2d_list_to_length( 161 | response_ids, self.pad_token_id, max_length=self.config.response_length 162 | ).to(input_ids.device) 163 | 164 | if self.config.n > 1 and do_sample: 165 | batch_size = batch_size * self.config.n 166 | input_ids = _repeat_interleave(input_ids, self.config.n) 167 | attention_mask = _repeat_interleave(attention_mask, self.config.n) 168 | position_ids = _repeat_interleave(position_ids, self.config.n) 169 | if "multi_modal_inputs" in non_tensor_batch.keys(): 170 | non_tensor_batch["multi_modal_inputs"] = _repeat_interleave( 171 | non_tensor_batch["multi_modal_inputs"], self.config.n 172 | ) 173 | 174 | sequence_ids = torch.cat([input_ids, response_ids], dim=-1) 175 | response_length = response_ids.size(1) 176 | delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) 177 | delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1) 178 | if position_ids.dim() == 3: # qwen2vl mrope 179 | delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) 180 | 181 | # prompt: left pad + response: right pad 182 | # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0] 183 | # position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11] 184 | response_position_ids = position_ids[..., -1:] + delta_position_id 185 | position_ids = torch.cat([position_ids, response_position_ids], dim=-1) 186 | response_attention_mask = VF.get_eos_mask( 187 | response_ids=response_ids, eos_token=eos_token_id, dtype=attention_mask.dtype 188 | ) 189 | attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) 190 | 191 | # all the tp ranks should contain the same data here. data in all ranks are valid 192 | batch = TensorDict( 193 | { 194 | "prompts": input_ids, 195 | "responses": response_ids, 196 | "input_ids": sequence_ids, # here input_ids become the whole sentences 197 | "attention_mask": attention_mask, 198 | "position_ids": position_ids, 199 | }, 200 | batch_size=batch_size, 201 | ) 202 | return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) 203 | --------------------------------------------------------------------------------