├── assets ├── wechat.jpg ├── easyr1_grpo.png ├── traj_reward.png └── qwen2_5_vl_7b_geo.png ├── .gitmodules ├── examples ├── runtime_env.yaml ├── baselines │ ├── qwen2_5_vl_3b_geoqa8k.sh │ └── qwen2_5_vl_3b_clevr.sh ├── osworld_full_arpo.sh ├── osworld_subset32.sh └── config.yaml ├── requirements.txt ├── verl ├── utils │ ├── __init__.py │ ├── logger │ │ ├── __init__.py │ │ ├── gen_logger.py │ │ └── logger.py │ ├── checkpoint │ │ ├── __init__.py │ │ ├── checkpoint_manager.py │ │ └── fsdp_checkpoint_manager.py │ ├── reward_score │ │ ├── __init__.py │ │ ├── math.py │ │ └── r1v.py │ ├── tokenizer.py │ ├── torch_dtypes.py │ ├── model_utils.py │ ├── py_functional.py │ ├── fsdp_utils.py │ ├── flops_counter.py │ ├── dataset.py │ └── seqlen_balancing.py ├── models │ ├── __init__.py │ ├── transformers │ │ ├── __init__.py │ │ └── flash_attention_utils.py │ └── monkey_patch.py ├── trainer │ ├── __init__.py │ ├── replay_buffer.py │ ├── config.py │ ├── main.py │ └── metrics.py ├── workers │ ├── __init__.py │ ├── reward │ │ ├── __init__.py │ │ ├── config.py │ │ └── custom.py │ ├── rollout │ │ ├── __init__.py │ │ ├── base.py │ │ ├── config.py │ │ └── vllm_rollout_spmd.py │ ├── critic │ │ ├── __init__.py │ │ ├── base.py │ │ ├── config.py │ │ └── dp_critic.py │ ├── sharding_manager │ │ ├── __init__.py │ │ ├── base.py │ │ ├── fsdp_ulysses.py │ │ └── fsdp_vllm.py │ ├── actor │ │ ├── __init__.py │ │ ├── base.py │ │ └── config.py │ └── config.py ├── single_controller │ ├── __init__.py │ ├── base │ │ ├── register_center │ │ │ ├── __init__.py │ │ │ └── ray.py │ │ ├── __init__.py │ │ ├── worker.py │ │ ├── worker_group.py │ │ └── decorator.py │ └── ray │ │ └── __init__.py └── __init__.py ├── .pre-commit-config.yaml ├── start_server.sh ├── pyproject.toml ├── .github ├── workflows │ └── tests.yml ├── CONTRIBUTING.md └── CODE_OF_CONDUCT.md ├── setup.py ├── .gitignore ├── README.md └── scripts └── model_merger.py /assets/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ARPO/HEAD/assets/wechat.jpg -------------------------------------------------------------------------------- /assets/easyr1_grpo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ARPO/HEAD/assets/easyr1_grpo.png -------------------------------------------------------------------------------- /assets/traj_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ARPO/HEAD/assets/traj_reward.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "OSWorld"] 2 | path = OSWorld 3 | url = https://github.com/FanbinLu/OSWorld 4 | -------------------------------------------------------------------------------- /assets/qwen2_5_vl_7b_geo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ARPO/HEAD/assets/qwen2_5_vl_7b_geo.png -------------------------------------------------------------------------------- /examples/runtime_env.yaml: -------------------------------------------------------------------------------- 1 | working_dir: ./ 2 | excludes: ["/.git/"] 3 | env_vars: 4 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1" 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | wandb 3 | tqdm 4 | accelerate==1.4.0 5 | codetiming==1.4.0 6 | datasets==3.3.2 7 | filelock==3.18.0 8 | flash_attn==2.7.4.post1 9 | liger_kernel==0.5.6 10 | mathruler==0.1.0 11 | mlflow==2.22.0 12 | omegaconf==2.3.0 13 | Pillow==11.2.1 14 | psutil==6.1.0 15 | PyYAML==6.0.2 16 | qwen_vl_utils==0.0.11 17 | swanlab==0.5.8 18 | tensordict==0.5.0 19 | torch==2.5.1 20 | torchdata==0.11.0 21 | vllm==0.7.3 22 | -------------------------------------------------------------------------------- /verl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/single_controller/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = "0.2.0.dev" 16 | -------------------------------------------------------------------------------- /verl/single_controller/base/register_center/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-added-large-files 7 | args: ['--maxkb=25000'] 8 | - id: check-merge-conflict 9 | - id: check-yaml 10 | - id: debug-statements 11 | - id: end-of-file-fixer 12 | - id: requirements-txt-fixer 13 | - id: trailing-whitespace 14 | args: [--markdown-linebreak-ext=md] 15 | - id: no-commit-to-branch 16 | args: ['--branch', 'main'] 17 | 18 | - repo: https://github.com/asottile/pyupgrade 19 | rev: v3.17.0 20 | hooks: 21 | - id: pyupgrade 22 | args: [--py38-plus] 23 | -------------------------------------------------------------------------------- /verl/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .logger import Tracker 17 | 18 | 19 | __all__ = ["Tracker"] 20 | -------------------------------------------------------------------------------- /start_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model=PATH_TO_MODEL 4 | model_name=ui-tars 5 | num_images=16 6 | 7 | port=9000 8 | 9 | # Function to clean up processes on exit 10 | cleanup() { 11 | echo "Stopping all processes..." 12 | pkill -P $$ # Kill all child processes of this script 13 | exit 0 14 | } 15 | 16 | # Trap SIGINT (Ctrl+C) and SIGTERM to run cleanup function 17 | trap cleanup SIGINT SIGTERM 18 | 19 | # Start processes 20 | for i in {0..7}; do 21 | CUDA_VISIBLE_DEVICES=$i python -m vllm.entrypoints.openai.api_server \ 22 | --served-model-name $model_name \ 23 | --model $model \ 24 | --limit-mm-per-prompt image=$num_images \ 25 | -tp=1 \ 26 | --port $((9000 + i)) & 27 | done 28 | 29 | # Wait to keep the script running 30 | wait -------------------------------------------------------------------------------- /verl/workers/reward/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .config import RewardConfig 16 | from .custom import CustomRewardManager 17 | 18 | 19 | __all__ = ["CustomRewardManager", "RewardConfig"] 20 | -------------------------------------------------------------------------------- /verl/utils/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt 16 | 17 | 18 | __all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"] 19 | -------------------------------------------------------------------------------- /verl/workers/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .config import RolloutConfig 17 | from .vllm_rollout_spmd import vLLMRollout 18 | 19 | 20 | __all__ = ["RolloutConfig", "vLLMRollout"] 21 | -------------------------------------------------------------------------------- /verl/utils/reward_score/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .math import math_compute_score 17 | from .r1v import r1v_compute_score 18 | 19 | 20 | __all__ = ["math_compute_score", "r1v_compute_score"] 21 | -------------------------------------------------------------------------------- /verl/single_controller/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .worker import Worker 16 | from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup 17 | 18 | 19 | __all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"] 20 | -------------------------------------------------------------------------------- /verl/single_controller/ray/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls 16 | 17 | 18 | __all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"] 19 | -------------------------------------------------------------------------------- /verl/workers/critic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOCritic 16 | from .config import CriticConfig, ModelConfig 17 | from .dp_critic import DataParallelPPOCritic 18 | 19 | 20 | __all__ = ["BasePPOCritic", "CriticConfig", "DataParallelPPOCritic", "ModelConfig"] 21 | -------------------------------------------------------------------------------- /verl/workers/reward/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Reward config 16 | """ 17 | 18 | from dataclasses import dataclass 19 | 20 | 21 | @dataclass 22 | class RewardConfig: 23 | reward_type: str = "function" 24 | score_function: str = "math" 25 | skip_special_tokens: bool = True 26 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .base import BaseShardingManager 17 | from .fsdp_ulysses import FSDPUlyssesShardingManager 18 | from .fsdp_vllm import FSDPVLLMShardingManager 19 | 20 | 21 | __all__ = ["BaseShardingManager", "FSDPUlyssesShardingManager", "FSDPVLLMShardingManager"] 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "verl" 7 | dynamic = [ 8 | "version", 9 | "dependencies", 10 | "optional-dependencies", 11 | "requires-python", 12 | "authors", 13 | "description", 14 | "readme", 15 | "license" 16 | ] 17 | 18 | [tool.ruff] 19 | target-version = "py39" 20 | line-length = 119 21 | indent-width = 4 22 | 23 | [tool.ruff.lint] 24 | ignore = ["C901", "E501", "E741", "W605", "C408"] 25 | select = ["C", "E", "F", "I", "W", "RUF022"] 26 | 27 | [tool.ruff.lint.per-file-ignores] 28 | "__init__.py" = ["E402", "F401", "F403", "F811"] 29 | 30 | [tool.ruff.lint.isort] 31 | lines-after-imports = 2 32 | known-first-party = ["verl"] 33 | known-third-party = ["torch", "transformers", "wandb"] 34 | 35 | [tool.ruff.format] 36 | quote-style = "double" 37 | indent-style = "space" 38 | skip-magic-trailing-comma = false 39 | line-ending = "auto" 40 | -------------------------------------------------------------------------------- /verl/workers/rollout/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | 17 | from ...protocol import DataProto 18 | 19 | 20 | __all__ = ["BaseRollout"] 21 | 22 | 23 | class BaseRollout(ABC): 24 | @abstractmethod 25 | def generate_sequences(self, prompts: DataProto) -> DataProto: 26 | """Generate sequences""" 27 | pass 28 | -------------------------------------------------------------------------------- /examples/baselines/qwen2_5_vl_3b_geoqa8k.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path 4 | 5 | FORMAT_PROMPT="""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant 6 | first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning 7 | process and answer are enclosed within and tags, respectively, i.e., 8 | reasoning process here answer here """ 9 | 10 | python3 -m verl.trainer.main \ 11 | config=examples/config.yaml \ 12 | data.train_files=leonardPKU/GEOQA_8K_R1V@train \ 13 | data.val_files=leonardPKU/GEOQA_8K_R1V@test \ 14 | data.format_prompt="${FORMAT_PROMPT}" \ 15 | worker.actor.model.model_path=${MODEL_PATH} \ 16 | worker.rollout.tensor_parallel_size=1 \ 17 | worker.reward.score_function=r1v \ 18 | trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \ 19 | trainer.n_gpus_per_node=8 20 | -------------------------------------------------------------------------------- /examples/baselines/qwen2_5_vl_3b_clevr.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path 4 | 5 | FORMAT_PROMPT="""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant 6 | first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning 7 | process and answer are enclosed within and tags, respectively, i.e., 8 | reasoning process here answer here """ 9 | 10 | python3 -m verl.trainer.main \ 11 | config=examples/config.yaml \ 12 | data.train_files=BUAADreamer/clevr_count_70k@train \ 13 | data.val_files=BUAADreamer/clevr_count_70k@test \ 14 | data.format_prompt="${FORMAT_PROMPT}" \ 15 | worker.actor.model.model_path=${MODEL_PATH} \ 16 | worker.rollout.tensor_parallel_size=1 \ 17 | worker.reward.score_function=r1v \ 18 | trainer.experiment_name=qwen2_5_vl_3b_clevr \ 19 | trainer.n_gpus_per_node=2 20 | -------------------------------------------------------------------------------- /verl/workers/actor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOActor 16 | from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig 17 | from .dp_actor import DataParallelPPOActor 18 | 19 | 20 | __all__ = [ 21 | "ActorConfig", 22 | "BasePPOActor", 23 | "DataParallelPPOActor", 24 | "FSDPConfig", 25 | "ModelConfig", 26 | "OptimConfig", 27 | "RefConfig", 28 | ] 29 | -------------------------------------------------------------------------------- /verl/single_controller/base/register_center/ray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | 18 | @ray.remote 19 | class WorkerGroupRegisterCenter: 20 | def __init__(self, rank_zero_info): 21 | self.rank_zero_info = rank_zero_info 22 | 23 | def get_rank_zero_info(self): 24 | return self.rank_zero_info 25 | 26 | 27 | def create_worker_group_register_center(name, info): 28 | return WorkerGroupRegisterCenter.options(name=name).remote(info) 29 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Sharding manager to implement HybridEngine 16 | """ 17 | 18 | from ...protocol import DataProto 19 | 20 | 21 | class BaseShardingManager: 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, exc_type, exc_value, traceback): 26 | pass 27 | 28 | def preprocess_data(self, data: DataProto) -> DataProto: 29 | return data 30 | 31 | def postprocess_data(self, data: DataProto) -> DataProto: 32 | return data 33 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | paths: 8 | - "**.py" 9 | - "requirements.txt" 10 | - ".github/workflows/*.yml" 11 | pull_request: 12 | branches: 13 | - "main" 14 | paths: 15 | - "**.py" 16 | - "requirements.txt" 17 | - ".github/workflows/*.yml" 18 | 19 | jobs: 20 | tests: 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | python-version: 25 | - "3.11" 26 | os: 27 | - "ubuntu-latest" 28 | 29 | runs-on: ${{ matrix.os }} 30 | 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v4 34 | 35 | - name: Set up Python 36 | uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | cache: "pip" 40 | cache-dependency-path: "setup.py" 41 | 42 | - name: Install dependencies 43 | run: | 44 | python -m pip install --upgrade pip 45 | python -m pip install ruff 46 | 47 | - name: Check quality 48 | run: | 49 | make style && make quality 50 | -------------------------------------------------------------------------------- /verl/workers/critic/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Base class for Critic 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, Dict 20 | 21 | import torch 22 | 23 | from ...protocol import DataProto 24 | from .config import CriticConfig 25 | 26 | 27 | __all__ = ["BasePPOCritic"] 28 | 29 | 30 | class BasePPOCritic(ABC): 31 | def __init__(self, config: CriticConfig): 32 | self.config = config 33 | 34 | @abstractmethod 35 | def compute_values(self, data: DataProto) -> torch.Tensor: 36 | """Compute values""" 37 | pass 38 | 39 | @abstractmethod 40 | def update_critic(self, data: DataProto) -> Dict[str, Any]: 41 | """Update the critic""" 42 | pass 43 | -------------------------------------------------------------------------------- /verl/workers/critic/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Critic config 16 | """ 17 | 18 | from dataclasses import dataclass, field 19 | 20 | from ..actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig 21 | 22 | 23 | @dataclass 24 | class CriticConfig: 25 | strategy: str = "fsdp" 26 | global_batch_size: int = 256 27 | micro_batch_size_per_device_for_update: int = 4 28 | micro_batch_size_per_device_for_experience: int = 16 29 | max_grad_norm: float = 1.0 30 | cliprange_value: float = 0.5 31 | ppo_epochs: int = 1 32 | padding_free: bool = False 33 | ulysses_sequence_parallel_size: int = 1 34 | model: ModelConfig = field(default_factory=ModelConfig) 35 | optim: OptimConfig = field(default_factory=OptimConfig) 36 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 37 | offload: OffloadConfig = field(default_factory=OffloadConfig) 38 | """auto keys""" 39 | global_batch_size_per_device: int = field(default=-1, init=False) 40 | -------------------------------------------------------------------------------- /verl/workers/rollout/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Rollout config 16 | """ 17 | 18 | from dataclasses import asdict, dataclass, field 19 | from typing import Any, Dict 20 | 21 | 22 | @dataclass 23 | class RolloutConfig: 24 | name: str = "vllm" 25 | n: int = 1 26 | temperature: float = 1.0 27 | top_p: float = 1.0 28 | top_k: int = -1 29 | limit_images: int = 0 30 | dtype: str = "bf16" 31 | gpu_memory_utilization: float = 0.6 32 | ignore_eos: bool = False 33 | enforce_eager: bool = False 34 | enable_chunked_prefill: bool = False # only for v0 engine 35 | tensor_parallel_size: int = 2 36 | max_num_batched_tokens: int = 8192 37 | max_num_seqs: int = 1024 38 | disable_log_stats: bool = True 39 | val_override_config: Dict[str, Any] = field(default_factory=dict) 40 | """auto keys""" 41 | prompt_length: int = field(default=-1, init=False) 42 | response_length: int = field(default=-1, init=False) 43 | 44 | def to_dict(self): 45 | return asdict(self) 46 | -------------------------------------------------------------------------------- /verl/utils/reward_score/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from typing import Dict 17 | 18 | from mathruler.grader import extract_boxed_content, grade_answer 19 | 20 | 21 | def math_format_reward(predict_str: str) -> float: 22 | pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) 23 | format_match = re.fullmatch(pattern, predict_str) 24 | return 1.0 if format_match else 0.0 25 | 26 | 27 | def math_acc_reward(predict_str: str, ground_truth: str) -> float: 28 | answer = extract_boxed_content(predict_str) 29 | return 1.0 if grade_answer(answer, ground_truth) else 0.0 30 | 31 | 32 | def math_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]: 33 | predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) # handle qwen2.5vl-32b format 34 | format = math_format_reward(predict_str) 35 | accuracy = math_acc_reward(predict_str, ground_truth) 36 | return { 37 | "overall": 0.9 * accuracy + 0.1 * format, 38 | "format": format, 39 | "accuracy": accuracy, 40 | } 41 | -------------------------------------------------------------------------------- /verl/models/monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS 17 | 18 | from .transformers.flash_attention_utils import flash_attention_forward 19 | from .transformers.qwen2_vl import qwen2_vl_attn_forward, qwen_2_mixed_modality_forward 20 | 21 | 22 | def apply_ulysses_patch(model_type: str) -> None: 23 | if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"): 24 | ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward 25 | elif model_type in ("qwen2_vl", "qwen2_5_vl"): 26 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 27 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 28 | 29 | Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward 30 | Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward 31 | 32 | # from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration 33 | # Qwen2VLForConditionalGeneration.forward = qwen_2_mixed_modality_forward 34 | else: 35 | raise NotImplementedError(f"Model architecture {model_type} is not supported yet.") 36 | -------------------------------------------------------------------------------- /verl/utils/reward_score/r1v.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from typing import Dict 17 | 18 | from mathruler.grader import grade_answer 19 | 20 | 21 | def r1v_format_reward(predict_str: str) -> float: 22 | pattern = re.compile(r".*?\s*.*?", re.DOTALL) 23 | format_match = re.fullmatch(pattern, predict_str) 24 | return 1.0 if format_match else 0.0 25 | 26 | 27 | def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float: 28 | try: 29 | ground_truth = ground_truth.strip() 30 | content_match = re.search(r"(.*?)", predict_str) 31 | given_answer = content_match.group(1).strip() if content_match else predict_str.strip() 32 | if grade_answer(given_answer, ground_truth): 33 | return 1.0 34 | 35 | except Exception: 36 | pass 37 | 38 | return 0.0 39 | 40 | 41 | def r1v_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]: 42 | format = r1v_format_reward(predict_str) 43 | accuracy = r1v_accuracy_reward(predict_str, ground_truth) 44 | return { 45 | "overall": 0.5 * accuracy + 0.5 * format, 46 | "format": format, 47 | "accuracy": accuracy, 48 | } 49 | -------------------------------------------------------------------------------- /verl/trainer/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import json 4 | import random 5 | from collections import defaultdict 6 | 7 | from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto, collate_fn 8 | 9 | class ReplayBuffer(): 10 | 11 | def __init__(self, json_path, buffer_size): 12 | self.buffer_size = buffer_size 13 | 14 | self.pos_dataset = defaultdict(list) 15 | 16 | # if json_path is not None: 17 | # with open(json_path, 'r') as f: 18 | # replay_data = json.load(f) 19 | 20 | # for data in replay_data: 21 | # # task_id, history_images, history_messages, eval_result 22 | # task_id = data['task_id'] 23 | # eval_result = data['eval_result'] 24 | # if eval_result > 0.1: 25 | # self.pos_dataset[task_id].append(data) 26 | 27 | 28 | def update_replay_buffer(self, task_config, batch_item, eval_result): 29 | task_id = task_config["task_id"] 30 | if eval_result > 0.1: 31 | task_replay_buffer = self.pos_dataset[task_id] 32 | else: 33 | return 34 | 35 | task_replay_buffer.append(batch_item) 36 | 37 | if len(task_replay_buffer) > self.buffer_size: 38 | task_replay_buffer.pop(0) 39 | 40 | def update_replay_buffer_batch(self, task_configs, batch): 41 | eval_results = batch.batch['eval_results'].tolist() 42 | 43 | for task_config, batch_item, eval_result in zip(task_configs, batch, eval_results): 44 | self.update_replay_buffer(task_config, batch_item, eval_result) 45 | 46 | def get_pos(self, task_id, num_samples=1): 47 | if task_id not in self.pos_dataset: 48 | return DataProto() 49 | else: 50 | datalist = random.choices(self.pos_dataset[task_id], k=num_samples) 51 | return collate_fn(datalist) 52 | -------------------------------------------------------------------------------- /verl/workers/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | ActorRolloutRef config 16 | """ 17 | 18 | from dataclasses import dataclass, field 19 | 20 | from .actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig 21 | from .critic import CriticConfig 22 | from .reward import RewardConfig 23 | from .rollout import RolloutConfig 24 | 25 | 26 | __all__ = [ 27 | "ActorConfig", 28 | "CriticConfig", 29 | "FSDPConfig", 30 | "ModelConfig", 31 | "OptimConfig", 32 | "RefConfig", 33 | "RewardConfig", 34 | "RolloutConfig", 35 | "WorkerConfig", 36 | ] 37 | 38 | 39 | @dataclass 40 | class WorkerConfig: 41 | hybrid_engine: bool = True 42 | actor: ActorConfig = field(default_factory=ActorConfig) 43 | critic: CriticConfig = field(default_factory=CriticConfig) 44 | ref: RefConfig = field(default_factory=RefConfig) 45 | reward: RewardConfig = field(default_factory=RewardConfig) 46 | rollout: RolloutConfig = field(default_factory=RolloutConfig) 47 | 48 | def post_init(self): 49 | self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience 50 | self.ref.padding_free = self.actor.padding_free 51 | self.ref.ulysses_sequence_parallel_size = self.actor.ulysses_sequence_parallel_size 52 | self.ref.use_torch_compile = self.actor.use_torch_compile 53 | -------------------------------------------------------------------------------- /examples/osworld_full_arpo.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | MODEL_PATH=UI-TARS-1.5-7B 4 | 5 | SYSTEM_PROMPT="""You are helpful assistant.""" 6 | 7 | NUM_GPUS=8 8 | NUM_ENVS=128 9 | ROLLOUT_N=8 10 | 11 | ((ROLLOUT_BSZ = NUM_ENVS/ROLLOUT_N)) 12 | 13 | 14 | python3 -m verl.trainer.main \ 15 | config=examples/config.yaml \ 16 | data.format_prompt="${SYSTEM_PROMPT}" \ 17 | data.train_files=evaluation_examples/test_success_uitars1.5_wo_impossible.json \ 18 | data.val_files=evaluation_examples/test_success_uitars1.5_wo_impossible.json \ 19 | data.max_prompt_length=64000 \ 20 | data.max_response_length=8192 \ 21 | data.max_pixels=2116800 \ 22 | data.min_pixels=256 \ 23 | data.rollout_batch_size=16 \ 24 | worker.actor.fsdp.torch_dtype=bf16 \ 25 | worker.actor.optim.strategy=adamw_bf16 \ 26 | worker.actor.max_grad_norm=1.0 \ 27 | worker.actor.optim.lr=1e-6 \ 28 | worker.actor.ulysses_sequence_parallel_size=1 \ 29 | worker.actor.padding_free=true \ 30 | worker.actor.ppo_epochs=1 \ 31 | worker.actor.clip_ratio_low=0.2 \ 32 | worker.actor.clip_ratio_high=0.3 \ 33 | worker.actor.global_batch_size=8 \ 34 | worker.actor.micro_batch_size_per_device_for_update=1 \ 35 | worker.actor.micro_batch_size_per_device_for_experience=1 \ 36 | worker.actor.model.model_path=${MODEL_PATH} \ 37 | worker.rollout.gpu_memory_utilization=0.6 \ 38 | worker.rollout.temperature=1.0 \ 39 | worker.rollout.n=$ROLLOUT_N \ 40 | worker.rollout.limit_images=15 \ 41 | worker.rollout.tensor_parallel_size=1 \ 42 | worker.rollout.max_num_batched_tokens=128000 \ 43 | algorithm.disable_kl=True \ 44 | algorithm.kl_coef=0 \ 45 | env.num_envs=$NUM_ENVS \ 46 | env.max_steps=15 \ 47 | trainer.experiment_name=osworld_cot_7b_nokl_twonodes_onlinereplay \ 48 | trainer.n_gpus_per_node=$NUM_GPUS \ 49 | trainer.nnodes=2 \ 50 | trainer.save_freq=8 \ 51 | trainer.save_limit=3 \ 52 | trainer.val_before_train=True \ 53 | trainer.val_freq=8 \ 54 | trainer.total_episodes=15 55 | -------------------------------------------------------------------------------- /examples/osworld_subset32.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | 4 | MODEL_PATH=UI-TARS-1.5-7B 5 | 6 | SYSTEM_PROMPT="""You are helpful assistant.""" 7 | 8 | NUM_GPUS=8 9 | NUM_ENVS=16 10 | ROLLOUT_N=8 11 | 12 | ((ROLLOUT_BSZ = NUM_ENVS/ROLLOUT_N)) 13 | 14 | python3 -m verl.trainer.main \ 15 | config=examples/config.yaml \ 16 | data.format_prompt="${SYSTEM_PROMPT}" \ 17 | data.train_files=evaluation_examples/test_success_middle_difficult.json \ 18 | data.val_files=evaluation_examples/test_success_middle_difficult.json \ 19 | data.max_prompt_length=64000 \ 20 | data.max_response_length=8192 \ 21 | data.max_pixels=2116800 \ 22 | data.min_pixels=256 \ 23 | data.rollout_batch_size=2 \ 24 | worker.actor.fsdp.torch_dtype=bf16 \ 25 | worker.actor.optim.strategy=adamw_bf16 \ 26 | worker.actor.max_grad_norm=1.0 \ 27 | worker.actor.optim.lr=1e-6 \ 28 | worker.actor.optim.lr_warmup_ratio=0.05 \ 29 | worker.actor.ulysses_sequence_parallel_size=1 \ 30 | worker.actor.padding_free=true \ 31 | worker.actor.ppo_epochs=1 \ 32 | worker.actor.clip_ratio_low=0.2 \ 33 | worker.actor.clip_ratio_high=0.3 \ 34 | worker.actor.global_batch_size=1 \ 35 | worker.actor.micro_batch_size_per_device_for_update=1 \ 36 | worker.actor.micro_batch_size_per_device_for_experience=1 \ 37 | worker.actor.model.model_path=${MODEL_PATH} \ 38 | worker.rollout.gpu_memory_utilization=0.6 \ 39 | worker.rollout.temperature=1.0 \ 40 | worker.rollout.n=$ROLLOUT_N \ 41 | worker.rollout.limit_images=15 \ 42 | worker.rollout.tensor_parallel_size=1 \ 43 | worker.rollout.max_num_batched_tokens=128000 \ 44 | algorithm.disable_kl=True \ 45 | algorithm.kl_coef=0 \ 46 | algorithm.enable_replay=True \ 47 | env.num_envs=$NUM_ENVS \ 48 | env.max_steps=15 \ 49 | trainer.experiment_name=osworld_cot_7b_nokl_subset32 \ 50 | trainer.n_gpus_per_node=$NUM_GPUS \ 51 | trainer.nnodes=1 \ 52 | trainer.save_freq=16 \ 53 | trainer.save_limit=3 \ 54 | trainer.val_before_train=True \ 55 | trainer.val_freq=16 \ 56 | trainer.total_episodes=15 57 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import re 17 | 18 | from setuptools import find_packages, setup 19 | 20 | 21 | def get_version() -> str: 22 | with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f: 23 | file_content = f.read() 24 | pattern = r"__version__\W*=\W*\"([^\"]+)\"" 25 | (version,) = re.findall(pattern, file_content) 26 | return version 27 | 28 | 29 | def get_requires() -> list[str]: 30 | with open("requirements.txt", encoding="utf-8") as f: 31 | file_content = f.read() 32 | lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] 33 | return lines 34 | 35 | 36 | extra_require = { 37 | "dev": ["pre-commit", "ruff"], 38 | } 39 | 40 | 41 | def main(): 42 | setup( 43 | name="verl", 44 | version=get_version(), 45 | description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL", 46 | long_description=open("README.md", encoding="utf-8").read(), 47 | long_description_content_type="text/markdown", 48 | author="verl", 49 | author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn", 50 | license="Apache 2.0 License", 51 | url="https://github.com/volcengine/verl", 52 | package_dir={"": "."}, 53 | packages=find_packages(where="."), 54 | python_requires=">=3.9.0", 55 | install_requires=get_requires(), 56 | extras_require=extra_require, 57 | ) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /verl/workers/actor/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for Actor 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, Dict 20 | 21 | import torch 22 | 23 | from ...protocol import DataProto 24 | from .config import ActorConfig 25 | 26 | 27 | __all__ = ["BasePPOActor"] 28 | 29 | 30 | class BasePPOActor(ABC): 31 | def __init__(self, config: ActorConfig): 32 | """The base class for PPO actor 33 | 34 | Args: 35 | config (ActorConfig): a config passed to the PPOActor. 36 | """ 37 | self.config = config 38 | 39 | @abstractmethod 40 | def compute_log_prob(self, data: DataProto) -> torch.Tensor: 41 | """Compute logits given a batch of data. 42 | 43 | Args: 44 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, 45 | ```attention_mask``` and ```position_ids```. 46 | 47 | Returns: 48 | DataProto: a DataProto containing the key ```log_probs``` 49 | """ 50 | pass 51 | 52 | @abstractmethod 53 | def update_policy(self, data: DataProto) -> Dict[str, Any]: 54 | """Update the policy with an iterator of DataProto 55 | 56 | Args: 57 | data (DataProto): an iterator over the DataProto that returns by 58 | ```make_minibatch_iterator``` 59 | 60 | Returns: 61 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model 62 | such as ```loss```, ```grad_norm```, etc,. 63 | """ 64 | pass 65 | -------------------------------------------------------------------------------- /verl/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for tokenization.""" 15 | 16 | from typing import Optional 17 | 18 | from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin 19 | 20 | 21 | def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer: 22 | """Create a huggingface pretrained tokenizer.""" 23 | tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs) 24 | 25 | if tokenizer.bos_token == "" and tokenizer.eos_token == "": 26 | # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance. 27 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a 28 | print("Found gemma model. Set eos_token and eos_token_id to and 107.") 29 | tokenizer.eos_token = "" 30 | 31 | if tokenizer.pad_token_id is None: 32 | print("Pad token is None. Set it to eos_token.") 33 | tokenizer.pad_token = tokenizer.eos_token 34 | 35 | return tokenizer 36 | 37 | 38 | def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]: 39 | """Create a huggingface pretrained processor.""" 40 | try: 41 | processor = AutoProcessor.from_pretrained(model_path, **kwargs) 42 | except Exception: 43 | processor = None 44 | 45 | # Avoid load tokenizer, see: 46 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 47 | if processor is not None and "Processor" not in processor.__class__.__name__: 48 | processor = None 49 | 50 | return processor 51 | -------------------------------------------------------------------------------- /verl/utils/torch_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | HALF_LIST = [16, "16", "fp16", "float16"] 19 | FLOAT_LIST = [32, "32", "fp32", "float32"] 20 | BFLOAT_LIST = ["bf16", "bfloat16"] 21 | 22 | 23 | class PrecisionType: 24 | """Type of precision used. 25 | 26 | >>> PrecisionType.HALF == 16 27 | True 28 | >>> PrecisionType.HALF in (16, "16") 29 | True 30 | """ 31 | 32 | HALF = "16" 33 | FLOAT = "32" 34 | FULL = "64" 35 | BFLOAT = "bf16" 36 | MIXED = "mixed" 37 | 38 | @staticmethod 39 | def is_fp16(precision): 40 | return precision in HALF_LIST 41 | 42 | @staticmethod 43 | def is_fp32(precision): 44 | return precision in FLOAT_LIST 45 | 46 | @staticmethod 47 | def is_bf16(precision): 48 | return precision in BFLOAT_LIST 49 | 50 | @staticmethod 51 | def to_dtype(precision) -> torch.dtype: 52 | if precision in HALF_LIST: 53 | return torch.float16 54 | elif precision in FLOAT_LIST: 55 | return torch.float32 56 | elif precision in BFLOAT_LIST: 57 | return torch.bfloat16 58 | else: 59 | raise RuntimeError(f"unexpected precision: {precision}") 60 | 61 | @staticmethod 62 | def to_str(precision: torch.dtype) -> str: 63 | if precision == torch.float16: 64 | return "float16" 65 | elif precision == torch.float32: 66 | return "float32" 67 | elif precision == torch.bfloat16: 68 | return "bfloat16" 69 | else: 70 | raise RuntimeError(f"unexpected precision: {precision}") 71 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to EasyR1 2 | 3 | Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable. 4 | 5 | It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. 6 | 7 | However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md). 8 | 9 | **This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).** 10 | 11 | ## Ways to contribute 12 | 13 | There are several ways you can contribute to EasyR1: 14 | 15 | * Fix outstanding issues with the existing code. 16 | * Submit issues related to bugs or desired new features. 17 | * Contribute to the examples or to the documentation. 18 | 19 | ### Style guide 20 | 21 | EasyR1 follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details. 22 | 23 | ### Create a Pull Request 24 | 25 | 1. Fork the [repository](https://github.com/hiyouga/EasyR1) by clicking on the [Fork](https://github.com/hiyouga/EasyR1/fork) button on the repository's page. This creates a copy of the code under your GitHub user account. 26 | 27 | 2. Clone your fork to your local disk, and add the base repository as a remote: 28 | 29 | ```bash 30 | git clone git@github.com:[username]/EasyR1.git 31 | cd EasyR1 32 | git remote add upstream https://github.com/hiyouga/EasyR1.git 33 | ``` 34 | 35 | 3. Create a new branch to hold your development changes: 36 | 37 | ```bash 38 | git checkout -b dev_your_branch 39 | ``` 40 | 41 | 4. Set up a development environment by running the following command in a virtual environment: 42 | 43 | ```bash 44 | pip install -e ".[dev]" 45 | ``` 46 | 47 | 5. Check code before commit: 48 | 49 | ```bash 50 | make commit 51 | make style && make quality 52 | ``` 53 | 54 | 6. Submit changes: 55 | 56 | ```bash 57 | git add . 58 | git commit -m "commit message" 59 | git fetch upstream 60 | git rebase upstream/main 61 | git push -u origin dev_your_branch 62 | ``` 63 | 64 | 7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/EasyR1). 65 | -------------------------------------------------------------------------------- /verl/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities to create common models 16 | """ 17 | 18 | from functools import lru_cache 19 | from typing import Optional, Tuple 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch import nn 24 | 25 | 26 | @lru_cache 27 | def is_rank0() -> int: 28 | return (not dist.is_initialized()) or (dist.get_rank() == 0) 29 | 30 | 31 | def print_gpu_memory_usage(prefix: str = "GPU memory usage") -> None: 32 | """Report the current GPU VRAM usage.""" 33 | if is_rank0(): 34 | free_mem, total_mem = torch.cuda.mem_get_info() 35 | print(f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB.") 36 | 37 | 38 | def _get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]: 39 | """Compute the model size.""" 40 | n_params = sum(p.numel() for p in model.parameters()) 41 | 42 | if scale == "auto": 43 | if n_params > 1e9: 44 | scale = "B" 45 | elif n_params > 1e6: 46 | scale = "M" 47 | elif n_params > 1e3: 48 | scale = "K" 49 | else: 50 | scale = "" 51 | 52 | if scale == "B": 53 | n_params = n_params / 1e9 54 | elif scale == "M": 55 | n_params = n_params / 1e6 56 | elif scale == "K": 57 | n_params = n_params / 1e3 58 | elif scale == "": 59 | pass 60 | else: 61 | raise NotImplementedError(f"Unknown scale {scale}.") 62 | 63 | return n_params, scale 64 | 65 | 66 | def print_model_size(model: nn.Module, name: Optional[str] = None) -> None: 67 | """Print the model size.""" 68 | if is_rank0(): 69 | n_params, scale = _get_model_size(model, scale="auto") 70 | if name is None: 71 | name = model.__class__.__name__ 72 | 73 | print(f"{name} contains {n_params:.2f}{scale} parameters.") 74 | -------------------------------------------------------------------------------- /examples/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_files: hiyouga/math12k@train 3 | val_files: hiyouga/math12k@test 4 | prompt_key: problem 5 | answer_key: answer 6 | image_key: images 7 | max_prompt_length: 2048 8 | max_response_length: 2048 9 | rollout_batch_size: 512 10 | val_batch_size: -1 11 | shuffle: true 12 | seed: 1 13 | max_pixels: 4194304 14 | min_pixels: 262144 15 | 16 | algorithm: 17 | adv_estimator: grpo 18 | disable_kl: false 19 | use_kl_loss: true 20 | kl_penalty: low_var_kl 21 | kl_coef: 1.0e-2 22 | 23 | worker: 24 | actor: 25 | global_batch_size: 128 26 | micro_batch_size_per_device_for_update: 4 27 | micro_batch_size_per_device_for_experience: 16 28 | max_grad_norm: 1.0 29 | padding_free: true 30 | ulysses_sequence_parallel_size: 1 31 | model: 32 | model_path: Qwen/Qwen2.5-7B-Instruct 33 | enable_gradient_checkpointing: true 34 | trust_remote_code: false 35 | freeze_vision_tower: false 36 | optim: 37 | lr: 1.0e-6 38 | weight_decay: 1.0e-2 39 | strategy: adamw # {adamw, adamw_bf16} 40 | lr_warmup_ratio: 0.0 41 | fsdp: 42 | enable_full_shard: true 43 | enable_cpu_offload: false 44 | enable_rank0_init: true 45 | offload: 46 | offload_params: true # true: more CPU memory; false: more GPU memory 47 | offload_optimizer: true # true: more CPU memory; false: more GPU memory 48 | 49 | rollout: 50 | temperature: 1.0 51 | n: 5 52 | gpu_memory_utilization: 0.6 53 | enforce_eager: false 54 | enable_chunked_prefill: false 55 | tensor_parallel_size: 2 56 | limit_images: 0 57 | val_override_config: 58 | temperature: 0.5 59 | n: 1 60 | 61 | ref: 62 | fsdp: 63 | enable_full_shard: true 64 | enable_cpu_offload: true # true: more CPU memory; false: more GPU memory 65 | enable_rank0_init: true 66 | offload: 67 | offload_params: false 68 | 69 | reward: 70 | reward_type: function 71 | score_function: math 72 | skip_special_tokens: true 73 | 74 | env: 75 | num_envs: 32 76 | screen_size: [1920, 1080] 77 | 78 | trainer: 79 | total_episodes: 15 80 | logger: ["console", "wandb"] 81 | project_name: easy_r1 82 | experiment_name: qwen2_5_7b_math_grpo 83 | n_gpus_per_node: 8 84 | nnodes: 1 85 | val_freq: 5 # -1 to disable 86 | val_before_train: true 87 | val_only: false 88 | val_generations_to_log: 3 89 | save_freq: 5 # -1 to disable 90 | save_limit: 3 # -1 to disable 91 | save_checkpoint_path: null 92 | load_checkpoint_path: null 93 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/fsdp_ulysses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT 16 | """ 17 | 18 | from torch.distributed.device_mesh import DeviceMesh 19 | 20 | from ...protocol import DataProto, all_gather_data_proto 21 | from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group 22 | from .base import BaseShardingManager 23 | 24 | 25 | class FSDPUlyssesShardingManager(BaseShardingManager): 26 | """ 27 | Sharding manager to support data resharding when using FSDP + Ulysses 28 | """ 29 | 30 | def __init__(self, device_mesh: DeviceMesh): 31 | super().__init__() 32 | self.device_mesh = device_mesh 33 | 34 | def __enter__(self): 35 | if self.device_mesh is not None: 36 | self.prev_sp_group = get_ulysses_sequence_parallel_group() 37 | set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group()) 38 | 39 | def __exit__(self, exc_type, exc_value, traceback): 40 | if self.device_mesh is not None: 41 | set_ulysses_sequence_parallel_group(self.prev_sp_group) 42 | 43 | def preprocess_data(self, data: DataProto) -> DataProto: 44 | """ 45 | AllGather data from sp region 46 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE 47 | In Ulysses, we need to make sure the same data is used across a SP group 48 | """ 49 | if self.device_mesh is not None: 50 | sp_size = self.device_mesh["sp"].size() 51 | sp_group = self.device_mesh["sp"].get_group() 52 | all_gather_data_proto(data, size=sp_size, group=sp_group) 53 | 54 | return data 55 | 56 | def postprocess_data(self, data: DataProto) -> DataProto: 57 | """ 58 | Split the data to follow FSDP partition 59 | """ 60 | if self.device_mesh is not None: 61 | sp_size = self.device_mesh["sp"].size() 62 | sp_rank = self.device_mesh["sp"].get_local_rank() 63 | data = data.chunk(chunks=sp_size)[sp_rank] 64 | 65 | return data 66 | -------------------------------------------------------------------------------- /verl/workers/reward/custom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from collections import defaultdict 17 | from typing import Callable, Dict, List, Tuple, TypedDict 18 | 19 | import torch 20 | from transformers import PreTrainedTokenizer 21 | 22 | from ...protocol import DataProto 23 | from ...utils.reward_score import math_compute_score, r1v_compute_score 24 | from .config import RewardConfig 25 | 26 | 27 | class RewardScore(TypedDict): 28 | overall: float 29 | format: float 30 | accuracy: float 31 | 32 | 33 | class CustomRewardManager: 34 | def __init__(self, tokenizer: PreTrainedTokenizer, config: RewardConfig): 35 | self.config = config 36 | self.tokenizer = tokenizer 37 | if config.score_function == "math": 38 | self.compute_score: Callable[[str, str], RewardScore] = math_compute_score 39 | elif config.score_function == "r1v": 40 | self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score 41 | else: 42 | raise NotImplementedError(f"Unknown score function {config.score_function}.") 43 | 44 | def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: 45 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) 46 | reward_metrics = defaultdict(list) 47 | for i in range(len(data)): 48 | data_item = data[i] # DataProtoItem 49 | response_ids = data_item.batch["responses"] 50 | response_mask = data_item.batch["response_mask"] 51 | valid_response_length = response_mask.sum() 52 | valid_response_ids = response_ids[:valid_response_length] 53 | 54 | response_str = self.tokenizer.decode( 55 | valid_response_ids, skip_special_tokens=self.config.skip_special_tokens 56 | ) 57 | ground_truth = data_item.non_tensor_batch["ground_truth"] 58 | 59 | score = self.compute_score(response_str, ground_truth) 60 | reward_tensor[i, valid_response_length - 1] = score["overall"] 61 | for key, value in score.items(): 62 | reward_metrics[key].append(value) 63 | 64 | return reward_tensor, reward_metrics 65 | -------------------------------------------------------------------------------- /verl/utils/py_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contain small python utility functions 16 | """ 17 | 18 | import importlib.util 19 | import re 20 | from functools import lru_cache 21 | from typing import Any, Dict, List, Union 22 | 23 | import numpy as np 24 | import yaml 25 | from yaml import Dumper 26 | 27 | 28 | def is_sci_notation(number: float) -> bool: 29 | pattern = re.compile(r"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$") 30 | return bool(pattern.match(str(number))) 31 | 32 | 33 | def float_representer(dumper: Dumper, number: Union[float, np.float32, np.float64]): 34 | if is_sci_notation(number): 35 | value = str(number) 36 | if "." not in value and "e" in value: 37 | value = value.replace("e", ".0e", 1) 38 | else: 39 | value = str(round(number, 3)) 40 | 41 | return dumper.represent_scalar("tag:yaml.org,2002:float", value) 42 | 43 | 44 | yaml.add_representer(float, float_representer) 45 | yaml.add_representer(np.float32, float_representer) 46 | yaml.add_representer(np.float64, float_representer) 47 | 48 | 49 | @lru_cache 50 | def is_package_available(name: str) -> bool: 51 | return importlib.util.find_spec(name) is not None 52 | 53 | 54 | def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: 55 | """Union two dict. Will throw an error if there is an item not the same object with the same key.""" 56 | for key in dict2.keys(): 57 | if key in dict1: 58 | assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object" 59 | 60 | dict1[key] = dict2[key] 61 | 62 | return dict1 63 | 64 | 65 | def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None: 66 | """Append dict to a dict of list.""" 67 | for key, val in new_data.items(): 68 | if key not in data: 69 | data[key] = [] 70 | 71 | data[key].append(val) 72 | 73 | 74 | def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]: 75 | unflattened = {} 76 | for key, value in data.items(): 77 | pieces = key.split(sep) 78 | pointer = unflattened 79 | for piece in pieces[:-1]: 80 | if piece not in pointer: 81 | pointer[piece] = {} 82 | 83 | pointer = pointer[piece] 84 | 85 | pointer[pieces[-1]] = value 86 | 87 | return unflattened 88 | 89 | 90 | def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]: 91 | flattened = {} 92 | for key, value in data.items(): 93 | new_key = parent_key + sep + key if parent_key else key 94 | if isinstance(value, dict): 95 | flattened.update(flatten_dict(value, new_key, sep=sep)) 96 | else: 97 | flattened[new_key] = value 98 | 99 | return flattened 100 | 101 | 102 | def convert_dict_to_str(data: Dict[str, Any]) -> str: 103 | return yaml.dump(data, indent=2) 104 | -------------------------------------------------------------------------------- /verl/workers/actor/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Actor config 16 | """ 17 | 18 | from dataclasses import dataclass, field 19 | from typing import Any, Dict, Optional, Tuple 20 | 21 | 22 | @dataclass 23 | class ModelConfig: 24 | model_path: Optional[str] = None 25 | tokenizer_path: Optional[str] = None 26 | override_config: Dict[str, Any] = field(default_factory=dict) 27 | enable_gradient_checkpointing: bool = True 28 | trust_remote_code: bool = True 29 | freeze_vision_tower: bool = False 30 | 31 | def post_init(self): 32 | if self.tokenizer_path is None: 33 | self.tokenizer_path = self.model_path 34 | 35 | 36 | @dataclass 37 | class OptimConfig: 38 | lr: float = 1e-6 39 | betas: Tuple[float, float] = (0.9, 0.999) 40 | weight_decay: float = 1e-2 41 | strategy: str = "adamw" 42 | lr_warmup_ratio: float = 0.0 43 | min_lr_ratio: Optional[float] = None 44 | warmup_style: str = "constant" 45 | """auto keys""" 46 | training_steps: int = field(default=-1, init=False) 47 | 48 | 49 | @dataclass 50 | class FSDPConfig: 51 | enable_full_shard: bool = True 52 | enable_cpu_offload: bool = False 53 | enable_rank0_init: bool = False 54 | use_orig_params: bool = False 55 | torch_dtype: Optional[str] = None 56 | fsdp_size: int = -1 57 | mp_param_dtype: str = "bf16" 58 | mp_reduce_dtype: str = "fp32" 59 | mp_buffer_dtype: str = "fp32" 60 | 61 | 62 | @dataclass 63 | class OffloadConfig: 64 | offload_params: bool = False 65 | offload_optimizer: bool = False 66 | 67 | 68 | @dataclass 69 | class ActorConfig: 70 | strategy: str = "fsdp" 71 | global_batch_size: int = 256 72 | micro_batch_size_per_device_for_update: int = 4 73 | micro_batch_size_per_device_for_experience: int = 16 74 | max_grad_norm: float = 1.0 75 | clip_ratio_low: float = 0.2 76 | clip_ratio_high: float = 0.3 77 | clip_ratio_dual: float = 3.0 78 | ppo_epochs: int = 1 79 | padding_free: bool = False 80 | ulysses_sequence_parallel_size: int = 1 81 | use_torch_compile: bool = True 82 | model: ModelConfig = field(default_factory=ModelConfig) 83 | optim: OptimConfig = field(default_factory=OptimConfig) 84 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 85 | offload: OffloadConfig = field(default_factory=OffloadConfig) 86 | """auto keys""" 87 | global_batch_size_per_device: int = field(default=-1, init=False) 88 | disable_kl: bool = field(default=False, init=False) 89 | use_kl_loss: bool = field(default=False, init=False) 90 | kl_penalty: str = field(default="kl", init=False) 91 | kl_coef: float = field(default=0.0, init=False) 92 | 93 | 94 | @dataclass 95 | class RefConfig: 96 | strategy: str = "fsdp" 97 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 98 | offload: OffloadConfig = field(default_factory=OffloadConfig) 99 | """auto keys""" 100 | micro_batch_size_per_device_for_experience: int = field(default=-1, init=False) 101 | padding_free: bool = field(default=False, init=False) 102 | ulysses_sequence_parallel_size: int = field(default=1, init=False) 103 | use_torch_compile: bool = field(default=True, init=False) 104 | -------------------------------------------------------------------------------- /verl/utils/logger/gen_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | from dataclasses import dataclass 18 | from typing import List, Tuple 19 | 20 | from ..py_functional import is_package_available 21 | 22 | 23 | if is_package_available("wandb"): 24 | import wandb # type: ignore 25 | 26 | 27 | if is_package_available("swanlab"): 28 | import swanlab # type: ignore 29 | 30 | 31 | @dataclass 32 | class GenerationLogger(ABC): 33 | @abstractmethod 34 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: ... 35 | 36 | 37 | @dataclass 38 | class ConsoleGenerationLogger(GenerationLogger): 39 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 40 | for inp, out, lab, score in samples: 41 | print(f"[prompt] {inp}\n[output] {out}\n[ground_truth] {lab}\n[score] {score}\n") 42 | 43 | 44 | @dataclass 45 | class WandbGenerationLogger(GenerationLogger): 46 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 47 | # Create column names for all samples 48 | columns = ["step"] + sum( 49 | [[f"input_{i + 1}", f"output_{i + 1}", f"label_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], 50 | [], 51 | ) 52 | 53 | if not hasattr(self, "validation_table"): 54 | # Initialize the table on first call 55 | self.validation_table = wandb.Table(columns=columns) 56 | 57 | # Create a new table with same columns and existing data 58 | # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 59 | new_table = wandb.Table(columns=columns, data=self.validation_table.data) 60 | 61 | # Add new row with all data 62 | row_data = [step] 63 | for sample in samples: 64 | row_data.extend(sample) 65 | 66 | new_table.add_data(*row_data) 67 | wandb.log({"val/generations": new_table}, step=step) 68 | self.validation_table = new_table 69 | 70 | 71 | @dataclass 72 | class SwanlabGenerationLogger(GenerationLogger): 73 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 74 | swanlab_text_list = [] 75 | for i, sample in enumerate(samples): 76 | row_text = "\n\n---\n\n".join( 77 | (f"input: {sample[0]}", f"output: {sample[1]}", f"label: {sample[2]}", f"score: {sample[3]}") 78 | ) 79 | swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) 80 | 81 | swanlab.log({"val/generations": swanlab_text_list}, step=step) 82 | 83 | 84 | GEN_LOGGERS = { 85 | "console": ConsoleGenerationLogger, 86 | "wandb": WandbGenerationLogger, 87 | "swanlab": SwanlabGenerationLogger, 88 | } 89 | 90 | 91 | @dataclass 92 | class AggregateGenerationsLogger: 93 | def __init__(self, loggers: List[str]): 94 | self.loggers: List[GenerationLogger] = [] 95 | 96 | for logger in loggers: 97 | if logger in GEN_LOGGERS: 98 | self.loggers.append(GEN_LOGGERS[logger]()) 99 | 100 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 101 | for logger in self.loggers: 102 | logger.log(samples, step) 103 | -------------------------------------------------------------------------------- /verl/trainer/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | PPO config 16 | """ 17 | 18 | import os 19 | from dataclasses import asdict, dataclass, field, fields, is_dataclass 20 | from typing import Optional, Tuple 21 | 22 | from ..workers.config import WorkerConfig 23 | 24 | 25 | def recursive_post_init(dataclass_obj): 26 | if hasattr(dataclass_obj, "post_init"): 27 | dataclass_obj.post_init() 28 | 29 | for attr in fields(dataclass_obj): 30 | if is_dataclass(getattr(dataclass_obj, attr.name)): 31 | recursive_post_init(getattr(dataclass_obj, attr.name)) 32 | 33 | 34 | @dataclass 35 | class DataConfig: 36 | train_files: str = "" 37 | val_files: str = "" 38 | prompt_key: str = "prompt" 39 | answer_key: str = "answer" 40 | image_key: str = "images" 41 | max_prompt_length: int = 512 42 | max_response_length: int = 512 43 | rollout_batch_size: int = 512 44 | val_batch_size: int = -1 45 | format_prompt: Optional[str] = None 46 | shuffle: bool = True 47 | seed: int = 1 48 | max_pixels: int = 4194304 49 | min_pixels: int = 262144 50 | 51 | 52 | @dataclass 53 | class AlgorithmConfig: 54 | gamma: float = 1.0 55 | lam: float = 1.0 56 | adv_estimator: str = "grpo" 57 | disable_kl: bool = False 58 | use_kl_loss: bool = False 59 | kl_penalty: str = "kl" 60 | kl_coef: float = 1e-3 61 | kl_type: str = "fixed" 62 | kl_horizon: float = 0.0 63 | kl_target: float = 0.0 64 | enable_replay: bool = False 65 | 66 | 67 | @dataclass 68 | class TrainerConfig: 69 | total_episodes: int = 10 70 | max_steps: Optional[int] = None 71 | project_name: str = "easy_r1" 72 | experiment_name: str = "demo" 73 | logger: Tuple[str] = ("console", "wandb") 74 | nnodes: int = 1 75 | n_gpus_per_node: int = 8 76 | critic_warmup: int = 0 77 | val_freq: int = -1 78 | val_before_train: bool = True 79 | val_only: bool = False 80 | val_generations_to_log: int = 0 81 | save_freq: int = -1 82 | save_limit: int = -1 83 | save_checkpoint_path: Optional[str] = None 84 | load_checkpoint_path: Optional[str] = None 85 | 86 | def post_init(self): 87 | if self.save_checkpoint_path is None: 88 | self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name) 89 | 90 | @dataclass 91 | class EnvConfig: 92 | num_envs: int = 32 93 | max_steps: int = 15 94 | screen_size: Tuple[int, int] = (1920, 1080) 95 | 96 | @dataclass 97 | class PPOConfig: 98 | data: DataConfig = field(default_factory=DataConfig) 99 | worker: WorkerConfig = field(default_factory=WorkerConfig) 100 | algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig) 101 | trainer: TrainerConfig = field(default_factory=TrainerConfig) 102 | env: EnvConfig = field(default_factory=EnvConfig) 103 | 104 | def post_init(self): 105 | self.worker.rollout.prompt_length = self.data.max_prompt_length 106 | self.worker.rollout.response_length = self.data.max_response_length 107 | self.worker.actor.disable_kl = self.algorithm.disable_kl 108 | self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss 109 | self.worker.actor.kl_penalty = self.algorithm.kl_penalty 110 | self.worker.actor.kl_coef = self.algorithm.kl_coef 111 | 112 | def deep_post_init(self): 113 | recursive_post_init(self) 114 | 115 | def to_dict(self): 116 | return asdict(self) 117 | -------------------------------------------------------------------------------- /verl/trainer/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. 16 | """ 17 | 18 | import json 19 | 20 | import ray 21 | from omegaconf import OmegaConf 22 | 23 | from ..single_controller.ray import RayWorkerGroup 24 | from ..utils.tokenizer import get_processor, get_tokenizer 25 | from ..workers.fsdp_workers import FSDPWorker 26 | from ..workers.reward import CustomRewardManager 27 | from .config import PPOConfig 28 | from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role 29 | 30 | 31 | @ray.remote(num_cpus=1) 32 | class Runner: 33 | """A runner for RL training.""" 34 | 35 | def run(self, config: PPOConfig): 36 | # print config 37 | config.deep_post_init() 38 | print(json.dumps(config.to_dict(), indent=2)) 39 | 40 | # instantiate tokenizer 41 | tokenizer = get_tokenizer( 42 | config.worker.actor.model.model_path, 43 | trust_remote_code=config.worker.actor.model.trust_remote_code, 44 | use_fast=True, 45 | ) 46 | processor = get_processor( 47 | config.worker.actor.model.model_path, 48 | trust_remote_code=config.worker.actor.model.trust_remote_code, 49 | use_fast=True, 50 | ) 51 | 52 | # define worker classes 53 | ray_worker_group_cls = RayWorkerGroup 54 | role_worker_mapping = { 55 | Role.ActorRollout: ray.remote(FSDPWorker), 56 | Role.Critic: ray.remote(FSDPWorker), 57 | Role.RefPolicy: ray.remote(FSDPWorker), 58 | } 59 | global_pool_id = "global_pool" 60 | resource_pool_spec = { 61 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, 62 | } 63 | mapping = { 64 | Role.ActorRollout: global_pool_id, 65 | Role.Critic: global_pool_id, 66 | Role.RefPolicy: global_pool_id, 67 | } 68 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) 69 | 70 | reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward) 71 | val_reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward) 72 | 73 | trainer = RayPPOTrainer( 74 | config=config, 75 | tokenizer=tokenizer, 76 | processor=processor, 77 | role_worker_mapping=role_worker_mapping, 78 | resource_pool_manager=resource_pool_manager, 79 | ray_worker_group_cls=ray_worker_group_cls, 80 | reward_fn=reward_fn, 81 | val_reward_fn=val_reward_fn, 82 | ) 83 | trainer.init_workers() 84 | trainer.fit() 85 | 86 | 87 | def main(): 88 | cli_args = OmegaConf.from_cli() 89 | default_config = OmegaConf.structured(PPOConfig()) 90 | 91 | if hasattr(cli_args, "config"): 92 | config_path = cli_args.pop("config", None) 93 | file_config = OmegaConf.load(config_path) 94 | default_config = OmegaConf.merge(default_config, file_config) 95 | 96 | ppo_config = OmegaConf.merge(default_config, cli_args) 97 | ppo_config = OmegaConf.to_object(ppo_config) 98 | 99 | if not ray.is_initialized(): 100 | # this is for local ray cluster 101 | ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) 102 | 103 | print(ray.cluster_resources().keys()) 104 | 105 | runner = Runner.remote() 106 | ray.get(runner.run.remote(ppo_config)) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # outputs 174 | outputs/ 175 | checkpoints/ 176 | wandb/ 177 | -------------------------------------------------------------------------------- /verl/utils/logger/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A unified tracking interface that supports logging data to different backend 16 | """ 17 | 18 | import os 19 | from abc import ABC, abstractmethod 20 | from typing import Any, Dict, List, Optional, Tuple, Union 21 | 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict 25 | from .gen_logger import AggregateGenerationsLogger 26 | 27 | 28 | if is_package_available("mlflow"): 29 | import mlflow # type: ignore 30 | 31 | 32 | if is_package_available("wandb"): 33 | import wandb # type: ignore 34 | 35 | 36 | if is_package_available("swanlab"): 37 | import swanlab # type: ignore 38 | 39 | 40 | class Logger(ABC): 41 | @abstractmethod 42 | def __init__(self, config: Dict[str, Any]) -> None: ... 43 | 44 | @abstractmethod 45 | def log(self, data: Dict[str, Any], step: int) -> None: ... 46 | 47 | def finish(self) -> None: 48 | pass 49 | 50 | 51 | class ConsoleLogger(Logger): 52 | def __init__(self, config: Dict[str, Any]) -> None: 53 | print("Config\n" + convert_dict_to_str(config)) 54 | 55 | def log(self, data: Dict[str, Any], step: int) -> None: 56 | print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data))) 57 | 58 | 59 | class MlflowLogger(Logger): 60 | def __init__(self, config: Dict[str, Any]) -> None: 61 | mlflow.start_run(run_name=config["trainer"]["experiment_name"]) 62 | mlflow.log_params(flatten_dict(config)) 63 | 64 | def log(self, data: Dict[str, Any], step: int) -> None: 65 | mlflow.log_metrics(metrics=data, step=step) 66 | 67 | 68 | class TensorBoardLogger(Logger): 69 | def __init__(self, config: Dict[str, Any]) -> None: 70 | tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log") 71 | os.makedirs(tensorboard_dir, exist_ok=True) 72 | print(f"Saving tensorboard log to {tensorboard_dir}.") 73 | self.writer = SummaryWriter(tensorboard_dir) 74 | self.writer.add_hparams(flatten_dict(config)) 75 | 76 | def log(self, data: Dict[str, Any], step: int) -> None: 77 | for key, value in data.items(): 78 | self.writer.add_scalar(key, value, step) 79 | 80 | def finish(self): 81 | self.writer.close() 82 | 83 | 84 | class WandbLogger(Logger): 85 | def __init__(self, config: Dict[str, Any]) -> None: 86 | wandb.init( 87 | project=config["trainer"]["project_name"], 88 | name=config["trainer"]["experiment_name"], 89 | config=config, 90 | ) 91 | 92 | def log(self, data: Dict[str, Any], step: int) -> None: 93 | wandb.log(data=data, step=step) 94 | 95 | def finish(self) -> None: 96 | wandb.finish() 97 | 98 | 99 | class SwanlabLogger(Logger): 100 | def __init__(self, config: Dict[str, Any]) -> None: 101 | swanlab_key = os.getenv("SWANLAB_API_KEY") 102 | swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log") 103 | swanlab_mode = os.getenv("SWANLAB_MODE", "cloud") 104 | if swanlab_key: 105 | swanlab.login(swanlab_key) 106 | 107 | swanlab.init( 108 | project=config["trainer"]["project_name"], 109 | experiment_name=config["trainer"]["experiment_name"], 110 | config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config}, 111 | logdir=swanlab_dir, 112 | mode=swanlab_mode, 113 | ) 114 | 115 | def log(self, data: Dict[str, Any], step: int) -> None: 116 | swanlab.log(data=data, step=step) 117 | 118 | def finish(self) -> None: 119 | swanlab.finish() 120 | 121 | 122 | LOGGERS = { 123 | "wandb": WandbLogger, 124 | "mlflow": MlflowLogger, 125 | "tensorboard": TensorBoardLogger, 126 | "console": ConsoleLogger, 127 | "swanlab": SwanlabLogger, 128 | } 129 | 130 | 131 | class Tracker: 132 | def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None): 133 | if isinstance(loggers, str): 134 | loggers = [loggers] 135 | 136 | self.loggers: List[Logger] = [] 137 | for logger in loggers: 138 | if logger not in LOGGERS: 139 | raise ValueError(f"{logger} is not supported.") 140 | 141 | self.loggers.append(LOGGERS[logger](config)) 142 | 143 | self.gen_logger = AggregateGenerationsLogger(loggers) 144 | 145 | def log(self, data: Dict[str, Any], step: int) -> None: 146 | for logger in self.loggers: 147 | logger.log(data=data, step=step) 148 | 149 | def log_generation(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 150 | self.gen_logger.log(samples, step) 151 | 152 | def __del__(self): 153 | for logger in self.loggers: 154 | logger.finish() 155 | -------------------------------------------------------------------------------- /verl/utils/fsdp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gc 16 | from collections import defaultdict 17 | from functools import partial 18 | from typing import Callable, Union 19 | 20 | import torch 21 | from torch import nn 22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 23 | from torch.distributed.fsdp._runtime_utils import _lazy_init 24 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 25 | from torch.optim import Optimizer 26 | from transformers import PreTrainedModel 27 | from transformers.trainer_pt_utils import get_module_class_from_name 28 | 29 | 30 | def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]: 31 | param_occurrence = defaultdict(int) 32 | for _, param in model.named_parameters(remove_duplicate=False): 33 | param_occurrence[param] += 1 34 | 35 | duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1} 36 | materialized_params = {} 37 | 38 | def init_fn(module: nn.Module): 39 | for name, param in module.named_parameters(recurse=False): 40 | if param in duplicated_params: 41 | module._parameters[name] = materialized_params.setdefault( 42 | param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad) 43 | ) 44 | else: 45 | module._parameters[name] = nn.Parameter( 46 | torch.empty_like(param.data, device=device), requires_grad=param.requires_grad 47 | ) 48 | 49 | return init_fn 50 | 51 | 52 | def get_fsdp_wrap_policy(model: PreTrainedModel): 53 | """Get FSDP wrap policy for the model. 54 | 55 | Args: 56 | module: The module to get wrap policy for 57 | """ 58 | transformer_cls_to_wrap = set() 59 | for module in model._no_split_modules: 60 | transformer_cls = get_module_class_from_name(model, module) 61 | if transformer_cls is None: 62 | raise Exception(f"Cannot find {module} in pretrained model.") 63 | else: 64 | transformer_cls_to_wrap.add(transformer_cls) 65 | 66 | return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap) 67 | 68 | 69 | @torch.no_grad() 70 | def offload_fsdp_model(model: FSDP, empty_cache: bool = True): 71 | # lazy init FSDP model 72 | _lazy_init(model, model) 73 | assert model._is_root, "Only support root model offloading to CPU" 74 | for handle in model._all_handles: 75 | if handle._offload_params: 76 | continue 77 | 78 | flat_param = handle.flat_param 79 | assert ( 80 | flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() 81 | and id(flat_param.data) != id(flat_param._local_shard) 82 | and flat_param.data.size() == flat_param._local_shard.size() 83 | ) 84 | handle.flat_param_to("cpu", non_blocking=True) 85 | # the following still keeps id(._local_shard) != id(.data) 86 | flat_param._local_shard = flat_param.data 87 | assert id(flat_param._local_shard) != id(flat_param.data) 88 | 89 | if empty_cache: 90 | torch.cuda.empty_cache() 91 | 92 | 93 | @torch.no_grad() 94 | def load_fsdp_model(model: FSDP, empty_cache: bool = True): 95 | # lazy init FSDP model 96 | _lazy_init(model, model) 97 | assert model._is_root, "Only support root model loading to GPU" 98 | for handle in model._all_handles: 99 | if handle._offload_params: 100 | continue 101 | 102 | flat_param = handle.flat_param 103 | handle.flat_param_to("cuda", non_blocking=True) 104 | # the following still keeps id(._local_shard) != id(.data) 105 | flat_param._local_shard = flat_param.data 106 | 107 | if empty_cache: 108 | gc.collect() 109 | 110 | 111 | @torch.no_grad() 112 | def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): 113 | if not optimizer.state: 114 | return 115 | 116 | for param_group in optimizer.param_groups: 117 | for param in param_group["params"]: 118 | state = optimizer.state[param] 119 | for key, value in state.items(): 120 | if isinstance(value, torch.Tensor): 121 | state[key] = value.to("cpu", non_blocking=True) 122 | 123 | if empty_cache: 124 | torch.cuda.empty_cache() 125 | 126 | 127 | @torch.no_grad() 128 | def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): 129 | if not optimizer.state: 130 | return 131 | 132 | for param_group in optimizer.param_groups: 133 | for param in param_group["params"]: 134 | state = optimizer.state[param] 135 | for key, value in state.items(): 136 | if isinstance(value, torch.Tensor): 137 | state[key] = value.to("cuda", non_blocking=True) 138 | 139 | if empty_cache: 140 | gc.collect() 141 | -------------------------------------------------------------------------------- /verl/utils/flops_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, List, Tuple 16 | 17 | import torch 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers.models.llama.configuration_llama import LlamaConfig 22 | 23 | 24 | VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"} 25 | 26 | 27 | def get_device_flops(unit: str = "T") -> float: 28 | def unit_convert(number: float, level: str): 29 | units = ["B", "K", "M", "G", "T", "P"] 30 | if number <= 0: 31 | return number 32 | 33 | ptr = 0 34 | while ptr < len(units) and units[ptr] != level: 35 | number /= 1000 36 | ptr += 1 37 | 38 | return number 39 | 40 | device_name = torch.cuda.get_device_name() 41 | flops = float("inf") # INF flops for unkown gpu type 42 | if "H100" in device_name or "H800" in device_name: 43 | flops = 989e12 44 | elif "A100" in device_name or "A800" in device_name: 45 | flops = 312e12 46 | elif "L40" in device_name: 47 | flops = 181.05e12 48 | elif "L20" in device_name: 49 | flops = 119.5e12 50 | elif "H20" in device_name: 51 | flops = 148e12 52 | elif "910B" in device_name: 53 | flops = 354e12 54 | flops_unit = unit_convert(flops, unit) 55 | return flops_unit 56 | 57 | 58 | class FlopsCounter: 59 | """ 60 | Used to count mfu during training loop 61 | 62 | Example: 63 | flops_counter = FlopsCounter(config) 64 | flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time) 65 | """ 66 | 67 | def __init__(self, config: "LlamaConfig"): 68 | if config.model_type not in VALID_MODLE_TYPE: 69 | print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.") 70 | 71 | self.estimate_func = { 72 | "llama": self._estimate_llama_flops, 73 | "qwen2": self._estimate_llama_flops, 74 | "qwen2_vl": self._estimate_llama_flops, 75 | "qwen2_5_vl": self._estimate_llama_flops, 76 | } 77 | self.config = config 78 | 79 | def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float: 80 | return 0 81 | 82 | def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float: 83 | hidden_size = self.config.hidden_size 84 | vocab_size = self.config.vocab_size 85 | num_hidden_layers = self.config.num_hidden_layers 86 | num_key_value_heads = self.config.num_key_value_heads 87 | num_attention_heads = self.config.num_attention_heads 88 | intermediate_size = self.config.intermediate_size 89 | 90 | head_dim = hidden_size // num_attention_heads 91 | q_size = num_attention_heads * head_dim 92 | k_size = num_key_value_heads * head_dim 93 | v_size = num_key_value_heads * head_dim 94 | 95 | # non-attn per layer parm 96 | # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp 97 | mlp_N = hidden_size * intermediate_size * 3 98 | attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) 99 | emd_and_lm_head_N = vocab_size * hidden_size * 2 100 | # non-attn all_layer parm 101 | dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N 102 | # non-attn all_layer & all_token fwd & bwd flops 103 | dense_N_flops = 6 * dense_N * tokens_sum 104 | 105 | # attn all_layer & all_token fwd & bwd flops 106 | seqlen_square_sum = 0 107 | for seqlen in batch_seqlens: 108 | seqlen_square_sum += seqlen * seqlen 109 | 110 | attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers 111 | 112 | # all_layer & all_token fwd & bwd flops 113 | flops_all_token = dense_N_flops + attn_qkv_flops 114 | flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 115 | return flops_achieved 116 | 117 | def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]: 118 | """ 119 | Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. 120 | 121 | Args: 122 | batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch. 123 | delta_time (float): The time taken to process the batch, in seconds. 124 | 125 | Returns: 126 | estimated_flops (float): The estimated FLOPS based on the input tokens and time. 127 | promised_flops (float): The expected FLOPS of the current device. 128 | """ 129 | tokens_sum = sum(batch_seqlens) 130 | func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops) 131 | estimated_flops = func(tokens_sum, batch_seqlens, delta_time) 132 | promised_flops = get_device_flops() 133 | return estimated_flops, promised_flops 134 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | `hoshihiyouga AT gmail DOT com`. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /verl/utils/checkpoint/checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import random 17 | import re 18 | import shutil 19 | import tempfile 20 | from abc import ABC, abstractmethod 21 | from typing import Any, Dict, Optional, Union 22 | 23 | import numpy as np 24 | import torch 25 | import torch.distributed as dist 26 | from filelock import FileLock 27 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 28 | from transformers import PreTrainedTokenizer, ProcessorMixin 29 | 30 | 31 | CHECKPOINT_TRACKER = "latest_global_step.txt" 32 | 33 | 34 | class BaseCheckpointManager(ABC): 35 | """ 36 | A checkpoint manager that saves and loads 37 | - model 38 | - optimizer 39 | - lr_scheduler 40 | - extra_states 41 | in a SPMD way. 42 | 43 | We save 44 | - sharded model states and optimizer states 45 | - full lr_scheduler states 46 | - huggingface tokenizer and config for ckpt merge 47 | """ 48 | 49 | def __init__( 50 | self, 51 | model: FSDP, 52 | optimizer: torch.optim.Optimizer, 53 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler, 54 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin], 55 | ): 56 | self.model = model 57 | self.optimizer = optimizer 58 | self.lr_scheduler = lr_scheduler 59 | self.processing_class = processing_class 60 | 61 | assert isinstance(self.model, FSDP) 62 | self.rank = dist.get_rank() 63 | self.world_size = dist.get_world_size() 64 | 65 | @abstractmethod 66 | def load_checkpoint(self, *args, **kwargs): 67 | raise NotImplementedError 68 | 69 | @abstractmethod 70 | def save_checkpoint(self, *args, **kwargs): 71 | raise NotImplementedError 72 | 73 | @staticmethod 74 | def local_mkdir(path: str) -> str: 75 | if not os.path.isabs(path): 76 | working_dir = os.getcwd() 77 | path = os.path.join(working_dir, path) 78 | 79 | # Using hash value of path as lock file name to avoid long file name 80 | lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock" 81 | lock_path = os.path.join(tempfile.gettempdir(), lock_filename) 82 | 83 | try: 84 | with FileLock(lock_path, timeout=60): 85 | os.makedirs(path, exist_ok=True) 86 | except Exception as e: 87 | print(f"Warning: Failed to acquire lock for {path}: {e}") 88 | os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory 89 | 90 | return path 91 | 92 | @staticmethod 93 | def get_rng_state() -> Dict[str, Any]: 94 | rng_state = { 95 | "cpu": torch.get_rng_state(), 96 | "cuda": torch.cuda.get_rng_state(), 97 | "numpy": np.random.get_state(), 98 | "random": random.getstate(), 99 | } 100 | return rng_state 101 | 102 | @staticmethod 103 | def load_rng_state(rng_state: Dict[str, Any]): 104 | torch.set_rng_state(rng_state["cpu"]) 105 | torch.cuda.set_rng_state(rng_state["cuda"]) 106 | np.random.set_state(rng_state["numpy"]) 107 | random.setstate(rng_state["random"]) 108 | 109 | 110 | def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]: 111 | if path is None: 112 | return None 113 | 114 | tracker_file = get_checkpoint_tracker_filename(path) 115 | if not os.path.exists(tracker_file): 116 | print("Checkpoint tracker file does not exist: %s", tracker_file) 117 | return None 118 | 119 | with open(tracker_file, "rb") as f: 120 | iteration = int(f.read().decode()) 121 | 122 | ckpt_path = os.path.join(path, directory_format.format(iteration)) 123 | if not os.path.exists(ckpt_path): 124 | print("Checkpoint does not exist: %s", ckpt_path) 125 | return None 126 | 127 | print("Found checkpoint: %s", ckpt_path) 128 | return ckpt_path 129 | 130 | 131 | def get_checkpoint_tracker_filename(root_path: str) -> str: 132 | """ 133 | Tracker file rescords the latest chckpoint during training to restart from. 134 | """ 135 | return os.path.join(root_path, CHECKPOINT_TRACKER) 136 | 137 | 138 | def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"): 139 | """ 140 | Remove the obsolete checkpoints that exceed the save_limit. 141 | """ 142 | if save_limit <= 0: 143 | return 144 | 145 | if not os.path.exists(path): 146 | return 147 | 148 | pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)") 149 | ckpt_folders = [] 150 | for folder in os.listdir(path): 151 | if match := re.match(pattern, folder): 152 | step = int(match.group(1)) 153 | if step < global_step: 154 | ckpt_folders.append((step, folder)) 155 | 156 | ckpt_folders.sort(reverse=True) 157 | for _, folder in ckpt_folders[save_limit - 1 :]: 158 | folder_path = os.path.join(path, folder) 159 | shutil.rmtree(folder_path, ignore_errors=True) 160 | print(f"Removed obsolete checkpoint: {folder_path}") 161 | -------------------------------------------------------------------------------- /verl/trainer/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List 16 | 17 | import numpy as np 18 | import torch 19 | 20 | from ..protocol import DataProto 21 | 22 | 23 | def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: 24 | return {key: np.mean(value) for key, value in metrics.items()} 25 | 26 | 27 | def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]: 28 | sequence_score = batch.batch["token_level_scores"].sum(-1) 29 | sequence_reward = batch.batch["token_level_rewards"].sum(-1) 30 | 31 | advantages = batch.batch["advantages"] 32 | returns = batch.batch["returns"] 33 | 34 | max_response_length = batch.batch["responses"].size(-1) 35 | 36 | # prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() 37 | # response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() 38 | prompt_mask = (torch.logical_and(batch.batch["attention_mask"], batch.batch["labels"] == -100)).bool() 39 | response_mask = (batch.batch["labels"] != -100).bool() 40 | 41 | max_prompt_length = prompt_mask.size(-1) 42 | prompt_length = prompt_mask.sum(-1).float() 43 | response_length = response_mask.sum(-1).float() 44 | num_images = (batch.batch["input_ids"] == 151655).bool().sum(-1).float() // 2691 # image_pad 45 | response_length = response_length / num_images # average response length per action 46 | 47 | valid_adv = torch.masked_select(advantages, response_mask) 48 | valid_returns = torch.masked_select(returns, response_mask) 49 | 50 | if use_critic: 51 | values = batch.batch["values"] 52 | valid_values = torch.masked_select(values, response_mask) 53 | return_diff_var = torch.var(valid_returns - valid_values) 54 | return_var = torch.var(valid_returns) 55 | 56 | metrics = { 57 | # score 58 | "critic/score/mean": torch.mean(sequence_score).detach().item(), 59 | "critic/score/max": torch.max(sequence_score).detach().item(), 60 | "critic/score/min": torch.min(sequence_score).detach().item(), 61 | # reward 62 | "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), 63 | "critic/rewards/max": torch.max(sequence_reward).detach().item(), 64 | "critic/rewards/min": torch.min(sequence_reward).detach().item(), 65 | # adv 66 | "critic/advantages/mean": torch.mean(valid_adv).detach().item(), 67 | "critic/advantages/max": torch.max(valid_adv).detach().item(), 68 | "critic/advantages/min": torch.min(valid_adv).detach().item(), 69 | # returns 70 | "critic/returns/mean": torch.mean(valid_returns).detach().item(), 71 | "critic/returns/max": torch.max(valid_returns).detach().item(), 72 | "critic/returns/min": torch.min(valid_returns).detach().item(), 73 | **( 74 | { 75 | # values 76 | "critic/values/mean": torch.mean(valid_values).detach().item(), 77 | "critic/values/max": torch.max(valid_values).detach().item(), 78 | "critic/values/min": torch.min(valid_values).detach().item(), 79 | # vf explained var 80 | "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), 81 | } 82 | if use_critic 83 | else {} 84 | ), 85 | # response length 86 | "response_length/num_images": torch.mean(num_images).detach().item(), 87 | "response_length/mean": torch.mean(response_length).detach().item(), 88 | "response_length/max": torch.max(response_length).detach().item(), 89 | "response_length/min": torch.min(response_length).detach().item(), 90 | "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) 91 | .detach() 92 | .item(), 93 | # prompt length 94 | "prompt_length/mean": torch.mean(prompt_length).detach().item(), 95 | "prompt_length/max": torch.max(prompt_length).detach().item(), 96 | "prompt_length/min": torch.min(prompt_length).detach().item(), 97 | "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), 98 | } 99 | return metrics 100 | 101 | 102 | def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: 103 | num_response_tokens = torch.sum(batch.batch["response_mask"]).item() 104 | num_overall_tokens = sum(batch.meta_info["global_token_num"]) 105 | num_tokens_of_section = { 106 | **dict.fromkeys(["gen", "reward"], num_response_tokens), 107 | **dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens), 108 | } 109 | return { 110 | **{f"timing_s/{name}": value for name, value in timing_raw.items()}, 111 | **{ 112 | f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] 113 | for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) 114 | }, 115 | } 116 | 117 | 118 | def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]: 119 | total_num_tokens = sum(batch.meta_info["global_token_num"]) 120 | time = timing_raw["step"] 121 | return { 122 | "perf/total_num_tokens": total_num_tokens, 123 | "perf/time_per_step": time, 124 | "perf/throughput": total_num_tokens / (time * n_gpus), 125 | } 126 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/fsdp_vllm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from typing import Dict, Iterable, Tuple, Union 17 | 18 | import torch 19 | import torch.distributed as dist 20 | from torch.distributed._tensor import DTensor 21 | from torch.distributed.device_mesh import DeviceMesh 22 | from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType 23 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP 24 | from vllm import LLM 25 | from vllm.distributed import parallel_state as vllm_ps 26 | 27 | from ...protocol import DataProto, all_gather_data_proto 28 | from ...utils.model_utils import print_gpu_memory_usage 29 | from .base import BaseShardingManager 30 | 31 | 32 | class FSDPVLLMShardingManager(BaseShardingManager): 33 | def __init__( 34 | self, 35 | module: FSDP, 36 | inference_engine: LLM, 37 | device_mesh: DeviceMesh = None, 38 | ): 39 | self.module = module 40 | self.inference_engine = inference_engine 41 | self.device_mesh = device_mesh 42 | with warnings.catch_warnings(): 43 | warnings.simplefilter("ignore") 44 | FSDP.set_state_dict_type( 45 | self.module, 46 | state_dict_type=StateDictType.SHARDED_STATE_DICT, 47 | state_dict_config=ShardedStateDictConfig(), 48 | ) 49 | 50 | self.world_size = dist.get_world_size() 51 | self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() 52 | self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() 53 | self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group 54 | 55 | # Record freed bytes to estimate memory usage correctly 56 | # https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 57 | self.freed_bytes = 0 58 | 59 | # Note that torch_random_states may be different on each dp rank 60 | self.torch_random_states = torch.cuda.get_rng_state() 61 | # get a random rng states 62 | if self.device_mesh is not None: 63 | gen_dp_rank = self.device_mesh["dp"].get_local_rank() 64 | torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states 65 | self.gen_random_states = torch.cuda.get_rng_state() 66 | torch.cuda.set_rng_state(self.torch_random_states) 67 | else: 68 | self.gen_random_states = None 69 | 70 | def _make_weight_iterator( 71 | self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]] 72 | ) -> Iterable[Tuple[str, torch.Tensor]]: 73 | for name, tensor in actor_weights.items(): 74 | yield name, tensor.full_tensor() if self.world_size != 1 else tensor 75 | 76 | def __enter__(self): 77 | # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and 78 | # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. 79 | # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory 80 | # to speed up memory allocations. 81 | # 82 | # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management 83 | # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 84 | torch.cuda.empty_cache() 85 | print_gpu_memory_usage("Before state_dict() in sharding manager") 86 | actor_weights = self.module.state_dict() 87 | print_gpu_memory_usage("After state_dict() in sharding manager") 88 | 89 | self.inference_engine.wake_up() 90 | model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model 91 | model.load_weights(self._make_weight_iterator(actor_weights)) 92 | print_gpu_memory_usage("After sync model weights in sharding manager") 93 | 94 | del actor_weights 95 | torch.cuda.empty_cache() 96 | print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager") 97 | # important: need to manually set the random states of each tp to be identical. 98 | if self.device_mesh is not None: 99 | self.torch_random_states = torch.cuda.get_rng_state() 100 | torch.cuda.set_rng_state(self.gen_random_states) 101 | 102 | def __exit__(self, exc_type, exc_value, traceback): 103 | print_gpu_memory_usage("Before vllm offload in sharding manager") 104 | free_bytes_before_sleep = torch.cuda.mem_get_info()[0] 105 | self.inference_engine.sleep(level=1) 106 | free_bytes_after_sleep = torch.cuda.mem_get_info()[0] 107 | self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep 108 | print_gpu_memory_usage("After vllm offload in sharding manager") 109 | 110 | self.module.train() 111 | torch.cuda.empty_cache() # add empty cache after each compute 112 | 113 | # restore random states 114 | if self.device_mesh is not None: 115 | self.gen_random_states = torch.cuda.get_rng_state() 116 | torch.cuda.set_rng_state(self.torch_random_states) 117 | 118 | def preprocess_data(self, data: DataProto) -> DataProto: 119 | """All gather across tp group to make each rank has identical input.""" 120 | all_gather_data_proto(data, size=self.tp_size, group=self.tp_group) 121 | return data 122 | 123 | def postprocess_data(self, data: DataProto) -> DataProto: 124 | """Get chunk data of this tp rank since we do all gather in preprocess.""" 125 | if self.tp_size > 1: 126 | data = data.chunk(chunks=self.tp_size)[self.tp_rank] 127 | 128 | return data 129 | -------------------------------------------------------------------------------- /verl/utils/checkpoint/fsdp_checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import warnings 17 | from typing import Optional, Union 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 22 | from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType 23 | from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin 24 | 25 | from .checkpoint_manager import BaseCheckpointManager 26 | 27 | 28 | class FSDPCheckpointManager(BaseCheckpointManager): 29 | """ 30 | A checkpoint manager that saves and loads 31 | - model 32 | - optimizer 33 | - lr_scheduler 34 | - extra_states 35 | in a SPMD way. 36 | 37 | We save 38 | - sharded model states and optimizer states 39 | - full lr_scheduler states 40 | - huggingface tokenizer and config for ckpt merge 41 | """ 42 | 43 | def __init__( 44 | self, 45 | model: FSDP, 46 | optimizer: torch.optim.Optimizer, 47 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler, 48 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin], 49 | ): 50 | super().__init__(model, optimizer, lr_scheduler, processing_class) 51 | 52 | def load_checkpoint(self, path: Optional[str] = None): 53 | if path is None: 54 | return 55 | 56 | # every rank download its own checkpoint 57 | model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") 58 | optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") 59 | extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") 60 | print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.") 61 | model_state_dict = torch.load(model_path, weights_only=False) 62 | optimizer_state_dict = torch.load(optim_path, weights_only=False) 63 | extra_state_dict = torch.load(extra_state_path, weights_only=False) 64 | lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] 65 | 66 | state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) 67 | optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) 68 | with warnings.catch_warnings(): 69 | warnings.simplefilter("ignore") 70 | with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): 71 | self.model.load_state_dict(model_state_dict) 72 | if self.optimizer is not None: 73 | self.optimizer.load_state_dict(optimizer_state_dict) 74 | 75 | if self.lr_scheduler is not None: 76 | self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) 77 | 78 | # recover random state 79 | if "rng" in extra_state_dict: 80 | self.load_rng_state(extra_state_dict["rng"]) 81 | 82 | def save_checkpoint(self, path: str): 83 | path = self.local_mkdir(path) 84 | dist.barrier() 85 | 86 | # every rank will save its own model and optim shard 87 | state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) 88 | optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) 89 | with warnings.catch_warnings(): 90 | warnings.simplefilter("ignore") 91 | with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): 92 | model_state_dict = self.model.state_dict() 93 | if self.optimizer is not None: 94 | optimizer_state_dict = self.optimizer.state_dict() 95 | else: 96 | optimizer_state_dict = None 97 | 98 | if self.lr_scheduler is not None: 99 | lr_scheduler_state_dict = self.lr_scheduler.state_dict() 100 | else: 101 | lr_scheduler_state_dict = None 102 | 103 | extra_state_dict = { 104 | "lr_scheduler": lr_scheduler_state_dict, 105 | "rng": self.get_rng_state(), 106 | } 107 | model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") 108 | optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") 109 | extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") 110 | 111 | print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.") 112 | print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.") 113 | print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.") 114 | torch.save(model_state_dict, model_path) 115 | if self.optimizer is not None: 116 | torch.save(optimizer_state_dict, optim_path) 117 | 118 | torch.save(extra_state_dict, extra_path) 119 | 120 | # wait for everyone to dump to local 121 | dist.barrier() 122 | 123 | if self.rank == 0: 124 | hf_path = os.path.join(path, "huggingface") 125 | os.makedirs(hf_path, exist_ok=True) 126 | assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel) 127 | self.model._fsdp_wrapped_module.config.save_pretrained(hf_path) 128 | self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path) 129 | self.processing_class.save_pretrained(hf_path) 130 | 131 | dist.barrier() 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # ARPO: End-to-End Policy Optimization for GUI Agents with Experience Replay 3 | 4 | This repository contains the code and models for the paper: 5 | 6 | > **ARPO: End-to-End Policy Optimization for GUI Agents with Experience Replay** 7 | > *Fanbin Lu, Zhisheng Zhong, Shu Liu, Chi-Wing Fu, Jiaya Jia* 8 | > CUHK, SmartMore, HKUST 9 | > [[Paper](https://arxiv.org/abs/2505.16282)] • [[Project Page](https://github.com/dvlab-research/ARPO)] • [[Model on HF](https://huggingface.co/Fanbin/ARPO_UITARS1.5_7B)] 10 | 11 | ## Overview 12 | 13 | **ARPO (Agentic Replay Policy Optimization)** is a novel reinforcement learning framework designed to train **vision-language GUI agents** to complete **long-horizon desktop tasks**. It builds upon **Group Relative Policy Optimization (GRPO)** and introduces: 14 | 15 | - **Distributed Rollouts**: Scalable task execution across parallel OSWorld environments with docker. 16 | - **Multi-modal Input Support**: Processes long histories (15 steps) of screenshots + actions in an end-to-end way. 17 | 18 | Access our [model](https://huggingface.co/Fanbin/ARPO_UITARS1.5_7B) on huggingface and view [training logs](https://wandb.ai/fanbinlu/arpo) on the Weights & Biases. 19 | 20 |

