├── .flake8 ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── custom_mcp_tool.md ├── install.md └── license_header.txt ├── examples └── data_preprocess │ └── geo3k.py ├── mirorl ├── __init__.py ├── monitor │ ├── __init__.py │ └── wandb_alert.py ├── recipe │ ├── launch.sh │ └── mcp │ │ ├── config │ │ ├── mcp_trainer.yaml │ │ └── tool_config │ │ │ └── mcp_tool_config.yaml │ │ ├── main_ppo.py │ │ └── run_mirorl_14b_8xgpu.sh ├── tools │ ├── __init__.py │ ├── jina_scrape_llm_summary.py │ ├── mcp_tool.py │ ├── python_server.py │ └── serper_search.py ├── trainer │ ├── __init__.py │ └── ppo │ │ ├── __init__.py │ │ ├── core_algos.py │ │ ├── ray_trainer.py │ │ └── reward.py ├── utils │ ├── __init__.py │ ├── debug │ │ ├── __init__.py │ │ └── exception_helper.py │ ├── reward_score │ │ ├── __init__.py │ │ ├── llm_judge.py │ │ └── simpleqa.py │ └── tracking.py └── workers │ ├── __init__.py │ ├── fsdp_workers.py │ ├── reward_manager │ ├── __init__.py │ ├── tests │ │ ├── testData.jsonl │ │ ├── test_mcp_format.py │ │ └── test_others.py │ └── tool.py │ └── rollout │ ├── __init__.py │ ├── schemas.py │ └── sglang_rollout │ ├── __init__.py │ └── sglang_rollout.py ├── scripts ├── README_GRPO_Visualizer.md ├── converter_hf_to_mcore.py ├── grpo_visualizer.py ├── install.sh └── visualizer_requirements.txt └── tests ├── distributed ├── run_all.sh └── test_tensor_dict.py ├── kernels └── test_linear_cross_entropy.py ├── models ├── test_transformer.py └── test_transformers_ulysses.py ├── ray_gpu ├── detached_worker │ ├── README.md │ ├── client.py │ ├── run.sh │ └── server.py ├── test_colocated_workers.py ├── test_colocated_workers_fused.py ├── test_data_transfer.py ├── test_driverfunc_to_worker.py ├── test_high_level_scheduling_api.py ├── test_rvdz.py ├── test_worker_group_basics.py └── test_worker_group_torch.py ├── sanity ├── check_license.py ├── check_pr_title.py ├── test_config_docs.py └── test_import.py ├── test_protocol.py ├── trainer ├── __init__.py └── ppo │ ├── __init__.py │ ├── test_core_algos.py │ └── test_metric_utils.py ├── utils ├── cpu_tests │ ├── _test_module.py │ ├── test_fs.py │ ├── test_import_utils.py │ ├── test_model.py │ └── test_timeout_decorator.py └── gpu_tests │ ├── checkpoint │ └── test_fsdp_ckpt.py │ ├── dataset │ ├── test_multiturn_sft_dataset.py │ ├── test_rl_dataset.py │ ├── test_rm_dataset.py │ └── test_sft_dataset.py │ ├── megatron │ └── test_pipeline_parallel.py │ ├── test_activation_offload.py │ ├── test_flops_counter.py │ ├── test_seqlen_balancing.py │ └── test_torch_functional.py └── workers ├── reward_manager └── test_registry.py └── rollout ├── resource └── tool_configs │ ├── sandbox_fusion_tool_config │ └── search_tool_config ├── test_sglang_async_rollout_search_tools.py ├── test_sglang_async_rollout_sf_tools.py ├── test_sglang_async_rollout_w_tools.py ├── test_sglang_spmd.py └── utils_sglang.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Suggested config from pytorch that we can adapt 3 | select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2 4 | max-line-length = 120 5 | # C408 ignored because we like the dict keyword argument syntax 6 | # E501 is not flexible enough, we're using B950 instead 7 | # N812 ignored because import torch.nn.functional as F is PyTorch convention 8 | # N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP) 9 | # E731 allow usage of assigning lambda expressions 10 | ignore = 11 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731 12 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 13 | # to line this up with executable bit 14 | EXE001, 15 | # these ignores are from flake8-bugbear; please fix! 16 | B007,B008, 17 | optional-ascii-coding = True 18 | exclude = 19 | ./.git, 20 | ./.github, 21 | ./docs 22 | ./build 23 | ./scripts, 24 | ./venv, 25 | *.pyi 26 | .pre-commit-config.yaml 27 | *.md 28 | .flake8 29 | tests/torchtune/models/llama2/scripts/*.py -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | **/*.pt 3 | **/checkpoints 4 | **/wget-log 5 | **/_build/ 6 | **/*.ckpt 7 | **/outputs 8 | **/*.tar.gz 9 | **/playground 10 | **/wandb 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | dataset/* 17 | tensorflow/my_graph/* 18 | .idea/ 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | tmp/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *,cover 60 | .hypothesis/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # IPython Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # dotenv 93 | .env 94 | 95 | # virtualenv 96 | venv/ 97 | .venv/ 98 | ENV/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # vscode 107 | .vscode 108 | 109 | # Mac 110 | .DS_Store 111 | 112 | # vim 113 | *.swp 114 | 115 | # ckpt 116 | *.lock 117 | 118 | # data 119 | *.parquet 120 | 121 | 122 | # local logs 123 | logs 124 | log 125 | outputs 126 | .history 127 | 128 | train_rollout_data 129 | val_rollout_data -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "verl"] 2 | path = verl 3 | url = https://github.com/volcengine/verl.git 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'build,verl' 2 | 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | - repo: https://github.com/astral-sh/ruff-pre-commit 8 | rev: "v0.11.4" 9 | hooks: 10 | - id: ruff 11 | args: ["--fix", "--show-fixes", "--output-format=full"] 12 | exclude: ^.*\.(ipynb)$ 13 | - id: ruff-format 14 | 15 | - repo: https://github.com/pre-commit/pre-commit-hooks 16 | rev: v5.0.0 17 | hooks: 18 | - id: trailing-whitespace 19 | - id: check-ast 20 | - id: check-merge-conflict 21 | - id: no-commit-to-branch 22 | args: ['--branch=main'] 23 | - id: check-added-large-files 24 | args: ['--maxkb=1000'] 25 | - id: end-of-file-fixer 26 | exclude: '^(.*\.svg)$' 27 | 28 | - repo: https://github.com/Lucas-C/pre-commit-hooks 29 | rev: v1.5.5 30 | hooks: 31 | - id: insert-license 32 | files: \.py$|\.sh$ 33 | args: 34 | - --license-filepath 35 | - docs/license_header.txt 36 | 37 | - repo: https://github.com/pycqa/flake8 38 | rev: 7.1.1 39 | hooks: 40 | - id: flake8 41 | additional_dependencies: 42 | - flake8-bugbear == 22.4.25 43 | - pep8-naming == 0.12.1 44 | - torchfix 45 | args: ['--config=.flake8'] 46 | 47 | - repo: https://github.com/omnilib/ufmt 48 | rev: v2.8.0 49 | hooks: 50 | - id: ufmt 51 | additional_dependencies: 52 | - black == 22.12.0 53 | - usort == 1.0.5 54 | 55 | - repo: https://github.com/jsh9/pydoclint 56 | rev: 0.5.12 57 | hooks: 58 | - id: pydoclint 59 | args: [--config=pyproject.toml] 60 | -------------------------------------------------------------------------------- /docs/custom_mcp_tool.md: -------------------------------------------------------------------------------- 1 | ## Define Your Custom MCP Tool 2 | 3 | MiroRL supports two main approaches for adding custom MCP tools: 4 | 5 | ### 1. Using npx/uvx (Node.js-based MCP tools) 6 | 7 | For Node.js-based MCP tools that can be installed via npm and executed with `npx` or `uvx`, add your tool configuration file: 8 | 9 | ```yaml 10 | tools: 11 | - class_name: "mirorl.tools.mcp_tool.MCPTool" 12 | config: 13 | command: "npx" # or "uvx" 14 | args: ["your-mcp-tool-name"] 15 | env: ["YOUR_API_KEY", "HTTPS_PROXY"] # Required environment variables 16 | server_name: "your_tool_name" 17 | max_retries: 5 18 | delay_between_retries: 1 19 | connection_timeout: 5 20 | execution_timeout: 60 21 | tool_schema: 22 | type: "mcp" 23 | function: {} 24 | ``` 25 | 26 | ### 2. Using Python-based MCP tools 27 | 28 | For Python-based MCP tools using the FastMCP framework, add your tool configuration file: 29 | 30 | ```yaml 31 | tools: 32 | - class_name: "mirorl.tools.mcp_tool.MCPTool" 33 | config: 34 | command: "python" 35 | args: ["path/to/your_tool.py"] 36 | env: ["YOUR_API_KEY", "HTTPS_PROXY"] # Required environment variables 37 | server_name: "your_tool_name" 38 | max_retries: 5 39 | delay_between_retries: 1 40 | connection_timeout: 5 41 | execution_timeout: 60 42 | tool_schema: 43 | type: "mcp" 44 | function: {} 45 | ``` 46 | 47 | ### 3. Configuration Parameters 48 | 49 | All MCP tools support the following configuration parameters: 50 | 51 | | Parameter | Description | Default | Example | 52 | |-----------|-------------|---------|---------| 53 | | `command` | MCP server command to execute | - | `"npx"`, `"python"` | 54 | | `args` | Command line arguments | `[]` | `["your-tool"]`, `["tool.py"]` | 55 | | `env` | Environment variables | `[]` | `["API_KEY", "PROXY"]` | 56 | | `server_name` | MCP server identifier | - | `"your_tool_name"` | 57 | | `blacklist` | Forbidden operations | `[]` | `["scrape"]` | 58 | | `max_retries` | Maximum retry attempts | `5` | `3` | 59 | | `delay_between_retries` | Retry delay in seconds | `1` | `2` | 60 | | `connection_timeout` | Connection timeout in seconds | `5` | `10` | 61 | | `execution_timeout` | Execution timeout in seconds | `60` | `120` | 62 | 63 | ### 4. Integration with Training 64 | 65 | To use your custom MCP tool in training: 66 | 67 | 1. **Create your tool configuration file** in `mirorl/recipe/mcp/config/tool_config/` 68 | 2. **Set environment variables** required by your tool 69 | 3. **Update training command** to use your tool config: 70 | 71 | ```bash 72 | python3 -m mirorl.recipe.mcp.main_ppo \ 73 | --config-path="mirorl/recipe/mcp/config" \ 74 | --config-name='mcp_trainer' \ 75 | actor_rollout_ref.rollout.multi_turn.tool_config_path="path/to/your_tool_config.yaml" \ 76 | data.train_batch_size=256 \ 77 | trainer.n_gpus_per_node=8 78 | ``` 79 | -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Requirements 4 | 5 | First of all, to manage environment, we recommend using conda: 6 | 7 | ```bash 8 | conda create -n mirorl python==3.10.4 9 | conda activate mirorl 10 | ``` 11 | 12 | For prerequisites installation (CUDA, cuDNN, Apex), we recommend following the [verl prerequisites guide](https://verl.readthedocs.io/en/latest/start/install.html#pre-requisites) which provides detailed instructions for: 13 | 14 | - CUDA: Version >= 12.4 15 | - cuDNN: Version >= 9.8.0 16 | - Apex 17 | 18 | ### Install Dependencies 19 | 20 | #### Install python dependencies 21 | 22 | For python dependencies installation (sglang, flash-attn, flashinfer), execute the `install.sh` script that we provided in `mirorl/scripts`: 23 | 24 | ```bash 25 | # Make sure you have activated verl conda env 26 | # only for FSDP backend and sglang engine 27 | bash scripts/install.sh 28 | ``` 29 | 30 | If you encounter errors in this step, please check the script and manually follow the steps in the script. 31 | 32 | #### Install Node.js for MCP Support 33 | 34 | MCP (Model Context Protocol) requires Node.js to run MCP servers. Node.js version 18+ is recommended for optimal compatibility with MCP tools. 35 | 36 | ```bash 37 | # Download Node.js binary (example for Linux x64) 38 | wget https://nodejs.org/dist/v24.2.0/node-v24.2.0-linux-x64.tar.xz 39 | 40 | # Extract to your target path 41 | tar -xf node-v24.2.0-linux-x64.tar.xz -C /your/target/path 42 | 43 | # Add to PATH 44 | export NODEJS_HOME=/your/target/path/node-v24.2.0-linux-x64 45 | export PATH=$NODEJS_HOME/bin:$PATH 46 | export NODE_SHARED=/your/target/path/node-shared/node_modules 47 | export PATH=$NODE_SHARED/.bin:$PATH 48 | 49 | # Verify installation 50 | node --version 51 | npm --version 52 | 53 | # Install serper mcp server 54 | mkdir -p /your/target/path/node-shared 55 | cd /your/target/path/node-shared 56 | npm init -y 57 | npm install serper-search-scrape-mcp-server 58 | ``` 59 | 60 | ### Install verl 61 | 62 | For installing the target version of verl, the best way is to clone and install it from source. 63 | 64 | ```bash 65 | git clone https://github.com/MiroMindAsia/mirorl.git --recursive 66 | cd mirorl/verl 67 | pip install --no-deps -v . 68 | ``` 69 | -------------------------------------------------------------------------------- /docs/license_header.txt: -------------------------------------------------------------------------------- 1 | Copyright 2025 MiroMind Team 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /examples/data_preprocess/geo3k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Preprocess the Geometry3k dataset to parquet format 16 | """ 17 | 18 | import argparse 19 | import os 20 | 21 | import datasets 22 | 23 | from verl.utils.hdfs_io import copy, makedirs 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--local_dir", default="~/data/geo3k") 28 | parser.add_argument("--hdfs_dir", default=None) 29 | 30 | args = parser.parse_args() 31 | 32 | data_source = "hiyouga/geometry3k" 33 | 34 | dataset = datasets.load_dataset(data_source) 35 | 36 | train_dataset = dataset["train"] 37 | test_dataset = dataset["test"] 38 | 39 | instruction_following = ( 40 | r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. " 41 | r"The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}." 42 | ) 43 | 44 | # add a row to each data item that represents a unique id 45 | def make_map_fn(split): 46 | def process_fn(example, idx): 47 | problem = example.pop("problem") 48 | prompt = problem + " " + instruction_following 49 | answer = example.pop("answer") 50 | images = example.pop("images") 51 | 52 | data = { 53 | "data_source": data_source, 54 | "prompt": [ 55 | { 56 | "role": "user", 57 | "content": prompt, 58 | } 59 | ], 60 | "images": images, 61 | "ability": "math", 62 | "reward_model": {"style": "rule", "ground_truth": answer}, 63 | "extra_info": { 64 | "split": split, 65 | "index": idx, 66 | "answer": answer, 67 | "question": problem, 68 | }, 69 | } 70 | return data 71 | 72 | return process_fn 73 | 74 | train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8) 75 | test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=8) 76 | 77 | local_dir = args.local_dir 78 | hdfs_dir = args.hdfs_dir 79 | 80 | train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) 81 | test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) 82 | 83 | if hdfs_dir is not None: 84 | makedirs(hdfs_dir) 85 | copy(src=local_dir, dst=hdfs_dir) 86 | -------------------------------------------------------------------------------- /mirorl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mirorl/monitor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .wandb_alert import get_or_init_wandb, HangChecker, MCPToolChecker 16 | 17 | __all__ = ["get_or_init_wandb", "MCPToolChecker", "HangChecker"] 18 | -------------------------------------------------------------------------------- /mirorl/recipe/launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 MiroMind Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # set -ex 17 | 18 | export JOB_NAME=mirorl-odr-mcp 19 | 20 | # Setup Node.js for MCP 21 | export NODEJS_HOME=/your/path/to/nodejs 22 | export PATH=$NODEJS_HOME/bin:$PATH 23 | export NODE_SHARED=$NODEJS_HOME/node-shared/node_modules 24 | export PATH=$NODE_SHARED/.bin:$PATH 25 | 26 | # Setup proxies 27 | export HTTP_PROXY=xxx 28 | export HTTPS_PROXY=xxx 29 | export NO_PROXY=localhost,127.0.0.1 30 | 31 | # Check for singlenode flag 32 | SCRIPT=$1 33 | SINGLENODE=${SINGLENODE:-false} 34 | 35 | export LOG_DIR=$(pwd)/outputs/$MLP_TASK_ID/logs 36 | mkdir -p $LOG_DIR 37 | 38 | if [ "$SINGLENODE" == "true" ]; then 39 | bash $SCRIPT 2>&1 | tee $LOG_DIR/main_ppo.log 40 | exit 0 41 | fi 42 | 43 | #==============================================================================# 44 | 45 | export NCCL_IB_TIMEOUT=80 46 | export NCCL_IB_RETRY_CNT=20 47 | export NCCL_IB_AR_THRESHOLD=0 48 | 49 | export MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 50 | export MASTER_PORT=${MASTER_PORT:-10086} 51 | export NNODES=${NNODES:-1} 52 | export NODE_RANK=${NODE_RANK:-0} 53 | export GPUS_PER_NODE=${GPUS_PER_NODE:-8} 54 | 55 | # Compute total world size (number of processes) 56 | export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) 57 | 58 | echo "NNODES: $NNODES" 59 | echo "NODE_RANK: $NODE_RANK" 60 | echo "GPUS_PER_NODE: $GPUS_PER_NODE" 61 | echo "MASTER_PORT: $MASTER_PORT" 62 | echo "MASTER_ADDR: $MASTER_ADDR" 63 | echo "WORLD_SIZE: $WORLD_SIZE" 64 | echo "HTTP_PROXY: $HTTP_PROXY, HTTPS_PROXY: $HTTPS_PROXY" 65 | 66 | export NCCL_P2P_LEVEL=NVL 67 | export PYTHONPATH=$PWD:$PYTHONPATH 68 | 69 | export LOG_DIR=$(pwd)/outputs/$JOB_NAME/logs 70 | mkdir -p $LOG_DIR 71 | 72 | # num_nodes has to be at least 1 73 | if [ $NNODES -lt 1 ]; then 74 | echo "Number of nodes must be at least 1" 75 | exit 1 76 | fi 77 | 78 | # if HOST contains "master", then this is the head node 79 | if [[ $NODE_RANK -eq 0 ]]; then 80 | node_role="master" 81 | else 82 | node_role="worker" 83 | fi 84 | head_node_ip=${MASTER_ADDR:-127.0.0.1} 85 | 86 | wait_time=30 87 | if [ "$node_role" == "master" ]; then 88 | echo "Starting Ray head node..." 89 | # Start Ray on this node as the head node and extract its address 90 | # The `ray start --head` command outputs information that includes the address, 91 | # but here we're assuming it's known or statically assigned for simplicity. 92 | ray start --head --node-ip-address=$head_node_ip --include-dashboard=True --dashboard-host $head_node_ip --port=6379 --min-worker-port 15000 --max-worker-port 19999 93 | sleep $wait_time 94 | elif [ "$node_role" == "worker" ]; then 95 | sleep $wait_time 96 | attempt=1 97 | echo "Starting Ray worker node and attempting to connect to the head node at $head_node_ip:6379" 98 | while true; do 99 | # Attempt to start Ray and connect to the head node 100 | ray start --address="$head_node_ip:6379" --min-worker-port 15000 --max-worker-port 19999 && break || { 101 | if [ $attempt -le 5 ]; then 102 | echo "Ray worker start attempt $attempt failed. Retrying in $wait_time seconds..." 103 | ((attempt++)) 104 | sleep $wait_time 105 | else 106 | echo "Failed to connect to the head node after $wait_time attempts. Exiting." 107 | exit 1 108 | fi 109 | } 110 | done 111 | fi 112 | 113 | # run the training script once Ray has been started on all nodes 114 | sleep $wait_time 115 | if [ "$node_role" == "master" ]; then 116 | num_active_ray_nodes=$(ray list nodes | grep ALIVE | wc -l) 117 | echo "Number of active Ray nodes: $num_active_ray_nodes" 118 | if [ $num_active_ray_nodes -lt $NNODES ]; then 119 | echo "Waiting for all Ray nodes to start..." 120 | attempt=1 121 | while true; do 122 | num_active_ray_nodes=$(ray list nodes | grep ALIVE | wc -l) 123 | if [ $num_active_ray_nodes -eq $NNODES ]; then 124 | break 125 | elif [ $attempt -le 5 ]; then 126 | echo "python command attempt $attempt failed. Retrying in $wait_time seconds..." 127 | ((attempt++)) 128 | sleep $wait_time 129 | else 130 | echo "Failed to connect to the head node after $wait_time attempts. Exiting." 131 | exit 1 132 | fi 133 | done 134 | fi 135 | echo "End starting" 136 | bash ${SCRIPT} 2>&1 | tee $LOG_DIR/main_ppo.log 137 | else 138 | echo "End starting" 139 | # Continuously check the health of the Ray cluster by pinging the head node. 140 | # If the health check fails, break the loop and proceed. 141 | while true; do 142 | ray health-check --address $head_node_ip:6379 &>/dev/null 143 | if [ $? -ne 0 ]; then 144 | break 145 | fi 146 | sleep 60 147 | done 148 | fi 149 | -------------------------------------------------------------------------------- /mirorl/recipe/mcp/config/mcp_trainer.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | searchpath: 3 | - file://verl/verl/trainer/config 4 | 5 | defaults: 6 | - ppo_trainer 7 | - _self_ 8 | 9 | custom_reward_function: 10 | name: browsecomp_compute_score # available scorers: gaia_naive_compute_score, simpleqa_compute_score, browsecomp_compute_score, gpt41_compute_score, cascade_compute_score 11 | path: "./mirorl/utils/reward_score/llm_judge.py" 12 | reward_model: 13 | reward_manager: tool 14 | reward_kwargs: 15 | accuracy_reward_weight: 0.9 16 | tool_format_reward_weight: 0.1 17 | gate_tool_format_reward: False # If set to True, the tool format reward will not be given unless the final answer is correct, even if the tool format is correct. If set to False, the tool format reward will be given as long as the tool format is correct, regardless of whether the final answer is correct. 18 | async_process: True 19 | batch_size: 300 20 | max_retry: 3 21 | 22 | data: 23 | max_prompt_length: 1024 24 | max_response_length: 1024 25 | train_batch_size: 256 26 | return_raw_chat: True 27 | 28 | actor_rollout_ref: 29 | hybrid_engine: True 30 | rollout: 31 | name: sglang 32 | multi_turn: 33 | enable: True 34 | max_turns: 20 35 | 36 | # ... style tool call for multi-turn rollout 37 | use_mcp_tool_call: True 38 | 39 | # tool response cut off length(general web page should have a info density of 4 chars per token) 40 | tool_response_cut_off_length: 20000 41 | 42 | # this means both self.input_ids and self.messages will not contain the think text for intermediate rounds 43 | keep_think_text_for_last_round_only: False 44 | 45 | # think block close tag, this works for both Qwen2.5/3 and DeepSeek-R1 46 | think_block_close_tag: "" 47 | 48 | # this flag is used to ignore loss calculation for failed rollouts 49 | ignore_failed_rollouts_in_loss: False 50 | 51 | trainer: 52 | # Monitor config 53 | # Whether to enable wandb alert 54 | enable_wandb_alert: False 55 | 56 | # Tool check alert threshold 57 | # current step's tool calling error count is x times of the previous step's tool calling error count 58 | # then will trigger alert 59 | tool_check_alert_threshold: 10 60 | 61 | # Training hang check interval (in seconds) 62 | # if global step is not updated for hang_check_interval seconds, will trigger alert 63 | hang_check_interval: 3600 64 | -------------------------------------------------------------------------------- /mirorl/recipe/mcp/config/tool_config/mcp_tool_config.yaml: -------------------------------------------------------------------------------- 1 | tools: 2 | - class_name: "mirorl.tools.mcp_tool.MCPTool" 3 | config: 4 | command: "python" 5 | args: ["mirorl/tools/serper_search.py"] 6 | env: ["SERPER_API_KEY", "HTTPS_PROXY"] 7 | server_name: "search_and_scrape_webpage" 8 | tool_schema: 9 | type: "mcp" 10 | function: {} 11 | - class_name: "mirorl.tools.mcp_tool.MCPTool" 12 | config: 13 | command: "python" 14 | args: ["mirorl/tools/jina_scrape_llm_summary.py"] 15 | env: ["JINA_API_KEY", "HTTPS_PROXY", "NO_PROXY", "SUMMARY_LLM_URL", "SUMMARY_LLM_NAME"] 16 | server_name: "jina_scrape_llm_summary" 17 | execution_timeout: 600 18 | tool_schema: 19 | type: "mcp" 20 | function: {} 21 | -------------------------------------------------------------------------------- /mirorl/recipe/mcp/run_mirorl_14b_8xgpu.sh: -------------------------------------------------------------------------------- 1 | # run on 8xH100 2 | # make sure your current working directory is the root of the project 3 | # Training with Musique + GAIA text-103 for train/test 4 | 5 | set -x 6 | 7 | ulimit -n 65535 8 | 9 | PROJECT_DIR="$(pwd)" 10 | CONFIG_PATH="$PROJECT_DIR/mirorl/recipe/mcp/config" 11 | TOOL_CONFIG_PATH="$PROJECT_DIR/mirorl/recipe/mcp/config/tool_config/mcp_tool_config.yaml" 12 | MODEL_PATH=/your/model/path 13 | DATA_HOME=/your/data/path 14 | EXPERIMENT_NAME="mirorl-14b_genqa_64k" 15 | 16 | 17 | python -m mirorl.recipe.mcp.main_ppo \ 18 | --config-path="$CONFIG_PATH" \ 19 | --config-name='mcp_trainer' \ 20 | algorithm.adv_estimator=grpo \ 21 | data.train_batch_size=64 \ 22 | data.max_prompt_length=3072 \ 23 | data.max_response_length=62464 \ 24 | data.filter_overlong_prompts=True \ 25 | data.truncation='error' \ 26 | data.return_raw_chat=True \ 27 | actor_rollout_ref.model.path=$MODEL_PATH \ 28 | actor_rollout_ref.actor.optim.lr=2e-5 \ 29 | actor_rollout_ref.model.use_remove_padding=True \ 30 | actor_rollout_ref.model.use_liger=True \ 31 | actor_rollout_ref.model.use_fused_kernels=True \ 32 | actor_rollout_ref.actor.ppo_mini_batch_size=64 \ 33 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ 34 | actor_rollout_ref.actor.use_kl_loss=False \ 35 | actor_rollout_ref.actor.kl_loss_coef=0.001 \ 36 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 37 | actor_rollout_ref.actor.entropy_coeff=0 \ 38 | actor_rollout_ref.actor.use_dynamic_bsz=True \ 39 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \ 40 | actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ 41 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 42 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 43 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ 44 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ 45 | actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ 46 | actor_rollout_ref.rollout.name=sglang \ 47 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ 48 | actor_rollout_ref.rollout.n=8 \ 49 | actor_rollout_ref.rollout.multi_turn.enable_tokenization_sanity_check=False \ 50 | actor_rollout_ref.rollout.val_kwargs.do_sample=True \ 51 | actor_rollout_ref.rollout.val_kwargs.temperature=0.3 \ 52 | actor_rollout_ref.rollout.multi_turn.tool_config_path=$TOOL_CONFIG_PATH \ 53 | actor_rollout_ref.rollout.multi_turn.max_turns=40 \ 54 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ 55 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 56 | algorithm.use_kl_in_reward=False \ 57 | trainer.critic_warmup=0 \ 58 | trainer.logger=['console','wandb'] \ 59 | trainer.project_name='odr-mcp' \ 60 | trainer.experiment_name=$EXPERIMENT_NAME \ 61 | trainer.n_gpus_per_node=8 \ 62 | trainer.nnodes=1 \ 63 | trainer.save_freq=30 \ 64 | trainer.test_freq=5 \ 65 | trainer.val_before_train=True \ 66 | trainer.rollout_data_dir=$PROJECT_DIR/train_rollout_data/$EXPERIMENT_NAME \ 67 | trainer.validation_data_dir=$PROJECT_DIR/val_rollout_data/$EXPERIMENT_NAME \ 68 | data.train_files=$DATA_HOME/genqa/train.parquet \ 69 | data.val_files=$DATA_HOME/genqa/val.parquet \ 70 | trainer.total_epochs=15 $@ 71 | -------------------------------------------------------------------------------- /mirorl/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mirorl/tools/mcp_tool.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import asyncio 16 | import json 17 | import logging 18 | import os 19 | from datetime import timedelta 20 | from typing import Any, Optional, Tuple 21 | from uuid import uuid4 22 | 23 | import exceptiongroup 24 | 25 | import ray 26 | from mcp import ClientSession, StdioServerParameters 27 | from mcp.client.stdio import stdio_client 28 | 29 | from verl.tools.base_tool import BaseTool 30 | from verl.tools.schemas import OpenAIFunctionToolSchema 31 | 32 | from mirorl.utils.debug.exception_helper import extract_exception_details 33 | 34 | logger = logging.getLogger(__name__) 35 | logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) 36 | 37 | 38 | def is_timeout_error(error: Exception) -> bool: 39 | """Check if an error is a timeout-related error.""" 40 | error_str = str(error) 41 | return any( 42 | keyword in error_str 43 | for keyword in ["ETIMEDOUT", "ECONNRESET", "Timeout", "Timed out"] 44 | ) 45 | 46 | 47 | class MCPTool(BaseTool): 48 | """A tool for calling MCP tools. 49 | 50 | - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. 51 | - `create`: create a tool instance for a trajectory. 52 | - `execute`: execute the tool. 53 | - `calc_reward`: calculate the reward respect to tool state. 54 | - `release`: release the tool instance. 55 | """ 56 | 57 | def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): 58 | super().__init__(config, tool_schema) 59 | self._instance_dict = {} 60 | 61 | self.max_retries = config.get("max_retries", 5) 62 | self.delay_between_retries = config.get( 63 | "delay_between_retries", 1 64 | ) # in seconds 65 | self.connection_timeout = config.get("connection_timeout", 5) 66 | self.execution_timeout = config.get("execution_timeout", 60) 67 | 68 | # Get command, args, and env from config 69 | self.params = StdioServerParameters( 70 | command=config["command"], 71 | args=config.get("args", []), 72 | env={e: os.environ.get(e) for e in config["env"]} 73 | if "env" in config.keys() 74 | else {}, 75 | ) 76 | 77 | def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: 78 | return self.tool_schema 79 | 80 | async def create( 81 | self, 82 | instance_id: Optional[str] = None, 83 | ground_truth: Optional[str] = None, 84 | **kwargs, 85 | ) -> str: 86 | if instance_id is None: 87 | instance_id = str(uuid4()) 88 | # TODO: Add all create_kwargs to dict 89 | self._instance_dict[instance_id] = { 90 | "response": "", 91 | "ground_truth": ground_truth, 92 | "reward": 0.0, 93 | } 94 | return instance_id 95 | 96 | async def _execute(self, parameters: dict[str, Any]): 97 | response = "" 98 | 99 | async with stdio_client(self.params) as (read, write): 100 | async with ClientSession( 101 | read, 102 | write, 103 | read_timeout_seconds=timedelta(seconds=self.connection_timeout), 104 | ) as session: 105 | for attempt in range(self.max_retries): 106 | try: 107 | await session.initialize() 108 | 109 | result = await session.call_tool( 110 | self.name, 111 | arguments=parameters, 112 | read_timeout_seconds=timedelta( 113 | seconds=self.execution_timeout 114 | ), 115 | ) 116 | response = result.content[0].text if result.content else "" 117 | if attempt > 0: 118 | logger.error(f"Attempt {attempt + 1} success") 119 | break # Exit loop if successful 120 | except Exception as e: 121 | # The error type is McpError, consistently having an error code of -32603. 122 | # To determine if the failed connection is network-related, we check the message. 123 | if is_timeout_error(e) and attempt < self.max_retries - 1: 124 | logger.error(f"Attempt {attempt + 1} failed: {e}") 125 | await asyncio.sleep(self.delay_between_retries) 126 | else: 127 | response = f"Tool execution failed: {e}" 128 | 129 | try: 130 | mcp_tool_checker = ray.get_actor( 131 | name="mcp_tool_checker", namespace="monitor" 132 | ) 133 | ray.get( 134 | mcp_tool_checker.record_error.remote( 135 | str(e), self.name 136 | ) 137 | ) 138 | except Exception as e: 139 | print( 140 | f"MCP tool checker not found, skip recording error: {e}" 141 | ) 142 | 143 | break # Exit loop if the error is not timed out 144 | 145 | return response 146 | 147 | async def execute( 148 | self, instance_id: str, parameters: dict[str, Any], **kwargs 149 | ) -> Tuple[str, float, dict]: 150 | # Call MCP tool with retry logic 151 | response = "" 152 | 153 | for attempt in range(self.max_retries): 154 | try: 155 | response = await self._execute(parameters) 156 | break 157 | except Exception as e: 158 | if isinstance(e, exceptiongroup.ExceptionGroup): 159 | details = extract_exception_details(e) 160 | logger.error( 161 | f"MCPTool.execute attempt {attempt + 1} failed: {e}\n" 162 | f"exception group details: {json.dumps(details, indent=4)}" 163 | ) 164 | else: 165 | logger.error(f"MCPTool.execute attempt {attempt + 1} failed: {e}") 166 | 167 | if attempt == self.max_retries: 168 | logger.error( 169 | f"MCPTool.execute failed after {self.max_retries} attempts, returning empty response" 170 | ) 171 | 172 | self._instance_dict[instance_id]["response"] = response 173 | 174 | # NOTE: tool_reward is not used in anywhere 175 | return response, 0.0, {} 176 | 177 | async def calc_reward(self, instance_id: str, **kwargs) -> float: 178 | # NOTE: tool_reward is not used in anywhere 179 | return self._instance_dict[instance_id]["reward"] 180 | 181 | async def release(self, instance_id: str, **kwargs) -> None: 182 | if instance_id in self._instance_dict: 183 | del self._instance_dict[instance_id] 184 | -------------------------------------------------------------------------------- /mirorl/tools/python_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | 16 | from e2b_code_interpreter import Sandbox 17 | from mcp.server.fastmcp import FastMCP 18 | 19 | # Initialize FastMCP server 20 | mcp = FastMCP("e2b-python-interpreter") 21 | 22 | # API keys 23 | E2B_API_KEY = os.environ.get("E2B_API_KEY") 24 | 25 | # DEFAULT CONFS 26 | DEFAULT_TIMEOUT = 300 # seconds 27 | 28 | 29 | @mcp.tool() 30 | async def run_python_code( 31 | code_block, timeout: int = DEFAULT_TIMEOUT, sandbox_id=None 32 | ) -> str: 33 | """Run python code in an interperter and return the execution result. 34 | 35 | Args: 36 | code_block: The python code to run. 37 | timeout: Time in seconds before the sandbox is automatically shutdown. The default is 300 seconds. 38 | sandbox_id: The id of the sandbox to run the code in. Reuse existing sandboxes whenever possible. 39 | Only create new ones if this is the first time running code in this sandbox. 40 | 41 | Returns: 42 | An object containing the sandbox id and the execution result object including results, logs and errors. 43 | """ 44 | if sandbox_id: 45 | sandbox = Sandbox.connect(sandbox_id, api_key=E2B_API_KEY) 46 | else: 47 | sandbox = Sandbox(timeout=timeout, api_key=E2B_API_KEY) 48 | 49 | execution = sandbox.run_code(code_block) 50 | 51 | sandbox_id = sandbox.get_info().sandbox_id 52 | 53 | return dict(execution_result=execution, sandbox_id=sandbox_id) 54 | 55 | 56 | @mcp.tool() 57 | async def upload_local_file_to_python_interpreter( 58 | local_file_path: str, sandbox_id=None 59 | ) -> str: 60 | """Upload a local file to the `/home/user` dir of the remote python interpreter. 61 | 62 | Args: 63 | file_path: The path of the file on local machine to upload. 64 | sandbox_id: The id of the sandbox to run the code in. Reuse existing sandboxes whenever possible. 65 | Only create new ones if this is the first time running code in this sandbox. 66 | 67 | Returns: 68 | The path of the uploaded file in the remote python interpreter. 69 | """ 70 | if sandbox_id: 71 | sandbox = Sandbox.connect(sandbox_id, api_key=E2B_API_KEY) 72 | else: 73 | sandbox = Sandbox(api_key=E2B_API_KEY) 74 | 75 | if not os.path.exists(local_file_path): 76 | raise FileNotFoundError(f"File {local_file_path} does not exist.") 77 | 78 | # Get the uploaded file path 79 | uploaded_file_path = os.path.join("/home/user", os.path.basename(local_file_path)) 80 | 81 | # Upload the file 82 | with open(local_file_path, "rb") as f: 83 | sandbox.files.write(uploaded_file_path, f) 84 | 85 | sandbox_id = sandbox.get_info().sandbox_id 86 | 87 | return dict(uploaded_file_path=uploaded_file_path, sandbox_id=sandbox_id) 88 | 89 | 90 | @mcp.tool() 91 | async def download_internet_file_to_python_interpreter( 92 | url: str, sandbox_id=None 93 | ) -> str: 94 | """Download a file from the internet to the `/home/user` dir of the remote python interpreter. 95 | 96 | Args: 97 | url: The URL of the file to download. 98 | sandbox_id: The id of the sandbox to run the code in. Reuse existing sandboxes whenever possible. 99 | Only create new ones if this is the first time running code in this sandbox. 100 | 101 | Returns: 102 | The path of the downloaded file in the python interpreter. 103 | """ 104 | if sandbox_id: 105 | sandbox = Sandbox.connect(sandbox_id, api_key=E2B_API_KEY) 106 | else: 107 | sandbox = Sandbox(api_key=E2B_API_KEY) 108 | 109 | downloaded_file_path = os.path.join("/home/user", os.path.basename(url)) 110 | 111 | # Download the file 112 | sandbox.commands.run(f"wget {url} -O {downloaded_file_path}") 113 | 114 | sandbox_id = sandbox.get_info().sandbox_id 115 | 116 | return dict(downloaded_file_path=downloaded_file_path, sandbox_id=sandbox_id) 117 | 118 | 119 | if __name__ == "__main__": 120 | mcp.run(transport="stdio") 121 | -------------------------------------------------------------------------------- /mirorl/tools/serper_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | from typing import Any, Dict 18 | 19 | import requests 20 | from mcp.server.fastmcp import FastMCP 21 | from tenacity import ( 22 | retry, 23 | retry_if_exception_type, 24 | stop_after_attempt, 25 | wait_exponential, 26 | ) 27 | 28 | 29 | # Configure logging 30 | logging.basicConfig(level=logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | # Initialize FastMCP server 35 | mcp = FastMCP("search_and_scrape_webpage") 36 | 37 | 38 | @retry( 39 | stop=stop_after_attempt(3), 40 | wait=wait_exponential(multiplier=1, min=4, max=10), 41 | retry=retry_if_exception_type( 42 | (requests.ConnectionError, requests.Timeout, requests.HTTPError) 43 | ), 44 | ) 45 | def make_serper_request( 46 | payload: Dict[str, Any], headers: Dict[str, str] 47 | ) -> requests.Response: 48 | """Make HTTP request to Serper API with retry logic.""" 49 | response = requests.post( 50 | "https://google.serper.dev/search", json=payload, headers=headers 51 | ) 52 | response.raise_for_status() 53 | return response 54 | 55 | 56 | def _is_huggingface_dataset_or_space_url(url): 57 | """ 58 | Check if the URL is a Hugging Face dataset or space URL. 59 | :param url: The URL to check 60 | :return: True if it's a HuggingFace dataset or space URL, False otherwise 61 | """ 62 | if not url: 63 | return False 64 | return "huggingface.co/datasets" in url or "huggingface.co/spaces" in url 65 | 66 | 67 | @mcp.tool() 68 | def google_search( 69 | q: str, 70 | gl: str = "us", 71 | hl: str = "en", 72 | location: str = None, 73 | num: int = None, 74 | tbs: str = None, 75 | page: int = None, 76 | autocorrect: bool = None, 77 | ) -> Dict[str, Any]: 78 | """ 79 | Tool to perform web searches via Serper API and retrieve rich results. 80 | 81 | It is able to retrieve organic search results, people also ask, 82 | related searches, and knowledge graph. 83 | 84 | Args: 85 | q: Search query string 86 | gl: Optional region code for search results in ISO 3166-1 alpha-2 format (e.g., 'us') 87 | hl: Optional language code for search results in ISO 639-1 format (e.g., 'en') 88 | location: Optional location for search results (e.g., 'SoHo, New York, United States', 'California, United States') 89 | num: Number of results to return (default: 10) 90 | tbs: Time-based search filter ('qdr:h' for past hour, 'qdr:d' for past day, 'qdr:w' for past week, 91 | 'qdr:m' for past month, 'qdr:y' for past year) 92 | page: Page number of results to return (default: 1) 93 | autocorrect: Whether to autocorrect spelling in query 94 | 95 | Returns: 96 | Dictionary containing search results and metadata 97 | """ 98 | 99 | # Check for API key 100 | serper_api_key = os.getenv("SERPER_API_KEY") 101 | if not serper_api_key: 102 | return { 103 | "success": False, 104 | "error": "SERPER_API_KEY environment variable not set", 105 | "results": [], 106 | } 107 | 108 | # Validate required parameter 109 | if not q or not q.strip(): 110 | return { 111 | "success": False, 112 | "error": "Search query 'q' is required and cannot be empty", 113 | "results": [], 114 | } 115 | 116 | try: 117 | # Build payload with all supported parameters 118 | payload = { 119 | "q": q.strip(), 120 | "gl": gl, 121 | "hl": hl, 122 | } 123 | 124 | # Add optional parameters if provided 125 | if location: 126 | payload["location"] = location 127 | if num is not None: 128 | payload["num"] = num 129 | else: 130 | payload["num"] = 10 # Default 131 | if tbs: 132 | payload["tbs"] = tbs 133 | if page is not None: 134 | payload["page"] = page 135 | if autocorrect is not None: 136 | payload["autocorrect"] = autocorrect 137 | 138 | # Set up headers 139 | headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"} 140 | 141 | # Make the API request 142 | response = make_serper_request(payload, headers) 143 | data = response.json() 144 | 145 | # filter out huggingface dataset or space urls 146 | organic_results = [] 147 | if "organic" in data: 148 | for item in data["organic"]: 149 | if _is_huggingface_dataset_or_space_url(item.get("link", "")): 150 | continue 151 | organic_results.append(item) 152 | 153 | # Build comprehensive response 154 | response_data = { 155 | "organic": organic_results, 156 | "searchParameters": data.get("searchParameters", {}), 157 | } 158 | 159 | return response_data 160 | 161 | except Exception as e: 162 | return {"success": False, "error": f"Unexpected error: {str(e)}", "results": []} 163 | 164 | 165 | if __name__ == "__main__": 166 | mcp.run() 167 | -------------------------------------------------------------------------------- /mirorl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mirorl/trainer/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mirorl/trainer/ppo/core_algos.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | def ignore_failed_rollouts_in_loss(data, rollouts_mask: torch.Tensor): 19 | """ 20 | Mask the failed rollouts in the loss calculation. 21 | 22 | Args: 23 | data: DataProto object, containing batch, non_tensor_batch and meta_info 24 | rollouts_mask: torch.Tensor, shape (rollouts_num,) 25 | 26 | """ 27 | data.batch["loss_mask"][rollouts_mask] = 0 28 | -------------------------------------------------------------------------------- /mirorl/trainer/ppo/reward.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Individual Contributor: Thibaut Barroyer 2 | # Copyright 2025 MiroMind Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import multiprocessing 17 | import os 18 | from functools import partial 19 | 20 | import ray 21 | 22 | from verl import DataProto 23 | 24 | # Adapted from verl/trainer/ppo/reward.py 25 | # Updated by MiroMind Team 26 | # import default_compute_score from mirorl.utils.reward_score 27 | from mirorl.utils.reward_score import default_compute_score 28 | 29 | 30 | def get_custom_reward_fn(config): 31 | import importlib.util 32 | import sys 33 | 34 | reward_fn_config = config.get("custom_reward_function") or {} 35 | file_path = reward_fn_config.get("path") 36 | if not file_path: 37 | return None 38 | 39 | if not os.path.exists(file_path): 40 | raise FileNotFoundError(f"Reward function file '{file_path}' not found.") 41 | 42 | spec = importlib.util.spec_from_file_location("custom_module", file_path) 43 | module = importlib.util.module_from_spec(spec) 44 | try: 45 | sys.modules["custom_module"] = module 46 | spec.loader.exec_module(module) 47 | except Exception as e: 48 | raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e 49 | 50 | function_name = reward_fn_config.get("name") 51 | if not hasattr(module, function_name): 52 | raise AttributeError( 53 | f"Reward function '{function_name}' not found in '{file_path}'." 54 | ) 55 | 56 | print(f"using customized reward function '{function_name}' from '{file_path}'") 57 | raw_fn = getattr(module, function_name) 58 | 59 | reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) 60 | 61 | def wrapped_fn(*args, **kwargs): 62 | return raw_fn(*args, **kwargs, **reward_kwargs) 63 | 64 | return wrapped_fn 65 | 66 | 67 | def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): 68 | """ 69 | Load and initialize a reward manager based on the configuration. 70 | 71 | Args: 72 | config: PPO trainer configuration object containing reward_model fields. 73 | tokenizer: Tokenizer object used for processing text. 74 | num_examine: Number of samples to examine. 75 | **reward_kwargs: Additional keyword arguments for the reward manager. 76 | 77 | Returns: 78 | An instance of the specified reward manager class. 79 | """ 80 | from verl.workers.reward_manager import get_reward_manager_cls 81 | 82 | # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: 83 | # naive: NaiveRewardManager 84 | # prime: PrimeRewardManager 85 | # batch: BatchRewardManager 86 | # dapo: DAPORewardManager 87 | # Note(haibin.lin): For custom reward managers, please make sure they are imported and 88 | # registered via `verl.workers.reward_manager.register` 89 | # By default reward_manager is set to naive (NaiveRewardManager) 90 | reward_manager_name = config.reward_model.get("reward_manager", "naive") 91 | reward_manager_cls = get_reward_manager_cls(reward_manager_name) 92 | 93 | # Try to get a custom reward function based on the configuration 94 | compute_score = get_custom_reward_fn(config) 95 | final_compute_score = compute_score 96 | 97 | if compute_score is None: 98 | sandbox_config = config.reward_model.get("sandbox_fusion") 99 | sandbox_url = sandbox_config.get("url") if sandbox_config else None 100 | if sandbox_url: 101 | sandbox_manager = multiprocessing.Manager() 102 | # Create a semaphore to control concurrent access to the sandbox 103 | _concurrent_semaphore = sandbox_manager.Semaphore( 104 | sandbox_config.get("max_concurrent", 64) 105 | ) 106 | final_compute_score = partial( 107 | default_compute_score, 108 | sandbox_fusion_url=sandbox_url, 109 | concurrent_semaphore=_concurrent_semaphore, 110 | ) 111 | else: 112 | final_compute_score = default_compute_score 113 | 114 | # Instantiate and return the reward manager with the specified parameters 115 | return reward_manager_cls( 116 | tokenizer=tokenizer, 117 | num_examine=num_examine, 118 | compute_score=final_compute_score, 119 | reward_fn_key=config.data.reward_fn_key, 120 | **reward_kwargs, 121 | ) 122 | 123 | 124 | def compute_reward(data: DataProto, reward_fn): 125 | """ 126 | Compute reward for a batch of data. 127 | Args: 128 | data: DataProto object containing the input data. 129 | reward_fn: Reward function to compute the reward. 130 | Returns: 131 | Tuple of reward tensor and extra info dictionary. 132 | """ 133 | try: 134 | reward_result = reward_fn(data, return_dict=True) 135 | reward_tensor = reward_result["reward_tensor"] 136 | reward_extra_infos_dict = reward_result["reward_extra_info"] 137 | except Exception as e: 138 | print(f"Error in reward_fn: {e}") 139 | reward_tensor = reward_fn(data) 140 | reward_extra_infos_dict = {} 141 | 142 | return reward_tensor, reward_extra_infos_dict 143 | 144 | 145 | @ray.remote(num_cpus=1) 146 | def compute_reward_async(data: DataProto, config, tokenizer): 147 | """ 148 | Load the reward manager and compute the reward for a batch of data. 149 | This is meant to be run in a separate Ray worker. 150 | """ 151 | reward_fn = load_reward_manager( 152 | config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) 153 | ) 154 | return compute_reward(data, reward_fn) 155 | -------------------------------------------------------------------------------- /mirorl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mirorl/utils/debug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mirorl/utils/debug/exception_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import traceback 16 | from typing import Any, Dict 17 | 18 | 19 | def extract_exception_details(exception_group) -> Dict[str, Any]: 20 | """ 21 | Extract detailed exception information from ExceptionGroup 22 | 23 | Args: 24 | exception_group: ExceptionGroup instance 25 | 26 | Returns: 27 | Dictionary containing detailed exception information 28 | """ 29 | details = { 30 | "message": str(exception_group), 31 | "type": type(exception_group).__name__, 32 | "exceptions": [], 33 | "traceback": traceback.format_exc(), 34 | } 35 | 36 | # Extract all sub-exceptions 37 | for i, exc in enumerate(exception_group.exceptions): 38 | exc_info = { 39 | "index": i, 40 | "type": type(exc).__name__, 41 | "message": str(exc), 42 | "traceback": "".join( 43 | traceback.format_exception(type(exc), exc, exc.__traceback__) 44 | ), 45 | } 46 | details["exceptions"].append(exc_info) 47 | 48 | return details 49 | -------------------------------------------------------------------------------- /mirorl/utils/reward_score/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2025 MiroMind Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # from . import gsm8k, math, prime_math, prime_code 16 | 17 | from verl.utils.import_utils import deprecated 18 | 19 | 20 | # Adapted from verl/utils/reward_score/__init__.py 21 | # Updated by MiroMind Team 22 | # 1. Add custom reward function support 23 | def default_compute_score( 24 | data_source, 25 | solution_str, 26 | ground_truth, 27 | extra_info=None, 28 | sandbox_fusion_url=None, 29 | concurrent_semaphore=None, 30 | ): 31 | """Compute the score for a given solution based on the data source. 32 | 33 | Args: 34 | data_source (str): The source dataset identifier which determines the scoring method. 35 | solution_str (str): The solution string to be evaluated. 36 | ground_truth (str): The ground truth answer for comparison. 37 | extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None. 38 | 39 | Returns: 40 | float: The computed score as a floating point number. If the result is a dictionary, 41 | it returns the dictionary instead. 42 | 43 | Raises: 44 | NotImplementedError: If the reward function is not implemented for the given data source. 45 | """ 46 | if data_source == "openai/gsm8k": 47 | from verl.utils.reward_score import gsm8k 48 | 49 | res = gsm8k.compute_score(solution_str, ground_truth) 50 | elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval"]: 51 | from verl.utils.reward_score import math 52 | 53 | res = math.compute_score(solution_str, ground_truth) 54 | # [Optional] Math-Verify Integration 55 | # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify). 56 | # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`. 57 | # To use it, override the `compute_score` function with the following implementation: 58 | 59 | # from . import math_verify 60 | # res = math_verify.compute_score(solution_str, ground_truth) 61 | elif data_source == "math_dapo" or data_source.startswith("aime"): 62 | from verl.utils.reward_score import math_dapo 63 | 64 | res = math_dapo.compute_score(solution_str, ground_truth) 65 | elif data_source in [ 66 | "numina_aops_forum", 67 | "numina_synthetic_math", 68 | "numina_amc_aime", 69 | "numina_synthetic_amc", 70 | "numina_cn_k12", 71 | "numina_olympiads", 72 | ]: 73 | from verl.utils.reward_score import prime_math 74 | 75 | res = prime_math.compute_score(solution_str, ground_truth) 76 | elif data_source in ["codecontests", "apps", "codeforces", "taco"]: 77 | # Use the passed sandbox_fusion_url if available 78 | if sandbox_fusion_url: 79 | from verl.utils.reward_score import sandbox_fusion 80 | 81 | # Pass the URL directly, ground_truth likely contains test cases here 82 | res = sandbox_fusion.compute_score( 83 | sandbox_fusion_url, 84 | concurrent_semaphore, 85 | solution_str, 86 | ground_truth, 87 | continuous=True, 88 | ) 89 | else: 90 | # If no sandbox URL is provided, fall back to prime_code or raise error 91 | from verl.utils.reward_score import prime_code 92 | 93 | # Assuming prime_code doesn't need the URL 94 | res = prime_code.compute_score(solution_str, ground_truth, continuous=True) 95 | elif data_source in ["hiyouga/geometry3k"]: 96 | from verl.utils.reward_score import geo3k 97 | 98 | res = geo3k.compute_score(solution_str, ground_truth) 99 | elif data_source in [ 100 | "searchR1_nq", 101 | "searchR1_triviaqa", 102 | "searchR1_popqa", 103 | "searchR1_hotpotqa", 104 | "searchR1_2wikimultihopqa", 105 | "searchR1_musique", 106 | "searchR1_bamboogle", 107 | ]: 108 | from verl.utils.reward_score import search_r1_like_qa_em 109 | 110 | res = search_r1_like_qa_em.compute_score(solution_str, ground_truth) 111 | elif data_source in [ 112 | "basicv8vc/SimpleQA", 113 | "yixuantt/MultiHopRAG", 114 | "callanwu/WebWalkerQA", 115 | ]: 116 | from . import simpleqa 117 | 118 | res = simpleqa.compute_score(solution_str, ground_truth) 119 | else: 120 | raise NotImplementedError( 121 | f"Reward function is not implemented for {data_source=}" 122 | ) 123 | 124 | if isinstance(res, dict): 125 | return res 126 | elif isinstance(res, (int, float, bool)): 127 | return float(res) 128 | else: 129 | return float(res[0]) 130 | 131 | 132 | @deprecated("verl.utils.reward_score.default_compute_score") 133 | def _default_compute_score( 134 | data_source, 135 | solution_str, 136 | ground_truth, 137 | extra_info=None, 138 | sandbox_fusion_url=None, 139 | concurrent_semaphore=None, 140 | ): 141 | """ 142 | Legacy function API to be deprecated. Please use `default_compute_score` instead. 143 | """ 144 | return default_compute_score( 145 | data_source, 146 | solution_str, 147 | ground_truth, 148 | extra_info, 149 | sandbox_fusion_url, 150 | concurrent_semaphore, 151 | ) 152 | 153 | 154 | __all__ = ["default_compute_score"] 155 | -------------------------------------------------------------------------------- /mirorl/utils/reward_score/simpleqa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py 15 | 16 | import regex 17 | 18 | 19 | def extract_solution(solution_str): 20 | """Extract the equation from the solution string.""" 21 | try: 22 | matches = regex.findall( 23 | r"\\boxed\{((?:[^{}]+|\{(?1)*\})*)\}", 24 | solution_str, 25 | regex.DOTALL, 26 | timeout=10, 27 | ) 28 | except TimeoutError as _: 29 | matches = [] 30 | 31 | # If there are 0 matches, return None 32 | if len(matches) < 1: 33 | return None 34 | 35 | # If there are 2 or more matches, return the last one 36 | return matches[-1].strip() 37 | 38 | 39 | def compute_score( 40 | solution_str, ground_truth, method="strict", format_score=0.0, score=1.0 41 | ): 42 | """The scoring function for exact match (EM). 43 | 44 | Args: 45 | solution_str: the solution text 46 | ground_truth: the ground truth 47 | method: the method to extract the solution, choices are 'strict' and 'flexible' 48 | format_score: the score for the format 49 | score: the score for the correct answer 50 | """ 51 | answer = extract_solution(solution_str=solution_str) 52 | do_print = False 53 | 54 | if do_print: 55 | print("--------------------------------") 56 | print(f"Golden answers: {ground_truth}") 57 | if answer is not None: 58 | print(f"Extracted answer is not None: {answer}") 59 | else: 60 | print("Extracted answer: None!") 61 | print(f"Solution string: {solution_str}") 62 | 63 | if answer is None: 64 | return 0 65 | else: 66 | if answer == ground_truth: 67 | return score 68 | else: 69 | return format_score 70 | -------------------------------------------------------------------------------- /mirorl/utils/tracking.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2025 MiroMind Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | A unified tracking interface that supports logging data to different backend 17 | """ 18 | 19 | from typing import List, Union 20 | 21 | from verl.utils.tracking import Tracking 22 | 23 | from wandb import AlertLevel 24 | 25 | 26 | # Adapted from verl/utils/tracking.py 27 | # Updated by MiroMind Team 28 | # 1. Add alert function 29 | class MonitorTracking(Tracking): 30 | """A unified tracking interface for logging experiment data to multiple backends. 31 | 32 | This class provides a centralized way to log experiment metrics, parameters, and artifacts 33 | to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console. 34 | 35 | Attributes: 36 | supported_backend: List of supported tracking backends. 37 | logger: Dictionary of initialized logger instances for each backend. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | project_name, 43 | experiment_name, 44 | default_backend: Union[str, List[str]] = "console", 45 | config=None, 46 | ): 47 | super().__init__(project_name, experiment_name, default_backend, config) 48 | 49 | def alert(self, title, text, level=AlertLevel.WARN, wait_duration=180): 50 | assert "wandb" in self.logger, "wandb is not in the tracking" 51 | self.logger["wandb"].alert( 52 | title=title, 53 | text=text, 54 | level=level, 55 | wait_duration=wait_duration, 56 | ) 57 | -------------------------------------------------------------------------------- /mirorl/workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .reward_manager import ToolRewardManager 16 | from .rollout import MCPSGLangRollout 17 | 18 | __all__ = ["ToolRewardManager", "MCPSGLangRollout"] 19 | -------------------------------------------------------------------------------- /mirorl/workers/reward_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .tool import ToolRewardManager 16 | 17 | __all__ = ["ToolRewardManager"] 18 | -------------------------------------------------------------------------------- /mirorl/workers/reward_manager/tests/test_others.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import json 15 | import time 16 | from mirorl.workers.reward_manager.tool import run_async_computation 17 | from mirorl.utils.reward_score.llm_judge import extract_boxed_answer, gaia_naive_compute_score 18 | from openai import AsyncOpenAI 19 | import os 20 | from unittest.mock import Mock 21 | 22 | def build_data(num_test_samples): 23 | # read file from testData.jsonl 24 | data = [] 25 | with open("tests/testData.jsonl", "r") as f: 26 | for line in f: 27 | data.append(json.loads(line)) 28 | 29 | prompt_str_list = [item["prompt_str"] for item in data[:num_test_samples]] 30 | response_str_list = [item["response_str"] for item in data[:num_test_samples]] 31 | ground_truth_list = [item["ground_truth"] for item in data[:num_test_samples]] 32 | data_source_list = [item["data_source"] for item in data[:num_test_samples]] 33 | extra_info_list = [item["extra_info"] for item in data[:num_test_samples]] 34 | valid_response_length_list = [item["valid_response_length"] for item in data[:num_test_samples]] 35 | 36 | return prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list 37 | 38 | 39 | # test the async process function 40 | def async_process_function(prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list, openai_client, batch_size=100, max_retry=3): 41 | start_time = time.time() 42 | print("Start computing accuracy reward with async api, this may take a while...") 43 | print(f"Async configuration: batch_size={batch_size}, max_retry={max_retry}") 44 | try: 45 | # Run async computation using the new async functions 46 | accuracy_reward_list = run_async_computation( 47 | openai_client, 48 | data_source_list, 49 | response_str_list, 50 | ground_truth_list, 51 | extra_info_list, 52 | batch_size, 53 | max_retry 54 | ) 55 | 56 | end_time = time.time() 57 | print(f"Async batch processing completed in {end_time - start_time:.2f} seconds") 58 | 59 | except Exception as e: 60 | # if error, we use gaia_naive_compute_score to compute accuracy reward 61 | print(f"Error in batch api: {e}") 62 | print("Falling back to gaia_naive_compute_score...") 63 | accuracy_reward_list = [] 64 | for i in range(len(response_str_list)): 65 | accuracy_reward = gaia_naive_compute_score( 66 | data_source=data_source_list[i], 67 | solution_str=response_str_list[i], 68 | ground_truth=ground_truth_list[i], 69 | extra_info=extra_info_list[i], 70 | ) 71 | accuracy_reward_list.append(accuracy_reward) 72 | 73 | for i in range(len(accuracy_reward_list)): 74 | extracted = extract_boxed_answer(response_str_list[i]) 75 | ground_truth = ground_truth_list[i] 76 | reward = accuracy_reward_list[i] 77 | print(f"{extracted:<40} | {ground_truth:<40} | {reward}") 78 | 79 | 80 | def test_exception_handling(num_test_samples): 81 | """Test the exception handling and fallback mechanism""" 82 | print("=== Testing Exception Handling ===") 83 | 84 | # Load test data 85 | prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list = build_data(num_test_samples) 86 | 87 | # Create a mock OpenAI client that always raises an exception 88 | mock_client = Mock() 89 | mock_client.chat.completions.create.side_effect = Exception("API Error: Rate limit exceeded") 90 | 91 | # Call the function with the mock client that will fail 92 | async_process_function( 93 | prompt_str_list, 94 | response_str_list, 95 | ground_truth_list, 96 | data_source_list, 97 | extra_info_list, 98 | valid_response_length_list, 99 | mock_client, 100 | batch_size=100, 101 | max_retry=1 102 | ) 103 | 104 | print("=== Exception handling test completed ===") 105 | 106 | 107 | 108 | 109 | if __name__ == "__main__": 110 | num_test_samples = 128 # max is 128 111 | test_mode = "exception" # "normal" or "exception" 112 | 113 | if test_mode == "normal": 114 | print("=== Normal Operation Test ===") 115 | prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list = build_data(num_test_samples) 116 | 117 | openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 118 | async_process_function(prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list, openai_client, batch_size=100, max_retry=3) 119 | print("=== Normal Operation Test Completed ===") 120 | 121 | elif test_mode == "exception": 122 | test_exception_handling(num_test_samples) 123 | 124 | else: 125 | raise ValueError(f"Invalid test mode: {test_mode}") -------------------------------------------------------------------------------- /mirorl/workers/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .sglang_rollout.sglang_rollout import MCPSGLangRollout 16 | 17 | __all__ = ["MCPSGLangRollout"] 18 | -------------------------------------------------------------------------------- /mirorl/workers/rollout/sglang_rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MiroMind Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .sglang_rollout import MCPSGLangRollout 16 | 17 | __all__ = ["MCPSGLangRollout"] -------------------------------------------------------------------------------- /scripts/README_GRPO_Visualizer.md: -------------------------------------------------------------------------------- 1 | # 🎯 GRPO Rollout Visualizer 2 | 3 | A beautiful and interactive visualizer for GRPO (Generalized Preference Optimization) rollouts. This tool provides a hierarchical navigation system to explore your training data across projects, epochs, and individual rollouts with stunning visualizations. 4 | 5 | ## ✨ Features 6 | 7 | ### 🏗️ Hierarchical Navigation 8 | - **Projects View**: Browse different training projects 9 | - **Epochs View**: Explore training epochs with performance metrics 10 | - **Sample View**: Analyze input groups and their statistics 11 | - **Rollouts View**: Compare individual rollouts side-by-side 12 | 13 | ### 📊 Rich Visualizations 14 | - **Interactive Charts**: Plotly-powered charts for metrics visualization 15 | - **Performance Trends**: Track improvements across epochs 16 | - **Comparison Views**: Split rollouts by performance (better vs worse) 17 | - **Statistical Insights**: Average scores, accuracy metrics, and standard deviations 18 | 19 | ### 🎨 Beautiful UI 20 | - **Modern Design**: Gradient cards and smooth animations 21 | - **Responsive Layout**: Works on different screen sizes 22 | - **Intuitive Navigation**: Breadcrumb navigation and clear visual hierarchy 23 | - **Color-coded Performance**: Green for better, red for worse rollouts 24 | 25 | ## 📦 Installation 26 | 27 | ### Install Dependencies 28 | ```bash 29 | pip install -r visualizer_requirements.txt 30 | ``` 31 | 32 | ### Dependencies 33 | - **Streamlit**: Web app framework 34 | - **Pandas**: Data manipulation 35 | - **NumPy**: Numerical computations 36 | - **Plotly**: Interactive visualizations 37 | 38 | ## 🚀 Usage 39 | 40 | ### Directory Structure 41 | Your work folder should be organized as follows: 42 | ``` 43 | work_folder/ 44 | ├── project1/ 45 | │ ├── epoch1.jsonl 46 | │ ├── epoch2.jsonl 47 | │ └── epoch3.jsonl 48 | ├── project2/ 49 | │ ├── epoch1.jsonl 50 | │ └── epoch2.jsonl 51 | └── project3/ 52 | └── epoch1.jsonl 53 | ``` 54 | 55 | ### JSONL File Format 56 | Each line in the JSONL file should contain a rollout dictionary with the following structure: 57 | ```json 58 | { 59 | "input": "Text input with \\nuser\\n marker", 60 | "output": "Model output/response", 61 | "ground_truth": "Expected answer (optional)", 62 | "score": 0.85, 63 | "accuracy_reward": 0.9, 64 | "tool_format_reward": 0.1, 65 | "combined_reward": 0.87, 66 | "step": 0 67 | } 68 | ``` 69 | 70 | ### Running the Visualizer 71 | 72 | 1. **Start the application**: 73 | ```bash 74 | streamlit run grpo_visualizer.py 75 | ``` 76 | 77 | 2. **Configure the work folder**: 78 | - Enter the path to your work folder in the sidebar 79 | - Click "🔄 Refresh Data" to load the data 80 | 81 | 3. **Navigate through your data**: 82 | - **Projects**: Click on a project to explore its epochs 83 | - **Epochs**: View epoch performance and click to see input groups 84 | - **Sample**: Examine grouped inputs and their statistics 85 | - **Rollouts**: Compare individual rollouts split by performance 86 | 87 | ## 🎮 Interface Guide 88 | 89 | ### 📁 Projects View 90 | - **Project Cards**: Shows project statistics (epochs, rollouts, average score) 91 | - **Quick Access**: Click any project to dive into its epochs 92 | - **Overview**: Get a high-level view of all your training projects 93 | 94 | ### 📈 Epochs View 95 | - **Performance Charts**: Interactive plots showing score and accuracy trends 96 | - **Epoch Details**: Click on any epoch to explore its input groups 97 | - **Metrics Dashboard**: Comprehensive overview of training progress 98 | 99 | ### 📝 Sample View 100 | - **Input Grouping**: All rollouts with the same input are grouped together 101 | - **User Prompts**: Displays the text after `\nuser\n` in the input 102 | - **Sample Statistics**: Average scores, accuracy, and standard deviations 103 | - **Performance Ranking**: Sample sorted by accuracy (best first) 104 | 105 | ### 🎯 Rollouts View 106 | - **Side-by-Side Comparison**: Lower accuracy rollouts on the left, higher accuracy on the right 107 | - **Detailed Metrics**: Score, accuracy, reward, and combined reward for each rollout 108 | - **Output Inspection**: Expandable views to examine model outputs 109 | - **Ground Truth**: Compare outputs with expected answers (when available) 110 | 111 | ## 🔧 Customization 112 | 113 | ### Styling 114 | The visualizer uses custom CSS for beautiful styling. You can modify the `load_custom_css()` function to adjust: 115 | - Color schemes 116 | - Card designs 117 | - Layout spacing 118 | - Font sizes 119 | 120 | ### Data Processing 121 | The `GRPOVisualizer` class can be extended to: 122 | - Support different data formats 123 | - Add custom metrics 124 | - Implement additional visualizations 125 | - Handle special data preprocessing 126 | 127 | ## 📊 Key Metrics 128 | 129 | ### Score Metrics 130 | - **Score**: Overall performance score 131 | - **Accuracy Reward**: Accuracy-based reward 132 | - **Reward**: Base reward value 133 | - **Combined Reward**: Composite reward metric 134 | 135 | ### Statistics 136 | - **Average**: Mean values across rollouts 137 | - **Standard Deviation**: Measure of variance 138 | - **Sample Count**: Number of rollouts per input group 139 | - **Performance Split**: Median-based classification 140 | 141 | ## 🎨 UI Components 142 | 143 | ### Visual Elements 144 | - **Gradient Cards**: Beautiful project and epoch cards 145 | - **Interactive Charts**: Plotly-powered visualizations 146 | - **Responsive Grid**: Adaptive layout system 147 | - **Color Coding**: Performance-based visual cues 148 | 149 | ### Navigation 150 | - **Breadcrumb Navigation**: Always know where you are 151 | - **Back Buttons**: Easy navigation between levels 152 | - **Session State**: Maintains your position during exploration 153 | - **Sidebar Controls**: Configuration and navigation info 154 | 155 | ## 🔍 Data Insights 156 | 157 | ### Performance Analysis 158 | - **Trend Tracking**: Monitor improvements across epochs 159 | - **Comparative Analysis**: Better vs worse rollout comparison 160 | - **Statistical Overview**: Mean, standard deviation, and distributions 161 | - **Input Grouping**: Understand performance per input type 162 | 163 | ### Quality Assessment 164 | - **Accuracy Metrics**: Detailed accuracy breakdowns 165 | - **Score Distributions**: Understand performance variations 166 | - **Output Inspection**: Manual quality assessment 167 | - **Ground Truth Comparison**: Validate model outputs 168 | 169 | ## 🛠️ Technical Details 170 | 171 | ### Architecture 172 | - **Modular Design**: Clean separation of concerns 173 | - **Session Management**: Streamlit session state for navigation 174 | - **Efficient Data Loading**: Optimized JSONL parsing 175 | - **Responsive UI**: Modern web interface 176 | 177 | ### Performance 178 | - **Lazy Loading**: Data loaded on demand 179 | - **Caching**: Efficient data retrieval 180 | - **Memory Management**: Optimized for large datasets 181 | - **Interactive Updates**: Real-time UI updates 182 | 183 | ## 🤝 Contributing 184 | 185 | Feel free to contribute improvements: 186 | - **Bug Reports**: Submit issues with detailed descriptions 187 | - **Feature Requests**: Suggest new visualization features 188 | - **Code Improvements**: Submit pull requests with enhancements 189 | - **Documentation**: Help improve this README 190 | 191 | ## 📝 License 192 | 193 | This project follows the same license as the main VERL project. 194 | 195 | --- 196 | 197 | **Happy Visualizing!** 🎉 Explore your GRPO rollouts with style and gain insights into your model's training progress. -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export MAX_JOBS=32 4 | 5 | echo "1. install inference frameworks and pytorch they need" 6 | pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir 7 | pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata 8 | 9 | 10 | echo "2. install basic packages" 11 | pip install "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ 12 | "numpy<2.0.0" "pyarrow>=15.0.0" pandas \ 13 | ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \ 14 | pytest py-spy pyext pre-commit ruff mcp==1.10.1 tenacity 15 | 16 | pip install "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" 17 | 18 | 19 | echo "3. install FlashAttention" 20 | # Install flash-attn-2.7.4.post1 (cxx11abi=False) 21 | wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ 22 | pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 23 | 24 | 25 | echo "4. install FlashInfer" 26 | # Install flashinfer-0.2.7.post1+cu124 (cxx11abi=False) 27 | # 1. Clone the FlashInfer repository: 28 | git clone https://github.com/flashinfer-ai/flashinfer.git --recursive 29 | # 2. Make sure you have installed PyTorch with CUDA support. You can check the PyTorch version and CUDA version by running: 30 | python -c "import torch; print(torch.__version__, torch.version.cuda)" 31 | # 3. Install Ninja build system: 32 | pip install ninja 33 | # 4. Install FlashInfer(AOT mode): 34 | cd flashinfer 35 | git checkout v0.2.7.post1 36 | # Set the TORCH_CUDA_ARCH_LIST environment variable to the supported architectures: 37 | # export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a" 38 | # if you are using a100/a800, you can use the following command: 39 | export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" 40 | # Produces AOT kernels in aot-ops/ 41 | python -m flashinfer.aot 42 | python -m pip install --no-build-isolation --verbose . 43 | 44 | 45 | echo "5. May need to fix opencv" 46 | pip install opencv-python 47 | pip install opencv-fixer && \ 48 | python -c "from opencv_fixer import AutoFix; AutoFix()" 49 | 50 | echo "Successfully installed all packages" 51 | -------------------------------------------------------------------------------- /scripts/visualizer_requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit>=1.28.0 2 | pandas>=1.5.0 3 | numpy>=1.21.0 4 | plotly>=5.15.0 -------------------------------------------------------------------------------- /tests/distributed/run_all.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/usr/bin/env bash 16 | 17 | set -e -x 18 | torchrun --nproc-per-node=4 --standalone tests/distributed/test_tensor_dict.py -------------------------------------------------------------------------------- /tests/distributed/test_tensor_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | os.environ["NCCL_DEBUG"] = "WARN" 18 | 19 | import numpy as np 20 | import torch 21 | import torch.distributed 22 | 23 | from verl.protocol import DataProto, all_gather_data_proto 24 | from verl.utils.distributed import initialize_global_process_group 25 | 26 | 27 | def test_all_gather_data_proto(): 28 | device_mesh = torch.distributed.device_mesh.init_device_mesh("cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"]) 29 | 30 | global_rank = torch.distributed.get_rank() 31 | 32 | obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]]) 33 | 34 | labels = ["a", "b"] if global_rank % 2 == 0 else ["b", "a"] 35 | labels = np.array(labels, dtype=object) 36 | data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) 37 | 38 | all_gather_data_proto(data=data, process_group=device_mesh.get_group("dp")) 39 | 40 | if global_rank == 0: 41 | expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda") 42 | expected_labels = ["a", "b", "a", "b"] 43 | elif global_rank == 1: 44 | expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda") 45 | expected_labels = ["b", "a", "b", "a"] 46 | elif global_rank == 2: 47 | expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda") 48 | expected_labels = ["a", "b", "a", "b"] 49 | elif global_rank == 3: 50 | expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda") 51 | expected_labels = ["b", "a", "b", "a"] 52 | 53 | torch.testing.assert_close(data.batch["obs"], expected_obs, atol=0, rtol=0) 54 | assert (data.non_tensor_batch["labels"] == expected_labels).all() 55 | assert data.meta_info == {"info": "test_info"} 56 | 57 | 58 | def test_vocab_parallel_entropy(): 59 | from megatron.core import parallel_state as mpu 60 | 61 | from verl.utils.debug import log_gpu_memory_usage 62 | from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy 63 | from verl.utils.torch_functional import entropy_from_logits 64 | 65 | mpu.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None) 66 | 67 | batch_size = 2 68 | seqlen = 128 69 | vocab_size = 155136 70 | 71 | logits = torch.randn(batch_size * seqlen, vocab_size, device="cuda", requires_grad=True) 72 | target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64) 73 | 74 | # broadcast across tp 75 | torch.distributed.broadcast(logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) 76 | torch.distributed.broadcast(target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) 77 | 78 | tp_rank = mpu.get_tensor_model_parallel_rank() 79 | vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size() 80 | 81 | # get the local logits of each tp 82 | vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_() 83 | logits.grad = None 84 | vocab_parallel_logits.grad = None 85 | 86 | log_gpu_memory_usage("begin") 87 | output_entropy = vocab_parallel_entropy(vocab_parallel_logits) 88 | log_gpu_memory_usage("after forward") 89 | grad_output = torch.randn_like(output_entropy) 90 | output_entropy.backward(grad_output) 91 | log_gpu_memory_usage("after backward") 92 | 93 | target_entropy = entropy_from_logits(logits) 94 | torch.testing.assert_close(output_entropy, target_entropy) 95 | target_entropy.backward(grad_output) 96 | torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad) 97 | # make sure logits is not altered 98 | torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits) 99 | 100 | if mpu.get_tensor_model_parallel_rank() == 0: 101 | print("test_vocab_parallel_entropy passes") 102 | 103 | mpu.destroy_model_parallel() 104 | 105 | 106 | if __name__ == "__main__": 107 | local_rank, rank, world_size = initialize_global_process_group() 108 | test_all_gather_data_proto() 109 | test_vocab_parallel_entropy() 110 | -------------------------------------------------------------------------------- /tests/models/test_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input 17 | from transformers import ( 18 | AutoModelForCausalLM, 19 | AutoModelForTokenClassification, 20 | GemmaConfig, 21 | LlamaConfig, 22 | MistralConfig, 23 | Qwen2Config, 24 | ) 25 | 26 | from verl.utils.model import compute_position_id_with_mask, create_random_mask 27 | from verl.utils.torch_functional import log_probs_from_logits_all_rmpad, masked_mean 28 | 29 | # TODO(sgm): add more models for test 30 | # we only need one scale for each model 31 | test_configs = [ 32 | LlamaConfig(num_hidden_layers=1), 33 | MistralConfig(num_hidden_layers=1), 34 | GemmaConfig(num_hidden_layers=1), 35 | Qwen2Config(num_hidden_layers=1), 36 | ] 37 | 38 | 39 | def test_hf_casual_models(): 40 | batch_size = 4 41 | seqlen = 128 42 | response_length = 127 43 | 44 | for config in test_configs: 45 | # config = AutoConfig.from_pretrained(test_case) 46 | with torch.device("cuda"): 47 | model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") 48 | model = model.to(device="cuda") 49 | input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") 50 | attention_mask = create_random_mask( 51 | input_ids=input_ids, 52 | max_ratio_of_left_padding=0.1, 53 | max_ratio_of_valid_token=0.8, 54 | min_ratio_of_valid_token=0.5, 55 | ) 56 | position_ids = compute_position_id_with_mask(attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here 57 | 58 | input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) 59 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) 60 | 61 | # unpad the position_ids to align the rotary 62 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) 63 | 64 | # input with input_ids_rmpad and postition_ids to enable flash attention varlen 65 | logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False).logits # (1, total_nnz, vocab_size) 66 | 67 | origin_logits = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False).logits 68 | origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask) 69 | 70 | logits_rmpad = logits_rmpad.squeeze(0) 71 | log_probs = log_probs_from_logits_all_rmpad( 72 | input_ids_rmpad=input_ids_rmpad, 73 | logits_rmpad=logits_rmpad, 74 | indices=indices, 75 | batch_size=batch_size, 76 | seqlen=seqlen, 77 | response_length=response_length, 78 | ) # (batch, seqlen) 79 | origin_log_probs = log_probs_from_logits_all_rmpad( 80 | input_ids_rmpad=input_ids_rmpad, 81 | logits_rmpad=origin_logits_rmpad, 82 | indices=origin_logits_indices, 83 | batch_size=batch_size, 84 | seqlen=seqlen, 85 | response_length=response_length, 86 | ) # (batch, seqlen) 87 | 88 | torch.testing.assert_close( 89 | masked_mean(log_probs, attention_mask[:, -response_length - 1 : -1]), 90 | masked_mean(origin_log_probs, attention_mask[:, -response_length - 1 : -1]), 91 | atol=1e-2, 92 | rtol=1e-5, 93 | ) 94 | print("Check pass") 95 | 96 | 97 | def test_hf_value_models(): 98 | batch_size = 4 99 | seqlen = 128 100 | 101 | for config in test_configs: 102 | # config = AutoConfig.from_pretrained(test_case) 103 | config.num_labels = 1 104 | config.classifier_dropout = 0 105 | config.hidden_dropout = 0 106 | with torch.device("cuda"): 107 | model = AutoModelForTokenClassification.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") 108 | model = model.to(device="cuda") 109 | input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") 110 | attention_mask = create_random_mask( 111 | input_ids=input_ids, 112 | max_ratio_of_left_padding=0.1, 113 | max_ratio_of_valid_token=0.8, 114 | min_ratio_of_valid_token=0.5, 115 | ) 116 | position_ids = compute_position_id_with_mask(attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here 117 | 118 | input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) 119 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) 120 | 121 | # unpad the position_ids to align the rotary 122 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) 123 | 124 | origin_logits = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False).logits 125 | 126 | # input with input_ids_rmpad and postition_ids to enable flash attention varlen 127 | rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False).logits # (1, total_nnz, 1) 128 | rmpad_logits = rmpad_logits.squeeze(0) 129 | pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen) 130 | 131 | torch.testing.assert_close( 132 | masked_mean(pad_logits, attention_mask[:, :, None]), 133 | masked_mean(origin_logits, attention_mask[:, :, None]), 134 | atol=1e-2, 135 | rtol=1e-5, 136 | ) 137 | print("Value model check pass") 138 | 139 | 140 | if __name__ == "__main__": 141 | test_hf_casual_models() 142 | test_hf_value_models() 143 | -------------------------------------------------------------------------------- /tests/ray_gpu/detached_worker/README.md: -------------------------------------------------------------------------------- 1 | # Detached Worker 2 | ## How to run (Only on a single node) 3 | - Start a local ray cluster: 4 | ```bash 5 | ray start --head --port=6379 6 | ``` 7 | - Run the server 8 | ```bash 9 | python3 server.py 10 | ``` 11 | - On another terminal, Run the client 12 | ```bash 13 | python3 client.py 14 | ``` 15 | -------------------------------------------------------------------------------- /tests/ray_gpu/detached_worker/client.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | In client, we can get the server handler and send RPC request 16 | """ 17 | 18 | import ray 19 | import torch 20 | from server import Trainer 21 | from tensordict import TensorDict 22 | 23 | from verl import DataProto 24 | from verl.single_controller.ray import RayClassWithInitArgs 25 | from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup 26 | 27 | 28 | def compute_position_id_with_mask(mask): 29 | return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) 30 | 31 | 32 | if __name__ == "__main__": 33 | ray.init(address="auto", namespace="verl") 34 | # get the worker group using names 35 | worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"] 36 | cls_with_init_args = RayClassWithInitArgs(cls=Trainer) 37 | worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args) 38 | 39 | batch_size = 16 40 | sequence_length = 1024 41 | 42 | # give Trainer some data to train 43 | input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda") 44 | attention_mask = torch.ones_like(input_ids) 45 | position_ids = compute_position_id_with_mask(attention_mask) 46 | 47 | data = DataProto( 48 | batch=TensorDict( 49 | {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}, 50 | batch_size=batch_size, 51 | ), 52 | meta_info={}, 53 | ) 54 | 55 | output = worker_group.train_model(data) 56 | 57 | print(output) 58 | -------------------------------------------------------------------------------- /tests/ray_gpu/detached_worker/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ray start --head --port=6379 3 | python3 server.py 4 | python3 client.py 5 | ray stop --force -------------------------------------------------------------------------------- /tests/ray_gpu/detached_worker/server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Server starts a Trainer. Client sends data to the server to train. 16 | """ 17 | 18 | import os 19 | 20 | os.environ["MEGATRON_USE_CUDA_TIMER"] = "0" 21 | os.environ["MEGATRON_START_PROCESS_TIMER"] = "False" 22 | os.environ["NCCL_DEBUG"] = "WARN" 23 | 24 | import ray 25 | import torch 26 | from megatron.core import parallel_state as mpu 27 | from megatron.core import tensor_parallel 28 | from megatron.core.models.gpt.gpt_model import ModelType 29 | from omegaconf import OmegaConf 30 | from tensordict import TensorDict 31 | from torch import nn 32 | from transformers import LlamaConfig 33 | 34 | from verl import DataProto 35 | from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP 36 | from verl.single_controller.base.decorator import Dispatch, register 37 | from verl.single_controller.base.megatron.worker import MegatronWorker 38 | from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool 39 | from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup 40 | from verl.utils.megatron.optimizer import get_megatron_optimizer 41 | from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config 42 | 43 | 44 | @ray.remote 45 | class Trainer(MegatronWorker): 46 | def __init__(self): 47 | super().__init__() 48 | 49 | if not torch.distributed.is_initialized(): 50 | rank = int(os.environ["LOCAL_RANK"]) 51 | torch.distributed.init_process_group(backend="nccl") 52 | torch.cuda.set_device(rank) 53 | 54 | os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" 55 | mpu.initialize_model_parallel( 56 | tensor_model_parallel_size=2, 57 | pipeline_model_parallel_size=1, 58 | virtual_pipeline_model_parallel_size=None, 59 | pipeline_model_parallel_split_rank=None, 60 | use_sharp=False, 61 | context_parallel_size=1, 62 | expert_model_parallel_size=1, 63 | nccl_communicator_config_path=None, 64 | ) 65 | tensor_parallel.model_parallel_cuda_manual_seed(10) 66 | 67 | @register(dispatch_mode=Dispatch.ONE_TO_ALL) 68 | def init_model(self): 69 | actor_model_config = LlamaConfig( 70 | vocab_size=256, 71 | hidden_size=2048, 72 | intermediate_size=5504, 73 | num_hidden_layers=24, 74 | num_attention_heads=16, 75 | num_key_value_heads=16, 76 | ) 77 | 78 | megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16) 79 | self.megatron_config = megatron_config 80 | 81 | def megatron_actor_model_provider(pre_process, post_process): 82 | # vpp is not supported yet because it will hang for some reason. Need debugging 83 | # this_megatron_config = copy.deepcopy(megatron_config) 84 | # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank 85 | parallel_model = ParallelLlamaForCausalLMRmPadPP( 86 | config=actor_model_config, 87 | megatron_config=megatron_config, 88 | pre_process=pre_process, 89 | post_process=post_process, 90 | ) 91 | parallel_model.cuda() 92 | return parallel_model 93 | 94 | actor_module = get_model( 95 | model_provider_func=megatron_actor_model_provider, 96 | model_type=ModelType.encoder_or_decoder, 97 | wrap_with_ddp=True, 98 | ) 99 | actor_module = nn.ModuleList(actor_module) 100 | 101 | optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0}) 102 | 103 | optim_config = init_megatron_optim_config(optim_config) 104 | self.optimizer_config = optim_config 105 | actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config) 106 | 107 | self.model = actor_module[0] 108 | self.optimizer = actor_optimizer 109 | 110 | @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) 111 | def train_model(self, data: DataProto) -> DataProto: 112 | input_ids = data.batch["input_ids"] 113 | attention_mask = data.batch["attention_mask"] 114 | position_ids = data.batch["position_ids"] 115 | 116 | self.optimizer.zero_grad() 117 | self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm 118 | # update for 1 iteration 119 | output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits 120 | output.mean().backward() 121 | 122 | update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config, self.megatron_config.timers) 123 | 124 | return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0])) 125 | 126 | 127 | if __name__ == "__main__": 128 | ray.init(address="auto", namespace="verl") 129 | 130 | resource_pool = RayResourcePool(process_on_nodes=[2], detached=True) 131 | cls_with_init_args = RayClassWithInitArgs(cls=Trainer) 132 | worker_group = NVMegatronRayWorkerGroup( 133 | resource_pool=resource_pool, 134 | ray_cls_with_init=cls_with_init_args, 135 | name_prefix="trainer", 136 | detached=True, 137 | ) 138 | 139 | worker_group.init_model() 140 | 141 | worker_names = worker_group.worker_names 142 | print(worker_names) 143 | -------------------------------------------------------------------------------- /tests/ray_gpu/test_colocated_workers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | from verl import DataProto 18 | from verl.single_controller.base import Worker 19 | from verl.single_controller.base.decorator import Dispatch, register 20 | from verl.single_controller.ray.base import ( 21 | RayClassWithInitArgs, 22 | RayResourcePool, 23 | RayWorkerGroup, 24 | create_colocated_worker_cls, 25 | ) 26 | 27 | 28 | @ray.remote 29 | class Actor(Worker): 30 | def __init__(self) -> None: 31 | super().__init__() 32 | 33 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) 34 | def add(self, data: DataProto): 35 | data.batch["a"] += self.rank 36 | return data 37 | 38 | 39 | @ray.remote 40 | class Critic(Worker): 41 | def __init__(self, config) -> None: 42 | super().__init__() 43 | self.config = config 44 | 45 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) 46 | async def sub(self, data: DataProto): 47 | data.batch["a"] -= self.config["b"] 48 | return data 49 | 50 | 51 | def test_colocated_workers(): 52 | ray.init() 53 | 54 | import torch 55 | 56 | data = DataProto.from_dict({"a": torch.zeros(10)}) 57 | # create separate workers on the same resource pool 58 | actor_cls = RayClassWithInitArgs(cls=Actor) 59 | critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10}) 60 | resource_pool = RayResourcePool(process_on_nodes=[2]) 61 | 62 | actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) 63 | critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls) 64 | 65 | expected_actor_output = actor_wg.add(data) 66 | expected_critic_output = critic_wg.sub(data) 67 | 68 | # create colocated workers 69 | cls_dict = {"actor": actor_cls, "critic": critic_cls} 70 | ray_cls_with_init = create_colocated_worker_cls(cls_dict) 71 | wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) 72 | spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys()) 73 | 74 | colocated_actor_wg = spawn_wg["actor"] 75 | colocated_critic_wg = spawn_wg["critic"] 76 | 77 | actor_output = colocated_actor_wg.add(data) 78 | critic_output = colocated_critic_wg.sub(data) 79 | 80 | torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0) 81 | torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0) 82 | 83 | ray.shutdown() 84 | -------------------------------------------------------------------------------- /tests/ray_gpu/test_colocated_workers_fused.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | from verl import DataProto 18 | from verl.single_controller.base import Worker 19 | from verl.single_controller.base.decorator import Dispatch, register 20 | from verl.single_controller.ray.base import ( 21 | RayClassWithInitArgs, 22 | RayResourcePool, 23 | RayWorkerGroup, 24 | create_colocated_worker_cls_fused, 25 | ) 26 | 27 | 28 | @ray.remote 29 | class Actor(Worker): 30 | def __init__(self) -> None: 31 | super().__init__() 32 | 33 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) 34 | def add(self, data: DataProto): 35 | data.batch["a"] += self.rank 36 | return data 37 | 38 | 39 | @ray.remote 40 | class Critic(Worker): 41 | def __init__(self, config) -> None: 42 | super().__init__() 43 | self.config = config 44 | 45 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) 46 | def sub(self, data: DataProto): 47 | data.batch["a"] -= self.config["b"] 48 | return data 49 | 50 | 51 | def test_colocated_workers_fused(): 52 | ray.init() 53 | 54 | import torch 55 | 56 | data = DataProto.from_dict({"a": torch.zeros(10)}) 57 | # create separate workers on the same resource pool 58 | actor_cls = RayClassWithInitArgs(cls=Actor) 59 | critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10}) 60 | resource_pool = RayResourcePool(process_on_nodes=[2]) 61 | 62 | actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) 63 | critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls) 64 | 65 | expected_actor_output = actor_wg.add(data) 66 | expected_critic_output = critic_wg.sub(data) 67 | 68 | # create colocated workers 69 | cls_dict = {"actor": actor_cls, "critic": critic_cls} 70 | ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict) 71 | wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) 72 | spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys()) 73 | 74 | colocated_actor_wg = spawn_wg["actor"] 75 | colocated_critic_wg = spawn_wg["critic"] 76 | 77 | actor_output = colocated_actor_wg.add(data) 78 | critic_output = colocated_critic_wg.sub(data) 79 | 80 | torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0) 81 | torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0) 82 | 83 | ray.shutdown() 84 | -------------------------------------------------------------------------------- /tests/ray_gpu/test_data_transfer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | In this test, we instantiate a data parallel worker with 8 GPUs 16 | """ 17 | 18 | import ray 19 | import tensordict 20 | import torch 21 | from codetiming import Timer 22 | from torch import distributed as dist 23 | 24 | from verl import DataProto 25 | from verl.single_controller.base import Worker 26 | from verl.single_controller.base.decorator import Dispatch, register 27 | from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup 28 | from verl.utils.ray_utils import parallel_put 29 | 30 | 31 | @ray.remote 32 | class DummyWorker(Worker): 33 | def __init__(self): 34 | super().__init__() 35 | dist.init_process_group() 36 | 37 | @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) 38 | def do_nothing(self, data): 39 | for key in data.batch.keys(): 40 | data.batch[key] += 1 41 | if tensordict.__version__ >= "0.5.0": 42 | data.batch = data.batch.consolidate() 43 | return data 44 | 45 | 46 | def test_data_transfer(): 47 | ray.init() 48 | # construct resource pool 49 | resource_pool = RayResourcePool([8]) 50 | cls_with_init = RayClassWithInitArgs(cls=DummyWorker) 51 | # construct worker group 52 | wg = RayWorkerGroup(resource_pool, cls_with_init) 53 | 54 | # this is real dataset size 55 | batch_size = 4096 56 | seqlen = 32768 57 | 58 | data_dict = {} 59 | 60 | for i in range(2): 61 | data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen)) 62 | 63 | data = DataProto.from_dict(tensors=data_dict) 64 | 65 | print(data) 66 | 67 | # we manually split data here and send to each worker 68 | data_list = data.chunk(wg.world_size) 69 | 70 | for i in range(wg.world_size): 71 | # consolidate is necessary 72 | if tensordict.__version__ >= "0.5.0": 73 | data_list[i].batch = data_list[i].batch.consolidate() 74 | 75 | with Timer(name="ray.pickle", initial_text=True): 76 | for i in range(wg.world_size): 77 | ray.cloudpickle.pickle.dumps(data_list[i]) 78 | 79 | with Timer(name="raw.pickle", initial_text=True): 80 | import pickle 81 | 82 | for i in range(wg.world_size): 83 | pickle.dumps(data_list[i]) 84 | 85 | # we put in advance 86 | with Timer(name="put", initial_text=True): 87 | # takes around 40 seconds 88 | data_list_ref = parallel_put(data_list) 89 | # for i in range(wg.world_size): 90 | # data_list[i] = ray.put(data_list[i]) 91 | 92 | with Timer(name="launch", initial_text=True): 93 | output_ref = wg.do_nothing(data_list_ref) 94 | 95 | with Timer(name="get", initial_text=True): 96 | # takes around 40 seconds 97 | output_lst = ray.get(output_ref) 98 | 99 | for input_data, output_data in zip(data_list, output_lst): 100 | for key in input_data.batch.keys(): 101 | assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), ( 102 | input_data.batch[key], 103 | output_data.batch[key], 104 | key, 105 | ) 106 | 107 | ray.shutdown() 108 | -------------------------------------------------------------------------------- /tests/ray_gpu/test_driverfunc_to_worker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import ray 18 | import torch 19 | from tensordict import TensorDict 20 | 21 | from verl import DataProto 22 | from verl.single_controller.base.worker import Worker 23 | from verl.single_controller.ray import RayWorkerGroup 24 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool 25 | 26 | os.environ["RAY_DEDUP_LOGS"] = "0" 27 | os.environ["NCCL_DEBUG"] = "WARN" 28 | 29 | 30 | @ray.remote 31 | class ModelActor(Worker): 32 | def __init__(self): 33 | pass 34 | 35 | 36 | class HackSelf: 37 | def __init__(self): 38 | pass 39 | 40 | 41 | def get_aux_metrics(self, test_proto): 42 | sequence_ids = test_proto.batch["sequence_ids"] 43 | decode_count = [] 44 | for i in range(sequence_ids.size(0)): 45 | decode_count.append(len(sequence_ids[i].tolist())) 46 | ret_proto = DataProto(batch=TensorDict({"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0))) 47 | return ret_proto 48 | 49 | 50 | def test(): 51 | # construct model 52 | ray.init() 53 | 54 | # create 2 workers, each hold a GPU 55 | resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a") 56 | 57 | class_with_args = RayClassWithInitArgs(cls=ModelActor) 58 | shard_wg = RayWorkerGroup(resource_pool, class_with_args) 59 | 60 | test_bs = 8 61 | test_proto = DataProto( 62 | TensorDict( 63 | { 64 | "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64), 65 | }, 66 | batch_size=test_bs, 67 | ), 68 | meta_info={"query_length": 1536}, 69 | ) 70 | 71 | # Sharding among different ranks 72 | ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto) 73 | 74 | # compare execute on driver 75 | hs = HackSelf() 76 | ret_proto2 = get_aux_metrics(hs, test_proto) 77 | 78 | torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"]) 79 | 80 | ray.shutdown() 81 | -------------------------------------------------------------------------------- /tests/ray_gpu/test_high_level_scheduling_api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import time 16 | 17 | import ray 18 | 19 | from verl.single_controller.base.worker import Worker 20 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool 21 | 22 | 23 | @ray.remote 24 | class TestActor(Worker): 25 | # TODO: pass *args and **kwargs is bug prone and not very convincing 26 | def __init__(self, cuda_visible_devices=None) -> None: 27 | super().__init__(cuda_visible_devices) 28 | 29 | def get_node_id(self): 30 | return ray.get_runtime_context().get_node_id() 31 | 32 | 33 | def test(): 34 | ray.init() 35 | 36 | # test single-node-no-partition 37 | print("test single-node-no-partition") 38 | resource_pool = RayResourcePool([8], use_gpu=True) 39 | 40 | class_with_args = RayClassWithInitArgs(cls=TestActor) 41 | 42 | print("create actor worker group") 43 | actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor") 44 | print("create critic worker group") 45 | critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic") 46 | print("create rm worker group") 47 | rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm") 48 | print("create ref worker group") 49 | ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref") 50 | 51 | assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] 52 | assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] 53 | assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] 54 | assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] 55 | 56 | del actor_wg 57 | del critic_wg 58 | del rm_wg 59 | del ref_wg 60 | 61 | [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()] 62 | print("wait 5s to remove placemeng_group") 63 | time.sleep(5) 64 | # test single-node-multi-partition 65 | 66 | print("test single-node-multi-partition") 67 | rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm") 68 | ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref") 69 | total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool) 70 | 71 | assert rm_resource_pool.world_size == 4 72 | assert ref_resource_pool.world_size == 4 73 | assert total_resource_pool.world_size == 8 74 | 75 | actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor") 76 | critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic") 77 | rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm") 78 | ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref") 79 | 80 | assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] 81 | assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] 82 | assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)] 83 | assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)] 84 | 85 | ray.shutdown() 86 | -------------------------------------------------------------------------------- /tests/ray_gpu/test_rvdz.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | 18 | @ray.remote 19 | class TestWorker: 20 | def __init__(self, rank, world_size, group_name): 21 | self.rank = rank 22 | self.world_size = world_size 23 | self.group_name = group_name 24 | self.communicator = None 25 | 26 | def init(self): 27 | from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray 28 | 29 | self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name) 30 | 31 | def test(self): 32 | if self.communicator is None: 33 | return None 34 | return self.communicator.rank_id() 35 | 36 | 37 | def test_rvdz(): 38 | ray.init() 39 | 40 | group_name = "test_group" 41 | world_size = 2 42 | 43 | workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)] 44 | 45 | ray.get([worker.init.remote() for worker in workers]) 46 | 47 | ranks = ray.get([worker.test.remote() for worker in workers]) 48 | 49 | assert ranks == [0, 1], f"expecting [0, 1], got {ranks}" 50 | 51 | ray.shutdown() 52 | -------------------------------------------------------------------------------- /tests/ray_gpu/test_worker_group_basics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | e2e test verl.single_controller.ray 16 | """ 17 | 18 | import ray 19 | import torch 20 | 21 | from verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register 22 | from verl.single_controller.base.worker import Worker 23 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup 24 | 25 | 26 | def two_to_all_dispatch_fn(worker_group, *args, **kwargs): 27 | """ 28 | Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker. 29 | """ 30 | for arg in args: 31 | assert len(arg) == 2 32 | for i in range(worker_group.world_size - 2): 33 | arg.append(arg[i % 2]) 34 | for k, v in kwargs.items(): 35 | assert len(v) == 2 36 | for i in range(worker_group.world_size - 2): 37 | v.append(v[i % 2]) 38 | return args, kwargs 39 | 40 | 41 | @ray.remote 42 | class TestActor(Worker): 43 | # TODO: pass *args and **kwargs is bug prone and not very convincing 44 | def __init__(self, x) -> None: 45 | super().__init__() 46 | self._x = x 47 | 48 | def foo(self, y): 49 | return self._x + y 50 | 51 | @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) 52 | def foo_rank_zero(self, x, y): 53 | return self._x + y + x 54 | 55 | @register(Dispatch.ONE_TO_ALL, blocking=False) 56 | def foo_one_to_all(self, x, y): 57 | return self._x + y + x 58 | 59 | @register(Dispatch.ALL_TO_ALL, blocking=False) 60 | def foo_all_to_all(self, x, y): 61 | return self._x + y + x 62 | 63 | @register(dispatch_mode={"dispatch_fn": two_to_all_dispatch_fn, "collect_fn": collect_all_to_all}) 64 | def foo_custom(self, x, y): 65 | return self._x + y + x 66 | 67 | 68 | @ray.remote(num_gpus=0.1) 69 | def remote_call_wg(worker_names): 70 | class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) 71 | worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None) 72 | print(worker_group.worker_names) 73 | 74 | output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6]) 75 | assert output_ref == [8, 10, 8, 10] 76 | 77 | output_ref = worker_group.foo_rank_zero(x=1, y=2) 78 | assert output_ref == 5 79 | 80 | return worker_group.worker_names 81 | 82 | 83 | def add_one(data): 84 | data = data.to("cuda") 85 | data += 1 86 | data = data.to("cpu") 87 | return data 88 | 89 | 90 | def test_basics(): 91 | ray.init(num_cpus=100) 92 | 93 | # create 4 workers, each hold a GPU 94 | resource_pool = RayResourcePool([4], use_gpu=True) 95 | class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) 96 | 97 | worker_group = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic") 98 | 99 | print(worker_group.worker_names) 100 | 101 | # this will wait for all the results 102 | output = worker_group.execute_all_sync("foo", y=3) 103 | assert output == [5, 5, 5, 5] 104 | 105 | # this is a list of object reference. It won't block. 106 | output_ref = worker_group.execute_all_async("foo", y=4) 107 | print(output_ref) 108 | 109 | assert ray.get(output_ref) == [6, 6, 6, 6] 110 | 111 | output_ref = worker_group.foo_one_to_all(x=1, y=2) 112 | assert ray.get(output_ref) == [5, 5, 5, 5] 113 | 114 | output_ref = worker_group.foo_all_to_all(x=[1, 2, 3, 4], y=[5, 6, 7, 8]) 115 | assert ray.get(output_ref) == [8, 10, 12, 14] 116 | 117 | print(ray.get(remote_call_wg.remote(worker_group.worker_names))) 118 | 119 | output = worker_group.execute_func_rank_zero(add_one, torch.ones(2, 2)) 120 | torch.testing.assert_close(output, torch.ones(2, 2) + 1) 121 | 122 | ray.shutdown() 123 | 124 | 125 | if __name__ == "__main__": 126 | test_basics() 127 | -------------------------------------------------------------------------------- /tests/ray_gpu/test_worker_group_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | os.environ["RAY_DEDUP_LOGS"] = "0" 18 | os.environ["NCCL_DEBUG"] = "WARN" 19 | 20 | import ray 21 | import torch 22 | import torch.distributed 23 | 24 | from verl.single_controller.base.worker import Worker 25 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup 26 | 27 | 28 | @ray.remote 29 | class TestAllGatherActor(Worker): 30 | def __init__(self, size) -> None: 31 | super().__init__() 32 | self.size = size 33 | 34 | def init(self): 35 | torch.distributed.init_process_group() 36 | self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda") 37 | self.tensor += self.rank 38 | 39 | def all_gather(self): 40 | world_size = self._world_size 41 | output = torch.zeros(size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device) 42 | torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) 43 | return output 44 | 45 | 46 | @ray.remote 47 | class TestAllGatherActorV2(Worker): 48 | def __init__(self, size) -> None: 49 | super().__init__() 50 | self.size = size 51 | 52 | torch.distributed.init_process_group() 53 | self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda") 54 | self.tensor += self.rank 55 | 56 | def all_gather(self): 57 | world_size = self._world_size 58 | output = torch.zeros(size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device) 59 | torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) 60 | return output 61 | 62 | 63 | def test_all_gather_torch(): 64 | """ 65 | In this test, we instantiate 4 GPUs in a group and test the all_gather 66 | """ 67 | ray.init() 68 | 69 | # create 4 workers, each hold a GPU 70 | resource_pool = RayResourcePool([4], use_gpu=True) 71 | class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2) 72 | 73 | worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch") 74 | 75 | worker_group.execute_all_sync("init") 76 | output = worker_group.execute_all_sync("all_gather") 77 | for i in range(1, len(output)): 78 | assert torch.all(output[i] == output[0]) 79 | 80 | output = output[0].cpu() 81 | print(output) 82 | assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64)) 83 | 84 | ray.shutdown() 85 | 86 | 87 | def test_all_gather_torch_v2(): 88 | """ 89 | In this test, we instantiate 4 GPUs in a group and test the all_gather 90 | """ 91 | ray.init() 92 | 93 | # create 4 workers, each hold a GPU 94 | resource_pool = RayResourcePool([4], use_gpu=True) 95 | class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2) 96 | 97 | worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch") 98 | 99 | output = worker_group.execute_all_sync("all_gather") 100 | for i in range(1, len(output)): 101 | assert torch.all(output[i] == output[0]) 102 | 103 | output = output[0].cpu() 104 | print(output) 105 | assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64)) 106 | 107 | ray.shutdown() 108 | -------------------------------------------------------------------------------- /tests/sanity/check_license.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | from pathlib import Path 16 | 17 | license_head_miro = "Copyright 2025 MiroMind Team" 18 | license_head_bytedance = "Copyright 2024 Bytedance Ltd. and/or its affiliates" 19 | license_head_bytedance_25 = "Copyright 2025 Bytedance Ltd. and/or its affiliates" 20 | # Add custom license headers below 21 | license_head_prime = "Copyright 2024 PRIME team and/or its affiliates" 22 | license_head_individual = "Copyright 2025 Individual Contributor:" 23 | license_head_sglang = "Copyright 2023-2024 SGLang Team" 24 | license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates" 25 | license_headers = [ 26 | license_head_miro, 27 | license_head_bytedance, 28 | license_head_bytedance_25, 29 | license_head_prime, 30 | license_head_individual, 31 | license_head_sglang, 32 | license_head_modelbest, 33 | ] 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = ArgumentParser() 38 | parser.add_argument("--directory", "-d", required=True, type=str) 39 | args = parser.parse_args() 40 | directory_in_str = args.directory 41 | 42 | pathlist = Path(directory_in_str).glob("**/*.py") 43 | for path in pathlist: 44 | # because path is object not string 45 | path_in_str = str(path.absolute()) 46 | print(path_in_str) 47 | with open(path_in_str, encoding="utf-8") as f: 48 | file_content = f.read() 49 | 50 | has_license = False 51 | for lh in license_headers: 52 | if lh in file_content: 53 | has_license = True 54 | break 55 | assert has_license, f"file {path_in_str} does not contain license" 56 | -------------------------------------------------------------------------------- /tests/sanity/check_pr_title.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import re 17 | import sys 18 | 19 | # Get PR title from environment 20 | pr_title = os.environ.get("PR_TITLE", "").strip() 21 | 22 | # Define rules 23 | allowed_modules = ["fsdp", "megatron", "sglang", "vllm", "rollout", "trainer"] 24 | allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"] 25 | allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"] 26 | allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt"] 27 | allowed_types = ["feat", "fix", "doc", "refactor", "chore"] 28 | 29 | # Build dynamic regex pattern 30 | re_modules_pattern = re.compile(r"^\[([a-z_,\s]+)\]", re.IGNORECASE) 31 | re_modules = re_modules_pattern.match(pr_title) 32 | if not re_modules: 33 | print(f"❌ Invalid PR title: '{pr_title}'") 34 | print("Expected format: [module] type: description") 35 | print(f"Allowed modules: {', '.join(allowed_modules)}") 36 | sys.exit(1) 37 | else: 38 | modules = re.findall(r"[a-z]+", re_modules.group(1).lower()) 39 | if not all(module in allowed_modules for module in modules): 40 | invalid_modules = [module for module in modules if module not in allowed_modules] 41 | print(f"❌ Invalid modules: {', '.join(invalid_modules)}") 42 | print(f"Allowed modules: {', '.join(allowed_modules)}") 43 | sys.exit(1) 44 | 45 | types_pattern = "|".join(re.escape(t) for t in allowed_types) 46 | re_types_pattern = re.compile(rf"^\[[a-z_,\s]+\]\s+({types_pattern}):\s+.+$", re.IGNORECASE) 47 | match = re_types_pattern.match(pr_title) 48 | 49 | if not match: 50 | print(f"❌ Invalid PR title: '{pr_title}'") 51 | print("Expected format: [module] type: description") 52 | print(f"Allowed types: {', '.join(allowed_types)}") 53 | sys.exit(1) 54 | 55 | change_type = match.group(1).lower() 56 | 57 | print(f"✅ PR title is valid: {pr_title}, modules: {modules}, type: {change_type}") 58 | -------------------------------------------------------------------------------- /tests/sanity/test_config_docs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from pathlib import Path 17 | 18 | def validate_yaml_format(yaml_lines): 19 | errors = [] 20 | i = 0 21 | prev_key_indent = None 22 | 23 | while i < len(yaml_lines): 24 | line = yaml_lines[i] 25 | stripped = line.strip() 26 | 27 | # Skip empty lines 28 | if stripped == "": 29 | i += 1 30 | continue 31 | 32 | # Match YAML keys like "field:" or "field: value" 33 | key_match = re.match(r'^(\s*)([a-zA-Z0-9_]+):', line) 34 | if key_match: 35 | indent = key_match.group(1) 36 | is_section_header = (i + 1 < len(yaml_lines) and yaml_lines[i + 1].strip() == "") 37 | 38 | # Check if there's a comment above 39 | if i == 0 or not yaml_lines[i - 1].strip().startswith("#"): 40 | errors.append(f"Missing comment above line {i+1}: {line.strip()}") 41 | 42 | # Check for inline comment 43 | if "#" in line and not stripped.startswith("#"): 44 | comment_index = line.index("#") 45 | colon_index = line.index(":") 46 | if comment_index > colon_index: 47 | errors.append(f"Inline comment found on line {i+1}: {line.strip()}") 48 | 49 | # Check for blank line after this key line (unless next is a deeper indent) 50 | if i + 1 < len(yaml_lines): 51 | next_line = yaml_lines[i + 1] 52 | next_stripped = next_line.strip() 53 | 54 | # If next is not empty and not a deeper nested line, enforce blank line 55 | if next_stripped != "": 56 | errors.append(f"Missing blank line after line {i+1}: {line.strip()}") 57 | 58 | i += 1 59 | 60 | return errors 61 | 62 | 63 | def test_trainer_config_doc(): 64 | yaml_path = Path("mirorl/trainer/config/ppo_trainer.yaml") # path to your YAML file 65 | with open(yaml_path, "r") as f: 66 | lines = f.readlines() 67 | 68 | validation_errors = validate_yaml_format(lines) 69 | if validation_errors: 70 | print("YAML documentation format check failed:") 71 | print("Please read the top block of `verl/trainer/config/ppo_trainer.yaml` to see format rules:\n") 72 | for err in validation_errors: 73 | print(" -", err) 74 | else: 75 | print("YAML format check passed ✅") 76 | -------------------------------------------------------------------------------- /tests/sanity/test_import.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def test_import(): 17 | import verl 18 | 19 | print(verl.__version__) 20 | 21 | 22 | def test_single_controller_import(): 23 | import verl.single_controller 24 | 25 | print(verl.single_controller.__version__) 26 | -------------------------------------------------------------------------------- /tests/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Tests for the trainer module. 16 | """ -------------------------------------------------------------------------------- /tests/trainer/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Tests for the PPO trainer module. 16 | """ 17 | -------------------------------------------------------------------------------- /tests/trainer/ppo/test_core_algos.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | import pytest 18 | 19 | import verl.trainer.ppo.core_algos 20 | from verl.trainer.ppo.core_algos import get_adv_estimator_fn, register_adv_est 21 | 22 | 23 | def mock_test_fn(): 24 | pass 25 | 26 | class TestRegisterAdvEst(unittest.TestCase): 27 | def setUp(self): 28 | """Clear the registry before each test""" 29 | verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear() 30 | verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = { 31 | "gae": lambda x: x * 2, 32 | "vtrace": lambda x: x + 1, 33 | } 34 | self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY 35 | 36 | def tearDown(self) -> None: 37 | verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear() 38 | return super().tearDown() 39 | 40 | def test_register_new_function(self): 41 | """Test registering a new function with a string name""" 42 | @register_adv_est("test_estimator") 43 | def test_fn(): 44 | pass 45 | 46 | self.assertIn("test_estimator", self.ADV_ESTIMATOR_REGISTRY) 47 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_estimator"], test_fn) 48 | 49 | def test_register_with_enum(self): 50 | """Test registering with an enum value (assuming AdvantageEstimator exists)""" 51 | from enum import Enum 52 | class AdvantageEstimator(Enum): 53 | TEST = "test_enum_estimator" 54 | 55 | @register_adv_est(AdvantageEstimator.TEST) 56 | def test_fn(): 57 | pass 58 | 59 | self.assertIn("test_enum_estimator", self.ADV_ESTIMATOR_REGISTRY) 60 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_enum_estimator"], test_fn) 61 | 62 | def test_duplicate_registration_same_function(self): 63 | """Test that registering the same function twice doesn't raise an error""" 64 | register_adv_est("duplicate_test")(mock_test_fn) 65 | register_adv_est("duplicate_test")(mock_test_fn) 66 | 67 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["duplicate_test"], mock_test_fn) 68 | 69 | def test_duplicate_registration_different_function(self): 70 | """Test that registering different functions with same name raises ValueError""" 71 | @register_adv_est("conflict_test") 72 | def test_fn1(): 73 | pass 74 | 75 | with self.assertRaises(ValueError): 76 | @register_adv_est("conflict_test") 77 | def test_fn2(): 78 | pass 79 | 80 | def test_decorator_preserves_function(self): 81 | """Test that the decorator returns the original function""" 82 | def test_fn(): 83 | return "original" 84 | 85 | decorated = register_adv_est("preserve_test")(test_fn) 86 | self.assertEqual(decorated(), "original") 87 | 88 | def test_multiple_registrations(self): 89 | """Test registering multiple different functions""" 90 | init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY) 91 | @register_adv_est("estimator1") 92 | def fn1(): 93 | pass 94 | 95 | @register_adv_est("estimator2") 96 | def fn2(): 97 | pass 98 | 99 | self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count) 100 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator1"], fn1) 101 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator2"], fn2) 102 | 103 | def test_get_adv_estimator_fn_valid_names(self): 104 | """Test that valid names return the correct function from registry.""" 105 | # Test GAE 106 | gae_fn = get_adv_estimator_fn("gae") 107 | assert gae_fn(5) == 10 # 5 * 2 = 10 108 | 109 | # Test Vtrace 110 | vtrace_fn = get_adv_estimator_fn("vtrace") 111 | assert vtrace_fn(5) == 6 # 5 + 1 = 6 112 | 113 | def test_get_adv_estimator_fn_invalid_name(self): 114 | """Test that invalid names raise ValueError.""" 115 | with pytest.raises(ValueError) as excinfo: 116 | get_adv_estimator_fn("invalid_name") 117 | assert "Unknown advantage estimator simply: invalid_name" in str(excinfo.value) 118 | 119 | def test_get_adv_estimator_fn_case_sensitive(self): 120 | """Test that name lookup is case-sensitive.""" 121 | with pytest.raises(ValueError): 122 | get_adv_estimator_fn("GAE") # Different case 123 | 124 | 125 | if __name__ == "__main__": 126 | unittest.main() -------------------------------------------------------------------------------- /tests/utils/cpu_tests/_test_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # Test module for import_utils.load_extern_type testing 17 | class TestClass: 18 | """A test class to be imported by load_extern_type""" 19 | 20 | def __init__(self, value=None): 21 | self.value = value or "default" 22 | 23 | def get_value(self): 24 | return self.value 25 | 26 | 27 | TEST_CONSTANT = "test_constant_value" 28 | 29 | 30 | def test_function(): 31 | return "test_function_result" 32 | -------------------------------------------------------------------------------- /tests/utils/cpu_tests/test_fs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | 18 | import verl.utils.fs as fs 19 | 20 | 21 | def test_record_and_check_directory_structure(tmp_path): 22 | # Create test directory structure 23 | test_dir = tmp_path / "test_dir" 24 | test_dir.mkdir() 25 | (test_dir / "file1.txt").write_text("test") 26 | (test_dir / "subdir").mkdir() 27 | (test_dir / "subdir" / "file2.txt").write_text("test") 28 | 29 | # Create structure record 30 | record_file = fs._record_directory_structure(test_dir) 31 | 32 | # Verify record file exists 33 | assert os.path.exists(record_file) 34 | 35 | # Initial check should pass 36 | assert fs._check_directory_structure(test_dir, record_file) is True 37 | 38 | # Modify structure and verify check fails 39 | (test_dir / "new_file.txt").write_text("test") 40 | assert fs._check_directory_structure(test_dir, record_file) is False 41 | 42 | 43 | def test_copy_from_hdfs_with_mocks(tmp_path, monkeypatch): 44 | # Mock HDFS dependencies 45 | monkeypatch.setattr(fs, "is_non_local", lambda path: True) 46 | 47 | # side_effect will simulate the copy by creating parent dirs + empty file 48 | def fake_copy(src: str, dst: str, *args, **kwargs): 49 | dst_path = Path(dst) 50 | dst_path.parent.mkdir(parents=True, exist_ok=True) 51 | dst_path.write_bytes(b"") # touch an empty file 52 | 53 | monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy 54 | 55 | # Test parameters 56 | test_cache = tmp_path / "cache" 57 | hdfs_path = "hdfs://test/path/file.txt" 58 | 59 | # Test initial copy 60 | local_path = fs.copy_to_local(hdfs_path, cache_dir=test_cache) 61 | expected_path = os.path.join(test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path)) 62 | assert local_path == expected_path 63 | assert os.path.exists(local_path) 64 | 65 | 66 | def test_always_recopy_flag(tmp_path, monkeypatch): 67 | # Mock HDFS dependencies 68 | monkeypatch.setattr(fs, "is_non_local", lambda path: True) 69 | 70 | copy_call_count = 0 71 | 72 | def fake_copy(src: str, dst: str, *args, **kwargs): 73 | nonlocal copy_call_count 74 | copy_call_count += 1 75 | dst_path = Path(dst) 76 | dst_path.parent.mkdir(parents=True, exist_ok=True) 77 | dst_path.write_bytes(b"") 78 | 79 | monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy 80 | 81 | test_cache = tmp_path / "cache" 82 | hdfs_path = "hdfs://test/path/file.txt" 83 | 84 | # Initial copy (always_recopy=False) 85 | fs.copy_to_local(hdfs_path, cache_dir=test_cache) 86 | assert copy_call_count == 1 87 | 88 | # Force recopy (always_recopy=True) 89 | fs.copy_to_local(hdfs_path, cache_dir=test_cache, always_recopy=True) 90 | assert copy_call_count == 2 91 | 92 | # Subsequent normal call (always_recopy=False) 93 | fs.copy_to_local(hdfs_path, cache_dir=test_cache) 94 | assert copy_call_count == 2 # Should not increment 95 | -------------------------------------------------------------------------------- /tests/utils/cpu_tests/test_import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import pytest 18 | 19 | from verl.utils.import_utils import load_extern_type 20 | 21 | # Path to the test module 22 | TEST_MODULE_PATH = os.path.join(os.path.dirname(__file__), "_test_module.py") 23 | 24 | 25 | def test_load_extern_type_class(): 26 | """Test loading a class from an external file""" 27 | TestClass = load_extern_type(TEST_MODULE_PATH, "TestClass") 28 | 29 | # Verify the class was loaded correctly 30 | assert TestClass is not None 31 | assert TestClass.__name__ == "TestClass" 32 | 33 | # Test instantiation and functionality 34 | instance = TestClass() 35 | assert instance.value == "default" 36 | 37 | # Test with a custom value 38 | custom_instance = TestClass("custom") 39 | assert custom_instance.get_value() == "custom" 40 | 41 | 42 | def test_load_extern_type_function(): 43 | """Test loading a function from an external file""" 44 | test_function = load_extern_type(TEST_MODULE_PATH, "test_function") 45 | 46 | # Verify the function was loaded correctly 47 | assert test_function is not None 48 | assert callable(test_function) 49 | 50 | # Test function execution 51 | result = test_function() 52 | assert result == "test_function_result" 53 | 54 | 55 | def test_load_extern_type_constant(): 56 | """Test loading a constant from an external file""" 57 | constant = load_extern_type(TEST_MODULE_PATH, "TEST_CONSTANT") 58 | 59 | # Verify the constant was loaded correctly 60 | assert constant is not None 61 | assert constant == "test_constant_value" 62 | 63 | 64 | def test_load_extern_type_nonexistent_file(): 65 | """Test behavior when file doesn't exist""" 66 | with pytest.raises(FileNotFoundError): 67 | load_extern_type("/nonexistent/path.py", "SomeType") 68 | 69 | 70 | def test_load_extern_type_nonexistent_type(): 71 | """Test behavior when type doesn't exist in the file""" 72 | with pytest.raises(AttributeError): 73 | load_extern_type(TEST_MODULE_PATH, "NonExistentType") 74 | 75 | 76 | def test_load_extern_type_none_path(): 77 | """Test behavior when file path is None""" 78 | result = load_extern_type(None, "SomeType") 79 | assert result is None 80 | 81 | 82 | def test_load_extern_type_invalid_module(): 83 | """Test behavior when module has syntax errors""" 84 | # Create a temporary file with syntax errors 85 | import tempfile 86 | 87 | with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp_file: 88 | temp_file.write("This is not valid Python syntax :") 89 | temp_path = temp_file.name 90 | 91 | try: 92 | with pytest.raises(RuntimeError): 93 | load_extern_type(temp_path, "SomeType") 94 | finally: 95 | # Clean up the temporary file 96 | if os.path.exists(temp_path): 97 | os.remove(temp_path) 98 | -------------------------------------------------------------------------------- /tests/utils/cpu_tests/test_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from types import SimpleNamespace # Or use a mock object library 16 | 17 | import pytest 18 | 19 | from verl.utils.model import update_model_config 20 | 21 | 22 | # Parametrize with different override scenarios 23 | @pytest.mark.parametrize( 24 | "override_kwargs", 25 | [ 26 | {"param_a": 5, "new_param": "plain_added"}, 27 | {"param_a": 2, "nested_params": {"sub_param_x": "updated_x", "sub_param_z": True}}, 28 | ], 29 | ) 30 | def test_update_model_config(override_kwargs): 31 | """ 32 | Tests that update_model_config correctly updates attributes, 33 | handling both plain and nested overrides via parametrization. 34 | """ 35 | # Create a fresh mock config object for each test case 36 | mock_config = SimpleNamespace(param_a=1, nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), other_param="keep_me") 37 | # Apply the updates using the parametrized override_kwargs 38 | update_model_config(mock_config, override_kwargs) 39 | 40 | # Assertions to check if the config was updated correctly 41 | if "nested_params" in override_kwargs: # Case 2: Nested override 42 | override_nested = override_kwargs["nested_params"] 43 | assert mock_config.nested_params.sub_param_x == override_nested["sub_param_x"], "Nested sub_param_x mismatch" 44 | assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged" 45 | assert hasattr(mock_config.nested_params, "sub_param_z"), "Expected nested sub_param_z to be added" 46 | assert mock_config.nested_params.sub_param_z == override_nested["sub_param_z"], "Value of sub_param_z mismatch" 47 | else: # Case 1: Plain override (nested params untouched) 48 | assert mock_config.nested_params.sub_param_x == "original_x", "Nested sub_param_x should be unchanged" 49 | assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged" 50 | assert not hasattr(mock_config.nested_params, "sub_param_z"), "Nested sub_param_z should not exist" 51 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/checkpoint/test_fsdp_ckpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import shutil 16 | 17 | import torch 18 | import torch.distributed 19 | from torch.distributed import init_device_mesh 20 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 21 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 22 | from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config 23 | 24 | from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager 25 | from verl.utils.distributed import initialize_global_process_group 26 | from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2 27 | 28 | ckpt_path = "./ci_checkpoints" 29 | 30 | 31 | def test_fsdp_ckpt(strategy="fsdp"): 32 | assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" 33 | local_rank, rank, world_size = initialize_global_process_group() 34 | device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) 35 | 36 | model_name = "Qwen/Qwen2.5-0.5B-Instruct" 37 | config = Qwen2Config(num_hidden_layers=1) 38 | 39 | with torch.device("cuda"): 40 | model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") 41 | model = model.to(device="cuda") 42 | 43 | # Wrap model with FSDP 44 | if strategy == "fsdp": 45 | mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) 46 | 47 | model = FSDP( 48 | model, 49 | use_orig_params=False, 50 | device_id=torch.cuda.current_device(), 51 | sharding_strategy=ShardingStrategy.FULL_SHARD, 52 | mixed_precision=mixed_precision, 53 | device_mesh=device_mesh, 54 | ) 55 | else: 56 | mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True) 57 | fsdp_kwargs = { 58 | "mesh": device_mesh, 59 | "mp_policy": mp_policy, 60 | } 61 | apply_fsdp2(model, fsdp_kwargs, {}) 62 | 63 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) 64 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) 65 | 66 | # Create checkpoint manager 67 | tokenizer = AutoTokenizer.from_pretrained(model_name) 68 | checkpoint_manager = FSDPCheckpointManager(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer) 69 | 70 | # Generate sample input 71 | batch_size = 2 72 | seq_len = 32 73 | vocab_size = 32000 74 | # First input for initial update 75 | input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") 76 | attention_mask1 = torch.ones_like(input_ids1) 77 | 78 | # Second input for verification 79 | input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") 80 | attention_mask2 = torch.ones_like(input_ids2) 81 | 82 | # Step 1: Initial update and save checkpoint 83 | outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1) 84 | loss1 = outputs1.logits.mean() 85 | loss1.backward() 86 | optimizer.step() 87 | lr_scheduler.step() 88 | optimizer.zero_grad() 89 | 90 | # Save checkpoint after first update 91 | checkpoint_path = os.path.join(ckpt_path, "checkpoint") 92 | checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) 93 | 94 | # Step 2: Second update and forward pass 95 | outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2) 96 | loss2 = outputs2.logits.mean() 97 | loss2.backward() 98 | optimizer.step() 99 | lr_scheduler.step() 100 | optimizer.zero_grad() 101 | 102 | # Record logits after second update 103 | with torch.no_grad(): 104 | logits_before_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits 105 | 106 | # Step 3: Load checkpoint and repeat second update 107 | checkpoint_manager.load_checkpoint(checkpoint_path) 108 | 109 | # Repeat the second update with same input 110 | outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2) 111 | loss3 = outputs3.logits.mean() 112 | loss3.backward() 113 | optimizer.step() 114 | lr_scheduler.step() 115 | optimizer.zero_grad() 116 | 117 | # Record logits after loaded checkpoint and update 118 | with torch.no_grad(): 119 | logits_after_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits 120 | 121 | # Step 4: Verify outputs match 122 | torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0) 123 | print("Checkpoint save/load test passed!") 124 | 125 | # Cleanup 126 | shutil.rmtree(ckpt_path, ignore_errors=True) 127 | torch.distributed.barrier() 128 | torch.distributed.destroy_process_group() 129 | 130 | 131 | if __name__ == "__main__": 132 | strategy = os.environ.get("STRATEGY", "fsdp") 133 | test_fsdp_ckpt(strategy=strategy) 134 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/dataset/test_rl_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | 16 | import torch 17 | from omegaconf import OmegaConf 18 | from torch.utils.data import DataLoader 19 | 20 | 21 | def get_gsm8k_data(): 22 | # prepare test dataset 23 | local_folder = os.path.expanduser("~/verl-data/gsm8k/") 24 | local_path = os.path.join(local_folder, "train.parquet") 25 | os.makedirs(local_folder, exist_ok=True) 26 | return local_path 27 | 28 | 29 | def test_rl_dataset(): 30 | from verl.utils import hf_tokenizer 31 | from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn 32 | 33 | tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct") 34 | local_path = get_gsm8k_data() 35 | config = OmegaConf.create( 36 | { 37 | "prompt_key": "prompt", 38 | "max_prompt_length": 256, 39 | "filter_overlong_prompts": True, 40 | "filter_overlong_prompts_workers": 2, 41 | } 42 | ) 43 | dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config) 44 | 45 | dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) 46 | 47 | a = next(iter(dataloader)) 48 | 49 | from verl import DataProto 50 | 51 | tensors = {} 52 | non_tensors = {} 53 | 54 | for key, val in a.items(): 55 | if isinstance(val, torch.Tensor): 56 | tensors[key] = val 57 | else: 58 | non_tensors[key] = val 59 | 60 | data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) 61 | assert "input_ids" in data_proto.batch 62 | 63 | data = dataset[0]["input_ids"] 64 | output = tokenizer.batch_decode([data])[0] 65 | print(f"type: type{output}") 66 | print(f"\n\noutput: {output}") 67 | 68 | 69 | def test_image_rl_data(): 70 | from verl.utils import hf_processor, hf_tokenizer 71 | from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn 72 | 73 | tokenizer = hf_tokenizer("Qwen/Qwen2-VL-2B-Instruct") 74 | processor = hf_processor("Qwen/Qwen2-VL-2B-Instruct") 75 | config = OmegaConf.create( 76 | { 77 | "prompt_key": "prompt", 78 | "max_prompt_length": 1024, 79 | "filter_overlong_prompts": True, 80 | "filter_overlong_prompts_workers": 2, 81 | } 82 | ) 83 | dataset = RLHFDataset( 84 | data_files=os.path.expanduser("~/data/geo3k/train.parquet"), 85 | tokenizer=tokenizer, 86 | config=config, 87 | processor=processor, 88 | ) 89 | 90 | dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) 91 | 92 | a = next(iter(dataloader)) 93 | 94 | from verl import DataProto 95 | 96 | tensors = {} 97 | non_tensors = {} 98 | 99 | for key, val in a.items(): 100 | if isinstance(val, torch.Tensor): 101 | tensors[key] = val 102 | else: 103 | non_tensors[key] = val 104 | 105 | data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) 106 | 107 | assert "multi_modal_data" in data_proto.non_tensor_batch 108 | assert "multi_modal_inputs" in data_proto.non_tensor_batch 109 | 110 | data = dataset[0]["input_ids"] 111 | output = tokenizer.batch_decode([data])[0] 112 | print(f"type: type{output}") 113 | print(f"\n\noutput: {output}") 114 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/dataset/test_rm_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | 16 | from verl.utils import hf_tokenizer 17 | from verl.utils.dataset.rm_dataset import RMDataset 18 | 19 | 20 | def get_rm_data(): 21 | # prepare test dataset 22 | local_folder = os.path.expanduser("~/verl-data/full_hh_rlhf/rm/") 23 | local_path = os.path.join(local_folder, "test.parquet") 24 | os.makedirs(local_folder, exist_ok=True) 25 | return local_path 26 | 27 | 28 | def test_rm_dataset(): 29 | tokenizer = hf_tokenizer("facebook/opt-1.3b") 30 | local_path = get_rm_data() 31 | dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512) 32 | data = dataset[0]["input_ids"] 33 | output = tokenizer.batch_decode(data) 34 | assert len(output) > 1 35 | assert isinstance(output[0], str) 36 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/dataset/test_sft_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | 16 | from verl.utils import hf_tokenizer 17 | from verl.utils.dataset.sft_dataset import SFTDataset 18 | 19 | 20 | def get_gsm8k_data(): 21 | # prepare test dataset 22 | local_folder = os.path.expanduser("~/verl-data/gsm8k/") 23 | local_path = os.path.join(local_folder, "train.parquet") 24 | return local_path 25 | 26 | 27 | def test_sft_cot_dataset(): 28 | tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct") 29 | local_path = get_gsm8k_data() 30 | from omegaconf import OmegaConf 31 | 32 | dataset = SFTDataset( 33 | parquet_files=local_path, 34 | tokenizer=tokenizer, 35 | config=OmegaConf.create( 36 | { 37 | "prompt_key": "prompt", 38 | "prompt_dict_keys": ["content"], 39 | "response_key": "extra_info", 40 | "response_dict_keys": ["answer"], 41 | "max_length": 512, 42 | } 43 | ), 44 | ) 45 | 46 | data = dataset[0]["input_ids"] 47 | output = tokenizer.batch_decode([data])[0] 48 | assert len(output) > 1 49 | assert isinstance(output, str) 50 | 51 | 52 | def test_sft_dataset(): 53 | tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct") 54 | local_path = get_gsm8k_data() 55 | from omegaconf import OmegaConf 56 | 57 | dataset = SFTDataset( 58 | parquet_files=local_path, 59 | tokenizer=tokenizer, 60 | config=OmegaConf.create( 61 | { 62 | "prompt_key": "extra_info", 63 | "prompt_dict_keys": ["question"], 64 | "response_key": "extra_info", 65 | "response_dict_keys": ["answer"], 66 | "max_length": 512, 67 | } 68 | ), 69 | ) 70 | 71 | data = dataset[0]["input_ids"] 72 | output = tokenizer.batch_decode([data])[0] 73 | assert len(output) > 1 74 | assert isinstance(output, str) 75 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/megatron/test_pipeline_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from verl.utils.megatron.pipeline_parallel import make_batch_generator 16 | 17 | 18 | def test_make_batch_generator_no_vpp(): 19 | batches = [1, 2, 3] 20 | vpp_size = 1 21 | generator = make_batch_generator(batches, vpp_size) 22 | assert list(generator) == batches 23 | 24 | 25 | def test_make_batch_generator_with_vpp(): 26 | batches = [{"data": 1}, {"data": 2}] 27 | vpp_size = 2 28 | generators = make_batch_generator(batches, vpp_size) 29 | assert isinstance(generators, list) 30 | assert len(generators) == vpp_size 31 | 32 | # Check each generator yields the original batches 33 | for gen in generators: 34 | assert list(gen) == batches 35 | 36 | 37 | def test_make_batch_generator_empty(): 38 | batches = [] 39 | vpp_size = 1 40 | generator = make_batch_generator(batches, vpp_size) 41 | assert list(generator) == [] 42 | 43 | vpp_size = 3 44 | generators = make_batch_generator(batches, vpp_size) 45 | assert len(generators) == vpp_size 46 | for gen in generators: 47 | assert list(gen) == [] 48 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/test_activation_offload.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import shutil 16 | 17 | import pytest 18 | import torch 19 | import torch.distributed 20 | import torch.multiprocessing as mp 21 | from torch.distributed import init_device_mesh 22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 23 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 24 | from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config 25 | 26 | from verl.utils.activation_offload import enable_activation_offloading 27 | from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager 28 | from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy 29 | 30 | ckpt_path = "./ci_checkpoints" 31 | 32 | def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy="fsdp"): 33 | torch.cuda.set_device(rank) 34 | torch.distributed.init_process_group( 35 | backend="nccl", 36 | init_method=f"file://{rendezvous_file}", 37 | rank=rank, 38 | world_size=world_size, 39 | ) 40 | device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) 41 | 42 | model_name = "Qwen/Qwen2.5-0.5B-Instruct" 43 | config = Qwen2Config(num_hidden_layers=4) 44 | 45 | with torch.device("cuda"): 46 | model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") 47 | model = model.to(device="cuda") 48 | 49 | # Wrap model with FSDP 50 | mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) 51 | 52 | if strategy == "fsdp": 53 | model = FSDP(model, use_orig_params=False, device_id=torch.cuda.current_device(), sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mixed_precision, device_mesh=device_mesh, auto_wrap_policy=get_fsdp_wrap_policy(module=model)) 54 | else: 55 | mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True) 56 | fsdp_kwargs = { 57 | "mesh": device_mesh, 58 | "mp_policy": mp_policy, 59 | } 60 | apply_fsdp2(model, fsdp_kwargs, {}) 61 | 62 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) 63 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) 64 | 65 | # Create checkpoint manager 66 | tokenizer = AutoTokenizer.from_pretrained(model_name) 67 | checkpoint_manager = FSDPCheckpointManager(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer) 68 | 69 | # Generate sample input 70 | batch_size = 2 71 | seq_len = 32 72 | vocab_size = 32000 73 | # First input for initial update 74 | input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") 75 | attention_mask1 = torch.ones_like(input_ids1) 76 | 77 | # Second input for verification 78 | input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") 79 | attention_mask2 = torch.ones_like(input_ids2) 80 | 81 | # Step 1: Initial update and save checkpoint 82 | outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1) 83 | loss1 = outputs1.logits.mean() 84 | loss1.backward() 85 | optimizer.step() 86 | lr_scheduler.step() 87 | optimizer.zero_grad() 88 | 89 | # Save checkpoint after first update 90 | checkpoint_path = os.path.join(ckpt_path, "checkpoint") 91 | checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) 92 | 93 | # Step 2: Second update and forward pass 94 | outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2) 95 | loss2 = outputs2.logits.mean() 96 | loss2.backward() 97 | optimizer.step() 98 | lr_scheduler.step() 99 | optimizer.zero_grad() 100 | 101 | # Record logits after second update 102 | with torch.no_grad(): 103 | logits_without_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits 104 | 105 | # Step 3: wrap module with activation offloading and load checkpoint 106 | enable_activation_offloading(model, "fsdp") 107 | checkpoint_manager.load_checkpoint(checkpoint_path) 108 | 109 | # Step 4: Repeat the second update with same input 110 | outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2) 111 | loss3 = outputs3.logits.mean() 112 | loss3.backward() 113 | optimizer.step() 114 | lr_scheduler.step() 115 | optimizer.zero_grad() 116 | 117 | # Record logits after loaded checkpoint and update 118 | with torch.no_grad(): 119 | logits_with_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits 120 | 121 | # Step 4: Verify outputs match 122 | torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0) 123 | print(f"Activaiton offloading for {strategy} test passed on {world_size} GPUs!") 124 | 125 | # Cleanup 126 | shutil.rmtree(ckpt_path, ignore_errors=True) 127 | torch.distributed.barrier() 128 | torch.distributed.destroy_process_group() 129 | 130 | 131 | @pytest.mark.parametrize("world_size", (2, 4)) 132 | @pytest.mark.parametrize("strategy", ("fsdp", "fsdp2")) 133 | def test_activation_offloading(world_size, strategy, tmp_path): 134 | rendezvous_file = str(tmp_path / "rdzv_file") 135 | os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) 136 | 137 | mp.spawn( 138 | fn=_fsdp_activation_offloading_test, 139 | args=(world_size, rendezvous_file, strategy), 140 | nprocs=world_size, 141 | join=True, 142 | ) 143 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/test_flops_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import pytest 18 | 19 | from verl.utils.flops_counter import FlopsCounter 20 | 21 | VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3"} 22 | 23 | 24 | class Config: 25 | def __init__(self, config_dict): 26 | for key, value in config_dict.items(): 27 | setattr(self, key, value) 28 | 29 | 30 | CONFIG = { 31 | "llama": { 32 | "config": { # llama2-7B 33 | "model_type": "llama", 34 | "vocab_size": 32000, 35 | "hidden_size": 4096, 36 | "intermediate_size": 11008, 37 | "num_hidden_layers": 32, 38 | "num_attention_heads": 32, 39 | "num_key_value_heads": 32, 40 | }, 41 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), 42 | # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim 43 | # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*32*4096 44 | # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*32*4096 45 | "expected_flops_tuple": (153555818250240 / 1e12, 575955114393600 / 1e12), 46 | }, 47 | "qwen2": { 48 | "config": { # Qwen/Qwen2.5-7B-Instruct 49 | "model_type": "qwen2", 50 | "vocab_size": 152064, 51 | "hidden_size": 3584, 52 | "intermediate_size": 18944, 53 | "num_hidden_layers": 28, 54 | "num_attention_heads": 28, 55 | "num_key_value_heads": 4, 56 | }, 57 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), 58 | # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim 59 | # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*28*3584 60 | # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*28*3584 61 | "expected_flops_tuple": (170388331954176 / 1e12, 622070178250752 / 1e12), 62 | }, 63 | "qwen3": { 64 | "config": { # Qwen/Qwen3-8B 65 | "model_type": "qwen3", 66 | "vocab_size": 151936, 67 | "hidden_size": 4096, 68 | "intermediate_size": 12288, 69 | "num_hidden_layers": 36, 70 | "num_attention_heads": 32, 71 | "num_key_value_heads": 8, 72 | "head_dim": 128, 73 | }, 74 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), 75 | # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim 76 | # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*36*128*32 77 | # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*36*128*32 78 | "expected_flops_tuple": (185867930959872 / 1e12, 692924253732864 / 1e12), 79 | }, 80 | "qwen3_moe": { 81 | "config": { # Qwen/Qwen3-30B-A3B-Base 82 | "model_type": "qwen3_moe", 83 | "hidden_size": 2048, 84 | "vocab_size": 151936, 85 | "num_hidden_layers": 48, 86 | "num_key_value_heads": 4, 87 | "num_attention_heads": 32, 88 | "head_dim": 128, 89 | "moe_intermediate_size": 768, 90 | "num_experts_per_tok": 8, 91 | "num_experts": 128, 92 | }, 93 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), 94 | # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 + hidden*num_experts))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim 95 | # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*48*128*32 96 | # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*48*128*32 97 | "expected_flops_tuple": (85087060230144 / 1e12, 365944098521088 / 1e12), 98 | }, 99 | "deepseek_v3": { 100 | "config": { # deepseek-ai/DeepSeek-Prover-V2-671B 101 | "model_type": "deepseek_v3", 102 | "hidden_size": 7168, 103 | "vocab_size": 129280, 104 | "moe_intermediate_size": 2048, 105 | "num_hidden_layers": 61, 106 | "first_k_dense_replace": 3, 107 | "num_attention_heads": 128, 108 | "n_routed_experts": 256, 109 | "num_experts_per_tok": 8, 110 | "n_shared_experts": 1, 111 | "kv_lora_rank": 512, 112 | "qk_rope_head_dim": 64, 113 | "v_head_dim": 128, 114 | "intermediate_size": 18432, 115 | "qk_nope_head_dim": 128, 116 | "q_lora_rank": 1536, 117 | }, 118 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), 119 | # (1536*7168+128*192*1536+7168*(512+64)+128*(128+128)*512+128*128*7168) = 187105280 120 | # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*61*192*128 121 | # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*61*192*128 122 | "expected_flops_tuple": (906535995703296 / 1e12, 3674028304760832 / 1e12), 123 | }, 124 | } 125 | 126 | 127 | @pytest.mark.parametrize( 128 | "config_type", 129 | ["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3"], 130 | ) 131 | def test_flops_counter(config_type: str): 132 | test_config = CONFIG[config_type] 133 | config = Config(test_config["config"]) 134 | flops_counter = FlopsCounter(config) 135 | for batch_seqlens, expected_flops in zip(test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"]): 136 | # set delta time to 1 to get the flops 137 | counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1) 138 | print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}") 139 | assert math.isclose(counted_flops, expected_flops), f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}" 140 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/test_seqlen_balancing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.distributed as dist 17 | import torch.multiprocessing as mp 18 | 19 | from verl import DataProto 20 | from verl.utils.model import create_random_mask 21 | from verl.utils.seqlen_balancing import ceildiv, get_reverse_idx, rearrange_micro_batches 22 | 23 | 24 | def test_seqlen_balancing(): 25 | input_ids = torch.randint(low=0, high=10, size=(20, 100)) 26 | 27 | attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5) 28 | data = {"input_ids": input_ids, "attention_mask": attention_mask} 29 | dataproto = DataProto.from_single_dict(data) 30 | micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300) 31 | batch = torch.cat(micro_batches) 32 | micro_bsz_idx = [] 33 | for idx in micro_bsz_idx_lst: 34 | micro_bsz_idx.extend(idx) 35 | reverse_idx_map = get_reverse_idx(micro_bsz_idx) 36 | reverse_idx_map = torch.tensor(reverse_idx_map) 37 | new_batch = batch[reverse_idx_map] 38 | torch.testing.assert_close(new_batch, dataproto.batch) 39 | 40 | 41 | def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb): 42 | # 1) init process group & CUDA 43 | torch.cuda.set_device(rank) 44 | dist.init_process_group( 45 | backend="nccl", 46 | init_method=init_method, 47 | world_size=world_size, 48 | rank=rank, 49 | ) 50 | 51 | # 2) build a small random batch (each rank different length to force mismatch) 52 | torch.manual_seed(42 + rank) 53 | input_ids = torch.randint(0, 10, (20 + rank * 5, 100), device=f"cuda:{rank}") 54 | attention_mask = create_random_mask( 55 | input_ids=input_ids, 56 | max_ratio_of_left_padding=0.1, 57 | max_ratio_of_valid_token=0.9, 58 | min_ratio_of_valid_token=0.5, 59 | ) 60 | dp = {"input_ids": input_ids, "attention_mask": attention_mask} 61 | proto = DataProto.from_single_dict(dp) 62 | batch = proto.batch 63 | 64 | # 3) call rearrange_micro_batches with one of the two params under test 65 | micros, idx_lst = rearrange_micro_batches( 66 | batch, 67 | max_token_len=max_token_len, 68 | dp_group=dist.group.WORLD, 69 | same_micro_num_in_dp=use_same_dp, 70 | min_num_micro_batch=min_mb, 71 | ) 72 | 73 | # 4) check the enforced counts 74 | seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) 75 | total_seqlen = seq_len_effective.sum().item() 76 | local = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len)) 77 | 78 | if min_mb is not None: 79 | expected = max(local, min_mb) 80 | assert len(micros) == expected 81 | if use_same_dp: 82 | # gather all local_counts 83 | counts = [torch.zeros(1, device=f"cuda:{rank}") for _ in range(world_size)] 84 | counts[rank].fill_(local) 85 | dist.all_gather(counts, counts[rank]) 86 | expected = max(int(c.item()) for c in counts) 87 | assert len(micros) == expected 88 | else: 89 | # if neither, we get the local natural count 90 | assert len(micros) == local 91 | 92 | # 5) reconstruction sanity: concat→reverse_idx→orig 93 | flat = torch.cat(micros, dim=0) 94 | idx = [] 95 | for sub in idx_lst: 96 | idx.extend(sub) 97 | inv = get_reverse_idx(idx) 98 | inv = torch.tensor(inv, device=flat.device) 99 | reconstructed = flat[inv] 100 | torch.testing.assert_close(reconstructed, batch) 101 | 102 | dist.destroy_process_group() 103 | 104 | 105 | def test_seqlen_balancing_distributed_params(tmp_path): 106 | world_size = 2 107 | init_file = tmp_path / "dist_init" 108 | init_file.write_text("") # empty file 109 | init_method = f"file://{init_file}" 110 | 111 | # test min_num_micro_batch only 112 | mp.spawn( 113 | _worker, 114 | args=(world_size, init_method, 300, False, 4), 115 | nprocs=world_size, 116 | join=True, 117 | ) 118 | 119 | # test same_micro_num_in_dp only 120 | mp.spawn( 121 | _worker, 122 | args=(world_size, init_method, 300, True, None), 123 | nprocs=world_size, 124 | join=True, 125 | ) 126 | -------------------------------------------------------------------------------- /tests/utils/gpu_tests/test_torch_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import pytest 18 | import torch 19 | import torch.distributed as dist 20 | import torch.multiprocessing as mp 21 | 22 | from verl.utils.torch_functional import distributed_masked_mean, distributed_mean_max_min_std 23 | 24 | 25 | def _worker_mean(rank: int, world_size: int, rendezvous_file: str): 26 | # 1) set GPU and init NCCL 27 | torch.cuda.set_device(rank) 28 | dist.init_process_group( 29 | backend="nccl", 30 | init_method=f"file://{rendezvous_file}", 31 | rank=rank, 32 | world_size=world_size, 33 | ) 34 | 35 | # each rank holds tensor [rank+1] 36 | local = torch.tensor([float(rank + 1)], device=f"cuda:{rank}") 37 | mean, gmax, gmin, gstd = distributed_mean_max_min_std(local, True, True, True) 38 | 39 | values = [float(i + 1) for i in range(world_size)] 40 | exp_mean = sum(values) / len(values) 41 | exp_max = max(values) 42 | exp_min = min(values) 43 | var = sum((x - exp_mean) ** 2 for x in values) / (len(values) - 1) 44 | exp_std = var**0.5 45 | 46 | # all ranks should see the same result 47 | assert torch.allclose(mean.cpu(), torch.tensor(exp_mean)), f"mean@{rank}" 48 | assert torch.allclose(gmax.cpu(), torch.tensor(exp_max)), f"max@{rank}" 49 | assert torch.allclose(gmin.cpu(), torch.tensor(exp_min)), f"min@{rank}" 50 | assert torch.allclose(gstd.cpu(), torch.tensor(exp_std)), f"std@{rank}" 51 | 52 | dist.destroy_process_group() 53 | 54 | 55 | @pytest.mark.parametrize("world_size", [2, 4]) 56 | def test_distributed_mean_max_min_std(world_size, tmp_path): 57 | rendezvous_file = str(tmp_path / "rdzv_mean") 58 | os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) 59 | 60 | mp.spawn( 61 | fn=_worker_mean, 62 | args=(world_size, rendezvous_file), 63 | nprocs=world_size, 64 | join=True, 65 | ) 66 | 67 | 68 | def _worker_mask(rank: int, world_size: int, rendezvous_file: str): 69 | torch.cuda.set_device(rank) 70 | dist.init_process_group( 71 | backend="nccl", 72 | init_method=f"file://{rendezvous_file}", 73 | rank=rank, 74 | world_size=world_size, 75 | ) 76 | 77 | # build per‐rank tensor and mask 78 | local_tensor = torch.tensor([rank * 2 + 1.0, rank * 2 + 2.0], device=f"cuda:{rank}") 79 | if rank == 0: 80 | mask = torch.tensor([1, 0], device=f"cuda:{rank}", dtype=torch.float32) 81 | else: 82 | mask = torch.tensor([0, 1], device=f"cuda:{rank}", dtype=torch.float32) 83 | 84 | gmean = distributed_masked_mean(local_tensor, mask) 85 | 86 | valid_values = [1.0] + [2 * i + 2.0 for i in range(1, world_size)] 87 | expected_mean = sum(valid_values) / len(valid_values) 88 | assert torch.allclose(gmean.cpu(), torch.tensor(expected_mean)), f"masked_mean@{rank}" 89 | 90 | dist.destroy_process_group() 91 | 92 | 93 | @pytest.mark.parametrize("world_size", [2, 4]) 94 | def test_distributed_masked_mean(world_size, tmp_path): 95 | rendezvous_file = str(tmp_path / "rdzv_mask") 96 | os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) 97 | 98 | mp.spawn( 99 | fn=_worker_mask, 100 | args=(world_size, rendezvous_file), 101 | nprocs=world_size, 102 | join=True, 103 | ) 104 | -------------------------------------------------------------------------------- /tests/workers/reward_manager/test_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | # Assuming REWARD_MANAGER_REGISTRY is defined somewhere in the module 18 | from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY, get_reward_manager_cls, register 19 | 20 | 21 | @pytest.fixture 22 | def setup(): 23 | """Setup test cases with a mock registry.""" 24 | REWARD_MANAGER_REGISTRY.clear() 25 | REWARD_MANAGER_REGISTRY.update({ 26 | "manager1": "Manager1Class", 27 | "manager2": "Manager2Class" 28 | }) 29 | return REWARD_MANAGER_REGISTRY 30 | 31 | def test_get_existing_manager(setup): 32 | """Test getting an existing reward manager class.""" 33 | assert get_reward_manager_cls("manager1") == "Manager1Class" 34 | assert get_reward_manager_cls("manager2") == "Manager2Class" 35 | 36 | def test_get_nonexistent_manager(setup): 37 | """Test getting a non-existent reward manager raises ValueError.""" 38 | with pytest.raises(ValueError) as excinfo: 39 | get_reward_manager_cls("unknown_manager") 40 | assert "Unknown reward manager: unknown_manager" in str(excinfo.value) 41 | 42 | def test_case_sensitivity(setup): 43 | """Test that manager names are case-sensitive.""" 44 | with pytest.raises(ValueError): 45 | get_reward_manager_cls("MANAGER1") 46 | with pytest.raises(ValueError): 47 | get_reward_manager_cls("Manager1") 48 | 49 | def test_empty_registry(setup): 50 | """Test behavior when registry is empty.""" 51 | REWARD_MANAGER_REGISTRY.clear() 52 | with pytest.raises(ValueError) as excinfo: 53 | get_reward_manager_cls("any_manager") 54 | assert "Unknown reward manager: any_manager" in str(excinfo.value) 55 | 56 | def test_register_new_class(setup): 57 | """Test registering a new class with the decorator.""" 58 | @register("test_manager") 59 | class TestManager: 60 | pass 61 | 62 | assert "test_manager" in REWARD_MANAGER_REGISTRY 63 | assert REWARD_MANAGER_REGISTRY["test_manager"] == TestManager 64 | 65 | def test_register_different_classes_same_name(setup): 66 | """Test that registering different classes with same name raises ValueError.""" 67 | @register("conflict_manager") 68 | class Manager1: 69 | pass 70 | 71 | with pytest.raises(ValueError) as context: 72 | @register("conflict_manager") 73 | class Manager2: 74 | pass 75 | 76 | assert REWARD_MANAGER_REGISTRY["conflict_manager"] == Manager1 77 | 78 | def test_decorator_returns_original_class(setup): 79 | """Test that the decorator returns the original class unchanged.""" 80 | @register("return_test") 81 | class OriginalClass: 82 | def method(setup): 83 | return 42 84 | 85 | assert OriginalClass().method() == 42 86 | assert REWARD_MANAGER_REGISTRY["return_test"] == OriginalClass 87 | -------------------------------------------------------------------------------- /tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config: -------------------------------------------------------------------------------- 1 | tools: 2 | - class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool" 3 | config: 4 | sandbox_fusion_url: "https://xxx.apigateway-cn-beijing.volceapi.com/run_code" 5 | tool_schema: 6 | type: "function" 7 | function: 8 | name: "code_interpreter" 9 | description: "A tool for executing code." 10 | parameters: 11 | type: "object" 12 | properties: 13 | code: 14 | type: "string" 15 | description: "The code to execute." 16 | required: ["code"] -------------------------------------------------------------------------------- /tests/workers/rollout/resource/tool_configs/search_tool_config: -------------------------------------------------------------------------------- 1 | tools: 2 | - class_name: verl.tools.search_tool.SearchTool 3 | config: 4 | retrieval_service_url: http://127.0.0.1:8000/retrieve 5 | num_workers: 120 6 | rate_limit: 120 7 | timeout: 30 8 | tool_schema: 9 | type: function 10 | function: 11 | name: search 12 | description: Searches the web for relevant information based on the given query. 13 | parameters: 14 | type: object 15 | properties: 16 | query_list: 17 | type: array 18 | item: 19 | type: string 20 | description: A list of fully-formed semantic queries. The tool will return search results for each query. 21 | required: 22 | - query_list -------------------------------------------------------------------------------- /tests/workers/rollout/test_sglang_async_rollout_w_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 SGLang Team 2 | # Copyright 2025 ModelBest Inc. and/or its affiliates 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | usage: torchrun --standalone --nnodes=1 \ 17 | --nproc_per_node=2 $(which pytest) \ 18 | -s test_sglang_async_rollout_w_tools.py 19 | """ 20 | 21 | import numpy as np 22 | import torch 23 | from tensordict import TensorDict 24 | from torch.distributed.device_mesh import init_device_mesh 25 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 26 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 27 | from utils_sglang import ( 28 | are_lists_similar, 29 | clean_torchelastic_env, 30 | generate_hf_output, 31 | get_rollout_config, 32 | initialize_global_process_group, 33 | load_tokenizer_and_model, 34 | prepare_inputs, 35 | ) 36 | 37 | from verl import DataProto 38 | from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout 39 | from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager 40 | 41 | 42 | def test_async_sglang_rollout_w_tool(): 43 | assert torch.cuda.device_count() >= 2 44 | initialize_global_process_group() 45 | clean_torchelastic_env() 46 | 47 | max_prompt_length = 32 48 | max_response_length = 16 49 | dtype = "bfloat16" 50 | tensor_parallel_size = 2 51 | local_model_path = "Qwen/Qwen2.5-0.5B" 52 | 53 | tokenizer, actor_model = load_tokenizer_and_model(local_model_path) 54 | 55 | preencode_prompts = [ 56 | [{"role": "user", "content": prompt, "tool_calls": None}] 57 | for prompt in [ 58 | "Who won the Champions League in 2019?", 59 | "The founder of Apple is", 60 | "What's the best way to learn python?", 61 | ] 62 | ] 63 | prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts] 64 | input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) 65 | 66 | hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) 67 | 68 | fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",)) 69 | inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp")) 70 | 71 | fsdp_model = FSDP( 72 | actor_model, 73 | use_orig_params=True, 74 | device_id=fsdp_device_mesh["fsdp"].get_local_rank(), 75 | mixed_precision=MixedPrecision(param_dtype=getattr(torch, dtype)), 76 | sharding_strategy=ShardingStrategy.FULL_SHARD, 77 | device_mesh=fsdp_device_mesh, 78 | ) 79 | 80 | rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, "./resource/tool_configs/sandbox_fusion_tool_config") 81 | rollout = SGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config) 82 | 83 | rollout_sharding_manager = FSDPSGLangShardingManager( 84 | module=fsdp_model, 85 | inference_engine=rollout._engine, 86 | model_config=actor_model.config, 87 | full_params=True, 88 | device_mesh=inference_device_mesh_cpu, 89 | ) 90 | 91 | with rollout_sharding_manager: 92 | prompt_dict = TensorDict( 93 | { 94 | "input_ids": input_ids, 95 | "attention_mask": attention_mask, 96 | "position_ids": position_ids, 97 | }, 98 | batch_size=input_ids.shape[0], 99 | ) 100 | print(f"preprocessed {input_ids.shape=}") 101 | 102 | messages = np.asarray(preencode_prompts) 103 | prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": np.array([{}] * input_ids.shape[0], dtype=object)}) 104 | 105 | prompts.meta_info.update( 106 | { 107 | "eos_token_id": tokenizer.eos_token_id, 108 | "pad_token_id": tokenizer.pad_token_id, 109 | } 110 | ) 111 | 112 | prompts = rollout_sharding_manager.preprocess_data(prompts) 113 | # log_gpu_memory_usage("Before generating sequences", logger=None) 114 | output = rollout.generate_sequences(prompts=prompts) 115 | print(f"generated {output.batch['responses'].shape=}") 116 | # log_gpu_memory_usage("After generating sequences", logger=None) 117 | output = rollout_sharding_manager.postprocess_data(output) 118 | print(f"postprocessed {output.batch['responses'].shape=}") 119 | sglang_output = output.to("cpu") 120 | 121 | sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch["responses"]) 122 | 123 | print(f"hf response: {hf_response_tokens}") 124 | print(f"sglang response: {sglang_response_tokens}") 125 | assert are_lists_similar(hf_response_tokens, sglang_response_tokens) 126 | print("SGLang w tool Test Passed!") 127 | 128 | torch.distributed.barrier() 129 | torch.distributed.destroy_process_group() 130 | 131 | 132 | if __name__ == "__main__": 133 | test_async_sglang_rollout_w_tool() 134 | -------------------------------------------------------------------------------- /tests/workers/rollout/test_sglang_spmd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 SGLang Team 2 | # Copyright 2025 ModelBest Inc. and/or its affiliates 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | usage: torchrun --standalone --nnodes=1 \ 17 | --nproc_per_node=2 $(which pytest) \ 18 | -s test_sglang_async_spmd.py 19 | """ 20 | 21 | import asyncio 22 | 23 | import torch 24 | from sglang.srt.entrypoints.engine import Engine 25 | from sglang.srt.utils import broadcast_pyobj 26 | from torch.distributed.device_mesh import init_device_mesh 27 | from utils_sglang import ( 28 | are_lists_similar, 29 | clean_torchelastic_env, 30 | generate_hf_output, 31 | initialize_global_process_group, 32 | load_tokenizer_and_model, 33 | prepare_inputs, 34 | ) 35 | 36 | 37 | def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): 38 | non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] 39 | token_ids = prompt_token_ids[non_pad_index:].tolist() 40 | return token_ids 41 | 42 | 43 | def test_sglang_spmd(): 44 | assert torch.cuda.device_count() >= 2 45 | initialize_global_process_group(spmd=True) 46 | clean_torchelastic_env() 47 | 48 | max_prompt_length = 16 49 | max_response_length = 16 50 | 51 | local_model_path = "Qwen/Qwen2.5-0.5B" 52 | tokenizer, actor_model = load_tokenizer_and_model(local_model_path) 53 | 54 | preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"] 55 | input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length) 56 | 57 | hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) 58 | 59 | tensor_parallel_size = 2 60 | inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]) 61 | tp_rank = inference_device_mesh_cpu["tp"].get_local_rank() 62 | 63 | if tp_rank == 0: 64 | llm = Engine( 65 | model_path=local_model_path, 66 | dtype="bfloat16", 67 | mem_fraction_static=0.5, 68 | enable_memory_saver=True, 69 | tp_size=inference_device_mesh_cpu["tp"].size(), 70 | ) 71 | 72 | input_ids = input_ids.cuda() 73 | idx_list = [] 74 | 75 | pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id 76 | for i in range(input_ids.shape[0]): 77 | idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) 78 | 79 | sampling_params = dict( 80 | n=1, 81 | temperature=0, 82 | top_p=1, 83 | top_k=-1, 84 | max_new_tokens=max_response_length, 85 | presence_penalty=0.0, 86 | frequency_penalty=0.0, 87 | repetition_penalty=1.0, 88 | skip_special_tokens=True, 89 | spaces_between_special_tokens=True, 90 | ignore_eos=False, 91 | ) 92 | 93 | loop = asyncio.get_event_loop() 94 | outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params)) 95 | else: 96 | outputs = None 97 | 98 | [outputs] = broadcast_pyobj( 99 | [outputs], 100 | rank=inference_device_mesh_cpu["tp"].get_local_rank(), 101 | src=inference_device_mesh_cpu["tp"].mesh[0].item(), 102 | dist_group=inference_device_mesh_cpu["tp"].get_group(), 103 | force_cpu_device=False, 104 | ) 105 | 106 | sglang_response_tokens = [output["text"] for output in outputs] 107 | 108 | print(f"sglang response: {sglang_response_tokens}") 109 | assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n" 110 | print("SPMD Test Passed!") 111 | 112 | torch.distributed.barrier() 113 | torch.distributed.destroy_process_group() 114 | -------------------------------------------------------------------------------- /tests/workers/rollout/utils_sglang.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 SGLang Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from datetime import timedelta 16 | 17 | import torch 18 | from omegaconf import OmegaConf 19 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 20 | 21 | from verl.utils.model import compute_position_id_with_mask 22 | from verl.utils.torch_functional import pad_sequence_to_length 23 | 24 | 25 | # ====================== utils ====================== 26 | def levenshtein(s1, s2): 27 | m, n = len(s1), len(s2) 28 | dp = [[0] * (n + 1) for _ in range(m + 1)] 29 | for i in range(m + 1): 30 | dp[i][0] = i 31 | for j in range(n + 1): 32 | dp[0][j] = j 33 | for i in range(1, m + 1): 34 | for j in range(1, n + 1): 35 | cost = 0 if s1[i - 1] == s2[j - 1] else 1 36 | dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) 37 | return dp[m][n] 38 | 39 | 40 | def are_lists_similar(a, b, threshold=10): 41 | if len(a) != len(b): 42 | print("The lists are of different lengths.") 43 | return False 44 | total_length = 0 45 | total_diff = 0 46 | for s1, s2 in zip(a, b): 47 | max_len = max(len(s1), len(s2)) 48 | total_length += max_len 49 | total_diff += levenshtein(s1, s2) 50 | percentage_difference = (total_diff / total_length) * 100 51 | print(f"Total difference: {percentage_difference:.2f}%") 52 | return percentage_difference <= threshold 53 | 54 | 55 | def initialize_global_process_group(timeout_second=36000, spmd=False): 56 | import torch.distributed 57 | 58 | if not torch.distributed.is_initialized(): # Check if already initialized 59 | print("Initializing process group...") 60 | torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second)) 61 | else: 62 | print("Process group already initialized.") 63 | 64 | local_rank = int(os.environ["LOCAL_RANK"]) 65 | rank = int(os.environ["RANK"]) 66 | world_size = int(os.environ["WORLD_SIZE"]) 67 | torch.cuda.set_device(local_rank) 68 | 69 | CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "") 70 | if not CUDA_VISIBLE_DEVICES: 71 | if spmd: 72 | # CUDA_VISIBLE_DEVICES = ','.join(str(i) for i in range(tensor_parallel_size)) 73 | CUDA_VISIBLE_DEVICES = ",".join(str(i) for i in range(world_size)) 74 | else: 75 | CUDA_VISIBLE_DEVICES = str(local_rank) 76 | os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES 77 | print(f"CUDA_VISIBLE_DEVICES is not set, set to {CUDA_VISIBLE_DEVICES}") 78 | 79 | return local_rank, rank, world_size 80 | 81 | 82 | def clean_torchelastic_env(): 83 | for k in ["TORCHELASTIC_USE_AGENT_STORE"]: 84 | if k in os.environ: 85 | del os.environ[k] 86 | 87 | 88 | def load_tokenizer_and_model(local_model_path, dtype="bfloat16"): 89 | tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") 90 | tokenizer.pad_token = tokenizer.eos_token 91 | model = AutoModelForCausalLM.from_pretrained(local_model_path, torch_dtype=getattr(torch, dtype), device_map="cuda") 92 | return tokenizer, model 93 | 94 | 95 | def prepare_inputs(tokenizer, prompts, max_prompt_length): 96 | pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id 97 | tokenized = tokenizer(prompts, return_tensors="pt", padding=True) 98 | input_ids = pad_sequence_to_length(tokenized["input_ids"], max_prompt_length, pad_token_id, left_pad=True) 99 | attention_mask = pad_sequence_to_length(tokenized["attention_mask"], max_prompt_length, pad_token_id=0, left_pad=True) 100 | position_ids = compute_position_id_with_mask(attention_mask) 101 | position_ids = pad_sequence_to_length(position_ids, max_prompt_length, pad_token_id=0, left_pad=True) 102 | return input_ids, attention_mask, position_ids 103 | 104 | 105 | def generate_hf_output(model, input_ids, attention_mask, tokenizer, max_response_length): 106 | generation_config = GenerationConfig(do_sample=False) 107 | output = model.generate( 108 | input_ids=input_ids.cuda(), 109 | attention_mask=attention_mask.cuda(), 110 | max_new_tokens=max_response_length, 111 | eos_token_id=tokenizer.eos_token_id, 112 | pad_token_id=tokenizer.pad_token_id, 113 | generation_config=generation_config, 114 | output_scores=False, 115 | return_dict_in_generate=True, 116 | use_cache=False, 117 | ) 118 | seq = output.sequences 119 | response = seq[:, input_ids.shape[1] :] 120 | return tokenizer.batch_decode(response) 121 | 122 | 123 | def get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_config_path): 124 | sampling_params = dict( 125 | n=1, 126 | temperature=0, 127 | top_p=1, 128 | top_k=-1, 129 | max_new_tokens=max_response_length, 130 | presence_penalty=0.0, 131 | frequency_penalty=0.0, 132 | repetition_penalty=1.0, 133 | skip_special_tokens=True, 134 | spaces_between_special_tokens=True, 135 | ignore_eos=False, 136 | ) 137 | 138 | rollout_config = OmegaConf.create( 139 | { 140 | "name": "sglang", 141 | "load_format": "dummy_dtensor", 142 | "enforce_eager": False, 143 | "free_cache_engine": False, 144 | "dtype": dtype, 145 | "gpu_memory_utilization": 0.5, 146 | "ignore_eos": False, 147 | "max_num_batched_tokens": 8192, 148 | "prompt_length": max_prompt_length, 149 | "response_length": max_response_length, 150 | "tensor_model_parallel_size": tensor_parallel_size, 151 | "multi_turn": { 152 | "max_turns": 4, 153 | "enable": True, 154 | "tool_config_path": tool_config_path, 155 | "use_inference_chat_template": False, 156 | "enable_tokenization_sanity_check": True, 157 | }, 158 | "max_model_len": None, 159 | **sampling_params, 160 | } 161 | ) 162 | 163 | return rollout_config 164 | --------------------------------------------------------------------------------