├── .github
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
└── workflows
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── Dockerfile
├── Dockerfile.legacy
├── LICENSE
├── Makefile
├── README.md
├── assets
├── baselines.md
├── easyr1_grpo.png
├── qwen2_5_vl_7b_geo.png
└── wechat.jpg
├── examples
├── baselines
│ ├── qwen2_5_vl_3b_clevr.sh
│ └── qwen2_5_vl_3b_geoqa8k.sh
├── config.yaml
├── format_prompt
│ ├── math_format.jinja
│ └── r1v_format.jinja
├── qwen2_5_7b_math_grpo.sh
├── qwen2_5_vl_32b_geo3k_grpo.sh
├── qwen2_5_vl_3b_geo3k_grpo.sh
├── qwen2_5_vl_7b_geo3k_grpo.sh
├── qwen2_5_vl_7b_geo3k_reinforce.sh
├── qwen2_5_vl_7b_geo3k_swanlab.sh
├── qwen2_5_vl_7b_multi_image.sh
├── qwen3_4b_math_grpo.sh
├── reward_function
│ ├── math.py
│ └── r1v.py
└── runtime_env.yaml
├── pyproject.toml
├── requirements.txt
├── scripts
└── model_merger.py
├── setup.py
└── verl
├── __init__.py
├── models
├── __init__.py
├── monkey_patch.py
└── transformers
│ ├── __init__.py
│ ├── flash_attention_utils.py
│ └── qwen2_vl.py
├── protocol.py
├── single_controller
├── __init__.py
├── base
│ ├── __init__.py
│ ├── decorator.py
│ ├── register_center
│ │ ├── __init__.py
│ │ └── ray.py
│ ├── worker.py
│ └── worker_group.py
└── ray
│ ├── __init__.py
│ └── base.py
├── trainer
├── __init__.py
├── config.py
├── core_algos.py
├── data_loader.py
├── main.py
├── metrics.py
└── ray_trainer.py
├── utils
├── __init__.py
├── checkpoint
│ ├── __init__.py
│ ├── checkpoint_manager.py
│ └── fsdp_checkpoint_manager.py
├── dataset.py
├── flops_counter.py
├── fsdp_utils.py
├── logger
│ ├── __init__.py
│ ├── gen_logger.py
│ └── logger.py
├── model_utils.py
├── py_functional.py
├── seqlen_balancing.py
├── tokenizer.py
├── torch_dtypes.py
├── torch_functional.py
└── ulysses.py
└── workers
├── __init__.py
├── actor
├── __init__.py
├── base.py
├── config.py
└── dp_actor.py
├── config.py
├── critic
├── __init__.py
├── base.py
├── config.py
└── dp_critic.py
├── fsdp_workers.py
├── reward
├── __init__.py
├── config.py
└── function.py
├── rollout
├── __init__.py
├── base.py
├── config.py
└── vllm_rollout_spmd.py
└── sharding_manager
├── __init__.py
├── base.py
├── fsdp_ulysses.py
└── fsdp_vllm.py
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | `hoshihiyouga AT gmail DOT com`.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to EasyR1
2 |
3 | Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
4 |
5 | It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
6 |
7 | However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md).
8 |
9 | **This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**
10 |
11 | ## Ways to contribute
12 |
13 | There are several ways you can contribute to EasyR1:
14 |
15 | * Fix outstanding issues with the existing code.
16 | * Submit issues related to bugs or desired new features.
17 | * Contribute to the examples or to the documentation.
18 |
19 | ### Style guide
20 |
21 | EasyR1 follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
22 |
23 | ### Create a Pull Request
24 |
25 | 1. Fork the [repository](https://github.com/hiyouga/EasyR1) by clicking on the [Fork](https://github.com/hiyouga/EasyR1/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
26 |
27 | 2. Clone your fork to your local disk, and add the base repository as a remote:
28 |
29 | ```bash
30 | git clone git@github.com:[username]/EasyR1.git
31 | cd EasyR1
32 | git remote add upstream https://github.com/hiyouga/EasyR1.git
33 | ```
34 |
35 | 3. Create a new branch to hold your development changes:
36 |
37 | ```bash
38 | git checkout -b dev_your_branch
39 | ```
40 |
41 | 4. Set up a development environment by running the following command in a virtual environment:
42 |
43 | ```bash
44 | pip install -e ".[dev]"
45 | ```
46 |
47 | 5. Check code before commit:
48 |
49 | ```bash
50 | make commit
51 | make style && make quality
52 | ```
53 |
54 | 6. Submit changes:
55 |
56 | ```bash
57 | git add .
58 | git commit -m "commit message"
59 | git fetch upstream
60 | git rebase upstream/main
61 | git push -u origin dev_your_branch
62 | ```
63 |
64 | 7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/EasyR1).
65 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | paths:
8 | - "**.py"
9 | - "requirements.txt"
10 | - ".github/workflows/*.yml"
11 | pull_request:
12 | branches:
13 | - "main"
14 | paths:
15 | - "**.py"
16 | - "requirements.txt"
17 | - ".github/workflows/*.yml"
18 |
19 | jobs:
20 | tests:
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | python-version:
25 | - "3.11"
26 | os:
27 | - "ubuntu-latest"
28 |
29 | runs-on: ${{ matrix.os }}
30 |
31 | steps:
32 | - name: Checkout
33 | uses: actions/checkout@v4
34 |
35 | - name: Set up Python
36 | uses: actions/setup-python@v5
37 | with:
38 | python-version: ${{ matrix.python-version }}
39 | cache: "pip"
40 | cache-dependency-path: "setup.py"
41 |
42 | - name: Install dependencies
43 | run: |
44 | python -m pip install --upgrade pip
45 | python -m pip install ruff
46 |
47 | - name: Check quality
48 | run: |
49 | make style && make quality
50 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # outputs
174 | outputs/
175 | checkpoints/
176 | wandb/
177 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: check-ast
6 | - id: check-added-large-files
7 | args: ['--maxkb=25000']
8 | - id: check-merge-conflict
9 | - id: check-yaml
10 | - id: debug-statements
11 | - id: end-of-file-fixer
12 | - id: requirements-txt-fixer
13 | - id: trailing-whitespace
14 | args: [--markdown-linebreak-ext=md]
15 | - id: no-commit-to-branch
16 | args: ['--branch', 'main']
17 |
18 | - repo: https://github.com/asottile/pyupgrade
19 | rev: v3.17.0
20 | hooks:
21 | - id: pyupgrade
22 | args: [--py38-plus]
23 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)
2 | # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
3 | FROM nvcr.io/nvidia/pytorch:24.08-py3
4 |
5 | # Define environments
6 | ENV MAX_JOBS=32
7 | ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
8 | ENV DEBIAN_FRONTEND=noninteractive
9 | ENV NODE_OPTIONS=""
10 | ENV PIP_ROOT_USER_ACTION=ignore
11 | ENV HF_HUB_ENABLE_HF_TRANSFER="1"
12 |
13 | # Define installation arguments
14 | ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
15 | ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
16 |
17 | # Set apt source
18 | RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
19 | { \
20 | echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
21 | echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
22 | echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
23 | echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
24 | } > /etc/apt/sources.list
25 |
26 | # Install systemctl
27 | RUN apt-get update && \
28 | apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \
29 | apt-get clean
30 |
31 | # Install tini
32 | RUN apt-get update && \
33 | apt-get install -y tini && \
34 | apt-get clean
35 |
36 | # Change pip source
37 | RUN pip config set global.index-url "${PIP_INDEX}" && \
38 | pip config set global.extra-index-url "${PIP_INDEX}" && \
39 | python -m pip install --upgrade pip
40 |
41 | # Uninstall nv-pytorch fork
42 | RUN pip uninstall -y torch torchvision torchaudio \
43 | pytorch-quantization pytorch-triton torch-tensorrt \
44 | transformer_engine flash_attn apex megatron-core \
45 | xgboost opencv grpcio
46 |
47 | # Fix cv2
48 | RUN rm -rf /usr/local/lib/python3.10/dist-packages/cv2
49 |
50 | # Install torch-2.6.0+cu124 + vllm-0.8.4
51 | # torch-2.6.0+cu124: cxx11abi=False
52 | # torch-2.6.0+cu126: cxx11abi=True
53 | # see https://github.com/flashinfer-ai/flashinfer/issues/911
54 | RUN pip install --no-cache-dir "vllm==0.8.4" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" tensordict torchdata \
55 | "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \
56 | "numpy<2.0.0" "pyarrow>=15.0.0" pandas \
57 | ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb liger-kernel mathruler \
58 | pytest yapf py-spy pyext pre-commit ruff
59 |
60 | # Install flash-attn-2.7.4.post1 (cxx11abi=False)
61 | RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
62 | pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
63 |
64 | # Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False)
65 | # vllm-0.8.3 does not support flashinfer>=0.2.3
66 | # see https://github.com/vllm-project/vllm/pull/15777
67 | RUN wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \
68 | pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl
69 |
70 | # Fix packages
71 | RUN pip uninstall -y pynvml nvidia-ml-py && \
72 | pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1"
73 |
74 | # Reset pip config
75 | RUN pip config unset global.index-url && \
76 | pip config unset global.extra-index-url
77 |
--------------------------------------------------------------------------------
/Dockerfile.legacy:
--------------------------------------------------------------------------------
1 | # Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)
2 | # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
3 | FROM nvcr.io/nvidia/pytorch:24.08-py3
4 |
5 | # Define environments
6 | ENV MAX_JOBS=32
7 | ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
8 | ENV DEBIAN_FRONTEND=noninteractive
9 | ENV NODE_OPTIONS=""
10 | ENV PIP_ROOT_USER_ACTION=ignore
11 | ENV HF_HUB_ENABLE_HF_TRANSFER="1"
12 |
13 | # Define installation arguments
14 | ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
15 | ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
16 | ARG VLLM_COMMIT=227578480d71fc94ef46ca77fb69496412158d68
17 |
18 | # Set apt source
19 | RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
20 | { \
21 | echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
22 | echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
23 | echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
24 | echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
25 | } > /etc/apt/sources.list
26 |
27 | # Install systemctl
28 | RUN apt-get update && \
29 | apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \
30 | apt-get clean
31 |
32 | # Install tini
33 | RUN apt-get update && \
34 | apt-get install -y tini && \
35 | apt-get clean
36 |
37 | # Change pip source
38 | RUN pip config set global.index-url "${PIP_INDEX}" && \
39 | pip config set global.extra-index-url "${PIP_INDEX}" && \
40 | python -m pip install --upgrade pip
41 |
42 | # Uninstall nv-pytorch fork
43 | RUN pip uninstall -y torch torchvision torchaudio \
44 | pytorch-quantization pytorch-triton torch-tensorrt \
45 | transformer_engine flash_attn apex megatron-core \
46 | xgboost opencv grpcio
47 |
48 | # Fix cv2
49 | RUN rm -rf /usr/local/lib/python3.10/dist-packages/cv2
50 |
51 | # Install vllm-0.7.4-nightly
52 | RUN pip install --no-cache-dir vllm --pre --extra-index-url "https://wheels.vllm.ai/${VLLM_COMMIT}" && \
53 | git clone -b verl_v1 https://github.com/hiyouga/vllm.git && \
54 | cp -r vllm/vllm/ /usr/local/lib/python3.10/dist-packages/
55 |
56 | # Install torch-2.5.1
57 | RUN pip install --no-cache-dir "torch==2.5.1" "torchvision==0.20.1" "torchaudio==2.5.1" tensordict torchdata \
58 | "transformers>=4.49.0" accelerate datasets peft hf-transfer \
59 | ray[default] codetiming hydra-core pandas "pyarrow>=15.0.0" pylatexenc qwen-vl-utils wandb liger-kernel mathruler \
60 | pytest yapf py-spy pyext pre-commit ruff
61 |
62 | # Install flash_attn-2.7.4.post1
63 | RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
64 | pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
65 |
66 | # Fix packages
67 | RUN pip uninstall -y pynvml nvidia-ml-py && \
68 | pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1"
69 |
70 | # Reset pip config
71 | RUN pip config unset global.index-url && \
72 | pip config unset global.extra-index-url
73 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: build commit quality style
2 |
3 | check_dirs := examples scripts verl setup.py
4 |
5 | build:
6 | python3 setup.py sdist bdist_wheel
7 |
8 | commit:
9 | pre-commit install
10 | pre-commit run --all-files
11 |
12 | quality:
13 | ruff check $(check_dirs)
14 | ruff format --check $(check_dirs)
15 |
16 | style:
17 | ruff check $(check_dirs) --fix
18 | ruff format $(check_dirs)
19 |
--------------------------------------------------------------------------------
/assets/baselines.md:
--------------------------------------------------------------------------------
1 | # Baselines
2 |
3 | Environment: [hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0](https://hub.docker.com/layers/hiyouga/verl/ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0/images/sha256-335ed6cd1fe73090e458409cfa4394d6abf4cd0503ca44dbafdc28ff72e5ed20)
4 |
5 | EasyR1 version: [v0.3.0](https://github.com/hiyouga/EasyR1/tree/v0.3.0)
6 |
7 | Welcome to contribute new data points!
8 |
9 | ## Algorithm Baselines
10 |
11 | ### [Qwen2.5-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) on [Math12k](https://huggingface.co/datasets/hiyouga/math12k)
12 |
13 | | Size | Algorithm | Bits | LR | KL | Test Score |
14 | | ---- | ----------- | ---- | ---- | ---- | ---------- |
15 | | 7B | GRPO | AMP | 1e-6 | 1e-2 | 0.73->0.79 |
16 |
17 | ### [Qwen2.5-VL-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) on [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k)
18 |
19 | | Size | Algorithm | Bits | LR | KL | Test Score |
20 | | ---- | ----------- | ---- | ---- | ---- | ---------- |
21 | | 7B | GRPO | AMP | 1e-6 | 1e-2 | 0.39->0.52 |
22 | | 7B | GRPO | BF16 | 1e-6 | 1e-2 | 0.39->0.52 |
23 | | 7B | GRPO | AMP | 1e-6 | 1e-3 | 0.39->0.52 |
24 | | 7B | RLOO | AMP | 1e-6 | 1e-2 | 0.39->0.53 |
25 | | 3B | GRPO | AMP | 1e-6 | 1e-2 | 0.27->0.44 |
26 | | 32B | GRPO | BF16 | 1e-6 | 1e-2 | 0.46->0.61 |
27 |
28 | > [!NOTE]
29 | > The hyper-parameters not listed are all the same as the default values.
30 |
31 | ## Performance Baselines
32 |
33 | ### [Qwen2.5-VL-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) on [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k)
34 |
35 | | Size | GPU Type | Bits | Batch Size | vLLM Util | vLLM TP | Peak Mem | Peak VRAM | Throughput | Sec per step | Actor MFU |
36 | | ---- | ------------- | ---- | ---------- | --------- | ------- | -------- | --------- | ---------- | ------------ | --------- |
37 | | 3B | 8 * H100 80GB | AMP | 4 / 16 | 0.6 | 2 | 120GB | 35GB | 1200 | 180s | 6.3% |
38 | | 7B | 8 * H100 80GB | AMP | 4 / 16 | 0.6 | 2 | 140GB | 60GB | 1200 | 180s | 13.6% |
39 | | 7B | 8 * H100 80GB | AMP | 10 / 20 | 0.6 | 2 | 150GB | 75GB | 1400 | 170s | 19.2% |
40 | | 7B | 8 * L20 48GB | AMP | 4 / 16 | 0.6 | 2 | 150GB | 44GB | 410 | 580s | 26.5% |
41 | | 7B | 8 * H100 80GB | BF16 | 4 / 16 | 0.6 | 2 | 150GB | 50GB | 1280 | 190s | 13.9% |
42 | | 32B | 8 * H100 80GB | BF16 | 1 / 8 | 0.6 | 8 | 240GB | 68GB | 360 | 860s | 11.2% |
43 |
44 | - Batch Size: micro_batch_size_per_device_for_update / micro_batch_size_per_device_for_experience
45 | - vLLM Util: rollout.gpu_memory_utilization
46 | - vLLM TP: rollout.tensor_parallel_size
47 | - Peak Mem: Peak CPU memory usage
48 | - Peak VRAM: Peak GPU memory usage
49 | - Throughput: Number of tokens per second per GPU by one training step
50 | - Sec per step: Average time per step in seconds
51 |
52 | > [!NOTE]
53 | > The hyper-parameters not listed are all the same as the default values.
54 |
--------------------------------------------------------------------------------
/assets/easyr1_grpo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiyouga/EasyR1/c835ae63c2f302b1491529527fc39127c6e2379b/assets/easyr1_grpo.png
--------------------------------------------------------------------------------
/assets/qwen2_5_vl_7b_geo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiyouga/EasyR1/c835ae63c2f302b1491529527fc39127c6e2379b/assets/qwen2_5_vl_7b_geo.png
--------------------------------------------------------------------------------
/assets/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiyouga/EasyR1/c835ae63c2f302b1491529527fc39127c6e2379b/assets/wechat.jpg
--------------------------------------------------------------------------------
/examples/baselines/qwen2_5_vl_3b_clevr.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | data.train_files=BUAADreamer/clevr_count_70k@train \
12 | data.val_files=BUAADreamer/clevr_count_70k@test \
13 | data.format_prompt=./examples/format_prompt/r1v_format.jinja \
14 | worker.actor.model.model_path=${MODEL_PATH} \
15 | worker.rollout.tensor_parallel_size=1 \
16 | worker.reward.reward_type=sequential \
17 | worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
18 | trainer.experiment_name=qwen2_5_vl_3b_clevr \
19 | trainer.n_gpus_per_node=2
20 |
--------------------------------------------------------------------------------
/examples/baselines/qwen2_5_vl_3b_geoqa8k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | data.train_files=leonardPKU/GEOQA_8K_R1V@train \
12 | data.val_files=leonardPKU/GEOQA_8K_R1V@test \
13 | data.format_prompt=./examples/format_prompt/r1v_format.jinja \
14 | worker.actor.model.model_path=${MODEL_PATH} \
15 | worker.rollout.tensor_parallel_size=1 \
16 | worker.reward.reward_type=sequential \
17 | worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
18 | trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
19 | trainer.n_gpus_per_node=8
20 |
--------------------------------------------------------------------------------
/examples/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_files: hiyouga/math12k@train
3 | val_files: hiyouga/math12k@test
4 | prompt_key: problem
5 | answer_key: answer
6 | image_key: images
7 | max_prompt_length: 2048
8 | max_response_length: 2048
9 | rollout_batch_size: 512
10 | val_batch_size: 1024
11 | format_prompt: ./examples/format_prompt/math_format.jinja
12 | override_chat_template: null
13 | shuffle: true
14 | seed: 1
15 | max_pixels: 4194304
16 | min_pixels: 262144
17 | filter_overlong_prompts: true
18 |
19 | algorithm:
20 | adv_estimator: grpo
21 | disable_kl: false
22 | use_kl_loss: true
23 | kl_penalty: low_var_kl
24 | kl_coef: 1.0e-2
25 |
26 | worker:
27 | actor:
28 | global_batch_size: 128
29 | micro_batch_size_per_device_for_update: 4
30 | micro_batch_size_per_device_for_experience: 16
31 | max_grad_norm: 1.0
32 | padding_free: true
33 | ulysses_sequence_parallel_size: 1
34 | model:
35 | model_path: Qwen/Qwen2.5-7B-Instruct
36 | enable_gradient_checkpointing: true
37 | trust_remote_code: false
38 | freeze_vision_tower: false
39 | optim:
40 | lr: 1.0e-6
41 | weight_decay: 1.0e-2
42 | strategy: adamw # {adamw, adamw_bf16}
43 | lr_warmup_ratio: 0.0
44 | fsdp:
45 | enable_full_shard: true
46 | enable_cpu_offload: false
47 | enable_rank0_init: true
48 | offload:
49 | offload_params: true # true: more CPU memory; false: more GPU memory
50 | offload_optimizer: true # true: more CPU memory; false: more GPU memory
51 |
52 | rollout:
53 | n: 5
54 | temperature: 1.0
55 | top_p: 0.99
56 | gpu_memory_utilization: 0.6
57 | enforce_eager: false
58 | enable_chunked_prefill: false
59 | tensor_parallel_size: 2
60 | limit_images: 0
61 | val_override_config:
62 | temperature: 0.5
63 | n: 1
64 |
65 | ref:
66 | fsdp:
67 | enable_full_shard: true
68 | enable_cpu_offload: true # true: more CPU memory; false: more GPU memory
69 | enable_rank0_init: true
70 | offload:
71 | offload_params: false
72 |
73 | reward:
74 | reward_type: batch
75 | reward_function: ./examples/reward_function/math.py:compute_score
76 |
77 | trainer:
78 | total_epochs: 15
79 | max_steps: null
80 | project_name: easy_r1
81 | experiment_name: qwen2_5_7b_math_grpo
82 | logger: ["console", "wandb"]
83 | nnodes: 1
84 | n_gpus_per_node: 8
85 | val_freq: 5 # -1 to disable
86 | val_before_train: true
87 | val_only: false
88 | val_generations_to_log: 3
89 | save_freq: 5 # -1 to disable
90 | save_limit: 3 # -1 to disable
91 | save_checkpoint_path: null
92 | load_checkpoint_path: null
93 |
--------------------------------------------------------------------------------
/examples/format_prompt/math_format.jinja:
--------------------------------------------------------------------------------
1 | {{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}.
2 |
--------------------------------------------------------------------------------
/examples/format_prompt/r1v_format.jinja:
--------------------------------------------------------------------------------
1 | {{ content | trim }} A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here
2 |
--------------------------------------------------------------------------------
/examples/qwen2_5_7b_math_grpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen2.5-7B-Instruct # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | worker.actor.model.model_path=${MODEL_PATH}
12 |
--------------------------------------------------------------------------------
/examples/qwen2_5_vl_32b_geo3k_grpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen2.5-VL-32B-Instruct # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | data.train_files=hiyouga/geometry3k@train \
12 | data.val_files=hiyouga/geometry3k@test \
13 | worker.actor.model.model_path=${MODEL_PATH} \
14 | worker.actor.micro_batch_size_per_device_for_update=1 \
15 | worker.actor.micro_batch_size_per_device_for_experience=8 \
16 | worker.actor.fsdp.torch_dtype=bf16 \
17 | worker.actor.optim.strategy=adamw_bf16 \
18 | worker.rollout.tensor_parallel_size=8 \
19 | trainer.experiment_name=qwen2_5_vl_32b_geo_grpo \
20 | trainer.n_gpus_per_node=8
21 |
--------------------------------------------------------------------------------
/examples/qwen2_5_vl_3b_geo3k_grpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | data.train_files=hiyouga/geometry3k@train \
12 | data.val_files=hiyouga/geometry3k@test \
13 | worker.actor.model.model_path=${MODEL_PATH} \
14 | worker.rollout.tensor_parallel_size=1 \
15 | trainer.experiment_name=qwen2_5_vl_3b_geo_grpo \
16 | trainer.n_gpus_per_node=2
17 |
--------------------------------------------------------------------------------
/examples/qwen2_5_vl_7b_geo3k_grpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | data.train_files=hiyouga/geometry3k@train \
12 | data.val_files=hiyouga/geometry3k@test \
13 | worker.actor.model.model_path=${MODEL_PATH} \
14 | trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \
15 | trainer.n_gpus_per_node=8
16 |
--------------------------------------------------------------------------------
/examples/qwen2_5_vl_7b_geo3k_reinforce.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | data.train_files=hiyouga/geometry3k@train \
12 | data.val_files=hiyouga/geometry3k@test \
13 | worker.actor.model.model_path=${MODEL_PATH} \
14 | algorithm.adv_estimator=reinforce_plus_plus \
15 | algorithm.use_kl_loss=false \
16 | algorithm.kl_penalty=kl \
17 | algorithm.kl_coef=1.0e-3 \
18 | trainer.experiment_name=qwen2_5_vl_7b_geo_reinforce_pp \
19 | trainer.n_gpus_per_node=8
20 |
--------------------------------------------------------------------------------
/examples/qwen2_5_vl_7b_geo3k_swanlab.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | data.train_files=hiyouga/geometry3k@train \
12 | data.val_files=hiyouga/geometry3k@test \
13 | worker.actor.model.model_path=${MODEL_PATH} \
14 | trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \
15 | trainer.logger=['console','swanlab'] \
16 | trainer.n_gpus_per_node=8
17 |
--------------------------------------------------------------------------------
/examples/qwen2_5_vl_7b_multi_image.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # REMINDER: this script uses test data split and should ONLY be used for debugging. DO NOT use for training.
3 |
4 | set -x
5 |
6 | export PYTHONUNBUFFERED=1
7 |
8 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
9 |
10 | python3 -m verl.trainer.main \
11 | config=examples/config.yaml \
12 | data.train_files=hiyouga/journeybench-multi-image-vqa@train \
13 | data.val_files=hiyouga/journeybench-multi-image-vqa@test \
14 | data.rollout_batch_size=256 \
15 | worker.actor.model.model_path=${MODEL_PATH} \
16 | worker.rollout.limit_images=2 \
17 | trainer.experiment_name=qwen2_5_vl_7b_multi_image \
18 | trainer.n_gpus_per_node=8
19 |
--------------------------------------------------------------------------------
/examples/qwen3_4b_math_grpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONUNBUFFERED=1
6 |
7 | MODEL_PATH=Qwen/Qwen3-4B # replace it with your local file path
8 |
9 | python3 -m verl.trainer.main \
10 | config=examples/config.yaml \
11 | data.max_response_length=4096 \
12 | worker.actor.model.model_path=${MODEL_PATH} \
13 | trainer.experiment_name=qwen3_4b_math_grpo
14 |
--------------------------------------------------------------------------------
/examples/reward_function/math.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | from typing import Dict, List
17 |
18 | from mathruler.grader import extract_boxed_content, grade_answer
19 |
20 |
21 | def format_reward(predict: str) -> float:
22 | pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL)
23 | format_match = re.fullmatch(pattern, predict)
24 | return 1.0 if format_match else 0.0
25 |
26 |
27 | def accuracy_reward(predict: str, ground_truth: str) -> float:
28 | answer = extract_boxed_content(predict)
29 | return 1.0 if grade_answer(answer, ground_truth) else 0.0
30 |
31 |
32 | def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
33 | scores = []
34 | for predict, ground_truth in zip(predicts, ground_truths):
35 | predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
36 | format_score = format_reward(predict)
37 | accuracy_score = accuracy_reward(predict, ground_truth)
38 | scores.append(
39 | {
40 | "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
41 | "format": format_score,
42 | "accuracy": accuracy_score,
43 | }
44 | )
45 |
46 | return scores
47 |
--------------------------------------------------------------------------------
/examples/reward_function/r1v.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | from typing import Dict
17 |
18 | from mathruler.grader import grade_answer
19 |
20 |
21 | def format_reward(predict: str) -> float:
22 | pattern = re.compile(r".*?\s*.*?", re.DOTALL)
23 | format_match = re.fullmatch(pattern, predict)
24 | return 1.0 if format_match else 0.0
25 |
26 |
27 | def accuracy_reward(predict: str, ground_truth: str) -> float:
28 | try:
29 | content_match = re.search(r"(.*?)", predict)
30 | given_answer = content_match.group(1).strip() if content_match else predict.strip()
31 | if grade_answer(given_answer, ground_truth.strip()):
32 | return 1.0
33 |
34 | except Exception:
35 | pass
36 |
37 | return 0.0
38 |
39 |
40 | def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
41 | format_score = format_reward(predict)
42 | accuracy_score = accuracy_reward(predict, ground_truth)
43 | return {
44 | "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
45 | "format": format_score,
46 | "accuracy": accuracy_score,
47 | }
48 |
--------------------------------------------------------------------------------
/examples/runtime_env.yaml:
--------------------------------------------------------------------------------
1 | working_dir: ./
2 | excludes: ["/.git/"]
3 | env_vars:
4 | TOKENIZERS_PARALLELISM: "true"
5 | NCCL_DEBUG: "WARN"
6 | VLLM_LOGGING_LEVEL: "WARN"
7 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
8 | PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False"
9 | PYTHONUNBUFFERED: "1"
10 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "verl"
7 | dynamic = [
8 | "version",
9 | "dependencies",
10 | "optional-dependencies",
11 | "requires-python",
12 | "authors",
13 | "description",
14 | "readme",
15 | "license"
16 | ]
17 |
18 | [tool.ruff]
19 | target-version = "py39"
20 | line-length = 119
21 | indent-width = 4
22 |
23 | [tool.ruff.lint]
24 | ignore = ["C901", "E501", "E741", "W605", "C408"]
25 | select = ["C", "E", "F", "I", "W", "RUF022"]
26 |
27 | [tool.ruff.lint.per-file-ignores]
28 | "__init__.py" = ["E402", "F401", "F403", "F811"]
29 |
30 | [tool.ruff.lint.isort]
31 | lines-after-imports = 2
32 | known-first-party = ["verl"]
33 | known-third-party = ["torch", "transformers", "wandb"]
34 |
35 | [tool.ruff.format]
36 | quote-style = "double"
37 | indent-style = "space"
38 | skip-magic-trailing-comma = false
39 | line-ending = "auto"
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | codetiming
3 | datasets
4 | flash-attn>=2.4.3
5 | liger-kernel
6 | mathruler
7 | numpy
8 | omegaconf
9 | pandas
10 | peft
11 | pillow
12 | pyarrow>=15.0.0
13 | pylatexenc
14 | qwen-vl-utils
15 | ray[default]
16 | tensordict
17 | torchdata
18 | transformers>=4.51.0
19 | vllm>=0.7.3
20 | wandb
21 |
--------------------------------------------------------------------------------
/scripts/model_merger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | import re
18 | from concurrent.futures import ThreadPoolExecutor
19 | from typing import Dict, List, Tuple
20 |
21 | import numpy as np
22 | import torch
23 | from torch.distributed._tensor import DTensor, Placement, Shard
24 | from transformers import (
25 | AutoConfig,
26 | AutoModelForCausalLM,
27 | AutoModelForTokenClassification,
28 | AutoModelForVision2Seq,
29 | PretrainedConfig,
30 | PreTrainedModel,
31 | )
32 |
33 |
34 | def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
35 | if placement.is_replicate():
36 | return tensors[0]
37 | elif placement.is_partial():
38 | raise NotImplementedError("Partial placement is not supported yet")
39 | elif placement.is_shard():
40 | return torch.cat(tensors, dim=placement.dim).contiguous()
41 | else:
42 | raise ValueError(f"Unsupported placement: {placement}")
43 |
44 |
45 | def upload_model_to_huggingface(local_path: str, remote_path: str):
46 | # Push to hugging face
47 | from huggingface_hub import HfApi
48 |
49 | api = HfApi()
50 | api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
51 | api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")
52 |
53 |
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
57 | parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
58 | args = parser.parse_args()
59 | local_dir: str = args.local_dir
60 |
61 | assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface."
62 |
63 | # copy rank zero to find the shape of (dp, fsdp)
64 | rank = 0
65 | world_size = 0
66 | for filename in os.listdir(local_dir):
67 | match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
68 | if match:
69 | world_size = match.group(1)
70 | break
71 |
72 | assert world_size, "No model file with the proper format."
73 |
74 | rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
75 | state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
76 | pivot_key = sorted(state_dict.keys())[0]
77 | weight = state_dict[pivot_key]
78 | if isinstance(weight, DTensor):
79 | # get sharding info
80 | device_mesh = weight.device_mesh
81 | mesh = device_mesh.mesh
82 | mesh_dim_names = device_mesh.mesh_dim_names
83 | else:
84 | # for non-DTensor
85 | mesh = np.array([int(world_size)], dtype=np.int64)
86 | mesh_dim_names = ("fsdp",)
87 |
88 | print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
89 |
90 | assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}."
91 |
92 | if "tp" in mesh_dim_names:
93 | # fsdp * tp
94 | total_shards = mesh.shape[-1] * mesh.shape[-2]
95 | mesh_shape = (mesh.shape[-2], mesh.shape[-1])
96 | else:
97 | # fsdp
98 | total_shards = mesh.shape[-1]
99 | mesh_shape = (mesh.shape[-1],)
100 |
101 | print(f"Processing {total_shards} model shards in total.")
102 | model_state_dict_lst = []
103 | model_state_dict_lst.append(state_dict)
104 | model_state_dict_lst.extend([""] * (total_shards - 1))
105 |
106 | def process_one_shard(rank, model_state_dict_lst):
107 | model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
108 | state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
109 | model_state_dict_lst[rank] = state_dict
110 | return state_dict
111 |
112 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
113 | for rank in range(1, total_shards):
114 | executor.submit(process_one_shard, rank, model_state_dict_lst)
115 |
116 | state_dict: Dict[str, List[torch.Tensor]] = {}
117 | param_placements: Dict[str, List[Placement]] = {}
118 | keys = set(model_state_dict_lst[0].keys())
119 | for key in keys:
120 | state_dict[key] = []
121 | for model_state_dict in model_state_dict_lst:
122 | try:
123 | tensor = model_state_dict.pop(key)
124 | except Exception:
125 | print(f"Cannot find key {key} in rank {rank}.")
126 |
127 | if isinstance(tensor, DTensor):
128 | state_dict[key].append(tensor._local_tensor.bfloat16())
129 | placements = tuple(tensor.placements)
130 | # replicated placement at ddp dimension can be discarded
131 | if mesh_dim_names[0] == "ddp":
132 | placements = placements[1:]
133 |
134 | if key not in param_placements:
135 | param_placements[key] = placements
136 | else:
137 | assert param_placements[key] == placements
138 | else:
139 | state_dict[key].append(tensor.bfloat16())
140 |
141 | del model_state_dict_lst
142 |
143 | for key in sorted(state_dict):
144 | if not isinstance(state_dict[key], list):
145 | print(f"No need to merge key {key}")
146 | continue
147 |
148 | if key in param_placements:
149 | # merge shards
150 | placements: Tuple[Shard] = param_placements[key]
151 | if len(mesh_shape) == 1:
152 | # 1-D list, FSDP without TP
153 | assert len(placements) == 1
154 | shards = state_dict[key]
155 | state_dict[key] = merge_by_placement(shards, placements[0])
156 | else:
157 | # 2-D list, FSDP + TP
158 | raise NotImplementedError("FSDP + TP is not supported yet.")
159 | else:
160 | state_dict[key] = torch.cat(state_dict[key], dim=0)
161 |
162 | print("Merge completed.")
163 | hf_path = os.path.join(local_dir, "huggingface")
164 | config: PretrainedConfig = AutoConfig.from_pretrained(hf_path)
165 | architectures: List[str] = getattr(config, "architectures", ["Unknown"])
166 |
167 | if "ForTokenClassification" in architectures[0]:
168 | AutoClass = AutoModelForTokenClassification
169 | elif "ForCausalLM" in architectures[0]:
170 | AutoClass = AutoModelForCausalLM
171 | elif "ForConditionalGeneration" in architectures[0]:
172 | AutoClass = AutoModelForVision2Seq
173 | else:
174 | raise NotImplementedError(f"Unknown architecture {architectures}.")
175 |
176 | with torch.device("meta"):
177 | model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16)
178 |
179 | assert isinstance(model, PreTrainedModel)
180 | model.to_empty(device="cpu")
181 |
182 | print(f"Saving model to {hf_path}...")
183 | model.save_pretrained(hf_path, state_dict=state_dict)
184 | del state_dict, model
185 |
186 | if args.hf_upload_path:
187 | upload_model_to_huggingface(hf_path, args.hf_upload_path)
188 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import re
17 |
18 | from setuptools import find_packages, setup
19 |
20 |
21 | def get_version() -> str:
22 | with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f:
23 | file_content = f.read()
24 | pattern = r"__version__\W*=\W*\"([^\"]+)\""
25 | (version,) = re.findall(pattern, file_content)
26 | return version
27 |
28 |
29 | def get_requires() -> list[str]:
30 | with open("requirements.txt", encoding="utf-8") as f:
31 | file_content = f.read()
32 | lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
33 | return lines
34 |
35 |
36 | extra_require = {
37 | "dev": ["pre-commit", "ruff"],
38 | }
39 |
40 |
41 | def main():
42 | setup(
43 | name="verl",
44 | version=get_version(),
45 | description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL",
46 | long_description=open("README.md", encoding="utf-8").read(),
47 | long_description_content_type="text/markdown",
48 | author="verl",
49 | author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
50 | license="Apache 2.0 License",
51 | url="https://github.com/volcengine/verl",
52 | package_dir={"": "."},
53 | packages=find_packages(where="."),
54 | python_requires=">=3.9.0",
55 | install_requires=get_requires(),
56 | extras_require=extra_require,
57 | )
58 |
59 |
60 | if __name__ == "__main__":
61 | main()
62 |
--------------------------------------------------------------------------------
/verl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | from .utils.py_functional import is_package_available
18 |
19 |
20 | if is_package_available("modelscope"):
21 | from modelscope.utils.hf_util import patch_hub # type: ignore
22 |
23 |
24 | __version__ = "0.3.1.dev0"
25 |
26 |
27 | if os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "y", "1"]:
28 | # Patch hub to download models from modelscope to speed up.
29 | if not is_package_available("modelscope"):
30 | raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope`.")
31 |
32 | patch_hub()
33 |
--------------------------------------------------------------------------------
/verl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
17 |
18 | from .transformers.flash_attention_utils import flash_attention_forward
19 | from .transformers.qwen2_vl import qwen2_vl_attn_forward
20 |
21 |
22 | def apply_ulysses_patch(model_type: str) -> None:
23 | if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"):
24 | ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
25 | elif model_type in ("qwen2_vl", "qwen2_5_vl"):
26 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
27 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
28 |
29 | Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
30 | Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
31 | else:
32 | raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
33 |
--------------------------------------------------------------------------------
/verl/models/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/transformers/flash_attention_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
3 | # Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import inspect
18 | import os
19 | from typing import Optional, Tuple
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
24 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
25 |
26 | from ...utils.ulysses import (
27 | gather_heads_scatter_seq,
28 | gather_seq_scatter_heads,
29 | get_ulysses_sequence_parallel_group,
30 | get_ulysses_sequence_parallel_world_size,
31 | )
32 |
33 |
34 | if is_flash_attn_2_available():
35 | from flash_attn import flash_attn_func, flash_attn_varlen_func
36 |
37 | _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
38 | _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
39 | _flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
40 | _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
41 |
42 |
43 | def prepare_fa2_from_position_ids(
44 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
45 | ):
46 | query = query.view(-1, query.size(-2), query.size(-1))
47 | key = key.contiguous().view(-1, key.size(-2), key.size(-1))
48 | value = value.contiguous().view(-1, value.size(-2), value.size(-1))
49 | position_ids = position_ids.flatten()
50 | indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
51 | cu_seqlens = torch.cat(
52 | (
53 | indices_q[position_ids == 0],
54 | torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
55 | )
56 | )
57 | max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
58 | return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
59 |
60 |
61 | def _custom_flash_attention_forward(
62 | query_states: torch.Tensor,
63 | key_states: torch.Tensor,
64 | value_states: torch.Tensor,
65 | attention_mask: Optional[torch.Tensor],
66 | query_length: int,
67 | is_causal: bool = True,
68 | position_ids: Optional[torch.Tensor] = None,
69 | sliding_window: Optional[int] = None,
70 | use_top_left_mask: bool = False,
71 | deterministic: Optional[bool] = None,
72 | **kwargs,
73 | ):
74 | """
75 | Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
76 | """
77 | if not use_top_left_mask:
78 | causal = is_causal
79 | else:
80 | causal = is_causal and query_length != 1
81 |
82 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
83 | use_sliding_windows = (
84 | _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
85 | )
86 | flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
87 |
88 | if _flash_supports_deterministic:
89 | flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
90 |
91 | if kwargs.get("softcap") is not None:
92 | flash_kwargs["softcap"] = kwargs.pop("softcap")
93 |
94 | query_states, key_states, value_states = fa_peft_integration_check(
95 | query_states, key_states, value_states, target_dtype=torch.bfloat16
96 | )
97 |
98 | sp_size = get_ulysses_sequence_parallel_world_size()
99 | if sp_size > 1:
100 | # (batch_size, seq_length, num_head, head_size)
101 | query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
102 | key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
103 | value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
104 | position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
105 | position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
106 | position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
107 |
108 | if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
109 | position_ids = position_ids[0]
110 |
111 | if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
112 | batch_size = query_states.size(0)
113 | query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
114 | query_states, key_states, value_states, position_ids
115 | )
116 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
117 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
118 | attn_output = flash_attn_varlen_func(
119 | query_states,
120 | key_states,
121 | value_states,
122 | cu_seqlens_q=cu_seqlens_q,
123 | cu_seqlens_k=cu_seqlens_k,
124 | max_seqlen_q=max_seqlen_in_batch_q,
125 | max_seqlen_k=max_seqlen_in_batch_k,
126 | dropout_p=kwargs.pop("dropout", 0.0),
127 | softmax_scale=kwargs.pop("softmax_scale", None),
128 | causal=causal,
129 | **flash_kwargs,
130 | )
131 | attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
132 | else:
133 | attn_output = _flash_attention_forward(
134 | query_states,
135 | key_states,
136 | value_states,
137 | attention_mask,
138 | query_length,
139 | is_causal=is_causal,
140 | sliding_window=sliding_window,
141 | use_top_left_mask=use_top_left_mask,
142 | deterministic=deterministic,
143 | **kwargs,
144 | ) # do not pass position_ids to old flash_attention_forward
145 |
146 | if sp_size > 1:
147 | # (batch_size, seq_length, num_head, head_size)
148 | attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
149 |
150 | return attn_output
151 |
152 |
153 | def flash_attention_forward(
154 | module: torch.nn.Module,
155 | query: torch.Tensor,
156 | key: torch.Tensor,
157 | value: torch.Tensor,
158 | attention_mask: Optional[torch.Tensor],
159 | dropout: float = 0.0,
160 | scaling: Optional[float] = None,
161 | sliding_window: Optional[int] = None,
162 | softcap: Optional[float] = None,
163 | **kwargs,
164 | ) -> Tuple[torch.Tensor, None]:
165 | # This is before the transpose
166 | q_len = query.shape[2]
167 |
168 | # FA2 uses non-transposed inputs
169 | query = query.transpose(1, 2)
170 | key = key.transpose(1, 2)
171 | value = value.transpose(1, 2)
172 |
173 | # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
174 | kwargs.pop("is_causal", None)
175 |
176 | attn_output = _custom_flash_attention_forward(
177 | query,
178 | key,
179 | value,
180 | attention_mask,
181 | query_length=q_len,
182 | is_causal=True,
183 | dropout=dropout,
184 | softmax_scale=scaling,
185 | sliding_window=sliding_window,
186 | softcap=softcap,
187 | use_top_left_mask=_flash_use_top_left_mask,
188 | **kwargs,
189 | )
190 |
191 | return attn_output, None
192 |
--------------------------------------------------------------------------------
/verl/models/transformers/qwen2_vl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
3 | # Based on:
4 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from typing import Optional, Tuple
19 |
20 | import torch
21 |
22 | from .flash_attention_utils import flash_attention_forward
23 |
24 |
25 | try:
26 | from transformers.models.qwen2_vl.modeling_qwen2_vl import (
27 | Qwen2VLAttention,
28 | apply_multimodal_rotary_pos_emb,
29 | repeat_kv,
30 | )
31 | from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
32 | except ImportError:
33 | pass
34 |
35 |
36 | def get_rope_index(
37 | processor: "Qwen2VLProcessor",
38 | input_ids: torch.Tensor,
39 | image_grid_thw: Optional[torch.Tensor] = None,
40 | video_grid_thw: Optional[torch.Tensor] = None,
41 | second_per_grid_ts: Optional[torch.Tensor] = None,
42 | attention_mask: Optional[torch.Tensor] = None,
43 | ) -> torch.Tensor:
44 | """
45 | Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
46 | The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
47 | https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
48 | """
49 | spatial_merge_size = processor.image_processor.merge_size
50 | tokens_per_second = 2
51 | image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
52 | video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
53 | vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
54 | if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
55 | if attention_mask is None:
56 | attention_mask = torch.ones_like(input_ids)
57 |
58 | position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
59 | image_index, video_index = 0, 0
60 | input_ids = input_ids[attention_mask == 1]
61 | image_nums, video_nums = 0, 0
62 | vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
63 | vision_tokens = input_ids[vision_start_indices + 1]
64 | image_nums = (vision_tokens == image_token_id).sum()
65 | video_nums = (vision_tokens == video_token_id).sum()
66 | input_tokens = input_ids.tolist()
67 | llm_pos_ids_list: list = []
68 | st = 0
69 | remain_images, remain_videos = image_nums, video_nums
70 | for _ in range(image_nums + video_nums):
71 | if image_token_id in input_tokens and remain_images > 0:
72 | ed_image = input_tokens.index(image_token_id, st)
73 | else:
74 | ed_image = len(input_tokens) + 1
75 | if video_token_id in input_tokens and remain_videos > 0:
76 | ed_video = input_tokens.index(video_token_id, st)
77 | else:
78 | ed_video = len(input_tokens) + 1
79 | if ed_image < ed_video:
80 | t, h, w = (
81 | image_grid_thw[image_index][0],
82 | image_grid_thw[image_index][1],
83 | image_grid_thw[image_index][2],
84 | )
85 | second_per_grid_t = 0
86 | image_index += 1
87 | remain_images -= 1
88 | ed = ed_image
89 | else:
90 | t, h, w = (
91 | video_grid_thw[video_index][0],
92 | video_grid_thw[video_index][1],
93 | video_grid_thw[video_index][2],
94 | )
95 | if second_per_grid_ts is not None:
96 | second_per_grid_t = second_per_grid_ts[video_index]
97 | else:
98 | second_per_grid_t = 1.0
99 |
100 | video_index += 1
101 | remain_videos -= 1
102 | ed = ed_video
103 |
104 | llm_grid_t, llm_grid_h, llm_grid_w = (
105 | t.item(),
106 | h.item() // spatial_merge_size,
107 | w.item() // spatial_merge_size,
108 | )
109 | text_len = ed - st
110 |
111 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
112 | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
113 |
114 | t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
115 | t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
116 | h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
117 | w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
118 | llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
119 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w
120 |
121 | if st < len(input_tokens):
122 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
123 | text_len = len(input_tokens) - st
124 | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
125 |
126 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
127 | position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
128 | else:
129 | if attention_mask is not None:
130 | position_ids = attention_mask.long().cumsum(-1) - 1
131 | position_ids.masked_fill_(attention_mask == 0, 1)
132 | position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
133 | else:
134 | position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
135 |
136 | return position_ids
137 |
138 |
139 | def qwen2_vl_attn_forward(
140 | self: "Qwen2VLAttention",
141 | hidden_states: torch.Tensor,
142 | attention_mask: Optional[torch.Tensor] = None,
143 | position_ids: Optional[torch.LongTensor] = None,
144 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
145 | **kwargs,
146 | ) -> Tuple[torch.Tensor, None, None]:
147 | bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
148 | query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
149 | key_states = self.k_proj(hidden_states)
150 | value_states = self.v_proj(hidden_states)
151 |
152 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
153 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
154 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
155 |
156 | # Because the input can be padded, the absolute sequence length depends on the max position id.
157 | if position_embeddings is None:
158 | cos, sin = self.rotary_emb(value_states, position_ids)
159 | else:
160 | cos, sin = position_embeddings
161 |
162 | query_states, key_states = apply_multimodal_rotary_pos_emb(
163 | query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
164 | )
165 | key_states = repeat_kv(key_states, self.num_key_value_groups)
166 | value_states = repeat_kv(value_states, self.num_key_value_groups)
167 | dropout_rate = 0.0 if not self.training else self.attention_dropout
168 |
169 | sliding_window = None
170 | if (
171 | self.config.use_sliding_window
172 | and getattr(self.config, "sliding_window", None) is not None
173 | and self.layer_idx >= self.config.max_window_layers
174 | ):
175 | sliding_window = self.config.sliding_window
176 |
177 | attn_output, _ = flash_attention_forward(
178 | self,
179 | query_states,
180 | key_states,
181 | value_states,
182 | attention_mask,
183 | dropout=dropout_rate,
184 | sliding_window=sliding_window,
185 | position_ids=position_ids, # important: pass position ids
186 | ) # (batch_size, seq_length, num_head / sp_size, head_size)
187 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
188 | attn_output = self.o_proj(attn_output)
189 | return attn_output, None, None
190 |
--------------------------------------------------------------------------------
/verl/single_controller/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/single_controller/base/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .worker import Worker
16 | from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
17 |
18 |
19 | __all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
20 |
--------------------------------------------------------------------------------
/verl/single_controller/base/decorator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from enum import Enum, auto
16 | from functools import wraps
17 | from types import FunctionType
18 | from typing import TYPE_CHECKING, Dict, List, Literal, Union
19 |
20 | import ray
21 |
22 | from ...protocol import DataProto, DataProtoFuture
23 |
24 |
25 | if TYPE_CHECKING:
26 | from .worker_group import WorkerGroup
27 |
28 |
29 | # here we add a magic number of avoid user-defined function already have this attribute
30 | MAGIC_ATTR = "attrs_3141562937"
31 |
32 |
33 | class Dispatch(Enum):
34 | RANK_ZERO = auto()
35 | ONE_TO_ALL = auto()
36 | ALL_TO_ALL = auto()
37 | DP_COMPUTE = auto()
38 | DP_COMPUTE_PROTO = auto()
39 | DP_COMPUTE_PROTO_WITH_FUNC = auto()
40 | DP_COMPUTE_METRIC = auto()
41 |
42 |
43 | class Execute(Enum):
44 | ALL = 0
45 | RANK_ZERO = 1
46 |
47 |
48 | def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
49 | splitted_args = []
50 | for arg in args:
51 | assert isinstance(arg, (DataProto, DataProtoFuture))
52 | splitted_args.append(arg.chunk(chunks=chunks))
53 |
54 | splitted_kwargs = {}
55 | for key, value in kwargs.items():
56 | assert isinstance(value, (DataProto, DataProtoFuture))
57 | splitted_kwargs[key] = value.chunk(chunks=chunks)
58 |
59 | return splitted_args, splitted_kwargs
60 |
61 |
62 | def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
63 | args = tuple([arg] * worker_group.world_size for arg in args)
64 | kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
65 | return args, kwargs
66 |
67 |
68 | def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
69 | return args, kwargs
70 |
71 |
72 | def collect_all_to_all(worker_group: "WorkerGroup", output):
73 | return output
74 |
75 |
76 | def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
77 | # make sure all the elements in output has the same type
78 | for output in outputs:
79 | assert type(output) is type(outputs[0])
80 |
81 | output = outputs[0]
82 |
83 | if isinstance(output, DataProto):
84 | return DataProto.concat(outputs)
85 | elif isinstance(output, ray.ObjectRef):
86 | return DataProtoFuture.concat(outputs)
87 | else:
88 | raise NotImplementedError
89 |
90 |
91 | def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
92 | for arg in args:
93 | assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
94 |
95 | for value in kwargs.values():
96 | assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
97 |
98 | return args, kwargs
99 |
100 |
101 | def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
102 | assert len(outputs) == worker_group.world_size
103 | return outputs
104 |
105 |
106 | def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
107 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
108 | return splitted_args, splitted_kwargs
109 |
110 |
111 | def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
112 | assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
113 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
114 | splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
115 | return splitted_args_with_func, splitted_kwargs
116 |
117 |
118 | def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
119 | for output in outputs:
120 | assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
121 |
122 | outputs = collect_dp_compute(worker_group, outputs)
123 | return _concat_data_proto_or_future(outputs)
124 |
125 |
126 | def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
127 | predefined_dispatch_mode_fn = {
128 | Dispatch.ONE_TO_ALL: {
129 | "dispatch_fn": dispatch_one_to_all,
130 | "collect_fn": collect_all_to_all,
131 | },
132 | Dispatch.ALL_TO_ALL: {
133 | "dispatch_fn": dispatch_all_to_all,
134 | "collect_fn": collect_all_to_all,
135 | },
136 | Dispatch.DP_COMPUTE: {
137 | "dispatch_fn": dispatch_dp_compute,
138 | "collect_fn": collect_dp_compute,
139 | },
140 | Dispatch.DP_COMPUTE_PROTO: {
141 | "dispatch_fn": dispatch_dp_compute_data_proto,
142 | "collect_fn": collect_dp_compute_data_proto,
143 | },
144 | Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
145 | "dispatch_fn": dispatch_dp_compute_data_proto_with_func,
146 | "collect_fn": collect_dp_compute_data_proto,
147 | },
148 | Dispatch.DP_COMPUTE_METRIC: {
149 | "dispatch_fn": dispatch_dp_compute_data_proto,
150 | "collect_fn": collect_dp_compute,
151 | },
152 | }
153 | return predefined_dispatch_mode_fn[dispatch_mode]
154 |
155 |
156 | def get_predefined_execute_fn(execute_mode: Execute):
157 | """
158 | Note that here we only asks execute_all and execute_rank_zero to be implemented
159 | Leave the choice of how these two functions handle argument 'blocking' to users
160 | """
161 | predefined_execute_mode_fn = {
162 | Execute.ALL: {"execute_fn_name": "execute_all"},
163 | Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
164 | }
165 | return predefined_execute_mode_fn[execute_mode]
166 |
167 |
168 | def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
169 | assert isinstance(dispatch_mode, (Dispatch, dict)), (
170 | f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
171 | )
172 | if isinstance(dispatch_mode, dict):
173 | necessary_keys = ["dispatch_fn", "collect_fn"]
174 | for key in necessary_keys:
175 | assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
176 |
177 |
178 | def _check_execute_mode(execute_mode: Execute):
179 | assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
180 |
181 |
182 | def _materialize_futures(*args, **kwargs):
183 | new_args = []
184 | for arg in args:
185 | if isinstance(arg, DataProtoFuture):
186 | arg = arg.get()
187 | # add more type to materialize
188 | new_args.append(arg)
189 |
190 | for key, value in kwargs.items():
191 | if isinstance(value, DataProtoFuture):
192 | kwargs[key] = value.get()
193 |
194 | new_args = tuple(new_args)
195 | return new_args, kwargs
196 |
197 |
198 | def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
199 | _check_dispatch_mode(dispatch_mode=dispatch_mode)
200 | _check_execute_mode(execute_mode=execute_mode)
201 |
202 | def decorator(func):
203 | @wraps(func)
204 | def inner(*args, **kwargs):
205 | if materialize_futures:
206 | args, kwargs = _materialize_futures(*args, **kwargs)
207 | return func(*args, **kwargs)
208 |
209 | attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
210 | setattr(inner, MAGIC_ATTR, attrs)
211 | return inner
212 |
213 | return decorator
214 |
--------------------------------------------------------------------------------
/verl/single_controller/base/register_center/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/single_controller/base/register_center/ray.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ray
16 |
17 |
18 | @ray.remote
19 | class WorkerGroupRegisterCenter:
20 | def __init__(self, rank_zero_info):
21 | self.rank_zero_info = rank_zero_info
22 |
23 | def get_rank_zero_info(self):
24 | return self.rank_zero_info
25 |
26 |
27 | def create_worker_group_register_center(name, info):
28 | return WorkerGroupRegisterCenter.options(name=name).remote(info)
29 |
--------------------------------------------------------------------------------
/verl/single_controller/base/worker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | the class for Worker
16 | """
17 |
18 | import os
19 | import socket
20 | from dataclasses import dataclass
21 | from typing import Tuple
22 |
23 | import ray
24 | import torch
25 |
26 | from .decorator import Dispatch, Execute, register
27 | from .register_center.ray import create_worker_group_register_center
28 |
29 |
30 | @dataclass
31 | class DistRankInfo:
32 | tp_rank: int
33 | dp_rank: int
34 | pp_rank: int
35 |
36 |
37 | @dataclass
38 | class DistGlobalInfo:
39 | tp_size: int
40 | dp_size: int
41 | pp_size: int
42 |
43 |
44 | class WorkerHelper:
45 | def _get_node_ip(self) -> str:
46 | host_ipv4 = os.getenv("MY_HOST_IP", None)
47 | host_ipv6 = os.getenv("MY_HOST_IPV6", None)
48 | host_ip_by_env = host_ipv4 or host_ipv6
49 | host_ip_by_sdk = ray._private.services.get_node_ip_address()
50 |
51 | host_ip = host_ip_by_env or host_ip_by_sdk
52 | return host_ip
53 |
54 | def _get_free_port(self) -> int:
55 | with socket.socket() as sock:
56 | sock.bind(("", 0))
57 | return sock.getsockname()[1]
58 |
59 | def get_availale_master_addr_port(self) -> Tuple[str, str]:
60 | return self._get_node_ip(), str(self._get_free_port())
61 |
62 | def _get_pid(self):
63 | return
64 |
65 |
66 | class WorkerMeta:
67 | keys = [
68 | "WORLD_SIZE",
69 | "RANK",
70 | "LOCAL_WORLD_SIZE",
71 | "LOCAL_RANK",
72 | "MASTER_ADDR",
73 | "MASTER_PORT",
74 | "CUDA_VISIBLE_DEVICES",
75 | ]
76 |
77 | def __init__(self, store) -> None:
78 | self._store = store
79 |
80 | def to_dict(self):
81 | return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
82 |
83 |
84 | # we assume that in each WorkerGroup, there is a Master Worker
85 | class Worker(WorkerHelper):
86 | """A (distributed) worker."""
87 |
88 | _world_size: int
89 | _rank: int
90 | _local_world_size: int
91 | _local_rank: int
92 | _master_addr: str
93 | _master_port: str
94 | _cuda_visible_devices: str
95 |
96 | def __new__(cls, *args, **kwargs):
97 | instance = super().__new__(cls)
98 |
99 | # note that here we use int to distinguish
100 | disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
101 | if disable_worker_init:
102 | return instance
103 |
104 | rank = os.getenv("RANK", None)
105 | worker_group_prefix = os.getenv("WG_PREFIX", None)
106 |
107 | # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
108 | if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
109 | instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
110 |
111 | return instance
112 |
113 | def _configure_before_init(self, register_center_name: str, rank: int):
114 | assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
115 |
116 | if rank == 0:
117 | master_addr, master_port = self.get_availale_master_addr_port()
118 | rank_zero_info = {
119 | "MASTER_ADDR": master_addr,
120 | "MASTER_PORT": master_port,
121 | }
122 | self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info)
123 | os.environ.update(rank_zero_info)
124 |
125 | def __init__(self, cuda_visible_devices=None) -> None:
126 | # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
127 | world_size = int(os.getenv("WORLD_SIZE"))
128 | rank = int(os.getenv("RANK"))
129 | self._rank = rank
130 | self._world_size = world_size
131 |
132 | if "AMD" in torch.cuda.get_device_name():
133 | os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
134 | os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
135 | cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
136 | torch.cuda.set_device(int(cuda_visible_devices))
137 |
138 | master_addr = os.getenv("MASTER_ADDR")
139 | master_port = os.getenv("MASTER_PORT")
140 |
141 | local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
142 | local_rank = int(os.getenv("LOCAL_RANK", "0"))
143 |
144 | store = {
145 | "_world_size": world_size,
146 | "_rank": rank,
147 | "_local_world_size": local_world_size,
148 | "_local_rank": local_rank,
149 | "_master_addr": master_addr,
150 | "_master_port": master_port,
151 | }
152 | if cuda_visible_devices is not None:
153 | store["_cuda_visible_devices"] = cuda_visible_devices
154 |
155 | meta = WorkerMeta(store=store)
156 | self._configure_with_meta(meta=meta)
157 |
158 | def _configure_with_meta(self, meta: WorkerMeta):
159 | """
160 | This function should only be called inside by WorkerGroup
161 | """
162 | assert isinstance(meta, WorkerMeta)
163 | self.__dict__.update(meta.to_dict()) # this is hacky
164 | # print(f"__dict__: {self.__dict__}")
165 | for key in WorkerMeta.keys:
166 | val = self.__dict__.get(f"_{key.lower()}", None)
167 | if val is not None:
168 | # print(f"set {key} to {val}")
169 | os.environ[key] = str(val)
170 |
171 | os.environ["REDIS_STORE_SERVER_HOST"] = (
172 | str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
173 | )
174 |
175 | def get_master_addr_port(self):
176 | return self._master_addr, self._master_port
177 |
178 | def get_cuda_visible_devices(self):
179 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
180 | return cuda_visible_devices
181 |
182 | def print_rank0(self, *args, **kwargs):
183 | if self.rank == 0:
184 | print(*args, **kwargs)
185 |
186 | @property
187 | def world_size(self):
188 | return self._world_size
189 |
190 | @property
191 | def rank(self):
192 | return self._rank
193 |
194 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
195 | def execute_with_func_generator(self, func, *args, **kwargs):
196 | ret_proto = func(self, *args, **kwargs)
197 | return ret_proto
198 |
199 | @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
200 | def execute_func_rank_zero(self, func, *args, **kwargs):
201 | result = func(*args, **kwargs)
202 | return result
203 |
--------------------------------------------------------------------------------
/verl/single_controller/base/worker_group.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | the class of WorkerGroup
16 | """
17 |
18 | import logging
19 | import signal
20 | import threading
21 | import time
22 | from typing import Any, Callable, Dict, List, Optional
23 |
24 | from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
25 |
26 |
27 | class ResourcePool:
28 | """The resource pool with meta info such as world size."""
29 |
30 | def __init__(
31 | self, process_on_nodes: Optional[Any] = None, max_colocate_count: int = 10, n_gpus_per_node: int = 8
32 | ) -> None:
33 | if process_on_nodes is None:
34 | process_on_nodes = []
35 |
36 | self._store = process_on_nodes
37 | self.max_colocate_count = max_colocate_count
38 | self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
39 |
40 | def add_node(self, process_count):
41 | self._store.append(process_count)
42 |
43 | @property
44 | def world_size(self):
45 | return sum(self._store)
46 |
47 | def __call__(self) -> Any:
48 | return self._store
49 |
50 | @property
51 | def store(self):
52 | return self._store
53 |
54 | def local_world_size_list(self) -> List[int]:
55 | nested_local_world_size_list = [
56 | [local_world_size for _ in range(local_world_size)] for local_world_size in self._store
57 | ]
58 | return [item for row in nested_local_world_size_list for item in row]
59 |
60 | def local_rank_list(self) -> List[int]:
61 | nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416
62 | return [item for row in nested_local_rank_list for item in row]
63 |
64 |
65 | class ClassWithInitArgs:
66 | """
67 | This class stores a class constructor and the args/kwargs to construct the class.
68 | It is used to instantiate the remote class.
69 | """
70 |
71 | def __init__(self, cls, *args, **kwargs) -> None:
72 | self.cls = cls
73 | self.args = args
74 | self.kwargs = kwargs
75 |
76 | def __call__(self) -> Any:
77 | return self.cls(*self.args, **self.kwargs)
78 |
79 |
80 | def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
81 | while True:
82 | for worker in workers:
83 | if not is_alive(worker):
84 | logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
85 | signal.raise_signal(signal.SIGABRT)
86 |
87 | time.sleep(gap_time)
88 |
89 |
90 | class WorkerGroup:
91 | """A group of workers"""
92 |
93 | def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
94 | self._is_init_with_detached_workers = True if resource_pool is None else False
95 |
96 | if resource_pool is not None:
97 | # handle the case when WorkGroup is attached to an existing one
98 | self._procecss_dispatch_config = resource_pool()
99 | else:
100 | self._procecss_dispatch_config = None
101 |
102 | self._workers = []
103 | self._worker_names = []
104 |
105 | self._master_addr = None
106 | self._master_port = None
107 |
108 | self._checker_thread: threading.Thread = None
109 |
110 | def _is_worker_alive(self, worker):
111 | raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.")
112 |
113 | def _block_until_all_workers_alive(self) -> None:
114 | while True:
115 | all_state = [self._is_worker_alive(worker) for worker in self._workers]
116 | if False in all_state:
117 | time.sleep(1)
118 | else:
119 | break
120 |
121 | def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
122 | # before starting checking worker aliveness, make sure all workers are already alive
123 | self._block_until_all_workers_alive()
124 |
125 | self._checker_thread = threading.Thread(
126 | target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)
127 | )
128 | self._checker_thread.start()
129 |
130 | @property
131 | def world_size(self):
132 | return len(self._workers)
133 |
134 | def _bind_worker_method(self, user_defined_cls, func_generator):
135 | """
136 | Bind the worker method to the WorkerGroup
137 | """
138 | for method_name in dir(user_defined_cls):
139 | try:
140 | method = getattr(user_defined_cls, method_name)
141 | assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
142 | except Exception:
143 | # if it is a property, it will fail because Class doesn't have instance property
144 | continue
145 |
146 | if hasattr(method, MAGIC_ATTR):
147 | # this method is decorated by register
148 | attribute = getattr(method, MAGIC_ATTR)
149 | assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
150 | assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
151 |
152 | dispatch_mode = attribute["dispatch_mode"]
153 | execute_mode = attribute["execute_mode"]
154 | blocking = attribute["blocking"]
155 |
156 | # get dispatch fn
157 | if isinstance(dispatch_mode, Dispatch):
158 | # get default dispatch fn
159 | fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
160 | dispatch_fn = fn["dispatch_fn"]
161 | collect_fn = fn["collect_fn"]
162 | else:
163 | assert isinstance(dispatch_mode, dict)
164 | assert "dispatch_fn" in dispatch_mode
165 | assert "collect_fn" in dispatch_mode
166 | dispatch_fn = dispatch_mode["dispatch_fn"]
167 | collect_fn = dispatch_mode["collect_fn"]
168 |
169 | # get execute_fn_name
170 | execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
171 | wg_execute_fn_name = execute_mode["execute_fn_name"]
172 |
173 | # get execute_fn from string
174 | try:
175 | execute_fn = getattr(self, wg_execute_fn_name)
176 | assert callable(execute_fn), "execute_fn must be callable"
177 | except Exception:
178 | print(f"execute_fn {wg_execute_fn_name} is invalid")
179 | raise
180 |
181 | # bind a new method to the RayWorkerGroup
182 | func = func_generator(
183 | self,
184 | method_name,
185 | dispatch_fn=dispatch_fn,
186 | collect_fn=collect_fn,
187 | execute_fn=execute_fn,
188 | blocking=blocking,
189 | )
190 |
191 | try:
192 | setattr(self, method_name, func)
193 | except Exception:
194 | raise ValueError(f"Fail to set method_name {method_name}")
195 |
--------------------------------------------------------------------------------
/verl/single_controller/ray/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls
16 |
17 |
18 | __all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"]
19 |
--------------------------------------------------------------------------------
/verl/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/trainer/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | PPO config
16 | """
17 |
18 | import os
19 | from dataclasses import asdict, dataclass, field, fields, is_dataclass
20 | from typing import Optional, Tuple
21 |
22 | from ..workers.config import WorkerConfig
23 |
24 |
25 | def recursive_post_init(dataclass_obj):
26 | if hasattr(dataclass_obj, "post_init"):
27 | dataclass_obj.post_init()
28 |
29 | for attr in fields(dataclass_obj):
30 | if is_dataclass(getattr(dataclass_obj, attr.name)):
31 | recursive_post_init(getattr(dataclass_obj, attr.name))
32 |
33 |
34 | @dataclass
35 | class DataConfig:
36 | train_files: str = ""
37 | val_files: str = ""
38 | prompt_key: str = "prompt"
39 | answer_key: str = "answer"
40 | image_key: str = "images"
41 | max_prompt_length: int = 512
42 | max_response_length: int = 512
43 | rollout_batch_size: int = 512
44 | val_batch_size: int = -1
45 | format_prompt: Optional[str] = None
46 | override_chat_template: Optional[str] = None
47 | shuffle: bool = True
48 | seed: int = 1
49 | max_pixels: int = 4194304
50 | min_pixels: int = 262144
51 | filter_overlong_prompts: bool = True
52 |
53 | def post_init(self):
54 | if self.format_prompt is not None:
55 | if os.path.exists(self.format_prompt): # ray job uses absolute path
56 | self.format_prompt = os.path.abspath(self.format_prompt)
57 | else:
58 | self.format_prompt = None
59 |
60 |
61 | @dataclass
62 | class AlgorithmConfig:
63 | gamma: float = 1.0
64 | lam: float = 1.0
65 | adv_estimator: str = "grpo"
66 | disable_kl: bool = False
67 | use_kl_loss: bool = False
68 | kl_penalty: str = "kl"
69 | kl_coef: float = 1e-3
70 | kl_type: str = "fixed"
71 | kl_horizon: float = 0.0
72 | kl_target: float = 0.0
73 |
74 |
75 | @dataclass
76 | class TrainerConfig:
77 | total_epochs: int = 10
78 | max_steps: Optional[int] = None
79 | project_name: str = "easy_r1"
80 | experiment_name: str = "demo"
81 | logger: Tuple[str] = ("console", "wandb")
82 | nnodes: int = 1
83 | n_gpus_per_node: int = 8
84 | critic_warmup: int = 0
85 | val_freq: int = -1
86 | val_before_train: bool = True
87 | val_only: bool = False
88 | val_generations_to_log: int = 0
89 | save_freq: int = -1
90 | save_limit: int = -1
91 | save_checkpoint_path: Optional[str] = None
92 | load_checkpoint_path: Optional[str] = None
93 |
94 | def post_init(self):
95 | if self.save_checkpoint_path is None:
96 | self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
97 |
98 | self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
99 | if self.load_checkpoint_path is not None:
100 | self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
101 |
102 |
103 | @dataclass
104 | class PPOConfig:
105 | data: DataConfig = field(default_factory=DataConfig)
106 | worker: WorkerConfig = field(default_factory=WorkerConfig)
107 | algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
108 | trainer: TrainerConfig = field(default_factory=TrainerConfig)
109 |
110 | def post_init(self):
111 | self.worker.rollout.prompt_length = self.data.max_prompt_length
112 | self.worker.rollout.response_length = self.data.max_response_length
113 | self.worker.rollout.trust_remote_code = self.worker.actor.model.trust_remote_code
114 | self.worker.actor.disable_kl = self.algorithm.disable_kl
115 | self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
116 | self.worker.actor.kl_penalty = self.algorithm.kl_penalty
117 | self.worker.actor.kl_coef = self.algorithm.kl_coef
118 |
119 | def deep_post_init(self):
120 | recursive_post_init(self)
121 |
122 | def to_dict(self):
123 | return asdict(self)
124 |
--------------------------------------------------------------------------------
/verl/trainer/data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional
16 |
17 | import torch
18 | from torch.utils.data import RandomSampler, SequentialSampler
19 | from torchdata.stateful_dataloader import StatefulDataLoader
20 | from transformers import PreTrainedTokenizer, ProcessorMixin
21 |
22 | from ..utils.dataset import RLHFDataset, collate_fn
23 | from .config import DataConfig
24 |
25 |
26 | def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin]) -> None:
27 | train_dataset = RLHFDataset(
28 | data_path=config.train_files,
29 | tokenizer=tokenizer,
30 | processor=processor,
31 | prompt_key=config.prompt_key,
32 | answer_key=config.answer_key,
33 | image_key=config.image_key,
34 | max_prompt_length=config.max_prompt_length,
35 | truncation="right",
36 | format_prompt=config.format_prompt,
37 | min_pixels=config.min_pixels,
38 | max_pixels=config.max_pixels,
39 | filter_overlong_prompts=config.filter_overlong_prompts,
40 | )
41 | # use sampler for better ckpt resume
42 | if config.shuffle:
43 | train_dataloader_generator = torch.Generator()
44 | train_dataloader_generator.manual_seed(config.seed)
45 | sampler = RandomSampler(data_source=train_dataset, generator=train_dataloader_generator)
46 | else:
47 | sampler = SequentialSampler(data_source=train_dataset)
48 |
49 | train_dataloader = StatefulDataLoader(
50 | dataset=train_dataset,
51 | batch_size=config.rollout_batch_size,
52 | sampler=sampler,
53 | num_workers=8,
54 | collate_fn=collate_fn,
55 | pin_memory=False,
56 | drop_last=True,
57 | )
58 |
59 | val_dataset = RLHFDataset(
60 | data_path=config.val_files,
61 | tokenizer=tokenizer,
62 | processor=processor,
63 | prompt_key=config.prompt_key,
64 | answer_key=config.answer_key,
65 | image_key=config.image_key,
66 | max_prompt_length=config.max_prompt_length,
67 | truncation="right",
68 | format_prompt=config.format_prompt,
69 | min_pixels=config.min_pixels,
70 | max_pixels=config.max_pixels,
71 | filter_overlong_prompts=config.filter_overlong_prompts,
72 | )
73 | val_dataloader = StatefulDataLoader(
74 | dataset=val_dataset,
75 | batch_size=len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size,
76 | shuffle=False,
77 | num_workers=8,
78 | collate_fn=collate_fn,
79 | pin_memory=False,
80 | drop_last=False,
81 | )
82 |
83 | assert len(train_dataloader) >= 1
84 | assert len(val_dataloader) >= 1
85 | print(f"Size of train dataloader: {len(train_dataloader)}")
86 | print(f"Size of val dataloader: {len(val_dataloader)}")
87 | return train_dataloader, val_dataloader
88 |
--------------------------------------------------------------------------------
/verl/trainer/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 |
17 | import ray
18 | from omegaconf import OmegaConf
19 |
20 | from ..single_controller.ray import RayWorkerGroup
21 | from ..utils.tokenizer import get_processor, get_tokenizer
22 | from ..workers.fsdp_workers import FSDPWorker
23 | from ..workers.reward import BatchFunctionRewardManager, SequentialFunctionRewardManager
24 | from .config import PPOConfig
25 | from .data_loader import create_dataloader
26 | from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
27 |
28 |
29 | # please make sure main_task is not scheduled on head
30 | @ray.remote(num_cpus=1)
31 | class Runner:
32 | """A runner for RL training."""
33 |
34 | def run(self, config: PPOConfig):
35 | # print config
36 | print(json.dumps(config.to_dict(), indent=2))
37 |
38 | # instantiate tokenizer
39 | tokenizer = get_tokenizer(
40 | config.worker.actor.model.model_path,
41 | override_chat_template=config.data.override_chat_template,
42 | trust_remote_code=config.worker.actor.model.trust_remote_code,
43 | use_fast=True,
44 | )
45 | processor = get_processor(
46 | config.worker.actor.model.model_path,
47 | override_chat_template=config.data.override_chat_template,
48 | trust_remote_code=config.worker.actor.model.trust_remote_code,
49 | use_fast=True,
50 | )
51 |
52 | # define worker classes
53 | ray_worker_group_cls = RayWorkerGroup
54 | role_worker_mapping = {
55 | Role.ActorRollout: ray.remote(FSDPWorker),
56 | Role.Critic: ray.remote(FSDPWorker),
57 | Role.RefPolicy: ray.remote(FSDPWorker),
58 | }
59 | global_pool_id = "global_pool"
60 | resource_pool_spec = {
61 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
62 | }
63 | mapping = {
64 | Role.ActorRollout: global_pool_id,
65 | Role.Critic: global_pool_id,
66 | Role.RefPolicy: global_pool_id,
67 | }
68 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
69 |
70 | if config.worker.reward.reward_type == "sequential":
71 | RewardManager = SequentialFunctionRewardManager
72 | elif config.worker.reward.reward_type == "batch":
73 | RewardManager = BatchFunctionRewardManager
74 | else:
75 | raise NotImplementedError(f"Unknown reward type {config.worker.reward.reward_type}.")
76 |
77 | RemoteRewardManager = ray.remote(RewardManager).options(num_cpus=config.worker.reward.num_cpus)
78 | reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
79 | val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
80 |
81 | train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor)
82 |
83 | trainer = RayPPOTrainer(
84 | config=config,
85 | tokenizer=tokenizer,
86 | processor=processor,
87 | train_dataloader=train_dataloader,
88 | val_dataloader=val_dataloader,
89 | role_worker_mapping=role_worker_mapping,
90 | resource_pool_manager=resource_pool_manager,
91 | ray_worker_group_cls=ray_worker_group_cls,
92 | reward_fn=reward_fn,
93 | val_reward_fn=val_reward_fn,
94 | )
95 | trainer.init_workers()
96 | trainer.fit()
97 |
98 |
99 | def main():
100 | cli_args = OmegaConf.from_cli()
101 | default_config = OmegaConf.structured(PPOConfig())
102 |
103 | if hasattr(cli_args, "config"):
104 | config_path = cli_args.pop("config", None)
105 | file_config = OmegaConf.load(config_path)
106 | default_config = OmegaConf.merge(default_config, file_config)
107 |
108 | ppo_config = OmegaConf.merge(default_config, cli_args)
109 | ppo_config: PPOConfig = OmegaConf.to_object(ppo_config)
110 | ppo_config.deep_post_init()
111 |
112 | if not ray.is_initialized():
113 | runtime_env = {
114 | "env_vars": {
115 | "TOKENIZERS_PARALLELISM": "true",
116 | "NCCL_DEBUG": "WARN",
117 | "VLLM_LOGGING_LEVEL": "WARN",
118 | "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
119 | "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:False",
120 | "PYTHONUNBUFFERED": "1",
121 | }
122 | }
123 | ray.init(runtime_env=runtime_env)
124 |
125 | runner = Runner.remote()
126 | ray.get(runner.run.remote(ppo_config))
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/verl/trainer/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List
16 |
17 | import numpy as np
18 | import torch
19 |
20 | from ..protocol import DataProto
21 |
22 |
23 | def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
24 | return {key: np.mean(value) for key, value in metrics.items()}
25 |
26 |
27 | def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]:
28 | sequence_score = batch.batch["token_level_scores"].sum(-1)
29 | sequence_reward = batch.batch["token_level_rewards"].sum(-1)
30 |
31 | advantages = batch.batch["advantages"]
32 | returns = batch.batch["returns"]
33 |
34 | max_response_length = batch.batch["responses"].size(-1)
35 |
36 | prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
37 | response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
38 |
39 | max_prompt_length = prompt_mask.size(-1)
40 | prompt_length = prompt_mask.sum(-1).float()
41 | response_length = response_mask.sum(-1).float()
42 |
43 | valid_adv = torch.masked_select(advantages, response_mask)
44 | valid_returns = torch.masked_select(returns, response_mask)
45 |
46 | if use_critic:
47 | values = batch.batch["values"]
48 | valid_values = torch.masked_select(values, response_mask)
49 | return_diff_var = torch.var(valid_returns - valid_values)
50 | return_var = torch.var(valid_returns)
51 |
52 | metrics = {
53 | # score
54 | "critic/score/mean": torch.mean(sequence_score).detach().item(),
55 | "critic/score/max": torch.max(sequence_score).detach().item(),
56 | "critic/score/min": torch.min(sequence_score).detach().item(),
57 | # reward
58 | "critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
59 | "critic/rewards/max": torch.max(sequence_reward).detach().item(),
60 | "critic/rewards/min": torch.min(sequence_reward).detach().item(),
61 | # adv
62 | "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
63 | "critic/advantages/max": torch.max(valid_adv).detach().item(),
64 | "critic/advantages/min": torch.min(valid_adv).detach().item(),
65 | # returns
66 | "critic/returns/mean": torch.mean(valid_returns).detach().item(),
67 | "critic/returns/max": torch.max(valid_returns).detach().item(),
68 | "critic/returns/min": torch.min(valid_returns).detach().item(),
69 | **(
70 | {
71 | # values
72 | "critic/values/mean": torch.mean(valid_values).detach().item(),
73 | "critic/values/max": torch.max(valid_values).detach().item(),
74 | "critic/values/min": torch.min(valid_values).detach().item(),
75 | # vf explained var
76 | "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
77 | }
78 | if use_critic
79 | else {}
80 | ),
81 | # response length
82 | "response_length/mean": torch.mean(response_length).detach().item(),
83 | "response_length/max": torch.max(response_length).detach().item(),
84 | "response_length/min": torch.min(response_length).detach().item(),
85 | "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
86 | .detach()
87 | .item(),
88 | # prompt length
89 | "prompt_length/mean": torch.mean(prompt_length).detach().item(),
90 | "prompt_length/max": torch.max(prompt_length).detach().item(),
91 | "prompt_length/min": torch.min(prompt_length).detach().item(),
92 | "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
93 | }
94 | return metrics
95 |
96 |
97 | def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
98 | num_response_tokens = torch.sum(batch.batch["response_mask"]).item()
99 | num_overall_tokens = sum(batch.meta_info["global_token_num"])
100 | num_tokens_of_section = {
101 | **dict.fromkeys(["gen", "reward"], num_response_tokens),
102 | **dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens),
103 | }
104 | return {
105 | **{f"timing_s/{name}": value for name, value in timing_raw.items()},
106 | **{
107 | f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
108 | for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
109 | },
110 | }
111 |
112 |
113 | def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], num_gpus: int) -> Dict[str, Any]:
114 | total_num_tokens = sum(batch.meta_info["global_token_num"])
115 | time = timing_raw["step"]
116 | return {
117 | "perf/total_num_tokens": total_num_tokens,
118 | "perf/time_per_step": time,
119 | "perf/throughput": total_num_tokens / (time * num_gpus),
120 | }
121 |
--------------------------------------------------------------------------------
/verl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/utils/checkpoint/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt
16 |
17 |
18 | __all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"]
19 |
--------------------------------------------------------------------------------
/verl/utils/checkpoint/checkpoint_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import random
17 | import re
18 | import shutil
19 | import tempfile
20 | from abc import ABC, abstractmethod
21 | from typing import Any, Dict, Optional, Union
22 |
23 | import numpy as np
24 | import torch
25 | import torch.distributed as dist
26 | from filelock import FileLock
27 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
28 | from transformers import PreTrainedTokenizer, ProcessorMixin
29 |
30 |
31 | CHECKPOINT_TRACKER = "latest_global_step.txt"
32 |
33 |
34 | class BaseCheckpointManager(ABC):
35 | """
36 | A checkpoint manager that saves and loads
37 | - model
38 | - optimizer
39 | - lr_scheduler
40 | - extra_states
41 | in a SPMD way.
42 |
43 | We save
44 | - sharded model states and optimizer states
45 | - full lr_scheduler states
46 | - huggingface tokenizer and config for ckpt merge
47 | """
48 |
49 | def __init__(
50 | self,
51 | model: FSDP,
52 | optimizer: torch.optim.Optimizer,
53 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
54 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
55 | ):
56 | self.model = model
57 | self.optimizer = optimizer
58 | self.lr_scheduler = lr_scheduler
59 | self.processing_class = processing_class
60 |
61 | assert isinstance(self.model, FSDP)
62 | self.rank = dist.get_rank()
63 | self.world_size = dist.get_world_size()
64 |
65 | @abstractmethod
66 | def load_checkpoint(self, *args, **kwargs):
67 | raise NotImplementedError
68 |
69 | @abstractmethod
70 | def save_checkpoint(self, *args, **kwargs):
71 | raise NotImplementedError
72 |
73 | @staticmethod
74 | def local_mkdir(path: str) -> str:
75 | if not os.path.isabs(path):
76 | working_dir = os.getcwd()
77 | path = os.path.join(working_dir, path)
78 |
79 | # Using hash value of path as lock file name to avoid long file name
80 | lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock"
81 | lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
82 |
83 | try:
84 | with FileLock(lock_path, timeout=60):
85 | os.makedirs(path, exist_ok=True)
86 | except Exception as e:
87 | print(f"Warning: Failed to acquire lock for {path}: {e}")
88 | os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory
89 |
90 | return path
91 |
92 | @staticmethod
93 | def get_rng_state() -> Dict[str, Any]:
94 | rng_state = {
95 | "cpu": torch.get_rng_state(),
96 | "cuda": torch.cuda.get_rng_state(),
97 | "numpy": np.random.get_state(),
98 | "random": random.getstate(),
99 | }
100 | return rng_state
101 |
102 | @staticmethod
103 | def load_rng_state(rng_state: Dict[str, Any]):
104 | torch.set_rng_state(rng_state["cpu"])
105 | torch.cuda.set_rng_state(rng_state["cuda"])
106 | np.random.set_state(rng_state["numpy"])
107 | random.setstate(rng_state["random"])
108 |
109 |
110 | def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]:
111 | if path is None:
112 | return None
113 |
114 | tracker_file = get_checkpoint_tracker_filename(path)
115 | if not os.path.exists(tracker_file):
116 | print("Checkpoint tracker file does not exist: %s", tracker_file)
117 | return None
118 |
119 | with open(tracker_file, "rb") as f:
120 | iteration = int(f.read().decode())
121 |
122 | ckpt_path = os.path.join(path, directory_format.format(iteration))
123 | if not os.path.exists(ckpt_path):
124 | print("Checkpoint does not exist: %s", ckpt_path)
125 | return None
126 |
127 | print("Found checkpoint: %s", ckpt_path)
128 | return ckpt_path
129 |
130 |
131 | def get_checkpoint_tracker_filename(root_path: str) -> str:
132 | """
133 | Tracker file rescords the latest chckpoint during training to restart from.
134 | """
135 | return os.path.join(root_path, CHECKPOINT_TRACKER)
136 |
137 |
138 | def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"):
139 | """
140 | Remove the obsolete checkpoints that exceed the save_limit.
141 | """
142 | if save_limit <= 0:
143 | return
144 |
145 | if not os.path.exists(path):
146 | return
147 |
148 | pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)")
149 | ckpt_folders = []
150 | for folder in os.listdir(path):
151 | if match := re.match(pattern, folder):
152 | step = int(match.group(1))
153 | if step < global_step:
154 | ckpt_folders.append((step, folder))
155 |
156 | ckpt_folders.sort(reverse=True)
157 | for _, folder in ckpt_folders[save_limit - 1 :]:
158 | folder_path = os.path.join(path, folder)
159 | shutil.rmtree(folder_path, ignore_errors=True)
160 | print(f"Removed obsolete checkpoint: {folder_path}")
161 |
--------------------------------------------------------------------------------
/verl/utils/checkpoint/fsdp_checkpoint_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from typing import Optional, Union
17 |
18 | import torch
19 | import torch.distributed as dist
20 | from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict, set_state_dict
21 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
22 | from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
23 |
24 | from .checkpoint_manager import BaseCheckpointManager
25 |
26 |
27 | class FSDPCheckpointManager(BaseCheckpointManager):
28 | """
29 | A checkpoint manager that saves and loads
30 | - model
31 | - optimizer
32 | - lr_scheduler
33 | - extra_states
34 | in a SPMD way.
35 |
36 | We save
37 | - sharded model states and optimizer states
38 | - full lr_scheduler states
39 | - huggingface tokenizer and config for ckpt merge
40 | """
41 |
42 | def __init__(
43 | self,
44 | model: FSDP,
45 | optimizer: torch.optim.Optimizer,
46 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
47 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
48 | ):
49 | super().__init__(model, optimizer, lr_scheduler, processing_class)
50 |
51 | def load_checkpoint(self, path: Optional[str] = None):
52 | if path is None:
53 | return
54 |
55 | # every rank download its own checkpoint
56 | model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
57 | optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
58 | extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
59 | print(f"[rank-{self.rank}]: Loading model from {os.path.abspath(model_path)}.")
60 | print(f"[rank-{self.rank}]: Loading optimizer from {os.path.abspath(optim_path)}.")
61 | print(f"[rank-{self.rank}]: Loading extra_state from {os.path.abspath(extra_path)}.")
62 | model_state_dict = torch.load(model_path, weights_only=False)
63 | optim_state_dict = torch.load(optim_path, weights_only=False)
64 | extra_state_dict = torch.load(extra_path, weights_only=False)
65 |
66 | state_dict_options = StateDictOptions(cpu_offload=True)
67 | set_state_dict(
68 | model=self.model,
69 | optimizers=self.optimizer,
70 | model_state_dict=model_state_dict,
71 | optim_state_dict=optim_state_dict,
72 | options=state_dict_options,
73 | )
74 | self.lr_scheduler.load_state_dict(extra_state_dict["lr_scheduler"])
75 |
76 | # recover random state
77 | if "rng" in extra_state_dict:
78 | self.load_rng_state(extra_state_dict["rng"])
79 |
80 | def save_checkpoint(self, path: str):
81 | path = self.local_mkdir(path)
82 | dist.barrier()
83 |
84 | # every rank will save its own model and optim shard
85 | state_dict_options = StateDictOptions(cpu_offload=True)
86 | model_state_dict, optim_state_dict = get_state_dict(self.model, self.optimizer, options=state_dict_options)
87 | extra_state_dict = {
88 | "lr_scheduler": self.lr_scheduler.state_dict(),
89 | "rng": self.get_rng_state(),
90 | }
91 | model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
92 | optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
93 | extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
94 |
95 | print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
96 | print(f"[rank-{self.rank}]: Saving optimizer to {os.path.abspath(optim_path)}.")
97 | print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
98 | torch.save(model_state_dict, model_path)
99 | torch.save(optim_state_dict, optim_path)
100 | torch.save(extra_state_dict, extra_path)
101 |
102 | # wait for everyone to dump to local
103 | dist.barrier()
104 |
105 | if self.rank == 0:
106 | hf_path = os.path.join(path, "huggingface")
107 | os.makedirs(hf_path, exist_ok=True)
108 | assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel)
109 | self.model._fsdp_wrapped_module.config.save_pretrained(hf_path)
110 | self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path)
111 | self.processing_class.save_pretrained(hf_path)
112 |
113 | dist.barrier()
114 |
--------------------------------------------------------------------------------
/verl/utils/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import os
17 | from collections import defaultdict
18 | from io import BytesIO
19 | from typing import Any, Dict, List, Optional, Union
20 |
21 | import numpy as np
22 | import torch
23 | from datasets import load_dataset
24 | from jinja2 import Template
25 | from PIL import Image
26 | from PIL.Image import Image as ImageObject
27 | from torch.utils.data import Dataset
28 | from transformers import PreTrainedTokenizer, ProcessorMixin
29 |
30 | from ..models.transformers.qwen2_vl import get_rope_index
31 | from . import torch_functional as VF
32 |
33 |
34 | def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
35 | tensors = defaultdict(list)
36 | non_tensors = defaultdict(list)
37 | for feature in features:
38 | for key, value in feature.items():
39 | if isinstance(value, torch.Tensor):
40 | tensors[key].append(value)
41 | else:
42 | non_tensors[key].append(value)
43 |
44 | for key, value in tensors.items():
45 | tensors[key] = torch.stack(value, dim=0)
46 |
47 | for key, value in non_tensors.items():
48 | non_tensors[key] = np.array(value, dtype=object)
49 |
50 | return {**tensors, **non_tensors}
51 |
52 |
53 | class ImageProcessMixin:
54 | max_pixels: int
55 | min_pixels: int
56 |
57 | def process_image(self, image: Union[Dict[str, Any], ImageObject]) -> ImageObject:
58 | if isinstance(image, dict):
59 | image = Image.open(BytesIO(image["bytes"]))
60 | elif isinstance(image, bytes):
61 | image = Image.open(BytesIO(image))
62 |
63 | if (image.width * image.height) > self.max_pixels:
64 | resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
65 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
66 | image = image.resize((width, height))
67 |
68 | if (image.width * image.height) < self.min_pixels:
69 | resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
70 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
71 | image = image.resize((width, height))
72 |
73 | if image.mode != "RGB":
74 | image = image.convert("RGB")
75 |
76 | return image
77 |
78 |
79 | class RLHFDataset(Dataset, ImageProcessMixin):
80 | """
81 | We assume the dataset contains a column that contains prompts and other information
82 | """
83 |
84 | def __init__(
85 | self,
86 | data_path: str,
87 | tokenizer: PreTrainedTokenizer,
88 | processor: Optional[ProcessorMixin],
89 | prompt_key: str = "prompt",
90 | answer_key: str = "answer",
91 | image_key: str = "images",
92 | max_prompt_length: int = 1024,
93 | truncation: str = "error",
94 | format_prompt: Optional[str] = None,
95 | max_pixels: Optional[int] = None,
96 | min_pixels: Optional[int] = None,
97 | filter_overlong_prompts: bool = True,
98 | ):
99 | self.tokenizer = tokenizer
100 | self.processor = processor
101 | self.prompt_key = prompt_key
102 | self.answer_key = answer_key
103 | self.image_key = image_key
104 | self.max_prompt_length = max_prompt_length
105 | self.truncation = truncation
106 | self.max_pixels = max_pixels
107 | self.min_pixels = min_pixels
108 | self.filter_overlong_prompts = filter_overlong_prompts
109 |
110 | if "@" in data_path:
111 | data_path, data_split = data_path.split("@")
112 | else:
113 | data_split = "train"
114 |
115 | if os.path.isdir(data_path):
116 | # when we use dataset builder, we should always refer to the train split
117 | self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
118 | elif os.path.isfile(data_path):
119 | self.dataset = load_dataset("parquet", data_files=data_path, split="train")
120 | else:
121 | # load remote dataset from huggingface hub
122 | self.dataset = load_dataset(data_path, split=data_split)
123 |
124 | self.format_prompt = None
125 | if format_prompt:
126 | with open(format_prompt, encoding="utf-8") as f:
127 | self.format_prompt = f.read()
128 |
129 | if self.filter_overlong_prompts:
130 | self.dataset = self.dataset.filter(self._filter_overlong_prompts, desc="Filtering overlong prompts")
131 |
132 | def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]:
133 | prompt_str: str = example[self.prompt_key]
134 | if self.format_prompt:
135 | format_prompt = Template(self.format_prompt.strip())
136 | prompt_str = format_prompt.render(content=prompt_str)
137 |
138 | if self.image_key in example:
139 | # https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
140 | content_list = []
141 | for i, content in enumerate(prompt_str.split("")):
142 | if i != 0:
143 | content_list.append({"type": "image"})
144 |
145 | if content:
146 | content_list.append({"type": "text", "text": content})
147 |
148 | return [{"role": "user", "content": content_list}]
149 | else:
150 | return [{"role": "user", "content": prompt_str}]
151 |
152 | def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool:
153 | messages = self._build_messages(example)
154 | processing_class = self.processor if self.processor is not None else self.tokenizer
155 | return (
156 | len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length
157 | )
158 |
159 | def __len__(self):
160 | return len(self.dataset)
161 |
162 | def __getitem__(self, index):
163 | example: dict = self.dataset[index]
164 | messages = self._build_messages(example)
165 |
166 | if self.image_key in example:
167 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
168 | images = [self.process_image(image) for image in example.pop(self.image_key)]
169 | model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt")
170 | input_ids = model_inputs.pop("input_ids")[0]
171 | attention_mask = model_inputs.pop("attention_mask")[0]
172 | example["multi_modal_data"] = {"image": images}
173 | example["multi_modal_inputs"] = dict(model_inputs)
174 | else:
175 | prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
176 | model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
177 | input_ids = model_inputs.pop("input_ids")[0]
178 | attention_mask = model_inputs.pop("attention_mask")[0]
179 |
180 | if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor":
181 | # qwen2vl mrope
182 | position_ids = get_rope_index(
183 | self.processor,
184 | input_ids=input_ids,
185 | image_grid_thw=model_inputs.get("image_grid_thw"),
186 | attention_mask=attention_mask,
187 | ) # (3, seq_length)
188 | else:
189 | position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
190 |
191 | input_ids, attention_mask, position_ids = VF.postprocess_data(
192 | input_ids=input_ids,
193 | attention_mask=attention_mask,
194 | position_ids=position_ids,
195 | max_length=self.max_prompt_length,
196 | pad_token_id=self.tokenizer.pad_token_id,
197 | left_pad=True,
198 | truncation=self.truncation,
199 | )
200 | raw_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
201 | if len(raw_prompt_ids) > self.max_prompt_length:
202 | if self.truncation == "left":
203 | raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
204 | elif self.truncation == "right":
205 | raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
206 | elif self.truncation == "error":
207 | raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
208 |
209 | example["input_ids"] = input_ids
210 | example["attention_mask"] = attention_mask
211 | example["position_ids"] = position_ids
212 | example["raw_prompt_ids"] = raw_prompt_ids
213 | example["ground_truth"] = example.pop(self.answer_key)
214 | return example
215 |
--------------------------------------------------------------------------------
/verl/utils/flops_counter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING, List, Tuple
16 |
17 | import torch
18 |
19 |
20 | if TYPE_CHECKING:
21 | from transformers.models.llama.configuration_llama import LlamaConfig
22 |
23 |
24 | VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3"}
25 |
26 |
27 | def get_device_flops(unit: str = "T") -> float:
28 | def unit_convert(number: float, level: str):
29 | units = ["B", "K", "M", "G", "T", "P"]
30 | if number <= 0:
31 | return number
32 |
33 | ptr = 0
34 | while ptr < len(units) and units[ptr] != level:
35 | number /= 1000
36 | ptr += 1
37 |
38 | return number
39 |
40 | device_name = torch.cuda.get_device_name()
41 | flops = float("inf") # INF flops for unkown gpu type
42 | if "H100" in device_name or "H800" in device_name:
43 | flops = 989e12
44 | elif "A100" in device_name or "A800" in device_name:
45 | flops = 312e12
46 | elif "L40" in device_name:
47 | flops = 181.05e12
48 | elif "L20" in device_name:
49 | flops = 119.5e12
50 | elif "H20" in device_name:
51 | flops = 148e12
52 | elif "910B" in device_name:
53 | flops = 354e12
54 | flops_unit = unit_convert(flops, unit)
55 | return flops_unit
56 |
57 |
58 | class FlopsCounter:
59 | """
60 | Used to count mfu during training loop
61 |
62 | Example:
63 | flops_counter = FlopsCounter(config)
64 | flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
65 | """
66 |
67 | def __init__(self, config: "LlamaConfig"):
68 | if config.model_type not in VALID_MODLE_TYPE:
69 | print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.")
70 |
71 | self.estimate_func = {
72 | "llama": self._estimate_llama_flops,
73 | "qwen2": self._estimate_llama_flops,
74 | "qwen2_vl": self._estimate_llama_flops,
75 | "qwen2_5_vl": self._estimate_llama_flops,
76 | }
77 | self.config = config
78 |
79 | def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
80 | return 0
81 |
82 | def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
83 | hidden_size = self.config.hidden_size
84 | vocab_size = self.config.vocab_size
85 | num_hidden_layers = self.config.num_hidden_layers
86 | num_key_value_heads = self.config.num_key_value_heads
87 | num_attention_heads = self.config.num_attention_heads
88 | intermediate_size = self.config.intermediate_size
89 |
90 | head_dim = hidden_size // num_attention_heads
91 | q_size = num_attention_heads * head_dim
92 | k_size = num_key_value_heads * head_dim
93 | v_size = num_key_value_heads * head_dim
94 |
95 | # non-attn per layer parm
96 | # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
97 | mlp_N = hidden_size * intermediate_size * 3
98 | attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
99 | emd_and_lm_head_N = vocab_size * hidden_size * 2
100 | # non-attn all_layer parm
101 | dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
102 | # non-attn all_layer & all_token fwd & bwd flops
103 | dense_N_flops = 6 * dense_N * tokens_sum
104 |
105 | # attn all_layer & all_token fwd & bwd flops
106 | seqlen_square_sum = 0
107 | for seqlen in batch_seqlens:
108 | seqlen_square_sum += seqlen * seqlen
109 |
110 | attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
111 |
112 | # all_layer & all_token fwd & bwd flops
113 | flops_all_token = dense_N_flops + attn_qkv_flops
114 | flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
115 | return flops_achieved
116 |
117 | def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]:
118 | """
119 | Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
120 |
121 | Args:
122 | batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
123 | delta_time (float): The time taken to process the batch, in seconds.
124 |
125 | Returns:
126 | estimated_flops (float): The estimated FLOPS based on the input tokens and time.
127 | promised_flops (float): The expected FLOPS of the current device.
128 | """
129 | tokens_sum = sum(batch_seqlens)
130 | func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
131 | estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
132 | promised_flops = get_device_flops()
133 | return estimated_flops, promised_flops
134 |
--------------------------------------------------------------------------------
/verl/utils/fsdp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import gc
16 | from collections import defaultdict
17 | from functools import partial
18 | from typing import Callable, Union
19 |
20 | import torch
21 | from torch import nn
22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
23 | from torch.distributed.fsdp._runtime_utils import _lazy_init
24 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
25 | from torch.optim import Optimizer
26 | from transformers import PreTrainedModel
27 | from transformers.trainer_pt_utils import get_module_class_from_name
28 |
29 |
30 | def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]:
31 | param_occurrence = defaultdict(int)
32 | for _, param in model.named_parameters(remove_duplicate=False):
33 | param_occurrence[param] += 1
34 |
35 | duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1}
36 | materialized_params = {}
37 |
38 | def init_fn(module: nn.Module):
39 | for name, param in module.named_parameters(recurse=False):
40 | if param in duplicated_params:
41 | module._parameters[name] = materialized_params.setdefault(
42 | param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad)
43 | )
44 | else:
45 | module._parameters[name] = nn.Parameter(
46 | torch.empty_like(param.data, device=device), requires_grad=param.requires_grad
47 | )
48 |
49 | return init_fn
50 |
51 |
52 | def get_fsdp_wrap_policy(model: PreTrainedModel):
53 | """Get FSDP wrap policy for the model.
54 |
55 | Args:
56 | module: The module to get wrap policy for
57 | """
58 | transformer_cls_to_wrap = set()
59 | for module in model._no_split_modules:
60 | transformer_cls = get_module_class_from_name(model, module)
61 | if transformer_cls is None:
62 | raise Exception(f"Cannot find {module} in pretrained model.")
63 | else:
64 | transformer_cls_to_wrap.add(transformer_cls)
65 |
66 | return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap)
67 |
68 |
69 | @torch.no_grad()
70 | def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
71 | # lazy init FSDP model
72 | _lazy_init(model, model)
73 | assert model._is_root, "Only support root model offloading to CPU"
74 | for handle in model._all_handles:
75 | if handle._offload_params:
76 | continue
77 |
78 | flat_param = handle.flat_param
79 | assert (
80 | flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()
81 | and id(flat_param.data) != id(flat_param._local_shard)
82 | and flat_param.data.size() == flat_param._local_shard.size()
83 | )
84 | handle.flat_param_to("cpu", non_blocking=True)
85 | # the following still keeps id(._local_shard) != id(.data)
86 | flat_param._local_shard = flat_param.data
87 | assert id(flat_param._local_shard) != id(flat_param.data)
88 |
89 | if empty_cache:
90 | torch.cuda.empty_cache()
91 |
92 |
93 | @torch.no_grad()
94 | def load_fsdp_model(model: FSDP, empty_cache: bool = True):
95 | # lazy init FSDP model
96 | _lazy_init(model, model)
97 | assert model._is_root, "Only support root model loading to GPU"
98 | for handle in model._all_handles:
99 | if handle._offload_params:
100 | continue
101 |
102 | flat_param = handle.flat_param
103 | handle.flat_param_to("cuda", non_blocking=True)
104 | # the following still keeps id(._local_shard) != id(.data)
105 | flat_param._local_shard = flat_param.data
106 |
107 | if empty_cache:
108 | gc.collect()
109 |
110 |
111 | @torch.no_grad()
112 | def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
113 | if not optimizer.state:
114 | return
115 |
116 | for param_group in optimizer.param_groups:
117 | for param in param_group["params"]:
118 | state = optimizer.state[param]
119 | for key, value in state.items():
120 | if isinstance(value, torch.Tensor):
121 | state[key] = value.to("cpu", non_blocking=True)
122 |
123 | if empty_cache:
124 | torch.cuda.empty_cache()
125 |
126 |
127 | @torch.no_grad()
128 | def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
129 | if not optimizer.state:
130 | return
131 |
132 | for param_group in optimizer.param_groups:
133 | for param in param_group["params"]:
134 | state = optimizer.state[param]
135 | for key, value in state.items():
136 | if isinstance(value, torch.Tensor):
137 | state[key] = value.to("cuda", non_blocking=True)
138 |
139 | if empty_cache:
140 | gc.collect()
141 |
--------------------------------------------------------------------------------
/verl/utils/logger/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .logger import Tracker
17 |
18 |
19 | __all__ = ["Tracker"]
20 |
--------------------------------------------------------------------------------
/verl/utils/logger/gen_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from abc import ABC, abstractmethod
17 | from dataclasses import dataclass
18 | from typing import List, Tuple
19 |
20 | from ..py_functional import is_package_available
21 |
22 |
23 | if is_package_available("wandb"):
24 | import wandb # type: ignore
25 |
26 |
27 | if is_package_available("swanlab"):
28 | import swanlab # type: ignore
29 |
30 |
31 | @dataclass
32 | class GenerationLogger(ABC):
33 | @abstractmethod
34 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: ...
35 |
36 |
37 | @dataclass
38 | class ConsoleGenerationLogger(GenerationLogger):
39 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
40 | for inp, out, lab, score in samples:
41 | print(f"[prompt] {inp}\n[output] {out}\n[ground_truth] {lab}\n[score] {score}\n")
42 |
43 |
44 | @dataclass
45 | class WandbGenerationLogger(GenerationLogger):
46 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
47 | # Create column names for all samples
48 | columns = ["step"] + sum(
49 | [[f"input_{i + 1}", f"output_{i + 1}", f"label_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))],
50 | [],
51 | )
52 |
53 | if not hasattr(self, "validation_table"):
54 | # Initialize the table on first call
55 | self.validation_table = wandb.Table(columns=columns)
56 |
57 | # Create a new table with same columns and existing data
58 | # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
59 | new_table = wandb.Table(columns=columns, data=self.validation_table.data)
60 |
61 | # Add new row with all data
62 | row_data = [step]
63 | for sample in samples:
64 | row_data.extend(sample)
65 |
66 | new_table.add_data(*row_data)
67 | wandb.log({"val/generations": new_table}, step=step)
68 | self.validation_table = new_table
69 |
70 |
71 | @dataclass
72 | class SwanlabGenerationLogger(GenerationLogger):
73 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
74 | swanlab_text_list = []
75 | for i, sample in enumerate(samples):
76 | row_text = "\n\n---\n\n".join(
77 | (f"input: {sample[0]}", f"output: {sample[1]}", f"label: {sample[2]}", f"score: {sample[3]}")
78 | )
79 | swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
80 |
81 | swanlab.log({"val/generations": swanlab_text_list}, step=step)
82 |
83 |
84 | GEN_LOGGERS = {
85 | "console": ConsoleGenerationLogger,
86 | "wandb": WandbGenerationLogger,
87 | "swanlab": SwanlabGenerationLogger,
88 | }
89 |
90 |
91 | @dataclass
92 | class AggregateGenerationsLogger:
93 | def __init__(self, loggers: List[str]):
94 | self.loggers: List[GenerationLogger] = []
95 |
96 | for logger in loggers:
97 | if logger in GEN_LOGGERS:
98 | self.loggers.append(GEN_LOGGERS[logger]())
99 |
100 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
101 | for logger in self.loggers:
102 | logger.log(samples, step)
103 |
--------------------------------------------------------------------------------
/verl/utils/logger/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A unified tracking interface that supports logging data to different backend
16 | """
17 |
18 | import os
19 | from abc import ABC, abstractmethod
20 | from typing import Any, Dict, List, Optional, Tuple, Union
21 |
22 | from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict
23 | from .gen_logger import AggregateGenerationsLogger
24 |
25 |
26 | if is_package_available("mlflow"):
27 | import mlflow # type: ignore
28 |
29 |
30 | if is_package_available("tensorboard"):
31 | from torch.utils.tensorboard import SummaryWriter
32 |
33 |
34 | if is_package_available("wandb"):
35 | import wandb # type: ignore
36 |
37 |
38 | if is_package_available("swanlab"):
39 | import swanlab # type: ignore
40 |
41 |
42 | class Logger(ABC):
43 | @abstractmethod
44 | def __init__(self, config: Dict[str, Any]) -> None: ...
45 |
46 | @abstractmethod
47 | def log(self, data: Dict[str, Any], step: int) -> None: ...
48 |
49 | def finish(self) -> None:
50 | pass
51 |
52 |
53 | class ConsoleLogger(Logger):
54 | def __init__(self, config: Dict[str, Any]) -> None:
55 | print("Config\n" + convert_dict_to_str(config))
56 |
57 | def log(self, data: Dict[str, Any], step: int) -> None:
58 | print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data)))
59 |
60 |
61 | class MlflowLogger(Logger):
62 | def __init__(self, config: Dict[str, Any]) -> None:
63 | mlflow.start_run(run_name=config["trainer"]["experiment_name"])
64 | mlflow.log_params(flatten_dict(config))
65 |
66 | def log(self, data: Dict[str, Any], step: int) -> None:
67 | mlflow.log_metrics(metrics=data, step=step)
68 |
69 |
70 | class SwanlabLogger(Logger):
71 | def __init__(self, config: Dict[str, Any]) -> None:
72 | swanlab_key = os.getenv("SWANLAB_API_KEY")
73 | swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log")
74 | swanlab_mode = os.getenv("SWANLAB_MODE", "cloud")
75 | if swanlab_key:
76 | swanlab.login(swanlab_key)
77 |
78 | swanlab.init(
79 | project=config["trainer"]["project_name"],
80 | experiment_name=config["trainer"]["experiment_name"],
81 | config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config},
82 | logdir=swanlab_dir,
83 | mode=swanlab_mode,
84 | )
85 |
86 | def log(self, data: Dict[str, Any], step: int) -> None:
87 | swanlab.log(data=data, step=step)
88 |
89 | def finish(self) -> None:
90 | swanlab.finish()
91 |
92 |
93 | class TensorBoardLogger(Logger):
94 | def __init__(self, config: Dict[str, Any]) -> None:
95 | tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log")
96 | os.makedirs(tensorboard_dir, exist_ok=True)
97 | print(f"Saving tensorboard log to {tensorboard_dir}.")
98 | self.writer = SummaryWriter(tensorboard_dir)
99 | # self.writer.add_hparams(hparam_dict=flatten_dict(config), metric_dict={"placeholder": 0})
100 |
101 | def log(self, data: Dict[str, Any], step: int) -> None:
102 | for key, value in data.items():
103 | self.writer.add_scalar(key, value, step)
104 |
105 | def finish(self):
106 | self.writer.close()
107 |
108 |
109 | class WandbLogger(Logger):
110 | def __init__(self, config: Dict[str, Any]) -> None:
111 | wandb.init(
112 | project=config["trainer"]["project_name"],
113 | name=config["trainer"]["experiment_name"],
114 | config=config,
115 | )
116 |
117 | def log(self, data: Dict[str, Any], step: int) -> None:
118 | wandb.log(data=data, step=step)
119 |
120 | def finish(self) -> None:
121 | wandb.finish()
122 |
123 |
124 | LOGGERS = {
125 | "console": ConsoleLogger,
126 | "mlflow": MlflowLogger,
127 | "swanlab": SwanlabLogger,
128 | "tensorboard": TensorBoardLogger,
129 | "wandb": WandbLogger,
130 | }
131 |
132 |
133 | class Tracker:
134 | def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None):
135 | if isinstance(loggers, str):
136 | loggers = [loggers]
137 |
138 | self.loggers: List[Logger] = []
139 | for logger in loggers:
140 | if logger not in LOGGERS:
141 | raise ValueError(f"{logger} is not supported.")
142 |
143 | self.loggers.append(LOGGERS[logger](config))
144 |
145 | self.gen_logger = AggregateGenerationsLogger(loggers)
146 |
147 | def log(self, data: Dict[str, Any], step: int) -> None:
148 | for logger in self.loggers:
149 | logger.log(data=data, step=step)
150 |
151 | def log_generation(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
152 | self.gen_logger.log(samples, step)
153 |
154 | def __del__(self):
155 | for logger in self.loggers:
156 | logger.finish()
157 |
--------------------------------------------------------------------------------
/verl/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Utilities to create common models
16 | """
17 |
18 | from functools import lru_cache
19 | from typing import Optional, Tuple
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from torch import nn
24 |
25 |
26 | @lru_cache
27 | def is_rank0() -> int:
28 | return (not dist.is_initialized()) or (dist.get_rank() == 0)
29 |
30 |
31 | def print_gpu_memory_usage(prefix: str = "GPU memory usage") -> None:
32 | """Report the current GPU VRAM usage."""
33 | if is_rank0():
34 | free_mem, total_mem = torch.cuda.mem_get_info()
35 | print(f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB.")
36 |
37 |
38 | def _get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]:
39 | """Compute the model size."""
40 | n_params = sum(p.numel() for p in model.parameters())
41 |
42 | if scale == "auto":
43 | if n_params > 1e9:
44 | scale = "B"
45 | elif n_params > 1e6:
46 | scale = "M"
47 | elif n_params > 1e3:
48 | scale = "K"
49 | else:
50 | scale = ""
51 |
52 | if scale == "B":
53 | n_params = n_params / 1e9
54 | elif scale == "M":
55 | n_params = n_params / 1e6
56 | elif scale == "K":
57 | n_params = n_params / 1e3
58 | elif scale == "":
59 | pass
60 | else:
61 | raise NotImplementedError(f"Unknown scale {scale}.")
62 |
63 | return n_params, scale
64 |
65 |
66 | def print_model_size(model: nn.Module, name: Optional[str] = None) -> None:
67 | """Print the model size."""
68 | if is_rank0():
69 | n_params, scale = _get_model_size(model, scale="auto")
70 | if name is None:
71 | name = model.__class__.__name__
72 |
73 | print(f"{name} contains {n_params:.2f}{scale} parameters.")
74 |
--------------------------------------------------------------------------------
/verl/utils/py_functional.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contain small python utility functions
16 | """
17 |
18 | import importlib.util
19 | import re
20 | from contextlib import contextmanager
21 | from functools import lru_cache
22 | from typing import Any, Dict, List, Union
23 |
24 | import numpy as np
25 | import yaml
26 | from codetiming import Timer
27 | from yaml import Dumper
28 |
29 |
30 | def is_sci_notation(number: float) -> bool:
31 | pattern = re.compile(r"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$")
32 | return bool(pattern.match(str(number)))
33 |
34 |
35 | def float_representer(dumper: Dumper, number: Union[float, np.float32, np.float64]):
36 | if is_sci_notation(number):
37 | value = str(number)
38 | if "." not in value and "e" in value:
39 | value = value.replace("e", ".0e", 1)
40 | else:
41 | value = str(round(number, 3))
42 |
43 | return dumper.represent_scalar("tag:yaml.org,2002:float", value)
44 |
45 |
46 | yaml.add_representer(float, float_representer)
47 | yaml.add_representer(np.float32, float_representer)
48 | yaml.add_representer(np.float64, float_representer)
49 |
50 |
51 | @lru_cache
52 | def is_package_available(name: str) -> bool:
53 | return importlib.util.find_spec(name) is not None
54 |
55 |
56 | def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
57 | """Union two dict. Will throw an error if there is an item not the same object with the same key."""
58 | for key in dict2.keys():
59 | if key in dict1:
60 | assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object"
61 |
62 | dict1[key] = dict2[key]
63 |
64 | return dict1
65 |
66 |
67 | def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None:
68 | """Append dict to a dict of list."""
69 | for key, val in new_data.items():
70 | if key not in data:
71 | data[key] = []
72 |
73 | data[key].append(val)
74 |
75 |
76 | def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]:
77 | unflattened = {}
78 | for key, value in data.items():
79 | pieces = key.split(sep)
80 | pointer = unflattened
81 | for piece in pieces[:-1]:
82 | if piece not in pointer:
83 | pointer[piece] = {}
84 |
85 | pointer = pointer[piece]
86 |
87 | pointer[pieces[-1]] = value
88 |
89 | return unflattened
90 |
91 |
92 | def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]:
93 | flattened = {}
94 | for key, value in data.items():
95 | new_key = parent_key + sep + key if parent_key else key
96 | if isinstance(value, dict):
97 | flattened.update(flatten_dict(value, new_key, sep=sep))
98 | else:
99 | flattened[new_key] = value
100 |
101 | return flattened
102 |
103 |
104 | def convert_dict_to_str(data: Dict[str, Any]) -> str:
105 | return yaml.dump(data, indent=2)
106 |
107 |
108 | @contextmanager
109 | def timer(name: str, timing_raw: Dict[str, float]):
110 | with Timer(name=name, logger=None) as timer:
111 | yield
112 |
113 | timing_raw[name] = timer.last
114 |
--------------------------------------------------------------------------------
/verl/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for tokenization."""
15 |
16 | from typing import Optional
17 |
18 | from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
19 |
20 |
21 | def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> PreTrainedTokenizer:
22 | """Create a huggingface pretrained tokenizer."""
23 | tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
24 | if override_chat_template is not None:
25 | tokenizer.chat_template = override_chat_template
26 |
27 | if tokenizer.bos_token == "" and tokenizer.eos_token == "":
28 | # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance.
29 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
30 | print("Found gemma model. Set eos_token and eos_token_id to and 107.")
31 | tokenizer.eos_token = ""
32 |
33 | if tokenizer.pad_token_id is None:
34 | print("Pad token is None. Set it to eos_token.")
35 | tokenizer.pad_token = tokenizer.eos_token
36 |
37 | return tokenizer
38 |
39 |
40 | def get_processor(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> Optional[ProcessorMixin]:
41 | """Create a huggingface pretrained processor."""
42 | processor = AutoProcessor.from_pretrained(model_path, **kwargs)
43 | if override_chat_template is not None:
44 | processor.chat_template = override_chat_template
45 |
46 | # Avoid load tokenizer, see:
47 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
48 | if processor is not None and "Processor" not in processor.__class__.__name__:
49 | processor = None
50 |
51 | return processor
52 |
--------------------------------------------------------------------------------
/verl/utils/torch_dtypes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 |
18 | HALF_LIST = ["fp16", "float16"]
19 | FLOAT_LIST = ["fp32", "float32"]
20 | BFLOAT_LIST = ["bf16", "bfloat16"]
21 |
22 |
23 | class PrecisionType:
24 | """Type of precision used."""
25 |
26 | @staticmethod
27 | def is_fp16(precision: str) -> bool:
28 | return precision in HALF_LIST
29 |
30 | @staticmethod
31 | def is_fp32(precision: str) -> bool:
32 | return precision in FLOAT_LIST
33 |
34 | @staticmethod
35 | def is_bf16(precision: str) -> bool:
36 | return precision in BFLOAT_LIST
37 |
38 | @staticmethod
39 | def to_dtype(precision: str) -> torch.dtype:
40 | if precision in HALF_LIST:
41 | return torch.float16
42 | elif precision in FLOAT_LIST:
43 | return torch.float32
44 | elif precision in BFLOAT_LIST:
45 | return torch.bfloat16
46 | else:
47 | raise RuntimeError(f"Unexpected precision: {precision}")
48 |
49 | @staticmethod
50 | def to_str(precision: torch.dtype) -> str:
51 | if precision == torch.float16:
52 | return "float16"
53 | elif precision == torch.float32:
54 | return "float32"
55 | elif precision == torch.bfloat16:
56 | return "bfloat16"
57 | else:
58 | raise RuntimeError(f"Unexpected precision: {precision}")
59 |
--------------------------------------------------------------------------------
/verl/workers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/workers/actor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
16 |
17 |
18 | __all__ = [
19 | "ActorConfig",
20 | "FSDPConfig",
21 | "ModelConfig",
22 | "OptimConfig",
23 | "RefConfig",
24 | ]
25 |
--------------------------------------------------------------------------------
/verl/workers/actor/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The base class for Actor
16 | """
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, Dict
20 |
21 | import torch
22 |
23 | from ...protocol import DataProto
24 | from .config import ActorConfig
25 |
26 |
27 | __all__ = ["BasePPOActor"]
28 |
29 |
30 | class BasePPOActor(ABC):
31 | def __init__(self, config: ActorConfig):
32 | """The base class for PPO actor
33 |
34 | Args:
35 | config (ActorConfig): a config passed to the PPOActor.
36 | """
37 | self.config = config
38 |
39 | @abstractmethod
40 | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
41 | """Compute logits given a batch of data.
42 |
43 | Args:
44 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
45 | ```attention_mask``` and ```position_ids```.
46 |
47 | Returns:
48 | DataProto: a DataProto containing the key ```log_probs```
49 | """
50 | pass
51 |
52 | @abstractmethod
53 | def update_policy(self, data: DataProto) -> Dict[str, Any]:
54 | """Update the policy with an iterator of DataProto
55 |
56 | Args:
57 | data (DataProto): an iterator over the DataProto that returns by
58 | ```make_minibatch_iterator```
59 |
60 | Returns:
61 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
62 | such as ```loss```, ```grad_norm```, etc,.
63 | """
64 | pass
65 |
--------------------------------------------------------------------------------
/verl/workers/actor/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Actor config
16 | """
17 |
18 | import os
19 | from dataclasses import dataclass, field
20 | from typing import Any, Dict, Optional, Tuple
21 |
22 |
23 | @dataclass
24 | class ModelConfig:
25 | model_path: Optional[str] = None
26 | tokenizer_path: Optional[str] = None
27 | override_config: Dict[str, Any] = field(default_factory=dict)
28 | enable_gradient_checkpointing: bool = True
29 | trust_remote_code: bool = True
30 | freeze_vision_tower: bool = False
31 |
32 | def post_init(self):
33 | if self.tokenizer_path is None:
34 | self.tokenizer_path = self.model_path
35 |
36 | if self.model_path is not None and os.path.exists(self.model_path): # ray job uses absolute path
37 | self.model_path = os.path.abspath(self.model_path)
38 |
39 | if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
40 | self.tokenizer_path = os.path.abspath(self.tokenizer_path)
41 |
42 |
43 | @dataclass
44 | class OptimConfig:
45 | lr: float = 1e-6
46 | betas: Tuple[float, float] = (0.9, 0.999)
47 | weight_decay: float = 1e-2
48 | strategy: str = "adamw"
49 | lr_warmup_ratio: float = 0.0
50 | min_lr_ratio: Optional[float] = None
51 | warmup_style: str = "constant"
52 | """auto keys"""
53 | training_steps: int = field(default=-1, init=False)
54 |
55 |
56 | @dataclass
57 | class FSDPConfig:
58 | enable_full_shard: bool = True
59 | enable_cpu_offload: bool = False
60 | enable_rank0_init: bool = False
61 | use_orig_params: bool = False
62 | torch_dtype: Optional[str] = None
63 | fsdp_size: int = -1
64 | mp_param_dtype: str = "bf16"
65 | mp_reduce_dtype: str = "fp32"
66 | mp_buffer_dtype: str = "fp32"
67 |
68 |
69 | @dataclass
70 | class OffloadConfig:
71 | offload_params: bool = False
72 | offload_optimizer: bool = False
73 |
74 |
75 | @dataclass
76 | class ActorConfig:
77 | strategy: str = "fsdp"
78 | global_batch_size: int = 256
79 | micro_batch_size_per_device_for_update: int = 4
80 | micro_batch_size_per_device_for_experience: int = 16
81 | max_grad_norm: float = 1.0
82 | clip_ratio_low: float = 0.2
83 | clip_ratio_high: float = 0.3
84 | clip_ratio_dual: float = 3.0
85 | ppo_epochs: int = 1
86 | padding_free: bool = False
87 | ulysses_sequence_parallel_size: int = 1
88 | use_torch_compile: bool = True
89 | model: ModelConfig = field(default_factory=ModelConfig)
90 | optim: OptimConfig = field(default_factory=OptimConfig)
91 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
92 | offload: OffloadConfig = field(default_factory=OffloadConfig)
93 | """auto keys"""
94 | global_batch_size_per_device: int = field(default=-1, init=False)
95 | disable_kl: bool = field(default=False, init=False)
96 | use_kl_loss: bool = field(default=False, init=False)
97 | kl_penalty: str = field(default="kl", init=False)
98 | kl_coef: float = field(default=0.0, init=False)
99 |
100 |
101 | @dataclass
102 | class RefConfig:
103 | strategy: str = "fsdp"
104 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
105 | offload: OffloadConfig = field(default_factory=OffloadConfig)
106 | """auto keys"""
107 | micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
108 | padding_free: bool = field(default=False, init=False)
109 | ulysses_sequence_parallel_size: int = field(default=1, init=False)
110 | use_torch_compile: bool = field(default=True, init=False)
111 |
--------------------------------------------------------------------------------
/verl/workers/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | ActorRolloutRef config
16 | """
17 |
18 | from dataclasses import dataclass, field
19 |
20 | from .actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
21 | from .critic import CriticConfig
22 | from .reward import RewardConfig
23 | from .rollout import RolloutConfig
24 |
25 |
26 | __all__ = [
27 | "ActorConfig",
28 | "CriticConfig",
29 | "FSDPConfig",
30 | "ModelConfig",
31 | "OptimConfig",
32 | "RefConfig",
33 | "RewardConfig",
34 | "RolloutConfig",
35 | "WorkerConfig",
36 | ]
37 |
38 |
39 | @dataclass
40 | class WorkerConfig:
41 | hybrid_engine: bool = True
42 | actor: ActorConfig = field(default_factory=ActorConfig)
43 | critic: CriticConfig = field(default_factory=CriticConfig)
44 | ref: RefConfig = field(default_factory=RefConfig)
45 | reward: RewardConfig = field(default_factory=RewardConfig)
46 | rollout: RolloutConfig = field(default_factory=RolloutConfig)
47 |
48 | def post_init(self):
49 | self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience
50 | self.ref.padding_free = self.actor.padding_free
51 | self.ref.ulysses_sequence_parallel_size = self.actor.ulysses_sequence_parallel_size
52 | self.ref.use_torch_compile = self.actor.use_torch_compile
53 |
--------------------------------------------------------------------------------
/verl/workers/critic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .config import CriticConfig
16 |
17 |
18 | __all__ = ["CriticConfig"]
19 |
--------------------------------------------------------------------------------
/verl/workers/critic/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Base class for Critic
16 | """
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, Dict
20 |
21 | import torch
22 |
23 | from ...protocol import DataProto
24 | from .config import CriticConfig
25 |
26 |
27 | __all__ = ["BasePPOCritic"]
28 |
29 |
30 | class BasePPOCritic(ABC):
31 | def __init__(self, config: CriticConfig):
32 | self.config = config
33 |
34 | @abstractmethod
35 | def compute_values(self, data: DataProto) -> torch.Tensor:
36 | """Compute values"""
37 | pass
38 |
39 | @abstractmethod
40 | def update_critic(self, data: DataProto) -> Dict[str, Any]:
41 | """Update the critic"""
42 | pass
43 |
--------------------------------------------------------------------------------
/verl/workers/critic/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Critic config
16 | """
17 |
18 | from dataclasses import dataclass, field
19 |
20 | from ..actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig
21 |
22 |
23 | @dataclass
24 | class CriticConfig:
25 | strategy: str = "fsdp"
26 | global_batch_size: int = 256
27 | micro_batch_size_per_device_for_update: int = 4
28 | micro_batch_size_per_device_for_experience: int = 16
29 | max_grad_norm: float = 1.0
30 | cliprange_value: float = 0.5
31 | ppo_epochs: int = 1
32 | padding_free: bool = False
33 | ulysses_sequence_parallel_size: int = 1
34 | model: ModelConfig = field(default_factory=ModelConfig)
35 | optim: OptimConfig = field(default_factory=OptimConfig)
36 | fsdp: FSDPConfig = field(default_factory=FSDPConfig)
37 | offload: OffloadConfig = field(default_factory=OffloadConfig)
38 | """auto keys"""
39 | global_batch_size_per_device: int = field(default=-1, init=False)
40 |
--------------------------------------------------------------------------------
/verl/workers/reward/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .config import RewardConfig
16 | from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager
17 |
18 |
19 | __all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"]
20 |
--------------------------------------------------------------------------------
/verl/workers/reward/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Reward config
16 | """
17 |
18 | import os
19 | from dataclasses import dataclass, field
20 | from typing import Optional
21 |
22 |
23 | @dataclass
24 | class RewardConfig:
25 | reward_type: str = "batch"
26 | reward_function: Optional[str] = None
27 | reward_function_kwargs: dict = field(default_factory=dict)
28 | skip_special_tokens: bool = True
29 | num_cpus: int = 1
30 | """auto keys"""
31 | reward_function_name: Optional[str] = field(default=None, init=False)
32 |
33 | def post_init(self):
34 | if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main
35 | if ":" not in self.reward_function:
36 | self.reward_function_name = "main"
37 | else:
38 | self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1)
39 |
40 | if os.path.exists(self.reward_function): # ray job uses absolute path
41 | self.reward_function = os.path.abspath(self.reward_function)
42 | else:
43 | self.reward_function = None
44 |
--------------------------------------------------------------------------------
/verl/workers/reward/function.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import importlib.util
16 | import os
17 | import sys
18 | from abc import ABC, abstractmethod
19 | from collections import defaultdict
20 | from functools import partial
21 | from typing import Callable, Dict, List, Optional, Tuple, TypedDict
22 |
23 | import torch
24 | from transformers import PreTrainedTokenizer
25 |
26 | from ...protocol import DataProto
27 | from .config import RewardConfig
28 |
29 |
30 | class RewardScore(TypedDict):
31 | overall: float
32 | format: Optional[float]
33 | accuracy: Optional[float]
34 |
35 |
36 | SequentialRewardFunction = Callable[[str, str], RewardScore]
37 |
38 | BatchRewardFunction = Callable[[List[str], List[str]], List[RewardScore]]
39 |
40 |
41 | class FunctionRewardManager(ABC):
42 | """Reward manager for rule-based reward."""
43 |
44 | def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
45 | if config.reward_function is None:
46 | raise ValueError("Reward function is not provided.")
47 |
48 | if not os.path.exists(config.reward_function):
49 | raise FileNotFoundError(f"Reward function file {config.reward_function} not found.")
50 |
51 | spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function)
52 | module = importlib.util.module_from_spec(spec)
53 | try:
54 | sys.modules["custom_reward_fn"] = module
55 | spec.loader.exec_module(module)
56 | except Exception as e:
57 | raise RuntimeError(f"Failed to load reward function: {e}")
58 |
59 | if not hasattr(module, config.reward_function_name):
60 | raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")
61 |
62 | reward_fn = getattr(module, config.reward_function_name)
63 | print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
64 | self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
65 | self.config = config
66 | self.tokenizer = tokenizer
67 |
68 | @abstractmethod
69 | def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
70 | """Compute reward for a batch of data."""
71 | ...
72 |
73 |
74 | class SequentialFunctionRewardManager(FunctionRewardManager):
75 | reward_fn: SequentialRewardFunction
76 |
77 | def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
78 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
79 | reward_metrics = defaultdict(list)
80 | response_ids = data.batch["responses"]
81 | response_length = data.batch["response_mask"].sum(dim=-1)
82 | for i in range(len(data)):
83 | valid_response_ids = response_ids[i][: response_length[i]]
84 | response_str = self.tokenizer.decode(
85 | valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
86 | )
87 | ground_truth = data.non_tensor_batch["ground_truth"][i]
88 |
89 | score = self.reward_fn(response_str, ground_truth)
90 | reward_tensor[i, response_length[i] - 1] = score["overall"]
91 | for key, value in score.items():
92 | reward_metrics[key].append(value)
93 |
94 | return reward_tensor, reward_metrics
95 |
96 |
97 | class BatchFunctionRewardManager(FunctionRewardManager):
98 | reward_fn: BatchRewardFunction
99 |
100 | def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
101 | response_str, ground_truth = [], []
102 | response_ids = data.batch["responses"]
103 | response_length = data.batch["response_mask"].sum(dim=-1)
104 | for i in range(len(data)):
105 | valid_response_ids = response_ids[i][: response_length[i]]
106 | response_str.append(
107 | self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens)
108 | )
109 | ground_truth.append(data.non_tensor_batch["ground_truth"][i])
110 |
111 | scores = self.reward_fn(response_str, ground_truth)
112 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
113 | reward_metrics = defaultdict(list)
114 | for i, score in enumerate(scores):
115 | reward_tensor[i, response_length[i] - 1] = score["overall"]
116 | for key, value in score.items():
117 | reward_metrics[key].append(value)
118 |
119 | return reward_tensor, reward_metrics
120 |
--------------------------------------------------------------------------------
/verl/workers/rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .config import RolloutConfig
17 | from .vllm_rollout_spmd import vLLMRollout
18 |
19 |
20 | __all__ = ["RolloutConfig", "vLLMRollout"]
21 |
--------------------------------------------------------------------------------
/verl/workers/rollout/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from abc import ABC, abstractmethod
16 |
17 | from ...protocol import DataProto
18 |
19 |
20 | __all__ = ["BaseRollout"]
21 |
22 |
23 | class BaseRollout(ABC):
24 | @abstractmethod
25 | def generate_sequences(self, prompts: DataProto) -> DataProto:
26 | """Generate sequences"""
27 | pass
28 |
--------------------------------------------------------------------------------
/verl/workers/rollout/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Rollout config
16 | """
17 |
18 | from dataclasses import asdict, dataclass, field
19 | from typing import Any, Dict, Optional
20 |
21 |
22 | @dataclass
23 | class RolloutConfig:
24 | name: str = "vllm"
25 | n: int = 1
26 | temperature: float = 1.0
27 | top_p: float = 1.0
28 | top_k: int = -1
29 | seed: int = 1
30 | limit_images: int = 0
31 | dtype: str = "bf16"
32 | gpu_memory_utilization: float = 0.6
33 | ignore_eos: bool = False
34 | enforce_eager: bool = False
35 | enable_chunked_prefill: bool = False # only for v0 engine
36 | tensor_parallel_size: int = 2
37 | max_model_len: Optional[int] = None
38 | max_num_batched_tokens: int = 8192
39 | disable_log_stats: bool = True
40 | val_override_config: Dict[str, Any] = field(default_factory=dict)
41 | """auto keys"""
42 | prompt_length: int = field(default=-1, init=False)
43 | response_length: int = field(default=-1, init=False)
44 | trust_remote_code: bool = field(default=False, init=False)
45 |
46 | def to_dict(self):
47 | return asdict(self)
48 |
--------------------------------------------------------------------------------
/verl/workers/rollout/vllm_rollout_spmd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from contextlib import contextmanager
17 | from typing import Any, Dict, List, Optional, Union
18 |
19 | import numpy as np
20 | import torch
21 | import torch.distributed
22 | from tensordict import TensorDict
23 | from transformers import PreTrainedTokenizer
24 | from vllm import LLM, RequestOutput, SamplingParams
25 |
26 | from ...protocol import DataProto
27 | from ...utils import torch_functional as VF
28 | from ...utils.tokenizer import get_processor
29 | from ...utils.torch_dtypes import PrecisionType
30 | from .base import BaseRollout
31 | from .config import RolloutConfig
32 |
33 |
34 | def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
35 | if isinstance(value, torch.Tensor):
36 | return value.repeat_interleave(repeats, dim=0)
37 | else:
38 | return np.repeat(value, repeats, axis=0)
39 |
40 |
41 | def _get_logit_bias(model_path: str, trust_remote_code: bool) -> Optional[Dict[int, float]]:
42 | processor = get_processor(model_path, trust_remote_code=trust_remote_code)
43 | if processor is not None and hasattr(processor, "image_token"):
44 | image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
45 | return {image_token_id: -100}
46 | else:
47 | return None
48 |
49 |
50 | class vLLMRollout(BaseRollout):
51 | def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer):
52 | """A vLLM rollout. It requires the module is supported by the vllm.
53 |
54 | Args:
55 | module: module here follows huggingface APIs
56 | config: DictConfig
57 | tokenizer: the task/model tokenizer
58 | """
59 | super().__init__()
60 | self.rank = int(os.getenv("RANK", "0"))
61 | self.config = config
62 | self.pad_token_id = tokenizer.pad_token_id
63 | if config.tensor_parallel_size > torch.distributed.get_world_size():
64 | raise ValueError("Tensor parallelism size should be less than world size.")
65 |
66 | if config.max_num_batched_tokens < config.prompt_length + config.response_length:
67 | raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")
68 |
69 | self.inference_engine = LLM(
70 | model=model_path,
71 | skip_tokenizer_init=False,
72 | trust_remote_code=config.trust_remote_code,
73 | load_format="dummy",
74 | dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
75 | seed=config.seed,
76 | max_model_len=config.max_model_len or config.prompt_length + config.response_length,
77 | distributed_executor_backend="external_launcher",
78 | tensor_parallel_size=config.tensor_parallel_size,
79 | gpu_memory_utilization=config.gpu_memory_utilization,
80 | max_num_batched_tokens=config.max_num_batched_tokens,
81 | disable_log_stats=config.disable_log_stats,
82 | enforce_eager=config.enforce_eager,
83 | disable_custom_all_reduce=True,
84 | limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None,
85 | disable_mm_preprocessor_cache=True,
86 | enable_chunked_prefill=config.enable_chunked_prefill,
87 | enable_sleep_mode=True,
88 | )
89 |
90 | # Offload vllm model to reduce peak memory usage
91 | self.inference_engine.sleep(level=1)
92 |
93 | sampling_kwargs = {
94 | "max_tokens": config.response_length,
95 | "detokenize": False,
96 | "logit_bias": _get_logit_bias(model_path, trust_remote_code=config.trust_remote_code),
97 | }
98 | default_sampling_params = SamplingParams()
99 | for key in config.to_dict().keys():
100 | if hasattr(default_sampling_params, key):
101 | sampling_kwargs[key] = getattr(config, key)
102 |
103 | print(f"Sampling params: {sampling_kwargs}.")
104 | self.sampling_params = SamplingParams(**sampling_kwargs)
105 |
106 | @contextmanager
107 | def update_sampling_params(self, **kwargs):
108 | # update sampling params
109 | old_sampling_params_args = {}
110 | if kwargs:
111 | for key, value in kwargs.items():
112 | if hasattr(self.sampling_params, key):
113 | old_value = getattr(self.sampling_params, key)
114 | old_sampling_params_args[key] = old_value
115 | setattr(self.sampling_params, key, value)
116 |
117 | yield
118 | # roll back to previous sampling params
119 | for key, value in old_sampling_params_args.items():
120 | setattr(self.sampling_params, key, value)
121 |
122 | @torch.no_grad()
123 | def generate_sequences(self, prompts: DataProto) -> DataProto:
124 | # left-padded attention_mask
125 | input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length)
126 | attention_mask: torch.Tensor = prompts.batch["attention_mask"]
127 | position_ids: torch.Tensor = prompts.batch["position_ids"]
128 | eos_token_id: int = prompts.meta_info["eos_token_id"]
129 | batch_size = input_ids.size(0)
130 |
131 | non_tensor_batch = prompts.non_tensor_batch
132 | if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
133 | raise RuntimeError("vllm sharding manager is not work properly.")
134 |
135 | if "multi_modal_data" in non_tensor_batch:
136 | vllm_inputs = []
137 | for raw_prompt_ids, multi_modal_data in zip(
138 | non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")
139 | ):
140 | vllm_inputs.append({"prompt_token_ids": list(raw_prompt_ids), "multi_modal_data": multi_modal_data})
141 | else:
142 | vllm_inputs = [
143 | {"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
144 | ]
145 |
146 | # users can customize different sampling_params at different run
147 | with self.update_sampling_params(**prompts.meta_info):
148 | completions: List[RequestOutput] = self.inference_engine.generate(
149 | prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=(self.rank == 0)
150 | )
151 | response_ids = [output.token_ids for completion in completions for output in completion.outputs]
152 | response_ids = VF.pad_2d_list_to_length(
153 | response_ids, self.pad_token_id, max_length=self.config.response_length
154 | ).to(input_ids.device)
155 |
156 | if self.sampling_params.n > 1:
157 | batch_size = batch_size * self.sampling_params.n
158 | input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
159 | attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
160 | position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
161 |
162 | sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
163 | response_length = response_ids.size(1)
164 | delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
165 | delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1)
166 | if position_ids.dim() == 3: # qwen2vl mrope
167 | delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
168 |
169 | # prompt: left pad + response: right pad
170 | # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
171 | # position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
172 | response_position_ids = position_ids[..., -1:] + delta_position_id
173 | position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
174 | response_mask = VF.get_response_mask(
175 | response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
176 | )
177 | attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
178 |
179 | # all the tp ranks should contain the same data here. data in all ranks are valid
180 | batch = TensorDict(
181 | {
182 | "prompts": input_ids,
183 | "responses": response_ids,
184 | "input_ids": sequence_ids, # here input_ids become the whole sentences
185 | "attention_mask": attention_mask,
186 | "response_mask": response_mask,
187 | "position_ids": position_ids,
188 | },
189 | batch_size=batch_size,
190 | )
191 | return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
192 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .base import BaseShardingManager
17 | from .fsdp_ulysses import FSDPUlyssesShardingManager
18 | from .fsdp_vllm import FSDPVLLMShardingManager
19 |
20 |
21 | __all__ = ["BaseShardingManager", "FSDPUlyssesShardingManager", "FSDPVLLMShardingManager"]
22 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Sharding manager to implement HybridEngine
16 | """
17 |
18 | from ...protocol import DataProto
19 |
20 |
21 | class BaseShardingManager:
22 | def __enter__(self):
23 | pass
24 |
25 | def __exit__(self, exc_type, exc_value, traceback):
26 | pass
27 |
28 | def preprocess_data(self, data: DataProto) -> DataProto:
29 | return data
30 |
31 | def postprocess_data(self, data: DataProto) -> DataProto:
32 | return data
33 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/fsdp_ulysses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
16 | """
17 |
18 | from torch.distributed.device_mesh import DeviceMesh
19 |
20 | from ...protocol import DataProto, all_gather_data_proto
21 | from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group
22 | from .base import BaseShardingManager
23 |
24 |
25 | class FSDPUlyssesShardingManager(BaseShardingManager):
26 | """
27 | Sharding manager to support data resharding when using FSDP + Ulysses
28 | """
29 |
30 | def __init__(self, device_mesh: DeviceMesh):
31 | super().__init__()
32 | self.device_mesh = device_mesh
33 |
34 | def __enter__(self):
35 | if self.device_mesh is not None:
36 | self.prev_sp_group = get_ulysses_sequence_parallel_group()
37 | set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group())
38 |
39 | def __exit__(self, exc_type, exc_value, traceback):
40 | if self.device_mesh is not None:
41 | set_ulysses_sequence_parallel_group(self.prev_sp_group)
42 |
43 | def preprocess_data(self, data: DataProto) -> DataProto:
44 | """
45 | AllGather data from sp region
46 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
47 | In Ulysses, we need to make sure the same data is used across a SP group
48 | """
49 | if self.device_mesh is not None:
50 | sp_size = self.device_mesh["sp"].size()
51 | sp_group = self.device_mesh["sp"].get_group()
52 | all_gather_data_proto(data, size=sp_size, group=sp_group)
53 |
54 | return data
55 |
56 | def postprocess_data(self, data: DataProto) -> DataProto:
57 | """
58 | Split the data to follow FSDP partition
59 | """
60 | if self.device_mesh is not None:
61 | sp_size = self.device_mesh["sp"].size()
62 | sp_rank = self.device_mesh["sp"].get_local_rank()
63 | data = data.chunk(chunks=sp_size)[sp_rank]
64 |
65 | return data
66 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/fsdp_vllm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | from typing import Dict, Iterable, Tuple, Union
17 |
18 | import torch
19 | import torch.distributed as dist
20 | from torch.distributed._tensor import DTensor
21 | from torch.distributed.checkpoint.state_dict import get_model_state_dict
22 | from torch.distributed.device_mesh import DeviceMesh
23 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
24 | from vllm import LLM
25 | from vllm.distributed import parallel_state as vllm_ps
26 |
27 | from ...protocol import DataProto, all_gather_data_proto
28 | from ...utils.model_utils import print_gpu_memory_usage
29 | from .base import BaseShardingManager
30 |
31 |
32 | class FSDPVLLMShardingManager(BaseShardingManager):
33 | def __init__(
34 | self,
35 | module: FSDP,
36 | inference_engine: LLM,
37 | device_mesh: DeviceMesh,
38 | ):
39 | self.module = module
40 | self.inference_engine = inference_engine
41 | self.device_mesh = device_mesh
42 |
43 | self.world_size = dist.get_world_size()
44 | self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
45 | self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
46 | self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group
47 |
48 | # Record freed bytes to estimate memory usage correctly
49 | # https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119
50 | self.freed_bytes = 0
51 |
52 | # Note that torch_random_states may be different on each dp rank
53 | self.torch_random_states = torch.cuda.get_rng_state()
54 | # get a random rng states
55 | gen_dp_rank = self.device_mesh["dp"].get_local_rank()
56 | torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
57 | self.gen_random_states = torch.cuda.get_rng_state()
58 | torch.cuda.set_rng_state(self.torch_random_states)
59 |
60 | def _make_weight_iterator(
61 | self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
62 | ) -> Iterable[Tuple[str, torch.Tensor]]:
63 | for name, tensor in actor_weights.items():
64 | yield name, tensor.full_tensor() if self.world_size != 1 else tensor
65 |
66 | def __enter__(self):
67 | # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
68 | # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
69 | # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
70 | # to speed up memory allocations.
71 | #
72 | # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
73 | # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
74 | torch.cuda.empty_cache()
75 | print_gpu_memory_usage("Before state_dict() in sharding manager")
76 | actor_weights = get_model_state_dict(self.module)
77 | print_gpu_memory_usage("After state_dict() in sharding manager")
78 |
79 | if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
80 | self.inference_engine.wake_up(tags=["weights"])
81 | else:
82 | self.inference_engine.wake_up()
83 |
84 | model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
85 | model.load_weights(self._make_weight_iterator(actor_weights))
86 | print_gpu_memory_usage("After sync model weights in sharding manager")
87 |
88 | del actor_weights
89 | torch.cuda.empty_cache()
90 |
91 | if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
92 | self.inference_engine.wake_up(tags=["kv_cache"])
93 |
94 | print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
95 | # important: need to manually set the random states of each tp to be identical.
96 | if self.device_mesh is not None:
97 | self.torch_random_states = torch.cuda.get_rng_state()
98 | torch.cuda.set_rng_state(self.gen_random_states)
99 |
100 | def __exit__(self, exc_type, exc_value, traceback):
101 | print_gpu_memory_usage("Before vllm offload in sharding manager")
102 | free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
103 | self.inference_engine.sleep(level=1)
104 | free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
105 | self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
106 | print_gpu_memory_usage("After vllm offload in sharding manager")
107 |
108 | self.module.train()
109 | torch.cuda.empty_cache() # add empty cache after each compute
110 |
111 | # restore random states
112 | if self.device_mesh is not None:
113 | self.gen_random_states = torch.cuda.get_rng_state()
114 | torch.cuda.set_rng_state(self.torch_random_states)
115 |
116 | def preprocess_data(self, data: DataProto) -> DataProto:
117 | """All gather across tp group to make each rank has identical input."""
118 | all_gather_data_proto(data, size=self.tp_size, group=self.tp_group)
119 | return data
120 |
121 | def postprocess_data(self, data: DataProto) -> DataProto:
122 | """Get chunk data of this tp rank since we do all gather in preprocess."""
123 | if self.tp_size > 1:
124 | data = data.chunk(chunks=self.tp_size)[self.tp_rank]
125 |
126 | return data
127 |
--------------------------------------------------------------------------------