21 | Trajectory reward during trainig 22 |

23 | 24 | ## 📊 Results on OSWorld 25 | 26 | | Model | 128 training tasks | OSWorld overall| 27 | |-----------------------------|---------|-------| 28 | | UI-Tars-1.5 |68.7% | 23.5% | 29 | | UI-Tars-1.5 + GRPO |72.9% | 26.0% | 30 | | **UI-Tars-1.5 + ARPO (Ours)** |83.9% | **29.9%** | 31 | 32 | > Evaluated with a max of **15 steps per trajectory**. 33 | 34 | --- 35 | 36 | ## 🛠 Installation 37 | 38 | ### 1. Clone the repository and create environment 39 | 40 | ```bash 41 | git clone --recurse-submodules https://github.com/dvlab-research/ARPO.git 42 | cd ARPO 43 | 44 | # Create and activate Conda environment 45 | conda create -n arpo python=3.10 46 | conda activate arpo 47 | 48 | # Install Python dependencies 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | ### 2. Install OSWorld 53 | Follow the origin installation guide of [OSWorld](https://github.com/xlang-ai/OSWorld) if you only want to evaluate the model. If you want to train with GRPO, you are required to pip install it. 54 | ```bash 55 | cd OSWorld 56 | pip install -e . 57 | cd .. 58 | ``` 59 | 60 | > 💡 We strongly recommend running a full evaluation **with Docker** before training to prepare the docker image, Ubuntu VM data, and cache_dir required. 61 | 62 | 63 | ## ⚙️ Setup for Evaluation with OSWorld 64 | 65 | To evaluate ARPO on the OSWorld benchmark with the [released model](https://huggingface.co/Fanbin/ARPO_UITARS1.5_7B) using Docker-based virtual environments, follow these steps: 66 | 67 | ### 1. **Prepare the Environment** 68 | 69 | Ensure you have correctly installed [OSWorld](https://github.com/xlang-ai/OSWorld) by following its Docker setup instructions. Once OSWorld is set up: 70 | 71 | ```bash 72 | nohup bash start_server.sh & 73 | ``` 74 | 75 | ### 2. **Run Evaluation Script** 76 | 77 | Navigate into the OSWorld directory and execute the evaluation script: 78 | 79 | ```bash 80 | cd OSWorld 81 | 82 | python run_multienv_uitars.py \ 83 | --headless \ 84 | --observation_type screenshot \ 85 | --max_steps 15 \ 86 | --max_trajectory_length 15 \ 87 | --temperature 0.6 \ 88 | --model ui-tars \ 89 | --action_space pyautogui \ 90 | --num_envs 8 \ 91 | --result_dir ./results/ \ 92 | --test_all_meta_path ./evaluation_examples/test_all.json \ 93 | --trial-id 0 \ 94 | --server_ip http://127.0.0.1 95 | ``` 96 | 97 | ### ✅ Parameters Explained 98 | 99 | - `--headless`: Enables headless mode (no GUI rendering). 100 | - `--observation_type screenshot`: Use visual observations for the agent. 101 | - `--max_steps` / `--max_trajectory_length`: Limit per-task interaction steps. 102 | - `--temperature`: Sampling temperature for model output. 103 | - `--model`: Name of the model. 104 | - `--num_envs`: Number of parallel environments (VMs). 105 | - `--result_dir`: Directory to store evaluation results. 106 | - `--test_all_meta_path`: JSON file with evaluation task metadata. 107 | - `--trial-id`: ID for the evaluation trial. 108 | - `--server_ip`: IP of the evaluation server (usually localhost). 109 | 110 | > You will find vmware_vm_data/, docker_vm_data/, and cache/ folders under the OSWorld after evaluation. 111 | --- 112 | 113 | ## ⚙️ Setup for GRPO Training 114 | 115 | ```bash 116 | # Link evaluation examples and cache 117 | ln -s $(pwd)/OSWorld/evaluation_examples ./ 118 | mkdir cache_dirs/ 119 | ln -s $(pwd)/OSWorld/cache ./cache_dirs/cache_0 120 | ln -s $(pwd)/OSWorld/vmware_vm_data ./ 121 | ln -s $(pwd)/OSWorld/docker_vm_data ./ 122 | ``` 123 | 124 | To run Docker without `sudo`: 125 | 126 | ```bash 127 | sudo usermod -aG docker $USER 128 | newgrp docker 129 | ``` 130 | 131 | --- 132 | 133 | ## Training ARPO/GRPO with OSWorld 134 | 135 | ### Single Node (subset training: 32 tasks) 136 | If you only have one node, we suggest training on a subset of OSWorld tasks with at most 16 Docker environments. 137 | ```bash 138 | RAY_PORT=2468 139 | RAY_HEAD_IP= 140 | ray start --head --port=$RAY_PORT --resources='{"docker:'$RAY_HEAD_IP'": 128}' 141 | bash ./examples/osworld_subset32.sh 142 | ``` 143 | 144 | ### Multi-Node Setup with Ray (e.g. 8 nodes, 128 envs) 145 | 146 | On **Ray master node**: 147 | 148 | ```bash 149 | RAY_PORT=2468 150 | RAY_HEAD_IP= 151 | ray start --head --port=$RAY_PORT --resources='{"docker:'$RAY_HEAD_IP'": 128}' 152 | ``` 153 | 154 | On **Ray slave nodes** (with GPU): 155 | 156 | ```bash 157 | ray start --address=$RAY_HEAD_IP:$RAY_PORT --num-gpus=8 --resources='{"docker:'$CURRENT_IP'": 128}' 158 | ``` 159 | 160 | Or (CPU only): 161 | 162 | ```bash 163 | ray start --address=$RAY_HEAD_IP:$RAY_PORT --resources='{"docker:'$CURRENT_IP'": 128}' 164 | ``` 165 | 166 | Then run: 167 | 168 | ```bash 169 | bash ./examples/osworld_full_arpo.sh 170 | ``` 171 | 172 | --- 173 | 174 | ## 🔗 Related Projects 175 | 176 | - [OSWorld](https://github.com/FanbinLu/OSWorld) — Realistic GUI environments for multimodal agents modified for GRPO training. 177 | - [EasyR1](https://github.com/hiyouga/EasyR1) An efficient, scalable, multi-modality RL training framework based on veRL, supporting advanced VLMs and algorithms like GRPO. 178 | --- 179 | 180 | ## 📄 Citation 181 | 182 | If you find ARPO useful, please consider citing our work. 183 | -------------------------------------------------------------------------------- /scripts/model_merger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import re 18 | from concurrent.futures import ThreadPoolExecutor 19 | from typing import Dict, List, Tuple 20 | 21 | import torch 22 | from torch.distributed._tensor import DTensor, Placement, Shard 23 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq 24 | 25 | 26 | def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): 27 | if placement.is_replicate(): 28 | return tensors[0] 29 | elif placement.is_partial(): 30 | raise NotImplementedError("Partial placement is not supported yet") 31 | elif placement.is_shard(): 32 | return torch.cat(tensors, dim=placement.dim).contiguous() 33 | else: 34 | raise ValueError(f"Unsupported placement: {placement}") 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model") 40 | parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") 41 | args = parser.parse_args() 42 | 43 | assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface" 44 | local_dir = args.local_dir 45 | 46 | # copy rank zero to find the shape of (dp, fsdp) 47 | rank = 0 48 | world_size = 0 49 | for filename in os.listdir(local_dir): 50 | match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) 51 | if match: 52 | world_size = match.group(1) 53 | break 54 | assert world_size, "No model file with the proper format" 55 | 56 | state_dict = torch.load( 57 | os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu" 58 | ) 59 | pivot_key = sorted(state_dict.keys())[0] 60 | weight = state_dict[pivot_key] 61 | assert isinstance(weight, torch.distributed._tensor.DTensor) 62 | # get sharding info 63 | device_mesh = weight.device_mesh 64 | mesh = device_mesh.mesh 65 | mesh_dim_names = device_mesh.mesh_dim_names 66 | 67 | print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") 68 | 69 | assert mesh_dim_names in (("fsdp",),), f"Unsupported mesh_dim_names {mesh_dim_names}" 70 | 71 | if "tp" in mesh_dim_names: 72 | # fsdp * tp 73 | total_shards = mesh.shape[-1] * mesh.shape[-2] 74 | mesh_shape = (mesh.shape[-2], mesh.shape[-1]) 75 | else: 76 | # fsdp 77 | total_shards = mesh.shape[-1] 78 | mesh_shape = (mesh.shape[-1],) 79 | 80 | print(f"Processing model shards with {total_shards} {mesh_shape} in total") 81 | 82 | model_state_dict_lst = [] 83 | model_state_dict_lst.append(state_dict) 84 | model_state_dict_lst.extend([""] * (total_shards - 1)) 85 | 86 | def process_one_shard(rank): 87 | model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") 88 | state_dict = torch.load(model_path, map_location="cpu", weights_only=False) 89 | model_state_dict_lst[rank] = state_dict 90 | return state_dict 91 | 92 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: 93 | for rank in range(1, total_shards): 94 | executor.submit(process_one_shard, rank) 95 | state_dict = {} 96 | param_placements: Dict[str, List[Placement]] = {} 97 | keys = set(model_state_dict_lst[0].keys()) 98 | for key in keys: 99 | state_dict[key] = [] 100 | for model_state_dict in model_state_dict_lst: 101 | try: 102 | tensor = model_state_dict.pop(key) 103 | except Exception: 104 | print("-" * 30) 105 | print(model_state_dict) 106 | if isinstance(tensor, DTensor): 107 | state_dict[key].append(tensor._local_tensor.bfloat16()) 108 | placements = tuple(tensor.placements) 109 | # replicated placement at dp dimension can be discarded 110 | if mesh_dim_names[0] == "dp": 111 | placements = placements[1:] 112 | if key not in param_placements: 113 | param_placements[key] = placements 114 | else: 115 | assert param_placements[key] == placements 116 | else: 117 | state_dict[key] = tensor.bfloat16() 118 | 119 | del model_state_dict_lst 120 | 121 | for key in sorted(state_dict): 122 | if not isinstance(state_dict[key], list): 123 | print(f"No need to merge key {key}") 124 | continue 125 | # merge shards 126 | placements: Tuple[Shard] = param_placements[key] 127 | if len(mesh_shape) == 1: 128 | # 1-D list, FSDP without TP 129 | assert len(placements) == 1 130 | shards = state_dict[key] 131 | state_dict[key] = merge_by_placement(shards, placements[0]) 132 | else: 133 | # 2-D list, FSDP + TP 134 | raise NotImplementedError("FSDP + TP is not supported yet") 135 | 136 | print("Writing to local disk") 137 | hf_path = os.path.join(local_dir, "huggingface") 138 | config = AutoConfig.from_pretrained(hf_path) 139 | 140 | if "ForTokenClassification" in config.architectures[0]: 141 | auto_model = AutoModelForTokenClassification 142 | elif "ForCausalLM" in config.architectures[0]: 143 | auto_model = AutoModelForCausalLM 144 | elif "ForConditionalGeneration" in config.architectures[0]: 145 | auto_model = AutoModelForVision2Seq 146 | else: 147 | raise NotImplementedError(f"Unknown architecture {config.architectures}") 148 | 149 | with torch.device("meta"): 150 | model = auto_model.from_config(config, torch_dtype=torch.bfloat16) 151 | 152 | model.to_empty(device="cpu") 153 | 154 | print(f"Saving model to {hf_path}") 155 | model.save_pretrained(hf_path, state_dict=state_dict) 156 | del state_dict 157 | del model 158 | if args.hf_upload_path: 159 | # Push to hugging face 160 | from huggingface_hub import HfApi 161 | 162 | api = HfApi() 163 | api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True) 164 | api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model") 165 | -------------------------------------------------------------------------------- /verl/single_controller/base/worker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | the class for Worker 16 | """ 17 | 18 | import os 19 | import socket 20 | from dataclasses import dataclass 21 | from typing import Tuple 22 | 23 | import ray 24 | import torch 25 | 26 | from .decorator import Dispatch, Execute, register 27 | from .register_center.ray import create_worker_group_register_center 28 | 29 | 30 | @dataclass 31 | class DistRankInfo: 32 | tp_rank: int 33 | dp_rank: int 34 | pp_rank: int 35 | 36 | 37 | @dataclass 38 | class DistGlobalInfo: 39 | tp_size: int 40 | dp_size: int 41 | pp_size: int 42 | 43 | 44 | class WorkerHelper: 45 | def _get_node_ip(self) -> str: 46 | host_ipv4 = os.getenv("MY_HOST_IP", None) 47 | host_ipv6 = os.getenv("MY_HOST_IPV6", None) 48 | host_ip_by_env = host_ipv4 or host_ipv6 49 | host_ip_by_sdk = ray._private.services.get_node_ip_address() 50 | 51 | host_ip = host_ip_by_env or host_ip_by_sdk 52 | return host_ip 53 | 54 | def _get_free_port(self) -> int: 55 | with socket.socket() as sock: 56 | sock.bind(("", 0)) 57 | return sock.getsockname()[1] 58 | 59 | def get_availale_master_addr_port(self) -> Tuple[str, str]: 60 | return self._get_node_ip(), str(self._get_free_port()) 61 | 62 | def _get_pid(self): 63 | return 64 | 65 | 66 | class WorkerMeta: 67 | keys = [ 68 | "WORLD_SIZE", 69 | "RANK", 70 | "LOCAL_WORLD_SIZE", 71 | "LOCAL_RANK", 72 | "MASTER_ADDR", 73 | "MASTER_PORT", 74 | "CUDA_VISIBLE_DEVICES", 75 | ] 76 | 77 | def __init__(self, store) -> None: 78 | self._store = store 79 | 80 | def to_dict(self): 81 | return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys} 82 | 83 | 84 | # we assume that in each WorkerGroup, there is a Master Worker 85 | class Worker(WorkerHelper): 86 | """A (distributed) worker.""" 87 | 88 | _world_size: int 89 | _rank: int 90 | _local_world_size: int 91 | _local_rank: int 92 | _master_addr: str 93 | _master_port: str 94 | _cuda_visible_devices: str 95 | 96 | def __new__(cls, *args, **kwargs): 97 | instance = super().__new__(cls) 98 | 99 | # note that here we use int to distinguish 100 | disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0)) 101 | if disable_worker_init: 102 | return instance 103 | 104 | rank = os.getenv("RANK", None) 105 | worker_group_prefix = os.getenv("WG_PREFIX", None) 106 | 107 | # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init 108 | if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: 109 | instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) 110 | 111 | return instance 112 | 113 | def _configure_before_init(self, register_center_name: str, rank: int): 114 | assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}" 115 | 116 | if rank == 0: 117 | master_addr, master_port = self.get_availale_master_addr_port() 118 | rank_zero_info = { 119 | "MASTER_ADDR": master_addr, 120 | "MASTER_PORT": master_port, 121 | } 122 | self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info) 123 | os.environ.update(rank_zero_info) 124 | 125 | def __init__(self, cuda_visible_devices=None) -> None: 126 | # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely 127 | world_size = int(os.getenv("WORLD_SIZE")) 128 | rank = int(os.getenv("RANK")) 129 | self._rank = rank 130 | self._world_size = world_size 131 | 132 | if "AMD" in torch.cuda.get_device_name(): 133 | os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES") 134 | os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK") 135 | cuda_visible_devices = os.getenv("LOCAL_RANK", "0") 136 | torch.cuda.set_device(int(cuda_visible_devices)) 137 | 138 | master_addr = os.getenv("MASTER_ADDR") 139 | master_port = os.getenv("MASTER_PORT") 140 | 141 | local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) 142 | local_rank = int(os.getenv("LOCAL_RANK", "0")) 143 | 144 | store = { 145 | "_world_size": world_size, 146 | "_rank": rank, 147 | "_local_world_size": local_world_size, 148 | "_local_rank": local_rank, 149 | "_master_addr": master_addr, 150 | "_master_port": master_port, 151 | } 152 | if cuda_visible_devices is not None: 153 | store["_cuda_visible_devices"] = cuda_visible_devices 154 | 155 | meta = WorkerMeta(store=store) 156 | self._configure_with_meta(meta=meta) 157 | 158 | def _configure_with_meta(self, meta: WorkerMeta): 159 | """ 160 | This function should only be called inside by WorkerGroup 161 | """ 162 | assert isinstance(meta, WorkerMeta) 163 | self.__dict__.update(meta.to_dict()) # this is hacky 164 | # print(f"__dict__: {self.__dict__}") 165 | for key in WorkerMeta.keys: 166 | val = self.__dict__.get(f"_{key.lower()}", None) 167 | if val is not None: 168 | # print(f"set {key} to {val}") 169 | os.environ[key] = str(val) 170 | 171 | os.environ["REDIS_STORE_SERVER_HOST"] = ( 172 | str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" 173 | ) 174 | 175 | def get_master_addr_port(self): 176 | return self._master_addr, self._master_port 177 | 178 | def get_cuda_visible_devices(self): 179 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set") 180 | return cuda_visible_devices 181 | 182 | def print_rank0(self, *args, **kwargs): 183 | if self.rank == 0: 184 | print(*args, **kwargs) 185 | 186 | @property 187 | def world_size(self): 188 | return self._world_size 189 | 190 | @property 191 | def rank(self): 192 | return self._rank 193 | 194 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC) 195 | def execute_with_func_generator(self, func, *args, **kwargs): 196 | ret_proto = func(self, *args, **kwargs) 197 | return ret_proto 198 | 199 | @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) 200 | def execute_func_rank_zero(self, func, *args, **kwargs): 201 | result = func(*args, **kwargs) 202 | return result 203 | -------------------------------------------------------------------------------- /verl/utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import os 17 | from collections import defaultdict 18 | from io import BytesIO 19 | from typing import Any, Dict, List, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | from datasets import load_dataset 24 | from PIL import Image 25 | from PIL.Image import Image as ImageObject 26 | from torch.utils.data import Dataset 27 | from transformers import PreTrainedTokenizer, ProcessorMixin 28 | 29 | from ..models.transformers.qwen2_vl import get_rope_index 30 | from . import torch_functional as VF 31 | 32 | 33 | def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]: 34 | tensors = defaultdict(list) 35 | non_tensors = defaultdict(list) 36 | for feature in features: 37 | for key, value in feature.items(): 38 | if isinstance(value, torch.Tensor): 39 | tensors[key].append(value) 40 | else: 41 | non_tensors[key].append(value) 42 | 43 | for key, value in tensors.items(): 44 | tensors[key] = torch.stack(value, dim=0) 45 | 46 | for key, value in non_tensors.items(): 47 | non_tensors[key] = np.array(value, dtype=object) 48 | 49 | return {**tensors, **non_tensors} 50 | 51 | 52 | class ImageProcessMixin: 53 | max_pixels: int 54 | min_pixels: int 55 | 56 | def process_image(self, image: Union[Dict[str, Any], ImageObject]) -> ImageObject: 57 | if isinstance(image, dict): 58 | image = Image.open(BytesIO(image["bytes"])) 59 | elif isinstance(image, bytes): 60 | image = Image.open(BytesIO(image)) 61 | 62 | if (image.width * image.height) > self.max_pixels: 63 | resize_factor = math.sqrt(self.max_pixels / (image.width * image.height)) 64 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 65 | image = image.resize((width, height)) 66 | 67 | if (image.width * image.height) < self.min_pixels: 68 | resize_factor = math.sqrt(self.min_pixels / (image.width * image.height)) 69 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 70 | image = image.resize((width, height)) 71 | 72 | if image.mode != "RGB": 73 | image = image.convert("RGB") 74 | 75 | return image 76 | 77 | 78 | class RLHFDataset(Dataset, ImageProcessMixin): 79 | """ 80 | We assume the dataset contains a column that contains prompts and other information 81 | """ 82 | 83 | def __init__( 84 | self, 85 | data_path: str, 86 | tokenizer: PreTrainedTokenizer, 87 | processor: Optional[ProcessorMixin], 88 | prompt_key: str = "prompt", 89 | answer_key: str = "answer", 90 | image_key: str = "images", 91 | max_prompt_length: int = 1024, 92 | truncation: str = "error", 93 | format_prompt: str = None, 94 | max_pixels: int = None, 95 | min_pixels: int = None, 96 | ): 97 | self.tokenizer = tokenizer 98 | self.processor = processor 99 | self.prompt_key = prompt_key 100 | self.answer_key = answer_key 101 | self.image_key = image_key 102 | self.max_prompt_length = max_prompt_length 103 | self.truncation = truncation 104 | self.format_prompt = format_prompt 105 | self.max_pixels = max_pixels 106 | self.min_pixels = min_pixels 107 | 108 | if "@" in data_path: 109 | data_path, data_split = data_path.split("@") 110 | else: 111 | data_split = "train" 112 | 113 | if os.path.isdir(data_path): 114 | self.dataset = load_dataset("parquet", data_dir=data_path, split="train") 115 | elif os.path.isfile(data_path): 116 | self.dataset = load_dataset("parquet", data_files=data_path, split="train") 117 | else: # remote dataset 118 | self.dataset = load_dataset(data_path, split=data_split) 119 | 120 | def __len__(self): 121 | return len(self.dataset) 122 | 123 | def __getitem__(self, index): 124 | row_dict: dict = self.dataset[index] 125 | prompt_str: str = row_dict[self.prompt_key] 126 | if self.format_prompt: 127 | prompt_str = prompt_str + " " + self.format_prompt.strip() 128 | 129 | if self.image_key in row_dict: 130 | # https://huggingface.co/docs/transformers/en/tasks/image_text_to_text 131 | content_list = [] 132 | for i, content in enumerate(prompt_str.split("")): 133 | if i != 0: 134 | content_list.append({"type": "image"}) 135 | 136 | if content: 137 | content_list.append({"type": "text", "text": content}) 138 | 139 | messages = [{"role": "user", "content": content_list}] 140 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 141 | images = [self.process_image(image) for image in row_dict.pop(self.image_key)] 142 | model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt") 143 | input_ids = model_inputs.pop("input_ids")[0] 144 | attention_mask = model_inputs.pop("attention_mask")[0] 145 | row_dict["multi_modal_data"] = {"image": images} 146 | row_dict["multi_modal_inputs"] = dict(model_inputs) 147 | 148 | # qwen2vl mrope 149 | position_ids = get_rope_index( 150 | self.processor, 151 | input_ids=input_ids, 152 | image_grid_thw=model_inputs["image_grid_thw"], 153 | attention_mask=attention_mask, 154 | ) # (3, seq_length) 155 | else: 156 | messages = [{"role": "user", "content": prompt_str}] 157 | prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 158 | model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt") 159 | input_ids = model_inputs.pop("input_ids")[0] 160 | attention_mask = model_inputs.pop("attention_mask")[0] 161 | position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,) 162 | 163 | input_ids, attention_mask, position_ids = VF.postprocess_data( 164 | input_ids=input_ids, 165 | attention_mask=attention_mask, 166 | position_ids=position_ids, 167 | max_length=self.max_prompt_length, 168 | pad_token_id=self.tokenizer.pad_token_id, 169 | left_pad=True, 170 | truncation=self.truncation, 171 | ) 172 | row_dict["input_ids"] = input_ids 173 | row_dict["attention_mask"] = attention_mask 174 | row_dict["position_ids"] = position_ids 175 | row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False) 176 | row_dict["ground_truth"] = row_dict.pop(self.answer_key) 177 | return row_dict 178 | -------------------------------------------------------------------------------- /verl/single_controller/base/worker_group.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | the class of WorkerGroup 16 | """ 17 | 18 | import logging 19 | import signal 20 | import threading 21 | import time 22 | from typing import Any, Callable, Dict, List, Optional 23 | 24 | from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn 25 | 26 | 27 | class ResourcePool: 28 | """The resource pool with meta info such as world size.""" 29 | 30 | def __init__( 31 | self, process_on_nodes: Optional[Any] = None, max_collocate_count: int = 10, n_gpus_per_node: int = 8 32 | ) -> None: 33 | if process_on_nodes is None: 34 | process_on_nodes = [] 35 | 36 | self._store = process_on_nodes 37 | self.max_collocate_count = max_collocate_count 38 | self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node 39 | 40 | def add_node(self, process_count): 41 | self._store.append(process_count) 42 | 43 | @property 44 | def world_size(self): 45 | return sum(self._store) 46 | 47 | def __call__(self) -> Any: 48 | return self._store 49 | 50 | @property 51 | def store(self): 52 | return self._store 53 | 54 | def local_world_size_list(self) -> List[int]: 55 | nested_local_world_size_list = [ 56 | [local_world_size for _ in range(local_world_size)] for local_world_size in self._store 57 | ] 58 | return [item for row in nested_local_world_size_list for item in row] 59 | 60 | def local_rank_list(self) -> List[int]: 61 | nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416 62 | return [item for row in nested_local_rank_list for item in row] 63 | 64 | 65 | class ClassWithInitArgs: 66 | """ 67 | This class stores a class constructor and the args/kwargs to construct the class. 68 | It is used to instantiate the remote class. 69 | """ 70 | 71 | def __init__(self, cls, *args, **kwargs) -> None: 72 | self.cls = cls 73 | self.args = args 74 | self.kwargs = kwargs 75 | 76 | def __call__(self) -> Any: 77 | return self.cls(*self.args, **self.kwargs) 78 | 79 | 80 | def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: 81 | while True: 82 | for worker in workers: 83 | if not is_alive(worker): 84 | logging.warning(f"Worker {worker} is not alive, sending signal to main thread") 85 | signal.raise_signal(signal.SIGABRT) 86 | 87 | time.sleep(gap_time) 88 | 89 | 90 | class WorkerGroup: 91 | """A group of workers""" 92 | 93 | def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: 94 | self._is_init_with_detached_workers = True if resource_pool is None else False 95 | 96 | if resource_pool is not None: 97 | # handle the case when WorkGroup is attached to an existing one 98 | self._procecss_dispatch_config = resource_pool() 99 | else: 100 | self._procecss_dispatch_config = None 101 | 102 | self._workers = [] 103 | self._worker_names = [] 104 | 105 | self._master_addr = None 106 | self._master_port = None 107 | 108 | self._checker_thread: threading.Thread = None 109 | 110 | def _is_worker_alive(self, worker): 111 | raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") 112 | 113 | def _block_until_all_workers_alive(self) -> None: 114 | while True: 115 | all_state = [self._is_worker_alive(worker) for worker in self._workers] 116 | if False in all_state: 117 | time.sleep(1) 118 | else: 119 | break 120 | 121 | def start_worker_aliveness_check(self, every_n_seconds=1) -> None: 122 | # before starting checking worker aliveness, make sure all workers are already alive 123 | self._block_until_all_workers_alive() 124 | 125 | self._checker_thread = threading.Thread( 126 | target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) 127 | ) 128 | self._checker_thread.start() 129 | 130 | @property 131 | def world_size(self): 132 | return len(self._workers) 133 | 134 | def _bind_worker_method(self, user_defined_cls, func_generator): 135 | """ 136 | Bind the worker method to the WorkerGroup 137 | """ 138 | for method_name in dir(user_defined_cls): 139 | try: 140 | method = getattr(user_defined_cls, method_name) 141 | assert callable(method), f"{method_name} in {user_defined_cls} is not callable" 142 | except Exception: 143 | # if it is a property, it will fail because Class doesn't have instance property 144 | continue 145 | 146 | if hasattr(method, MAGIC_ATTR): 147 | # this method is decorated by register 148 | attribute = getattr(method, MAGIC_ATTR) 149 | assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}" 150 | assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" 151 | 152 | dispatch_mode = attribute["dispatch_mode"] 153 | execute_mode = attribute["execute_mode"] 154 | blocking = attribute["blocking"] 155 | 156 | # get dispatch fn 157 | if isinstance(dispatch_mode, Dispatch): 158 | # get default dispatch fn 159 | fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) 160 | dispatch_fn = fn["dispatch_fn"] 161 | collect_fn = fn["collect_fn"] 162 | else: 163 | assert isinstance(dispatch_mode, dict) 164 | assert "dispatch_fn" in dispatch_mode 165 | assert "collect_fn" in dispatch_mode 166 | dispatch_fn = dispatch_mode["dispatch_fn"] 167 | collect_fn = dispatch_mode["collect_fn"] 168 | 169 | # get execute_fn_name 170 | execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) 171 | wg_execute_fn_name = execute_mode["execute_fn_name"] 172 | 173 | # get execute_fn from string 174 | try: 175 | execute_fn = getattr(self, wg_execute_fn_name) 176 | assert callable(execute_fn), "execute_fn must be callable" 177 | except Exception: 178 | print(f"execute_fn {wg_execute_fn_name} is invalid") 179 | raise 180 | 181 | # bind a new method to the RayWorkerGroup 182 | func = func_generator( 183 | self, 184 | method_name, 185 | dispatch_fn=dispatch_fn, 186 | collect_fn=collect_fn, 187 | execute_fn=execute_fn, 188 | blocking=blocking, 189 | ) 190 | 191 | try: 192 | setattr(self, method_name, func) 193 | except Exception: 194 | raise ValueError(f"Fail to set method_name {method_name}") 195 | -------------------------------------------------------------------------------- /verl/single_controller/base/decorator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum, auto 16 | from functools import wraps 17 | from types import FunctionType 18 | from typing import TYPE_CHECKING, Dict, List, Literal, Union 19 | 20 | import ray 21 | 22 | from ...protocol import DataProto, DataProtoFuture 23 | 24 | 25 | if TYPE_CHECKING: 26 | from .worker_group import WorkerGroup 27 | 28 | 29 | # here we add a magic number of avoid user-defined function already have this attribute 30 | MAGIC_ATTR = "attrs_3141562937" 31 | 32 | 33 | class Dispatch(Enum): 34 | RANK_ZERO = auto() 35 | ONE_TO_ALL = auto() 36 | ALL_TO_ALL = auto() 37 | DP_COMPUTE = auto() 38 | DP_COMPUTE_PROTO = auto() 39 | DP_COMPUTE_PROTO_WITH_FUNC = auto() 40 | DP_COMPUTE_METRIC = auto() 41 | 42 | 43 | class Execute(Enum): 44 | ALL = 0 45 | RANK_ZERO = 1 46 | 47 | 48 | def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs): 49 | splitted_args = [] 50 | for arg in args: 51 | assert isinstance(arg, (DataProto, DataProtoFuture)) 52 | splitted_args.append(arg.chunk(chunks=chunks)) 53 | 54 | splitted_kwargs = {} 55 | for key, value in kwargs.items(): 56 | assert isinstance(value, (DataProto, DataProtoFuture)) 57 | splitted_kwargs[key] = value.chunk(chunks=chunks) 58 | 59 | return splitted_args, splitted_kwargs 60 | 61 | 62 | def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs): 63 | args = tuple([arg] * worker_group.world_size for arg in args) 64 | kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} 65 | return args, kwargs 66 | 67 | 68 | def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs): 69 | return args, kwargs 70 | 71 | 72 | def collect_all_to_all(worker_group: "WorkerGroup", output): 73 | return output 74 | 75 | 76 | def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto: 77 | # make sure all the elements in output has the same type 78 | for output in outputs: 79 | assert type(output) is type(outputs[0]) 80 | 81 | output = outputs[0] 82 | 83 | if isinstance(output, DataProto): 84 | return DataProto.concat(outputs) 85 | elif isinstance(output, ray.ObjectRef): 86 | return DataProtoFuture.concat(outputs) 87 | else: 88 | raise NotImplementedError 89 | 90 | 91 | def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs): 92 | for arg in args: 93 | assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size 94 | 95 | for value in kwargs.values(): 96 | assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size 97 | 98 | return args, kwargs 99 | 100 | 101 | def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]: 102 | assert len(outputs) == worker_group.world_size 103 | return outputs 104 | 105 | 106 | def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs): 107 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) 108 | return splitted_args, splitted_kwargs 109 | 110 | 111 | def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs): 112 | assert type(args[0]) is FunctionType # NOTE: The first one args is a function! 113 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) 114 | splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args 115 | return splitted_args_with_func, splitted_kwargs 116 | 117 | 118 | def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto: 119 | for output in outputs: 120 | assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}" 121 | 122 | outputs = collect_dp_compute(worker_group, outputs) 123 | return _concat_data_proto_or_future(outputs) 124 | 125 | 126 | def get_predefined_dispatch_fn(dispatch_mode: Dispatch): 127 | predefined_dispatch_mode_fn = { 128 | Dispatch.ONE_TO_ALL: { 129 | "dispatch_fn": dispatch_one_to_all, 130 | "collect_fn": collect_all_to_all, 131 | }, 132 | Dispatch.ALL_TO_ALL: { 133 | "dispatch_fn": dispatch_all_to_all, 134 | "collect_fn": collect_all_to_all, 135 | }, 136 | Dispatch.DP_COMPUTE: { 137 | "dispatch_fn": dispatch_dp_compute, 138 | "collect_fn": collect_dp_compute, 139 | }, 140 | Dispatch.DP_COMPUTE_PROTO: { 141 | "dispatch_fn": dispatch_dp_compute_data_proto, 142 | "collect_fn": collect_dp_compute_data_proto, 143 | }, 144 | Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { 145 | "dispatch_fn": dispatch_dp_compute_data_proto_with_func, 146 | "collect_fn": collect_dp_compute_data_proto, 147 | }, 148 | Dispatch.DP_COMPUTE_METRIC: { 149 | "dispatch_fn": dispatch_dp_compute_data_proto, 150 | "collect_fn": collect_dp_compute, 151 | }, 152 | } 153 | return predefined_dispatch_mode_fn[dispatch_mode] 154 | 155 | 156 | def get_predefined_execute_fn(execute_mode: Execute): 157 | """ 158 | Note that here we only asks execute_all and execute_rank_zero to be implemented 159 | Leave the choice of how these two functions handle argument 'blocking' to users 160 | """ 161 | predefined_execute_mode_fn = { 162 | Execute.ALL: {"execute_fn_name": "execute_all"}, 163 | Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, 164 | } 165 | return predefined_execute_mode_fn[execute_mode] 166 | 167 | 168 | def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]): 169 | assert isinstance(dispatch_mode, (Dispatch, dict)), ( 170 | f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" 171 | ) 172 | if isinstance(dispatch_mode, dict): 173 | necessary_keys = ["dispatch_fn", "collect_fn"] 174 | for key in necessary_keys: 175 | assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" 176 | 177 | 178 | def _check_execute_mode(execute_mode: Execute): 179 | assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" 180 | 181 | 182 | def _materialize_futures(*args, **kwargs): 183 | new_args = [] 184 | for arg in args: 185 | if isinstance(arg, DataProtoFuture): 186 | arg = arg.get() 187 | # add more type to materialize 188 | new_args.append(arg) 189 | 190 | for key, value in kwargs.items(): 191 | if isinstance(value, DataProtoFuture): 192 | kwargs[key] = value.get() 193 | 194 | new_args = tuple(new_args) 195 | return new_args, kwargs 196 | 197 | 198 | def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): 199 | _check_dispatch_mode(dispatch_mode=dispatch_mode) 200 | _check_execute_mode(execute_mode=execute_mode) 201 | 202 | def decorator(func): 203 | @wraps(func) 204 | def inner(*args, **kwargs): 205 | if materialize_futures: 206 | args, kwargs = _materialize_futures(*args, **kwargs) 207 | return func(*args, **kwargs) 208 | 209 | attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} 210 | setattr(inner, MAGIC_ATTR, attrs) 211 | return inner 212 | 213 | return decorator 214 | -------------------------------------------------------------------------------- /verl/models/transformers/flash_attention_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team 2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 3 | # Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import inspect 18 | import os 19 | from typing import Optional, Tuple 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check 24 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 25 | 26 | from ...utils.ulysses import ( 27 | gather_heads_scatter_seq, 28 | gather_seq_scatter_heads, 29 | get_ulysses_sequence_parallel_group, 30 | get_ulysses_sequence_parallel_world_size, 31 | ) 32 | 33 | 34 | if is_flash_attn_2_available(): 35 | from flash_attn import flash_attn_func, flash_attn_varlen_func 36 | 37 | _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters 38 | _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters 39 | _flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" 40 | _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 41 | 42 | 43 | def prepare_fa2_from_position_ids( 44 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor 45 | ): 46 | query = query.view(-1, query.size(-2), query.size(-1)) 47 | key = key.contiguous().view(-1, key.size(-2), key.size(-1)) 48 | value = value.contiguous().view(-1, value.size(-2), value.size(-1)) 49 | position_ids = position_ids.flatten() 50 | indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) 51 | cu_seqlens = torch.cat( 52 | ( 53 | indices_q[position_ids == 0], 54 | torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), 55 | ) 56 | ) 57 | max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope 58 | return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) 59 | 60 | 61 | def _custom_flash_attention_forward( 62 | query_states: torch.Tensor, 63 | key_states: torch.Tensor, 64 | value_states: torch.Tensor, 65 | attention_mask: Optional[torch.Tensor], 66 | query_length: int, 67 | is_causal: bool = True, 68 | position_ids: Optional[torch.Tensor] = None, 69 | sliding_window: Optional[int] = None, 70 | use_top_left_mask: bool = False, 71 | deterministic: Optional[bool] = None, 72 | **kwargs, 73 | ): 74 | """ 75 | Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) 76 | """ 77 | if not use_top_left_mask: 78 | causal = is_causal 79 | else: 80 | causal = is_causal and query_length != 1 81 | 82 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 83 | use_sliding_windows = ( 84 | _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window 85 | ) 86 | flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} 87 | 88 | if _flash_supports_deterministic: 89 | flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled 90 | 91 | if kwargs.get("softcap") is not None: 92 | flash_kwargs["softcap"] = kwargs.pop("softcap") 93 | 94 | query_states, key_states, value_states = fa_peft_integration_check( 95 | query_states, key_states, value_states, target_dtype=torch.bfloat16 96 | ) 97 | 98 | sp_size = get_ulysses_sequence_parallel_world_size() 99 | if sp_size > 1: 100 | # (batch_size, seq_length, num_head, head_size) 101 | query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) 102 | key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) 103 | value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) 104 | position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)] 105 | position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group()) 106 | position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length) 107 | 108 | if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope 109 | position_ids = position_ids[0] 110 | 111 | if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): 112 | batch_size = query_states.size(0) 113 | query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( 114 | query_states, key_states, value_states, position_ids 115 | ) 116 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 117 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 118 | attn_output = flash_attn_varlen_func( 119 | query_states, 120 | key_states, 121 | value_states, 122 | cu_seqlens_q=cu_seqlens_q, 123 | cu_seqlens_k=cu_seqlens_k, 124 | max_seqlen_q=max_seqlen_in_batch_q, 125 | max_seqlen_k=max_seqlen_in_batch_k, 126 | dropout_p=kwargs.pop("dropout", 0.0), 127 | softmax_scale=kwargs.pop("softmax_scale", None), 128 | causal=causal, 129 | **flash_kwargs, 130 | ) 131 | attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) 132 | else: 133 | attn_output = _flash_attention_forward( 134 | query_states, 135 | key_states, 136 | value_states, 137 | attention_mask, 138 | query_length, 139 | is_causal=is_causal, 140 | sliding_window=sliding_window, 141 | use_top_left_mask=use_top_left_mask, 142 | deterministic=deterministic, 143 | **kwargs, 144 | ) # do not pass position_ids to old flash_attention_forward 145 | 146 | if sp_size > 1: 147 | # (batch_size, seq_length, num_head, head_size) 148 | attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) 149 | 150 | return attn_output 151 | 152 | 153 | def flash_attention_forward( 154 | module: torch.nn.Module, 155 | query: torch.Tensor, 156 | key: torch.Tensor, 157 | value: torch.Tensor, 158 | attention_mask: Optional[torch.Tensor], 159 | dropout: float = 0.0, 160 | scaling: Optional[float] = None, 161 | sliding_window: Optional[int] = None, 162 | softcap: Optional[float] = None, 163 | **kwargs, 164 | ) -> Tuple[torch.Tensor, None]: 165 | # This is before the transpose 166 | q_len = query.shape[2] 167 | 168 | # FA2 uses non-transposed inputs 169 | query = query.transpose(1, 2) 170 | key = key.transpose(1, 2) 171 | value = value.transpose(1, 2) 172 | 173 | # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice 174 | kwargs.pop("is_causal", None) 175 | 176 | attn_output = _custom_flash_attention_forward( 177 | query, 178 | key, 179 | value, 180 | attention_mask, 181 | query_length=q_len, 182 | is_causal=True, 183 | dropout=dropout, 184 | softmax_scale=scaling, 185 | sliding_window=sliding_window, 186 | softcap=softcap, 187 | use_top_left_mask=_flash_use_top_left_mask, 188 | **kwargs, 189 | ) 190 | 191 | return attn_output, None 192 | -------------------------------------------------------------------------------- /verl/workers/rollout/vllm_rollout_spmd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The vllm_rollout that can be applied in different backend 16 | When working with FSDP: 17 | - Use DTensor weight loader (recommended) or HF weight loader 18 | - Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM 19 | """ 20 | 21 | import os 22 | from contextlib import contextmanager 23 | from typing import Any, List, Union 24 | 25 | import numpy as np 26 | import torch 27 | import torch.distributed 28 | from tensordict import TensorDict 29 | from transformers import PreTrainedTokenizer 30 | from vllm import LLM, RequestOutput, SamplingParams 31 | 32 | from ...protocol import DataProto 33 | from ...utils import torch_functional as VF 34 | from ...utils.torch_dtypes import PrecisionType 35 | from .base import BaseRollout 36 | from .config import RolloutConfig 37 | 38 | 39 | def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]: 40 | if isinstance(value, torch.Tensor): 41 | return value.repeat_interleave(repeats, dim=0) 42 | else: 43 | return np.repeat(value, repeats, axis=0) 44 | 45 | 46 | class vLLMRollout(BaseRollout): 47 | def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer): 48 | """A vLLM rollout. It requires the module is supported by the vllm. 49 | 50 | Args: 51 | module: module here follows huggingface APIs 52 | config: DictConfig 53 | tokenizer: the task/model tokenizer 54 | """ 55 | super().__init__() 56 | self.rank = int(os.getenv("RANK", "0")) 57 | self.config = config 58 | self.pad_token_id = tokenizer.pad_token_id 59 | if config.tensor_parallel_size > torch.distributed.get_world_size(): 60 | raise ValueError("Tensor parallelism size should be less than world size.") 61 | 62 | if config.max_num_batched_tokens < config.prompt_length + config.response_length: 63 | raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.") 64 | 65 | vllm_init_kwargs = {} 66 | if config.limit_images > 0: 67 | vllm_init_kwargs = {"limit_mm_per_prompt": {"image": config.limit_images}} 68 | 69 | self.inference_engine = LLM( 70 | model=model_path, 71 | skip_tokenizer_init=False, 72 | tensor_parallel_size=config.tensor_parallel_size, 73 | dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)), 74 | gpu_memory_utilization=config.gpu_memory_utilization, 75 | enforce_eager=config.enforce_eager, 76 | # max_model_len=config.prompt_length + config.response_length, 77 | # max_num_batched_tokens=config.max_num_batched_tokens, 78 | enable_sleep_mode=True, 79 | distributed_executor_backend="external_launcher", 80 | disable_custom_all_reduce=True, 81 | disable_mm_preprocessor_cache=True, 82 | disable_log_stats=config.disable_log_stats, 83 | enable_chunked_prefill=config.enable_chunked_prefill, 84 | **vllm_init_kwargs, 85 | ) 86 | 87 | # Offload vllm model to reduce peak memory usage 88 | self.inference_engine.sleep(level=1) 89 | 90 | sampling_kwargs = {"max_tokens": config.response_length, "detokenize": False} 91 | default_sampling_params = SamplingParams() 92 | for key in config.to_dict().keys(): 93 | if hasattr(default_sampling_params, key): 94 | sampling_kwargs[key] = getattr(config, key) 95 | 96 | print(f"Sampling params: {sampling_kwargs}.") 97 | self.sampling_params = SamplingParams(**sampling_kwargs) 98 | 99 | print(f"Reset sampling_params.n=1") 100 | self.sampling_params.n = 1 101 | 102 | @contextmanager 103 | def update_sampling_params(self, **kwargs): 104 | # update sampling params 105 | old_sampling_params_args = {} 106 | if kwargs: 107 | for key, value in kwargs.items(): 108 | if hasattr(self.sampling_params, key): 109 | old_value = getattr(self.sampling_params, key) 110 | old_sampling_params_args[key] = old_value 111 | setattr(self.sampling_params, key, value) 112 | 113 | yield 114 | # roll back to previous sampling params 115 | for key, value in old_sampling_params_args.items(): 116 | setattr(self.sampling_params, key, value) 117 | 118 | @torch.no_grad() 119 | def generate_sequences(self, prompts: DataProto) -> DataProto: 120 | # left-padded attention_mask 121 | input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length) 122 | attention_mask: torch.Tensor = prompts.batch["attention_mask"] 123 | position_ids: torch.Tensor = prompts.batch["position_ids"] 124 | eos_token_id: int = prompts.meta_info["eos_token_id"] 125 | batch_size = input_ids.size(0) 126 | 127 | non_tensor_batch = prompts.non_tensor_batch 128 | if batch_size != len(non_tensor_batch["raw_prompt_ids"]): 129 | raise RuntimeError("vllm sharding manager is not work properly.") 130 | 131 | if "multi_modal_data" in non_tensor_batch: 132 | vllm_inputs = [] 133 | for raw_prompt_ids, multi_modal_data in zip( 134 | non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data") 135 | ): 136 | vllm_inputs.append({"prompt_token_ids": list(raw_prompt_ids), "multi_modal_data": multi_modal_data}) 137 | else: 138 | vllm_inputs = [ 139 | {"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") 140 | ] 141 | 142 | # users can customize different sampling_params at different run 143 | with self.update_sampling_params(**prompts.meta_info): 144 | completions: List[RequestOutput] = self.inference_engine.generate( 145 | prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=(self.rank == 0) 146 | ) 147 | response_ids = [output.token_ids for completion in completions for output in completion.outputs] 148 | response_ids = VF.pad_2d_list_to_length( 149 | response_ids, self.pad_token_id, max_length=self.config.response_length 150 | ).to(input_ids.device) 151 | 152 | if self.sampling_params.n > 1: 153 | batch_size = batch_size * self.sampling_params.n 154 | input_ids = _repeat_interleave(input_ids, self.sampling_params.n) 155 | attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) 156 | position_ids = _repeat_interleave(position_ids, self.sampling_params.n) 157 | if "multi_modal_inputs" in non_tensor_batch.keys(): 158 | non_tensor_batch["multi_modal_inputs"] = _repeat_interleave( 159 | non_tensor_batch["multi_modal_inputs"], self.sampling_params.n 160 | ) 161 | 162 | sequence_ids = torch.cat([input_ids, response_ids], dim=-1) 163 | response_length = response_ids.size(1) 164 | delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) 165 | delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1) 166 | if position_ids.dim() == 3: # qwen2vl mrope 167 | delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) 168 | 169 | # prompt: left pad + response: right pad 170 | # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0] 171 | # position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11] 172 | response_position_ids = position_ids[..., -1:] + delta_position_id 173 | position_ids = torch.cat([position_ids, response_position_ids], dim=-1) 174 | response_mask = VF.get_response_mask( 175 | response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype 176 | ) 177 | attention_mask = torch.cat((attention_mask, response_mask), dim=-1) 178 | 179 | # all the tp ranks should contain the same data here. data in all ranks are valid 180 | batch = TensorDict( 181 | { 182 | "prompts": input_ids, 183 | "responses": response_ids, 184 | "input_ids": sequence_ids, # here input_ids become the whole sentences 185 | "attention_mask": attention_mask, 186 | "response_mask": response_mask, 187 | "position_ids": position_ids, 188 | }, 189 | batch_size=batch_size, 190 | ) 191 | return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) 192 | -------------------------------------------------------------------------------- /verl/workers/critic/dp_critic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Implement Critic 16 | """ 17 | 18 | import os 19 | from collections import defaultdict 20 | from typing import Any, Dict 21 | 22 | import torch 23 | from ray.experimental.tqdm_ray import tqdm 24 | from torch import nn 25 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 26 | 27 | from ...protocol import DataProto 28 | from ...trainer import core_algos 29 | from ...utils import torch_functional as VF 30 | from ...utils.py_functional import append_to_dict 31 | from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs 32 | from .base import BasePPOCritic 33 | from .config import CriticConfig 34 | 35 | 36 | try: 37 | from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input 38 | except ImportError: 39 | pass 40 | 41 | 42 | __all__ = ["DataParallelPPOCritic"] 43 | 44 | 45 | class DataParallelPPOCritic(BasePPOCritic): 46 | def __init__(self, config: CriticConfig, critic_module: nn.Module, critic_optimizer: torch.optim.Optimizer): 47 | super().__init__(config) 48 | self.rank = int(os.getenv("RANK", "0")) 49 | self.critic_module = critic_module 50 | self.critic_optimizer = critic_optimizer 51 | 52 | def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Tensor: 53 | input_ids = micro_batch["input_ids"] 54 | batch_size, seqlen = input_ids.shape 55 | attention_mask = micro_batch["attention_mask"] 56 | position_ids = micro_batch["position_ids"] 57 | responses = micro_batch["responses"] 58 | response_length = responses.size(-1) 59 | if position_ids.dim() == 3: # qwen2vl mrope 60 | position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) 61 | 62 | multi_modal_inputs = {} 63 | if "multi_modal_inputs" in micro_batch: 64 | for key in micro_batch["multi_modal_inputs"][0].keys(): 65 | multi_modal_inputs[key] = torch.cat( 66 | [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 67 | ) 68 | 69 | if self.config.padding_free: 70 | input_ids_rmpad, indices, *_ = unpad_input( 71 | input_ids.unsqueeze(-1), attention_mask 72 | ) # input_ids_rmpad (total_nnz, ...) 73 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) 74 | 75 | # unpad the position_ids to align the rotary 76 | if position_ids.dim() == 3: 77 | position_ids_rmpad = ( 78 | index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) 79 | .transpose(0, 1) 80 | .unsqueeze(1) 81 | ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) 82 | else: 83 | position_ids_rmpad = index_first_axis( 84 | rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices 85 | ).transpose(0, 1) 86 | 87 | # pad and slice the inputs if sp > 1 88 | if self.config.ulysses_sequence_parallel_size > 1: 89 | input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( 90 | input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size 91 | ) 92 | 93 | # only pass input_ids and position_ids to enable flash_attn_varlen 94 | output = self.critic_module( 95 | input_ids=input_ids_rmpad, 96 | attention_mask=None, 97 | position_ids=position_ids_rmpad, 98 | **multi_modal_inputs, 99 | use_cache=False, 100 | ) # prevent model thinks we are generating 101 | values_rmpad = output.logits 102 | values_rmpad = values_rmpad.squeeze(0) # (total_nnz) 103 | 104 | # gather output if sp > 1 105 | if self.config.ulysses_sequence_parallel_size > 1: 106 | values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) 107 | 108 | # pad it back 109 | values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) 110 | values = values[:, -response_length - 1 : -1] 111 | else: 112 | output = self.critic_module( 113 | input_ids=input_ids, 114 | attention_mask=attention_mask, 115 | position_ids=position_ids, 116 | **multi_modal_inputs, 117 | use_cache=False, 118 | ) 119 | values: torch.Tensor = output.logits 120 | values = values[:, -response_length - 1 : -1].squeeze(-1) # (bsz, response_length, vocab_size) 121 | 122 | return values 123 | 124 | def _optimizer_step(self) -> torch.Tensor: 125 | if isinstance(self.critic_module, FSDP): 126 | grad_norm = self.critic_module.clip_grad_norm_(self.config.max_grad_norm) 127 | else: 128 | grad_norm = torch.nn.utils.clip_grad_norm_( 129 | self.critic_module.parameters(), max_norm=self.config.max_grad_norm 130 | ) 131 | 132 | if not torch.isfinite(grad_norm): 133 | print("Gradient norm is not finite. Skip update.") 134 | else: 135 | self.critic_optimizer.step() 136 | 137 | self.critic_optimizer.zero_grad() 138 | return grad_norm 139 | 140 | @torch.no_grad() 141 | def compute_values(self, data: DataProto) -> torch.Tensor: 142 | self.critic_module.eval() 143 | 144 | select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] 145 | if "multi_modal_inputs" in data.non_tensor_batch.keys(): 146 | non_tensor_select_keys = ["multi_modal_inputs"] 147 | else: 148 | non_tensor_select_keys = [] 149 | 150 | micro_batches = data.select(select_keys, non_tensor_select_keys).split( 151 | self.config.micro_batch_size_per_device_for_experience 152 | ) 153 | values_lst = [] 154 | if self.rank == 0: 155 | micro_batches = tqdm(micro_batches, desc="Compute values", position=2) 156 | 157 | for micro_batch in micro_batches: 158 | model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} 159 | values = self._forward_micro_batch(model_inputs) 160 | values_lst.append(values) 161 | 162 | values = torch.concat(values_lst, dim=0) 163 | responses = data.batch["responses"] 164 | attention_mask = data.batch["attention_mask"] 165 | response_length = responses.size(1) 166 | values = values * attention_mask[:, -response_length - 1 : -1] 167 | return values 168 | 169 | def update_critic(self, data: DataProto) -> Dict[str, Any]: 170 | self.critic_module.train() 171 | 172 | select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] 173 | if "multi_modal_inputs" in data.non_tensor_batch.keys(): 174 | non_tensor_select_keys = ["multi_modal_inputs"] 175 | else: 176 | non_tensor_select_keys = [] 177 | 178 | # Split to make minibatch iterator for updating the actor 179 | # See PPO paper for details. https://arxiv.org/abs/1707.06347 180 | mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device) 181 | 182 | metrics = defaultdict(list) 183 | for _ in range(self.config.ppo_epochs): 184 | if self.rank == 0: 185 | mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2) 186 | 187 | for mini_batch in mini_batches: 188 | gradient_accumulation = ( 189 | self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update 190 | ) 191 | micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update) 192 | if self.rank == 0: 193 | micro_batches = tqdm(micro_batches, desc="Update critic", position=3) 194 | 195 | for micro_batch in micro_batches: 196 | model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} 197 | responses = model_inputs["responses"] 198 | attention_mask = model_inputs["attention_mask"] 199 | values = model_inputs["values"] 200 | returns = model_inputs["returns"] 201 | response_length = responses.size(1) 202 | action_mask = attention_mask[:, -response_length - 1 : -1] # shift left for value computation 203 | 204 | vpreds = self._forward_micro_batch(model_inputs) 205 | vf_loss, vf_clipfrac = core_algos.compute_value_loss( 206 | vpreds=vpreds, 207 | returns=returns, 208 | values=values, 209 | action_mask=action_mask, 210 | cliprange_value=self.config.cliprange_value, 211 | ) 212 | loss = vf_loss / gradient_accumulation 213 | loss.backward() 214 | 215 | batch_metrics = { 216 | "critic/vf_loss": vf_loss.detach().item(), 217 | "critic/vf_clipfrac": vf_clipfrac.detach().item(), 218 | "critic/vpred_mean": VF.masked_mean(vpreds, action_mask).detach().item(), 219 | } 220 | append_to_dict(metrics, batch_metrics) 221 | 222 | grad_norm = self._optimizer_step() 223 | append_to_dict(metrics, {"critic/grad_norm": grad_norm.detach().item()}) 224 | 225 | return metrics 226 | -------------------------------------------------------------------------------- /verl/utils/seqlen_balancing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import heapq 17 | from typing import List, Tuple 18 | 19 | import torch 20 | from tensordict import TensorDict 21 | from torch import distributed as dist 22 | 23 | 24 | class Set: 25 | def __init__(self) -> None: 26 | self.sum = 0 27 | self.items = [] 28 | 29 | def add(self, idx: int, val: int): 30 | self.items.append((idx, val)) 31 | self.sum += val 32 | 33 | def merge(self, other): 34 | for idx, val in other.items: 35 | self.items.append((idx, val)) 36 | self.sum += val 37 | 38 | def __lt__(self, other): 39 | if self.sum != other.sum: 40 | return self.sum < other.sum 41 | if len(self.items) != len(other.items): 42 | return len(self.items) < len(other.items) 43 | return self.items < other.items 44 | 45 | 46 | class State: 47 | def __init__(self, items: List[Tuple[int, int]], k: int) -> None: 48 | self.k = k 49 | # sets should always be decreasing order 50 | self.sets = [Set() for _ in range(k)] 51 | assert len(items) in [1, k], f"{len(items)} not in [1, {k}]" 52 | for i, (idx, seqlen) in enumerate(items): 53 | self.sets[i].add(idx=idx, val=seqlen) 54 | self.sets = sorted(self.sets, reverse=True) 55 | 56 | def get_partitions(self): 57 | partitions = [] 58 | for i in range(len(self.sets)): 59 | cur_partition = [] 60 | for idx, _ in self.sets[i].items: 61 | cur_partition.append(idx) 62 | partitions.append(cur_partition) 63 | return partitions 64 | 65 | def merge(self, other): 66 | for i in range(self.k): 67 | self.sets[i].merge(other.sets[self.k - 1 - i]) 68 | self.sets = sorted(self.sets, reverse=True) 69 | 70 | @property 71 | def spread(self) -> int: 72 | return self.sets[0].sum - self.sets[-1].sum 73 | 74 | def __lt__(self, other): 75 | # least heap, let the state with largest spread to be popped first, 76 | # if the spread is the same, let the state who has the largest set 77 | # to be popped first. 78 | if self.spread != other.spread: 79 | return self.spread > other.spread 80 | return self.sets[0] > other.sets[0] 81 | 82 | def __repr__(self) -> str: 83 | repr_str = "[" 84 | for i in range(self.k): 85 | if i > 0: 86 | repr_str += "," 87 | repr_str += "{" 88 | for j, (_, seqlen) in enumerate(self.sets[i].items): 89 | if j > 0: 90 | repr_str += "," 91 | repr_str += str(seqlen) 92 | repr_str += "}" 93 | repr_str += "]" 94 | return repr_str 95 | 96 | 97 | def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): 98 | # see: https://en.wikipedia.org/wiki/Largest_differencing_method 99 | sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) 100 | states_pq: List[State] = [] 101 | if equal_size: 102 | assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" 103 | for offset in range(0, len(sorted_seqlen_list), k_partitions): 104 | items = [] 105 | for i in range(k_partitions): 106 | seqlen, idx = sorted_seqlen_list[offset + i] 107 | items.append((idx, seqlen)) 108 | heapq.heappush(states_pq, State(items=items, k=k_partitions)) 109 | else: 110 | for seqlen, idx in sorted_seqlen_list: 111 | heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions)) 112 | 113 | while len(states_pq) > 1: 114 | state0 = heapq.heappop(states_pq) 115 | state1 = heapq.heappop(states_pq) 116 | # merge states 117 | state0.merge(state1) 118 | heapq.heappush(states_pq, state0) 119 | 120 | final_state = states_pq[0] 121 | partitions = final_state.get_partitions() 122 | if equal_size: 123 | for i, partition in enumerate(partitions): 124 | assert len(partition) * k_partitions == len(seqlen_list), ( 125 | f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" 126 | ) 127 | return partitions 128 | 129 | 130 | def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool): 131 | bias = sum(seqlen_list) + 1 if equal_size else 0 132 | sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)] 133 | partitions = [[] for _ in range(k_partitions)] 134 | partition_sums = [0 for _ in range(k_partitions)] 135 | for seqlen, i in sorted_seqlen: 136 | min_idx = None 137 | for j in range(k_partitions): 138 | if min_idx is None or partition_sums[j] < partition_sums[min_idx]: 139 | min_idx = j 140 | partitions[min_idx].append(i) 141 | partition_sums[min_idx] += seqlen 142 | if equal_size: 143 | for i, partition in enumerate(partitions): 144 | assert len(partition) * k_partitions == len(seqlen_list), ( 145 | f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" 146 | ) 147 | return partitions 148 | 149 | 150 | def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool): 151 | """get order of seq lengths to make partitions balanced, this is 152 | used in balacing sum of seqlength across dp ranks and microbatches 153 | Parameters: 154 | seqlen_list (List[int]): 155 | seq lengths of each items 156 | k_partitions (int): 157 | resulting number of partitions 158 | equal_size (bool): 159 | if True, number of items in each partitions must be equal. 160 | if False, only consider balancing the sum, each partition can have 161 | variable number of items 162 | Returns: 163 | partitions (List[List[int]]): 164 | return k_partitions list containing the index of items. 165 | """ 166 | assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" 167 | 168 | def _check_and_sort_partitions(partitions): 169 | assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" 170 | seen_idx = set() 171 | sorted_partitions = [None] * k_partitions 172 | for i, partition in enumerate(partitions): 173 | assert len(partition) > 0, f"the {i}-th partition is empty" 174 | for idx in partition: 175 | seen_idx.add(idx) 176 | sorted_partitions[i] = sorted(partition) 177 | assert seen_idx == set(range(len(seqlen_list))) 178 | return sorted_partitions 179 | 180 | partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) 181 | return _check_and_sort_partitions(partitions) 182 | 183 | 184 | def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix): 185 | # add some metrics of seqlen sum on dp ranks 186 | k_partition = len(partitions) 187 | # assert len(seqlen_list) % k_partition == 0 188 | batch_size = len(seqlen_list) // k_partition 189 | min_sum_seqlen = None 190 | max_sum_seqlen = None 191 | total_sum_seqlen = 0 192 | for offset in range(0, len(seqlen_list), batch_size): 193 | cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size]) 194 | if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen: 195 | min_sum_seqlen = cur_sum_seqlen 196 | if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen: 197 | max_sum_seqlen = cur_sum_seqlen 198 | total_sum_seqlen += cur_sum_seqlen 199 | 200 | balanced_sum_seqlen_list = [] 201 | for partition in partitions: 202 | cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition]) 203 | balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced) 204 | # print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list) 205 | min_sum_seqlen_balanced = min(balanced_sum_seqlen_list) 206 | max_sum_seqlen_balanced = max(balanced_sum_seqlen_list) 207 | 208 | return { 209 | f"{prefix}/min": min_sum_seqlen, 210 | f"{prefix}/max": max_sum_seqlen, 211 | f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen, 212 | f"{prefix}/balanced_min": min_sum_seqlen_balanced, 213 | f"{prefix}/balanced_max": max_sum_seqlen_balanced, 214 | f"{prefix}/mean": total_sum_seqlen / len(partitions), 215 | } 216 | 217 | 218 | def ceildiv(a, b): 219 | return -(a // -b) 220 | 221 | 222 | def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): 223 | """Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len 224 | and the number of valid tokens in each micro batch is well balanced. 225 | """ 226 | # this is per local micro_bsz 227 | max_seq_len = batch["attention_mask"].shape[-1] 228 | assert max_token_len >= max_seq_len, ( 229 | f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" 230 | ) 231 | 232 | seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) 233 | total_seqlen = seq_len_effective.sum().item() 234 | num_micro_batches = ceildiv(total_seqlen, max_token_len) 235 | if dist.is_initialized(): 236 | num_micro_batches = torch.tensor([num_micro_batches], device="cuda") 237 | dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) 238 | num_micro_batches = num_micro_batches.cpu().item() 239 | 240 | seq_len_effective = seq_len_effective.tolist() 241 | assert num_micro_batches <= len(seq_len_effective) 242 | 243 | micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) 244 | 245 | micro_batches = [] 246 | 247 | for partition in micro_bsz_idx: 248 | curr_micro_batch = [] 249 | for idx in partition: 250 | curr_micro_batch.append(batch[idx : idx + 1]) 251 | curr_micro_batch = torch.cat(curr_micro_batch) 252 | 253 | micro_batches.append(curr_micro_batch) 254 | 255 | return micro_batches, micro_bsz_idx 256 | 257 | 258 | def get_reverse_idx(idx_map): 259 | reverse_idx_map = copy.deepcopy(idx_map) 260 | 261 | for i, idx in enumerate(idx_map): 262 | reverse_idx_map[idx] = i 263 | 264 | return reverse_idx_map 265 | --------------------------------------------------------------------------------