├── .flake8
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs
├── custom_mcp_tool.md
├── install.md
└── license_header.txt
├── examples
└── data_preprocess
│ └── geo3k.py
├── mirorl
├── __init__.py
├── monitor
│ ├── __init__.py
│ └── wandb_alert.py
├── recipe
│ ├── launch.sh
│ └── mcp
│ │ ├── config
│ │ ├── mcp_trainer.yaml
│ │ └── tool_config
│ │ │ └── mcp_tool_config.yaml
│ │ ├── main_ppo.py
│ │ └── run_mirorl_14b_8xgpu.sh
├── tools
│ ├── __init__.py
│ ├── jina_scrape_llm_summary.py
│ ├── mcp_tool.py
│ ├── python_server.py
│ └── serper_search.py
├── trainer
│ ├── __init__.py
│ └── ppo
│ │ ├── __init__.py
│ │ ├── core_algos.py
│ │ ├── ray_trainer.py
│ │ └── reward.py
├── utils
│ ├── __init__.py
│ ├── debug
│ │ ├── __init__.py
│ │ └── exception_helper.py
│ ├── reward_score
│ │ ├── __init__.py
│ │ ├── llm_judge.py
│ │ └── simpleqa.py
│ └── tracking.py
└── workers
│ ├── __init__.py
│ ├── fsdp_workers.py
│ ├── reward_manager
│ ├── __init__.py
│ ├── tests
│ │ ├── testData.jsonl
│ │ ├── test_mcp_format.py
│ │ └── test_others.py
│ └── tool.py
│ └── rollout
│ ├── __init__.py
│ ├── schemas.py
│ └── sglang_rollout
│ ├── __init__.py
│ └── sglang_rollout.py
├── scripts
├── README_GRPO_Visualizer.md
├── converter_hf_to_mcore.py
├── grpo_visualizer.py
├── install.sh
└── visualizer_requirements.txt
└── tests
├── distributed
├── run_all.sh
└── test_tensor_dict.py
├── kernels
└── test_linear_cross_entropy.py
├── models
├── test_transformer.py
└── test_transformers_ulysses.py
├── ray_gpu
├── detached_worker
│ ├── README.md
│ ├── client.py
│ ├── run.sh
│ └── server.py
├── test_colocated_workers.py
├── test_colocated_workers_fused.py
├── test_data_transfer.py
├── test_driverfunc_to_worker.py
├── test_high_level_scheduling_api.py
├── test_rvdz.py
├── test_worker_group_basics.py
└── test_worker_group_torch.py
├── sanity
├── check_license.py
├── check_pr_title.py
├── test_config_docs.py
└── test_import.py
├── test_protocol.py
├── trainer
├── __init__.py
└── ppo
│ ├── __init__.py
│ ├── test_core_algos.py
│ └── test_metric_utils.py
├── utils
├── cpu_tests
│ ├── _test_module.py
│ ├── test_fs.py
│ ├── test_import_utils.py
│ ├── test_model.py
│ └── test_timeout_decorator.py
└── gpu_tests
│ ├── checkpoint
│ └── test_fsdp_ckpt.py
│ ├── dataset
│ ├── test_multiturn_sft_dataset.py
│ ├── test_rl_dataset.py
│ ├── test_rm_dataset.py
│ └── test_sft_dataset.py
│ ├── megatron
│ └── test_pipeline_parallel.py
│ ├── test_activation_offload.py
│ ├── test_flops_counter.py
│ ├── test_seqlen_balancing.py
│ └── test_torch_functional.py
└── workers
├── reward_manager
└── test_registry.py
└── rollout
├── resource
└── tool_configs
│ ├── sandbox_fusion_tool_config
│ └── search_tool_config
├── test_sglang_async_rollout_search_tools.py
├── test_sglang_async_rollout_sf_tools.py
├── test_sglang_async_rollout_w_tools.py
├── test_sglang_spmd.py
└── utils_sglang.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | # Suggested config from pytorch that we can adapt
3 | select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
4 | max-line-length = 120
5 | # C408 ignored because we like the dict keyword argument syntax
6 | # E501 is not flexible enough, we're using B950 instead
7 | # N812 ignored because import torch.nn.functional as F is PyTorch convention
8 | # N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
9 | # E731 allow usage of assigning lambda expressions
10 | ignore =
11 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731
12 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying
13 | # to line this up with executable bit
14 | EXE001,
15 | # these ignores are from flake8-bugbear; please fix!
16 | B007,B008,
17 | optional-ascii-coding = True
18 | exclude =
19 | ./.git,
20 | ./.github,
21 | ./docs
22 | ./build
23 | ./scripts,
24 | ./venv,
25 | *.pyi
26 | .pre-commit-config.yaml
27 | *.md
28 | .flake8
29 | tests/torchtune/models/llama2/scripts/*.py
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | **/*.pt
3 | **/checkpoints
4 | **/wget-log
5 | **/_build/
6 | **/*.ckpt
7 | **/outputs
8 | **/*.tar.gz
9 | **/playground
10 | **/wandb
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 | dataset/*
17 | tensorflow/my_graph/*
18 | .idea/
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | env/
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | tmp/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *,cover
60 | .hypothesis/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # IPython Notebook
84 | .ipynb_checkpoints
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 |
92 | # dotenv
93 | .env
94 |
95 | # virtualenv
96 | venv/
97 | .venv/
98 | ENV/
99 |
100 | # Spyder project settings
101 | .spyderproject
102 |
103 | # Rope project settings
104 | .ropeproject
105 |
106 | # vscode
107 | .vscode
108 |
109 | # Mac
110 | .DS_Store
111 |
112 | # vim
113 | *.swp
114 |
115 | # ckpt
116 | *.lock
117 |
118 | # data
119 | *.parquet
120 |
121 |
122 | # local logs
123 | logs
124 | log
125 | outputs
126 | .history
127 |
128 | train_rollout_data
129 | val_rollout_data
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "verl"]
2 | path = verl
3 | url = https://github.com/volcengine/verl.git
4 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: 'build,verl'
2 |
3 | default_language_version:
4 | python: python3
5 |
6 | repos:
7 | - repo: https://github.com/astral-sh/ruff-pre-commit
8 | rev: "v0.11.4"
9 | hooks:
10 | - id: ruff
11 | args: ["--fix", "--show-fixes", "--output-format=full"]
12 | exclude: ^.*\.(ipynb)$
13 | - id: ruff-format
14 |
15 | - repo: https://github.com/pre-commit/pre-commit-hooks
16 | rev: v5.0.0
17 | hooks:
18 | - id: trailing-whitespace
19 | - id: check-ast
20 | - id: check-merge-conflict
21 | - id: no-commit-to-branch
22 | args: ['--branch=main']
23 | - id: check-added-large-files
24 | args: ['--maxkb=1000']
25 | - id: end-of-file-fixer
26 | exclude: '^(.*\.svg)$'
27 |
28 | - repo: https://github.com/Lucas-C/pre-commit-hooks
29 | rev: v1.5.5
30 | hooks:
31 | - id: insert-license
32 | files: \.py$|\.sh$
33 | args:
34 | - --license-filepath
35 | - docs/license_header.txt
36 |
37 | - repo: https://github.com/pycqa/flake8
38 | rev: 7.1.1
39 | hooks:
40 | - id: flake8
41 | additional_dependencies:
42 | - flake8-bugbear == 22.4.25
43 | - pep8-naming == 0.12.1
44 | - torchfix
45 | args: ['--config=.flake8']
46 |
47 | - repo: https://github.com/omnilib/ufmt
48 | rev: v2.8.0
49 | hooks:
50 | - id: ufmt
51 | additional_dependencies:
52 | - black == 22.12.0
53 | - usort == 1.0.5
54 |
55 | - repo: https://github.com/jsh9/pydoclint
56 | rev: 0.5.12
57 | hooks:
58 | - id: pydoclint
59 | args: [--config=pyproject.toml]
60 |
--------------------------------------------------------------------------------
/docs/custom_mcp_tool.md:
--------------------------------------------------------------------------------
1 | ## Define Your Custom MCP Tool
2 |
3 | MiroRL supports two main approaches for adding custom MCP tools:
4 |
5 | ### 1. Using npx/uvx (Node.js-based MCP tools)
6 |
7 | For Node.js-based MCP tools that can be installed via npm and executed with `npx` or `uvx`, add your tool configuration file:
8 |
9 | ```yaml
10 | tools:
11 | - class_name: "mirorl.tools.mcp_tool.MCPTool"
12 | config:
13 | command: "npx" # or "uvx"
14 | args: ["your-mcp-tool-name"]
15 | env: ["YOUR_API_KEY", "HTTPS_PROXY"] # Required environment variables
16 | server_name: "your_tool_name"
17 | max_retries: 5
18 | delay_between_retries: 1
19 | connection_timeout: 5
20 | execution_timeout: 60
21 | tool_schema:
22 | type: "mcp"
23 | function: {}
24 | ```
25 |
26 | ### 2. Using Python-based MCP tools
27 |
28 | For Python-based MCP tools using the FastMCP framework, add your tool configuration file:
29 |
30 | ```yaml
31 | tools:
32 | - class_name: "mirorl.tools.mcp_tool.MCPTool"
33 | config:
34 | command: "python"
35 | args: ["path/to/your_tool.py"]
36 | env: ["YOUR_API_KEY", "HTTPS_PROXY"] # Required environment variables
37 | server_name: "your_tool_name"
38 | max_retries: 5
39 | delay_between_retries: 1
40 | connection_timeout: 5
41 | execution_timeout: 60
42 | tool_schema:
43 | type: "mcp"
44 | function: {}
45 | ```
46 |
47 | ### 3. Configuration Parameters
48 |
49 | All MCP tools support the following configuration parameters:
50 |
51 | | Parameter | Description | Default | Example |
52 | |-----------|-------------|---------|---------|
53 | | `command` | MCP server command to execute | - | `"npx"`, `"python"` |
54 | | `args` | Command line arguments | `[]` | `["your-tool"]`, `["tool.py"]` |
55 | | `env` | Environment variables | `[]` | `["API_KEY", "PROXY"]` |
56 | | `server_name` | MCP server identifier | - | `"your_tool_name"` |
57 | | `blacklist` | Forbidden operations | `[]` | `["scrape"]` |
58 | | `max_retries` | Maximum retry attempts | `5` | `3` |
59 | | `delay_between_retries` | Retry delay in seconds | `1` | `2` |
60 | | `connection_timeout` | Connection timeout in seconds | `5` | `10` |
61 | | `execution_timeout` | Execution timeout in seconds | `60` | `120` |
62 |
63 | ### 4. Integration with Training
64 |
65 | To use your custom MCP tool in training:
66 |
67 | 1. **Create your tool configuration file** in `mirorl/recipe/mcp/config/tool_config/`
68 | 2. **Set environment variables** required by your tool
69 | 3. **Update training command** to use your tool config:
70 |
71 | ```bash
72 | python3 -m mirorl.recipe.mcp.main_ppo \
73 | --config-path="mirorl/recipe/mcp/config" \
74 | --config-name='mcp_trainer' \
75 | actor_rollout_ref.rollout.multi_turn.tool_config_path="path/to/your_tool_config.yaml" \
76 | data.train_batch_size=256 \
77 | trainer.n_gpus_per_node=8
78 | ```
79 |
--------------------------------------------------------------------------------
/docs/install.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 | ### Requirements
4 |
5 | First of all, to manage environment, we recommend using conda:
6 |
7 | ```bash
8 | conda create -n mirorl python==3.10.4
9 | conda activate mirorl
10 | ```
11 |
12 | For prerequisites installation (CUDA, cuDNN, Apex), we recommend following the [verl prerequisites guide](https://verl.readthedocs.io/en/latest/start/install.html#pre-requisites) which provides detailed instructions for:
13 |
14 | - CUDA: Version >= 12.4
15 | - cuDNN: Version >= 9.8.0
16 | - Apex
17 |
18 | ### Install Dependencies
19 |
20 | #### Install python dependencies
21 |
22 | For python dependencies installation (sglang, flash-attn, flashinfer), execute the `install.sh` script that we provided in `mirorl/scripts`:
23 |
24 | ```bash
25 | # Make sure you have activated verl conda env
26 | # only for FSDP backend and sglang engine
27 | bash scripts/install.sh
28 | ```
29 |
30 | If you encounter errors in this step, please check the script and manually follow the steps in the script.
31 |
32 | #### Install Node.js for MCP Support
33 |
34 | MCP (Model Context Protocol) requires Node.js to run MCP servers. Node.js version 18+ is recommended for optimal compatibility with MCP tools.
35 |
36 | ```bash
37 | # Download Node.js binary (example for Linux x64)
38 | wget https://nodejs.org/dist/v24.2.0/node-v24.2.0-linux-x64.tar.xz
39 |
40 | # Extract to your target path
41 | tar -xf node-v24.2.0-linux-x64.tar.xz -C /your/target/path
42 |
43 | # Add to PATH
44 | export NODEJS_HOME=/your/target/path/node-v24.2.0-linux-x64
45 | export PATH=$NODEJS_HOME/bin:$PATH
46 | export NODE_SHARED=/your/target/path/node-shared/node_modules
47 | export PATH=$NODE_SHARED/.bin:$PATH
48 |
49 | # Verify installation
50 | node --version
51 | npm --version
52 |
53 | # Install serper mcp server
54 | mkdir -p /your/target/path/node-shared
55 | cd /your/target/path/node-shared
56 | npm init -y
57 | npm install serper-search-scrape-mcp-server
58 | ```
59 |
60 | ### Install verl
61 |
62 | For installing the target version of verl, the best way is to clone and install it from source.
63 |
64 | ```bash
65 | git clone https://github.com/MiroMindAsia/mirorl.git --recursive
66 | cd mirorl/verl
67 | pip install --no-deps -v .
68 | ```
69 |
--------------------------------------------------------------------------------
/docs/license_header.txt:
--------------------------------------------------------------------------------
1 | Copyright 2025 MiroMind Team
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
--------------------------------------------------------------------------------
/examples/data_preprocess/geo3k.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Preprocess the Geometry3k dataset to parquet format
16 | """
17 |
18 | import argparse
19 | import os
20 |
21 | import datasets
22 |
23 | from verl.utils.hdfs_io import copy, makedirs
24 |
25 | if __name__ == "__main__":
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--local_dir", default="~/data/geo3k")
28 | parser.add_argument("--hdfs_dir", default=None)
29 |
30 | args = parser.parse_args()
31 |
32 | data_source = "hiyouga/geometry3k"
33 |
34 | dataset = datasets.load_dataset(data_source)
35 |
36 | train_dataset = dataset["train"]
37 | test_dataset = dataset["test"]
38 |
39 | instruction_following = (
40 | r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. "
41 | r"The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}."
42 | )
43 |
44 | # add a row to each data item that represents a unique id
45 | def make_map_fn(split):
46 | def process_fn(example, idx):
47 | problem = example.pop("problem")
48 | prompt = problem + " " + instruction_following
49 | answer = example.pop("answer")
50 | images = example.pop("images")
51 |
52 | data = {
53 | "data_source": data_source,
54 | "prompt": [
55 | {
56 | "role": "user",
57 | "content": prompt,
58 | }
59 | ],
60 | "images": images,
61 | "ability": "math",
62 | "reward_model": {"style": "rule", "ground_truth": answer},
63 | "extra_info": {
64 | "split": split,
65 | "index": idx,
66 | "answer": answer,
67 | "question": problem,
68 | },
69 | }
70 | return data
71 |
72 | return process_fn
73 |
74 | train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8)
75 | test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=8)
76 |
77 | local_dir = args.local_dir
78 | hdfs_dir = args.hdfs_dir
79 |
80 | train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
81 | test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
82 |
83 | if hdfs_dir is not None:
84 | makedirs(hdfs_dir)
85 | copy(src=local_dir, dst=hdfs_dir)
86 |
--------------------------------------------------------------------------------
/mirorl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mirorl/monitor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .wandb_alert import get_or_init_wandb, HangChecker, MCPToolChecker
16 |
17 | __all__ = ["get_or_init_wandb", "MCPToolChecker", "HangChecker"]
18 |
--------------------------------------------------------------------------------
/mirorl/recipe/launch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 MiroMind Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # set -ex
17 |
18 | export JOB_NAME=mirorl-odr-mcp
19 |
20 | # Setup Node.js for MCP
21 | export NODEJS_HOME=/your/path/to/nodejs
22 | export PATH=$NODEJS_HOME/bin:$PATH
23 | export NODE_SHARED=$NODEJS_HOME/node-shared/node_modules
24 | export PATH=$NODE_SHARED/.bin:$PATH
25 |
26 | # Setup proxies
27 | export HTTP_PROXY=xxx
28 | export HTTPS_PROXY=xxx
29 | export NO_PROXY=localhost,127.0.0.1
30 |
31 | # Check for singlenode flag
32 | SCRIPT=$1
33 | SINGLENODE=${SINGLENODE:-false}
34 |
35 | export LOG_DIR=$(pwd)/outputs/$MLP_TASK_ID/logs
36 | mkdir -p $LOG_DIR
37 |
38 | if [ "$SINGLENODE" == "true" ]; then
39 | bash $SCRIPT 2>&1 | tee $LOG_DIR/main_ppo.log
40 | exit 0
41 | fi
42 |
43 | #==============================================================================#
44 |
45 | export NCCL_IB_TIMEOUT=80
46 | export NCCL_IB_RETRY_CNT=20
47 | export NCCL_IB_AR_THRESHOLD=0
48 |
49 | export MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
50 | export MASTER_PORT=${MASTER_PORT:-10086}
51 | export NNODES=${NNODES:-1}
52 | export NODE_RANK=${NODE_RANK:-0}
53 | export GPUS_PER_NODE=${GPUS_PER_NODE:-8}
54 |
55 | # Compute total world size (number of processes)
56 | export WORLD_SIZE=$((NNODES * GPUS_PER_NODE))
57 |
58 | echo "NNODES: $NNODES"
59 | echo "NODE_RANK: $NODE_RANK"
60 | echo "GPUS_PER_NODE: $GPUS_PER_NODE"
61 | echo "MASTER_PORT: $MASTER_PORT"
62 | echo "MASTER_ADDR: $MASTER_ADDR"
63 | echo "WORLD_SIZE: $WORLD_SIZE"
64 | echo "HTTP_PROXY: $HTTP_PROXY, HTTPS_PROXY: $HTTPS_PROXY"
65 |
66 | export NCCL_P2P_LEVEL=NVL
67 | export PYTHONPATH=$PWD:$PYTHONPATH
68 |
69 | export LOG_DIR=$(pwd)/outputs/$JOB_NAME/logs
70 | mkdir -p $LOG_DIR
71 |
72 | # num_nodes has to be at least 1
73 | if [ $NNODES -lt 1 ]; then
74 | echo "Number of nodes must be at least 1"
75 | exit 1
76 | fi
77 |
78 | # if HOST contains "master", then this is the head node
79 | if [[ $NODE_RANK -eq 0 ]]; then
80 | node_role="master"
81 | else
82 | node_role="worker"
83 | fi
84 | head_node_ip=${MASTER_ADDR:-127.0.0.1}
85 |
86 | wait_time=30
87 | if [ "$node_role" == "master" ]; then
88 | echo "Starting Ray head node..."
89 | # Start Ray on this node as the head node and extract its address
90 | # The `ray start --head` command outputs information that includes the address,
91 | # but here we're assuming it's known or statically assigned for simplicity.
92 | ray start --head --node-ip-address=$head_node_ip --include-dashboard=True --dashboard-host $head_node_ip --port=6379 --min-worker-port 15000 --max-worker-port 19999
93 | sleep $wait_time
94 | elif [ "$node_role" == "worker" ]; then
95 | sleep $wait_time
96 | attempt=1
97 | echo "Starting Ray worker node and attempting to connect to the head node at $head_node_ip:6379"
98 | while true; do
99 | # Attempt to start Ray and connect to the head node
100 | ray start --address="$head_node_ip:6379" --min-worker-port 15000 --max-worker-port 19999 && break || {
101 | if [ $attempt -le 5 ]; then
102 | echo "Ray worker start attempt $attempt failed. Retrying in $wait_time seconds..."
103 | ((attempt++))
104 | sleep $wait_time
105 | else
106 | echo "Failed to connect to the head node after $wait_time attempts. Exiting."
107 | exit 1
108 | fi
109 | }
110 | done
111 | fi
112 |
113 | # run the training script once Ray has been started on all nodes
114 | sleep $wait_time
115 | if [ "$node_role" == "master" ]; then
116 | num_active_ray_nodes=$(ray list nodes | grep ALIVE | wc -l)
117 | echo "Number of active Ray nodes: $num_active_ray_nodes"
118 | if [ $num_active_ray_nodes -lt $NNODES ]; then
119 | echo "Waiting for all Ray nodes to start..."
120 | attempt=1
121 | while true; do
122 | num_active_ray_nodes=$(ray list nodes | grep ALIVE | wc -l)
123 | if [ $num_active_ray_nodes -eq $NNODES ]; then
124 | break
125 | elif [ $attempt -le 5 ]; then
126 | echo "python command attempt $attempt failed. Retrying in $wait_time seconds..."
127 | ((attempt++))
128 | sleep $wait_time
129 | else
130 | echo "Failed to connect to the head node after $wait_time attempts. Exiting."
131 | exit 1
132 | fi
133 | done
134 | fi
135 | echo "End starting"
136 | bash ${SCRIPT} 2>&1 | tee $LOG_DIR/main_ppo.log
137 | else
138 | echo "End starting"
139 | # Continuously check the health of the Ray cluster by pinging the head node.
140 | # If the health check fails, break the loop and proceed.
141 | while true; do
142 | ray health-check --address $head_node_ip:6379 &>/dev/null
143 | if [ $? -ne 0 ]; then
144 | break
145 | fi
146 | sleep 60
147 | done
148 | fi
149 |
--------------------------------------------------------------------------------
/mirorl/recipe/mcp/config/mcp_trainer.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | searchpath:
3 | - file://verl/verl/trainer/config
4 |
5 | defaults:
6 | - ppo_trainer
7 | - _self_
8 |
9 | custom_reward_function:
10 | name: browsecomp_compute_score # available scorers: gaia_naive_compute_score, simpleqa_compute_score, browsecomp_compute_score, gpt41_compute_score, cascade_compute_score
11 | path: "./mirorl/utils/reward_score/llm_judge.py"
12 | reward_model:
13 | reward_manager: tool
14 | reward_kwargs:
15 | accuracy_reward_weight: 0.9
16 | tool_format_reward_weight: 0.1
17 | gate_tool_format_reward: False # If set to True, the tool format reward will not be given unless the final answer is correct, even if the tool format is correct. If set to False, the tool format reward will be given as long as the tool format is correct, regardless of whether the final answer is correct.
18 | async_process: True
19 | batch_size: 300
20 | max_retry: 3
21 |
22 | data:
23 | max_prompt_length: 1024
24 | max_response_length: 1024
25 | train_batch_size: 256
26 | return_raw_chat: True
27 |
28 | actor_rollout_ref:
29 | hybrid_engine: True
30 | rollout:
31 | name: sglang
32 | multi_turn:
33 | enable: True
34 | max_turns: 20
35 |
36 | # ... style tool call for multi-turn rollout
37 | use_mcp_tool_call: True
38 |
39 | # tool response cut off length(general web page should have a info density of 4 chars per token)
40 | tool_response_cut_off_length: 20000
41 |
42 | # this means both self.input_ids and self.messages will not contain the think text for intermediate rounds
43 | keep_think_text_for_last_round_only: False
44 |
45 | # think block close tag, this works for both Qwen2.5/3 and DeepSeek-R1
46 | think_block_close_tag: ""
47 |
48 | # this flag is used to ignore loss calculation for failed rollouts
49 | ignore_failed_rollouts_in_loss: False
50 |
51 | trainer:
52 | # Monitor config
53 | # Whether to enable wandb alert
54 | enable_wandb_alert: False
55 |
56 | # Tool check alert threshold
57 | # current step's tool calling error count is x times of the previous step's tool calling error count
58 | # then will trigger alert
59 | tool_check_alert_threshold: 10
60 |
61 | # Training hang check interval (in seconds)
62 | # if global step is not updated for hang_check_interval seconds, will trigger alert
63 | hang_check_interval: 3600
64 |
--------------------------------------------------------------------------------
/mirorl/recipe/mcp/config/tool_config/mcp_tool_config.yaml:
--------------------------------------------------------------------------------
1 | tools:
2 | - class_name: "mirorl.tools.mcp_tool.MCPTool"
3 | config:
4 | command: "python"
5 | args: ["mirorl/tools/serper_search.py"]
6 | env: ["SERPER_API_KEY", "HTTPS_PROXY"]
7 | server_name: "search_and_scrape_webpage"
8 | tool_schema:
9 | type: "mcp"
10 | function: {}
11 | - class_name: "mirorl.tools.mcp_tool.MCPTool"
12 | config:
13 | command: "python"
14 | args: ["mirorl/tools/jina_scrape_llm_summary.py"]
15 | env: ["JINA_API_KEY", "HTTPS_PROXY", "NO_PROXY", "SUMMARY_LLM_URL", "SUMMARY_LLM_NAME"]
16 | server_name: "jina_scrape_llm_summary"
17 | execution_timeout: 600
18 | tool_schema:
19 | type: "mcp"
20 | function: {}
21 |
--------------------------------------------------------------------------------
/mirorl/recipe/mcp/run_mirorl_14b_8xgpu.sh:
--------------------------------------------------------------------------------
1 | # run on 8xH100
2 | # make sure your current working directory is the root of the project
3 | # Training with Musique + GAIA text-103 for train/test
4 |
5 | set -x
6 |
7 | ulimit -n 65535
8 |
9 | PROJECT_DIR="$(pwd)"
10 | CONFIG_PATH="$PROJECT_DIR/mirorl/recipe/mcp/config"
11 | TOOL_CONFIG_PATH="$PROJECT_DIR/mirorl/recipe/mcp/config/tool_config/mcp_tool_config.yaml"
12 | MODEL_PATH=/your/model/path
13 | DATA_HOME=/your/data/path
14 | EXPERIMENT_NAME="mirorl-14b_genqa_64k"
15 |
16 |
17 | python -m mirorl.recipe.mcp.main_ppo \
18 | --config-path="$CONFIG_PATH" \
19 | --config-name='mcp_trainer' \
20 | algorithm.adv_estimator=grpo \
21 | data.train_batch_size=64 \
22 | data.max_prompt_length=3072 \
23 | data.max_response_length=62464 \
24 | data.filter_overlong_prompts=True \
25 | data.truncation='error' \
26 | data.return_raw_chat=True \
27 | actor_rollout_ref.model.path=$MODEL_PATH \
28 | actor_rollout_ref.actor.optim.lr=2e-5 \
29 | actor_rollout_ref.model.use_remove_padding=True \
30 | actor_rollout_ref.model.use_liger=True \
31 | actor_rollout_ref.model.use_fused_kernels=True \
32 | actor_rollout_ref.actor.ppo_mini_batch_size=64 \
33 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
34 | actor_rollout_ref.actor.use_kl_loss=False \
35 | actor_rollout_ref.actor.kl_loss_coef=0.001 \
36 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \
37 | actor_rollout_ref.actor.entropy_coeff=0 \
38 | actor_rollout_ref.actor.use_dynamic_bsz=True \
39 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \
40 | actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
41 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
42 | actor_rollout_ref.actor.fsdp_config.param_offload=False \
43 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
44 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
45 | actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
46 | actor_rollout_ref.rollout.name=sglang \
47 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
48 | actor_rollout_ref.rollout.n=8 \
49 | actor_rollout_ref.rollout.multi_turn.enable_tokenization_sanity_check=False \
50 | actor_rollout_ref.rollout.val_kwargs.do_sample=True \
51 | actor_rollout_ref.rollout.val_kwargs.temperature=0.3 \
52 | actor_rollout_ref.rollout.multi_turn.tool_config_path=$TOOL_CONFIG_PATH \
53 | actor_rollout_ref.rollout.multi_turn.max_turns=40 \
54 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
55 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
56 | algorithm.use_kl_in_reward=False \
57 | trainer.critic_warmup=0 \
58 | trainer.logger=['console','wandb'] \
59 | trainer.project_name='odr-mcp' \
60 | trainer.experiment_name=$EXPERIMENT_NAME \
61 | trainer.n_gpus_per_node=8 \
62 | trainer.nnodes=1 \
63 | trainer.save_freq=30 \
64 | trainer.test_freq=5 \
65 | trainer.val_before_train=True \
66 | trainer.rollout_data_dir=$PROJECT_DIR/train_rollout_data/$EXPERIMENT_NAME \
67 | trainer.validation_data_dir=$PROJECT_DIR/val_rollout_data/$EXPERIMENT_NAME \
68 | data.train_files=$DATA_HOME/genqa/train.parquet \
69 | data.val_files=$DATA_HOME/genqa/val.parquet \
70 | trainer.total_epochs=15 $@
71 |
--------------------------------------------------------------------------------
/mirorl/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mirorl/tools/mcp_tool.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import asyncio
16 | import json
17 | import logging
18 | import os
19 | from datetime import timedelta
20 | from typing import Any, Optional, Tuple
21 | from uuid import uuid4
22 |
23 | import exceptiongroup
24 |
25 | import ray
26 | from mcp import ClientSession, StdioServerParameters
27 | from mcp.client.stdio import stdio_client
28 |
29 | from verl.tools.base_tool import BaseTool
30 | from verl.tools.schemas import OpenAIFunctionToolSchema
31 |
32 | from mirorl.utils.debug.exception_helper import extract_exception_details
33 |
34 | logger = logging.getLogger(__name__)
35 | logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
36 |
37 |
38 | def is_timeout_error(error: Exception) -> bool:
39 | """Check if an error is a timeout-related error."""
40 | error_str = str(error)
41 | return any(
42 | keyword in error_str
43 | for keyword in ["ETIMEDOUT", "ECONNRESET", "Timeout", "Timed out"]
44 | )
45 |
46 |
47 | class MCPTool(BaseTool):
48 | """A tool for calling MCP tools.
49 |
50 | - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.
51 | - `create`: create a tool instance for a trajectory.
52 | - `execute`: execute the tool.
53 | - `calc_reward`: calculate the reward respect to tool state.
54 | - `release`: release the tool instance.
55 | """
56 |
57 | def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
58 | super().__init__(config, tool_schema)
59 | self._instance_dict = {}
60 |
61 | self.max_retries = config.get("max_retries", 5)
62 | self.delay_between_retries = config.get(
63 | "delay_between_retries", 1
64 | ) # in seconds
65 | self.connection_timeout = config.get("connection_timeout", 5)
66 | self.execution_timeout = config.get("execution_timeout", 60)
67 |
68 | # Get command, args, and env from config
69 | self.params = StdioServerParameters(
70 | command=config["command"],
71 | args=config.get("args", []),
72 | env={e: os.environ.get(e) for e in config["env"]}
73 | if "env" in config.keys()
74 | else {},
75 | )
76 |
77 | def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
78 | return self.tool_schema
79 |
80 | async def create(
81 | self,
82 | instance_id: Optional[str] = None,
83 | ground_truth: Optional[str] = None,
84 | **kwargs,
85 | ) -> str:
86 | if instance_id is None:
87 | instance_id = str(uuid4())
88 | # TODO: Add all create_kwargs to dict
89 | self._instance_dict[instance_id] = {
90 | "response": "",
91 | "ground_truth": ground_truth,
92 | "reward": 0.0,
93 | }
94 | return instance_id
95 |
96 | async def _execute(self, parameters: dict[str, Any]):
97 | response = ""
98 |
99 | async with stdio_client(self.params) as (read, write):
100 | async with ClientSession(
101 | read,
102 | write,
103 | read_timeout_seconds=timedelta(seconds=self.connection_timeout),
104 | ) as session:
105 | for attempt in range(self.max_retries):
106 | try:
107 | await session.initialize()
108 |
109 | result = await session.call_tool(
110 | self.name,
111 | arguments=parameters,
112 | read_timeout_seconds=timedelta(
113 | seconds=self.execution_timeout
114 | ),
115 | )
116 | response = result.content[0].text if result.content else ""
117 | if attempt > 0:
118 | logger.error(f"Attempt {attempt + 1} success")
119 | break # Exit loop if successful
120 | except Exception as e:
121 | # The error type is McpError, consistently having an error code of -32603.
122 | # To determine if the failed connection is network-related, we check the message.
123 | if is_timeout_error(e) and attempt < self.max_retries - 1:
124 | logger.error(f"Attempt {attempt + 1} failed: {e}")
125 | await asyncio.sleep(self.delay_between_retries)
126 | else:
127 | response = f"Tool execution failed: {e}"
128 |
129 | try:
130 | mcp_tool_checker = ray.get_actor(
131 | name="mcp_tool_checker", namespace="monitor"
132 | )
133 | ray.get(
134 | mcp_tool_checker.record_error.remote(
135 | str(e), self.name
136 | )
137 | )
138 | except Exception as e:
139 | print(
140 | f"MCP tool checker not found, skip recording error: {e}"
141 | )
142 |
143 | break # Exit loop if the error is not timed out
144 |
145 | return response
146 |
147 | async def execute(
148 | self, instance_id: str, parameters: dict[str, Any], **kwargs
149 | ) -> Tuple[str, float, dict]:
150 | # Call MCP tool with retry logic
151 | response = ""
152 |
153 | for attempt in range(self.max_retries):
154 | try:
155 | response = await self._execute(parameters)
156 | break
157 | except Exception as e:
158 | if isinstance(e, exceptiongroup.ExceptionGroup):
159 | details = extract_exception_details(e)
160 | logger.error(
161 | f"MCPTool.execute attempt {attempt + 1} failed: {e}\n"
162 | f"exception group details: {json.dumps(details, indent=4)}"
163 | )
164 | else:
165 | logger.error(f"MCPTool.execute attempt {attempt + 1} failed: {e}")
166 |
167 | if attempt == self.max_retries:
168 | logger.error(
169 | f"MCPTool.execute failed after {self.max_retries} attempts, returning empty response"
170 | )
171 |
172 | self._instance_dict[instance_id]["response"] = response
173 |
174 | # NOTE: tool_reward is not used in anywhere
175 | return response, 0.0, {}
176 |
177 | async def calc_reward(self, instance_id: str, **kwargs) -> float:
178 | # NOTE: tool_reward is not used in anywhere
179 | return self._instance_dict[instance_id]["reward"]
180 |
181 | async def release(self, instance_id: str, **kwargs) -> None:
182 | if instance_id in self._instance_dict:
183 | del self._instance_dict[instance_id]
184 |
--------------------------------------------------------------------------------
/mirorl/tools/python_server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 |
16 | from e2b_code_interpreter import Sandbox
17 | from mcp.server.fastmcp import FastMCP
18 |
19 | # Initialize FastMCP server
20 | mcp = FastMCP("e2b-python-interpreter")
21 |
22 | # API keys
23 | E2B_API_KEY = os.environ.get("E2B_API_KEY")
24 |
25 | # DEFAULT CONFS
26 | DEFAULT_TIMEOUT = 300 # seconds
27 |
28 |
29 | @mcp.tool()
30 | async def run_python_code(
31 | code_block, timeout: int = DEFAULT_TIMEOUT, sandbox_id=None
32 | ) -> str:
33 | """Run python code in an interperter and return the execution result.
34 |
35 | Args:
36 | code_block: The python code to run.
37 | timeout: Time in seconds before the sandbox is automatically shutdown. The default is 300 seconds.
38 | sandbox_id: The id of the sandbox to run the code in. Reuse existing sandboxes whenever possible.
39 | Only create new ones if this is the first time running code in this sandbox.
40 |
41 | Returns:
42 | An object containing the sandbox id and the execution result object including results, logs and errors.
43 | """
44 | if sandbox_id:
45 | sandbox = Sandbox.connect(sandbox_id, api_key=E2B_API_KEY)
46 | else:
47 | sandbox = Sandbox(timeout=timeout, api_key=E2B_API_KEY)
48 |
49 | execution = sandbox.run_code(code_block)
50 |
51 | sandbox_id = sandbox.get_info().sandbox_id
52 |
53 | return dict(execution_result=execution, sandbox_id=sandbox_id)
54 |
55 |
56 | @mcp.tool()
57 | async def upload_local_file_to_python_interpreter(
58 | local_file_path: str, sandbox_id=None
59 | ) -> str:
60 | """Upload a local file to the `/home/user` dir of the remote python interpreter.
61 |
62 | Args:
63 | file_path: The path of the file on local machine to upload.
64 | sandbox_id: The id of the sandbox to run the code in. Reuse existing sandboxes whenever possible.
65 | Only create new ones if this is the first time running code in this sandbox.
66 |
67 | Returns:
68 | The path of the uploaded file in the remote python interpreter.
69 | """
70 | if sandbox_id:
71 | sandbox = Sandbox.connect(sandbox_id, api_key=E2B_API_KEY)
72 | else:
73 | sandbox = Sandbox(api_key=E2B_API_KEY)
74 |
75 | if not os.path.exists(local_file_path):
76 | raise FileNotFoundError(f"File {local_file_path} does not exist.")
77 |
78 | # Get the uploaded file path
79 | uploaded_file_path = os.path.join("/home/user", os.path.basename(local_file_path))
80 |
81 | # Upload the file
82 | with open(local_file_path, "rb") as f:
83 | sandbox.files.write(uploaded_file_path, f)
84 |
85 | sandbox_id = sandbox.get_info().sandbox_id
86 |
87 | return dict(uploaded_file_path=uploaded_file_path, sandbox_id=sandbox_id)
88 |
89 |
90 | @mcp.tool()
91 | async def download_internet_file_to_python_interpreter(
92 | url: str, sandbox_id=None
93 | ) -> str:
94 | """Download a file from the internet to the `/home/user` dir of the remote python interpreter.
95 |
96 | Args:
97 | url: The URL of the file to download.
98 | sandbox_id: The id of the sandbox to run the code in. Reuse existing sandboxes whenever possible.
99 | Only create new ones if this is the first time running code in this sandbox.
100 |
101 | Returns:
102 | The path of the downloaded file in the python interpreter.
103 | """
104 | if sandbox_id:
105 | sandbox = Sandbox.connect(sandbox_id, api_key=E2B_API_KEY)
106 | else:
107 | sandbox = Sandbox(api_key=E2B_API_KEY)
108 |
109 | downloaded_file_path = os.path.join("/home/user", os.path.basename(url))
110 |
111 | # Download the file
112 | sandbox.commands.run(f"wget {url} -O {downloaded_file_path}")
113 |
114 | sandbox_id = sandbox.get_info().sandbox_id
115 |
116 | return dict(downloaded_file_path=downloaded_file_path, sandbox_id=sandbox_id)
117 |
118 |
119 | if __name__ == "__main__":
120 | mcp.run(transport="stdio")
121 |
--------------------------------------------------------------------------------
/mirorl/tools/serper_search.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import os
17 | from typing import Any, Dict
18 |
19 | import requests
20 | from mcp.server.fastmcp import FastMCP
21 | from tenacity import (
22 | retry,
23 | retry_if_exception_type,
24 | stop_after_attempt,
25 | wait_exponential,
26 | )
27 |
28 |
29 | # Configure logging
30 | logging.basicConfig(level=logging.INFO)
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | # Initialize FastMCP server
35 | mcp = FastMCP("search_and_scrape_webpage")
36 |
37 |
38 | @retry(
39 | stop=stop_after_attempt(3),
40 | wait=wait_exponential(multiplier=1, min=4, max=10),
41 | retry=retry_if_exception_type(
42 | (requests.ConnectionError, requests.Timeout, requests.HTTPError)
43 | ),
44 | )
45 | def make_serper_request(
46 | payload: Dict[str, Any], headers: Dict[str, str]
47 | ) -> requests.Response:
48 | """Make HTTP request to Serper API with retry logic."""
49 | response = requests.post(
50 | "https://google.serper.dev/search", json=payload, headers=headers
51 | )
52 | response.raise_for_status()
53 | return response
54 |
55 |
56 | def _is_huggingface_dataset_or_space_url(url):
57 | """
58 | Check if the URL is a Hugging Face dataset or space URL.
59 | :param url: The URL to check
60 | :return: True if it's a HuggingFace dataset or space URL, False otherwise
61 | """
62 | if not url:
63 | return False
64 | return "huggingface.co/datasets" in url or "huggingface.co/spaces" in url
65 |
66 |
67 | @mcp.tool()
68 | def google_search(
69 | q: str,
70 | gl: str = "us",
71 | hl: str = "en",
72 | location: str = None,
73 | num: int = None,
74 | tbs: str = None,
75 | page: int = None,
76 | autocorrect: bool = None,
77 | ) -> Dict[str, Any]:
78 | """
79 | Tool to perform web searches via Serper API and retrieve rich results.
80 |
81 | It is able to retrieve organic search results, people also ask,
82 | related searches, and knowledge graph.
83 |
84 | Args:
85 | q: Search query string
86 | gl: Optional region code for search results in ISO 3166-1 alpha-2 format (e.g., 'us')
87 | hl: Optional language code for search results in ISO 639-1 format (e.g., 'en')
88 | location: Optional location for search results (e.g., 'SoHo, New York, United States', 'California, United States')
89 | num: Number of results to return (default: 10)
90 | tbs: Time-based search filter ('qdr:h' for past hour, 'qdr:d' for past day, 'qdr:w' for past week,
91 | 'qdr:m' for past month, 'qdr:y' for past year)
92 | page: Page number of results to return (default: 1)
93 | autocorrect: Whether to autocorrect spelling in query
94 |
95 | Returns:
96 | Dictionary containing search results and metadata
97 | """
98 |
99 | # Check for API key
100 | serper_api_key = os.getenv("SERPER_API_KEY")
101 | if not serper_api_key:
102 | return {
103 | "success": False,
104 | "error": "SERPER_API_KEY environment variable not set",
105 | "results": [],
106 | }
107 |
108 | # Validate required parameter
109 | if not q or not q.strip():
110 | return {
111 | "success": False,
112 | "error": "Search query 'q' is required and cannot be empty",
113 | "results": [],
114 | }
115 |
116 | try:
117 | # Build payload with all supported parameters
118 | payload = {
119 | "q": q.strip(),
120 | "gl": gl,
121 | "hl": hl,
122 | }
123 |
124 | # Add optional parameters if provided
125 | if location:
126 | payload["location"] = location
127 | if num is not None:
128 | payload["num"] = num
129 | else:
130 | payload["num"] = 10 # Default
131 | if tbs:
132 | payload["tbs"] = tbs
133 | if page is not None:
134 | payload["page"] = page
135 | if autocorrect is not None:
136 | payload["autocorrect"] = autocorrect
137 |
138 | # Set up headers
139 | headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"}
140 |
141 | # Make the API request
142 | response = make_serper_request(payload, headers)
143 | data = response.json()
144 |
145 | # filter out huggingface dataset or space urls
146 | organic_results = []
147 | if "organic" in data:
148 | for item in data["organic"]:
149 | if _is_huggingface_dataset_or_space_url(item.get("link", "")):
150 | continue
151 | organic_results.append(item)
152 |
153 | # Build comprehensive response
154 | response_data = {
155 | "organic": organic_results,
156 | "searchParameters": data.get("searchParameters", {}),
157 | }
158 |
159 | return response_data
160 |
161 | except Exception as e:
162 | return {"success": False, "error": f"Unexpected error: {str(e)}", "results": []}
163 |
164 |
165 | if __name__ == "__main__":
166 | mcp.run()
167 |
--------------------------------------------------------------------------------
/mirorl/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mirorl/trainer/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mirorl/trainer/ppo/core_algos.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 |
18 | def ignore_failed_rollouts_in_loss(data, rollouts_mask: torch.Tensor):
19 | """
20 | Mask the failed rollouts in the loss calculation.
21 |
22 | Args:
23 | data: DataProto object, containing batch, non_tensor_batch and meta_info
24 | rollouts_mask: torch.Tensor, shape (rollouts_num,)
25 |
26 | """
27 | data.batch["loss_mask"][rollouts_mask] = 0
28 |
--------------------------------------------------------------------------------
/mirorl/trainer/ppo/reward.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Individual Contributor: Thibaut Barroyer
2 | # Copyright 2025 MiroMind Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import multiprocessing
17 | import os
18 | from functools import partial
19 |
20 | import ray
21 |
22 | from verl import DataProto
23 |
24 | # Adapted from verl/trainer/ppo/reward.py
25 | # Updated by MiroMind Team
26 | # import default_compute_score from mirorl.utils.reward_score
27 | from mirorl.utils.reward_score import default_compute_score
28 |
29 |
30 | def get_custom_reward_fn(config):
31 | import importlib.util
32 | import sys
33 |
34 | reward_fn_config = config.get("custom_reward_function") or {}
35 | file_path = reward_fn_config.get("path")
36 | if not file_path:
37 | return None
38 |
39 | if not os.path.exists(file_path):
40 | raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
41 |
42 | spec = importlib.util.spec_from_file_location("custom_module", file_path)
43 | module = importlib.util.module_from_spec(spec)
44 | try:
45 | sys.modules["custom_module"] = module
46 | spec.loader.exec_module(module)
47 | except Exception as e:
48 | raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e
49 |
50 | function_name = reward_fn_config.get("name")
51 | if not hasattr(module, function_name):
52 | raise AttributeError(
53 | f"Reward function '{function_name}' not found in '{file_path}'."
54 | )
55 |
56 | print(f"using customized reward function '{function_name}' from '{file_path}'")
57 | raw_fn = getattr(module, function_name)
58 |
59 | reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))
60 |
61 | def wrapped_fn(*args, **kwargs):
62 | return raw_fn(*args, **kwargs, **reward_kwargs)
63 |
64 | return wrapped_fn
65 |
66 |
67 | def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
68 | """
69 | Load and initialize a reward manager based on the configuration.
70 |
71 | Args:
72 | config: PPO trainer configuration object containing reward_model fields.
73 | tokenizer: Tokenizer object used for processing text.
74 | num_examine: Number of samples to examine.
75 | **reward_kwargs: Additional keyword arguments for the reward manager.
76 |
77 | Returns:
78 | An instance of the specified reward manager class.
79 | """
80 | from verl.workers.reward_manager import get_reward_manager_cls
81 |
82 | # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
83 | # naive: NaiveRewardManager
84 | # prime: PrimeRewardManager
85 | # batch: BatchRewardManager
86 | # dapo: DAPORewardManager
87 | # Note(haibin.lin): For custom reward managers, please make sure they are imported and
88 | # registered via `verl.workers.reward_manager.register`
89 | # By default reward_manager is set to naive (NaiveRewardManager)
90 | reward_manager_name = config.reward_model.get("reward_manager", "naive")
91 | reward_manager_cls = get_reward_manager_cls(reward_manager_name)
92 |
93 | # Try to get a custom reward function based on the configuration
94 | compute_score = get_custom_reward_fn(config)
95 | final_compute_score = compute_score
96 |
97 | if compute_score is None:
98 | sandbox_config = config.reward_model.get("sandbox_fusion")
99 | sandbox_url = sandbox_config.get("url") if sandbox_config else None
100 | if sandbox_url:
101 | sandbox_manager = multiprocessing.Manager()
102 | # Create a semaphore to control concurrent access to the sandbox
103 | _concurrent_semaphore = sandbox_manager.Semaphore(
104 | sandbox_config.get("max_concurrent", 64)
105 | )
106 | final_compute_score = partial(
107 | default_compute_score,
108 | sandbox_fusion_url=sandbox_url,
109 | concurrent_semaphore=_concurrent_semaphore,
110 | )
111 | else:
112 | final_compute_score = default_compute_score
113 |
114 | # Instantiate and return the reward manager with the specified parameters
115 | return reward_manager_cls(
116 | tokenizer=tokenizer,
117 | num_examine=num_examine,
118 | compute_score=final_compute_score,
119 | reward_fn_key=config.data.reward_fn_key,
120 | **reward_kwargs,
121 | )
122 |
123 |
124 | def compute_reward(data: DataProto, reward_fn):
125 | """
126 | Compute reward for a batch of data.
127 | Args:
128 | data: DataProto object containing the input data.
129 | reward_fn: Reward function to compute the reward.
130 | Returns:
131 | Tuple of reward tensor and extra info dictionary.
132 | """
133 | try:
134 | reward_result = reward_fn(data, return_dict=True)
135 | reward_tensor = reward_result["reward_tensor"]
136 | reward_extra_infos_dict = reward_result["reward_extra_info"]
137 | except Exception as e:
138 | print(f"Error in reward_fn: {e}")
139 | reward_tensor = reward_fn(data)
140 | reward_extra_infos_dict = {}
141 |
142 | return reward_tensor, reward_extra_infos_dict
143 |
144 |
145 | @ray.remote(num_cpus=1)
146 | def compute_reward_async(data: DataProto, config, tokenizer):
147 | """
148 | Load the reward manager and compute the reward for a batch of data.
149 | This is meant to be run in a separate Ray worker.
150 | """
151 | reward_fn = load_reward_manager(
152 | config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
153 | )
154 | return compute_reward(data, reward_fn)
155 |
--------------------------------------------------------------------------------
/mirorl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mirorl/utils/debug/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mirorl/utils/debug/exception_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import traceback
16 | from typing import Any, Dict
17 |
18 |
19 | def extract_exception_details(exception_group) -> Dict[str, Any]:
20 | """
21 | Extract detailed exception information from ExceptionGroup
22 |
23 | Args:
24 | exception_group: ExceptionGroup instance
25 |
26 | Returns:
27 | Dictionary containing detailed exception information
28 | """
29 | details = {
30 | "message": str(exception_group),
31 | "type": type(exception_group).__name__,
32 | "exceptions": [],
33 | "traceback": traceback.format_exc(),
34 | }
35 |
36 | # Extract all sub-exceptions
37 | for i, exc in enumerate(exception_group.exceptions):
38 | exc_info = {
39 | "index": i,
40 | "type": type(exc).__name__,
41 | "message": str(exc),
42 | "traceback": "".join(
43 | traceback.format_exception(type(exc), exc, exc.__traceback__)
44 | ),
45 | }
46 | details["exceptions"].append(exc_info)
47 |
48 | return details
49 |
--------------------------------------------------------------------------------
/mirorl/utils/reward_score/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2025 MiroMind Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # from . import gsm8k, math, prime_math, prime_code
16 |
17 | from verl.utils.import_utils import deprecated
18 |
19 |
20 | # Adapted from verl/utils/reward_score/__init__.py
21 | # Updated by MiroMind Team
22 | # 1. Add custom reward function support
23 | def default_compute_score(
24 | data_source,
25 | solution_str,
26 | ground_truth,
27 | extra_info=None,
28 | sandbox_fusion_url=None,
29 | concurrent_semaphore=None,
30 | ):
31 | """Compute the score for a given solution based on the data source.
32 |
33 | Args:
34 | data_source (str): The source dataset identifier which determines the scoring method.
35 | solution_str (str): The solution string to be evaluated.
36 | ground_truth (str): The ground truth answer for comparison.
37 | extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None.
38 |
39 | Returns:
40 | float: The computed score as a floating point number. If the result is a dictionary,
41 | it returns the dictionary instead.
42 |
43 | Raises:
44 | NotImplementedError: If the reward function is not implemented for the given data source.
45 | """
46 | if data_source == "openai/gsm8k":
47 | from verl.utils.reward_score import gsm8k
48 |
49 | res = gsm8k.compute_score(solution_str, ground_truth)
50 | elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval"]:
51 | from verl.utils.reward_score import math
52 |
53 | res = math.compute_score(solution_str, ground_truth)
54 | # [Optional] Math-Verify Integration
55 | # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).
56 | # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.
57 | # To use it, override the `compute_score` function with the following implementation:
58 |
59 | # from . import math_verify
60 | # res = math_verify.compute_score(solution_str, ground_truth)
61 | elif data_source == "math_dapo" or data_source.startswith("aime"):
62 | from verl.utils.reward_score import math_dapo
63 |
64 | res = math_dapo.compute_score(solution_str, ground_truth)
65 | elif data_source in [
66 | "numina_aops_forum",
67 | "numina_synthetic_math",
68 | "numina_amc_aime",
69 | "numina_synthetic_amc",
70 | "numina_cn_k12",
71 | "numina_olympiads",
72 | ]:
73 | from verl.utils.reward_score import prime_math
74 |
75 | res = prime_math.compute_score(solution_str, ground_truth)
76 | elif data_source in ["codecontests", "apps", "codeforces", "taco"]:
77 | # Use the passed sandbox_fusion_url if available
78 | if sandbox_fusion_url:
79 | from verl.utils.reward_score import sandbox_fusion
80 |
81 | # Pass the URL directly, ground_truth likely contains test cases here
82 | res = sandbox_fusion.compute_score(
83 | sandbox_fusion_url,
84 | concurrent_semaphore,
85 | solution_str,
86 | ground_truth,
87 | continuous=True,
88 | )
89 | else:
90 | # If no sandbox URL is provided, fall back to prime_code or raise error
91 | from verl.utils.reward_score import prime_code
92 |
93 | # Assuming prime_code doesn't need the URL
94 | res = prime_code.compute_score(solution_str, ground_truth, continuous=True)
95 | elif data_source in ["hiyouga/geometry3k"]:
96 | from verl.utils.reward_score import geo3k
97 |
98 | res = geo3k.compute_score(solution_str, ground_truth)
99 | elif data_source in [
100 | "searchR1_nq",
101 | "searchR1_triviaqa",
102 | "searchR1_popqa",
103 | "searchR1_hotpotqa",
104 | "searchR1_2wikimultihopqa",
105 | "searchR1_musique",
106 | "searchR1_bamboogle",
107 | ]:
108 | from verl.utils.reward_score import search_r1_like_qa_em
109 |
110 | res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)
111 | elif data_source in [
112 | "basicv8vc/SimpleQA",
113 | "yixuantt/MultiHopRAG",
114 | "callanwu/WebWalkerQA",
115 | ]:
116 | from . import simpleqa
117 |
118 | res = simpleqa.compute_score(solution_str, ground_truth)
119 | else:
120 | raise NotImplementedError(
121 | f"Reward function is not implemented for {data_source=}"
122 | )
123 |
124 | if isinstance(res, dict):
125 | return res
126 | elif isinstance(res, (int, float, bool)):
127 | return float(res)
128 | else:
129 | return float(res[0])
130 |
131 |
132 | @deprecated("verl.utils.reward_score.default_compute_score")
133 | def _default_compute_score(
134 | data_source,
135 | solution_str,
136 | ground_truth,
137 | extra_info=None,
138 | sandbox_fusion_url=None,
139 | concurrent_semaphore=None,
140 | ):
141 | """
142 | Legacy function API to be deprecated. Please use `default_compute_score` instead.
143 | """
144 | return default_compute_score(
145 | data_source,
146 | solution_str,
147 | ground_truth,
148 | extra_info,
149 | sandbox_fusion_url,
150 | concurrent_semaphore,
151 | )
152 |
153 |
154 | __all__ = ["default_compute_score"]
155 |
--------------------------------------------------------------------------------
/mirorl/utils/reward_score/simpleqa.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py
15 |
16 | import regex
17 |
18 |
19 | def extract_solution(solution_str):
20 | """Extract the equation from the solution string."""
21 | try:
22 | matches = regex.findall(
23 | r"\\boxed\{((?:[^{}]+|\{(?1)*\})*)\}",
24 | solution_str,
25 | regex.DOTALL,
26 | timeout=10,
27 | )
28 | except TimeoutError as _:
29 | matches = []
30 |
31 | # If there are 0 matches, return None
32 | if len(matches) < 1:
33 | return None
34 |
35 | # If there are 2 or more matches, return the last one
36 | return matches[-1].strip()
37 |
38 |
39 | def compute_score(
40 | solution_str, ground_truth, method="strict", format_score=0.0, score=1.0
41 | ):
42 | """The scoring function for exact match (EM).
43 |
44 | Args:
45 | solution_str: the solution text
46 | ground_truth: the ground truth
47 | method: the method to extract the solution, choices are 'strict' and 'flexible'
48 | format_score: the score for the format
49 | score: the score for the correct answer
50 | """
51 | answer = extract_solution(solution_str=solution_str)
52 | do_print = False
53 |
54 | if do_print:
55 | print("--------------------------------")
56 | print(f"Golden answers: {ground_truth}")
57 | if answer is not None:
58 | print(f"Extracted answer is not None: {answer}")
59 | else:
60 | print("Extracted answer: None!")
61 | print(f"Solution string: {solution_str}")
62 |
63 | if answer is None:
64 | return 0
65 | else:
66 | if answer == ground_truth:
67 | return score
68 | else:
69 | return format_score
70 |
--------------------------------------------------------------------------------
/mirorl/utils/tracking.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2025 MiroMind Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | A unified tracking interface that supports logging data to different backend
17 | """
18 |
19 | from typing import List, Union
20 |
21 | from verl.utils.tracking import Tracking
22 |
23 | from wandb import AlertLevel
24 |
25 |
26 | # Adapted from verl/utils/tracking.py
27 | # Updated by MiroMind Team
28 | # 1. Add alert function
29 | class MonitorTracking(Tracking):
30 | """A unified tracking interface for logging experiment data to multiple backends.
31 |
32 | This class provides a centralized way to log experiment metrics, parameters, and artifacts
33 | to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console.
34 |
35 | Attributes:
36 | supported_backend: List of supported tracking backends.
37 | logger: Dictionary of initialized logger instances for each backend.
38 | """
39 |
40 | def __init__(
41 | self,
42 | project_name,
43 | experiment_name,
44 | default_backend: Union[str, List[str]] = "console",
45 | config=None,
46 | ):
47 | super().__init__(project_name, experiment_name, default_backend, config)
48 |
49 | def alert(self, title, text, level=AlertLevel.WARN, wait_duration=180):
50 | assert "wandb" in self.logger, "wandb is not in the tracking"
51 | self.logger["wandb"].alert(
52 | title=title,
53 | text=text,
54 | level=level,
55 | wait_duration=wait_duration,
56 | )
57 |
--------------------------------------------------------------------------------
/mirorl/workers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .reward_manager import ToolRewardManager
16 | from .rollout import MCPSGLangRollout
17 |
18 | __all__ = ["ToolRewardManager", "MCPSGLangRollout"]
19 |
--------------------------------------------------------------------------------
/mirorl/workers/reward_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .tool import ToolRewardManager
16 |
17 | __all__ = ["ToolRewardManager"]
18 |
--------------------------------------------------------------------------------
/mirorl/workers/reward_manager/tests/test_others.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import json
15 | import time
16 | from mirorl.workers.reward_manager.tool import run_async_computation
17 | from mirorl.utils.reward_score.llm_judge import extract_boxed_answer, gaia_naive_compute_score
18 | from openai import AsyncOpenAI
19 | import os
20 | from unittest.mock import Mock
21 |
22 | def build_data(num_test_samples):
23 | # read file from testData.jsonl
24 | data = []
25 | with open("tests/testData.jsonl", "r") as f:
26 | for line in f:
27 | data.append(json.loads(line))
28 |
29 | prompt_str_list = [item["prompt_str"] for item in data[:num_test_samples]]
30 | response_str_list = [item["response_str"] for item in data[:num_test_samples]]
31 | ground_truth_list = [item["ground_truth"] for item in data[:num_test_samples]]
32 | data_source_list = [item["data_source"] for item in data[:num_test_samples]]
33 | extra_info_list = [item["extra_info"] for item in data[:num_test_samples]]
34 | valid_response_length_list = [item["valid_response_length"] for item in data[:num_test_samples]]
35 |
36 | return prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list
37 |
38 |
39 | # test the async process function
40 | def async_process_function(prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list, openai_client, batch_size=100, max_retry=3):
41 | start_time = time.time()
42 | print("Start computing accuracy reward with async api, this may take a while...")
43 | print(f"Async configuration: batch_size={batch_size}, max_retry={max_retry}")
44 | try:
45 | # Run async computation using the new async functions
46 | accuracy_reward_list = run_async_computation(
47 | openai_client,
48 | data_source_list,
49 | response_str_list,
50 | ground_truth_list,
51 | extra_info_list,
52 | batch_size,
53 | max_retry
54 | )
55 |
56 | end_time = time.time()
57 | print(f"Async batch processing completed in {end_time - start_time:.2f} seconds")
58 |
59 | except Exception as e:
60 | # if error, we use gaia_naive_compute_score to compute accuracy reward
61 | print(f"Error in batch api: {e}")
62 | print("Falling back to gaia_naive_compute_score...")
63 | accuracy_reward_list = []
64 | for i in range(len(response_str_list)):
65 | accuracy_reward = gaia_naive_compute_score(
66 | data_source=data_source_list[i],
67 | solution_str=response_str_list[i],
68 | ground_truth=ground_truth_list[i],
69 | extra_info=extra_info_list[i],
70 | )
71 | accuracy_reward_list.append(accuracy_reward)
72 |
73 | for i in range(len(accuracy_reward_list)):
74 | extracted = extract_boxed_answer(response_str_list[i])
75 | ground_truth = ground_truth_list[i]
76 | reward = accuracy_reward_list[i]
77 | print(f"{extracted:<40} | {ground_truth:<40} | {reward}")
78 |
79 |
80 | def test_exception_handling(num_test_samples):
81 | """Test the exception handling and fallback mechanism"""
82 | print("=== Testing Exception Handling ===")
83 |
84 | # Load test data
85 | prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list = build_data(num_test_samples)
86 |
87 | # Create a mock OpenAI client that always raises an exception
88 | mock_client = Mock()
89 | mock_client.chat.completions.create.side_effect = Exception("API Error: Rate limit exceeded")
90 |
91 | # Call the function with the mock client that will fail
92 | async_process_function(
93 | prompt_str_list,
94 | response_str_list,
95 | ground_truth_list,
96 | data_source_list,
97 | extra_info_list,
98 | valid_response_length_list,
99 | mock_client,
100 | batch_size=100,
101 | max_retry=1
102 | )
103 |
104 | print("=== Exception handling test completed ===")
105 |
106 |
107 |
108 |
109 | if __name__ == "__main__":
110 | num_test_samples = 128 # max is 128
111 | test_mode = "exception" # "normal" or "exception"
112 |
113 | if test_mode == "normal":
114 | print("=== Normal Operation Test ===")
115 | prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list = build_data(num_test_samples)
116 |
117 | openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
118 | async_process_function(prompt_str_list, response_str_list, ground_truth_list, data_source_list, extra_info_list, valid_response_length_list, openai_client, batch_size=100, max_retry=3)
119 | print("=== Normal Operation Test Completed ===")
120 |
121 | elif test_mode == "exception":
122 | test_exception_handling(num_test_samples)
123 |
124 | else:
125 | raise ValueError(f"Invalid test mode: {test_mode}")
--------------------------------------------------------------------------------
/mirorl/workers/rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .sglang_rollout.sglang_rollout import MCPSGLangRollout
16 |
17 | __all__ = ["MCPSGLangRollout"]
18 |
--------------------------------------------------------------------------------
/mirorl/workers/rollout/sglang_rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 MiroMind Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .sglang_rollout import MCPSGLangRollout
16 |
17 | __all__ = ["MCPSGLangRollout"]
--------------------------------------------------------------------------------
/scripts/README_GRPO_Visualizer.md:
--------------------------------------------------------------------------------
1 | # 🎯 GRPO Rollout Visualizer
2 |
3 | A beautiful and interactive visualizer for GRPO (Generalized Preference Optimization) rollouts. This tool provides a hierarchical navigation system to explore your training data across projects, epochs, and individual rollouts with stunning visualizations.
4 |
5 | ## ✨ Features
6 |
7 | ### 🏗️ Hierarchical Navigation
8 | - **Projects View**: Browse different training projects
9 | - **Epochs View**: Explore training epochs with performance metrics
10 | - **Sample View**: Analyze input groups and their statistics
11 | - **Rollouts View**: Compare individual rollouts side-by-side
12 |
13 | ### 📊 Rich Visualizations
14 | - **Interactive Charts**: Plotly-powered charts for metrics visualization
15 | - **Performance Trends**: Track improvements across epochs
16 | - **Comparison Views**: Split rollouts by performance (better vs worse)
17 | - **Statistical Insights**: Average scores, accuracy metrics, and standard deviations
18 |
19 | ### 🎨 Beautiful UI
20 | - **Modern Design**: Gradient cards and smooth animations
21 | - **Responsive Layout**: Works on different screen sizes
22 | - **Intuitive Navigation**: Breadcrumb navigation and clear visual hierarchy
23 | - **Color-coded Performance**: Green for better, red for worse rollouts
24 |
25 | ## 📦 Installation
26 |
27 | ### Install Dependencies
28 | ```bash
29 | pip install -r visualizer_requirements.txt
30 | ```
31 |
32 | ### Dependencies
33 | - **Streamlit**: Web app framework
34 | - **Pandas**: Data manipulation
35 | - **NumPy**: Numerical computations
36 | - **Plotly**: Interactive visualizations
37 |
38 | ## 🚀 Usage
39 |
40 | ### Directory Structure
41 | Your work folder should be organized as follows:
42 | ```
43 | work_folder/
44 | ├── project1/
45 | │ ├── epoch1.jsonl
46 | │ ├── epoch2.jsonl
47 | │ └── epoch3.jsonl
48 | ├── project2/
49 | │ ├── epoch1.jsonl
50 | │ └── epoch2.jsonl
51 | └── project3/
52 | └── epoch1.jsonl
53 | ```
54 |
55 | ### JSONL File Format
56 | Each line in the JSONL file should contain a rollout dictionary with the following structure:
57 | ```json
58 | {
59 | "input": "Text input with \\nuser\\n marker",
60 | "output": "Model output/response",
61 | "ground_truth": "Expected answer (optional)",
62 | "score": 0.85,
63 | "accuracy_reward": 0.9,
64 | "tool_format_reward": 0.1,
65 | "combined_reward": 0.87,
66 | "step": 0
67 | }
68 | ```
69 |
70 | ### Running the Visualizer
71 |
72 | 1. **Start the application**:
73 | ```bash
74 | streamlit run grpo_visualizer.py
75 | ```
76 |
77 | 2. **Configure the work folder**:
78 | - Enter the path to your work folder in the sidebar
79 | - Click "🔄 Refresh Data" to load the data
80 |
81 | 3. **Navigate through your data**:
82 | - **Projects**: Click on a project to explore its epochs
83 | - **Epochs**: View epoch performance and click to see input groups
84 | - **Sample**: Examine grouped inputs and their statistics
85 | - **Rollouts**: Compare individual rollouts split by performance
86 |
87 | ## 🎮 Interface Guide
88 |
89 | ### 📁 Projects View
90 | - **Project Cards**: Shows project statistics (epochs, rollouts, average score)
91 | - **Quick Access**: Click any project to dive into its epochs
92 | - **Overview**: Get a high-level view of all your training projects
93 |
94 | ### 📈 Epochs View
95 | - **Performance Charts**: Interactive plots showing score and accuracy trends
96 | - **Epoch Details**: Click on any epoch to explore its input groups
97 | - **Metrics Dashboard**: Comprehensive overview of training progress
98 |
99 | ### 📝 Sample View
100 | - **Input Grouping**: All rollouts with the same input are grouped together
101 | - **User Prompts**: Displays the text after `\nuser\n` in the input
102 | - **Sample Statistics**: Average scores, accuracy, and standard deviations
103 | - **Performance Ranking**: Sample sorted by accuracy (best first)
104 |
105 | ### 🎯 Rollouts View
106 | - **Side-by-Side Comparison**: Lower accuracy rollouts on the left, higher accuracy on the right
107 | - **Detailed Metrics**: Score, accuracy, reward, and combined reward for each rollout
108 | - **Output Inspection**: Expandable views to examine model outputs
109 | - **Ground Truth**: Compare outputs with expected answers (when available)
110 |
111 | ## 🔧 Customization
112 |
113 | ### Styling
114 | The visualizer uses custom CSS for beautiful styling. You can modify the `load_custom_css()` function to adjust:
115 | - Color schemes
116 | - Card designs
117 | - Layout spacing
118 | - Font sizes
119 |
120 | ### Data Processing
121 | The `GRPOVisualizer` class can be extended to:
122 | - Support different data formats
123 | - Add custom metrics
124 | - Implement additional visualizations
125 | - Handle special data preprocessing
126 |
127 | ## 📊 Key Metrics
128 |
129 | ### Score Metrics
130 | - **Score**: Overall performance score
131 | - **Accuracy Reward**: Accuracy-based reward
132 | - **Reward**: Base reward value
133 | - **Combined Reward**: Composite reward metric
134 |
135 | ### Statistics
136 | - **Average**: Mean values across rollouts
137 | - **Standard Deviation**: Measure of variance
138 | - **Sample Count**: Number of rollouts per input group
139 | - **Performance Split**: Median-based classification
140 |
141 | ## 🎨 UI Components
142 |
143 | ### Visual Elements
144 | - **Gradient Cards**: Beautiful project and epoch cards
145 | - **Interactive Charts**: Plotly-powered visualizations
146 | - **Responsive Grid**: Adaptive layout system
147 | - **Color Coding**: Performance-based visual cues
148 |
149 | ### Navigation
150 | - **Breadcrumb Navigation**: Always know where you are
151 | - **Back Buttons**: Easy navigation between levels
152 | - **Session State**: Maintains your position during exploration
153 | - **Sidebar Controls**: Configuration and navigation info
154 |
155 | ## 🔍 Data Insights
156 |
157 | ### Performance Analysis
158 | - **Trend Tracking**: Monitor improvements across epochs
159 | - **Comparative Analysis**: Better vs worse rollout comparison
160 | - **Statistical Overview**: Mean, standard deviation, and distributions
161 | - **Input Grouping**: Understand performance per input type
162 |
163 | ### Quality Assessment
164 | - **Accuracy Metrics**: Detailed accuracy breakdowns
165 | - **Score Distributions**: Understand performance variations
166 | - **Output Inspection**: Manual quality assessment
167 | - **Ground Truth Comparison**: Validate model outputs
168 |
169 | ## 🛠️ Technical Details
170 |
171 | ### Architecture
172 | - **Modular Design**: Clean separation of concerns
173 | - **Session Management**: Streamlit session state for navigation
174 | - **Efficient Data Loading**: Optimized JSONL parsing
175 | - **Responsive UI**: Modern web interface
176 |
177 | ### Performance
178 | - **Lazy Loading**: Data loaded on demand
179 | - **Caching**: Efficient data retrieval
180 | - **Memory Management**: Optimized for large datasets
181 | - **Interactive Updates**: Real-time UI updates
182 |
183 | ## 🤝 Contributing
184 |
185 | Feel free to contribute improvements:
186 | - **Bug Reports**: Submit issues with detailed descriptions
187 | - **Feature Requests**: Suggest new visualization features
188 | - **Code Improvements**: Submit pull requests with enhancements
189 | - **Documentation**: Help improve this README
190 |
191 | ## 📝 License
192 |
193 | This project follows the same license as the main VERL project.
194 |
195 | ---
196 |
197 | **Happy Visualizing!** 🎉 Explore your GRPO rollouts with style and gain insights into your model's training progress.
--------------------------------------------------------------------------------
/scripts/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export MAX_JOBS=32
4 |
5 | echo "1. install inference frameworks and pytorch they need"
6 | pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
7 | pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata
8 |
9 |
10 | echo "2. install basic packages"
11 | pip install "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \
12 | "numpy<2.0.0" "pyarrow>=15.0.0" pandas \
13 | ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \
14 | pytest py-spy pyext pre-commit ruff mcp==1.10.1 tenacity
15 |
16 | pip install "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1"
17 |
18 |
19 | echo "3. install FlashAttention"
20 | # Install flash-attn-2.7.4.post1 (cxx11abi=False)
21 | wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
22 | pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
23 |
24 |
25 | echo "4. install FlashInfer"
26 | # Install flashinfer-0.2.7.post1+cu124 (cxx11abi=False)
27 | # 1. Clone the FlashInfer repository:
28 | git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
29 | # 2. Make sure you have installed PyTorch with CUDA support. You can check the PyTorch version and CUDA version by running:
30 | python -c "import torch; print(torch.__version__, torch.version.cuda)"
31 | # 3. Install Ninja build system:
32 | pip install ninja
33 | # 4. Install FlashInfer(AOT mode):
34 | cd flashinfer
35 | git checkout v0.2.7.post1
36 | # Set the TORCH_CUDA_ARCH_LIST environment variable to the supported architectures:
37 | # export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a"
38 | # if you are using a100/a800, you can use the following command:
39 | export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
40 | # Produces AOT kernels in aot-ops/
41 | python -m flashinfer.aot
42 | python -m pip install --no-build-isolation --verbose .
43 |
44 |
45 | echo "5. May need to fix opencv"
46 | pip install opencv-python
47 | pip install opencv-fixer && \
48 | python -c "from opencv_fixer import AutoFix; AutoFix()"
49 |
50 | echo "Successfully installed all packages"
51 |
--------------------------------------------------------------------------------
/scripts/visualizer_requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit>=1.28.0
2 | pandas>=1.5.0
3 | numpy>=1.21.0
4 | plotly>=5.15.0
--------------------------------------------------------------------------------
/tests/distributed/run_all.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/usr/bin/env bash
16 |
17 | set -e -x
18 | torchrun --nproc-per-node=4 --standalone tests/distributed/test_tensor_dict.py
--------------------------------------------------------------------------------
/tests/distributed/test_tensor_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | os.environ["NCCL_DEBUG"] = "WARN"
18 |
19 | import numpy as np
20 | import torch
21 | import torch.distributed
22 |
23 | from verl.protocol import DataProto, all_gather_data_proto
24 | from verl.utils.distributed import initialize_global_process_group
25 |
26 |
27 | def test_all_gather_data_proto():
28 | device_mesh = torch.distributed.device_mesh.init_device_mesh("cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"])
29 |
30 | global_rank = torch.distributed.get_rank()
31 |
32 | obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]])
33 |
34 | labels = ["a", "b"] if global_rank % 2 == 0 else ["b", "a"]
35 | labels = np.array(labels, dtype=object)
36 | data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
37 |
38 | all_gather_data_proto(data=data, process_group=device_mesh.get_group("dp"))
39 |
40 | if global_rank == 0:
41 | expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda")
42 | expected_labels = ["a", "b", "a", "b"]
43 | elif global_rank == 1:
44 | expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda")
45 | expected_labels = ["b", "a", "b", "a"]
46 | elif global_rank == 2:
47 | expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda")
48 | expected_labels = ["a", "b", "a", "b"]
49 | elif global_rank == 3:
50 | expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda")
51 | expected_labels = ["b", "a", "b", "a"]
52 |
53 | torch.testing.assert_close(data.batch["obs"], expected_obs, atol=0, rtol=0)
54 | assert (data.non_tensor_batch["labels"] == expected_labels).all()
55 | assert data.meta_info == {"info": "test_info"}
56 |
57 |
58 | def test_vocab_parallel_entropy():
59 | from megatron.core import parallel_state as mpu
60 |
61 | from verl.utils.debug import log_gpu_memory_usage
62 | from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy
63 | from verl.utils.torch_functional import entropy_from_logits
64 |
65 | mpu.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None)
66 |
67 | batch_size = 2
68 | seqlen = 128
69 | vocab_size = 155136
70 |
71 | logits = torch.randn(batch_size * seqlen, vocab_size, device="cuda", requires_grad=True)
72 | target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64)
73 |
74 | # broadcast across tp
75 | torch.distributed.broadcast(logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
76 | torch.distributed.broadcast(target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
77 |
78 | tp_rank = mpu.get_tensor_model_parallel_rank()
79 | vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size()
80 |
81 | # get the local logits of each tp
82 | vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_()
83 | logits.grad = None
84 | vocab_parallel_logits.grad = None
85 |
86 | log_gpu_memory_usage("begin")
87 | output_entropy = vocab_parallel_entropy(vocab_parallel_logits)
88 | log_gpu_memory_usage("after forward")
89 | grad_output = torch.randn_like(output_entropy)
90 | output_entropy.backward(grad_output)
91 | log_gpu_memory_usage("after backward")
92 |
93 | target_entropy = entropy_from_logits(logits)
94 | torch.testing.assert_close(output_entropy, target_entropy)
95 | target_entropy.backward(grad_output)
96 | torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad)
97 | # make sure logits is not altered
98 | torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits)
99 |
100 | if mpu.get_tensor_model_parallel_rank() == 0:
101 | print("test_vocab_parallel_entropy passes")
102 |
103 | mpu.destroy_model_parallel()
104 |
105 |
106 | if __name__ == "__main__":
107 | local_rank, rank, world_size = initialize_global_process_group()
108 | test_all_gather_data_proto()
109 | test_vocab_parallel_entropy()
110 |
--------------------------------------------------------------------------------
/tests/models/test_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
17 | from transformers import (
18 | AutoModelForCausalLM,
19 | AutoModelForTokenClassification,
20 | GemmaConfig,
21 | LlamaConfig,
22 | MistralConfig,
23 | Qwen2Config,
24 | )
25 |
26 | from verl.utils.model import compute_position_id_with_mask, create_random_mask
27 | from verl.utils.torch_functional import log_probs_from_logits_all_rmpad, masked_mean
28 |
29 | # TODO(sgm): add more models for test
30 | # we only need one scale for each model
31 | test_configs = [
32 | LlamaConfig(num_hidden_layers=1),
33 | MistralConfig(num_hidden_layers=1),
34 | GemmaConfig(num_hidden_layers=1),
35 | Qwen2Config(num_hidden_layers=1),
36 | ]
37 |
38 |
39 | def test_hf_casual_models():
40 | batch_size = 4
41 | seqlen = 128
42 | response_length = 127
43 |
44 | for config in test_configs:
45 | # config = AutoConfig.from_pretrained(test_case)
46 | with torch.device("cuda"):
47 | model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
48 | model = model.to(device="cuda")
49 | input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
50 | attention_mask = create_random_mask(
51 | input_ids=input_ids,
52 | max_ratio_of_left_padding=0.1,
53 | max_ratio_of_valid_token=0.8,
54 | min_ratio_of_valid_token=0.5,
55 | )
56 | position_ids = compute_position_id_with_mask(attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
57 |
58 | input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
59 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
60 |
61 | # unpad the position_ids to align the rotary
62 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1)
63 |
64 | # input with input_ids_rmpad and postition_ids to enable flash attention varlen
65 | logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False).logits # (1, total_nnz, vocab_size)
66 |
67 | origin_logits = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False).logits
68 | origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)
69 |
70 | logits_rmpad = logits_rmpad.squeeze(0)
71 | log_probs = log_probs_from_logits_all_rmpad(
72 | input_ids_rmpad=input_ids_rmpad,
73 | logits_rmpad=logits_rmpad,
74 | indices=indices,
75 | batch_size=batch_size,
76 | seqlen=seqlen,
77 | response_length=response_length,
78 | ) # (batch, seqlen)
79 | origin_log_probs = log_probs_from_logits_all_rmpad(
80 | input_ids_rmpad=input_ids_rmpad,
81 | logits_rmpad=origin_logits_rmpad,
82 | indices=origin_logits_indices,
83 | batch_size=batch_size,
84 | seqlen=seqlen,
85 | response_length=response_length,
86 | ) # (batch, seqlen)
87 |
88 | torch.testing.assert_close(
89 | masked_mean(log_probs, attention_mask[:, -response_length - 1 : -1]),
90 | masked_mean(origin_log_probs, attention_mask[:, -response_length - 1 : -1]),
91 | atol=1e-2,
92 | rtol=1e-5,
93 | )
94 | print("Check pass")
95 |
96 |
97 | def test_hf_value_models():
98 | batch_size = 4
99 | seqlen = 128
100 |
101 | for config in test_configs:
102 | # config = AutoConfig.from_pretrained(test_case)
103 | config.num_labels = 1
104 | config.classifier_dropout = 0
105 | config.hidden_dropout = 0
106 | with torch.device("cuda"):
107 | model = AutoModelForTokenClassification.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
108 | model = model.to(device="cuda")
109 | input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
110 | attention_mask = create_random_mask(
111 | input_ids=input_ids,
112 | max_ratio_of_left_padding=0.1,
113 | max_ratio_of_valid_token=0.8,
114 | min_ratio_of_valid_token=0.5,
115 | )
116 | position_ids = compute_position_id_with_mask(attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
117 |
118 | input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
119 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
120 |
121 | # unpad the position_ids to align the rotary
122 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1)
123 |
124 | origin_logits = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False).logits
125 |
126 | # input with input_ids_rmpad and postition_ids to enable flash attention varlen
127 | rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False).logits # (1, total_nnz, 1)
128 | rmpad_logits = rmpad_logits.squeeze(0)
129 | pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen)
130 |
131 | torch.testing.assert_close(
132 | masked_mean(pad_logits, attention_mask[:, :, None]),
133 | masked_mean(origin_logits, attention_mask[:, :, None]),
134 | atol=1e-2,
135 | rtol=1e-5,
136 | )
137 | print("Value model check pass")
138 |
139 |
140 | if __name__ == "__main__":
141 | test_hf_casual_models()
142 | test_hf_value_models()
143 |
--------------------------------------------------------------------------------
/tests/ray_gpu/detached_worker/README.md:
--------------------------------------------------------------------------------
1 | # Detached Worker
2 | ## How to run (Only on a single node)
3 | - Start a local ray cluster:
4 | ```bash
5 | ray start --head --port=6379
6 | ```
7 | - Run the server
8 | ```bash
9 | python3 server.py
10 | ```
11 | - On another terminal, Run the client
12 | ```bash
13 | python3 client.py
14 | ```
15 |
--------------------------------------------------------------------------------
/tests/ray_gpu/detached_worker/client.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | In client, we can get the server handler and send RPC request
16 | """
17 |
18 | import ray
19 | import torch
20 | from server import Trainer
21 | from tensordict import TensorDict
22 |
23 | from verl import DataProto
24 | from verl.single_controller.ray import RayClassWithInitArgs
25 | from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
26 |
27 |
28 | def compute_position_id_with_mask(mask):
29 | return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
30 |
31 |
32 | if __name__ == "__main__":
33 | ray.init(address="auto", namespace="verl")
34 | # get the worker group using names
35 | worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"]
36 | cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
37 | worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args)
38 |
39 | batch_size = 16
40 | sequence_length = 1024
41 |
42 | # give Trainer some data to train
43 | input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda")
44 | attention_mask = torch.ones_like(input_ids)
45 | position_ids = compute_position_id_with_mask(attention_mask)
46 |
47 | data = DataProto(
48 | batch=TensorDict(
49 | {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids},
50 | batch_size=batch_size,
51 | ),
52 | meta_info={},
53 | )
54 |
55 | output = worker_group.train_model(data)
56 |
57 | print(output)
58 |
--------------------------------------------------------------------------------
/tests/ray_gpu/detached_worker/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ray start --head --port=6379
3 | python3 server.py
4 | python3 client.py
5 | ray stop --force
--------------------------------------------------------------------------------
/tests/ray_gpu/detached_worker/server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Server starts a Trainer. Client sends data to the server to train.
16 | """
17 |
18 | import os
19 |
20 | os.environ["MEGATRON_USE_CUDA_TIMER"] = "0"
21 | os.environ["MEGATRON_START_PROCESS_TIMER"] = "False"
22 | os.environ["NCCL_DEBUG"] = "WARN"
23 |
24 | import ray
25 | import torch
26 | from megatron.core import parallel_state as mpu
27 | from megatron.core import tensor_parallel
28 | from megatron.core.models.gpt.gpt_model import ModelType
29 | from omegaconf import OmegaConf
30 | from tensordict import TensorDict
31 | from torch import nn
32 | from transformers import LlamaConfig
33 |
34 | from verl import DataProto
35 | from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
36 | from verl.single_controller.base.decorator import Dispatch, register
37 | from verl.single_controller.base.megatron.worker import MegatronWorker
38 | from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool
39 | from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
40 | from verl.utils.megatron.optimizer import get_megatron_optimizer
41 | from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config
42 |
43 |
44 | @ray.remote
45 | class Trainer(MegatronWorker):
46 | def __init__(self):
47 | super().__init__()
48 |
49 | if not torch.distributed.is_initialized():
50 | rank = int(os.environ["LOCAL_RANK"])
51 | torch.distributed.init_process_group(backend="nccl")
52 | torch.cuda.set_device(rank)
53 |
54 | os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
55 | mpu.initialize_model_parallel(
56 | tensor_model_parallel_size=2,
57 | pipeline_model_parallel_size=1,
58 | virtual_pipeline_model_parallel_size=None,
59 | pipeline_model_parallel_split_rank=None,
60 | use_sharp=False,
61 | context_parallel_size=1,
62 | expert_model_parallel_size=1,
63 | nccl_communicator_config_path=None,
64 | )
65 | tensor_parallel.model_parallel_cuda_manual_seed(10)
66 |
67 | @register(dispatch_mode=Dispatch.ONE_TO_ALL)
68 | def init_model(self):
69 | actor_model_config = LlamaConfig(
70 | vocab_size=256,
71 | hidden_size=2048,
72 | intermediate_size=5504,
73 | num_hidden_layers=24,
74 | num_attention_heads=16,
75 | num_key_value_heads=16,
76 | )
77 |
78 | megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)
79 | self.megatron_config = megatron_config
80 |
81 | def megatron_actor_model_provider(pre_process, post_process):
82 | # vpp is not supported yet because it will hang for some reason. Need debugging
83 | # this_megatron_config = copy.deepcopy(megatron_config)
84 | # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
85 | parallel_model = ParallelLlamaForCausalLMRmPadPP(
86 | config=actor_model_config,
87 | megatron_config=megatron_config,
88 | pre_process=pre_process,
89 | post_process=post_process,
90 | )
91 | parallel_model.cuda()
92 | return parallel_model
93 |
94 | actor_module = get_model(
95 | model_provider_func=megatron_actor_model_provider,
96 | model_type=ModelType.encoder_or_decoder,
97 | wrap_with_ddp=True,
98 | )
99 | actor_module = nn.ModuleList(actor_module)
100 |
101 | optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0})
102 |
103 | optim_config = init_megatron_optim_config(optim_config)
104 | self.optimizer_config = optim_config
105 | actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
106 |
107 | self.model = actor_module[0]
108 | self.optimizer = actor_optimizer
109 |
110 | @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
111 | def train_model(self, data: DataProto) -> DataProto:
112 | input_ids = data.batch["input_ids"]
113 | attention_mask = data.batch["attention_mask"]
114 | position_ids = data.batch["position_ids"]
115 |
116 | self.optimizer.zero_grad()
117 | self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
118 | # update for 1 iteration
119 | output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
120 | output.mean().backward()
121 |
122 | update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config, self.megatron_config.timers)
123 |
124 | return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0]))
125 |
126 |
127 | if __name__ == "__main__":
128 | ray.init(address="auto", namespace="verl")
129 |
130 | resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
131 | cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
132 | worker_group = NVMegatronRayWorkerGroup(
133 | resource_pool=resource_pool,
134 | ray_cls_with_init=cls_with_init_args,
135 | name_prefix="trainer",
136 | detached=True,
137 | )
138 |
139 | worker_group.init_model()
140 |
141 | worker_names = worker_group.worker_names
142 | print(worker_names)
143 |
--------------------------------------------------------------------------------
/tests/ray_gpu/test_colocated_workers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ray
16 |
17 | from verl import DataProto
18 | from verl.single_controller.base import Worker
19 | from verl.single_controller.base.decorator import Dispatch, register
20 | from verl.single_controller.ray.base import (
21 | RayClassWithInitArgs,
22 | RayResourcePool,
23 | RayWorkerGroup,
24 | create_colocated_worker_cls,
25 | )
26 |
27 |
28 | @ray.remote
29 | class Actor(Worker):
30 | def __init__(self) -> None:
31 | super().__init__()
32 |
33 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
34 | def add(self, data: DataProto):
35 | data.batch["a"] += self.rank
36 | return data
37 |
38 |
39 | @ray.remote
40 | class Critic(Worker):
41 | def __init__(self, config) -> None:
42 | super().__init__()
43 | self.config = config
44 |
45 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
46 | async def sub(self, data: DataProto):
47 | data.batch["a"] -= self.config["b"]
48 | return data
49 |
50 |
51 | def test_colocated_workers():
52 | ray.init()
53 |
54 | import torch
55 |
56 | data = DataProto.from_dict({"a": torch.zeros(10)})
57 | # create separate workers on the same resource pool
58 | actor_cls = RayClassWithInitArgs(cls=Actor)
59 | critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
60 | resource_pool = RayResourcePool(process_on_nodes=[2])
61 |
62 | actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
63 | critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)
64 |
65 | expected_actor_output = actor_wg.add(data)
66 | expected_critic_output = critic_wg.sub(data)
67 |
68 | # create colocated workers
69 | cls_dict = {"actor": actor_cls, "critic": critic_cls}
70 | ray_cls_with_init = create_colocated_worker_cls(cls_dict)
71 | wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
72 | spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
73 |
74 | colocated_actor_wg = spawn_wg["actor"]
75 | colocated_critic_wg = spawn_wg["critic"]
76 |
77 | actor_output = colocated_actor_wg.add(data)
78 | critic_output = colocated_critic_wg.sub(data)
79 |
80 | torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
81 | torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
82 |
83 | ray.shutdown()
84 |
--------------------------------------------------------------------------------
/tests/ray_gpu/test_colocated_workers_fused.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ray
16 |
17 | from verl import DataProto
18 | from verl.single_controller.base import Worker
19 | from verl.single_controller.base.decorator import Dispatch, register
20 | from verl.single_controller.ray.base import (
21 | RayClassWithInitArgs,
22 | RayResourcePool,
23 | RayWorkerGroup,
24 | create_colocated_worker_cls_fused,
25 | )
26 |
27 |
28 | @ray.remote
29 | class Actor(Worker):
30 | def __init__(self) -> None:
31 | super().__init__()
32 |
33 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
34 | def add(self, data: DataProto):
35 | data.batch["a"] += self.rank
36 | return data
37 |
38 |
39 | @ray.remote
40 | class Critic(Worker):
41 | def __init__(self, config) -> None:
42 | super().__init__()
43 | self.config = config
44 |
45 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
46 | def sub(self, data: DataProto):
47 | data.batch["a"] -= self.config["b"]
48 | return data
49 |
50 |
51 | def test_colocated_workers_fused():
52 | ray.init()
53 |
54 | import torch
55 |
56 | data = DataProto.from_dict({"a": torch.zeros(10)})
57 | # create separate workers on the same resource pool
58 | actor_cls = RayClassWithInitArgs(cls=Actor)
59 | critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
60 | resource_pool = RayResourcePool(process_on_nodes=[2])
61 |
62 | actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
63 | critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)
64 |
65 | expected_actor_output = actor_wg.add(data)
66 | expected_critic_output = critic_wg.sub(data)
67 |
68 | # create colocated workers
69 | cls_dict = {"actor": actor_cls, "critic": critic_cls}
70 | ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict)
71 | wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
72 | spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
73 |
74 | colocated_actor_wg = spawn_wg["actor"]
75 | colocated_critic_wg = spawn_wg["critic"]
76 |
77 | actor_output = colocated_actor_wg.add(data)
78 | critic_output = colocated_critic_wg.sub(data)
79 |
80 | torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
81 | torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
82 |
83 | ray.shutdown()
84 |
--------------------------------------------------------------------------------
/tests/ray_gpu/test_data_transfer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | In this test, we instantiate a data parallel worker with 8 GPUs
16 | """
17 |
18 | import ray
19 | import tensordict
20 | import torch
21 | from codetiming import Timer
22 | from torch import distributed as dist
23 |
24 | from verl import DataProto
25 | from verl.single_controller.base import Worker
26 | from verl.single_controller.base.decorator import Dispatch, register
27 | from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
28 | from verl.utils.ray_utils import parallel_put
29 |
30 |
31 | @ray.remote
32 | class DummyWorker(Worker):
33 | def __init__(self):
34 | super().__init__()
35 | dist.init_process_group()
36 |
37 | @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
38 | def do_nothing(self, data):
39 | for key in data.batch.keys():
40 | data.batch[key] += 1
41 | if tensordict.__version__ >= "0.5.0":
42 | data.batch = data.batch.consolidate()
43 | return data
44 |
45 |
46 | def test_data_transfer():
47 | ray.init()
48 | # construct resource pool
49 | resource_pool = RayResourcePool([8])
50 | cls_with_init = RayClassWithInitArgs(cls=DummyWorker)
51 | # construct worker group
52 | wg = RayWorkerGroup(resource_pool, cls_with_init)
53 |
54 | # this is real dataset size
55 | batch_size = 4096
56 | seqlen = 32768
57 |
58 | data_dict = {}
59 |
60 | for i in range(2):
61 | data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen))
62 |
63 | data = DataProto.from_dict(tensors=data_dict)
64 |
65 | print(data)
66 |
67 | # we manually split data here and send to each worker
68 | data_list = data.chunk(wg.world_size)
69 |
70 | for i in range(wg.world_size):
71 | # consolidate is necessary
72 | if tensordict.__version__ >= "0.5.0":
73 | data_list[i].batch = data_list[i].batch.consolidate()
74 |
75 | with Timer(name="ray.pickle", initial_text=True):
76 | for i in range(wg.world_size):
77 | ray.cloudpickle.pickle.dumps(data_list[i])
78 |
79 | with Timer(name="raw.pickle", initial_text=True):
80 | import pickle
81 |
82 | for i in range(wg.world_size):
83 | pickle.dumps(data_list[i])
84 |
85 | # we put in advance
86 | with Timer(name="put", initial_text=True):
87 | # takes around 40 seconds
88 | data_list_ref = parallel_put(data_list)
89 | # for i in range(wg.world_size):
90 | # data_list[i] = ray.put(data_list[i])
91 |
92 | with Timer(name="launch", initial_text=True):
93 | output_ref = wg.do_nothing(data_list_ref)
94 |
95 | with Timer(name="get", initial_text=True):
96 | # takes around 40 seconds
97 | output_lst = ray.get(output_ref)
98 |
99 | for input_data, output_data in zip(data_list, output_lst):
100 | for key in input_data.batch.keys():
101 | assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (
102 | input_data.batch[key],
103 | output_data.batch[key],
104 | key,
105 | )
106 |
107 | ray.shutdown()
108 |
--------------------------------------------------------------------------------
/tests/ray_gpu/test_driverfunc_to_worker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import ray
18 | import torch
19 | from tensordict import TensorDict
20 |
21 | from verl import DataProto
22 | from verl.single_controller.base.worker import Worker
23 | from verl.single_controller.ray import RayWorkerGroup
24 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool
25 |
26 | os.environ["RAY_DEDUP_LOGS"] = "0"
27 | os.environ["NCCL_DEBUG"] = "WARN"
28 |
29 |
30 | @ray.remote
31 | class ModelActor(Worker):
32 | def __init__(self):
33 | pass
34 |
35 |
36 | class HackSelf:
37 | def __init__(self):
38 | pass
39 |
40 |
41 | def get_aux_metrics(self, test_proto):
42 | sequence_ids = test_proto.batch["sequence_ids"]
43 | decode_count = []
44 | for i in range(sequence_ids.size(0)):
45 | decode_count.append(len(sequence_ids[i].tolist()))
46 | ret_proto = DataProto(batch=TensorDict({"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0)))
47 | return ret_proto
48 |
49 |
50 | def test():
51 | # construct model
52 | ray.init()
53 |
54 | # create 2 workers, each hold a GPU
55 | resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a")
56 |
57 | class_with_args = RayClassWithInitArgs(cls=ModelActor)
58 | shard_wg = RayWorkerGroup(resource_pool, class_with_args)
59 |
60 | test_bs = 8
61 | test_proto = DataProto(
62 | TensorDict(
63 | {
64 | "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
65 | },
66 | batch_size=test_bs,
67 | ),
68 | meta_info={"query_length": 1536},
69 | )
70 |
71 | # Sharding among different ranks
72 | ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)
73 |
74 | # compare execute on driver
75 | hs = HackSelf()
76 | ret_proto2 = get_aux_metrics(hs, test_proto)
77 |
78 | torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"])
79 |
80 | ray.shutdown()
81 |
--------------------------------------------------------------------------------
/tests/ray_gpu/test_high_level_scheduling_api.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import time
16 |
17 | import ray
18 |
19 | from verl.single_controller.base.worker import Worker
20 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool
21 |
22 |
23 | @ray.remote
24 | class TestActor(Worker):
25 | # TODO: pass *args and **kwargs is bug prone and not very convincing
26 | def __init__(self, cuda_visible_devices=None) -> None:
27 | super().__init__(cuda_visible_devices)
28 |
29 | def get_node_id(self):
30 | return ray.get_runtime_context().get_node_id()
31 |
32 |
33 | def test():
34 | ray.init()
35 |
36 | # test single-node-no-partition
37 | print("test single-node-no-partition")
38 | resource_pool = RayResourcePool([8], use_gpu=True)
39 |
40 | class_with_args = RayClassWithInitArgs(cls=TestActor)
41 |
42 | print("create actor worker group")
43 | actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor")
44 | print("create critic worker group")
45 | critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic")
46 | print("create rm worker group")
47 | rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm")
48 | print("create ref worker group")
49 | ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref")
50 |
51 | assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
52 | assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
53 | assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
54 | assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
55 |
56 | del actor_wg
57 | del critic_wg
58 | del rm_wg
59 | del ref_wg
60 |
61 | [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]
62 | print("wait 5s to remove placemeng_group")
63 | time.sleep(5)
64 | # test single-node-multi-partition
65 |
66 | print("test single-node-multi-partition")
67 | rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
68 | ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
69 | total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)
70 |
71 | assert rm_resource_pool.world_size == 4
72 | assert ref_resource_pool.world_size == 4
73 | assert total_resource_pool.world_size == 8
74 |
75 | actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor")
76 | critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic")
77 | rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm")
78 | ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref")
79 |
80 | assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
81 | assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
82 | assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)]
83 | assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)]
84 |
85 | ray.shutdown()
86 |
--------------------------------------------------------------------------------
/tests/ray_gpu/test_rvdz.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ray
16 |
17 |
18 | @ray.remote
19 | class TestWorker:
20 | def __init__(self, rank, world_size, group_name):
21 | self.rank = rank
22 | self.world_size = world_size
23 | self.group_name = group_name
24 | self.communicator = None
25 |
26 | def init(self):
27 | from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray
28 |
29 | self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name)
30 |
31 | def test(self):
32 | if self.communicator is None:
33 | return None
34 | return self.communicator.rank_id()
35 |
36 |
37 | def test_rvdz():
38 | ray.init()
39 |
40 | group_name = "test_group"
41 | world_size = 2
42 |
43 | workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)]
44 |
45 | ray.get([worker.init.remote() for worker in workers])
46 |
47 | ranks = ray.get([worker.test.remote() for worker in workers])
48 |
49 | assert ranks == [0, 1], f"expecting [0, 1], got {ranks}"
50 |
51 | ray.shutdown()
52 |
--------------------------------------------------------------------------------
/tests/ray_gpu/test_worker_group_basics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | e2e test verl.single_controller.ray
16 | """
17 |
18 | import ray
19 | import torch
20 |
21 | from verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register
22 | from verl.single_controller.base.worker import Worker
23 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
24 |
25 |
26 | def two_to_all_dispatch_fn(worker_group, *args, **kwargs):
27 | """
28 | Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.
29 | """
30 | for arg in args:
31 | assert len(arg) == 2
32 | for i in range(worker_group.world_size - 2):
33 | arg.append(arg[i % 2])
34 | for k, v in kwargs.items():
35 | assert len(v) == 2
36 | for i in range(worker_group.world_size - 2):
37 | v.append(v[i % 2])
38 | return args, kwargs
39 |
40 |
41 | @ray.remote
42 | class TestActor(Worker):
43 | # TODO: pass *args and **kwargs is bug prone and not very convincing
44 | def __init__(self, x) -> None:
45 | super().__init__()
46 | self._x = x
47 |
48 | def foo(self, y):
49 | return self._x + y
50 |
51 | @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
52 | def foo_rank_zero(self, x, y):
53 | return self._x + y + x
54 |
55 | @register(Dispatch.ONE_TO_ALL, blocking=False)
56 | def foo_one_to_all(self, x, y):
57 | return self._x + y + x
58 |
59 | @register(Dispatch.ALL_TO_ALL, blocking=False)
60 | def foo_all_to_all(self, x, y):
61 | return self._x + y + x
62 |
63 | @register(dispatch_mode={"dispatch_fn": two_to_all_dispatch_fn, "collect_fn": collect_all_to_all})
64 | def foo_custom(self, x, y):
65 | return self._x + y + x
66 |
67 |
68 | @ray.remote(num_gpus=0.1)
69 | def remote_call_wg(worker_names):
70 | class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
71 | worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None)
72 | print(worker_group.worker_names)
73 |
74 | output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])
75 | assert output_ref == [8, 10, 8, 10]
76 |
77 | output_ref = worker_group.foo_rank_zero(x=1, y=2)
78 | assert output_ref == 5
79 |
80 | return worker_group.worker_names
81 |
82 |
83 | def add_one(data):
84 | data = data.to("cuda")
85 | data += 1
86 | data = data.to("cpu")
87 | return data
88 |
89 |
90 | def test_basics():
91 | ray.init(num_cpus=100)
92 |
93 | # create 4 workers, each hold a GPU
94 | resource_pool = RayResourcePool([4], use_gpu=True)
95 | class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
96 |
97 | worker_group = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic")
98 |
99 | print(worker_group.worker_names)
100 |
101 | # this will wait for all the results
102 | output = worker_group.execute_all_sync("foo", y=3)
103 | assert output == [5, 5, 5, 5]
104 |
105 | # this is a list of object reference. It won't block.
106 | output_ref = worker_group.execute_all_async("foo", y=4)
107 | print(output_ref)
108 |
109 | assert ray.get(output_ref) == [6, 6, 6, 6]
110 |
111 | output_ref = worker_group.foo_one_to_all(x=1, y=2)
112 | assert ray.get(output_ref) == [5, 5, 5, 5]
113 |
114 | output_ref = worker_group.foo_all_to_all(x=[1, 2, 3, 4], y=[5, 6, 7, 8])
115 | assert ray.get(output_ref) == [8, 10, 12, 14]
116 |
117 | print(ray.get(remote_call_wg.remote(worker_group.worker_names)))
118 |
119 | output = worker_group.execute_func_rank_zero(add_one, torch.ones(2, 2))
120 | torch.testing.assert_close(output, torch.ones(2, 2) + 1)
121 |
122 | ray.shutdown()
123 |
124 |
125 | if __name__ == "__main__":
126 | test_basics()
127 |
--------------------------------------------------------------------------------
/tests/ray_gpu/test_worker_group_torch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | os.environ["RAY_DEDUP_LOGS"] = "0"
18 | os.environ["NCCL_DEBUG"] = "WARN"
19 |
20 | import ray
21 | import torch
22 | import torch.distributed
23 |
24 | from verl.single_controller.base.worker import Worker
25 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
26 |
27 |
28 | @ray.remote
29 | class TestAllGatherActor(Worker):
30 | def __init__(self, size) -> None:
31 | super().__init__()
32 | self.size = size
33 |
34 | def init(self):
35 | torch.distributed.init_process_group()
36 | self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda")
37 | self.tensor += self.rank
38 |
39 | def all_gather(self):
40 | world_size = self._world_size
41 | output = torch.zeros(size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device)
42 | torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
43 | return output
44 |
45 |
46 | @ray.remote
47 | class TestAllGatherActorV2(Worker):
48 | def __init__(self, size) -> None:
49 | super().__init__()
50 | self.size = size
51 |
52 | torch.distributed.init_process_group()
53 | self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda")
54 | self.tensor += self.rank
55 |
56 | def all_gather(self):
57 | world_size = self._world_size
58 | output = torch.zeros(size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device)
59 | torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
60 | return output
61 |
62 |
63 | def test_all_gather_torch():
64 | """
65 | In this test, we instantiate 4 GPUs in a group and test the all_gather
66 | """
67 | ray.init()
68 |
69 | # create 4 workers, each hold a GPU
70 | resource_pool = RayResourcePool([4], use_gpu=True)
71 | class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2)
72 |
73 | worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch")
74 |
75 | worker_group.execute_all_sync("init")
76 | output = worker_group.execute_all_sync("all_gather")
77 | for i in range(1, len(output)):
78 | assert torch.all(output[i] == output[0])
79 |
80 | output = output[0].cpu()
81 | print(output)
82 | assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
83 |
84 | ray.shutdown()
85 |
86 |
87 | def test_all_gather_torch_v2():
88 | """
89 | In this test, we instantiate 4 GPUs in a group and test the all_gather
90 | """
91 | ray.init()
92 |
93 | # create 4 workers, each hold a GPU
94 | resource_pool = RayResourcePool([4], use_gpu=True)
95 | class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2)
96 |
97 | worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch")
98 |
99 | output = worker_group.execute_all_sync("all_gather")
100 | for i in range(1, len(output)):
101 | assert torch.all(output[i] == output[0])
102 |
103 | output = output[0].cpu()
104 | print(output)
105 | assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
106 |
107 | ray.shutdown()
108 |
--------------------------------------------------------------------------------
/tests/sanity/check_license.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from argparse import ArgumentParser
15 | from pathlib import Path
16 |
17 | license_head_miro = "Copyright 2025 MiroMind Team"
18 | license_head_bytedance = "Copyright 2024 Bytedance Ltd. and/or its affiliates"
19 | license_head_bytedance_25 = "Copyright 2025 Bytedance Ltd. and/or its affiliates"
20 | # Add custom license headers below
21 | license_head_prime = "Copyright 2024 PRIME team and/or its affiliates"
22 | license_head_individual = "Copyright 2025 Individual Contributor:"
23 | license_head_sglang = "Copyright 2023-2024 SGLang Team"
24 | license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates"
25 | license_headers = [
26 | license_head_miro,
27 | license_head_bytedance,
28 | license_head_bytedance_25,
29 | license_head_prime,
30 | license_head_individual,
31 | license_head_sglang,
32 | license_head_modelbest,
33 | ]
34 |
35 |
36 | if __name__ == "__main__":
37 | parser = ArgumentParser()
38 | parser.add_argument("--directory", "-d", required=True, type=str)
39 | args = parser.parse_args()
40 | directory_in_str = args.directory
41 |
42 | pathlist = Path(directory_in_str).glob("**/*.py")
43 | for path in pathlist:
44 | # because path is object not string
45 | path_in_str = str(path.absolute())
46 | print(path_in_str)
47 | with open(path_in_str, encoding="utf-8") as f:
48 | file_content = f.read()
49 |
50 | has_license = False
51 | for lh in license_headers:
52 | if lh in file_content:
53 | has_license = True
54 | break
55 | assert has_license, f"file {path_in_str} does not contain license"
56 |
--------------------------------------------------------------------------------
/tests/sanity/check_pr_title.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import re
17 | import sys
18 |
19 | # Get PR title from environment
20 | pr_title = os.environ.get("PR_TITLE", "").strip()
21 |
22 | # Define rules
23 | allowed_modules = ["fsdp", "megatron", "sglang", "vllm", "rollout", "trainer"]
24 | allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"]
25 | allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"]
26 | allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt"]
27 | allowed_types = ["feat", "fix", "doc", "refactor", "chore"]
28 |
29 | # Build dynamic regex pattern
30 | re_modules_pattern = re.compile(r"^\[([a-z_,\s]+)\]", re.IGNORECASE)
31 | re_modules = re_modules_pattern.match(pr_title)
32 | if not re_modules:
33 | print(f"❌ Invalid PR title: '{pr_title}'")
34 | print("Expected format: [module] type: description")
35 | print(f"Allowed modules: {', '.join(allowed_modules)}")
36 | sys.exit(1)
37 | else:
38 | modules = re.findall(r"[a-z]+", re_modules.group(1).lower())
39 | if not all(module in allowed_modules for module in modules):
40 | invalid_modules = [module for module in modules if module not in allowed_modules]
41 | print(f"❌ Invalid modules: {', '.join(invalid_modules)}")
42 | print(f"Allowed modules: {', '.join(allowed_modules)}")
43 | sys.exit(1)
44 |
45 | types_pattern = "|".join(re.escape(t) for t in allowed_types)
46 | re_types_pattern = re.compile(rf"^\[[a-z_,\s]+\]\s+({types_pattern}):\s+.+$", re.IGNORECASE)
47 | match = re_types_pattern.match(pr_title)
48 |
49 | if not match:
50 | print(f"❌ Invalid PR title: '{pr_title}'")
51 | print("Expected format: [module] type: description")
52 | print(f"Allowed types: {', '.join(allowed_types)}")
53 | sys.exit(1)
54 |
55 | change_type = match.group(1).lower()
56 |
57 | print(f"✅ PR title is valid: {pr_title}, modules: {modules}, type: {change_type}")
58 |
--------------------------------------------------------------------------------
/tests/sanity/test_config_docs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | from pathlib import Path
17 |
18 | def validate_yaml_format(yaml_lines):
19 | errors = []
20 | i = 0
21 | prev_key_indent = None
22 |
23 | while i < len(yaml_lines):
24 | line = yaml_lines[i]
25 | stripped = line.strip()
26 |
27 | # Skip empty lines
28 | if stripped == "":
29 | i += 1
30 | continue
31 |
32 | # Match YAML keys like "field:" or "field: value"
33 | key_match = re.match(r'^(\s*)([a-zA-Z0-9_]+):', line)
34 | if key_match:
35 | indent = key_match.group(1)
36 | is_section_header = (i + 1 < len(yaml_lines) and yaml_lines[i + 1].strip() == "")
37 |
38 | # Check if there's a comment above
39 | if i == 0 or not yaml_lines[i - 1].strip().startswith("#"):
40 | errors.append(f"Missing comment above line {i+1}: {line.strip()}")
41 |
42 | # Check for inline comment
43 | if "#" in line and not stripped.startswith("#"):
44 | comment_index = line.index("#")
45 | colon_index = line.index(":")
46 | if comment_index > colon_index:
47 | errors.append(f"Inline comment found on line {i+1}: {line.strip()}")
48 |
49 | # Check for blank line after this key line (unless next is a deeper indent)
50 | if i + 1 < len(yaml_lines):
51 | next_line = yaml_lines[i + 1]
52 | next_stripped = next_line.strip()
53 |
54 | # If next is not empty and not a deeper nested line, enforce blank line
55 | if next_stripped != "":
56 | errors.append(f"Missing blank line after line {i+1}: {line.strip()}")
57 |
58 | i += 1
59 |
60 | return errors
61 |
62 |
63 | def test_trainer_config_doc():
64 | yaml_path = Path("mirorl/trainer/config/ppo_trainer.yaml") # path to your YAML file
65 | with open(yaml_path, "r") as f:
66 | lines = f.readlines()
67 |
68 | validation_errors = validate_yaml_format(lines)
69 | if validation_errors:
70 | print("YAML documentation format check failed:")
71 | print("Please read the top block of `verl/trainer/config/ppo_trainer.yaml` to see format rules:\n")
72 | for err in validation_errors:
73 | print(" -", err)
74 | else:
75 | print("YAML format check passed ✅")
76 |
--------------------------------------------------------------------------------
/tests/sanity/test_import.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def test_import():
17 | import verl
18 |
19 | print(verl.__version__)
20 |
21 |
22 | def test_single_controller_import():
23 | import verl.single_controller
24 |
25 | print(verl.single_controller.__version__)
26 |
--------------------------------------------------------------------------------
/tests/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Tests for the trainer module.
16 | """
--------------------------------------------------------------------------------
/tests/trainer/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Tests for the PPO trainer module.
16 | """
17 |
--------------------------------------------------------------------------------
/tests/trainer/ppo/test_core_algos.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import unittest
16 |
17 | import pytest
18 |
19 | import verl.trainer.ppo.core_algos
20 | from verl.trainer.ppo.core_algos import get_adv_estimator_fn, register_adv_est
21 |
22 |
23 | def mock_test_fn():
24 | pass
25 |
26 | class TestRegisterAdvEst(unittest.TestCase):
27 | def setUp(self):
28 | """Clear the registry before each test"""
29 | verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()
30 | verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = {
31 | "gae": lambda x: x * 2,
32 | "vtrace": lambda x: x + 1,
33 | }
34 | self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY
35 |
36 | def tearDown(self) -> None:
37 | verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()
38 | return super().tearDown()
39 |
40 | def test_register_new_function(self):
41 | """Test registering a new function with a string name"""
42 | @register_adv_est("test_estimator")
43 | def test_fn():
44 | pass
45 |
46 | self.assertIn("test_estimator", self.ADV_ESTIMATOR_REGISTRY)
47 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_estimator"], test_fn)
48 |
49 | def test_register_with_enum(self):
50 | """Test registering with an enum value (assuming AdvantageEstimator exists)"""
51 | from enum import Enum
52 | class AdvantageEstimator(Enum):
53 | TEST = "test_enum_estimator"
54 |
55 | @register_adv_est(AdvantageEstimator.TEST)
56 | def test_fn():
57 | pass
58 |
59 | self.assertIn("test_enum_estimator", self.ADV_ESTIMATOR_REGISTRY)
60 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_enum_estimator"], test_fn)
61 |
62 | def test_duplicate_registration_same_function(self):
63 | """Test that registering the same function twice doesn't raise an error"""
64 | register_adv_est("duplicate_test")(mock_test_fn)
65 | register_adv_est("duplicate_test")(mock_test_fn)
66 |
67 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["duplicate_test"], mock_test_fn)
68 |
69 | def test_duplicate_registration_different_function(self):
70 | """Test that registering different functions with same name raises ValueError"""
71 | @register_adv_est("conflict_test")
72 | def test_fn1():
73 | pass
74 |
75 | with self.assertRaises(ValueError):
76 | @register_adv_est("conflict_test")
77 | def test_fn2():
78 | pass
79 |
80 | def test_decorator_preserves_function(self):
81 | """Test that the decorator returns the original function"""
82 | def test_fn():
83 | return "original"
84 |
85 | decorated = register_adv_est("preserve_test")(test_fn)
86 | self.assertEqual(decorated(), "original")
87 |
88 | def test_multiple_registrations(self):
89 | """Test registering multiple different functions"""
90 | init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY)
91 | @register_adv_est("estimator1")
92 | def fn1():
93 | pass
94 |
95 | @register_adv_est("estimator2")
96 | def fn2():
97 | pass
98 |
99 | self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count)
100 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator1"], fn1)
101 | self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator2"], fn2)
102 |
103 | def test_get_adv_estimator_fn_valid_names(self):
104 | """Test that valid names return the correct function from registry."""
105 | # Test GAE
106 | gae_fn = get_adv_estimator_fn("gae")
107 | assert gae_fn(5) == 10 # 5 * 2 = 10
108 |
109 | # Test Vtrace
110 | vtrace_fn = get_adv_estimator_fn("vtrace")
111 | assert vtrace_fn(5) == 6 # 5 + 1 = 6
112 |
113 | def test_get_adv_estimator_fn_invalid_name(self):
114 | """Test that invalid names raise ValueError."""
115 | with pytest.raises(ValueError) as excinfo:
116 | get_adv_estimator_fn("invalid_name")
117 | assert "Unknown advantage estimator simply: invalid_name" in str(excinfo.value)
118 |
119 | def test_get_adv_estimator_fn_case_sensitive(self):
120 | """Test that name lookup is case-sensitive."""
121 | with pytest.raises(ValueError):
122 | get_adv_estimator_fn("GAE") # Different case
123 |
124 |
125 | if __name__ == "__main__":
126 | unittest.main()
--------------------------------------------------------------------------------
/tests/utils/cpu_tests/_test_module.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # Test module for import_utils.load_extern_type testing
17 | class TestClass:
18 | """A test class to be imported by load_extern_type"""
19 |
20 | def __init__(self, value=None):
21 | self.value = value or "default"
22 |
23 | def get_value(self):
24 | return self.value
25 |
26 |
27 | TEST_CONSTANT = "test_constant_value"
28 |
29 |
30 | def test_function():
31 | return "test_function_result"
32 |
--------------------------------------------------------------------------------
/tests/utils/cpu_tests/test_fs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from pathlib import Path
17 |
18 | import verl.utils.fs as fs
19 |
20 |
21 | def test_record_and_check_directory_structure(tmp_path):
22 | # Create test directory structure
23 | test_dir = tmp_path / "test_dir"
24 | test_dir.mkdir()
25 | (test_dir / "file1.txt").write_text("test")
26 | (test_dir / "subdir").mkdir()
27 | (test_dir / "subdir" / "file2.txt").write_text("test")
28 |
29 | # Create structure record
30 | record_file = fs._record_directory_structure(test_dir)
31 |
32 | # Verify record file exists
33 | assert os.path.exists(record_file)
34 |
35 | # Initial check should pass
36 | assert fs._check_directory_structure(test_dir, record_file) is True
37 |
38 | # Modify structure and verify check fails
39 | (test_dir / "new_file.txt").write_text("test")
40 | assert fs._check_directory_structure(test_dir, record_file) is False
41 |
42 |
43 | def test_copy_from_hdfs_with_mocks(tmp_path, monkeypatch):
44 | # Mock HDFS dependencies
45 | monkeypatch.setattr(fs, "is_non_local", lambda path: True)
46 |
47 | # side_effect will simulate the copy by creating parent dirs + empty file
48 | def fake_copy(src: str, dst: str, *args, **kwargs):
49 | dst_path = Path(dst)
50 | dst_path.parent.mkdir(parents=True, exist_ok=True)
51 | dst_path.write_bytes(b"") # touch an empty file
52 |
53 | monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy
54 |
55 | # Test parameters
56 | test_cache = tmp_path / "cache"
57 | hdfs_path = "hdfs://test/path/file.txt"
58 |
59 | # Test initial copy
60 | local_path = fs.copy_to_local(hdfs_path, cache_dir=test_cache)
61 | expected_path = os.path.join(test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path))
62 | assert local_path == expected_path
63 | assert os.path.exists(local_path)
64 |
65 |
66 | def test_always_recopy_flag(tmp_path, monkeypatch):
67 | # Mock HDFS dependencies
68 | monkeypatch.setattr(fs, "is_non_local", lambda path: True)
69 |
70 | copy_call_count = 0
71 |
72 | def fake_copy(src: str, dst: str, *args, **kwargs):
73 | nonlocal copy_call_count
74 | copy_call_count += 1
75 | dst_path = Path(dst)
76 | dst_path.parent.mkdir(parents=True, exist_ok=True)
77 | dst_path.write_bytes(b"")
78 |
79 | monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy
80 |
81 | test_cache = tmp_path / "cache"
82 | hdfs_path = "hdfs://test/path/file.txt"
83 |
84 | # Initial copy (always_recopy=False)
85 | fs.copy_to_local(hdfs_path, cache_dir=test_cache)
86 | assert copy_call_count == 1
87 |
88 | # Force recopy (always_recopy=True)
89 | fs.copy_to_local(hdfs_path, cache_dir=test_cache, always_recopy=True)
90 | assert copy_call_count == 2
91 |
92 | # Subsequent normal call (always_recopy=False)
93 | fs.copy_to_local(hdfs_path, cache_dir=test_cache)
94 | assert copy_call_count == 2 # Should not increment
95 |
--------------------------------------------------------------------------------
/tests/utils/cpu_tests/test_import_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import pytest
18 |
19 | from verl.utils.import_utils import load_extern_type
20 |
21 | # Path to the test module
22 | TEST_MODULE_PATH = os.path.join(os.path.dirname(__file__), "_test_module.py")
23 |
24 |
25 | def test_load_extern_type_class():
26 | """Test loading a class from an external file"""
27 | TestClass = load_extern_type(TEST_MODULE_PATH, "TestClass")
28 |
29 | # Verify the class was loaded correctly
30 | assert TestClass is not None
31 | assert TestClass.__name__ == "TestClass"
32 |
33 | # Test instantiation and functionality
34 | instance = TestClass()
35 | assert instance.value == "default"
36 |
37 | # Test with a custom value
38 | custom_instance = TestClass("custom")
39 | assert custom_instance.get_value() == "custom"
40 |
41 |
42 | def test_load_extern_type_function():
43 | """Test loading a function from an external file"""
44 | test_function = load_extern_type(TEST_MODULE_PATH, "test_function")
45 |
46 | # Verify the function was loaded correctly
47 | assert test_function is not None
48 | assert callable(test_function)
49 |
50 | # Test function execution
51 | result = test_function()
52 | assert result == "test_function_result"
53 |
54 |
55 | def test_load_extern_type_constant():
56 | """Test loading a constant from an external file"""
57 | constant = load_extern_type(TEST_MODULE_PATH, "TEST_CONSTANT")
58 |
59 | # Verify the constant was loaded correctly
60 | assert constant is not None
61 | assert constant == "test_constant_value"
62 |
63 |
64 | def test_load_extern_type_nonexistent_file():
65 | """Test behavior when file doesn't exist"""
66 | with pytest.raises(FileNotFoundError):
67 | load_extern_type("/nonexistent/path.py", "SomeType")
68 |
69 |
70 | def test_load_extern_type_nonexistent_type():
71 | """Test behavior when type doesn't exist in the file"""
72 | with pytest.raises(AttributeError):
73 | load_extern_type(TEST_MODULE_PATH, "NonExistentType")
74 |
75 |
76 | def test_load_extern_type_none_path():
77 | """Test behavior when file path is None"""
78 | result = load_extern_type(None, "SomeType")
79 | assert result is None
80 |
81 |
82 | def test_load_extern_type_invalid_module():
83 | """Test behavior when module has syntax errors"""
84 | # Create a temporary file with syntax errors
85 | import tempfile
86 |
87 | with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp_file:
88 | temp_file.write("This is not valid Python syntax :")
89 | temp_path = temp_file.name
90 |
91 | try:
92 | with pytest.raises(RuntimeError):
93 | load_extern_type(temp_path, "SomeType")
94 | finally:
95 | # Clean up the temporary file
96 | if os.path.exists(temp_path):
97 | os.remove(temp_path)
98 |
--------------------------------------------------------------------------------
/tests/utils/cpu_tests/test_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from types import SimpleNamespace # Or use a mock object library
16 |
17 | import pytest
18 |
19 | from verl.utils.model import update_model_config
20 |
21 |
22 | # Parametrize with different override scenarios
23 | @pytest.mark.parametrize(
24 | "override_kwargs",
25 | [
26 | {"param_a": 5, "new_param": "plain_added"},
27 | {"param_a": 2, "nested_params": {"sub_param_x": "updated_x", "sub_param_z": True}},
28 | ],
29 | )
30 | def test_update_model_config(override_kwargs):
31 | """
32 | Tests that update_model_config correctly updates attributes,
33 | handling both plain and nested overrides via parametrization.
34 | """
35 | # Create a fresh mock config object for each test case
36 | mock_config = SimpleNamespace(param_a=1, nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), other_param="keep_me")
37 | # Apply the updates using the parametrized override_kwargs
38 | update_model_config(mock_config, override_kwargs)
39 |
40 | # Assertions to check if the config was updated correctly
41 | if "nested_params" in override_kwargs: # Case 2: Nested override
42 | override_nested = override_kwargs["nested_params"]
43 | assert mock_config.nested_params.sub_param_x == override_nested["sub_param_x"], "Nested sub_param_x mismatch"
44 | assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged"
45 | assert hasattr(mock_config.nested_params, "sub_param_z"), "Expected nested sub_param_z to be added"
46 | assert mock_config.nested_params.sub_param_z == override_nested["sub_param_z"], "Value of sub_param_z mismatch"
47 | else: # Case 1: Plain override (nested params untouched)
48 | assert mock_config.nested_params.sub_param_x == "original_x", "Nested sub_param_x should be unchanged"
49 | assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged"
50 | assert not hasattr(mock_config.nested_params, "sub_param_z"), "Nested sub_param_z should not exist"
51 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/checkpoint/test_fsdp_ckpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import shutil
16 |
17 | import torch
18 | import torch.distributed
19 | from torch.distributed import init_device_mesh
20 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
21 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
22 | from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
23 |
24 | from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
25 | from verl.utils.distributed import initialize_global_process_group
26 | from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2
27 |
28 | ckpt_path = "./ci_checkpoints"
29 |
30 |
31 | def test_fsdp_ckpt(strategy="fsdp"):
32 | assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
33 | local_rank, rank, world_size = initialize_global_process_group()
34 | device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",))
35 |
36 | model_name = "Qwen/Qwen2.5-0.5B-Instruct"
37 | config = Qwen2Config(num_hidden_layers=1)
38 |
39 | with torch.device("cuda"):
40 | model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
41 | model = model.to(device="cuda")
42 |
43 | # Wrap model with FSDP
44 | if strategy == "fsdp":
45 | mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
46 |
47 | model = FSDP(
48 | model,
49 | use_orig_params=False,
50 | device_id=torch.cuda.current_device(),
51 | sharding_strategy=ShardingStrategy.FULL_SHARD,
52 | mixed_precision=mixed_precision,
53 | device_mesh=device_mesh,
54 | )
55 | else:
56 | mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True)
57 | fsdp_kwargs = {
58 | "mesh": device_mesh,
59 | "mp_policy": mp_policy,
60 | }
61 | apply_fsdp2(model, fsdp_kwargs, {})
62 |
63 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
64 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
65 |
66 | # Create checkpoint manager
67 | tokenizer = AutoTokenizer.from_pretrained(model_name)
68 | checkpoint_manager = FSDPCheckpointManager(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer)
69 |
70 | # Generate sample input
71 | batch_size = 2
72 | seq_len = 32
73 | vocab_size = 32000
74 | # First input for initial update
75 | input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
76 | attention_mask1 = torch.ones_like(input_ids1)
77 |
78 | # Second input for verification
79 | input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
80 | attention_mask2 = torch.ones_like(input_ids2)
81 |
82 | # Step 1: Initial update and save checkpoint
83 | outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1)
84 | loss1 = outputs1.logits.mean()
85 | loss1.backward()
86 | optimizer.step()
87 | lr_scheduler.step()
88 | optimizer.zero_grad()
89 |
90 | # Save checkpoint after first update
91 | checkpoint_path = os.path.join(ckpt_path, "checkpoint")
92 | checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)
93 |
94 | # Step 2: Second update and forward pass
95 | outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2)
96 | loss2 = outputs2.logits.mean()
97 | loss2.backward()
98 | optimizer.step()
99 | lr_scheduler.step()
100 | optimizer.zero_grad()
101 |
102 | # Record logits after second update
103 | with torch.no_grad():
104 | logits_before_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
105 |
106 | # Step 3: Load checkpoint and repeat second update
107 | checkpoint_manager.load_checkpoint(checkpoint_path)
108 |
109 | # Repeat the second update with same input
110 | outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2)
111 | loss3 = outputs3.logits.mean()
112 | loss3.backward()
113 | optimizer.step()
114 | lr_scheduler.step()
115 | optimizer.zero_grad()
116 |
117 | # Record logits after loaded checkpoint and update
118 | with torch.no_grad():
119 | logits_after_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
120 |
121 | # Step 4: Verify outputs match
122 | torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0)
123 | print("Checkpoint save/load test passed!")
124 |
125 | # Cleanup
126 | shutil.rmtree(ckpt_path, ignore_errors=True)
127 | torch.distributed.barrier()
128 | torch.distributed.destroy_process_group()
129 |
130 |
131 | if __name__ == "__main__":
132 | strategy = os.environ.get("STRATEGY", "fsdp")
133 | test_fsdp_ckpt(strategy=strategy)
134 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/dataset/test_rl_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 |
16 | import torch
17 | from omegaconf import OmegaConf
18 | from torch.utils.data import DataLoader
19 |
20 |
21 | def get_gsm8k_data():
22 | # prepare test dataset
23 | local_folder = os.path.expanduser("~/verl-data/gsm8k/")
24 | local_path = os.path.join(local_folder, "train.parquet")
25 | os.makedirs(local_folder, exist_ok=True)
26 | return local_path
27 |
28 |
29 | def test_rl_dataset():
30 | from verl.utils import hf_tokenizer
31 | from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
32 |
33 | tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct")
34 | local_path = get_gsm8k_data()
35 | config = OmegaConf.create(
36 | {
37 | "prompt_key": "prompt",
38 | "max_prompt_length": 256,
39 | "filter_overlong_prompts": True,
40 | "filter_overlong_prompts_workers": 2,
41 | }
42 | )
43 | dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config)
44 |
45 | dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)
46 |
47 | a = next(iter(dataloader))
48 |
49 | from verl import DataProto
50 |
51 | tensors = {}
52 | non_tensors = {}
53 |
54 | for key, val in a.items():
55 | if isinstance(val, torch.Tensor):
56 | tensors[key] = val
57 | else:
58 | non_tensors[key] = val
59 |
60 | data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)
61 | assert "input_ids" in data_proto.batch
62 |
63 | data = dataset[0]["input_ids"]
64 | output = tokenizer.batch_decode([data])[0]
65 | print(f"type: type{output}")
66 | print(f"\n\noutput: {output}")
67 |
68 |
69 | def test_image_rl_data():
70 | from verl.utils import hf_processor, hf_tokenizer
71 | from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
72 |
73 | tokenizer = hf_tokenizer("Qwen/Qwen2-VL-2B-Instruct")
74 | processor = hf_processor("Qwen/Qwen2-VL-2B-Instruct")
75 | config = OmegaConf.create(
76 | {
77 | "prompt_key": "prompt",
78 | "max_prompt_length": 1024,
79 | "filter_overlong_prompts": True,
80 | "filter_overlong_prompts_workers": 2,
81 | }
82 | )
83 | dataset = RLHFDataset(
84 | data_files=os.path.expanduser("~/data/geo3k/train.parquet"),
85 | tokenizer=tokenizer,
86 | config=config,
87 | processor=processor,
88 | )
89 |
90 | dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)
91 |
92 | a = next(iter(dataloader))
93 |
94 | from verl import DataProto
95 |
96 | tensors = {}
97 | non_tensors = {}
98 |
99 | for key, val in a.items():
100 | if isinstance(val, torch.Tensor):
101 | tensors[key] = val
102 | else:
103 | non_tensors[key] = val
104 |
105 | data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)
106 |
107 | assert "multi_modal_data" in data_proto.non_tensor_batch
108 | assert "multi_modal_inputs" in data_proto.non_tensor_batch
109 |
110 | data = dataset[0]["input_ids"]
111 | output = tokenizer.batch_decode([data])[0]
112 | print(f"type: type{output}")
113 | print(f"\n\noutput: {output}")
114 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/dataset/test_rm_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 |
16 | from verl.utils import hf_tokenizer
17 | from verl.utils.dataset.rm_dataset import RMDataset
18 |
19 |
20 | def get_rm_data():
21 | # prepare test dataset
22 | local_folder = os.path.expanduser("~/verl-data/full_hh_rlhf/rm/")
23 | local_path = os.path.join(local_folder, "test.parquet")
24 | os.makedirs(local_folder, exist_ok=True)
25 | return local_path
26 |
27 |
28 | def test_rm_dataset():
29 | tokenizer = hf_tokenizer("facebook/opt-1.3b")
30 | local_path = get_rm_data()
31 | dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512)
32 | data = dataset[0]["input_ids"]
33 | output = tokenizer.batch_decode(data)
34 | assert len(output) > 1
35 | assert isinstance(output[0], str)
36 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/dataset/test_sft_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 |
16 | from verl.utils import hf_tokenizer
17 | from verl.utils.dataset.sft_dataset import SFTDataset
18 |
19 |
20 | def get_gsm8k_data():
21 | # prepare test dataset
22 | local_folder = os.path.expanduser("~/verl-data/gsm8k/")
23 | local_path = os.path.join(local_folder, "train.parquet")
24 | return local_path
25 |
26 |
27 | def test_sft_cot_dataset():
28 | tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")
29 | local_path = get_gsm8k_data()
30 | from omegaconf import OmegaConf
31 |
32 | dataset = SFTDataset(
33 | parquet_files=local_path,
34 | tokenizer=tokenizer,
35 | config=OmegaConf.create(
36 | {
37 | "prompt_key": "prompt",
38 | "prompt_dict_keys": ["content"],
39 | "response_key": "extra_info",
40 | "response_dict_keys": ["answer"],
41 | "max_length": 512,
42 | }
43 | ),
44 | )
45 |
46 | data = dataset[0]["input_ids"]
47 | output = tokenizer.batch_decode([data])[0]
48 | assert len(output) > 1
49 | assert isinstance(output, str)
50 |
51 |
52 | def test_sft_dataset():
53 | tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")
54 | local_path = get_gsm8k_data()
55 | from omegaconf import OmegaConf
56 |
57 | dataset = SFTDataset(
58 | parquet_files=local_path,
59 | tokenizer=tokenizer,
60 | config=OmegaConf.create(
61 | {
62 | "prompt_key": "extra_info",
63 | "prompt_dict_keys": ["question"],
64 | "response_key": "extra_info",
65 | "response_dict_keys": ["answer"],
66 | "max_length": 512,
67 | }
68 | ),
69 | )
70 |
71 | data = dataset[0]["input_ids"]
72 | output = tokenizer.batch_decode([data])[0]
73 | assert len(output) > 1
74 | assert isinstance(output, str)
75 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/megatron/test_pipeline_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from verl.utils.megatron.pipeline_parallel import make_batch_generator
16 |
17 |
18 | def test_make_batch_generator_no_vpp():
19 | batches = [1, 2, 3]
20 | vpp_size = 1
21 | generator = make_batch_generator(batches, vpp_size)
22 | assert list(generator) == batches
23 |
24 |
25 | def test_make_batch_generator_with_vpp():
26 | batches = [{"data": 1}, {"data": 2}]
27 | vpp_size = 2
28 | generators = make_batch_generator(batches, vpp_size)
29 | assert isinstance(generators, list)
30 | assert len(generators) == vpp_size
31 |
32 | # Check each generator yields the original batches
33 | for gen in generators:
34 | assert list(gen) == batches
35 |
36 |
37 | def test_make_batch_generator_empty():
38 | batches = []
39 | vpp_size = 1
40 | generator = make_batch_generator(batches, vpp_size)
41 | assert list(generator) == []
42 |
43 | vpp_size = 3
44 | generators = make_batch_generator(batches, vpp_size)
45 | assert len(generators) == vpp_size
46 | for gen in generators:
47 | assert list(gen) == []
48 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/test_activation_offload.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import shutil
16 |
17 | import pytest
18 | import torch
19 | import torch.distributed
20 | import torch.multiprocessing as mp
21 | from torch.distributed import init_device_mesh
22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
23 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
24 | from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
25 |
26 | from verl.utils.activation_offload import enable_activation_offloading
27 | from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
28 | from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy
29 |
30 | ckpt_path = "./ci_checkpoints"
31 |
32 | def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy="fsdp"):
33 | torch.cuda.set_device(rank)
34 | torch.distributed.init_process_group(
35 | backend="nccl",
36 | init_method=f"file://{rendezvous_file}",
37 | rank=rank,
38 | world_size=world_size,
39 | )
40 | device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",))
41 |
42 | model_name = "Qwen/Qwen2.5-0.5B-Instruct"
43 | config = Qwen2Config(num_hidden_layers=4)
44 |
45 | with torch.device("cuda"):
46 | model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
47 | model = model.to(device="cuda")
48 |
49 | # Wrap model with FSDP
50 | mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
51 |
52 | if strategy == "fsdp":
53 | model = FSDP(model, use_orig_params=False, device_id=torch.cuda.current_device(), sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mixed_precision, device_mesh=device_mesh, auto_wrap_policy=get_fsdp_wrap_policy(module=model))
54 | else:
55 | mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True)
56 | fsdp_kwargs = {
57 | "mesh": device_mesh,
58 | "mp_policy": mp_policy,
59 | }
60 | apply_fsdp2(model, fsdp_kwargs, {})
61 |
62 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
63 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
64 |
65 | # Create checkpoint manager
66 | tokenizer = AutoTokenizer.from_pretrained(model_name)
67 | checkpoint_manager = FSDPCheckpointManager(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer)
68 |
69 | # Generate sample input
70 | batch_size = 2
71 | seq_len = 32
72 | vocab_size = 32000
73 | # First input for initial update
74 | input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
75 | attention_mask1 = torch.ones_like(input_ids1)
76 |
77 | # Second input for verification
78 | input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
79 | attention_mask2 = torch.ones_like(input_ids2)
80 |
81 | # Step 1: Initial update and save checkpoint
82 | outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1)
83 | loss1 = outputs1.logits.mean()
84 | loss1.backward()
85 | optimizer.step()
86 | lr_scheduler.step()
87 | optimizer.zero_grad()
88 |
89 | # Save checkpoint after first update
90 | checkpoint_path = os.path.join(ckpt_path, "checkpoint")
91 | checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)
92 |
93 | # Step 2: Second update and forward pass
94 | outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2)
95 | loss2 = outputs2.logits.mean()
96 | loss2.backward()
97 | optimizer.step()
98 | lr_scheduler.step()
99 | optimizer.zero_grad()
100 |
101 | # Record logits after second update
102 | with torch.no_grad():
103 | logits_without_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
104 |
105 | # Step 3: wrap module with activation offloading and load checkpoint
106 | enable_activation_offloading(model, "fsdp")
107 | checkpoint_manager.load_checkpoint(checkpoint_path)
108 |
109 | # Step 4: Repeat the second update with same input
110 | outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2)
111 | loss3 = outputs3.logits.mean()
112 | loss3.backward()
113 | optimizer.step()
114 | lr_scheduler.step()
115 | optimizer.zero_grad()
116 |
117 | # Record logits after loaded checkpoint and update
118 | with torch.no_grad():
119 | logits_with_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
120 |
121 | # Step 4: Verify outputs match
122 | torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0)
123 | print(f"Activaiton offloading for {strategy} test passed on {world_size} GPUs!")
124 |
125 | # Cleanup
126 | shutil.rmtree(ckpt_path, ignore_errors=True)
127 | torch.distributed.barrier()
128 | torch.distributed.destroy_process_group()
129 |
130 |
131 | @pytest.mark.parametrize("world_size", (2, 4))
132 | @pytest.mark.parametrize("strategy", ("fsdp", "fsdp2"))
133 | def test_activation_offloading(world_size, strategy, tmp_path):
134 | rendezvous_file = str(tmp_path / "rdzv_file")
135 | os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
136 |
137 | mp.spawn(
138 | fn=_fsdp_activation_offloading_test,
139 | args=(world_size, rendezvous_file, strategy),
140 | nprocs=world_size,
141 | join=True,
142 | )
143 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/test_flops_counter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 |
17 | import pytest
18 |
19 | from verl.utils.flops_counter import FlopsCounter
20 |
21 | VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3"}
22 |
23 |
24 | class Config:
25 | def __init__(self, config_dict):
26 | for key, value in config_dict.items():
27 | setattr(self, key, value)
28 |
29 |
30 | CONFIG = {
31 | "llama": {
32 | "config": { # llama2-7B
33 | "model_type": "llama",
34 | "vocab_size": 32000,
35 | "hidden_size": 4096,
36 | "intermediate_size": 11008,
37 | "num_hidden_layers": 32,
38 | "num_attention_heads": 32,
39 | "num_key_value_heads": 32,
40 | },
41 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
42 | # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim
43 | # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*32*4096
44 | # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*32*4096
45 | "expected_flops_tuple": (153555818250240 / 1e12, 575955114393600 / 1e12),
46 | },
47 | "qwen2": {
48 | "config": { # Qwen/Qwen2.5-7B-Instruct
49 | "model_type": "qwen2",
50 | "vocab_size": 152064,
51 | "hidden_size": 3584,
52 | "intermediate_size": 18944,
53 | "num_hidden_layers": 28,
54 | "num_attention_heads": 28,
55 | "num_key_value_heads": 4,
56 | },
57 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
58 | # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim
59 | # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*28*3584
60 | # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*28*3584
61 | "expected_flops_tuple": (170388331954176 / 1e12, 622070178250752 / 1e12),
62 | },
63 | "qwen3": {
64 | "config": { # Qwen/Qwen3-8B
65 | "model_type": "qwen3",
66 | "vocab_size": 151936,
67 | "hidden_size": 4096,
68 | "intermediate_size": 12288,
69 | "num_hidden_layers": 36,
70 | "num_attention_heads": 32,
71 | "num_key_value_heads": 8,
72 | "head_dim": 128,
73 | },
74 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
75 | # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim
76 | # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*36*128*32
77 | # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*36*128*32
78 | "expected_flops_tuple": (185867930959872 / 1e12, 692924253732864 / 1e12),
79 | },
80 | "qwen3_moe": {
81 | "config": { # Qwen/Qwen3-30B-A3B-Base
82 | "model_type": "qwen3_moe",
83 | "hidden_size": 2048,
84 | "vocab_size": 151936,
85 | "num_hidden_layers": 48,
86 | "num_key_value_heads": 4,
87 | "num_attention_heads": 32,
88 | "head_dim": 128,
89 | "moe_intermediate_size": 768,
90 | "num_experts_per_tok": 8,
91 | "num_experts": 128,
92 | },
93 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
94 | # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 + hidden*num_experts))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim
95 | # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*48*128*32
96 | # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*48*128*32
97 | "expected_flops_tuple": (85087060230144 / 1e12, 365944098521088 / 1e12),
98 | },
99 | "deepseek_v3": {
100 | "config": { # deepseek-ai/DeepSeek-Prover-V2-671B
101 | "model_type": "deepseek_v3",
102 | "hidden_size": 7168,
103 | "vocab_size": 129280,
104 | "moe_intermediate_size": 2048,
105 | "num_hidden_layers": 61,
106 | "first_k_dense_replace": 3,
107 | "num_attention_heads": 128,
108 | "n_routed_experts": 256,
109 | "num_experts_per_tok": 8,
110 | "n_shared_experts": 1,
111 | "kv_lora_rank": 512,
112 | "qk_rope_head_dim": 64,
113 | "v_head_dim": 128,
114 | "intermediate_size": 18432,
115 | "qk_nope_head_dim": 128,
116 | "q_lora_rank": 1536,
117 | },
118 | "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
119 | # (1536*7168+128*192*1536+7168*(512+64)+128*(128+128)*512+128*128*7168) = 187105280
120 | # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*61*192*128
121 | # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*61*192*128
122 | "expected_flops_tuple": (906535995703296 / 1e12, 3674028304760832 / 1e12),
123 | },
124 | }
125 |
126 |
127 | @pytest.mark.parametrize(
128 | "config_type",
129 | ["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3"],
130 | )
131 | def test_flops_counter(config_type: str):
132 | test_config = CONFIG[config_type]
133 | config = Config(test_config["config"])
134 | flops_counter = FlopsCounter(config)
135 | for batch_seqlens, expected_flops in zip(test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"]):
136 | # set delta time to 1 to get the flops
137 | counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)
138 | print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
139 | assert math.isclose(counted_flops, expected_flops), f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}"
140 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/test_seqlen_balancing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.distributed as dist
17 | import torch.multiprocessing as mp
18 |
19 | from verl import DataProto
20 | from verl.utils.model import create_random_mask
21 | from verl.utils.seqlen_balancing import ceildiv, get_reverse_idx, rearrange_micro_batches
22 |
23 |
24 | def test_seqlen_balancing():
25 | input_ids = torch.randint(low=0, high=10, size=(20, 100))
26 |
27 | attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5)
28 | data = {"input_ids": input_ids, "attention_mask": attention_mask}
29 | dataproto = DataProto.from_single_dict(data)
30 | micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300)
31 | batch = torch.cat(micro_batches)
32 | micro_bsz_idx = []
33 | for idx in micro_bsz_idx_lst:
34 | micro_bsz_idx.extend(idx)
35 | reverse_idx_map = get_reverse_idx(micro_bsz_idx)
36 | reverse_idx_map = torch.tensor(reverse_idx_map)
37 | new_batch = batch[reverse_idx_map]
38 | torch.testing.assert_close(new_batch, dataproto.batch)
39 |
40 |
41 | def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb):
42 | # 1) init process group & CUDA
43 | torch.cuda.set_device(rank)
44 | dist.init_process_group(
45 | backend="nccl",
46 | init_method=init_method,
47 | world_size=world_size,
48 | rank=rank,
49 | )
50 |
51 | # 2) build a small random batch (each rank different length to force mismatch)
52 | torch.manual_seed(42 + rank)
53 | input_ids = torch.randint(0, 10, (20 + rank * 5, 100), device=f"cuda:{rank}")
54 | attention_mask = create_random_mask(
55 | input_ids=input_ids,
56 | max_ratio_of_left_padding=0.1,
57 | max_ratio_of_valid_token=0.9,
58 | min_ratio_of_valid_token=0.5,
59 | )
60 | dp = {"input_ids": input_ids, "attention_mask": attention_mask}
61 | proto = DataProto.from_single_dict(dp)
62 | batch = proto.batch
63 |
64 | # 3) call rearrange_micro_batches with one of the two params under test
65 | micros, idx_lst = rearrange_micro_batches(
66 | batch,
67 | max_token_len=max_token_len,
68 | dp_group=dist.group.WORLD,
69 | same_micro_num_in_dp=use_same_dp,
70 | min_num_micro_batch=min_mb,
71 | )
72 |
73 | # 4) check the enforced counts
74 | seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
75 | total_seqlen = seq_len_effective.sum().item()
76 | local = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))
77 |
78 | if min_mb is not None:
79 | expected = max(local, min_mb)
80 | assert len(micros) == expected
81 | if use_same_dp:
82 | # gather all local_counts
83 | counts = [torch.zeros(1, device=f"cuda:{rank}") for _ in range(world_size)]
84 | counts[rank].fill_(local)
85 | dist.all_gather(counts, counts[rank])
86 | expected = max(int(c.item()) for c in counts)
87 | assert len(micros) == expected
88 | else:
89 | # if neither, we get the local natural count
90 | assert len(micros) == local
91 |
92 | # 5) reconstruction sanity: concat→reverse_idx→orig
93 | flat = torch.cat(micros, dim=0)
94 | idx = []
95 | for sub in idx_lst:
96 | idx.extend(sub)
97 | inv = get_reverse_idx(idx)
98 | inv = torch.tensor(inv, device=flat.device)
99 | reconstructed = flat[inv]
100 | torch.testing.assert_close(reconstructed, batch)
101 |
102 | dist.destroy_process_group()
103 |
104 |
105 | def test_seqlen_balancing_distributed_params(tmp_path):
106 | world_size = 2
107 | init_file = tmp_path / "dist_init"
108 | init_file.write_text("") # empty file
109 | init_method = f"file://{init_file}"
110 |
111 | # test min_num_micro_batch only
112 | mp.spawn(
113 | _worker,
114 | args=(world_size, init_method, 300, False, 4),
115 | nprocs=world_size,
116 | join=True,
117 | )
118 |
119 | # test same_micro_num_in_dp only
120 | mp.spawn(
121 | _worker,
122 | args=(world_size, init_method, 300, True, None),
123 | nprocs=world_size,
124 | join=True,
125 | )
126 |
--------------------------------------------------------------------------------
/tests/utils/gpu_tests/test_torch_functional.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import pytest
18 | import torch
19 | import torch.distributed as dist
20 | import torch.multiprocessing as mp
21 |
22 | from verl.utils.torch_functional import distributed_masked_mean, distributed_mean_max_min_std
23 |
24 |
25 | def _worker_mean(rank: int, world_size: int, rendezvous_file: str):
26 | # 1) set GPU and init NCCL
27 | torch.cuda.set_device(rank)
28 | dist.init_process_group(
29 | backend="nccl",
30 | init_method=f"file://{rendezvous_file}",
31 | rank=rank,
32 | world_size=world_size,
33 | )
34 |
35 | # each rank holds tensor [rank+1]
36 | local = torch.tensor([float(rank + 1)], device=f"cuda:{rank}")
37 | mean, gmax, gmin, gstd = distributed_mean_max_min_std(local, True, True, True)
38 |
39 | values = [float(i + 1) for i in range(world_size)]
40 | exp_mean = sum(values) / len(values)
41 | exp_max = max(values)
42 | exp_min = min(values)
43 | var = sum((x - exp_mean) ** 2 for x in values) / (len(values) - 1)
44 | exp_std = var**0.5
45 |
46 | # all ranks should see the same result
47 | assert torch.allclose(mean.cpu(), torch.tensor(exp_mean)), f"mean@{rank}"
48 | assert torch.allclose(gmax.cpu(), torch.tensor(exp_max)), f"max@{rank}"
49 | assert torch.allclose(gmin.cpu(), torch.tensor(exp_min)), f"min@{rank}"
50 | assert torch.allclose(gstd.cpu(), torch.tensor(exp_std)), f"std@{rank}"
51 |
52 | dist.destroy_process_group()
53 |
54 |
55 | @pytest.mark.parametrize("world_size", [2, 4])
56 | def test_distributed_mean_max_min_std(world_size, tmp_path):
57 | rendezvous_file = str(tmp_path / "rdzv_mean")
58 | os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
59 |
60 | mp.spawn(
61 | fn=_worker_mean,
62 | args=(world_size, rendezvous_file),
63 | nprocs=world_size,
64 | join=True,
65 | )
66 |
67 |
68 | def _worker_mask(rank: int, world_size: int, rendezvous_file: str):
69 | torch.cuda.set_device(rank)
70 | dist.init_process_group(
71 | backend="nccl",
72 | init_method=f"file://{rendezvous_file}",
73 | rank=rank,
74 | world_size=world_size,
75 | )
76 |
77 | # build per‐rank tensor and mask
78 | local_tensor = torch.tensor([rank * 2 + 1.0, rank * 2 + 2.0], device=f"cuda:{rank}")
79 | if rank == 0:
80 | mask = torch.tensor([1, 0], device=f"cuda:{rank}", dtype=torch.float32)
81 | else:
82 | mask = torch.tensor([0, 1], device=f"cuda:{rank}", dtype=torch.float32)
83 |
84 | gmean = distributed_masked_mean(local_tensor, mask)
85 |
86 | valid_values = [1.0] + [2 * i + 2.0 for i in range(1, world_size)]
87 | expected_mean = sum(valid_values) / len(valid_values)
88 | assert torch.allclose(gmean.cpu(), torch.tensor(expected_mean)), f"masked_mean@{rank}"
89 |
90 | dist.destroy_process_group()
91 |
92 |
93 | @pytest.mark.parametrize("world_size", [2, 4])
94 | def test_distributed_masked_mean(world_size, tmp_path):
95 | rendezvous_file = str(tmp_path / "rdzv_mask")
96 | os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
97 |
98 | mp.spawn(
99 | fn=_worker_mask,
100 | args=(world_size, rendezvous_file),
101 | nprocs=world_size,
102 | join=True,
103 | )
104 |
--------------------------------------------------------------------------------
/tests/workers/reward_manager/test_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import pytest
16 |
17 | # Assuming REWARD_MANAGER_REGISTRY is defined somewhere in the module
18 | from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY, get_reward_manager_cls, register
19 |
20 |
21 | @pytest.fixture
22 | def setup():
23 | """Setup test cases with a mock registry."""
24 | REWARD_MANAGER_REGISTRY.clear()
25 | REWARD_MANAGER_REGISTRY.update({
26 | "manager1": "Manager1Class",
27 | "manager2": "Manager2Class"
28 | })
29 | return REWARD_MANAGER_REGISTRY
30 |
31 | def test_get_existing_manager(setup):
32 | """Test getting an existing reward manager class."""
33 | assert get_reward_manager_cls("manager1") == "Manager1Class"
34 | assert get_reward_manager_cls("manager2") == "Manager2Class"
35 |
36 | def test_get_nonexistent_manager(setup):
37 | """Test getting a non-existent reward manager raises ValueError."""
38 | with pytest.raises(ValueError) as excinfo:
39 | get_reward_manager_cls("unknown_manager")
40 | assert "Unknown reward manager: unknown_manager" in str(excinfo.value)
41 |
42 | def test_case_sensitivity(setup):
43 | """Test that manager names are case-sensitive."""
44 | with pytest.raises(ValueError):
45 | get_reward_manager_cls("MANAGER1")
46 | with pytest.raises(ValueError):
47 | get_reward_manager_cls("Manager1")
48 |
49 | def test_empty_registry(setup):
50 | """Test behavior when registry is empty."""
51 | REWARD_MANAGER_REGISTRY.clear()
52 | with pytest.raises(ValueError) as excinfo:
53 | get_reward_manager_cls("any_manager")
54 | assert "Unknown reward manager: any_manager" in str(excinfo.value)
55 |
56 | def test_register_new_class(setup):
57 | """Test registering a new class with the decorator."""
58 | @register("test_manager")
59 | class TestManager:
60 | pass
61 |
62 | assert "test_manager" in REWARD_MANAGER_REGISTRY
63 | assert REWARD_MANAGER_REGISTRY["test_manager"] == TestManager
64 |
65 | def test_register_different_classes_same_name(setup):
66 | """Test that registering different classes with same name raises ValueError."""
67 | @register("conflict_manager")
68 | class Manager1:
69 | pass
70 |
71 | with pytest.raises(ValueError) as context:
72 | @register("conflict_manager")
73 | class Manager2:
74 | pass
75 |
76 | assert REWARD_MANAGER_REGISTRY["conflict_manager"] == Manager1
77 |
78 | def test_decorator_returns_original_class(setup):
79 | """Test that the decorator returns the original class unchanged."""
80 | @register("return_test")
81 | class OriginalClass:
82 | def method(setup):
83 | return 42
84 |
85 | assert OriginalClass().method() == 42
86 | assert REWARD_MANAGER_REGISTRY["return_test"] == OriginalClass
87 |
--------------------------------------------------------------------------------
/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config:
--------------------------------------------------------------------------------
1 | tools:
2 | - class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool"
3 | config:
4 | sandbox_fusion_url: "https://xxx.apigateway-cn-beijing.volceapi.com/run_code"
5 | tool_schema:
6 | type: "function"
7 | function:
8 | name: "code_interpreter"
9 | description: "A tool for executing code."
10 | parameters:
11 | type: "object"
12 | properties:
13 | code:
14 | type: "string"
15 | description: "The code to execute."
16 | required: ["code"]
--------------------------------------------------------------------------------
/tests/workers/rollout/resource/tool_configs/search_tool_config:
--------------------------------------------------------------------------------
1 | tools:
2 | - class_name: verl.tools.search_tool.SearchTool
3 | config:
4 | retrieval_service_url: http://127.0.0.1:8000/retrieve
5 | num_workers: 120
6 | rate_limit: 120
7 | timeout: 30
8 | tool_schema:
9 | type: function
10 | function:
11 | name: search
12 | description: Searches the web for relevant information based on the given query.
13 | parameters:
14 | type: object
15 | properties:
16 | query_list:
17 | type: array
18 | item:
19 | type: string
20 | description: A list of fully-formed semantic queries. The tool will return search results for each query.
21 | required:
22 | - query_list
--------------------------------------------------------------------------------
/tests/workers/rollout/test_sglang_async_rollout_w_tools.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 SGLang Team
2 | # Copyright 2025 ModelBest Inc. and/or its affiliates
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | usage: torchrun --standalone --nnodes=1 \
17 | --nproc_per_node=2 $(which pytest) \
18 | -s test_sglang_async_rollout_w_tools.py
19 | """
20 |
21 | import numpy as np
22 | import torch
23 | from tensordict import TensorDict
24 | from torch.distributed.device_mesh import init_device_mesh
25 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
26 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
27 | from utils_sglang import (
28 | are_lists_similar,
29 | clean_torchelastic_env,
30 | generate_hf_output,
31 | get_rollout_config,
32 | initialize_global_process_group,
33 | load_tokenizer_and_model,
34 | prepare_inputs,
35 | )
36 |
37 | from verl import DataProto
38 | from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
39 | from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
40 |
41 |
42 | def test_async_sglang_rollout_w_tool():
43 | assert torch.cuda.device_count() >= 2
44 | initialize_global_process_group()
45 | clean_torchelastic_env()
46 |
47 | max_prompt_length = 32
48 | max_response_length = 16
49 | dtype = "bfloat16"
50 | tensor_parallel_size = 2
51 | local_model_path = "Qwen/Qwen2.5-0.5B"
52 |
53 | tokenizer, actor_model = load_tokenizer_and_model(local_model_path)
54 |
55 | preencode_prompts = [
56 | [{"role": "user", "content": prompt, "tool_calls": None}]
57 | for prompt in [
58 | "Who won the Champions League in 2019?",
59 | "The founder of Apple is",
60 | "What's the best way to learn python?",
61 | ]
62 | ]
63 | prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts]
64 | input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length)
65 |
66 | hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)
67 |
68 | fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",))
69 | inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp"))
70 |
71 | fsdp_model = FSDP(
72 | actor_model,
73 | use_orig_params=True,
74 | device_id=fsdp_device_mesh["fsdp"].get_local_rank(),
75 | mixed_precision=MixedPrecision(param_dtype=getattr(torch, dtype)),
76 | sharding_strategy=ShardingStrategy.FULL_SHARD,
77 | device_mesh=fsdp_device_mesh,
78 | )
79 |
80 | rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, "./resource/tool_configs/sandbox_fusion_tool_config")
81 | rollout = SGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config)
82 |
83 | rollout_sharding_manager = FSDPSGLangShardingManager(
84 | module=fsdp_model,
85 | inference_engine=rollout._engine,
86 | model_config=actor_model.config,
87 | full_params=True,
88 | device_mesh=inference_device_mesh_cpu,
89 | )
90 |
91 | with rollout_sharding_manager:
92 | prompt_dict = TensorDict(
93 | {
94 | "input_ids": input_ids,
95 | "attention_mask": attention_mask,
96 | "position_ids": position_ids,
97 | },
98 | batch_size=input_ids.shape[0],
99 | )
100 | print(f"preprocessed {input_ids.shape=}")
101 |
102 | messages = np.asarray(preencode_prompts)
103 | prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": np.array([{}] * input_ids.shape[0], dtype=object)})
104 |
105 | prompts.meta_info.update(
106 | {
107 | "eos_token_id": tokenizer.eos_token_id,
108 | "pad_token_id": tokenizer.pad_token_id,
109 | }
110 | )
111 |
112 | prompts = rollout_sharding_manager.preprocess_data(prompts)
113 | # log_gpu_memory_usage("Before generating sequences", logger=None)
114 | output = rollout.generate_sequences(prompts=prompts)
115 | print(f"generated {output.batch['responses'].shape=}")
116 | # log_gpu_memory_usage("After generating sequences", logger=None)
117 | output = rollout_sharding_manager.postprocess_data(output)
118 | print(f"postprocessed {output.batch['responses'].shape=}")
119 | sglang_output = output.to("cpu")
120 |
121 | sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch["responses"])
122 |
123 | print(f"hf response: {hf_response_tokens}")
124 | print(f"sglang response: {sglang_response_tokens}")
125 | assert are_lists_similar(hf_response_tokens, sglang_response_tokens)
126 | print("SGLang w tool Test Passed!")
127 |
128 | torch.distributed.barrier()
129 | torch.distributed.destroy_process_group()
130 |
131 |
132 | if __name__ == "__main__":
133 | test_async_sglang_rollout_w_tool()
134 |
--------------------------------------------------------------------------------
/tests/workers/rollout/test_sglang_spmd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 SGLang Team
2 | # Copyright 2025 ModelBest Inc. and/or its affiliates
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | usage: torchrun --standalone --nnodes=1 \
17 | --nproc_per_node=2 $(which pytest) \
18 | -s test_sglang_async_spmd.py
19 | """
20 |
21 | import asyncio
22 |
23 | import torch
24 | from sglang.srt.entrypoints.engine import Engine
25 | from sglang.srt.utils import broadcast_pyobj
26 | from torch.distributed.device_mesh import init_device_mesh
27 | from utils_sglang import (
28 | are_lists_similar,
29 | clean_torchelastic_env,
30 | generate_hf_output,
31 | initialize_global_process_group,
32 | load_tokenizer_and_model,
33 | prepare_inputs,
34 | )
35 |
36 |
37 | def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
38 | non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
39 | token_ids = prompt_token_ids[non_pad_index:].tolist()
40 | return token_ids
41 |
42 |
43 | def test_sglang_spmd():
44 | assert torch.cuda.device_count() >= 2
45 | initialize_global_process_group(spmd=True)
46 | clean_torchelastic_env()
47 |
48 | max_prompt_length = 16
49 | max_response_length = 16
50 |
51 | local_model_path = "Qwen/Qwen2.5-0.5B"
52 | tokenizer, actor_model = load_tokenizer_and_model(local_model_path)
53 |
54 | preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"]
55 | input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)
56 |
57 | hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)
58 |
59 | tensor_parallel_size = 2
60 | inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
61 | tp_rank = inference_device_mesh_cpu["tp"].get_local_rank()
62 |
63 | if tp_rank == 0:
64 | llm = Engine(
65 | model_path=local_model_path,
66 | dtype="bfloat16",
67 | mem_fraction_static=0.5,
68 | enable_memory_saver=True,
69 | tp_size=inference_device_mesh_cpu["tp"].size(),
70 | )
71 |
72 | input_ids = input_ids.cuda()
73 | idx_list = []
74 |
75 | pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
76 | for i in range(input_ids.shape[0]):
77 | idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
78 |
79 | sampling_params = dict(
80 | n=1,
81 | temperature=0,
82 | top_p=1,
83 | top_k=-1,
84 | max_new_tokens=max_response_length,
85 | presence_penalty=0.0,
86 | frequency_penalty=0.0,
87 | repetition_penalty=1.0,
88 | skip_special_tokens=True,
89 | spaces_between_special_tokens=True,
90 | ignore_eos=False,
91 | )
92 |
93 | loop = asyncio.get_event_loop()
94 | outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))
95 | else:
96 | outputs = None
97 |
98 | [outputs] = broadcast_pyobj(
99 | [outputs],
100 | rank=inference_device_mesh_cpu["tp"].get_local_rank(),
101 | src=inference_device_mesh_cpu["tp"].mesh[0].item(),
102 | dist_group=inference_device_mesh_cpu["tp"].get_group(),
103 | force_cpu_device=False,
104 | )
105 |
106 | sglang_response_tokens = [output["text"] for output in outputs]
107 |
108 | print(f"sglang response: {sglang_response_tokens}")
109 | assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n"
110 | print("SPMD Test Passed!")
111 |
112 | torch.distributed.barrier()
113 | torch.distributed.destroy_process_group()
114 |
--------------------------------------------------------------------------------
/tests/workers/rollout/utils_sglang.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 SGLang Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | from datetime import timedelta
16 |
17 | import torch
18 | from omegaconf import OmegaConf
19 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
20 |
21 | from verl.utils.model import compute_position_id_with_mask
22 | from verl.utils.torch_functional import pad_sequence_to_length
23 |
24 |
25 | # ====================== utils ======================
26 | def levenshtein(s1, s2):
27 | m, n = len(s1), len(s2)
28 | dp = [[0] * (n + 1) for _ in range(m + 1)]
29 | for i in range(m + 1):
30 | dp[i][0] = i
31 | for j in range(n + 1):
32 | dp[0][j] = j
33 | for i in range(1, m + 1):
34 | for j in range(1, n + 1):
35 | cost = 0 if s1[i - 1] == s2[j - 1] else 1
36 | dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost)
37 | return dp[m][n]
38 |
39 |
40 | def are_lists_similar(a, b, threshold=10):
41 | if len(a) != len(b):
42 | print("The lists are of different lengths.")
43 | return False
44 | total_length = 0
45 | total_diff = 0
46 | for s1, s2 in zip(a, b):
47 | max_len = max(len(s1), len(s2))
48 | total_length += max_len
49 | total_diff += levenshtein(s1, s2)
50 | percentage_difference = (total_diff / total_length) * 100
51 | print(f"Total difference: {percentage_difference:.2f}%")
52 | return percentage_difference <= threshold
53 |
54 |
55 | def initialize_global_process_group(timeout_second=36000, spmd=False):
56 | import torch.distributed
57 |
58 | if not torch.distributed.is_initialized(): # Check if already initialized
59 | print("Initializing process group...")
60 | torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))
61 | else:
62 | print("Process group already initialized.")
63 |
64 | local_rank = int(os.environ["LOCAL_RANK"])
65 | rank = int(os.environ["RANK"])
66 | world_size = int(os.environ["WORLD_SIZE"])
67 | torch.cuda.set_device(local_rank)
68 |
69 | CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "")
70 | if not CUDA_VISIBLE_DEVICES:
71 | if spmd:
72 | # CUDA_VISIBLE_DEVICES = ','.join(str(i) for i in range(tensor_parallel_size))
73 | CUDA_VISIBLE_DEVICES = ",".join(str(i) for i in range(world_size))
74 | else:
75 | CUDA_VISIBLE_DEVICES = str(local_rank)
76 | os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
77 | print(f"CUDA_VISIBLE_DEVICES is not set, set to {CUDA_VISIBLE_DEVICES}")
78 |
79 | return local_rank, rank, world_size
80 |
81 |
82 | def clean_torchelastic_env():
83 | for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
84 | if k in os.environ:
85 | del os.environ[k]
86 |
87 |
88 | def load_tokenizer_and_model(local_model_path, dtype="bfloat16"):
89 | tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
90 | tokenizer.pad_token = tokenizer.eos_token
91 | model = AutoModelForCausalLM.from_pretrained(local_model_path, torch_dtype=getattr(torch, dtype), device_map="cuda")
92 | return tokenizer, model
93 |
94 |
95 | def prepare_inputs(tokenizer, prompts, max_prompt_length):
96 | pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
97 | tokenized = tokenizer(prompts, return_tensors="pt", padding=True)
98 | input_ids = pad_sequence_to_length(tokenized["input_ids"], max_prompt_length, pad_token_id, left_pad=True)
99 | attention_mask = pad_sequence_to_length(tokenized["attention_mask"], max_prompt_length, pad_token_id=0, left_pad=True)
100 | position_ids = compute_position_id_with_mask(attention_mask)
101 | position_ids = pad_sequence_to_length(position_ids, max_prompt_length, pad_token_id=0, left_pad=True)
102 | return input_ids, attention_mask, position_ids
103 |
104 |
105 | def generate_hf_output(model, input_ids, attention_mask, tokenizer, max_response_length):
106 | generation_config = GenerationConfig(do_sample=False)
107 | output = model.generate(
108 | input_ids=input_ids.cuda(),
109 | attention_mask=attention_mask.cuda(),
110 | max_new_tokens=max_response_length,
111 | eos_token_id=tokenizer.eos_token_id,
112 | pad_token_id=tokenizer.pad_token_id,
113 | generation_config=generation_config,
114 | output_scores=False,
115 | return_dict_in_generate=True,
116 | use_cache=False,
117 | )
118 | seq = output.sequences
119 | response = seq[:, input_ids.shape[1] :]
120 | return tokenizer.batch_decode(response)
121 |
122 |
123 | def get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_config_path):
124 | sampling_params = dict(
125 | n=1,
126 | temperature=0,
127 | top_p=1,
128 | top_k=-1,
129 | max_new_tokens=max_response_length,
130 | presence_penalty=0.0,
131 | frequency_penalty=0.0,
132 | repetition_penalty=1.0,
133 | skip_special_tokens=True,
134 | spaces_between_special_tokens=True,
135 | ignore_eos=False,
136 | )
137 |
138 | rollout_config = OmegaConf.create(
139 | {
140 | "name": "sglang",
141 | "load_format": "dummy_dtensor",
142 | "enforce_eager": False,
143 | "free_cache_engine": False,
144 | "dtype": dtype,
145 | "gpu_memory_utilization": 0.5,
146 | "ignore_eos": False,
147 | "max_num_batched_tokens": 8192,
148 | "prompt_length": max_prompt_length,
149 | "response_length": max_response_length,
150 | "tensor_model_parallel_size": tensor_parallel_size,
151 | "multi_turn": {
152 | "max_turns": 4,
153 | "enable": True,
154 | "tool_config_path": tool_config_path,
155 | "use_inference_chat_template": False,
156 | "enable_tokenization_sanity_check": True,
157 | },
158 | "max_model_len": None,
159 | **sampling_params,
160 | }
161 | )
162 |
163 | return rollout_config
164 |
--------------------------------------------------------------------------------