├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md └── workflows │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── Dockerfile.legacy ├── LICENSE ├── Makefile ├── README.md ├── assets ├── baselines.md ├── easyr1_grpo.png ├── qwen2_5_vl_7b_geo.png └── wechat.jpg ├── examples ├── baselines │ ├── qwen2_5_vl_3b_clevr.sh │ └── qwen2_5_vl_3b_geoqa8k.sh ├── config.yaml ├── format_prompt │ ├── math_format.jinja │ └── r1v_format.jinja ├── qwen2_5_7b_math_grpo.sh ├── qwen2_5_vl_32b_geo3k_grpo.sh ├── qwen2_5_vl_3b_geo3k_grpo.sh ├── qwen2_5_vl_7b_geo3k_grpo.sh ├── qwen2_5_vl_7b_geo3k_reinforce.sh ├── qwen2_5_vl_7b_geo3k_swanlab.sh ├── qwen2_5_vl_7b_multi_image.sh ├── qwen3_4b_math_grpo.sh ├── reward_function │ ├── math.py │ └── r1v.py └── runtime_env.yaml ├── pyproject.toml ├── requirements.txt ├── scripts └── model_merger.py ├── setup.py └── verl ├── __init__.py ├── models ├── __init__.py ├── monkey_patch.py └── transformers │ ├── __init__.py │ ├── flash_attention_utils.py │ └── qwen2_vl.py ├── protocol.py ├── single_controller ├── __init__.py ├── base │ ├── __init__.py │ ├── decorator.py │ ├── register_center │ │ ├── __init__.py │ │ └── ray.py │ ├── worker.py │ └── worker_group.py └── ray │ ├── __init__.py │ └── base.py ├── trainer ├── __init__.py ├── config.py ├── core_algos.py ├── data_loader.py ├── main.py ├── metrics.py └── ray_trainer.py ├── utils ├── __init__.py ├── checkpoint │ ├── __init__.py │ ├── checkpoint_manager.py │ └── fsdp_checkpoint_manager.py ├── dataset.py ├── flops_counter.py ├── fsdp_utils.py ├── logger │ ├── __init__.py │ ├── gen_logger.py │ └── logger.py ├── model_utils.py ├── py_functional.py ├── seqlen_balancing.py ├── tokenizer.py ├── torch_dtypes.py ├── torch_functional.py └── ulysses.py └── workers ├── __init__.py ├── actor ├── __init__.py ├── base.py ├── config.py └── dp_actor.py ├── config.py ├── critic ├── __init__.py ├── base.py ├── config.py └── dp_critic.py ├── fsdp_workers.py ├── reward ├── __init__.py ├── config.py └── function.py ├── rollout ├── __init__.py ├── base.py ├── config.py └── vllm_rollout_spmd.py └── sharding_manager ├── __init__.py ├── base.py ├── fsdp_ulysses.py └── fsdp_vllm.py /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | `hoshihiyouga AT gmail DOT com`. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to EasyR1 2 | 3 | Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable. 4 | 5 | It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. 6 | 7 | However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md). 8 | 9 | **This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).** 10 | 11 | ## Ways to contribute 12 | 13 | There are several ways you can contribute to EasyR1: 14 | 15 | * Fix outstanding issues with the existing code. 16 | * Submit issues related to bugs or desired new features. 17 | * Contribute to the examples or to the documentation. 18 | 19 | ### Style guide 20 | 21 | EasyR1 follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details. 22 | 23 | ### Create a Pull Request 24 | 25 | 1. Fork the [repository](https://github.com/hiyouga/EasyR1) by clicking on the [Fork](https://github.com/hiyouga/EasyR1/fork) button on the repository's page. This creates a copy of the code under your GitHub user account. 26 | 27 | 2. Clone your fork to your local disk, and add the base repository as a remote: 28 | 29 | ```bash 30 | git clone git@github.com:[username]/EasyR1.git 31 | cd EasyR1 32 | git remote add upstream https://github.com/hiyouga/EasyR1.git 33 | ``` 34 | 35 | 3. Create a new branch to hold your development changes: 36 | 37 | ```bash 38 | git checkout -b dev_your_branch 39 | ``` 40 | 41 | 4. Set up a development environment by running the following command in a virtual environment: 42 | 43 | ```bash 44 | pip install -e ".[dev]" 45 | ``` 46 | 47 | 5. Check code before commit: 48 | 49 | ```bash 50 | make commit 51 | make style && make quality 52 | ``` 53 | 54 | 6. Submit changes: 55 | 56 | ```bash 57 | git add . 58 | git commit -m "commit message" 59 | git fetch upstream 60 | git rebase upstream/main 61 | git push -u origin dev_your_branch 62 | ``` 63 | 64 | 7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/EasyR1). 65 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | paths: 8 | - "**.py" 9 | - "requirements.txt" 10 | - ".github/workflows/*.yml" 11 | pull_request: 12 | branches: 13 | - "main" 14 | paths: 15 | - "**.py" 16 | - "requirements.txt" 17 | - ".github/workflows/*.yml" 18 | 19 | jobs: 20 | tests: 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | python-version: 25 | - "3.11" 26 | os: 27 | - "ubuntu-latest" 28 | 29 | runs-on: ${{ matrix.os }} 30 | 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v4 34 | 35 | - name: Set up Python 36 | uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | cache: "pip" 40 | cache-dependency-path: "setup.py" 41 | 42 | - name: Install dependencies 43 | run: | 44 | python -m pip install --upgrade pip 45 | python -m pip install ruff 46 | 47 | - name: Check quality 48 | run: | 49 | make style && make quality 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # outputs 174 | outputs/ 175 | checkpoints/ 176 | wandb/ 177 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-added-large-files 7 | args: ['--maxkb=25000'] 8 | - id: check-merge-conflict 9 | - id: check-yaml 10 | - id: debug-statements 11 | - id: end-of-file-fixer 12 | - id: requirements-txt-fixer 13 | - id: trailing-whitespace 14 | args: [--markdown-linebreak-ext=md] 15 | - id: no-commit-to-branch 16 | args: ['--branch', 'main'] 17 | 18 | - repo: https://github.com/asottile/pyupgrade 19 | rev: v3.17.0 20 | hooks: 21 | - id: pyupgrade 22 | args: [--py38-plus] 23 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) 2 | # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html 3 | FROM nvcr.io/nvidia/pytorch:24.08-py3 4 | 5 | # Define environments 6 | ENV MAX_JOBS=32 7 | ENV VLLM_WORKER_MULTIPROC_METHOD=spawn 8 | ENV DEBIAN_FRONTEND=noninteractive 9 | ENV NODE_OPTIONS="" 10 | ENV PIP_ROOT_USER_ACTION=ignore 11 | ENV HF_HUB_ENABLE_HF_TRANSFER="1" 12 | 13 | # Define installation arguments 14 | ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ 15 | ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 16 | 17 | # Set apt source 18 | RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ 19 | { \ 20 | echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ 21 | echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ 22 | echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ 23 | echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ 24 | } > /etc/apt/sources.list 25 | 26 | # Install systemctl 27 | RUN apt-get update && \ 28 | apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ 29 | apt-get clean 30 | 31 | # Install tini 32 | RUN apt-get update && \ 33 | apt-get install -y tini && \ 34 | apt-get clean 35 | 36 | # Change pip source 37 | RUN pip config set global.index-url "${PIP_INDEX}" && \ 38 | pip config set global.extra-index-url "${PIP_INDEX}" && \ 39 | python -m pip install --upgrade pip 40 | 41 | # Uninstall nv-pytorch fork 42 | RUN pip uninstall -y torch torchvision torchaudio \ 43 | pytorch-quantization pytorch-triton torch-tensorrt \ 44 | transformer_engine flash_attn apex megatron-core \ 45 | xgboost opencv grpcio 46 | 47 | # Fix cv2 48 | RUN rm -rf /usr/local/lib/python3.10/dist-packages/cv2 49 | 50 | # Install torch-2.6.0+cu124 + vllm-0.8.4 51 | # torch-2.6.0+cu124: cxx11abi=False 52 | # torch-2.6.0+cu126: cxx11abi=True 53 | # see https://github.com/flashinfer-ai/flashinfer/issues/911 54 | RUN pip install --no-cache-dir "vllm==0.8.4" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" tensordict torchdata \ 55 | "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ 56 | "numpy<2.0.0" "pyarrow>=15.0.0" pandas \ 57 | ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb liger-kernel mathruler \ 58 | pytest yapf py-spy pyext pre-commit ruff 59 | 60 | # Install flash-attn-2.7.4.post1 (cxx11abi=False) 61 | RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ 62 | pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 63 | 64 | # Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False) 65 | # vllm-0.8.3 does not support flashinfer>=0.2.3 66 | # see https://github.com/vllm-project/vllm/pull/15777 67 | RUN wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ 68 | pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl 69 | 70 | # Fix packages 71 | RUN pip uninstall -y pynvml nvidia-ml-py && \ 72 | pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" 73 | 74 | # Reset pip config 75 | RUN pip config unset global.index-url && \ 76 | pip config unset global.extra-index-url 77 | -------------------------------------------------------------------------------- /Dockerfile.legacy: -------------------------------------------------------------------------------- 1 | # Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) 2 | # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html 3 | FROM nvcr.io/nvidia/pytorch:24.08-py3 4 | 5 | # Define environments 6 | ENV MAX_JOBS=32 7 | ENV VLLM_WORKER_MULTIPROC_METHOD=spawn 8 | ENV DEBIAN_FRONTEND=noninteractive 9 | ENV NODE_OPTIONS="" 10 | ENV PIP_ROOT_USER_ACTION=ignore 11 | ENV HF_HUB_ENABLE_HF_TRANSFER="1" 12 | 13 | # Define installation arguments 14 | ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ 15 | ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple 16 | ARG VLLM_COMMIT=227578480d71fc94ef46ca77fb69496412158d68 17 | 18 | # Set apt source 19 | RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ 20 | { \ 21 | echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ 22 | echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ 23 | echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ 24 | echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ 25 | } > /etc/apt/sources.list 26 | 27 | # Install systemctl 28 | RUN apt-get update && \ 29 | apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ 30 | apt-get clean 31 | 32 | # Install tini 33 | RUN apt-get update && \ 34 | apt-get install -y tini && \ 35 | apt-get clean 36 | 37 | # Change pip source 38 | RUN pip config set global.index-url "${PIP_INDEX}" && \ 39 | pip config set global.extra-index-url "${PIP_INDEX}" && \ 40 | python -m pip install --upgrade pip 41 | 42 | # Uninstall nv-pytorch fork 43 | RUN pip uninstall -y torch torchvision torchaudio \ 44 | pytorch-quantization pytorch-triton torch-tensorrt \ 45 | transformer_engine flash_attn apex megatron-core \ 46 | xgboost opencv grpcio 47 | 48 | # Fix cv2 49 | RUN rm -rf /usr/local/lib/python3.10/dist-packages/cv2 50 | 51 | # Install vllm-0.7.4-nightly 52 | RUN pip install --no-cache-dir vllm --pre --extra-index-url "https://wheels.vllm.ai/${VLLM_COMMIT}" && \ 53 | git clone -b verl_v1 https://github.com/hiyouga/vllm.git && \ 54 | cp -r vllm/vllm/ /usr/local/lib/python3.10/dist-packages/ 55 | 56 | # Install torch-2.5.1 57 | RUN pip install --no-cache-dir "torch==2.5.1" "torchvision==0.20.1" "torchaudio==2.5.1" tensordict torchdata \ 58 | "transformers>=4.49.0" accelerate datasets peft hf-transfer \ 59 | ray[default] codetiming hydra-core pandas "pyarrow>=15.0.0" pylatexenc qwen-vl-utils wandb liger-kernel mathruler \ 60 | pytest yapf py-spy pyext pre-commit ruff 61 | 62 | # Install flash_attn-2.7.4.post1 63 | RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ 64 | pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 65 | 66 | # Fix packages 67 | RUN pip uninstall -y pynvml nvidia-ml-py && \ 68 | pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" 69 | 70 | # Reset pip config 71 | RUN pip config unset global.index-url && \ 72 | pip config unset global.extra-index-url 73 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build commit quality style 2 | 3 | check_dirs := examples scripts verl setup.py 4 | 5 | build: 6 | python3 setup.py sdist bdist_wheel 7 | 8 | commit: 9 | pre-commit install 10 | pre-commit run --all-files 11 | 12 | quality: 13 | ruff check $(check_dirs) 14 | ruff format --check $(check_dirs) 15 | 16 | style: 17 | ruff check $(check_dirs) --fix 18 | ruff format $(check_dirs) 19 | -------------------------------------------------------------------------------- /assets/baselines.md: -------------------------------------------------------------------------------- 1 | # Baselines 2 | 3 | Environment: [hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0](https://hub.docker.com/layers/hiyouga/verl/ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0/images/sha256-335ed6cd1fe73090e458409cfa4394d6abf4cd0503ca44dbafdc28ff72e5ed20) 4 | 5 | EasyR1 version: [v0.3.0](https://github.com/hiyouga/EasyR1/tree/v0.3.0) 6 | 7 | Welcome to contribute new data points! 8 | 9 | ## Algorithm Baselines 10 | 11 | ### [Qwen2.5-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) on [Math12k](https://huggingface.co/datasets/hiyouga/math12k) 12 | 13 | | Size | Algorithm | Bits | LR | KL | Test Score | 14 | | ---- | ----------- | ---- | ---- | ---- | ---------- | 15 | | 7B | GRPO | AMP | 1e-6 | 1e-2 | 0.73->0.79 | 16 | 17 | ### [Qwen2.5-VL-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) on [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k) 18 | 19 | | Size | Algorithm | Bits | LR | KL | Test Score | 20 | | ---- | ----------- | ---- | ---- | ---- | ---------- | 21 | | 7B | GRPO | AMP | 1e-6 | 1e-2 | 0.39->0.52 | 22 | | 7B | GRPO | BF16 | 1e-6 | 1e-2 | 0.39->0.52 | 23 | | 7B | GRPO | AMP | 1e-6 | 1e-3 | 0.39->0.52 | 24 | | 7B | RLOO | AMP | 1e-6 | 1e-2 | 0.39->0.53 | 25 | | 3B | GRPO | AMP | 1e-6 | 1e-2 | 0.27->0.44 | 26 | | 32B | GRPO | BF16 | 1e-6 | 1e-2 | 0.46->0.61 | 27 | 28 | > [!NOTE] 29 | > The hyper-parameters not listed are all the same as the default values. 30 | 31 | ## Performance Baselines 32 | 33 | ### [Qwen2.5-VL-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) on [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k) 34 | 35 | | Size | GPU Type | Bits | Batch Size | vLLM Util | vLLM TP | Peak Mem | Peak VRAM | Throughput | Sec per step | Actor MFU | 36 | | ---- | ------------- | ---- | ---------- | --------- | ------- | -------- | --------- | ---------- | ------------ | --------- | 37 | | 3B | 8 * H100 80GB | AMP | 4 / 16 | 0.6 | 2 | 120GB | 35GB | 1200 | 180s | 6.3% | 38 | | 7B | 8 * H100 80GB | AMP | 4 / 16 | 0.6 | 2 | 140GB | 60GB | 1200 | 180s | 13.6% | 39 | | 7B | 8 * H100 80GB | AMP | 10 / 20 | 0.6 | 2 | 150GB | 75GB | 1400 | 170s | 19.2% | 40 | | 7B | 8 * L20 48GB | AMP | 4 / 16 | 0.6 | 2 | 150GB | 44GB | 410 | 580s | 26.5% | 41 | | 7B | 8 * H100 80GB | BF16 | 4 / 16 | 0.6 | 2 | 150GB | 50GB | 1280 | 190s | 13.9% | 42 | | 32B | 8 * H100 80GB | BF16 | 1 / 8 | 0.6 | 8 | 240GB | 68GB | 360 | 860s | 11.2% | 43 | 44 | - Batch Size: micro_batch_size_per_device_for_update / micro_batch_size_per_device_for_experience 45 | - vLLM Util: rollout.gpu_memory_utilization 46 | - vLLM TP: rollout.tensor_parallel_size 47 | - Peak Mem: Peak CPU memory usage 48 | - Peak VRAM: Peak GPU memory usage 49 | - Throughput: Number of tokens per second per GPU by one training step 50 | - Sec per step: Average time per step in seconds 51 | 52 | > [!NOTE] 53 | > The hyper-parameters not listed are all the same as the default values. 54 | -------------------------------------------------------------------------------- /assets/easyr1_grpo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiyouga/EasyR1/c835ae63c2f302b1491529527fc39127c6e2379b/assets/easyr1_grpo.png -------------------------------------------------------------------------------- /assets/qwen2_5_vl_7b_geo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiyouga/EasyR1/c835ae63c2f302b1491529527fc39127c6e2379b/assets/qwen2_5_vl_7b_geo.png -------------------------------------------------------------------------------- /assets/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiyouga/EasyR1/c835ae63c2f302b1491529527fc39127c6e2379b/assets/wechat.jpg -------------------------------------------------------------------------------- /examples/baselines/qwen2_5_vl_3b_clevr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | data.train_files=BUAADreamer/clevr_count_70k@train \ 12 | data.val_files=BUAADreamer/clevr_count_70k@test \ 13 | data.format_prompt=./examples/format_prompt/r1v_format.jinja \ 14 | worker.actor.model.model_path=${MODEL_PATH} \ 15 | worker.rollout.tensor_parallel_size=1 \ 16 | worker.reward.reward_type=sequential \ 17 | worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \ 18 | trainer.experiment_name=qwen2_5_vl_3b_clevr \ 19 | trainer.n_gpus_per_node=2 20 | -------------------------------------------------------------------------------- /examples/baselines/qwen2_5_vl_3b_geoqa8k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | data.train_files=leonardPKU/GEOQA_8K_R1V@train \ 12 | data.val_files=leonardPKU/GEOQA_8K_R1V@test \ 13 | data.format_prompt=./examples/format_prompt/r1v_format.jinja \ 14 | worker.actor.model.model_path=${MODEL_PATH} \ 15 | worker.rollout.tensor_parallel_size=1 \ 16 | worker.reward.reward_type=sequential \ 17 | worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \ 18 | trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \ 19 | trainer.n_gpus_per_node=8 20 | -------------------------------------------------------------------------------- /examples/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_files: hiyouga/math12k@train 3 | val_files: hiyouga/math12k@test 4 | prompt_key: problem 5 | answer_key: answer 6 | image_key: images 7 | max_prompt_length: 2048 8 | max_response_length: 2048 9 | rollout_batch_size: 512 10 | val_batch_size: 1024 11 | format_prompt: ./examples/format_prompt/math_format.jinja 12 | override_chat_template: null 13 | shuffle: true 14 | seed: 1 15 | max_pixels: 4194304 16 | min_pixels: 262144 17 | filter_overlong_prompts: true 18 | 19 | algorithm: 20 | adv_estimator: grpo 21 | disable_kl: false 22 | use_kl_loss: true 23 | kl_penalty: low_var_kl 24 | kl_coef: 1.0e-2 25 | 26 | worker: 27 | actor: 28 | global_batch_size: 128 29 | micro_batch_size_per_device_for_update: 4 30 | micro_batch_size_per_device_for_experience: 16 31 | max_grad_norm: 1.0 32 | padding_free: true 33 | ulysses_sequence_parallel_size: 1 34 | model: 35 | model_path: Qwen/Qwen2.5-7B-Instruct 36 | enable_gradient_checkpointing: true 37 | trust_remote_code: false 38 | freeze_vision_tower: false 39 | optim: 40 | lr: 1.0e-6 41 | weight_decay: 1.0e-2 42 | strategy: adamw # {adamw, adamw_bf16} 43 | lr_warmup_ratio: 0.0 44 | fsdp: 45 | enable_full_shard: true 46 | enable_cpu_offload: false 47 | enable_rank0_init: true 48 | offload: 49 | offload_params: true # true: more CPU memory; false: more GPU memory 50 | offload_optimizer: true # true: more CPU memory; false: more GPU memory 51 | 52 | rollout: 53 | n: 5 54 | temperature: 1.0 55 | top_p: 0.99 56 | gpu_memory_utilization: 0.6 57 | enforce_eager: false 58 | enable_chunked_prefill: false 59 | tensor_parallel_size: 2 60 | limit_images: 0 61 | val_override_config: 62 | temperature: 0.5 63 | n: 1 64 | 65 | ref: 66 | fsdp: 67 | enable_full_shard: true 68 | enable_cpu_offload: true # true: more CPU memory; false: more GPU memory 69 | enable_rank0_init: true 70 | offload: 71 | offload_params: false 72 | 73 | reward: 74 | reward_type: batch 75 | reward_function: ./examples/reward_function/math.py:compute_score 76 | 77 | trainer: 78 | total_epochs: 15 79 | max_steps: null 80 | project_name: easy_r1 81 | experiment_name: qwen2_5_7b_math_grpo 82 | logger: ["console", "wandb"] 83 | nnodes: 1 84 | n_gpus_per_node: 8 85 | val_freq: 5 # -1 to disable 86 | val_before_train: true 87 | val_only: false 88 | val_generations_to_log: 3 89 | save_freq: 5 # -1 to disable 90 | save_limit: 3 # -1 to disable 91 | save_checkpoint_path: null 92 | load_checkpoint_path: null 93 | -------------------------------------------------------------------------------- /examples/format_prompt/math_format.jinja: -------------------------------------------------------------------------------- 1 | {{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}. 2 | -------------------------------------------------------------------------------- /examples/format_prompt/r1v_format.jinja: -------------------------------------------------------------------------------- 1 | {{ content | trim }} A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here 2 | -------------------------------------------------------------------------------- /examples/qwen2_5_7b_math_grpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen2.5-7B-Instruct # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | worker.actor.model.model_path=${MODEL_PATH} 12 | -------------------------------------------------------------------------------- /examples/qwen2_5_vl_32b_geo3k_grpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen2.5-VL-32B-Instruct # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | data.train_files=hiyouga/geometry3k@train \ 12 | data.val_files=hiyouga/geometry3k@test \ 13 | worker.actor.model.model_path=${MODEL_PATH} \ 14 | worker.actor.micro_batch_size_per_device_for_update=1 \ 15 | worker.actor.micro_batch_size_per_device_for_experience=8 \ 16 | worker.actor.fsdp.torch_dtype=bf16 \ 17 | worker.actor.optim.strategy=adamw_bf16 \ 18 | worker.rollout.tensor_parallel_size=8 \ 19 | trainer.experiment_name=qwen2_5_vl_32b_geo_grpo \ 20 | trainer.n_gpus_per_node=8 21 | -------------------------------------------------------------------------------- /examples/qwen2_5_vl_3b_geo3k_grpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | data.train_files=hiyouga/geometry3k@train \ 12 | data.val_files=hiyouga/geometry3k@test \ 13 | worker.actor.model.model_path=${MODEL_PATH} \ 14 | worker.rollout.tensor_parallel_size=1 \ 15 | trainer.experiment_name=qwen2_5_vl_3b_geo_grpo \ 16 | trainer.n_gpus_per_node=2 17 | -------------------------------------------------------------------------------- /examples/qwen2_5_vl_7b_geo3k_grpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | data.train_files=hiyouga/geometry3k@train \ 12 | data.val_files=hiyouga/geometry3k@test \ 13 | worker.actor.model.model_path=${MODEL_PATH} \ 14 | trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \ 15 | trainer.n_gpus_per_node=8 16 | -------------------------------------------------------------------------------- /examples/qwen2_5_vl_7b_geo3k_reinforce.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | data.train_files=hiyouga/geometry3k@train \ 12 | data.val_files=hiyouga/geometry3k@test \ 13 | worker.actor.model.model_path=${MODEL_PATH} \ 14 | algorithm.adv_estimator=reinforce_plus_plus \ 15 | algorithm.use_kl_loss=false \ 16 | algorithm.kl_penalty=kl \ 17 | algorithm.kl_coef=1.0e-3 \ 18 | trainer.experiment_name=qwen2_5_vl_7b_geo_reinforce_pp \ 19 | trainer.n_gpus_per_node=8 20 | -------------------------------------------------------------------------------- /examples/qwen2_5_vl_7b_geo3k_swanlab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | data.train_files=hiyouga/geometry3k@train \ 12 | data.val_files=hiyouga/geometry3k@test \ 13 | worker.actor.model.model_path=${MODEL_PATH} \ 14 | trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \ 15 | trainer.logger=['console','swanlab'] \ 16 | trainer.n_gpus_per_node=8 17 | -------------------------------------------------------------------------------- /examples/qwen2_5_vl_7b_multi_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # REMINDER: this script uses test data split and should ONLY be used for debugging. DO NOT use for training. 3 | 4 | set -x 5 | 6 | export PYTHONUNBUFFERED=1 7 | 8 | MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path 9 | 10 | python3 -m verl.trainer.main \ 11 | config=examples/config.yaml \ 12 | data.train_files=hiyouga/journeybench-multi-image-vqa@train \ 13 | data.val_files=hiyouga/journeybench-multi-image-vqa@test \ 14 | data.rollout_batch_size=256 \ 15 | worker.actor.model.model_path=${MODEL_PATH} \ 16 | worker.rollout.limit_images=2 \ 17 | trainer.experiment_name=qwen2_5_vl_7b_multi_image \ 18 | trainer.n_gpus_per_node=8 19 | -------------------------------------------------------------------------------- /examples/qwen3_4b_math_grpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONUNBUFFERED=1 6 | 7 | MODEL_PATH=Qwen/Qwen3-4B # replace it with your local file path 8 | 9 | python3 -m verl.trainer.main \ 10 | config=examples/config.yaml \ 11 | data.max_response_length=4096 \ 12 | worker.actor.model.model_path=${MODEL_PATH} \ 13 | trainer.experiment_name=qwen3_4b_math_grpo 14 | -------------------------------------------------------------------------------- /examples/reward_function/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from typing import Dict, List 17 | 18 | from mathruler.grader import extract_boxed_content, grade_answer 19 | 20 | 21 | def format_reward(predict: str) -> float: 22 | pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) 23 | format_match = re.fullmatch(pattern, predict) 24 | return 1.0 if format_match else 0.0 25 | 26 | 27 | def accuracy_reward(predict: str, ground_truth: str) -> float: 28 | answer = extract_boxed_content(predict) 29 | return 1.0 if grade_answer(answer, ground_truth) else 0.0 30 | 31 | 32 | def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]: 33 | scores = [] 34 | for predict, ground_truth in zip(predicts, ground_truths): 35 | predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format 36 | format_score = format_reward(predict) 37 | accuracy_score = accuracy_reward(predict, ground_truth) 38 | scores.append( 39 | { 40 | "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, 41 | "format": format_score, 42 | "accuracy": accuracy_score, 43 | } 44 | ) 45 | 46 | return scores 47 | -------------------------------------------------------------------------------- /examples/reward_function/r1v.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from typing import Dict 17 | 18 | from mathruler.grader import grade_answer 19 | 20 | 21 | def format_reward(predict: str) -> float: 22 | pattern = re.compile(r".*?\s*.*?", re.DOTALL) 23 | format_match = re.fullmatch(pattern, predict) 24 | return 1.0 if format_match else 0.0 25 | 26 | 27 | def accuracy_reward(predict: str, ground_truth: str) -> float: 28 | try: 29 | content_match = re.search(r"(.*?)", predict) 30 | given_answer = content_match.group(1).strip() if content_match else predict.strip() 31 | if grade_answer(given_answer, ground_truth.strip()): 32 | return 1.0 33 | 34 | except Exception: 35 | pass 36 | 37 | return 0.0 38 | 39 | 40 | def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]: 41 | format_score = format_reward(predict) 42 | accuracy_score = accuracy_reward(predict, ground_truth) 43 | return { 44 | "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, 45 | "format": format_score, 46 | "accuracy": accuracy_score, 47 | } 48 | -------------------------------------------------------------------------------- /examples/runtime_env.yaml: -------------------------------------------------------------------------------- 1 | working_dir: ./ 2 | excludes: ["/.git/"] 3 | env_vars: 4 | TOKENIZERS_PARALLELISM: "true" 5 | NCCL_DEBUG: "WARN" 6 | VLLM_LOGGING_LEVEL: "WARN" 7 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1" 8 | PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" 9 | PYTHONUNBUFFERED: "1" 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "verl" 7 | dynamic = [ 8 | "version", 9 | "dependencies", 10 | "optional-dependencies", 11 | "requires-python", 12 | "authors", 13 | "description", 14 | "readme", 15 | "license" 16 | ] 17 | 18 | [tool.ruff] 19 | target-version = "py39" 20 | line-length = 119 21 | indent-width = 4 22 | 23 | [tool.ruff.lint] 24 | ignore = ["C901", "E501", "E741", "W605", "C408"] 25 | select = ["C", "E", "F", "I", "W", "RUF022"] 26 | 27 | [tool.ruff.lint.per-file-ignores] 28 | "__init__.py" = ["E402", "F401", "F403", "F811"] 29 | 30 | [tool.ruff.lint.isort] 31 | lines-after-imports = 2 32 | known-first-party = ["verl"] 33 | known-third-party = ["torch", "transformers", "wandb"] 34 | 35 | [tool.ruff.format] 36 | quote-style = "double" 37 | indent-style = "space" 38 | skip-magic-trailing-comma = false 39 | line-ending = "auto" 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | codetiming 3 | datasets 4 | flash-attn>=2.4.3 5 | liger-kernel 6 | mathruler 7 | numpy 8 | omegaconf 9 | pandas 10 | peft 11 | pillow 12 | pyarrow>=15.0.0 13 | pylatexenc 14 | qwen-vl-utils 15 | ray[default] 16 | tensordict 17 | torchdata 18 | transformers>=4.51.0 19 | vllm>=0.7.3 20 | wandb 21 | -------------------------------------------------------------------------------- /scripts/model_merger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import re 18 | from concurrent.futures import ThreadPoolExecutor 19 | from typing import Dict, List, Tuple 20 | 21 | import numpy as np 22 | import torch 23 | from torch.distributed._tensor import DTensor, Placement, Shard 24 | from transformers import ( 25 | AutoConfig, 26 | AutoModelForCausalLM, 27 | AutoModelForTokenClassification, 28 | AutoModelForVision2Seq, 29 | PretrainedConfig, 30 | PreTrainedModel, 31 | ) 32 | 33 | 34 | def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): 35 | if placement.is_replicate(): 36 | return tensors[0] 37 | elif placement.is_partial(): 38 | raise NotImplementedError("Partial placement is not supported yet") 39 | elif placement.is_shard(): 40 | return torch.cat(tensors, dim=placement.dim).contiguous() 41 | else: 42 | raise ValueError(f"Unsupported placement: {placement}") 43 | 44 | 45 | def upload_model_to_huggingface(local_path: str, remote_path: str): 46 | # Push to hugging face 47 | from huggingface_hub import HfApi 48 | 49 | api = HfApi() 50 | api.create_repo(repo_id=remote_path, private=False, exist_ok=True) 51 | api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model") 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model") 57 | parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") 58 | args = parser.parse_args() 59 | local_dir: str = args.local_dir 60 | 61 | assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface." 62 | 63 | # copy rank zero to find the shape of (dp, fsdp) 64 | rank = 0 65 | world_size = 0 66 | for filename in os.listdir(local_dir): 67 | match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) 68 | if match: 69 | world_size = match.group(1) 70 | break 71 | 72 | assert world_size, "No model file with the proper format." 73 | 74 | rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") 75 | state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False) 76 | pivot_key = sorted(state_dict.keys())[0] 77 | weight = state_dict[pivot_key] 78 | if isinstance(weight, DTensor): 79 | # get sharding info 80 | device_mesh = weight.device_mesh 81 | mesh = device_mesh.mesh 82 | mesh_dim_names = device_mesh.mesh_dim_names 83 | else: 84 | # for non-DTensor 85 | mesh = np.array([int(world_size)], dtype=np.int64) 86 | mesh_dim_names = ("fsdp",) 87 | 88 | print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") 89 | 90 | assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}." 91 | 92 | if "tp" in mesh_dim_names: 93 | # fsdp * tp 94 | total_shards = mesh.shape[-1] * mesh.shape[-2] 95 | mesh_shape = (mesh.shape[-2], mesh.shape[-1]) 96 | else: 97 | # fsdp 98 | total_shards = mesh.shape[-1] 99 | mesh_shape = (mesh.shape[-1],) 100 | 101 | print(f"Processing {total_shards} model shards in total.") 102 | model_state_dict_lst = [] 103 | model_state_dict_lst.append(state_dict) 104 | model_state_dict_lst.extend([""] * (total_shards - 1)) 105 | 106 | def process_one_shard(rank, model_state_dict_lst): 107 | model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") 108 | state_dict = torch.load(model_path, map_location="cpu", weights_only=False) 109 | model_state_dict_lst[rank] = state_dict 110 | return state_dict 111 | 112 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: 113 | for rank in range(1, total_shards): 114 | executor.submit(process_one_shard, rank, model_state_dict_lst) 115 | 116 | state_dict: Dict[str, List[torch.Tensor]] = {} 117 | param_placements: Dict[str, List[Placement]] = {} 118 | keys = set(model_state_dict_lst[0].keys()) 119 | for key in keys: 120 | state_dict[key] = [] 121 | for model_state_dict in model_state_dict_lst: 122 | try: 123 | tensor = model_state_dict.pop(key) 124 | except Exception: 125 | print(f"Cannot find key {key} in rank {rank}.") 126 | 127 | if isinstance(tensor, DTensor): 128 | state_dict[key].append(tensor._local_tensor.bfloat16()) 129 | placements = tuple(tensor.placements) 130 | # replicated placement at ddp dimension can be discarded 131 | if mesh_dim_names[0] == "ddp": 132 | placements = placements[1:] 133 | 134 | if key not in param_placements: 135 | param_placements[key] = placements 136 | else: 137 | assert param_placements[key] == placements 138 | else: 139 | state_dict[key].append(tensor.bfloat16()) 140 | 141 | del model_state_dict_lst 142 | 143 | for key in sorted(state_dict): 144 | if not isinstance(state_dict[key], list): 145 | print(f"No need to merge key {key}") 146 | continue 147 | 148 | if key in param_placements: 149 | # merge shards 150 | placements: Tuple[Shard] = param_placements[key] 151 | if len(mesh_shape) == 1: 152 | # 1-D list, FSDP without TP 153 | assert len(placements) == 1 154 | shards = state_dict[key] 155 | state_dict[key] = merge_by_placement(shards, placements[0]) 156 | else: 157 | # 2-D list, FSDP + TP 158 | raise NotImplementedError("FSDP + TP is not supported yet.") 159 | else: 160 | state_dict[key] = torch.cat(state_dict[key], dim=0) 161 | 162 | print("Merge completed.") 163 | hf_path = os.path.join(local_dir, "huggingface") 164 | config: PretrainedConfig = AutoConfig.from_pretrained(hf_path) 165 | architectures: List[str] = getattr(config, "architectures", ["Unknown"]) 166 | 167 | if "ForTokenClassification" in architectures[0]: 168 | AutoClass = AutoModelForTokenClassification 169 | elif "ForCausalLM" in architectures[0]: 170 | AutoClass = AutoModelForCausalLM 171 | elif "ForConditionalGeneration" in architectures[0]: 172 | AutoClass = AutoModelForVision2Seq 173 | else: 174 | raise NotImplementedError(f"Unknown architecture {architectures}.") 175 | 176 | with torch.device("meta"): 177 | model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16) 178 | 179 | assert isinstance(model, PreTrainedModel) 180 | model.to_empty(device="cpu") 181 | 182 | print(f"Saving model to {hf_path}...") 183 | model.save_pretrained(hf_path, state_dict=state_dict) 184 | del state_dict, model 185 | 186 | if args.hf_upload_path: 187 | upload_model_to_huggingface(hf_path, args.hf_upload_path) 188 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import re 17 | 18 | from setuptools import find_packages, setup 19 | 20 | 21 | def get_version() -> str: 22 | with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f: 23 | file_content = f.read() 24 | pattern = r"__version__\W*=\W*\"([^\"]+)\"" 25 | (version,) = re.findall(pattern, file_content) 26 | return version 27 | 28 | 29 | def get_requires() -> list[str]: 30 | with open("requirements.txt", encoding="utf-8") as f: 31 | file_content = f.read() 32 | lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] 33 | return lines 34 | 35 | 36 | extra_require = { 37 | "dev": ["pre-commit", "ruff"], 38 | } 39 | 40 | 41 | def main(): 42 | setup( 43 | name="verl", 44 | version=get_version(), 45 | description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL", 46 | long_description=open("README.md", encoding="utf-8").read(), 47 | long_description_content_type="text/markdown", 48 | author="verl", 49 | author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn", 50 | license="Apache 2.0 License", 51 | url="https://github.com/volcengine/verl", 52 | package_dir={"": "."}, 53 | packages=find_packages(where="."), 54 | python_requires=">=3.9.0", 55 | install_requires=get_requires(), 56 | extras_require=extra_require, 57 | ) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /verl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from .utils.py_functional import is_package_available 18 | 19 | 20 | if is_package_available("modelscope"): 21 | from modelscope.utils.hf_util import patch_hub # type: ignore 22 | 23 | 24 | __version__ = "0.3.1.dev0" 25 | 26 | 27 | if os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "y", "1"]: 28 | # Patch hub to download models from modelscope to speed up. 29 | if not is_package_available("modelscope"): 30 | raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope`.") 31 | 32 | patch_hub() 33 | -------------------------------------------------------------------------------- /verl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS 17 | 18 | from .transformers.flash_attention_utils import flash_attention_forward 19 | from .transformers.qwen2_vl import qwen2_vl_attn_forward 20 | 21 | 22 | def apply_ulysses_patch(model_type: str) -> None: 23 | if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"): 24 | ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward 25 | elif model_type in ("qwen2_vl", "qwen2_5_vl"): 26 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 27 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 28 | 29 | Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward 30 | Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward 31 | else: 32 | raise NotImplementedError(f"Model architecture {model_type} is not supported yet.") 33 | -------------------------------------------------------------------------------- /verl/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/transformers/flash_attention_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team 2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 3 | # Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import inspect 18 | import os 19 | from typing import Optional, Tuple 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check 24 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 25 | 26 | from ...utils.ulysses import ( 27 | gather_heads_scatter_seq, 28 | gather_seq_scatter_heads, 29 | get_ulysses_sequence_parallel_group, 30 | get_ulysses_sequence_parallel_world_size, 31 | ) 32 | 33 | 34 | if is_flash_attn_2_available(): 35 | from flash_attn import flash_attn_func, flash_attn_varlen_func 36 | 37 | _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters 38 | _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters 39 | _flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" 40 | _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 41 | 42 | 43 | def prepare_fa2_from_position_ids( 44 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor 45 | ): 46 | query = query.view(-1, query.size(-2), query.size(-1)) 47 | key = key.contiguous().view(-1, key.size(-2), key.size(-1)) 48 | value = value.contiguous().view(-1, value.size(-2), value.size(-1)) 49 | position_ids = position_ids.flatten() 50 | indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) 51 | cu_seqlens = torch.cat( 52 | ( 53 | indices_q[position_ids == 0], 54 | torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), 55 | ) 56 | ) 57 | max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope 58 | return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) 59 | 60 | 61 | def _custom_flash_attention_forward( 62 | query_states: torch.Tensor, 63 | key_states: torch.Tensor, 64 | value_states: torch.Tensor, 65 | attention_mask: Optional[torch.Tensor], 66 | query_length: int, 67 | is_causal: bool = True, 68 | position_ids: Optional[torch.Tensor] = None, 69 | sliding_window: Optional[int] = None, 70 | use_top_left_mask: bool = False, 71 | deterministic: Optional[bool] = None, 72 | **kwargs, 73 | ): 74 | """ 75 | Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) 76 | """ 77 | if not use_top_left_mask: 78 | causal = is_causal 79 | else: 80 | causal = is_causal and query_length != 1 81 | 82 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 83 | use_sliding_windows = ( 84 | _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window 85 | ) 86 | flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} 87 | 88 | if _flash_supports_deterministic: 89 | flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled 90 | 91 | if kwargs.get("softcap") is not None: 92 | flash_kwargs["softcap"] = kwargs.pop("softcap") 93 | 94 | query_states, key_states, value_states = fa_peft_integration_check( 95 | query_states, key_states, value_states, target_dtype=torch.bfloat16 96 | ) 97 | 98 | sp_size = get_ulysses_sequence_parallel_world_size() 99 | if sp_size > 1: 100 | # (batch_size, seq_length, num_head, head_size) 101 | query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) 102 | key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) 103 | value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) 104 | position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)] 105 | position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group()) 106 | position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length) 107 | 108 | if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope 109 | position_ids = position_ids[0] 110 | 111 | if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): 112 | batch_size = query_states.size(0) 113 | query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( 114 | query_states, key_states, value_states, position_ids 115 | ) 116 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 117 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 118 | attn_output = flash_attn_varlen_func( 119 | query_states, 120 | key_states, 121 | value_states, 122 | cu_seqlens_q=cu_seqlens_q, 123 | cu_seqlens_k=cu_seqlens_k, 124 | max_seqlen_q=max_seqlen_in_batch_q, 125 | max_seqlen_k=max_seqlen_in_batch_k, 126 | dropout_p=kwargs.pop("dropout", 0.0), 127 | softmax_scale=kwargs.pop("softmax_scale", None), 128 | causal=causal, 129 | **flash_kwargs, 130 | ) 131 | attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) 132 | else: 133 | attn_output = _flash_attention_forward( 134 | query_states, 135 | key_states, 136 | value_states, 137 | attention_mask, 138 | query_length, 139 | is_causal=is_causal, 140 | sliding_window=sliding_window, 141 | use_top_left_mask=use_top_left_mask, 142 | deterministic=deterministic, 143 | **kwargs, 144 | ) # do not pass position_ids to old flash_attention_forward 145 | 146 | if sp_size > 1: 147 | # (batch_size, seq_length, num_head, head_size) 148 | attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) 149 | 150 | return attn_output 151 | 152 | 153 | def flash_attention_forward( 154 | module: torch.nn.Module, 155 | query: torch.Tensor, 156 | key: torch.Tensor, 157 | value: torch.Tensor, 158 | attention_mask: Optional[torch.Tensor], 159 | dropout: float = 0.0, 160 | scaling: Optional[float] = None, 161 | sliding_window: Optional[int] = None, 162 | softcap: Optional[float] = None, 163 | **kwargs, 164 | ) -> Tuple[torch.Tensor, None]: 165 | # This is before the transpose 166 | q_len = query.shape[2] 167 | 168 | # FA2 uses non-transposed inputs 169 | query = query.transpose(1, 2) 170 | key = key.transpose(1, 2) 171 | value = value.transpose(1, 2) 172 | 173 | # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice 174 | kwargs.pop("is_causal", None) 175 | 176 | attn_output = _custom_flash_attention_forward( 177 | query, 178 | key, 179 | value, 180 | attention_mask, 181 | query_length=q_len, 182 | is_causal=True, 183 | dropout=dropout, 184 | softmax_scale=scaling, 185 | sliding_window=sliding_window, 186 | softcap=softcap, 187 | use_top_left_mask=_flash_use_top_left_mask, 188 | **kwargs, 189 | ) 190 | 191 | return attn_output, None 192 | -------------------------------------------------------------------------------- /verl/models/transformers/qwen2_vl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team 2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 3 | # Based on: 4 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import Optional, Tuple 19 | 20 | import torch 21 | 22 | from .flash_attention_utils import flash_attention_forward 23 | 24 | 25 | try: 26 | from transformers.models.qwen2_vl.modeling_qwen2_vl import ( 27 | Qwen2VLAttention, 28 | apply_multimodal_rotary_pos_emb, 29 | repeat_kv, 30 | ) 31 | from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor 32 | except ImportError: 33 | pass 34 | 35 | 36 | def get_rope_index( 37 | processor: "Qwen2VLProcessor", 38 | input_ids: torch.Tensor, 39 | image_grid_thw: Optional[torch.Tensor] = None, 40 | video_grid_thw: Optional[torch.Tensor] = None, 41 | second_per_grid_ts: Optional[torch.Tensor] = None, 42 | attention_mask: Optional[torch.Tensor] = None, 43 | ) -> torch.Tensor: 44 | """ 45 | Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. 46 | The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. 47 | https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 48 | """ 49 | spatial_merge_size = processor.image_processor.merge_size 50 | tokens_per_second = 2 51 | image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") 52 | video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") 53 | vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") 54 | if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): 55 | if attention_mask is None: 56 | attention_mask = torch.ones_like(input_ids) 57 | 58 | position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) 59 | image_index, video_index = 0, 0 60 | input_ids = input_ids[attention_mask == 1] 61 | image_nums, video_nums = 0, 0 62 | vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) 63 | vision_tokens = input_ids[vision_start_indices + 1] 64 | image_nums = (vision_tokens == image_token_id).sum() 65 | video_nums = (vision_tokens == video_token_id).sum() 66 | input_tokens = input_ids.tolist() 67 | llm_pos_ids_list: list = [] 68 | st = 0 69 | remain_images, remain_videos = image_nums, video_nums 70 | for _ in range(image_nums + video_nums): 71 | if image_token_id in input_tokens and remain_images > 0: 72 | ed_image = input_tokens.index(image_token_id, st) 73 | else: 74 | ed_image = len(input_tokens) + 1 75 | if video_token_id in input_tokens and remain_videos > 0: 76 | ed_video = input_tokens.index(video_token_id, st) 77 | else: 78 | ed_video = len(input_tokens) + 1 79 | if ed_image < ed_video: 80 | t, h, w = ( 81 | image_grid_thw[image_index][0], 82 | image_grid_thw[image_index][1], 83 | image_grid_thw[image_index][2], 84 | ) 85 | second_per_grid_t = 0 86 | image_index += 1 87 | remain_images -= 1 88 | ed = ed_image 89 | else: 90 | t, h, w = ( 91 | video_grid_thw[video_index][0], 92 | video_grid_thw[video_index][1], 93 | video_grid_thw[video_index][2], 94 | ) 95 | if second_per_grid_ts is not None: 96 | second_per_grid_t = second_per_grid_ts[video_index] 97 | else: 98 | second_per_grid_t = 1.0 99 | 100 | video_index += 1 101 | remain_videos -= 1 102 | ed = ed_video 103 | 104 | llm_grid_t, llm_grid_h, llm_grid_w = ( 105 | t.item(), 106 | h.item() // spatial_merge_size, 107 | w.item() // spatial_merge_size, 108 | ) 109 | text_len = ed - st 110 | 111 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 112 | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) 113 | 114 | t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) 115 | t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() 116 | h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() 117 | w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() 118 | llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) 119 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w 120 | 121 | if st < len(input_tokens): 122 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 123 | text_len = len(input_tokens) - st 124 | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) 125 | 126 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) 127 | position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) 128 | else: 129 | if attention_mask is not None: 130 | position_ids = attention_mask.long().cumsum(-1) - 1 131 | position_ids.masked_fill_(attention_mask == 0, 1) 132 | position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) 133 | else: 134 | position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) 135 | 136 | return position_ids 137 | 138 | 139 | def qwen2_vl_attn_forward( 140 | self: "Qwen2VLAttention", 141 | hidden_states: torch.Tensor, 142 | attention_mask: Optional[torch.Tensor] = None, 143 | position_ids: Optional[torch.LongTensor] = None, 144 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 145 | **kwargs, 146 | ) -> Tuple[torch.Tensor, None, None]: 147 | bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size 148 | query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) 149 | key_states = self.k_proj(hidden_states) 150 | value_states = self.v_proj(hidden_states) 151 | 152 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 153 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 154 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 155 | 156 | # Because the input can be padded, the absolute sequence length depends on the max position id. 157 | if position_embeddings is None: 158 | cos, sin = self.rotary_emb(value_states, position_ids) 159 | else: 160 | cos, sin = position_embeddings 161 | 162 | query_states, key_states = apply_multimodal_rotary_pos_emb( 163 | query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] 164 | ) 165 | key_states = repeat_kv(key_states, self.num_key_value_groups) 166 | value_states = repeat_kv(value_states, self.num_key_value_groups) 167 | dropout_rate = 0.0 if not self.training else self.attention_dropout 168 | 169 | sliding_window = None 170 | if ( 171 | self.config.use_sliding_window 172 | and getattr(self.config, "sliding_window", None) is not None 173 | and self.layer_idx >= self.config.max_window_layers 174 | ): 175 | sliding_window = self.config.sliding_window 176 | 177 | attn_output, _ = flash_attention_forward( 178 | self, 179 | query_states, 180 | key_states, 181 | value_states, 182 | attention_mask, 183 | dropout=dropout_rate, 184 | sliding_window=sliding_window, 185 | position_ids=position_ids, # important: pass position ids 186 | ) # (batch_size, seq_length, num_head / sp_size, head_size) 187 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 188 | attn_output = self.o_proj(attn_output) 189 | return attn_output, None, None 190 | -------------------------------------------------------------------------------- /verl/single_controller/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/single_controller/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .worker import Worker 16 | from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup 17 | 18 | 19 | __all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"] 20 | -------------------------------------------------------------------------------- /verl/single_controller/base/decorator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum, auto 16 | from functools import wraps 17 | from types import FunctionType 18 | from typing import TYPE_CHECKING, Dict, List, Literal, Union 19 | 20 | import ray 21 | 22 | from ...protocol import DataProto, DataProtoFuture 23 | 24 | 25 | if TYPE_CHECKING: 26 | from .worker_group import WorkerGroup 27 | 28 | 29 | # here we add a magic number of avoid user-defined function already have this attribute 30 | MAGIC_ATTR = "attrs_3141562937" 31 | 32 | 33 | class Dispatch(Enum): 34 | RANK_ZERO = auto() 35 | ONE_TO_ALL = auto() 36 | ALL_TO_ALL = auto() 37 | DP_COMPUTE = auto() 38 | DP_COMPUTE_PROTO = auto() 39 | DP_COMPUTE_PROTO_WITH_FUNC = auto() 40 | DP_COMPUTE_METRIC = auto() 41 | 42 | 43 | class Execute(Enum): 44 | ALL = 0 45 | RANK_ZERO = 1 46 | 47 | 48 | def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs): 49 | splitted_args = [] 50 | for arg in args: 51 | assert isinstance(arg, (DataProto, DataProtoFuture)) 52 | splitted_args.append(arg.chunk(chunks=chunks)) 53 | 54 | splitted_kwargs = {} 55 | for key, value in kwargs.items(): 56 | assert isinstance(value, (DataProto, DataProtoFuture)) 57 | splitted_kwargs[key] = value.chunk(chunks=chunks) 58 | 59 | return splitted_args, splitted_kwargs 60 | 61 | 62 | def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs): 63 | args = tuple([arg] * worker_group.world_size for arg in args) 64 | kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} 65 | return args, kwargs 66 | 67 | 68 | def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs): 69 | return args, kwargs 70 | 71 | 72 | def collect_all_to_all(worker_group: "WorkerGroup", output): 73 | return output 74 | 75 | 76 | def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto: 77 | # make sure all the elements in output has the same type 78 | for output in outputs: 79 | assert type(output) is type(outputs[0]) 80 | 81 | output = outputs[0] 82 | 83 | if isinstance(output, DataProto): 84 | return DataProto.concat(outputs) 85 | elif isinstance(output, ray.ObjectRef): 86 | return DataProtoFuture.concat(outputs) 87 | else: 88 | raise NotImplementedError 89 | 90 | 91 | def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs): 92 | for arg in args: 93 | assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size 94 | 95 | for value in kwargs.values(): 96 | assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size 97 | 98 | return args, kwargs 99 | 100 | 101 | def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]: 102 | assert len(outputs) == worker_group.world_size 103 | return outputs 104 | 105 | 106 | def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs): 107 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) 108 | return splitted_args, splitted_kwargs 109 | 110 | 111 | def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs): 112 | assert type(args[0]) is FunctionType # NOTE: The first one args is a function! 113 | splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) 114 | splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args 115 | return splitted_args_with_func, splitted_kwargs 116 | 117 | 118 | def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto: 119 | for output in outputs: 120 | assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}" 121 | 122 | outputs = collect_dp_compute(worker_group, outputs) 123 | return _concat_data_proto_or_future(outputs) 124 | 125 | 126 | def get_predefined_dispatch_fn(dispatch_mode: Dispatch): 127 | predefined_dispatch_mode_fn = { 128 | Dispatch.ONE_TO_ALL: { 129 | "dispatch_fn": dispatch_one_to_all, 130 | "collect_fn": collect_all_to_all, 131 | }, 132 | Dispatch.ALL_TO_ALL: { 133 | "dispatch_fn": dispatch_all_to_all, 134 | "collect_fn": collect_all_to_all, 135 | }, 136 | Dispatch.DP_COMPUTE: { 137 | "dispatch_fn": dispatch_dp_compute, 138 | "collect_fn": collect_dp_compute, 139 | }, 140 | Dispatch.DP_COMPUTE_PROTO: { 141 | "dispatch_fn": dispatch_dp_compute_data_proto, 142 | "collect_fn": collect_dp_compute_data_proto, 143 | }, 144 | Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { 145 | "dispatch_fn": dispatch_dp_compute_data_proto_with_func, 146 | "collect_fn": collect_dp_compute_data_proto, 147 | }, 148 | Dispatch.DP_COMPUTE_METRIC: { 149 | "dispatch_fn": dispatch_dp_compute_data_proto, 150 | "collect_fn": collect_dp_compute, 151 | }, 152 | } 153 | return predefined_dispatch_mode_fn[dispatch_mode] 154 | 155 | 156 | def get_predefined_execute_fn(execute_mode: Execute): 157 | """ 158 | Note that here we only asks execute_all and execute_rank_zero to be implemented 159 | Leave the choice of how these two functions handle argument 'blocking' to users 160 | """ 161 | predefined_execute_mode_fn = { 162 | Execute.ALL: {"execute_fn_name": "execute_all"}, 163 | Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, 164 | } 165 | return predefined_execute_mode_fn[execute_mode] 166 | 167 | 168 | def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]): 169 | assert isinstance(dispatch_mode, (Dispatch, dict)), ( 170 | f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" 171 | ) 172 | if isinstance(dispatch_mode, dict): 173 | necessary_keys = ["dispatch_fn", "collect_fn"] 174 | for key in necessary_keys: 175 | assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" 176 | 177 | 178 | def _check_execute_mode(execute_mode: Execute): 179 | assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" 180 | 181 | 182 | def _materialize_futures(*args, **kwargs): 183 | new_args = [] 184 | for arg in args: 185 | if isinstance(arg, DataProtoFuture): 186 | arg = arg.get() 187 | # add more type to materialize 188 | new_args.append(arg) 189 | 190 | for key, value in kwargs.items(): 191 | if isinstance(value, DataProtoFuture): 192 | kwargs[key] = value.get() 193 | 194 | new_args = tuple(new_args) 195 | return new_args, kwargs 196 | 197 | 198 | def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): 199 | _check_dispatch_mode(dispatch_mode=dispatch_mode) 200 | _check_execute_mode(execute_mode=execute_mode) 201 | 202 | def decorator(func): 203 | @wraps(func) 204 | def inner(*args, **kwargs): 205 | if materialize_futures: 206 | args, kwargs = _materialize_futures(*args, **kwargs) 207 | return func(*args, **kwargs) 208 | 209 | attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} 210 | setattr(inner, MAGIC_ATTR, attrs) 211 | return inner 212 | 213 | return decorator 214 | -------------------------------------------------------------------------------- /verl/single_controller/base/register_center/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/single_controller/base/register_center/ray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | 18 | @ray.remote 19 | class WorkerGroupRegisterCenter: 20 | def __init__(self, rank_zero_info): 21 | self.rank_zero_info = rank_zero_info 22 | 23 | def get_rank_zero_info(self): 24 | return self.rank_zero_info 25 | 26 | 27 | def create_worker_group_register_center(name, info): 28 | return WorkerGroupRegisterCenter.options(name=name).remote(info) 29 | -------------------------------------------------------------------------------- /verl/single_controller/base/worker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | the class for Worker 16 | """ 17 | 18 | import os 19 | import socket 20 | from dataclasses import dataclass 21 | from typing import Tuple 22 | 23 | import ray 24 | import torch 25 | 26 | from .decorator import Dispatch, Execute, register 27 | from .register_center.ray import create_worker_group_register_center 28 | 29 | 30 | @dataclass 31 | class DistRankInfo: 32 | tp_rank: int 33 | dp_rank: int 34 | pp_rank: int 35 | 36 | 37 | @dataclass 38 | class DistGlobalInfo: 39 | tp_size: int 40 | dp_size: int 41 | pp_size: int 42 | 43 | 44 | class WorkerHelper: 45 | def _get_node_ip(self) -> str: 46 | host_ipv4 = os.getenv("MY_HOST_IP", None) 47 | host_ipv6 = os.getenv("MY_HOST_IPV6", None) 48 | host_ip_by_env = host_ipv4 or host_ipv6 49 | host_ip_by_sdk = ray._private.services.get_node_ip_address() 50 | 51 | host_ip = host_ip_by_env or host_ip_by_sdk 52 | return host_ip 53 | 54 | def _get_free_port(self) -> int: 55 | with socket.socket() as sock: 56 | sock.bind(("", 0)) 57 | return sock.getsockname()[1] 58 | 59 | def get_availale_master_addr_port(self) -> Tuple[str, str]: 60 | return self._get_node_ip(), str(self._get_free_port()) 61 | 62 | def _get_pid(self): 63 | return 64 | 65 | 66 | class WorkerMeta: 67 | keys = [ 68 | "WORLD_SIZE", 69 | "RANK", 70 | "LOCAL_WORLD_SIZE", 71 | "LOCAL_RANK", 72 | "MASTER_ADDR", 73 | "MASTER_PORT", 74 | "CUDA_VISIBLE_DEVICES", 75 | ] 76 | 77 | def __init__(self, store) -> None: 78 | self._store = store 79 | 80 | def to_dict(self): 81 | return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys} 82 | 83 | 84 | # we assume that in each WorkerGroup, there is a Master Worker 85 | class Worker(WorkerHelper): 86 | """A (distributed) worker.""" 87 | 88 | _world_size: int 89 | _rank: int 90 | _local_world_size: int 91 | _local_rank: int 92 | _master_addr: str 93 | _master_port: str 94 | _cuda_visible_devices: str 95 | 96 | def __new__(cls, *args, **kwargs): 97 | instance = super().__new__(cls) 98 | 99 | # note that here we use int to distinguish 100 | disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0)) 101 | if disable_worker_init: 102 | return instance 103 | 104 | rank = os.getenv("RANK", None) 105 | worker_group_prefix = os.getenv("WG_PREFIX", None) 106 | 107 | # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init 108 | if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: 109 | instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) 110 | 111 | return instance 112 | 113 | def _configure_before_init(self, register_center_name: str, rank: int): 114 | assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}" 115 | 116 | if rank == 0: 117 | master_addr, master_port = self.get_availale_master_addr_port() 118 | rank_zero_info = { 119 | "MASTER_ADDR": master_addr, 120 | "MASTER_PORT": master_port, 121 | } 122 | self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info) 123 | os.environ.update(rank_zero_info) 124 | 125 | def __init__(self, cuda_visible_devices=None) -> None: 126 | # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely 127 | world_size = int(os.getenv("WORLD_SIZE")) 128 | rank = int(os.getenv("RANK")) 129 | self._rank = rank 130 | self._world_size = world_size 131 | 132 | if "AMD" in torch.cuda.get_device_name(): 133 | os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES") 134 | os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK") 135 | cuda_visible_devices = os.getenv("LOCAL_RANK", "0") 136 | torch.cuda.set_device(int(cuda_visible_devices)) 137 | 138 | master_addr = os.getenv("MASTER_ADDR") 139 | master_port = os.getenv("MASTER_PORT") 140 | 141 | local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) 142 | local_rank = int(os.getenv("LOCAL_RANK", "0")) 143 | 144 | store = { 145 | "_world_size": world_size, 146 | "_rank": rank, 147 | "_local_world_size": local_world_size, 148 | "_local_rank": local_rank, 149 | "_master_addr": master_addr, 150 | "_master_port": master_port, 151 | } 152 | if cuda_visible_devices is not None: 153 | store["_cuda_visible_devices"] = cuda_visible_devices 154 | 155 | meta = WorkerMeta(store=store) 156 | self._configure_with_meta(meta=meta) 157 | 158 | def _configure_with_meta(self, meta: WorkerMeta): 159 | """ 160 | This function should only be called inside by WorkerGroup 161 | """ 162 | assert isinstance(meta, WorkerMeta) 163 | self.__dict__.update(meta.to_dict()) # this is hacky 164 | # print(f"__dict__: {self.__dict__}") 165 | for key in WorkerMeta.keys: 166 | val = self.__dict__.get(f"_{key.lower()}", None) 167 | if val is not None: 168 | # print(f"set {key} to {val}") 169 | os.environ[key] = str(val) 170 | 171 | os.environ["REDIS_STORE_SERVER_HOST"] = ( 172 | str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" 173 | ) 174 | 175 | def get_master_addr_port(self): 176 | return self._master_addr, self._master_port 177 | 178 | def get_cuda_visible_devices(self): 179 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set") 180 | return cuda_visible_devices 181 | 182 | def print_rank0(self, *args, **kwargs): 183 | if self.rank == 0: 184 | print(*args, **kwargs) 185 | 186 | @property 187 | def world_size(self): 188 | return self._world_size 189 | 190 | @property 191 | def rank(self): 192 | return self._rank 193 | 194 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC) 195 | def execute_with_func_generator(self, func, *args, **kwargs): 196 | ret_proto = func(self, *args, **kwargs) 197 | return ret_proto 198 | 199 | @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) 200 | def execute_func_rank_zero(self, func, *args, **kwargs): 201 | result = func(*args, **kwargs) 202 | return result 203 | -------------------------------------------------------------------------------- /verl/single_controller/base/worker_group.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | the class of WorkerGroup 16 | """ 17 | 18 | import logging 19 | import signal 20 | import threading 21 | import time 22 | from typing import Any, Callable, Dict, List, Optional 23 | 24 | from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn 25 | 26 | 27 | class ResourcePool: 28 | """The resource pool with meta info such as world size.""" 29 | 30 | def __init__( 31 | self, process_on_nodes: Optional[Any] = None, max_colocate_count: int = 10, n_gpus_per_node: int = 8 32 | ) -> None: 33 | if process_on_nodes is None: 34 | process_on_nodes = [] 35 | 36 | self._store = process_on_nodes 37 | self.max_colocate_count = max_colocate_count 38 | self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node 39 | 40 | def add_node(self, process_count): 41 | self._store.append(process_count) 42 | 43 | @property 44 | def world_size(self): 45 | return sum(self._store) 46 | 47 | def __call__(self) -> Any: 48 | return self._store 49 | 50 | @property 51 | def store(self): 52 | return self._store 53 | 54 | def local_world_size_list(self) -> List[int]: 55 | nested_local_world_size_list = [ 56 | [local_world_size for _ in range(local_world_size)] for local_world_size in self._store 57 | ] 58 | return [item for row in nested_local_world_size_list for item in row] 59 | 60 | def local_rank_list(self) -> List[int]: 61 | nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416 62 | return [item for row in nested_local_rank_list for item in row] 63 | 64 | 65 | class ClassWithInitArgs: 66 | """ 67 | This class stores a class constructor and the args/kwargs to construct the class. 68 | It is used to instantiate the remote class. 69 | """ 70 | 71 | def __init__(self, cls, *args, **kwargs) -> None: 72 | self.cls = cls 73 | self.args = args 74 | self.kwargs = kwargs 75 | 76 | def __call__(self) -> Any: 77 | return self.cls(*self.args, **self.kwargs) 78 | 79 | 80 | def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: 81 | while True: 82 | for worker in workers: 83 | if not is_alive(worker): 84 | logging.warning(f"Worker {worker} is not alive, sending signal to main thread") 85 | signal.raise_signal(signal.SIGABRT) 86 | 87 | time.sleep(gap_time) 88 | 89 | 90 | class WorkerGroup: 91 | """A group of workers""" 92 | 93 | def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: 94 | self._is_init_with_detached_workers = True if resource_pool is None else False 95 | 96 | if resource_pool is not None: 97 | # handle the case when WorkGroup is attached to an existing one 98 | self._procecss_dispatch_config = resource_pool() 99 | else: 100 | self._procecss_dispatch_config = None 101 | 102 | self._workers = [] 103 | self._worker_names = [] 104 | 105 | self._master_addr = None 106 | self._master_port = None 107 | 108 | self._checker_thread: threading.Thread = None 109 | 110 | def _is_worker_alive(self, worker): 111 | raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") 112 | 113 | def _block_until_all_workers_alive(self) -> None: 114 | while True: 115 | all_state = [self._is_worker_alive(worker) for worker in self._workers] 116 | if False in all_state: 117 | time.sleep(1) 118 | else: 119 | break 120 | 121 | def start_worker_aliveness_check(self, every_n_seconds=1) -> None: 122 | # before starting checking worker aliveness, make sure all workers are already alive 123 | self._block_until_all_workers_alive() 124 | 125 | self._checker_thread = threading.Thread( 126 | target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) 127 | ) 128 | self._checker_thread.start() 129 | 130 | @property 131 | def world_size(self): 132 | return len(self._workers) 133 | 134 | def _bind_worker_method(self, user_defined_cls, func_generator): 135 | """ 136 | Bind the worker method to the WorkerGroup 137 | """ 138 | for method_name in dir(user_defined_cls): 139 | try: 140 | method = getattr(user_defined_cls, method_name) 141 | assert callable(method), f"{method_name} in {user_defined_cls} is not callable" 142 | except Exception: 143 | # if it is a property, it will fail because Class doesn't have instance property 144 | continue 145 | 146 | if hasattr(method, MAGIC_ATTR): 147 | # this method is decorated by register 148 | attribute = getattr(method, MAGIC_ATTR) 149 | assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}" 150 | assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" 151 | 152 | dispatch_mode = attribute["dispatch_mode"] 153 | execute_mode = attribute["execute_mode"] 154 | blocking = attribute["blocking"] 155 | 156 | # get dispatch fn 157 | if isinstance(dispatch_mode, Dispatch): 158 | # get default dispatch fn 159 | fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) 160 | dispatch_fn = fn["dispatch_fn"] 161 | collect_fn = fn["collect_fn"] 162 | else: 163 | assert isinstance(dispatch_mode, dict) 164 | assert "dispatch_fn" in dispatch_mode 165 | assert "collect_fn" in dispatch_mode 166 | dispatch_fn = dispatch_mode["dispatch_fn"] 167 | collect_fn = dispatch_mode["collect_fn"] 168 | 169 | # get execute_fn_name 170 | execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) 171 | wg_execute_fn_name = execute_mode["execute_fn_name"] 172 | 173 | # get execute_fn from string 174 | try: 175 | execute_fn = getattr(self, wg_execute_fn_name) 176 | assert callable(execute_fn), "execute_fn must be callable" 177 | except Exception: 178 | print(f"execute_fn {wg_execute_fn_name} is invalid") 179 | raise 180 | 181 | # bind a new method to the RayWorkerGroup 182 | func = func_generator( 183 | self, 184 | method_name, 185 | dispatch_fn=dispatch_fn, 186 | collect_fn=collect_fn, 187 | execute_fn=execute_fn, 188 | blocking=blocking, 189 | ) 190 | 191 | try: 192 | setattr(self, method_name, func) 193 | except Exception: 194 | raise ValueError(f"Fail to set method_name {method_name}") 195 | -------------------------------------------------------------------------------- /verl/single_controller/ray/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls 16 | 17 | 18 | __all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"] 19 | -------------------------------------------------------------------------------- /verl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/trainer/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | PPO config 16 | """ 17 | 18 | import os 19 | from dataclasses import asdict, dataclass, field, fields, is_dataclass 20 | from typing import Optional, Tuple 21 | 22 | from ..workers.config import WorkerConfig 23 | 24 | 25 | def recursive_post_init(dataclass_obj): 26 | if hasattr(dataclass_obj, "post_init"): 27 | dataclass_obj.post_init() 28 | 29 | for attr in fields(dataclass_obj): 30 | if is_dataclass(getattr(dataclass_obj, attr.name)): 31 | recursive_post_init(getattr(dataclass_obj, attr.name)) 32 | 33 | 34 | @dataclass 35 | class DataConfig: 36 | train_files: str = "" 37 | val_files: str = "" 38 | prompt_key: str = "prompt" 39 | answer_key: str = "answer" 40 | image_key: str = "images" 41 | max_prompt_length: int = 512 42 | max_response_length: int = 512 43 | rollout_batch_size: int = 512 44 | val_batch_size: int = -1 45 | format_prompt: Optional[str] = None 46 | override_chat_template: Optional[str] = None 47 | shuffle: bool = True 48 | seed: int = 1 49 | max_pixels: int = 4194304 50 | min_pixels: int = 262144 51 | filter_overlong_prompts: bool = True 52 | 53 | def post_init(self): 54 | if self.format_prompt is not None: 55 | if os.path.exists(self.format_prompt): # ray job uses absolute path 56 | self.format_prompt = os.path.abspath(self.format_prompt) 57 | else: 58 | self.format_prompt = None 59 | 60 | 61 | @dataclass 62 | class AlgorithmConfig: 63 | gamma: float = 1.0 64 | lam: float = 1.0 65 | adv_estimator: str = "grpo" 66 | disable_kl: bool = False 67 | use_kl_loss: bool = False 68 | kl_penalty: str = "kl" 69 | kl_coef: float = 1e-3 70 | kl_type: str = "fixed" 71 | kl_horizon: float = 0.0 72 | kl_target: float = 0.0 73 | 74 | 75 | @dataclass 76 | class TrainerConfig: 77 | total_epochs: int = 10 78 | max_steps: Optional[int] = None 79 | project_name: str = "easy_r1" 80 | experiment_name: str = "demo" 81 | logger: Tuple[str] = ("console", "wandb") 82 | nnodes: int = 1 83 | n_gpus_per_node: int = 8 84 | critic_warmup: int = 0 85 | val_freq: int = -1 86 | val_before_train: bool = True 87 | val_only: bool = False 88 | val_generations_to_log: int = 0 89 | save_freq: int = -1 90 | save_limit: int = -1 91 | save_checkpoint_path: Optional[str] = None 92 | load_checkpoint_path: Optional[str] = None 93 | 94 | def post_init(self): 95 | if self.save_checkpoint_path is None: 96 | self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name) 97 | 98 | self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path 99 | if self.load_checkpoint_path is not None: 100 | self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path) 101 | 102 | 103 | @dataclass 104 | class PPOConfig: 105 | data: DataConfig = field(default_factory=DataConfig) 106 | worker: WorkerConfig = field(default_factory=WorkerConfig) 107 | algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig) 108 | trainer: TrainerConfig = field(default_factory=TrainerConfig) 109 | 110 | def post_init(self): 111 | self.worker.rollout.prompt_length = self.data.max_prompt_length 112 | self.worker.rollout.response_length = self.data.max_response_length 113 | self.worker.rollout.trust_remote_code = self.worker.actor.model.trust_remote_code 114 | self.worker.actor.disable_kl = self.algorithm.disable_kl 115 | self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss 116 | self.worker.actor.kl_penalty = self.algorithm.kl_penalty 117 | self.worker.actor.kl_coef = self.algorithm.kl_coef 118 | 119 | def deep_post_init(self): 120 | recursive_post_init(self) 121 | 122 | def to_dict(self): 123 | return asdict(self) 124 | -------------------------------------------------------------------------------- /verl/trainer/data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | import torch 18 | from torch.utils.data import RandomSampler, SequentialSampler 19 | from torchdata.stateful_dataloader import StatefulDataLoader 20 | from transformers import PreTrainedTokenizer, ProcessorMixin 21 | 22 | from ..utils.dataset import RLHFDataset, collate_fn 23 | from .config import DataConfig 24 | 25 | 26 | def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin]) -> None: 27 | train_dataset = RLHFDataset( 28 | data_path=config.train_files, 29 | tokenizer=tokenizer, 30 | processor=processor, 31 | prompt_key=config.prompt_key, 32 | answer_key=config.answer_key, 33 | image_key=config.image_key, 34 | max_prompt_length=config.max_prompt_length, 35 | truncation="right", 36 | format_prompt=config.format_prompt, 37 | min_pixels=config.min_pixels, 38 | max_pixels=config.max_pixels, 39 | filter_overlong_prompts=config.filter_overlong_prompts, 40 | ) 41 | # use sampler for better ckpt resume 42 | if config.shuffle: 43 | train_dataloader_generator = torch.Generator() 44 | train_dataloader_generator.manual_seed(config.seed) 45 | sampler = RandomSampler(data_source=train_dataset, generator=train_dataloader_generator) 46 | else: 47 | sampler = SequentialSampler(data_source=train_dataset) 48 | 49 | train_dataloader = StatefulDataLoader( 50 | dataset=train_dataset, 51 | batch_size=config.rollout_batch_size, 52 | sampler=sampler, 53 | num_workers=8, 54 | collate_fn=collate_fn, 55 | pin_memory=False, 56 | drop_last=True, 57 | ) 58 | 59 | val_dataset = RLHFDataset( 60 | data_path=config.val_files, 61 | tokenizer=tokenizer, 62 | processor=processor, 63 | prompt_key=config.prompt_key, 64 | answer_key=config.answer_key, 65 | image_key=config.image_key, 66 | max_prompt_length=config.max_prompt_length, 67 | truncation="right", 68 | format_prompt=config.format_prompt, 69 | min_pixels=config.min_pixels, 70 | max_pixels=config.max_pixels, 71 | filter_overlong_prompts=config.filter_overlong_prompts, 72 | ) 73 | val_dataloader = StatefulDataLoader( 74 | dataset=val_dataset, 75 | batch_size=len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size, 76 | shuffle=False, 77 | num_workers=8, 78 | collate_fn=collate_fn, 79 | pin_memory=False, 80 | drop_last=False, 81 | ) 82 | 83 | assert len(train_dataloader) >= 1 84 | assert len(val_dataloader) >= 1 85 | print(f"Size of train dataloader: {len(train_dataloader)}") 86 | print(f"Size of val dataloader: {len(val_dataloader)}") 87 | return train_dataloader, val_dataloader 88 | -------------------------------------------------------------------------------- /verl/trainer/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | 17 | import ray 18 | from omegaconf import OmegaConf 19 | 20 | from ..single_controller.ray import RayWorkerGroup 21 | from ..utils.tokenizer import get_processor, get_tokenizer 22 | from ..workers.fsdp_workers import FSDPWorker 23 | from ..workers.reward import BatchFunctionRewardManager, SequentialFunctionRewardManager 24 | from .config import PPOConfig 25 | from .data_loader import create_dataloader 26 | from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role 27 | 28 | 29 | # please make sure main_task is not scheduled on head 30 | @ray.remote(num_cpus=1) 31 | class Runner: 32 | """A runner for RL training.""" 33 | 34 | def run(self, config: PPOConfig): 35 | # print config 36 | print(json.dumps(config.to_dict(), indent=2)) 37 | 38 | # instantiate tokenizer 39 | tokenizer = get_tokenizer( 40 | config.worker.actor.model.model_path, 41 | override_chat_template=config.data.override_chat_template, 42 | trust_remote_code=config.worker.actor.model.trust_remote_code, 43 | use_fast=True, 44 | ) 45 | processor = get_processor( 46 | config.worker.actor.model.model_path, 47 | override_chat_template=config.data.override_chat_template, 48 | trust_remote_code=config.worker.actor.model.trust_remote_code, 49 | use_fast=True, 50 | ) 51 | 52 | # define worker classes 53 | ray_worker_group_cls = RayWorkerGroup 54 | role_worker_mapping = { 55 | Role.ActorRollout: ray.remote(FSDPWorker), 56 | Role.Critic: ray.remote(FSDPWorker), 57 | Role.RefPolicy: ray.remote(FSDPWorker), 58 | } 59 | global_pool_id = "global_pool" 60 | resource_pool_spec = { 61 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, 62 | } 63 | mapping = { 64 | Role.ActorRollout: global_pool_id, 65 | Role.Critic: global_pool_id, 66 | Role.RefPolicy: global_pool_id, 67 | } 68 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) 69 | 70 | if config.worker.reward.reward_type == "sequential": 71 | RewardManager = SequentialFunctionRewardManager 72 | elif config.worker.reward.reward_type == "batch": 73 | RewardManager = BatchFunctionRewardManager 74 | else: 75 | raise NotImplementedError(f"Unknown reward type {config.worker.reward.reward_type}.") 76 | 77 | RemoteRewardManager = ray.remote(RewardManager).options(num_cpus=config.worker.reward.num_cpus) 78 | reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer) 79 | val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer) 80 | 81 | train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor) 82 | 83 | trainer = RayPPOTrainer( 84 | config=config, 85 | tokenizer=tokenizer, 86 | processor=processor, 87 | train_dataloader=train_dataloader, 88 | val_dataloader=val_dataloader, 89 | role_worker_mapping=role_worker_mapping, 90 | resource_pool_manager=resource_pool_manager, 91 | ray_worker_group_cls=ray_worker_group_cls, 92 | reward_fn=reward_fn, 93 | val_reward_fn=val_reward_fn, 94 | ) 95 | trainer.init_workers() 96 | trainer.fit() 97 | 98 | 99 | def main(): 100 | cli_args = OmegaConf.from_cli() 101 | default_config = OmegaConf.structured(PPOConfig()) 102 | 103 | if hasattr(cli_args, "config"): 104 | config_path = cli_args.pop("config", None) 105 | file_config = OmegaConf.load(config_path) 106 | default_config = OmegaConf.merge(default_config, file_config) 107 | 108 | ppo_config = OmegaConf.merge(default_config, cli_args) 109 | ppo_config: PPOConfig = OmegaConf.to_object(ppo_config) 110 | ppo_config.deep_post_init() 111 | 112 | if not ray.is_initialized(): 113 | runtime_env = { 114 | "env_vars": { 115 | "TOKENIZERS_PARALLELISM": "true", 116 | "NCCL_DEBUG": "WARN", 117 | "VLLM_LOGGING_LEVEL": "WARN", 118 | "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", 119 | "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:False", 120 | "PYTHONUNBUFFERED": "1", 121 | } 122 | } 123 | ray.init(runtime_env=runtime_env) 124 | 125 | runner = Runner.remote() 126 | ray.get(runner.run.remote(ppo_config)) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /verl/trainer/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List 16 | 17 | import numpy as np 18 | import torch 19 | 20 | from ..protocol import DataProto 21 | 22 | 23 | def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: 24 | return {key: np.mean(value) for key, value in metrics.items()} 25 | 26 | 27 | def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]: 28 | sequence_score = batch.batch["token_level_scores"].sum(-1) 29 | sequence_reward = batch.batch["token_level_rewards"].sum(-1) 30 | 31 | advantages = batch.batch["advantages"] 32 | returns = batch.batch["returns"] 33 | 34 | max_response_length = batch.batch["responses"].size(-1) 35 | 36 | prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() 37 | response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() 38 | 39 | max_prompt_length = prompt_mask.size(-1) 40 | prompt_length = prompt_mask.sum(-1).float() 41 | response_length = response_mask.sum(-1).float() 42 | 43 | valid_adv = torch.masked_select(advantages, response_mask) 44 | valid_returns = torch.masked_select(returns, response_mask) 45 | 46 | if use_critic: 47 | values = batch.batch["values"] 48 | valid_values = torch.masked_select(values, response_mask) 49 | return_diff_var = torch.var(valid_returns - valid_values) 50 | return_var = torch.var(valid_returns) 51 | 52 | metrics = { 53 | # score 54 | "critic/score/mean": torch.mean(sequence_score).detach().item(), 55 | "critic/score/max": torch.max(sequence_score).detach().item(), 56 | "critic/score/min": torch.min(sequence_score).detach().item(), 57 | # reward 58 | "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), 59 | "critic/rewards/max": torch.max(sequence_reward).detach().item(), 60 | "critic/rewards/min": torch.min(sequence_reward).detach().item(), 61 | # adv 62 | "critic/advantages/mean": torch.mean(valid_adv).detach().item(), 63 | "critic/advantages/max": torch.max(valid_adv).detach().item(), 64 | "critic/advantages/min": torch.min(valid_adv).detach().item(), 65 | # returns 66 | "critic/returns/mean": torch.mean(valid_returns).detach().item(), 67 | "critic/returns/max": torch.max(valid_returns).detach().item(), 68 | "critic/returns/min": torch.min(valid_returns).detach().item(), 69 | **( 70 | { 71 | # values 72 | "critic/values/mean": torch.mean(valid_values).detach().item(), 73 | "critic/values/max": torch.max(valid_values).detach().item(), 74 | "critic/values/min": torch.min(valid_values).detach().item(), 75 | # vf explained var 76 | "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), 77 | } 78 | if use_critic 79 | else {} 80 | ), 81 | # response length 82 | "response_length/mean": torch.mean(response_length).detach().item(), 83 | "response_length/max": torch.max(response_length).detach().item(), 84 | "response_length/min": torch.min(response_length).detach().item(), 85 | "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) 86 | .detach() 87 | .item(), 88 | # prompt length 89 | "prompt_length/mean": torch.mean(prompt_length).detach().item(), 90 | "prompt_length/max": torch.max(prompt_length).detach().item(), 91 | "prompt_length/min": torch.min(prompt_length).detach().item(), 92 | "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), 93 | } 94 | return metrics 95 | 96 | 97 | def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: 98 | num_response_tokens = torch.sum(batch.batch["response_mask"]).item() 99 | num_overall_tokens = sum(batch.meta_info["global_token_num"]) 100 | num_tokens_of_section = { 101 | **dict.fromkeys(["gen", "reward"], num_response_tokens), 102 | **dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens), 103 | } 104 | return { 105 | **{f"timing_s/{name}": value for name, value in timing_raw.items()}, 106 | **{ 107 | f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] 108 | for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) 109 | }, 110 | } 111 | 112 | 113 | def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], num_gpus: int) -> Dict[str, Any]: 114 | total_num_tokens = sum(batch.meta_info["global_token_num"]) 115 | time = timing_raw["step"] 116 | return { 117 | "perf/total_num_tokens": total_num_tokens, 118 | "perf/time_per_step": time, 119 | "perf/throughput": total_num_tokens / (time * num_gpus), 120 | } 121 | -------------------------------------------------------------------------------- /verl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/utils/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt 16 | 17 | 18 | __all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"] 19 | -------------------------------------------------------------------------------- /verl/utils/checkpoint/checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import random 17 | import re 18 | import shutil 19 | import tempfile 20 | from abc import ABC, abstractmethod 21 | from typing import Any, Dict, Optional, Union 22 | 23 | import numpy as np 24 | import torch 25 | import torch.distributed as dist 26 | from filelock import FileLock 27 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 28 | from transformers import PreTrainedTokenizer, ProcessorMixin 29 | 30 | 31 | CHECKPOINT_TRACKER = "latest_global_step.txt" 32 | 33 | 34 | class BaseCheckpointManager(ABC): 35 | """ 36 | A checkpoint manager that saves and loads 37 | - model 38 | - optimizer 39 | - lr_scheduler 40 | - extra_states 41 | in a SPMD way. 42 | 43 | We save 44 | - sharded model states and optimizer states 45 | - full lr_scheduler states 46 | - huggingface tokenizer and config for ckpt merge 47 | """ 48 | 49 | def __init__( 50 | self, 51 | model: FSDP, 52 | optimizer: torch.optim.Optimizer, 53 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler, 54 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin], 55 | ): 56 | self.model = model 57 | self.optimizer = optimizer 58 | self.lr_scheduler = lr_scheduler 59 | self.processing_class = processing_class 60 | 61 | assert isinstance(self.model, FSDP) 62 | self.rank = dist.get_rank() 63 | self.world_size = dist.get_world_size() 64 | 65 | @abstractmethod 66 | def load_checkpoint(self, *args, **kwargs): 67 | raise NotImplementedError 68 | 69 | @abstractmethod 70 | def save_checkpoint(self, *args, **kwargs): 71 | raise NotImplementedError 72 | 73 | @staticmethod 74 | def local_mkdir(path: str) -> str: 75 | if not os.path.isabs(path): 76 | working_dir = os.getcwd() 77 | path = os.path.join(working_dir, path) 78 | 79 | # Using hash value of path as lock file name to avoid long file name 80 | lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock" 81 | lock_path = os.path.join(tempfile.gettempdir(), lock_filename) 82 | 83 | try: 84 | with FileLock(lock_path, timeout=60): 85 | os.makedirs(path, exist_ok=True) 86 | except Exception as e: 87 | print(f"Warning: Failed to acquire lock for {path}: {e}") 88 | os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory 89 | 90 | return path 91 | 92 | @staticmethod 93 | def get_rng_state() -> Dict[str, Any]: 94 | rng_state = { 95 | "cpu": torch.get_rng_state(), 96 | "cuda": torch.cuda.get_rng_state(), 97 | "numpy": np.random.get_state(), 98 | "random": random.getstate(), 99 | } 100 | return rng_state 101 | 102 | @staticmethod 103 | def load_rng_state(rng_state: Dict[str, Any]): 104 | torch.set_rng_state(rng_state["cpu"]) 105 | torch.cuda.set_rng_state(rng_state["cuda"]) 106 | np.random.set_state(rng_state["numpy"]) 107 | random.setstate(rng_state["random"]) 108 | 109 | 110 | def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]: 111 | if path is None: 112 | return None 113 | 114 | tracker_file = get_checkpoint_tracker_filename(path) 115 | if not os.path.exists(tracker_file): 116 | print("Checkpoint tracker file does not exist: %s", tracker_file) 117 | return None 118 | 119 | with open(tracker_file, "rb") as f: 120 | iteration = int(f.read().decode()) 121 | 122 | ckpt_path = os.path.join(path, directory_format.format(iteration)) 123 | if not os.path.exists(ckpt_path): 124 | print("Checkpoint does not exist: %s", ckpt_path) 125 | return None 126 | 127 | print("Found checkpoint: %s", ckpt_path) 128 | return ckpt_path 129 | 130 | 131 | def get_checkpoint_tracker_filename(root_path: str) -> str: 132 | """ 133 | Tracker file rescords the latest chckpoint during training to restart from. 134 | """ 135 | return os.path.join(root_path, CHECKPOINT_TRACKER) 136 | 137 | 138 | def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"): 139 | """ 140 | Remove the obsolete checkpoints that exceed the save_limit. 141 | """ 142 | if save_limit <= 0: 143 | return 144 | 145 | if not os.path.exists(path): 146 | return 147 | 148 | pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)") 149 | ckpt_folders = [] 150 | for folder in os.listdir(path): 151 | if match := re.match(pattern, folder): 152 | step = int(match.group(1)) 153 | if step < global_step: 154 | ckpt_folders.append((step, folder)) 155 | 156 | ckpt_folders.sort(reverse=True) 157 | for _, folder in ckpt_folders[save_limit - 1 :]: 158 | folder_path = os.path.join(path, folder) 159 | shutil.rmtree(folder_path, ignore_errors=True) 160 | print(f"Removed obsolete checkpoint: {folder_path}") 161 | -------------------------------------------------------------------------------- /verl/utils/checkpoint/fsdp_checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from typing import Optional, Union 17 | 18 | import torch 19 | import torch.distributed as dist 20 | from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict, set_state_dict 21 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 22 | from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin 23 | 24 | from .checkpoint_manager import BaseCheckpointManager 25 | 26 | 27 | class FSDPCheckpointManager(BaseCheckpointManager): 28 | """ 29 | A checkpoint manager that saves and loads 30 | - model 31 | - optimizer 32 | - lr_scheduler 33 | - extra_states 34 | in a SPMD way. 35 | 36 | We save 37 | - sharded model states and optimizer states 38 | - full lr_scheduler states 39 | - huggingface tokenizer and config for ckpt merge 40 | """ 41 | 42 | def __init__( 43 | self, 44 | model: FSDP, 45 | optimizer: torch.optim.Optimizer, 46 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler, 47 | processing_class: Union[PreTrainedTokenizer, ProcessorMixin], 48 | ): 49 | super().__init__(model, optimizer, lr_scheduler, processing_class) 50 | 51 | def load_checkpoint(self, path: Optional[str] = None): 52 | if path is None: 53 | return 54 | 55 | # every rank download its own checkpoint 56 | model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") 57 | optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") 58 | extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") 59 | print(f"[rank-{self.rank}]: Loading model from {os.path.abspath(model_path)}.") 60 | print(f"[rank-{self.rank}]: Loading optimizer from {os.path.abspath(optim_path)}.") 61 | print(f"[rank-{self.rank}]: Loading extra_state from {os.path.abspath(extra_path)}.") 62 | model_state_dict = torch.load(model_path, weights_only=False) 63 | optim_state_dict = torch.load(optim_path, weights_only=False) 64 | extra_state_dict = torch.load(extra_path, weights_only=False) 65 | 66 | state_dict_options = StateDictOptions(cpu_offload=True) 67 | set_state_dict( 68 | model=self.model, 69 | optimizers=self.optimizer, 70 | model_state_dict=model_state_dict, 71 | optim_state_dict=optim_state_dict, 72 | options=state_dict_options, 73 | ) 74 | self.lr_scheduler.load_state_dict(extra_state_dict["lr_scheduler"]) 75 | 76 | # recover random state 77 | if "rng" in extra_state_dict: 78 | self.load_rng_state(extra_state_dict["rng"]) 79 | 80 | def save_checkpoint(self, path: str): 81 | path = self.local_mkdir(path) 82 | dist.barrier() 83 | 84 | # every rank will save its own model and optim shard 85 | state_dict_options = StateDictOptions(cpu_offload=True) 86 | model_state_dict, optim_state_dict = get_state_dict(self.model, self.optimizer, options=state_dict_options) 87 | extra_state_dict = { 88 | "lr_scheduler": self.lr_scheduler.state_dict(), 89 | "rng": self.get_rng_state(), 90 | } 91 | model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") 92 | optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") 93 | extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") 94 | 95 | print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.") 96 | print(f"[rank-{self.rank}]: Saving optimizer to {os.path.abspath(optim_path)}.") 97 | print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.") 98 | torch.save(model_state_dict, model_path) 99 | torch.save(optim_state_dict, optim_path) 100 | torch.save(extra_state_dict, extra_path) 101 | 102 | # wait for everyone to dump to local 103 | dist.barrier() 104 | 105 | if self.rank == 0: 106 | hf_path = os.path.join(path, "huggingface") 107 | os.makedirs(hf_path, exist_ok=True) 108 | assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel) 109 | self.model._fsdp_wrapped_module.config.save_pretrained(hf_path) 110 | self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path) 111 | self.processing_class.save_pretrained(hf_path) 112 | 113 | dist.barrier() 114 | -------------------------------------------------------------------------------- /verl/utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import os 17 | from collections import defaultdict 18 | from io import BytesIO 19 | from typing import Any, Dict, List, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | from datasets import load_dataset 24 | from jinja2 import Template 25 | from PIL import Image 26 | from PIL.Image import Image as ImageObject 27 | from torch.utils.data import Dataset 28 | from transformers import PreTrainedTokenizer, ProcessorMixin 29 | 30 | from ..models.transformers.qwen2_vl import get_rope_index 31 | from . import torch_functional as VF 32 | 33 | 34 | def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]: 35 | tensors = defaultdict(list) 36 | non_tensors = defaultdict(list) 37 | for feature in features: 38 | for key, value in feature.items(): 39 | if isinstance(value, torch.Tensor): 40 | tensors[key].append(value) 41 | else: 42 | non_tensors[key].append(value) 43 | 44 | for key, value in tensors.items(): 45 | tensors[key] = torch.stack(value, dim=0) 46 | 47 | for key, value in non_tensors.items(): 48 | non_tensors[key] = np.array(value, dtype=object) 49 | 50 | return {**tensors, **non_tensors} 51 | 52 | 53 | class ImageProcessMixin: 54 | max_pixels: int 55 | min_pixels: int 56 | 57 | def process_image(self, image: Union[Dict[str, Any], ImageObject]) -> ImageObject: 58 | if isinstance(image, dict): 59 | image = Image.open(BytesIO(image["bytes"])) 60 | elif isinstance(image, bytes): 61 | image = Image.open(BytesIO(image)) 62 | 63 | if (image.width * image.height) > self.max_pixels: 64 | resize_factor = math.sqrt(self.max_pixels / (image.width * image.height)) 65 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 66 | image = image.resize((width, height)) 67 | 68 | if (image.width * image.height) < self.min_pixels: 69 | resize_factor = math.sqrt(self.min_pixels / (image.width * image.height)) 70 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 71 | image = image.resize((width, height)) 72 | 73 | if image.mode != "RGB": 74 | image = image.convert("RGB") 75 | 76 | return image 77 | 78 | 79 | class RLHFDataset(Dataset, ImageProcessMixin): 80 | """ 81 | We assume the dataset contains a column that contains prompts and other information 82 | """ 83 | 84 | def __init__( 85 | self, 86 | data_path: str, 87 | tokenizer: PreTrainedTokenizer, 88 | processor: Optional[ProcessorMixin], 89 | prompt_key: str = "prompt", 90 | answer_key: str = "answer", 91 | image_key: str = "images", 92 | max_prompt_length: int = 1024, 93 | truncation: str = "error", 94 | format_prompt: Optional[str] = None, 95 | max_pixels: Optional[int] = None, 96 | min_pixels: Optional[int] = None, 97 | filter_overlong_prompts: bool = True, 98 | ): 99 | self.tokenizer = tokenizer 100 | self.processor = processor 101 | self.prompt_key = prompt_key 102 | self.answer_key = answer_key 103 | self.image_key = image_key 104 | self.max_prompt_length = max_prompt_length 105 | self.truncation = truncation 106 | self.max_pixels = max_pixels 107 | self.min_pixels = min_pixels 108 | self.filter_overlong_prompts = filter_overlong_prompts 109 | 110 | if "@" in data_path: 111 | data_path, data_split = data_path.split("@") 112 | else: 113 | data_split = "train" 114 | 115 | if os.path.isdir(data_path): 116 | # when we use dataset builder, we should always refer to the train split 117 | self.dataset = load_dataset("parquet", data_dir=data_path, split="train") 118 | elif os.path.isfile(data_path): 119 | self.dataset = load_dataset("parquet", data_files=data_path, split="train") 120 | else: 121 | # load remote dataset from huggingface hub 122 | self.dataset = load_dataset(data_path, split=data_split) 123 | 124 | self.format_prompt = None 125 | if format_prompt: 126 | with open(format_prompt, encoding="utf-8") as f: 127 | self.format_prompt = f.read() 128 | 129 | if self.filter_overlong_prompts: 130 | self.dataset = self.dataset.filter(self._filter_overlong_prompts, desc="Filtering overlong prompts") 131 | 132 | def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: 133 | prompt_str: str = example[self.prompt_key] 134 | if self.format_prompt: 135 | format_prompt = Template(self.format_prompt.strip()) 136 | prompt_str = format_prompt.render(content=prompt_str) 137 | 138 | if self.image_key in example: 139 | # https://huggingface.co/docs/transformers/en/tasks/image_text_to_text 140 | content_list = [] 141 | for i, content in enumerate(prompt_str.split("")): 142 | if i != 0: 143 | content_list.append({"type": "image"}) 144 | 145 | if content: 146 | content_list.append({"type": "text", "text": content}) 147 | 148 | return [{"role": "user", "content": content_list}] 149 | else: 150 | return [{"role": "user", "content": prompt_str}] 151 | 152 | def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool: 153 | messages = self._build_messages(example) 154 | processing_class = self.processor if self.processor is not None else self.tokenizer 155 | return ( 156 | len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length 157 | ) 158 | 159 | def __len__(self): 160 | return len(self.dataset) 161 | 162 | def __getitem__(self, index): 163 | example: dict = self.dataset[index] 164 | messages = self._build_messages(example) 165 | 166 | if self.image_key in example: 167 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 168 | images = [self.process_image(image) for image in example.pop(self.image_key)] 169 | model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt") 170 | input_ids = model_inputs.pop("input_ids")[0] 171 | attention_mask = model_inputs.pop("attention_mask")[0] 172 | example["multi_modal_data"] = {"image": images} 173 | example["multi_modal_inputs"] = dict(model_inputs) 174 | else: 175 | prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 176 | model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt") 177 | input_ids = model_inputs.pop("input_ids")[0] 178 | attention_mask = model_inputs.pop("attention_mask")[0] 179 | 180 | if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor": 181 | # qwen2vl mrope 182 | position_ids = get_rope_index( 183 | self.processor, 184 | input_ids=input_ids, 185 | image_grid_thw=model_inputs.get("image_grid_thw"), 186 | attention_mask=attention_mask, 187 | ) # (3, seq_length) 188 | else: 189 | position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,) 190 | 191 | input_ids, attention_mask, position_ids = VF.postprocess_data( 192 | input_ids=input_ids, 193 | attention_mask=attention_mask, 194 | position_ids=position_ids, 195 | max_length=self.max_prompt_length, 196 | pad_token_id=self.tokenizer.pad_token_id, 197 | left_pad=True, 198 | truncation=self.truncation, 199 | ) 200 | raw_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) 201 | if len(raw_prompt_ids) > self.max_prompt_length: 202 | if self.truncation == "left": 203 | raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] 204 | elif self.truncation == "right": 205 | raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] 206 | elif self.truncation == "error": 207 | raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") 208 | 209 | example["input_ids"] = input_ids 210 | example["attention_mask"] = attention_mask 211 | example["position_ids"] = position_ids 212 | example["raw_prompt_ids"] = raw_prompt_ids 213 | example["ground_truth"] = example.pop(self.answer_key) 214 | return example 215 | -------------------------------------------------------------------------------- /verl/utils/flops_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, List, Tuple 16 | 17 | import torch 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers.models.llama.configuration_llama import LlamaConfig 22 | 23 | 24 | VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3"} 25 | 26 | 27 | def get_device_flops(unit: str = "T") -> float: 28 | def unit_convert(number: float, level: str): 29 | units = ["B", "K", "M", "G", "T", "P"] 30 | if number <= 0: 31 | return number 32 | 33 | ptr = 0 34 | while ptr < len(units) and units[ptr] != level: 35 | number /= 1000 36 | ptr += 1 37 | 38 | return number 39 | 40 | device_name = torch.cuda.get_device_name() 41 | flops = float("inf") # INF flops for unkown gpu type 42 | if "H100" in device_name or "H800" in device_name: 43 | flops = 989e12 44 | elif "A100" in device_name or "A800" in device_name: 45 | flops = 312e12 46 | elif "L40" in device_name: 47 | flops = 181.05e12 48 | elif "L20" in device_name: 49 | flops = 119.5e12 50 | elif "H20" in device_name: 51 | flops = 148e12 52 | elif "910B" in device_name: 53 | flops = 354e12 54 | flops_unit = unit_convert(flops, unit) 55 | return flops_unit 56 | 57 | 58 | class FlopsCounter: 59 | """ 60 | Used to count mfu during training loop 61 | 62 | Example: 63 | flops_counter = FlopsCounter(config) 64 | flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time) 65 | """ 66 | 67 | def __init__(self, config: "LlamaConfig"): 68 | if config.model_type not in VALID_MODLE_TYPE: 69 | print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.") 70 | 71 | self.estimate_func = { 72 | "llama": self._estimate_llama_flops, 73 | "qwen2": self._estimate_llama_flops, 74 | "qwen2_vl": self._estimate_llama_flops, 75 | "qwen2_5_vl": self._estimate_llama_flops, 76 | } 77 | self.config = config 78 | 79 | def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float: 80 | return 0 81 | 82 | def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float: 83 | hidden_size = self.config.hidden_size 84 | vocab_size = self.config.vocab_size 85 | num_hidden_layers = self.config.num_hidden_layers 86 | num_key_value_heads = self.config.num_key_value_heads 87 | num_attention_heads = self.config.num_attention_heads 88 | intermediate_size = self.config.intermediate_size 89 | 90 | head_dim = hidden_size // num_attention_heads 91 | q_size = num_attention_heads * head_dim 92 | k_size = num_key_value_heads * head_dim 93 | v_size = num_key_value_heads * head_dim 94 | 95 | # non-attn per layer parm 96 | # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp 97 | mlp_N = hidden_size * intermediate_size * 3 98 | attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) 99 | emd_and_lm_head_N = vocab_size * hidden_size * 2 100 | # non-attn all_layer parm 101 | dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N 102 | # non-attn all_layer & all_token fwd & bwd flops 103 | dense_N_flops = 6 * dense_N * tokens_sum 104 | 105 | # attn all_layer & all_token fwd & bwd flops 106 | seqlen_square_sum = 0 107 | for seqlen in batch_seqlens: 108 | seqlen_square_sum += seqlen * seqlen 109 | 110 | attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers 111 | 112 | # all_layer & all_token fwd & bwd flops 113 | flops_all_token = dense_N_flops + attn_qkv_flops 114 | flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 115 | return flops_achieved 116 | 117 | def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]: 118 | """ 119 | Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. 120 | 121 | Args: 122 | batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch. 123 | delta_time (float): The time taken to process the batch, in seconds. 124 | 125 | Returns: 126 | estimated_flops (float): The estimated FLOPS based on the input tokens and time. 127 | promised_flops (float): The expected FLOPS of the current device. 128 | """ 129 | tokens_sum = sum(batch_seqlens) 130 | func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops) 131 | estimated_flops = func(tokens_sum, batch_seqlens, delta_time) 132 | promised_flops = get_device_flops() 133 | return estimated_flops, promised_flops 134 | -------------------------------------------------------------------------------- /verl/utils/fsdp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gc 16 | from collections import defaultdict 17 | from functools import partial 18 | from typing import Callable, Union 19 | 20 | import torch 21 | from torch import nn 22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 23 | from torch.distributed.fsdp._runtime_utils import _lazy_init 24 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 25 | from torch.optim import Optimizer 26 | from transformers import PreTrainedModel 27 | from transformers.trainer_pt_utils import get_module_class_from_name 28 | 29 | 30 | def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]: 31 | param_occurrence = defaultdict(int) 32 | for _, param in model.named_parameters(remove_duplicate=False): 33 | param_occurrence[param] += 1 34 | 35 | duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1} 36 | materialized_params = {} 37 | 38 | def init_fn(module: nn.Module): 39 | for name, param in module.named_parameters(recurse=False): 40 | if param in duplicated_params: 41 | module._parameters[name] = materialized_params.setdefault( 42 | param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad) 43 | ) 44 | else: 45 | module._parameters[name] = nn.Parameter( 46 | torch.empty_like(param.data, device=device), requires_grad=param.requires_grad 47 | ) 48 | 49 | return init_fn 50 | 51 | 52 | def get_fsdp_wrap_policy(model: PreTrainedModel): 53 | """Get FSDP wrap policy for the model. 54 | 55 | Args: 56 | module: The module to get wrap policy for 57 | """ 58 | transformer_cls_to_wrap = set() 59 | for module in model._no_split_modules: 60 | transformer_cls = get_module_class_from_name(model, module) 61 | if transformer_cls is None: 62 | raise Exception(f"Cannot find {module} in pretrained model.") 63 | else: 64 | transformer_cls_to_wrap.add(transformer_cls) 65 | 66 | return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap) 67 | 68 | 69 | @torch.no_grad() 70 | def offload_fsdp_model(model: FSDP, empty_cache: bool = True): 71 | # lazy init FSDP model 72 | _lazy_init(model, model) 73 | assert model._is_root, "Only support root model offloading to CPU" 74 | for handle in model._all_handles: 75 | if handle._offload_params: 76 | continue 77 | 78 | flat_param = handle.flat_param 79 | assert ( 80 | flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() 81 | and id(flat_param.data) != id(flat_param._local_shard) 82 | and flat_param.data.size() == flat_param._local_shard.size() 83 | ) 84 | handle.flat_param_to("cpu", non_blocking=True) 85 | # the following still keeps id(._local_shard) != id(.data) 86 | flat_param._local_shard = flat_param.data 87 | assert id(flat_param._local_shard) != id(flat_param.data) 88 | 89 | if empty_cache: 90 | torch.cuda.empty_cache() 91 | 92 | 93 | @torch.no_grad() 94 | def load_fsdp_model(model: FSDP, empty_cache: bool = True): 95 | # lazy init FSDP model 96 | _lazy_init(model, model) 97 | assert model._is_root, "Only support root model loading to GPU" 98 | for handle in model._all_handles: 99 | if handle._offload_params: 100 | continue 101 | 102 | flat_param = handle.flat_param 103 | handle.flat_param_to("cuda", non_blocking=True) 104 | # the following still keeps id(._local_shard) != id(.data) 105 | flat_param._local_shard = flat_param.data 106 | 107 | if empty_cache: 108 | gc.collect() 109 | 110 | 111 | @torch.no_grad() 112 | def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): 113 | if not optimizer.state: 114 | return 115 | 116 | for param_group in optimizer.param_groups: 117 | for param in param_group["params"]: 118 | state = optimizer.state[param] 119 | for key, value in state.items(): 120 | if isinstance(value, torch.Tensor): 121 | state[key] = value.to("cpu", non_blocking=True) 122 | 123 | if empty_cache: 124 | torch.cuda.empty_cache() 125 | 126 | 127 | @torch.no_grad() 128 | def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): 129 | if not optimizer.state: 130 | return 131 | 132 | for param_group in optimizer.param_groups: 133 | for param in param_group["params"]: 134 | state = optimizer.state[param] 135 | for key, value in state.items(): 136 | if isinstance(value, torch.Tensor): 137 | state[key] = value.to("cuda", non_blocking=True) 138 | 139 | if empty_cache: 140 | gc.collect() 141 | -------------------------------------------------------------------------------- /verl/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .logger import Tracker 17 | 18 | 19 | __all__ = ["Tracker"] 20 | -------------------------------------------------------------------------------- /verl/utils/logger/gen_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | from dataclasses import dataclass 18 | from typing import List, Tuple 19 | 20 | from ..py_functional import is_package_available 21 | 22 | 23 | if is_package_available("wandb"): 24 | import wandb # type: ignore 25 | 26 | 27 | if is_package_available("swanlab"): 28 | import swanlab # type: ignore 29 | 30 | 31 | @dataclass 32 | class GenerationLogger(ABC): 33 | @abstractmethod 34 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: ... 35 | 36 | 37 | @dataclass 38 | class ConsoleGenerationLogger(GenerationLogger): 39 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 40 | for inp, out, lab, score in samples: 41 | print(f"[prompt] {inp}\n[output] {out}\n[ground_truth] {lab}\n[score] {score}\n") 42 | 43 | 44 | @dataclass 45 | class WandbGenerationLogger(GenerationLogger): 46 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 47 | # Create column names for all samples 48 | columns = ["step"] + sum( 49 | [[f"input_{i + 1}", f"output_{i + 1}", f"label_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], 50 | [], 51 | ) 52 | 53 | if not hasattr(self, "validation_table"): 54 | # Initialize the table on first call 55 | self.validation_table = wandb.Table(columns=columns) 56 | 57 | # Create a new table with same columns and existing data 58 | # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 59 | new_table = wandb.Table(columns=columns, data=self.validation_table.data) 60 | 61 | # Add new row with all data 62 | row_data = [step] 63 | for sample in samples: 64 | row_data.extend(sample) 65 | 66 | new_table.add_data(*row_data) 67 | wandb.log({"val/generations": new_table}, step=step) 68 | self.validation_table = new_table 69 | 70 | 71 | @dataclass 72 | class SwanlabGenerationLogger(GenerationLogger): 73 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 74 | swanlab_text_list = [] 75 | for i, sample in enumerate(samples): 76 | row_text = "\n\n---\n\n".join( 77 | (f"input: {sample[0]}", f"output: {sample[1]}", f"label: {sample[2]}", f"score: {sample[3]}") 78 | ) 79 | swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) 80 | 81 | swanlab.log({"val/generations": swanlab_text_list}, step=step) 82 | 83 | 84 | GEN_LOGGERS = { 85 | "console": ConsoleGenerationLogger, 86 | "wandb": WandbGenerationLogger, 87 | "swanlab": SwanlabGenerationLogger, 88 | } 89 | 90 | 91 | @dataclass 92 | class AggregateGenerationsLogger: 93 | def __init__(self, loggers: List[str]): 94 | self.loggers: List[GenerationLogger] = [] 95 | 96 | for logger in loggers: 97 | if logger in GEN_LOGGERS: 98 | self.loggers.append(GEN_LOGGERS[logger]()) 99 | 100 | def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 101 | for logger in self.loggers: 102 | logger.log(samples, step) 103 | -------------------------------------------------------------------------------- /verl/utils/logger/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A unified tracking interface that supports logging data to different backend 16 | """ 17 | 18 | import os 19 | from abc import ABC, abstractmethod 20 | from typing import Any, Dict, List, Optional, Tuple, Union 21 | 22 | from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict 23 | from .gen_logger import AggregateGenerationsLogger 24 | 25 | 26 | if is_package_available("mlflow"): 27 | import mlflow # type: ignore 28 | 29 | 30 | if is_package_available("tensorboard"): 31 | from torch.utils.tensorboard import SummaryWriter 32 | 33 | 34 | if is_package_available("wandb"): 35 | import wandb # type: ignore 36 | 37 | 38 | if is_package_available("swanlab"): 39 | import swanlab # type: ignore 40 | 41 | 42 | class Logger(ABC): 43 | @abstractmethod 44 | def __init__(self, config: Dict[str, Any]) -> None: ... 45 | 46 | @abstractmethod 47 | def log(self, data: Dict[str, Any], step: int) -> None: ... 48 | 49 | def finish(self) -> None: 50 | pass 51 | 52 | 53 | class ConsoleLogger(Logger): 54 | def __init__(self, config: Dict[str, Any]) -> None: 55 | print("Config\n" + convert_dict_to_str(config)) 56 | 57 | def log(self, data: Dict[str, Any], step: int) -> None: 58 | print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data))) 59 | 60 | 61 | class MlflowLogger(Logger): 62 | def __init__(self, config: Dict[str, Any]) -> None: 63 | mlflow.start_run(run_name=config["trainer"]["experiment_name"]) 64 | mlflow.log_params(flatten_dict(config)) 65 | 66 | def log(self, data: Dict[str, Any], step: int) -> None: 67 | mlflow.log_metrics(metrics=data, step=step) 68 | 69 | 70 | class SwanlabLogger(Logger): 71 | def __init__(self, config: Dict[str, Any]) -> None: 72 | swanlab_key = os.getenv("SWANLAB_API_KEY") 73 | swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log") 74 | swanlab_mode = os.getenv("SWANLAB_MODE", "cloud") 75 | if swanlab_key: 76 | swanlab.login(swanlab_key) 77 | 78 | swanlab.init( 79 | project=config["trainer"]["project_name"], 80 | experiment_name=config["trainer"]["experiment_name"], 81 | config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config}, 82 | logdir=swanlab_dir, 83 | mode=swanlab_mode, 84 | ) 85 | 86 | def log(self, data: Dict[str, Any], step: int) -> None: 87 | swanlab.log(data=data, step=step) 88 | 89 | def finish(self) -> None: 90 | swanlab.finish() 91 | 92 | 93 | class TensorBoardLogger(Logger): 94 | def __init__(self, config: Dict[str, Any]) -> None: 95 | tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log") 96 | os.makedirs(tensorboard_dir, exist_ok=True) 97 | print(f"Saving tensorboard log to {tensorboard_dir}.") 98 | self.writer = SummaryWriter(tensorboard_dir) 99 | # self.writer.add_hparams(hparam_dict=flatten_dict(config), metric_dict={"placeholder": 0}) 100 | 101 | def log(self, data: Dict[str, Any], step: int) -> None: 102 | for key, value in data.items(): 103 | self.writer.add_scalar(key, value, step) 104 | 105 | def finish(self): 106 | self.writer.close() 107 | 108 | 109 | class WandbLogger(Logger): 110 | def __init__(self, config: Dict[str, Any]) -> None: 111 | wandb.init( 112 | project=config["trainer"]["project_name"], 113 | name=config["trainer"]["experiment_name"], 114 | config=config, 115 | ) 116 | 117 | def log(self, data: Dict[str, Any], step: int) -> None: 118 | wandb.log(data=data, step=step) 119 | 120 | def finish(self) -> None: 121 | wandb.finish() 122 | 123 | 124 | LOGGERS = { 125 | "console": ConsoleLogger, 126 | "mlflow": MlflowLogger, 127 | "swanlab": SwanlabLogger, 128 | "tensorboard": TensorBoardLogger, 129 | "wandb": WandbLogger, 130 | } 131 | 132 | 133 | class Tracker: 134 | def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None): 135 | if isinstance(loggers, str): 136 | loggers = [loggers] 137 | 138 | self.loggers: List[Logger] = [] 139 | for logger in loggers: 140 | if logger not in LOGGERS: 141 | raise ValueError(f"{logger} is not supported.") 142 | 143 | self.loggers.append(LOGGERS[logger](config)) 144 | 145 | self.gen_logger = AggregateGenerationsLogger(loggers) 146 | 147 | def log(self, data: Dict[str, Any], step: int) -> None: 148 | for logger in self.loggers: 149 | logger.log(data=data, step=step) 150 | 151 | def log_generation(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: 152 | self.gen_logger.log(samples, step) 153 | 154 | def __del__(self): 155 | for logger in self.loggers: 156 | logger.finish() 157 | -------------------------------------------------------------------------------- /verl/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities to create common models 16 | """ 17 | 18 | from functools import lru_cache 19 | from typing import Optional, Tuple 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch import nn 24 | 25 | 26 | @lru_cache 27 | def is_rank0() -> int: 28 | return (not dist.is_initialized()) or (dist.get_rank() == 0) 29 | 30 | 31 | def print_gpu_memory_usage(prefix: str = "GPU memory usage") -> None: 32 | """Report the current GPU VRAM usage.""" 33 | if is_rank0(): 34 | free_mem, total_mem = torch.cuda.mem_get_info() 35 | print(f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB.") 36 | 37 | 38 | def _get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]: 39 | """Compute the model size.""" 40 | n_params = sum(p.numel() for p in model.parameters()) 41 | 42 | if scale == "auto": 43 | if n_params > 1e9: 44 | scale = "B" 45 | elif n_params > 1e6: 46 | scale = "M" 47 | elif n_params > 1e3: 48 | scale = "K" 49 | else: 50 | scale = "" 51 | 52 | if scale == "B": 53 | n_params = n_params / 1e9 54 | elif scale == "M": 55 | n_params = n_params / 1e6 56 | elif scale == "K": 57 | n_params = n_params / 1e3 58 | elif scale == "": 59 | pass 60 | else: 61 | raise NotImplementedError(f"Unknown scale {scale}.") 62 | 63 | return n_params, scale 64 | 65 | 66 | def print_model_size(model: nn.Module, name: Optional[str] = None) -> None: 67 | """Print the model size.""" 68 | if is_rank0(): 69 | n_params, scale = _get_model_size(model, scale="auto") 70 | if name is None: 71 | name = model.__class__.__name__ 72 | 73 | print(f"{name} contains {n_params:.2f}{scale} parameters.") 74 | -------------------------------------------------------------------------------- /verl/utils/py_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contain small python utility functions 16 | """ 17 | 18 | import importlib.util 19 | import re 20 | from contextlib import contextmanager 21 | from functools import lru_cache 22 | from typing import Any, Dict, List, Union 23 | 24 | import numpy as np 25 | import yaml 26 | from codetiming import Timer 27 | from yaml import Dumper 28 | 29 | 30 | def is_sci_notation(number: float) -> bool: 31 | pattern = re.compile(r"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$") 32 | return bool(pattern.match(str(number))) 33 | 34 | 35 | def float_representer(dumper: Dumper, number: Union[float, np.float32, np.float64]): 36 | if is_sci_notation(number): 37 | value = str(number) 38 | if "." not in value and "e" in value: 39 | value = value.replace("e", ".0e", 1) 40 | else: 41 | value = str(round(number, 3)) 42 | 43 | return dumper.represent_scalar("tag:yaml.org,2002:float", value) 44 | 45 | 46 | yaml.add_representer(float, float_representer) 47 | yaml.add_representer(np.float32, float_representer) 48 | yaml.add_representer(np.float64, float_representer) 49 | 50 | 51 | @lru_cache 52 | def is_package_available(name: str) -> bool: 53 | return importlib.util.find_spec(name) is not None 54 | 55 | 56 | def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: 57 | """Union two dict. Will throw an error if there is an item not the same object with the same key.""" 58 | for key in dict2.keys(): 59 | if key in dict1: 60 | assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object" 61 | 62 | dict1[key] = dict2[key] 63 | 64 | return dict1 65 | 66 | 67 | def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None: 68 | """Append dict to a dict of list.""" 69 | for key, val in new_data.items(): 70 | if key not in data: 71 | data[key] = [] 72 | 73 | data[key].append(val) 74 | 75 | 76 | def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]: 77 | unflattened = {} 78 | for key, value in data.items(): 79 | pieces = key.split(sep) 80 | pointer = unflattened 81 | for piece in pieces[:-1]: 82 | if piece not in pointer: 83 | pointer[piece] = {} 84 | 85 | pointer = pointer[piece] 86 | 87 | pointer[pieces[-1]] = value 88 | 89 | return unflattened 90 | 91 | 92 | def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]: 93 | flattened = {} 94 | for key, value in data.items(): 95 | new_key = parent_key + sep + key if parent_key else key 96 | if isinstance(value, dict): 97 | flattened.update(flatten_dict(value, new_key, sep=sep)) 98 | else: 99 | flattened[new_key] = value 100 | 101 | return flattened 102 | 103 | 104 | def convert_dict_to_str(data: Dict[str, Any]) -> str: 105 | return yaml.dump(data, indent=2) 106 | 107 | 108 | @contextmanager 109 | def timer(name: str, timing_raw: Dict[str, float]): 110 | with Timer(name=name, logger=None) as timer: 111 | yield 112 | 113 | timing_raw[name] = timer.last 114 | -------------------------------------------------------------------------------- /verl/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for tokenization.""" 15 | 16 | from typing import Optional 17 | 18 | from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin 19 | 20 | 21 | def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> PreTrainedTokenizer: 22 | """Create a huggingface pretrained tokenizer.""" 23 | tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs) 24 | if override_chat_template is not None: 25 | tokenizer.chat_template = override_chat_template 26 | 27 | if tokenizer.bos_token == "" and tokenizer.eos_token == "": 28 | # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance. 29 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a 30 | print("Found gemma model. Set eos_token and eos_token_id to and 107.") 31 | tokenizer.eos_token = "" 32 | 33 | if tokenizer.pad_token_id is None: 34 | print("Pad token is None. Set it to eos_token.") 35 | tokenizer.pad_token = tokenizer.eos_token 36 | 37 | return tokenizer 38 | 39 | 40 | def get_processor(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> Optional[ProcessorMixin]: 41 | """Create a huggingface pretrained processor.""" 42 | processor = AutoProcessor.from_pretrained(model_path, **kwargs) 43 | if override_chat_template is not None: 44 | processor.chat_template = override_chat_template 45 | 46 | # Avoid load tokenizer, see: 47 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 48 | if processor is not None and "Processor" not in processor.__class__.__name__: 49 | processor = None 50 | 51 | return processor 52 | -------------------------------------------------------------------------------- /verl/utils/torch_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | HALF_LIST = ["fp16", "float16"] 19 | FLOAT_LIST = ["fp32", "float32"] 20 | BFLOAT_LIST = ["bf16", "bfloat16"] 21 | 22 | 23 | class PrecisionType: 24 | """Type of precision used.""" 25 | 26 | @staticmethod 27 | def is_fp16(precision: str) -> bool: 28 | return precision in HALF_LIST 29 | 30 | @staticmethod 31 | def is_fp32(precision: str) -> bool: 32 | return precision in FLOAT_LIST 33 | 34 | @staticmethod 35 | def is_bf16(precision: str) -> bool: 36 | return precision in BFLOAT_LIST 37 | 38 | @staticmethod 39 | def to_dtype(precision: str) -> torch.dtype: 40 | if precision in HALF_LIST: 41 | return torch.float16 42 | elif precision in FLOAT_LIST: 43 | return torch.float32 44 | elif precision in BFLOAT_LIST: 45 | return torch.bfloat16 46 | else: 47 | raise RuntimeError(f"Unexpected precision: {precision}") 48 | 49 | @staticmethod 50 | def to_str(precision: torch.dtype) -> str: 51 | if precision == torch.float16: 52 | return "float16" 53 | elif precision == torch.float32: 54 | return "float32" 55 | elif precision == torch.bfloat16: 56 | return "bfloat16" 57 | else: 58 | raise RuntimeError(f"Unexpected precision: {precision}") 59 | -------------------------------------------------------------------------------- /verl/workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/workers/actor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig 16 | 17 | 18 | __all__ = [ 19 | "ActorConfig", 20 | "FSDPConfig", 21 | "ModelConfig", 22 | "OptimConfig", 23 | "RefConfig", 24 | ] 25 | -------------------------------------------------------------------------------- /verl/workers/actor/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for Actor 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, Dict 20 | 21 | import torch 22 | 23 | from ...protocol import DataProto 24 | from .config import ActorConfig 25 | 26 | 27 | __all__ = ["BasePPOActor"] 28 | 29 | 30 | class BasePPOActor(ABC): 31 | def __init__(self, config: ActorConfig): 32 | """The base class for PPO actor 33 | 34 | Args: 35 | config (ActorConfig): a config passed to the PPOActor. 36 | """ 37 | self.config = config 38 | 39 | @abstractmethod 40 | def compute_log_prob(self, data: DataProto) -> torch.Tensor: 41 | """Compute logits given a batch of data. 42 | 43 | Args: 44 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, 45 | ```attention_mask``` and ```position_ids```. 46 | 47 | Returns: 48 | DataProto: a DataProto containing the key ```log_probs``` 49 | """ 50 | pass 51 | 52 | @abstractmethod 53 | def update_policy(self, data: DataProto) -> Dict[str, Any]: 54 | """Update the policy with an iterator of DataProto 55 | 56 | Args: 57 | data (DataProto): an iterator over the DataProto that returns by 58 | ```make_minibatch_iterator``` 59 | 60 | Returns: 61 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model 62 | such as ```loss```, ```grad_norm```, etc,. 63 | """ 64 | pass 65 | -------------------------------------------------------------------------------- /verl/workers/actor/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Actor config 16 | """ 17 | 18 | import os 19 | from dataclasses import dataclass, field 20 | from typing import Any, Dict, Optional, Tuple 21 | 22 | 23 | @dataclass 24 | class ModelConfig: 25 | model_path: Optional[str] = None 26 | tokenizer_path: Optional[str] = None 27 | override_config: Dict[str, Any] = field(default_factory=dict) 28 | enable_gradient_checkpointing: bool = True 29 | trust_remote_code: bool = True 30 | freeze_vision_tower: bool = False 31 | 32 | def post_init(self): 33 | if self.tokenizer_path is None: 34 | self.tokenizer_path = self.model_path 35 | 36 | if self.model_path is not None and os.path.exists(self.model_path): # ray job uses absolute path 37 | self.model_path = os.path.abspath(self.model_path) 38 | 39 | if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path): 40 | self.tokenizer_path = os.path.abspath(self.tokenizer_path) 41 | 42 | 43 | @dataclass 44 | class OptimConfig: 45 | lr: float = 1e-6 46 | betas: Tuple[float, float] = (0.9, 0.999) 47 | weight_decay: float = 1e-2 48 | strategy: str = "adamw" 49 | lr_warmup_ratio: float = 0.0 50 | min_lr_ratio: Optional[float] = None 51 | warmup_style: str = "constant" 52 | """auto keys""" 53 | training_steps: int = field(default=-1, init=False) 54 | 55 | 56 | @dataclass 57 | class FSDPConfig: 58 | enable_full_shard: bool = True 59 | enable_cpu_offload: bool = False 60 | enable_rank0_init: bool = False 61 | use_orig_params: bool = False 62 | torch_dtype: Optional[str] = None 63 | fsdp_size: int = -1 64 | mp_param_dtype: str = "bf16" 65 | mp_reduce_dtype: str = "fp32" 66 | mp_buffer_dtype: str = "fp32" 67 | 68 | 69 | @dataclass 70 | class OffloadConfig: 71 | offload_params: bool = False 72 | offload_optimizer: bool = False 73 | 74 | 75 | @dataclass 76 | class ActorConfig: 77 | strategy: str = "fsdp" 78 | global_batch_size: int = 256 79 | micro_batch_size_per_device_for_update: int = 4 80 | micro_batch_size_per_device_for_experience: int = 16 81 | max_grad_norm: float = 1.0 82 | clip_ratio_low: float = 0.2 83 | clip_ratio_high: float = 0.3 84 | clip_ratio_dual: float = 3.0 85 | ppo_epochs: int = 1 86 | padding_free: bool = False 87 | ulysses_sequence_parallel_size: int = 1 88 | use_torch_compile: bool = True 89 | model: ModelConfig = field(default_factory=ModelConfig) 90 | optim: OptimConfig = field(default_factory=OptimConfig) 91 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 92 | offload: OffloadConfig = field(default_factory=OffloadConfig) 93 | """auto keys""" 94 | global_batch_size_per_device: int = field(default=-1, init=False) 95 | disable_kl: bool = field(default=False, init=False) 96 | use_kl_loss: bool = field(default=False, init=False) 97 | kl_penalty: str = field(default="kl", init=False) 98 | kl_coef: float = field(default=0.0, init=False) 99 | 100 | 101 | @dataclass 102 | class RefConfig: 103 | strategy: str = "fsdp" 104 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 105 | offload: OffloadConfig = field(default_factory=OffloadConfig) 106 | """auto keys""" 107 | micro_batch_size_per_device_for_experience: int = field(default=-1, init=False) 108 | padding_free: bool = field(default=False, init=False) 109 | ulysses_sequence_parallel_size: int = field(default=1, init=False) 110 | use_torch_compile: bool = field(default=True, init=False) 111 | -------------------------------------------------------------------------------- /verl/workers/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | ActorRolloutRef config 16 | """ 17 | 18 | from dataclasses import dataclass, field 19 | 20 | from .actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig 21 | from .critic import CriticConfig 22 | from .reward import RewardConfig 23 | from .rollout import RolloutConfig 24 | 25 | 26 | __all__ = [ 27 | "ActorConfig", 28 | "CriticConfig", 29 | "FSDPConfig", 30 | "ModelConfig", 31 | "OptimConfig", 32 | "RefConfig", 33 | "RewardConfig", 34 | "RolloutConfig", 35 | "WorkerConfig", 36 | ] 37 | 38 | 39 | @dataclass 40 | class WorkerConfig: 41 | hybrid_engine: bool = True 42 | actor: ActorConfig = field(default_factory=ActorConfig) 43 | critic: CriticConfig = field(default_factory=CriticConfig) 44 | ref: RefConfig = field(default_factory=RefConfig) 45 | reward: RewardConfig = field(default_factory=RewardConfig) 46 | rollout: RolloutConfig = field(default_factory=RolloutConfig) 47 | 48 | def post_init(self): 49 | self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience 50 | self.ref.padding_free = self.actor.padding_free 51 | self.ref.ulysses_sequence_parallel_size = self.actor.ulysses_sequence_parallel_size 52 | self.ref.use_torch_compile = self.actor.use_torch_compile 53 | -------------------------------------------------------------------------------- /verl/workers/critic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .config import CriticConfig 16 | 17 | 18 | __all__ = ["CriticConfig"] 19 | -------------------------------------------------------------------------------- /verl/workers/critic/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Base class for Critic 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, Dict 20 | 21 | import torch 22 | 23 | from ...protocol import DataProto 24 | from .config import CriticConfig 25 | 26 | 27 | __all__ = ["BasePPOCritic"] 28 | 29 | 30 | class BasePPOCritic(ABC): 31 | def __init__(self, config: CriticConfig): 32 | self.config = config 33 | 34 | @abstractmethod 35 | def compute_values(self, data: DataProto) -> torch.Tensor: 36 | """Compute values""" 37 | pass 38 | 39 | @abstractmethod 40 | def update_critic(self, data: DataProto) -> Dict[str, Any]: 41 | """Update the critic""" 42 | pass 43 | -------------------------------------------------------------------------------- /verl/workers/critic/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Critic config 16 | """ 17 | 18 | from dataclasses import dataclass, field 19 | 20 | from ..actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig 21 | 22 | 23 | @dataclass 24 | class CriticConfig: 25 | strategy: str = "fsdp" 26 | global_batch_size: int = 256 27 | micro_batch_size_per_device_for_update: int = 4 28 | micro_batch_size_per_device_for_experience: int = 16 29 | max_grad_norm: float = 1.0 30 | cliprange_value: float = 0.5 31 | ppo_epochs: int = 1 32 | padding_free: bool = False 33 | ulysses_sequence_parallel_size: int = 1 34 | model: ModelConfig = field(default_factory=ModelConfig) 35 | optim: OptimConfig = field(default_factory=OptimConfig) 36 | fsdp: FSDPConfig = field(default_factory=FSDPConfig) 37 | offload: OffloadConfig = field(default_factory=OffloadConfig) 38 | """auto keys""" 39 | global_batch_size_per_device: int = field(default=-1, init=False) 40 | -------------------------------------------------------------------------------- /verl/workers/reward/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .config import RewardConfig 16 | from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager 17 | 18 | 19 | __all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"] 20 | -------------------------------------------------------------------------------- /verl/workers/reward/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Reward config 16 | """ 17 | 18 | import os 19 | from dataclasses import dataclass, field 20 | from typing import Optional 21 | 22 | 23 | @dataclass 24 | class RewardConfig: 25 | reward_type: str = "batch" 26 | reward_function: Optional[str] = None 27 | reward_function_kwargs: dict = field(default_factory=dict) 28 | skip_special_tokens: bool = True 29 | num_cpus: int = 1 30 | """auto keys""" 31 | reward_function_name: Optional[str] = field(default=None, init=False) 32 | 33 | def post_init(self): 34 | if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main 35 | if ":" not in self.reward_function: 36 | self.reward_function_name = "main" 37 | else: 38 | self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1) 39 | 40 | if os.path.exists(self.reward_function): # ray job uses absolute path 41 | self.reward_function = os.path.abspath(self.reward_function) 42 | else: 43 | self.reward_function = None 44 | -------------------------------------------------------------------------------- /verl/workers/reward/function.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib.util 16 | import os 17 | import sys 18 | from abc import ABC, abstractmethod 19 | from collections import defaultdict 20 | from functools import partial 21 | from typing import Callable, Dict, List, Optional, Tuple, TypedDict 22 | 23 | import torch 24 | from transformers import PreTrainedTokenizer 25 | 26 | from ...protocol import DataProto 27 | from .config import RewardConfig 28 | 29 | 30 | class RewardScore(TypedDict): 31 | overall: float 32 | format: Optional[float] 33 | accuracy: Optional[float] 34 | 35 | 36 | SequentialRewardFunction = Callable[[str, str], RewardScore] 37 | 38 | BatchRewardFunction = Callable[[List[str], List[str]], List[RewardScore]] 39 | 40 | 41 | class FunctionRewardManager(ABC): 42 | """Reward manager for rule-based reward.""" 43 | 44 | def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer): 45 | if config.reward_function is None: 46 | raise ValueError("Reward function is not provided.") 47 | 48 | if not os.path.exists(config.reward_function): 49 | raise FileNotFoundError(f"Reward function file {config.reward_function} not found.") 50 | 51 | spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function) 52 | module = importlib.util.module_from_spec(spec) 53 | try: 54 | sys.modules["custom_reward_fn"] = module 55 | spec.loader.exec_module(module) 56 | except Exception as e: 57 | raise RuntimeError(f"Failed to load reward function: {e}") 58 | 59 | if not hasattr(module, config.reward_function_name): 60 | raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.") 61 | 62 | reward_fn = getattr(module, config.reward_function_name) 63 | print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.") 64 | self.reward_fn = partial(reward_fn, **config.reward_function_kwargs) 65 | self.config = config 66 | self.tokenizer = tokenizer 67 | 68 | @abstractmethod 69 | def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: 70 | """Compute reward for a batch of data.""" 71 | ... 72 | 73 | 74 | class SequentialFunctionRewardManager(FunctionRewardManager): 75 | reward_fn: SequentialRewardFunction 76 | 77 | def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: 78 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) 79 | reward_metrics = defaultdict(list) 80 | response_ids = data.batch["responses"] 81 | response_length = data.batch["response_mask"].sum(dim=-1) 82 | for i in range(len(data)): 83 | valid_response_ids = response_ids[i][: response_length[i]] 84 | response_str = self.tokenizer.decode( 85 | valid_response_ids, skip_special_tokens=self.config.skip_special_tokens 86 | ) 87 | ground_truth = data.non_tensor_batch["ground_truth"][i] 88 | 89 | score = self.reward_fn(response_str, ground_truth) 90 | reward_tensor[i, response_length[i] - 1] = score["overall"] 91 | for key, value in score.items(): 92 | reward_metrics[key].append(value) 93 | 94 | return reward_tensor, reward_metrics 95 | 96 | 97 | class BatchFunctionRewardManager(FunctionRewardManager): 98 | reward_fn: BatchRewardFunction 99 | 100 | def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: 101 | response_str, ground_truth = [], [] 102 | response_ids = data.batch["responses"] 103 | response_length = data.batch["response_mask"].sum(dim=-1) 104 | for i in range(len(data)): 105 | valid_response_ids = response_ids[i][: response_length[i]] 106 | response_str.append( 107 | self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens) 108 | ) 109 | ground_truth.append(data.non_tensor_batch["ground_truth"][i]) 110 | 111 | scores = self.reward_fn(response_str, ground_truth) 112 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) 113 | reward_metrics = defaultdict(list) 114 | for i, score in enumerate(scores): 115 | reward_tensor[i, response_length[i] - 1] = score["overall"] 116 | for key, value in score.items(): 117 | reward_metrics[key].append(value) 118 | 119 | return reward_tensor, reward_metrics 120 | -------------------------------------------------------------------------------- /verl/workers/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .config import RolloutConfig 17 | from .vllm_rollout_spmd import vLLMRollout 18 | 19 | 20 | __all__ = ["RolloutConfig", "vLLMRollout"] 21 | -------------------------------------------------------------------------------- /verl/workers/rollout/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | 17 | from ...protocol import DataProto 18 | 19 | 20 | __all__ = ["BaseRollout"] 21 | 22 | 23 | class BaseRollout(ABC): 24 | @abstractmethod 25 | def generate_sequences(self, prompts: DataProto) -> DataProto: 26 | """Generate sequences""" 27 | pass 28 | -------------------------------------------------------------------------------- /verl/workers/rollout/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Rollout config 16 | """ 17 | 18 | from dataclasses import asdict, dataclass, field 19 | from typing import Any, Dict, Optional 20 | 21 | 22 | @dataclass 23 | class RolloutConfig: 24 | name: str = "vllm" 25 | n: int = 1 26 | temperature: float = 1.0 27 | top_p: float = 1.0 28 | top_k: int = -1 29 | seed: int = 1 30 | limit_images: int = 0 31 | dtype: str = "bf16" 32 | gpu_memory_utilization: float = 0.6 33 | ignore_eos: bool = False 34 | enforce_eager: bool = False 35 | enable_chunked_prefill: bool = False # only for v0 engine 36 | tensor_parallel_size: int = 2 37 | max_model_len: Optional[int] = None 38 | max_num_batched_tokens: int = 8192 39 | disable_log_stats: bool = True 40 | val_override_config: Dict[str, Any] = field(default_factory=dict) 41 | """auto keys""" 42 | prompt_length: int = field(default=-1, init=False) 43 | response_length: int = field(default=-1, init=False) 44 | trust_remote_code: bool = field(default=False, init=False) 45 | 46 | def to_dict(self): 47 | return asdict(self) 48 | -------------------------------------------------------------------------------- /verl/workers/rollout/vllm_rollout_spmd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from contextlib import contextmanager 17 | from typing import Any, Dict, List, Optional, Union 18 | 19 | import numpy as np 20 | import torch 21 | import torch.distributed 22 | from tensordict import TensorDict 23 | from transformers import PreTrainedTokenizer 24 | from vllm import LLM, RequestOutput, SamplingParams 25 | 26 | from ...protocol import DataProto 27 | from ...utils import torch_functional as VF 28 | from ...utils.tokenizer import get_processor 29 | from ...utils.torch_dtypes import PrecisionType 30 | from .base import BaseRollout 31 | from .config import RolloutConfig 32 | 33 | 34 | def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]: 35 | if isinstance(value, torch.Tensor): 36 | return value.repeat_interleave(repeats, dim=0) 37 | else: 38 | return np.repeat(value, repeats, axis=0) 39 | 40 | 41 | def _get_logit_bias(model_path: str, trust_remote_code: bool) -> Optional[Dict[int, float]]: 42 | processor = get_processor(model_path, trust_remote_code=trust_remote_code) 43 | if processor is not None and hasattr(processor, "image_token"): 44 | image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) 45 | return {image_token_id: -100} 46 | else: 47 | return None 48 | 49 | 50 | class vLLMRollout(BaseRollout): 51 | def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer): 52 | """A vLLM rollout. It requires the module is supported by the vllm. 53 | 54 | Args: 55 | module: module here follows huggingface APIs 56 | config: DictConfig 57 | tokenizer: the task/model tokenizer 58 | """ 59 | super().__init__() 60 | self.rank = int(os.getenv("RANK", "0")) 61 | self.config = config 62 | self.pad_token_id = tokenizer.pad_token_id 63 | if config.tensor_parallel_size > torch.distributed.get_world_size(): 64 | raise ValueError("Tensor parallelism size should be less than world size.") 65 | 66 | if config.max_num_batched_tokens < config.prompt_length + config.response_length: 67 | raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.") 68 | 69 | self.inference_engine = LLM( 70 | model=model_path, 71 | skip_tokenizer_init=False, 72 | trust_remote_code=config.trust_remote_code, 73 | load_format="dummy", 74 | dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)), 75 | seed=config.seed, 76 | max_model_len=config.max_model_len or config.prompt_length + config.response_length, 77 | distributed_executor_backend="external_launcher", 78 | tensor_parallel_size=config.tensor_parallel_size, 79 | gpu_memory_utilization=config.gpu_memory_utilization, 80 | max_num_batched_tokens=config.max_num_batched_tokens, 81 | disable_log_stats=config.disable_log_stats, 82 | enforce_eager=config.enforce_eager, 83 | disable_custom_all_reduce=True, 84 | limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None, 85 | disable_mm_preprocessor_cache=True, 86 | enable_chunked_prefill=config.enable_chunked_prefill, 87 | enable_sleep_mode=True, 88 | ) 89 | 90 | # Offload vllm model to reduce peak memory usage 91 | self.inference_engine.sleep(level=1) 92 | 93 | sampling_kwargs = { 94 | "max_tokens": config.response_length, 95 | "detokenize": False, 96 | "logit_bias": _get_logit_bias(model_path, trust_remote_code=config.trust_remote_code), 97 | } 98 | default_sampling_params = SamplingParams() 99 | for key in config.to_dict().keys(): 100 | if hasattr(default_sampling_params, key): 101 | sampling_kwargs[key] = getattr(config, key) 102 | 103 | print(f"Sampling params: {sampling_kwargs}.") 104 | self.sampling_params = SamplingParams(**sampling_kwargs) 105 | 106 | @contextmanager 107 | def update_sampling_params(self, **kwargs): 108 | # update sampling params 109 | old_sampling_params_args = {} 110 | if kwargs: 111 | for key, value in kwargs.items(): 112 | if hasattr(self.sampling_params, key): 113 | old_value = getattr(self.sampling_params, key) 114 | old_sampling_params_args[key] = old_value 115 | setattr(self.sampling_params, key, value) 116 | 117 | yield 118 | # roll back to previous sampling params 119 | for key, value in old_sampling_params_args.items(): 120 | setattr(self.sampling_params, key, value) 121 | 122 | @torch.no_grad() 123 | def generate_sequences(self, prompts: DataProto) -> DataProto: 124 | # left-padded attention_mask 125 | input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length) 126 | attention_mask: torch.Tensor = prompts.batch["attention_mask"] 127 | position_ids: torch.Tensor = prompts.batch["position_ids"] 128 | eos_token_id: int = prompts.meta_info["eos_token_id"] 129 | batch_size = input_ids.size(0) 130 | 131 | non_tensor_batch = prompts.non_tensor_batch 132 | if batch_size != len(non_tensor_batch["raw_prompt_ids"]): 133 | raise RuntimeError("vllm sharding manager is not work properly.") 134 | 135 | if "multi_modal_data" in non_tensor_batch: 136 | vllm_inputs = [] 137 | for raw_prompt_ids, multi_modal_data in zip( 138 | non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data") 139 | ): 140 | vllm_inputs.append({"prompt_token_ids": list(raw_prompt_ids), "multi_modal_data": multi_modal_data}) 141 | else: 142 | vllm_inputs = [ 143 | {"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") 144 | ] 145 | 146 | # users can customize different sampling_params at different run 147 | with self.update_sampling_params(**prompts.meta_info): 148 | completions: List[RequestOutput] = self.inference_engine.generate( 149 | prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=(self.rank == 0) 150 | ) 151 | response_ids = [output.token_ids for completion in completions for output in completion.outputs] 152 | response_ids = VF.pad_2d_list_to_length( 153 | response_ids, self.pad_token_id, max_length=self.config.response_length 154 | ).to(input_ids.device) 155 | 156 | if self.sampling_params.n > 1: 157 | batch_size = batch_size * self.sampling_params.n 158 | input_ids = _repeat_interleave(input_ids, self.sampling_params.n) 159 | attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) 160 | position_ids = _repeat_interleave(position_ids, self.sampling_params.n) 161 | 162 | sequence_ids = torch.cat([input_ids, response_ids], dim=-1) 163 | response_length = response_ids.size(1) 164 | delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) 165 | delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1) 166 | if position_ids.dim() == 3: # qwen2vl mrope 167 | delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) 168 | 169 | # prompt: left pad + response: right pad 170 | # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0] 171 | # position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11] 172 | response_position_ids = position_ids[..., -1:] + delta_position_id 173 | position_ids = torch.cat([position_ids, response_position_ids], dim=-1) 174 | response_mask = VF.get_response_mask( 175 | response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype 176 | ) 177 | attention_mask = torch.cat((attention_mask, response_mask), dim=-1) 178 | 179 | # all the tp ranks should contain the same data here. data in all ranks are valid 180 | batch = TensorDict( 181 | { 182 | "prompts": input_ids, 183 | "responses": response_ids, 184 | "input_ids": sequence_ids, # here input_ids become the whole sentences 185 | "attention_mask": attention_mask, 186 | "response_mask": response_mask, 187 | "position_ids": position_ids, 188 | }, 189 | batch_size=batch_size, 190 | ) 191 | return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) 192 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .base import BaseShardingManager 17 | from .fsdp_ulysses import FSDPUlyssesShardingManager 18 | from .fsdp_vllm import FSDPVLLMShardingManager 19 | 20 | 21 | __all__ = ["BaseShardingManager", "FSDPUlyssesShardingManager", "FSDPVLLMShardingManager"] 22 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Sharding manager to implement HybridEngine 16 | """ 17 | 18 | from ...protocol import DataProto 19 | 20 | 21 | class BaseShardingManager: 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, exc_type, exc_value, traceback): 26 | pass 27 | 28 | def preprocess_data(self, data: DataProto) -> DataProto: 29 | return data 30 | 31 | def postprocess_data(self, data: DataProto) -> DataProto: 32 | return data 33 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/fsdp_ulysses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT 16 | """ 17 | 18 | from torch.distributed.device_mesh import DeviceMesh 19 | 20 | from ...protocol import DataProto, all_gather_data_proto 21 | from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group 22 | from .base import BaseShardingManager 23 | 24 | 25 | class FSDPUlyssesShardingManager(BaseShardingManager): 26 | """ 27 | Sharding manager to support data resharding when using FSDP + Ulysses 28 | """ 29 | 30 | def __init__(self, device_mesh: DeviceMesh): 31 | super().__init__() 32 | self.device_mesh = device_mesh 33 | 34 | def __enter__(self): 35 | if self.device_mesh is not None: 36 | self.prev_sp_group = get_ulysses_sequence_parallel_group() 37 | set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group()) 38 | 39 | def __exit__(self, exc_type, exc_value, traceback): 40 | if self.device_mesh is not None: 41 | set_ulysses_sequence_parallel_group(self.prev_sp_group) 42 | 43 | def preprocess_data(self, data: DataProto) -> DataProto: 44 | """ 45 | AllGather data from sp region 46 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE 47 | In Ulysses, we need to make sure the same data is used across a SP group 48 | """ 49 | if self.device_mesh is not None: 50 | sp_size = self.device_mesh["sp"].size() 51 | sp_group = self.device_mesh["sp"].get_group() 52 | all_gather_data_proto(data, size=sp_size, group=sp_group) 53 | 54 | return data 55 | 56 | def postprocess_data(self, data: DataProto) -> DataProto: 57 | """ 58 | Split the data to follow FSDP partition 59 | """ 60 | if self.device_mesh is not None: 61 | sp_size = self.device_mesh["sp"].size() 62 | sp_rank = self.device_mesh["sp"].get_local_rank() 63 | data = data.chunk(chunks=sp_size)[sp_rank] 64 | 65 | return data 66 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/fsdp_vllm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Dict, Iterable, Tuple, Union 17 | 18 | import torch 19 | import torch.distributed as dist 20 | from torch.distributed._tensor import DTensor 21 | from torch.distributed.checkpoint.state_dict import get_model_state_dict 22 | from torch.distributed.device_mesh import DeviceMesh 23 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP 24 | from vllm import LLM 25 | from vllm.distributed import parallel_state as vllm_ps 26 | 27 | from ...protocol import DataProto, all_gather_data_proto 28 | from ...utils.model_utils import print_gpu_memory_usage 29 | from .base import BaseShardingManager 30 | 31 | 32 | class FSDPVLLMShardingManager(BaseShardingManager): 33 | def __init__( 34 | self, 35 | module: FSDP, 36 | inference_engine: LLM, 37 | device_mesh: DeviceMesh, 38 | ): 39 | self.module = module 40 | self.inference_engine = inference_engine 41 | self.device_mesh = device_mesh 42 | 43 | self.world_size = dist.get_world_size() 44 | self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() 45 | self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() 46 | self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group 47 | 48 | # Record freed bytes to estimate memory usage correctly 49 | # https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 50 | self.freed_bytes = 0 51 | 52 | # Note that torch_random_states may be different on each dp rank 53 | self.torch_random_states = torch.cuda.get_rng_state() 54 | # get a random rng states 55 | gen_dp_rank = self.device_mesh["dp"].get_local_rank() 56 | torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states 57 | self.gen_random_states = torch.cuda.get_rng_state() 58 | torch.cuda.set_rng_state(self.torch_random_states) 59 | 60 | def _make_weight_iterator( 61 | self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]] 62 | ) -> Iterable[Tuple[str, torch.Tensor]]: 63 | for name, tensor in actor_weights.items(): 64 | yield name, tensor.full_tensor() if self.world_size != 1 else tensor 65 | 66 | def __enter__(self): 67 | # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and 68 | # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. 69 | # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory 70 | # to speed up memory allocations. 71 | # 72 | # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management 73 | # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 74 | torch.cuda.empty_cache() 75 | print_gpu_memory_usage("Before state_dict() in sharding manager") 76 | actor_weights = get_model_state_dict(self.module) 77 | print_gpu_memory_usage("After state_dict() in sharding manager") 78 | 79 | if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: 80 | self.inference_engine.wake_up(tags=["weights"]) 81 | else: 82 | self.inference_engine.wake_up() 83 | 84 | model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model 85 | model.load_weights(self._make_weight_iterator(actor_weights)) 86 | print_gpu_memory_usage("After sync model weights in sharding manager") 87 | 88 | del actor_weights 89 | torch.cuda.empty_cache() 90 | 91 | if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: 92 | self.inference_engine.wake_up(tags=["kv_cache"]) 93 | 94 | print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager") 95 | # important: need to manually set the random states of each tp to be identical. 96 | if self.device_mesh is not None: 97 | self.torch_random_states = torch.cuda.get_rng_state() 98 | torch.cuda.set_rng_state(self.gen_random_states) 99 | 100 | def __exit__(self, exc_type, exc_value, traceback): 101 | print_gpu_memory_usage("Before vllm offload in sharding manager") 102 | free_bytes_before_sleep = torch.cuda.mem_get_info()[0] 103 | self.inference_engine.sleep(level=1) 104 | free_bytes_after_sleep = torch.cuda.mem_get_info()[0] 105 | self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep 106 | print_gpu_memory_usage("After vllm offload in sharding manager") 107 | 108 | self.module.train() 109 | torch.cuda.empty_cache() # add empty cache after each compute 110 | 111 | # restore random states 112 | if self.device_mesh is not None: 113 | self.gen_random_states = torch.cuda.get_rng_state() 114 | torch.cuda.set_rng_state(self.torch_random_states) 115 | 116 | def preprocess_data(self, data: DataProto) -> DataProto: 117 | """All gather across tp group to make each rank has identical input.""" 118 | all_gather_data_proto(data, size=self.tp_size, group=self.tp_group) 119 | return data 120 | 121 | def postprocess_data(self, data: DataProto) -> DataProto: 122 | """Get chunk data of this tp rank since we do all gather in preprocess.""" 123 | if self.tp_size > 1: 124 | data = data.chunk(chunks=self.tp_size)[self.tp_rank] 125 | 126 | return data 127 | --------------------------------------------------------------------------------