├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── README_zh.md ├── docker ├── Dockerfile ├── Dockerfile.rocm └── patch │ ├── megatron-sandwich_norm.patch │ └── sglang.patch ├── docs ├── en │ ├── agent_training.md │ ├── amd_tutorial.md │ ├── build.md │ ├── debug.md │ ├── models │ │ ├── glm4-9B.md │ │ ├── qwen3-30B-A3B.md │ │ └── qwen3-4B.md │ ├── qa.md │ ├── rollout_buffer_usage.md │ ├── sft.md │ └── usage.md └── zh │ ├── agent_training.md │ ├── build.md │ ├── debug.md │ ├── models │ ├── glm4-9B.md │ ├── qwen3-30B-A3B.md │ └── qwen3-4B.md │ ├── qa.md │ ├── rollout_buffer_usage.md │ ├── sft.md │ └── usage.md ├── examples └── search-r1 │ ├── README.md │ ├── README_zh.md │ ├── generate_with_search.py │ ├── google_search_server.py │ ├── qa_em_format.py │ └── run_qwen2.5_3B.sh ├── imgs └── arch.png ├── pyproject.toml ├── requirements.txt ├── scripts ├── agent-example.sh ├── models │ ├── glm4-32B.sh │ ├── glm4-9B.sh │ ├── moonlight.sh │ ├── qwen2.5-1.5B.sh │ ├── qwen2.5-32B.sh │ ├── qwen2.5-3B.sh │ ├── qwen2.5-7B.sh │ ├── qwen3-0.6B.sh │ ├── qwen3-235B-A22B.sh │ ├── qwen3-30B-A3B.sh │ ├── qwen3-4B.sh │ └── qwen3-8B.sh ├── run-glm4-9B.sh ├── run-qwen3-235B-A22B-sft.sh ├── run-qwen3-235B-A22B.sh ├── run-qwen3-30B-A3B.sh ├── run-qwen3-4B-amd.sh ├── run-qwen3-4B-base-sft.sh ├── run-qwen3-4B-rf-baseline.sh ├── run-qwen3-4B-rf.sh ├── run-qwen3-4B.sh └── run_agent.sh ├── setup.py ├── slime ├── __init__.py ├── backends │ ├── megatron_utils │ │ ├── __init__.py │ │ ├── arguments.py │ │ ├── checkpoint.py │ │ ├── cp_utils.py │ │ ├── data.py │ │ ├── initialize.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── models │ │ │ └── __init__.py │ │ └── update_weight_utils.py │ └── sglang_utils │ │ ├── __init__.py │ │ ├── arguments.py │ │ ├── http_server_engine.py │ │ └── sglang_engine.py ├── ray │ ├── __init__.py │ ├── buffer.py │ ├── placement_group.py │ ├── ppo_actor.py │ ├── ray_actor.py │ ├── rollout.py │ └── utils.py ├── rollout │ ├── agent_rollout.py │ ├── filter_hub │ │ ├── __init__.py │ │ ├── dynamic_sampling_filters.py │ │ └── over_sampling_filters.py │ ├── rm_hub │ │ ├── __init__.py │ │ ├── deepscaler.py │ │ ├── f1.py │ │ ├── math_dapo_utils.py │ │ └── math_utils.py │ ├── sft_example.py │ └── sglang_example.py └── utils │ ├── __init__.py │ ├── arguments.py │ ├── async_utils.py │ ├── data.py │ ├── distributed_utils.py │ ├── flops_utils.py │ ├── http_utils.py │ ├── mask_utils.py │ ├── memory_utils.py │ ├── misc.py │ ├── ppo_utils.py │ ├── seqlen_balancing.py │ ├── timer.py │ └── types.py ├── slime_plugins ├── __init__.py ├── mbridge │ ├── __init__.py │ └── glm4.py ├── models │ ├── __init__.py │ └── glm4.py └── rollout_buffer │ ├── buffer.py │ ├── generator │ ├── __init__.py │ ├── base_generator.py │ ├── reward_utils │ │ ├── __init__.py │ │ └── math_utils.py │ └── utils │ │ ├── arguments.py │ │ └── default_func.py │ └── tools │ ├── assign_instance_id.py │ └── visualizer.py ├── tests └── test_qwen3_0.6B.sh ├── tools ├── convert_hf_to_torch_dist.py ├── convert_to_hf.py └── convert_torch_dist_to_hf.py ├── train.py └── train_async.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | wandb/ 177 | outputs/ 178 | local/ 179 | **/rollout_data/ 180 | **/buffer_stats/ 181 | *.out 182 | *.pkl 183 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | autoupdate_schedule: quarterly 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.5.0 12 | hooks: 13 | - id: check-yaml 14 | - id: check-case-conflict 15 | - id: detect-private-key 16 | - id: check-added-large-files 17 | args: ['--maxkb=1000'] 18 | - id: requirements-txt-fixer 19 | 20 | - repo: https://github.com/PyCQA/autoflake 21 | rev: v2.0.2 22 | hooks: 23 | - id: autoflake 24 | args: [--remove-all-unused-imports, --in-place] 25 | 26 | - repo: https://github.com/PyCQA/isort 27 | rev: 5.13.2 28 | hooks: 29 | - id: isort 30 | name: Format imports 31 | exclude: docs/ 32 | 33 | - repo: https://github.com/psf/black 34 | rev: 24.3.0 35 | hooks: 36 | - id: black 37 | name: Format code 38 | additional_dependencies: ['click==8.0.2'] 39 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # slime 2 | 3 | [English](./README.md) 4 | 5 | **slime** 是为 RL scaling 设计的 LLM post‑training 框架,提供两大核心能力: 6 | 7 | 1. **高性能训练**:通过连接 Megatron 与 SGLang,支持各种模式的高效训练; 8 | 2. **灵活的数据生成**:通过自定义数据生成接口以及 server based engine,实现任意的数据训练数据生成流程。 9 | 10 | ## 目录 11 | 12 | - [架构总览](#架构总览) 13 | - [快速开始](#快速开始) 14 | - [环境准备](#环境准备) 15 | - [示例](#示例) 16 | - [Dense 模型示例:GLM-4-9B 与 Qwen3-4B](#Dense-模型示例GLM-4-9B-与-Qwen3-4B) 17 | - [MoE 模型示例:Qwen3-30B-A3B](#MoE-模型示例Qwen3-30B-A3B) 18 | - [多轮对话 + 工具调用示例:Search-R1 lite](#多轮对话--工具调用示例Search-R1-lite) 19 | - [SFT 示例:Qwen3-4B-Base + OpenHermes-2.5](#SFT-示例Qwen3-4B-Base--OpenHermes-25) 20 | - [Checkpoint 格式转换](#checkpoint-格式转换) 21 | - [启动训练流程](#启动训练流程) 22 | - [参数说明](#参数说明) 23 | - [开发指南](#开发指南) 24 | - [常见 Q&A 与致谢](#常见-qa-与致谢) 25 | 26 | ## 架构总览 27 | 28 |  29 | 30 | **模块说明**: 31 | 32 | - **training (Megatron)**:负责主训练流程,从 Data Buffer 读取数据,训练完后将参数同步至 rollout 模块; 33 | - **rollout (SGLang + router)**:生成新数据(含 reward/verifier),存储至 Data Buffer; 34 | - **data buffer**:桥梁模块,管理 prompt 初始化、自定义数据与 rollout 生成方法。 35 | 36 | ## 快速开始 37 | 38 | ### 环境准备 39 | 40 | 基于镜像 zhuzilin/slime:latest(已预装 SGLang 0.4.7 和 Megatron): 41 | 42 | ```bash 43 | docker run --rm --gpus all --ipc=host --shm-size=16g \ 44 | --ulimit memlock=-1 --ulimit stack=67108864 \ 45 | -it zhuzilin/slime:latest /bin/bash 46 | 47 | git clone https://github.com/THUDM/slime.git 48 | cd slime 49 | pip install -e . 50 | ``` 51 | 52 | - 对于不方便使用 docker 的场景,请参考 [从零搭建环境](./docs/zh/build.md); 53 | - 对于 AMD 支持,请参考 [AMD 使用教程](./docs/en/amd_tutorial.md)。 54 | 55 | ### 示例 56 | 57 | #### Dense 模型示例:GLM-4-9B 与 Qwen3-4B 58 | 59 | 我们提供了 [GLM-4-9B](https://huggingface.co/THUDM/GLM-Z1-9B-0414) 和 [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B) 的使用示例,可以通过他们对 slime 的使用方法有个基本的了解: 60 | 61 | - [示例:GLM-4-9B 模型](docs/zh/models/glm4-9B.md) 62 | - [示例:Qwen3-4B 模型](docs/zh/models/qwen3-4B.md) 63 | 64 | #### MoE 模型示例:Qwen3-30B-A3B 65 | 66 | 我们也提供了 MoE 模型的示例,请查看: 67 | 68 | - [示例:Qwen3-30B-A3B 模型](docs/zh/models/qwen3-30B-A3B.md) 69 | 70 | #### 多轮对话 + 工具调用示例:Search-R1 lite 71 | 72 | 针对多轮对话和工具调用场景,我们提供了一个简化版的 Search-R1 复现,请查看: 73 | 74 | - [示例:Search-R1 lite](examples/search-r1/README_zh.md) 75 | 76 | #### SFT 示例:Qwen3-4B-Base + OpenHermes-2.5 77 | 78 | slime is not just a RL framework, we support a diverse set of post-training setups. For an SFT example, please refer to: 79 | 80 | slime 不仅仅是一个 RL 框架,我们还支持了各种后训练流程。如果想使用 SFT,请参看: 81 | 82 | - [示例: Qwen3-4B-Base + OpenHermes-2.5](docs/zh/sft.md). 83 | 84 | ### Checkpoint 格式转换 85 | 86 | 由于 slime 使用 megatron,而 megatron 不支持加载 huggingface checkpoint,我们需要将模型转换至 megatron 可以支持的 torch_dist 格式。 87 | 88 | #### HF → Megatron torch_dist ckpt 89 | 90 | 使用 [mbridge](https://github.com/ISEEKYAN/mbridge.git) 转换: 91 | 92 | ```bash 93 | cd slime/ 94 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ 95 | --hf-checkpoint /root/GLM-Z1-9B-0414 \ 96 | --save /root/GLM-Z1-9B-0414_torch_dist 97 | ``` 98 | 99 | 在遇到 mbridge 暂时不支持的模型的时候,可以考虑使用 [Pai-Megatron-Patch](https://github.com/alibaba/Pai-Megatron-Patch) 进行转换。 100 | 101 | ⚠️ 如果出现找不到 slime 的问题,请在 slime 目录下 `pip install -e .`。 102 | 103 | #### Megatron torch_dist → HF ckpt 104 | 105 | 将训练过程中的存储的 torch_dist ckpt 转为 hf ckpt: 106 | 107 | ```bash 108 | cd slime/ 109 | PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \ 110 | --input-dir /path/to/torch_dist_ckpt/iter_xxx/ \ 111 | --output-dir /root/GLM-Z1-9B-0414-iter_xxx \ 112 | --origin-hf-dir /root/GLM-Z1-9B-0414 113 | ``` 114 | 115 | ⚠️ 由于 mbridge 转换的 torch_dist ckpt 目前不保存 args,不能基于上一步的 torch_dist ckpt 反转回 HF。 116 | 117 | #### 任意 Megatron ckpt → HF 118 | 119 | 适用于自定义保存格式(如 `--ckpt-format torch`)。 120 | 121 | 转化方式的原理是直接复用训练中,从 megatron 向 sglang 更新参数的函数,也就是直接复用一下训练脚本,将原先的: 122 | 123 | ```bash 124 | ray job submit --address="http://127.0.0.1:8265" \ 125 | --runtime-env-json='{ 126 | "env_vars": { ...} 127 | }' \ 128 | -- python3 train.py \ 129 | ... # 其他训练 args 130 | ``` 131 | 132 | 改成: 133 | 134 | ```bash 135 | torchrun --nproc_per_node ${NUM_GPU} tools/convert_to_hf.py \ 136 | --load /your/saved/megatron_ckpt \ 137 | --output-dir /your/converted/hf_ckpt \ 138 | ... # 其他训练 args 139 | ``` 140 | 141 | 即,保持所有的参数不变,将: 142 | 143 | 1. 任务启动从 ray 变成 torchrun,把 gpu 数量保存为 megatron 并行的不带 dp 的最小 gpu 数,例如如果是 tp4,就设成 4; 144 | 2. 确认把 `--load` 改成了需要 load 的路径; 145 | 3. 增加 `--output-dir` 对应要保存的 hf_ckpt。 146 | 147 | ## 启动训练流程 148 | 149 | 整个程序需要使用 ray 进行启动,首先需要启动一个 ray 集群,即在 node 0 运行: 150 | 151 | ```bash 152 | # Node0(HEAD) 153 | ray start --head --node-ip-address ${MASTER_ADDR} \ 154 | --num-gpus 8 --disable-usage-stats 155 | 156 | # 其他 Node 157 | ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 158 | ``` 159 | 160 | 在 ray 集群启动后,可以在 node 0 提交任务,例如: 161 | 162 | ```bash 163 | ray job submit --address="http://127.0.0.1:8265" \ 164 | --runtime-env-json='{ 165 | "env_vars": { 166 | "PYTHONPATH": "/root/Megatron-LM/", 167 | ... # e.g. no_proxy、接口变量等 168 | } 169 | }' \ 170 | -- python3 train.py \ 171 | --...(其他 Megatron/SGLang/slime 参数) 172 | ``` 173 | 174 | #### 参数说明 175 | 176 | 参数分为三类: 177 | 178 | 1. **megatron 参数**:slime 会读取 `PYTHONPATH` 中的 megatron 里设置的所有参数,可以通过传入如 `--tensor-model-parallel-size 2` 的方式配置 megatron; 179 | 2. **sglang 参数**:支持环境中安装的 sglang 的所有参数,这些参数需要以 `--sglang` 起始,例如 `--mem-fraction-static` 需要通过 `--sglang-mem-fraction-static` 传入。 180 | 3. **slime 自身的参数**:请见:[slime/utils/arguments.py](slime/utils/arguments.py) 181 | 182 | 完整使用说明请查阅 [使用文档](docs/zh/usage.md)。 183 | 184 | ## 开发指南 185 | 186 | - **欢迎贡献!** 若有功能建议、性能调优或使用体验反馈,欢迎提交 Issue / PR 😊 187 | 188 | - 使用 [pre-commit](https://pre-commit.com/) 保证提交代码风格: 189 | 190 | ```bash 191 | apt install pre-commit -y 192 | pre-commit install 193 | ``` 194 | 195 | - 调试技巧请参考 [debug 指南](docs/zh/debug.md) 196 | 197 | ## 常见 Q&A 与致谢 198 | 199 | - 常见问题请见 [Q&A](docs/zh/qa.md) 200 | - 特别感谢以下项目 & 社区:SGLang、Megatron‑LM、mbridge、OpenRLHF、veRL 等。 201 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM lmsysorg/sglang:dev AS base 2 | 3 | # TODO: change to pip install sglang-router after it has a new release 4 | RUN pip install sglang-router --force-reinstall 5 | RUN pip install ray[default] 6 | RUN pip install httpx[http2] wandb pylatexenc blobfile accelerate "mcp[cli]" 7 | RUN pip install git+https://github.com/zhuzilin/cumem_allocator.git 8 | 9 | # mbridge 10 | RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps 11 | 12 | RUN TORCH_CUDA_ARCH_LIST="8.0;8.9;9.0;9.0a" pip install git+https://github.com/fanshiqing/grouped_gemm@v1.1.4 13 | # apex 14 | RUN NVCC_APPEND_FLAGS="--threads 4" \ 15 | pip -v install --disable-pip-version-check --no-cache-dir \ 16 | --no-build-isolation \ 17 | --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git 18 | # transformer engine 19 | RUN pip -v install transformer_engine[pytorch] 20 | # flash attn 21 | # the newest version megatron supports is v2.7.4.post1 22 | RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl 23 | 24 | WORKDIR /root/ 25 | RUN git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ 26 | cd Megatron-LM && \ 27 | pip install -e . 28 | 29 | # sandwitch norm for GLM models 30 | COPY patch/megatron-sandwich_norm.patch /root/Megatron-LM/ 31 | RUN cd Megatron-LM && \ 32 | git apply megatron-sandwich_norm.patch --3way && \ 33 | if grep -R -n '^<<<<<<< ' .; then \ 34 | echo "Patch failed to apply cleanly. Please resolve conflicts." && \ 35 | exit 1; \ 36 | fi && \ 37 | rm megatron-sandwich_norm.patch 38 | 39 | # sglang patch 40 | COPY patch/sglang.patch /sgl-workspace/sglang/ 41 | RUN cd /sgl-workspace/sglang && \ 42 | git apply sglang.patch && \ 43 | if grep -R -n '^<<<<<<< ' .; then \ 44 | echo "Patch failed to apply cleanly. Please resolve conflicts." && \ 45 | exit 1; \ 46 | fi && \ 47 | rm sglang.patch 48 | -------------------------------------------------------------------------------- /docker/patch/sglang.patch: -------------------------------------------------------------------------------- 1 | diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py 2 | index 25104bd..6853784 100644 3 | --- a/python/sglang/srt/configs/model_config.py 4 | +++ b/python/sglang/srt/configs/model_config.py 5 | @@ -405,14 +405,14 @@ class ModelConfig: 6 | quant_method = quant_cfg.get("quant_method", "").lower() 7 | 8 | # Detect which checkpoint is it 9 | - for _, method in QUANTIZATION_METHODS.items(): 10 | - quantization_override = method.override_quantization_method( 11 | - quant_cfg, self.quantization 12 | - ) 13 | - if quantization_override: 14 | - quant_method = quantization_override 15 | - self.quantization = quantization_override 16 | - break 17 | + #for _, method in QUANTIZATION_METHODS.items(): 18 | + # quantization_override = method.override_quantization_method( 19 | + # quant_cfg, self.quantization 20 | + # ) 21 | + # if quantization_override: 22 | + # quant_method = quantization_override 23 | + # self.quantization = quantization_override 24 | + # break 25 | 26 | # Verify quantization configurations. 27 | if self.quantization is None: 28 | diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py 29 | index f2c0d61..e548f39 100644 30 | --- a/python/sglang/srt/layers/quantization/fp8.py 31 | +++ b/python/sglang/srt/layers/quantization/fp8.py 32 | @@ -340,10 +340,10 @@ class Fp8LinearMethod(LinearMethodBase): 33 | return 34 | else: 35 | weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data 36 | - layer.weight = torch.nn.Parameter(weight, requires_grad=False) 37 | - layer.weight_scale_inv = torch.nn.Parameter( 38 | - weight_scale, requires_grad=False 39 | - ) 40 | + # layer.weight = torch.nn.Parameter(weight, requires_grad=False) 41 | + # layer.weight_scale_inv = torch.nn.Parameter( 42 | + # weight_scale, requires_grad=False 43 | + # ) 44 | return 45 | 46 | layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) 47 | -------------------------------------------------------------------------------- /docs/en/build.md: -------------------------------------------------------------------------------- 1 | # Setting up the Environment from Scratch 2 | 3 | [中文版](../zh/build.md) 4 | 5 | If it is inconvenient to directly use our pre-built image, we provide the following solution for setting up the environment: 6 | 7 | ## Setting up the environment based on anaconda / mamba 8 | 9 | Here, we take micromamba as an example to build a conda environment named `slime` within the official sglang image `lmsysorg/sglang:latest`: 10 | 11 | ```bash 12 | #################### 13 | # create conda 14 | #################### 15 | yes '' | "${SHELL}" <(curl -L micro.mamba.pm/install.sh) 16 | source ~/.bashrc 17 | micromamba self-update 18 | 19 | micromamba create -n slime python=3.10 pip -c conda-forge -y 20 | # install cuda-12.6.0 as this is the default cuda version for pytorch 21 | # and apex need this alignment. 22 | micromamba install -n slime cuda cuda-nvtx cuda-nvtx-dev -c nvidia/label/cuda-12.6.0 -y 23 | micromamba install -n slime -c conda-forge cudnn -y 24 | micromamba run -n slime pip install cmake ninja 25 | 26 | #################### 27 | # sglang deps 28 | #################### 29 | cd /root/ 30 | git clone https://github.com/sgl-project/sglang.git --branch v0.4.9 --depth 1 31 | cd /root/sglang/ 32 | micromamba run -n slime pip -v install -e "python[all]" 33 | # TODO: change to pip install sglang-router after it has a new release 34 | micromamba run -n slime pip install https://github.com/zhuzilin/sgl-router/releases/download/dev/sglang_router-0.1.4-cp310-cp310-linux_x86_64.whl --force-reinstall 35 | 36 | #################### 37 | # megatron deps 38 | #################### 39 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" micromamba run -n slime \ 40 | pip -v install --no-build-isolation \ 41 | git+https://github.com/fanshiqing/grouped_gemm@v1.1.4 42 | # apex 43 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" NVCC_APPEND_FLAGS="--threads 4" \ 44 | micromamba run -n slime \ 45 | pip -v install --disable-pip-version-check --no-cache-dir \ 46 | --no-build-isolation \ 47 | --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git 48 | # transformer engine 49 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" micromamba run -n slime \ 50 | pip -v install transformer_engine[pytorch] 51 | # flash attn 52 | # the newest version megatron supports is v2.7.4.post1 53 | micromamba run -n slime pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl 54 | # megatron 55 | cd /root/ 56 | git clone https://github.com/NVIDIA/Megatron-LM.git 57 | cd Megatron-LM/ 58 | micromamba run -n slime pip install -e . 59 | 60 | #################### 61 | # other deps 62 | #################### 63 | micromamba run -n slime pip install git+https://github.com/zhuzilin/cumem_allocator.git --no-build-isolation 64 | 65 | #################### 66 | # slime 67 | #################### 68 | cd /root/ 69 | git clone https://github.com/THUDM/slime.git 70 | cd slime/ 71 | micromamba run -n slime pip install -e . 72 | # apply patch 73 | cd /root/sglang 74 | git apply /root/slime/docker/patch/sglang.patch 75 | ``` 76 | -------------------------------------------------------------------------------- /docs/en/debug.md: -------------------------------------------------------------------------------- 1 | # Debugging Guide 2 | 3 | [中文版](../zh/debug.md) 4 | 5 | ## Aligning Precision 6 | 7 | During the development of slime, it is often necessary to check if the model's precision is correct. This can be verified in the following ways: 8 | 9 | 1. **First Training Step** 10 | 1. Check if the generated `rollout` is coherent. If not, there are two possible reasons: 11 | * Parameters were not loaded correctly. You need to check the logs for a confirmation that Megatron successfully loaded the checkpoint (ckpt). 12 | * There was an error in updating the parameters. You can check if all parameters were converted and mapped correctly, or if the parameter names were converted according to the parallelization strategy (e.g., when `pp_size > 1`, check if the layer IDs for the parameters provided by the second stage are correct). A thorough method is to save all parameters in the `load_weights` implementation of the corresponding model in SGLang and verify that they are consistent with the loaded checkpoint. 13 | * If all parameters are updated correctly and the problem persists, it's possible that some special buffers in SGLang were released during the release process. 14 | * If you are testing with a pretrained model, you can switch to an instruct version of a model with the same architecture to see if this garbled output is specific to the pretrained model. 15 | 16 | 2. Check the printed rollout stats to see if `log_probs` and `ref_log_probs` are exactly equal (meaning KL divergence is 0 in the first step) and their values are small. 17 | * If they are not exactly equal, it is usually caused by certain non-deterministic kernels in the Transformer Engine, for example: 18 | * In some versions of Transformer Engine (TE), Megatron requires `--attention-backend flash` to enforce the use of Flash Attention, thereby avoiding numerical instability from the fused attention under Context Parallelism (CP). 19 | * If the values are large (e.g., > 1), there are generally two possibilities: 20 | * If the value is extremely large, there is likely a problem with the training configuration. 21 | * If the value is only slightly larger than the SFT loss, for example, if the log probability of an instruct model reaches 0.8, it might be because the data does not conform to the trained chat template or does not match the cold-start distribution. 22 | 23 | 3. When running one inference step per training step (`num_steps_per_rollout == 1`), check if the KL divergence is 0 and if the `grad_norm` is small. 24 | * This is basically due to some Megatron / TE related bugs, for example: 25 | * Mixture of Experts (MoE) requires enabling `--moe-permute-fusion`. 26 | 27 | 2. **Second Training Step** 28 | 1. For integrated training and inference, check if the second step can be loaded correctly and whether it results in an Out of Memory (OOM) error. 29 | 30 | ## Separate Debugging for Training and Inference 31 | 32 | slime supports debugging the training and inference parts separately, which allows for the following: 33 | 34 | * When tuning/debugging the inference part, you can start the task with only a few GPUs. 35 | * When tuning/debugging the training part, you can ensure the model input is fixed, removing the randomness of rollouts. 36 | 37 | Specifically, slime currently provides the following parameters for separate debugging: 38 | 39 | 1. `--debug-rollout-only` 40 | 41 | When enabled, slime will not load Megatron and will only initialize SGLang. You can use this method to debug the inference part. 42 | 43 | 2. `--debug-train-only` 44 | 45 | When enabled, slime will not load SGLang and will only initialize Megatron. You can use this method to debug the training part. 46 | 47 | 3. `--save-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt` 48 | 49 | When enabled, the results of each rollout will be saved. This can be used in conjunction with `--debug-rollout-only`. Note that the data is saved using the format: `args.save_debug_rollout_data.format(rollout_id=rollout_id)`. 50 | 51 | 4. `--load-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt` 52 | 53 | When enabled, data will be loaded from `args.load_debug_rollout_data.format(rollout_id=rollout_id)`, and SGLang will not be initialized (automatically setting `debug_train_only=True`). This method allows you to fix the input for the training part to tune it, for example, by switching between different parallelization strategies. -------------------------------------------------------------------------------- /docs/en/models/qwen3-30B-A3B.md: -------------------------------------------------------------------------------- 1 | # Example: Qwen3-30B-A3B Model 2 | 3 | [中文版](../../zh/models/qwen3-30B-A3B.md) 4 | 5 | ## Environment Preparation 6 | 7 | The environment setup, model download, data, and checkpoint conversion are the same as for the Qwen3-4B model. You can refer to [Example: Qwen3-4B Model](./qwen3-4B.md), replacing mentions of Qwen3-4B with Qwen3-30B-A3B. 8 | 9 | ## Run Training 10 | 11 | Execute the training script: 12 | 13 | ```bash 14 | cd /root/slime 15 | bash scripts/run-qwen3-30B-A3B.sh 16 | ``` 17 | 18 | ### Parameter Introduction 19 | 20 | Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B.sh](../../../scripts/run-qwen3-30B-A3B.sh) script. 21 | 22 | 1. To support running Qwen3-30B-A3B in an 8xH800 environment, we need to enable Megatron's CPU Adam to save GPU memory. The corresponding configuration is: 23 | 24 | ```bash 25 | OPTIMIZER_ARGS=( 26 | ... 27 | --optimizer-cpu-offload 28 | --overlap-cpu-optimizer-d2h-h2d 29 | --use-precision-aware-optimizer 30 | ) 31 | ``` 32 | 33 | 2. Enable MoE optimization supported by Megatron. The current configuration is tp4, ep8: 34 | 35 | ```bash 36 | PERF_ARGS=( 37 | --tensor-model-parallel-size 4 38 | --sequence-parallel 39 | --pipeline-model-parallel-size 1 40 | --context-parallel-size 1 41 | --expert-model-parallel-size 8 42 | --expert-tensor-parallel-size 1 43 | ... 44 | ) 45 | ``` 46 | 47 | 3. Enable MoE optimization supported by SGLang. The current configuration is ep8: 48 | 49 | ```bash 50 | SGLANG_ARGS=( 51 | --rollout-num-gpus-per-engine 8 52 | --sglang-mem-fraction-static 0.5 53 | --sglang-enable-ep-moe 54 | --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) 55 | ) 56 | ``` 57 | 58 | Similarly, you can also add DP attention, for example, by configuring: 59 | 60 | ```bash 61 | --sglang-enable-dp-attention 62 | --sglang-dp-size 8 63 | ``` 64 | -------------------------------------------------------------------------------- /docs/en/qa.md: -------------------------------------------------------------------------------- 1 | # slime FAQ 2 | 3 | [中文版](../zh/qa.md) 4 | 5 | 1. **Why do I see garbled text during training?** 6 | 7 | This situation generally occurs because Megatron is not loaded correctly. Please check if there is a corresponding checkpoint in the directory specified by `--load` or `--ref-load`. Note that Megatron can only load a directory that contains a `latest_checkpointed_iteration.txt` file. 8 | 9 | If you need to specify a particular iteration, you can refer to the current Megatron usage instructions. Generally, you can specify the step number using `--ckpt-step`. 10 | 11 | 2. **Why is my task stuck on the Ray submission page?** 12 | 13 | Please check whether your task is set up for co-located training and inference or decoupled training and inference. 14 | 15 | If it's **co-located** (training and inference share the same GPUs), please check: 16 | 17 | * Whether the `--colocate` parameter is set to enable co-located mode. 18 | * Whether the total number of GPUs for the current task is greater than or equal to `actor_num_nodes * actor_num_gpus_per_node`. 19 | 20 | If it's **decoupled**, please check: 21 | 22 | * Whether the total number of GPUs for the current task is greater than or equal to `actor_num_nodes * actor_num_gpus_per_node + rollout_num_gpus`. 23 | 24 | 3. **Why did I encounter an Out-of-Memory (OOM) error during training? What is `max_tokens_per_gpu` for?** 25 | 26 | OOM errors often happen because `max_tokens_per_gpu` is set too high. This parameter defines the maximum number of tokens that can be processed on each GPU during training. If you are concerned about OOM, you can initially set this value to `rollout_max_response_len / cp_size` and then increase it later to improve training efficiency. Note that `--max-tokens-per-gpu` is only active when `--use-dynamic-batch-size` is enabled. 27 | 28 | If you still experience OOM with a small `max_tokens_per_gpu`, check if the data generated in a single pass is too long. You may need to enable context parallelism (CP) with `--context-parallel-size`. If you are using custom data generation, check if the total length of multi-turn generations is much longer than expected. 29 | 30 | 4. **During multi-node training, what should I do if the `transformers` library reports it cannot find a model?** 31 | 32 | This usually happens when multiple processes try to read local files simultaneously using methods like `AutoConfig.from_pretrained` or `AutoModelForCausalLM.from_pretrained`, causing file system write conflicts. You can mitigate this issue by setting the `--model-name` argument. 33 | 34 | 5. **How do I resume training?** 35 | 36 | Simply set the `--load` directory to your `--save` directory. 37 | 38 | 6. **How is the batch size calculated?** 39 | 40 | A single rollout uses `rollout_batch_size` prompts. For each prompt, `n_samples_per_prompt` samples are generated. Therefore, one rollout contains a total of `rollout_batch_size * n_samples_per_prompt` data entries. 41 | 42 | You can use `--num-steps-per-rollout` to determine how many steps to run per rollout. This is equivalent to setting the `global_batch_size` to `rollout_batch_size * n_samples_per_prompt // num_steps_per_rollout`. 43 | 44 | 7. **Does slime perform data packing / variable-length (varlen) processing?** 45 | 46 | Yes. Data packing refers to the process of concatenating samples of varying lengths during training to improve GPU utilization. slime performs this operation by default. 47 | 48 | 8. **What should I do if the sglang component shows a `Max retries exceeded with url: /get_model_info (Caused by NewConnectionError)` error?** 49 | 50 | This issue primarily stems from port conflicts caused by multiple sglang servers running on a single machine. We are currently working with the sglang team to resolve this. A temporary workaround is to minimize the number of sglang servers on a single machine, for example, by setting `tp=8`. 51 | 52 | 9. **My gradient norm is very high and the training crashes. What should I do?** 53 | 54 | First, ensure that your data and model are compatible. For example, if your data already uses a chat template, check if this template matches the one used by the original model. If the data is correct, please refer to our [Debug Guide](./debug.md) for a more in-depth analysis. 55 | 56 | 10. **My sglang generation takes an extremely long time, GPU power is maxed out, and there's no output for a long while. Why?** 57 | 58 | Please verify that the model corresponding to `--hf-checkpoint` has its stop tokens configured correctly. If not, you can set them using the `--rollout-stop` or `--rollout-stop-token-ids` arguments. 59 | 60 | 11. **Sglang shows an `an illegal memory access was encountered` error.** 61 | 62 | According to the sglang documentation ([https://docs.sglang.ai/references/troubleshooting.html](https://docs.sglang.ai/references/troubleshooting.html)), this could be an OOM error. Consider reducing the value of `--sglang-mem-fraction-static`. 63 | 64 | 12. **A `JSONDecodeError` occurs related to torch compile/inductor.** 65 | 66 | This is generally an issue with the torch compiler's cache read/write operations. You can try adding `"TORCHINDUCTOR_FORCE_DISABLE_CACHES": "1"` to the `env_vars` in your Ray configuration. 67 | 68 | 13. **Gradient becomes NaN or Inf during training.** 69 | 70 | You can try setting the `--no-check-for-nan-in-loss-and-grad` flag to skip the corresponding training steps. -------------------------------------------------------------------------------- /docs/en/sft.md: -------------------------------------------------------------------------------- 1 | # Example: Qwen3-4B-Base with OpenHermes-2.5 2 | 3 | [中文版](../zh/sft.md) 4 | 5 | ## Environment Preparation 6 | 7 | First, we need to create a mirror environment and convert the `Qwen3-4B-Base` model by following the [Example: Qwen3-4B Model](./models/qwen3-4B.md). 8 | 9 | After that, we will process the SFT data. Here, we use the classic [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) as an example. First, we process the data into a format suitable for `slime` to load. You can use the following script to add a column that conforms to the OpenAI message format and save it to `/root/openhermes2_5.parquet`. 10 | 11 | ```python 12 | from datasets import load_dataset 13 | 14 | ds = load_dataset("teknium/OpenHermes-2.5")["train"] 15 | 16 | def convert(sample): 17 | conversations = sample["conversations"] 18 | 19 | def convert_role(role): 20 | if role == "human": 21 | return "user" 22 | elif role == "gpt": 23 | return "assistant" 24 | elif role == "system": 25 | return "system" 26 | else: 27 | raise ValueError(f"Unknown role: {role}") 28 | 29 | messages = [ 30 | { 31 | "role": convert_role(turn["from"]), 32 | "content": turn["value"], 33 | } 34 | for turn in conversations 35 | ] 36 | 37 | return {"messages": messages} 38 | 39 | ds = ds.map(convert) 40 | ds.to_parquet("/root/openhermes2_5.parquet") 41 | ``` 42 | 43 | ## Execute Training 44 | 45 | Execute the training: 46 | 47 | ```bash 48 | cd /root/slime 49 | bash script/run-qwen3-4B-base-sft.sh 50 | ``` 51 | 52 | ### Parameter Introduction 53 | 54 | You can compare [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B.sh) with [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh). You will find that besides changing the model from the instruct version to the base model, the main adjustments are as follows: 55 | 56 | 1. Removed `SGLANG_ARGS` and `GRPO_ARGS`. This is because it is not necessary to start SGLang or configure GRPO-related settings during the SFT process. 57 | 58 | 2. Renamed `ROLLOUT_ARGS` to `SFT_ARGS` and configured it as follows: 59 | 60 | ```bash 61 | SFT_ARGS=( 62 | --rollout-function-path slime.rollout.sft_example.generate_rollout 63 | --prompt-data /root/openhermes2_5.parquet 64 | --input-key messages 65 | --rollout-shuffle 66 | --num-epoch 3 67 | --rollout-batch-size 128 68 | --global-batch-size 128 69 | 70 | --loss-type sft_loss 71 | --calculate-per-token-loss 72 | --disable-compute-advantages-and-returns 73 | --debug-train-only 74 | ) 75 | ``` 76 | 77 | SFT actually reuses the custom rollout functionality of slime. By using `--rollout-function-path`, the data generation part is switched from the RL rollout that uses `sglang` to the SFT version that reads data from a file, which is `slime.rollout.sft_example.generate_rollout`. 78 | 79 | For SFT, it is recommended to set `rollout_batch_size` and `global_batch_size` to the same value and not to configure `n_samples_per_prompt`. This is equivalent to training one batch right after reading one batch. 80 | 81 | `slime` also supports different loss types, and we configure the SFT loss using `--loss-type sft_loss`. 82 | 83 | As for `--calculate-per-token-loss`, this is because `slime` defaults to calculating the per-sample mean for GRPO. In general SFT training, the average is taken over all unmasked tokens in a batch, so it is recommended to configure this. 84 | 85 | Finally, `--disable-compute-advantages-and-returns` indicates that there is no need to pre-calculate log probabilities during the SFT process, and `--debug-train-only` means that `sglang` does not need to be initialized. 86 | 87 | 3. Used `train_async.py` instead of `train.py`. This is to leverage the asynchronous training process to implement data prefetching. 88 | -------------------------------------------------------------------------------- /docs/zh/build.md: -------------------------------------------------------------------------------- 1 | # 从零搭建环境 2 | 3 | [English](../en/build.md) 4 | 5 | 在不方便直接使用我们预先准备的镜像的情况下,我们提供了如下的搭建环境的方案: 6 | 7 | ## 基于 anaconda / mamba 搭建环境 8 | 9 | 这里我们以 micromamba 为例,在 sglang 的官方镜像 `lmsysorg/sglang:latest` 中搭建一个名为 slime 的 conda 环境: 10 | 11 | 12 | ```bash 13 | #################### 14 | # create conda 15 | #################### 16 | yes '' | "${SHELL}" <(curl -L micro.mamba.pm/install.sh) 17 | source ~/.bashrc 18 | micromamba self-update 19 | 20 | micromamba create -n slime python=3.10 pip -c conda-forge -y 21 | # install cuda-12.6.0 as this is the default cuda version for pytorch 22 | # and apex need this alignment. 23 | micromamba install -n slime cuda cuda-nvtx cuda-nvtx-dev -c nvidia/label/cuda-12.6.0 -y 24 | micromamba install -n slime -c conda-forge cudnn -y 25 | micromamba run -n slime pip install cmake ninja 26 | 27 | #################### 28 | # sglang deps 29 | #################### 30 | cd /root/ 31 | git clone https://github.com/sgl-project/sglang.git --branch v0.4.9 --depth 1 32 | cd /root/sglang/ 33 | micromamba run -n slime pip -v install -e "python[all]" 34 | # TODO: change to pip install sglang-router after it has a new release 35 | micromamba run -n slime pip install https://github.com/zhuzilin/sgl-router/releases/download/dev/sglang_router-0.1.4-cp310-cp310-linux_x86_64.whl --force-reinstall 36 | 37 | #################### 38 | # megatron deps 39 | #################### 40 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" micromamba run -n slime \ 41 | pip -v install --no-build-isolation \ 42 | git+https://github.com/fanshiqing/grouped_gemm@v1.1.4 43 | # apex 44 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" NVCC_APPEND_FLAGS="--threads 4" \ 45 | micromamba run -n slime \ 46 | pip -v install --disable-pip-version-check --no-cache-dir \ 47 | --no-build-isolation \ 48 | --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git 49 | # transformer engine 50 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" micromamba run -n slime \ 51 | pip -v install transformer_engine[pytorch] 52 | # flash attn 53 | # the newest version megatron supports is v2.7.4.post1 54 | micromamba run -n slime pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl 55 | # megatron 56 | cd /root/ 57 | git clone https://github.com/NVIDIA/Megatron-LM.git 58 | cd Megatron-LM/ 59 | micromamba run -n slime pip install -e . 60 | 61 | #################### 62 | # other deps 63 | #################### 64 | micromamba run -n slime pip install git+https://github.com/zhuzilin/cumem_allocator.git --no-build-isolation 65 | 66 | #################### 67 | # slime 68 | #################### 69 | cd /root/ 70 | git clone https://github.com/THUDM/slime.git 71 | cd slime/ 72 | micromamba run -n slime pip install -e . 73 | # apply patch 74 | cd /root/sglang 75 | git apply /root/slime/docker/patch/sglang.patch 76 | ``` 77 | -------------------------------------------------------------------------------- /docs/zh/debug.md: -------------------------------------------------------------------------------- 1 | # Debug 指南 2 | 3 | [English](../en/debug.md) 4 | 5 | ## 对齐精度 6 | 7 | 在开发 slime 的过程中,经常会需要检查模型的精度是否正确,可以通过以下方式检查: 8 | 9 | 1. 训练第一步 10 | 1. rollout 的生成是否是人话,如果不是,有以下 2 种可能: 11 | - 参数没有正常加载。需要查看是否有 megatron 成功加载 ckpt 的日志; 12 | - 更新参数有误。可以查看是不是所有的参数都做了转换和参数对应,或者参数名是不是根据并行做了转换(例如 pp_size > 1 时,第二个 stage 提供的参数的 layer id 是不是正确的)。一个比较彻底的方法是在对应模型的 sglang 实现的 `load_weights` 中保存所有的参数,查看和加载的 ckpt 中是否一致; 13 | - 如果所有参数更新都正确,还出现问题,有可能是 sglang 里有一些特殊的 buffer 在 release 的时候被释放了; 14 | - 如果是用 pretrain 模型进行的测试,可以换成同结构模型的 instruct 版本,查看这种乱码是不是 pretrain 模型特有的。 15 | 2. 查看打印的 rollout stats 的 `log_probs` 和 `ref_log_probs` 是否完全相等(即第一步 kl=0),且值较小 16 | - 如果不是完全相等的,一般是 transformer engine 中的某些 non-deterministic kernel 导致的,例如: 17 | - 在某些版本的 te 里,megatron 需要 `--attention-backend flash`,来强制使用 flash attention,从而避免 CP 下 fused attention 的数值不稳定; 18 | - 如果数值较大(例如 >1),一般有 2 种可能: 19 | - 如果值非常大,应该是训练配置有问题; 20 | - 如果值只是比 sft loss 的状态略大,例如 instruct 模型的 logprob 到了 0.8,有可能是数据不符合训练的 chat template,或者不符合冷启动的分布。 21 | 3. 查看在推一训一(`num_steps_per_rollout == 1`),kl 是否为 0,grad_norm 是否较小 22 | - 基本上就是一些 megatron / te 相关的 bug,例如: 23 | - moe 需要开启 `--moe-permute-fusion`。 24 | 25 | 2. 训练第二步 26 | 1. 对于训推一体,查看是否能正确加载第二步,是否会 OOM; 27 | 28 | ## 训练推理单独 debug 29 | 30 | slime 支持将训练部分和推理部分分开进行调试,从而实现: 31 | 32 | - 在调优/debug 推理部分时,只用少量卡就可以启动任务; 33 | - 在调优/debug 训练部分时,可以保证模型输入固定,去除 rollout 的随机性。 34 | 35 | 具体来说,目前 slime 提供了如下的参数来进行分离调试: 36 | 37 | 1. `--debug-rollout-only` 38 | 39 | 开启后,slime 将不会加载 megatron,只初始化 sglang ,可以用这个方法来进行推理部分的调试。 40 | 41 | 1. `--debug-train-only` 42 | 43 | 开启后,slime 将不会加载 sglang,只初始化 megatron ,可以用这个方法来进行训练部分的调试。 44 | 45 | 2. `--save-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt` 46 | 47 | 开启后,会保存每次 rollout 的结果,可以和 `--debug-rollout-only` 配合使用。注意保存的方式为 `args.save_debug_rollout_data.format(rollout_id=rollout_id)`。 48 | 49 | 3. `--load-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt` 50 | 51 | 开启后,会从 `args.load_debug_rollout_data.format(rollout_id=rollout_id)` 来加载数据,并且不会初始化 sglang(自动设置 `debug_train_only=True`)。可以以这种方式来固定训练部分的输入,对训练部分进行调优,例如切换各种并行。 52 | -------------------------------------------------------------------------------- /docs/zh/models/qwen3-30B-A3B.md: -------------------------------------------------------------------------------- 1 | # 示例:Qwen3-30B-A3B 模型 2 | 3 | [English](../../en/models/qwen3-30B-A3B.md) 4 | 5 | ## 环境准备 6 | 7 | 搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B 模型](./qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 Qwen3-30B-A3B 即可。 8 | 9 | ## 执行训练 10 | 11 | 执行训练: 12 | 13 | ```bash 14 | cd /root/slime 15 | bash scripts/run-qwen3-30B-A3B.sh 16 | ``` 17 | 18 | ### 参数简介 19 | 20 | 这里我们简单介绍一下脚本 [run-qwen3-30B-A3B.sh](../../../scripts/run-qwen3-30B-A3B.sh) 中与 MoE 相关的部分。 21 | 22 | 1. 为了支持在 8xH800 环境中运行 Qwen3-30B-A3B,我们需要开启 megatron 的 CPU Adam 以节省显存,对应配置为: 23 | 24 | ```bash 25 | OPTIMIZER_ARGS=( 26 | ... 27 | --optimizer-cpu-offload 28 | --overlap-cpu-optimizer-d2h-h2d 29 | --use-precision-aware-optimizer 30 | ) 31 | ``` 32 | 33 | 2. 开启 megatron 支持的 moe 优化,当前配置为 tp4, ep8: 34 | 35 | ```bash 36 | PERF_ARGS=( 37 | --tensor-model-parallel-size 4 38 | --sequence-parallel 39 | --pipeline-model-parallel-size 1 40 | --context-parallel-size 1 41 | --expert-model-parallel-size 8 42 | --expert-tensor-parallel-size 1 43 | ... 44 | ) 45 | ``` 46 | 47 | 3. 开启 sglang 支持的 moe 优化,当前配置为 ep8: 48 | 49 | ```bash 50 | SGLANG_ARGS=( 51 | --rollout-num-gpus-per-engine 8 52 | --sglang-mem-fraction-static 0.5 53 | --sglang-enable-ep-moe 54 | --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) 55 | ) 56 | ``` 57 | 58 | 类似地,也可以加入 dp attention,例如配置上: 59 | 60 | ```bash 61 | --sglang-enable-dp-attention 62 | --sglang-dp-size 8 63 | ``` 64 | -------------------------------------------------------------------------------- /docs/zh/qa.md: -------------------------------------------------------------------------------- 1 | # slime 常见 Q&A 2 | 3 | [English](../en/qa.md) 4 | 5 | 1. **训练过程中为什么会出现乱码?** 6 | 7 | 一般来说这种情况是 megatron 没有被正确加载。请检查 `--load` 或 `--ref-load` 是否有对应的 ckpt。注意 megatron 只能加载其中有 `latest_checkpointed_iteration.txt` 的目录。 8 | 9 | 如果需要指定某个特定的 iter,可以查看当前 megatron 的使用方法,一般是可以通过 `--ckpt-step` 来指定步数。 10 | 11 | 1. **为什么我的任务一直卡在 ray 提交的页面上?** 12 | 13 | 请先检查你需要跑的任务是训推一体的,还是训推分离的。 14 | 15 | 如果是训推一体,即训练和推理共用 GPU,请检查 16 | 17 | - 是否设置了 `--colocate` 参数开启训推一体; 18 | - 当前任务的总卡数是否大于等于 `actor_num_nodes * actor_num_gpus_per_node` 19 | 20 | 如果是训推分离,请检查: 21 | 22 | - 当前任务的总卡数是否大于等于 `actor_num_nodes * actor_num_gpus_per_node + rollout_num_gpus` 23 | 24 | 1. **为什么训着训着 OOM 了?`max_tokens_per_gpu` 是干什么用的?** 25 | 26 | OOM 往往是因为 `max_tokens_per_gpu` 设置过高了。 `max_tokens_per_gpu` 是指在训练过程中,每张 GPU 上最多可以放多少 token。如果担心 OOM 的话,可以先把这个值设成 `rollout_max_response_len / cp_size`,之后再为了提升训练效率来增大这个值。`--max-tokens-per-gpu` 只有在开启 `--use-dynamic-batch-size` 的情况下才会启用。 27 | 28 | 如果 `max_tokens_per_gpu` 很小,还会 oom,可以检查一下是否单次生成的数据太长了,需要开启 cp(`--context-parallel-size`)。如果进行了自定义的数据生成,可以看一下是否在多轮生成的情况下,生成的总长度比预期的长很多。 29 | 30 | 1. **多机训练的时候,遇到了 transformers 库找不到某个模型的错误该怎么办?** 31 | 32 | 这种情况一般是因为多个进程都在通过类似于 `AutoConfig.from_pretrained` 或者 `AutoModelForCausalLM.from_pretrained` 的方式读取本地文件,出现了文件系统的写冲突。可以通过设置 `--model-name` 缓解这一问题。 33 | 34 | 1. **如何续训?** 35 | 36 | 直接将 `--load` 设置为 `--save` 的目录即可。 37 | 38 | 1. **batch size 是如何计算的?** 39 | 40 | 一个 rollout 会用 `rollout_batch_size` 条 prompt,每一条会采 `n_samples_per_prompt` 条,所以一个 rollout 共 `rollout_batch_size * n_samples_per_prompt` 条数据。 41 | 42 | 可以用 `--num-steps-per-rollout` 来决定每一个 rollout 跑多少步。这相当于是把 `global_batch_size` 设置成 `rollout_batch_size * n_samples_per_prompt // num_steps_per_rollout`。 43 | 44 | 1. **slime 是否进行了 data packing / varlen 处理?** 45 | 46 | data packing 是指在训练过程中,将长短不一的 sample 拼接到一起,从而提升训练的利用率。slime 默认会进行这样的操作。 47 | 48 | 1. **sglang 部分出现 `Max retries exceeded with url: /get_model_info (Caused by NewConnectionError` 的问题怎么办?** 49 | 50 | 这个问题主要来源于单机内多个 sglang server 导致的端口冲突,目前我们仍在和 sglang 团队一起解决这个问题。一个临时的缓解方案是尽可能减少单机内的 sglang server 数量,例如设置 tp=8。 51 | 52 | 1. **grad norm 好高,训练训崩了怎么办?** 53 | 54 | 首先请确保数据和模型是匹配的,例如说,如果数据是实现已经做好 chat template 的了,这个 chat template 是否和原模型一致。如果数据正确的话,可以参考 [debug 指南](./debug.md) 进行更深入的分析。 55 | 56 | 1. **我的 sglang 生成时间特别特别久,gpu 功率都打满了,跑了好久好没有输出是为什么?** 57 | 58 | 请确认一下 `--hf-checkpoint` 对应的模型是否正确设置了 stop token,如果没有,可以通过 `--rollout-stop` 或者 `--rollout-stop-token-ids` 来进行设置。 59 | 60 | 1. **sglang 出现 an illegal memory access was encountered** 61 | 62 | 根据 sglang 的文档(https://docs.sglang.ai/references/troubleshooting.html),有可能是 OOM 了,可以考虑缩小 `--sglang-mem-fraction-static`。 63 | 64 | 1. **出现 torch compile/inducer 的 `JSONDecodeError`** 65 | 66 | 一般是 torch compile 读写 cache 出现的问题。可以考虑在 ray 的 env_var 里加上 `"TORCHINDUCTOR_FORCE_DISABLE_CACHES": "1"`。 67 | 68 | 1. **训练出现 grad NaN 或者 Inf 的情况** 69 | 70 | 可以通过设置 `--no-check-for-nan-in-loss-and-grad` 来尝试跳过对应的训练步。 71 | -------------------------------------------------------------------------------- /docs/zh/sft.md: -------------------------------------------------------------------------------- 1 | # 示例:Qwen3-4B-Base + OpenHermes-2.5 2 | 3 | [English](../en/sft.md) 4 | 5 | ## 环境准备 6 | 7 | 首先需要我们仿照 [示例:Qwen3-4B 模型](./models/qwen3-4B.md) 创建镜像环境与转换 `Qwen3-4B-Base` 模型。 8 | 9 | 之后,我们处理 sft 数据。这里我们以经典的 [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) 为例,首先把数据处理成适合 slime 加载的格式,可以用如下的脚本进行处理,增加一个符合 openai message 格式的列,并保存在 `/root/openhermes2_5.parquet`。 10 | 11 | ```python 12 | from datasets import load_dataset 13 | 14 | ds = load_dataset("teknium/OpenHermes-2.5")["train"] 15 | 16 | def convert(sample): 17 | conversations = sample["conversations"] 18 | 19 | def convert_role(role): 20 | if role == "human": 21 | return "user" 22 | elif role == "gpt": 23 | return "assistant" 24 | elif role == "system": 25 | return "system" 26 | else: 27 | raise ValueError(f"Unknown role: {role}") 28 | 29 | messages = [ 30 | { 31 | "role": convert_role(turn["from"]), 32 | "content": turn["value"], 33 | } 34 | for turn in conversations 35 | ] 36 | 37 | return {"messages": messages} 38 | 39 | ds = ds.map(convert) 40 | ds.to_parquet("/root/openhermes2_5.parquet") 41 | ``` 42 | 43 | ## 执行训练 44 | 45 | 执行训练: 46 | 47 | ```bash 48 | cd /root/slime 49 | bash script/run-qwen3-4B-base-sft.sh 50 | ``` 51 | 52 | ### 参数简介 53 | 54 | 可以将 [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B-base-sft.sh) 与 [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh) 进行对比。会发现除了我们将模型由 instruct 模型换为了 base 模型之外,主要进行了如下的几个调整: 55 | 56 | 1. 移除了 `SGLANG_ARGS` 和 `GRPO_ARGS`。这是因为 sft 的过程中不需要启动 sglang 或者做 grpo 相关的配置; 57 | 58 | 2. 将 `ROLLOUT_ARGS` 改名为了 `SFT_ARGS`,并配置为: 59 | 60 | ```bash 61 | SFT_ARGS=( 62 | --rollout-function-path slime.rollout.sft_example.generate_rollout 63 | --prompt-data /root/openhermes2_5.parquet 64 | --input-key messages 65 | --rollout-shuffle 66 | --num-epoch 3 67 | --rollout-batch-size 128 68 | --global-batch-size 128 69 | 70 | --loss-type sft_loss 71 | --calculate-per-token-loss 72 | --disable-compute-advantages-and-returns 73 | --debug-train-only 74 | ) 75 | ``` 76 | 77 | slime 中的 sft 实际上是复用了 slime 的 custom rollout 功能,通过 `--rollout-function-path` 将数据生成部分从使用 sglang 的 RL rollout,切换成了从文件中读取数据的 sft 版本,即 `slime.rollout.sft_example.generate_rollout`。 78 | 79 | 对于 sft 来说,建议将 `rollout_batch_size` 与 `global_batch_size` 设置成相同的,并不要配置 `n_samples_per_prompt`,这样相当于是读一个 batch 就训一个 batch。 80 | 81 | slime 还支持不同的 loss 类型,我们就是通过 `--loss-type sft_loss` 配置上 sft loss 的。 82 | 83 | 至于 `--calculate-per-token-loss`,这是因为 slime 默认是以 GRPO 的 per sample mean 进行计算的,而一般 sft 训练都是按一个 batch 的所有不被 mask 的 token 取平均,所以建议配置上。 84 | 85 | 最后 `--disable-compute-advantages-and-returns` 表示 sft 的过程中不需要预先计算 log prob,`--debug-train-only` 表示不需要初始化 sglang。 86 | 87 | 3. 使用了 `train_async.py` 而不是 `train.py`。这是为了利用异步训练的流程,来实现数据 prefetch。 88 | -------------------------------------------------------------------------------- /examples/search-r1/README.md: -------------------------------------------------------------------------------- 1 | # Example: Search-R1 lite 2 | 3 | [中文版](./README_zh.md) 4 | 5 | This is a minimal reproduction of [Search-R1](https://github.com/PeterGriffinJin/Search-R1) and an example of using multi-turn conversation and tool-calling in slime. 6 | 7 | ## Environment Setup 8 | 9 | Use the `zhuzilin/slime:latest` image and initialize the environment required for Search-R1: 10 | 11 | ```bash 12 | cd /root/ 13 | git clone https://github.com/THUDM/slime.git 14 | pip install -e . 15 | # for Search R1 16 | pip install chardet 17 | ``` 18 | 19 | Please refer to the script provided in Search-R1 to download the data: 20 | 21 | ```bash 22 | git clone https://github.com/PeterGriffinJin/Search-R1.git 23 | cd Search-R1/ 24 | python scripts/data_process/nq_search.py --local_dir /root/nq_search/ 25 | ``` 26 | 27 | Initialize the Qwen2.5-3B model: 28 | 29 | ```bash 30 | # hf checkpoint 31 | huggingface-cli download Qwen/Qwen2.5-3B --local-dir /root/Qwen2.5-3B 32 | 33 | # mcore checkpoint 34 | cd /root/slime 35 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ 36 | --hf-checkpoint /root/Qwen2.5-3B \ 37 | --save /root/Qwen2.5-3B_torch_dist 38 | ``` 39 | 40 | ## Running the Script 41 | 42 | You need to configure your serper.dev API in `generate_with_search.py`: 43 | 44 | ```python 45 | SEARCH_R1_CONFIGS = { 46 | "max_turns": 3, 47 | "topk": 3, 48 | "google_api_key": "YOUR_API_KEY", # Replace with your actual API key 49 | "snippet_only": True, # Set to True to only return snippets 50 | "proxy": None, # Set to your proxy if needed 51 | "search_concurrency": 256, 52 | # rm 53 | "format_score": 0.2, 54 | } 55 | ``` 56 | 57 | And run: 58 | 59 | ```bash 60 | cd slime/ 61 | bash examples/search-r1/run_qwen2.5_3B.sh 62 | ``` 63 | 64 | ## Code Structure 65 | 66 | To implement multi-turn conversation + tool-calling in slime, you only need to implement a custom data generation function and a reward model for the task. These correspond to the following 2 configuration items in the startup script: 67 | 68 | ```bash 69 | CUSTOM_ARGS=( 70 | --custom-generate-function-path generate_with_search.generate 71 | --custom-rm-path generate_with_search.reward_func 72 | ) 73 | ``` 74 | 75 | These are the `generate` and `reward_func` functions in `generate_with_search.py`. 76 | -------------------------------------------------------------------------------- /examples/search-r1/README_zh.md: -------------------------------------------------------------------------------- 1 | # 示例:Search-R1 lite 2 | 3 | [English](./README.md) 4 | 5 | 这里是一个对 [Search-R1](https://github.com/PeterGriffinJin/Search-R1) 的简单复现,以及是一个在 slime 中使用多轮对话和工具调用的样例。 6 | 7 | ## 配置环境 8 | 9 | 使用 `zhuzilin/slime:latest` 镜像,并初始化 Search-R1 需要的环境: 10 | 11 | ```bash 12 | cd /root/ 13 | git clone https://github.com/THUDM/slime.git 14 | pip install -e . 15 | # for Search R1 16 | pip install chardet 17 | ``` 18 | 19 | 请参照 Search-R1 中提供的脚本下载数据: 20 | 21 | ```bash 22 | git clone https://github.com/PeterGriffinJin/Search-R1.git 23 | cd Search-R1/ 24 | python scripts/data_process/nq_search.py --local_dir /root/nq_search/ 25 | ``` 26 | 27 | 初始化 Qwen2.5-3B 模型: 28 | 29 | ```bash 30 | # hf checkpoint 31 | huggingface-cli download Qwen/Qwen2.5-3B --local-dir /root/Qwen2.5-3B 32 | 33 | # mcore checkpoint 34 | cd /root/slime 35 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ 36 | --hf-checkpoint /root/Qwen2.5-3B \ 37 | --save /root/Qwen2.5-3B_torch_dist 38 | ``` 39 | 40 | ## 运行脚本 41 | 42 | 需要将你的 serper.dev API 配置在 `generate_with_search.py` 中: 43 | 44 | ```python 45 | SEARCH_R1_CONFIGS = { 46 | "max_turns": 3, 47 | "topk": 3, 48 | "google_api_key": "YOUR_API_KEY", # Replace with your actual API key 49 | "snippet_only": True, # Set to True to only return snippets 50 | "proxy": None, # Set to your proxy if needed 51 | "search_concurrency": 256, 52 | # rm 53 | "format_score": 0.2, 54 | } 55 | ``` 56 | 57 | 并运行: 58 | 59 | ```bash 60 | cd slime/ 61 | bash examples/search-r1/run_qwen2.5_3B.sh 62 | ``` 63 | 64 | ## 代码结构 65 | 66 | 为了实现多轮 + 工具调用,在 slime 中只需要实现一个自定义的数据生成函数,以及一个任务所需的 reward model,对应启动脚本中的这 2 个配置项: 67 | 68 | ```bash 69 | CUSTOM_ARGS=( 70 | --custom-generate-function-path generate_with_search.generate 71 | --custom-rm-path generate_with_search.reward_func 72 | ) 73 | ``` 74 | 75 | 也就是 `generate_with_search.py` 中的 `generate` 和 `reward_func` 两个函数。 76 | -------------------------------------------------------------------------------- /examples/search-r1/generate_with_search.py: -------------------------------------------------------------------------------- 1 | # Adapted form https://github.com/PeterGriffinJin/Search-R1/blob/ceee7b89655ed52f205b9beb98e1190c3eedcfb0/search_r1/llm_agent/generation.py 2 | import asyncio 3 | import re 4 | 5 | from google_search_server import google_search 6 | from qa_em_format import compute_score_em 7 | 8 | from slime.rollout.sglang_example import GenerateState 9 | from slime.utils.http_utils import post 10 | from slime.utils.types import Sample 11 | 12 | SEARCH_R1_CONFIGS = { 13 | "max_turns": 3, 14 | "topk": 3, 15 | "google_api_key": "YOUR_API_KEY", # Replace with your actual API key 16 | "snippet_only": True, # Set to True to only return snippets 17 | "proxy": None, # Set to your proxy if needed 18 | "search_concurrency": 256, 19 | # rm 20 | "format_score": 0.2, 21 | } 22 | 23 | 24 | SEMAPHORE = asyncio.Semaphore(SEARCH_R1_CONFIGS["search_concurrency"]) 25 | 26 | 27 | def _passages2string(retrieval_result): 28 | format_reference = "" 29 | for idx, doc_item in enumerate(retrieval_result): 30 | 31 | content = doc_item["document"]["contents"] 32 | title = content.split("\n")[0] 33 | text = "\n".join(content.split("\n")[1:]) 34 | format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" 35 | 36 | return format_reference 37 | 38 | 39 | async def search(query: str) -> str: 40 | result = await google_search( 41 | SEARCH_R1_CONFIGS["google_api_key"], 42 | query, 43 | SEARCH_R1_CONFIGS["topk"], 44 | snippet_only=SEARCH_R1_CONFIGS["snippet_only"], 45 | proxy=SEARCH_R1_CONFIGS["proxy"], 46 | ) 47 | return _passages2string(result) 48 | 49 | 50 | def postprocess_responses(resp: str) -> str: 51 | return ( 52 | resp.split("</search>")[0] + "</search>" 53 | if "</search>" in resp 54 | else resp.split("</answer>")[0] + "</answer>" if "</answer>" in resp else resp 55 | ) 56 | 57 | 58 | def postprocess_predictions(prediction: str): 59 | pattern = r"<(search|answer)>(.*?)</\1>" 60 | match = re.search(pattern, prediction, re.DOTALL) 61 | if match: 62 | content = match.group(2).strip() # Return only the content inside the tags 63 | action = match.group(1) 64 | else: 65 | content = "" 66 | action = None 67 | 68 | return action, content 69 | 70 | 71 | async def execute_predictions(prediction: str) -> str: 72 | action, content = postprocess_predictions(prediction) 73 | 74 | if action == "search": 75 | search_query = content 76 | async with SEMAPHORE: 77 | search_results = await search(search_query) 78 | next_obs = f"\n\n<information>{search_results.strip()}</information>\n\n" 79 | done = False 80 | elif action == "answer": 81 | next_obs = "" 82 | done = True 83 | else: 84 | next_obs = f"\nMy previous action is invalid. \ 85 | If I want to search, I should put the query between <search> and </search>. \ 86 | If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n" 87 | done = False 88 | 89 | return next_obs, done 90 | 91 | 92 | async def generate(args, sample: Sample, sampling_params) -> Sample: 93 | assert not args.partial_rollout, f"Partial rollout is not supported for this function at the moment." 94 | 95 | state = GenerateState(args) 96 | 97 | url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" 98 | 99 | # Handle partial rollout samples: continue generation from existing response 100 | prompt = sample.prompt 101 | prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"] 102 | response = "" 103 | response_token_ids = [] 104 | loss_masks = [] 105 | for _ in range(SEARCH_R1_CONFIGS["max_turns"]): 106 | payload = { 107 | "text": prompt + response, 108 | "sampling_params": sampling_params, 109 | } 110 | output = await post(url, payload, use_http2=args.use_http2) 111 | 112 | # abort 113 | if output["meta_info"]["finish_reason"]["type"] == "abort": 114 | sample.status = Sample.Status.ABORTED 115 | return sample 116 | 117 | cur_response = output["text"] 118 | cur_response = postprocess_responses(cur_response) 119 | 120 | cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"] 121 | response += cur_response 122 | response_token_ids += cur_response_token_ids 123 | loss_masks += [1] * len(cur_response_token_ids) 124 | 125 | if output["meta_info"]["finish_reason"]["type"] == "length": 126 | break 127 | 128 | next_obs, done = await execute_predictions(cur_response) 129 | if done: 130 | break 131 | 132 | assert next_obs != "", "Next observation should not be empty." 133 | obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"] 134 | response += next_obs 135 | response_token_ids += obs_tokens_ids 136 | loss_masks += [0] * len(obs_tokens_ids) 137 | 138 | sample.tokens = prompt_tokens_ids + response_token_ids 139 | sample.response_length = len(response_token_ids) 140 | sample.response = response 141 | sample.loss_masks = loss_masks 142 | match output["meta_info"]["finish_reason"]["type"]: 143 | case "length": 144 | sample.status = Sample.Status.TRUNCATED 145 | case "abort": 146 | sample.status = Sample.Status.ABORTED 147 | case "stop": 148 | sample.status = Sample.Status.COMPLETED 149 | 150 | return sample 151 | 152 | 153 | async def reward_func(args, sample, **kwargs): 154 | """The reward function for retrieval-based question answering. 155 | 156 | Args: 157 | args: the arguments 158 | sample: the sample to evaluate 159 | """ 160 | if not isinstance(sample, Sample): 161 | raise TypeError("Sample must be an instance of Sample class.") 162 | 163 | score = compute_score_em( 164 | solution_str=sample.prompt + sample.response, 165 | ground_truth=sample.label["ground_truth"], 166 | format_score=SEARCH_R1_CONFIGS["format_score"], 167 | ) 168 | 169 | return score 170 | -------------------------------------------------------------------------------- /examples/search-r1/google_search_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import random 4 | import re 5 | from typing import Dict, List 6 | 7 | import aiohttp 8 | import chardet 9 | 10 | 11 | # --- Utilities --- 12 | def parse_snippet(snippet: str) -> List[str]: 13 | segments = snippet.split("...") 14 | return [s.strip() for s in segments if len(s.strip().split()) > 5] 15 | 16 | 17 | def sanitize_search_query(query: str) -> str: 18 | # Remove or replace special characters that might cause issues. 19 | # This is a basic example; you might need to add more characters or patterns. 20 | sanitized_query = re.sub(r"[^\w\s]", " ", query) # Replace non-alphanumeric and non-whitespace with spaces. 21 | sanitized_query = re.sub( 22 | r"[\t\r\f\v\n]", " ", sanitized_query 23 | ) # replace tab, return, formfeed, vertical tab with spaces. 24 | sanitized_query = re.sub( 25 | r"\s+", " ", sanitized_query 26 | ).strip() # remove duplicate spaces, and trailing/leading spaces. 27 | 28 | return sanitized_query 29 | 30 | 31 | def filter_links(search_results: List[Dict]) -> List[str]: 32 | links = [] 33 | for result in search_results: 34 | for item in result.get("items", []): 35 | if "mime" in item: 36 | continue 37 | ext = os.path.splitext(item["link"])[1] 38 | if ext in ["", ".html", ".htm", ".shtml"]: 39 | links.append(item["link"]) 40 | return links 41 | 42 | 43 | async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str: 44 | if url == "": 45 | return "" 46 | user_agents = [ 47 | "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...", 48 | "Mozilla/5.0 AppleWebKit/537.36...", 49 | "Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)", 50 | ] 51 | headers = {"User-Agent": random.choice(user_agents)} 52 | 53 | async with semaphore: 54 | try: 55 | async with session.get(url, headers=headers) as response: 56 | raw = await response.read() 57 | detected = chardet.detect(raw) 58 | encoding = detected["encoding"] or "utf-8" 59 | return raw.decode(encoding, errors="ignore") 60 | except (aiohttp.ClientError, asyncio.TimeoutError): 61 | return "" 62 | 63 | 64 | async def fetch_all(urls: List[str], limit: int = 8) -> List[str]: 65 | semaphore = asyncio.Semaphore(limit) 66 | timeout = aiohttp.ClientTimeout(total=5) 67 | connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True) 68 | 69 | async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: 70 | tasks = [fetch(session, url, semaphore) for url in urls] 71 | return await asyncio.gather(*tasks) 72 | 73 | 74 | def collect_context(snippet: str, doc: str) -> str: 75 | snippets = parse_snippet(snippet) 76 | ctx_paras = [] 77 | 78 | for s in snippets: 79 | pos = doc.replace("\n", " ").find(s) 80 | if pos == -1: 81 | continue 82 | sta = pos 83 | while sta > 0 and doc[sta] != "\n": 84 | sta -= 1 85 | end = pos + len(s) 86 | while end < len(doc) and doc[end] != "\n": 87 | end += 1 88 | para = doc[sta:end].strip() 89 | if para not in ctx_paras: 90 | ctx_paras.append(para) 91 | 92 | return "\n".join(ctx_paras) 93 | 94 | 95 | async def google_search(api_key, query, top_k=5, timeout: int = 60, proxy=None, snippet_only=False) -> List[Dict]: 96 | timeout_obj = aiohttp.ClientTimeout(total=timeout) 97 | session_kwargs = {} 98 | if proxy: 99 | session_kwargs["proxy"] = proxy 100 | async with aiohttp.ClientSession(**session_kwargs) as session: 101 | async with session.post( 102 | "https://google.serper.dev/search", 103 | json={ 104 | "q": query, 105 | "num": top_k, 106 | "gl": "us", 107 | "hl": "en", 108 | }, 109 | headers={ 110 | "Content-Type": "application/json", 111 | "X-API-KEY": api_key, 112 | }, 113 | timeout=timeout_obj, 114 | ) as resp: 115 | resp.raise_for_status() 116 | response = await resp.json() 117 | items = response.get("organic", []) 118 | 119 | contexts = [] 120 | if snippet_only: 121 | for item in items: 122 | title = item.get("title", "") 123 | context = " ".join(parse_snippet(item.get("snippet", ""))) 124 | if title != "" or context != "": 125 | title = "No title." if not title else title 126 | context = "No snippet available." if not context else context 127 | contexts.append( 128 | { 129 | "document": {"contents": f'"{title}"\n{context}'}, 130 | } 131 | ) 132 | else: 133 | links = [item.get("link", "") for item in items if "link" in item] 134 | web_contents = await fetch_all(links) 135 | contexts = [] 136 | for i, item in enumerate(items): 137 | title = item.get("title", "") 138 | snippet = item.get("snippet", "") 139 | 140 | context = collect_context(snippet, web_contents[i]) 141 | if title != "" or context != "": 142 | title = "No title." if not title else title 143 | context = "No snippet available." if not context else context 144 | contexts.append( 145 | { 146 | "document": {"contents": f'"{title}"\n{context}'}, 147 | } 148 | ) 149 | 150 | return contexts 151 | -------------------------------------------------------------------------------- /examples/search-r1/run_qwen2.5_3B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # will prevent ray from buffering stdout/stderr 16 | export PYTHONBUFFERED=16 17 | 18 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 19 | source "${SCRIPT_DIR}/../../scripts/models/qwen2.5-3B.sh" 20 | 21 | CKPT_ARGS=( 22 | --hf-checkpoint /root/Qwen2.5-3B/ 23 | --ref-load /root/Qwen2.5-3B_torch_dist/ 24 | --load /root/Qwen2.5-3B_slime/ 25 | --save /root/Qwen2.5-3B_slime/ 26 | --save-interval 20 27 | ) 28 | 29 | ROLLOUT_ARGS=( 30 | --prompt-data /root/nq_search/train.parquet 31 | --input-key prompt 32 | --label-key reward_model 33 | --apply-chat-template 34 | --rollout-shuffle 35 | --num-rollout 3000 36 | --rollout-batch-size 32 37 | --n-samples-per-prompt 8 38 | --rollout-max-response-len 512 39 | --rollout-temperature 0.8 40 | 41 | --global-batch-size 256 42 | --balance-data 43 | ) 44 | 45 | PERF_ARGS=( 46 | --tensor-model-parallel-size 2 47 | --sequence-parallel 48 | --pipeline-model-parallel-size 1 49 | --context-parallel-size 1 50 | --expert-model-parallel-size 1 51 | --expert-tensor-parallel-size 1 52 | 53 | --recompute-granularity full 54 | --recompute-method uniform 55 | --recompute-num-layers 1 56 | 57 | # --micro-batch-size 1 58 | --use-dynamic-batch-size 59 | --max-tokens-per-gpu 9216 60 | ) 61 | 62 | GRPO_ARGS=( 63 | --advantage-estimator grpo 64 | --use-kl-loss 65 | --kl-loss-coef 0.00 66 | --kl-loss-type low_var_kl 67 | --kl-coef 0.00 68 | --entropy-coef 0.00 69 | --eps-clip 0.2 70 | --eps-clip-high 0.28 71 | ) 72 | 73 | OPTIMIZER_ARGS=( 74 | --optimizer adam 75 | --lr 1e-6 76 | --lr-decay-style constant 77 | --weight-decay 0.1 78 | --adam-beta1 0.9 79 | --adam-beta2 0.98 80 | ) 81 | 82 | WANDB_ARGS=( 83 | # --use-wandb 84 | # --wandb-project slime-dev 85 | # --wandb-group search-r1_qwen2.5-3B-test 86 | # --wandb-key ${WANDB_KEY} 87 | ) 88 | 89 | SGLANG_ARGS=( 90 | --rollout-num-gpus-per-engine 2 91 | --sglang-mem-fraction-static 0.7 92 | ) 93 | 94 | MISC_ARGS=( 95 | # default dropout in megatron is 0.1 96 | --attention-dropout 0.0 97 | --hidden-dropout 0.0 98 | # should be good for model performance 99 | --accumulate-allreduce-grads-in-fp32 100 | --attention-softmax-in-fp32 101 | # need to comment this when using model with MLA 102 | --attention-backend flash 103 | ) 104 | 105 | CUSTOM_ARGS=( 106 | --custom-generate-function-path generate_with_search.generate 107 | --custom-rm-path generate_with_search.reward_func 108 | ) 109 | 110 | # launch the master node of ray in container 111 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 112 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 113 | 114 | RUNTIME_ENV_JSON="{ 115 | \"env_vars\": { 116 | \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", 117 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" 118 | } 119 | }" 120 | 121 | ray job submit --address="http://127.0.0.1:8265" \ 122 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 123 | -- python3 train.py \ 124 | --actor-num-nodes 1 \ 125 | --actor-num-gpus-per-node 4 \ 126 | --rollout-num-gpus 4 \ 127 | --colocate \ 128 | ${MODEL_ARGS[@]} \ 129 | ${CKPT_ARGS[@]} \ 130 | ${ROLLOUT_ARGS[@]} \ 131 | ${OPTIMIZER_ARGS[@]} \ 132 | ${GRPO_ARGS[@]} \ 133 | ${DISTRIBUTED_ARGS[@]} \ 134 | ${WANDB_ARGS[@]} \ 135 | ${PERF_ARGS[@]} \ 136 | ${SGLANG_ARGS[@]} \ 137 | ${MISC_ARGS[@]} \ 138 | ${CUSTOM_ARGS[@]} 139 | -------------------------------------------------------------------------------- /imgs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/imgs/arch.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "packaging", 4 | "setuptools >= 49.4.0", 5 | "wheel", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.isort] 10 | profile = "black" # black-compatible 11 | line_length = 119 # should match black parameters 12 | ignore_whitespace = true # ignore whitespace for compatibility with the initial style 13 | py_version = 310 # python 3.10 as a target version 14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] 15 | default_section = "THIRDPARTY" 16 | extend_skip = ["setup.py", "docs/source/conf.py"] 17 | 18 | 19 | [tool.black] 20 | line_length = 119 21 | 22 | [tool.ruff] 23 | line-length = 119 24 | 25 | [tool.pytest.ini_options] 26 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one. 27 | # -vv will also display tests with durration = 0.00s 28 | addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest 29 | testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module 30 | # directories to ignore when discovering tests 31 | norecursedirs = [ 32 | "external", 33 | "examples", 34 | "docs", 35 | "scripts", 36 | "tools", 37 | "tutorials", 38 | "*.egg", 39 | ".*", 40 | "_darcs", 41 | "build", 42 | "CVS", 43 | "dist", 44 | "venv", 45 | "{arch}", 46 | ] 47 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m "<marker>"` to select tests 48 | markers = [ 49 | "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')", 50 | "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')", 51 | "system: marks test working at the highest integration level (deselect with '-m \"not system\"')", 52 | "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')", 53 | "docs: mark tests related to documentation (deselect with '-m \"not docs\"')", 54 | "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups", 55 | "pleasefixme: marks tests that are broken and need fixing", 56 | ] 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | httpx[http2] 3 | mcp[cli] 4 | pylatexenc 5 | ray[default] 6 | sglang 7 | torch 8 | wandb 9 | -------------------------------------------------------------------------------- /scripts/agent-example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | export PYTHONBUFFERED=16 16 | 17 | export TP_SIZE=2 18 | export PP_SIZE=1 19 | export CP_SIZE=1 20 | 21 | export HF_MODEL_PATH=/root/hf_models/deepseek-ai--DeepSeek-R1-Distill-Qwen-7B 22 | export MCORE_MODEL_PATH=/root/megatron_model/DeepSeek-R1-Distill-Qwen-7B-25.02 23 | export PROMPT_DATA=/root/dapo-math-17k/dapo-math-17k_processed.jsonl 24 | export MCORE_MODEL_PATH_SAVE=/root/megatron_model/DeepSeek-R1-Distill-Qwen-7B-25.02_save 25 | 26 | # DeepSeek-R1-Distill-Qwen-7B 27 | MODEL_ARGS=( 28 | --swiglu 29 | --num-layers 28 30 | --hidden-size 3584 31 | --ffn-hidden-size 18944 32 | --num-attention-heads 28 33 | --group-query-attention 34 | --num-query-groups 4 35 | --max-position-embeddings 131072 36 | --seq-length 4096 37 | --use-rotary-position-embeddings 38 | --disable-bias-linear 39 | --add-qkv-bias 40 | --normalization "RMSNorm" 41 | --norm-epsilon 1e-06 42 | --rotary-base 10000 43 | --vocab-size 152064 44 | --accumulate-allreduce-grads-in-fp32 45 | --attention-softmax-in-fp32 46 | --attention-backend flash 47 | --moe-token-dispatcher-type alltoall 48 | --untie-embeddings-and-output-weights 49 | --attention-dropout 0.0 50 | --hidden-dropout 0.0 51 | ) 52 | 53 | CKPT_ARGS=( 54 | --hf-checkpoint ${HF_MODEL_PATH} 55 | --ref-load ${MCORE_MODEL_PATH} 56 | --save-interval 100 57 | --save ${MCORE_MODEL_PATH_SAVE} 58 | ) 59 | 60 | ROLLOUT_ARGS=( 61 | --rollout-function-path slime.rollout.agent_rollout.generate_rollout 62 | --rm-type deepscaler 63 | --prompt-data ${PROMPT_DATA} 64 | --label-key label 65 | --num-rollout 3000 66 | --rollout-batch-size 128 67 | --rollout-max-response-len 8192 68 | --rollout-temperature 0.8 69 | --rollout-shuffle 70 | --n-samples-per-prompt 8 71 | --global-batch-size 1024 72 | --micro-batch-size 8 73 | --ref-micro-batch-size 8 74 | --use-dynamic-batch-size 75 | --max-tokens-per-gpu 9216 76 | --balance-data 77 | ) 78 | 79 | DISTRIBUTED_ARGS=( 80 | --tensor-model-parallel-size ${TP_SIZE} 81 | --pipeline-model-parallel-size ${PP_SIZE} 82 | --context-parallel-size ${CP_SIZE} 83 | --sequence-parallel 84 | ) 85 | 86 | PERF_ARGS=( 87 | --recompute-granularity full 88 | --recompute-method uniform 89 | --recompute-num-layers 1 90 | ) 91 | 92 | GRPO_ARGS=( 93 | --advantage-estimator grpo 94 | --use-kl-loss 95 | --kl-loss-coef 0.001 96 | --kl-loss-type low_var_kl 97 | --kl-coef 0.00 98 | --entropy-coef 0.00 99 | ) 100 | 101 | OPTIMIZER_ARGS=( 102 | --lr 1e-6 103 | --lr-decay-style constant 104 | --weight-decay 0.1 105 | --adam-beta1 0.9 106 | --adam-beta2 0.98 107 | ) 108 | 109 | WANDB_ARGS=( 110 | # --use-wandb \ 111 | ) 112 | 113 | # launch the master node of ray in container 114 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 115 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 116 | 117 | ray job submit --address="http://127.0.0.1:8265" \ 118 | --runtime-env-json='{ 119 | "env_vars": { 120 | "PYTHONPATH": "/root/Megatron-LM/", 121 | "CUDA_DEVICE_MAX_CONNECTIONS": "1", 122 | "NCCL_CUMEM_ENABLE": "0" 123 | } 124 | }' \ 125 | -- python3 train_async.py \ 126 | --actor-num-nodes 1 \ 127 | --actor-num-gpus-per-node 4 \ 128 | --rollout-num-gpus 4 \ 129 | --rollout-num-gpus-per-engine 1 \ 130 | --sglang-mem-fraction-static 0.8 \ 131 | ${MODEL_ARGS[@]} \ 132 | ${CKPT_ARGS[@]} \ 133 | ${ROLLOUT_ARGS[@]} \ 134 | ${OPTIMIZER_ARGS[@]} \ 135 | ${GRPO_ARGS[@]} \ 136 | ${DISTRIBUTED_ARGS[@]} \ 137 | ${WANDB_ARGS[@]} \ 138 | ${PERF_ARGS[@]} \ 139 | --agent-rollout-buffer-url http://${MASTER_ADDR}:8889 \ 140 | --keep-old-actor \ 141 | --disable-rewards-normalization \ 142 | --offload-old-actor \ 143 | --offload-ref \ 144 | --loss-mask-type distill_qwen \ 145 | --sglang-log-level error \ 146 | --input-key prompt \ 147 | --log-passrate 148 | -------------------------------------------------------------------------------- /scripts/models/glm4-32B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --spec "slime_plugins.models.glm4" "get_glm_spec" 3 | --swiglu 4 | --num-layers 64 5 | --hidden-size 6144 6 | --ffn-hidden-size 23040 7 | --num-attention-heads 48 8 | --max-position-embeddings 32768 9 | --seq-length 32768 10 | --use-rotary-position-embeddings 11 | --disable-bias-linear 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-5 14 | --rotary-base 10000 15 | --group-query-attention 16 | --num-query-groups 8 17 | --vocab-size 151552 18 | --post-self-attn-layernorm 19 | --post-mlp-layernorm 20 | --rotary-interleaved 21 | --rotary-percent 0.5 22 | --no-rope-fusion 23 | --untie-embeddings-and-output-weights 24 | ) -------------------------------------------------------------------------------- /scripts/models/glm4-9B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --spec "slime_plugins.models.glm4" "get_glm_spec" 3 | --swiglu 4 | --num-layers 40 5 | --hidden-size 4096 6 | --ffn-hidden-size 13696 7 | --num-attention-heads 32 8 | --group-query-attention 9 | --num-query-groups 2 10 | --use-rotary-position-embeddings 11 | --disable-bias-linear 12 | --add-qkv-bias 13 | --normalization "RMSNorm" 14 | --norm-epsilon 1e-5 15 | --rotary-base 10000 16 | --vocab-size 151552 17 | --post-self-attn-layernorm 18 | --post-mlp-layernorm 19 | --rotary-interleaved 20 | --rotary-percent 0.5 21 | --no-rope-fusion 22 | --untie-embeddings-and-output-weights 23 | ) -------------------------------------------------------------------------------- /scripts/models/moonlight.sh: -------------------------------------------------------------------------------- 1 | MOE_SHARED_EXPERTS=2 2 | MOE_FFN_HIDDEN=1408 3 | MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS)) 4 | MOE_ROUTER_TOPK_SCALING_FACTOR=2.446 5 | NLAYERS=27 6 | FIRST_K_DENSE_REPLACE=1 7 | 8 | arr=() 9 | for ((i=0; i<NLAYERS; i++)); do 10 | if (( i < FIRST_K_DENSE_REPLACE )); then 11 | arr+=(0) 12 | else 13 | arr+=(1) 14 | fi 15 | done 16 | 17 | printf -v MOE_LAYER_FREQ "[%s]" "$(IFS=', '; echo "${arr[*]}")" 18 | 19 | # moonlight 20 | MODEL_ARGS=( 21 | --disable-bias-linear 22 | --num-layers 27 23 | --hidden-size 2048 24 | --ffn-hidden-size 11264 25 | --num-attention-heads 16 26 | --kv-channels 128 27 | --normalization RMSNorm 28 | --position-embedding-type rope 29 | --norm-epsilon 1e-5 30 | --rotary-percent 1.0 31 | --swiglu 32 | --untie-embeddings-and-output-weights 33 | --no-masked-softmax-fusion 34 | --vocab-size 163840 35 | 36 | --multi-latent-attention 37 | --kv-lora-rank 512 38 | --qk-head-dim 128 39 | --qk-pos-emb-head-dim 64 40 | --v-head-dim 128 41 | --qk-layernorm 42 | --rotary-scaling-factor 1 43 | --rotary-base 50000 44 | --mscale 1.0 45 | --mscale-all-dim 1.0 46 | --attention-softmax-in-fp32 47 | --no-rope-fusion 48 | 49 | # moe 50 | --num-experts 64 51 | --moe-layer-freq $MOE_LAYER_FREQ 52 | --moe-ffn-hidden-size $MOE_FFN_HIDDEN 53 | --moe-router-topk 6 54 | --moe-shared-expert-intermediate-size $MOE_SHARED_EXPERT_INTERMEDIATE_SIZE 55 | --moe-router-pre-softmax 56 | --moe-router-score-function sigmoid 57 | --moe-router-enable-expert-bias 58 | --moe-router-load-balancing-type seq_aux_loss 59 | --moe-token-dispatcher-type alltoall 60 | --moe-aux-loss-coeff 0 61 | --moe-router-bias-update-rate 0 62 | --moe-router-group-topk 1 63 | --moe-router-num-groups 1 64 | --moe-grouped-gemm 65 | --moe-router-topk-scaling-factor $MOE_ROUTER_TOPK_SCALING_FACTOR 66 | --moe-token-drop-policy probs 67 | --moe-router-dtype fp32 68 | --moe-permute-fusion 69 | ) -------------------------------------------------------------------------------- /scripts/models/qwen2.5-1.5B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 1536 5 | --ffn-hidden-size 8960 6 | --num-attention-heads 12 7 | --use-rotary-position-embeddings 8 | --disable-bias-linear 9 | --add-qkv-bias 10 | --normalization "RMSNorm" 11 | --norm-epsilon 1e-6 12 | --rotary-base 10000 13 | --group-query-attention 14 | --num-query-groups 2 15 | --vocab-size 151936 16 | --untie-embeddings-and-output-weights 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen2.5-32B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 64 4 | --hidden-size 5120 5 | --ffn-hidden-size 27648 6 | --num-attention-heads 40 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --add-qkv-bias 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-5 14 | --rotary-base 1000000 15 | --vocab-size 152064 16 | --untie-embeddings-and-output-weights 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen2.5-3B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 36 4 | --hidden-size 2048 5 | --ffn-hidden-size 11008 6 | --num-attention-heads 16 7 | --use-rotary-position-embeddings 8 | --disable-bias-linear 9 | --add-qkv-bias 10 | --normalization "RMSNorm" 11 | --norm-epsilon 1e-6 12 | --rotary-base 10000 13 | --group-query-attention 14 | --num-query-groups 2 15 | --vocab-size 151936 16 | ) -------------------------------------------------------------------------------- /scripts/models/qwen2.5-7B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 3584 5 | --ffn-hidden-size 18944 6 | --num-attention-heads 28 7 | --group-query-attention 8 | --num-query-groups 4 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --add-qkv-bias 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-06 14 | --rotary-base 1000000 15 | --vocab-size 152064 16 | --untie-embeddings-and-output-weights 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen3-0.6B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 1024 5 | --ffn-hidden-size 3072 6 | --num-attention-heads 16 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base 1000000 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen3-235B-A22B.sh: -------------------------------------------------------------------------------- 1 | # qwen3-235B-a22B 2 | MODEL_ARGS=( 3 | --disable-bias-linear 4 | --qk-layernorm 5 | --group-query-attention 6 | --num-attention-heads 64 7 | --num-query-groups 4 8 | --kv-channels 128 9 | --num-layers 94 10 | --hidden-size 4096 11 | --ffn-hidden-size 12288 12 | 13 | --normalization RMSNorm 14 | --position-embedding-type rope 15 | --norm-epsilon 1e-6 16 | --rotary-percent 1.0 17 | --swiglu 18 | --untie-embeddings-and-output-weights 19 | --vocab-size 151936 20 | 21 | --rotary-base 1000000 22 | 23 | # moe 24 | --moe-ffn-hidden-size 1536 25 | --moe-router-score-function softmax 26 | --moe-token-dispatcher-type alltoall 27 | --moe-router-topk 8 28 | --moe-layer-freq "'([1]*94)'" 29 | --num-experts 128 30 | --moe-grouped-gemm 31 | --moe-token-drop-policy probs 32 | --moe-router-dtype fp32 33 | --moe-permute-fusion 34 | --moe-aux-loss-coeff 0 35 | ) -------------------------------------------------------------------------------- /scripts/models/qwen3-30B-A3B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --disable-bias-linear 3 | --qk-layernorm 4 | --group-query-attention 5 | --num-attention-heads 32 6 | --num-query-groups 4 7 | --kv-channels 128 8 | --num-layers 48 9 | --hidden-size 2048 10 | --ffn-hidden-size 6144 11 | 12 | --normalization RMSNorm 13 | --position-embedding-type rope 14 | --norm-epsilon 1e-6 15 | --rotary-percent 1.0 16 | --swiglu 17 | --untie-embeddings-and-output-weights 18 | --vocab-size 151936 19 | 20 | --rotary-base 1000000 21 | 22 | # moe 23 | --moe-ffn-hidden-size 768 24 | --moe-router-score-function softmax 25 | --moe-token-dispatcher-type alltoall 26 | --moe-router-topk 8 27 | --moe-layer-freq "'([1]*48)'" 28 | --num-experts 128 29 | --moe-grouped-gemm 30 | --moe-token-drop-policy probs 31 | --moe-router-dtype fp32 32 | --moe-permute-fusion 33 | --moe-aux-loss-coeff 0 34 | ) -------------------------------------------------------------------------------- /scripts/models/qwen3-4B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 36 4 | --hidden-size 2560 5 | --ffn-hidden-size 9728 6 | --num-attention-heads 32 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base 1000000 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen3-8B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 36 4 | --hidden-size 4096 5 | --ffn-hidden-size 12288 6 | --num-attention-heads 32 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base 1000000 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | --untie-embeddings-and-output-weights 18 | ) -------------------------------------------------------------------------------- /scripts/run-glm4-9B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # will prevent ray from buffering stdout/stderr 16 | export PYTHONBUFFERED=16 17 | 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then 20 | HAS_NVLINK=1 21 | else 22 | HAS_NVLINK=0 23 | fi 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" 25 | 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 27 | source "${SCRIPT_DIR}/models/glm4-9B.sh" 28 | 29 | CKPT_ARGS=( 30 | --hf-checkpoint /root/GLM-Z1-9B-0414/ 31 | --ref-load /root/GLM-Z1-9B-0414_torch_dist 32 | --load /root/GLM-Z1-9B-0414_slime/ 33 | --save /root/GLM-Z1-9B-0414_slime/ 34 | --save-interval 20 35 | ) 36 | 37 | ROLLOUT_ARGS=( 38 | --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl 39 | --input-key prompt 40 | --label-key label 41 | --apply-chat-template 42 | --rollout-shuffle 43 | 44 | --rm-type deepscaler 45 | 46 | --num-rollout 3000 47 | --rollout-batch-size 32 48 | --n-samples-per-prompt 8 49 | --rollout-max-response-len 8192 50 | --rollout-temperature 0.8 51 | 52 | --global-batch-size 256 53 | --balance-data 54 | ) 55 | 56 | EVAL_ARGS=( 57 | --eval-interval 20 58 | --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl 59 | --n-samples-per-eval-prompt 16 60 | --eval-max-response-len 16384 61 | --eval-top-p 0.7 62 | ) 63 | 64 | PERF_ARGS=( 65 | --tensor-model-parallel-size 2 66 | --sequence-parallel 67 | --pipeline-model-parallel-size 1 68 | --context-parallel-size 2 69 | --expert-model-parallel-size 1 70 | --expert-tensor-parallel-size 1 71 | 72 | --recompute-granularity full 73 | --recompute-method uniform 74 | --recompute-num-layers 1 75 | 76 | # --micro-batch-size 1 77 | --use-dynamic-batch-size 78 | --max-tokens-per-gpu 4608 79 | ) 80 | 81 | GRPO_ARGS=( 82 | --advantage-estimator grpo 83 | --use-kl-loss 84 | --kl-loss-coef 0.00 85 | --kl-loss-type low_var_kl 86 | --kl-coef 0.00 87 | --entropy-coef 0.00 88 | --eps-clip 0.2 89 | --eps-clip-high 0.28 90 | ) 91 | 92 | OPTIMIZER_ARGS=( 93 | --optimizer adam 94 | --lr 1e-6 95 | --lr-decay-style constant 96 | --weight-decay 0.1 97 | --adam-beta1 0.9 98 | --adam-beta2 0.98 99 | ) 100 | 101 | WANDB_ARGS=( 102 | #--use-wandb 103 | # --wandb-project slime-dev 104 | # --wandb-group qwen3-4B-test 105 | # --wandb-key ${WANDB_KEY} 106 | ) 107 | 108 | SGLANG_ARGS=( 109 | --rollout-num-gpus-per-engine 2 110 | ) 111 | 112 | MISC_ARGS=( 113 | # default dropout in megatron is 0.1 114 | --attention-dropout 0.0 115 | --hidden-dropout 0.0 116 | # should be good for model performance 117 | --accumulate-allreduce-grads-in-fp32 118 | --attention-softmax-in-fp32 119 | # need to comment this when using model with MLA 120 | --attention-backend flash 121 | ) 122 | 123 | # launch the master node of ray in container 124 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 125 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 126 | 127 | # Build the runtime environment JSON with proper variable substitution 128 | RUNTIME_ENV_JSON="{ 129 | \"env_vars\": { 130 | \"PYTHONPATH\": \"/root/Megatron-LM/\", 131 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", 132 | \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" 133 | } 134 | }" 135 | 136 | ray job submit --address="http://127.0.0.1:8265" \ 137 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 138 | -- python3 train.py \ 139 | --actor-num-nodes 1 \ 140 | --actor-num-gpus-per-node 4 \ 141 | --rollout-num-gpus 4 \ 142 | ${MODEL_ARGS[@]} \ 143 | ${CKPT_ARGS[@]} \ 144 | ${ROLLOUT_ARGS[@]} \ 145 | ${OPTIMIZER_ARGS[@]} \ 146 | ${GRPO_ARGS[@]} \ 147 | ${DISTRIBUTED_ARGS[@]} \ 148 | ${WANDB_ARGS[@]} \ 149 | ${PERF_ARGS[@]} \ 150 | ${EVAL_ARGS[@]} \ 151 | ${SGLANG_ARGS[@]} \ 152 | ${MISC_ARGS[@]} 153 | -------------------------------------------------------------------------------- /scripts/run-qwen3-235B-A22B-sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # if base folder not set raise error 16 | if [ -z "${BASE_FOLDER}" ]; then 17 | echo "BASE_FOLDER is not set. Please set it to the base directory of your checkpoints." 18 | exit 1 19 | fi 20 | 21 | if [ -z "${MASTER_ADDR}" ]; then 22 | echo "MASTER_ADDR is not set. Please set it to the master node address." 23 | exit 1 24 | fi 25 | 26 | # will prevent ray from buffering stdout/stderr 27 | export PYTHONBUFFERED=16 28 | 29 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) 30 | if [ "$NVLINK_COUNT" -gt 0 ]; then 31 | HAS_NVLINK=1 32 | else 33 | HAS_NVLINK=0 34 | fi 35 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" 36 | 37 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 38 | source "${SCRIPT_DIR}/models/qwen3-235B-A22B.sh" 39 | 40 | CKPT_ARGS=( 41 | --hf-checkpoint ${BASE_FOLDER}/Qwen3-235B-A22B 42 | --ref-load ${BASE_FOLDER}/Qwen3-235B-A22B_torch_dist 43 | --load ${BASE_FOLDER}/Qwen3-235B-A22B_slime/ 44 | --save ${BASE_FOLDER}/Qwen3-235B-A22B_slime/ 45 | --save-interval 1000 46 | ) 47 | 48 | SFT_ARGS=( 49 | --rollout-function-path slime.rollout.sft_example.generate_rollout 50 | --prompt-data ${BASE_FOLDER}/openhermes2_5.parquet 51 | --input-key messages 52 | --rollout-shuffle 53 | --num-epoch 3 54 | --rollout-batch-size 128 55 | --global-batch-size 128 56 | 57 | --loss-type sft_loss 58 | --calculate-per-token-loss 59 | --disable-compute-advantages-and-returns 60 | --debug-train-only 61 | ) 62 | 63 | PERF_ARGS=( 64 | --tensor-model-parallel-size 4 65 | --sequence-parallel 66 | --pipeline-model-parallel-size 1 67 | --context-parallel-size 1 68 | --expert-model-parallel-size 32 69 | --expert-tensor-parallel-size 1 70 | 71 | --recompute-granularity full 72 | --recompute-method uniform 73 | --recompute-num-layers 1 74 | 75 | # --micro-batch-size 1 76 | --use-dynamic-batch-size 77 | --max-tokens-per-gpu 9216 78 | ) 79 | 80 | OPTIMIZER_ARGS=( 81 | --optimizer adam 82 | --lr 1e-5 83 | --lr-warmup-iters 128 84 | --lr-decay-style cosine 85 | --min-lr 1e-6 86 | --lr-warmup-fraction 0.9 87 | --weight-decay 0.1 88 | --adam-beta1 0.9 89 | --adam-beta2 0.98 90 | 91 | --optimizer-cpu-offload 92 | --overlap-cpu-optimizer-d2h-h2d 93 | --use-precision-aware-optimizer 94 | ) 95 | 96 | WANDB_ARGS=( 97 | # --use-wandb 98 | # --wandb-project slime-dev 99 | # --wandb-group qwen3-235B-sft 100 | ) 101 | 102 | MISC_ARGS=( 103 | # default dropout in megatron is 0.1 104 | --attention-dropout 0.0 105 | --hidden-dropout 0.0 106 | # should be good for model performance 107 | --accumulate-allreduce-grads-in-fp32 108 | --attention-softmax-in-fp32 109 | # need to comment this when using model with MLA 110 | --attention-backend flash 111 | ) 112 | 113 | # launch the master node of ray in container 114 | export no_proxy="127.0.0.1,${MASTER_ADDR}" 115 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 116 | for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do 117 | if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then 118 | continue 119 | fi 120 | echo "Starting Ray worker on ${WORKER_IP}" 121 | ssh root@"${WORKER_IP}" \ 122 | "pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats" & 123 | done 124 | wait 125 | 126 | 127 | # Build the runtime environment JSON with proper variable substitution 128 | RUNTIME_ENV_JSON="{ 129 | \"env_vars\": { 130 | \"PYTHONPATH\": \"/root/Megatron-LM/\", 131 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", 132 | \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" 133 | \"no_proxy\": \"${no_proxy}\", 134 | \"MASTER_ADDR\": \"${MASTER_ADDR}\" 135 | } 136 | }" 137 | 138 | ray job submit --address="http://127.0.0.1:8265" \ 139 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 140 | -- python3 train_async.py \ 141 | --actor-num-nodes 4 \ 142 | --actor-num-gpus-per-node 8 \ 143 | ${MODEL_ARGS[@]} \ 144 | ${CKPT_ARGS[@]} \ 145 | ${SFT_ARGS[@]} \ 146 | ${OPTIMIZER_ARGS[@]} \ 147 | ${DISTRIBUTED_ARGS[@]} \ 148 | ${WANDB_ARGS[@]} \ 149 | ${PERF_ARGS[@]} \ 150 | ${EVAL_ARGS[@]} \ 151 | ${MISC_ARGS[@]} 152 | -------------------------------------------------------------------------------- /scripts/run-qwen3-235B-A22B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # if base folder not set raise error 16 | if [ -z "${BASE_FOLDER}" ]; then 17 | echo "BASE_FOLDER is not set. Please set it to the base directory of your checkpoints." 18 | exit 1 19 | fi 20 | 21 | if [ -z "${MASTER_ADDR}" ]; then 22 | echo "MASTER_ADDR is not set. Please set it to the master node address." 23 | exit 1 24 | fi 25 | 26 | # will prevent ray from buffering stdout/stderr 27 | export PYTHONBUFFERED=16 28 | 29 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) 30 | if [ "$NVLINK_COUNT" -gt 0 ]; then 31 | HAS_NVLINK=1 32 | else 33 | HAS_NVLINK=0 34 | fi 35 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" 36 | 37 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 38 | source "${SCRIPT_DIR}/models/qwen3-235B-A22B.sh" 39 | 40 | CKPT_ARGS=( 41 | --hf-checkpoint ${BASE_FOLDER}/Qwen3-235B-A22B 42 | #--hf-checkpoint ${BASE_FOLDER}/Qwen3-235B-A22B-FP8 43 | --ref-load ${BASE_FOLDER}/Qwen3-235B-A22B_torch_dist 44 | --load ${BASE_FOLDER}/Qwen3-235B-A22B_slime/ 45 | --save ${BASE_FOLDER}/Qwen3-235B-A22B_slime/ 46 | --save-interval 20 47 | ) 48 | 49 | ROLLOUT_ARGS=( 50 | --prompt-data ${BASE_FOLDER}/dapo-math-17k/dapo-math-17k.jsonl 51 | --input-key prompt 52 | --label-key label 53 | --apply-chat-template 54 | --rollout-shuffle 55 | 56 | --rm-type deepscaler 57 | 58 | --num-rollout 3000 59 | --rollout-batch-size 8 60 | --n-samples-per-prompt 8 61 | --rollout-max-response-len 8192 62 | --rollout-temperature 0.8 63 | 64 | --global-batch-size 64 65 | --balance-data 66 | ) 67 | 68 | EVAL_ARGS=( 69 | #--eval-interval 20 70 | --eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl 71 | --n-samples-per-eval-prompt 16 72 | --eval-max-response-len 16384 73 | --eval-top-p 0.7 74 | ) 75 | 76 | PERF_ARGS=( 77 | --tensor-model-parallel-size 4 78 | --sequence-parallel 79 | --pipeline-model-parallel-size 4 80 | --context-parallel-size 2 81 | --expert-model-parallel-size 8 82 | --expert-tensor-parallel-size 1 83 | --decoder-last-pipeline-num-layers 22 84 | 85 | --recompute-granularity full 86 | --recompute-method uniform 87 | --recompute-num-layers 1 88 | 89 | # --micro-batch-size 1 90 | --use-dynamic-batch-size 91 | --max-tokens-per-gpu 4096 92 | ) 93 | 94 | GRPO_ARGS=( 95 | --advantage-estimator grpo 96 | #--use-kl-loss 97 | --kl-loss-coef 0.00 98 | --kl-loss-type low_var_kl 99 | --kl-coef 0.00 100 | --entropy-coef 0.00 101 | --eps-clip 0.2 102 | --eps-clip-high 0.28 103 | ) 104 | 105 | OPTIMIZER_ARGS=( 106 | --optimizer adam 107 | --lr 1e-6 108 | --lr-decay-style constant 109 | --weight-decay 0.1 110 | --adam-beta1 0.9 111 | --adam-beta2 0.98 112 | 113 | --optimizer-cpu-offload 114 | --overlap-cpu-optimizer-d2h-h2d 115 | --use-precision-aware-optimizer 116 | ) 117 | 118 | WANDB_ARGS=( 119 | # --use-wandb 120 | # --wandb-project slime-dev 121 | # --wandb-group qwen3-235B-sft 122 | ) 123 | 124 | SGLANG_ARGS=( 125 | --rollout-num-gpus-per-engine 32 126 | --sglang-mem-fraction-static 0.5 127 | --sglang-enable-ep-moe 128 | --sglang-enable-dp-attention 129 | --sglang-dp-size 4 130 | --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) 131 | ) 132 | 133 | MISC_ARGS=( 134 | # default dropout in megatron is 0.1 135 | --attention-dropout 0.0 136 | --hidden-dropout 0.0 137 | # should be good for model performance 138 | --accumulate-allreduce-grads-in-fp32 139 | --attention-softmax-in-fp32 140 | # need to comment this when using model with MLA 141 | --attention-backend flash 142 | ) 143 | 144 | # launch the master node of ray in container 145 | export no_proxy="127.0.0.1,${MASTER_ADDR}" 146 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 147 | for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do 148 | if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then 149 | continue 150 | fi 151 | echo "Starting Ray worker on ${WORKER_IP}" 152 | ssh root@"${WORKER_IP}" \ 153 | "pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats" & 154 | done 155 | wait 156 | 157 | 158 | # Build the runtime environment JSON with proper variable substitution 159 | RUNTIME_ENV_JSON="{ 160 | \"env_vars\": { 161 | \"PYTHONPATH\": \"/root/Megatron-LM/\", 162 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", 163 | \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" 164 | \"no_proxy\": \"${no_proxy}\", 165 | \"MASTER_ADDR\": \"${MASTER_ADDR}\" 166 | } 167 | }" 168 | 169 | ray job submit --address="http://127.0.0.1:8265" \ 170 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 171 | -- python3 train.py \ 172 | --actor-num-nodes 4 \ 173 | --actor-num-gpus-per-node 8 \ 174 | --colocate \ 175 | ${MODEL_ARGS[@]} \ 176 | ${CKPT_ARGS[@]} \ 177 | ${ROLLOUT_ARGS[@]} \ 178 | ${OPTIMIZER_ARGS[@]} \ 179 | ${GRPO_ARGS[@]} \ 180 | ${DISTRIBUTED_ARGS[@]} \ 181 | ${WANDB_ARGS[@]} \ 182 | ${PERF_ARGS[@]} \ 183 | ${EVAL_ARGS[@]} \ 184 | ${SGLANG_ARGS[@]} \ 185 | ${MISC_ARGS[@]} 186 | -------------------------------------------------------------------------------- /scripts/run-qwen3-30B-A3B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # will prevent ray from buffering stdout/stderr 16 | export PYTHONBUFFERED=16 17 | 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then 20 | HAS_NVLINK=1 21 | else 22 | HAS_NVLINK=0 23 | fi 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" 25 | 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 27 | source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh" 28 | 29 | CKPT_ARGS=( 30 | --hf-checkpoint /root/Qwen3-30B-A3B 31 | #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 32 | --ref-load /root/Qwen3-30B-A3B_torch_dist 33 | --load /root/Qwen3-4B_slime/ 34 | --save /root/Qwen3-4B_slime/ 35 | --save-interval 20 36 | ) 37 | 38 | ROLLOUT_ARGS=( 39 | --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl 40 | --input-key prompt 41 | --label-key label 42 | --apply-chat-template 43 | --rollout-shuffle 44 | 45 | --rm-type deepscaler 46 | 47 | --num-rollout 3000 48 | --rollout-batch-size 32 49 | --n-samples-per-prompt 8 50 | --rollout-max-response-len 8192 51 | --rollout-temperature 0.8 52 | 53 | --global-batch-size 256 54 | --balance-data 55 | ) 56 | 57 | EVAL_ARGS=( 58 | --eval-interval 20 59 | --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl 60 | --n-samples-per-eval-prompt 16 61 | --eval-max-response-len 16384 62 | --eval-top-p 0.7 63 | ) 64 | 65 | PERF_ARGS=( 66 | --tensor-model-parallel-size 4 67 | --sequence-parallel 68 | --pipeline-model-parallel-size 1 69 | --context-parallel-size 1 70 | --expert-model-parallel-size 8 71 | --expert-tensor-parallel-size 1 72 | 73 | --recompute-granularity full 74 | --recompute-method uniform 75 | --recompute-num-layers 1 76 | 77 | # --micro-batch-size 1 78 | --use-dynamic-batch-size 79 | --max-tokens-per-gpu 20480 80 | ) 81 | 82 | GRPO_ARGS=( 83 | --advantage-estimator grpo 84 | --use-kl-loss 85 | --kl-loss-coef 0.00 86 | --kl-loss-type low_var_kl 87 | --kl-coef 0.00 88 | --entropy-coef 0.00 89 | --eps-clip 0.2 90 | --eps-clip-high 0.28 91 | ) 92 | 93 | OPTIMIZER_ARGS=( 94 | --optimizer adam 95 | --lr 1e-6 96 | --lr-decay-style constant 97 | --weight-decay 0.1 98 | --adam-beta1 0.9 99 | --adam-beta2 0.98 100 | 101 | --optimizer-cpu-offload 102 | --overlap-cpu-optimizer-d2h-h2d 103 | --use-precision-aware-optimizer 104 | ) 105 | 106 | WANDB_ARGS=( 107 | #--use-wandb 108 | # --wandb-project slime-dev 109 | # --wandb-group qwen3-4B-test 110 | # --wandb-key ${WANDB_KEY} 111 | ) 112 | 113 | SGLANG_ARGS=( 114 | --rollout-num-gpus-per-engine 8 115 | --sglang-mem-fraction-static 0.5 116 | --sglang-enable-ep-moe 117 | --sglang-enable-dp-attention 118 | --sglang-dp-size 8 119 | --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) 120 | ) 121 | 122 | MISC_ARGS=( 123 | # default dropout in megatron is 0.1 124 | --attention-dropout 0.0 125 | --hidden-dropout 0.0 126 | # should be good for model performance 127 | --accumulate-allreduce-grads-in-fp32 128 | --attention-softmax-in-fp32 129 | # need to comment this when using model with MLA 130 | --attention-backend flash 131 | ) 132 | 133 | # launch the master node of ray in container 134 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 135 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 136 | 137 | # Build the runtime environment JSON with proper variable substitution 138 | RUNTIME_ENV_JSON="{ 139 | \"env_vars\": { 140 | \"PYTHONPATH\": \"/root/Megatron-LM/\", 141 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", 142 | \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" 143 | } 144 | }" 145 | 146 | ray job submit --address="http://127.0.0.1:8265" \ 147 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 148 | -- python3 train.py \ 149 | --actor-num-nodes 1 \ 150 | --actor-num-gpus-per-node 8 \ 151 | --colocate \ 152 | ${MODEL_ARGS[@]} \ 153 | ${CKPT_ARGS[@]} \ 154 | ${ROLLOUT_ARGS[@]} \ 155 | ${OPTIMIZER_ARGS[@]} \ 156 | ${GRPO_ARGS[@]} \ 157 | ${DISTRIBUTED_ARGS[@]} \ 158 | ${WANDB_ARGS[@]} \ 159 | ${PERF_ARGS[@]} \ 160 | ${EVAL_ARGS[@]} \ 161 | ${SGLANG_ARGS[@]} \ 162 | ${MISC_ARGS[@]} 163 | -------------------------------------------------------------------------------- /scripts/run-qwen3-4B-amd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # bash scripts/run-qwen3-4B-amd.sh 5 | 6 | 7 | ####clear before training 8 | pkill -9 sglang 9 | sleep 3 10 | ray stop --force 11 | pkill -9 ray 12 | pkill -9 python 13 | sleep 3 14 | pkill -9 ray 15 | pkill -9 python 16 | 17 | 18 | set -euxo pipefail 19 | 20 | 21 | ### AMD Support ### 22 | SLIME_DIR="/home/yushensu/projects/slime" # Need to change to your own path 23 | export SLIME_DIR=$SLIME_DIR 24 | 25 | MODEL_DIR="/home/yushensu/projects/model" # Need to change to your own path 26 | export MODEL_DIR=$MODEL_DIR 27 | 28 | DATA_DIR="/home/yushensu/projects/data" # Need to change to your own path 29 | export DATA_DIR=$DATA_DIR 30 | 31 | # For AMD GPU 32 | export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 33 | export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use 34 | #################### 35 | 36 | 37 | # ### AMD Support ### (If you do not istall, please install them) 38 | # # # Clone and install Megatron-LMi-amd_version 39 | # export MAX_JOBS=512 40 | # cd $SLIME_DIR 41 | # pip uninstall megatron-core -y 42 | # if [ ! -d "Megatron-LM-amd_version" ]; then 43 | # git clone git@github.com:yushengsu-thu/Megatron-LM-amd_version.git 44 | # else 45 | # echo "Megatron-LM-amd_version directory already exists, skipping clone" 46 | # fi 47 | # cd Megatron-LM-amd_version 48 | # pip install -vvv -e . 49 | # cd $SLIME_DIR 50 | 51 | # # Install slime 52 | # pip install -e . 53 | # #################### 54 | 55 | 56 | 57 | # will prevent ray from buffering stdout/stderr 58 | export PYTHONBUFFERED=16 59 | 60 | 61 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 62 | source "${SCRIPT_DIR}/models/qwen3-4B.sh" 63 | 64 | CKPT_ARGS=( 65 | --hf-checkpoint ${MODEL_DIR}/Qwen3-4B 66 | #--hf-checkpoint /root/Qwen3-4B-FP8 67 | --ref-load ${MODEL_DIR}/Qwen3-4B_torch 68 | --load ${MODEL_DIR}/Qwen3-4B_slime/ 69 | --save ${MODEL_DIR}/Qwen3-4B_slime/ 70 | --save-interval 20 71 | ) 72 | 73 | ROLLOUT_ARGS=( 74 | --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl 75 | --input-key prompt 76 | --label-key label 77 | --apply-chat-template 78 | --rollout-shuffle 79 | 80 | --rm-type deepscaler 81 | 82 | --num-rollout 3000 83 | --rollout-batch-size 32 84 | --n-samples-per-prompt 8 85 | --rollout-max-response-len 8192 86 | --rollout-temperature 0.8 87 | 88 | --global-batch-size 256 89 | --balance-data 90 | ) 91 | 92 | EVAL_ARGS=( 93 | --eval-interval 20 94 | --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl 95 | --n-samples-per-eval-prompt 16 96 | --eval-max-response-len 16384 97 | --eval-top-p 0.7 98 | ) 99 | 100 | PERF_ARGS=( 101 | --tensor-model-parallel-size 2 102 | --sequence-parallel 103 | --pipeline-model-parallel-size 1 104 | --context-parallel-size 1 105 | --expert-model-parallel-size 1 106 | --expert-tensor-parallel-size 1 107 | 108 | --recompute-granularity full 109 | --recompute-method uniform 110 | --recompute-num-layers 1 111 | 112 | # --micro-batch-size 1 113 | --use-dynamic-batch-size 114 | --max-tokens-per-gpu 9216 115 | ) 116 | 117 | GRPO_ARGS=( 118 | --advantage-estimator grpo 119 | --use-kl-loss 120 | --kl-loss-coef 0.00 121 | --kl-loss-type low_var_kl 122 | --kl-coef 0.00 123 | --entropy-coef 0.00 124 | --eps-clip 0.2 125 | --eps-clip-high 0.28 126 | ) 127 | 128 | OPTIMIZER_ARGS=( 129 | --optimizer adam 130 | --lr 1e-6 131 | --lr-decay-style constant 132 | --weight-decay 0.1 133 | --adam-beta1 0.9 134 | --adam-beta2 0.98 135 | ) 136 | 137 | WANDB_ARGS=( 138 | #--use-wandb 139 | # --wandb-project slime-dev 140 | # --wandb-group qwen3-4B-test 141 | # --wandb-key ${WANDB_KEY} 142 | ) 143 | 144 | ### AMD Support ### 145 | # Need to fix some issue with torch_memory_saver in rocm to support larger --sglang-mem-fraction-static 146 | # SGLANG_ARGS=( 147 | # --rollout-num-gpus-per-engine 2 148 | # --sglang-mem-fraction-static 0.7 149 | # ) 150 | SGLANG_ARGS=( 151 | --rollout-num-gpus-per-engine 2 152 | --sglang-mem-fraction-static 0.4 153 | ) 154 | #################### 155 | 156 | 157 | MISC_ARGS=( 158 | # default dropout in megatron is 0.1 159 | --attention-dropout 0.0 160 | --hidden-dropout 0.0 161 | # should be good for model performance 162 | --accumulate-allreduce-grads-in-fp32 163 | --attention-softmax-in-fp32 164 | # need to comment this when using model with MLA 165 | --attention-backend flash 166 | ### AMD Support ### 167 | # disable gradient accumulation fusion: Need to add apex to enable this 168 | --no-gradient-accumulation-fusion 169 | ################### 170 | ) 171 | 172 | # launch the master node of ray in container 173 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 174 | 175 | NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) 176 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats 177 | 178 | 179 | # "PYTHONPATH": "/workspace/Megatron-LM-amd_version/", 180 | MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}') 181 | 182 | ray job submit --address="http://127.0.0.1:8265" \ 183 | --runtime-env-json='{ 184 | "env_vars": { 185 | "PYTHONPATH": "/workspace/Megatron-LM-amd_version/", 186 | "CUDA_DEVICE_MAX_CONNECTIONS": "1" 187 | } 188 | }' \ 189 | -- python3 train.py \ 190 | --actor-num-nodes 1 \ 191 | --actor-num-gpus-per-node 8 \ 192 | --colocate \ 193 | ${MODEL_ARGS[@]} \ 194 | ${CKPT_ARGS[@]} \ 195 | ${ROLLOUT_ARGS[@]} \ 196 | ${OPTIMIZER_ARGS[@]} \ 197 | ${GRPO_ARGS[@]} \ 198 | ${DISTRIBUTED_ARGS[@]} \ 199 | ${WANDB_ARGS[@]} \ 200 | ${PERF_ARGS[@]} \ 201 | ${EVAL_ARGS[@]} \ 202 | ${SGLANG_ARGS[@]} \ 203 | ${MISC_ARGS[@]} 204 | 205 | 206 | 207 | ####clear after training 208 | 209 | pkill -9 sglang 210 | sleep 3 211 | ray stop --force 212 | pkill -9 ray 213 | pkill -9 python 214 | sleep 3 215 | pkill -9 ray 216 | pkill -9 python 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | -------------------------------------------------------------------------------- /scripts/run-qwen3-4B-base-sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # will prevent ray from buffering stdout/stderr 16 | export PYTHONBUFFERED=16 17 | 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then 20 | HAS_NVLINK=1 21 | else 22 | HAS_NVLINK=0 23 | fi 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" 25 | 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 27 | source "${SCRIPT_DIR}/models/qwen3-4B.sh" 28 | 29 | CKPT_ARGS=( 30 | --hf-checkpoint /root/Qwen3-4B-Base/ 31 | --ref-load /root/Qwen3-4B-Base_torch_dist 32 | --load /root/Qwen3-4B-Base_slime/ 33 | --save /root/Qwen3-4B-Base_slime/ 34 | --save-interval 1000 35 | ) 36 | 37 | SFT_ARGS=( 38 | --rollout-function-path slime.rollout.sft_example.generate_rollout 39 | --prompt-data /root/openhermes2_5.parquet 40 | --input-key messages 41 | --rollout-shuffle 42 | --num-epoch 3 43 | --rollout-batch-size 128 44 | --global-batch-size 128 45 | 46 | --loss-type sft_loss 47 | --calculate-per-token-loss 48 | --disable-compute-advantages-and-returns 49 | --debug-train-only 50 | ) 51 | 52 | PERF_ARGS=( 53 | --tensor-model-parallel-size 1 54 | --sequence-parallel 55 | --pipeline-model-parallel-size 1 56 | --context-parallel-size 1 57 | --expert-model-parallel-size 1 58 | --expert-tensor-parallel-size 1 59 | 60 | --recompute-granularity full 61 | --recompute-method uniform 62 | --recompute-num-layers 1 63 | 64 | # --micro-batch-size 1 65 | --use-dynamic-batch-size 66 | --max-tokens-per-gpu 9216 67 | ) 68 | 69 | OPTIMIZER_ARGS=( 70 | --optimizer adam 71 | --lr 1e-5 72 | --lr-warmup-iters 128 73 | --lr-decay-style cosine 74 | --min-lr 1e-6 75 | --lr-warmup-fraction 0.9 76 | --weight-decay 0.1 77 | --adam-beta1 0.9 78 | --adam-beta2 0.95 79 | ) 80 | 81 | WANDB_ARGS=( 82 | # --use-wandb 83 | # --wandb-project slime-dev 84 | # --wandb-group qwen3-4B-base-sft 85 | # --wandb-key ${WANDB_KEY} 86 | ) 87 | 88 | MISC_ARGS=( 89 | # default dropout in megatron is 0.1 90 | --attention-dropout 0.0 91 | --hidden-dropout 0.0 92 | # should be good for model performance 93 | --accumulate-allreduce-grads-in-fp32 94 | --attention-softmax-in-fp32 95 | # need to comment this when using model with MLA 96 | --attention-backend flash 97 | ) 98 | 99 | # launch the master node of ray in container 100 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 101 | export no_proxy="127.0.0.1,${MASTER_ADDR}" 102 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 103 | 104 | 105 | # Build the runtime environment JSON with proper variable substitution 106 | RUNTIME_ENV_JSON="{ 107 | \"env_vars\": { 108 | \"PYTHONPATH\": \"/root/Megatron-LM/\", 109 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", 110 | \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" 111 | } 112 | }" 113 | 114 | ray job submit --address="http://127.0.0.1:8265" \ 115 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 116 | -- python3 train_async.py \ 117 | --actor-num-nodes 1 \ 118 | --actor-num-gpus-per-node 8 \ 119 | ${MODEL_ARGS[@]} \ 120 | ${CKPT_ARGS[@]} \ 121 | ${SFT_ARGS[@]} \ 122 | ${OPTIMIZER_ARGS[@]} \ 123 | ${DISTRIBUTED_ARGS[@]} \ 124 | ${WANDB_ARGS[@]} \ 125 | ${PERF_ARGS[@]} \ 126 | ${EVAL_ARGS[@]} \ 127 | ${MISC_ARGS[@]} 128 | -------------------------------------------------------------------------------- /scripts/run-qwen3-4B-rf-baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # will prevent ray from buffering stdout/stderr 16 | export PYTHONBUFFERED=16 17 | 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then 20 | HAS_NVLINK=1 21 | else 22 | HAS_NVLINK=0 23 | fi 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" 25 | 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 27 | source "${SCRIPT_DIR}/models/qwen3-4B.sh" 28 | 29 | CKPT_ARGS=( 30 | --hf-checkpoint /root/Qwen3-4B 31 | #--hf-checkpoint /root/Qwen3-4B-FP8 32 | --ref-load /root/Qwen3-4B_torch_dist 33 | --load /root/Qwen3-4B_slime/ 34 | --save /root/Qwen3-4B_slime/ 35 | --save-interval 20 36 | ) 37 | 38 | ROLLOUT_ARGS=( 39 | --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl 40 | --input-key prompt 41 | --label-key label 42 | --apply-chat-template 43 | --rollout-shuffle 44 | --rm-type deepscaler 45 | --num-rollout 3000 46 | --rollout-batch-size 32 47 | --n-samples-per-prompt 8 48 | --rollout-max-response-len 8192 49 | --rollout-temperature 0.8 50 | 51 | --global-batch-size 256 52 | --balance-data 53 | ) 54 | 55 | EVAL_ARGS=( 56 | --eval-interval 20 57 | --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl 58 | --n-samples-per-eval-prompt 16 59 | --eval-max-response-len 16384 60 | --eval-top-p 0.7 61 | ) 62 | 63 | PERF_ARGS=( 64 | --tensor-model-parallel-size 1 65 | --sequence-parallel 66 | --pipeline-model-parallel-size 1 67 | --context-parallel-size 1 68 | --expert-model-parallel-size 1 69 | --expert-tensor-parallel-size 1 70 | 71 | --recompute-granularity full 72 | --recompute-method uniform 73 | --recompute-num-layers 1 74 | 75 | # --micro-batch-size 1 76 | --use-dynamic-batch-size 77 | --max-tokens-per-gpu 9216 78 | ) 79 | 80 | GRPO_ARGS=( 81 | --advantage-estimator reinforce_plus_plus_baseline 82 | --use-kl-loss 83 | --kl-loss-coef 0.00 84 | --kl-loss-type low_var_kl 85 | --kl-coef 0.00 86 | --entropy-coef 0.00 87 | --eps-clip 0.2 88 | --eps-clip-high 0.28 89 | --normalize-advantages 90 | ) 91 | 92 | OPTIMIZER_ARGS=( 93 | --optimizer adam 94 | --lr 1e-6 95 | --lr-decay-style constant 96 | --weight-decay 0.1 97 | --adam-beta1 0.9 98 | --adam-beta2 0.98 99 | ) 100 | 101 | WANDB_ARGS=( 102 | # --use-wandb 103 | # --wandb-project slime-dev 104 | # --wandb-group qwen3-4B-test-reinforce_plus_plus-baseline 105 | # --wandb-key ${WANDB_KEY} 106 | ) 107 | 108 | SGLANG_ARGS=( 109 | --rollout-num-gpus-per-engine 2 110 | --sglang-mem-fraction-static 0.7 111 | ) 112 | 113 | MISC_ARGS=( 114 | # default dropout in megatron is 0.1 115 | --attention-dropout 0.0 116 | --hidden-dropout 0.0 117 | # should be good for model performance 118 | --accumulate-allreduce-grads-in-fp32 119 | --attention-softmax-in-fp32 120 | # need to comment this when using model with MLA 121 | --attention-backend flash 122 | ) 123 | 124 | # launch the master node of ray in container 125 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 126 | export MASTER_PORT=${MASTER_PORT:-"12345"} 127 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 128 | 129 | # Build the runtime environment JSON with proper variable substitution 130 | RUNTIME_ENV_JSON="{ 131 | \"env_vars\": { 132 | \"PYTHONPATH\": \"/root/Megatron-LM/\", 133 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", 134 | \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" 135 | } 136 | }" 137 | 138 | ray job submit --address="http://127.0.0.1:8265" \ 139 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 140 | -- python3 train.py \ 141 | --actor-num-nodes 1 \ 142 | --actor-num-gpus-per-node 8 \ 143 | --colocate \ 144 | ${MODEL_ARGS[@]} \ 145 | ${CKPT_ARGS[@]} \ 146 | ${ROLLOUT_ARGS[@]} \ 147 | ${OPTIMIZER_ARGS[@]} \ 148 | ${GRPO_ARGS[@]} \ 149 | ${DISTRIBUTED_ARGS[@]} \ 150 | ${WANDB_ARGS[@]} \ 151 | ${PERF_ARGS[@]} \ 152 | ${EVAL_ARGS[@]} \ 153 | ${SGLANG_ARGS[@]} \ 154 | ${MISC_ARGS[@]} -------------------------------------------------------------------------------- /scripts/run-qwen3-4B-rf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # will prevent ray from buffering stdout/stderr 16 | export PYTHONBUFFERED=16 17 | 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then 20 | HAS_NVLINK=1 21 | else 22 | HAS_NVLINK=0 23 | fi 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" 25 | 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 27 | source "${SCRIPT_DIR}/models/qwen3-4B.sh" 28 | 29 | CKPT_ARGS=( 30 | --hf-checkpoint /root/Qwen3-4B 31 | #--hf-checkpoint /root/Qwen3-4B-FP8 32 | --ref-load /root/Qwen3-4B_torch_dist 33 | --load /root/Qwen3-4B_slime/ 34 | --save /root/Qwen3-4B_slime/ 35 | --save-interval 20 36 | ) 37 | 38 | ROLLOUT_ARGS=( 39 | --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl 40 | --input-key prompt 41 | --label-key label 42 | --apply-chat-template 43 | --rollout-shuffle 44 | --rm-type deepscaler 45 | --num-rollout 3000 46 | --rollout-batch-size 32 47 | --n-samples-per-prompt 8 48 | --rollout-max-response-len 8192 49 | --rollout-temperature 0.8 50 | 51 | --global-batch-size 256 52 | --balance-data 53 | ) 54 | 55 | EVAL_ARGS=( 56 | --eval-interval 20 57 | --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl 58 | --n-samples-per-eval-prompt 16 59 | --eval-max-response-len 16384 60 | --eval-top-p 0.7 61 | ) 62 | 63 | PERF_ARGS=( 64 | --tensor-model-parallel-size 1 65 | --sequence-parallel 66 | --pipeline-model-parallel-size 1 67 | --context-parallel-size 1 68 | --expert-model-parallel-size 1 69 | --expert-tensor-parallel-size 1 70 | 71 | --recompute-granularity full 72 | --recompute-method uniform 73 | --recompute-num-layers 1 74 | 75 | # --micro-batch-size 1 76 | --use-dynamic-batch-size 77 | --max-tokens-per-gpu 9216 78 | ) 79 | 80 | GRPO_ARGS=( 81 | --advantage-estimator reinforce_plus_plus 82 | --use-kl-loss 83 | --kl-loss-coef 0.00 84 | --kl-loss-type low_var_kl 85 | --kl-coef 0.005 86 | --entropy-coef 0.00 87 | --eps-clip 0.2 88 | --eps-clip-high 0.28 89 | --gamma 1.0 90 | --normalize-advantages 91 | ) 92 | 93 | OPTIMIZER_ARGS=( 94 | --optimizer adam 95 | --lr 5e-7 96 | --lr-decay-style constant 97 | --weight-decay 0.1 98 | --adam-beta1 0.9 99 | --adam-beta2 0.98 100 | ) 101 | 102 | WANDB_ARGS=( 103 | # --use-wandb 104 | # --wandb-project slime-dev 105 | # --wandb-group qwen3-4B-test-reinforce_plus_plus 106 | # --wandb-key ${WANDB_KEY} 107 | ) 108 | 109 | SGLANG_ARGS=( 110 | --rollout-num-gpus-per-engine 2 111 | --sglang-mem-fraction-static 0.7 112 | ) 113 | 114 | MISC_ARGS=( 115 | # default dropout in megatron is 0.1 116 | --attention-dropout 0.0 117 | --hidden-dropout 0.0 118 | # should be good for model performance 119 | --accumulate-allreduce-grads-in-fp32 120 | --attention-softmax-in-fp32 121 | # need to comment this when using model with MLA 122 | --attention-backend flash 123 | ) 124 | 125 | # launch the master node of ray in container 126 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 127 | export MASTER_PORT=${MASTER_PORT:-"12345"} 128 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 129 | 130 | # Build the runtime environment JSON with proper variable substitution 131 | RUNTIME_ENV_JSON="{ 132 | \"env_vars\": { 133 | \"PYTHONPATH\": \"/root/Megatron-LM/\", 134 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", 135 | \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" 136 | } 137 | }" 138 | 139 | ray job submit --address="http://127.0.0.1:8265" \ 140 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 141 | -- python3 train.py \ 142 | --actor-num-nodes 1 \ 143 | --actor-num-gpus-per-node 8 \ 144 | --colocate \ 145 | ${MODEL_ARGS[@]} \ 146 | ${CKPT_ARGS[@]} \ 147 | ${ROLLOUT_ARGS[@]} \ 148 | ${OPTIMIZER_ARGS[@]} \ 149 | ${GRPO_ARGS[@]} \ 150 | ${DISTRIBUTED_ARGS[@]} \ 151 | ${WANDB_ARGS[@]} \ 152 | ${PERF_ARGS[@]} \ 153 | ${EVAL_ARGS[@]} \ 154 | ${SGLANG_ARGS[@]} \ 155 | ${MISC_ARGS[@]} -------------------------------------------------------------------------------- /scripts/run-qwen3-4B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # will prevent ray from buffering stdout/stderr 16 | export PYTHONBUFFERED=16 17 | 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then 20 | HAS_NVLINK=1 21 | else 22 | HAS_NVLINK=0 23 | fi 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" 25 | 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 27 | source "${SCRIPT_DIR}/models/qwen3-4B.sh" 28 | 29 | CKPT_ARGS=( 30 | --hf-checkpoint /root/Qwen3-4B 31 | #--hf-checkpoint /root/Qwen3-4B-FP8 32 | --ref-load /root/Qwen3-4B_torch_dist 33 | --load /root/Qwen3-4B_slime/ 34 | --save /root/Qwen3-4B_slime/ 35 | --save-interval 20 36 | ) 37 | 38 | ROLLOUT_ARGS=( 39 | --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl 40 | --input-key prompt 41 | --label-key label 42 | --apply-chat-template 43 | --rollout-shuffle 44 | --rm-type deepscaler 45 | --num-rollout 3000 46 | --rollout-batch-size 32 47 | --n-samples-per-prompt 8 48 | --rollout-max-response-len 8192 49 | --rollout-temperature 0.8 50 | 51 | --global-batch-size 256 52 | --balance-data 53 | ) 54 | 55 | EVAL_ARGS=( 56 | --eval-interval 20 57 | --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl 58 | --n-samples-per-eval-prompt 16 59 | --eval-max-response-len 16384 60 | --eval-top-p 0.7 61 | ) 62 | 63 | PERF_ARGS=( 64 | --tensor-model-parallel-size 2 65 | --sequence-parallel 66 | --pipeline-model-parallel-size 1 67 | --context-parallel-size 1 68 | --expert-model-parallel-size 1 69 | --expert-tensor-parallel-size 1 70 | 71 | --recompute-granularity full 72 | --recompute-method uniform 73 | --recompute-num-layers 1 74 | 75 | # --micro-batch-size 1 76 | --use-dynamic-batch-size 77 | --max-tokens-per-gpu 9216 78 | ) 79 | 80 | GRPO_ARGS=( 81 | --advantage-estimator grpo 82 | --use-kl-loss 83 | --kl-loss-coef 0.00 84 | --kl-loss-type low_var_kl 85 | --kl-coef 0.00 86 | --entropy-coef 0.00 87 | --eps-clip 0.2 88 | --eps-clip-high 0.28 89 | ) 90 | 91 | OPTIMIZER_ARGS=( 92 | --optimizer adam 93 | --lr 1e-6 94 | --lr-decay-style constant 95 | --weight-decay 0.1 96 | --adam-beta1 0.9 97 | --adam-beta2 0.98 98 | ) 99 | 100 | WANDB_ARGS=( 101 | # --use-wandb 102 | # --wandb-project slime-dev 103 | # --wandb-group qwen3-4B-test 104 | # --wandb-key ${WANDB_KEY} 105 | ) 106 | 107 | SGLANG_ARGS=( 108 | --rollout-num-gpus-per-engine 2 109 | --sglang-mem-fraction-static 0.7 110 | ) 111 | 112 | MISC_ARGS=( 113 | # default dropout in megatron is 0.1 114 | --attention-dropout 0.0 115 | --hidden-dropout 0.0 116 | # should be good for model performance 117 | --accumulate-allreduce-grads-in-fp32 118 | --attention-softmax-in-fp32 119 | # need to comment this when using model with MLA 120 | --attention-backend flash 121 | ) 122 | 123 | # launch the master node of ray in container 124 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 125 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 126 | 127 | # Build the runtime environment JSON with proper variable substitution 128 | RUNTIME_ENV_JSON="{ 129 | \"env_vars\": { 130 | \"PYTHONPATH\": \"/root/Megatron-LM/\", 131 | \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", 132 | \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" 133 | } 134 | }" 135 | 136 | ray job submit --address="http://127.0.0.1:8265" \ 137 | --runtime-env-json="${RUNTIME_ENV_JSON}" \ 138 | -- python3 train.py \ 139 | --actor-num-nodes 1 \ 140 | --actor-num-gpus-per-node 8 \ 141 | --colocate \ 142 | ${MODEL_ARGS[@]} \ 143 | ${CKPT_ARGS[@]} \ 144 | ${ROLLOUT_ARGS[@]} \ 145 | ${OPTIMIZER_ARGS[@]} \ 146 | ${GRPO_ARGS[@]} \ 147 | ${DISTRIBUTED_ARGS[@]} \ 148 | ${WANDB_ARGS[@]} \ 149 | ${PERF_ARGS[@]} \ 150 | ${EVAL_ARGS[@]} \ 151 | ${SGLANG_ARGS[@]} \ 152 | ${MISC_ARGS[@]} -------------------------------------------------------------------------------- /scripts/run_agent.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | set -e 5 | 6 | SESSION_NAME="slime_run" 7 | WINDOW_1="slime" 8 | WINDOW_2="buffer" 9 | 10 | if tmux has-session -t $SESSION_NAME 2>/dev/null; then 11 | echo "Killing existing tmux session: $SESSION_NAME" 12 | tmux kill-session -t $SESSION_NAME 13 | fi 14 | 15 | tmux new-session -d -s $SESSION_NAME -n $WINDOW_1 16 | tmux send-keys -t ${SESSION_NAME}:${WINDOW_1} "cd $(pwd)" C-m 17 | tmux send-keys -t ${SESSION_NAME}:${WINDOW_1} "bash ./scripts/agent-example.sh" C-m 18 | 19 | tmux new-window -t $SESSION_NAME -n $WINDOW_2 20 | tmux send-keys -t ${SESSION_NAME}:${WINDOW_2} "sleep 30 && cd slime_plugins/rollout_buffer && python buffer.py" C-m 21 | 22 | tmux attach-session -t $SESSION_NAME -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import platform 3 | 4 | from setuptools import find_packages, setup 5 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 6 | 7 | 8 | def _fetch_requirements(path): 9 | with open(path, "r") as fd: 10 | return [r.strip() for r in fd.readlines()] 11 | 12 | 13 | # Custom wheel class to modify the wheel name 14 | class bdist_wheel(_bdist_wheel): 15 | def finalize_options(self): 16 | _bdist_wheel.finalize_options(self) 17 | self.root_is_pure = False 18 | 19 | def get_tag(self): 20 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 21 | abi_tag = f"{python_version}" 22 | 23 | if platform.system() == "Linux": 24 | platform_tag = "manylinux1_x86_64" 25 | else: 26 | platform_tag = platform.system().lower() 27 | 28 | return python_version, abi_tag, platform_tag 29 | 30 | 31 | # Setup configuration 32 | setup( 33 | author="slime Team", 34 | name="slime", 35 | version="0.0.1", 36 | packages=find_packages(include=["slime*", "slime_plugins*"]), 37 | include_package_data=True, 38 | install_requires=_fetch_requirements("requirements.txt"), 39 | python_requires=">=3.10", 40 | classifiers=[ 41 | "Programming Language :: Python :: 3.10", 42 | "Programming Language :: Python :: 3.11", 43 | "Programming Language :: Python :: 3.12", 44 | "Environment :: GPU :: NVIDIA CUDA", 45 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 46 | "Topic :: System :: Distributed Computing", 47 | ], 48 | cmdclass={"bdist_wheel": bdist_wheel}, 49 | ) 50 | -------------------------------------------------------------------------------- /slime/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime/__init__.py -------------------------------------------------------------------------------- /slime/backends/megatron_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .arguments import _vocab_size_with_padding, parse_args, validate_args 4 | from .checkpoint import load_checkpoint, save_checkpoint 5 | from .data import ( 6 | get_batch, 7 | get_data_iterator, 8 | log_eval_data, 9 | log_multi_turn_data, 10 | log_passrate, 11 | log_perf_data, 12 | log_rollout_data, 13 | process_rollout_data, 14 | set_metadata, 15 | ) 16 | from .initialize import get_gloo_group, init 17 | from .loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function 18 | from .model import forward_only, initialize_model_and_optimizer, save, train 19 | 20 | logging.getLogger().setLevel(logging.WARNING) 21 | 22 | 23 | __all__ = [ 24 | "parse_args", 25 | "validate_args", 26 | "load_checkpoint", 27 | "save_checkpoint", 28 | "get_batch", 29 | "get_data_iterator", 30 | "get_gloo_group", 31 | "process_rollout_data", 32 | "init", 33 | "set_metadata", 34 | "get_log_probs_and_entropy", 35 | "log_rollout_data", 36 | "log_passrate", 37 | "log_multi_turn_data", 38 | "log_eval_data", 39 | "log_perf_data", 40 | "compute_advantages_and_returns", 41 | "loss_function", 42 | "forward_only", 43 | "train", 44 | "save", 45 | "initialize_model_and_optimizer", 46 | "_vocab_size_with_padding", 47 | ] 48 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/arguments.py: -------------------------------------------------------------------------------- 1 | from megatron.training.arguments import parse_args, validate_args 2 | from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding 3 | 4 | __all__ = ["validate_args", "parse_args", "_vocab_size_with_padding"] 5 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # TODO: may need to copy those 2 functions and do refactoring. 2 | from megatron.training.checkpointing import load_checkpoint, save_checkpoint 3 | 4 | __all__ = ["load_checkpoint", "save_checkpoint"] 5 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/cp_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from megatron.core import mpu 3 | 4 | 5 | def get_logits_and_tokens_offset_with_cp( 6 | total_length: int, 7 | response_length: int, 8 | ): 9 | """ 10 | All offsets start from the begining of the prompt. 11 | """ 12 | cp_rank = mpu.get_context_parallel_rank() 13 | cp_size = mpu.get_context_parallel_world_size() 14 | assert cp_size > 1 15 | 16 | prompt_length = total_length - response_length 17 | chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size) 18 | 19 | # the offset of 2 chunks 20 | chunk_0 = (cp_rank * chunk_size, (cp_rank + 1) * chunk_size) 21 | chunk_1 = ((2 * cp_size - cp_rank - 1) * chunk_size, (2 * cp_size - cp_rank) * chunk_size) 22 | 23 | # the offset of 2 logits, note that the logits need a "-1". 24 | logits_0 = (max(chunk_0[0], prompt_length - 1), min(chunk_0[1], total_length - 1)) 25 | logits_1 = (max(chunk_1[0], prompt_length - 1), min(chunk_1[1], total_length - 1)) 26 | 27 | # when the sequence is empty, make an empty slice to continue the gradient flow. 28 | if logits_0[0] < logits_0[1]: 29 | token_0 = (logits_0[0] + 1, logits_0[1] + 1) 30 | else: 31 | logits_0 = (0, 0) 32 | token_0 = (0, 0) 33 | 34 | if logits_1[0] < logits_1[1]: 35 | token_1 = (logits_1[0] + 1, logits_1[1] + 1) 36 | else: 37 | logits_1 = (0, 0) 38 | token_1 = (0, 0) 39 | 40 | return chunk_size, (chunk_0, chunk_1), (logits_0, logits_1), (token_0, token_1) 41 | 42 | 43 | def get_sum_of_sample_mean( 44 | total_lengths, 45 | response_lengths, 46 | loss_masks, 47 | calculate_per_token_loss: bool = False, 48 | ): 49 | """ 50 | Calculate correct sample mean for CP 51 | """ 52 | cp_size = mpu.get_context_parallel_world_size() 53 | if cp_size == 1: 54 | 55 | def sum_of_sample_mean(x: torch.Tensor): 56 | return sum( 57 | [ 58 | (x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1) 59 | for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks) 60 | ] 61 | ) 62 | 63 | def sum_of_token(x: torch.Tensor): 64 | return sum( 65 | [(x_i * loss_mask_i).sum() for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks)] 66 | ) 67 | 68 | else: 69 | cp_chunk_lengths = [] 70 | chunked_loss_masks = [] 71 | for i, (total_length, response_length, loss_mask) in enumerate( 72 | zip(total_lengths, response_lengths, loss_masks) 73 | ): 74 | prompt_length = total_length - response_length 75 | _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(total_length, response_length) 76 | loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] 77 | loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] 78 | chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0)) 79 | cp_chunk_lengths.append(chunked_loss_masks[i].size(0)) 80 | 81 | def sum_of_sample_mean(x): 82 | return sum( 83 | [ 84 | (x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) 85 | for x_i, chunked_loss_mask, loss_mask in zip( 86 | x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks 87 | ) 88 | ] 89 | ) 90 | 91 | def sum_of_token(x: torch.Tensor): 92 | return sum( 93 | [ 94 | (x_i * chunked_loss_mask).sum() 95 | for x_i, chunked_loss_mask in zip(x.split(cp_chunk_lengths, dim=0), chunked_loss_masks) 96 | ] 97 | ) 98 | 99 | return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token 100 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/NVIDIA/Megatron-LM/blob/b1efb3c7126ef7615e8c333432d76e08038e17ff/pretrain_gpt.py 2 | import inspect 3 | from contextlib import nullcontext 4 | 5 | from megatron.core.enums import ModelType 6 | from megatron.core.models.gpt import GPTModel 7 | from megatron.core.models.gpt.gpt_layer_specs import ( 8 | get_gpt_decoder_block_spec, 9 | get_gpt_layer_local_spec, 10 | get_gpt_layer_with_transformer_engine_spec, 11 | ) 12 | from megatron.core.transformer.spec_utils import import_module 13 | from megatron.training import get_args 14 | from megatron.training.arguments import core_transformer_config_from_args 15 | 16 | 17 | def model_provider(pre_process=True, post_process=True) -> GPTModel: 18 | """Builds the model. 19 | 20 | If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. 21 | 22 | Args: 23 | pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. 24 | post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. 25 | 26 | 27 | Returns: 28 | Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model 29 | """ 30 | args = get_args() 31 | use_te = args.transformer_impl == "transformer_engine" 32 | 33 | # Experimental loading arguments from yaml 34 | config = core_transformer_config_from_args(args) 35 | 36 | if args.spec is not None: 37 | transformer_layer_spec = import_module(args.spec) 38 | # Allow the spec to be a function so that user can use customized Megatron easier. 39 | if callable(transformer_layer_spec): 40 | transformer_layer_spec = transformer_layer_spec(args) 41 | else: 42 | if args.num_experts: 43 | # Define the decoder block spec 44 | transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te) 45 | else: 46 | # Define the decoder layer spec 47 | if use_te: 48 | transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( 49 | args.num_experts, 50 | args.moe_grouped_gemm, 51 | args.qk_layernorm, 52 | args.multi_latent_attention, 53 | args.moe_use_legacy_grouped_gemm, 54 | ) 55 | else: 56 | transformer_layer_spec = get_gpt_layer_local_spec( 57 | args.num_experts, 58 | args.moe_grouped_gemm, 59 | args.qk_layernorm, 60 | args.multi_latent_attention, 61 | args.moe_use_legacy_grouped_gemm, 62 | ) 63 | 64 | build_model_context = nullcontext 65 | build_model_context_args = {} 66 | if args.fp8_param_gather: 67 | try: 68 | from transformer_engine.pytorch import fp8_model_init 69 | 70 | build_model_context = fp8_model_init 71 | build_model_context_args["enabled"] = True 72 | 73 | # Check if fp8_model_init supports preserve_high_precision_init_val 74 | if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters: 75 | build_model_context_args["preserve_high_precision_init_val"] = True 76 | except: 77 | raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.") 78 | 79 | kwargs = { 80 | "config": config, 81 | "transformer_layer_spec": transformer_layer_spec, 82 | "vocab_size": args.padded_vocab_size, 83 | "max_sequence_length": args.max_position_embeddings, 84 | "pre_process": pre_process, 85 | "post_process": post_process, 86 | "fp16_lm_cross_entropy": args.fp16_lm_cross_entropy, 87 | "parallel_output": True, 88 | "share_embeddings_and_output_weights": not args.untie_embeddings_and_output_weights, 89 | "position_embedding_type": args.position_embedding_type, 90 | "rotary_percent": args.rotary_percent, 91 | "rotary_base": args.rotary_base, 92 | "rope_scaling": args.use_rope_scaling, 93 | } 94 | 95 | if getattr(args, "mtp_num_layers", None): 96 | from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec 97 | 98 | mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te) 99 | kwargs["mtp_block_spec"] = mtp_block_spec 100 | 101 | with build_model_context(**build_model_context_args): 102 | model = GPTModel(**kwargs) 103 | 104 | return model 105 | 106 | 107 | def get_model_provider_and_type(): 108 | return model_provider, ModelType.encoder_or_decoder 109 | -------------------------------------------------------------------------------- /slime/backends/sglang_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime/backends/sglang_utils/__init__.py -------------------------------------------------------------------------------- /slime/backends/sglang_utils/arguments.py: -------------------------------------------------------------------------------- 1 | from sglang.srt.server_args import ServerArgs 2 | 3 | 4 | def add_sglang_router_arguments(parser): 5 | """ 6 | Add arguments to the parser for the SGLang router. 7 | """ 8 | parser.add_argument( 9 | "--sglang-router-ip", 10 | type=str, 11 | default=None, 12 | help="IP address of the SGLang router", 13 | ) 14 | parser.add_argument( 15 | "--sglang-router-port", 16 | type=int, 17 | default=None, 18 | help="Port of the SGLang router", 19 | ) 20 | return parser 21 | 22 | 23 | def add_sglang_arguments(parser): 24 | """ 25 | Add arguments to the parser for the SGLang server. 26 | """ 27 | parser = add_sglang_router_arguments(parser) 28 | parser.add_argument("--sglang-server-concurrency", type=int, default=512) 29 | 30 | old_add_argument = parser.add_argument 31 | 32 | skipped_args = [ 33 | "model_path", 34 | "dtype", 35 | "trust_remote_code", 36 | "random_seed", 37 | # memory 38 | "enable_memory_saver", 39 | # distributed 40 | "tp_size", 41 | "port", 42 | "nnodes", 43 | "node_rank", 44 | "dist_init_addr", 45 | "gpu_id_step", 46 | "base_gpu_id", 47 | "nccl_port", 48 | "skip_server_warmup", 49 | ] 50 | 51 | def new_add_argument_wrapper(*name_or_flags, **kwargs): 52 | """ 53 | Add arguments to the parser, ensuring that the server arguments are prefixed and skippable. 54 | """ 55 | # Determine the canonical name for skip check (e.g., "model_path") 56 | canonical_name_for_skip_check = None 57 | if "dest" in kwargs: 58 | canonical_name_for_skip_check = kwargs["dest"] 59 | else: 60 | for flag_name_candidate in name_or_flags: 61 | if isinstance(flag_name_candidate, str) and flag_name_candidate.startswith("--"): 62 | # Derive from first long flag: --foo-bar -> foo_bar 63 | stem = flag_name_candidate[2:] 64 | canonical_name_for_skip_check = stem.replace("-", "_") 65 | break 66 | # If no long flag and no dest, skip logic might not catch it unless short flags imply a dest. 67 | 68 | if canonical_name_for_skip_check and canonical_name_for_skip_check in skipped_args: 69 | return # Skip this entire argument definition 70 | 71 | # If not skipped, proceed to prefix flags and dest 72 | new_name_or_flags_list = [] 73 | for item_flag in name_or_flags: 74 | if isinstance(item_flag, str) and item_flag.startswith("-"): 75 | original_flag_stem = item_flag.lstrip("-") # "foo-bar" from "--foo-bar", or "f" from "-f" 76 | prefixed_item = f"--sglang-{original_flag_stem}" 77 | new_name_or_flags_list.append(prefixed_item) 78 | else: 79 | # Positional arguments or non-string items 80 | new_name_or_flags_list.append(item_flag) 81 | 82 | # Prepare kwargs for the actual add_argument call. 83 | # Make a copy to avoid modifying the original kwargs dict. 84 | final_kwargs = kwargs.copy() 85 | 86 | # If 'dest' is explicitly provided and is a string, prefix it. 87 | # This ensures the attribute on the args namespace becomes, e.g., args.sglang_dest_name. 88 | if "dest" in final_kwargs and isinstance(final_kwargs["dest"], str): 89 | original_dest = final_kwargs["dest"] 90 | # Avoid double prefixing if dest somehow already starts with sglang_ 91 | if not original_dest.startswith("sglang_"): 92 | final_kwargs["dest"] = f"sglang_{original_dest}" 93 | # If 'dest' is not explicitly provided (or is None/not a string), 94 | # argparse will derive 'dest' from the (now prefixed) flag names. 95 | # E.g., if the first flag is "--sglang-foo-bar", argparse sets dest to "sglang_foo_bar". 96 | 97 | old_add_argument(*new_name_or_flags_list, **final_kwargs) 98 | 99 | parser.add_argument = new_add_argument_wrapper 100 | ServerArgs.add_cli_args(parser) 101 | parser.add_argument = old_add_argument 102 | 103 | return parser 104 | 105 | 106 | def validate_args(args): 107 | # sglang 108 | args.sglang_tp_size = args.rollout_num_gpus_per_engine 109 | args.sglang_dp_size = args.sglang_data_parallel_size 110 | args.sglang_pp_size = args.sglang_pipeline_parallel_size 111 | args.sglang_ep_size = args.sglang_expert_parallel_size 112 | 113 | if args.sglang_dp_size > 1: 114 | assert args.sglang_enable_dp_attention 115 | -------------------------------------------------------------------------------- /slime/backends/sglang_utils/sglang_engine.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import os 4 | from typing import TYPE_CHECKING 5 | 6 | from sglang.srt.server_args import ServerArgs 7 | from slime.utils.http_utils import get_host_info 8 | from .http_server_engine import HttpServerEngineAdapter 9 | 10 | if TYPE_CHECKING: 11 | pass 12 | 13 | 14 | def get_base_gpu_id(args, rank): 15 | num_gpus = min(8, args.rollout_num_gpus_per_engine) 16 | if args.colocate: 17 | start_index = (rank * num_gpus) % 8 18 | else: 19 | num_actor_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes 20 | start_index = (num_actor_gpus + rank * num_gpus) % 8 21 | return start_index 22 | 23 | 24 | class SglangEngine: 25 | 26 | def __init__(self, args, rank, dist_init_addr, port, nccl_port): 27 | self.args = args 28 | 29 | # remove the CUDA_VISIBLE_DEVICES set by ray and use base_gpu_id 30 | os.environ.pop("CUDA_VISIBLE_DEVICES", None) 31 | 32 | nnodes = max(1, args.rollout_num_gpus_per_engine // 8) 33 | node_rank = rank % nnodes 34 | kwargs = { 35 | "model_path": args.hf_checkpoint, 36 | "trust_remote_code": True, 37 | "random_seed": args.seed + rank, 38 | # memory 39 | "enable_memory_saver": args.offload, 40 | # distributed 41 | "host": get_host_info()[1], 42 | "port": port, 43 | "nccl_port": nccl_port, 44 | "nnodes": nnodes, 45 | "node_rank": node_rank, 46 | "dist_init_addr": dist_init_addr, 47 | "gpu_id_step": 1, 48 | "base_gpu_id": get_base_gpu_id(args, rank), 49 | # parallel 50 | "tp_size": args.rollout_num_gpus_per_engine, 51 | "dp_size": args.sglang_dp_size, 52 | "pp_size": args.sglang_pp_size, 53 | "ep_size": args.sglang_ep_size, 54 | # always skip warmup to prevent warmup timeout. 55 | "skip_server_warmup": True, 56 | } 57 | 58 | unused_keys = set(kwargs.keys()) 59 | for attr in dataclasses.fields(ServerArgs): 60 | if hasattr(args, f"sglang_{attr.name}") and attr.name not in kwargs: 61 | kwargs[attr.name] = getattr(args, f"sglang_{attr.name}") 62 | unused_keys.discard(attr.name) 63 | 64 | # for compatibility with old args 65 | if len(unused_keys) > 0: 66 | print(f"Warning: The following arguments is not supported in the current sglang: {unused_keys}.") 67 | for key in unused_keys: 68 | kwargs.pop(key) 69 | 70 | self.llm = HttpServerEngineAdapter( 71 | router_ip=args.sglang_router_ip, router_port=args.sglang_router_port, **kwargs 72 | ) 73 | 74 | def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): 75 | return self.llm.init_weights_update_group( 76 | master_address, master_port, rank_offset, world_size, group_name, backend 77 | ) 78 | 79 | def update_weights_from_distributed(self, names, dtypes, shapes, group_name): 80 | self.llm.update_weights_from_distributed(names, dtypes, shapes, group_name) 81 | return 82 | 83 | def update_weights_from_tensor(self, ipc_handles): 84 | self.llm.update_weights_from_tensor(ipc_handles) 85 | return 86 | 87 | def reset_prefix_cache(self): 88 | self.llm.flush_cache() 89 | 90 | def sleep(self, level=1): 91 | # Adhoc solution to ensure no running requests 92 | self.llm.flush_cache() 93 | self.llm.release_memory_occupation() 94 | 95 | def wake_up(self): 96 | self.llm.resume_memory_occupation() 97 | 98 | def pause_generation(self): 99 | self.llm.pause_generation() 100 | 101 | def continue_generation(self): 102 | self.llm.continue_generation() 103 | -------------------------------------------------------------------------------- /slime/ray/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime/ray/__init__.py -------------------------------------------------------------------------------- /slime/ray/placement_group.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import ray 3 | from ray.util.placement_group import placement_group 4 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 5 | 6 | from .ppo_actor import RayTrainGroup 7 | from .rollout import RolloutGroup 8 | 9 | 10 | @ray.remote(num_gpus=1) 11 | class InfoActor: 12 | def get_ip_and_gpu_id(self): 13 | return ray.util.get_node_ip_address(), ray.get_gpu_ids()[0] 14 | 15 | 16 | def sort_key(x): 17 | index, node_identifier, gpu_id = x 18 | # Sort by node IP number and then by GPU ID 19 | try: 20 | # try to parse it as an IP address. 21 | ip_address = node_identifier 22 | node_ip_parts = list(map(int, ip_address.split("."))) 23 | except ValueError: 24 | # Try to resolve the hostname to an IP address. 25 | try: 26 | ip_address = socket.gethostbyname(node_identifier) 27 | node_ip_parts = list(map(int, ip_address.split("."))) 28 | except (socket.gaierror, TypeError): 29 | # Instead, we convert each character of the original identifier string 30 | # to its ASCII value. This provides a stable and consistent numerical 31 | # representation that allows for sorting. 32 | node_ip_parts = [ord(c) for c in node_identifier] 33 | 34 | return (node_ip_parts, gpu_id) 35 | 36 | 37 | def _create_placement_group(num_gpus): 38 | """Create a placement group with the specified number of GPUs.""" 39 | bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] 40 | pg = placement_group(bundles, strategy="PACK") 41 | num_bundles = len(bundles) 42 | 43 | ray.get(pg.ready()) 44 | # use info actor to get the GPU id 45 | info_actors = [] 46 | for i in range(num_bundles): 47 | info_actors.append( 48 | InfoActor.options( 49 | scheduling_strategy=PlacementGroupSchedulingStrategy( 50 | placement_group=pg, 51 | placement_group_bundle_index=i, 52 | ) 53 | ).remote() 54 | ) 55 | gpu_ids = ray.get([actor.get_ip_and_gpu_id.remote() for actor in info_actors]) 56 | for actor in info_actors: 57 | ray.kill(actor) 58 | 59 | bundle_infos = [(i, gpu_ids[i][0], gpu_ids[i][1]) for i in range(num_bundles)] 60 | pg_reordered_bundle_indices = [bundle_info[0] for bundle_info in sorted(bundle_infos, key=sort_key)] 61 | for i in range(num_bundles): 62 | actual_bundle_index = pg_reordered_bundle_indices[i] 63 | print( 64 | f" bundle {i:4}, actual_bundle_index: {actual_bundle_index:4}, " 65 | f"node: {gpu_ids[actual_bundle_index][0]}, gpu: {gpu_ids[actual_bundle_index][1]}" 66 | ) 67 | 68 | return pg, pg_reordered_bundle_indices 69 | 70 | 71 | def create_placement_groups(args): 72 | """Create placement groups for actor and rollout engines.""" 73 | 74 | num_gpus = 0 75 | if args.debug_train_only: 76 | num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node 77 | rollout_offset = 0 78 | elif args.debug_rollout_only: 79 | num_gpus = args.rollout_num_gpus 80 | rollout_offset = 0 81 | elif args.colocate: 82 | num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node 83 | rollout_offset = 0 84 | else: 85 | num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + args.rollout_num_gpus 86 | rollout_offset = args.actor_num_nodes * args.actor_num_gpus_per_node 87 | 88 | print(f"Creating placement group with {num_gpus} GPUs...") 89 | pg, actor_pg_reordered_bundle_indices = _create_placement_group(num_gpus) 90 | 91 | rollout_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[rollout_offset:] 92 | 93 | return { 94 | "actor": (pg, actor_pg_reordered_bundle_indices), 95 | "rollout": (pg, rollout_pg_reordered_bundle_indices), 96 | } 97 | 98 | 99 | def allocate_train_group(num_nodes, num_gpus_per_node, pg): 100 | return RayTrainGroup( 101 | num_nodes=num_nodes, 102 | num_gpus_per_node=num_gpus_per_node, 103 | pg=pg, 104 | num_gpus_per_actor=0.8, 105 | ) 106 | 107 | 108 | def create_actor_group(args, pg): 109 | actor_model = allocate_train_group( 110 | num_nodes=args.actor_num_nodes, 111 | num_gpus_per_node=args.actor_num_gpus_per_node, 112 | pg=pg, 113 | ) 114 | return actor_model 115 | 116 | 117 | def create_rollout_group(args, pg): 118 | return RolloutGroup(args, pg) 119 | -------------------------------------------------------------------------------- /slime/ray/ray_actor.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from slime.utils.http_utils import is_port_available 3 | 4 | 5 | class RayActor: 6 | @staticmethod 7 | def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1): 8 | address = ray._private.services.get_node_ip_address() 9 | # strip ipv6 address 10 | address = address.strip("[]") 11 | 12 | # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available 13 | port = start_port 14 | while not all(is_port_available(port + i) for i in range(consecutive)): 15 | port += 1 16 | 17 | return address, port 18 | 19 | def get_master_addr_and_port(self): 20 | return self.master_addr, self.master_port 21 | -------------------------------------------------------------------------------- /slime/ray/utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ray/utils.py#L1 2 | import os 3 | 4 | import ray 5 | import torch 6 | 7 | 8 | def ray_noset_visible_devices(env_vars=os.environ): 9 | # Refer to 10 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 11 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 12 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 13 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 14 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 15 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 16 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 17 | NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ 18 | "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", 19 | "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", 20 | "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", 21 | "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", 22 | "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", 23 | "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", 24 | "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", 25 | ] 26 | return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) 27 | 28 | 29 | def get_physical_gpu_id(): 30 | device = torch.cuda.current_device() 31 | props = torch.cuda.get_device_properties(device) 32 | return str(props.uuid) 33 | 34 | 35 | @ray.remote 36 | class Lock: 37 | def __init__(self): 38 | self._locked = False # False: unlocked, True: locked 39 | 40 | def acquire(self): 41 | """ 42 | Try to acquire the lock. Returns True if acquired, False otherwise. 43 | Caller should retry until it returns True. 44 | """ 45 | if not self._locked: 46 | self._locked = True 47 | return True 48 | return False 49 | 50 | def release(self): 51 | """Release the lock, allowing others to acquire.""" 52 | assert self._locked, "Lock is not acquired, cannot release." 53 | self._locked = False 54 | -------------------------------------------------------------------------------- /slime/rollout/filter_hub/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime/rollout/filter_hub/__init__.py -------------------------------------------------------------------------------- /slime/rollout/filter_hub/dynamic_sampling_filters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from slime.utils.types import Sample 3 | 4 | 5 | __all__ = ["check_reward_nonzero_std"] 6 | 7 | 8 | def check_reward_nonzero_std(args, samples: list[Sample], **kwargs): 9 | rewards = [sample.reward for sample in samples] 10 | return torch.tensor(rewards, dtype=torch.float).std() > 0.0 11 | -------------------------------------------------------------------------------- /slime/rollout/filter_hub/over_sampling_filters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from slime.utils.types import Sample 3 | 4 | 5 | __all__ = ["sort_by_reward_std"] 6 | 7 | 8 | def sort_by_reward_std(args, samples: list[list[Sample]], **kwargs) -> list[list[Sample]]: 9 | samples_with_std = [] 10 | for group in samples: 11 | rewards = [item.reward for item in group] 12 | std = torch.tensor(rewards, dtype=torch.float).std() 13 | samples_with_std.append((group, std)) 14 | # python sort is stable, so the order of samples with the same std is preserved 15 | samples_with_std.sort(key=lambda x: x[1], reverse=True) 16 | return [item[0] for item in samples_with_std] 17 | -------------------------------------------------------------------------------- /slime/rollout/rm_hub/__init__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Union 3 | 4 | import aiohttp 5 | 6 | from slime.utils.misc import load_function 7 | from slime.utils.types import Sample 8 | 9 | from .deepscaler import get_deepscaler_rule_based_reward 10 | from .f1 import f1_score 11 | from .math_dapo_utils import compute_score as compute_score_dapo 12 | from .math_utils import extract_answer as extract_boxed_answer 13 | from .math_utils import grade_answer_verl 14 | 15 | 16 | async def remote_rm(args, sample: Sample): 17 | payload = { 18 | "prompt": sample.prompt, 19 | "response": sample.response, 20 | "label": sample.label, 21 | } 22 | session_kwargs = {} 23 | async with aiohttp.ClientSession(**session_kwargs) as session: 24 | async with session.post(args.rm_url, json=payload) as resp: 25 | resp.raise_for_status() 26 | return await resp.json() 27 | 28 | 29 | async def async_rm(args, sample: Sample, **kwargs): 30 | if args.custom_rm_path is not None: 31 | rm_function = load_function(args.custom_rm_path) 32 | return await rm_function(args, sample, **kwargs) 33 | 34 | rm_type = args.rm_type 35 | response = sample.response 36 | label = sample.label 37 | if rm_type.startswith("boxed_"): 38 | response = extract_boxed_answer(response) 39 | rm_type = rm_type[len("boxed_") :] 40 | 41 | # This function is intended for remote or time-consuming reward model evaluation. 42 | # Implement the actual logic as needed. 43 | if rm_type == "remote_rm": 44 | return await remote_rm(args, sample) 45 | elif rm_type == "deepscaler": 46 | return get_deepscaler_rule_based_reward(response, label) 47 | elif rm_type == "dapo": 48 | return compute_score_dapo(response, label) 49 | elif rm_type == "math": 50 | return 1 if grade_answer_verl(response, label) else 0 51 | elif rm_type == "f1": 52 | return f1_score(response, label)[0] 53 | else: 54 | raise NotImplementedError(f"Rule-based RM for {type} is not implemented.") 55 | 56 | 57 | async def batched_async_rm( 58 | args, 59 | samples: list[Sample], 60 | **kwargs, 61 | ) -> list[Union[int, float]]: 62 | if args.custom_rm_path is not None: 63 | rm_function = load_function(args.custom_rm_path) 64 | return await rm_function(args, samples, **kwargs) 65 | 66 | rm_type = args.rm_type 67 | prompts = [sample.prompt for sample in samples] 68 | responses = [sample.response for sample in samples] 69 | labels = [sample.label for sample in samples] 70 | if labels is None: 71 | labels = [None] * len(prompts) 72 | tasks = [ 73 | async_rm(rm_type, prompt, response, label, **kwargs) 74 | for prompt, response, label in zip(prompts, responses, labels) 75 | ] 76 | rewards = await asyncio.gather(*tasks) 77 | return rewards 78 | -------------------------------------------------------------------------------- /slime/rollout/rm_hub/deepscaler.py: -------------------------------------------------------------------------------- 1 | from .math_utils import extract_answer, grade_answer_mathd, grade_answer_sympy 2 | 3 | 4 | def get_deepscaler_rule_based_reward(response, label): 5 | if "</think>" in response: 6 | model_solution = response.split("</think>")[1] 7 | elif "###Response" in response: 8 | model_solution = response.split("###Response")[1] 9 | else: 10 | return 0 11 | 12 | model_answer = extract_answer(model_solution) 13 | if model_answer is None: 14 | return 0 15 | if label == "": 16 | return 0 17 | 18 | # Convert single answer to list for uniform processing 19 | assert isinstance(label, (str, float, int)) 20 | ground_truths = [label] 21 | 22 | # Process each ground truth 23 | processed_ground_truths = [] 24 | for truth in ground_truths: 25 | truth = str(truth) 26 | if "\\boxed" in truth: 27 | processed_truth = extract_answer(truth) 28 | if processed_truth is not None: 29 | processed_ground_truths.append(processed_truth) 30 | else: 31 | processed_ground_truths.append(truth) 32 | 33 | if not processed_ground_truths: 34 | return 0 35 | 36 | # Check against all possible correct answers 37 | for ground_truth in processed_ground_truths: 38 | is_correct = grade_answer_mathd(model_answer, ground_truth) or grade_answer_sympy(model_answer, ground_truth) 39 | if is_correct: 40 | return 1 41 | 42 | return 0 43 | -------------------------------------------------------------------------------- /slime/rollout/rm_hub/f1.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from collections import Counter 4 | 5 | 6 | def normalize_answer(s): 7 | 8 | def remove_articles(text): 9 | return re.sub(r"\b(a|an|the)\b", " ", text) 10 | 11 | def white_space_fix(text): 12 | return " ".join(text.split()) 13 | 14 | def remove_punc(text): 15 | exclude = set(string.punctuation) 16 | return "".join(ch for ch in text if ch not in exclude) 17 | 18 | def lower(text): 19 | return text.lower() 20 | 21 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 22 | 23 | 24 | def f1_score(prediction, ground_truth): 25 | ZERO_METRIC = (0, 0, 0) 26 | 27 | if prediction is None: 28 | return ZERO_METRIC 29 | 30 | normalized_prediction = normalize_answer(prediction) 31 | normalized_ground_truth = normalize_answer(ground_truth) 32 | 33 | if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: 34 | return ZERO_METRIC 35 | if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: 36 | return ZERO_METRIC 37 | 38 | prediction_tokens = normalized_prediction.split() 39 | ground_truth_tokens = normalized_ground_truth.split() 40 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 41 | num_same = sum(common.values()) 42 | if num_same == 0: 43 | return ZERO_METRIC 44 | precision = 1.0 * num_same / len(prediction_tokens) 45 | recall = 1.0 * num_same / len(ground_truth_tokens) 46 | f1 = (2 * precision * recall) / (precision + recall) 47 | return f1, precision, recall 48 | -------------------------------------------------------------------------------- /slime/rollout/sft_example.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | from slime.utils.mask_utils import MultiTurnLossMaskGenerator 4 | 5 | __all__ = ["generate_rollout"] 6 | 7 | 8 | TOKENIZER = None 9 | MASK_GENERATOR = None 10 | 11 | 12 | def generate_rollout(args, rollout_id, data_buffer, evaluation=False): 13 | """An example to implement the generate_rollout function for an rule based rm rollout generation. 14 | 15 | Args: 16 | args: the whole args 17 | rollout_id: int, the id of the rollout, used for deterministic data generation 18 | data_buffer: the data buffer to store the generated samples 19 | evaluation: bool, whether the rollout is for evaluation or not 20 | 21 | Returns: 22 | list[Sample]: a list of samples generated by the rollout 23 | """ 24 | assert not evaluation 25 | assert args.rollout_global_dataset 26 | 27 | global TOKENIZER, MASK_GENERATOR 28 | if TOKENIZER is None: 29 | TOKENIZER = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) 30 | 31 | if MASK_GENERATOR is None: 32 | MASK_GENERATOR = MultiTurnLossMaskGenerator(TOKENIZER, tokenizer_type=args.loss_mask_type) 33 | 34 | samples = data_buffer.get_samples(args.rollout_batch_size) 35 | 36 | for sample in samples: 37 | (sample,) = sample 38 | messages = sample.prompt 39 | token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages) 40 | response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0] 41 | 42 | sample.tokens = token_ids 43 | sample.response_length = response_length 44 | sample.reward = 0 45 | sample.loss_mask = loss_mask[-response_length:] 46 | 47 | return samples 48 | -------------------------------------------------------------------------------- /slime/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .arguments import parse_args 2 | 3 | __all__ = ["parse_args"] 4 | -------------------------------------------------------------------------------- /slime/utils/async_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | 4 | __all__ = ["get_async_loop", "run"] 5 | 6 | 7 | # Create a background event loop thread 8 | class AsyncLoopThread: 9 | def __init__(self): 10 | self.loop = asyncio.new_event_loop() 11 | self._thread = threading.Thread(target=self._start_loop, daemon=True) 12 | self._thread.start() 13 | 14 | def _start_loop(self): 15 | asyncio.set_event_loop(self.loop) 16 | self.loop.run_forever() 17 | 18 | def run(self, coro): 19 | # Schedule a coroutine onto the loop and block until it's done 20 | return asyncio.run_coroutine_threadsafe(coro, self.loop).result() 21 | 22 | 23 | # Create one global instance 24 | async_loop = None 25 | 26 | 27 | def get_async_loop(): 28 | global async_loop 29 | if async_loop is None: 30 | async_loop = AsyncLoopThread() 31 | return async_loop 32 | 33 | 34 | def run(coro): 35 | """Run a coroutine in the background event loop.""" 36 | return get_async_loop().run(coro) 37 | -------------------------------------------------------------------------------- /slime/utils/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pandas as pd 4 | 5 | from slime.utils.types import Sample 6 | 7 | __all__ = ["Dataset"] 8 | 9 | 10 | # TODO: don't read the whole file into memory. 11 | def read_file(path): 12 | if path.endswith(".jsonl"): 13 | df = pd.read_json(path, lines=True) 14 | elif path.endswith(".parquet"): 15 | df = pd.read_parquet(path) 16 | else: 17 | raise ValueError(f"Unsupported file format: {path}. Supported formats are .jsonl and .parquet.") 18 | 19 | for _, row in df.iterrows(): 20 | yield row.to_dict() 21 | 22 | 23 | class Dataset: 24 | def __init__( 25 | self, 26 | path, 27 | tokenizer, 28 | max_length, 29 | *, 30 | prompt_key="text", 31 | label_key=None, 32 | tool_key=None, 33 | metadata_key="metadata", 34 | seed=42, 35 | apply_chat_template=False, 36 | ): 37 | self.origin_samples = [] 38 | for data in read_file(path): 39 | prompt = data[prompt_key] 40 | if apply_chat_template: 41 | if tool_key is not None: 42 | tools = data[tool_key] 43 | else: 44 | tools = None 45 | prompt = tokenizer.apply_chat_template(prompt, tools, tokenize=False, add_generation_prompt=True) 46 | 47 | # TODO: this is slow. 48 | if max_length is not None: 49 | if len(tokenizer(prompt)["input_ids"]) > max_length: 50 | continue 51 | 52 | self.origin_samples.append( 53 | Sample( 54 | prompt=prompt, 55 | label=data[label_key] if label_key is not None else None, 56 | metadata=data.get(metadata_key) or {}, 57 | ) 58 | ) 59 | 60 | self.epoch_id = -1 61 | self.seed = seed 62 | self.samples = self.origin_samples 63 | 64 | def shuffle(self, new_epoch_id): 65 | if self.epoch_id == new_epoch_id: 66 | return 67 | 68 | random.seed(self.seed + new_epoch_id) 69 | permutation = list(range(len(self.samples))) 70 | random.shuffle(permutation) 71 | self.samples = [self.origin_samples[i] for i in permutation] 72 | self.epoch_id = new_epoch_id 73 | 74 | def __getitem__(self, idx): 75 | return self.samples[idx] 76 | 77 | def __len__(self): 78 | return len(self.samples) 79 | -------------------------------------------------------------------------------- /slime/utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any, Optional, Union 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch.distributed.distributed_c10d import ( 7 | Backend, 8 | PrefixStore, 9 | Store, 10 | _new_process_group_helper, 11 | _world, 12 | default_pg_timeout, 13 | rendezvous, 14 | ) 15 | from megatron.core import mpu 16 | 17 | 18 | # Copy from pytorch to allow creating multiple main groups. 19 | # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py 20 | def init_process_group( 21 | backend: Union[str, Backend] = None, 22 | init_method: Optional[str] = None, 23 | timeout: Optional[timedelta] = None, 24 | world_size: int = -1, 25 | rank: int = -1, 26 | store: Optional[Store] = None, 27 | group_name: str = None, 28 | pg_options: Optional[Any] = None, 29 | ): 30 | assert (store is None) or (init_method is None), "Cannot specify both init_method and store." 31 | 32 | if store is not None: 33 | assert world_size > 0, "world_size must be positive if using store" 34 | assert rank >= 0, "rank must be non-negative if using store" 35 | elif init_method is None: 36 | init_method = "env://" 37 | 38 | if backend: 39 | backend = Backend(backend) 40 | else: 41 | backend = Backend("undefined") 42 | 43 | if timeout is None: 44 | timeout = default_pg_timeout 45 | 46 | # backward compatible API 47 | if store is None: 48 | rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) 49 | store, rank, world_size = next(rendezvous_iterator) 50 | store.set_timeout(timeout) 51 | 52 | # Use a PrefixStore to avoid accidental overrides of keys used by 53 | # different systems (e.g. RPC) in case the store is multi-tenant. 54 | store = PrefixStore(group_name, store) 55 | 56 | # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 57 | # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 58 | # We need to determine the appropriate parameter name based on PyTorch version 59 | pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" 60 | pg, _ = _new_process_group_helper( 61 | world_size, 62 | rank, 63 | [], 64 | backend, 65 | store, 66 | group_name=group_name, 67 | **{pg_options_param_name: pg_options}, 68 | timeout=timeout, 69 | ) 70 | 71 | _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} 72 | 73 | return pg 74 | 75 | 76 | def distributed_masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True, epsilon: float = 1e-8): 77 | """ 78 | Performs whitening on a tensor using global statistics from all participating GPUs. 79 | 80 | It calculates the global mean and variance across all ranks in the default 81 | process group (the WORLD) and uses these global statistics to normalize the 82 | local data on each rank. 83 | 84 | Args: 85 | values (torch.Tensor): The local tensor of values to whiten. 86 | mask (torch.Tensor): The local mask corresponding to the values. 87 | shift_mean (bool): If True, the output is zero-mean. Defaults to True. 88 | epsilon (float): A small value for numerical stability. 89 | 90 | Returns: 91 | torch.Tensor: The locally whitened tensor using global statistics. 92 | """ 93 | # Calculate local intermediate statistics 94 | local_sum = (values * mask).sum() 95 | local_sum_sq = ((values**2) * mask).sum() 96 | local_mask_sum = mask.sum() 97 | 98 | stats_tensor = torch.tensor( 99 | [local_sum, local_sum_sq, local_mask_sum], 100 | device=values.device, 101 | dtype=torch.float32, 102 | ) 103 | 104 | # Aggregate via all_reduce within the DP group 105 | dist.all_reduce(stats_tensor) 106 | 107 | # Calculate global stats from aggregated results 108 | global_sum, global_sum_sq, global_mask_sum = stats_tensor 109 | 110 | if global_mask_sum.item() == 0: 111 | raise ValueError("The global mask sum across all participating GPUs is zero.") 112 | 113 | global_mean = global_sum / global_mask_sum 114 | global_mean_sq = global_sum_sq / global_mask_sum 115 | global_var = global_mean_sq - global_mean**2 116 | 117 | # Bessel's correction for unbiased estimate 118 | if global_mask_sum.item() >= 2: 119 | bessel_correction = global_mask_sum / (global_mask_sum - 1) 120 | global_var = global_var * bessel_correction 121 | 122 | # Whiten local data using global stats 123 | whitened_values = (values - global_mean) * torch.rsqrt(global_var + epsilon) 124 | 125 | if not shift_mean: 126 | whitened_values += global_mean 127 | 128 | return whitened_values -------------------------------------------------------------------------------- /slime/utils/flops_utils.py: -------------------------------------------------------------------------------- 1 | def calculate_embedding_flops(seqlen, hidden_size): 2 | return 2 * seqlen * hidden_size 3 | 4 | 5 | def calculate_lm_head_flops(seqlen, hidden_size, vocab_size): 6 | return 2 * seqlen * hidden_size * vocab_size 7 | 8 | 9 | def calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups): 10 | head_dim = hidden_size // num_attention_heads 11 | n_q_heads = num_attention_heads 12 | n_kv_heads = num_query_groups 13 | q_flops = 2 * seqlen * hidden_size * n_q_heads * head_dim 14 | kv_flops = 2 * seqlen * hidden_size * n_kv_heads * head_dim * 2 15 | return q_flops + kv_flops 16 | 17 | 18 | def calculate_attention_flops(seqlen, num_attention_heads, head_dim): 19 | # QK^T 20 | flops = 2 * num_attention_heads * seqlen * seqlen * head_dim 21 | # A*V 22 | flops += 2 * num_attention_heads * seqlen * seqlen * head_dim 23 | return flops 24 | 25 | 26 | def calculate_output_flops(seqlen, hidden_size): 27 | return 2 * seqlen * hidden_size * hidden_size 28 | 29 | 30 | def calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size): 31 | return 2 * seqlen * hidden_size * ffn_hidden_size * 3 32 | 33 | 34 | def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size): 35 | head_dim = hidden_size // num_attention_heads 36 | return ( 37 | calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups) 38 | + calculate_attention_flops(seqlen, num_attention_heads, head_dim) 39 | + calculate_output_flops(seqlen, hidden_size) 40 | + calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size) 41 | ) 42 | 43 | 44 | def calculate_fwd_flops( 45 | seqlens, 46 | args, 47 | ): 48 | hidden_size = args.hidden_size 49 | num_attention_heads = args.num_attention_heads 50 | num_query_groups = args.num_query_groups 51 | vocab_size = args.vocab_size 52 | 53 | total_flops = 0 54 | 55 | dense_ffn = args.ffn_hidden_size 56 | if args.num_experts is None: 57 | num_dense_layers = args.num_layers 58 | num_moe_layers = 0 59 | else: 60 | shared_expert_ffn = getattr(args, "moe_shared_expert_intermediate_size", None) 61 | if shared_expert_ffn is None: 62 | shared_expert_ffn = 0 63 | 64 | moe_ffn = args.moe_ffn_hidden_size * args.moe_router_topk + shared_expert_ffn 65 | if hasattr(args, "moe_layer_freq"): 66 | if isinstance(args.moe_layer_freq, list): 67 | num_dense_layers = sum(1 for freq in args.moe_layer_freq if freq == 0) 68 | num_moe_layers = sum(1 for freq in args.moe_layer_freq if freq > 0) 69 | else: 70 | num_dense_layers = sum(1 for i in range(args.num_layers) if i % args.moe_layer_freq != 0) 71 | num_moe_layers = sum(1 for i in range(args.num_layers) if i % args.moe_layer_freq == 0) 72 | else: 73 | num_dense_layers = 0 74 | num_moe_layers = args.num_layers 75 | 76 | for seqlen in seqlens: 77 | if num_dense_layers > 0: 78 | total_flops += ( 79 | calculate_layer_flops( 80 | seqlen, 81 | hidden_size, 82 | num_attention_heads, 83 | num_query_groups, 84 | dense_ffn, 85 | ) 86 | * num_dense_layers 87 | ) 88 | 89 | if num_moe_layers > 0: 90 | total_flops += ( 91 | calculate_layer_flops( 92 | seqlen, 93 | hidden_size, 94 | num_attention_heads, 95 | num_query_groups, 96 | moe_ffn, 97 | ) 98 | * num_moe_layers 99 | ) 100 | 101 | total_flops += calculate_lm_head_flops(seqlen, hidden_size, vocab_size) 102 | 103 | return total_flops 104 | -------------------------------------------------------------------------------- /slime/utils/http_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import multiprocessing 3 | import random 4 | import socket 5 | 6 | import httpx 7 | 8 | 9 | def find_available_port(base_port: int): 10 | port = base_port + random.randint(100, 1000) 11 | while True: 12 | if is_port_available(port): 13 | return port 14 | if port < 60000: 15 | port += 42 16 | else: 17 | port -= 43 18 | 19 | 20 | def is_port_available(port): 21 | """Return whether a port is available.""" 22 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 23 | try: 24 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 25 | s.bind(("", port)) 26 | s.listen(1) 27 | return True 28 | except socket.error: 29 | return False 30 | except OverflowError: 31 | return False 32 | 33 | 34 | def get_host_info(): 35 | hostname = socket.gethostname() 36 | 37 | local_ip = socket.gethostbyname(hostname) 38 | 39 | return hostname, local_ip 40 | 41 | 42 | def run_router(args): 43 | try: 44 | from sglang_router.launch_router import launch_router 45 | 46 | router = launch_router(args) 47 | if router is None: 48 | return 1 49 | return 0 50 | except Exception as e: 51 | print(e) 52 | return 1 53 | 54 | 55 | def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None: 56 | """Terminate a process gracefully, with forced kill as fallback. 57 | 58 | Args: 59 | process: The process to terminate 60 | timeout: Seconds to wait for graceful termination before forcing kill 61 | """ 62 | if not process.is_alive(): 63 | return 64 | 65 | process.terminate() 66 | process.join(timeout=timeout) 67 | if process.is_alive(): 68 | process.kill() 69 | process.join() 70 | 71 | 72 | async def post(url, payload, use_http2=False, max_retries=60): 73 | # never timeout 74 | timeout = httpx.Timeout(None) 75 | max_retries = 60 76 | retry_count = 0 77 | while retry_count < max_retries: 78 | try: 79 | async with httpx.AsyncClient(http1=not use_http2, http2=use_http2, timeout=timeout) as client: 80 | response = await client.post(url, json=payload or {}) 81 | response.raise_for_status() 82 | try: 83 | output = response.json() 84 | except: 85 | output = response.text 86 | except Exception as e: 87 | retry_count += 1 88 | print(f"Error: {e}, retrying... (attempt {retry_count}/{max_retries})") 89 | if retry_count >= max_retries: 90 | print(f"Max retries ({max_retries}) reached, failing...") 91 | raise e 92 | await asyncio.sleep(1) 93 | continue 94 | break 95 | 96 | return output 97 | 98 | 99 | async def get(url, use_http2=False): 100 | # never timeout 101 | timeout = httpx.Timeout(None) 102 | async with httpx.AsyncClient(http1=not use_http2, http2=use_http2, timeout=timeout) as client: 103 | response = await client.get(url) 104 | response.raise_for_status() 105 | output = response.json() 106 | return output 107 | -------------------------------------------------------------------------------- /slime/utils/mask_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | from transformers import AutoTokenizer 4 | 5 | 6 | class MultiTurnLossMaskGenerator: 7 | def __init__(self, tokenizer: AutoTokenizer, tokenizer_type: str = "qwen"): 8 | self.tokenizer = tokenizer 9 | self.system_message_length, self.gen_token_length = self.get_system_message_length() 10 | self.tokenizer_type = tokenizer_type 11 | 12 | def get_response_lengths(self, loss_masks: List[List[int]]) -> List[int]: 13 | return [len(mask[mask.index(1) :]) if 1 in mask else 0 for mask in loss_masks] 14 | 15 | def find_all_sublist_indices(self, main_list, sublist): 16 | sublist_len = len(sublist) 17 | indices = [] 18 | for i in range(len(main_list) - sublist_len + 1): 19 | if main_list[i : i + sublist_len] == sublist: 20 | indices.append(i) 21 | return indices 22 | 23 | def get_system_message_length(self) -> Tuple[int, int]: 24 | test_string = "FOR TESTING ONLY" 25 | test_messages = [ 26 | {"role": "user", "content": test_string}, 27 | {"role": "user", "content": test_string}, 28 | ] 29 | raw_token_ids = self.tokenizer(test_string, add_special_tokens=False)["input_ids"] 30 | chat_template_token = self.tokenizer.apply_chat_template( 31 | test_messages, add_special_tokens=False, tokenize=False 32 | ) 33 | chat_template_token_ids = self.tokenizer(chat_template_token, add_special_tokens=False)["input_ids"] 34 | idx_1, idx_2 = self.find_all_sublist_indices(chat_template_token_ids, raw_token_ids) 35 | end_interval = len(chat_template_token_ids) - len(raw_token_ids) - idx_2 36 | gen_token_length = len( 37 | self.tokenizer.apply_chat_template( 38 | test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True 39 | ) 40 | ) - len(chat_template_token_ids) 41 | 42 | system_message_length = idx_1 - ((idx_2 - idx_1) - end_interval - len(raw_token_ids)) 43 | return system_message_length, gen_token_length 44 | 45 | def gen_multi_turn_loss_mask_qwen(self, messages: List[Dict]) -> Tuple[List[int], List[int]]: 46 | all_loss_masks = [] 47 | all_token_ids = [] 48 | 49 | for i, message in enumerate(messages): 50 | message_ids = self.tokenizer.apply_chat_template([message], tokenize=True) 51 | 52 | if message["role"] != "system" and i > 0: 53 | message_ids = message_ids[self.system_message_length :] 54 | 55 | if message["role"] == "assistant": 56 | loss_mask = [0] * self.gen_token_length + [1] * (len(message_ids) - self.gen_token_length) 57 | else: 58 | loss_mask = [0] * len(message_ids) 59 | 60 | all_loss_masks.extend(loss_mask) 61 | all_token_ids.extend(message_ids) 62 | 63 | return all_token_ids, all_loss_masks 64 | 65 | def gen_multi_turn_loss_mask_distill_qwen(self, messages: List[Dict]) -> Tuple[List[int], List[int]]: 66 | prompt = self.tokenizer.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True) 67 | response = messages[-1]["content"] 68 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] 69 | response_tokens = self.tokenizer(response, add_special_tokens=False)["input_ids"] 70 | 71 | response_length = len(response_tokens) 72 | token_ids = prompt_tokens + response_tokens 73 | loss_mask = [0] * len(prompt_tokens) + [1] * response_length 74 | return token_ids, loss_mask 75 | 76 | def get_loss_mask(self, messages: List[Dict]) -> List[int]: 77 | if self.tokenizer_type == "qwen": 78 | if "<|Assistant|>" in self.tokenizer.get_added_vocab(): 79 | return self.gen_multi_turn_loss_mask_distill_qwen(messages) 80 | 81 | return self.gen_multi_turn_loss_mask_qwen(messages) 82 | elif self.tokenizer_type == "distill_qwen": 83 | return self.gen_multi_turn_loss_mask_distill_qwen(messages) 84 | else: 85 | raise ValueError(f"Unsupported tokenizer type: {self.tokenizer_type}") 86 | 87 | def get_text_from_loss_mask(self, token_ids: List[int], loss_masks: List[int]) -> List[str]: 88 | selected_texts = [] 89 | current_tokens = [] 90 | 91 | for idx, mask in enumerate(loss_masks): 92 | if mask == 1: 93 | current_tokens.append(token_ids[idx]) 94 | elif current_tokens: 95 | selected_texts.append(self.tokenizer.decode(current_tokens)) 96 | current_tokens = [] 97 | 98 | if current_tokens: 99 | selected_texts.append(self.tokenizer.decode(current_tokens)) 100 | 101 | return selected_texts 102 | -------------------------------------------------------------------------------- /slime/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | def clear_memory(): 7 | torch.cuda.synchronize() 8 | gc.collect() 9 | torch.cuda.empty_cache() 10 | 11 | 12 | def available_memory(): 13 | free, total = torch.cuda.mem_get_info(torch.cuda.current_device()) 14 | return { 15 | "gpu": str(torch.cuda.current_device()), 16 | "total_GB": round(total / (1024**3), 2), 17 | "free_GB": round(free / (1024**3), 2), 18 | "used_GB": round((total - free) / (1024**3), 2), 19 | } 20 | 21 | 22 | def print_memory(msg): 23 | if dist.get_rank() == 0: 24 | print(f"Memory-Usage {msg}:", available_memory()) 25 | -------------------------------------------------------------------------------- /slime/utils/misc.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def load_function(path): 5 | """ 6 | Load a function from a module. 7 | :param path: The path to the function, e.g. "module.submodule.function". 8 | :return: The function object. 9 | """ 10 | module_path, _, attr = path.rpartition(".") 11 | module = importlib.import_module(module_path) 12 | return getattr(module, attr) 13 | 14 | 15 | class SingletonMeta(type): 16 | """ 17 | A metaclass for creating singleton classes. 18 | """ 19 | 20 | _instances = {} 21 | 22 | def __call__(cls, *args, **kwargs): 23 | if cls not in cls._instances: 24 | instance = super().__call__(*args, **kwargs) 25 | cls._instances[cls] = instance 26 | return cls._instances[cls] 27 | -------------------------------------------------------------------------------- /slime/utils/timer.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from functools import wraps 3 | from time import time 4 | 5 | from .misc import SingletonMeta 6 | 7 | __all__ = ["Timer", "timer"] 8 | 9 | 10 | class Timer(metaclass=SingletonMeta): 11 | def __init__(self): 12 | self.timers = {} 13 | self.start_time = {} 14 | 15 | def start(self, name): 16 | assert name not in self.timers, f"Timer {name} already started." 17 | self.start_time[name] = time() 18 | 19 | def end(self, name): 20 | assert name in self.start_time, f"Timer {name} not started." 21 | elapsed_time = time() - self.start_time[name] 22 | self.add(name, elapsed_time) 23 | del self.start_time[name] 24 | 25 | def reset(self, name=None): 26 | if name is None: 27 | self.timers = {} 28 | elif name in self.timers: 29 | del self.timers[name] 30 | 31 | def add(self, name, elapsed_time): 32 | if name not in self.timers: 33 | self.timers[name] = elapsed_time 34 | else: 35 | self.timers[name] += elapsed_time 36 | 37 | def log_dict(self): 38 | return self.timers 39 | 40 | @contextmanager 41 | def context(self, name): 42 | self.start(name) 43 | try: 44 | yield 45 | finally: 46 | self.end(name) 47 | 48 | 49 | def timer(name_or_func): 50 | """ 51 | Can be used either as a decorator or a context manager: 52 | 53 | @timer 54 | def func(): 55 | ... 56 | 57 | or 58 | 59 | with timer("block_name"): 60 | ... 61 | """ 62 | # When used as a context manager 63 | if isinstance(name_or_func, str): 64 | name = name_or_func 65 | return Timer().context(name) 66 | 67 | func = name_or_func 68 | 69 | @wraps(func) 70 | def wrapper(*args, **kwargs): 71 | with Timer().context(func.__name__): 72 | return func(*args, **kwargs) 73 | 74 | return wrapper 75 | -------------------------------------------------------------------------------- /slime/utils/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | from typing import Optional, Union 4 | 5 | import torch 6 | 7 | 8 | @dataclass 9 | class Sample: 10 | """The sample generated""" 11 | 12 | index: Optional[int] = None 13 | # prompt 14 | prompt: Union[str, list[dict[str, str]]] = "" 15 | tokens: list[int] = field(default_factory=list) 16 | # response 17 | response: str = "" 18 | response_length: int = 0 19 | label: Optional[str] = None 20 | reward: Optional[Union[float, dict[str, float]]] = None 21 | loss_mask: Optional[list[int]] = None 22 | 23 | class Status(Enum): 24 | PENDING = "pending" 25 | COMPLETED = "completed" 26 | TRUNCATED = "truncated" 27 | ABORTED = "aborted" 28 | 29 | status: Status = Status.PENDING 30 | metadata: dict = field(default_factory=dict) 31 | 32 | def to_dict(self): 33 | value = self.__dict__.copy() 34 | value["status"] = self.status.value 35 | return value 36 | 37 | @staticmethod 38 | def from_dict(data: dict): 39 | data["status"] = Sample.Status(data["status"]) 40 | return Sample(**data) 41 | 42 | 43 | @dataclass 44 | class ParamInfo: 45 | name: str 46 | dtype: torch.dtype 47 | shape: torch.Size 48 | attrs: dict 49 | size: int 50 | src_rank: int 51 | -------------------------------------------------------------------------------- /slime_plugins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime_plugins/__init__.py -------------------------------------------------------------------------------- /slime_plugins/mbridge/__init__.py: -------------------------------------------------------------------------------- 1 | from .glm4 import GLM4Bridge -------------------------------------------------------------------------------- /slime_plugins/mbridge/glm4.py: -------------------------------------------------------------------------------- 1 | from mbridge.core import LLMBridge, register_model 2 | from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec 3 | 4 | 5 | @register_model("glm4") 6 | class GLM4Bridge(LLMBridge): 7 | """ 8 | Bridge implementation for Qwen2 models. 9 | 10 | This class extends LLMBridge to provide specific configurations and 11 | optimizations for Qwen2 models, handling the conversion between 12 | Hugging Face Qwen2 format and Megatron-Core. 13 | """ 14 | 15 | _DIRECT_MAPPING = { 16 | "embedding.word_embeddings.weight": "model.embed_tokens.weight", 17 | "decoder.final_layernorm.weight": "model.norm.weight", 18 | "output_layer.weight": "lm_head.weight", 19 | } 20 | _ATTENTION_MAPPING = { 21 | "self_attention.linear_proj.weight": [ 22 | "model.layers.{layer_number}.self_attn.o_proj.weight" 23 | ], 24 | "self_attention.linear_qkv.layer_norm_weight": [ 25 | "model.layers.{layer_number}.input_layernorm.weight" 26 | ], 27 | "self_attention.q_layernorm.weight": [ 28 | "model.layers.{layer_number}.self_attn.q_norm.weight" 29 | ], 30 | "self_attention.k_layernorm.weight": [ 31 | "model.layers.{layer_number}.self_attn.k_norm.weight" 32 | ], 33 | "self_attention.linear_qkv.weight": [ 34 | "model.layers.{layer_number}.self_attn.q_proj.weight", 35 | "model.layers.{layer_number}.self_attn.k_proj.weight", 36 | "model.layers.{layer_number}.self_attn.v_proj.weight", 37 | ], 38 | "self_attention.linear_qkv.bias": [ 39 | "model.layers.{layer_number}.self_attn.q_proj.bias", 40 | "model.layers.{layer_number}.self_attn.k_proj.bias", 41 | "model.layers.{layer_number}.self_attn.v_proj.bias", 42 | ], 43 | } 44 | _MLP_MAPPING = { 45 | "mlp.linear_fc1.weight": [ 46 | "model.layers.{layer_number}.mlp.gate_up_proj.weight", 47 | ], 48 | "mlp.linear_fc1.layer_norm_weight": [ 49 | "model.layers.{layer_number}.post_attention_layernorm.weight" 50 | ], 51 | "mlp.linear_fc2.weight": ["model.layers.{layer_number}.mlp.down_proj.weight"], 52 | } 53 | 54 | def _build_config(self): 55 | """ 56 | Build the configuration for Qwen2 models. 57 | 58 | Configures Qwen2-specific parameters such as QKV bias settings and 59 | layer normalization options. 60 | 61 | Returns: 62 | TransformerConfig: Configuration object for Qwen2 models 63 | """ 64 | return self._build_base_config( 65 | # qwen2 66 | add_qkv_bias=True, 67 | qk_layernorm=False, 68 | post_mlp_layernorm=True, 69 | post_self_attn_layernorm=True, 70 | rotary_interleaved=True, 71 | ) 72 | 73 | def _get_transformer_layer_spec(self): 74 | """ 75 | Gets the transformer layer specification. 76 | 77 | Creates and returns a specification for the transformer layers based on 78 | the current configuration. 79 | 80 | Returns: 81 | TransformerLayerSpec: Specification for transformer layers 82 | 83 | Raises: 84 | AssertionError: If normalization is not RMSNorm 85 | """ 86 | transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( 87 | post_self_attn_layernorm=True, 88 | post_mlp_layernorm=True, 89 | ) 90 | return transformer_layer_spec 91 | 92 | def _weight_name_mapping_mcore_to_hf(self, mcore_weights_name: str) -> list[str]: 93 | """ 94 | Map MCore weight names to Hugging Face weight names. 95 | 96 | Args: 97 | mcore_weights_name: MCore weight name 98 | 99 | Returns: 100 | list: Corresponding Hugging Face weight names 101 | """ 102 | assert ( 103 | "_extra_state" not in mcore_weights_name 104 | ), "extra_state should not be loaded" 105 | 106 | if mcore_weights_name in self._DIRECT_MAPPING: 107 | return [self._DIRECT_MAPPING[mcore_weights_name]] 108 | 109 | if "post_self_attn_layernorm" in mcore_weights_name: 110 | layer_number = mcore_weights_name.split(".")[2] 111 | return [ 112 | f"model.layers.{layer_number}.post_self_attn_layernorm.weight" 113 | ] 114 | elif "post_mlp_layernorm" in mcore_weights_name: 115 | layer_number = mcore_weights_name.split(".")[2] 116 | return [ 117 | f"model.layers.{layer_number}.post_mlp_layernorm.weight" 118 | ] 119 | elif "self_attention" in mcore_weights_name: 120 | return self._weight_name_mapping_attention(mcore_weights_name) 121 | elif "mlp" in mcore_weights_name: 122 | return self._weight_name_mapping_mlp(mcore_weights_name) 123 | else: 124 | raise NotImplementedError( 125 | f"Unsupported parameter name: {mcore_weights_name}" 126 | ) 127 | -------------------------------------------------------------------------------- /slime_plugins/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime_plugins/models/__init__.py -------------------------------------------------------------------------------- /slime_plugins/models/glm4.py: -------------------------------------------------------------------------------- 1 | from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec 2 | 3 | 4 | def get_glm_spec(args): 5 | transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( 6 | args.num_experts, 7 | args.moe_grouped_gemm, 8 | args.qk_layernorm, 9 | args.multi_latent_attention, 10 | args.moe_use_legacy_grouped_gemm, 11 | post_self_attn_layernorm=args.post_self_attn_layernorm, 12 | post_mlp_layernorm=args.post_mlp_layernorm, 13 | ) 14 | return transformer_layer_spec 15 | -------------------------------------------------------------------------------- /slime_plugins/rollout_buffer/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_generator import BaseGenerator, query_single_turn 2 | from .reward_utils import get_rule_based_math_reward 3 | from .utils.arguments import add_arguments 4 | 5 | __all__ = [ 6 | "BaseGenerator", 7 | "query_single_turn", 8 | "get_rule_based_math_reward", 9 | "add_arguments", 10 | ] 11 | -------------------------------------------------------------------------------- /slime_plugins/rollout_buffer/generator/reward_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .math_utils import get_rule_based_math_reward 2 | 3 | __all__ = ["get_rule_based_math_reward"] 4 | -------------------------------------------------------------------------------- /slime_plugins/rollout_buffer/generator/utils/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def add_arguments(add_task_arguments=None): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--task_type", type=str, default="math") 7 | parser.add_argument("--input_file", type=str, default="None") 8 | parser.add_argument("--num_epoch", type=int, default=1) 9 | parser.add_argument("--num_repeat_per_sample", type=int, default=4) 10 | parser.add_argument("--num_process", type=int, default=4) 11 | parser.add_argument("--remote_engine_url", type=str, default="http://0.0.0.0:8000/v1") 12 | parser.add_argument("--remote_buffer_url", type=str, default="http://localhost:8888") 13 | parser.add_argument("--max_tokens", type=int, default=4096) 14 | parser.add_argument("--num_repeats", type=int, default=20) 15 | 16 | if add_task_arguments is not None: 17 | parser = add_task_arguments(parser) 18 | 19 | args = parser.parse_args() 20 | return args 21 | -------------------------------------------------------------------------------- /slime_plugins/rollout_buffer/tools/assign_instance_id.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | 6 | def main(input_path, task_type="math", output_path=None): 7 | input_path = Path(input_path) 8 | if output_path is None: 9 | output_path = str(input_path).replace(".jsonl", "_processed.jsonl") 10 | used_ids = set() 11 | processed = [] 12 | 13 | # First pass: load all lines and collect existing instance_ids 14 | with open(input_path, "r", encoding="utf-8") as f: 15 | for line in f: 16 | item = json.loads(line) 17 | if "instance_id" in item: 18 | used_ids.add(item["instance_id"]) 19 | processed.append(item) 20 | 21 | # Second pass: assign missing instance_ids 22 | counter = 0 23 | for item in processed: 24 | if "instance_id" not in item: 25 | # Find unused id 26 | while True: 27 | candidate_id = f"{task_type}_{counter}" 28 | counter += 1 29 | if candidate_id not in used_ids: 30 | item["instance_id"] = candidate_id 31 | used_ids.add(candidate_id) 32 | break 33 | 34 | # Save to new jsonl file 35 | with open(output_path, "w", encoding="utf-8") as f: 36 | for item in processed: 37 | f.write(json.dumps(item, ensure_ascii=False) + "\n") 38 | 39 | print(f"✅ Processed {len(processed)} items. Saved to {output_path}") 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--input_path", type=str, help="Path to input JSONL file") 45 | parser.add_argument( 46 | "--task_type", 47 | type=str, 48 | default="math", 49 | help="Task type prefix for new instance_id", 50 | ) 51 | parser.add_argument("--output_path", type=str, default=None, help="Optional path to output file") 52 | args = parser.parse_args() 53 | 54 | main(args.input_path, args.task_type, args.output_path) 55 | -------------------------------------------------------------------------------- /slime_plugins/rollout_buffer/tools/visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import time 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class BufferStatsVisualizer: 9 | def __init__(self, time_window=60): 10 | """ 11 | Initialize the buffer statistics visualizer 12 | 13 | Args: 14 | time_window (int): Time window in seconds for each data point (default: 60s) 15 | """ 16 | self.time_window = time_window 17 | self.data_points = [] # List to store data points 18 | self.timestamps = [] # List to store timestamps 19 | self.start_time = time.time() 20 | self.last_window_start = self.start_time 21 | self.window_count = 0 # Counter for current window 22 | self.args_dict = None # Store args for filename 23 | 24 | # Initialize the plot 25 | plt.ion() # Enable interactive mode 26 | self.fig, self.ax = plt.subplots(figsize=(12, 6)) 27 | (self.line,) = self.ax.plot([], [], "b-", label="Data Points per Window") 28 | 29 | # Set up the plot 30 | self.ax.set_xlabel("Time (minutes)") 31 | self.ax.set_ylabel("Data Points per 60s Window") 32 | self.ax.set_title("Buffer Statistics - Data Points per 60s Window") 33 | self.ax.grid(True) 34 | self.ax.legend() 35 | 36 | # Start the update thread 37 | self.running = True 38 | self.update_thread = threading.Thread(target=self._update_plot) 39 | self.update_thread.daemon = True 40 | self.update_thread.start() 41 | 42 | def set_args(self, args_dict): 43 | """Set the args dictionary for filename generation""" 44 | self.args_dict = args_dict 45 | 46 | def add_data_point(self, _): 47 | """Add a new data point to the statistics""" 48 | current_time = time.time() 49 | self.window_count += 1 # Increment counter for current window 50 | 51 | # Check if we've reached the end of a time window 52 | if current_time - self.last_window_start >= self.time_window: 53 | # Calculate the time in minutes since start 54 | time_in_minutes = (current_time - self.start_time) / 60 55 | 56 | # Add the data point and timestamp 57 | self.data_points.append(self.window_count) 58 | self.timestamps.append(time_in_minutes) 59 | 60 | # Save the plot after adding new data point 61 | self.save_plot() 62 | 63 | # Reset for next window 64 | self.last_window_start = current_time 65 | self.window_count = 0 66 | 67 | def _update_plot(self): 68 | """Update the plot periodically""" 69 | while self.running: 70 | if self.data_points: # Only update if we have data 71 | self.line.set_data(self.timestamps, self.data_points) 72 | self.ax.relim() 73 | self.ax.autoscale_view() 74 | self.fig.canvas.draw() 75 | self.fig.canvas.flush_events() 76 | time.sleep(1) # Update every second 77 | 78 | def save_plot(self): 79 | """Save the current plot to a file""" 80 | timestamp = self.start_time 81 | 82 | # Create filename based on args and timestamp 83 | if self.args_dict: 84 | # Extract key parameters from args 85 | key_params = [] 86 | for key in ["task_type", "num_repeat_per_sample", "group_size"]: 87 | if key in self.args_dict: 88 | key_params.append(f"{key}_{self.args_dict[key]}") 89 | 90 | filename = f"buffer_stats_{'_'.join(key_params)}_{timestamp}.png" 91 | else: 92 | filename = f"buffer_stats_{timestamp}.png" 93 | 94 | # Create directory if it doesn't exist 95 | os.makedirs("buffer_stats", exist_ok=True) 96 | filepath = os.path.join("buffer_stats", filename) 97 | 98 | # Save the plot 99 | plt.savefig(filepath, dpi=300, bbox_inches="tight") 100 | print(f"Plot saved to {filepath}") 101 | 102 | def close(self): 103 | """Close the visualizer and clean up""" 104 | self.running = False 105 | if self.update_thread.is_alive(): 106 | self.update_thread.join() 107 | plt.close(self.fig) 108 | 109 | 110 | # Example usage: 111 | if __name__ == "__main__": 112 | visualizer = BufferStatsVisualizer(time_window=60) 113 | visualizer.set_args({"task_type": "test", "num_repeat_per_sample": 16}) 114 | 115 | # Simulate some data 116 | try: 117 | for i in range(1000): 118 | visualizer.add_data_point(1) # Just increment the counter 119 | time.sleep(0.1) 120 | except KeyboardInterrupt: 121 | pass 122 | finally: 123 | visualizer.close() 124 | -------------------------------------------------------------------------------- /tests/test_qwen3_0.6B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for rerun the task 4 | pkill -9 sglang 5 | sleep 3 6 | ray stop --force 7 | pkill -9 ray 8 | pkill -9 python 9 | sleep 3 10 | pkill -9 ray 11 | pkill -9 python 12 | 13 | set -ex 14 | 15 | # will prevent ray from buffering stdout/stderr 16 | export PYTHONBUFFERED=16 17 | 18 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" 19 | source "${SCRIPT_DIR}/../scripts/models/qwen3-0.6B.sh" 20 | 21 | CKPT_ARGS=( 22 | --hf-checkpoint /root/Qwen3-0.6B 23 | --ref-load /root/Qwen3-0.6B_torch_dist 24 | ) 25 | 26 | ROLLOUT_ARGS=( 27 | --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl 28 | --input-key prompt 29 | --label-key label 30 | --apply-chat-template 31 | --rollout-shuffle 32 | --rm-type deepscaler 33 | --num-rollout 3000 34 | --rollout-batch-size 32 35 | --n-samples-per-prompt 8 36 | --rollout-max-response-len 8192 37 | --rollout-temperature 0.8 38 | 39 | --over-sampling-batch-size 64 40 | --dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std 41 | #--partial-rollout 42 | 43 | --global-batch-size 256 44 | #--balance-data 45 | ) 46 | 47 | EVAL_ARGS=( 48 | --eval-interval 20 49 | --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl 50 | --n-samples-per-eval-prompt 1 51 | --eval-max-response-len 16384 52 | --eval-temperature 0 53 | ) 54 | 55 | PERF_ARGS=( 56 | --tensor-model-parallel-size 1 57 | --sequence-parallel 58 | --pipeline-model-parallel-size 1 59 | --context-parallel-size 1 60 | --expert-model-parallel-size 1 61 | --expert-tensor-parallel-size 1 62 | 63 | --recompute-granularity full 64 | --recompute-method uniform 65 | --recompute-num-layers 1 66 | 67 | # --micro-batch-size 1 68 | --use-dynamic-batch-size 69 | --max-tokens-per-gpu 9216 70 | ) 71 | 72 | GRPO_ARGS=( 73 | --advantage-estimator grpo 74 | --use-kl-loss 75 | --kl-loss-coef 0.00 76 | --kl-loss-type low_var_kl 77 | --kl-coef 0.00 78 | --entropy-coef 0.00 79 | --eps-clip 0.2 80 | --eps-clip-high 0.28 81 | ) 82 | 83 | OPTIMIZER_ARGS=( 84 | --optimizer adam 85 | --lr 1e-6 86 | --lr-decay-style constant 87 | --weight-decay 0.1 88 | --adam-beta1 0.9 89 | --adam-beta2 0.98 90 | ) 91 | 92 | WANDB_ARGS=( 93 | #--use-wandb 94 | --wandb-project slime-test 95 | --wandb-group test-qwen-3-0.6B 96 | ) 97 | 98 | SGLANG_ARGS=( 99 | --rollout-num-gpus-per-engine 1 100 | --sglang-mem-fraction-static 0.7 101 | ) 102 | 103 | MISC_ARGS=( 104 | # default dropout in megatron is 0.1 105 | --attention-dropout 0.0 106 | --hidden-dropout 0.0 107 | # should be good for model performance 108 | --accumulate-allreduce-grads-in-fp32 109 | --attention-softmax-in-fp32 110 | # need to comment this when using model with MLA 111 | --attention-backend flash 112 | ) 113 | 114 | # launch the master node of ray in container 115 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats 116 | 117 | ray job submit --address="http://127.0.0.1:8265" \ 118 | --runtime-env-json='{ 119 | "env_vars": { 120 | "PYTHONPATH": "/root/Megatron-LM", 121 | "CUDA_DEVICE_MAX_CONNECTIONS": "1" 122 | } 123 | }' \ 124 | -- python3 train.py \ 125 | --actor-num-nodes 1 \ 126 | --actor-num-gpus-per-node 1 \ 127 | --colocate \ 128 | ${MODEL_ARGS[@]} \ 129 | ${CKPT_ARGS[@]} \ 130 | ${ROLLOUT_ARGS[@]} \ 131 | ${OPTIMIZER_ARGS[@]} \ 132 | ${GRPO_ARGS[@]} \ 133 | ${DISTRIBUTED_ARGS[@]} \ 134 | ${WANDB_ARGS[@]} \ 135 | ${PERF_ARGS[@]} \ 136 | ${EVAL_ARGS[@]} \ 137 | ${SGLANG_ARGS[@]} \ 138 | ${MISC_ARGS[@]} -------------------------------------------------------------------------------- /tools/convert_hf_to_torch_dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import inspect 4 | 5 | import torch 6 | import mbridge 7 | from mbridge import AutoBridge 8 | 9 | import slime_plugins.mbridge 10 | from megatron.core import parallel_state as mpu 11 | from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed 12 | from megatron.training.arguments import parse_args 13 | from megatron.training.checkpointing import get_checkpoint_name, get_checkpoint_tracker_filename, save_checkpoint 14 | from megatron.training.global_vars import set_args 15 | 16 | 17 | def init_distributed(): 18 | """Initialize distributed environment""" 19 | os.environ["RANK"] = "0" 20 | os.environ["WORLD_SIZE"] = "1" 21 | os.environ["MASTER_ADDR"] = "localhost" 22 | os.environ["MASTER_PORT"] = "12355" 23 | torch.distributed.init_process_group("nccl") 24 | mpu.initialize_model_parallel( 25 | tensor_model_parallel_size=1, 26 | virtual_pipeline_model_parallel_size=None, 27 | context_parallel_size=1, 28 | expert_model_parallel_size=1, 29 | ) 30 | model_parallel_cuda_manual_seed(0) 31 | 32 | 33 | def add_convertion_args(parser): 34 | """Add conversion arguments to the parser""" 35 | parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path") 36 | return parser 37 | 38 | 39 | def main(): 40 | # Parse command line arguments 41 | args = parse_args(add_convertion_args) 42 | args.use_dist_ckpt = args.ckpt_format != "torch" 43 | set_args(args) 44 | 45 | 46 | # Initialize distributed environment 47 | init_distributed() 48 | 49 | # Load model 50 | hf_model_path = args.hf_checkpoint 51 | bridge = AutoBridge.from_pretrained(hf_model_path) 52 | model = bridge.get_model() 53 | bridge.load_weights(model, hf_model_path) 54 | print(f"Model loaded: {hf_model_path}") 55 | 56 | save_checkpoint(1, model, None, None, 0) 57 | # change to release ckpt 58 | tracker_filename = get_checkpoint_tracker_filename(args.save) 59 | with open(tracker_filename, "w") as f: 60 | f.write("release") 61 | source_dir = get_checkpoint_name(args.save, 1, False, return_base_dir=True) 62 | target_dir = get_checkpoint_name(args.save, -1, True, return_base_dir=True) 63 | shutil.move(source_dir, target_dir) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /tools/convert_to_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from megatron.core import mpu 4 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 5 | 6 | import slime.backends.megatron_utils as megatron_utils 7 | from slime.backends.megatron_utils import update_weight_utils 8 | from slime.utils.arguments import parse_args 9 | 10 | 11 | def add_checkpoint_args(parser): 12 | parser.add_argument( 13 | "--output-dir", 14 | type=str, 15 | default=None, 16 | help="Directory to save the converted HF model.", 17 | ) 18 | parser.add_argument( 19 | "--check-same", 20 | action="store_true", 21 | default=False, 22 | help="Check if the converted model is the same as the original model.", 23 | ) 24 | return parser 25 | 26 | 27 | def main(args): 28 | megatron_utils.init(args) 29 | 30 | pp_size = mpu.get_pipeline_model_parallel_world_size() 31 | ep_size = mpu.get_expert_model_parallel_world_size() 32 | 33 | is_save_rank = ( 34 | mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 35 | ) 36 | 37 | # Setup the model and optimizer 38 | args.no_load_optim = True 39 | args.no_load_rng = True 40 | model, _, _, _ = megatron_utils.initialize_model_and_optimizer(args) 41 | 42 | hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) 43 | model_name = type(hf_config).__name__.lower() 44 | 45 | tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) 46 | 47 | vocab_size = tokenizer.vocab_size if args.vocab_size is None else args.vocab_size 48 | 49 | param_infos = update_weight_utils.get_param_infos(args, model) 50 | 51 | state_dict = {} 52 | rank = dist.get_rank() 53 | for info in param_infos: 54 | if dist.get_rank() == info.src_rank: 55 | for name_, param_ in update_weight_utils.named_parameters(args, model): 56 | if name_ == info.name: 57 | param = param_ 58 | break 59 | else: 60 | param = torch.empty(info.shape, dtype=info.dtype, device=torch.cuda.current_device()) 61 | 62 | if pp_size > 1: 63 | if info.src_rank in dist.get_process_group_ranks(mpu.get_pipeline_model_parallel_group()): 64 | torch.distributed.broadcast(param, src=info.src_rank, group=mpu.get_pipeline_model_parallel_group()) 65 | 66 | # broadcast params across ep ranks 67 | if ep_size > 1: 68 | if ".experts." in info.name: 69 | src_rank = ( 70 | info.src_rank 71 | if info.src_rank in dist.get_process_group_ranks(mpu.get_expert_model_parallel_group()) 72 | else rank 73 | ) 74 | torch.distributed.broadcast(param, src=src_rank, group=mpu.get_expert_model_parallel_group()) 75 | 76 | for key, value in info.attrs.items(): 77 | setattr(param, key, value) 78 | 79 | param = update_weight_utils.all_gather_param(info.name, param) 80 | param = update_weight_utils.remove_padding(info.name, param, vocab_size) 81 | # use torch.distributed 82 | if is_save_rank: 83 | converted_named_tensors = update_weight_utils.convert_to_hf(args, model_name, info.name, param) 84 | for name, param in converted_named_tensors: 85 | state_dict[name] = param.cpu() 86 | del param 87 | 88 | if is_save_rank: 89 | hf_model = AutoModelForCausalLM.from_pretrained( 90 | args.hf_checkpoint, torch_dtype="auto", device_map="cpu", trust_remote_code=True 91 | ) 92 | 93 | if args.check_same: 94 | for name, param in hf_model.named_parameters(): 95 | if name in state_dict: 96 | assert ( 97 | param.shape == state_dict[name].shape 98 | ), f"Shape mismatch for {name}: {param.shape} vs {state_dict[name].shape}" 99 | assert torch.all(param == state_dict[name]), f"Value mismatch for {name}" 100 | else: 101 | print(f"Warning: {name} not found in state_dict") 102 | 103 | if args.output_dir: 104 | tokenizer.save_pretrained(args.output_dir) 105 | print(hf_model.load_state_dict(state_dict, strict=False)) 106 | hf_model.save_pretrained(args.output_dir) 107 | 108 | dist.barrier() 109 | 110 | 111 | if __name__ == "__main__": 112 | args = parse_args(add_custom_arguments=add_checkpoint_args) 113 | main(args) 114 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import ray 2 | 3 | from slime.ray.placement_group import create_actor_group, create_placement_groups, create_rollout_group 4 | from slime.utils.arguments import parse_args 5 | 6 | 7 | def train(args): 8 | # allocate the GPUs 9 | pgs = create_placement_groups(args) 10 | 11 | actor_model = create_actor_group(args, pgs["actor"]) 12 | 13 | # create the rollout generator, with sglang engines inside. 14 | rollout_generator = create_rollout_group(args, pgs["rollout"]) 15 | 16 | # calculate num_rollout from num_epoch 17 | num_rollout_per_epoch = None 18 | if args.num_rollout is None: 19 | num_rollout_per_epoch = ray.get(rollout_generator.data_buffer.get_num_rollout_per_epoch.remote()) 20 | args.num_rollout = num_rollout_per_epoch * args.num_epoch 21 | assert args.num_rollout > 0 22 | 23 | # sync the initialization (model initalization, load checkpoint, etc.) 24 | # Note that we initialize it earlier as megatron ckpt loading may have really large peak memory usage. 25 | start_rollout_ids = ray.get( 26 | actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss) 27 | ) 28 | assert len(set(start_rollout_ids)) == 1 29 | if args.start_rollout_id is None: 30 | args.start_rollout_id = start_rollout_ids[0] 31 | 32 | if args.rollout_global_dataset: 33 | ray.get(rollout_generator.data_buffer.load.remote(args.start_rollout_id - 1)) 34 | 35 | # initialize the connection for weight update during training 36 | ray.get(actor_model.async_init_weight_update_connections(rollout_generator)) 37 | 38 | if args.offload: 39 | ray.get(rollout_generator.async_onload()) 40 | 41 | # always update weight first so that sglang has the loaded weights from training. 42 | ray.get(actor_model.async_update_weights()) 43 | 44 | # train loop. 45 | # note that for async training, one can change the position of the sync operation(ray.get). 46 | for rollout_id in range(args.start_rollout_id, args.num_rollout): 47 | if args.eval_interval is not None and rollout_id == 0: 48 | ray.get(rollout_generator.async_generate(rollout_id, evaluation=True)) 49 | ray.get(actor_model.async_eval(rollout_id)) 50 | 51 | ray.get(rollout_generator.async_generate(rollout_id)) 52 | 53 | if args.offload: 54 | ray.get(rollout_generator.async_offload()) 55 | 56 | ray.get(actor_model.async_train(rollout_id)) 57 | 58 | if args.save_interval is not None and ( 59 | (rollout_id + 1) % args.save_interval == 0 60 | or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0) 61 | ): 62 | ray.get(actor_model.async_save_model(rollout_id)) 63 | if args.rollout_global_dataset: 64 | ray.get(rollout_generator.data_buffer.save.remote(rollout_id)) 65 | 66 | if args.offload: 67 | ray.get(actor_model.async_offload()) 68 | ray.get(rollout_generator.async_onload()) 69 | 70 | ray.get(actor_model.async_update_weights()) 71 | 72 | if args.eval_interval is not None and ( 73 | (rollout_id + 1) % args.eval_interval == 0 74 | or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0) 75 | ): 76 | ray.get(rollout_generator.async_generate(rollout_id, evaluation=True)) 77 | ray.get(actor_model.async_eval(rollout_id)) 78 | 79 | 80 | if __name__ == "__main__": 81 | args = parse_args() 82 | train(args) 83 | -------------------------------------------------------------------------------- /train_async.py: -------------------------------------------------------------------------------- 1 | import ray 2 | 3 | from slime.ray.placement_group import create_actor_group, create_placement_groups, create_rollout_group 4 | from slime.utils.arguments import parse_args 5 | 6 | 7 | def train(args): 8 | assert not args.colocate, "Colocation is not supported for async training." 9 | # allocate the GPUs 10 | pgs = create_placement_groups(args) 11 | 12 | actor_model = create_actor_group(args, pgs["actor"]) 13 | 14 | # create the rollout generator, with sglang engines inside. 15 | rollout_generator = create_rollout_group(args, pgs["rollout"]) 16 | 17 | # calculate num_rollout from num_epoch 18 | num_rollout_per_epoch = None 19 | if args.num_rollout is None: 20 | num_rollout_per_epoch = ray.get(rollout_generator.data_buffer.get_num_rollout_per_epoch.remote()) 21 | args.num_rollout = num_rollout_per_epoch * args.num_epoch 22 | assert args.num_rollout > 0 23 | 24 | # sync the initialization (model initalization, load checkpoint, etc.) 25 | # Note that we initialize it earlier as megatron ckpt loading may have really large peak memory usage. 26 | start_rollout_ids = ray.get( 27 | actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss) 28 | ) 29 | assert len(set(start_rollout_ids)) == 1 30 | if args.start_rollout_id is None: 31 | args.start_rollout_id = start_rollout_ids[0] 32 | 33 | if args.rollout_global_dataset: 34 | ray.get(rollout_generator.data_buffer.load.remote(args.start_rollout_id - 1)) 35 | 36 | # initialize the connection for weight update during training 37 | ray.get(actor_model.async_init_weight_update_connections(rollout_generator)) 38 | 39 | # always update weight first so that sglang has the loaded weights from training. 40 | ray.get(actor_model.async_update_weights()) 41 | 42 | # async train loop. 43 | generation_handles = rollout_generator.async_generate(args.start_rollout_id) 44 | for rollout_id in range(args.start_rollout_id, args.num_rollout): 45 | # Sync the last generation 46 | ray.get(generation_handles) 47 | 48 | # This is a synchronous call to ensure that the rollout data is ready 49 | actor_model.get_rollout_data(rollout_id) 50 | 51 | # Start the next rollout early. 52 | if rollout_id + 1 < args.num_rollout: 53 | generation_handles = rollout_generator.async_generate(rollout_id + 1) 54 | 55 | ray.get(actor_model.async_train(rollout_id, with_data_fetching=False)) 56 | 57 | if args.save_interval is not None and ( 58 | (rollout_id + 1) % args.save_interval == 0 59 | or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0) 60 | ): 61 | ray.get(actor_model.async_save_model(rollout_id)) 62 | if args.rollout_global_dataset: 63 | ray.get(rollout_generator.data_buffer.save.remote(rollout_id)) 64 | 65 | if (rollout_id + 1) % args.update_weights_interval == 0: 66 | # sync generate before update weights to prevent update weight in the middle of generation 67 | ray.get(generation_handles) 68 | ray.get(actor_model.async_update_weights()) 69 | 70 | if args.eval_interval is not None and ( 71 | (rollout_id + 1) % args.eval_interval == 0 72 | or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0) 73 | ): 74 | ray.get(rollout_generator.async_generate(rollout_id, evaluation=True)) 75 | ray.get(actor_model.async_eval(rollout_id)) 76 | 77 | 78 | if __name__ == "__main__": 79 | args = parse_args() 80 | train(args) 81 | --------------------------------------------------------------------------------