├── eval
├── utils
│ ├── __init__.py
│ ├── data_loaders.py
│ └── processing.py
├── example_eval.sh
└── main.py
├── assets
├── teaser_comparison.png
├── noisyrollout_workflow.png
└── noisyrollout_workflow_caption.png
├── requirements.txt
├── Makefile
├── verl
├── utils
│ ├── __init__.py
│ ├── logger
│ │ ├── __init__.py
│ │ └── aggregate_logger.py
│ ├── checkpoint
│ │ ├── __init__.py
│ │ ├── checkpoint_manager.py
│ │ └── fsdp_checkpoint_manager.py
│ ├── reward_score
│ │ ├── __init__.py
│ │ ├── math.py
│ │ └── r1v.py
│ ├── py_functional.py
│ ├── tokenizer.py
│ ├── torch_dtypes.py
│ ├── model_utils.py
│ ├── image_aug.py
│ ├── fsdp_utils.py
│ ├── flops_counter.py
│ ├── tracking.py
│ └── dataset.py
├── models
│ ├── __init__.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── qwen2_vl.py
│ │ └── flash_attention_utils.py
│ └── monkey_patch.py
├── trainer
│ ├── __init__.py
│ ├── config.py
│ ├── main.py
│ └── metrics.py
├── workers
│ ├── __init__.py
│ ├── rollout
│ │ ├── __init__.py
│ │ ├── vllm_rollout
│ │ │ ├── __init__.py
│ │ │ └── vllm_rollout_spmd.py
│ │ ├── base.py
│ │ └── config.py
│ ├── reward
│ │ ├── __init__.py
│ │ ├── config.py
│ │ └── custom.py
│ ├── critic
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── config.py
│ ├── sharding_manager
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── fsdp_ulysses.py
│ │ └── fsdp_vllm.py
│ ├── actor
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── config.py
│ └── config.py
├── single_controller
│ ├── __init__.py
│ ├── base
│ │ ├── register_center
│ │ │ ├── __init__.py
│ │ │ └── ray.py
│ │ ├── __init__.py
│ │ ├── worker.py
│ │ ├── decorator.py
│ │ └── worker_group.py
│ └── ray
│ │ └── __init__.py
└── __init__.py
├── pyproject.toml
├── training_scripts
├── README.md
├── qwen2_5_vl_7b_geo3k_grpo.sh
├── qwen2_5_vl_7b_k12_grpo.sh
├── qwen2_5_vl_7b_geo3k_noisyrollout.sh
├── qwen2_5_vl_7b_k12_noisyrollout.sh
└── config.yaml
├── setup.py
├── .gitignore
├── README.md
└── scripts
└── model_merger.py
/eval/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/teaser_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NUS-TRAIL/NoisyRollout/HEAD/assets/teaser_comparison.png
--------------------------------------------------------------------------------
/assets/noisyrollout_workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NUS-TRAIL/NoisyRollout/HEAD/assets/noisyrollout_workflow.png
--------------------------------------------------------------------------------
/assets/noisyrollout_workflow_caption.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NUS-TRAIL/NoisyRollout/HEAD/assets/noisyrollout_workflow_caption.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | codetiming
3 | datasets
4 | flash-attn>=2.4.3
5 | liger-kernel
6 | mathruler
7 | numpy
8 | omegaconf
9 | pandas
10 | peft
11 | pillow
12 | pyarrow>=15.0.0
13 | pylatexenc
14 | qwen-vl-utils
15 | ray
16 | tensordict
17 | torchdata
18 | transformers==4.49.0
19 | wandb
20 | vllm>=0.7.3
21 | torch==2.5.1
22 | torchvision==0.20.1
23 | torchaudio==2.5.1
24 | numpy==1.26.4
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: build commit quality style
2 |
3 | check_dirs := scripts verl setup.py
4 |
5 | build:
6 | python3 setup.py sdist bdist_wheel
7 |
8 | commit:
9 | pre-commit install
10 | pre-commit run --all-files
11 |
12 | quality:
13 | ruff check $(check_dirs)
14 | ruff format --check $(check_dirs)
15 |
16 | style:
17 | ruff check $(check_dirs) --fix
18 | ruff format $(check_dirs)
19 |
--------------------------------------------------------------------------------
/verl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/workers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/utils/logger/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/single_controller/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/utils/checkpoint/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/single_controller/base/register_center/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .protocol import DataProto
16 |
17 |
18 | __all__ = ["DataProto"]
19 | __version__ = "0.2.0.dev"
20 |
--------------------------------------------------------------------------------
/verl/workers/rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .config import RolloutConfig
17 |
18 |
19 | __all__ = ["RolloutConfig"]
20 |
--------------------------------------------------------------------------------
/verl/workers/rollout/vllm_rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .vllm_rollout_spmd import vLLMRollout
16 |
17 |
18 | __all__ = ["vLLMRollout"]
19 |
--------------------------------------------------------------------------------
/verl/workers/reward/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .config import RewardConfig
16 | from .custom import CustomRewardManager
17 |
18 |
19 | __all__ = ["CustomRewardManager", "RewardConfig"]
20 |
--------------------------------------------------------------------------------
/verl/workers/reward/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Reward config
16 | """
17 |
18 | from dataclasses import dataclass
19 |
20 |
21 | @dataclass
22 | class RewardConfig:
23 | reward_type: str = "function"
24 | compute_score: str = "math"
25 |
--------------------------------------------------------------------------------
/verl/single_controller/base/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .worker import Worker
16 | from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
17 |
18 |
19 | __all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
20 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "verl"
7 | dynamic = ["version", "dependencies", "optional-dependencies", "readme", "license"]
8 | requires-python = ">=3.8"
9 |
10 | [tool.ruff]
11 | target-version = "py38"
12 | line-length = 119
13 | indent-width = 4
14 |
15 | [tool.ruff.lint]
16 | ignore = ["C901", "E501", "E741", "W605", "C408"]
17 | select = ["C", "E", "F", "I", "W", "RUF022"]
18 |
19 | [tool.ruff.lint.per-file-ignores]
20 | "__init__.py" = ["E402", "F401", "F403", "F811"]
21 |
22 | [tool.ruff.lint.isort]
23 | lines-after-imports = 2
24 | known-first-party = ["verl"]
25 | known-third-party = ["torch", "transformers", "wandb"]
26 |
27 | [tool.ruff.format]
28 | quote-style = "double"
29 | indent-style = "space"
30 | skip-magic-trailing-comma = false
31 | line-ending = "auto"
32 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .math import math_compute_score, pure_math_compute_score
17 | from .r1v import r1v_compute_score
18 |
19 |
20 | __all__ = ["math_compute_score", "r1v_compute_score", "pure_math_compute_score"]
21 |
--------------------------------------------------------------------------------
/verl/single_controller/ray/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls
16 |
17 |
18 | __all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"]
19 |
--------------------------------------------------------------------------------
/verl/workers/critic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPOCritic
16 | from .config import CriticConfig, ModelConfig
17 | from .dp_critic import DataParallelPPOCritic
18 |
19 |
20 | __all__ = ["BasePPOCritic", "CriticConfig", "DataParallelPPOCritic", "ModelConfig"]
21 |
--------------------------------------------------------------------------------
/training_scripts/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | * In our current implementation of NoisyRollout, $n_1$ and $n_2$ are set to the same value. If you set `worker.rollout.n` to 6 in our NoisyRollout training scripts, it means $n_1=n_2=6$, resulting in a total of 12 rollouts.
6 | * We empirically freeze the vision encoder and remove the KL loss for better performance.
7 | * For `worker.actor.aug_type`, please only use `gaussian` for now, as the performance of other choices is not guaranteed.
8 | * For `worker.actor.gaussian_noise_step` (i.e., initial noise strength $\alpha_0$ in our paper), it's a critical hyper-parameter when applying NoisyRollout to other datasets. Suggested values include 400, 450, and 500.
9 | * We also suggest setting `worker.actor.decay_mode=sigmoid` with an appropriate `worker.actor.decay_decay_sig_mid_step` for stable training.
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .base import BaseShardingManager
17 | from .fsdp_ulysses import FSDPUlyssesShardingManager
18 | from .fsdp_vllm import FSDPVLLMShardingManager
19 |
20 |
21 | __all__ = ["BaseShardingManager", "FSDPUlyssesShardingManager", "FSDPVLLMShardingManager"]
22 |
--------------------------------------------------------------------------------
/verl/workers/rollout/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from abc import ABC, abstractmethod
16 |
17 | from ...protocol import DataProto
18 |
19 |
20 | __all__ = ["BaseRollout"]
21 |
22 |
23 | class BaseRollout(ABC):
24 | @abstractmethod
25 | def generate_sequences(self, prompts: DataProto) -> DataProto:
26 | """Generate sequences"""
27 | pass
28 |
--------------------------------------------------------------------------------
/eval/example_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source ~/.bashrc
3 | source ~/miniconda3/bin/activate noisyrollout
4 |
5 | export VLLM_ATTENTION_BACKEND=XFORMERS
6 | export VLLM_USE_V1=0
7 | export GOOGLE_API_KEY="xxx"
8 |
9 | # Define list of model paths to evaluate
10 | HF_MODEL_PATH=""
11 | RESULTS_DIR="results/"
12 | EVAL_DIR="~/NoisyRollout/eval"
13 | DATA_DIR="~/NoisyRollout/eval/data"
14 |
15 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
16 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}."""
17 |
18 | cd $EVAL_DIR
19 | python main.py \
20 | --model $HF_MODEL_PATH \
21 | --output-dir $RESULTS_DIR \
22 | --data-path $DATA_DIR \
23 | --datasets geo3k,hallubench,mathvista,wemath,mathverse,mathvision \
24 | --tensor-parallel-size 2 \
25 | --system-prompt="$SYSTEM_PROMPT" \
26 | --min-pixels 262144 \
27 | --max-pixels 1000000 \
28 | --max-model-len 8192 \
29 | --temperature 0.0 \
30 | --eval-threads 24
--------------------------------------------------------------------------------
/verl/workers/actor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPOActor
16 | from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
17 | from .dp_actor import DataParallelPPOActor
18 |
19 |
20 | __all__ = [
21 | "ActorConfig",
22 | "BasePPOActor",
23 | "DataParallelPPOActor",
24 | "FSDPConfig",
25 | "ModelConfig",
26 | "OptimConfig",
27 | "RefConfig",
28 | ]
29 |
--------------------------------------------------------------------------------
/verl/single_controller/base/register_center/ray.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ray
16 |
17 |
18 | @ray.remote
19 | class WorkerGroupRegisterCenter:
20 | def __init__(self, rank_zero_info):
21 | self.rank_zero_info = rank_zero_info
22 |
23 | def get_rank_zero_info(self):
24 | return self.rank_zero_info
25 |
26 |
27 | def create_worker_group_register_center(name, info):
28 | return WorkerGroupRegisterCenter.options(name=name).remote(info)
29 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Sharding manager to implement HybridEngine
16 | """
17 |
18 | from ...protocol import DataProto
19 |
20 |
21 | class BaseShardingManager:
22 | def __enter__(self):
23 | pass
24 |
25 | def __exit__(self, exc_type, exc_value, traceback):
26 | pass
27 |
28 | def preprocess_data(self, data: DataProto) -> DataProto:
29 | return data
30 |
31 | def postprocess_data(self, data: DataProto) -> DataProto:
32 | return data
33 |
--------------------------------------------------------------------------------
/verl/workers/critic/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Base class for Critic
16 | """
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, Dict
20 |
21 | import torch
22 |
23 | from ...protocol import DataProto
24 | from .config import CriticConfig
25 |
26 |
27 | __all__ = ["BasePPOCritic"]
28 |
29 |
30 | class BasePPOCritic(ABC):
31 | def __init__(self, config: CriticConfig):
32 | self.config = config
33 |
34 | @abstractmethod
35 | def compute_values(self, data: DataProto) -> torch.Tensor:
36 | """Compute values"""
37 | pass
38 |
39 | @abstractmethod
40 | def update_critic(self, data: DataProto) -> Dict[str, Any]:
41 | """Update the critic"""
42 | pass
43 |
--------------------------------------------------------------------------------
/verl/utils/logger/aggregate_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A Ray logger will receive logging info from different processes.
16 | """
17 |
18 | import numbers
19 | from typing import Any, Dict
20 |
21 |
22 | def concat_dict_to_str(dict: Dict[str, Any], step: int) -> str:
23 | output = [f"step {step}:"]
24 | for k, v in dict.items():
25 | if isinstance(v, numbers.Number):
26 | output.append(f"{k}:{v:.3f}")
27 |
28 | output_str = " - ".join(output)
29 | return output_str
30 |
31 |
32 | class LocalLogger:
33 | def __init__(self):
34 | pass
35 |
36 | def flush(self):
37 | pass
38 |
39 | def log(self, data: Dict[str, Any], step: int) -> None:
40 | print(concat_dict_to_str(data, step=step), flush=True)
41 |
--------------------------------------------------------------------------------
/verl/utils/py_functional.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contain small python utility functions
16 | """
17 |
18 | from typing import Any, Dict, List
19 |
20 |
21 | def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
22 | """Union two dict. Will throw an error if there is an item not the same object with the same key."""
23 | for key in dict2.keys():
24 | if key in dict1:
25 | assert dict1[key] == dict2[key], f"{key} in meta_dict1 and meta_dict2 are not the same object"
26 |
27 | dict1[key] = dict2[key]
28 |
29 | return dict1
30 |
31 |
32 | def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None:
33 | for key, val in new_data.items():
34 | if key not in data:
35 | data[key] = []
36 |
37 | data[key].append(val)
38 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/math.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 |
17 | from mathruler.grader import extract_boxed_content, grade_answer
18 |
19 |
20 | def math_format_reward(predict_str: str) -> float:
21 | pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL)
22 | format_match = re.fullmatch(pattern, predict_str)
23 | return 1.0 if format_match else 0.0
24 |
25 |
26 | def math_acc_reward(predict_str: str, ground_truth: str) -> float:
27 | answer = extract_boxed_content(predict_str)
28 | return 1.0 if grade_answer(answer, ground_truth) else 0.0
29 |
30 |
31 | def pure_math_compute_score(predict_str: str, ground_truth: str) -> float:
32 | return math_acc_reward(predict_str, ground_truth)
33 |
34 |
35 | def math_compute_score(predict_str: str, ground_truth: str) -> float:
36 | return 0.9 * math_acc_reward(predict_str, ground_truth) + 0.1 * math_format_reward(predict_str)
--------------------------------------------------------------------------------
/verl/workers/rollout/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Rollout config
16 | """
17 |
18 | from dataclasses import asdict, dataclass, field
19 |
20 |
21 | @dataclass
22 | class RolloutConfig:
23 | name: str = "vllm"
24 | temperature: float = 1.0
25 | top_k: int = -1
26 | top_p: float = 1.0
27 | dtype: str = "bf16"
28 | gpu_memory_utilization: float = 0.5
29 | ignore_eos: bool = False
30 | enforce_eager: bool = False
31 | free_cache_engine: bool = False
32 | enable_chunked_prefill: bool = False
33 | tensor_parallel_size: int = 2
34 | max_num_batched_tokens: int = 8192
35 | max_num_seqs: int = 1024
36 | disable_log_stats: bool = True
37 | do_sample: bool = True
38 | n: int = 1
39 | limit_images: int = 0
40 | """auto keys"""
41 | prompt_length: int = field(default=-1, init=False)
42 | response_length: int = field(default=-1, init=False)
43 |
44 | def to_dict(self):
45 | return asdict(self)
46 |
--------------------------------------------------------------------------------
/verl/models/monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
17 |
18 | from .transformers.flash_attention_utils import flash_attention_forward
19 | from .transformers.qwen2_vl import qwen2_vl_attn_forward
20 |
21 |
22 | def apply_ulysses_patch(model_type: str) -> None:
23 | if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
24 | ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
25 | elif model_type in ("qwen2_vl", "qwen2_5_vl"):
26 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
27 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
28 |
29 | Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
30 | Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
31 | else:
32 | raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
33 |
--------------------------------------------------------------------------------
/verl/workers/critic/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Critic config
16 | """
17 |
18 | from dataclasses import dataclass, field
19 |
20 | from ..actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig
21 |
22 |
23 | @dataclass
24 | class CriticConfig:
25 | strategy: str = "fsdp"
26 | global_batch_size: int = 256
27 | micro_batch_size_per_device_for_update: int = 4
28 | micro_batch_size_per_device_for_experience: int = 16
29 | max_grad_norm: float = 1.0
30 | cliprange_value: float = 0.5
31 | ppo_epochs: int = 1
32 | padding_free: bool = False
33 | ulysses_sequence_parallel_size: int = 1
34 | model: ModelConfig = field(default_factory=ModelConfig)
35 | optim: OptimConfig = field(default_factory=OptimConfig)
36 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
37 | offload: OffloadConfig = field(default_factory=OffloadConfig)
38 | """auto keys"""
39 | global_batch_size_per_device: int = field(default=-1, init=False)
40 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/r1v.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 |
17 | from mathruler.grader import grade_answer
18 |
19 |
20 | def r1v_format_reward(predict_str: str) -> float:
21 | pattern = re.compile(r".*?\s*.*?", re.DOTALL)
22 | format_match = re.fullmatch(pattern, predict_str)
23 | return 1.0 if format_match else 0.0
24 |
25 |
26 | def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
27 | try:
28 | ground_truth = ground_truth.strip()
29 | content_match = re.search(r"(.*?)", predict_str)
30 | given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
31 | if grade_answer(given_answer, ground_truth):
32 | return 1.0
33 | except Exception:
34 | pass
35 |
36 | return 0.0
37 |
38 |
39 | def r1v_compute_score(predict_str: str, ground_truth: str) -> float:
40 | return 0.5 * r1v_accuracy_reward(predict_str, ground_truth) + 0.5 * r1v_format_reward(predict_str)
41 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from setuptools import find_packages, setup
17 |
18 |
19 | def get_requires():
20 | with open("requirements.txt", encoding="utf-8") as f:
21 | file_content = f.read()
22 | lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
23 | return lines
24 |
25 |
26 | extra_require = {
27 | "dev": ["pre-commit", "ruff"],
28 | }
29 |
30 |
31 | def main():
32 | setup(
33 | name="verl",
34 | version="0.2.0.dev0",
35 | package_dir={"": "."},
36 | packages=find_packages(where="."),
37 | url="https://github.com/volcengine/verl",
38 | license="Apache 2.0",
39 | author="verl",
40 | author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
41 | description="",
42 | install_requires=get_requires(),
43 | extras_require=extra_require,
44 | long_description=open("README.md", encoding="utf-8").read(),
45 | long_description_content_type="text/markdown",
46 | )
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/training_scripts/qwen2_5_vl_7b_geo3k_grpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 | source ~/.bashrc
4 | source ~/miniconda3/bin/activate noisyrollout
5 | cd ~/NoisyRollout
6 |
7 | export VLLM_ATTENTION_BACKEND=XFORMERS
8 | export VLLM_USE_V1=0
9 |
10 | export WANDB_BASE_URL=https://api.wandb.ai
11 | export WANDB_API_KEY="xxx"
12 |
13 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct
14 | EXPERIMENT_NAME=qwen2_5_vl_7b_geo3k_grpo
15 | PROJECT_NAME=noisy_rollout
16 | CHECKPOINT_DIR="checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}"
17 |
18 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
19 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}."""
20 |
21 | python3 -m verl.trainer.main \
22 | config=training_scripts/config.yaml \
23 | data.train_files=xyliu6/geometry3k@train \
24 | data.val_files=xyliu6/geometry3k@test \
25 | data.system_prompt="${SYSTEM_PROMPT}" \
26 | data.max_response_length=2048 \
27 | data.max_pixels=1000000 \
28 | worker.actor.micro_batch_size_per_device_for_update=2 \
29 | worker.actor.micro_batch_size_per_device_for_experience=4 \
30 | worker.actor.model.freeze_vision_tower=true \
31 | worker.actor.model.model_path=${MODEL_PATH} \
32 | worker.actor.use_kl_loss=false \
33 | worker.actor.offload.offload_params=true \
34 | worker.actor.offload.offload_optimizer=true \
35 | worker.reward.compute_score=math \
36 | worker.rollout.gpu_memory_utilization=0.35 \
37 | worker.rollout.tensor_parallel_size=4 \
38 | worker.rollout.n=12 \
39 | worker.rollout.enable_chunked_prefill=false \
40 | trainer.experiment_name=${EXPERIMENT_NAME} \
41 | trainer.project_name=${PROJECT_NAME} \
42 | trainer.n_gpus_per_node=8 \
43 | trainer.save_freq=20 \
44 | worker.actor.is_noisy=false
45 |
--------------------------------------------------------------------------------
/training_scripts/qwen2_5_vl_7b_k12_grpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 | source ~/.bashrc
4 | source ~/miniconda3/bin/activate noisyrollout
5 | cd ~/NoisyRollout
6 |
7 | export VLLM_ATTENTION_BACKEND=XFORMERS
8 | export VLLM_USE_V1=0
9 |
10 | export WANDB_BASE_URL=https://api.wandb.ai
11 | export WANDB_API_KEY="xxx"
12 |
13 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct
14 | EXPERIMENT_NAME=qwen2_5_vl_7b_k12_grpo
15 | PROJECT_NAME=noisy_rollout
16 | CHECKPOINT_DIR="checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}"
17 |
18 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
19 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}."""
20 |
21 | python3 -m verl.trainer.main \
22 | config=training_scripts/config.yaml \
23 | data.train_files=xyliu6/k12-freeform@mini_train \
24 | data.val_files=xyliu6/k12-freeform@test \
25 | data.system_prompt="${SYSTEM_PROMPT}" \
26 | data.max_response_length=2048 \
27 | data.max_pixels=1000000 \
28 | worker.actor.micro_batch_size_per_device_for_update=2 \
29 | worker.actor.micro_batch_size_per_device_for_experience=4 \
30 | worker.actor.model.freeze_vision_tower=true \
31 | worker.actor.model.model_path=${MODEL_PATH} \
32 | worker.actor.use_kl_loss=false \
33 | worker.actor.offload.offload_params=true \
34 | worker.actor.offload.offload_optimizer=true \
35 | worker.reward.compute_score=math \
36 | worker.rollout.gpu_memory_utilization=0.35 \
37 | worker.rollout.tensor_parallel_size=4 \
38 | worker.rollout.n=12 \
39 | worker.rollout.enable_chunked_prefill=false \
40 | trainer.experiment_name=${EXPERIMENT_NAME} \
41 | trainer.project_name=${PROJECT_NAME} \
42 | trainer.n_gpus_per_node=8 \
43 | trainer.save_freq=20 \
44 | worker.actor.is_noisy=false
--------------------------------------------------------------------------------
/verl/workers/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | ActorRolloutRef config
16 | """
17 |
18 | from dataclasses import dataclass, field
19 |
20 | from .actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
21 | from .critic import CriticConfig
22 | from .reward import RewardConfig
23 | from .rollout import RolloutConfig
24 |
25 |
26 | __all__ = [
27 | "ActorConfig",
28 | "CriticConfig",
29 | "FSDPConfig",
30 | "ModelConfig",
31 | "OptimConfig",
32 | "RefConfig",
33 | "RewardConfig",
34 | "RolloutConfig",
35 | "WorkerConfig",
36 | ]
37 |
38 |
39 | @dataclass
40 | class WorkerConfig:
41 | hybrid_engine: bool = True
42 | actor: ActorConfig = field(default_factory=ActorConfig)
43 | critic: CriticConfig = field(default_factory=CriticConfig)
44 | ref: RefConfig = field(default_factory=RefConfig)
45 | reward: RewardConfig = field(default_factory=RewardConfig)
46 | rollout: RolloutConfig = field(default_factory=RolloutConfig)
47 |
48 | def post_init(self):
49 | self.ref.padding_free = self.actor.padding_free
50 | self.ref.ulysses_sequence_parallel_size = self.actor.ulysses_sequence_parallel_size
51 | self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience
52 |
--------------------------------------------------------------------------------
/training_scripts/qwen2_5_vl_7b_geo3k_noisyrollout.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 | source ~/.bashrc
4 | source ~/miniconda3/bin/activate noisyrollout
5 | cd ~/NoisyRollout
6 |
7 | export VLLM_ATTENTION_BACKEND=XFORMERS
8 | export VLLM_USE_V1=0
9 |
10 | export WANDB_BASE_URL=https://api.wandb.ai
11 | export WANDB_API_KEY="xxx"
12 |
13 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct
14 | EXPERIMENT_NAME=qwen2_5_vl_7b_geo3k_noisyrollout
15 | PROJECT_NAME=noisy_rollout
16 | CHECKPOINT_DIR="checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}"
17 |
18 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
19 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}."""
20 |
21 | python3 -m verl.trainer.main \
22 | config=training_scripts/config.yaml \
23 | data.train_files=xyliu6/geometry3k@train \
24 | data.val_files=xyliu6/geometry3k@test \
25 | data.system_prompt="${SYSTEM_PROMPT}" \
26 | data.max_response_length=2048 \
27 | data.max_pixels=1000000 \
28 | worker.actor.micro_batch_size_per_device_for_update=2 \
29 | worker.actor.micro_batch_size_per_device_for_experience=4 \
30 | worker.actor.model.freeze_vision_tower=true \
31 | worker.actor.model.model_path=${MODEL_PATH} \
32 | worker.actor.use_kl_loss=false \
33 | worker.actor.offload.offload_params=true \
34 | worker.actor.offload.offload_optimizer=true \
35 | worker.reward.compute_score=math \
36 | worker.rollout.gpu_memory_utilization=0.35 \
37 | worker.rollout.tensor_parallel_size=4 \
38 | worker.rollout.n=6 \
39 | worker.rollout.enable_chunked_prefill=false \
40 | trainer.experiment_name=${EXPERIMENT_NAME} \
41 | trainer.project_name=${PROJECT_NAME} \
42 | trainer.n_gpus_per_node=8 \
43 | trainer.save_freq=20 \
44 | worker.actor.is_noisy=true \
45 | worker.actor.aug_type=gaussian \
46 | worker.actor.gaussian_noise_step=500 \
47 | worker.actor.decay_mode=sigmoid \
48 | worker.actor.decay_coef=30 \
49 | worker.actor.decay_sig_mid_step=40
50 |
51 |
--------------------------------------------------------------------------------
/training_scripts/qwen2_5_vl_7b_k12_noisyrollout.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 | source ~/.bashrc
4 | source ~/miniconda3/bin/activate noisyrollout
5 | cd ~/NoisyRollout
6 |
7 | export VLLM_ATTENTION_BACKEND=XFORMERS
8 | export VLLM_USE_V1=0
9 |
10 | export WANDB_BASE_URL=https://api.wandb.ai
11 | export WANDB_API_KEY="xxx"
12 |
13 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct
14 | EXPERIMENT_NAME=qwen2_5_vl_7b_k12_noisyrollout
15 | PROJECT_NAME=noisy_rollout
16 | CHECKPOINT_DIR="checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}"
17 |
18 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
19 | The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}."""
20 |
21 | python3 -m verl.trainer.main \
22 | config=training_scripts/config.yaml \
23 | data.train_files=xyliu6/k12-freeform@train \
24 | data.val_files=xyliu6/k12-freeform@test \
25 | data.system_prompt="${SYSTEM_PROMPT}" \
26 | data.max_response_length=2048 \
27 | data.max_pixels=1000000 \
28 | worker.actor.micro_batch_size_per_device_for_update=2 \
29 | worker.actor.micro_batch_size_per_device_for_experience=4 \
30 | worker.actor.model.freeze_vision_tower=true \
31 | worker.actor.model.model_path=${MODEL_PATH} \
32 | worker.actor.use_kl_loss=false \
33 | worker.actor.offload.offload_params=true \
34 | worker.actor.offload.offload_optimizer=true \
35 | worker.reward.compute_score=math \
36 | worker.rollout.gpu_memory_utilization=0.35 \
37 | worker.rollout.tensor_parallel_size=4 \
38 | worker.rollout.n=6 \
39 | worker.rollout.enable_chunked_prefill=false \
40 | trainer.experiment_name=${EXPERIMENT_NAME} \
41 | trainer.project_name=${PROJECT_NAME} \
42 | trainer.n_gpus_per_node=8 \
43 | trainer.save_freq=20 \
44 | trainer.total_episodes=10 \
45 | worker.actor.is_noisy=true \
46 | worker.actor.aug_type=gaussian \
47 | worker.actor.gaussian_noise_step=450 \
48 | worker.actor.decay_mode=sigmoid \
49 | worker.actor.decay_coef=60 \
50 | worker.actor.decay_sig_mid_step=40
51 |
52 |
--------------------------------------------------------------------------------
/verl/workers/actor/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The base class for Actor
16 | """
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, Dict
20 |
21 | import torch
22 |
23 | from ...protocol import DataProto
24 | from .config import ActorConfig
25 |
26 |
27 | __all__ = ["BasePPOActor"]
28 |
29 |
30 | class BasePPOActor(ABC):
31 | def __init__(self, config: ActorConfig):
32 | """The base class for PPO actor
33 |
34 | Args:
35 | config (ActorConfig): a config passed to the PPOActor.
36 | """
37 | self.config = config
38 |
39 | @abstractmethod
40 | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
41 | """Compute logits given a batch of data.
42 |
43 | Args:
44 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
45 | ```attention_mask``` and ```position_ids```.
46 |
47 | Returns:
48 | DataProto: a DataProto containing the key ```log_probs```
49 | """
50 | pass
51 |
52 | @abstractmethod
53 | def update_policy(self, data: DataProto) -> Dict[str, Any]:
54 | """Update the policy with an iterator of DataProto
55 |
56 | Args:
57 | data (DataProto): an iterator over the DataProto that returns by
58 | ```make_minibatch_iterator```
59 |
60 | Returns:
61 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
62 | such as ```loss```, ```grad_norm```, etc,.
63 | """
64 | pass
65 |
--------------------------------------------------------------------------------
/verl/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for tokenization."""
15 |
16 | from typing import Optional
17 |
18 | from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
19 |
20 |
21 | def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer:
22 | """Create a huggingface pretrained tokenizer."""
23 | tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
24 |
25 | if tokenizer.bos_token == "" and tokenizer.eos_token == "":
26 | # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance.
27 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
28 | print("Found gemma model. Set eos_token and eos_token_id to and 107.")
29 | tokenizer.eos_token = ""
30 |
31 | if tokenizer.pad_token_id is None:
32 | print("Pad token is None. Set it to eos_token.")
33 | tokenizer.pad_token = tokenizer.eos_token
34 |
35 | return tokenizer
36 |
37 |
38 | def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]:
39 | """Create a huggingface pretrained processor."""
40 | try:
41 | processor = AutoProcessor.from_pretrained(model_path, **kwargs)
42 | except Exception:
43 | processor = None
44 |
45 | # Avoid load tokenizer, see:
46 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
47 | if processor is not None and "Processor" not in processor.__class__.__name__:
48 | processor = None
49 |
50 | return processor
51 |
--------------------------------------------------------------------------------
/verl/utils/torch_dtypes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 |
18 | HALF_LIST = [16, "16", "fp16", "float16"]
19 | FLOAT_LIST = [32, "32", "fp32", "float32"]
20 | BFLOAT_LIST = ["bf16", "bfloat16"]
21 |
22 |
23 | class PrecisionType:
24 | """Type of precision used.
25 |
26 | >>> PrecisionType.HALF == 16
27 | True
28 | >>> PrecisionType.HALF in (16, "16")
29 | True
30 | """
31 |
32 | HALF = "16"
33 | FLOAT = "32"
34 | FULL = "64"
35 | BFLOAT = "bf16"
36 | MIXED = "mixed"
37 |
38 | @staticmethod
39 | def is_fp16(precision):
40 | return precision in HALF_LIST
41 |
42 | @staticmethod
43 | def is_fp32(precision):
44 | return precision in FLOAT_LIST
45 |
46 | @staticmethod
47 | def is_bf16(precision):
48 | return precision in BFLOAT_LIST
49 |
50 | @staticmethod
51 | def to_dtype(precision) -> torch.dtype:
52 | if precision in HALF_LIST:
53 | return torch.float16
54 | elif precision in FLOAT_LIST:
55 | return torch.float32
56 | elif precision in BFLOAT_LIST:
57 | return torch.bfloat16
58 | else:
59 | raise RuntimeError(f"unexpected precision: {precision}")
60 |
61 | @staticmethod
62 | def to_str(precision: torch.dtype) -> str:
63 | if precision == torch.float16:
64 | return "float16"
65 | elif precision == torch.float32:
66 | return "float32"
67 | elif precision == torch.bfloat16:
68 | return "bfloat16"
69 | else:
70 | raise RuntimeError(f"unexpected precision: {precision}")
71 |
--------------------------------------------------------------------------------
/verl/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Utilities to create common models
16 | """
17 |
18 | from functools import lru_cache
19 | from typing import Tuple
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from torch import nn
24 |
25 |
26 | @lru_cache
27 | def is_rank0() -> int:
28 | return (not dist.is_initialized()) or (dist.get_rank() == 0)
29 |
30 |
31 | def print_gpu_memory_usage(prefix: str) -> None:
32 | if is_rank0():
33 | memory_allocated = torch.cuda.memory_allocated() / (1024**3)
34 | memory_reserved = torch.cuda.memory_reserved() / (1024**3)
35 | print(f"{prefix} memory allocated: {memory_allocated:.2f} GB, memory reserved: {memory_reserved:.2f} GB.")
36 |
37 |
38 | def get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]:
39 | n_params = sum(p.numel() for p in model.parameters())
40 |
41 | if scale == "auto":
42 | if n_params > 1e9:
43 | scale = "B"
44 | elif n_params > 1e6:
45 | scale = "M"
46 | elif n_params > 1e3:
47 | scale = "K"
48 | else:
49 | scale = ""
50 |
51 | if scale == "B":
52 | n_params = n_params / 1e9
53 | elif scale == "M":
54 | n_params = n_params / 1e6
55 | elif scale == "K":
56 | n_params = n_params / 1e3
57 | elif scale == "":
58 | pass
59 | else:
60 | raise NotImplementedError(f"Unknown scale {scale}.")
61 |
62 | return n_params, scale
63 |
64 |
65 | def print_model_size(model: nn.Module, name: str = None) -> None:
66 | n_params, scale = get_model_size(model, scale="auto")
67 | if name is None:
68 | name = model.__class__.__name__
69 |
70 | print(f"{name} contains {n_params:.2f}{scale} parameters")
71 |
--------------------------------------------------------------------------------
/training_scripts/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_files: hiyouga/math12k@train
3 | val_files: hiyouga/math12k@test
4 | prompt_key: problem
5 | answer_key: answer
6 | image_key: images
7 | max_prompt_length: 2048
8 | max_response_length: 2048
9 | rollout_batch_size: 512
10 | shuffle: true
11 | seed: 1
12 | max_pixels: 4194304
13 | min_pixels: 262144
14 |
15 | algorithm:
16 | adv_estimator: grpo
17 | kl_coef: 0.0
18 |
19 | worker:
20 | actor:
21 | global_batch_size: 128
22 | micro_batch_size_per_device_for_update: 4
23 | micro_batch_size_per_device_for_experience: 16
24 | max_grad_norm: 1.0
25 | entropy_coeff: 1.0e-3
26 | use_kl_loss: true
27 | kl_loss_coef: 1.0e-2
28 | kl_loss_type: low_var_kl
29 | padding_free: true
30 | ulysses_sequence_parallel_size: 1
31 | is_noisy: false
32 | aug_type: gaussian
33 | gaussian_noise_step: 500
34 | crop_size: 0.0
35 | rotate_angle: 15
36 | decay_coef: 1.0
37 | decay_mode: none
38 | decay_sig_mid_step: 40
39 | model:
40 | model_path: Qwen/Qwen2.5-7B-Instruct
41 | enable_gradient_checkpointing: true
42 | trust_remote_code: false
43 | freeze_vision_tower: false
44 | freeze_language_model: false
45 | optim:
46 | lr: 1.0e-6
47 | weight_decay: 1.0e-2
48 | lr_warmup_ratio: 0.0
49 | fsdp:
50 | enable_full_shard: true
51 | enable_cpu_offload: false
52 | enable_rank0_init: true
53 | offload:
54 | offload_params: true
55 | offload_optimizer: true
56 |
57 | rollout:
58 | temperature: 1.0
59 | n: 5
60 | gpu_memory_utilization: 0.5
61 | enforce_eager: false
62 | enable_chunked_prefill: false
63 | tensor_parallel_size: 2
64 | limit_images: 0
65 |
66 | ref:
67 | fsdp:
68 | enable_full_shard: true
69 | enable_cpu_offload: true
70 | enable_rank0_init: true
71 | offload:
72 | offload_params: false
73 |
74 | reward:
75 | reward_type: function
76 | compute_score: math
77 |
78 | trainer:
79 | total_episodes: 15
80 | logger: ["console", "wandb"]
81 | project_name: noisy_rollout
82 | experiment_name: qwen2_5_7b_math
83 | n_gpus_per_node: 8
84 | nnodes: 1
85 | val_freq: 5
86 | val_before_train: true
87 | val_only: false
88 | val_generations_to_log: 1
89 | save_freq: 5
90 | remove_previous_ckpt: false
91 | remove_ckpt_after_load: false
92 | save_checkpoint_path: null
93 | load_checkpoint_path: null
94 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/fsdp_ulysses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
16 | """
17 |
18 | from torch.distributed.device_mesh import DeviceMesh
19 |
20 | from ...protocol import DataProto, all_gather_data_proto
21 | from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group
22 | from .base import BaseShardingManager
23 |
24 |
25 | class FSDPUlyssesShardingManager(BaseShardingManager):
26 | """
27 | Sharding manager to support data resharding when using FSDP + Ulysses
28 | """
29 |
30 | def __init__(self, device_mesh: DeviceMesh):
31 | super().__init__()
32 | self.device_mesh = device_mesh
33 |
34 | def __enter__(self):
35 | if self.device_mesh is not None:
36 | self.prev_sp_group = get_ulysses_sequence_parallel_group()
37 | set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group())
38 |
39 | def __exit__(self, exc_type, exc_value, traceback):
40 | if self.device_mesh is not None:
41 | set_ulysses_sequence_parallel_group(self.prev_sp_group)
42 |
43 | def preprocess_data(self, data: DataProto) -> DataProto:
44 | """
45 | AllGather data from sp region
46 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
47 | In Ulysses, we need to make sure the same data is used across a SP group
48 | """
49 | if self.device_mesh is not None:
50 | sp_size = self.device_mesh["sp"].size()
51 | sp_group = self.device_mesh["sp"].get_group()
52 | all_gather_data_proto(data, size=sp_size, group=sp_group)
53 |
54 | return data
55 |
56 | def postprocess_data(self, data: DataProto) -> DataProto:
57 | """
58 | Split the data to follow FSDP partition
59 | """
60 | if self.device_mesh is not None:
61 | sp_size = self.device_mesh["sp"].size()
62 | sp_rank = self.device_mesh["sp"].get_local_rank()
63 | data = data.chunk(chunks=sp_size)[sp_rank]
64 |
65 | return data
66 |
--------------------------------------------------------------------------------
/verl/workers/reward/custom.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | from transformers import PreTrainedTokenizer
18 |
19 | from ...protocol import DataProto
20 | from ...utils.reward_score import math_compute_score, r1v_compute_score, pure_math_compute_score
21 |
22 |
23 | class CustomRewardManager:
24 | def __init__(self, tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: str):
25 | self.tokenizer = tokenizer
26 | self.num_examine = num_examine
27 | if compute_score == "math":
28 | self.compute_score = math_compute_score
29 | elif compute_score == "r1v":
30 | self.compute_score = r1v_compute_score
31 | elif compute_score == "pure_math":
32 | self.compute_score = pure_math_compute_score
33 | else:
34 | raise NotImplementedError()
35 |
36 | def __call__(self, data: DataProto) -> torch.Tensor:
37 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
38 | already_print = 0
39 |
40 | for i in range(len(data)):
41 | data_item = data[i] # DataProtoItem
42 |
43 | prompt_ids = data_item.batch["prompts"]
44 | prompt_length = prompt_ids.shape[-1]
45 |
46 | valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
47 | valid_prompt_ids = prompt_ids[-valid_prompt_length:]
48 |
49 | response_ids = data_item.batch["responses"]
50 | valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
51 | valid_response_ids = response_ids[:valid_response_length]
52 |
53 | # decode
54 | prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
55 | response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
56 |
57 | ground_truth = data_item.non_tensor_batch["ground_truth"]
58 |
59 | score = self.compute_score(response_str, ground_truth)
60 | reward_tensor[i, valid_response_length - 1] = score
61 |
62 | if already_print < self.num_examine:
63 | already_print += 1
64 | print("[prompt]", prompt_str)
65 | print("[response]", response_str)
66 | print("[ground_truth]", ground_truth)
67 | print("[score]", score)
68 |
69 | return reward_tensor
70 |
--------------------------------------------------------------------------------
/verl/trainer/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | PPO config
16 | """
17 |
18 | import os
19 | from dataclasses import asdict, dataclass, field, fields, is_dataclass
20 | from typing import Optional, Tuple
21 |
22 | from ..workers.config import WorkerConfig
23 |
24 |
25 | def recursive_post_init(dataclass_obj):
26 | if hasattr(dataclass_obj, "post_init"):
27 | dataclass_obj.post_init()
28 |
29 | for attr in fields(dataclass_obj):
30 | if is_dataclass(getattr(dataclass_obj, attr.name)):
31 | recursive_post_init(getattr(dataclass_obj, attr.name))
32 |
33 |
34 | @dataclass
35 | class DataConfig:
36 | train_files: str = ""
37 | val_files: str = ""
38 | prompt_key: str = "prompt"
39 | answer_key: str = "answer"
40 | image_key: str = "images"
41 | max_prompt_length: int = 512
42 | max_response_length: int = 512
43 | rollout_batch_size: int = 512
44 | system_prompt: Optional[str] = None
45 | shuffle: bool = True
46 | seed: int = 1
47 | max_pixels: int = 4194304
48 | min_pixels: int = 262144
49 |
50 |
51 | @dataclass
52 | class AlgorithmConfig:
53 | gamma: float = 1.0
54 | lam: float = 1.0
55 | adv_estimator: str = "grpo"
56 | kl_penalty: str = "kl"
57 | kl_type: str = "fixed"
58 | kl_coef: float = 1e-3
59 | kl_horizon: float = 0.0
60 | kl_target: float = 0.0
61 |
62 |
63 | @dataclass
64 | class TrainerConfig:
65 | total_episodes: int = 10
66 | max_steps: Optional[int] = None
67 | project_name: str = "easy_r1"
68 | experiment_name: str = "demo"
69 | logger: Tuple[str] = ("console", "wandb")
70 | nnodes: int = 1
71 | n_gpus_per_node: int = 8
72 | critic_warmup: int = 0
73 | val_freq: int = -1
74 | val_before_train: bool = True
75 | val_only: bool = False
76 | val_generations_to_log: int = 1
77 | save_freq: int = -1
78 | remove_previous_ckpt: bool = False
79 | remove_ckpt_after_load: bool = False
80 | save_checkpoint_path: Optional[str] = None
81 | load_checkpoint_path: Optional[str] = None
82 | def post_init(self):
83 | if self.save_checkpoint_path is None:
84 | self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
85 |
86 |
87 | @dataclass
88 | class PPOConfig:
89 | data: DataConfig = field(default_factory=DataConfig)
90 | worker: WorkerConfig = field(default_factory=WorkerConfig)
91 | algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
92 | trainer: TrainerConfig = field(default_factory=TrainerConfig)
93 |
94 | def post_init(self):
95 | self.worker.rollout.prompt_length = self.data.max_prompt_length
96 | self.worker.rollout.response_length = self.data.max_response_length
97 |
98 | def deep_post_init(self):
99 | recursive_post_init(self)
100 |
101 | def to_dict(self):
102 | return asdict(self)
103 |
--------------------------------------------------------------------------------
/verl/workers/actor/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Actor config
16 | """
17 |
18 | from dataclasses import dataclass, field
19 | from typing import Any, Dict, Optional, Tuple
20 |
21 |
22 | @dataclass
23 | class ModelConfig:
24 | model_path: Optional[str] = None
25 | tokenizer_path: Optional[str] = None
26 | override_config: Dict[str, Any] = field(default_factory=dict)
27 | enable_gradient_checkpointing: bool = True
28 | trust_remote_code: bool = True
29 | freeze_vision_tower: bool = False
30 | freeze_language_model: bool = False
31 |
32 | def post_init(self):
33 | if self.tokenizer_path is None:
34 | self.tokenizer_path = self.model_path
35 |
36 |
37 | @dataclass
38 | class OptimConfig:
39 | lr: float = 1e-6
40 | betas: Tuple[float, float] = (0.9, 0.999)
41 | weight_decay: float = 1e-2
42 | lr_warmup_ratio: float = 0.0
43 | min_lr_ratio: Optional[float] = None
44 | warmup_style: str = "constant"
45 | """auto keys"""
46 | training_steps: int = field(default=-1, init=False)
47 |
48 |
49 | @dataclass
50 | class FSDPConfig:
51 | enable_full_shard: bool = True
52 | enable_cpu_offload: bool = False
53 | enable_rank0_init: bool = False
54 | use_orig_params: bool = False
55 | torch_dtype: Optional[str] = None
56 | fsdp_size: int = -1
57 | mp_param_dtype: str = "bf16"
58 | mp_reduce_dtype: str = "fp32"
59 | mp_buffer_dtype: str = "fp32"
60 |
61 |
62 | @dataclass
63 | class OffloadConfig:
64 | offload_params: bool = False
65 | offload_optimizer: bool = False
66 |
67 |
68 | @dataclass
69 | class ActorConfig:
70 | strategy: str = "fsdp"
71 | global_batch_size: int = 256
72 | micro_batch_size_per_device_for_update: int = 4
73 | micro_batch_size_per_device_for_experience: int = 16
74 | max_grad_norm: float = 1.0
75 | clip_ratio: float = 0.2
76 | entropy_coeff: float = 1e-3
77 | use_kl_loss: bool = True
78 | kl_loss_coef: float = 1e-3
79 | kl_loss_type: str = "low_var_kl"
80 | ppo_epochs: int = 1
81 | padding_free: bool = False
82 | ulysses_sequence_parallel_size: int = 1
83 | is_noisy: bool = False
84 | aug_type: str = "gaussian" # aug
85 | gaussian_noise_step: int = 500 # aug
86 | crop_size: float = 0.0 # aug
87 | rotate_angle: int = 15 # aug
88 | decay_coef: float = 1.0 # decay
89 | decay_mode: str = "exp" # decay
90 | decay_sig_mid_step: int = 40 # decay
91 | model: ModelConfig = field(default_factory=ModelConfig)
92 | optim: OptimConfig = field(default_factory=OptimConfig)
93 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
94 | offload: OffloadConfig = field(default_factory=OffloadConfig)
95 | """auto keys"""
96 | global_batch_size_per_device: int = field(default=-1, init=False)
97 |
98 |
99 | @dataclass
100 | class RefConfig:
101 | strategy: str = "fsdp"
102 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
103 | offload: OffloadConfig = field(default_factory=OffloadConfig)
104 | """auto keys"""
105 | micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
106 | padding_free: bool = field(default=False, init=False)
107 | ulysses_sequence_parallel_size: int = field(default=1, init=False)
108 |
--------------------------------------------------------------------------------
/verl/trainer/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
16 | """
17 |
18 | import json
19 |
20 | import ray
21 | from omegaconf import OmegaConf
22 |
23 | from ..single_controller.ray import RayWorkerGroup
24 | from ..utils.tokenizer import get_processor, get_tokenizer
25 | from ..workers.fsdp_workers import FSDPWorker
26 | from ..workers.reward import CustomRewardManager
27 | from .config import PPOConfig
28 | from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
29 |
30 |
31 | @ray.remote(num_cpus=1)
32 | def main_task(config: PPOConfig):
33 | # please make sure main_task is not scheduled on head
34 | # print config
35 | config.deep_post_init()
36 | print(json.dumps(config.to_dict(), indent=2))
37 |
38 | # instantiate tokenizer
39 | tokenizer = get_tokenizer(
40 | config.worker.actor.model.model_path,
41 | trust_remote_code=config.worker.actor.model.trust_remote_code,
42 | use_fast=True,
43 | )
44 | processor = get_processor(
45 | config.worker.actor.model.model_path,
46 | trust_remote_code=config.worker.actor.model.trust_remote_code,
47 | use_fast=True,
48 | )
49 |
50 | # define worker classes
51 | ray_worker_group_cls = RayWorkerGroup
52 | role_worker_mapping = {
53 | Role.ActorRollout: ray.remote(FSDPWorker),
54 | Role.Critic: ray.remote(FSDPWorker),
55 | Role.RefPolicy: ray.remote(FSDPWorker),
56 | }
57 | global_pool_id = "global_pool"
58 | resource_pool_spec = {
59 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
60 | }
61 | mapping = {
62 | Role.ActorRollout: global_pool_id,
63 | Role.Critic: global_pool_id,
64 | Role.RefPolicy: global_pool_id,
65 | }
66 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
67 |
68 | reward_fn = CustomRewardManager(
69 | tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
70 | )
71 | val_reward_fn = CustomRewardManager(
72 | tokenizer=tokenizer, num_examine=1, compute_score="pure_math"
73 | )
74 |
75 | trainer = RayPPOTrainer(
76 | config=config,
77 | tokenizer=tokenizer,
78 | processor=processor,
79 | role_worker_mapping=role_worker_mapping,
80 | resource_pool_manager=resource_pool_manager,
81 | ray_worker_group_cls=ray_worker_group_cls,
82 | reward_fn=reward_fn,
83 | val_reward_fn=val_reward_fn,
84 | )
85 | trainer.init_workers()
86 | trainer.fit()
87 |
88 |
89 | def main():
90 | cli_args = OmegaConf.from_cli()
91 | file_config = OmegaConf.load(getattr(cli_args, "config"))
92 | cli_args.pop("config", None)
93 |
94 | default_config = OmegaConf.structured(PPOConfig())
95 | ppo_config = OmegaConf.merge(default_config, file_config, cli_args)
96 | ppo_config = OmegaConf.to_object(ppo_config)
97 |
98 | if not ray.is_initialized():
99 | # this is for local ray cluster
100 | ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
101 |
102 | ray.get(main_task.remote(ppo_config))
103 |
104 |
105 | if __name__ == "__main__":
106 | main()
107 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # outputs
174 | outputs/
175 | checkpoints/
176 | wandb/
177 |
178 | # data
179 | eval/data
180 | eval/results
--------------------------------------------------------------------------------
/verl/utils/checkpoint/checkpoint_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import random
17 | import shutil
18 | import tempfile
19 | from abc import ABC, abstractmethod
20 | from typing import Union
21 |
22 | import numpy as np
23 | import torch
24 | import torch.distributed as dist
25 | from filelock import FileLock
26 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
27 | from transformers import PreTrainedTokenizer, ProcessorMixin
28 |
29 |
30 | class BaseCheckpointManager(ABC):
31 | """
32 | A checkpoint manager that saves and loads
33 | - model
34 | - optimizer
35 | - lr_scheduler
36 | - extra_states
37 | in a SPMD way.
38 |
39 | We save
40 | - sharded model states and optimizer states
41 | - full lr_scheduler states
42 | - huggingface tokenizer and config for ckpt merge
43 | """
44 |
45 | def __init__(
46 | self,
47 | model: FSDP,
48 | optimizer: torch.optim.Optimizer,
49 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
50 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
51 | ):
52 | self.previous_global_step = None
53 | self.previous_save_local_path = None
54 |
55 | self.model = model
56 | self.optimizer = optimizer
57 | self.lr_scheduler = lr_scheduler
58 | self.processing_class = processing_class
59 |
60 | assert isinstance(self.model, FSDP)
61 | self.rank = dist.get_rank()
62 | self.world_size = dist.get_world_size()
63 |
64 | @abstractmethod
65 | def load_checkpoint(self, *args, **kwargs):
66 | raise NotImplementedError
67 |
68 | @abstractmethod
69 | def save_checkpoint(self, *args, **kwargs):
70 | raise NotImplementedError
71 |
72 | def remove_previous_save_local_path(self):
73 | if not self.previous_save_local_path:
74 | return
75 |
76 | abs_path = os.path.abspath(self.previous_save_local_path)
77 | print(f"Checkpoint manager remove previous save local path: {abs_path}")
78 | if not os.path.exists(abs_path):
79 | return
80 |
81 | # remove previous local_path
82 | shutil.rmtree(abs_path, ignore_errors=True)
83 |
84 | @staticmethod
85 | def local_mkdir(path):
86 | if not os.path.isabs(path):
87 | working_dir = os.getcwd()
88 | path = os.path.join(working_dir, path)
89 |
90 | # Using hash value of path as lock file name to avoid long file name
91 | lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock"
92 | lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
93 |
94 | try:
95 | with FileLock(lock_path, timeout=60): # Add timeout
96 | # make a new dir
97 | os.makedirs(path, exist_ok=True)
98 | except Exception as e:
99 | print(f"Warning: Failed to acquire lock for {path}: {e}")
100 | # Even if the lock is not acquired, try to create the directory
101 | os.makedirs(path, exist_ok=True)
102 |
103 | return path
104 |
105 | @staticmethod
106 | def get_rng_state():
107 | rng_state = {
108 | "cpu": torch.get_rng_state(),
109 | "cuda": torch.cuda.get_rng_state(),
110 | "numpy": np.random.get_state(),
111 | "random": random.getstate(),
112 | }
113 | return rng_state
114 |
115 | @staticmethod
116 | def load_rng_state(rng_state):
117 | torch.set_rng_state(rng_state["cpu"])
118 | torch.cuda.set_rng_state(rng_state["cuda"])
119 | np.random.set_state(rng_state["numpy"])
120 | random.setstate(rng_state["random"])
121 |
122 |
123 | def find_latest_ckpt_path(path, directory_format="global_step_{}"):
124 | if path is None:
125 | return None
126 |
127 | tracker_file = get_checkpoint_tracker_filename(path)
128 | if not os.path.exists(tracker_file):
129 | print("Checkpoint tracker file does not exist: %s", tracker_file)
130 | return None
131 |
132 | with open(tracker_file, "rb") as f:
133 | iteration = int(f.read().decode())
134 | ckpt_path = os.path.join(path, directory_format.format(iteration))
135 | if not os.path.exists(ckpt_path):
136 | print("Checkpoint does not exist: %s", ckpt_path)
137 | return None
138 |
139 | print("Found checkpoint: %s", ckpt_path)
140 | return ckpt_path
141 |
142 |
143 | def get_checkpoint_tracker_filename(root_path: str):
144 | """
145 | Tracker file rescords the latest chckpoint during training to restart from.
146 | """
147 | return os.path.join(root_path, "latest_checkpointed_iteration.txt")
148 |
--------------------------------------------------------------------------------
/verl/utils/image_aug.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import random
5 | from PIL import Image
6 | import torchvision.transforms as T
7 | from ..workers.actor.config import ActorConfig
8 |
9 | class ImageAugmenter:
10 | def __init__(self, config: ActorConfig=None):
11 | self.config = config or {}
12 | self.to_tensor = T.ToTensor()
13 | self.to_pil = T.ToPILImage()
14 |
15 | # Set default configurations
16 | self.aug_type = self.config.aug_type
17 | self.noise_step = self.config.gaussian_noise_step # Gaussian noise steps
18 | self.crop_size = self.config.crop_size # Random occlusion region ratio
19 | self.rotate_angle = self.config.rotate_angle # Maximum rotation angle
20 | self.decay_sig_mid_step = self.config.decay_sig_mid_step
21 |
22 | # Augmentation method mapping
23 | self.aug_methods = {
24 | 'gaussian': self.apply_gaussian_noise,
25 | 'crop_fill': self.apply_crop_fill,
26 | 'rotate': self.apply_rotation,
27 | }
28 |
29 | def augment(self, image, step=0, total_steps=1):
30 | # Handle decay
31 | decay = 1.0
32 | if hasattr(self.config, 'decay_mode') and hasattr(self.config, 'decay_coef'):
33 | decay_mode = self.config.decay_mode
34 | decay_coef = self.config.decay_coef
35 | norm_step = step / total_steps
36 |
37 | if decay_mode == 'exp':
38 | decay = 1.0 - decay_coef ** (total_steps - step)
39 | elif decay_mode == 'pow':
40 | decay = 1.0 - norm_step ** decay_coef
41 | elif decay_mode == 'linear':
42 | decay = 1.0 - norm_step
43 | elif decay_mode == 'sigmoid':
44 | x = decay_coef * (norm_step - self.decay_sig_mid_step / total_steps)
45 | decay = 1.0 - (1 / (1 + math.exp(-x)))
46 |
47 | aug_method = self.aug_methods.get(self.aug_type)
48 | if aug_method is None:
49 | return image
50 |
51 | return aug_method(image, decay)
52 |
53 | def apply_gaussian_noise(self, image, decay=1.0):
54 | """Apply gaussian noise to image"""
55 | image_tensor = self.to_tensor(image)
56 | noise_step = int(self.noise_step * decay)
57 | noisy_tensor = self._add_gaussian_noise(image_tensor, noise_step)
58 | noisy_tensor = torch.clamp(noisy_tensor, 0.0, 1.0)
59 | return self.to_pil(noisy_tensor)
60 |
61 | def _add_gaussian_noise(self, image_tensor, noise_step):
62 | """Implementation of gaussian noise"""
63 | num_steps = 1000
64 | betas = torch.linspace(-6, 6, num_steps)
65 | betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
66 | alphas = 1 - betas
67 | alphas_prod = torch.cumprod(alphas, dim=0)
68 | alphas_bar_sqrt = torch.sqrt(alphas_prod)
69 | one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
70 |
71 | def q_x(x_0, t):
72 | noise = torch.randn_like(x_0)
73 | alphas_t = alphas_bar_sqrt[t]
74 | alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
75 | return (alphas_t * x_0 + alphas_1_m_t * noise)
76 |
77 | return q_x(image_tensor, noise_step)
78 |
79 | def apply_rotation(self, image, decay=1.0):
80 | """Apply rotation to image"""
81 | max_angle = self.rotate_angle * decay
82 | angle = random.uniform(-max_angle, max_angle)
83 | return image.rotate(angle, resample=Image.BILINEAR, expand=True)
84 |
85 | def apply_crop_fill(self, image, decay=1.0):
86 | """Apply random region filling (occlusion)"""
87 | width, height = image.size
88 | crop_size_scaled = self.crop_size * decay
89 | crop_w = int(width * crop_size_scaled)
90 | crop_h = int(height * crop_size_scaled)
91 |
92 | if width > crop_w and height > crop_h:
93 | x = random.randint(0, width - crop_w)
94 | y = random.randint(0, height - crop_h)
95 | else:
96 | x, y = 0, 0
97 | crop_w = min(crop_w, width)
98 | crop_h = min(crop_h, height)
99 |
100 | img_np = np.array(image)
101 | img_np[y:y+crop_h, x:x+crop_w] = 0
102 |
103 | return Image.fromarray(img_np)
104 |
105 |
106 | def augment_images(images, config, step=0, total_steps=1):
107 | augmenter = ImageAugmenter(config)
108 | return [augmenter.augment(img, step, total_steps) for img in images]
109 |
110 |
111 | def augment_batch(batch, config, step=0, total_steps=1):
112 | if "multi_modal_data" not in batch.non_tensor_batch:
113 | return batch
114 |
115 | from copy import deepcopy
116 | new_batch = deepcopy(batch)
117 |
118 | for i, item in enumerate(new_batch.non_tensor_batch["multi_modal_data"]):
119 | if "image" in item:
120 | image_list = item["image"]
121 | augmented_images = augment_images(
122 | image_list,
123 | config,
124 | step,
125 | total_steps
126 | )
127 | new_batch.non_tensor_batch["multi_modal_data"][i]["image"] = augmented_images
128 |
129 | return new_batch
130 |
--------------------------------------------------------------------------------
/verl/utils/fsdp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import gc
16 | from collections import defaultdict
17 | from functools import partial
18 | from typing import Callable, Union
19 |
20 | import torch
21 | from torch import nn
22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
23 | from torch.distributed.fsdp._runtime_utils import _lazy_init
24 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
25 | from torch.optim import Optimizer
26 | from transformers import PreTrainedModel
27 | from transformers.trainer_pt_utils import get_module_class_from_name
28 |
29 |
30 | def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]:
31 | param_occurrence = defaultdict(int)
32 | for _, param in model.named_parameters(remove_duplicate=False):
33 | param_occurrence[param] += 1
34 |
35 | duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1}
36 | materialized_params = {}
37 |
38 | def init_fn(module: nn.Module):
39 | for name, param in module.named_parameters(recurse=False):
40 | if param in duplicated_params:
41 | module._parameters[name] = materialized_params.setdefault(
42 | param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad)
43 | )
44 | else:
45 | module._parameters[name] = nn.Parameter(
46 | torch.empty_like(param.data, device=device), requires_grad=param.requires_grad
47 | )
48 |
49 | return init_fn
50 |
51 |
52 | def get_fsdp_wrap_policy(model: PreTrainedModel):
53 | """Get FSDP wrap policy for the model.
54 |
55 | Args:
56 | module: The module to get wrap policy for
57 | """
58 | transformer_cls_to_wrap = set()
59 | for module in model._no_split_modules:
60 | transformer_cls = get_module_class_from_name(model, module)
61 | if transformer_cls is None:
62 | raise Exception(f"Cannot find {module} in pretrained model.")
63 | else:
64 | transformer_cls_to_wrap.add(transformer_cls)
65 |
66 | return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap)
67 |
68 |
69 | @torch.no_grad()
70 | def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
71 | # lazy init FSDP model
72 | _lazy_init(model, model)
73 | assert model._is_root, "Only support root model offloading to CPU"
74 | for handle in model._all_handles:
75 | if handle._offload_params:
76 | continue
77 |
78 | flat_param = handle.flat_param
79 | assert (
80 | flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()
81 | and id(flat_param.data) != id(flat_param._local_shard)
82 | and flat_param.data.size() == flat_param._local_shard.size()
83 | )
84 | handle.flat_param_to("cpu", non_blocking=True)
85 | # the following still keeps id(._local_shard) != id(.data)
86 | flat_param._local_shard = flat_param.data
87 | assert id(flat_param._local_shard) != id(flat_param.data)
88 |
89 | if empty_cache:
90 | torch.cuda.empty_cache()
91 |
92 |
93 | @torch.no_grad()
94 | def load_fsdp_model(model: FSDP, empty_cache: bool = True):
95 | # lazy init FSDP model
96 | _lazy_init(model, model)
97 | assert model._is_root, "Only support root model loading to GPU"
98 | for handle in model._all_handles:
99 | if handle._offload_params:
100 | continue
101 |
102 | flat_param = handle.flat_param
103 | handle.flat_param_to("cuda", non_blocking=True)
104 | # the following still keeps id(._local_shard) != id(.data)
105 | flat_param._local_shard = flat_param.data
106 |
107 | if empty_cache:
108 | gc.collect()
109 |
110 |
111 | @torch.no_grad()
112 | def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
113 | if not optimizer.state:
114 | return
115 |
116 | for param_group in optimizer.param_groups:
117 | for param in param_group["params"]:
118 | state = optimizer.state[param]
119 | for key, value in state.items():
120 | if isinstance(value, torch.Tensor):
121 | state[key] = value.to("cpu", non_blocking=True)
122 |
123 | if empty_cache:
124 | torch.cuda.empty_cache()
125 |
126 |
127 | @torch.no_grad()
128 | def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
129 | if not optimizer.state:
130 | return
131 |
132 | for param_group in optimizer.param_groups:
133 | for param in param_group["params"]:
134 | state = optimizer.state[param]
135 | for key, value in state.items():
136 | if isinstance(value, torch.Tensor):
137 | state[key] = value.to("cuda", non_blocking=True)
138 |
139 | if empty_cache:
140 | gc.collect()
141 |
--------------------------------------------------------------------------------
/verl/utils/flops_counter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING, List, Tuple
16 |
17 | import torch
18 |
19 |
20 | if TYPE_CHECKING:
21 | from transformers.models.llama.configuration_llama import LlamaConfig
22 |
23 |
24 | VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
25 |
26 |
27 | def get_device_flops(unit: str = "T") -> float:
28 | def unit_convert(number: float, level: str):
29 | units = ["B", "K", "M", "G", "T", "P"]
30 | if number <= 0:
31 | return number
32 |
33 | ptr = 0
34 | while ptr < len(units) and units[ptr] != level:
35 | number /= 1000
36 | ptr += 1
37 |
38 | return number
39 |
40 | device_name = torch.cuda.get_device_name()
41 | flops = float("inf") # INF flops for unkown gpu type
42 | if "H100" in device_name or "H800" in device_name:
43 | flops = 989e12
44 | elif "A100" in device_name or "A800" in device_name:
45 | flops = 312e12
46 | elif "L40" in device_name:
47 | flops = 181.05e12
48 | elif "L20" in device_name:
49 | flops = 119.5e12
50 | elif "H20" in device_name:
51 | flops = 148e12
52 | elif "910B" in device_name:
53 | flops = 354e12
54 | flops_unit = unit_convert(flops, unit)
55 | return flops_unit
56 |
57 |
58 | class FlopsCounter:
59 | """
60 | Used to count mfu during training loop
61 |
62 | Example:
63 | flops_counter = FlopsCounter(config)
64 | flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
65 | """
66 |
67 | def __init__(self, config: "LlamaConfig"):
68 | if config.model_type not in VALID_MODLE_TYPE:
69 | print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.")
70 |
71 | self.estimate_func = {
72 | "llama": self._estimate_llama_flops,
73 | "qwen2": self._estimate_llama_flops,
74 | "qwen2_vl": self._estimate_llama_flops,
75 | "qwen2_5_vl": self._estimate_llama_flops,
76 | }
77 | self.config = config
78 |
79 | def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
80 | return 0
81 |
82 | def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
83 | hidden_size = self.config.hidden_size
84 | vocab_size = self.config.vocab_size
85 | num_hidden_layers = self.config.num_hidden_layers
86 | num_key_value_heads = self.config.num_key_value_heads
87 | num_attention_heads = self.config.num_attention_heads
88 | intermediate_size = self.config.intermediate_size
89 |
90 | head_dim = hidden_size // num_attention_heads
91 | q_size = num_attention_heads * head_dim
92 | k_size = num_key_value_heads * head_dim
93 | v_size = num_key_value_heads * head_dim
94 |
95 | # non-attn per layer parm
96 | # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
97 | mlp_N = hidden_size * intermediate_size * 3
98 | attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
99 | emd_and_lm_head_N = vocab_size * hidden_size * 2
100 | # non-attn all_layer parm
101 | dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
102 | # non-attn all_layer & all_token fwd & bwd flops
103 | dense_N_flops = 6 * dense_N * tokens_sum
104 |
105 | # attn all_layer & all_token fwd & bwd flops
106 | seqlen_square_sum = 0
107 | for seqlen in batch_seqlens:
108 | seqlen_square_sum += seqlen * seqlen
109 |
110 | attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
111 |
112 | # all_layer & all_token fwd & bwd flops
113 | flops_all_token = dense_N_flops + attn_qkv_flops
114 | flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
115 | return flops_achieved
116 |
117 | def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]:
118 | """
119 | Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
120 |
121 | Args:
122 | batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
123 | delta_time (float): The time taken to process the batch, in seconds.
124 |
125 | Returns:
126 | estimated_flops (float): The estimated FLOPS based on the input tokens and time.
127 | promised_flops (float): The expected FLOPS of the current device.
128 | """
129 | tokens_sum = sum(batch_seqlens)
130 | func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
131 | estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
132 | promised_flops = get_device_flops()
133 | return estimated_flops, promised_flops
134 |
--------------------------------------------------------------------------------
/verl/utils/tracking.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A unified tracking interface that supports logging data to different backend
16 | """
17 |
18 | import os
19 | from dataclasses import dataclass
20 | from typing import List, Tuple, Union
21 |
22 | from .logger.aggregate_logger import LocalLogger
23 |
24 |
25 | class Tracking:
26 | supported_backend = ["wandb", "mlflow", "swanlab", "console"]
27 |
28 | def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None):
29 | if isinstance(default_backend, str):
30 | default_backend = [default_backend]
31 |
32 | for backend in default_backend:
33 | assert backend in self.supported_backend, f"{backend} is not supported"
34 |
35 | self.logger = {}
36 |
37 | if "wandb" in default_backend:
38 | import wandb # type: ignore
39 |
40 | wandb.init(project=project_name, name=experiment_name, config=config)
41 | self.logger["wandb"] = wandb
42 |
43 | if "mlflow" in default_backend:
44 | import mlflow # type: ignore
45 |
46 | mlflow.start_run(run_name=experiment_name)
47 | mlflow.log_params(config)
48 | self.logger["mlflow"] = _MlflowLoggingAdapter()
49 |
50 | if "swanlab" in default_backend:
51 | import swanlab # type: ignore
52 |
53 | SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None)
54 | SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog")
55 | SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud")
56 | if SWANLAB_API_KEY:
57 | swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten
58 |
59 | swanlab.init(
60 | project=project_name,
61 | experiment_name=experiment_name,
62 | config=config,
63 | logdir=SWANLAB_LOG_DIR,
64 | mode=SWANLAB_MODE,
65 | )
66 | self.logger["swanlab"] = swanlab
67 |
68 | if "console" in default_backend:
69 | self.console_logger = LocalLogger()
70 | self.logger["console"] = self.console_logger
71 |
72 | def log(self, data, step, backend=None):
73 | for default_backend, logger_instance in self.logger.items():
74 | if backend is None or default_backend in backend:
75 | logger_instance.log(data=data, step=step)
76 |
77 | def __del__(self):
78 | if "wandb" in self.logger:
79 | self.logger["wandb"].finish(exit_code=0)
80 |
81 | if "swanlab" in self.logger:
82 | self.logger["swanlab"].finish()
83 |
84 |
85 | class _MlflowLoggingAdapter:
86 | def log(self, data, step):
87 | import mlflow # type: ignore
88 |
89 | mlflow.log_metrics(metrics=data, step=step)
90 |
91 |
92 | @dataclass
93 | class ValGenerationsLogger:
94 | def log(self, loggers: List[str], samples: List[Tuple[str, str, float]], step: int):
95 | if "wandb" in loggers:
96 | self.log_generations_to_wandb(samples, step)
97 | if "swanlab" in loggers:
98 | self.log_generations_to_swanlab(samples, step)
99 |
100 | def log_generations_to_wandb(self, samples: List[Tuple[str, str, float]], step: int) -> None:
101 | """Log samples to wandb as a table"""
102 | import wandb # type: ignore
103 |
104 | # Create column names for all samples
105 | columns = ["step"] + sum(
106 | [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
107 | )
108 |
109 | if not hasattr(self, "validation_table"):
110 | # Initialize the table on first call
111 | self.validation_table = wandb.Table(columns=columns)
112 |
113 | # Create a new table with same columns and existing data
114 | # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
115 | new_table = wandb.Table(columns=columns, data=self.validation_table.data)
116 |
117 | # Add new row with all data
118 | row_data = []
119 | row_data.append(step)
120 | for sample in samples:
121 | row_data.extend(sample)
122 |
123 | new_table.add_data(*row_data)
124 |
125 | # Update reference and log
126 | wandb.log({"val/generations": new_table}, step=step)
127 | self.validation_table = new_table
128 |
129 | def log_generations_to_swanlab(self, samples: List[Tuple[str, str, float]], step: int) -> None:
130 | """Log samples to swanlab as text"""
131 | import swanlab # type: ignore
132 |
133 | swanlab_text_list = []
134 | for i, sample in enumerate(samples):
135 | row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}"
136 | swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
137 |
138 | # Log to swanlab
139 | swanlab.log({"val/generations": swanlab_text_list}, step=step)
140 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/fsdp_vllm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from typing import Dict, Iterable, Tuple, Union
17 |
18 | import torch
19 | import torch.distributed as dist
20 | from torch.distributed._tensor import DTensor
21 | from torch.distributed.device_mesh import DeviceMesh
22 | from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
23 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
24 | from vllm import LLM
25 | from vllm.distributed import parallel_state as vllm_ps
26 |
27 | from ...protocol import DataProto, all_gather_data_proto
28 | from ...utils.model_utils import print_gpu_memory_usage
29 | from .base import BaseShardingManager
30 |
31 |
32 | class FSDPVLLMShardingManager(BaseShardingManager):
33 | def __init__(
34 | self,
35 | module: FSDP,
36 | inference_engine: LLM,
37 | device_mesh: DeviceMesh = None,
38 | ):
39 | self.module = module
40 | self.inference_engine = inference_engine
41 | self.device_mesh = device_mesh
42 | with warnings.catch_warnings():
43 | warnings.simplefilter("ignore")
44 | FSDP.set_state_dict_type(
45 | self.module,
46 | state_dict_type=StateDictType.SHARDED_STATE_DICT,
47 | state_dict_config=ShardedStateDictConfig(),
48 | )
49 |
50 | self.world_size = dist.get_world_size()
51 | self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
52 | self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
53 | self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group
54 |
55 | # Note that torch_random_states may be different on each dp rank
56 | self.torch_random_states = torch.cuda.get_rng_state()
57 | # get a random rng states
58 | if self.device_mesh is not None:
59 | gen_dp_rank = self.device_mesh["dp"].get_local_rank()
60 | torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
61 | self.gen_random_states = torch.cuda.get_rng_state()
62 | torch.cuda.set_rng_state(self.torch_random_states)
63 | else:
64 | self.gen_random_states = None
65 |
66 | def _make_weight_iterator(
67 | self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
68 | ) -> Iterable[Tuple[str, torch.Tensor]]:
69 | for name, tensor in actor_weights.items():
70 | yield name, tensor.full_tensor() if self.world_size != 1 else tensor
71 |
72 | def __enter__(self):
73 | # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
74 | # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
75 | # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
76 | # to speed up memory allocations.
77 | #
78 | # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
79 | # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
80 | torch.cuda.empty_cache()
81 | print_gpu_memory_usage("Before state_dict() in sharding manager")
82 | actor_weights = self.module.state_dict()
83 | print_gpu_memory_usage("After state_dict() in sharding manager")
84 |
85 | self.inference_engine.wake_up()
86 | model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
87 | model.load_weights(self._make_weight_iterator(actor_weights))
88 | print_gpu_memory_usage("After sync model weights in sharding manager")
89 |
90 | del actor_weights
91 | print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
92 | # important: need to manually set the random states of each tp to be identical.
93 | if self.device_mesh is not None:
94 | self.torch_random_states = torch.cuda.get_rng_state()
95 | torch.cuda.set_rng_state(self.gen_random_states)
96 |
97 | def __exit__(self, exc_type, exc_value, traceback):
98 | print_gpu_memory_usage("Before vllm offload in sharding manager")
99 | self.inference_engine.sleep(level=1)
100 | print_gpu_memory_usage("After vllm offload in sharding manager")
101 |
102 | self.module.train()
103 | torch.cuda.empty_cache() # add empty cache after each compute
104 |
105 | # restore random states
106 | if self.device_mesh is not None:
107 | self.gen_random_states = torch.cuda.get_rng_state()
108 | torch.cuda.set_rng_state(self.torch_random_states)
109 |
110 | def preprocess_data(self, data: DataProto) -> DataProto:
111 | """All gather across tp group to make each rank has identical input."""
112 | all_gather_data_proto(data, size=self.tp_size, group=self.tp_group)
113 | return data
114 |
115 | def postprocess_data(self, data: DataProto) -> DataProto:
116 | """Get chunk data of this tp rank since we do all gather in preprocess."""
117 | if self.tp_size > 1:
118 | data = data.chunk(chunks=self.tp_size)[self.tp_rank]
119 |
120 | return data
121 |
--------------------------------------------------------------------------------
/eval/utils/data_loaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pandas as pd
4 | from PIL import Image
5 | from tqdm import tqdm
6 | from typing import List, Dict
7 | from datasets import load_dataset
8 |
9 | def load_geo3k_dataset(data_path: str) -> List[Dict]:
10 | """Load Geo3K dataset"""
11 | data_path = os.path.join(data_path, "geometry3k/test")
12 | dataset = []
13 | folders = [f for f in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, f))]
14 |
15 | for folder in tqdm(folders, desc="Loading Geo3K data"):
16 | folder_path = os.path.join(data_path, folder)
17 | image_path = os.path.join(folder_path, "img_diagram.png")
18 | json_path = os.path.join(folder_path, "data.json")
19 |
20 | if not os.path.exists(image_path) or not os.path.exists(json_path):
21 | continue
22 |
23 | with open(json_path, "r", encoding="utf-8") as f:
24 | data = json.load(f)
25 |
26 | mapping = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
27 |
28 | dataset.append({
29 | "id": data["id"],
30 | "image_path": image_path,
31 | "question": data["annotat_text"],
32 | "answer": data["choices"][mapping[data["answer"]]],
33 | "dataset": "geo3k"
34 | })
35 |
36 | return dataset
37 |
38 | def load_wemath_dataset(data_path: str) -> List[Dict]:
39 | """Load WeMath dataset"""
40 | image_root = os.path.join(data_path, "wemath/images")
41 | data_path = os.path.join(data_path, "wemath/testmini.json")
42 | with open(data_path, "r", encoding="utf-8") as f:
43 | data = json.load(f)
44 |
45 | dataset = []
46 | for item in data:
47 | # Determine the image path
48 | image_path = os.path.join(image_root, item["image_path"])
49 |
50 | dataset.append({
51 | "id": item["ID"] + "@" + item["key"],
52 | "image_path": image_path,
53 | "question": f"{item['question']}\n\nOptions: {item['option']}",
54 | "answer": item["answer"],
55 | "dataset": "wemath"
56 | })
57 |
58 | return dataset
59 |
60 | def load_mathvista_dataset(data_path: str) -> List[Dict]:
61 | """Load MathVista dataset"""
62 | image_base_dir = os.path.join(data_path, "mathvista")
63 | dataset_raw = load_dataset("AI4Math/MathVista", split="testmini")
64 |
65 | dataset = []
66 | mapping = {
67 | "0": "A", "1": "B", "2": "C", "3": "D",
68 | "4": "E", "5": "F", "6": "G", "7": "H"
69 | }
70 |
71 | for item in dataset_raw:
72 | if item["question_type"] == "multi_choice":
73 | idx = item["choices"].index(item["answer"])
74 | answer = mapping[str(idx)]
75 | else:
76 | answer = item["answer"]
77 |
78 | dataset.append({
79 | "id": item.get("pid", ""),
80 | "image_path": os.path.join(image_base_dir, item["image"]),
81 | "question": item["query"],
82 | "answer": answer,
83 | "task": item["metadata"]["task"],
84 | "dataset": "mathvista"
85 | })
86 |
87 | return dataset
88 |
89 | def load_mathverse_dataset(data_path: str) -> List[Dict]:
90 | """Load MathVerse dataset"""
91 | image_base_dir = os.path.join(data_path, "mathverse/images")
92 | data_path = os.path.join(data_path, "mathverse/testmini.json")
93 |
94 | with open(data_path, "r", encoding="utf-8") as f:
95 | data = json.load(f)
96 |
97 | dataset = []
98 | for item in data:
99 | dataset.append({
100 | "id": item.get("sample_index", ""),
101 | "image_path": os.path.join(image_base_dir, item["image"]),
102 | "question": item["query_cot"],
103 | "question_for_eval": item["question_for_eval"],
104 | "answer": item["answer"],
105 | "problem_version": item["problem_version"],
106 | "dataset": "mathverse"
107 | })
108 |
109 | return dataset
110 |
111 | def load_mathvision_dataset(data_path: str) -> List[Dict]:
112 | """Load MathVision dataset"""
113 | image_base_dir = os.path.join(data_path, "mathvision/images")
114 | data_path = os.path.join(data_path, "mathvision/MathVision.tsv")
115 | df = pd.read_csv(data_path, sep='\t')
116 |
117 | dataset = []
118 | for _, row in df.iterrows():
119 | dataset.append({
120 | "id": row.get("index", ""),
121 | "image_path": os.path.join(image_base_dir, f"{row['index']}.jpg"),
122 | "question": row["question"],
123 | "answer": row["answer"],
124 | "subject": row.get("category", "unknown"),
125 | "dataset": "mathvision"
126 | })
127 |
128 | return dataset
129 |
130 | def load_hallubench_dataset(data_path: str) -> List[Dict]:
131 | """Load Hallubench dataset"""
132 | image_base_dir = os.path.join(data_path, "hallubench/images")
133 | data_path = os.path.join(data_path, "hallubench/HallusionBench.json")
134 |
135 | with open(data_path, "r", encoding="utf-8") as f:
136 | data = json.load(f)
137 |
138 | dataset = []
139 | for item in data:
140 | if not item["filename"]:
141 | continue
142 |
143 | if "?" in item["question"]:
144 | question = item["question"].split("?")[:-1][0]
145 | else:
146 | question = item["question"]
147 | question += "? You final answer can only be \\boxed{yes} or \\boxed{no}."
148 | gt_answer = "yes" if int(item["gt_answer"]) == 1 else "no"
149 | sid, fid, qid = item["set_id"], item["figure_id"], item["question_id"]
150 | dataset.append({
151 | "id": f"{sid}_{fid}_{qid}",
152 | "image_path": os.path.join(image_base_dir, item["filename"].replace("./", "")),
153 | "question": question,
154 | "question_for_eval": question,
155 | "answer": gt_answer,
156 | "problem_version": item["subcategory"],
157 | "dataset": "hallubench"
158 | })
159 |
160 | return dataset
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # NoisyRollout: Reinforcing Visual Reasoning with Data Augmentation
4 | [](https://arxiv.org/pdf/2504.13055) [](https://huggingface.co/collections/xyliu6/noisyrollout-67ff992d1cf251087fe021a2)
5 |
6 |
7 |
8 | ## ⚡ Updates
9 | * 18/09/2025: 🎊 NoisyRollout has been accepted to NeurIPS 2025. See you in San Diego!
10 | * 20/05/2025: 🔥 We updated the checkpoints of our trained models (larger model sizes, more training data)!
11 | * 18/04/2025: 🎉 We released our paper, models and codebase.
12 |
13 | ## 🚀 TL;DR
14 |
15 |
16 |
17 |
18 | **NoisyRollout** is a simple and effective data augmentation strategy for the RL training of VLMs that improves visual reasoning through better policy exploration. It introduces *targeted rollout diversity* by mixing rollouts from both clean and moderately distorted images, encouraging the model to learn more robust behaviors. Moreover, a *noise annealing schedule* is implemented to ensure early-stage exploration and late-stage training stability.
19 |
20 | 🎯 **Key Benefits**:
21 | - **No additional cost** — only the rollout strategy is modified
22 | - **Easy to adopt** — no changes to the model architecture or RL objective required
23 | - **Superior generalization** — achieves state-of-the-art results on **5** out-of-domain benchmarks (e.g., **MathVerse: 53.2%**, **HallusionBench: 72.1%**) with just **2.1K** RL samples
24 |
25 | 🫱 No complicated changes — just smarter rollouts and better training!
26 |
27 |
28 |
29 |
30 |
31 | ## 🛠️ Usage
32 | ### (Step1) Install
33 | First, download the wheel of vllm from this [link](https://drive.google.com/file/d/1tO6BQ4omkeXTQhDBTAFi7U7qR8vF55wP/view?usp=sharing).
34 | ```bash
35 | conda create -n noisyrollout python=3.11 -y && conda activate noisyrollout
36 |
37 | pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 transformers==4.49.0 numpy==1.26.4
38 | pip3 install google-generativeai
39 |
40 | # Use this version of vLLM to avoid memory leaks.
41 | pip3 install vllm-0.7.4.dev65+g22757848-cp38-abi3-manylinux1_x86_64.whl
42 | git clone -b verl_v1 https://github.com/hiyouga/vllm.git
43 | cp -r vllm/vllm/ ~/miniconda3/envs/noisyrollout/lib/python3.11/site-packages/
44 |
45 | pip3 install -e .
46 | ```
47 |
48 | ### (Step 2) Training
49 | ```bash
50 | # Geo3K (NoisyRollout)
51 | bash training_scripts/qwen2_5_vl_7b_geo3k_noisyrollout.sh
52 | # Geo3K (Vanilla GRPO)
53 | bash training_scripts/qwen2_5_vl_7b_geo3k_grpo.sh
54 |
55 | # K12 (NoisyRollout)
56 | bash training_scripts/qwen2_5_vl_7b_k12_noisyrollout.sh
57 | # K12 (Vanilla GRPO)
58 | bash training_scripts/qwen2_5_vl_7b_k12_grpo.sh
59 | ```
60 | ### (Step 3) Evaluation
61 | Before running the evaluation, please download the evaluation datasets from [🤗 NoisyRollout Evaluation](https://huggingface.co/datasets/xyliu6/noisyrollout_evaluation_data). Then, create a directory by running `mkdir -p ~/NoisyRollout/eval/data`, upload the `eval_data.zip` file to the `data` folder, and unzip it there.
62 | ```diff
63 | #!/bin/bash
64 | source ~/.bashrc
65 | source ~/miniconda3/bin/activate noisyrollout
66 |
67 | export VLLM_ATTENTION_BACKEND=XFORMERS # remove it when using 32b models
68 | export VLLM_USE_V1=0 # remove it when using 32b models
69 | export GOOGLE_API_KEY="xxx" # put your api key here
70 |
71 | HF_MODEL_PATH="xyliu6/NoisyRollout-Geo3K-7B"
72 | RESULTS_DIR="results/"
73 | EVAL_DIR="~/NoisyRollout/eval"
74 | DATA_DIR="~/NoisyRollout/eval/data"
75 |
76 | SYSTEM_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}."""
77 |
78 | cd $EVAL_DIR
79 | python main.py \
80 | --model $HF_MODEL_PATH \
81 | --output-dir $RESULTS_DIR \
82 | --data-path $DATA_DIR \
83 | --datasets geo3k,hallubench,mathvista,wemath,mathverse,mathvision \
84 | --tensor-parallel-size 2 \
85 | --system-prompt="$SYSTEM_PROMPT" \
86 | --min-pixels 262144 \
87 | --max-pixels 1000000 \
88 | --max-model-len 8192 \
89 | --temperature 0.0 \
90 | --eval-threads 24 \
91 | --version="7b" # change it to `32b` when using 32b models
92 | ```
93 | > 🚧 Currently, only `Gemini-2.0-Flash-001` is supported for parsing generated responses. Support for additional models will be introduced in future updates.
94 |
95 | ## Citation
96 | If you find our works useful for your research, please consider citing:
97 | ```bibtex
98 | @article{liu2025noisyrollout,
99 | title={Noisyrollout: Reinforcing visual reasoning with data augmentation},
100 | author={Liu, Xiangyan and Ni, Jinjie and Wu, Zijian and Du, Chao and Dou, Longxu and Wang, Haonan and Pang, Tianyu and Shieh, Michael Qizhe},
101 | journal={arXiv preprint arXiv:2504.13055},
102 | year={2025}
103 | }
104 | ```
105 |
106 | ## Acknowledgement
107 | * This work is supported by [Sea AI Lab](https://sail.sea.com/) for computing resources.
108 | * The training codes are built on [EasyR1](https://github.com/hiyouga/EasyR1), and the evaluation suite employs [vLLM](https://github.com/vllm-project/vllm) for acceleration.
109 | * The base models are from [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) and [Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct).
110 | * The original training datasets are from [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) and [K12](https://huggingface.co/datasets/FanqingM/MM-Eureka-Dataset).
111 | * The evaluation datasets are from [MathVerse](https://huggingface.co/datasets/AI4Math/MathVerse), [MathVision](https://huggingface.co/datasets/MathLLMs/MathVision), [MathVista](https://huggingface.co/datasets/AI4Math/MathVista), [WeMath](https://huggingface.co/datasets/We-Math/We-Math), and [HallusionBench](https://github.com/tianyi-lab/HallusionBench).
112 |
--------------------------------------------------------------------------------
/verl/trainer/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List
16 |
17 | import numpy as np
18 | import torch
19 |
20 | from ..protocol import DataProto
21 |
22 |
23 | def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
24 | response_length = batch.batch["responses"].shape[-1]
25 | prompt_mask = batch.batch["attention_mask"][:, :-response_length]
26 | response_mask = batch.batch["attention_mask"][:, -response_length:]
27 | prompt_length = prompt_mask.sum(-1).float()
28 | response_length = response_mask.sum(-1).float() # (batch_size,)
29 | return dict(
30 | response_mask=response_mask,
31 | prompt_length=prompt_length,
32 | response_length=response_length,
33 | )
34 |
35 |
36 | def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
37 | return {key: np.mean(value) for key, value in metrics.items()}
38 |
39 |
40 | def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]:
41 | sequence_score = batch.batch["token_level_scores"].sum(-1)
42 | sequence_reward = batch.batch["token_level_rewards"].sum(-1)
43 |
44 | advantages = batch.batch["advantages"]
45 | returns = batch.batch["returns"]
46 |
47 | max_response_length = batch.batch["responses"].size(-1)
48 |
49 | prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
50 | response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
51 |
52 | max_prompt_length = prompt_mask.size(-1)
53 |
54 | response_info = _compute_response_info(batch)
55 | prompt_length = response_info["prompt_length"]
56 | response_length = response_info["response_length"]
57 |
58 | valid_adv = torch.masked_select(advantages, response_mask)
59 | valid_returns = torch.masked_select(returns, response_mask)
60 |
61 | if use_critic:
62 | values = batch.batch["values"]
63 | valid_values = torch.masked_select(values, response_mask)
64 | return_diff_var = torch.var(valid_returns - valid_values)
65 | return_var = torch.var(valid_returns)
66 |
67 | metrics = {
68 | # score
69 | "critic/score/mean": torch.mean(sequence_score).detach().item(),
70 | "critic/score/max": torch.max(sequence_score).detach().item(),
71 | "critic/score/min": torch.min(sequence_score).detach().item(),
72 | # reward
73 | "critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
74 | "critic/rewards/max": torch.max(sequence_reward).detach().item(),
75 | "critic/rewards/min": torch.min(sequence_reward).detach().item(),
76 | # adv
77 | "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
78 | "critic/advantages/max": torch.max(valid_adv).detach().item(),
79 | "critic/advantages/min": torch.min(valid_adv).detach().item(),
80 | # returns
81 | "critic/returns/mean": torch.mean(valid_returns).detach().item(),
82 | "critic/returns/max": torch.max(valid_returns).detach().item(),
83 | "critic/returns/min": torch.min(valid_returns).detach().item(),
84 | **(
85 | {
86 | # values
87 | "critic/values/mean": torch.mean(valid_values).detach().item(),
88 | "critic/values/max": torch.max(valid_values).detach().item(),
89 | "critic/values/min": torch.min(valid_values).detach().item(),
90 | # vf explained var
91 | "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
92 | }
93 | if use_critic
94 | else {}
95 | ),
96 | # response length
97 | "response_length/mean": torch.mean(response_length).detach().item(),
98 | "response_length/max": torch.max(response_length).detach().item(),
99 | "response_length/min": torch.min(response_length).detach().item(),
100 | "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
101 | .detach()
102 | .item(),
103 | # prompt length
104 | "prompt_length/mean": torch.mean(prompt_length).detach().item(),
105 | "prompt_length/max": torch.max(prompt_length).detach().item(),
106 | "prompt_length/min": torch.min(prompt_length).detach().item(),
107 | "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
108 | }
109 | return metrics
110 |
111 |
112 | def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
113 | response_info = _compute_response_info(batch)
114 | num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
115 | num_response_tokens = torch.sum(response_info["response_length"]).item()
116 | num_overall_tokens = num_prompt_tokens + num_response_tokens
117 | num_tokens_of_section = {
118 | "gen": num_response_tokens,
119 | **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
120 | }
121 | return {
122 | **{f"timing_s/{name}": value for name, value in timing_raw.items()},
123 | **{
124 | f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
125 | for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
126 | },
127 | }
128 |
129 |
130 | def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
131 | total_num_tokens = sum(batch.meta_info["global_token_num"])
132 | time = timing_raw["step"]
133 | return {
134 | "perf/total_num_tokens": total_num_tokens,
135 | "perf/time_per_step": time,
136 | "perf/throughput": total_num_tokens / (time * n_gpus),
137 | }
138 |
--------------------------------------------------------------------------------
/eval/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import torch
5 | from vllm import LLM, SamplingParams
6 | from utils.data_loaders import (
7 | load_geo3k_dataset,
8 | load_wemath_dataset,
9 | load_mathvista_dataset,
10 | load_mathverse_dataset,
11 | load_mathvision_dataset,
12 | load_hallubench_dataset
13 | )
14 | from utils.processing import (
15 | prepare_prompts,
16 | process_outputs,
17 | calculate_metrics
18 | )
19 |
20 | def parse_arguments():
21 | parser = argparse.ArgumentParser(description="Unified evaluation for multimodal math datasets")
22 |
23 | # Model and runtime parameters
24 | parser.add_argument("--model", type=str, required=True, help="Path to the model")
25 | parser.add_argument("--output-dir", type=str, required=True, help="Directory to save results")
26 | parser.add_argument("--max-tokens", type=int, default=2048, help="Maximum number of tokens to generate")
27 | parser.add_argument("--min-pixels", type=int, default=262144)
28 | parser.add_argument("--max-pixels", type=int, default=1000000)
29 | parser.add_argument("--max-model-len", type=int, default=8192)
30 | parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
31 | parser.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling")
32 | parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty")
33 | parser.add_argument("--tensor-parallel-size", type=int, default=2, help="Number of GPUs for tensor parallelism")
34 | parser.add_argument("--eval-threads", type=int, default=32, help="Number of threads for evaluation")
35 | parser.add_argument("--system-prompt", type=str, default="You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \\boxed{}.", help="System prompt for the model")
36 | parser.add_argument("--version", type=str, default="7b")
37 |
38 | # Dataset selection
39 | parser.add_argument("--datasets", type=str, default="all", help="Comma-separated list of datasets to evaluate: geo3k,wemath,mathvista,mathverse,mathvision or 'all'")
40 |
41 | # Dataset-specific paths
42 | parser.add_argument("--data-path", type=str, default="NoisyRollout/eval/data", help="")
43 |
44 | return parser.parse_args()
45 |
46 | def main():
47 | args = parse_arguments()
48 |
49 | # Create output directory if it doesn't exist
50 | os.makedirs(args.output_dir, exist_ok=True)
51 |
52 | # Determine which datasets to evaluate
53 | datasets_to_eval = args.datasets.split(",") if args.datasets != "all" else [
54 | "geo3k", "wemath", "mathvista", "mathverse", "mathvision", "hallubench"
55 | ]
56 |
57 | # Dictionary to store all samples
58 | all_samples = {}
59 |
60 | # Load datasets based on selection
61 | for dataset_name in datasets_to_eval:
62 | if dataset_name == "geo3k":
63 | all_samples["geo3k"] = load_geo3k_dataset(args.data_path)
64 | print(f"Loaded {len(all_samples['geo3k'])} samples from Geo3K")
65 |
66 | elif dataset_name == "wemath":
67 | all_samples["wemath"] = load_wemath_dataset(args.data_path)
68 | print(f"Loaded {len(all_samples['wemath'])} samples from WeMath")
69 |
70 | elif dataset_name == "mathvista":
71 | all_samples["mathvista"] = load_mathvista_dataset(args.data_path)
72 | print(f"Loaded {len(all_samples['mathvista'])} samples from MathVista")
73 |
74 | elif dataset_name == "mathverse":
75 | all_samples["mathverse"] = load_mathverse_dataset(args.data_path)
76 | print(f"Loaded {len(all_samples['mathverse'])} samples from MathVerse")
77 |
78 | elif dataset_name == "mathvision":
79 | all_samples["mathvision"] = load_mathvision_dataset(args.data_path)
80 | print(f"Loaded {len(all_samples['mathvision'])} samples from MathVision")
81 |
82 | elif dataset_name == "hallubench":
83 | all_samples["hallubench"] = load_hallubench_dataset(args.data_path)
84 | print(f"Loaded {len(all_samples['hallubench'])} samples from HalluBench")
85 |
86 | if not all_samples:
87 | print("No datasets loaded. Please check the paths and dataset names.")
88 | return
89 |
90 | # Initialize model
91 | print(f"Initializing model from {args.model}")
92 | llm = LLM(
93 | model=args.model,
94 | tensor_parallel_size=args.tensor_parallel_size,
95 | dtype=torch.bfloat16,
96 | gpu_memory_utilization=0.7,
97 | max_model_len=args.max_model_len
98 | )
99 |
100 | # Configure sampling parameters
101 | sampling_params = SamplingParams(
102 | temperature=args.temperature,
103 | top_p=args.top_p,
104 | max_tokens=args.max_tokens,
105 | repetition_penalty=args.repetition_penalty,
106 | )
107 |
108 | # Process in batches
109 | all_results = {}
110 | for dataset_name in all_samples.keys():
111 | all_results[dataset_name] = []
112 |
113 | for dataset_name, samples in all_samples.items():
114 | prompts, metadata = prepare_prompts(dataset_name, samples, args)
115 |
116 | outputs = llm.generate(prompts, sampling_params)
117 |
118 | # Process outputs
119 | results = process_outputs(outputs, metadata, args.eval_threads)
120 | all_results[dataset_name] = results
121 |
122 | metrics = calculate_metrics(results)
123 |
124 | output_dict = {
125 | "results": results,
126 | "metrics": metrics,
127 | "config": vars(args)
128 | }
129 |
130 | output_path = os.path.join(args.output_dir, f"{dataset_name}.json")
131 | with open(output_path, 'w', encoding='utf-8') as f:
132 | json.dump(output_dict, f, ensure_ascii=False, indent=2)
133 |
134 | print(f"{dataset_name.upper()} Results:")
135 | print(f" Total samples: {len(results)}")
136 | print(f" Accuracy: {metrics['accuracy']:.4f}")
137 | if 'sub_accuracies' in metrics:
138 | print(" Task/Category Accuracies:")
139 | for task, acc in metrics['sub_accuracies'].items():
140 | print(f" {task}: {acc:.4f}")
141 | print()
142 |
143 | print(f"All results saved to {args.output_dir}")
144 |
145 | if __name__ == "__main__":
146 | main()
--------------------------------------------------------------------------------
/verl/single_controller/base/worker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | the class for Worker
16 | """
17 |
18 | import os
19 | import socket
20 | from dataclasses import dataclass
21 |
22 | import ray
23 |
24 | from .decorator import Dispatch, Execute, register
25 | from .register_center.ray import create_worker_group_register_center
26 |
27 |
28 | @dataclass
29 | class DistRankInfo:
30 | tp_rank: int
31 | dp_rank: int
32 | pp_rank: int
33 |
34 |
35 | @dataclass
36 | class DistGlobalInfo:
37 | tp_size: int
38 | dp_size: int
39 | pp_size: int
40 |
41 |
42 | class WorkerHelper:
43 | def _get_node_ip(self):
44 | host_ipv4 = os.getenv("MY_HOST_IP", None)
45 | host_ipv6 = os.getenv("MY_HOST_IPV6", None)
46 | host_ip_by_env = host_ipv4 or host_ipv6
47 | host_ip_by_sdk = ray._private.services.get_node_ip_address()
48 |
49 | host_ip = host_ip_by_env or host_ip_by_sdk
50 | return host_ip
51 |
52 | def _get_free_port(self):
53 | with socket.socket() as sock:
54 | sock.bind(("", 0))
55 | return sock.getsockname()[1]
56 |
57 | def get_availale_master_addr_port(self):
58 | return self._get_node_ip(), str(self._get_free_port())
59 |
60 | def _get_pid(self):
61 | return
62 |
63 |
64 | class WorkerMeta:
65 | keys = [
66 | "WORLD_SIZE",
67 | "RANK",
68 | "LOCAL_WORLD_SIZE",
69 | "LOCAL_RANK",
70 | "MASTER_ADDR",
71 | "MASTER_PORT",
72 | "CUDA_VISIBLE_DEVICES",
73 | ]
74 |
75 | def __init__(self, store) -> None:
76 | self._store = store
77 |
78 | def to_dict(self):
79 | return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
80 |
81 |
82 | # we assume that in each WorkerGroup, there is a Master Worker
83 | class Worker(WorkerHelper):
84 | def __new__(cls, *args, **kwargs):
85 | instance = super().__new__(cls)
86 |
87 | # note that here we use int to distinguish
88 | disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0))
89 | if disable_worker_init:
90 | return instance
91 |
92 | rank = os.environ.get("RANK", None)
93 | worker_group_prefix = os.environ.get("WG_PREFIX", None)
94 |
95 | # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
96 | if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
97 | instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
98 |
99 | return instance
100 |
101 | def _configure_before_init(self, register_center_name: str, rank: int):
102 | assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
103 |
104 | if rank == 0:
105 | master_addr, master_port = self.get_availale_master_addr_port()
106 | rank_zero_info = {
107 | "MASTER_ADDR": master_addr,
108 | "MASTER_PORT": master_port,
109 | }
110 | self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info)
111 | os.environ.update(rank_zero_info)
112 |
113 | def __init__(self, cuda_visible_devices=None) -> None:
114 | # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
115 | world_size = int(os.environ["WORLD_SIZE"])
116 | rank = int(os.environ["RANK"])
117 | self._rank = rank
118 | self._world_size = world_size
119 |
120 | master_addr = os.environ["MASTER_ADDR"]
121 | master_port = os.environ["MASTER_PORT"]
122 |
123 | local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
124 | local_rank = int(os.getenv("LOCAL_RANK", "0"))
125 |
126 | store = {
127 | "_world_size": world_size,
128 | "_rank": rank,
129 | "_local_world_size": local_world_size,
130 | "_local_rank": local_rank,
131 | "_master_addr": master_addr,
132 | "_master_port": master_port,
133 | }
134 | if cuda_visible_devices is not None:
135 | store["_cuda_visible_devices"] = cuda_visible_devices
136 |
137 | meta = WorkerMeta(store=store)
138 | self._configure_with_meta(meta=meta)
139 |
140 | def _configure_with_meta(self, meta: WorkerMeta):
141 | """
142 | This function should only be called inside by WorkerGroup
143 | """
144 | assert isinstance(meta, WorkerMeta)
145 | self.__dict__.update(meta.to_dict()) # this is hacky
146 | # print(f"__dict__: {self.__dict__}")
147 | for key in WorkerMeta.keys:
148 | val = self.__dict__.get(f"_{key.lower()}", None)
149 | if val is not None:
150 | # print(f"set {key} to {val}")
151 | os.environ[key] = str(val)
152 |
153 | os.environ["REDIS_STORE_SERVER_HOST"] = (
154 | str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
155 | )
156 |
157 | def get_master_addr_port(self):
158 | return self._master_addr, self._master_port
159 |
160 | def get_cuda_visible_devices(self):
161 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
162 | return cuda_visible_devices
163 |
164 | def print_rank0(self, *args, **kwargs):
165 | if self.rank == 0:
166 | print(*args, **kwargs)
167 |
168 | @property
169 | def world_size(self):
170 | return self._world_size
171 |
172 | @property
173 | def rank(self):
174 | return self._rank
175 |
176 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
177 | def execute_with_func_generator(self, func, *args, **kwargs):
178 | ret_proto = func(self, *args, **kwargs)
179 | return ret_proto
180 |
181 | @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
182 | def execute_func_rank_zero(self, func, *args, **kwargs):
183 | result = func(*args, **kwargs)
184 | return result
185 |
--------------------------------------------------------------------------------
/verl/utils/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import os
17 | from collections import defaultdict
18 | from io import BytesIO
19 | from typing import Any, Dict, List, Optional, Union
20 |
21 | import numpy as np
22 | import torch
23 | from datasets import load_dataset
24 | from PIL import Image
25 | from PIL.Image import Image as ImageObject
26 | from torch.utils.data import Dataset
27 | from transformers import PreTrainedTokenizer, ProcessorMixin
28 |
29 | from ..models.transformers.qwen2_vl import get_rope_index
30 | from . import torch_functional as VF
31 |
32 |
33 | def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
34 | tensors = defaultdict(list)
35 | non_tensors = defaultdict(list)
36 | for feature in features:
37 | for key, value in feature.items():
38 | if isinstance(value, torch.Tensor):
39 | tensors[key].append(value)
40 | else:
41 | non_tensors[key].append(value)
42 |
43 | for key, value in tensors.items():
44 | tensors[key] = torch.stack(value, dim=0)
45 |
46 | for key, value in non_tensors.items():
47 | non_tensors[key] = np.array(value, dtype=object)
48 |
49 | return {**tensors, **non_tensors}
50 |
51 |
52 | def process_image(image: Union[Dict[str, Any], ImageObject], max_pixels: int, min_pixels: int) -> ImageObject:
53 | if isinstance(image, dict):
54 | image = Image.open(BytesIO(image["bytes"]))
55 |
56 | if (image.width * image.height) > max_pixels:
57 | resize_factor = math.sqrt(max_pixels / (image.width * image.height))
58 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
59 | image = image.resize((width, height))
60 |
61 | if (image.width * image.height) < min_pixels:
62 | resize_factor = math.sqrt(min_pixels / (image.width * image.height))
63 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
64 | image = image.resize((width, height))
65 |
66 | if image.mode != "RGB":
67 | image = image.convert("RGB")
68 |
69 | return image
70 |
71 |
72 | class RLHFDataset(Dataset):
73 | """
74 | We assume the dataset contains a column that contains prompts and other information
75 | """
76 |
77 | def __init__(
78 | self,
79 | data_path: str,
80 | tokenizer: PreTrainedTokenizer,
81 | processor: Optional[ProcessorMixin],
82 | prompt_key: str = "prompt",
83 | answer_key: str = "answer",
84 | image_key: str = "images",
85 | max_prompt_length: int = 1024,
86 | truncation: str = "error",
87 | system_prompt: str = None,
88 | max_pixels: int = None,
89 | min_pixels: int = None,
90 | ):
91 | self.tokenizer = tokenizer
92 | self.processor = processor
93 | self.prompt_key = prompt_key
94 | self.answer_key = answer_key
95 | self.image_key = image_key
96 | self.max_prompt_length = max_prompt_length
97 | self.truncation = truncation
98 | self.system_prompt = system_prompt
99 | self.max_pixels = max_pixels
100 | self.min_pixels = min_pixels
101 |
102 | if "@" in data_path:
103 | data_path, data_split = data_path.split("@")
104 | else:
105 | data_split = "train"
106 |
107 | if os.path.isdir(data_path):
108 | self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
109 | elif os.path.isfile(data_path):
110 | self.dataset = load_dataset("parquet", data_files=data_path, split="train")
111 | else: # remote dataset
112 | self.dataset = load_dataset(data_path, split=data_split)
113 |
114 | def __len__(self):
115 | return len(self.dataset)
116 |
117 | def __getitem__(self, index):
118 | row_dict: dict = self.dataset[index]
119 | messages = [{"role": "user", "content": row_dict[self.prompt_key]}]
120 | if self.system_prompt:
121 | messages.insert(0, {"role": "system", "content": self.system_prompt})
122 |
123 | prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
124 |
125 | if self.image_key in row_dict:
126 | prompt = prompt.replace("", "<|vision_start|><|image_pad|><|vision_end|>")
127 | row_dict["multi_modal_data"] = {
128 | "image": [
129 | process_image(image, self.max_pixels, self.min_pixels) for image in row_dict.pop(self.image_key)
130 | ]
131 | }
132 | model_inputs = self.processor(row_dict["multi_modal_data"]["image"], prompt, return_tensors="pt")
133 | input_ids = model_inputs.pop("input_ids")[0]
134 | attention_mask = model_inputs.pop("attention_mask")[0]
135 | row_dict["multi_modal_inputs"] = dict(model_inputs)
136 | position_ids = get_rope_index(
137 | self.processor,
138 | input_ids=input_ids,
139 | image_grid_thw=model_inputs["image_grid_thw"],
140 | attention_mask=attention_mask,
141 | ) # (3, seq_length)
142 | else:
143 | model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
144 | input_ids = model_inputs.pop("input_ids")[0]
145 | attention_mask = model_inputs.pop("attention_mask")[0]
146 | position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
147 |
148 | input_ids, attention_mask, position_ids = VF.postprocess_data(
149 | input_ids=input_ids,
150 | attention_mask=attention_mask,
151 | position_ids=position_ids,
152 | max_length=self.max_prompt_length,
153 | pad_token_id=self.tokenizer.pad_token_id,
154 | left_pad=True,
155 | truncation=self.truncation,
156 | )
157 | row_dict["input_ids"] = input_ids
158 | row_dict["attention_mask"] = attention_mask
159 | row_dict["position_ids"] = position_ids
160 | row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False)
161 | row_dict["ground_truth"] = row_dict.pop(self.answer_key)
162 | row_dict["processed_prompt"] = prompt
163 | return row_dict
164 |
--------------------------------------------------------------------------------
/scripts/model_merger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | import re
18 | from concurrent.futures import ThreadPoolExecutor
19 | from typing import Dict, List, Tuple
20 |
21 | import torch
22 | from torch.distributed._tensor import DTensor, Placement, Shard
23 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
24 |
25 |
26 | def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
27 | if placement.is_replicate():
28 | return tensors[0]
29 | elif placement.is_partial():
30 | raise NotImplementedError("Partial placement is not supported yet")
31 | elif placement.is_shard():
32 | return torch.cat(tensors, dim=placement.dim).contiguous()
33 | else:
34 | raise ValueError(f"Unsupported placement: {placement}")
35 |
36 |
37 | if __name__ == "__main__":
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
40 | parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
41 | args = parser.parse_args()
42 |
43 | assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface"
44 | local_dir = args.local_dir
45 |
46 | # copy rank zero to find the shape of (dp, fsdp)
47 | rank = 0
48 | world_size = 0
49 | for filename in os.listdir(local_dir):
50 | match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
51 | if match:
52 | world_size = match.group(1)
53 | break
54 | assert world_size, "No model file with the proper format"
55 |
56 | state_dict = torch.load(
57 | os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu"
58 | )
59 | pivot_key = sorted(state_dict.keys())[0]
60 | weight = state_dict[pivot_key]
61 | assert isinstance(weight, torch.distributed._tensor.DTensor)
62 | # get sharding info
63 | device_mesh = weight.device_mesh
64 | mesh = device_mesh.mesh
65 | mesh_dim_names = device_mesh.mesh_dim_names
66 |
67 | print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
68 |
69 | assert mesh_dim_names in (("fsdp",),), f"Unsupported mesh_dim_names {mesh_dim_names}"
70 |
71 | if "tp" in mesh_dim_names:
72 | # fsdp * tp
73 | total_shards = mesh.shape[-1] * mesh.shape[-2]
74 | mesh_shape = (mesh.shape[-2], mesh.shape[-1])
75 | else:
76 | # fsdp
77 | total_shards = mesh.shape[-1]
78 | mesh_shape = (mesh.shape[-1],)
79 |
80 | print(f"Processing model shards with {total_shards} {mesh_shape} in total")
81 |
82 | model_state_dict_lst = []
83 | model_state_dict_lst.append(state_dict)
84 | model_state_dict_lst.extend([""] * (total_shards - 1))
85 |
86 | def process_one_shard(rank):
87 | model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
88 | state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
89 | model_state_dict_lst[rank] = state_dict
90 | return state_dict
91 |
92 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
93 | for rank in range(1, total_shards):
94 | executor.submit(process_one_shard, rank)
95 | state_dict = {}
96 | param_placements: Dict[str, List[Placement]] = {}
97 | keys = set(model_state_dict_lst[0].keys())
98 | for key in keys:
99 | state_dict[key] = []
100 | for model_state_dict in model_state_dict_lst:
101 | try:
102 | tensor = model_state_dict.pop(key)
103 | except Exception:
104 | print("-" * 30)
105 | print(model_state_dict)
106 | if isinstance(tensor, DTensor):
107 | state_dict[key].append(tensor._local_tensor.bfloat16())
108 | placements = tuple(tensor.placements)
109 | # replicated placement at dp dimension can be discarded
110 | if mesh_dim_names[0] == "dp":
111 | placements = placements[1:]
112 | if key not in param_placements:
113 | param_placements[key] = placements
114 | else:
115 | assert param_placements[key] == placements
116 | else:
117 | state_dict[key] = tensor.bfloat16()
118 |
119 | del model_state_dict_lst
120 |
121 | for key in sorted(state_dict):
122 | if not isinstance(state_dict[key], list):
123 | print(f"No need to merge key {key}")
124 | continue
125 | # merge shards
126 | placements: Tuple[Shard] = param_placements[key]
127 | if len(mesh_shape) == 1:
128 | # 1-D list, FSDP without TP
129 | assert len(placements) == 1
130 | shards = state_dict[key]
131 | state_dict[key] = merge_by_placement(shards, placements[0])
132 | else:
133 | # 2-D list, FSDP + TP
134 | raise NotImplementedError("FSDP + TP is not supported yet")
135 |
136 | print("Writing to local disk")
137 | hf_path = os.path.join(local_dir, "huggingface")
138 | config = AutoConfig.from_pretrained(hf_path)
139 |
140 | if "ForTokenClassification" in config.architectures[0]:
141 | auto_model = AutoModelForTokenClassification
142 | elif "ForCausalLM" in config.architectures[0]:
143 | auto_model = AutoModelForCausalLM
144 | elif "ForConditionalGeneration" in config.architectures[0]:
145 | auto_model = AutoModelForVision2Seq
146 | else:
147 | raise NotImplementedError(f"Unknown architecture {config.architectures}")
148 |
149 | with torch.device("meta"):
150 | model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
151 |
152 | model.to_empty(device="cpu")
153 |
154 | print(f"Saving model to {hf_path}")
155 | model.save_pretrained(hf_path, state_dict=state_dict)
156 | del state_dict
157 | del model
158 | if args.hf_upload_path:
159 | # Push to hugging face
160 | from huggingface_hub import HfApi
161 |
162 | api = HfApi()
163 | api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
164 | api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
165 |
--------------------------------------------------------------------------------
/verl/utils/checkpoint/fsdp_checkpoint_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import warnings
17 | from typing import Union
18 |
19 | import torch
20 | import torch.distributed as dist
21 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
22 | from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
23 | from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
24 |
25 | from .checkpoint_manager import BaseCheckpointManager
26 |
27 |
28 | class FSDPCheckpointManager(BaseCheckpointManager):
29 | """
30 | A checkpoint manager that saves and loads
31 | - model
32 | - optimizer
33 | - lr_scheduler
34 | - extra_states
35 | in a SPMD way.
36 |
37 | We save
38 | - sharded model states and optimizer states
39 | - full lr_scheduler states
40 | - huggingface tokenizer and config for ckpt merge
41 | """
42 |
43 | def __init__(
44 | self,
45 | model: FSDP,
46 | optimizer: torch.optim.Optimizer,
47 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
48 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
49 | ):
50 | super().__init__(model, optimizer, lr_scheduler, processing_class)
51 |
52 | def load_checkpoint(self, path: str = None, remove_ckpt_after_load: bool = False):
53 | if path is None:
54 | return
55 |
56 | # every rank download its own checkpoint
57 | local_model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
58 | local_optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
59 | local_extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
60 | print(
61 | f"[rank-{self.rank}]: Loading from {local_model_path} and {local_optim_path} and {local_extra_state_path}"
62 | )
63 | model_state_dict = torch.load(local_model_path, weights_only=False)
64 | optimizer_state_dict = torch.load(local_optim_path, weights_only=False)
65 | extra_state_dict = torch.load(local_extra_state_path, weights_only=False)
66 |
67 | if remove_ckpt_after_load:
68 | try:
69 | os.remove(local_model_path)
70 | os.remove(local_optim_path)
71 | os.remove(local_extra_state_path)
72 | except Exception as e:
73 | print(f"[rank-{self.rank}]: remove ckpt file after loading failed, exception {e} will be ignored.")
74 |
75 | lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
76 | state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
77 | optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
78 | with warnings.catch_warnings():
79 | warnings.simplefilter("ignore")
80 | with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
81 | self.model.load_state_dict(model_state_dict)
82 | if self.optimizer is not None:
83 | self.optimizer.load_state_dict(optimizer_state_dict)
84 |
85 | # recover random state
86 | if "rng" in extra_state_dict:
87 | self.load_rng_state(extra_state_dict["rng"])
88 |
89 | if self.lr_scheduler is not None:
90 | self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
91 |
92 | def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt: bool = False):
93 | # record the previous global step
94 | self.previous_global_step = global_step
95 |
96 | # remove previous local_path
97 | if remove_previous_ckpt:
98 | self.remove_previous_save_local_path()
99 |
100 | local_path = self.local_mkdir(local_path)
101 | dist.barrier()
102 |
103 | # every rank will save its own model and optim shard
104 | state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
105 | optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
106 | with warnings.catch_warnings():
107 | warnings.simplefilter("ignore")
108 | with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
109 | model_state_dict = self.model.state_dict()
110 | if self.optimizer is not None:
111 | optimizer_state_dict = self.optimizer.state_dict()
112 | else:
113 | optimizer_state_dict = None
114 |
115 | if self.lr_scheduler is not None:
116 | lr_scheduler_state_dict = self.lr_scheduler.state_dict()
117 | else:
118 | lr_scheduler_state_dict = None
119 |
120 | extra_state_dict = {
121 | "lr_scheduler": lr_scheduler_state_dict,
122 | "rng": self.get_rng_state(),
123 | }
124 | model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
125 | optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
126 | extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
127 |
128 | print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
129 | print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
130 | print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
131 | torch.save(model_state_dict, model_path)
132 | if self.optimizer is not None:
133 | torch.save(optimizer_state_dict, optim_path)
134 |
135 | torch.save(extra_state_dict, extra_path)
136 |
137 | # wait for everyone to dump to local
138 | dist.barrier()
139 |
140 | if self.rank == 0:
141 | hf_local_path = os.path.join(local_path, "huggingface")
142 | os.makedirs(hf_local_path, exist_ok=True)
143 | assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel)
144 | self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path)
145 | self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_local_path)
146 | self.processing_class.save_pretrained(hf_local_path)
147 |
148 | dist.barrier()
149 | self.previous_save_local_path = local_path
150 |
--------------------------------------------------------------------------------
/verl/single_controller/base/decorator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from enum import Enum, auto
16 | from functools import wraps
17 | from types import FunctionType
18 | from typing import TYPE_CHECKING, List
19 |
20 | import ray
21 |
22 | from ...protocol import DataProto, DataProtoFuture
23 |
24 |
25 | if TYPE_CHECKING:
26 | from .worker_group import WorkerGroup
27 |
28 |
29 | # here we add a magic number of avoid user-defined function already have this attribute
30 | MAGIC_ATTR = "attrs_3141562937"
31 |
32 |
33 | class Dispatch(Enum):
34 | RANK_ZERO = auto()
35 | ONE_TO_ALL = auto()
36 | ALL_TO_ALL = auto()
37 | DP_COMPUTE = auto()
38 | DP_COMPUTE_PROTO = auto()
39 | DP_COMPUTE_PROTO_WITH_FUNC = auto()
40 | DP_COMPUTE_METRIC = auto()
41 |
42 |
43 | class Execute(Enum):
44 | ALL = 0
45 | RANK_ZERO = 1
46 |
47 |
48 | def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
49 | splitted_args = []
50 | for arg in args:
51 | assert isinstance(arg, (DataProto, DataProtoFuture))
52 | splitted_args.append(arg.chunk(chunks=chunks))
53 |
54 | splitted_kwargs = {}
55 | for key, val in kwargs.items():
56 | assert isinstance(val, (DataProto, DataProtoFuture))
57 | splitted_kwargs[key] = val.chunk(chunks=chunks)
58 |
59 | return splitted_args, splitted_kwargs
60 |
61 |
62 | def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
63 | args = tuple([arg] * worker_group.world_size for arg in args)
64 | kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
65 | return args, kwargs
66 |
67 |
68 | def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
69 | return args, kwargs
70 |
71 |
72 | def collect_all_to_all(worker_group: "WorkerGroup", output):
73 | return output
74 |
75 |
76 | def _concat_data_proto_or_future(output: List):
77 | # make sure all the elements in output has the same type
78 | for o in output:
79 | assert type(o) is type(output[0])
80 |
81 | o = output[0]
82 |
83 | if isinstance(o, DataProto):
84 | return DataProto.concat(output)
85 | elif isinstance(o, ray.ObjectRef):
86 | return DataProtoFuture.concat(output)
87 | else:
88 | raise NotImplementedError
89 |
90 |
91 | def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
92 | for value in args:
93 | assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
94 |
95 | for value in kwargs.values():
96 | assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
97 |
98 | return args, kwargs
99 |
100 |
101 | def collect_dp_compute(worker_group: "WorkerGroup", output):
102 | assert len(output) == worker_group.world_size
103 | return output
104 |
105 |
106 | def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
107 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
108 | return splitted_args, splitted_kwargs
109 |
110 |
111 | def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
112 | assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
113 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
114 | splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
115 | return splitted_args_with_func, splitted_kwargs
116 |
117 |
118 | def collect_dp_compute_data_proto(worker_group: "WorkerGroup", output):
119 | for o in output:
120 | assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}."
121 |
122 | output = collect_dp_compute(worker_group, output)
123 | return _concat_data_proto_or_future(output)
124 |
125 |
126 | def get_predefined_dispatch_fn(dispatch_mode):
127 | predefined_dispatch_mode_fn = {
128 | Dispatch.ONE_TO_ALL: {
129 | "dispatch_fn": dispatch_one_to_all,
130 | "collect_fn": collect_all_to_all,
131 | },
132 | Dispatch.ALL_TO_ALL: {
133 | "dispatch_fn": dispatch_all_to_all,
134 | "collect_fn": collect_all_to_all,
135 | },
136 | Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute},
137 | Dispatch.DP_COMPUTE_PROTO: {
138 | "dispatch_fn": dispatch_dp_compute_data_proto,
139 | "collect_fn": collect_dp_compute_data_proto,
140 | },
141 | Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
142 | "dispatch_fn": dispatch_dp_compute_data_proto_with_func,
143 | "collect_fn": collect_dp_compute_data_proto,
144 | },
145 | Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute},
146 | }
147 | return predefined_dispatch_mode_fn[dispatch_mode]
148 |
149 |
150 | def get_predefined_execute_fn(execute_mode):
151 | """
152 | Note that here we only asks execute_all and execute_rank_zero to be implemented
153 | Leave the choice of how these two functions handle argument 'blocking' to users
154 | """
155 | predefined_execute_mode_fn = {
156 | Execute.ALL: {"execute_fn_name": "execute_all"},
157 | Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
158 | }
159 | return predefined_execute_mode_fn[execute_mode]
160 |
161 |
162 | def _check_dispatch_mode(dispatch_mode):
163 | assert isinstance(dispatch_mode, (Dispatch, dict)), (
164 | f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
165 | )
166 | if isinstance(dispatch_mode, dict):
167 | necessary_keys = ["dispatch_fn", "collect_fn"]
168 | for key in necessary_keys:
169 | assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
170 |
171 |
172 | def _check_execute_mode(execute_mode):
173 | assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
174 |
175 |
176 | def _materialize_futures(*args, **kwargs):
177 | new_args = []
178 | for arg in args:
179 | if isinstance(arg, DataProtoFuture):
180 | arg = arg.get()
181 | # add more type to materialize
182 | new_args.append(arg)
183 | for k, v in kwargs.items():
184 | if isinstance(v, DataProtoFuture):
185 | kwargs[k] = v.get()
186 |
187 | new_args = tuple(new_args)
188 | return new_args, kwargs
189 |
190 |
191 | def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
192 | _check_dispatch_mode(dispatch_mode=dispatch_mode)
193 | _check_execute_mode(execute_mode=execute_mode)
194 |
195 | def decorator(func):
196 | @wraps(func)
197 | def inner(*args, **kwargs):
198 | if materialize_futures:
199 | args, kwargs = _materialize_futures(*args, **kwargs)
200 | return func(*args, **kwargs)
201 |
202 | attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
203 | setattr(inner, MAGIC_ATTR, attrs)
204 | return inner
205 |
206 | return decorator
207 |
--------------------------------------------------------------------------------
/verl/single_controller/base/worker_group.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | the class of WorkerGroup
16 | """
17 |
18 | import logging
19 | import signal
20 | import threading
21 | import time
22 | from typing import Any, Callable, Dict, List
23 |
24 | from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
25 |
26 |
27 | class ResourcePool:
28 | """The resource pool with meta info such as world_size."""
29 |
30 | def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
31 | if process_on_nodes is None:
32 | process_on_nodes = []
33 |
34 | self._store = process_on_nodes
35 | self.max_collocate_count = max_collocate_count
36 | self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
37 |
38 | def add_node(self, process_count):
39 | self._store.append(process_count)
40 |
41 | @property
42 | def world_size(self):
43 | return sum(self._store)
44 |
45 | def __call__(self) -> Any:
46 | return self._store
47 |
48 | @property
49 | def store(self):
50 | return self._store
51 |
52 | def local_world_size_list(self) -> List[int]:
53 | nested_local_world_size_list = [
54 | [local_world_size for _ in range(local_world_size)] for local_world_size in self._store
55 | ]
56 | return [item for row in nested_local_world_size_list for item in row]
57 |
58 | def local_rank_list(self) -> List[int]:
59 | nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416
60 | return [item for row in nested_local_rank_list for item in row]
61 |
62 |
63 | class ClassWithInitArgs:
64 | """
65 | This class stores a class constructor and the args/kwargs to construct the class.
66 | It is used to instantiate the remote class.
67 | """
68 |
69 | def __init__(self, cls, *args, **kwargs) -> None:
70 | self.cls = cls
71 | self.args = args
72 | self.kwargs = kwargs
73 |
74 | def __call__(self) -> Any:
75 | return self.cls(*self.args, **self.kwargs)
76 |
77 |
78 | def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
79 | import time
80 |
81 | while True:
82 | for worker in workers:
83 | if not is_alive(worker):
84 | logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
85 | signal.raise_signal(signal.SIGABRT)
86 |
87 | time.sleep(gap_time)
88 |
89 |
90 | class WorkerGroup:
91 | """A group of workers"""
92 |
93 | def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
94 | self._is_init_with_detached_workers = True if resource_pool is None else False
95 |
96 | if resource_pool is not None:
97 | # handle the case when WorkGroup is attached to an existing one
98 | self._procecss_dispatch_config = resource_pool()
99 | else:
100 | self._procecss_dispatch_config = None
101 |
102 | self._workers = []
103 | self._worker_names = []
104 |
105 | self._master_addr = None
106 | self._master_port = None
107 |
108 | self._checker_thread: threading.Thread = None
109 |
110 | def _is_worker_alive(self, worker):
111 | raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.")
112 |
113 | def _block_until_all_workers_alive(self) -> None:
114 | while True:
115 | all_state = [self._is_worker_alive(worker) for worker in self._workers]
116 | if False in all_state:
117 | time.sleep(1)
118 | else:
119 | break
120 |
121 | def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
122 | # before starting checking worker aliveness, make sure all workers are already alive
123 | self._block_until_all_workers_alive()
124 |
125 | self._checker_thread = threading.Thread(
126 | target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)
127 | )
128 | self._checker_thread.start()
129 |
130 | @property
131 | def world_size(self):
132 | return len(self._workers)
133 |
134 | def _bind_worker_method(self, user_defined_cls, func_generator):
135 | """
136 | Bind the worker method to the WorkerGroup
137 | """
138 | for method_name in dir(user_defined_cls):
139 | try:
140 | method = getattr(user_defined_cls, method_name)
141 | assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
142 | except Exception:
143 | # if it is a property, it will fail because Class doesn't have instance property
144 | continue
145 |
146 | if hasattr(method, MAGIC_ATTR):
147 | # this method is decorated by register
148 | attribute = getattr(method, MAGIC_ATTR)
149 | assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
150 | assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
151 |
152 | dispatch_mode = attribute["dispatch_mode"]
153 | execute_mode = attribute["execute_mode"]
154 | blocking = attribute["blocking"]
155 |
156 | # get dispatch fn
157 | if isinstance(dispatch_mode, Dispatch):
158 | # get default dispatch fn
159 | fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
160 | dispatch_fn = fn["dispatch_fn"]
161 | collect_fn = fn["collect_fn"]
162 | else:
163 | assert isinstance(dispatch_mode, dict)
164 | assert "dispatch_fn" in dispatch_mode
165 | assert "collect_fn" in dispatch_mode
166 | dispatch_fn = dispatch_mode["dispatch_fn"]
167 | collect_fn = dispatch_mode["collect_fn"]
168 |
169 | # get execute_fn_name
170 | execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
171 | wg_execute_fn_name = execute_mode["execute_fn_name"]
172 |
173 | # get execute_fn from string
174 | try:
175 | execute_fn = getattr(self, wg_execute_fn_name)
176 | assert callable(execute_fn), "execute_fn must be callable"
177 | except Exception:
178 | print(f"execute_fn {wg_execute_fn_name} is invalid")
179 | raise
180 |
181 | # bind a new method to the RayWorkerGroup
182 | func = func_generator(
183 | self,
184 | method_name,
185 | dispatch_fn=dispatch_fn,
186 | collect_fn=collect_fn,
187 | execute_fn=execute_fn,
188 | blocking=blocking,
189 | )
190 |
191 | try:
192 | setattr(self, method_name, func)
193 | except Exception:
194 | raise ValueError(f"Fail to set method_name {method_name}")
195 |
--------------------------------------------------------------------------------
/eval/utils/processing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from PIL import Image
4 | from typing import List, Dict, Tuple
5 | from concurrent.futures import ThreadPoolExecutor
6 | from tqdm import tqdm
7 |
8 | from utils.model_parser import llm_eval_score
9 | from mathruler.grader import extract_boxed_content, grade_answer
10 |
11 | def load_image(image_path: str, min_pixels: int, max_pixels: int) -> Image.Image:
12 | """Load and preprocess an image"""
13 | try:
14 | image = Image.open(image_path).convert("RGB")
15 |
16 | # Resize if too large or too small
17 | if (image.width * image.height) > max_pixels:
18 | resize_factor = math.sqrt(max_pixels / (image.width * image.height))
19 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
20 | image = image.resize((width, height))
21 |
22 | if (image.width * image.height) < min_pixels:
23 | resize_factor = math.sqrt(min_pixels / (image.width * image.height))
24 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
25 | image = image.resize((width, height))
26 |
27 | return image
28 | except Exception as e:
29 | print(f"Error processing image {image_path}: {str(e)}")
30 | return None
31 |
32 | def prepare_prompts(dataset_name: str, samples: List[Dict], args) -> Tuple[List[Dict], List[Dict]]:
33 | """Prepare prompts for all samples"""
34 | prompts = []
35 | metadata = []
36 |
37 | for item in tqdm(samples, desc=f"Preparing {dataset_name} prompts"):
38 | # Skip if image doesn't exist
39 | if not os.path.exists(item["image_path"]):
40 | continue
41 |
42 | # Load image
43 | image = load_image(item["image_path"], args.min_pixels, args.max_pixels)
44 | if image is None:
45 | continue
46 |
47 | # Create prompt
48 | if args.version == "7b":
49 | prompt_text = f"<|im_start|>system\n{args.system_prompt}<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{item['question']}<|im_end|>\n<|im_start|>assistant\n"
50 | elif args.version == "32b":
51 | prompt_text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{args.system_prompt} {item['question']}<|im_end|>\n<|im_start|>assistant\n"
52 | else:
53 | raise
54 |
55 | prompts.append({
56 | "prompt": prompt_text,
57 | "multi_modal_data": {"image": image},
58 | })
59 |
60 | metadata.append({
61 | "dataset": dataset_name,
62 | "id": item["id"],
63 | "question": item["question"],
64 | "answer": item["answer"],
65 | "prompt": prompt_text,
66 | **{k: v for k, v in item.items() if k not in ["image_path", "dataset", "id", "question", "answer"]}
67 | })
68 |
69 | return prompts, metadata
70 |
71 | def evaluate_prediction(prediction: str, answer: str, dataset: str, question: str = "") -> float:
72 | """Evaluate a prediction against the ground truth"""
73 | if dataset == "geo3k":
74 | extracted_answer = extract_boxed_content(prediction)
75 | return 1.0 if grade_answer(extracted_answer, answer) else 0.0
76 |
77 | elif dataset == "mathvista" or dataset == "mathverse" or dataset == "mathvision" or dataset == "wemath":
78 | try:
79 | score = llm_eval_score(question, prediction, answer, dataset)
80 | except:
81 | import time
82 | time.sleep(10)
83 | score = llm_eval_score(question, prediction, answer, dataset)
84 | return score
85 |
86 | if dataset == "hallubench":
87 | prediction = prediction.replace("\\boxed{}", "")
88 | extracted_answer = extract_boxed_content(prediction)
89 | return 1.0 if extracted_answer.lower() == answer else 0.0
90 | # return 1.0 if answer.lower() in prediction.lower() else 0.0
91 |
92 | else:
93 | # Default evaluation
94 | return 1.0 if extracted_answer == answer else 0.0
95 |
96 | def process_outputs(outputs, metadata, max_workers: int) -> Dict[str, List[Dict]]:
97 | """Process model outputs and calculate metrics"""
98 | results = []
99 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
100 | futures = []
101 |
102 | for i, output in enumerate(outputs):
103 | prediction = output.outputs[0].text.strip()
104 | meta = metadata[i]
105 | dataset = meta["dataset"]
106 | if "question_for_eval" in meta:
107 | question = meta["question_for_eval"]
108 | else:
109 | question = meta["question"]
110 |
111 | future = executor.submit(
112 | evaluate_prediction,
113 | prediction,
114 | meta["answer"],
115 | dataset,
116 | question
117 | )
118 | futures.append((future, i, prediction, meta))
119 |
120 | for future, i, prediction, meta in tqdm(futures, desc="Evaluating predictions"):
121 | try:
122 | accuracy = future.result()
123 |
124 | result = {
125 | "id": meta["id"],
126 | "question": meta["question"],
127 | "answer": meta["answer"],
128 | "prediction": prediction,
129 | "accuracy": accuracy,
130 | "correct": accuracy > 0,
131 | **{k: v for k, v in meta.items() if k not in ["dataset", "id", "question", "answer"]}
132 | }
133 |
134 | results.append(result)
135 | except Exception as e:
136 | print(f"Error evaluating prediction {i}: {str(e)}")
137 |
138 | return results
139 |
140 | def calculate_metrics(results: List[Dict]) -> Dict:
141 | """Calculate evaluation metrics"""
142 | if not results:
143 | return {"accuracy": 0.0}
144 |
145 | accuracy = sum(1 for r in results if r["correct"]) / len(results)
146 | metrics = {"accuracy": accuracy}
147 |
148 | # Calculate task-specific accuracies if available
149 | if any("task" in r for r in results):
150 | task_results = {}
151 | for r in results:
152 | if "task" in r:
153 | task = r["task"]
154 | if task not in task_results:
155 | task_results[task] = []
156 | task_results[task].append(r["correct"])
157 |
158 | task_accuracies = {task: sum(results) / len(results) for task, results in task_results.items()}
159 | metrics["sub_accuracies"] = task_accuracies
160 |
161 | # Calculate problem version accuracies if available
162 | if any("problem_version" in r for r in results):
163 | version_results = {}
164 | for r in results:
165 | if "problem_version" in r:
166 | version = r["problem_version"]
167 | if version not in version_results:
168 | version_results[version] = []
169 | version_results[version].append(r["correct"])
170 |
171 | version_accuracies = {version: sum(results) / len(results) for version, results in version_results.items()}
172 | metrics["sub_accuracies"] = version_accuracies
173 |
174 | # Calculate subject accuracies if available
175 | if any("subject" in r for r in results):
176 | subject_results = {}
177 | for r in results:
178 | if "subject" in r:
179 | subject = r["subject"]
180 | if subject not in subject_results:
181 | subject_results[subject] = []
182 | subject_results[subject].append(r["correct"])
183 |
184 | subject_accuracies = {subject: sum(results) / len(results) for subject, results in subject_results.items()}
185 | metrics["sub_accuracies"] = subject_accuracies
186 |
187 | return metrics
--------------------------------------------------------------------------------
/verl/models/transformers/qwen2_vl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
3 | # Based on:
4 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from typing import Optional, Tuple
19 |
20 | import torch
21 |
22 | from .flash_attention_utils import flash_attention_forward
23 |
24 |
25 | try:
26 | from transformers.models.qwen2_vl.modeling_qwen2_vl import (
27 | Qwen2VLAttention,
28 | apply_multimodal_rotary_pos_emb,
29 | repeat_kv,
30 | )
31 | from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
32 | except ImportError:
33 | pass
34 |
35 |
36 | def get_rope_index(
37 | processor: "Qwen2VLProcessor",
38 | input_ids: torch.Tensor,
39 | image_grid_thw: Optional[torch.Tensor] = None,
40 | video_grid_thw: Optional[torch.Tensor] = None,
41 | second_per_grid_ts: Optional[torch.Tensor] = None,
42 | attention_mask: Optional[torch.Tensor] = None,
43 | ) -> torch.Tensor:
44 | """
45 | Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
46 | The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
47 | https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
48 | """
49 | spatial_merge_size = processor.image_processor.merge_size
50 | tokens_per_second = 2
51 | image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
52 | video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
53 | vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
54 | if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
55 | if attention_mask is None:
56 | attention_mask = torch.ones_like(input_ids)
57 |
58 | position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
59 | image_index, video_index = 0, 0
60 | input_ids = input_ids[attention_mask == 1]
61 | image_nums, video_nums = 0, 0
62 | vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
63 | vision_tokens = input_ids[vision_start_indices + 1]
64 | image_nums = (vision_tokens == image_token_id).sum()
65 | video_nums = (vision_tokens == video_token_id).sum()
66 | input_tokens = input_ids.tolist()
67 | llm_pos_ids_list: list = []
68 | st = 0
69 | remain_images, remain_videos = image_nums, video_nums
70 | for _ in range(image_nums + video_nums):
71 | if image_token_id in input_tokens and remain_images > 0:
72 | ed_image = input_tokens.index(image_token_id, st)
73 | else:
74 | ed_image = len(input_tokens) + 1
75 | if video_token_id in input_tokens and remain_videos > 0:
76 | ed_video = input_tokens.index(video_token_id, st)
77 | else:
78 | ed_video = len(input_tokens) + 1
79 | if ed_image < ed_video:
80 | t, h, w = (
81 | image_grid_thw[image_index][0],
82 | image_grid_thw[image_index][1],
83 | image_grid_thw[image_index][2],
84 | )
85 | second_per_grid_t = 0
86 | image_index += 1
87 | remain_images -= 1
88 | ed = ed_image
89 | else:
90 | t, h, w = (
91 | video_grid_thw[video_index][0],
92 | video_grid_thw[video_index][1],
93 | video_grid_thw[video_index][2],
94 | )
95 | if second_per_grid_ts is not None:
96 | second_per_grid_t = second_per_grid_ts[video_index]
97 | else:
98 | second_per_grid_t = 1.0
99 |
100 | video_index += 1
101 | remain_videos -= 1
102 | ed = ed_video
103 |
104 | llm_grid_t, llm_grid_h, llm_grid_w = (
105 | t.item(),
106 | h.item() // spatial_merge_size,
107 | w.item() // spatial_merge_size,
108 | )
109 | text_len = ed - st
110 |
111 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
112 | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
113 |
114 | t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
115 | t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
116 | h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
117 | w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
118 | llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
119 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w
120 |
121 | if st < len(input_tokens):
122 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
123 | text_len = len(input_tokens) - st
124 | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
125 |
126 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
127 | position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
128 | else:
129 | if attention_mask is not None:
130 | position_ids = attention_mask.long().cumsum(-1) - 1
131 | position_ids.masked_fill_(attention_mask == 0, 1)
132 | position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
133 | else:
134 | position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
135 |
136 | return position_ids
137 |
138 |
139 | def qwen2_vl_attn_forward(
140 | self: "Qwen2VLAttention",
141 | hidden_states: torch.Tensor,
142 | attention_mask: Optional[torch.Tensor] = None,
143 | position_ids: Optional[torch.LongTensor] = None,
144 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
145 | **kwargs,
146 | ) -> Tuple[torch.Tensor, None, None]:
147 | bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
148 | query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
149 | key_states = self.k_proj(hidden_states)
150 | value_states = self.v_proj(hidden_states)
151 |
152 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
153 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
154 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
155 |
156 | # Because the input can be padded, the absolute sequence length depends on the max position id.
157 | if position_embeddings is None:
158 | cos, sin = self.rotary_emb(value_states, position_ids)
159 | else:
160 | cos, sin = position_embeddings
161 |
162 | query_states, key_states = apply_multimodal_rotary_pos_emb(
163 | query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
164 | )
165 | key_states = repeat_kv(key_states, self.num_key_value_groups)
166 | value_states = repeat_kv(value_states, self.num_key_value_groups)
167 | dropout_rate = 0.0 if not self.training else self.attention_dropout
168 |
169 | sliding_window = None
170 | if (
171 | self.config.use_sliding_window
172 | and getattr(self.config, "sliding_window", None) is not None
173 | and self.layer_idx >= self.config.max_window_layers
174 | ):
175 | sliding_window = self.config.sliding_window
176 |
177 | attn_output, _ = flash_attention_forward(
178 | self,
179 | query_states,
180 | key_states,
181 | value_states,
182 | attention_mask,
183 | dropout=dropout_rate,
184 | sliding_window=sliding_window,
185 | position_ids=position_ids, # important: pass position ids
186 | ) # (batch_size, seq_length, num_head / sp_size, head_size)
187 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
188 | attn_output = self.o_proj(attn_output)
189 | return attn_output, None, None
190 |
--------------------------------------------------------------------------------
/verl/models/transformers/flash_attention_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
3 | # Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import inspect
18 | import os
19 | from typing import Optional, Tuple
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
24 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
25 |
26 | from ...utils.ulysses import (
27 | gather_heads_scatter_seq,
28 | gather_seq_scatter_heads,
29 | get_ulysses_sequence_parallel_group,
30 | get_ulysses_sequence_parallel_world_size,
31 | )
32 |
33 |
34 | if is_flash_attn_2_available():
35 | from flash_attn import flash_attn_func, flash_attn_varlen_func
36 |
37 | _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
38 | _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
39 | _flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
40 | _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
41 |
42 |
43 | def prepare_fa2_from_position_ids(
44 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
45 | ):
46 | query = query.view(-1, query.size(-2), query.size(-1))
47 | key = key.contiguous().view(-1, key.size(-2), key.size(-1))
48 | value = value.contiguous().view(-1, value.size(-2), value.size(-1))
49 | position_ids = position_ids.flatten()
50 | indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
51 | cu_seqlens = torch.cat(
52 | (
53 | indices_q[position_ids == 0],
54 | torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
55 | )
56 | )
57 | max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
58 | return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
59 |
60 |
61 | def _custom_flash_attention_forward(
62 | query_states: torch.Tensor,
63 | key_states: torch.Tensor,
64 | value_states: torch.Tensor,
65 | attention_mask: Optional[torch.Tensor],
66 | query_length: int,
67 | is_causal: bool = True,
68 | position_ids: Optional[torch.Tensor] = None,
69 | sliding_window: Optional[int] = None,
70 | use_top_left_mask: bool = False,
71 | deterministic: Optional[bool] = None,
72 | **kwargs,
73 | ):
74 | """
75 | Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
76 | """
77 | if not use_top_left_mask:
78 | causal = is_causal
79 | else:
80 | causal = is_causal and query_length != 1
81 |
82 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
83 | use_sliding_windows = (
84 | _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
85 | )
86 | flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
87 |
88 | if _flash_supports_deterministic:
89 | if deterministic is None:
90 | deterministic = _flash_deterministic_enabled
91 | flash_kwargs["deterministic"] = deterministic
92 |
93 | if kwargs.get("softcap") is not None:
94 | flash_kwargs["softcap"] = kwargs.pop("softcap")
95 |
96 | query_states, key_states, value_states = fa_peft_integration_check(
97 | query_states, key_states, value_states, target_dtype=torch.bfloat16
98 | )
99 |
100 | sp_size = get_ulysses_sequence_parallel_world_size()
101 | if sp_size > 1:
102 | # (batch_size, seq_length, num_head, head_size)
103 | query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
104 | key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
105 | value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
106 | position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
107 | position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
108 | position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
109 |
110 | if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
111 | position_ids = position_ids[0]
112 |
113 | if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
114 | batch_size = query_states.size(0)
115 | query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
116 | query_states, key_states, value_states, position_ids
117 | ) # remove channel dimension
118 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
119 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
120 | attn_output = flash_attn_varlen_func(
121 | query_states,
122 | key_states,
123 | value_states,
124 | cu_seqlens_q=cu_seqlens_q,
125 | cu_seqlens_k=cu_seqlens_k,
126 | max_seqlen_q=max_seqlen_in_batch_q,
127 | max_seqlen_k=max_seqlen_in_batch_k,
128 | dropout_p=kwargs.pop("dropout", 0.0),
129 | softmax_scale=kwargs.pop("softmax_scale", None),
130 | causal=causal,
131 | **flash_kwargs,
132 | )
133 | attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
134 | else:
135 | attn_output = _flash_attention_forward(
136 | query_states,
137 | key_states,
138 | value_states,
139 | attention_mask,
140 | query_length,
141 | is_causal=is_causal,
142 | sliding_window=sliding_window,
143 | use_top_left_mask=use_top_left_mask,
144 | deterministic=deterministic,
145 | **kwargs,
146 | ) # do not pass position_ids to old flash_attention_forward
147 |
148 | if sp_size > 1:
149 | # (batch_size, seq_length, num_head, head_size)
150 | attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
151 |
152 | return attn_output
153 |
154 |
155 | def flash_attention_forward(
156 | module: torch.nn.Module,
157 | query: torch.Tensor,
158 | key: torch.Tensor,
159 | value: torch.Tensor,
160 | attention_mask: Optional[torch.Tensor],
161 | dropout: float = 0.0,
162 | scaling: Optional[float] = None,
163 | sliding_window: Optional[int] = None,
164 | softcap: Optional[float] = None,
165 | **kwargs,
166 | ) -> Tuple[torch.Tensor, None]:
167 | # This is before the transpose
168 | q_len = query.shape[2]
169 |
170 | # FA2 uses non-transposed inputs
171 | query = query.transpose(1, 2)
172 | key = key.transpose(1, 2)
173 | value = value.transpose(1, 2)
174 |
175 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
176 | # therefore the input hidden states gets silently casted in float32. Hence, we need
177 | # cast them back in the correct dtype just to be sure everything works as expected.
178 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms
179 | # in fp32. (usually our RMSNorm modules handle it correctly)
180 | target_dtype = None
181 | if query.dtype == torch.float32:
182 | if torch.is_autocast_enabled():
183 | target_dtype = torch.get_autocast_gpu_dtype()
184 | # Handle the case where the model is quantized
185 | elif hasattr(module.config, "_pre_quantization_dtype"):
186 | target_dtype = module.config._pre_quantization_dtype
187 | else:
188 | target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
189 |
190 | # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
191 | kwargs.pop("is_causal", None)
192 |
193 | attn_output = _custom_flash_attention_forward(
194 | query,
195 | key,
196 | value,
197 | attention_mask,
198 | query_length=q_len,
199 | is_causal=True,
200 | dropout=dropout,
201 | softmax_scale=scaling,
202 | sliding_window=sliding_window,
203 | softcap=softcap,
204 | use_top_left_mask=_flash_use_top_left_mask,
205 | target_dtype=target_dtype,
206 | **kwargs,
207 | )
208 |
209 | return attn_output, None
210 |
--------------------------------------------------------------------------------
/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The vllm_rollout that can be applied in different backend
16 | When working with FSDP:
17 | - Use DTensor weight loader (recommended) or HF weight loader
18 | - Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
19 | """
20 |
21 | from contextlib import contextmanager
22 | from typing import Any, List, Union
23 |
24 | import numpy as np
25 | import torch
26 | import torch.distributed
27 | from tensordict import TensorDict
28 | from transformers import PreTrainedTokenizer
29 | from vllm import LLM, RequestOutput, SamplingParams
30 |
31 | from ....protocol import DataProto
32 | from ....utils import torch_functional as VF
33 | from ....utils.torch_dtypes import PrecisionType
34 | from ..base import BaseRollout
35 | from ..config import RolloutConfig
36 |
37 |
38 | def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
39 | if isinstance(value, torch.Tensor):
40 | return value.repeat_interleave(repeats, dim=0)
41 | else:
42 | return np.repeat(value, repeats, axis=0)
43 |
44 |
45 | class vLLMRollout(BaseRollout):
46 | def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer):
47 | """A vLLM rollout. It requires the module is supported by the vllm.
48 |
49 | Args:
50 | module: module here follows huggingface APIs
51 | config: DictConfig
52 | tokenizer: the task/model tokenizer
53 | """
54 | super().__init__()
55 | self.config = config
56 | self.pad_token_id = tokenizer.pad_token_id
57 | if config.tensor_parallel_size > torch.distributed.get_world_size():
58 | raise ValueError("Tensor parallelism size should be less than world size.")
59 |
60 | if not config.enforce_eager and config.free_cache_engine:
61 | raise ValueError("CUDA graph should be disabled when `free_cache_engine` is True.")
62 |
63 | if config.max_num_batched_tokens < config.prompt_length + config.response_length:
64 | raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")
65 |
66 | vllm_init_kwargs = {}
67 | if config.limit_images > 0:
68 | vllm_init_kwargs = {"limit_mm_per_prompt": {"image": config.limit_images}}
69 |
70 | self.inference_engine = LLM(
71 | model=model_path,
72 | skip_tokenizer_init=False,
73 | tensor_parallel_size=config.tensor_parallel_size,
74 | dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
75 | gpu_memory_utilization=config.gpu_memory_utilization,
76 | enforce_eager=config.enforce_eager,
77 | max_model_len=config.prompt_length + config.response_length,
78 | max_num_batched_tokens=config.max_num_batched_tokens,
79 | enable_sleep_mode=True,
80 | distributed_executor_backend="external_launcher",
81 | disable_custom_all_reduce=True,
82 | disable_log_stats=config.disable_log_stats,
83 | enable_chunked_prefill=config.enable_chunked_prefill,
84 | **vllm_init_kwargs,
85 | )
86 |
87 | # Offload vllm model to reduce peak memory usage
88 | self.inference_engine.sleep(level=1)
89 |
90 | sampling_kwargs = {"max_tokens": config.response_length, "detokenize": False}
91 | default_sampling_params = SamplingParams()
92 | for key in config.to_dict().keys():
93 | if hasattr(default_sampling_params, key):
94 | sampling_kwargs[key] = getattr(config, key)
95 |
96 | print(f"Sampling params: {sampling_kwargs}.")
97 | self.sampling_params = SamplingParams(**sampling_kwargs)
98 |
99 | @contextmanager
100 | def update_sampling_params(self, **kwargs):
101 | # update sampling params
102 | old_sampling_params_args = {}
103 | if kwargs:
104 | for key, value in kwargs.items():
105 | if hasattr(self.sampling_params, key):
106 | old_value = getattr(self.sampling_params, key)
107 | old_sampling_params_args[key] = old_value
108 | setattr(self.sampling_params, key, value)
109 |
110 | yield
111 | # roll back to previous sampling params
112 | for key, value in old_sampling_params_args.items():
113 | setattr(self.sampling_params, key, value)
114 |
115 | @torch.no_grad()
116 | def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
117 | # left-padded attention_mask
118 | input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length)
119 | attention_mask: torch.Tensor = prompts.batch["attention_mask"]
120 | position_ids: torch.Tensor = prompts.batch["position_ids"]
121 | eos_token_id: int = prompts.meta_info["eos_token_id"]
122 | batch_size = input_ids.size(0)
123 |
124 | do_sample = prompts.meta_info.get("do_sample", True)
125 | if not do_sample:
126 | kwargs = {
127 | "n": 1,
128 | "temperature": 0.0,
129 | "top_p": 1.0,
130 | "top_k": -1,
131 | "min_p": 0.0,
132 | }
133 |
134 | non_tensor_batch = prompts.non_tensor_batch
135 | if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
136 | raise RuntimeError("vllm sharding manager is not work properly.")
137 |
138 | if "multi_modal_data" in non_tensor_batch:
139 | vllm_inputs = []
140 | for raw_prompt_ids, multi_modal_data in zip(
141 | non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")
142 | ):
143 | vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data})
144 | else:
145 | vllm_inputs = [
146 | {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
147 | ]
148 |
149 | # users can customize different sampling_params at different run
150 | with self.update_sampling_params(**kwargs):
151 | completions: List[RequestOutput] = self.inference_engine.generate(
152 | prompts=vllm_inputs, sampling_params=self.sampling_params
153 | )
154 |
155 | response_ids = []
156 | for completion in completions:
157 | for output in completion.outputs:
158 | response_ids.append(output.token_ids)
159 |
160 | response_ids = VF.pad_2d_list_to_length(
161 | response_ids, self.pad_token_id, max_length=self.config.response_length
162 | ).to(input_ids.device)
163 |
164 | if self.config.n > 1 and do_sample:
165 | batch_size = batch_size * self.config.n
166 | input_ids = _repeat_interleave(input_ids, self.config.n)
167 | attention_mask = _repeat_interleave(attention_mask, self.config.n)
168 | position_ids = _repeat_interleave(position_ids, self.config.n)
169 | if "multi_modal_inputs" in non_tensor_batch.keys():
170 | non_tensor_batch["multi_modal_inputs"] = _repeat_interleave(
171 | non_tensor_batch["multi_modal_inputs"], self.config.n
172 | )
173 |
174 | sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
175 | response_length = response_ids.size(1)
176 | delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
177 | delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1)
178 | if position_ids.dim() == 3: # qwen2vl mrope
179 | delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
180 |
181 | # prompt: left pad + response: right pad
182 | # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
183 | # position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
184 | response_position_ids = position_ids[..., -1:] + delta_position_id
185 | position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
186 | response_attention_mask = VF.get_eos_mask(
187 | response_ids=response_ids, eos_token=eos_token_id, dtype=attention_mask.dtype
188 | )
189 | attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
190 |
191 | # all the tp ranks should contain the same data here. data in all ranks are valid
192 | batch = TensorDict(
193 | {
194 | "prompts": input_ids,
195 | "responses": response_ids,
196 | "input_ids": sequence_ids, # here input_ids become the whole sentences
197 | "attention_mask": attention_mask,
198 | "position_ids": position_ids,
199 | },
200 | batch_size=batch_size,
201 | )
202 | return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
203 |
--------------------------------------------------------------------------------