├── assets
├── wechat.jpg
├── easyr1_grpo.png
├── traj_reward.png
└── qwen2_5_vl_7b_geo.png
├── .gitmodules
├── examples
├── runtime_env.yaml
├── baselines
│ ├── qwen2_5_vl_3b_geoqa8k.sh
│ └── qwen2_5_vl_3b_clevr.sh
├── osworld_full_arpo.sh
├── osworld_subset32.sh
└── config.yaml
├── requirements.txt
├── verl
├── utils
│ ├── __init__.py
│ ├── logger
│ │ ├── __init__.py
│ │ ├── gen_logger.py
│ │ └── logger.py
│ ├── checkpoint
│ │ ├── __init__.py
│ │ ├── checkpoint_manager.py
│ │ └── fsdp_checkpoint_manager.py
│ ├── reward_score
│ │ ├── __init__.py
│ │ ├── math.py
│ │ └── r1v.py
│ ├── tokenizer.py
│ ├── torch_dtypes.py
│ ├── model_utils.py
│ ├── py_functional.py
│ ├── fsdp_utils.py
│ ├── flops_counter.py
│ ├── dataset.py
│ └── seqlen_balancing.py
├── models
│ ├── __init__.py
│ ├── transformers
│ │ ├── __init__.py
│ │ └── flash_attention_utils.py
│ └── monkey_patch.py
├── trainer
│ ├── __init__.py
│ ├── replay_buffer.py
│ ├── config.py
│ ├── main.py
│ └── metrics.py
├── workers
│ ├── __init__.py
│ ├── reward
│ │ ├── __init__.py
│ │ ├── config.py
│ │ └── custom.py
│ ├── rollout
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── config.py
│ │ └── vllm_rollout_spmd.py
│ ├── critic
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── config.py
│ │ └── dp_critic.py
│ ├── sharding_manager
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── fsdp_ulysses.py
│ │ └── fsdp_vllm.py
│ ├── actor
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── config.py
│ └── config.py
├── single_controller
│ ├── __init__.py
│ ├── base
│ │ ├── register_center
│ │ │ ├── __init__.py
│ │ │ └── ray.py
│ │ ├── __init__.py
│ │ ├── worker.py
│ │ ├── worker_group.py
│ │ └── decorator.py
│ └── ray
│ │ └── __init__.py
└── __init__.py
├── .pre-commit-config.yaml
├── start_server.sh
├── pyproject.toml
├── .github
├── workflows
│ └── tests.yml
├── CONTRIBUTING.md
└── CODE_OF_CONDUCT.md
├── setup.py
├── .gitignore
├── README.md
└── scripts
└── model_merger.py
/assets/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ARPO/HEAD/assets/wechat.jpg
--------------------------------------------------------------------------------
/assets/easyr1_grpo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ARPO/HEAD/assets/easyr1_grpo.png
--------------------------------------------------------------------------------
/assets/traj_reward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ARPO/HEAD/assets/traj_reward.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "OSWorld"]
2 | path = OSWorld
3 | url = https://github.com/FanbinLu/OSWorld
4 |
--------------------------------------------------------------------------------
/assets/qwen2_5_vl_7b_geo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ARPO/HEAD/assets/qwen2_5_vl_7b_geo.png
--------------------------------------------------------------------------------
/examples/runtime_env.yaml:
--------------------------------------------------------------------------------
1 | working_dir: ./
2 | excludes: ["/.git/"]
3 | env_vars:
4 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | wandb
3 | tqdm
4 | accelerate==1.4.0
5 | codetiming==1.4.0
6 | datasets==3.3.2
7 | filelock==3.18.0
8 | flash_attn==2.7.4.post1
9 | liger_kernel==0.5.6
10 | mathruler==0.1.0
11 | mlflow==2.22.0
12 | omegaconf==2.3.0
13 | Pillow==11.2.1
14 | psutil==6.1.0
15 | PyYAML==6.0.2
16 | qwen_vl_utils==0.0.11
17 | swanlab==0.5.8
18 | tensordict==0.5.0
19 | torch==2.5.1
20 | torchdata==0.11.0
21 | vllm==0.7.3
22 |
--------------------------------------------------------------------------------
/verl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/workers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/single_controller/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __version__ = "0.2.0.dev"
16 |
--------------------------------------------------------------------------------
/verl/single_controller/base/register_center/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: check-ast
6 | - id: check-added-large-files
7 | args: ['--maxkb=25000']
8 | - id: check-merge-conflict
9 | - id: check-yaml
10 | - id: debug-statements
11 | - id: end-of-file-fixer
12 | - id: requirements-txt-fixer
13 | - id: trailing-whitespace
14 | args: [--markdown-linebreak-ext=md]
15 | - id: no-commit-to-branch
16 | args: ['--branch', 'main']
17 |
18 | - repo: https://github.com/asottile/pyupgrade
19 | rev: v3.17.0
20 | hooks:
21 | - id: pyupgrade
22 | args: [--py38-plus]
23 |
--------------------------------------------------------------------------------
/verl/utils/logger/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .logger import Tracker
17 |
18 |
19 | __all__ = ["Tracker"]
20 |
--------------------------------------------------------------------------------
/start_server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model=PATH_TO_MODEL
4 | model_name=ui-tars
5 | num_images=16
6 |
7 | port=9000
8 |
9 | # Function to clean up processes on exit
10 | cleanup() {
11 | echo "Stopping all processes..."
12 | pkill -P $$ # Kill all child processes of this script
13 | exit 0
14 | }
15 |
16 | # Trap SIGINT (Ctrl+C) and SIGTERM to run cleanup function
17 | trap cleanup SIGINT SIGTERM
18 |
19 | # Start processes
20 | for i in {0..7}; do
21 | CUDA_VISIBLE_DEVICES=$i python -m vllm.entrypoints.openai.api_server \
22 | --served-model-name $model_name \
23 | --model $model \
24 | --limit-mm-per-prompt image=$num_images \
25 | -tp=1 \
26 | --port $((9000 + i)) &
27 | done
28 |
29 | # Wait to keep the script running
30 | wait
--------------------------------------------------------------------------------
/verl/workers/reward/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .config import RewardConfig
16 | from .custom import CustomRewardManager
17 |
18 |
19 | __all__ = ["CustomRewardManager", "RewardConfig"]
20 |
--------------------------------------------------------------------------------
/verl/utils/checkpoint/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt
16 |
17 |
18 | __all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"]
19 |
--------------------------------------------------------------------------------
/verl/workers/rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .config import RolloutConfig
17 | from .vllm_rollout_spmd import vLLMRollout
18 |
19 |
20 | __all__ = ["RolloutConfig", "vLLMRollout"]
21 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .math import math_compute_score
17 | from .r1v import r1v_compute_score
18 |
19 |
20 | __all__ = ["math_compute_score", "r1v_compute_score"]
21 |
--------------------------------------------------------------------------------
/verl/single_controller/base/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .worker import Worker
16 | from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
17 |
18 |
19 | __all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
20 |
--------------------------------------------------------------------------------
/verl/single_controller/ray/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls
16 |
17 |
18 | __all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"]
19 |
--------------------------------------------------------------------------------
/verl/workers/critic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPOCritic
16 | from .config import CriticConfig, ModelConfig
17 | from .dp_critic import DataParallelPPOCritic
18 |
19 |
20 | __all__ = ["BasePPOCritic", "CriticConfig", "DataParallelPPOCritic", "ModelConfig"]
21 |
--------------------------------------------------------------------------------
/verl/workers/reward/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Reward config
16 | """
17 |
18 | from dataclasses import dataclass
19 |
20 |
21 | @dataclass
22 | class RewardConfig:
23 | reward_type: str = "function"
24 | score_function: str = "math"
25 | skip_special_tokens: bool = True
26 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .base import BaseShardingManager
17 | from .fsdp_ulysses import FSDPUlyssesShardingManager
18 | from .fsdp_vllm import FSDPVLLMShardingManager
19 |
20 |
21 | __all__ = ["BaseShardingManager", "FSDPUlyssesShardingManager", "FSDPVLLMShardingManager"]
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "verl"
7 | dynamic = [
8 | "version",
9 | "dependencies",
10 | "optional-dependencies",
11 | "requires-python",
12 | "authors",
13 | "description",
14 | "readme",
15 | "license"
16 | ]
17 |
18 | [tool.ruff]
19 | target-version = "py39"
20 | line-length = 119
21 | indent-width = 4
22 |
23 | [tool.ruff.lint]
24 | ignore = ["C901", "E501", "E741", "W605", "C408"]
25 | select = ["C", "E", "F", "I", "W", "RUF022"]
26 |
27 | [tool.ruff.lint.per-file-ignores]
28 | "__init__.py" = ["E402", "F401", "F403", "F811"]
29 |
30 | [tool.ruff.lint.isort]
31 | lines-after-imports = 2
32 | known-first-party = ["verl"]
33 | known-third-party = ["torch", "transformers", "wandb"]
34 |
35 | [tool.ruff.format]
36 | quote-style = "double"
37 | indent-style = "space"
38 | skip-magic-trailing-comma = false
39 | line-ending = "auto"
40 |
--------------------------------------------------------------------------------
/verl/workers/rollout/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from abc import ABC, abstractmethod
16 |
17 | from ...protocol import DataProto
18 |
19 |
20 | __all__ = ["BaseRollout"]
21 |
22 |
23 | class BaseRollout(ABC):
24 | @abstractmethod
25 | def generate_sequences(self, prompts: DataProto) -> DataProto:
26 | """Generate sequences"""
27 | pass
28 |
--------------------------------------------------------------------------------
/examples/baselines/qwen2_5_vl_3b_geoqa8k.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
4 |
5 | FORMAT_PROMPT="""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant
6 | first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning
7 | process and answer are enclosed within and tags, respectively, i.e.,
8 | reasoning process here answer here """
9 |
10 | python3 -m verl.trainer.main \
11 | config=examples/config.yaml \
12 | data.train_files=leonardPKU/GEOQA_8K_R1V@train \
13 | data.val_files=leonardPKU/GEOQA_8K_R1V@test \
14 | data.format_prompt="${FORMAT_PROMPT}" \
15 | worker.actor.model.model_path=${MODEL_PATH} \
16 | worker.rollout.tensor_parallel_size=1 \
17 | worker.reward.score_function=r1v \
18 | trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
19 | trainer.n_gpus_per_node=8
20 |
--------------------------------------------------------------------------------
/examples/baselines/qwen2_5_vl_3b_clevr.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
4 |
5 | FORMAT_PROMPT="""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant
6 | first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning
7 | process and answer are enclosed within and tags, respectively, i.e.,
8 | reasoning process here answer here """
9 |
10 | python3 -m verl.trainer.main \
11 | config=examples/config.yaml \
12 | data.train_files=BUAADreamer/clevr_count_70k@train \
13 | data.val_files=BUAADreamer/clevr_count_70k@test \
14 | data.format_prompt="${FORMAT_PROMPT}" \
15 | worker.actor.model.model_path=${MODEL_PATH} \
16 | worker.rollout.tensor_parallel_size=1 \
17 | worker.reward.score_function=r1v \
18 | trainer.experiment_name=qwen2_5_vl_3b_clevr \
19 | trainer.n_gpus_per_node=2
20 |
--------------------------------------------------------------------------------
/verl/workers/actor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPOActor
16 | from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
17 | from .dp_actor import DataParallelPPOActor
18 |
19 |
20 | __all__ = [
21 | "ActorConfig",
22 | "BasePPOActor",
23 | "DataParallelPPOActor",
24 | "FSDPConfig",
25 | "ModelConfig",
26 | "OptimConfig",
27 | "RefConfig",
28 | ]
29 |
--------------------------------------------------------------------------------
/verl/single_controller/base/register_center/ray.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ray
16 |
17 |
18 | @ray.remote
19 | class WorkerGroupRegisterCenter:
20 | def __init__(self, rank_zero_info):
21 | self.rank_zero_info = rank_zero_info
22 |
23 | def get_rank_zero_info(self):
24 | return self.rank_zero_info
25 |
26 |
27 | def create_worker_group_register_center(name, info):
28 | return WorkerGroupRegisterCenter.options(name=name).remote(info)
29 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Sharding manager to implement HybridEngine
16 | """
17 |
18 | from ...protocol import DataProto
19 |
20 |
21 | class BaseShardingManager:
22 | def __enter__(self):
23 | pass
24 |
25 | def __exit__(self, exc_type, exc_value, traceback):
26 | pass
27 |
28 | def preprocess_data(self, data: DataProto) -> DataProto:
29 | return data
30 |
31 | def postprocess_data(self, data: DataProto) -> DataProto:
32 | return data
33 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | paths:
8 | - "**.py"
9 | - "requirements.txt"
10 | - ".github/workflows/*.yml"
11 | pull_request:
12 | branches:
13 | - "main"
14 | paths:
15 | - "**.py"
16 | - "requirements.txt"
17 | - ".github/workflows/*.yml"
18 |
19 | jobs:
20 | tests:
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | python-version:
25 | - "3.11"
26 | os:
27 | - "ubuntu-latest"
28 |
29 | runs-on: ${{ matrix.os }}
30 |
31 | steps:
32 | - name: Checkout
33 | uses: actions/checkout@v4
34 |
35 | - name: Set up Python
36 | uses: actions/setup-python@v5
37 | with:
38 | python-version: ${{ matrix.python-version }}
39 | cache: "pip"
40 | cache-dependency-path: "setup.py"
41 |
42 | - name: Install dependencies
43 | run: |
44 | python -m pip install --upgrade pip
45 | python -m pip install ruff
46 |
47 | - name: Check quality
48 | run: |
49 | make style && make quality
50 |
--------------------------------------------------------------------------------
/verl/workers/critic/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Base class for Critic
16 | """
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, Dict
20 |
21 | import torch
22 |
23 | from ...protocol import DataProto
24 | from .config import CriticConfig
25 |
26 |
27 | __all__ = ["BasePPOCritic"]
28 |
29 |
30 | class BasePPOCritic(ABC):
31 | def __init__(self, config: CriticConfig):
32 | self.config = config
33 |
34 | @abstractmethod
35 | def compute_values(self, data: DataProto) -> torch.Tensor:
36 | """Compute values"""
37 | pass
38 |
39 | @abstractmethod
40 | def update_critic(self, data: DataProto) -> Dict[str, Any]:
41 | """Update the critic"""
42 | pass
43 |
--------------------------------------------------------------------------------
/verl/workers/critic/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Critic config
16 | """
17 |
18 | from dataclasses import dataclass, field
19 |
20 | from ..actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig
21 |
22 |
23 | @dataclass
24 | class CriticConfig:
25 | strategy: str = "fsdp"
26 | global_batch_size: int = 256
27 | micro_batch_size_per_device_for_update: int = 4
28 | micro_batch_size_per_device_for_experience: int = 16
29 | max_grad_norm: float = 1.0
30 | cliprange_value: float = 0.5
31 | ppo_epochs: int = 1
32 | padding_free: bool = False
33 | ulysses_sequence_parallel_size: int = 1
34 | model: ModelConfig = field(default_factory=ModelConfig)
35 | optim: OptimConfig = field(default_factory=OptimConfig)
36 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
37 | offload: OffloadConfig = field(default_factory=OffloadConfig)
38 | """auto keys"""
39 | global_batch_size_per_device: int = field(default=-1, init=False)
40 |
--------------------------------------------------------------------------------
/verl/workers/rollout/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Rollout config
16 | """
17 |
18 | from dataclasses import asdict, dataclass, field
19 | from typing import Any, Dict
20 |
21 |
22 | @dataclass
23 | class RolloutConfig:
24 | name: str = "vllm"
25 | n: int = 1
26 | temperature: float = 1.0
27 | top_p: float = 1.0
28 | top_k: int = -1
29 | limit_images: int = 0
30 | dtype: str = "bf16"
31 | gpu_memory_utilization: float = 0.6
32 | ignore_eos: bool = False
33 | enforce_eager: bool = False
34 | enable_chunked_prefill: bool = False # only for v0 engine
35 | tensor_parallel_size: int = 2
36 | max_num_batched_tokens: int = 8192
37 | max_num_seqs: int = 1024
38 | disable_log_stats: bool = True
39 | val_override_config: Dict[str, Any] = field(default_factory=dict)
40 | """auto keys"""
41 | prompt_length: int = field(default=-1, init=False)
42 | response_length: int = field(default=-1, init=False)
43 |
44 | def to_dict(self):
45 | return asdict(self)
46 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/math.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | from typing import Dict
17 |
18 | from mathruler.grader import extract_boxed_content, grade_answer
19 |
20 |
21 | def math_format_reward(predict_str: str) -> float:
22 | pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL)
23 | format_match = re.fullmatch(pattern, predict_str)
24 | return 1.0 if format_match else 0.0
25 |
26 |
27 | def math_acc_reward(predict_str: str, ground_truth: str) -> float:
28 | answer = extract_boxed_content(predict_str)
29 | return 1.0 if grade_answer(answer, ground_truth) else 0.0
30 |
31 |
32 | def math_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
33 | predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) # handle qwen2.5vl-32b format
34 | format = math_format_reward(predict_str)
35 | accuracy = math_acc_reward(predict_str, ground_truth)
36 | return {
37 | "overall": 0.9 * accuracy + 0.1 * format,
38 | "format": format,
39 | "accuracy": accuracy,
40 | }
41 |
--------------------------------------------------------------------------------
/verl/models/monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
17 |
18 | from .transformers.flash_attention_utils import flash_attention_forward
19 | from .transformers.qwen2_vl import qwen2_vl_attn_forward, qwen_2_mixed_modality_forward
20 |
21 |
22 | def apply_ulysses_patch(model_type: str) -> None:
23 | if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
24 | ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
25 | elif model_type in ("qwen2_vl", "qwen2_5_vl"):
26 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
27 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
28 |
29 | Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
30 | Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
31 |
32 | # from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
33 | # Qwen2VLForConditionalGeneration.forward = qwen_2_mixed_modality_forward
34 | else:
35 | raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
36 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/r1v.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | from typing import Dict
17 |
18 | from mathruler.grader import grade_answer
19 |
20 |
21 | def r1v_format_reward(predict_str: str) -> float:
22 | pattern = re.compile(r".*?\s*.*?", re.DOTALL)
23 | format_match = re.fullmatch(pattern, predict_str)
24 | return 1.0 if format_match else 0.0
25 |
26 |
27 | def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
28 | try:
29 | ground_truth = ground_truth.strip()
30 | content_match = re.search(r"(.*?)", predict_str)
31 | given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
32 | if grade_answer(given_answer, ground_truth):
33 | return 1.0
34 |
35 | except Exception:
36 | pass
37 |
38 | return 0.0
39 |
40 |
41 | def r1v_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
42 | format = r1v_format_reward(predict_str)
43 | accuracy = r1v_accuracy_reward(predict_str, ground_truth)
44 | return {
45 | "overall": 0.5 * accuracy + 0.5 * format,
46 | "format": format,
47 | "accuracy": accuracy,
48 | }
49 |
--------------------------------------------------------------------------------
/verl/trainer/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import json
4 | import random
5 | from collections import defaultdict
6 |
7 | from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto, collate_fn
8 |
9 | class ReplayBuffer():
10 |
11 | def __init__(self, json_path, buffer_size):
12 | self.buffer_size = buffer_size
13 |
14 | self.pos_dataset = defaultdict(list)
15 |
16 | # if json_path is not None:
17 | # with open(json_path, 'r') as f:
18 | # replay_data = json.load(f)
19 |
20 | # for data in replay_data:
21 | # # task_id, history_images, history_messages, eval_result
22 | # task_id = data['task_id']
23 | # eval_result = data['eval_result']
24 | # if eval_result > 0.1:
25 | # self.pos_dataset[task_id].append(data)
26 |
27 |
28 | def update_replay_buffer(self, task_config, batch_item, eval_result):
29 | task_id = task_config["task_id"]
30 | if eval_result > 0.1:
31 | task_replay_buffer = self.pos_dataset[task_id]
32 | else:
33 | return
34 |
35 | task_replay_buffer.append(batch_item)
36 |
37 | if len(task_replay_buffer) > self.buffer_size:
38 | task_replay_buffer.pop(0)
39 |
40 | def update_replay_buffer_batch(self, task_configs, batch):
41 | eval_results = batch.batch['eval_results'].tolist()
42 |
43 | for task_config, batch_item, eval_result in zip(task_configs, batch, eval_results):
44 | self.update_replay_buffer(task_config, batch_item, eval_result)
45 |
46 | def get_pos(self, task_id, num_samples=1):
47 | if task_id not in self.pos_dataset:
48 | return DataProto()
49 | else:
50 | datalist = random.choices(self.pos_dataset[task_id], k=num_samples)
51 | return collate_fn(datalist)
52 |
--------------------------------------------------------------------------------
/verl/workers/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | ActorRolloutRef config
16 | """
17 |
18 | from dataclasses import dataclass, field
19 |
20 | from .actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
21 | from .critic import CriticConfig
22 | from .reward import RewardConfig
23 | from .rollout import RolloutConfig
24 |
25 |
26 | __all__ = [
27 | "ActorConfig",
28 | "CriticConfig",
29 | "FSDPConfig",
30 | "ModelConfig",
31 | "OptimConfig",
32 | "RefConfig",
33 | "RewardConfig",
34 | "RolloutConfig",
35 | "WorkerConfig",
36 | ]
37 |
38 |
39 | @dataclass
40 | class WorkerConfig:
41 | hybrid_engine: bool = True
42 | actor: ActorConfig = field(default_factory=ActorConfig)
43 | critic: CriticConfig = field(default_factory=CriticConfig)
44 | ref: RefConfig = field(default_factory=RefConfig)
45 | reward: RewardConfig = field(default_factory=RewardConfig)
46 | rollout: RolloutConfig = field(default_factory=RolloutConfig)
47 |
48 | def post_init(self):
49 | self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience
50 | self.ref.padding_free = self.actor.padding_free
51 | self.ref.ulysses_sequence_parallel_size = self.actor.ulysses_sequence_parallel_size
52 | self.ref.use_torch_compile = self.actor.use_torch_compile
53 |
--------------------------------------------------------------------------------
/examples/osworld_full_arpo.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | MODEL_PATH=UI-TARS-1.5-7B
4 |
5 | SYSTEM_PROMPT="""You are helpful assistant."""
6 |
7 | NUM_GPUS=8
8 | NUM_ENVS=128
9 | ROLLOUT_N=8
10 |
11 | ((ROLLOUT_BSZ = NUM_ENVS/ROLLOUT_N))
12 |
13 |
14 | python3 -m verl.trainer.main \
15 | config=examples/config.yaml \
16 | data.format_prompt="${SYSTEM_PROMPT}" \
17 | data.train_files=evaluation_examples/test_success_uitars1.5_wo_impossible.json \
18 | data.val_files=evaluation_examples/test_success_uitars1.5_wo_impossible.json \
19 | data.max_prompt_length=64000 \
20 | data.max_response_length=8192 \
21 | data.max_pixels=2116800 \
22 | data.min_pixels=256 \
23 | data.rollout_batch_size=16 \
24 | worker.actor.fsdp.torch_dtype=bf16 \
25 | worker.actor.optim.strategy=adamw_bf16 \
26 | worker.actor.max_grad_norm=1.0 \
27 | worker.actor.optim.lr=1e-6 \
28 | worker.actor.ulysses_sequence_parallel_size=1 \
29 | worker.actor.padding_free=true \
30 | worker.actor.ppo_epochs=1 \
31 | worker.actor.clip_ratio_low=0.2 \
32 | worker.actor.clip_ratio_high=0.3 \
33 | worker.actor.global_batch_size=8 \
34 | worker.actor.micro_batch_size_per_device_for_update=1 \
35 | worker.actor.micro_batch_size_per_device_for_experience=1 \
36 | worker.actor.model.model_path=${MODEL_PATH} \
37 | worker.rollout.gpu_memory_utilization=0.6 \
38 | worker.rollout.temperature=1.0 \
39 | worker.rollout.n=$ROLLOUT_N \
40 | worker.rollout.limit_images=15 \
41 | worker.rollout.tensor_parallel_size=1 \
42 | worker.rollout.max_num_batched_tokens=128000 \
43 | algorithm.disable_kl=True \
44 | algorithm.kl_coef=0 \
45 | env.num_envs=$NUM_ENVS \
46 | env.max_steps=15 \
47 | trainer.experiment_name=osworld_cot_7b_nokl_twonodes_onlinereplay \
48 | trainer.n_gpus_per_node=$NUM_GPUS \
49 | trainer.nnodes=2 \
50 | trainer.save_freq=8 \
51 | trainer.save_limit=3 \
52 | trainer.val_before_train=True \
53 | trainer.val_freq=8 \
54 | trainer.total_episodes=15
55 |
--------------------------------------------------------------------------------
/examples/osworld_subset32.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 |
4 | MODEL_PATH=UI-TARS-1.5-7B
5 |
6 | SYSTEM_PROMPT="""You are helpful assistant."""
7 |
8 | NUM_GPUS=8
9 | NUM_ENVS=16
10 | ROLLOUT_N=8
11 |
12 | ((ROLLOUT_BSZ = NUM_ENVS/ROLLOUT_N))
13 |
14 | python3 -m verl.trainer.main \
15 | config=examples/config.yaml \
16 | data.format_prompt="${SYSTEM_PROMPT}" \
17 | data.train_files=evaluation_examples/test_success_middle_difficult.json \
18 | data.val_files=evaluation_examples/test_success_middle_difficult.json \
19 | data.max_prompt_length=64000 \
20 | data.max_response_length=8192 \
21 | data.max_pixels=2116800 \
22 | data.min_pixels=256 \
23 | data.rollout_batch_size=2 \
24 | worker.actor.fsdp.torch_dtype=bf16 \
25 | worker.actor.optim.strategy=adamw_bf16 \
26 | worker.actor.max_grad_norm=1.0 \
27 | worker.actor.optim.lr=1e-6 \
28 | worker.actor.optim.lr_warmup_ratio=0.05 \
29 | worker.actor.ulysses_sequence_parallel_size=1 \
30 | worker.actor.padding_free=true \
31 | worker.actor.ppo_epochs=1 \
32 | worker.actor.clip_ratio_low=0.2 \
33 | worker.actor.clip_ratio_high=0.3 \
34 | worker.actor.global_batch_size=1 \
35 | worker.actor.micro_batch_size_per_device_for_update=1 \
36 | worker.actor.micro_batch_size_per_device_for_experience=1 \
37 | worker.actor.model.model_path=${MODEL_PATH} \
38 | worker.rollout.gpu_memory_utilization=0.6 \
39 | worker.rollout.temperature=1.0 \
40 | worker.rollout.n=$ROLLOUT_N \
41 | worker.rollout.limit_images=15 \
42 | worker.rollout.tensor_parallel_size=1 \
43 | worker.rollout.max_num_batched_tokens=128000 \
44 | algorithm.disable_kl=True \
45 | algorithm.kl_coef=0 \
46 | algorithm.enable_replay=True \
47 | env.num_envs=$NUM_ENVS \
48 | env.max_steps=15 \
49 | trainer.experiment_name=osworld_cot_7b_nokl_subset32 \
50 | trainer.n_gpus_per_node=$NUM_GPUS \
51 | trainer.nnodes=1 \
52 | trainer.save_freq=16 \
53 | trainer.save_limit=3 \
54 | trainer.val_before_train=True \
55 | trainer.val_freq=16 \
56 | trainer.total_episodes=15
57 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import re
17 |
18 | from setuptools import find_packages, setup
19 |
20 |
21 | def get_version() -> str:
22 | with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f:
23 | file_content = f.read()
24 | pattern = r"__version__\W*=\W*\"([^\"]+)\""
25 | (version,) = re.findall(pattern, file_content)
26 | return version
27 |
28 |
29 | def get_requires() -> list[str]:
30 | with open("requirements.txt", encoding="utf-8") as f:
31 | file_content = f.read()
32 | lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
33 | return lines
34 |
35 |
36 | extra_require = {
37 | "dev": ["pre-commit", "ruff"],
38 | }
39 |
40 |
41 | def main():
42 | setup(
43 | name="verl",
44 | version=get_version(),
45 | description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL",
46 | long_description=open("README.md", encoding="utf-8").read(),
47 | long_description_content_type="text/markdown",
48 | author="verl",
49 | author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
50 | license="Apache 2.0 License",
51 | url="https://github.com/volcengine/verl",
52 | package_dir={"": "."},
53 | packages=find_packages(where="."),
54 | python_requires=">=3.9.0",
55 | install_requires=get_requires(),
56 | extras_require=extra_require,
57 | )
58 |
59 |
60 | if __name__ == "__main__":
61 | main()
62 |
--------------------------------------------------------------------------------
/verl/workers/actor/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The base class for Actor
16 | """
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, Dict
20 |
21 | import torch
22 |
23 | from ...protocol import DataProto
24 | from .config import ActorConfig
25 |
26 |
27 | __all__ = ["BasePPOActor"]
28 |
29 |
30 | class BasePPOActor(ABC):
31 | def __init__(self, config: ActorConfig):
32 | """The base class for PPO actor
33 |
34 | Args:
35 | config (ActorConfig): a config passed to the PPOActor.
36 | """
37 | self.config = config
38 |
39 | @abstractmethod
40 | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
41 | """Compute logits given a batch of data.
42 |
43 | Args:
44 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
45 | ```attention_mask``` and ```position_ids```.
46 |
47 | Returns:
48 | DataProto: a DataProto containing the key ```log_probs```
49 | """
50 | pass
51 |
52 | @abstractmethod
53 | def update_policy(self, data: DataProto) -> Dict[str, Any]:
54 | """Update the policy with an iterator of DataProto
55 |
56 | Args:
57 | data (DataProto): an iterator over the DataProto that returns by
58 | ```make_minibatch_iterator```
59 |
60 | Returns:
61 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
62 | such as ```loss```, ```grad_norm```, etc,.
63 | """
64 | pass
65 |
--------------------------------------------------------------------------------
/verl/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for tokenization."""
15 |
16 | from typing import Optional
17 |
18 | from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
19 |
20 |
21 | def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer:
22 | """Create a huggingface pretrained tokenizer."""
23 | tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
24 |
25 | if tokenizer.bos_token == "" and tokenizer.eos_token == "":
26 | # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance.
27 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
28 | print("Found gemma model. Set eos_token and eos_token_id to and 107.")
29 | tokenizer.eos_token = ""
30 |
31 | if tokenizer.pad_token_id is None:
32 | print("Pad token is None. Set it to eos_token.")
33 | tokenizer.pad_token = tokenizer.eos_token
34 |
35 | return tokenizer
36 |
37 |
38 | def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]:
39 | """Create a huggingface pretrained processor."""
40 | try:
41 | processor = AutoProcessor.from_pretrained(model_path, **kwargs)
42 | except Exception:
43 | processor = None
44 |
45 | # Avoid load tokenizer, see:
46 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
47 | if processor is not None and "Processor" not in processor.__class__.__name__:
48 | processor = None
49 |
50 | return processor
51 |
--------------------------------------------------------------------------------
/verl/utils/torch_dtypes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 |
18 | HALF_LIST = [16, "16", "fp16", "float16"]
19 | FLOAT_LIST = [32, "32", "fp32", "float32"]
20 | BFLOAT_LIST = ["bf16", "bfloat16"]
21 |
22 |
23 | class PrecisionType:
24 | """Type of precision used.
25 |
26 | >>> PrecisionType.HALF == 16
27 | True
28 | >>> PrecisionType.HALF in (16, "16")
29 | True
30 | """
31 |
32 | HALF = "16"
33 | FLOAT = "32"
34 | FULL = "64"
35 | BFLOAT = "bf16"
36 | MIXED = "mixed"
37 |
38 | @staticmethod
39 | def is_fp16(precision):
40 | return precision in HALF_LIST
41 |
42 | @staticmethod
43 | def is_fp32(precision):
44 | return precision in FLOAT_LIST
45 |
46 | @staticmethod
47 | def is_bf16(precision):
48 | return precision in BFLOAT_LIST
49 |
50 | @staticmethod
51 | def to_dtype(precision) -> torch.dtype:
52 | if precision in HALF_LIST:
53 | return torch.float16
54 | elif precision in FLOAT_LIST:
55 | return torch.float32
56 | elif precision in BFLOAT_LIST:
57 | return torch.bfloat16
58 | else:
59 | raise RuntimeError(f"unexpected precision: {precision}")
60 |
61 | @staticmethod
62 | def to_str(precision: torch.dtype) -> str:
63 | if precision == torch.float16:
64 | return "float16"
65 | elif precision == torch.float32:
66 | return "float32"
67 | elif precision == torch.bfloat16:
68 | return "bfloat16"
69 | else:
70 | raise RuntimeError(f"unexpected precision: {precision}")
71 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to EasyR1
2 |
3 | Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
4 |
5 | It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
6 |
7 | However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md).
8 |
9 | **This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**
10 |
11 | ## Ways to contribute
12 |
13 | There are several ways you can contribute to EasyR1:
14 |
15 | * Fix outstanding issues with the existing code.
16 | * Submit issues related to bugs or desired new features.
17 | * Contribute to the examples or to the documentation.
18 |
19 | ### Style guide
20 |
21 | EasyR1 follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
22 |
23 | ### Create a Pull Request
24 |
25 | 1. Fork the [repository](https://github.com/hiyouga/EasyR1) by clicking on the [Fork](https://github.com/hiyouga/EasyR1/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
26 |
27 | 2. Clone your fork to your local disk, and add the base repository as a remote:
28 |
29 | ```bash
30 | git clone git@github.com:[username]/EasyR1.git
31 | cd EasyR1
32 | git remote add upstream https://github.com/hiyouga/EasyR1.git
33 | ```
34 |
35 | 3. Create a new branch to hold your development changes:
36 |
37 | ```bash
38 | git checkout -b dev_your_branch
39 | ```
40 |
41 | 4. Set up a development environment by running the following command in a virtual environment:
42 |
43 | ```bash
44 | pip install -e ".[dev]"
45 | ```
46 |
47 | 5. Check code before commit:
48 |
49 | ```bash
50 | make commit
51 | make style && make quality
52 | ```
53 |
54 | 6. Submit changes:
55 |
56 | ```bash
57 | git add .
58 | git commit -m "commit message"
59 | git fetch upstream
60 | git rebase upstream/main
61 | git push -u origin dev_your_branch
62 | ```
63 |
64 | 7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/EasyR1).
65 |
--------------------------------------------------------------------------------
/verl/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Utilities to create common models
16 | """
17 |
18 | from functools import lru_cache
19 | from typing import Optional, Tuple
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from torch import nn
24 |
25 |
26 | @lru_cache
27 | def is_rank0() -> int:
28 | return (not dist.is_initialized()) or (dist.get_rank() == 0)
29 |
30 |
31 | def print_gpu_memory_usage(prefix: str = "GPU memory usage") -> None:
32 | """Report the current GPU VRAM usage."""
33 | if is_rank0():
34 | free_mem, total_mem = torch.cuda.mem_get_info()
35 | print(f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB.")
36 |
37 |
38 | def _get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]:
39 | """Compute the model size."""
40 | n_params = sum(p.numel() for p in model.parameters())
41 |
42 | if scale == "auto":
43 | if n_params > 1e9:
44 | scale = "B"
45 | elif n_params > 1e6:
46 | scale = "M"
47 | elif n_params > 1e3:
48 | scale = "K"
49 | else:
50 | scale = ""
51 |
52 | if scale == "B":
53 | n_params = n_params / 1e9
54 | elif scale == "M":
55 | n_params = n_params / 1e6
56 | elif scale == "K":
57 | n_params = n_params / 1e3
58 | elif scale == "":
59 | pass
60 | else:
61 | raise NotImplementedError(f"Unknown scale {scale}.")
62 |
63 | return n_params, scale
64 |
65 |
66 | def print_model_size(model: nn.Module, name: Optional[str] = None) -> None:
67 | """Print the model size."""
68 | if is_rank0():
69 | n_params, scale = _get_model_size(model, scale="auto")
70 | if name is None:
71 | name = model.__class__.__name__
72 |
73 | print(f"{name} contains {n_params:.2f}{scale} parameters.")
74 |
--------------------------------------------------------------------------------
/examples/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_files: hiyouga/math12k@train
3 | val_files: hiyouga/math12k@test
4 | prompt_key: problem
5 | answer_key: answer
6 | image_key: images
7 | max_prompt_length: 2048
8 | max_response_length: 2048
9 | rollout_batch_size: 512
10 | val_batch_size: -1
11 | shuffle: true
12 | seed: 1
13 | max_pixels: 4194304
14 | min_pixels: 262144
15 |
16 | algorithm:
17 | adv_estimator: grpo
18 | disable_kl: false
19 | use_kl_loss: true
20 | kl_penalty: low_var_kl
21 | kl_coef: 1.0e-2
22 |
23 | worker:
24 | actor:
25 | global_batch_size: 128
26 | micro_batch_size_per_device_for_update: 4
27 | micro_batch_size_per_device_for_experience: 16
28 | max_grad_norm: 1.0
29 | padding_free: true
30 | ulysses_sequence_parallel_size: 1
31 | model:
32 | model_path: Qwen/Qwen2.5-7B-Instruct
33 | enable_gradient_checkpointing: true
34 | trust_remote_code: false
35 | freeze_vision_tower: false
36 | optim:
37 | lr: 1.0e-6
38 | weight_decay: 1.0e-2
39 | strategy: adamw # {adamw, adamw_bf16}
40 | lr_warmup_ratio: 0.0
41 | fsdp:
42 | enable_full_shard: true
43 | enable_cpu_offload: false
44 | enable_rank0_init: true
45 | offload:
46 | offload_params: true # true: more CPU memory; false: more GPU memory
47 | offload_optimizer: true # true: more CPU memory; false: more GPU memory
48 |
49 | rollout:
50 | temperature: 1.0
51 | n: 5
52 | gpu_memory_utilization: 0.6
53 | enforce_eager: false
54 | enable_chunked_prefill: false
55 | tensor_parallel_size: 2
56 | limit_images: 0
57 | val_override_config:
58 | temperature: 0.5
59 | n: 1
60 |
61 | ref:
62 | fsdp:
63 | enable_full_shard: true
64 | enable_cpu_offload: true # true: more CPU memory; false: more GPU memory
65 | enable_rank0_init: true
66 | offload:
67 | offload_params: false
68 |
69 | reward:
70 | reward_type: function
71 | score_function: math
72 | skip_special_tokens: true
73 |
74 | env:
75 | num_envs: 32
76 | screen_size: [1920, 1080]
77 |
78 | trainer:
79 | total_episodes: 15
80 | logger: ["console", "wandb"]
81 | project_name: easy_r1
82 | experiment_name: qwen2_5_7b_math_grpo
83 | n_gpus_per_node: 8
84 | nnodes: 1
85 | val_freq: 5 # -1 to disable
86 | val_before_train: true
87 | val_only: false
88 | val_generations_to_log: 3
89 | save_freq: 5 # -1 to disable
90 | save_limit: 3 # -1 to disable
91 | save_checkpoint_path: null
92 | load_checkpoint_path: null
93 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/fsdp_ulysses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
16 | """
17 |
18 | from torch.distributed.device_mesh import DeviceMesh
19 |
20 | from ...protocol import DataProto, all_gather_data_proto
21 | from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group
22 | from .base import BaseShardingManager
23 |
24 |
25 | class FSDPUlyssesShardingManager(BaseShardingManager):
26 | """
27 | Sharding manager to support data resharding when using FSDP + Ulysses
28 | """
29 |
30 | def __init__(self, device_mesh: DeviceMesh):
31 | super().__init__()
32 | self.device_mesh = device_mesh
33 |
34 | def __enter__(self):
35 | if self.device_mesh is not None:
36 | self.prev_sp_group = get_ulysses_sequence_parallel_group()
37 | set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group())
38 |
39 | def __exit__(self, exc_type, exc_value, traceback):
40 | if self.device_mesh is not None:
41 | set_ulysses_sequence_parallel_group(self.prev_sp_group)
42 |
43 | def preprocess_data(self, data: DataProto) -> DataProto:
44 | """
45 | AllGather data from sp region
46 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
47 | In Ulysses, we need to make sure the same data is used across a SP group
48 | """
49 | if self.device_mesh is not None:
50 | sp_size = self.device_mesh["sp"].size()
51 | sp_group = self.device_mesh["sp"].get_group()
52 | all_gather_data_proto(data, size=sp_size, group=sp_group)
53 |
54 | return data
55 |
56 | def postprocess_data(self, data: DataProto) -> DataProto:
57 | """
58 | Split the data to follow FSDP partition
59 | """
60 | if self.device_mesh is not None:
61 | sp_size = self.device_mesh["sp"].size()
62 | sp_rank = self.device_mesh["sp"].get_local_rank()
63 | data = data.chunk(chunks=sp_size)[sp_rank]
64 |
65 | return data
66 |
--------------------------------------------------------------------------------
/verl/workers/reward/custom.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from collections import defaultdict
17 | from typing import Callable, Dict, List, Tuple, TypedDict
18 |
19 | import torch
20 | from transformers import PreTrainedTokenizer
21 |
22 | from ...protocol import DataProto
23 | from ...utils.reward_score import math_compute_score, r1v_compute_score
24 | from .config import RewardConfig
25 |
26 |
27 | class RewardScore(TypedDict):
28 | overall: float
29 | format: float
30 | accuracy: float
31 |
32 |
33 | class CustomRewardManager:
34 | def __init__(self, tokenizer: PreTrainedTokenizer, config: RewardConfig):
35 | self.config = config
36 | self.tokenizer = tokenizer
37 | if config.score_function == "math":
38 | self.compute_score: Callable[[str, str], RewardScore] = math_compute_score
39 | elif config.score_function == "r1v":
40 | self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score
41 | else:
42 | raise NotImplementedError(f"Unknown score function {config.score_function}.")
43 |
44 | def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
45 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
46 | reward_metrics = defaultdict(list)
47 | for i in range(len(data)):
48 | data_item = data[i] # DataProtoItem
49 | response_ids = data_item.batch["responses"]
50 | response_mask = data_item.batch["response_mask"]
51 | valid_response_length = response_mask.sum()
52 | valid_response_ids = response_ids[:valid_response_length]
53 |
54 | response_str = self.tokenizer.decode(
55 | valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
56 | )
57 | ground_truth = data_item.non_tensor_batch["ground_truth"]
58 |
59 | score = self.compute_score(response_str, ground_truth)
60 | reward_tensor[i, valid_response_length - 1] = score["overall"]
61 | for key, value in score.items():
62 | reward_metrics[key].append(value)
63 |
64 | return reward_tensor, reward_metrics
65 |
--------------------------------------------------------------------------------
/verl/utils/py_functional.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contain small python utility functions
16 | """
17 |
18 | import importlib.util
19 | import re
20 | from functools import lru_cache
21 | from typing import Any, Dict, List, Union
22 |
23 | import numpy as np
24 | import yaml
25 | from yaml import Dumper
26 |
27 |
28 | def is_sci_notation(number: float) -> bool:
29 | pattern = re.compile(r"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$")
30 | return bool(pattern.match(str(number)))
31 |
32 |
33 | def float_representer(dumper: Dumper, number: Union[float, np.float32, np.float64]):
34 | if is_sci_notation(number):
35 | value = str(number)
36 | if "." not in value and "e" in value:
37 | value = value.replace("e", ".0e", 1)
38 | else:
39 | value = str(round(number, 3))
40 |
41 | return dumper.represent_scalar("tag:yaml.org,2002:float", value)
42 |
43 |
44 | yaml.add_representer(float, float_representer)
45 | yaml.add_representer(np.float32, float_representer)
46 | yaml.add_representer(np.float64, float_representer)
47 |
48 |
49 | @lru_cache
50 | def is_package_available(name: str) -> bool:
51 | return importlib.util.find_spec(name) is not None
52 |
53 |
54 | def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
55 | """Union two dict. Will throw an error if there is an item not the same object with the same key."""
56 | for key in dict2.keys():
57 | if key in dict1:
58 | assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object"
59 |
60 | dict1[key] = dict2[key]
61 |
62 | return dict1
63 |
64 |
65 | def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None:
66 | """Append dict to a dict of list."""
67 | for key, val in new_data.items():
68 | if key not in data:
69 | data[key] = []
70 |
71 | data[key].append(val)
72 |
73 |
74 | def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]:
75 | unflattened = {}
76 | for key, value in data.items():
77 | pieces = key.split(sep)
78 | pointer = unflattened
79 | for piece in pieces[:-1]:
80 | if piece not in pointer:
81 | pointer[piece] = {}
82 |
83 | pointer = pointer[piece]
84 |
85 | pointer[pieces[-1]] = value
86 |
87 | return unflattened
88 |
89 |
90 | def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]:
91 | flattened = {}
92 | for key, value in data.items():
93 | new_key = parent_key + sep + key if parent_key else key
94 | if isinstance(value, dict):
95 | flattened.update(flatten_dict(value, new_key, sep=sep))
96 | else:
97 | flattened[new_key] = value
98 |
99 | return flattened
100 |
101 |
102 | def convert_dict_to_str(data: Dict[str, Any]) -> str:
103 | return yaml.dump(data, indent=2)
104 |
--------------------------------------------------------------------------------
/verl/workers/actor/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Actor config
16 | """
17 |
18 | from dataclasses import dataclass, field
19 | from typing import Any, Dict, Optional, Tuple
20 |
21 |
22 | @dataclass
23 | class ModelConfig:
24 | model_path: Optional[str] = None
25 | tokenizer_path: Optional[str] = None
26 | override_config: Dict[str, Any] = field(default_factory=dict)
27 | enable_gradient_checkpointing: bool = True
28 | trust_remote_code: bool = True
29 | freeze_vision_tower: bool = False
30 |
31 | def post_init(self):
32 | if self.tokenizer_path is None:
33 | self.tokenizer_path = self.model_path
34 |
35 |
36 | @dataclass
37 | class OptimConfig:
38 | lr: float = 1e-6
39 | betas: Tuple[float, float] = (0.9, 0.999)
40 | weight_decay: float = 1e-2
41 | strategy: str = "adamw"
42 | lr_warmup_ratio: float = 0.0
43 | min_lr_ratio: Optional[float] = None
44 | warmup_style: str = "constant"
45 | """auto keys"""
46 | training_steps: int = field(default=-1, init=False)
47 |
48 |
49 | @dataclass
50 | class FSDPConfig:
51 | enable_full_shard: bool = True
52 | enable_cpu_offload: bool = False
53 | enable_rank0_init: bool = False
54 | use_orig_params: bool = False
55 | torch_dtype: Optional[str] = None
56 | fsdp_size: int = -1
57 | mp_param_dtype: str = "bf16"
58 | mp_reduce_dtype: str = "fp32"
59 | mp_buffer_dtype: str = "fp32"
60 |
61 |
62 | @dataclass
63 | class OffloadConfig:
64 | offload_params: bool = False
65 | offload_optimizer: bool = False
66 |
67 |
68 | @dataclass
69 | class ActorConfig:
70 | strategy: str = "fsdp"
71 | global_batch_size: int = 256
72 | micro_batch_size_per_device_for_update: int = 4
73 | micro_batch_size_per_device_for_experience: int = 16
74 | max_grad_norm: float = 1.0
75 | clip_ratio_low: float = 0.2
76 | clip_ratio_high: float = 0.3
77 | clip_ratio_dual: float = 3.0
78 | ppo_epochs: int = 1
79 | padding_free: bool = False
80 | ulysses_sequence_parallel_size: int = 1
81 | use_torch_compile: bool = True
82 | model: ModelConfig = field(default_factory=ModelConfig)
83 | optim: OptimConfig = field(default_factory=OptimConfig)
84 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
85 | offload: OffloadConfig = field(default_factory=OffloadConfig)
86 | """auto keys"""
87 | global_batch_size_per_device: int = field(default=-1, init=False)
88 | disable_kl: bool = field(default=False, init=False)
89 | use_kl_loss: bool = field(default=False, init=False)
90 | kl_penalty: str = field(default="kl", init=False)
91 | kl_coef: float = field(default=0.0, init=False)
92 |
93 |
94 | @dataclass
95 | class RefConfig:
96 | strategy: str = "fsdp"
97 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
98 | offload: OffloadConfig = field(default_factory=OffloadConfig)
99 | """auto keys"""
100 | micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
101 | padding_free: bool = field(default=False, init=False)
102 | ulysses_sequence_parallel_size: int = field(default=1, init=False)
103 | use_torch_compile: bool = field(default=True, init=False)
104 |
--------------------------------------------------------------------------------
/verl/utils/logger/gen_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from abc import ABC, abstractmethod
17 | from dataclasses import dataclass
18 | from typing import List, Tuple
19 |
20 | from ..py_functional import is_package_available
21 |
22 |
23 | if is_package_available("wandb"):
24 | import wandb # type: ignore
25 |
26 |
27 | if is_package_available("swanlab"):
28 | import swanlab # type: ignore
29 |
30 |
31 | @dataclass
32 | class GenerationLogger(ABC):
33 | @abstractmethod
34 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: ...
35 |
36 |
37 | @dataclass
38 | class ConsoleGenerationLogger(GenerationLogger):
39 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
40 | for inp, out, lab, score in samples:
41 | print(f"[prompt] {inp}\n[output] {out}\n[ground_truth] {lab}\n[score] {score}\n")
42 |
43 |
44 | @dataclass
45 | class WandbGenerationLogger(GenerationLogger):
46 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
47 | # Create column names for all samples
48 | columns = ["step"] + sum(
49 | [[f"input_{i + 1}", f"output_{i + 1}", f"label_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))],
50 | [],
51 | )
52 |
53 | if not hasattr(self, "validation_table"):
54 | # Initialize the table on first call
55 | self.validation_table = wandb.Table(columns=columns)
56 |
57 | # Create a new table with same columns and existing data
58 | # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
59 | new_table = wandb.Table(columns=columns, data=self.validation_table.data)
60 |
61 | # Add new row with all data
62 | row_data = [step]
63 | for sample in samples:
64 | row_data.extend(sample)
65 |
66 | new_table.add_data(*row_data)
67 | wandb.log({"val/generations": new_table}, step=step)
68 | self.validation_table = new_table
69 |
70 |
71 | @dataclass
72 | class SwanlabGenerationLogger(GenerationLogger):
73 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
74 | swanlab_text_list = []
75 | for i, sample in enumerate(samples):
76 | row_text = "\n\n---\n\n".join(
77 | (f"input: {sample[0]}", f"output: {sample[1]}", f"label: {sample[2]}", f"score: {sample[3]}")
78 | )
79 | swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
80 |
81 | swanlab.log({"val/generations": swanlab_text_list}, step=step)
82 |
83 |
84 | GEN_LOGGERS = {
85 | "console": ConsoleGenerationLogger,
86 | "wandb": WandbGenerationLogger,
87 | "swanlab": SwanlabGenerationLogger,
88 | }
89 |
90 |
91 | @dataclass
92 | class AggregateGenerationsLogger:
93 | def __init__(self, loggers: List[str]):
94 | self.loggers: List[GenerationLogger] = []
95 |
96 | for logger in loggers:
97 | if logger in GEN_LOGGERS:
98 | self.loggers.append(GEN_LOGGERS[logger]())
99 |
100 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
101 | for logger in self.loggers:
102 | logger.log(samples, step)
103 |
--------------------------------------------------------------------------------
/verl/trainer/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | PPO config
16 | """
17 |
18 | import os
19 | from dataclasses import asdict, dataclass, field, fields, is_dataclass
20 | from typing import Optional, Tuple
21 |
22 | from ..workers.config import WorkerConfig
23 |
24 |
25 | def recursive_post_init(dataclass_obj):
26 | if hasattr(dataclass_obj, "post_init"):
27 | dataclass_obj.post_init()
28 |
29 | for attr in fields(dataclass_obj):
30 | if is_dataclass(getattr(dataclass_obj, attr.name)):
31 | recursive_post_init(getattr(dataclass_obj, attr.name))
32 |
33 |
34 | @dataclass
35 | class DataConfig:
36 | train_files: str = ""
37 | val_files: str = ""
38 | prompt_key: str = "prompt"
39 | answer_key: str = "answer"
40 | image_key: str = "images"
41 | max_prompt_length: int = 512
42 | max_response_length: int = 512
43 | rollout_batch_size: int = 512
44 | val_batch_size: int = -1
45 | format_prompt: Optional[str] = None
46 | shuffle: bool = True
47 | seed: int = 1
48 | max_pixels: int = 4194304
49 | min_pixels: int = 262144
50 |
51 |
52 | @dataclass
53 | class AlgorithmConfig:
54 | gamma: float = 1.0
55 | lam: float = 1.0
56 | adv_estimator: str = "grpo"
57 | disable_kl: bool = False
58 | use_kl_loss: bool = False
59 | kl_penalty: str = "kl"
60 | kl_coef: float = 1e-3
61 | kl_type: str = "fixed"
62 | kl_horizon: float = 0.0
63 | kl_target: float = 0.0
64 | enable_replay: bool = False
65 |
66 |
67 | @dataclass
68 | class TrainerConfig:
69 | total_episodes: int = 10
70 | max_steps: Optional[int] = None
71 | project_name: str = "easy_r1"
72 | experiment_name: str = "demo"
73 | logger: Tuple[str] = ("console", "wandb")
74 | nnodes: int = 1
75 | n_gpus_per_node: int = 8
76 | critic_warmup: int = 0
77 | val_freq: int = -1
78 | val_before_train: bool = True
79 | val_only: bool = False
80 | val_generations_to_log: int = 0
81 | save_freq: int = -1
82 | save_limit: int = -1
83 | save_checkpoint_path: Optional[str] = None
84 | load_checkpoint_path: Optional[str] = None
85 |
86 | def post_init(self):
87 | if self.save_checkpoint_path is None:
88 | self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
89 |
90 | @dataclass
91 | class EnvConfig:
92 | num_envs: int = 32
93 | max_steps: int = 15
94 | screen_size: Tuple[int, int] = (1920, 1080)
95 |
96 | @dataclass
97 | class PPOConfig:
98 | data: DataConfig = field(default_factory=DataConfig)
99 | worker: WorkerConfig = field(default_factory=WorkerConfig)
100 | algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
101 | trainer: TrainerConfig = field(default_factory=TrainerConfig)
102 | env: EnvConfig = field(default_factory=EnvConfig)
103 |
104 | def post_init(self):
105 | self.worker.rollout.prompt_length = self.data.max_prompt_length
106 | self.worker.rollout.response_length = self.data.max_response_length
107 | self.worker.actor.disable_kl = self.algorithm.disable_kl
108 | self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
109 | self.worker.actor.kl_penalty = self.algorithm.kl_penalty
110 | self.worker.actor.kl_coef = self.algorithm.kl_coef
111 |
112 | def deep_post_init(self):
113 | recursive_post_init(self)
114 |
115 | def to_dict(self):
116 | return asdict(self)
117 |
--------------------------------------------------------------------------------
/verl/trainer/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
16 | """
17 |
18 | import json
19 |
20 | import ray
21 | from omegaconf import OmegaConf
22 |
23 | from ..single_controller.ray import RayWorkerGroup
24 | from ..utils.tokenizer import get_processor, get_tokenizer
25 | from ..workers.fsdp_workers import FSDPWorker
26 | from ..workers.reward import CustomRewardManager
27 | from .config import PPOConfig
28 | from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
29 |
30 |
31 | @ray.remote(num_cpus=1)
32 | class Runner:
33 | """A runner for RL training."""
34 |
35 | def run(self, config: PPOConfig):
36 | # print config
37 | config.deep_post_init()
38 | print(json.dumps(config.to_dict(), indent=2))
39 |
40 | # instantiate tokenizer
41 | tokenizer = get_tokenizer(
42 | config.worker.actor.model.model_path,
43 | trust_remote_code=config.worker.actor.model.trust_remote_code,
44 | use_fast=True,
45 | )
46 | processor = get_processor(
47 | config.worker.actor.model.model_path,
48 | trust_remote_code=config.worker.actor.model.trust_remote_code,
49 | use_fast=True,
50 | )
51 |
52 | # define worker classes
53 | ray_worker_group_cls = RayWorkerGroup
54 | role_worker_mapping = {
55 | Role.ActorRollout: ray.remote(FSDPWorker),
56 | Role.Critic: ray.remote(FSDPWorker),
57 | Role.RefPolicy: ray.remote(FSDPWorker),
58 | }
59 | global_pool_id = "global_pool"
60 | resource_pool_spec = {
61 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
62 | }
63 | mapping = {
64 | Role.ActorRollout: global_pool_id,
65 | Role.Critic: global_pool_id,
66 | Role.RefPolicy: global_pool_id,
67 | }
68 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
69 |
70 | reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward)
71 | val_reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward)
72 |
73 | trainer = RayPPOTrainer(
74 | config=config,
75 | tokenizer=tokenizer,
76 | processor=processor,
77 | role_worker_mapping=role_worker_mapping,
78 | resource_pool_manager=resource_pool_manager,
79 | ray_worker_group_cls=ray_worker_group_cls,
80 | reward_fn=reward_fn,
81 | val_reward_fn=val_reward_fn,
82 | )
83 | trainer.init_workers()
84 | trainer.fit()
85 |
86 |
87 | def main():
88 | cli_args = OmegaConf.from_cli()
89 | default_config = OmegaConf.structured(PPOConfig())
90 |
91 | if hasattr(cli_args, "config"):
92 | config_path = cli_args.pop("config", None)
93 | file_config = OmegaConf.load(config_path)
94 | default_config = OmegaConf.merge(default_config, file_config)
95 |
96 | ppo_config = OmegaConf.merge(default_config, cli_args)
97 | ppo_config = OmegaConf.to_object(ppo_config)
98 |
99 | if not ray.is_initialized():
100 | # this is for local ray cluster
101 | ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
102 |
103 | print(ray.cluster_resources().keys())
104 |
105 | runner = Runner.remote()
106 | ray.get(runner.run.remote(ppo_config))
107 |
108 |
109 | if __name__ == "__main__":
110 | main()
111 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # outputs
174 | outputs/
175 | checkpoints/
176 | wandb/
177 |
--------------------------------------------------------------------------------
/verl/utils/logger/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A unified tracking interface that supports logging data to different backend
16 | """
17 |
18 | import os
19 | from abc import ABC, abstractmethod
20 | from typing import Any, Dict, List, Optional, Tuple, Union
21 |
22 | from torch.utils.tensorboard import SummaryWriter
23 |
24 | from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict
25 | from .gen_logger import AggregateGenerationsLogger
26 |
27 |
28 | if is_package_available("mlflow"):
29 | import mlflow # type: ignore
30 |
31 |
32 | if is_package_available("wandb"):
33 | import wandb # type: ignore
34 |
35 |
36 | if is_package_available("swanlab"):
37 | import swanlab # type: ignore
38 |
39 |
40 | class Logger(ABC):
41 | @abstractmethod
42 | def __init__(self, config: Dict[str, Any]) -> None: ...
43 |
44 | @abstractmethod
45 | def log(self, data: Dict[str, Any], step: int) -> None: ...
46 |
47 | def finish(self) -> None:
48 | pass
49 |
50 |
51 | class ConsoleLogger(Logger):
52 | def __init__(self, config: Dict[str, Any]) -> None:
53 | print("Config\n" + convert_dict_to_str(config))
54 |
55 | def log(self, data: Dict[str, Any], step: int) -> None:
56 | print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data)))
57 |
58 |
59 | class MlflowLogger(Logger):
60 | def __init__(self, config: Dict[str, Any]) -> None:
61 | mlflow.start_run(run_name=config["trainer"]["experiment_name"])
62 | mlflow.log_params(flatten_dict(config))
63 |
64 | def log(self, data: Dict[str, Any], step: int) -> None:
65 | mlflow.log_metrics(metrics=data, step=step)
66 |
67 |
68 | class TensorBoardLogger(Logger):
69 | def __init__(self, config: Dict[str, Any]) -> None:
70 | tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log")
71 | os.makedirs(tensorboard_dir, exist_ok=True)
72 | print(f"Saving tensorboard log to {tensorboard_dir}.")
73 | self.writer = SummaryWriter(tensorboard_dir)
74 | self.writer.add_hparams(flatten_dict(config))
75 |
76 | def log(self, data: Dict[str, Any], step: int) -> None:
77 | for key, value in data.items():
78 | self.writer.add_scalar(key, value, step)
79 |
80 | def finish(self):
81 | self.writer.close()
82 |
83 |
84 | class WandbLogger(Logger):
85 | def __init__(self, config: Dict[str, Any]) -> None:
86 | wandb.init(
87 | project=config["trainer"]["project_name"],
88 | name=config["trainer"]["experiment_name"],
89 | config=config,
90 | )
91 |
92 | def log(self, data: Dict[str, Any], step: int) -> None:
93 | wandb.log(data=data, step=step)
94 |
95 | def finish(self) -> None:
96 | wandb.finish()
97 |
98 |
99 | class SwanlabLogger(Logger):
100 | def __init__(self, config: Dict[str, Any]) -> None:
101 | swanlab_key = os.getenv("SWANLAB_API_KEY")
102 | swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log")
103 | swanlab_mode = os.getenv("SWANLAB_MODE", "cloud")
104 | if swanlab_key:
105 | swanlab.login(swanlab_key)
106 |
107 | swanlab.init(
108 | project=config["trainer"]["project_name"],
109 | experiment_name=config["trainer"]["experiment_name"],
110 | config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config},
111 | logdir=swanlab_dir,
112 | mode=swanlab_mode,
113 | )
114 |
115 | def log(self, data: Dict[str, Any], step: int) -> None:
116 | swanlab.log(data=data, step=step)
117 |
118 | def finish(self) -> None:
119 | swanlab.finish()
120 |
121 |
122 | LOGGERS = {
123 | "wandb": WandbLogger,
124 | "mlflow": MlflowLogger,
125 | "tensorboard": TensorBoardLogger,
126 | "console": ConsoleLogger,
127 | "swanlab": SwanlabLogger,
128 | }
129 |
130 |
131 | class Tracker:
132 | def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None):
133 | if isinstance(loggers, str):
134 | loggers = [loggers]
135 |
136 | self.loggers: List[Logger] = []
137 | for logger in loggers:
138 | if logger not in LOGGERS:
139 | raise ValueError(f"{logger} is not supported.")
140 |
141 | self.loggers.append(LOGGERS[logger](config))
142 |
143 | self.gen_logger = AggregateGenerationsLogger(loggers)
144 |
145 | def log(self, data: Dict[str, Any], step: int) -> None:
146 | for logger in self.loggers:
147 | logger.log(data=data, step=step)
148 |
149 | def log_generation(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
150 | self.gen_logger.log(samples, step)
151 |
152 | def __del__(self):
153 | for logger in self.loggers:
154 | logger.finish()
155 |
--------------------------------------------------------------------------------
/verl/utils/fsdp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import gc
16 | from collections import defaultdict
17 | from functools import partial
18 | from typing import Callable, Union
19 |
20 | import torch
21 | from torch import nn
22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
23 | from torch.distributed.fsdp._runtime_utils import _lazy_init
24 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
25 | from torch.optim import Optimizer
26 | from transformers import PreTrainedModel
27 | from transformers.trainer_pt_utils import get_module_class_from_name
28 |
29 |
30 | def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]:
31 | param_occurrence = defaultdict(int)
32 | for _, param in model.named_parameters(remove_duplicate=False):
33 | param_occurrence[param] += 1
34 |
35 | duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1}
36 | materialized_params = {}
37 |
38 | def init_fn(module: nn.Module):
39 | for name, param in module.named_parameters(recurse=False):
40 | if param in duplicated_params:
41 | module._parameters[name] = materialized_params.setdefault(
42 | param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad)
43 | )
44 | else:
45 | module._parameters[name] = nn.Parameter(
46 | torch.empty_like(param.data, device=device), requires_grad=param.requires_grad
47 | )
48 |
49 | return init_fn
50 |
51 |
52 | def get_fsdp_wrap_policy(model: PreTrainedModel):
53 | """Get FSDP wrap policy for the model.
54 |
55 | Args:
56 | module: The module to get wrap policy for
57 | """
58 | transformer_cls_to_wrap = set()
59 | for module in model._no_split_modules:
60 | transformer_cls = get_module_class_from_name(model, module)
61 | if transformer_cls is None:
62 | raise Exception(f"Cannot find {module} in pretrained model.")
63 | else:
64 | transformer_cls_to_wrap.add(transformer_cls)
65 |
66 | return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap)
67 |
68 |
69 | @torch.no_grad()
70 | def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
71 | # lazy init FSDP model
72 | _lazy_init(model, model)
73 | assert model._is_root, "Only support root model offloading to CPU"
74 | for handle in model._all_handles:
75 | if handle._offload_params:
76 | continue
77 |
78 | flat_param = handle.flat_param
79 | assert (
80 | flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()
81 | and id(flat_param.data) != id(flat_param._local_shard)
82 | and flat_param.data.size() == flat_param._local_shard.size()
83 | )
84 | handle.flat_param_to("cpu", non_blocking=True)
85 | # the following still keeps id(._local_shard) != id(.data)
86 | flat_param._local_shard = flat_param.data
87 | assert id(flat_param._local_shard) != id(flat_param.data)
88 |
89 | if empty_cache:
90 | torch.cuda.empty_cache()
91 |
92 |
93 | @torch.no_grad()
94 | def load_fsdp_model(model: FSDP, empty_cache: bool = True):
95 | # lazy init FSDP model
96 | _lazy_init(model, model)
97 | assert model._is_root, "Only support root model loading to GPU"
98 | for handle in model._all_handles:
99 | if handle._offload_params:
100 | continue
101 |
102 | flat_param = handle.flat_param
103 | handle.flat_param_to("cuda", non_blocking=True)
104 | # the following still keeps id(._local_shard) != id(.data)
105 | flat_param._local_shard = flat_param.data
106 |
107 | if empty_cache:
108 | gc.collect()
109 |
110 |
111 | @torch.no_grad()
112 | def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
113 | if not optimizer.state:
114 | return
115 |
116 | for param_group in optimizer.param_groups:
117 | for param in param_group["params"]:
118 | state = optimizer.state[param]
119 | for key, value in state.items():
120 | if isinstance(value, torch.Tensor):
121 | state[key] = value.to("cpu", non_blocking=True)
122 |
123 | if empty_cache:
124 | torch.cuda.empty_cache()
125 |
126 |
127 | @torch.no_grad()
128 | def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
129 | if not optimizer.state:
130 | return
131 |
132 | for param_group in optimizer.param_groups:
133 | for param in param_group["params"]:
134 | state = optimizer.state[param]
135 | for key, value in state.items():
136 | if isinstance(value, torch.Tensor):
137 | state[key] = value.to("cuda", non_blocking=True)
138 |
139 | if empty_cache:
140 | gc.collect()
141 |
--------------------------------------------------------------------------------
/verl/utils/flops_counter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING, List, Tuple
16 |
17 | import torch
18 |
19 |
20 | if TYPE_CHECKING:
21 | from transformers.models.llama.configuration_llama import LlamaConfig
22 |
23 |
24 | VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
25 |
26 |
27 | def get_device_flops(unit: str = "T") -> float:
28 | def unit_convert(number: float, level: str):
29 | units = ["B", "K", "M", "G", "T", "P"]
30 | if number <= 0:
31 | return number
32 |
33 | ptr = 0
34 | while ptr < len(units) and units[ptr] != level:
35 | number /= 1000
36 | ptr += 1
37 |
38 | return number
39 |
40 | device_name = torch.cuda.get_device_name()
41 | flops = float("inf") # INF flops for unkown gpu type
42 | if "H100" in device_name or "H800" in device_name:
43 | flops = 989e12
44 | elif "A100" in device_name or "A800" in device_name:
45 | flops = 312e12
46 | elif "L40" in device_name:
47 | flops = 181.05e12
48 | elif "L20" in device_name:
49 | flops = 119.5e12
50 | elif "H20" in device_name:
51 | flops = 148e12
52 | elif "910B" in device_name:
53 | flops = 354e12
54 | flops_unit = unit_convert(flops, unit)
55 | return flops_unit
56 |
57 |
58 | class FlopsCounter:
59 | """
60 | Used to count mfu during training loop
61 |
62 | Example:
63 | flops_counter = FlopsCounter(config)
64 | flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
65 | """
66 |
67 | def __init__(self, config: "LlamaConfig"):
68 | if config.model_type not in VALID_MODLE_TYPE:
69 | print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.")
70 |
71 | self.estimate_func = {
72 | "llama": self._estimate_llama_flops,
73 | "qwen2": self._estimate_llama_flops,
74 | "qwen2_vl": self._estimate_llama_flops,
75 | "qwen2_5_vl": self._estimate_llama_flops,
76 | }
77 | self.config = config
78 |
79 | def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
80 | return 0
81 |
82 | def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
83 | hidden_size = self.config.hidden_size
84 | vocab_size = self.config.vocab_size
85 | num_hidden_layers = self.config.num_hidden_layers
86 | num_key_value_heads = self.config.num_key_value_heads
87 | num_attention_heads = self.config.num_attention_heads
88 | intermediate_size = self.config.intermediate_size
89 |
90 | head_dim = hidden_size // num_attention_heads
91 | q_size = num_attention_heads * head_dim
92 | k_size = num_key_value_heads * head_dim
93 | v_size = num_key_value_heads * head_dim
94 |
95 | # non-attn per layer parm
96 | # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
97 | mlp_N = hidden_size * intermediate_size * 3
98 | attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
99 | emd_and_lm_head_N = vocab_size * hidden_size * 2
100 | # non-attn all_layer parm
101 | dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
102 | # non-attn all_layer & all_token fwd & bwd flops
103 | dense_N_flops = 6 * dense_N * tokens_sum
104 |
105 | # attn all_layer & all_token fwd & bwd flops
106 | seqlen_square_sum = 0
107 | for seqlen in batch_seqlens:
108 | seqlen_square_sum += seqlen * seqlen
109 |
110 | attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
111 |
112 | # all_layer & all_token fwd & bwd flops
113 | flops_all_token = dense_N_flops + attn_qkv_flops
114 | flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
115 | return flops_achieved
116 |
117 | def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]:
118 | """
119 | Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
120 |
121 | Args:
122 | batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
123 | delta_time (float): The time taken to process the batch, in seconds.
124 |
125 | Returns:
126 | estimated_flops (float): The estimated FLOPS based on the input tokens and time.
127 | promised_flops (float): The expected FLOPS of the current device.
128 | """
129 | tokens_sum = sum(batch_seqlens)
130 | func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
131 | estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
132 | promised_flops = get_device_flops()
133 | return estimated_flops, promised_flops
134 |
--------------------------------------------------------------------------------
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | `hoshihiyouga AT gmail DOT com`.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/verl/utils/checkpoint/checkpoint_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import random
17 | import re
18 | import shutil
19 | import tempfile
20 | from abc import ABC, abstractmethod
21 | from typing import Any, Dict, Optional, Union
22 |
23 | import numpy as np
24 | import torch
25 | import torch.distributed as dist
26 | from filelock import FileLock
27 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
28 | from transformers import PreTrainedTokenizer, ProcessorMixin
29 |
30 |
31 | CHECKPOINT_TRACKER = "latest_global_step.txt"
32 |
33 |
34 | class BaseCheckpointManager(ABC):
35 | """
36 | A checkpoint manager that saves and loads
37 | - model
38 | - optimizer
39 | - lr_scheduler
40 | - extra_states
41 | in a SPMD way.
42 |
43 | We save
44 | - sharded model states and optimizer states
45 | - full lr_scheduler states
46 | - huggingface tokenizer and config for ckpt merge
47 | """
48 |
49 | def __init__(
50 | self,
51 | model: FSDP,
52 | optimizer: torch.optim.Optimizer,
53 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
54 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
55 | ):
56 | self.model = model
57 | self.optimizer = optimizer
58 | self.lr_scheduler = lr_scheduler
59 | self.processing_class = processing_class
60 |
61 | assert isinstance(self.model, FSDP)
62 | self.rank = dist.get_rank()
63 | self.world_size = dist.get_world_size()
64 |
65 | @abstractmethod
66 | def load_checkpoint(self, *args, **kwargs):
67 | raise NotImplementedError
68 |
69 | @abstractmethod
70 | def save_checkpoint(self, *args, **kwargs):
71 | raise NotImplementedError
72 |
73 | @staticmethod
74 | def local_mkdir(path: str) -> str:
75 | if not os.path.isabs(path):
76 | working_dir = os.getcwd()
77 | path = os.path.join(working_dir, path)
78 |
79 | # Using hash value of path as lock file name to avoid long file name
80 | lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock"
81 | lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
82 |
83 | try:
84 | with FileLock(lock_path, timeout=60):
85 | os.makedirs(path, exist_ok=True)
86 | except Exception as e:
87 | print(f"Warning: Failed to acquire lock for {path}: {e}")
88 | os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory
89 |
90 | return path
91 |
92 | @staticmethod
93 | def get_rng_state() -> Dict[str, Any]:
94 | rng_state = {
95 | "cpu": torch.get_rng_state(),
96 | "cuda": torch.cuda.get_rng_state(),
97 | "numpy": np.random.get_state(),
98 | "random": random.getstate(),
99 | }
100 | return rng_state
101 |
102 | @staticmethod
103 | def load_rng_state(rng_state: Dict[str, Any]):
104 | torch.set_rng_state(rng_state["cpu"])
105 | torch.cuda.set_rng_state(rng_state["cuda"])
106 | np.random.set_state(rng_state["numpy"])
107 | random.setstate(rng_state["random"])
108 |
109 |
110 | def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]:
111 | if path is None:
112 | return None
113 |
114 | tracker_file = get_checkpoint_tracker_filename(path)
115 | if not os.path.exists(tracker_file):
116 | print("Checkpoint tracker file does not exist: %s", tracker_file)
117 | return None
118 |
119 | with open(tracker_file, "rb") as f:
120 | iteration = int(f.read().decode())
121 |
122 | ckpt_path = os.path.join(path, directory_format.format(iteration))
123 | if not os.path.exists(ckpt_path):
124 | print("Checkpoint does not exist: %s", ckpt_path)
125 | return None
126 |
127 | print("Found checkpoint: %s", ckpt_path)
128 | return ckpt_path
129 |
130 |
131 | def get_checkpoint_tracker_filename(root_path: str) -> str:
132 | """
133 | Tracker file rescords the latest chckpoint during training to restart from.
134 | """
135 | return os.path.join(root_path, CHECKPOINT_TRACKER)
136 |
137 |
138 | def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"):
139 | """
140 | Remove the obsolete checkpoints that exceed the save_limit.
141 | """
142 | if save_limit <= 0:
143 | return
144 |
145 | if not os.path.exists(path):
146 | return
147 |
148 | pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)")
149 | ckpt_folders = []
150 | for folder in os.listdir(path):
151 | if match := re.match(pattern, folder):
152 | step = int(match.group(1))
153 | if step < global_step:
154 | ckpt_folders.append((step, folder))
155 |
156 | ckpt_folders.sort(reverse=True)
157 | for _, folder in ckpt_folders[save_limit - 1 :]:
158 | folder_path = os.path.join(path, folder)
159 | shutil.rmtree(folder_path, ignore_errors=True)
160 | print(f"Removed obsolete checkpoint: {folder_path}")
161 |
--------------------------------------------------------------------------------
/verl/trainer/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List
16 |
17 | import numpy as np
18 | import torch
19 |
20 | from ..protocol import DataProto
21 |
22 |
23 | def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
24 | return {key: np.mean(value) for key, value in metrics.items()}
25 |
26 |
27 | def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]:
28 | sequence_score = batch.batch["token_level_scores"].sum(-1)
29 | sequence_reward = batch.batch["token_level_rewards"].sum(-1)
30 |
31 | advantages = batch.batch["advantages"]
32 | returns = batch.batch["returns"]
33 |
34 | max_response_length = batch.batch["responses"].size(-1)
35 |
36 | # prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
37 | # response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
38 | prompt_mask = (torch.logical_and(batch.batch["attention_mask"], batch.batch["labels"] == -100)).bool()
39 | response_mask = (batch.batch["labels"] != -100).bool()
40 |
41 | max_prompt_length = prompt_mask.size(-1)
42 | prompt_length = prompt_mask.sum(-1).float()
43 | response_length = response_mask.sum(-1).float()
44 | num_images = (batch.batch["input_ids"] == 151655).bool().sum(-1).float() // 2691 # image_pad
45 | response_length = response_length / num_images # average response length per action
46 |
47 | valid_adv = torch.masked_select(advantages, response_mask)
48 | valid_returns = torch.masked_select(returns, response_mask)
49 |
50 | if use_critic:
51 | values = batch.batch["values"]
52 | valid_values = torch.masked_select(values, response_mask)
53 | return_diff_var = torch.var(valid_returns - valid_values)
54 | return_var = torch.var(valid_returns)
55 |
56 | metrics = {
57 | # score
58 | "critic/score/mean": torch.mean(sequence_score).detach().item(),
59 | "critic/score/max": torch.max(sequence_score).detach().item(),
60 | "critic/score/min": torch.min(sequence_score).detach().item(),
61 | # reward
62 | "critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
63 | "critic/rewards/max": torch.max(sequence_reward).detach().item(),
64 | "critic/rewards/min": torch.min(sequence_reward).detach().item(),
65 | # adv
66 | "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
67 | "critic/advantages/max": torch.max(valid_adv).detach().item(),
68 | "critic/advantages/min": torch.min(valid_adv).detach().item(),
69 | # returns
70 | "critic/returns/mean": torch.mean(valid_returns).detach().item(),
71 | "critic/returns/max": torch.max(valid_returns).detach().item(),
72 | "critic/returns/min": torch.min(valid_returns).detach().item(),
73 | **(
74 | {
75 | # values
76 | "critic/values/mean": torch.mean(valid_values).detach().item(),
77 | "critic/values/max": torch.max(valid_values).detach().item(),
78 | "critic/values/min": torch.min(valid_values).detach().item(),
79 | # vf explained var
80 | "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
81 | }
82 | if use_critic
83 | else {}
84 | ),
85 | # response length
86 | "response_length/num_images": torch.mean(num_images).detach().item(),
87 | "response_length/mean": torch.mean(response_length).detach().item(),
88 | "response_length/max": torch.max(response_length).detach().item(),
89 | "response_length/min": torch.min(response_length).detach().item(),
90 | "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
91 | .detach()
92 | .item(),
93 | # prompt length
94 | "prompt_length/mean": torch.mean(prompt_length).detach().item(),
95 | "prompt_length/max": torch.max(prompt_length).detach().item(),
96 | "prompt_length/min": torch.min(prompt_length).detach().item(),
97 | "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
98 | }
99 | return metrics
100 |
101 |
102 | def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
103 | num_response_tokens = torch.sum(batch.batch["response_mask"]).item()
104 | num_overall_tokens = sum(batch.meta_info["global_token_num"])
105 | num_tokens_of_section = {
106 | **dict.fromkeys(["gen", "reward"], num_response_tokens),
107 | **dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens),
108 | }
109 | return {
110 | **{f"timing_s/{name}": value for name, value in timing_raw.items()},
111 | **{
112 | f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
113 | for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
114 | },
115 | }
116 |
117 |
118 | def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
119 | total_num_tokens = sum(batch.meta_info["global_token_num"])
120 | time = timing_raw["step"]
121 | return {
122 | "perf/total_num_tokens": total_num_tokens,
123 | "perf/time_per_step": time,
124 | "perf/throughput": total_num_tokens / (time * n_gpus),
125 | }
126 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/fsdp_vllm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from typing import Dict, Iterable, Tuple, Union
17 |
18 | import torch
19 | import torch.distributed as dist
20 | from torch.distributed._tensor import DTensor
21 | from torch.distributed.device_mesh import DeviceMesh
22 | from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
23 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
24 | from vllm import LLM
25 | from vllm.distributed import parallel_state as vllm_ps
26 |
27 | from ...protocol import DataProto, all_gather_data_proto
28 | from ...utils.model_utils import print_gpu_memory_usage
29 | from .base import BaseShardingManager
30 |
31 |
32 | class FSDPVLLMShardingManager(BaseShardingManager):
33 | def __init__(
34 | self,
35 | module: FSDP,
36 | inference_engine: LLM,
37 | device_mesh: DeviceMesh = None,
38 | ):
39 | self.module = module
40 | self.inference_engine = inference_engine
41 | self.device_mesh = device_mesh
42 | with warnings.catch_warnings():
43 | warnings.simplefilter("ignore")
44 | FSDP.set_state_dict_type(
45 | self.module,
46 | state_dict_type=StateDictType.SHARDED_STATE_DICT,
47 | state_dict_config=ShardedStateDictConfig(),
48 | )
49 |
50 | self.world_size = dist.get_world_size()
51 | self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
52 | self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
53 | self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group
54 |
55 | # Record freed bytes to estimate memory usage correctly
56 | # https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119
57 | self.freed_bytes = 0
58 |
59 | # Note that torch_random_states may be different on each dp rank
60 | self.torch_random_states = torch.cuda.get_rng_state()
61 | # get a random rng states
62 | if self.device_mesh is not None:
63 | gen_dp_rank = self.device_mesh["dp"].get_local_rank()
64 | torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
65 | self.gen_random_states = torch.cuda.get_rng_state()
66 | torch.cuda.set_rng_state(self.torch_random_states)
67 | else:
68 | self.gen_random_states = None
69 |
70 | def _make_weight_iterator(
71 | self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
72 | ) -> Iterable[Tuple[str, torch.Tensor]]:
73 | for name, tensor in actor_weights.items():
74 | yield name, tensor.full_tensor() if self.world_size != 1 else tensor
75 |
76 | def __enter__(self):
77 | # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
78 | # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
79 | # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
80 | # to speed up memory allocations.
81 | #
82 | # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
83 | # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
84 | torch.cuda.empty_cache()
85 | print_gpu_memory_usage("Before state_dict() in sharding manager")
86 | actor_weights = self.module.state_dict()
87 | print_gpu_memory_usage("After state_dict() in sharding manager")
88 |
89 | self.inference_engine.wake_up()
90 | model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
91 | model.load_weights(self._make_weight_iterator(actor_weights))
92 | print_gpu_memory_usage("After sync model weights in sharding manager")
93 |
94 | del actor_weights
95 | torch.cuda.empty_cache()
96 | print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
97 | # important: need to manually set the random states of each tp to be identical.
98 | if self.device_mesh is not None:
99 | self.torch_random_states = torch.cuda.get_rng_state()
100 | torch.cuda.set_rng_state(self.gen_random_states)
101 |
102 | def __exit__(self, exc_type, exc_value, traceback):
103 | print_gpu_memory_usage("Before vllm offload in sharding manager")
104 | free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
105 | self.inference_engine.sleep(level=1)
106 | free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
107 | self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
108 | print_gpu_memory_usage("After vllm offload in sharding manager")
109 |
110 | self.module.train()
111 | torch.cuda.empty_cache() # add empty cache after each compute
112 |
113 | # restore random states
114 | if self.device_mesh is not None:
115 | self.gen_random_states = torch.cuda.get_rng_state()
116 | torch.cuda.set_rng_state(self.torch_random_states)
117 |
118 | def preprocess_data(self, data: DataProto) -> DataProto:
119 | """All gather across tp group to make each rank has identical input."""
120 | all_gather_data_proto(data, size=self.tp_size, group=self.tp_group)
121 | return data
122 |
123 | def postprocess_data(self, data: DataProto) -> DataProto:
124 | """Get chunk data of this tp rank since we do all gather in preprocess."""
125 | if self.tp_size > 1:
126 | data = data.chunk(chunks=self.tp_size)[self.tp_rank]
127 |
128 | return data
129 |
--------------------------------------------------------------------------------
/verl/utils/checkpoint/fsdp_checkpoint_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import warnings
17 | from typing import Optional, Union
18 |
19 | import torch
20 | import torch.distributed as dist
21 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
22 | from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
23 | from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
24 |
25 | from .checkpoint_manager import BaseCheckpointManager
26 |
27 |
28 | class FSDPCheckpointManager(BaseCheckpointManager):
29 | """
30 | A checkpoint manager that saves and loads
31 | - model
32 | - optimizer
33 | - lr_scheduler
34 | - extra_states
35 | in a SPMD way.
36 |
37 | We save
38 | - sharded model states and optimizer states
39 | - full lr_scheduler states
40 | - huggingface tokenizer and config for ckpt merge
41 | """
42 |
43 | def __init__(
44 | self,
45 | model: FSDP,
46 | optimizer: torch.optim.Optimizer,
47 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
48 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
49 | ):
50 | super().__init__(model, optimizer, lr_scheduler, processing_class)
51 |
52 | def load_checkpoint(self, path: Optional[str] = None):
53 | if path is None:
54 | return
55 |
56 | # every rank download its own checkpoint
57 | model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
58 | optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
59 | extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
60 | print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
61 | model_state_dict = torch.load(model_path, weights_only=False)
62 | optimizer_state_dict = torch.load(optim_path, weights_only=False)
63 | extra_state_dict = torch.load(extra_state_path, weights_only=False)
64 | lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
65 |
66 | state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
67 | optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
68 | with warnings.catch_warnings():
69 | warnings.simplefilter("ignore")
70 | with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
71 | self.model.load_state_dict(model_state_dict)
72 | if self.optimizer is not None:
73 | self.optimizer.load_state_dict(optimizer_state_dict)
74 |
75 | if self.lr_scheduler is not None:
76 | self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
77 |
78 | # recover random state
79 | if "rng" in extra_state_dict:
80 | self.load_rng_state(extra_state_dict["rng"])
81 |
82 | def save_checkpoint(self, path: str):
83 | path = self.local_mkdir(path)
84 | dist.barrier()
85 |
86 | # every rank will save its own model and optim shard
87 | state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
88 | optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
89 | with warnings.catch_warnings():
90 | warnings.simplefilter("ignore")
91 | with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
92 | model_state_dict = self.model.state_dict()
93 | if self.optimizer is not None:
94 | optimizer_state_dict = self.optimizer.state_dict()
95 | else:
96 | optimizer_state_dict = None
97 |
98 | if self.lr_scheduler is not None:
99 | lr_scheduler_state_dict = self.lr_scheduler.state_dict()
100 | else:
101 | lr_scheduler_state_dict = None
102 |
103 | extra_state_dict = {
104 | "lr_scheduler": lr_scheduler_state_dict,
105 | "rng": self.get_rng_state(),
106 | }
107 | model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
108 | optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
109 | extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
110 |
111 | print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
112 | print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
113 | print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
114 | torch.save(model_state_dict, model_path)
115 | if self.optimizer is not None:
116 | torch.save(optimizer_state_dict, optim_path)
117 |
118 | torch.save(extra_state_dict, extra_path)
119 |
120 | # wait for everyone to dump to local
121 | dist.barrier()
122 |
123 | if self.rank == 0:
124 | hf_path = os.path.join(path, "huggingface")
125 | os.makedirs(hf_path, exist_ok=True)
126 | assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel)
127 | self.model._fsdp_wrapped_module.config.save_pretrained(hf_path)
128 | self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path)
129 | self.processing_class.save_pretrained(hf_path)
130 |
131 | dist.barrier()
132 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # ARPO: End-to-End Policy Optimization for GUI Agents with Experience Replay
3 |
4 | This repository contains the code and models for the paper:
5 |
6 | > **ARPO: End-to-End Policy Optimization for GUI Agents with Experience Replay**
7 | > *Fanbin Lu, Zhisheng Zhong, Shu Liu, Chi-Wing Fu, Jiaya Jia*
8 | > CUHK, SmartMore, HKUST
9 | > [[Paper](https://arxiv.org/abs/2505.16282)] • [[Project Page](https://github.com/dvlab-research/ARPO)] • [[Model on HF](https://huggingface.co/Fanbin/ARPO_UITARS1.5_7B)]
10 |
11 | ## Overview
12 |
13 | **ARPO (Agentic Replay Policy Optimization)** is a novel reinforcement learning framework designed to train **vision-language GUI agents** to complete **long-horizon desktop tasks**. It builds upon **Group Relative Policy Optimization (GRPO)** and introduces:
14 |
15 | - **Distributed Rollouts**: Scalable task execution across parallel OSWorld environments with docker.
16 | - **Multi-modal Input Support**: Processes long histories (15 steps) of screenshots + actions in an end-to-end way.
17 |
18 | Access our [model](https://huggingface.co/Fanbin/ARPO_UITARS1.5_7B) on huggingface and view [training logs](https://wandb.ai/fanbinlu/arpo) on the Weights & Biases.
19 |
20 |
21 |
22 |
23 |
24 | ## 📊 Results on OSWorld
25 |
26 | | Model | 128 training tasks | OSWorld overall|
27 | |-----------------------------|---------|-------|
28 | | UI-Tars-1.5 |68.7% | 23.5% |
29 | | UI-Tars-1.5 + GRPO |72.9% | 26.0% |
30 | | **UI-Tars-1.5 + ARPO (Ours)** |83.9% | **29.9%** |
31 |
32 | > Evaluated with a max of **15 steps per trajectory**.
33 |
34 | ---
35 |
36 | ## 🛠 Installation
37 |
38 | ### 1. Clone the repository and create environment
39 |
40 | ```bash
41 | git clone --recurse-submodules https://github.com/dvlab-research/ARPO.git
42 | cd ARPO
43 |
44 | # Create and activate Conda environment
45 | conda create -n arpo python=3.10
46 | conda activate arpo
47 |
48 | # Install Python dependencies
49 | pip install -r requirements.txt
50 | ```
51 |
52 | ### 2. Install OSWorld
53 | Follow the origin installation guide of [OSWorld](https://github.com/xlang-ai/OSWorld) if you only want to evaluate the model. If you want to train with GRPO, you are required to pip install it.
54 | ```bash
55 | cd OSWorld
56 | pip install -e .
57 | cd ..
58 | ```
59 |
60 | > 💡 We strongly recommend running a full evaluation **with Docker** before training to prepare the docker image, Ubuntu VM data, and cache_dir required.
61 |
62 |
63 | ## ⚙️ Setup for Evaluation with OSWorld
64 |
65 | To evaluate ARPO on the OSWorld benchmark with the [released model](https://huggingface.co/Fanbin/ARPO_UITARS1.5_7B) using Docker-based virtual environments, follow these steps:
66 |
67 | ### 1. **Prepare the Environment**
68 |
69 | Ensure you have correctly installed [OSWorld](https://github.com/xlang-ai/OSWorld) by following its Docker setup instructions. Once OSWorld is set up:
70 |
71 | ```bash
72 | nohup bash start_server.sh &
73 | ```
74 |
75 | ### 2. **Run Evaluation Script**
76 |
77 | Navigate into the OSWorld directory and execute the evaluation script:
78 |
79 | ```bash
80 | cd OSWorld
81 |
82 | python run_multienv_uitars.py \
83 | --headless \
84 | --observation_type screenshot \
85 | --max_steps 15 \
86 | --max_trajectory_length 15 \
87 | --temperature 0.6 \
88 | --model ui-tars \
89 | --action_space pyautogui \
90 | --num_envs 8 \
91 | --result_dir ./results/ \
92 | --test_all_meta_path ./evaluation_examples/test_all.json \
93 | --trial-id 0 \
94 | --server_ip http://127.0.0.1
95 | ```
96 |
97 | ### ✅ Parameters Explained
98 |
99 | - `--headless`: Enables headless mode (no GUI rendering).
100 | - `--observation_type screenshot`: Use visual observations for the agent.
101 | - `--max_steps` / `--max_trajectory_length`: Limit per-task interaction steps.
102 | - `--temperature`: Sampling temperature for model output.
103 | - `--model`: Name of the model.
104 | - `--num_envs`: Number of parallel environments (VMs).
105 | - `--result_dir`: Directory to store evaluation results.
106 | - `--test_all_meta_path`: JSON file with evaluation task metadata.
107 | - `--trial-id`: ID for the evaluation trial.
108 | - `--server_ip`: IP of the evaluation server (usually localhost).
109 |
110 | > You will find vmware_vm_data/, docker_vm_data/, and cache/ folders under the OSWorld after evaluation.
111 | ---
112 |
113 | ## ⚙️ Setup for GRPO Training
114 |
115 | ```bash
116 | # Link evaluation examples and cache
117 | ln -s $(pwd)/OSWorld/evaluation_examples ./
118 | mkdir cache_dirs/
119 | ln -s $(pwd)/OSWorld/cache ./cache_dirs/cache_0
120 | ln -s $(pwd)/OSWorld/vmware_vm_data ./
121 | ln -s $(pwd)/OSWorld/docker_vm_data ./
122 | ```
123 |
124 | To run Docker without `sudo`:
125 |
126 | ```bash
127 | sudo usermod -aG docker $USER
128 | newgrp docker
129 | ```
130 |
131 | ---
132 |
133 | ## Training ARPO/GRPO with OSWorld
134 |
135 | ### Single Node (subset training: 32 tasks)
136 | If you only have one node, we suggest training on a subset of OSWorld tasks with at most 16 Docker environments.
137 | ```bash
138 | RAY_PORT=2468
139 | RAY_HEAD_IP=
140 | ray start --head --port=$RAY_PORT --resources='{"docker:'$RAY_HEAD_IP'": 128}'
141 | bash ./examples/osworld_subset32.sh
142 | ```
143 |
144 | ### Multi-Node Setup with Ray (e.g. 8 nodes, 128 envs)
145 |
146 | On **Ray master node**:
147 |
148 | ```bash
149 | RAY_PORT=2468
150 | RAY_HEAD_IP=
151 | ray start --head --port=$RAY_PORT --resources='{"docker:'$RAY_HEAD_IP'": 128}'
152 | ```
153 |
154 | On **Ray slave nodes** (with GPU):
155 |
156 | ```bash
157 | ray start --address=$RAY_HEAD_IP:$RAY_PORT --num-gpus=8 --resources='{"docker:'$CURRENT_IP'": 128}'
158 | ```
159 |
160 | Or (CPU only):
161 |
162 | ```bash
163 | ray start --address=$RAY_HEAD_IP:$RAY_PORT --resources='{"docker:'$CURRENT_IP'": 128}'
164 | ```
165 |
166 | Then run:
167 |
168 | ```bash
169 | bash ./examples/osworld_full_arpo.sh
170 | ```
171 |
172 | ---
173 |
174 | ## 🔗 Related Projects
175 |
176 | - [OSWorld](https://github.com/FanbinLu/OSWorld) — Realistic GUI environments for multimodal agents modified for GRPO training.
177 | - [EasyR1](https://github.com/hiyouga/EasyR1) An efficient, scalable, multi-modality RL training framework based on veRL, supporting advanced VLMs and algorithms like GRPO.
178 | ---
179 |
180 | ## 📄 Citation
181 |
182 | If you find ARPO useful, please consider citing our work.
183 |
--------------------------------------------------------------------------------
/scripts/model_merger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | import re
18 | from concurrent.futures import ThreadPoolExecutor
19 | from typing import Dict, List, Tuple
20 |
21 | import torch
22 | from torch.distributed._tensor import DTensor, Placement, Shard
23 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
24 |
25 |
26 | def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
27 | if placement.is_replicate():
28 | return tensors[0]
29 | elif placement.is_partial():
30 | raise NotImplementedError("Partial placement is not supported yet")
31 | elif placement.is_shard():
32 | return torch.cat(tensors, dim=placement.dim).contiguous()
33 | else:
34 | raise ValueError(f"Unsupported placement: {placement}")
35 |
36 |
37 | if __name__ == "__main__":
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
40 | parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
41 | args = parser.parse_args()
42 |
43 | assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface"
44 | local_dir = args.local_dir
45 |
46 | # copy rank zero to find the shape of (dp, fsdp)
47 | rank = 0
48 | world_size = 0
49 | for filename in os.listdir(local_dir):
50 | match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
51 | if match:
52 | world_size = match.group(1)
53 | break
54 | assert world_size, "No model file with the proper format"
55 |
56 | state_dict = torch.load(
57 | os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu"
58 | )
59 | pivot_key = sorted(state_dict.keys())[0]
60 | weight = state_dict[pivot_key]
61 | assert isinstance(weight, torch.distributed._tensor.DTensor)
62 | # get sharding info
63 | device_mesh = weight.device_mesh
64 | mesh = device_mesh.mesh
65 | mesh_dim_names = device_mesh.mesh_dim_names
66 |
67 | print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
68 |
69 | assert mesh_dim_names in (("fsdp",),), f"Unsupported mesh_dim_names {mesh_dim_names}"
70 |
71 | if "tp" in mesh_dim_names:
72 | # fsdp * tp
73 | total_shards = mesh.shape[-1] * mesh.shape[-2]
74 | mesh_shape = (mesh.shape[-2], mesh.shape[-1])
75 | else:
76 | # fsdp
77 | total_shards = mesh.shape[-1]
78 | mesh_shape = (mesh.shape[-1],)
79 |
80 | print(f"Processing model shards with {total_shards} {mesh_shape} in total")
81 |
82 | model_state_dict_lst = []
83 | model_state_dict_lst.append(state_dict)
84 | model_state_dict_lst.extend([""] * (total_shards - 1))
85 |
86 | def process_one_shard(rank):
87 | model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
88 | state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
89 | model_state_dict_lst[rank] = state_dict
90 | return state_dict
91 |
92 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
93 | for rank in range(1, total_shards):
94 | executor.submit(process_one_shard, rank)
95 | state_dict = {}
96 | param_placements: Dict[str, List[Placement]] = {}
97 | keys = set(model_state_dict_lst[0].keys())
98 | for key in keys:
99 | state_dict[key] = []
100 | for model_state_dict in model_state_dict_lst:
101 | try:
102 | tensor = model_state_dict.pop(key)
103 | except Exception:
104 | print("-" * 30)
105 | print(model_state_dict)
106 | if isinstance(tensor, DTensor):
107 | state_dict[key].append(tensor._local_tensor.bfloat16())
108 | placements = tuple(tensor.placements)
109 | # replicated placement at dp dimension can be discarded
110 | if mesh_dim_names[0] == "dp":
111 | placements = placements[1:]
112 | if key not in param_placements:
113 | param_placements[key] = placements
114 | else:
115 | assert param_placements[key] == placements
116 | else:
117 | state_dict[key] = tensor.bfloat16()
118 |
119 | del model_state_dict_lst
120 |
121 | for key in sorted(state_dict):
122 | if not isinstance(state_dict[key], list):
123 | print(f"No need to merge key {key}")
124 | continue
125 | # merge shards
126 | placements: Tuple[Shard] = param_placements[key]
127 | if len(mesh_shape) == 1:
128 | # 1-D list, FSDP without TP
129 | assert len(placements) == 1
130 | shards = state_dict[key]
131 | state_dict[key] = merge_by_placement(shards, placements[0])
132 | else:
133 | # 2-D list, FSDP + TP
134 | raise NotImplementedError("FSDP + TP is not supported yet")
135 |
136 | print("Writing to local disk")
137 | hf_path = os.path.join(local_dir, "huggingface")
138 | config = AutoConfig.from_pretrained(hf_path)
139 |
140 | if "ForTokenClassification" in config.architectures[0]:
141 | auto_model = AutoModelForTokenClassification
142 | elif "ForCausalLM" in config.architectures[0]:
143 | auto_model = AutoModelForCausalLM
144 | elif "ForConditionalGeneration" in config.architectures[0]:
145 | auto_model = AutoModelForVision2Seq
146 | else:
147 | raise NotImplementedError(f"Unknown architecture {config.architectures}")
148 |
149 | with torch.device("meta"):
150 | model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
151 |
152 | model.to_empty(device="cpu")
153 |
154 | print(f"Saving model to {hf_path}")
155 | model.save_pretrained(hf_path, state_dict=state_dict)
156 | del state_dict
157 | del model
158 | if args.hf_upload_path:
159 | # Push to hugging face
160 | from huggingface_hub import HfApi
161 |
162 | api = HfApi()
163 | api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
164 | api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
165 |
--------------------------------------------------------------------------------
/verl/single_controller/base/worker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | the class for Worker
16 | """
17 |
18 | import os
19 | import socket
20 | from dataclasses import dataclass
21 | from typing import Tuple
22 |
23 | import ray
24 | import torch
25 |
26 | from .decorator import Dispatch, Execute, register
27 | from .register_center.ray import create_worker_group_register_center
28 |
29 |
30 | @dataclass
31 | class DistRankInfo:
32 | tp_rank: int
33 | dp_rank: int
34 | pp_rank: int
35 |
36 |
37 | @dataclass
38 | class DistGlobalInfo:
39 | tp_size: int
40 | dp_size: int
41 | pp_size: int
42 |
43 |
44 | class WorkerHelper:
45 | def _get_node_ip(self) -> str:
46 | host_ipv4 = os.getenv("MY_HOST_IP", None)
47 | host_ipv6 = os.getenv("MY_HOST_IPV6", None)
48 | host_ip_by_env = host_ipv4 or host_ipv6
49 | host_ip_by_sdk = ray._private.services.get_node_ip_address()
50 |
51 | host_ip = host_ip_by_env or host_ip_by_sdk
52 | return host_ip
53 |
54 | def _get_free_port(self) -> int:
55 | with socket.socket() as sock:
56 | sock.bind(("", 0))
57 | return sock.getsockname()[1]
58 |
59 | def get_availale_master_addr_port(self) -> Tuple[str, str]:
60 | return self._get_node_ip(), str(self._get_free_port())
61 |
62 | def _get_pid(self):
63 | return
64 |
65 |
66 | class WorkerMeta:
67 | keys = [
68 | "WORLD_SIZE",
69 | "RANK",
70 | "LOCAL_WORLD_SIZE",
71 | "LOCAL_RANK",
72 | "MASTER_ADDR",
73 | "MASTER_PORT",
74 | "CUDA_VISIBLE_DEVICES",
75 | ]
76 |
77 | def __init__(self, store) -> None:
78 | self._store = store
79 |
80 | def to_dict(self):
81 | return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
82 |
83 |
84 | # we assume that in each WorkerGroup, there is a Master Worker
85 | class Worker(WorkerHelper):
86 | """A (distributed) worker."""
87 |
88 | _world_size: int
89 | _rank: int
90 | _local_world_size: int
91 | _local_rank: int
92 | _master_addr: str
93 | _master_port: str
94 | _cuda_visible_devices: str
95 |
96 | def __new__(cls, *args, **kwargs):
97 | instance = super().__new__(cls)
98 |
99 | # note that here we use int to distinguish
100 | disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
101 | if disable_worker_init:
102 | return instance
103 |
104 | rank = os.getenv("RANK", None)
105 | worker_group_prefix = os.getenv("WG_PREFIX", None)
106 |
107 | # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
108 | if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
109 | instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
110 |
111 | return instance
112 |
113 | def _configure_before_init(self, register_center_name: str, rank: int):
114 | assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
115 |
116 | if rank == 0:
117 | master_addr, master_port = self.get_availale_master_addr_port()
118 | rank_zero_info = {
119 | "MASTER_ADDR": master_addr,
120 | "MASTER_PORT": master_port,
121 | }
122 | self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info)
123 | os.environ.update(rank_zero_info)
124 |
125 | def __init__(self, cuda_visible_devices=None) -> None:
126 | # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
127 | world_size = int(os.getenv("WORLD_SIZE"))
128 | rank = int(os.getenv("RANK"))
129 | self._rank = rank
130 | self._world_size = world_size
131 |
132 | if "AMD" in torch.cuda.get_device_name():
133 | os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
134 | os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
135 | cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
136 | torch.cuda.set_device(int(cuda_visible_devices))
137 |
138 | master_addr = os.getenv("MASTER_ADDR")
139 | master_port = os.getenv("MASTER_PORT")
140 |
141 | local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
142 | local_rank = int(os.getenv("LOCAL_RANK", "0"))
143 |
144 | store = {
145 | "_world_size": world_size,
146 | "_rank": rank,
147 | "_local_world_size": local_world_size,
148 | "_local_rank": local_rank,
149 | "_master_addr": master_addr,
150 | "_master_port": master_port,
151 | }
152 | if cuda_visible_devices is not None:
153 | store["_cuda_visible_devices"] = cuda_visible_devices
154 |
155 | meta = WorkerMeta(store=store)
156 | self._configure_with_meta(meta=meta)
157 |
158 | def _configure_with_meta(self, meta: WorkerMeta):
159 | """
160 | This function should only be called inside by WorkerGroup
161 | """
162 | assert isinstance(meta, WorkerMeta)
163 | self.__dict__.update(meta.to_dict()) # this is hacky
164 | # print(f"__dict__: {self.__dict__}")
165 | for key in WorkerMeta.keys:
166 | val = self.__dict__.get(f"_{key.lower()}", None)
167 | if val is not None:
168 | # print(f"set {key} to {val}")
169 | os.environ[key] = str(val)
170 |
171 | os.environ["REDIS_STORE_SERVER_HOST"] = (
172 | str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
173 | )
174 |
175 | def get_master_addr_port(self):
176 | return self._master_addr, self._master_port
177 |
178 | def get_cuda_visible_devices(self):
179 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
180 | return cuda_visible_devices
181 |
182 | def print_rank0(self, *args, **kwargs):
183 | if self.rank == 0:
184 | print(*args, **kwargs)
185 |
186 | @property
187 | def world_size(self):
188 | return self._world_size
189 |
190 | @property
191 | def rank(self):
192 | return self._rank
193 |
194 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
195 | def execute_with_func_generator(self, func, *args, **kwargs):
196 | ret_proto = func(self, *args, **kwargs)
197 | return ret_proto
198 |
199 | @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
200 | def execute_func_rank_zero(self, func, *args, **kwargs):
201 | result = func(*args, **kwargs)
202 | return result
203 |
--------------------------------------------------------------------------------
/verl/utils/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import os
17 | from collections import defaultdict
18 | from io import BytesIO
19 | from typing import Any, Dict, List, Optional, Union
20 |
21 | import numpy as np
22 | import torch
23 | from datasets import load_dataset
24 | from PIL import Image
25 | from PIL.Image import Image as ImageObject
26 | from torch.utils.data import Dataset
27 | from transformers import PreTrainedTokenizer, ProcessorMixin
28 |
29 | from ..models.transformers.qwen2_vl import get_rope_index
30 | from . import torch_functional as VF
31 |
32 |
33 | def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
34 | tensors = defaultdict(list)
35 | non_tensors = defaultdict(list)
36 | for feature in features:
37 | for key, value in feature.items():
38 | if isinstance(value, torch.Tensor):
39 | tensors[key].append(value)
40 | else:
41 | non_tensors[key].append(value)
42 |
43 | for key, value in tensors.items():
44 | tensors[key] = torch.stack(value, dim=0)
45 |
46 | for key, value in non_tensors.items():
47 | non_tensors[key] = np.array(value, dtype=object)
48 |
49 | return {**tensors, **non_tensors}
50 |
51 |
52 | class ImageProcessMixin:
53 | max_pixels: int
54 | min_pixels: int
55 |
56 | def process_image(self, image: Union[Dict[str, Any], ImageObject]) -> ImageObject:
57 | if isinstance(image, dict):
58 | image = Image.open(BytesIO(image["bytes"]))
59 | elif isinstance(image, bytes):
60 | image = Image.open(BytesIO(image))
61 |
62 | if (image.width * image.height) > self.max_pixels:
63 | resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
64 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
65 | image = image.resize((width, height))
66 |
67 | if (image.width * image.height) < self.min_pixels:
68 | resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
69 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
70 | image = image.resize((width, height))
71 |
72 | if image.mode != "RGB":
73 | image = image.convert("RGB")
74 |
75 | return image
76 |
77 |
78 | class RLHFDataset(Dataset, ImageProcessMixin):
79 | """
80 | We assume the dataset contains a column that contains prompts and other information
81 | """
82 |
83 | def __init__(
84 | self,
85 | data_path: str,
86 | tokenizer: PreTrainedTokenizer,
87 | processor: Optional[ProcessorMixin],
88 | prompt_key: str = "prompt",
89 | answer_key: str = "answer",
90 | image_key: str = "images",
91 | max_prompt_length: int = 1024,
92 | truncation: str = "error",
93 | format_prompt: str = None,
94 | max_pixels: int = None,
95 | min_pixels: int = None,
96 | ):
97 | self.tokenizer = tokenizer
98 | self.processor = processor
99 | self.prompt_key = prompt_key
100 | self.answer_key = answer_key
101 | self.image_key = image_key
102 | self.max_prompt_length = max_prompt_length
103 | self.truncation = truncation
104 | self.format_prompt = format_prompt
105 | self.max_pixels = max_pixels
106 | self.min_pixels = min_pixels
107 |
108 | if "@" in data_path:
109 | data_path, data_split = data_path.split("@")
110 | else:
111 | data_split = "train"
112 |
113 | if os.path.isdir(data_path):
114 | self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
115 | elif os.path.isfile(data_path):
116 | self.dataset = load_dataset("parquet", data_files=data_path, split="train")
117 | else: # remote dataset
118 | self.dataset = load_dataset(data_path, split=data_split)
119 |
120 | def __len__(self):
121 | return len(self.dataset)
122 |
123 | def __getitem__(self, index):
124 | row_dict: dict = self.dataset[index]
125 | prompt_str: str = row_dict[self.prompt_key]
126 | if self.format_prompt:
127 | prompt_str = prompt_str + " " + self.format_prompt.strip()
128 |
129 | if self.image_key in row_dict:
130 | # https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
131 | content_list = []
132 | for i, content in enumerate(prompt_str.split("")):
133 | if i != 0:
134 | content_list.append({"type": "image"})
135 |
136 | if content:
137 | content_list.append({"type": "text", "text": content})
138 |
139 | messages = [{"role": "user", "content": content_list}]
140 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
141 | images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
142 | model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt")
143 | input_ids = model_inputs.pop("input_ids")[0]
144 | attention_mask = model_inputs.pop("attention_mask")[0]
145 | row_dict["multi_modal_data"] = {"image": images}
146 | row_dict["multi_modal_inputs"] = dict(model_inputs)
147 |
148 | # qwen2vl mrope
149 | position_ids = get_rope_index(
150 | self.processor,
151 | input_ids=input_ids,
152 | image_grid_thw=model_inputs["image_grid_thw"],
153 | attention_mask=attention_mask,
154 | ) # (3, seq_length)
155 | else:
156 | messages = [{"role": "user", "content": prompt_str}]
157 | prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
158 | model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
159 | input_ids = model_inputs.pop("input_ids")[0]
160 | attention_mask = model_inputs.pop("attention_mask")[0]
161 | position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
162 |
163 | input_ids, attention_mask, position_ids = VF.postprocess_data(
164 | input_ids=input_ids,
165 | attention_mask=attention_mask,
166 | position_ids=position_ids,
167 | max_length=self.max_prompt_length,
168 | pad_token_id=self.tokenizer.pad_token_id,
169 | left_pad=True,
170 | truncation=self.truncation,
171 | )
172 | row_dict["input_ids"] = input_ids
173 | row_dict["attention_mask"] = attention_mask
174 | row_dict["position_ids"] = position_ids
175 | row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False)
176 | row_dict["ground_truth"] = row_dict.pop(self.answer_key)
177 | return row_dict
178 |
--------------------------------------------------------------------------------
/verl/single_controller/base/worker_group.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | the class of WorkerGroup
16 | """
17 |
18 | import logging
19 | import signal
20 | import threading
21 | import time
22 | from typing import Any, Callable, Dict, List, Optional
23 |
24 | from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
25 |
26 |
27 | class ResourcePool:
28 | """The resource pool with meta info such as world size."""
29 |
30 | def __init__(
31 | self, process_on_nodes: Optional[Any] = None, max_collocate_count: int = 10, n_gpus_per_node: int = 8
32 | ) -> None:
33 | if process_on_nodes is None:
34 | process_on_nodes = []
35 |
36 | self._store = process_on_nodes
37 | self.max_collocate_count = max_collocate_count
38 | self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
39 |
40 | def add_node(self, process_count):
41 | self._store.append(process_count)
42 |
43 | @property
44 | def world_size(self):
45 | return sum(self._store)
46 |
47 | def __call__(self) -> Any:
48 | return self._store
49 |
50 | @property
51 | def store(self):
52 | return self._store
53 |
54 | def local_world_size_list(self) -> List[int]:
55 | nested_local_world_size_list = [
56 | [local_world_size for _ in range(local_world_size)] for local_world_size in self._store
57 | ]
58 | return [item for row in nested_local_world_size_list for item in row]
59 |
60 | def local_rank_list(self) -> List[int]:
61 | nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416
62 | return [item for row in nested_local_rank_list for item in row]
63 |
64 |
65 | class ClassWithInitArgs:
66 | """
67 | This class stores a class constructor and the args/kwargs to construct the class.
68 | It is used to instantiate the remote class.
69 | """
70 |
71 | def __init__(self, cls, *args, **kwargs) -> None:
72 | self.cls = cls
73 | self.args = args
74 | self.kwargs = kwargs
75 |
76 | def __call__(self) -> Any:
77 | return self.cls(*self.args, **self.kwargs)
78 |
79 |
80 | def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
81 | while True:
82 | for worker in workers:
83 | if not is_alive(worker):
84 | logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
85 | signal.raise_signal(signal.SIGABRT)
86 |
87 | time.sleep(gap_time)
88 |
89 |
90 | class WorkerGroup:
91 | """A group of workers"""
92 |
93 | def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
94 | self._is_init_with_detached_workers = True if resource_pool is None else False
95 |
96 | if resource_pool is not None:
97 | # handle the case when WorkGroup is attached to an existing one
98 | self._procecss_dispatch_config = resource_pool()
99 | else:
100 | self._procecss_dispatch_config = None
101 |
102 | self._workers = []
103 | self._worker_names = []
104 |
105 | self._master_addr = None
106 | self._master_port = None
107 |
108 | self._checker_thread: threading.Thread = None
109 |
110 | def _is_worker_alive(self, worker):
111 | raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.")
112 |
113 | def _block_until_all_workers_alive(self) -> None:
114 | while True:
115 | all_state = [self._is_worker_alive(worker) for worker in self._workers]
116 | if False in all_state:
117 | time.sleep(1)
118 | else:
119 | break
120 |
121 | def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
122 | # before starting checking worker aliveness, make sure all workers are already alive
123 | self._block_until_all_workers_alive()
124 |
125 | self._checker_thread = threading.Thread(
126 | target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)
127 | )
128 | self._checker_thread.start()
129 |
130 | @property
131 | def world_size(self):
132 | return len(self._workers)
133 |
134 | def _bind_worker_method(self, user_defined_cls, func_generator):
135 | """
136 | Bind the worker method to the WorkerGroup
137 | """
138 | for method_name in dir(user_defined_cls):
139 | try:
140 | method = getattr(user_defined_cls, method_name)
141 | assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
142 | except Exception:
143 | # if it is a property, it will fail because Class doesn't have instance property
144 | continue
145 |
146 | if hasattr(method, MAGIC_ATTR):
147 | # this method is decorated by register
148 | attribute = getattr(method, MAGIC_ATTR)
149 | assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
150 | assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
151 |
152 | dispatch_mode = attribute["dispatch_mode"]
153 | execute_mode = attribute["execute_mode"]
154 | blocking = attribute["blocking"]
155 |
156 | # get dispatch fn
157 | if isinstance(dispatch_mode, Dispatch):
158 | # get default dispatch fn
159 | fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
160 | dispatch_fn = fn["dispatch_fn"]
161 | collect_fn = fn["collect_fn"]
162 | else:
163 | assert isinstance(dispatch_mode, dict)
164 | assert "dispatch_fn" in dispatch_mode
165 | assert "collect_fn" in dispatch_mode
166 | dispatch_fn = dispatch_mode["dispatch_fn"]
167 | collect_fn = dispatch_mode["collect_fn"]
168 |
169 | # get execute_fn_name
170 | execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
171 | wg_execute_fn_name = execute_mode["execute_fn_name"]
172 |
173 | # get execute_fn from string
174 | try:
175 | execute_fn = getattr(self, wg_execute_fn_name)
176 | assert callable(execute_fn), "execute_fn must be callable"
177 | except Exception:
178 | print(f"execute_fn {wg_execute_fn_name} is invalid")
179 | raise
180 |
181 | # bind a new method to the RayWorkerGroup
182 | func = func_generator(
183 | self,
184 | method_name,
185 | dispatch_fn=dispatch_fn,
186 | collect_fn=collect_fn,
187 | execute_fn=execute_fn,
188 | blocking=blocking,
189 | )
190 |
191 | try:
192 | setattr(self, method_name, func)
193 | except Exception:
194 | raise ValueError(f"Fail to set method_name {method_name}")
195 |
--------------------------------------------------------------------------------
/verl/single_controller/base/decorator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from enum import Enum, auto
16 | from functools import wraps
17 | from types import FunctionType
18 | from typing import TYPE_CHECKING, Dict, List, Literal, Union
19 |
20 | import ray
21 |
22 | from ...protocol import DataProto, DataProtoFuture
23 |
24 |
25 | if TYPE_CHECKING:
26 | from .worker_group import WorkerGroup
27 |
28 |
29 | # here we add a magic number of avoid user-defined function already have this attribute
30 | MAGIC_ATTR = "attrs_3141562937"
31 |
32 |
33 | class Dispatch(Enum):
34 | RANK_ZERO = auto()
35 | ONE_TO_ALL = auto()
36 | ALL_TO_ALL = auto()
37 | DP_COMPUTE = auto()
38 | DP_COMPUTE_PROTO = auto()
39 | DP_COMPUTE_PROTO_WITH_FUNC = auto()
40 | DP_COMPUTE_METRIC = auto()
41 |
42 |
43 | class Execute(Enum):
44 | ALL = 0
45 | RANK_ZERO = 1
46 |
47 |
48 | def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
49 | splitted_args = []
50 | for arg in args:
51 | assert isinstance(arg, (DataProto, DataProtoFuture))
52 | splitted_args.append(arg.chunk(chunks=chunks))
53 |
54 | splitted_kwargs = {}
55 | for key, value in kwargs.items():
56 | assert isinstance(value, (DataProto, DataProtoFuture))
57 | splitted_kwargs[key] = value.chunk(chunks=chunks)
58 |
59 | return splitted_args, splitted_kwargs
60 |
61 |
62 | def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
63 | args = tuple([arg] * worker_group.world_size for arg in args)
64 | kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
65 | return args, kwargs
66 |
67 |
68 | def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
69 | return args, kwargs
70 |
71 |
72 | def collect_all_to_all(worker_group: "WorkerGroup", output):
73 | return output
74 |
75 |
76 | def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
77 | # make sure all the elements in output has the same type
78 | for output in outputs:
79 | assert type(output) is type(outputs[0])
80 |
81 | output = outputs[0]
82 |
83 | if isinstance(output, DataProto):
84 | return DataProto.concat(outputs)
85 | elif isinstance(output, ray.ObjectRef):
86 | return DataProtoFuture.concat(outputs)
87 | else:
88 | raise NotImplementedError
89 |
90 |
91 | def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
92 | for arg in args:
93 | assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
94 |
95 | for value in kwargs.values():
96 | assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
97 |
98 | return args, kwargs
99 |
100 |
101 | def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
102 | assert len(outputs) == worker_group.world_size
103 | return outputs
104 |
105 |
106 | def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
107 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
108 | return splitted_args, splitted_kwargs
109 |
110 |
111 | def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
112 | assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
113 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
114 | splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
115 | return splitted_args_with_func, splitted_kwargs
116 |
117 |
118 | def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
119 | for output in outputs:
120 | assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
121 |
122 | outputs = collect_dp_compute(worker_group, outputs)
123 | return _concat_data_proto_or_future(outputs)
124 |
125 |
126 | def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
127 | predefined_dispatch_mode_fn = {
128 | Dispatch.ONE_TO_ALL: {
129 | "dispatch_fn": dispatch_one_to_all,
130 | "collect_fn": collect_all_to_all,
131 | },
132 | Dispatch.ALL_TO_ALL: {
133 | "dispatch_fn": dispatch_all_to_all,
134 | "collect_fn": collect_all_to_all,
135 | },
136 | Dispatch.DP_COMPUTE: {
137 | "dispatch_fn": dispatch_dp_compute,
138 | "collect_fn": collect_dp_compute,
139 | },
140 | Dispatch.DP_COMPUTE_PROTO: {
141 | "dispatch_fn": dispatch_dp_compute_data_proto,
142 | "collect_fn": collect_dp_compute_data_proto,
143 | },
144 | Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
145 | "dispatch_fn": dispatch_dp_compute_data_proto_with_func,
146 | "collect_fn": collect_dp_compute_data_proto,
147 | },
148 | Dispatch.DP_COMPUTE_METRIC: {
149 | "dispatch_fn": dispatch_dp_compute_data_proto,
150 | "collect_fn": collect_dp_compute,
151 | },
152 | }
153 | return predefined_dispatch_mode_fn[dispatch_mode]
154 |
155 |
156 | def get_predefined_execute_fn(execute_mode: Execute):
157 | """
158 | Note that here we only asks execute_all and execute_rank_zero to be implemented
159 | Leave the choice of how these two functions handle argument 'blocking' to users
160 | """
161 | predefined_execute_mode_fn = {
162 | Execute.ALL: {"execute_fn_name": "execute_all"},
163 | Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
164 | }
165 | return predefined_execute_mode_fn[execute_mode]
166 |
167 |
168 | def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
169 | assert isinstance(dispatch_mode, (Dispatch, dict)), (
170 | f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
171 | )
172 | if isinstance(dispatch_mode, dict):
173 | necessary_keys = ["dispatch_fn", "collect_fn"]
174 | for key in necessary_keys:
175 | assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
176 |
177 |
178 | def _check_execute_mode(execute_mode: Execute):
179 | assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
180 |
181 |
182 | def _materialize_futures(*args, **kwargs):
183 | new_args = []
184 | for arg in args:
185 | if isinstance(arg, DataProtoFuture):
186 | arg = arg.get()
187 | # add more type to materialize
188 | new_args.append(arg)
189 |
190 | for key, value in kwargs.items():
191 | if isinstance(value, DataProtoFuture):
192 | kwargs[key] = value.get()
193 |
194 | new_args = tuple(new_args)
195 | return new_args, kwargs
196 |
197 |
198 | def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
199 | _check_dispatch_mode(dispatch_mode=dispatch_mode)
200 | _check_execute_mode(execute_mode=execute_mode)
201 |
202 | def decorator(func):
203 | @wraps(func)
204 | def inner(*args, **kwargs):
205 | if materialize_futures:
206 | args, kwargs = _materialize_futures(*args, **kwargs)
207 | return func(*args, **kwargs)
208 |
209 | attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
210 | setattr(inner, MAGIC_ATTR, attrs)
211 | return inner
212 |
213 | return decorator
214 |
--------------------------------------------------------------------------------
/verl/models/transformers/flash_attention_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
3 | # Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import inspect
18 | import os
19 | from typing import Optional, Tuple
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
24 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
25 |
26 | from ...utils.ulysses import (
27 | gather_heads_scatter_seq,
28 | gather_seq_scatter_heads,
29 | get_ulysses_sequence_parallel_group,
30 | get_ulysses_sequence_parallel_world_size,
31 | )
32 |
33 |
34 | if is_flash_attn_2_available():
35 | from flash_attn import flash_attn_func, flash_attn_varlen_func
36 |
37 | _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
38 | _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
39 | _flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
40 | _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
41 |
42 |
43 | def prepare_fa2_from_position_ids(
44 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
45 | ):
46 | query = query.view(-1, query.size(-2), query.size(-1))
47 | key = key.contiguous().view(-1, key.size(-2), key.size(-1))
48 | value = value.contiguous().view(-1, value.size(-2), value.size(-1))
49 | position_ids = position_ids.flatten()
50 | indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
51 | cu_seqlens = torch.cat(
52 | (
53 | indices_q[position_ids == 0],
54 | torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
55 | )
56 | )
57 | max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
58 | return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
59 |
60 |
61 | def _custom_flash_attention_forward(
62 | query_states: torch.Tensor,
63 | key_states: torch.Tensor,
64 | value_states: torch.Tensor,
65 | attention_mask: Optional[torch.Tensor],
66 | query_length: int,
67 | is_causal: bool = True,
68 | position_ids: Optional[torch.Tensor] = None,
69 | sliding_window: Optional[int] = None,
70 | use_top_left_mask: bool = False,
71 | deterministic: Optional[bool] = None,
72 | **kwargs,
73 | ):
74 | """
75 | Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
76 | """
77 | if not use_top_left_mask:
78 | causal = is_causal
79 | else:
80 | causal = is_causal and query_length != 1
81 |
82 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
83 | use_sliding_windows = (
84 | _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
85 | )
86 | flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
87 |
88 | if _flash_supports_deterministic:
89 | flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
90 |
91 | if kwargs.get("softcap") is not None:
92 | flash_kwargs["softcap"] = kwargs.pop("softcap")
93 |
94 | query_states, key_states, value_states = fa_peft_integration_check(
95 | query_states, key_states, value_states, target_dtype=torch.bfloat16
96 | )
97 |
98 | sp_size = get_ulysses_sequence_parallel_world_size()
99 | if sp_size > 1:
100 | # (batch_size, seq_length, num_head, head_size)
101 | query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
102 | key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
103 | value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
104 | position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
105 | position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
106 | position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
107 |
108 | if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
109 | position_ids = position_ids[0]
110 |
111 | if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
112 | batch_size = query_states.size(0)
113 | query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
114 | query_states, key_states, value_states, position_ids
115 | )
116 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
117 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
118 | attn_output = flash_attn_varlen_func(
119 | query_states,
120 | key_states,
121 | value_states,
122 | cu_seqlens_q=cu_seqlens_q,
123 | cu_seqlens_k=cu_seqlens_k,
124 | max_seqlen_q=max_seqlen_in_batch_q,
125 | max_seqlen_k=max_seqlen_in_batch_k,
126 | dropout_p=kwargs.pop("dropout", 0.0),
127 | softmax_scale=kwargs.pop("softmax_scale", None),
128 | causal=causal,
129 | **flash_kwargs,
130 | )
131 | attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
132 | else:
133 | attn_output = _flash_attention_forward(
134 | query_states,
135 | key_states,
136 | value_states,
137 | attention_mask,
138 | query_length,
139 | is_causal=is_causal,
140 | sliding_window=sliding_window,
141 | use_top_left_mask=use_top_left_mask,
142 | deterministic=deterministic,
143 | **kwargs,
144 | ) # do not pass position_ids to old flash_attention_forward
145 |
146 | if sp_size > 1:
147 | # (batch_size, seq_length, num_head, head_size)
148 | attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
149 |
150 | return attn_output
151 |
152 |
153 | def flash_attention_forward(
154 | module: torch.nn.Module,
155 | query: torch.Tensor,
156 | key: torch.Tensor,
157 | value: torch.Tensor,
158 | attention_mask: Optional[torch.Tensor],
159 | dropout: float = 0.0,
160 | scaling: Optional[float] = None,
161 | sliding_window: Optional[int] = None,
162 | softcap: Optional[float] = None,
163 | **kwargs,
164 | ) -> Tuple[torch.Tensor, None]:
165 | # This is before the transpose
166 | q_len = query.shape[2]
167 |
168 | # FA2 uses non-transposed inputs
169 | query = query.transpose(1, 2)
170 | key = key.transpose(1, 2)
171 | value = value.transpose(1, 2)
172 |
173 | # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
174 | kwargs.pop("is_causal", None)
175 |
176 | attn_output = _custom_flash_attention_forward(
177 | query,
178 | key,
179 | value,
180 | attention_mask,
181 | query_length=q_len,
182 | is_causal=True,
183 | dropout=dropout,
184 | softmax_scale=scaling,
185 | sliding_window=sliding_window,
186 | softcap=softcap,
187 | use_top_left_mask=_flash_use_top_left_mask,
188 | **kwargs,
189 | )
190 |
191 | return attn_output, None
192 |
--------------------------------------------------------------------------------
/verl/workers/rollout/vllm_rollout_spmd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The vllm_rollout that can be applied in different backend
16 | When working with FSDP:
17 | - Use DTensor weight loader (recommended) or HF weight loader
18 | - Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
19 | """
20 |
21 | import os
22 | from contextlib import contextmanager
23 | from typing import Any, List, Union
24 |
25 | import numpy as np
26 | import torch
27 | import torch.distributed
28 | from tensordict import TensorDict
29 | from transformers import PreTrainedTokenizer
30 | from vllm import LLM, RequestOutput, SamplingParams
31 |
32 | from ...protocol import DataProto
33 | from ...utils import torch_functional as VF
34 | from ...utils.torch_dtypes import PrecisionType
35 | from .base import BaseRollout
36 | from .config import RolloutConfig
37 |
38 |
39 | def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
40 | if isinstance(value, torch.Tensor):
41 | return value.repeat_interleave(repeats, dim=0)
42 | else:
43 | return np.repeat(value, repeats, axis=0)
44 |
45 |
46 | class vLLMRollout(BaseRollout):
47 | def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer):
48 | """A vLLM rollout. It requires the module is supported by the vllm.
49 |
50 | Args:
51 | module: module here follows huggingface APIs
52 | config: DictConfig
53 | tokenizer: the task/model tokenizer
54 | """
55 | super().__init__()
56 | self.rank = int(os.getenv("RANK", "0"))
57 | self.config = config
58 | self.pad_token_id = tokenizer.pad_token_id
59 | if config.tensor_parallel_size > torch.distributed.get_world_size():
60 | raise ValueError("Tensor parallelism size should be less than world size.")
61 |
62 | if config.max_num_batched_tokens < config.prompt_length + config.response_length:
63 | raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")
64 |
65 | vllm_init_kwargs = {}
66 | if config.limit_images > 0:
67 | vllm_init_kwargs = {"limit_mm_per_prompt": {"image": config.limit_images}}
68 |
69 | self.inference_engine = LLM(
70 | model=model_path,
71 | skip_tokenizer_init=False,
72 | tensor_parallel_size=config.tensor_parallel_size,
73 | dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
74 | gpu_memory_utilization=config.gpu_memory_utilization,
75 | enforce_eager=config.enforce_eager,
76 | # max_model_len=config.prompt_length + config.response_length,
77 | # max_num_batched_tokens=config.max_num_batched_tokens,
78 | enable_sleep_mode=True,
79 | distributed_executor_backend="external_launcher",
80 | disable_custom_all_reduce=True,
81 | disable_mm_preprocessor_cache=True,
82 | disable_log_stats=config.disable_log_stats,
83 | enable_chunked_prefill=config.enable_chunked_prefill,
84 | **vllm_init_kwargs,
85 | )
86 |
87 | # Offload vllm model to reduce peak memory usage
88 | self.inference_engine.sleep(level=1)
89 |
90 | sampling_kwargs = {"max_tokens": config.response_length, "detokenize": False}
91 | default_sampling_params = SamplingParams()
92 | for key in config.to_dict().keys():
93 | if hasattr(default_sampling_params, key):
94 | sampling_kwargs[key] = getattr(config, key)
95 |
96 | print(f"Sampling params: {sampling_kwargs}.")
97 | self.sampling_params = SamplingParams(**sampling_kwargs)
98 |
99 | print(f"Reset sampling_params.n=1")
100 | self.sampling_params.n = 1
101 |
102 | @contextmanager
103 | def update_sampling_params(self, **kwargs):
104 | # update sampling params
105 | old_sampling_params_args = {}
106 | if kwargs:
107 | for key, value in kwargs.items():
108 | if hasattr(self.sampling_params, key):
109 | old_value = getattr(self.sampling_params, key)
110 | old_sampling_params_args[key] = old_value
111 | setattr(self.sampling_params, key, value)
112 |
113 | yield
114 | # roll back to previous sampling params
115 | for key, value in old_sampling_params_args.items():
116 | setattr(self.sampling_params, key, value)
117 |
118 | @torch.no_grad()
119 | def generate_sequences(self, prompts: DataProto) -> DataProto:
120 | # left-padded attention_mask
121 | input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length)
122 | attention_mask: torch.Tensor = prompts.batch["attention_mask"]
123 | position_ids: torch.Tensor = prompts.batch["position_ids"]
124 | eos_token_id: int = prompts.meta_info["eos_token_id"]
125 | batch_size = input_ids.size(0)
126 |
127 | non_tensor_batch = prompts.non_tensor_batch
128 | if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
129 | raise RuntimeError("vllm sharding manager is not work properly.")
130 |
131 | if "multi_modal_data" in non_tensor_batch:
132 | vllm_inputs = []
133 | for raw_prompt_ids, multi_modal_data in zip(
134 | non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")
135 | ):
136 | vllm_inputs.append({"prompt_token_ids": list(raw_prompt_ids), "multi_modal_data": multi_modal_data})
137 | else:
138 | vllm_inputs = [
139 | {"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
140 | ]
141 |
142 | # users can customize different sampling_params at different run
143 | with self.update_sampling_params(**prompts.meta_info):
144 | completions: List[RequestOutput] = self.inference_engine.generate(
145 | prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=(self.rank == 0)
146 | )
147 | response_ids = [output.token_ids for completion in completions for output in completion.outputs]
148 | response_ids = VF.pad_2d_list_to_length(
149 | response_ids, self.pad_token_id, max_length=self.config.response_length
150 | ).to(input_ids.device)
151 |
152 | if self.sampling_params.n > 1:
153 | batch_size = batch_size * self.sampling_params.n
154 | input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
155 | attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
156 | position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
157 | if "multi_modal_inputs" in non_tensor_batch.keys():
158 | non_tensor_batch["multi_modal_inputs"] = _repeat_interleave(
159 | non_tensor_batch["multi_modal_inputs"], self.sampling_params.n
160 | )
161 |
162 | sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
163 | response_length = response_ids.size(1)
164 | delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
165 | delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1)
166 | if position_ids.dim() == 3: # qwen2vl mrope
167 | delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
168 |
169 | # prompt: left pad + response: right pad
170 | # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
171 | # position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
172 | response_position_ids = position_ids[..., -1:] + delta_position_id
173 | position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
174 | response_mask = VF.get_response_mask(
175 | response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
176 | )
177 | attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
178 |
179 | # all the tp ranks should contain the same data here. data in all ranks are valid
180 | batch = TensorDict(
181 | {
182 | "prompts": input_ids,
183 | "responses": response_ids,
184 | "input_ids": sequence_ids, # here input_ids become the whole sentences
185 | "attention_mask": attention_mask,
186 | "response_mask": response_mask,
187 | "position_ids": position_ids,
188 | },
189 | batch_size=batch_size,
190 | )
191 | return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
192 |
--------------------------------------------------------------------------------
/verl/workers/critic/dp_critic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Implement Critic
16 | """
17 |
18 | import os
19 | from collections import defaultdict
20 | from typing import Any, Dict
21 |
22 | import torch
23 | from ray.experimental.tqdm_ray import tqdm
24 | from torch import nn
25 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
26 |
27 | from ...protocol import DataProto
28 | from ...trainer import core_algos
29 | from ...utils import torch_functional as VF
30 | from ...utils.py_functional import append_to_dict
31 | from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
32 | from .base import BasePPOCritic
33 | from .config import CriticConfig
34 |
35 |
36 | try:
37 | from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
38 | except ImportError:
39 | pass
40 |
41 |
42 | __all__ = ["DataParallelPPOCritic"]
43 |
44 |
45 | class DataParallelPPOCritic(BasePPOCritic):
46 | def __init__(self, config: CriticConfig, critic_module: nn.Module, critic_optimizer: torch.optim.Optimizer):
47 | super().__init__(config)
48 | self.rank = int(os.getenv("RANK", "0"))
49 | self.critic_module = critic_module
50 | self.critic_optimizer = critic_optimizer
51 |
52 | def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Tensor:
53 | input_ids = micro_batch["input_ids"]
54 | batch_size, seqlen = input_ids.shape
55 | attention_mask = micro_batch["attention_mask"]
56 | position_ids = micro_batch["position_ids"]
57 | responses = micro_batch["responses"]
58 | response_length = responses.size(-1)
59 | if position_ids.dim() == 3: # qwen2vl mrope
60 | position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
61 |
62 | multi_modal_inputs = {}
63 | if "multi_modal_inputs" in micro_batch:
64 | for key in micro_batch["multi_modal_inputs"][0].keys():
65 | multi_modal_inputs[key] = torch.cat(
66 | [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
67 | )
68 |
69 | if self.config.padding_free:
70 | input_ids_rmpad, indices, *_ = unpad_input(
71 | input_ids.unsqueeze(-1), attention_mask
72 | ) # input_ids_rmpad (total_nnz, ...)
73 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
74 |
75 | # unpad the position_ids to align the rotary
76 | if position_ids.dim() == 3:
77 | position_ids_rmpad = (
78 | index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
79 | .transpose(0, 1)
80 | .unsqueeze(1)
81 | ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
82 | else:
83 | position_ids_rmpad = index_first_axis(
84 | rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
85 | ).transpose(0, 1)
86 |
87 | # pad and slice the inputs if sp > 1
88 | if self.config.ulysses_sequence_parallel_size > 1:
89 | input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
90 | input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size
91 | )
92 |
93 | # only pass input_ids and position_ids to enable flash_attn_varlen
94 | output = self.critic_module(
95 | input_ids=input_ids_rmpad,
96 | attention_mask=None,
97 | position_ids=position_ids_rmpad,
98 | **multi_modal_inputs,
99 | use_cache=False,
100 | ) # prevent model thinks we are generating
101 | values_rmpad = output.logits
102 | values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
103 |
104 | # gather output if sp > 1
105 | if self.config.ulysses_sequence_parallel_size > 1:
106 | values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
107 |
108 | # pad it back
109 | values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
110 | values = values[:, -response_length - 1 : -1]
111 | else:
112 | output = self.critic_module(
113 | input_ids=input_ids,
114 | attention_mask=attention_mask,
115 | position_ids=position_ids,
116 | **multi_modal_inputs,
117 | use_cache=False,
118 | )
119 | values: torch.Tensor = output.logits
120 | values = values[:, -response_length - 1 : -1].squeeze(-1) # (bsz, response_length, vocab_size)
121 |
122 | return values
123 |
124 | def _optimizer_step(self) -> torch.Tensor:
125 | if isinstance(self.critic_module, FSDP):
126 | grad_norm = self.critic_module.clip_grad_norm_(self.config.max_grad_norm)
127 | else:
128 | grad_norm = torch.nn.utils.clip_grad_norm_(
129 | self.critic_module.parameters(), max_norm=self.config.max_grad_norm
130 | )
131 |
132 | if not torch.isfinite(grad_norm):
133 | print("Gradient norm is not finite. Skip update.")
134 | else:
135 | self.critic_optimizer.step()
136 |
137 | self.critic_optimizer.zero_grad()
138 | return grad_norm
139 |
140 | @torch.no_grad()
141 | def compute_values(self, data: DataProto) -> torch.Tensor:
142 | self.critic_module.eval()
143 |
144 | select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
145 | if "multi_modal_inputs" in data.non_tensor_batch.keys():
146 | non_tensor_select_keys = ["multi_modal_inputs"]
147 | else:
148 | non_tensor_select_keys = []
149 |
150 | micro_batches = data.select(select_keys, non_tensor_select_keys).split(
151 | self.config.micro_batch_size_per_device_for_experience
152 | )
153 | values_lst = []
154 | if self.rank == 0:
155 | micro_batches = tqdm(micro_batches, desc="Compute values", position=2)
156 |
157 | for micro_batch in micro_batches:
158 | model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
159 | values = self._forward_micro_batch(model_inputs)
160 | values_lst.append(values)
161 |
162 | values = torch.concat(values_lst, dim=0)
163 | responses = data.batch["responses"]
164 | attention_mask = data.batch["attention_mask"]
165 | response_length = responses.size(1)
166 | values = values * attention_mask[:, -response_length - 1 : -1]
167 | return values
168 |
169 | def update_critic(self, data: DataProto) -> Dict[str, Any]:
170 | self.critic_module.train()
171 |
172 | select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"]
173 | if "multi_modal_inputs" in data.non_tensor_batch.keys():
174 | non_tensor_select_keys = ["multi_modal_inputs"]
175 | else:
176 | non_tensor_select_keys = []
177 |
178 | # Split to make minibatch iterator for updating the actor
179 | # See PPO paper for details. https://arxiv.org/abs/1707.06347
180 | mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)
181 |
182 | metrics = defaultdict(list)
183 | for _ in range(self.config.ppo_epochs):
184 | if self.rank == 0:
185 | mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2)
186 |
187 | for mini_batch in mini_batches:
188 | gradient_accumulation = (
189 | self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
190 | )
191 | micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
192 | if self.rank == 0:
193 | micro_batches = tqdm(micro_batches, desc="Update critic", position=3)
194 |
195 | for micro_batch in micro_batches:
196 | model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
197 | responses = model_inputs["responses"]
198 | attention_mask = model_inputs["attention_mask"]
199 | values = model_inputs["values"]
200 | returns = model_inputs["returns"]
201 | response_length = responses.size(1)
202 | action_mask = attention_mask[:, -response_length - 1 : -1] # shift left for value computation
203 |
204 | vpreds = self._forward_micro_batch(model_inputs)
205 | vf_loss, vf_clipfrac = core_algos.compute_value_loss(
206 | vpreds=vpreds,
207 | returns=returns,
208 | values=values,
209 | action_mask=action_mask,
210 | cliprange_value=self.config.cliprange_value,
211 | )
212 | loss = vf_loss / gradient_accumulation
213 | loss.backward()
214 |
215 | batch_metrics = {
216 | "critic/vf_loss": vf_loss.detach().item(),
217 | "critic/vf_clipfrac": vf_clipfrac.detach().item(),
218 | "critic/vpred_mean": VF.masked_mean(vpreds, action_mask).detach().item(),
219 | }
220 | append_to_dict(metrics, batch_metrics)
221 |
222 | grad_norm = self._optimizer_step()
223 | append_to_dict(metrics, {"critic/grad_norm": grad_norm.detach().item()})
224 |
225 | return metrics
226 |
--------------------------------------------------------------------------------
/verl/utils/seqlen_balancing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import heapq
17 | from typing import List, Tuple
18 |
19 | import torch
20 | from tensordict import TensorDict
21 | from torch import distributed as dist
22 |
23 |
24 | class Set:
25 | def __init__(self) -> None:
26 | self.sum = 0
27 | self.items = []
28 |
29 | def add(self, idx: int, val: int):
30 | self.items.append((idx, val))
31 | self.sum += val
32 |
33 | def merge(self, other):
34 | for idx, val in other.items:
35 | self.items.append((idx, val))
36 | self.sum += val
37 |
38 | def __lt__(self, other):
39 | if self.sum != other.sum:
40 | return self.sum < other.sum
41 | if len(self.items) != len(other.items):
42 | return len(self.items) < len(other.items)
43 | return self.items < other.items
44 |
45 |
46 | class State:
47 | def __init__(self, items: List[Tuple[int, int]], k: int) -> None:
48 | self.k = k
49 | # sets should always be decreasing order
50 | self.sets = [Set() for _ in range(k)]
51 | assert len(items) in [1, k], f"{len(items)} not in [1, {k}]"
52 | for i, (idx, seqlen) in enumerate(items):
53 | self.sets[i].add(idx=idx, val=seqlen)
54 | self.sets = sorted(self.sets, reverse=True)
55 |
56 | def get_partitions(self):
57 | partitions = []
58 | for i in range(len(self.sets)):
59 | cur_partition = []
60 | for idx, _ in self.sets[i].items:
61 | cur_partition.append(idx)
62 | partitions.append(cur_partition)
63 | return partitions
64 |
65 | def merge(self, other):
66 | for i in range(self.k):
67 | self.sets[i].merge(other.sets[self.k - 1 - i])
68 | self.sets = sorted(self.sets, reverse=True)
69 |
70 | @property
71 | def spread(self) -> int:
72 | return self.sets[0].sum - self.sets[-1].sum
73 |
74 | def __lt__(self, other):
75 | # least heap, let the state with largest spread to be popped first,
76 | # if the spread is the same, let the state who has the largest set
77 | # to be popped first.
78 | if self.spread != other.spread:
79 | return self.spread > other.spread
80 | return self.sets[0] > other.sets[0]
81 |
82 | def __repr__(self) -> str:
83 | repr_str = "["
84 | for i in range(self.k):
85 | if i > 0:
86 | repr_str += ","
87 | repr_str += "{"
88 | for j, (_, seqlen) in enumerate(self.sets[i].items):
89 | if j > 0:
90 | repr_str += ","
91 | repr_str += str(seqlen)
92 | repr_str += "}"
93 | repr_str += "]"
94 | return repr_str
95 |
96 |
97 | def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
98 | # see: https://en.wikipedia.org/wiki/Largest_differencing_method
99 | sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])
100 | states_pq: List[State] = []
101 | if equal_size:
102 | assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0"
103 | for offset in range(0, len(sorted_seqlen_list), k_partitions):
104 | items = []
105 | for i in range(k_partitions):
106 | seqlen, idx = sorted_seqlen_list[offset + i]
107 | items.append((idx, seqlen))
108 | heapq.heappush(states_pq, State(items=items, k=k_partitions))
109 | else:
110 | for seqlen, idx in sorted_seqlen_list:
111 | heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))
112 |
113 | while len(states_pq) > 1:
114 | state0 = heapq.heappop(states_pq)
115 | state1 = heapq.heappop(states_pq)
116 | # merge states
117 | state0.merge(state1)
118 | heapq.heappush(states_pq, state0)
119 |
120 | final_state = states_pq[0]
121 | partitions = final_state.get_partitions()
122 | if equal_size:
123 | for i, partition in enumerate(partitions):
124 | assert len(partition) * k_partitions == len(seqlen_list), (
125 | f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
126 | )
127 | return partitions
128 |
129 |
130 | def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool):
131 | bias = sum(seqlen_list) + 1 if equal_size else 0
132 | sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]
133 | partitions = [[] for _ in range(k_partitions)]
134 | partition_sums = [0 for _ in range(k_partitions)]
135 | for seqlen, i in sorted_seqlen:
136 | min_idx = None
137 | for j in range(k_partitions):
138 | if min_idx is None or partition_sums[j] < partition_sums[min_idx]:
139 | min_idx = j
140 | partitions[min_idx].append(i)
141 | partition_sums[min_idx] += seqlen
142 | if equal_size:
143 | for i, partition in enumerate(partitions):
144 | assert len(partition) * k_partitions == len(seqlen_list), (
145 | f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
146 | )
147 | return partitions
148 |
149 |
150 | def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool):
151 | """get order of seq lengths to make partitions balanced, this is
152 | used in balacing sum of seqlength across dp ranks and microbatches
153 | Parameters:
154 | seqlen_list (List[int]):
155 | seq lengths of each items
156 | k_partitions (int):
157 | resulting number of partitions
158 | equal_size (bool):
159 | if True, number of items in each partitions must be equal.
160 | if False, only consider balancing the sum, each partition can have
161 | variable number of items
162 | Returns:
163 | partitions (List[List[int]]):
164 | return k_partitions list containing the index of items.
165 | """
166 | assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
167 |
168 | def _check_and_sort_partitions(partitions):
169 | assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}"
170 | seen_idx = set()
171 | sorted_partitions = [None] * k_partitions
172 | for i, partition in enumerate(partitions):
173 | assert len(partition) > 0, f"the {i}-th partition is empty"
174 | for idx in partition:
175 | seen_idx.add(idx)
176 | sorted_partitions[i] = sorted(partition)
177 | assert seen_idx == set(range(len(seqlen_list)))
178 | return sorted_partitions
179 |
180 | partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
181 | return _check_and_sort_partitions(partitions)
182 |
183 |
184 | def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix):
185 | # add some metrics of seqlen sum on dp ranks
186 | k_partition = len(partitions)
187 | # assert len(seqlen_list) % k_partition == 0
188 | batch_size = len(seqlen_list) // k_partition
189 | min_sum_seqlen = None
190 | max_sum_seqlen = None
191 | total_sum_seqlen = 0
192 | for offset in range(0, len(seqlen_list), batch_size):
193 | cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])
194 | if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
195 | min_sum_seqlen = cur_sum_seqlen
196 | if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:
197 | max_sum_seqlen = cur_sum_seqlen
198 | total_sum_seqlen += cur_sum_seqlen
199 |
200 | balanced_sum_seqlen_list = []
201 | for partition in partitions:
202 | cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])
203 | balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)
204 | # print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list)
205 | min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)
206 | max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)
207 |
208 | return {
209 | f"{prefix}/min": min_sum_seqlen,
210 | f"{prefix}/max": max_sum_seqlen,
211 | f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen,
212 | f"{prefix}/balanced_min": min_sum_seqlen_balanced,
213 | f"{prefix}/balanced_max": max_sum_seqlen_balanced,
214 | f"{prefix}/mean": total_sum_seqlen / len(partitions),
215 | }
216 |
217 |
218 | def ceildiv(a, b):
219 | return -(a // -b)
220 |
221 |
222 | def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
223 | """Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len
224 | and the number of valid tokens in each micro batch is well balanced.
225 | """
226 | # this is per local micro_bsz
227 | max_seq_len = batch["attention_mask"].shape[-1]
228 | assert max_token_len >= max_seq_len, (
229 | f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}"
230 | )
231 |
232 | seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
233 | total_seqlen = seq_len_effective.sum().item()
234 | num_micro_batches = ceildiv(total_seqlen, max_token_len)
235 | if dist.is_initialized():
236 | num_micro_batches = torch.tensor([num_micro_batches], device="cuda")
237 | dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
238 | num_micro_batches = num_micro_batches.cpu().item()
239 |
240 | seq_len_effective = seq_len_effective.tolist()
241 | assert num_micro_batches <= len(seq_len_effective)
242 |
243 | micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)
244 |
245 | micro_batches = []
246 |
247 | for partition in micro_bsz_idx:
248 | curr_micro_batch = []
249 | for idx in partition:
250 | curr_micro_batch.append(batch[idx : idx + 1])
251 | curr_micro_batch = torch.cat(curr_micro_batch)
252 |
253 | micro_batches.append(curr_micro_batch)
254 |
255 | return micro_batches, micro_bsz_idx
256 |
257 |
258 | def get_reverse_idx(idx_map):
259 | reverse_idx_map = copy.deepcopy(idx_map)
260 |
261 | for i, idx in enumerate(idx_map):
262 | reverse_idx_map[idx] = i
263 |
264 | return reverse_idx_map
265 |
--------------------------------------------------------------------------------